feat(components): 添加 DEConv 和 SEBlock 组件
- 新增 DEConv 组件,用于细节增强卷积 - 新增 SEBlock组件,用于通道注意力机制 - 更新 net.py 中的 DetailNode 结构 - 调整 train.py 中的模型初始化
This commit is contained in:
parent
c1eed72f24
commit
30bbfdf86e
@ -1,2 +0,0 @@
|
|||||||
.idea/
|
|
||||||
status.md
|
|
37
componets/SEBlock.py
Normal file
37
componets/SEBlock.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
'''-------------一、SE模块-----------------------------'''
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
# 全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
|
||||||
|
class SE_Block(nn.Module):
|
||||||
|
def __init__(self, inchannel, ratio=16):
|
||||||
|
super(SE_Block, self).__init__()
|
||||||
|
# 全局平均池化(Fsq操作)
|
||||||
|
self.gap = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
# 两个全连接层(Fex操作)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# 读取批数据图片数量及通道数
|
||||||
|
b, c, h, w = x.size()
|
||||||
|
# Fsq操作:经池化后输出b*c的矩阵
|
||||||
|
y = self.gap(x).view(b, c)
|
||||||
|
# Fex操作:经全连接层输出(b,c,1,1)矩阵
|
||||||
|
y = self.fc(y).view(b, c, 1, 1)
|
||||||
|
# Fscale操作:将得到的权重乘以原来的特征图x
|
||||||
|
return x * y.expand_as(x)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
input = torch.randn(1, 64, 32, 32)
|
||||||
|
seblock = SE_Block(64)
|
||||||
|
print(seblock)
|
||||||
|
output = seblock(input)
|
||||||
|
print(input.shape)
|
||||||
|
print(output.shape)
|
||||||
|
|
78
net.py
78
net.py
@ -182,7 +182,7 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
normal = self.norm1(x) # 1 64 128 128
|
normal = self.norm1(x) # 1 64 128 128
|
||||||
token_mix = self.token_mixer(normal) # 1 64 128 128
|
token_mix = self.token_mixer(normal) # 1 64 128 128
|
||||||
|
|
||||||
x = self.WTConv2d(x)
|
|
||||||
|
|
||||||
x = (x +
|
x = (x +
|
||||||
self.drop_path(
|
self.drop_path(
|
||||||
@ -209,6 +209,7 @@ class BaseFeatureFusion(nn.Module):
|
|||||||
|
|
||||||
self.WTConv2d = WTConv2d(dim, dim)
|
self.WTConv2d = WTConv2d(dim, dim)
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
|
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
@ -271,12 +272,38 @@ class InvertedResidualBlock(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.bottleneckBlock(x)
|
return self.bottleneckBlock(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthwiseSeparableConvBlock(nn.Module):
|
||||||
|
def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1):
|
||||||
|
super(DepthwiseSeparableConvBlock, self).__init__()
|
||||||
|
self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False)
|
||||||
|
self.pointwise = nn.Conv2d(inp, oup, 1, bias=False)
|
||||||
|
self.bn = nn.BatchNorm2d(oup)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.depthwise(x)
|
||||||
|
x = self.pointwise(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class DetailNode(nn.Module):
|
class DetailNode(nn.Module):
|
||||||
|
|
||||||
# <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > '
|
# <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > '
|
||||||
def __init__(self):
|
def __init__(self,useBlock=0):
|
||||||
super(DetailNode, self).__init__()
|
super(DetailNode, self).__init__()
|
||||||
|
|
||||||
|
if useBlock == 0:
|
||||||
|
self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
||||||
|
self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
||||||
|
self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
||||||
|
elif useBlock == 1:
|
||||||
|
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
|
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
|
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
|
else:
|
||||||
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
@ -298,37 +325,37 @@ class DetailNode(nn.Module):
|
|||||||
class DetailFeatureExtraction(nn.Module):
|
class DetailFeatureExtraction(nn.Module):
|
||||||
def __init__(self, num_layers=3):
|
def __init__(self, num_layers=3):
|
||||||
super(DetailFeatureExtraction, self).__init__()
|
super(DetailFeatureExtraction, self).__init__()
|
||||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
INNmodules = [DetailNode(use) for _ in range(num_layers)]
|
||||||
self.net = nn.Sequential(*INNmodules)
|
self.net = nn.Sequential(*INNmodules)
|
||||||
self.enhancement_module = WTConv2d(32, 32)
|
self.enhancement_module = WTConv2d(32, 32)
|
||||||
|
|
||||||
def forward(self, x): # 1 64 128 128
|
def forward(self, x): # 1 64 128 128
|
||||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
||||||
# 增强并添加残差连接
|
# # 增强并添加残差连接
|
||||||
enhanced_z1 = self.enhancement_module(z1)
|
# enhanced_z1 = self.enhancement_module(z1)
|
||||||
enhanced_z2 = self.enhancement_module(z2)
|
# enhanced_z2 = self.enhancement_module(z2)
|
||||||
# 残差连接
|
# # 残差连接
|
||||||
z1 = z1 + enhanced_z1
|
# z1 = z1 + enhanced_z1
|
||||||
z2 = z2 + enhanced_z2
|
# z2 = z2 + enhanced_z2
|
||||||
for layer in self.net:
|
for layer in self.net:
|
||||||
z1, z2 = layer(z1, z2)
|
z1, z2 = layer(z1, z2)
|
||||||
return torch.cat((z1, z2), dim=1)
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
class DetailFeatureFusioin(nn.Module):
|
class DetailFeatureFusion(nn.Module):
|
||||||
def __init__(self, num_layers=3):
|
def __init__(self, num_layers=3):
|
||||||
super(DetailFeatureFusioin, self).__init__()
|
super(DetailFeatureFusion, self).__init__()
|
||||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||||
self.net = nn.Sequential(*INNmodules)
|
self.net = nn.Sequential(*INNmodules)
|
||||||
self.enhancement_module = WTConv2d(32, 32)
|
self.enhancement_module = WTConv2d(32, 32)
|
||||||
|
|
||||||
def forward(self, x): # 1 64 128 128
|
def forward(self, x): # 1 64 128 128
|
||||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
||||||
# 增强并添加残差连接
|
# # 增强并添加残差连接
|
||||||
enhanced_z1 = self.enhancement_module(z1)
|
# enhanced_z1 = self.enhancement_module(z1)
|
||||||
enhanced_z2 = self.enhancement_module(z2)
|
# enhanced_z2 = self.enhancement_module(z2)
|
||||||
# 残差连接
|
# # 残差连接
|
||||||
z1 = z1 + enhanced_z1
|
# z1 = z1 + enhanced_z1
|
||||||
z2 = z2 + enhanced_z2
|
# z2 = z2 + enhanced_z2
|
||||||
for layer in self.net:
|
for layer in self.net:
|
||||||
z1, z2 = layer(z1, z2)
|
z1, z2 = layer(z1, z2)
|
||||||
return torch.cat((z1, z2), dim=1)
|
return torch.cat((z1, z2), dim=1)
|
||||||
@ -508,7 +535,8 @@ class BaseFeatureExtractionSAR(nn.Module):
|
|||||||
|
|
||||||
self.WTConv2d = WTConv2d(dim, dim)
|
self.WTConv2d = WTConv2d(dim, dim)
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
self.token_mixer = SCSA(dim=dim, head_num=8)
|
||||||
|
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||||
@ -554,18 +582,18 @@ class BaseFeatureExtractionSAR(nn.Module):
|
|||||||
class DetailFeatureExtractionSAR(nn.Module):
|
class DetailFeatureExtractionSAR(nn.Module):
|
||||||
def __init__(self, num_layers=3):
|
def __init__(self, num_layers=3):
|
||||||
super(DetailFeatureExtractionSAR, self).__init__()
|
super(DetailFeatureExtractionSAR, self).__init__()
|
||||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)]
|
||||||
self.net = nn.Sequential(*INNmodules)
|
self.net = nn.Sequential(*INNmodules)
|
||||||
self.enhancement_module = WTConv2d(32, 32)
|
self.enhancement_module = WTConv2d(32, 32)
|
||||||
|
|
||||||
def forward(self, x): # 1 64 128 128
|
def forward(self, x): # 1 64 128 128
|
||||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
||||||
# 增强并添加残差连接
|
# # 增强并添加残差连接
|
||||||
enhanced_z1 = self.enhancement_module(z1)
|
# enhanced_z1 = self.enhancement_module(z1)
|
||||||
enhanced_z2 = self.enhancement_module(z2)
|
# enhanced_z2 = self.enhancement_module(z2)
|
||||||
# 残差连接
|
# # 残差连接
|
||||||
z1 = z1 + enhanced_z1
|
# z1 = z1 + enhanced_z1
|
||||||
z2 = z2 + enhanced_z2
|
# z2 = z2 + enhanced_z2
|
||||||
for layer in self.net:
|
for layer in self.net:
|
||||||
z1, z2 = layer(z1, z2)
|
z1, z2 = layer(z1, z2)
|
||||||
return torch.cat((z1, z2), dim=1)
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
2
train.py
2
train.py
@ -87,7 +87,7 @@ DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
|||||||
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||||
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||||
BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device)
|
BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device)
|
||||||
DetailFuseLayer = nn.DataParallel(DetailFeatureFusioin(num_layers=1)).to(device)
|
DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device)
|
||||||
|
|
||||||
# optimizer, scheduler and loss function
|
# optimizer, scheduler and loss function
|
||||||
optimizer1 = torch.optim.Adam(
|
optimizer1 = torch.optim.Adam(
|
||||||
|
Loading…
Reference in New Issue
Block a user