diff --git a/test_IVF.py b/test_IVF.py index 2bb4b7a..4b84b1f 100644 --- a/test_IVF.py +++ b/test_IVF.py @@ -8,7 +8,6 @@ import torch.nn as nn from utils.img_read_save import img_save,image_read_cv2 import warnings import logging -# 增加 warnings.filterwarnings("ignore") logging.basicConfig(level=logging.CRITICAL) @@ -26,7 +25,6 @@ for dataset_name in ["MSRS","TNO","RoadScene"]: device = 'cuda' if torch.cuda.is_available() else 'cpu' Encoder = nn.DataParallel(Restormer_Encoder()).to(device) Decoder = nn.DataParallel(Restormer_Decoder()).to(device) - # BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device) DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device) @@ -43,12 +41,9 @@ for dataset_name in ["MSRS","TNO","RoadScene"]: for img_name in os.listdir(os.path.join(test_folder,"ir")): data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0 - # 改 data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0 - # ycrcb, uint8 data_VIS_BGR = cv2.imread(os.path.join(test_folder, "vi", img_name)) _, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb)) - # 改 data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS) data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda() @@ -60,13 +55,10 @@ for dataset_name in ["MSRS","TNO","RoadScene"]: data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D) data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse)) fi = np.squeeze((data_Fuse * 255).cpu().numpy()) - # 改 - # float32 to uint8 fi = fi.astype(np.uint8) ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb)) rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB) img_save(rgb_fi, img_name.split(sep='.')[0], test_out_folder) - # 改 eval_folder=test_out_folder ori_img_folder=test_folder @@ -92,4 +84,4 @@ for dataset_name in ["MSRS","TNO","RoadScene"]: +str(np.round(metric_result[6], 2))+'\t' +str(np.round(metric_result[7], 2)) ) - print("="*80) \ No newline at end of file + print("="*80)