Compare commits

...

4 Commits

Author SHA1 Message Date
zjut
73b398b811 test: 更新测试脚本以使用新的模型路径
- 修改了 pth_path 变量,使用新的模型路径 "whaiFusion11-17-16-16"
- 更新了模型名称显示,使用 pth_path 变量替代固定字符串
- 优化了测试结果输出格式,增加了换行和分隔线
2024-11-18 09:42:51 +08:00
zjut
8383c68ff7 Merge remote-tracking branch 'origin/base' into base_增加WTconv
# Conflicts:
#	test_IVF.py
2024-11-18 09:39:41 +08:00
zjut
f87a65e68e refactor(test_IVF): 重构测试代码以提高灵活性和可维护性
- 引入变量 pth_path 以动态构建模型权重路径
- 使用 pth_path替代直接使用时间戳创建输出文件夹
- 优化代码结构,提高可读性和可维护性
2024-11-18 09:31:58 +08:00
zjut
6738c9057d feat(train): 添加模型保存路径打印功能
- 优化模型保存逻辑,将保存路径存储在变量中
- 在保存模型后,打印模型的保存路径
- 这个改动可以帮助用户更容易地找到和管理训练好的模型文件
2024-11-18 09:29:10 +08:00
2 changed files with 8 additions and 4 deletions

View File

@ -17,14 +17,16 @@ current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-17-10-34.pth" pth_path = "whaiFusion11-17-16-16"
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/"+pth_path+".pth"
print("path_pth:{}".format(ckpt_path))
for dataset_name in ["sar"]: for dataset_name in ["sar"]:
print("\n"*2+"="*80) print("\n"*2+"="*80)
model_name="PFCFuse Enhance 增加widthblock" model_name=pth_path
print("The test result of "+dataset_name+' :') print("The test result of "+dataset_name+' :')
test_folder = os.path.join('test_img', dataset_name) test_folder = os.path.join('test_img', dataset_name)
test_out_folder=os.path.join('test_result',current_time,dataset_name) test_out_folder=os.path.join('test_result',pth_path,dataset_name)
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
Encoder = nn.DataParallel(Restormer_Encoder()).to(device) Encoder = nn.DataParallel(Restormer_Encoder()).to(device)

View File

@ -262,5 +262,7 @@ if True:
'BaseFuseLayer': BaseFuseLayer.state_dict(), 'BaseFuseLayer': BaseFuseLayer.state_dict(),
'DetailFuseLayer': DetailFuseLayer.state_dict(), 'DetailFuseLayer': DetailFuseLayer.state_dict(),
} }
torch.save(checkpoint, os.path.join("models/whaiFusion"+timestamp+'.pth')) savepth = os.path.join("models/whaiFusion" + timestamp + '.pth');
torch.save(checkpoint, savepth)
print("save model:{}".format(savepth))