diff --git a/net.py b/net.py index 1791517..e404ed0 100644 --- a/net.py +++ b/net.py @@ -6,10 +6,7 @@ import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from einops import rearrange -from componets.WTConvCV2 import WTConv2d - -# 以一定概率随机丢弃输入张量中的路径,用于正则化模型 def drop_path(x, drop_prob: float = 0., training: bool = False): if drop_prob == 0. or not training: return x @@ -35,9 +32,6 @@ class DropPath(nn.Module): def forward(self, x): return drop_path(x, self.drop_prob, self.training) - - -# 改点,使用Pooling替换AttentionBase class Pooling(nn.Module): def __init__(self, kernel_size=3): super().__init__() @@ -50,8 +44,8 @@ class Pooling(nn.Module): class PoolMlp(nn.Module): """ - 实现基于1x1卷积的MLP模块。 - 输入:形状为[B, C, H, W]的张量。 + Implementation of MLP with 1*1 convolutions. + Input: tensor with shape [B, C, H, W] """ def __init__(self, @@ -61,17 +55,6 @@ class PoolMlp(nn.Module): act_layer=nn.GELU, bias=False, drop=0.): - """ - 初始化PoolMlp模块。 - - 参数: - in_features (int): 输入特征的数量。 - hidden_features (int, 可选): 隐藏层特征的数量。默认为None,设置为与in_features相同。 - out_features (int, 可选): 输出特征的数量。默认为None,设置为与in_features相同。 - act_layer (nn.Module, 可选): 使用的激活层。默认为nn.GELU。 - bias (bool, 可选): 是否在卷积层中包含偏置项。默认为False。 - drop (float, 可选): Dropout比率。默认为0。 - """ super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -81,15 +64,6 @@ class PoolMlp(nn.Module): self.drop = nn.Dropout(drop) def forward(self, x): - """ - 通过PoolMlp模块的前向传播。 - - 参数: - x (torch.Tensor): 形状为[B, C, H, W]的输入张量。 - - 返回: - torch.Tensor: 形状为[B, C, H, W]的输出张量。 - """ x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W) x = self.act(x) x = self.drop(x) @@ -97,55 +71,6 @@ class PoolMlp(nn.Module): x = self.drop(x) return x - -# class BaseFeatureExtraction1(nn.Module): -# def __init__(self, dim, pool_size=3, mlp_ratio=4., -# act_layer=nn.GELU, -# # norm_layer=nn.LayerNorm, -# drop=0., drop_path=0., -# use_layer_scale=True, layer_scale_init_value=1e-5): -# -# super().__init__() -# -# self.norm1 = LayerNorm(dim, 'WithBias') -# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 -# self.norm2 = LayerNorm(dim, 'WithBias') -# mlp_hidden_dim = int(dim * mlp_ratio) -# self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, -# act_layer=act_layer, drop=drop) -# -# # The following two techniques are useful to train deep PoolFormers. -# self.drop_path = DropPath(drop_path) if drop_path > 0. \ -# else nn.Identity() -# self.use_layer_scale = use_layer_scale -# -# if use_layer_scale: -# self.layer_scale_1 = nn.Parameter( -# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) -# -# self.layer_scale_2 = nn.Parameter( -# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) -# -# def forward(self, x): # 1 64 128 128 -# if self.use_layer_scale: -# # self.layer_scale_1(64,) -# tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1 -# normal = self.norm1(x) # 1 64 128 128 -# token_mix = self.token_mixer(normal) # 1 64 128 128 -# x = (x + -# self.drop_path( -# tmp1 * token_mix -# ) -# # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。 -# ) -# x = x + self.drop_path( -# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) -# * self.poolmlp(self.norm2(x))) -# else: -# x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse -# x = x + self.drop_path(self.poolmlp(self.norm2(x))) -# return x - class BaseFeatureExtraction(nn.Module): def __init__(self, dim, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, @@ -155,7 +80,6 @@ class BaseFeatureExtraction(nn.Module): super().__init__() - self.norm1 = LayerNorm(dim, 'WithBias') self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 self.norm2 = LayerNorm(dim, 'WithBias') @@ -175,81 +99,19 @@ class BaseFeatureExtraction(nn.Module): self.layer_scale_2 = nn.Parameter( torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) - def forward(self, x): # 1 64 128 128 + def forward(self, x): if self.use_layer_scale: - # self.layer_scale_1(64,) - tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1 - normal = self.norm1(x) # 1 64 128 128 - token_mix = self.token_mixer(normal) # 1 64 128 128 - - - - x = (x + - self.drop_path( - tmp1 * token_mix - ) - # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。 - ) + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) + * self.token_mixer(self.norm1(x))) x = x + self.drop_path( self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.poolmlp(self.norm2(x))) else: - x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse + x = x + self.drop_path(self.token_mixer(self.norm1(x))) x = x + self.drop_path(self.poolmlp(self.norm2(x))) return x -class BaseFeatureFusion(nn.Module): - def __init__(self, dim, pool_size=3, mlp_ratio=4., - act_layer=nn.GELU, - # norm_layer=nn.LayerNorm, - drop=0., drop_path=0., - use_layer_scale=True, layer_scale_init_value=1e-5): - - super().__init__() - - - self.norm1 = LayerNorm(dim, 'WithBias') - # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 - self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 - self.norm2 = LayerNorm(dim, 'WithBias') - mlp_hidden_dim = int(dim * mlp_ratio) - self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, - act_layer=act_layer, drop=drop) - - # The following two techniques are useful to train deep PoolFormers. - self.drop_path = DropPath(drop_path) if drop_path > 0. \ - else nn.Identity() - self.use_layer_scale = use_layer_scale - - if use_layer_scale: - self.layer_scale_1 = nn.Parameter( - torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) - - self.layer_scale_2 = nn.Parameter( - torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) - - def forward(self, x): # 1 64 128 128 - if self.use_layer_scale: - # self.layer_scale_1(64,) - tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1 - normal = self.norm1(x) # 1 64 128 128 - token_mix = self.token_mixer(normal) # 1 64 128 128 - - - - x = (x + - self.drop_path( - tmp1 * token_mix - ) - # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。 - ) - x = x + self.drop_path( - self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) - * self.poolmlp(self.norm2(x))) - else: - x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse - x = x + self.drop_path(self.poolmlp(self.norm2(x))) - return x class InvertedResidualBlock(nn.Module): def __init__(self, inp, oup, expand_ratio): @@ -269,44 +131,18 @@ class InvertedResidualBlock(nn.Module): nn.Conv2d(hidden_dim, oup, 1, bias=False), # nn.BatchNorm2d(oup), ) + def forward(self, x): return self.bottleneckBlock(x) -class DepthwiseSeparableConvBlock(nn.Module): - def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1): - super(DepthwiseSeparableConvBlock, self).__init__() - self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False) - self.pointwise = nn.Conv2d(inp, oup, 1, bias=False) - self.bn = nn.BatchNorm2d(oup) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - x = self.depthwise(x) - x = self.pointwise(x) - x = self.bn(x) - x = self.relu(x) - return x - - class DetailNode(nn.Module): - - # ' - def __init__(self,useBlock=0): + def __init__(self): super(DetailNode, self).__init__() - if useBlock == 0: - self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32) - self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32) - self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32) - elif useBlock == 1: - self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - else: - self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.shffleconv = nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0, bias=True) @@ -325,43 +161,16 @@ class DetailNode(nn.Module): class DetailFeatureExtraction(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtraction, self).__init__() - INNmodules = [DetailNode(use) for _ in range(num_layers)] - self.net = nn.Sequential(*INNmodules) - # self.enhancement_module = WTConv2d(32, 32) - - def forward(self, x): # 1 64 128 128 - z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 - # # 增强并添加残差连接 - # enhanced_z1 = self.enhancement_module(z1) - # enhanced_z2 = self.enhancement_module(z2) - # # 残差连接 - # z1 = z1 + enhanced_z1 - # z2 = z2 + enhanced_z2 - for layer in self.net: - z1, z2 = layer(z1, z2) - return torch.cat((z1, z2), dim=1) - -class DetailFeatureFusion(nn.Module): - def __init__(self, num_layers=3): - super(DetailFeatureFusion, self).__init__() INNmodules = [DetailNode() for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) - # self.enhancement_module = WTConv2d(32, 32) - def forward(self, x): # 1 64 128 128 - z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 - # # 增强并添加残差连接 - # enhanced_z1 = self.enhancement_module(z1) - # enhanced_z2 = self.enhancement_module(z2) - # # 残差连接 - # z1 = z1 + enhanced_z1 - # z2 = z2 + enhanced_z2 + def forward(self, x): + z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] for layer in self.net: z1, z2 = layer(z1, z2) return torch.cat((z1, z2), dim=1) - # ============================================================================= # ============================================================================= @@ -524,80 +333,6 @@ class OverlapPatchEmbed(nn.Module): return x -class BaseFeatureExtractionSAR(nn.Module): - def __init__(self, dim, pool_size=3, mlp_ratio=4., - act_layer=nn.GELU, - # norm_layer=nn.LayerNorm, - drop=0., drop_path=0., - use_layer_scale=True, layer_scale_init_value=1e-5): - - super().__init__() - - self.WTConv2d = WTConv2d(dim, dim) - self.norm1 = LayerNorm(dim, 'WithBias') - self.token_mixer = SCSA(dim=dim, head_num=8) - # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 - self.norm2 = LayerNorm(dim, 'WithBias') - mlp_hidden_dim = int(dim * mlp_ratio) - self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, - act_layer=act_layer, drop=drop) - - # The following two techniques are useful to train deep PoolFormers. - self.drop_path = DropPath(drop_path) if drop_path > 0. \ - else nn.Identity() - self.use_layer_scale = use_layer_scale - - if use_layer_scale: - self.layer_scale_1 = nn.Parameter( - torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) - - self.layer_scale_2 = nn.Parameter( - torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) - - def forward(self, x): # 1 64 128 128 - if self.use_layer_scale: - # self.layer_scale_1(64,) - tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1 - normal = self.norm1(x) # 1 64 128 128 - token_mix = self.token_mixer(normal) # 1 64 128 128 - - x = (x + - self.drop_path( - tmp1 * token_mix - ) - # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。 - ) - x = x + self.drop_path( - self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) - * self.poolmlp(self.norm2(x))) - else: - x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse - x = x + self.drop_path(self.poolmlp(self.norm2(x))) - return x - - - -class DetailFeatureExtractionSAR(nn.Module): - def __init__(self, num_layers=3): - super(DetailFeatureExtractionSAR, self).__init__() - INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)] - self.net = nn.Sequential(*INNmodules) - # self.enhancement_module = WTConv2d(32, 32) - - def forward(self, x): # 1 64 128 128 - z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 - # # 增强并添加残差连接 - # enhanced_z1 = self.enhancement_module(z1) - # enhanced_z2 = self.enhancement_module(z2) - # # 残差连接 - # z1 = z1 + enhanced_z1 - # z2 = z2 + enhanced_z2 - for layer in self.net: - z1, z2 = layer(z1, z2) - return torch.cat((z1, z2), dim=1) - - - class Restormer_Encoder(nn.Module): def __init__(self, inp_channels=1, @@ -611,30 +346,21 @@ class Restormer_Encoder(nn.Module): ): super(Restormer_Encoder, self).__init__() - # 区分 - self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + self.encoder_level1 = nn.Sequential( *[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) self.baseFeature = BaseFeatureExtraction(dim=dim) + self.detailFeature = DetailFeatureExtraction() - self.baseFeature_sar = BaseFeatureExtractionSAR(dim=dim) - self.detailFeature_sar = DetailFeatureExtractionSAR() - - def forward(self, inp_img,is_sar = False): + def forward(self, inp_img): inp_enc_level1 = self.patch_embed(inp_img) out_enc_level1 = self.encoder_level1(inp_enc_level1) - if is_sar: - base_feature = self.baseFeature_sar(out_enc_level1) # 1 64 128 128 - detail_feature = self.detailFeature_sar(out_enc_level1) # 1 64 128 128 - return base_feature, detail_feature, out_enc_level1 # 1 64 128 128 - - else: - base_feature = self.baseFeature(out_enc_level1) # 1 64 128 128 - detail_feature = self.detailFeature(out_enc_level1) # 1 64 128 128 - return base_feature, detail_feature, out_enc_level1 # 1 64 128 128 + base_feature = self.baseFeature(out_enc_level1) + detail_feature = self.detailFeature(out_enc_level1) + return base_feature, detail_feature, out_enc_level1 class Restormer_Decoder(nn.Module): @@ -651,7 +377,8 @@ class Restormer_Decoder(nn.Module): super(Restormer_Decoder, self).__init__() self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias) - self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, + self.encoder_level2 = nn.Sequential( + *[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) self.output = nn.Sequential( nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3, @@ -678,5 +405,3 @@ if __name__ == '__main__': window_size = 8 modelE = Restormer_Encoder().cuda() modelD = Restormer_Decoder().cuda() - print(modelE) - print(modelD) diff --git a/train.py b/train.py index b5f2363..5bf94a8 100644 --- a/train.py +++ b/train.py @@ -6,8 +6,7 @@ Import packages ------------------------------------------------------------------------------ ''' -from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion, \ - DetailFeatureFusioin +from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction from utils.dataset import H5Dataset import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' @@ -86,8 +85,8 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu' DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device) DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device) # BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) -BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device) -DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device) +BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device) +DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device) # optimizer, scheduler and loss function optimizer1 = torch.optim.Adam( @@ -150,7 +149,7 @@ for epoch in range(num_epochs): if epoch < epoch_gap: #Phase I feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS) - feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR,is_sar = True) + feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR) data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D) data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D) @@ -187,7 +186,7 @@ for epoch in range(num_epochs): optimizer2.step() else: #Phase II feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS) - feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR,is_sar = True) + feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR) feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B) feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D) data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D) @@ -223,20 +222,18 @@ for epoch in range(num_epochs): time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) epoch_time = time.time() - prev_time prev_time = time.time() - sys.stdout.write( - "\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s" - % ( - epoch, - num_epochs, - i, - len(loader['train']), - loss.item(), - time_left, - ) + "\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s" + % ( + epoch, + num_epochs, + i, + len(loader['train']), + loss.item(), + time_left, + ) ) - # adjust the learning rate scheduler1.step()