Update test_IVF.py
This commit is contained in:
parent
656c8ba0a1
commit
d65393b3c4
10
test_IVF.py
10
test_IVF.py
@ -8,7 +8,6 @@ import torch.nn as nn
|
|||||||
from utils.img_read_save import img_save,image_read_cv2
|
from utils.img_read_save import img_save,image_read_cv2
|
||||||
import warnings
|
import warnings
|
||||||
import logging
|
import logging
|
||||||
# 增加
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
logging.basicConfig(level=logging.CRITICAL)
|
logging.basicConfig(level=logging.CRITICAL)
|
||||||
|
|
||||||
@ -26,7 +25,6 @@ for dataset_name in ["MSRS","TNO","RoadScene"]:
|
|||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||||
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||||
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
|
||||||
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
|
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
|
||||||
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||||
|
|
||||||
@ -43,12 +41,9 @@ for dataset_name in ["MSRS","TNO","RoadScene"]:
|
|||||||
for img_name in os.listdir(os.path.join(test_folder,"ir")):
|
for img_name in os.listdir(os.path.join(test_folder,"ir")):
|
||||||
|
|
||||||
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
||||||
# 改
|
|
||||||
data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0
|
data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0
|
||||||
# ycrcb, uint8
|
|
||||||
data_VIS_BGR = cv2.imread(os.path.join(test_folder, "vi", img_name))
|
data_VIS_BGR = cv2.imread(os.path.join(test_folder, "vi", img_name))
|
||||||
_, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb))
|
_, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb))
|
||||||
# 改
|
|
||||||
|
|
||||||
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
|
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
|
||||||
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||||||
@ -60,13 +55,10 @@ for dataset_name in ["MSRS","TNO","RoadScene"]:
|
|||||||
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
|
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
|
||||||
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
|
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
|
||||||
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
|
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
|
||||||
# 改
|
|
||||||
# float32 to uint8
|
|
||||||
fi = fi.astype(np.uint8)
|
fi = fi.astype(np.uint8)
|
||||||
ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb))
|
ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb))
|
||||||
rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB)
|
rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB)
|
||||||
img_save(rgb_fi, img_name.split(sep='.')[0], test_out_folder)
|
img_save(rgb_fi, img_name.split(sep='.')[0], test_out_folder)
|
||||||
# 改
|
|
||||||
|
|
||||||
eval_folder=test_out_folder
|
eval_folder=test_out_folder
|
||||||
ori_img_folder=test_folder
|
ori_img_folder=test_folder
|
||||||
@ -92,4 +84,4 @@ for dataset_name in ["MSRS","TNO","RoadScene"]:
|
|||||||
+str(np.round(metric_result[6], 2))+'\t'
|
+str(np.round(metric_result[6], 2))+'\t'
|
||||||
+str(np.round(metric_result[7], 2))
|
+str(np.round(metric_result[7], 2))
|
||||||
)
|
)
|
||||||
print("="*80)
|
print("="*80)
|
||||||
|
Loading…
Reference in New Issue
Block a user