From d65f6ec3ce526e458decb33cf852cd5b32721803 Mon Sep 17 00:00:00 2001 From: whaifree Date: Sun, 6 Oct 2024 22:16:42 +0800 Subject: [PATCH] =?UTF-8?q?=E5=87=8F=E5=B0=91=E8=AE=AD=E7=BB=83=E8=BF=87?= =?UTF-8?q?=E7=A8=8B=E4=B8=AD=E7=9A=84=E8=BE=93=E5=87=BA=E9=A2=91=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修改了训练脚本,将训练过程中的进度输出频率从每批输出改为每100批输出一次,以减少输出量,提高训练效率。 --- train.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/train.py b/train.py index 205195e..f5a6abd 100644 --- a/train.py +++ b/train.py @@ -222,17 +222,19 @@ for epoch in range(num_epochs): time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) epoch_time = time.time() - prev_time prev_time = time.time() - sys.stdout.write( - "\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s" - % ( - epoch, - num_epochs, - i, - len(loader['train']), - loss.item(), - time_left, + if step % 100 == 0: + sys.stdout.write( + "\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s" + % ( + epoch, + num_epochs, + i, + len(loader['train']), + loss.item(), + time_left, + ) ) - ) + # adjust the learning rate