diff --git a/train.py b/train.py index bf43c31..583fdb2 100644 --- a/train.py +++ b/train.py @@ -80,6 +80,9 @@ print(f"Clip gradient norm value: {clip_grad_norm_value}") print(f"Optimization step: {optim_step}") print(f"Optimization gamma: {optim_gamma}") +# 控制台输入 +model_str = input("Model: ") +print(f"Model: {model_str}") # Model device = 'cuda' if torch.cuda.is_available() else 'cpu'