refactor(net): 重构网络结构并移除未使用的类

- 移除了未使用的导入语句
- 删除了多个未使用的类,包括 BaseFeatureExtraction1、BaseFeatureExtractionSAR等
- 重命名了部分类以更好地反映其功能,如将 BaseFeatureFusion 改为 BaseFeatureExtraction 等- 简化了部分代码结构,提高了代码的可读性和维护性
This commit is contained in:
zjut 2024-11-15 11:07:25 +08:00
parent ece5f30c2d
commit 4f805c2449
2 changed files with 70 additions and 269 deletions

323
net.py
View File

@ -6,11 +6,7 @@ import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange from einops import rearrange
from componets.WTConvCV2 import WTConv2d
from componets.SCSA import SCSA
# 以一定概率随机丢弃输入张量中的路径,用于正则化模型
def drop_path(x, drop_prob: float = 0., training: bool = False): def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training: if drop_prob == 0. or not training:
return x return x
@ -35,9 +31,6 @@ class DropPath(nn.Module):
def forward(self, x): def forward(self, x):
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
# 改点使用Pooling替换AttentionBase # 改点使用Pooling替换AttentionBase
class Pooling(nn.Module): class Pooling(nn.Module):
def __init__(self, kernel_size=3): def __init__(self, kernel_size=3):
@ -98,54 +91,46 @@ class PoolMlp(nn.Module):
x = self.drop(x) x = self.drop(x)
return x return x
class BaseFeatureExtractionSAR(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
# class BaseFeatureExtraction1(nn.Module): super().__init__()
# def __init__(self, dim, pool_size=3, mlp_ratio=4.,
# act_layer=nn.GELU, self.norm1 = LayerNorm(dim, 'WithBias')
# # norm_layer=nn.LayerNorm, self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
# drop=0., drop_path=0., self.norm2 = LayerNorm(dim, 'WithBias')
# use_layer_scale=True, layer_scale_init_value=1e-5): mlp_hidden_dim = int(dim * mlp_ratio)
# self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
# super().__init__() act_layer=act_layer, drop=drop)
#
# self.norm1 = LayerNorm(dim, 'WithBias') # The following two techniques are useful to train deep PoolFormers.
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代 self.drop_path = DropPath(drop_path) if drop_path > 0. \
# self.norm2 = LayerNorm(dim, 'WithBias') else nn.Identity()
# mlp_hidden_dim = int(dim * mlp_ratio) self.use_layer_scale = use_layer_scale
# self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
# act_layer=act_layer, drop=drop) if use_layer_scale:
# self.layer_scale_1 = nn.Parameter(
# # The following two techniques are useful to train deep PoolFormers. torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
# self.drop_path = DropPath(drop_path) if drop_path > 0. \
# else nn.Identity() self.layer_scale_2 = nn.Parameter(
# self.use_layer_scale = use_layer_scale torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
#
# if use_layer_scale: def forward(self, x):
# self.layer_scale_1 = nn.Parameter( if self.use_layer_scale:
# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) x = x + self.drop_path(
# self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
# self.layer_scale_2 = nn.Parameter( * self.token_mixer(self.norm1(x)))
# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) x = x + self.drop_path(
# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
# def forward(self, x): # 1 64 128 128 * self.poolmlp(self.norm2(x)))
# if self.use_layer_scale: else:
# # self.layer_scale_1(64,) x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
# tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1 x = x + self.drop_path(self.poolmlp(self.norm2(x)))
# normal = self.norm1(x) # 1 64 128 128 return x
# token_mix = self.token_mixer(normal) # 1 64 128 128
# x = (x +
# self.drop_path(
# tmp1 * token_mix
# )
# # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
# )
# x = x + self.drop_path(
# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
# * self.poolmlp(self.norm2(x)))
# else:
# x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
# x = x + self.drop_path(self.poolmlp(self.norm2(x)))
# return x
class BaseFeatureExtraction(nn.Module): class BaseFeatureExtraction(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4., def __init__(self, dim, pool_size=3, mlp_ratio=4.,
@ -156,7 +141,6 @@ class BaseFeatureExtraction(nn.Module):
super().__init__() super().__init__()
self.norm1 = LayerNorm(dim, 'WithBias') self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代 self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias') self.norm2 = LayerNorm(dim, 'WithBias')
@ -176,21 +160,11 @@ class BaseFeatureExtraction(nn.Module):
self.layer_scale_2 = nn.Parameter( self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x): # 1 64 128 128 def forward(self, x):
if self.use_layer_scale: if self.use_layer_scale:
# self.layer_scale_1(64,) x = x + self.drop_path(
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1 self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
normal = self.norm1(x) # 1 64 128 128 * self.token_mixer(self.norm1(x)))
token_mix = self.token_mixer(normal) # 1 64 128 128
x = (x +
self.drop_path(
tmp1 * token_mix
)
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
)
x = x + self.drop_path( x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x))) * self.poolmlp(self.norm2(x)))
@ -199,58 +173,6 @@ class BaseFeatureExtraction(nn.Module):
x = x + self.drop_path(self.poolmlp(self.norm2(x))) x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x return x
class BaseFeatureFusion(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.norm1 = LayerNorm(dim, 'WithBias')
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x): # 1 64 128 128
if self.use_layer_scale:
# self.layer_scale_1(64,)
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
normal = self.norm1(x) # 1 64 128 128
token_mix = self.token_mixer(normal) # 1 64 128 128
x = (x +
self.drop_path(
tmp1 * token_mix
)
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
)
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class InvertedResidualBlock(nn.Module): class InvertedResidualBlock(nn.Module):
def __init__(self, inp, oup, expand_ratio): def __init__(self, inp, oup, expand_ratio):
@ -273,38 +195,10 @@ class InvertedResidualBlock(nn.Module):
def forward(self, x): def forward(self, x):
return self.bottleneckBlock(x) return self.bottleneckBlock(x)
class DepthwiseSeparableConvBlock(nn.Module):
def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1):
super(DepthwiseSeparableConvBlock, self).__init__()
self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False)
self.pointwise = nn.Conv2d(inp, oup, 1, bias=False)
self.bn = nn.BatchNorm2d(oup)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
x = self.bn(x)
x = self.relu(x)
return x
class DetailNode(nn.Module): class DetailNode(nn.Module):
def __init__(self):
# <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > '
def __init__(self,useBlock=0):
super(DetailNode, self).__init__() super(DetailNode, self).__init__()
if useBlock == 0:
self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32)
self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32)
self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32)
elif useBlock == 1:
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
else:
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
@ -323,46 +217,28 @@ class DetailNode(nn.Module):
return z1, z2 return z1, z2
class DetailFeatureExtractionSAR(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureExtractionSAR, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
def forward(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
class DetailFeatureExtraction(nn.Module): class DetailFeatureExtraction(nn.Module):
def __init__(self, num_layers=3): def __init__(self, num_layers=3):
super(DetailFeatureExtraction, self).__init__() super(DetailFeatureExtraction, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)] INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules) self.net = nn.Sequential(*INNmodules)
# self.enhancement_module = WTConv2d(32, 32) def forward(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
# # 增强并添加残差连接
# enhanced_z1 = self.enhancement_module(z1)
# enhanced_z2 = self.enhancement_module(z2)
# # 残差连接
# z1 = z1 + enhanced_z1
# z2 = z2 + enhanced_z2
for layer in self.net: for layer in self.net:
z1, z2 = layer(z1, z2) z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1) return torch.cat((z1, z2), dim=1)
class DetailFeatureFusion(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureFusion, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
# self.enhancement_module = WTConv2d(32, 32)
def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
# # 增强并添加残差连接
# enhanced_z1 = self.enhancement_module(z1)
# enhanced_z2 = self.enhancement_module(z2)
# # 残差连接
# z1 = z1 + enhanced_z1
# z2 = z2 + enhanced_z2
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
# ============================================================================= # =============================================================================
# ============================================================================= # =============================================================================
@ -525,80 +401,6 @@ class OverlapPatchEmbed(nn.Module):
return x return x
class BaseFeatureExtractionSAR(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.WTConv2d = WTConv2d(dim, dim)
self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = SCSA(dim=dim, head_num=8)
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x): # 1 64 128 128
if self.use_layer_scale:
# self.layer_scale_1(64,)
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
normal = self.norm1(x) # 1 64 128 128
token_mix = self.token_mixer(normal) # 1 64 128 128
x = (x +
self.drop_path(
tmp1 * token_mix
)
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
)
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class DetailFeatureExtractionSAR(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureExtractionSAR, self).__init__()
INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
# self.enhancement_module = WTConv2d(32, 32)
def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
# # 增强并添加残差连接
# enhanced_z1 = self.enhancement_module(z1)
# enhanced_z2 = self.enhancement_module(z2)
# # 残差连接
# z1 = z1 + enhanced_z1
# z2 = z2 + enhanced_z2
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
class Restormer_Encoder(nn.Module): class Restormer_Encoder(nn.Module):
def __init__(self, def __init__(self,
inp_channels=1, inp_channels=1,
@ -612,19 +414,19 @@ class Restormer_Encoder(nn.Module):
): ):
super(Restormer_Encoder, self).__init__() super(Restormer_Encoder, self).__init__()
# 区分
self.patch_embed = OverlapPatchEmbed(inp_channels, dim) self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_level1 = nn.Sequential( self.encoder_level1 = nn.Sequential(
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, *[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
self.baseFeature = BaseFeatureExtraction(dim=dim) self.baseFeature = BaseFeatureExtraction(dim=dim)
self.detailFeature = DetailFeatureExtraction() self.detailFeature = DetailFeatureExtraction()
self.baseFeature_sar = BaseFeatureExtractionSAR(dim=dim) self.baseFeature_sar = BaseFeatureExtractionSAR(dim=dim)
self.detailFeature_sar = DetailFeatureExtractionSAR() self.detailFeature_sar = DetailFeatureExtractionSAR()
def forward(self, inp_img,is_sar = False): def forward(self, inp_img,is_sar=False):
inp_enc_level1 = self.patch_embed(inp_img) inp_enc_level1 = self.patch_embed(inp_img)
out_enc_level1 = self.encoder_level1(inp_enc_level1) out_enc_level1 = self.encoder_level1(inp_enc_level1)
if is_sar: if is_sar:
@ -652,7 +454,8 @@ class Restormer_Decoder(nn.Module):
super(Restormer_Decoder, self).__init__() super(Restormer_Decoder, self).__init__()
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias) self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, self.encoder_level2 = nn.Sequential(
*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.output = nn.Sequential( self.output = nn.Sequential(
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3, nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
@ -679,5 +482,3 @@ if __name__ == '__main__':
window_size = 8 window_size = 8
modelE = Restormer_Encoder().cuda() modelE = Restormer_Encoder().cuda()
modelD = Restormer_Decoder().cuda() modelD = Restormer_Decoder().cuda()
print(modelE)
print(modelD)

View File

@ -6,7 +6,7 @@ Import packages
------------------------------------------------------------------------------ ------------------------------------------------------------------------------
''' '''
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion,DetailFeatureFusion from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
from utils.dataset import H5Dataset from utils.dataset import H5Dataset
import os import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
@ -85,8 +85,8 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device) DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device) DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) # BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device) BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device) DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
# optimizer, scheduler and loss function # optimizer, scheduler and loss function
optimizer1 = torch.optim.Adam( optimizer1 = torch.optim.Adam(