From 6738c9057daa71631f4c59ba89e0c1f73973e992 Mon Sep 17 00:00:00 2001 From: zjut Date: Mon, 18 Nov 2024 09:29:10 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feat(train):=20=E6=B7=BB=E5=8A=A0=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E4=BF=9D=E5=AD=98=E8=B7=AF=E5=BE=84=E6=89=93=E5=8D=B0?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 优化模型保存逻辑,将保存路径存储在变量中 - 在保存模型后,打印模型的保存路径 - 这个改动可以帮助用户更容易地找到和管理训练好的模型文件 --- train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 5bf94a8..dfb5ec2 100644 --- a/train.py +++ b/train.py @@ -258,5 +258,7 @@ if True: 'BaseFuseLayer': BaseFuseLayer.state_dict(), 'DetailFuseLayer': DetailFuseLayer.state_dict(), } - torch.save(checkpoint, os.path.join("models/whaiFusion"+timestamp+'.pth')) + savepth = os.path.join("models/whaiFusion" + timestamp + '.pth'); + torch.save(checkpoint, savepth) + print("save model:{}".format(savepth)) From f87a65e68e44c3f2d998cbcb7b830aad629d807b Mon Sep 17 00:00:00 2001 From: zjut Date: Mon, 18 Nov 2024 09:31:58 +0800 Subject: [PATCH 2/2] =?UTF-8?q?refactor(test=5FIVF):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=BB=A3=E7=A0=81=E4=BB=A5=E6=8F=90=E9=AB=98?= =?UTF-8?q?=E7=81=B5=E6=B4=BB=E6=80=A7=E5=92=8C=E5=8F=AF=E7=BB=B4=E6=8A=A4?= =?UTF-8?q?=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 引入变量 pth_path 以动态构建模型权重路径 - 使用 pth_path替代直接使用时间戳创建输出文件夹 - 优化代码结构,提高可读性和可维护性 --- test_IVF.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test_IVF.py b/test_IVF.py index 04e10ce..09ad867 100644 --- a/test_IVF.py +++ b/test_IVF.py @@ -17,14 +17,16 @@ current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") os.environ["CUDA_VISIBLE_DEVICES"] = "0" -ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-15-17-48.pth" +pth_path = "whaiFusion11-15-17-48" +ckpt_path= r"/home/star/whaiDir/PFCFuse/models/"+pth_path+".pth" +print("path_pth:{}".format(ckpt_path)) for dataset_name in ["sar"]: print("\n"*2+"="*80) model_name="PFCFuse 最基本版本 " print("The test result of "+dataset_name+' :') test_folder = os.path.join('test_img', dataset_name) - test_out_folder=os.path.join('test_result',current_time,dataset_name) + test_out_folder=os.path.join('test_result',pth_path,dataset_name) device = 'cuda' if torch.cuda.is_available() else 'cpu' Encoder = nn.DataParallel(Restormer_Encoder()).to(device)