Add files via upload
This commit is contained in:
parent
e6852193de
commit
da5da74611
93
dataprocessing.py
Normal file
93
dataprocessing.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import os
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from skimage.io import imread
|
||||||
|
|
||||||
|
|
||||||
|
def get_img_file(file_name):
|
||||||
|
imagelist = []
|
||||||
|
for parent, dirnames, filenames in os.walk(file_name):
|
||||||
|
for filename in filenames:
|
||||||
|
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
|
||||||
|
imagelist.append(os.path.join(parent, filename))
|
||||||
|
return imagelist
|
||||||
|
|
||||||
|
def rgb2y(img):
|
||||||
|
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
|
||||||
|
return y
|
||||||
|
|
||||||
|
def Im2Patch(img, win, stride=1):
|
||||||
|
k = 0
|
||||||
|
endc = img.shape[0]
|
||||||
|
endw = img.shape[1]
|
||||||
|
endh = img.shape[2]
|
||||||
|
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
|
||||||
|
TotalPatNum = patch.shape[1] * patch.shape[2]
|
||||||
|
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
|
||||||
|
for i in range(win):
|
||||||
|
for j in range(win):
|
||||||
|
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
|
||||||
|
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
|
||||||
|
k = k + 1
|
||||||
|
return Y.reshape([endc, win, win, TotalPatNum])
|
||||||
|
|
||||||
|
def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
|
||||||
|
upper_percentile=90):
|
||||||
|
"""Determine if an image is low contrast."""
|
||||||
|
limits = np.percentile(image, [lower_percentile, upper_percentile])
|
||||||
|
ratio = (limits[1] - limits[0]) / limits[1]
|
||||||
|
return ratio < fraction_threshold
|
||||||
|
|
||||||
|
data_name="MSRS_train"
|
||||||
|
img_size=128 #patch size
|
||||||
|
stride=200 #patch stride
|
||||||
|
|
||||||
|
IR_files = sorted(get_img_file(r"MSRS_train/MSRS-main/train/ir"))
|
||||||
|
VIS_files = sorted(get_img_file(r"MSRS_train/MSRS-main/train/vi"))
|
||||||
|
|
||||||
|
assert len(IR_files) == len(VIS_files)
|
||||||
|
h5f = h5py.File(os.path.join('.\\data',
|
||||||
|
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
|
||||||
|
'w')
|
||||||
|
h5_ir = h5f.create_group('ir_patchs')
|
||||||
|
h5_vis = h5f.create_group('vis_patchs')
|
||||||
|
train_num=0
|
||||||
|
for i in tqdm(range(len(IR_files))):
|
||||||
|
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
|
||||||
|
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
|
||||||
|
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
|
||||||
|
|
||||||
|
# crop
|
||||||
|
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
|
||||||
|
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
|
||||||
|
|
||||||
|
for ii in range(I_IR_Patch_Group.shape[-1]):
|
||||||
|
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
|
||||||
|
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
|
||||||
|
# Determine if the contrast is low
|
||||||
|
if not (bad_IR or bad_VIS):
|
||||||
|
avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR
|
||||||
|
avl_VIS= I_VIS_Patch_Group[0,:,:,ii]
|
||||||
|
avl_IR=avl_IR[None,...]
|
||||||
|
avl_VIS=avl_VIS[None,...]
|
||||||
|
|
||||||
|
h5_ir.create_dataset(str(train_num), data=avl_IR,
|
||||||
|
dtype=avl_IR.dtype, shape=avl_IR.shape)
|
||||||
|
h5_vis.create_dataset(str(train_num), data=avl_VIS,
|
||||||
|
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
|
||||||
|
train_num += 1
|
||||||
|
|
||||||
|
h5f.close()
|
||||||
|
|
||||||
|
with h5py.File(os.path.join('data',
|
||||||
|
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
print(f[key], key, f[key].name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
427
net.py
Normal file
427
net.py
Normal file
@ -0,0 +1,427 @@
|
|||||||
|
# poolformer
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||||
|
"""
|
||||||
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||||
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||||
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||||
|
'survival rate' as the argument.
|
||||||
|
"""
|
||||||
|
if drop_prob == 0. or not training:
|
||||||
|
return x
|
||||||
|
keep_prob = 1 - drop_prob
|
||||||
|
# work with diff dim tensors, not just 2D ConvNets
|
||||||
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||||
|
random_tensor = keep_prob + \
|
||||||
|
torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||||
|
random_tensor.floor_() # binarize
|
||||||
|
output = x.div(keep_prob) * random_tensor
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class DropPath(nn.Module):
|
||||||
|
"""
|
||||||
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, drop_prob=None):
|
||||||
|
super(DropPath, self).__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
|
||||||
|
class Pooling(nn.Module):
|
||||||
|
def __init__(self, kernel_size=3):
|
||||||
|
super().__init__()
|
||||||
|
self.pool = nn.AvgPool2d(
|
||||||
|
kernel_size, stride=1, padding=kernel_size // 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.pool(x) - x
|
||||||
|
|
||||||
|
|
||||||
|
class PoolMlp(nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of MLP with 1*1 convolutions.
|
||||||
|
Input: tensor with shape [B, C, H, W]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_features,
|
||||||
|
hidden_features=None,
|
||||||
|
out_features=None,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
bias=False,
|
||||||
|
drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=bias)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=bias)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
|
# self.apply(self._init_weights)
|
||||||
|
|
||||||
|
# def _init_weights(self, m):
|
||||||
|
# if isinstance(m, nn.Conv2D):
|
||||||
|
# trunc_normal_(m.weight)
|
||||||
|
# if m.bias is not None:
|
||||||
|
# zeros_(m.bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.fc2(x) # (B, C, H, W) --> (B, C, H, W)
|
||||||
|
x = self.drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFeatureExtraction(nn.Module):
|
||||||
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
# norm_layer=nn.LayerNorm,
|
||||||
|
drop=0., drop_path=0.,
|
||||||
|
use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
|
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer, drop=drop)
|
||||||
|
|
||||||
|
# The following two techniques are useful to train deep PoolFormers.
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||||
|
else nn.Identity()
|
||||||
|
self.use_layer_scale = use_layer_scale
|
||||||
|
|
||||||
|
if use_layer_scale:
|
||||||
|
self.layer_scale_1 = nn.Parameter(
|
||||||
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
|
self.layer_scale_2 = nn.Parameter(
|
||||||
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_layer_scale:
|
||||||
|
x = x + self.drop_path(
|
||||||
|
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
* self.token_mixer(self.norm1(x)))
|
||||||
|
x = x + self.drop_path(
|
||||||
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
* self.poolmlp(self.norm2(x)))
|
||||||
|
else:
|
||||||
|
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
|
||||||
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class InvertedResidualBlock(nn.Module):
|
||||||
|
def __init__(self, inp, oup, expand_ratio):
|
||||||
|
super(InvertedResidualBlock, self).__init__()
|
||||||
|
hidden_dim = int(inp * expand_ratio)
|
||||||
|
self.bottleneckBlock = nn.Sequential(
|
||||||
|
# pw
|
||||||
|
nn.Conv2d(inp, hidden_dim, 1, bias=False),
|
||||||
|
# nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.ReLU6(inplace=True),
|
||||||
|
# dw
|
||||||
|
nn.ReflectionPad2d(1),
|
||||||
|
nn.Conv2d(hidden_dim, hidden_dim, 3, groups=hidden_dim, bias=False),
|
||||||
|
# nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.ReLU6(inplace=True),
|
||||||
|
# pw-linear
|
||||||
|
nn.Conv2d(hidden_dim, oup, 1, bias=False),
|
||||||
|
# nn.BatchNorm2d(oup),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.bottleneckBlock(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DetailNode(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(DetailNode, self).__init__()
|
||||||
|
|
||||||
|
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
|
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
|
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
|
self.shffleconv = nn.Conv2d(64, 64, kernel_size=1,
|
||||||
|
stride=1, padding=0, bias=True)
|
||||||
|
|
||||||
|
def separateFeature(self, x):
|
||||||
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
||||||
|
return z1, z2
|
||||||
|
|
||||||
|
def forward(self, z1, z2):
|
||||||
|
z1, z2 = self.separateFeature(
|
||||||
|
self.shffleconv(torch.cat((z1, z2), dim=1)))
|
||||||
|
z2 = z2 + self.theta_phi(z1)
|
||||||
|
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
|
||||||
|
return z1, z2
|
||||||
|
|
||||||
|
|
||||||
|
class DetailFeatureExtraction(nn.Module):
|
||||||
|
def __init__(self, num_layers=3):
|
||||||
|
super(DetailFeatureExtraction, self).__init__()
|
||||||
|
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||||
|
self.net = nn.Sequential(*INNmodules)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
||||||
|
for layer in self.net:
|
||||||
|
z1, z2 = layer(z1, z2)
|
||||||
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
## Layer Norm
|
||||||
|
def to_3d(x):
|
||||||
|
return rearrange(x, 'b c h w -> b (h w) c')
|
||||||
|
|
||||||
|
|
||||||
|
def to_4d(x, h, w):
|
||||||
|
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||||
|
|
||||||
|
|
||||||
|
class BiasFree_LayerNorm(nn.Module):
|
||||||
|
def __init__(self, normalized_shape):
|
||||||
|
super(BiasFree_LayerNorm, self).__init__()
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
normalized_shape = (normalized_shape,)
|
||||||
|
normalized_shape = torch.Size(normalized_shape)
|
||||||
|
|
||||||
|
assert len(normalized_shape) == 1
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||||
|
self.normalized_shape = normalized_shape
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
sigma = x.var(-1, keepdim=True, unbiased=False)
|
||||||
|
return x / torch.sqrt(sigma + 1e-5) * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
class WithBias_LayerNorm(nn.Module):
|
||||||
|
def __init__(self, normalized_shape):
|
||||||
|
super(WithBias_LayerNorm, self).__init__()
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
normalized_shape = (normalized_shape,)
|
||||||
|
normalized_shape = torch.Size(normalized_shape)
|
||||||
|
|
||||||
|
assert len(normalized_shape) == 1
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||||
|
self.normalized_shape = normalized_shape
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
mu = x.mean(-1, keepdim=True)
|
||||||
|
sigma = x.var(-1, keepdim=True, unbiased=False)
|
||||||
|
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, dim, LayerNorm_type):
|
||||||
|
super(LayerNorm, self).__init__()
|
||||||
|
if LayerNorm_type == 'BiasFree':
|
||||||
|
self.body = BiasFree_LayerNorm(dim)
|
||||||
|
else:
|
||||||
|
self.body = WithBias_LayerNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h, w = x.shape[-2:]
|
||||||
|
return to_4d(self.body(to_3d(x)), h, w)
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
## Gated-Dconv Feed-Forward Network (GDFN)
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, ffn_expansion_factor, bias):
|
||||||
|
super(FeedForward, self).__init__()
|
||||||
|
|
||||||
|
hidden_features = int(dim * ffn_expansion_factor)
|
||||||
|
|
||||||
|
self.project_in = nn.Conv2d(
|
||||||
|
dim, hidden_features * 2, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3,
|
||||||
|
stride=1, padding=1, groups=hidden_features * 2, bias=bias)
|
||||||
|
|
||||||
|
self.project_out = nn.Conv2d(
|
||||||
|
hidden_features, dim, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.project_in(x)
|
||||||
|
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
||||||
|
x = F.gelu(x1) * x2
|
||||||
|
x = self.project_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
## Multi-DConv Head Transposed Self-Attention (MDTA)
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, bias):
|
||||||
|
super(Attention, self).__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
||||||
|
|
||||||
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||||
|
self.qkv_dwconv = nn.Conv2d(
|
||||||
|
dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
|
||||||
|
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
qkv = self.qkv_dwconv(self.qkv(x))
|
||||||
|
q, k, v = qkv.chunk(3, dim=1)
|
||||||
|
|
||||||
|
q = rearrange(q, 'b (head c) h w -> b head c (h w)',
|
||||||
|
head=self.num_heads)
|
||||||
|
k = rearrange(k, 'b (head c) h w -> b head c (h w)',
|
||||||
|
head=self.num_heads)
|
||||||
|
v = rearrange(v, 'b (head c) h w -> b head c (h w)',
|
||||||
|
head=self.num_heads)
|
||||||
|
|
||||||
|
q = torch.nn.functional.normalize(q, dim=-1)
|
||||||
|
k = torch.nn.functional.normalize(k, dim=-1)
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = (attn @ v)
|
||||||
|
|
||||||
|
out = rearrange(out, 'b head c (h w) -> b (head c) h w',
|
||||||
|
head=self.num_heads, h=h, w=w)
|
||||||
|
|
||||||
|
out = self.project_out(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
|
||||||
|
super(TransformerBlock, self).__init__()
|
||||||
|
|
||||||
|
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
||||||
|
self.attn = Attention(dim, num_heads, bias)
|
||||||
|
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
||||||
|
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.attn(self.norm1(x))
|
||||||
|
x = x + self.ffn(self.norm2(x))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
## Overlapped image patch embedding with 3x3 Conv
|
||||||
|
class OverlapPatchEmbed(nn.Module):
|
||||||
|
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
||||||
|
super(OverlapPatchEmbed, self).__init__()
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3,
|
||||||
|
stride=1, padding=1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Restormer_Encoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
inp_channels=1,
|
||||||
|
out_channels=1,
|
||||||
|
dim=64,
|
||||||
|
num_blocks=[4, 4],
|
||||||
|
heads=[8, 8, 8],
|
||||||
|
ffn_expansion_factor=2,
|
||||||
|
bias=False,
|
||||||
|
LayerNorm_type='WithBias',
|
||||||
|
):
|
||||||
|
super(Restormer_Encoder, self).__init__()
|
||||||
|
|
||||||
|
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
||||||
|
|
||||||
|
self.encoder_level1 = nn.Sequential(
|
||||||
|
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
|
||||||
|
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
||||||
|
|
||||||
|
self.baseFeature = BaseFeatureExtraction(dim=dim)
|
||||||
|
|
||||||
|
self.detailFeature = DetailFeatureExtraction()
|
||||||
|
|
||||||
|
def forward(self, inp_img):
|
||||||
|
inp_enc_level1 = self.patch_embed(inp_img)
|
||||||
|
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
||||||
|
base_feature = self.baseFeature(out_enc_level1)
|
||||||
|
detail_feature = self.detailFeature(out_enc_level1)
|
||||||
|
return base_feature, detail_feature, out_enc_level1
|
||||||
|
|
||||||
|
|
||||||
|
class Restormer_Decoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
inp_channels=1,
|
||||||
|
out_channels=1,
|
||||||
|
dim=64,
|
||||||
|
num_blocks=[4, 4],
|
||||||
|
heads=[8, 8, 8],
|
||||||
|
ffn_expansion_factor=2,
|
||||||
|
bias=False,
|
||||||
|
LayerNorm_type='WithBias',
|
||||||
|
):
|
||||||
|
|
||||||
|
super(Restormer_Decoder, self).__init__()
|
||||||
|
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
|
||||||
|
self.encoder_level2 = nn.Sequential(
|
||||||
|
*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
|
||||||
|
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
||||||
|
self.output = nn.Sequential(
|
||||||
|
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
|
||||||
|
stride=1, padding=1, bias=bias),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3,
|
||||||
|
stride=1, padding=1, bias=bias), )
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, inp_img, base_feature, detail_feature):
|
||||||
|
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
|
||||||
|
out_enc_level0 = self.reduce_channel(out_enc_level0)
|
||||||
|
out_enc_level1 = self.encoder_level2(out_enc_level0)
|
||||||
|
if inp_img is not None:
|
||||||
|
out_enc_level1 = self.output(out_enc_level1) + inp_img
|
||||||
|
else:
|
||||||
|
out_enc_level1 = self.output(out_enc_level1)
|
||||||
|
return self.sigmoid(out_enc_level1), out_enc_level0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
height = 128
|
||||||
|
width = 128
|
||||||
|
window_size = 8
|
||||||
|
modelE = Restormer_Encoder().cuda()
|
||||||
|
modelD = Restormer_Decoder().cuda()
|
||||||
|
|
95
test_IVF.py
Normal file
95
test_IVF.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
import cv2
|
||||||
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from utils.Evaluator import Evaluator
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from utils.img_read_save import img_save,image_read_cv2
|
||||||
|
import warnings
|
||||||
|
import logging
|
||||||
|
# 增加
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
logging.basicConfig(level=logging.CRITICAL)
|
||||||
|
|
||||||
|
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
ckpt_path= r"models/PFCFuse.pth"
|
||||||
|
|
||||||
|
for dataset_name in ["MSRS","TNO","RoadScene"]:
|
||||||
|
print("\n"*2+"="*80)
|
||||||
|
model_name="PFCFuse "
|
||||||
|
print("The test result of "+dataset_name+' :')
|
||||||
|
test_folder=os.path.join('test_img',dataset_name)
|
||||||
|
test_out_folder=os.path.join('test_result',dataset_name)
|
||||||
|
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||||
|
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||||
|
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||||
|
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
|
||||||
|
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||||
|
|
||||||
|
Encoder.load_state_dict(torch.load(ckpt_path)['DIDF_Encoder'],strict=False)
|
||||||
|
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
|
||||||
|
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
|
||||||
|
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
|
||||||
|
Encoder.eval()
|
||||||
|
Decoder.eval()
|
||||||
|
BaseFuseLayer.eval()
|
||||||
|
DetailFuseLayer.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for img_name in os.listdir(os.path.join(test_folder,"ir")):
|
||||||
|
|
||||||
|
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
||||||
|
# 改
|
||||||
|
data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0
|
||||||
|
# ycrcb, uint8
|
||||||
|
data_VIS_BGR = cv2.imread(os.path.join(test_folder, "vi", img_name))
|
||||||
|
_, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb))
|
||||||
|
# 改
|
||||||
|
|
||||||
|
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
|
||||||
|
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||||||
|
|
||||||
|
feature_V_B, feature_V_D, feature_V = Encoder(data_VIS)
|
||||||
|
feature_I_B, feature_I_D, feature_I = Encoder(data_IR)
|
||||||
|
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
|
||||||
|
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
|
||||||
|
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
|
||||||
|
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
|
||||||
|
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
|
||||||
|
# 改
|
||||||
|
# float32 to uint8
|
||||||
|
fi = fi.astype(np.uint8)
|
||||||
|
ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb))
|
||||||
|
rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB)
|
||||||
|
img_save(rgb_fi, img_name.split(sep='.')[0], test_out_folder)
|
||||||
|
# 改
|
||||||
|
|
||||||
|
eval_folder=test_out_folder
|
||||||
|
ori_img_folder=test_folder
|
||||||
|
|
||||||
|
metric_result = np.zeros((8))
|
||||||
|
for img_name in os.listdir(os.path.join(ori_img_folder,"ir")):
|
||||||
|
ir = image_read_cv2(os.path.join(ori_img_folder,"ir", img_name), 'GRAY')
|
||||||
|
vi = image_read_cv2(os.path.join(ori_img_folder,"vi", img_name), 'GRAY')
|
||||||
|
fi = image_read_cv2(os.path.join(eval_folder, img_name.split('.')[0]+".png"), 'GRAY')
|
||||||
|
metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
|
||||||
|
, Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
|
||||||
|
, Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
|
||||||
|
, Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
|
||||||
|
|
||||||
|
metric_result /= len(os.listdir(eval_folder))
|
||||||
|
print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
|
||||||
|
print(model_name+'\t'+str(np.round(metric_result[0], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[1], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[2], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[3], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[4], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[5], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[6], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[7], 2))
|
||||||
|
)
|
||||||
|
print("="*80)
|
239
train.py
Normal file
239
train.py
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
'''
|
||||||
|
------------------------------------------------------------------------------
|
||||||
|
Import packages
|
||||||
|
------------------------------------------------------------------------------
|
||||||
|
'''
|
||||||
|
|
||||||
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||||
|
from utils.dataset import H5Dataset
|
||||||
|
import os
|
||||||
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from utils.loss import Fusionloss, cc, relative_diff_loss
|
||||||
|
import kornia
|
||||||
|
|
||||||
|
print(torch.__version__)
|
||||||
|
print(torch.cuda.is_available())
|
||||||
|
|
||||||
|
'''
|
||||||
|
------------------------------------------------------------------------------
|
||||||
|
Configure our network
|
||||||
|
------------------------------------------------------------------------------
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||||
|
criteria_fusion = Fusionloss()
|
||||||
|
model_str = 'PFCFuse'
|
||||||
|
|
||||||
|
# . Set the hyper-parameters for training
|
||||||
|
num_epochs = 120 # total epoch
|
||||||
|
epoch_gap = 40 # epoches of Phase I
|
||||||
|
|
||||||
|
lr = 1e-4
|
||||||
|
weight_decay = 0
|
||||||
|
batch_size = 1
|
||||||
|
GPU_number = os.environ['CUDA_VISIBLE_DEVICES']
|
||||||
|
# Coefficients of the loss function
|
||||||
|
coeff_mse_loss_VF = 1. # alpha1
|
||||||
|
coeff_mse_loss_IF = 1.
|
||||||
|
|
||||||
|
coeff_rmi_loss_VF = 1.
|
||||||
|
coeff_rmi_loss_IF = 1.
|
||||||
|
|
||||||
|
coeff_cos_loss_VF = 1.
|
||||||
|
coeff_cos_loss_IF = 1.
|
||||||
|
coeff_decomp = 2. # alpha2 and alpha4
|
||||||
|
coeff_tv = 5.
|
||||||
|
|
||||||
|
clip_grad_norm_value = 0.01
|
||||||
|
optim_step = 20
|
||||||
|
optim_gamma = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
# Model
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||||
|
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||||
|
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||||
|
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
|
||||||
|
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||||
|
|
||||||
|
# optimizer, scheduler and loss function
|
||||||
|
optimizer1 = torch.optim.Adam(
|
||||||
|
DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
optimizer2 = torch.optim.Adam(
|
||||||
|
DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
optimizer3 = torch.optim.Adam(
|
||||||
|
BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
optimizer4 = torch.optim.Adam(
|
||||||
|
DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma)
|
||||||
|
scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma)
|
||||||
|
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
|
||||||
|
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
|
||||||
|
|
||||||
|
MSELoss = nn.MSELoss()
|
||||||
|
L1Loss = nn.L1Loss()
|
||||||
|
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
|
||||||
|
HuberLoss = nn.HuberLoss()
|
||||||
|
|
||||||
|
# data loader
|
||||||
|
trainloader = DataLoader(H5Dataset(r"data/MSRS_train_imgsize_128_stride_200.h5"),
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=0)
|
||||||
|
|
||||||
|
loader = {'train': trainloader, }
|
||||||
|
timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M")
|
||||||
|
|
||||||
|
'''
|
||||||
|
------------------------------------------------------------------------------
|
||||||
|
Train
|
||||||
|
------------------------------------------------------------------------------
|
||||||
|
'''
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
prev_time = time.time()
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
''' train '''
|
||||||
|
for i, (data_VIS, data_IR) in enumerate(loader['train']):
|
||||||
|
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||||||
|
DIDF_Encoder.train()
|
||||||
|
DIDF_Decoder.train()
|
||||||
|
BaseFuseLayer.train()
|
||||||
|
DetailFuseLayer.train()
|
||||||
|
|
||||||
|
DIDF_Encoder.zero_grad()
|
||||||
|
DIDF_Decoder.zero_grad()
|
||||||
|
BaseFuseLayer.zero_grad()
|
||||||
|
DetailFuseLayer.zero_grad()
|
||||||
|
|
||||||
|
optimizer1.zero_grad()
|
||||||
|
optimizer2.zero_grad()
|
||||||
|
optimizer3.zero_grad()
|
||||||
|
optimizer4.zero_grad()
|
||||||
|
|
||||||
|
if epoch < epoch_gap: #Phase I
|
||||||
|
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
|
||||||
|
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR)
|
||||||
|
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
|
||||||
|
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
|
||||||
|
|
||||||
|
cc_loss_B = cc(feature_V_B, feature_I_B)
|
||||||
|
cc_loss_D = cc(feature_V_D, feature_I_D)
|
||||||
|
|
||||||
|
mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + HuberLoss(data_VIS, data_VIS_hat)
|
||||||
|
mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + HuberLoss(data_IR, data_IR_hat)
|
||||||
|
|
||||||
|
# print("mse_loss_V", mse_loss_V)
|
||||||
|
# print("mse_loss_I", mse_loss_I)
|
||||||
|
|
||||||
|
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
|
||||||
|
kornia.filters.SpatialGradient()(data_VIS_hat))
|
||||||
|
# print("Gradient_loss", Gradient_loss)
|
||||||
|
|
||||||
|
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
|
||||||
|
# print("loss_decomp", loss_decomp)
|
||||||
|
|
||||||
|
|
||||||
|
loss_rmi_v = relative_diff_loss(data_VIS, data_VIS_hat)
|
||||||
|
loss_rmi_i = relative_diff_loss(data_IR, data_IR_hat)
|
||||||
|
|
||||||
|
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
|
||||||
|
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss + \
|
||||||
|
coeff_rmi_loss_IF * loss_rmi_i + coeff_rmi_loss_VF * loss_rmi_v
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
nn.utils.clip_grad_norm_(
|
||||||
|
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
|
nn.utils.clip_grad_norm_(
|
||||||
|
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
|
optimizer1.step()
|
||||||
|
optimizer2.step()
|
||||||
|
else: #Phase II
|
||||||
|
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
|
||||||
|
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR)
|
||||||
|
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
|
||||||
|
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
|
||||||
|
data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
|
||||||
|
|
||||||
|
|
||||||
|
mse_loss_V = 5 * Loss_ssim(data_VIS, data_Fuse) + HuberLoss(data_VIS, data_Fuse)
|
||||||
|
mse_loss_I = 5 * Loss_ssim(data_IR, data_Fuse) + HuberLoss(data_IR, data_Fuse)
|
||||||
|
|
||||||
|
cc_loss_B = cc(feature_V_B, feature_I_B)
|
||||||
|
cc_loss_D = cc(feature_V_D, feature_I_D)
|
||||||
|
loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
|
||||||
|
|
||||||
|
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
|
||||||
|
|
||||||
|
loss = fusionloss + coeff_decomp * loss_decomp
|
||||||
|
loss.backward()
|
||||||
|
nn.utils.clip_grad_norm_(
|
||||||
|
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
|
nn.utils.clip_grad_norm_(
|
||||||
|
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
|
nn.utils.clip_grad_norm_(
|
||||||
|
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
|
nn.utils.clip_grad_norm_(
|
||||||
|
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||||
|
optimizer1.step()
|
||||||
|
optimizer2.step()
|
||||||
|
optimizer3.step()
|
||||||
|
optimizer4.step()
|
||||||
|
|
||||||
|
# Determine approximate time left
|
||||||
|
batches_done = epoch * len(loader['train']) + i
|
||||||
|
batches_left = num_epochs * len(loader['train']) - batches_done
|
||||||
|
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
|
||||||
|
epoch_time = time.time() - prev_time
|
||||||
|
prev_time = time.time()
|
||||||
|
sys.stdout.write(
|
||||||
|
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f]"
|
||||||
|
% (
|
||||||
|
epoch,
|
||||||
|
num_epochs,
|
||||||
|
i,
|
||||||
|
len(loader['train']),
|
||||||
|
loss.item(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# adjust the learning rate
|
||||||
|
|
||||||
|
scheduler1.step()
|
||||||
|
scheduler2.step()
|
||||||
|
if not epoch < epoch_gap:
|
||||||
|
scheduler3.step()
|
||||||
|
scheduler4.step()
|
||||||
|
|
||||||
|
if optimizer1.param_groups[0]['lr'] <= 1e-6:
|
||||||
|
optimizer1.param_groups[0]['lr'] = 1e-6
|
||||||
|
if optimizer2.param_groups[0]['lr'] <= 1e-6:
|
||||||
|
optimizer2.param_groups[0]['lr'] = 1e-6
|
||||||
|
if optimizer3.param_groups[0]['lr'] <= 1e-6:
|
||||||
|
optimizer3.param_groups[0]['lr'] = 1e-6
|
||||||
|
if optimizer4.param_groups[0]['lr'] <= 1e-6:
|
||||||
|
optimizer4.param_groups[0]['lr'] = 1e-6
|
||||||
|
|
||||||
|
if True:
|
||||||
|
checkpoint = {
|
||||||
|
'DIDF_Encoder': DIDF_Encoder.state_dict(),
|
||||||
|
'DIDF_Decoder': DIDF_Decoder.state_dict(),
|
||||||
|
'BaseFuseLayer': BaseFuseLayer.state_dict(),
|
||||||
|
'DetailFuseLayer': DetailFuseLayer.state_dict(),
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, os.path.join("models/PFCFusion"+timestamp+'.pth'))
|
||||||
|
|
305
utils/Evaluator.py
Normal file
305
utils/Evaluator.py
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import sklearn.metrics as skm
|
||||||
|
from scipy.signal import convolve2d
|
||||||
|
import math
|
||||||
|
from skimage.metrics import structural_similarity as ssim
|
||||||
|
|
||||||
|
def image_read_cv2(path, mode='RGB'):
|
||||||
|
img_BGR = cv2.imread(path).astype('float32')
|
||||||
|
assert mode == 'RGB' or mode == 'GRAY' or mode == 'YCrCb', 'mode error'
|
||||||
|
if mode == 'RGB':
|
||||||
|
img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2RGB)
|
||||||
|
elif mode == 'GRAY':
|
||||||
|
img = np.round(cv2.cvtColor(img_BGR, cv2.COLOR_BGR2GRAY))
|
||||||
|
elif mode == 'YCrCb':
|
||||||
|
img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2YCrCb)
|
||||||
|
return img
|
||||||
|
|
||||||
|
class Evaluator():
|
||||||
|
@classmethod
|
||||||
|
def input_check(cls, imgF, imgA=None, imgB=None):
|
||||||
|
if imgA is None:
|
||||||
|
assert type(imgF) == np.ndarray, 'type error'
|
||||||
|
assert len(imgF.shape) == 2, 'dimension error'
|
||||||
|
else:
|
||||||
|
assert type(imgF) == type(imgA) == type(imgB) == np.ndarray, 'type error'
|
||||||
|
assert imgF.shape == imgA.shape == imgB.shape, 'shape error'
|
||||||
|
assert len(imgF.shape) == 2, 'dimension error'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def EN(cls, img): # entropy
|
||||||
|
cls.input_check(img)
|
||||||
|
a = np.uint8(np.round(img)).flatten()
|
||||||
|
h = np.bincount(a) / a.shape[0]
|
||||||
|
return -sum(h * np.log2(h + (h == 0)))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def SD(cls, img):
|
||||||
|
cls.input_check(img)
|
||||||
|
return np.std(img)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def SF(cls, img):
|
||||||
|
cls.input_check(img)
|
||||||
|
return np.sqrt(np.mean((img[:, 1:] - img[:, :-1]) ** 2) + np.mean((img[1:, :] - img[:-1, :]) ** 2))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def AG(cls, img): # Average gradient
|
||||||
|
cls.input_check(img)
|
||||||
|
Gx, Gy = np.zeros_like(img), np.zeros_like(img)
|
||||||
|
|
||||||
|
Gx[:, 0] = img[:, 1] - img[:, 0]
|
||||||
|
Gx[:, -1] = img[:, -1] - img[:, -2]
|
||||||
|
Gx[:, 1:-1] = (img[:, 2:] - img[:, :-2]) / 2
|
||||||
|
|
||||||
|
Gy[0, :] = img[1, :] - img[0, :]
|
||||||
|
Gy[-1, :] = img[-1, :] - img[-2, :]
|
||||||
|
Gy[1:-1, :] = (img[2:, :] - img[:-2, :]) / 2
|
||||||
|
return np.mean(np.sqrt((Gx ** 2 + Gy ** 2) / 2))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def MI(cls, image_F, image_A, image_B):
|
||||||
|
cls.input_check(image_F, image_A, image_B)
|
||||||
|
return skm.mutual_info_score(image_F.flatten(), image_A.flatten()) + skm.mutual_info_score(image_F.flatten(),
|
||||||
|
image_B.flatten())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def MSE(cls, image_F, image_A, image_B): # MSE
|
||||||
|
cls.input_check(image_F, image_A, image_B)
|
||||||
|
return (np.mean((image_A - image_F) ** 2) + np.mean((image_B - image_F) ** 2)) / 2
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def CC(cls, image_F, image_A, image_B):
|
||||||
|
cls.input_check(image_F, image_A, image_B)
|
||||||
|
rAF = np.sum((image_A - np.mean(image_A)) * (image_F - np.mean(image_F))) / np.sqrt(
|
||||||
|
(np.sum((image_A - np.mean(image_A)) ** 2)) * (np.sum((image_F - np.mean(image_F)) ** 2)))
|
||||||
|
rBF = np.sum((image_B - np.mean(image_B)) * (image_F - np.mean(image_F))) / np.sqrt(
|
||||||
|
(np.sum((image_B - np.mean(image_B)) ** 2)) * (np.sum((image_F - np.mean(image_F)) ** 2)))
|
||||||
|
return (rAF + rBF) / 2
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def PSNR(cls, image_F, image_A, image_B):
|
||||||
|
cls.input_check(image_F, image_A, image_B)
|
||||||
|
return 10 * np.log10(np.max(image_F) ** 2 / cls.MSE(image_F, image_A, image_B))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def SCD(cls, image_F, image_A, image_B): # The sum of the correlations of differences
|
||||||
|
cls.input_check(image_F, image_A, image_B)
|
||||||
|
imgF_A = image_F - image_A
|
||||||
|
imgF_B = image_F - image_B
|
||||||
|
corr1 = np.sum((image_A - np.mean(image_A)) * (imgF_B - np.mean(imgF_B))) / np.sqrt(
|
||||||
|
(np.sum((image_A - np.mean(image_A)) ** 2)) * (np.sum((imgF_B - np.mean(imgF_B)) ** 2)))
|
||||||
|
corr2 = np.sum((image_B - np.mean(image_B)) * (imgF_A - np.mean(imgF_A))) / np.sqrt(
|
||||||
|
(np.sum((image_B - np.mean(image_B)) ** 2)) * (np.sum((imgF_A - np.mean(imgF_A)) ** 2)))
|
||||||
|
return corr1 + corr2
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VIFF(cls, image_F, image_A, image_B):
|
||||||
|
cls.input_check(image_F, image_A, image_B)
|
||||||
|
return cls.compare_viff(image_A, image_F)+cls.compare_viff(image_B, image_F)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def compare_viff(cls,ref, dist): # viff of a pair of pictures
|
||||||
|
sigma_nsq = 2
|
||||||
|
eps = 1e-10
|
||||||
|
|
||||||
|
num = 0.0
|
||||||
|
den = 0.0
|
||||||
|
for scale in range(1, 5):
|
||||||
|
|
||||||
|
N = 2 ** (4 - scale + 1) + 1
|
||||||
|
sd = N / 5.0
|
||||||
|
|
||||||
|
# Create a Gaussian kernel as MATLAB's
|
||||||
|
m, n = [(ss - 1.) / 2. for ss in (N, N)]
|
||||||
|
y, x = np.ogrid[-m:m + 1, -n:n + 1]
|
||||||
|
h = np.exp(-(x * x + y * y) / (2. * sd * sd))
|
||||||
|
h[h < np.finfo(h.dtype).eps * h.max()] = 0
|
||||||
|
sumh = h.sum()
|
||||||
|
if sumh != 0:
|
||||||
|
win = h / sumh
|
||||||
|
|
||||||
|
if scale > 1:
|
||||||
|
ref = convolve2d(ref, np.rot90(win, 2), mode='valid')
|
||||||
|
dist = convolve2d(dist, np.rot90(win, 2), mode='valid')
|
||||||
|
ref = ref[::2, ::2]
|
||||||
|
dist = dist[::2, ::2]
|
||||||
|
|
||||||
|
mu1 = convolve2d(ref, np.rot90(win, 2), mode='valid')
|
||||||
|
mu2 = convolve2d(dist, np.rot90(win, 2), mode='valid')
|
||||||
|
mu1_sq = mu1 * mu1
|
||||||
|
mu2_sq = mu2 * mu2
|
||||||
|
mu1_mu2 = mu1 * mu2
|
||||||
|
sigma1_sq = convolve2d(ref * ref, np.rot90(win, 2), mode='valid') - mu1_sq
|
||||||
|
sigma2_sq = convolve2d(dist * dist, np.rot90(win, 2), mode='valid') - mu2_sq
|
||||||
|
sigma12 = convolve2d(ref * dist, np.rot90(win, 2), mode='valid') - mu1_mu2
|
||||||
|
|
||||||
|
sigma1_sq[sigma1_sq < 0] = 0
|
||||||
|
sigma2_sq[sigma2_sq < 0] = 0
|
||||||
|
|
||||||
|
g = sigma12 / (sigma1_sq + eps)
|
||||||
|
sv_sq = sigma2_sq - g * sigma12
|
||||||
|
|
||||||
|
g[sigma1_sq < eps] = 0
|
||||||
|
sv_sq[sigma1_sq < eps] = sigma2_sq[sigma1_sq < eps]
|
||||||
|
sigma1_sq[sigma1_sq < eps] = 0
|
||||||
|
|
||||||
|
g[sigma2_sq < eps] = 0
|
||||||
|
sv_sq[sigma2_sq < eps] = 0
|
||||||
|
|
||||||
|
sv_sq[g < 0] = sigma2_sq[g < 0]
|
||||||
|
g[g < 0] = 0
|
||||||
|
sv_sq[sv_sq <= eps] = eps
|
||||||
|
|
||||||
|
num += np.sum(np.log10(1 + g * g * sigma1_sq / (sv_sq + sigma_nsq)))
|
||||||
|
den += np.sum(np.log10(1 + sigma1_sq / sigma_nsq))
|
||||||
|
|
||||||
|
vifp = num / den
|
||||||
|
|
||||||
|
if np.isnan(vifp):
|
||||||
|
return 1.0
|
||||||
|
else:
|
||||||
|
return vifp
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def Qabf(cls, image_F, image_A, image_B):
|
||||||
|
cls.input_check(image_F, image_A, image_B)
|
||||||
|
gA, aA = cls.Qabf_getArray(image_A)
|
||||||
|
gB, aB = cls.Qabf_getArray(image_B)
|
||||||
|
gF, aF = cls.Qabf_getArray(image_F)
|
||||||
|
QAF = cls.Qabf_getQabf(aA, gA, aF, gF)
|
||||||
|
QBF = cls.Qabf_getQabf(aB, gB, aF, gF)
|
||||||
|
|
||||||
|
# 计算QABF
|
||||||
|
deno = np.sum(gA + gB)
|
||||||
|
nume = np.sum(np.multiply(QAF, gA) + np.multiply(QBF, gB))
|
||||||
|
return nume / deno
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def Qabf_getArray(cls,img):
|
||||||
|
# Sobel Operator Sobel
|
||||||
|
h1 = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).astype(np.float32)
|
||||||
|
h2 = np.array([[0, 1, 2], [-1, 0, 1], [-2, -1, 0]]).astype(np.float32)
|
||||||
|
h3 = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).astype(np.float32)
|
||||||
|
|
||||||
|
SAx = convolve2d(img, h3, mode='same')
|
||||||
|
SAy = convolve2d(img, h1, mode='same')
|
||||||
|
gA = np.sqrt(np.multiply(SAx, SAx) + np.multiply(SAy, SAy))
|
||||||
|
aA = np.zeros_like(img)
|
||||||
|
aA[SAx == 0] = math.pi / 2
|
||||||
|
aA[SAx != 0]=np.arctan(SAy[SAx != 0] / SAx[SAx != 0])
|
||||||
|
return gA, aA
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def Qabf_getQabf(cls,aA, gA, aF, gF):
|
||||||
|
L = 1
|
||||||
|
Tg = 0.9994
|
||||||
|
kg = -15
|
||||||
|
Dg = 0.5
|
||||||
|
Ta = 0.9879
|
||||||
|
ka = -22
|
||||||
|
Da = 0.8
|
||||||
|
GAF,AAF,QgAF,QaAF,QAF = np.zeros_like(aA),np.zeros_like(aA),np.zeros_like(aA),np.zeros_like(aA),np.zeros_like(aA)
|
||||||
|
GAF[gA>gF]=gF[gA>gF]/gA[gA>gF]
|
||||||
|
GAF[gA == gF] = gF[gA == gF]
|
||||||
|
GAF[gA <gF] = gA[gA<gF]/gF[gA<gF]
|
||||||
|
AAF = 1 - np.abs(aA - aF) / (math.pi / 2)
|
||||||
|
QgAF = Tg / (1 + np.exp(kg * (GAF - Dg)))
|
||||||
|
QaAF = Ta / (1 + np.exp(ka * (AAF - Da)))
|
||||||
|
QAF = QgAF* QaAF
|
||||||
|
return QAF
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def SSIM(cls, image_F, image_A, image_B):
|
||||||
|
cls.input_check(image_F, image_A, image_B)
|
||||||
|
return ssim(image_F,image_A)+ssim(image_F,image_B)
|
||||||
|
|
||||||
|
|
||||||
|
def VIFF(image_F, image_A, image_B):
|
||||||
|
refA=image_A
|
||||||
|
refB=image_B
|
||||||
|
dist=image_F
|
||||||
|
|
||||||
|
sigma_nsq = 2
|
||||||
|
eps = 1e-10
|
||||||
|
numA = 0.0
|
||||||
|
denA = 0.0
|
||||||
|
numB = 0.0
|
||||||
|
denB = 0.0
|
||||||
|
for scale in range(1, 5):
|
||||||
|
N = 2 ** (4 - scale + 1) + 1
|
||||||
|
sd = N / 5.0
|
||||||
|
# Create a Gaussian kernel as MATLAB's
|
||||||
|
m, n = [(ss - 1.) / 2. for ss in (N, N)]
|
||||||
|
y, x = np.ogrid[-m:m + 1, -n:n + 1]
|
||||||
|
h = np.exp(-(x * x + y * y) / (2. * sd * sd))
|
||||||
|
h[h < np.finfo(h.dtype).eps * h.max()] = 0
|
||||||
|
sumh = h.sum()
|
||||||
|
if sumh != 0:
|
||||||
|
win = h / sumh
|
||||||
|
|
||||||
|
if scale > 1:
|
||||||
|
refA = convolve2d(refA, np.rot90(win, 2), mode='valid')
|
||||||
|
refB = convolve2d(refB, np.rot90(win, 2), mode='valid')
|
||||||
|
dist = convolve2d(dist, np.rot90(win, 2), mode='valid')
|
||||||
|
refA = refA[::2, ::2]
|
||||||
|
refB = refB[::2, ::2]
|
||||||
|
dist = dist[::2, ::2]
|
||||||
|
|
||||||
|
mu1A = convolve2d(refA, np.rot90(win, 2), mode='valid')
|
||||||
|
mu1B = convolve2d(refB, np.rot90(win, 2), mode='valid')
|
||||||
|
mu2 = convolve2d(dist, np.rot90(win, 2), mode='valid')
|
||||||
|
mu1_sq_A = mu1A * mu1A
|
||||||
|
mu1_sq_B = mu1B * mu1B
|
||||||
|
mu2_sq = mu2 * mu2
|
||||||
|
mu1A_mu2 = mu1A * mu2
|
||||||
|
mu1B_mu2 = mu1B * mu2
|
||||||
|
sigma1A_sq = convolve2d(refA * refA, np.rot90(win, 2), mode='valid') - mu1_sq_A
|
||||||
|
sigma1B_sq = convolve2d(refB * refB, np.rot90(win, 2), mode='valid') - mu1_sq_B
|
||||||
|
sigma2_sq = convolve2d(dist * dist, np.rot90(win, 2), mode='valid') - mu2_sq
|
||||||
|
sigma12_A = convolve2d(refA * dist, np.rot90(win, 2), mode='valid') - mu1A_mu2
|
||||||
|
sigma12_B = convolve2d(refB * dist, np.rot90(win, 2), mode='valid') - mu1B_mu2
|
||||||
|
|
||||||
|
sigma1A_sq[sigma1A_sq < 0] = 0
|
||||||
|
sigma1B_sq[sigma1B_sq < 0] = 0
|
||||||
|
sigma2_sq[sigma2_sq < 0] = 0
|
||||||
|
|
||||||
|
gA = sigma12_A / (sigma1A_sq + eps)
|
||||||
|
gB = sigma12_B / (sigma1B_sq + eps)
|
||||||
|
sv_sq_A = sigma2_sq - gA * sigma12_A
|
||||||
|
sv_sq_B = sigma2_sq - gB * sigma12_B
|
||||||
|
|
||||||
|
gA[sigma1A_sq < eps] = 0
|
||||||
|
gB[sigma1B_sq < eps] = 0
|
||||||
|
sv_sq_A[sigma1A_sq < eps] = sigma2_sq[sigma1A_sq < eps]
|
||||||
|
sv_sq_B[sigma1B_sq < eps] = sigma2_sq[sigma1B_sq < eps]
|
||||||
|
sigma1A_sq[sigma1A_sq < eps] = 0
|
||||||
|
sigma1B_sq[sigma1B_sq < eps] = 0
|
||||||
|
|
||||||
|
gA[sigma2_sq < eps] = 0
|
||||||
|
gB[sigma2_sq < eps] = 0
|
||||||
|
sv_sq_A[sigma2_sq < eps] = 0
|
||||||
|
sv_sq_B[sigma2_sq < eps] = 0
|
||||||
|
|
||||||
|
sv_sq_A[gA < 0] = sigma2_sq[gA < 0]
|
||||||
|
sv_sq_B[gB < 0] = sigma2_sq[gB < 0]
|
||||||
|
gA[gA < 0] = 0
|
||||||
|
gB[gB < 0] = 0
|
||||||
|
sv_sq_A[sv_sq_A <= eps] = eps
|
||||||
|
sv_sq_B[sv_sq_B <= eps] = eps
|
||||||
|
|
||||||
|
numA += np.sum(np.log10(1 + gA * gA * sigma1A_sq / (sv_sq_A + sigma_nsq)))
|
||||||
|
numB += np.sum(np.log10(1 + gB * gB * sigma1B_sq / (sv_sq_B + sigma_nsq)))
|
||||||
|
denA += np.sum(np.log10(1 + sigma1A_sq / sigma_nsq))
|
||||||
|
denB += np.sum(np.log10(1 + sigma1B_sq / sigma_nsq))
|
||||||
|
|
||||||
|
vifpA = numA / denA
|
||||||
|
vifpB =numB / denB
|
||||||
|
|
||||||
|
if np.isnan(vifpA):
|
||||||
|
vifpA=1
|
||||||
|
if np.isnan(vifpB):
|
||||||
|
vifpB = 1
|
||||||
|
return vifpA+vifpB
|
BIN
utils/__pycache__/Evaluator.cpython-38.pyc
Normal file
BIN
utils/__pycache__/Evaluator.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/dataset.cpython-38.pyc
Normal file
BIN
utils/__pycache__/dataset.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/dataset_MIF1.cpython-38.pyc
Normal file
BIN
utils/__pycache__/dataset_MIF1.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/dataset_MIF4.cpython-38.pyc
Normal file
BIN
utils/__pycache__/dataset_MIF4.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/img_read_save.cpython-38.pyc
Normal file
BIN
utils/__pycache__/img_read_save.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/loss.cpython-38.pyc
Normal file
BIN
utils/__pycache__/loss.cpython-38.pyc
Normal file
Binary file not shown.
22
utils/dataset.py
Normal file
22
utils/dataset.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import torch.utils.data as Data
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class H5Dataset(Data.Dataset):
|
||||||
|
def __init__(self, h5file_path):
|
||||||
|
self.h5file_path = h5file_path
|
||||||
|
h5f = h5py.File(h5file_path, 'r')
|
||||||
|
self.keys = list(h5f['ir_patchs'].keys())
|
||||||
|
h5f.close()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.keys)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
h5f = h5py.File(self.h5file_path, 'r')
|
||||||
|
key = self.keys[index]
|
||||||
|
IR = np.array(h5f['ir_patchs'][key])
|
||||||
|
VIS = np.array(h5f['vis_patchs'][key])
|
||||||
|
h5f.close()
|
||||||
|
return torch.Tensor(VIS), torch.Tensor(IR)
|
28
utils/img_read_save.py
Normal file
28
utils/img_read_save.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
from skimage.io import imsave
|
||||||
|
|
||||||
|
def image_read_cv2(path, mode='RGB'):
|
||||||
|
img_BGR = cv2.imread(path).astype('float32')
|
||||||
|
# img_BGR = cv2.imread(path)
|
||||||
|
# print(img_BGR)
|
||||||
|
# if img_BGR is not None:
|
||||||
|
# img_BGR = img_BGR.astype('float32')
|
||||||
|
# else:
|
||||||
|
# print("处理图像加载失败的情况")
|
||||||
|
assert mode == 'RGB' or mode == 'GRAY' or mode == 'YCrCb', 'mode error'
|
||||||
|
if mode == 'RGB':
|
||||||
|
img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2RGB)
|
||||||
|
elif mode == 'GRAY':
|
||||||
|
img = np.round(cv2.cvtColor(img_BGR, cv2.COLOR_BGR2GRAY))
|
||||||
|
elif mode == 'YCrCb':
|
||||||
|
img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2YCrCb)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def img_save(image,imagename,savepath):
|
||||||
|
if not os.path.exists(savepath):
|
||||||
|
os.makedirs(savepath)
|
||||||
|
# Gray_pic
|
||||||
|
imsave(os.path.join(savepath, "{}.png".format(imagename)),image.astype(np.uint8))
|
||||||
|
|
113
utils/loss.py
Normal file
113
utils/loss.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class Fusionloss(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Fusionloss, self).__init__()
|
||||||
|
self.sobelconv=Sobelxy()
|
||||||
|
|
||||||
|
def forward(self,image_vis,image_ir,generate_img):
|
||||||
|
# 加
|
||||||
|
loss_rmi_v=relative_diff_loss(image_vis, generate_img)
|
||||||
|
loss_rmi_i=relative_diff_loss(image_ir, generate_img)
|
||||||
|
x_rmi_max=torch.max(loss_rmi_v, loss_rmi_i)
|
||||||
|
loss_rmi=F.l1_loss(x_rmi_max, generate_img)
|
||||||
|
# 加
|
||||||
|
image_y=image_vis[:,:1,:,:]
|
||||||
|
x_in_max=torch.max(image_y,image_ir)
|
||||||
|
loss_in=F.l1_loss(x_in_max,generate_img)
|
||||||
|
y_grad=self.sobelconv(image_y)
|
||||||
|
ir_grad=self.sobelconv(image_ir)
|
||||||
|
generate_img_grad=self.sobelconv(generate_img)
|
||||||
|
x_grad_joint=torch.max(y_grad,ir_grad)
|
||||||
|
loss_grad=F.l1_loss(x_grad_joint,generate_img_grad)
|
||||||
|
# loss_total=loss_in+10*loss_grad
|
||||||
|
#改
|
||||||
|
loss_total = loss_in + 10 * loss_grad + loss_rmi
|
||||||
|
return loss_total,loss_in,loss_grad
|
||||||
|
|
||||||
|
class Sobelxy(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Sobelxy, self).__init__()
|
||||||
|
kernelx = [[-1, 0, 1],
|
||||||
|
[-2,0 , 2],
|
||||||
|
[-1, 0, 1]]
|
||||||
|
kernely = [[1, 2, 1],
|
||||||
|
[0,0 , 0],
|
||||||
|
[-1, -2, -1]]
|
||||||
|
kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0)
|
||||||
|
kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0)
|
||||||
|
self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda()
|
||||||
|
self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda()
|
||||||
|
def forward(self,x):
|
||||||
|
sobelx=F.conv2d(x, self.weightx, padding=1)
|
||||||
|
sobely=F.conv2d(x, self.weighty, padding=1)
|
||||||
|
return torch.abs(sobelx)+torch.abs(sobely)
|
||||||
|
|
||||||
|
|
||||||
|
def cc(img1, img2):
|
||||||
|
eps = torch.finfo(torch.float32).eps
|
||||||
|
"""Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.]."""
|
||||||
|
N, C, _, _ = img1.shape
|
||||||
|
img1 = img1.reshape(N, C, -1)
|
||||||
|
img2 = img2.reshape(N, C, -1)
|
||||||
|
img1 = img1 - img1.mean(dim=-1, keepdim=True)
|
||||||
|
img2 = img2 - img2.mean(dim=-1, keepdim=True)
|
||||||
|
cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum(img1 **
|
||||||
|
2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1)))
|
||||||
|
cc = torch.clamp(cc, -1., 1.)
|
||||||
|
return cc.mean()
|
||||||
|
|
||||||
|
|
||||||
|
# def dice_coeff(img1, img2):
|
||||||
|
# smooth = 1.
|
||||||
|
# num = img1.size(0)
|
||||||
|
# m1 = img1.view(num, -1) # Flatten
|
||||||
|
# m2 = img2.view(num, -1) # Flatten
|
||||||
|
# intersection = (m1 * m2).sum()
|
||||||
|
#
|
||||||
|
# return 1 - (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
|
||||||
|
|
||||||
|
# 用来衡量图像之间的平均灰度差异
|
||||||
|
def relative_diff_loss(img1, img2):
|
||||||
|
# 计算图像的平均灰度值
|
||||||
|
mean_intensity_img1 = torch.mean(img1)
|
||||||
|
mean_intensity_img2 = torch.mean(img2)
|
||||||
|
# print("mean_intensity_img1")
|
||||||
|
# print(mean_intensity_img1)
|
||||||
|
# print("mean_intensity_img2")
|
||||||
|
# print(mean_intensity_img2)
|
||||||
|
|
||||||
|
# 计算relative_diff
|
||||||
|
epsilon = 1e-10 # 防止除零错误
|
||||||
|
relative_diff = abs((mean_intensity_img1 - mean_intensity_img2) / (mean_intensity_img1 + epsilon))
|
||||||
|
|
||||||
|
return relative_diff
|
||||||
|
|
||||||
|
# 互信息MI损失
|
||||||
|
# def mutual_information_loss(img1, img2):
|
||||||
|
# # 计算 X 和 Y 的熵
|
||||||
|
# entropy_img1 = -torch.mean(torch.sum(F.softmax(img1, dim=-1) * F.log_softmax(img1, dim=-1), dim=-1))
|
||||||
|
# entropy_img2 = -torch.mean(torch.sum(F.softmax(img2, dim=-1) * F.log_softmax(img2, dim=-1), dim=-1))
|
||||||
|
#
|
||||||
|
# # 计算 X 和 Y 的联合熵
|
||||||
|
# joint_entropy = -torch.mean(torch.sum(F.softmax(img1, dim=-1) * F.log_softmax(img2, dim=-1), dim=-1))
|
||||||
|
#
|
||||||
|
# # 计算互信息损失
|
||||||
|
# mutual_information = entropy_img1 + entropy_img2 - joint_entropy
|
||||||
|
#
|
||||||
|
# return mutual_information
|
||||||
|
|
||||||
|
|
||||||
|
# # 余弦相似度计算
|
||||||
|
# def cosine_similarity(img1, img2):
|
||||||
|
# # Flatten the tensors
|
||||||
|
# img1_flat = img1.view(-1)
|
||||||
|
# img2_flat = img2.view(-1)
|
||||||
|
#
|
||||||
|
# # Calculate cosine similarity
|
||||||
|
# similarity = Fine_similarity(img1_flat, img2_flat, dim=0)
|
||||||
|
#
|
||||||
|
# loss = torch.abs(similarity - 1)
|
||||||
|
# return loss.item() # Convert to Python float
|
Loading…
Reference in New Issue
Block a user