diff --git a/ .gitignore b/ .gitignore deleted file mode 100644 index edf7231..0000000 --- a/ .gitignore +++ /dev/null @@ -1,2 +0,0 @@ - .idea/ - status.md diff --git a/componets/SEBlock.py b/componets/SEBlock.py new file mode 100644 index 0000000..9cb612b --- /dev/null +++ b/componets/SEBlock.py @@ -0,0 +1,37 @@ +'''-------------一、SE模块-----------------------------''' +import torch +from torch import nn + + +# 全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid +class SE_Block(nn.Module): + def __init__(self, inchannel, ratio=16): + super(SE_Block, self).__init__() + # 全局平均池化(Fsq操作) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + # 两个全连接层(Fex操作) + self.fc = nn.Sequential( + nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r + nn.ReLU(), + nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c + nn.Sigmoid() + ) + + def forward(self, x): + # 读取批数据图片数量及通道数 + b, c, h, w = x.size() + # Fsq操作:经池化后输出b*c的矩阵 + y = self.gap(x).view(b, c) + # Fex操作:经全连接层输出(b,c,1,1)矩阵 + y = self.fc(y).view(b, c, 1, 1) + # Fscale操作:将得到的权重乘以原来的特征图x + return x * y.expand_as(x) + +if __name__ == '__main__': + input = torch.randn(1, 64, 32, 32) + seblock = SE_Block(64) + print(seblock) + output = seblock(input) + print(input.shape) + print(output.shape) + diff --git a/net.py b/net.py index f23aac3..1f45825 100644 --- a/net.py +++ b/net.py @@ -182,7 +182,7 @@ class BaseFeatureExtraction(nn.Module): normal = self.norm1(x) # 1 64 128 128 token_mix = self.token_mixer(normal) # 1 64 128 128 - x = self.WTConv2d(x) + x = (x + self.drop_path( @@ -209,6 +209,7 @@ class BaseFeatureFusion(nn.Module): self.WTConv2d = WTConv2d(dim, dim) self.norm1 = LayerNorm(dim, 'WithBias') + # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 self.norm2 = LayerNorm(dim, 'WithBias') mlp_hidden_dim = int(dim * mlp_ratio) @@ -271,15 +272,41 @@ class InvertedResidualBlock(nn.Module): def forward(self, x): return self.bottleneckBlock(x) + +class DepthwiseSeparableConvBlock(nn.Module): + def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1): + super(DepthwiseSeparableConvBlock, self).__init__() + self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False) + self.pointwise = nn.Conv2d(inp, oup, 1, bias=False) + self.bn = nn.BatchNorm2d(oup) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.depthwise(x) + x = self.pointwise(x) + x = self.bn(x) + x = self.relu(x) + return x + + class DetailNode(nn.Module): # ' - def __init__(self): + def __init__(self,useBlock=0): super(DetailNode, self).__init__() - self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + if useBlock == 0: + self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32) + self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32) + self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32) + elif useBlock == 1: + self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + else: + self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.shffleconv = nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0, bias=True) @@ -298,37 +325,37 @@ class DetailNode(nn.Module): class DetailFeatureExtraction(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtraction, self).__init__() - INNmodules = [DetailNode() for _ in range(num_layers)] + INNmodules = [DetailNode(use) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) self.enhancement_module = WTConv2d(32, 32) def forward(self, x): # 1 64 128 128 z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 - # 增强并添加残差连接 - enhanced_z1 = self.enhancement_module(z1) - enhanced_z2 = self.enhancement_module(z2) - # 残差连接 - z1 = z1 + enhanced_z1 - z2 = z2 + enhanced_z2 + # # 增强并添加残差连接 + # enhanced_z1 = self.enhancement_module(z1) + # enhanced_z2 = self.enhancement_module(z2) + # # 残差连接 + # z1 = z1 + enhanced_z1 + # z2 = z2 + enhanced_z2 for layer in self.net: z1, z2 = layer(z1, z2) return torch.cat((z1, z2), dim=1) -class DetailFeatureFusioin(nn.Module): +class DetailFeatureFusion(nn.Module): def __init__(self, num_layers=3): - super(DetailFeatureFusioin, self).__init__() + super(DetailFeatureFusion, self).__init__() INNmodules = [DetailNode() for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) self.enhancement_module = WTConv2d(32, 32) def forward(self, x): # 1 64 128 128 z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 - # 增强并添加残差连接 - enhanced_z1 = self.enhancement_module(z1) - enhanced_z2 = self.enhancement_module(z2) - # 残差连接 - z1 = z1 + enhanced_z1 - z2 = z2 + enhanced_z2 + # # 增强并添加残差连接 + # enhanced_z1 = self.enhancement_module(z1) + # enhanced_z2 = self.enhancement_module(z2) + # # 残差连接 + # z1 = z1 + enhanced_z1 + # z2 = z2 + enhanced_z2 for layer in self.net: z1, z2 = layer(z1, z2) return torch.cat((z1, z2), dim=1) @@ -508,7 +535,8 @@ class BaseFeatureExtractionSAR(nn.Module): self.WTConv2d = WTConv2d(dim, dim) self.norm1 = LayerNorm(dim, 'WithBias') - self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 + self.token_mixer = SCSA(dim=dim, head_num=8) + # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 self.norm2 = LayerNorm(dim, 'WithBias') mlp_hidden_dim = int(dim * mlp_ratio) self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, @@ -554,18 +582,18 @@ class BaseFeatureExtractionSAR(nn.Module): class DetailFeatureExtractionSAR(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtractionSAR, self).__init__() - INNmodules = [DetailNode() for _ in range(num_layers)] + INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) self.enhancement_module = WTConv2d(32, 32) def forward(self, x): # 1 64 128 128 z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 - # 增强并添加残差连接 - enhanced_z1 = self.enhancement_module(z1) - enhanced_z2 = self.enhancement_module(z2) - # 残差连接 - z1 = z1 + enhanced_z1 - z2 = z2 + enhanced_z2 + # # 增强并添加残差连接 + # enhanced_z1 = self.enhancement_module(z1) + # enhanced_z2 = self.enhancement_module(z2) + # # 残差连接 + # z1 = z1 + enhanced_z1 + # z2 = z2 + enhanced_z2 for layer in self.net: z1, z2 = layer(z1, z2) return torch.cat((z1, z2), dim=1) diff --git a/train.py b/train.py index a5d8c86..b5f2363 100644 --- a/train.py +++ b/train.py @@ -87,7 +87,7 @@ DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device) DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device) # BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device) -DetailFuseLayer = nn.DataParallel(DetailFeatureFusioin(num_layers=1)).to(device) +DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device) # optimizer, scheduler and loss function optimizer1 = torch.optim.Adam(