feat(components): 添加 DEConv 和 SEBlock 组件

- 新增 DEConv 组件,用于细节增强卷积
- 新增 SEBlock组件,用于通道注意力机制
- 更新 net.py 中的 DetailNode 结构
- 调整 train.py 中的模型初始化
This commit is contained in:
whai 2024-11-14 16:59:11 +08:00
parent c1eed72f24
commit 30bbfdf86e
4 changed files with 94 additions and 31 deletions

View File

@ -1,2 +0,0 @@
.idea/
status.md

37
componets/SEBlock.py Normal file
View File

@ -0,0 +1,37 @@
'''-------------一、SE模块-----------------------------'''
import torch
from torch import nn
# 全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
class SE_Block(nn.Module):
def __init__(self, inchannel, ratio=16):
super(SE_Block, self).__init__()
# 全局平均池化(Fsq操作)
self.gap = nn.AdaptiveAvgPool2d((1, 1))
# 两个全连接层(Fex操作)
self.fc = nn.Sequential(
nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
nn.ReLU(),
nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
nn.Sigmoid()
)
def forward(self, x):
# 读取批数据图片数量及通道数
b, c, h, w = x.size()
# Fsq操作经池化后输出b*c的矩阵
y = self.gap(x).view(b, c)
# Fex操作经全连接层输出bc11矩阵
y = self.fc(y).view(b, c, 1, 1)
# Fscale操作将得到的权重乘以原来的特征图x
return x * y.expand_as(x)
if __name__ == '__main__':
input = torch.randn(1, 64, 32, 32)
seblock = SE_Block(64)
print(seblock)
output = seblock(input)
print(input.shape)
print(output.shape)

78
net.py
View File

@ -182,7 +182,7 @@ class BaseFeatureExtraction(nn.Module):
normal = self.norm1(x) # 1 64 128 128 normal = self.norm1(x) # 1 64 128 128
token_mix = self.token_mixer(normal) # 1 64 128 128 token_mix = self.token_mixer(normal) # 1 64 128 128
x = self.WTConv2d(x)
x = (x + x = (x +
self.drop_path( self.drop_path(
@ -209,6 +209,7 @@ class BaseFeatureFusion(nn.Module):
self.WTConv2d = WTConv2d(dim, dim) self.WTConv2d = WTConv2d(dim, dim)
self.norm1 = LayerNorm(dim, 'WithBias') self.norm1 = LayerNorm(dim, 'WithBias')
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代 self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias') self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
@ -271,12 +272,38 @@ class InvertedResidualBlock(nn.Module):
def forward(self, x): def forward(self, x):
return self.bottleneckBlock(x) return self.bottleneckBlock(x)
class DepthwiseSeparableConvBlock(nn.Module):
def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1):
super(DepthwiseSeparableConvBlock, self).__init__()
self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False)
self.pointwise = nn.Conv2d(inp, oup, 1, bias=False)
self.bn = nn.BatchNorm2d(oup)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
x = self.bn(x)
x = self.relu(x)
return x
class DetailNode(nn.Module): class DetailNode(nn.Module):
# <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > ' # <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > '
def __init__(self): def __init__(self,useBlock=0):
super(DetailNode, self).__init__() super(DetailNode, self).__init__()
if useBlock == 0:
self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32)
self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32)
self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32)
elif useBlock == 1:
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
else:
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
@ -298,37 +325,37 @@ class DetailNode(nn.Module):
class DetailFeatureExtraction(nn.Module): class DetailFeatureExtraction(nn.Module):
def __init__(self, num_layers=3): def __init__(self, num_layers=3):
super(DetailFeatureExtraction, self).__init__() super(DetailFeatureExtraction, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)] INNmodules = [DetailNode(use) for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules) self.net = nn.Sequential(*INNmodules)
self.enhancement_module = WTConv2d(32, 32) self.enhancement_module = WTConv2d(32, 32)
def forward(self, x): # 1 64 128 128 def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
# 增强并添加残差连接 # # 增强并添加残差连接
enhanced_z1 = self.enhancement_module(z1) # enhanced_z1 = self.enhancement_module(z1)
enhanced_z2 = self.enhancement_module(z2) # enhanced_z2 = self.enhancement_module(z2)
# 残差连接 # # 残差连接
z1 = z1 + enhanced_z1 # z1 = z1 + enhanced_z1
z2 = z2 + enhanced_z2 # z2 = z2 + enhanced_z2
for layer in self.net: for layer in self.net:
z1, z2 = layer(z1, z2) z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1) return torch.cat((z1, z2), dim=1)
class DetailFeatureFusioin(nn.Module): class DetailFeatureFusion(nn.Module):
def __init__(self, num_layers=3): def __init__(self, num_layers=3):
super(DetailFeatureFusioin, self).__init__() super(DetailFeatureFusion, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)] INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules) self.net = nn.Sequential(*INNmodules)
self.enhancement_module = WTConv2d(32, 32) self.enhancement_module = WTConv2d(32, 32)
def forward(self, x): # 1 64 128 128 def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
# 增强并添加残差连接 # # 增强并添加残差连接
enhanced_z1 = self.enhancement_module(z1) # enhanced_z1 = self.enhancement_module(z1)
enhanced_z2 = self.enhancement_module(z2) # enhanced_z2 = self.enhancement_module(z2)
# 残差连接 # # 残差连接
z1 = z1 + enhanced_z1 # z1 = z1 + enhanced_z1
z2 = z2 + enhanced_z2 # z2 = z2 + enhanced_z2
for layer in self.net: for layer in self.net:
z1, z2 = layer(z1, z2) z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1) return torch.cat((z1, z2), dim=1)
@ -508,7 +535,8 @@ class BaseFeatureExtractionSAR(nn.Module):
self.WTConv2d = WTConv2d(dim, dim) self.WTConv2d = WTConv2d(dim, dim)
self.norm1 = LayerNorm(dim, 'WithBias') self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代 self.token_mixer = SCSA(dim=dim, head_num=8)
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias') self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
@ -554,18 +582,18 @@ class BaseFeatureExtractionSAR(nn.Module):
class DetailFeatureExtractionSAR(nn.Module): class DetailFeatureExtractionSAR(nn.Module):
def __init__(self, num_layers=3): def __init__(self, num_layers=3):
super(DetailFeatureExtractionSAR, self).__init__() super(DetailFeatureExtractionSAR, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)] INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules) self.net = nn.Sequential(*INNmodules)
self.enhancement_module = WTConv2d(32, 32) self.enhancement_module = WTConv2d(32, 32)
def forward(self, x): # 1 64 128 128 def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
# 增强并添加残差连接 # # 增强并添加残差连接
enhanced_z1 = self.enhancement_module(z1) # enhanced_z1 = self.enhancement_module(z1)
enhanced_z2 = self.enhancement_module(z2) # enhanced_z2 = self.enhancement_module(z2)
# 残差连接 # # 残差连接
z1 = z1 + enhanced_z1 # z1 = z1 + enhanced_z1
z2 = z2 + enhanced_z2 # z2 = z2 + enhanced_z2
for layer in self.net: for layer in self.net:
z1, z2 = layer(z1, z2) z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1) return torch.cat((z1, z2), dim=1)

View File

@ -87,7 +87,7 @@ DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device) DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) # BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device) BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureFusioin(num_layers=1)).to(device) DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device)
# optimizer, scheduler and loss function # optimizer, scheduler and loss function
optimizer1 = torch.optim.Adam( optimizer1 = torch.optim.Adam(