diff --git a/dataprocessing.py b/dataprocessing.py new file mode 100644 index 0000000..83bb220 --- /dev/null +++ b/dataprocessing.py @@ -0,0 +1,93 @@ +import os +import h5py +import numpy as np +from tqdm import tqdm +from skimage.io import imread + + +def get_img_file(file_name): + imagelist = [] + for parent, dirnames, filenames in os.walk(file_name): + for filename in filenames: + if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')): + imagelist.append(os.path.join(parent, filename)) + return imagelist + +def rgb2y(img): + y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000 + return y + +def Im2Patch(img, win, stride=1): + k = 0 + endc = img.shape[0] + endw = img.shape[1] + endh = img.shape[2] + patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride] + TotalPatNum = patch.shape[1] * patch.shape[2] + Y = np.zeros([endc, win*win,TotalPatNum], np.float32) + for i in range(win): + for j in range(win): + patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride] + Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum) + k = k + 1 + return Y.reshape([endc, win, win, TotalPatNum]) + +def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10, + upper_percentile=90): + """Determine if an image is low contrast.""" + limits = np.percentile(image, [lower_percentile, upper_percentile]) + ratio = (limits[1] - limits[0]) / limits[1] + return ratio < fraction_threshold + +data_name="MSRS_train" +img_size=128 #patch size +stride=200 #patch stride + +IR_files = sorted(get_img_file(r"MSRS_train/MSRS-main/train/ir")) +VIS_files = sorted(get_img_file(r"MSRS_train/MSRS-main/train/vi")) + +assert len(IR_files) == len(VIS_files) +h5f = h5py.File(os.path.join('.\\data', + data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'), + 'w') +h5_ir = h5f.create_group('ir_patchs') +h5_vis = h5f.create_group('vis_patchs') +train_num=0 +for i in tqdm(range(len(IR_files))): + I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32 + I_VIS = rgb2y(I_VIS) # [1, H, W] Float32 + I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32 + + # crop + I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride) + I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12) + + for ii in range(I_IR_Patch_Group.shape[-1]): + bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii]) + bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii]) + # Determine if the contrast is low + if not (bad_IR or bad_VIS): + avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR + avl_VIS= I_VIS_Patch_Group[0,:,:,ii] + avl_IR=avl_IR[None,...] + avl_VIS=avl_VIS[None,...] + + h5_ir.create_dataset(str(train_num), data=avl_IR, + dtype=avl_IR.dtype, shape=avl_IR.shape) + h5_vis.create_dataset(str(train_num), data=avl_VIS, + dtype=avl_VIS.dtype, shape=avl_VIS.shape) + train_num += 1 + +h5f.close() + +with h5py.File(os.path.join('data', + data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f: + for key in f.keys(): + print(f[key], key, f[key].name) + + + + + + + diff --git a/net.py b/net.py new file mode 100644 index 0000000..a359de9 --- /dev/null +++ b/net.py @@ -0,0 +1,427 @@ +# poolformer +import torch +import torch.nn as nn +import math +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from einops import rearrange + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + # work with diff dim tensors, not just 2D ConvNets + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + \ + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Pooling(nn.Module): + def __init__(self, kernel_size=3): + super().__init__() + self.pool = nn.AvgPool2d( + kernel_size, stride=1, padding=kernel_size // 2) + + def forward(self, x): + return self.pool(x) - x + + +class PoolMlp(nn.Module): + """ + Implementation of MLP with 1*1 convolutions. + Input: tensor with shape [B, C, H, W] + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + bias=False, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=bias) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=bias) + self.drop = nn.Dropout(drop) + # self.apply(self._init_weights) + + # def _init_weights(self, m): + # if isinstance(m, nn.Conv2D): + # trunc_normal_(m.weight) + # if m.bias is not None: + # zeros_(m.bias) + + def forward(self, x): + x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) # (B, C, H, W) --> (B, C, H, W) + x = self.drop(x) + return x + + +class BaseFeatureExtraction(nn.Module): + def __init__(self, dim, pool_size=3, mlp_ratio=4., + act_layer=nn.GELU, + # norm_layer=nn.LayerNorm, + drop=0., drop_path=0., + use_layer_scale=True, layer_scale_init_value=1e-5): + + super().__init__() + + self.norm1 = LayerNorm(dim, 'WithBias') + self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 + self.norm2 = LayerNorm(dim, 'WithBias') + mlp_hidden_dim = int(dim * mlp_ratio) + self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop) + + # The following two techniques are useful to train deep PoolFormers. + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.use_layer_scale = use_layer_scale + + if use_layer_scale: + self.layer_scale_1 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + self.layer_scale_2 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + def forward(self, x): + if self.use_layer_scale: + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) + * self.token_mixer(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) + * self.poolmlp(self.norm2(x))) + else: + x = x + self.drop_path(self.token_mixer(self.norm1(x))) + x = x + self.drop_path(self.poolmlp(self.norm2(x))) + return x + + +class InvertedResidualBlock(nn.Module): + def __init__(self, inp, oup, expand_ratio): + super(InvertedResidualBlock, self).__init__() + hidden_dim = int(inp * expand_ratio) + self.bottleneckBlock = nn.Sequential( + # pw + nn.Conv2d(inp, hidden_dim, 1, bias=False), + # nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # dw + nn.ReflectionPad2d(1), + nn.Conv2d(hidden_dim, hidden_dim, 3, groups=hidden_dim, bias=False), + # nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, bias=False), + # nn.BatchNorm2d(oup), + ) + + def forward(self, x): + return self.bottleneckBlock(x) + + +class DetailNode(nn.Module): + def __init__(self): + super(DetailNode, self).__init__() + + self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.shffleconv = nn.Conv2d(64, 64, kernel_size=1, + stride=1, padding=0, bias=True) + + def separateFeature(self, x): + z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] + return z1, z2 + + def forward(self, z1, z2): + z1, z2 = self.separateFeature( + self.shffleconv(torch.cat((z1, z2), dim=1))) + z2 = z2 + self.theta_phi(z1) + z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2) + return z1, z2 + + +class DetailFeatureExtraction(nn.Module): + def __init__(self, num_layers=3): + super(DetailFeatureExtraction, self).__init__() + INNmodules = [DetailNode() for _ in range(num_layers)] + self.net = nn.Sequential(*INNmodules) + + def forward(self, x): + z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] + for layer in self.net: + z1, z2 = layer(z1, z2) + return torch.cat((z1, z2), dim=1) + + +# ============================================================================= + +# ============================================================================= +import numbers + + +########################################################################## +## Layer Norm +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + + +def to_4d(x, h, w): + return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma + 1e-5) * self.weight + + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type == 'BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim * ffn_expansion_factor) + + self.project_in = nn.Conv2d( + dim, hidden_features * 2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, + stride=1, padding=1, groups=hidden_features * 2, bias=bias) + + self.project_out = nn.Conv2d( + hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', + head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', + head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', + head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', + head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + + return x + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, + stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + return x + + +class Restormer_Encoder(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + dim=64, + num_blocks=[4, 4], + heads=[8, 8, 8], + ffn_expansion_factor=2, + bias=False, + LayerNorm_type='WithBias', + ): + super(Restormer_Encoder, self).__init__() + + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + + self.encoder_level1 = nn.Sequential( + *[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) + + self.baseFeature = BaseFeatureExtraction(dim=dim) + + self.detailFeature = DetailFeatureExtraction() + + def forward(self, inp_img): + inp_enc_level1 = self.patch_embed(inp_img) + out_enc_level1 = self.encoder_level1(inp_enc_level1) + base_feature = self.baseFeature(out_enc_level1) + detail_feature = self.detailFeature(out_enc_level1) + return base_feature, detail_feature, out_enc_level1 + + +class Restormer_Decoder(nn.Module): + def __init__(self, + inp_channels=1, + out_channels=1, + dim=64, + num_blocks=[4, 4], + heads=[8, 8, 8], + ffn_expansion_factor=2, + bias=False, + LayerNorm_type='WithBias', + ): + + super(Restormer_Decoder, self).__init__() + self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias) + self.encoder_level2 = nn.Sequential( + *[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, + bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) + self.output = nn.Sequential( + nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3, + stride=1, padding=1, bias=bias), + nn.LeakyReLU(), + nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3, + stride=1, padding=1, bias=bias), ) + self.sigmoid = nn.Sigmoid() + + def forward(self, inp_img, base_feature, detail_feature): + out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1) + out_enc_level0 = self.reduce_channel(out_enc_level0) + out_enc_level1 = self.encoder_level2(out_enc_level0) + if inp_img is not None: + out_enc_level1 = self.output(out_enc_level1) + inp_img + else: + out_enc_level1 = self.output(out_enc_level1) + return self.sigmoid(out_enc_level1), out_enc_level0 + + +if __name__ == '__main__': + height = 128 + width = 128 + window_size = 8 + modelE = Restormer_Encoder().cuda() + modelD = Restormer_Decoder().cuda() + diff --git a/test_IVF.py b/test_IVF.py new file mode 100644 index 0000000..2bb4b7a --- /dev/null +++ b/test_IVF.py @@ -0,0 +1,95 @@ +import cv2 +from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction +import os +import numpy as np +from utils.Evaluator import Evaluator +import torch +import torch.nn as nn +from utils.img_read_save import img_save,image_read_cv2 +import warnings +import logging +# 增加 +warnings.filterwarnings("ignore") +logging.basicConfig(level=logging.CRITICAL) + + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +ckpt_path= r"models/PFCFuse.pth" + +for dataset_name in ["MSRS","TNO","RoadScene"]: + print("\n"*2+"="*80) + model_name="PFCFuse " + print("The test result of "+dataset_name+' :') + test_folder=os.path.join('test_img',dataset_name) + test_out_folder=os.path.join('test_result',dataset_name) + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + Encoder = nn.DataParallel(Restormer_Encoder()).to(device) + Decoder = nn.DataParallel(Restormer_Decoder()).to(device) + # BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) + BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device) + DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device) + + Encoder.load_state_dict(torch.load(ckpt_path)['DIDF_Encoder'],strict=False) + Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder']) + BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer']) + DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer']) + Encoder.eval() + Decoder.eval() + BaseFuseLayer.eval() + DetailFuseLayer.eval() + + with torch.no_grad(): + for img_name in os.listdir(os.path.join(test_folder,"ir")): + + data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0 + # 改 + data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0 + # ycrcb, uint8 + data_VIS_BGR = cv2.imread(os.path.join(test_folder, "vi", img_name)) + _, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb)) + # 改 + + data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS) + data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda() + + feature_V_B, feature_V_D, feature_V = Encoder(data_VIS) + feature_I_B, feature_I_D, feature_I = Encoder(data_IR) + feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B) + feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D) + data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D) + data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse)) + fi = np.squeeze((data_Fuse * 255).cpu().numpy()) + # 改 + # float32 to uint8 + fi = fi.astype(np.uint8) + ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb)) + rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB) + img_save(rgb_fi, img_name.split(sep='.')[0], test_out_folder) + # 改 + + eval_folder=test_out_folder + ori_img_folder=test_folder + + metric_result = np.zeros((8)) + for img_name in os.listdir(os.path.join(ori_img_folder,"ir")): + ir = image_read_cv2(os.path.join(ori_img_folder,"ir", img_name), 'GRAY') + vi = image_read_cv2(os.path.join(ori_img_folder,"vi", img_name), 'GRAY') + fi = image_read_cv2(os.path.join(eval_folder, img_name.split('.')[0]+".png"), 'GRAY') + metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi) + , Evaluator.SF(fi), Evaluator.MI(fi, ir, vi) + , Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi) + , Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)]) + + metric_result /= len(os.listdir(eval_folder)) + print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM") + print(model_name+'\t'+str(np.round(metric_result[0], 2))+'\t' + +str(np.round(metric_result[1], 2))+'\t' + +str(np.round(metric_result[2], 2))+'\t' + +str(np.round(metric_result[3], 2))+'\t' + +str(np.round(metric_result[4], 2))+'\t' + +str(np.round(metric_result[5], 2))+'\t' + +str(np.round(metric_result[6], 2))+'\t' + +str(np.round(metric_result[7], 2)) + ) + print("="*80) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..bfc9c50 --- /dev/null +++ b/train.py @@ -0,0 +1,239 @@ +# -*- coding: utf-8 -*- + +''' +------------------------------------------------------------------------------ +Import packages +------------------------------------------------------------------------------ +''' + +from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction +from utils.dataset import H5Dataset +import os +os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' +import sys +import time +import datetime +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from utils.loss import Fusionloss, cc, relative_diff_loss +import kornia + +print(torch.__version__) +print(torch.cuda.is_available()) + +''' +------------------------------------------------------------------------------ +Configure our network +------------------------------------------------------------------------------ +''' + + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' +criteria_fusion = Fusionloss() +model_str = 'PFCFuse' + +# . Set the hyper-parameters for training +num_epochs = 120 # total epoch +epoch_gap = 40 # epoches of Phase I + +lr = 1e-4 +weight_decay = 0 +batch_size = 1 +GPU_number = os.environ['CUDA_VISIBLE_DEVICES'] +# Coefficients of the loss function +coeff_mse_loss_VF = 1. # alpha1 +coeff_mse_loss_IF = 1. + +coeff_rmi_loss_VF = 1. +coeff_rmi_loss_IF = 1. + +coeff_cos_loss_VF = 1. +coeff_cos_loss_IF = 1. +coeff_decomp = 2. # alpha2 and alpha4 +coeff_tv = 5. + +clip_grad_norm_value = 0.01 +optim_step = 20 +optim_gamma = 0.5 + + +# Model +device = 'cuda' if torch.cuda.is_available() else 'cpu' +DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device) +DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device) +# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) +BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device) +DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device) + +# optimizer, scheduler and loss function +optimizer1 = torch.optim.Adam( + DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay) +optimizer2 = torch.optim.Adam( + DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay) +optimizer3 = torch.optim.Adam( + BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay) +optimizer4 = torch.optim.Adam( + DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay) + +scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma) +scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma) +scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma) +scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma) + +MSELoss = nn.MSELoss() +L1Loss = nn.L1Loss() +Loss_ssim = kornia.losses.SSIM(11, reduction='mean') +HuberLoss = nn.HuberLoss() + +# data loader +trainloader = DataLoader(H5Dataset(r"data/MSRS_train_imgsize_128_stride_200.h5"), + batch_size=batch_size, + shuffle=True, + num_workers=0) + +loader = {'train': trainloader, } +timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M") + +''' +------------------------------------------------------------------------------ +Train +------------------------------------------------------------------------------ +''' + +step = 0 +torch.backends.cudnn.benchmark = True +prev_time = time.time() + +start = time.time() +for epoch in range(num_epochs): + ''' train ''' + for i, (data_VIS, data_IR) in enumerate(loader['train']): + data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda() + DIDF_Encoder.train() + DIDF_Decoder.train() + BaseFuseLayer.train() + DetailFuseLayer.train() + + DIDF_Encoder.zero_grad() + DIDF_Decoder.zero_grad() + BaseFuseLayer.zero_grad() + DetailFuseLayer.zero_grad() + + optimizer1.zero_grad() + optimizer2.zero_grad() + optimizer3.zero_grad() + optimizer4.zero_grad() + + if epoch < epoch_gap: #Phase I + feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS) + feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR) + data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D) + data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D) + + cc_loss_B = cc(feature_V_B, feature_I_B) + cc_loss_D = cc(feature_V_D, feature_I_D) + + mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + HuberLoss(data_VIS, data_VIS_hat) + mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + HuberLoss(data_IR, data_IR_hat) + + # print("mse_loss_V", mse_loss_V) + # print("mse_loss_I", mse_loss_I) + + Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS), + kornia.filters.SpatialGradient()(data_VIS_hat)) + # print("Gradient_loss", Gradient_loss) + + loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B) + # print("loss_decomp", loss_decomp) + + + loss_rmi_v = relative_diff_loss(data_VIS, data_VIS_hat) + loss_rmi_i = relative_diff_loss(data_IR, data_IR_hat) + + loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \ + mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss + \ + coeff_rmi_loss_IF * loss_rmi_i + coeff_rmi_loss_VF * loss_rmi_v + + loss.backward() + nn.utils.clip_grad_norm_( + DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2) + nn.utils.clip_grad_norm_( + DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2) + optimizer1.step() + optimizer2.step() + else: #Phase II + feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS) + feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR) + feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B) + feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D) + data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D) + + + mse_loss_V = 5 * Loss_ssim(data_VIS, data_Fuse) + HuberLoss(data_VIS, data_Fuse) + mse_loss_I = 5 * Loss_ssim(data_IR, data_Fuse) + HuberLoss(data_IR, data_Fuse) + + cc_loss_B = cc(feature_V_B, feature_I_B) + cc_loss_D = cc(feature_V_D, feature_I_D) + loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B) + + fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse) + + loss = fusionloss + coeff_decomp * loss_decomp + loss.backward() + nn.utils.clip_grad_norm_( + DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2) + nn.utils.clip_grad_norm_( + DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2) + nn.utils.clip_grad_norm_( + BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2) + nn.utils.clip_grad_norm_( + DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2) + optimizer1.step() + optimizer2.step() + optimizer3.step() + optimizer4.step() + + # Determine approximate time left + batches_done = epoch * len(loader['train']) + i + batches_left = num_epochs * len(loader['train']) - batches_done + time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) + epoch_time = time.time() - prev_time + prev_time = time.time() + sys.stdout.write( + "\r[Epoch %d/%d] [Batch %d/%d] [loss: %f]" + % ( + epoch, + num_epochs, + i, + len(loader['train']), + loss.item(), + ) + ) + + # adjust the learning rate + + scheduler1.step() + scheduler2.step() + if not epoch < epoch_gap: + scheduler3.step() + scheduler4.step() + + if optimizer1.param_groups[0]['lr'] <= 1e-6: + optimizer1.param_groups[0]['lr'] = 1e-6 + if optimizer2.param_groups[0]['lr'] <= 1e-6: + optimizer2.param_groups[0]['lr'] = 1e-6 + if optimizer3.param_groups[0]['lr'] <= 1e-6: + optimizer3.param_groups[0]['lr'] = 1e-6 + if optimizer4.param_groups[0]['lr'] <= 1e-6: + optimizer4.param_groups[0]['lr'] = 1e-6 + +if True: + checkpoint = { + 'DIDF_Encoder': DIDF_Encoder.state_dict(), + 'DIDF_Decoder': DIDF_Decoder.state_dict(), + 'BaseFuseLayer': BaseFuseLayer.state_dict(), + 'DetailFuseLayer': DetailFuseLayer.state_dict(), + } + torch.save(checkpoint, os.path.join("models/PFCFusion"+timestamp+'.pth')) + diff --git a/utils/Evaluator.py b/utils/Evaluator.py new file mode 100644 index 0000000..dc729ee --- /dev/null +++ b/utils/Evaluator.py @@ -0,0 +1,305 @@ + +import numpy as np +import cv2 +import sklearn.metrics as skm +from scipy.signal import convolve2d +import math +from skimage.metrics import structural_similarity as ssim + +def image_read_cv2(path, mode='RGB'): + img_BGR = cv2.imread(path).astype('float32') + assert mode == 'RGB' or mode == 'GRAY' or mode == 'YCrCb', 'mode error' + if mode == 'RGB': + img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2RGB) + elif mode == 'GRAY': + img = np.round(cv2.cvtColor(img_BGR, cv2.COLOR_BGR2GRAY)) + elif mode == 'YCrCb': + img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2YCrCb) + return img + +class Evaluator(): + @classmethod + def input_check(cls, imgF, imgA=None, imgB=None): + if imgA is None: + assert type(imgF) == np.ndarray, 'type error' + assert len(imgF.shape) == 2, 'dimension error' + else: + assert type(imgF) == type(imgA) == type(imgB) == np.ndarray, 'type error' + assert imgF.shape == imgA.shape == imgB.shape, 'shape error' + assert len(imgF.shape) == 2, 'dimension error' + + @classmethod + def EN(cls, img): # entropy + cls.input_check(img) + a = np.uint8(np.round(img)).flatten() + h = np.bincount(a) / a.shape[0] + return -sum(h * np.log2(h + (h == 0))) + + @classmethod + def SD(cls, img): + cls.input_check(img) + return np.std(img) + + @classmethod + def SF(cls, img): + cls.input_check(img) + return np.sqrt(np.mean((img[:, 1:] - img[:, :-1]) ** 2) + np.mean((img[1:, :] - img[:-1, :]) ** 2)) + + @classmethod + def AG(cls, img): # Average gradient + cls.input_check(img) + Gx, Gy = np.zeros_like(img), np.zeros_like(img) + + Gx[:, 0] = img[:, 1] - img[:, 0] + Gx[:, -1] = img[:, -1] - img[:, -2] + Gx[:, 1:-1] = (img[:, 2:] - img[:, :-2]) / 2 + + Gy[0, :] = img[1, :] - img[0, :] + Gy[-1, :] = img[-1, :] - img[-2, :] + Gy[1:-1, :] = (img[2:, :] - img[:-2, :]) / 2 + return np.mean(np.sqrt((Gx ** 2 + Gy ** 2) / 2)) + + @classmethod + def MI(cls, image_F, image_A, image_B): + cls.input_check(image_F, image_A, image_B) + return skm.mutual_info_score(image_F.flatten(), image_A.flatten()) + skm.mutual_info_score(image_F.flatten(), + image_B.flatten()) + + @classmethod + def MSE(cls, image_F, image_A, image_B): # MSE + cls.input_check(image_F, image_A, image_B) + return (np.mean((image_A - image_F) ** 2) + np.mean((image_B - image_F) ** 2)) / 2 + + @classmethod + def CC(cls, image_F, image_A, image_B): + cls.input_check(image_F, image_A, image_B) + rAF = np.sum((image_A - np.mean(image_A)) * (image_F - np.mean(image_F))) / np.sqrt( + (np.sum((image_A - np.mean(image_A)) ** 2)) * (np.sum((image_F - np.mean(image_F)) ** 2))) + rBF = np.sum((image_B - np.mean(image_B)) * (image_F - np.mean(image_F))) / np.sqrt( + (np.sum((image_B - np.mean(image_B)) ** 2)) * (np.sum((image_F - np.mean(image_F)) ** 2))) + return (rAF + rBF) / 2 + + @classmethod + def PSNR(cls, image_F, image_A, image_B): + cls.input_check(image_F, image_A, image_B) + return 10 * np.log10(np.max(image_F) ** 2 / cls.MSE(image_F, image_A, image_B)) + + @classmethod + def SCD(cls, image_F, image_A, image_B): # The sum of the correlations of differences + cls.input_check(image_F, image_A, image_B) + imgF_A = image_F - image_A + imgF_B = image_F - image_B + corr1 = np.sum((image_A - np.mean(image_A)) * (imgF_B - np.mean(imgF_B))) / np.sqrt( + (np.sum((image_A - np.mean(image_A)) ** 2)) * (np.sum((imgF_B - np.mean(imgF_B)) ** 2))) + corr2 = np.sum((image_B - np.mean(image_B)) * (imgF_A - np.mean(imgF_A))) / np.sqrt( + (np.sum((image_B - np.mean(image_B)) ** 2)) * (np.sum((imgF_A - np.mean(imgF_A)) ** 2))) + return corr1 + corr2 + + @classmethod + def VIFF(cls, image_F, image_A, image_B): + cls.input_check(image_F, image_A, image_B) + return cls.compare_viff(image_A, image_F)+cls.compare_viff(image_B, image_F) + + @classmethod + def compare_viff(cls,ref, dist): # viff of a pair of pictures + sigma_nsq = 2 + eps = 1e-10 + + num = 0.0 + den = 0.0 + for scale in range(1, 5): + + N = 2 ** (4 - scale + 1) + 1 + sd = N / 5.0 + + # Create a Gaussian kernel as MATLAB's + m, n = [(ss - 1.) / 2. for ss in (N, N)] + y, x = np.ogrid[-m:m + 1, -n:n + 1] + h = np.exp(-(x * x + y * y) / (2. * sd * sd)) + h[h < np.finfo(h.dtype).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + win = h / sumh + + if scale > 1: + ref = convolve2d(ref, np.rot90(win, 2), mode='valid') + dist = convolve2d(dist, np.rot90(win, 2), mode='valid') + ref = ref[::2, ::2] + dist = dist[::2, ::2] + + mu1 = convolve2d(ref, np.rot90(win, 2), mode='valid') + mu2 = convolve2d(dist, np.rot90(win, 2), mode='valid') + mu1_sq = mu1 * mu1 + mu2_sq = mu2 * mu2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = convolve2d(ref * ref, np.rot90(win, 2), mode='valid') - mu1_sq + sigma2_sq = convolve2d(dist * dist, np.rot90(win, 2), mode='valid') - mu2_sq + sigma12 = convolve2d(ref * dist, np.rot90(win, 2), mode='valid') - mu1_mu2 + + sigma1_sq[sigma1_sq < 0] = 0 + sigma2_sq[sigma2_sq < 0] = 0 + + g = sigma12 / (sigma1_sq + eps) + sv_sq = sigma2_sq - g * sigma12 + + g[sigma1_sq < eps] = 0 + sv_sq[sigma1_sq < eps] = sigma2_sq[sigma1_sq < eps] + sigma1_sq[sigma1_sq < eps] = 0 + + g[sigma2_sq < eps] = 0 + sv_sq[sigma2_sq < eps] = 0 + + sv_sq[g < 0] = sigma2_sq[g < 0] + g[g < 0] = 0 + sv_sq[sv_sq <= eps] = eps + + num += np.sum(np.log10(1 + g * g * sigma1_sq / (sv_sq + sigma_nsq))) + den += np.sum(np.log10(1 + sigma1_sq / sigma_nsq)) + + vifp = num / den + + if np.isnan(vifp): + return 1.0 + else: + return vifp + + @classmethod + def Qabf(cls, image_F, image_A, image_B): + cls.input_check(image_F, image_A, image_B) + gA, aA = cls.Qabf_getArray(image_A) + gB, aB = cls.Qabf_getArray(image_B) + gF, aF = cls.Qabf_getArray(image_F) + QAF = cls.Qabf_getQabf(aA, gA, aF, gF) + QBF = cls.Qabf_getQabf(aB, gB, aF, gF) + + # 计算QABF + deno = np.sum(gA + gB) + nume = np.sum(np.multiply(QAF, gA) + np.multiply(QBF, gB)) + return nume / deno + + @classmethod + def Qabf_getArray(cls,img): + # Sobel Operator Sobel + h1 = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).astype(np.float32) + h2 = np.array([[0, 1, 2], [-1, 0, 1], [-2, -1, 0]]).astype(np.float32) + h3 = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).astype(np.float32) + + SAx = convolve2d(img, h3, mode='same') + SAy = convolve2d(img, h1, mode='same') + gA = np.sqrt(np.multiply(SAx, SAx) + np.multiply(SAy, SAy)) + aA = np.zeros_like(img) + aA[SAx == 0] = math.pi / 2 + aA[SAx != 0]=np.arctan(SAy[SAx != 0] / SAx[SAx != 0]) + return gA, aA + + @classmethod + def Qabf_getQabf(cls,aA, gA, aF, gF): + L = 1 + Tg = 0.9994 + kg = -15 + Dg = 0.5 + Ta = 0.9879 + ka = -22 + Da = 0.8 + GAF,AAF,QgAF,QaAF,QAF = np.zeros_like(aA),np.zeros_like(aA),np.zeros_like(aA),np.zeros_like(aA),np.zeros_like(aA) + GAF[gA>gF]=gF[gA>gF]/gA[gA>gF] + GAF[gA == gF] = gF[gA == gF] + GAF[gA 1: + refA = convolve2d(refA, np.rot90(win, 2), mode='valid') + refB = convolve2d(refB, np.rot90(win, 2), mode='valid') + dist = convolve2d(dist, np.rot90(win, 2), mode='valid') + refA = refA[::2, ::2] + refB = refB[::2, ::2] + dist = dist[::2, ::2] + + mu1A = convolve2d(refA, np.rot90(win, 2), mode='valid') + mu1B = convolve2d(refB, np.rot90(win, 2), mode='valid') + mu2 = convolve2d(dist, np.rot90(win, 2), mode='valid') + mu1_sq_A = mu1A * mu1A + mu1_sq_B = mu1B * mu1B + mu2_sq = mu2 * mu2 + mu1A_mu2 = mu1A * mu2 + mu1B_mu2 = mu1B * mu2 + sigma1A_sq = convolve2d(refA * refA, np.rot90(win, 2), mode='valid') - mu1_sq_A + sigma1B_sq = convolve2d(refB * refB, np.rot90(win, 2), mode='valid') - mu1_sq_B + sigma2_sq = convolve2d(dist * dist, np.rot90(win, 2), mode='valid') - mu2_sq + sigma12_A = convolve2d(refA * dist, np.rot90(win, 2), mode='valid') - mu1A_mu2 + sigma12_B = convolve2d(refB * dist, np.rot90(win, 2), mode='valid') - mu1B_mu2 + + sigma1A_sq[sigma1A_sq < 0] = 0 + sigma1B_sq[sigma1B_sq < 0] = 0 + sigma2_sq[sigma2_sq < 0] = 0 + + gA = sigma12_A / (sigma1A_sq + eps) + gB = sigma12_B / (sigma1B_sq + eps) + sv_sq_A = sigma2_sq - gA * sigma12_A + sv_sq_B = sigma2_sq - gB * sigma12_B + + gA[sigma1A_sq < eps] = 0 + gB[sigma1B_sq < eps] = 0 + sv_sq_A[sigma1A_sq < eps] = sigma2_sq[sigma1A_sq < eps] + sv_sq_B[sigma1B_sq < eps] = sigma2_sq[sigma1B_sq < eps] + sigma1A_sq[sigma1A_sq < eps] = 0 + sigma1B_sq[sigma1B_sq < eps] = 0 + + gA[sigma2_sq < eps] = 0 + gB[sigma2_sq < eps] = 0 + sv_sq_A[sigma2_sq < eps] = 0 + sv_sq_B[sigma2_sq < eps] = 0 + + sv_sq_A[gA < 0] = sigma2_sq[gA < 0] + sv_sq_B[gB < 0] = sigma2_sq[gB < 0] + gA[gA < 0] = 0 + gB[gB < 0] = 0 + sv_sq_A[sv_sq_A <= eps] = eps + sv_sq_B[sv_sq_B <= eps] = eps + + numA += np.sum(np.log10(1 + gA * gA * sigma1A_sq / (sv_sq_A + sigma_nsq))) + numB += np.sum(np.log10(1 + gB * gB * sigma1B_sq / (sv_sq_B + sigma_nsq))) + denA += np.sum(np.log10(1 + sigma1A_sq / sigma_nsq)) + denB += np.sum(np.log10(1 + sigma1B_sq / sigma_nsq)) + + vifpA = numA / denA + vifpB =numB / denB + + if np.isnan(vifpA): + vifpA=1 + if np.isnan(vifpB): + vifpB = 1 + return vifpA+vifpB diff --git a/utils/__pycache__/Evaluator.cpython-38.pyc b/utils/__pycache__/Evaluator.cpython-38.pyc new file mode 100644 index 0000000..a61cd27 Binary files /dev/null and b/utils/__pycache__/Evaluator.cpython-38.pyc differ diff --git a/utils/__pycache__/dataset.cpython-38.pyc b/utils/__pycache__/dataset.cpython-38.pyc new file mode 100644 index 0000000..e927550 Binary files /dev/null and b/utils/__pycache__/dataset.cpython-38.pyc differ diff --git a/utils/__pycache__/dataset_MIF1.cpython-38.pyc b/utils/__pycache__/dataset_MIF1.cpython-38.pyc new file mode 100644 index 0000000..6804b69 Binary files /dev/null and b/utils/__pycache__/dataset_MIF1.cpython-38.pyc differ diff --git a/utils/__pycache__/dataset_MIF4.cpython-38.pyc b/utils/__pycache__/dataset_MIF4.cpython-38.pyc new file mode 100644 index 0000000..f518693 Binary files /dev/null and b/utils/__pycache__/dataset_MIF4.cpython-38.pyc differ diff --git a/utils/__pycache__/img_read_save.cpython-38.pyc b/utils/__pycache__/img_read_save.cpython-38.pyc new file mode 100644 index 0000000..1fbcaee Binary files /dev/null and b/utils/__pycache__/img_read_save.cpython-38.pyc differ diff --git a/utils/__pycache__/loss.cpython-38.pyc b/utils/__pycache__/loss.cpython-38.pyc new file mode 100644 index 0000000..716f1da Binary files /dev/null and b/utils/__pycache__/loss.cpython-38.pyc differ diff --git a/utils/dataset.py b/utils/dataset.py new file mode 100644 index 0000000..666175f --- /dev/null +++ b/utils/dataset.py @@ -0,0 +1,22 @@ +import torch.utils.data as Data +import h5py +import numpy as np +import torch + +class H5Dataset(Data.Dataset): + def __init__(self, h5file_path): + self.h5file_path = h5file_path + h5f = h5py.File(h5file_path, 'r') + self.keys = list(h5f['ir_patchs'].keys()) + h5f.close() + + def __len__(self): + return len(self.keys) + + def __getitem__(self, index): + h5f = h5py.File(self.h5file_path, 'r') + key = self.keys[index] + IR = np.array(h5f['ir_patchs'][key]) + VIS = np.array(h5f['vis_patchs'][key]) + h5f.close() + return torch.Tensor(VIS), torch.Tensor(IR) \ No newline at end of file diff --git a/utils/img_read_save.py b/utils/img_read_save.py new file mode 100644 index 0000000..a68a15f --- /dev/null +++ b/utils/img_read_save.py @@ -0,0 +1,28 @@ +import numpy as np +import cv2 +import os +from skimage.io import imsave + +def image_read_cv2(path, mode='RGB'): + img_BGR = cv2.imread(path).astype('float32') + # img_BGR = cv2.imread(path) + # print(img_BGR) + # if img_BGR is not None: + # img_BGR = img_BGR.astype('float32') + # else: + # print("处理图像加载失败的情况") + assert mode == 'RGB' or mode == 'GRAY' or mode == 'YCrCb', 'mode error' + if mode == 'RGB': + img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2RGB) + elif mode == 'GRAY': + img = np.round(cv2.cvtColor(img_BGR, cv2.COLOR_BGR2GRAY)) + elif mode == 'YCrCb': + img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2YCrCb) + return img + +def img_save(image,imagename,savepath): + if not os.path.exists(savepath): + os.makedirs(savepath) + # Gray_pic + imsave(os.path.join(savepath, "{}.png".format(imagename)),image.astype(np.uint8)) + diff --git a/utils/loss.py b/utils/loss.py new file mode 100644 index 0000000..1776e45 --- /dev/null +++ b/utils/loss.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Fusionloss(nn.Module): + def __init__(self): + super(Fusionloss, self).__init__() + self.sobelconv=Sobelxy() + + def forward(self,image_vis,image_ir,generate_img): + # 加 + loss_rmi_v=relative_diff_loss(image_vis, generate_img) + loss_rmi_i=relative_diff_loss(image_ir, generate_img) + x_rmi_max=torch.max(loss_rmi_v, loss_rmi_i) + loss_rmi=F.l1_loss(x_rmi_max, generate_img) + # 加 + image_y=image_vis[:,:1,:,:] + x_in_max=torch.max(image_y,image_ir) + loss_in=F.l1_loss(x_in_max,generate_img) + y_grad=self.sobelconv(image_y) + ir_grad=self.sobelconv(image_ir) + generate_img_grad=self.sobelconv(generate_img) + x_grad_joint=torch.max(y_grad,ir_grad) + loss_grad=F.l1_loss(x_grad_joint,generate_img_grad) + # loss_total=loss_in+10*loss_grad + #改 + loss_total = loss_in + 10 * loss_grad + loss_rmi + return loss_total,loss_in,loss_grad + +class Sobelxy(nn.Module): + def __init__(self): + super(Sobelxy, self).__init__() + kernelx = [[-1, 0, 1], + [-2,0 , 2], + [-1, 0, 1]] + kernely = [[1, 2, 1], + [0,0 , 0], + [-1, -2, -1]] + kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0) + kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0) + self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda() + self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda() + def forward(self,x): + sobelx=F.conv2d(x, self.weightx, padding=1) + sobely=F.conv2d(x, self.weighty, padding=1) + return torch.abs(sobelx)+torch.abs(sobely) + + +def cc(img1, img2): + eps = torch.finfo(torch.float32).eps + """Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.].""" + N, C, _, _ = img1.shape + img1 = img1.reshape(N, C, -1) + img2 = img2.reshape(N, C, -1) + img1 = img1 - img1.mean(dim=-1, keepdim=True) + img2 = img2 - img2.mean(dim=-1, keepdim=True) + cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum(img1 ** + 2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1))) + cc = torch.clamp(cc, -1., 1.) + return cc.mean() + + +# def dice_coeff(img1, img2): +# smooth = 1. +# num = img1.size(0) +# m1 = img1.view(num, -1) # Flatten +# m2 = img2.view(num, -1) # Flatten +# intersection = (m1 * m2).sum() +# +# return 1 - (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth) + +# 用来衡量图像之间的平均灰度差异 +def relative_diff_loss(img1, img2): + # 计算图像的平均灰度值 + mean_intensity_img1 = torch.mean(img1) + mean_intensity_img2 = torch.mean(img2) + # print("mean_intensity_img1") + # print(mean_intensity_img1) + # print("mean_intensity_img2") + # print(mean_intensity_img2) + + # 计算relative_diff + epsilon = 1e-10 # 防止除零错误 + relative_diff = abs((mean_intensity_img1 - mean_intensity_img2) / (mean_intensity_img1 + epsilon)) + + return relative_diff + +# 互信息MI损失 +# def mutual_information_loss(img1, img2): +# # 计算 X 和 Y 的熵 +# entropy_img1 = -torch.mean(torch.sum(F.softmax(img1, dim=-1) * F.log_softmax(img1, dim=-1), dim=-1)) +# entropy_img2 = -torch.mean(torch.sum(F.softmax(img2, dim=-1) * F.log_softmax(img2, dim=-1), dim=-1)) +# +# # 计算 X 和 Y 的联合熵 +# joint_entropy = -torch.mean(torch.sum(F.softmax(img1, dim=-1) * F.log_softmax(img2, dim=-1), dim=-1)) +# +# # 计算互信息损失 +# mutual_information = entropy_img1 + entropy_img2 - joint_entropy +# +# return mutual_information + + +# # 余弦相似度计算 +# def cosine_similarity(img1, img2): +# # Flatten the tensors +# img1_flat = img1.view(-1) +# img2_flat = img2.view(-1) +# +# # Calculate cosine similarity +# similarity = Fine_similarity(img1_flat, img2_flat, dim=0) +# +# loss = torch.abs(similarity - 1) +# return loss.item() # Convert to Python float