refactor(net): 重构网络结构并移除未使用的类
- 移除了未使用的导入语句 - 删除了多个未使用的类,包括 BaseFeatureExtraction1、BaseFeatureExtractionSAR等 - 重命名了部分类以更好地反映其功能,如将 BaseFeatureFusion 改为 BaseFeatureExtraction 等- 简化了部分代码结构,提高了代码的可读性和维护性
This commit is contained in:
parent
ece5f30c2d
commit
4f805c2449
321
net.py
321
net.py
@ -6,11 +6,7 @@ import torch.utils.checkpoint as checkpoint
|
|||||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from componets.WTConvCV2 import WTConv2d
|
|
||||||
from componets.SCSA import SCSA
|
|
||||||
|
|
||||||
|
|
||||||
# 以一定概率随机丢弃输入张量中的路径,用于正则化模型
|
|
||||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||||
if drop_prob == 0. or not training:
|
if drop_prob == 0. or not training:
|
||||||
return x
|
return x
|
||||||
@ -35,9 +31,6 @@ class DropPath(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return drop_path(x, self.drop_prob, self.training)
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 改点,使用Pooling替换AttentionBase
|
# 改点,使用Pooling替换AttentionBase
|
||||||
class Pooling(nn.Module):
|
class Pooling(nn.Module):
|
||||||
def __init__(self, kernel_size=3):
|
def __init__(self, kernel_size=3):
|
||||||
@ -98,54 +91,46 @@ class PoolMlp(nn.Module):
|
|||||||
x = self.drop(x)
|
x = self.drop(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class BaseFeatureExtractionSAR(nn.Module):
|
||||||
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
# norm_layer=nn.LayerNorm,
|
||||||
|
drop=0., drop_path=0.,
|
||||||
|
use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||||
|
|
||||||
# class BaseFeatureExtraction1(nn.Module):
|
super().__init__()
|
||||||
# def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
|
||||||
# act_layer=nn.GELU,
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
# # norm_layer=nn.LayerNorm,
|
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
# drop=0., drop_path=0.,
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
# use_layer_scale=True, layer_scale_init_value=1e-5):
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
#
|
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||||
# super().__init__()
|
act_layer=act_layer, drop=drop)
|
||||||
#
|
|
||||||
# self.norm1 = LayerNorm(dim, 'WithBias')
|
# The following two techniques are useful to train deep PoolFormers.
|
||||||
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||||
# self.norm2 = LayerNorm(dim, 'WithBias')
|
else nn.Identity()
|
||||||
# mlp_hidden_dim = int(dim * mlp_ratio)
|
self.use_layer_scale = use_layer_scale
|
||||||
# self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
|
||||||
# act_layer=act_layer, drop=drop)
|
if use_layer_scale:
|
||||||
#
|
self.layer_scale_1 = nn.Parameter(
|
||||||
# # The following two techniques are useful to train deep PoolFormers.
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
# self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
|
||||||
# else nn.Identity()
|
self.layer_scale_2 = nn.Parameter(
|
||||||
# self.use_layer_scale = use_layer_scale
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
#
|
|
||||||
# if use_layer_scale:
|
def forward(self, x):
|
||||||
# self.layer_scale_1 = nn.Parameter(
|
if self.use_layer_scale:
|
||||||
# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
x = x + self.drop_path(
|
||||||
#
|
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
||||||
# self.layer_scale_2 = nn.Parameter(
|
* self.token_mixer(self.norm1(x)))
|
||||||
# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
x = x + self.drop_path(
|
||||||
#
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
# def forward(self, x): # 1 64 128 128
|
* self.poolmlp(self.norm2(x)))
|
||||||
# if self.use_layer_scale:
|
else:
|
||||||
# # self.layer_scale_1(64,)
|
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
||||||
# tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
# normal = self.norm1(x) # 1 64 128 128
|
return x
|
||||||
# token_mix = self.token_mixer(normal) # 1 64 128 128
|
|
||||||
# x = (x +
|
|
||||||
# self.drop_path(
|
|
||||||
# tmp1 * token_mix
|
|
||||||
# )
|
|
||||||
# # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
|
||||||
# )
|
|
||||||
# x = x + self.drop_path(
|
|
||||||
# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
|
||||||
# * self.poolmlp(self.norm2(x)))
|
|
||||||
# else:
|
|
||||||
# x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
|
||||||
# x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
|
||||||
# return x
|
|
||||||
|
|
||||||
class BaseFeatureExtraction(nn.Module):
|
class BaseFeatureExtraction(nn.Module):
|
||||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
@ -156,7 +141,6 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
@ -176,21 +160,11 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
self.layer_scale_2 = nn.Parameter(
|
self.layer_scale_2 = nn.Parameter(
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
def forward(self, x): # 1 64 128 128
|
def forward(self, x):
|
||||||
if self.use_layer_scale:
|
if self.use_layer_scale:
|
||||||
# self.layer_scale_1(64,)
|
x = x + self.drop_path(
|
||||||
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
||||||
normal = self.norm1(x) # 1 64 128 128
|
* self.token_mixer(self.norm1(x)))
|
||||||
token_mix = self.token_mixer(normal) # 1 64 128 128
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
x = (x +
|
|
||||||
self.drop_path(
|
|
||||||
tmp1 * token_mix
|
|
||||||
)
|
|
||||||
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
|
||||||
)
|
|
||||||
x = x + self.drop_path(
|
x = x + self.drop_path(
|
||||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
* self.poolmlp(self.norm2(x)))
|
* self.poolmlp(self.norm2(x)))
|
||||||
@ -199,58 +173,6 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class BaseFeatureFusion(nn.Module):
|
|
||||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
|
||||||
act_layer=nn.GELU,
|
|
||||||
# norm_layer=nn.LayerNorm,
|
|
||||||
drop=0., drop_path=0.,
|
|
||||||
use_layer_scale=True, layer_scale_init_value=1e-5):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
|
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
|
||||||
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
|
||||||
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
|
||||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
||||||
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
|
||||||
act_layer=act_layer, drop=drop)
|
|
||||||
|
|
||||||
# The following two techniques are useful to train deep PoolFormers.
|
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
|
||||||
else nn.Identity()
|
|
||||||
self.use_layer_scale = use_layer_scale
|
|
||||||
|
|
||||||
if use_layer_scale:
|
|
||||||
self.layer_scale_1 = nn.Parameter(
|
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
|
||||||
|
|
||||||
self.layer_scale_2 = nn.Parameter(
|
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
|
||||||
|
|
||||||
def forward(self, x): # 1 64 128 128
|
|
||||||
if self.use_layer_scale:
|
|
||||||
# self.layer_scale_1(64,)
|
|
||||||
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
|
||||||
normal = self.norm1(x) # 1 64 128 128
|
|
||||||
token_mix = self.token_mixer(normal) # 1 64 128 128
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
x = (x +
|
|
||||||
self.drop_path(
|
|
||||||
tmp1 * token_mix
|
|
||||||
)
|
|
||||||
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
|
||||||
)
|
|
||||||
x = x + self.drop_path(
|
|
||||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
|
||||||
* self.poolmlp(self.norm2(x)))
|
|
||||||
else:
|
|
||||||
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
|
||||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
|
||||||
return x
|
|
||||||
|
|
||||||
class InvertedResidualBlock(nn.Module):
|
class InvertedResidualBlock(nn.Module):
|
||||||
def __init__(self, inp, oup, expand_ratio):
|
def __init__(self, inp, oup, expand_ratio):
|
||||||
@ -273,38 +195,10 @@ class InvertedResidualBlock(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.bottleneckBlock(x)
|
return self.bottleneckBlock(x)
|
||||||
|
|
||||||
|
|
||||||
class DepthwiseSeparableConvBlock(nn.Module):
|
|
||||||
def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1):
|
|
||||||
super(DepthwiseSeparableConvBlock, self).__init__()
|
|
||||||
self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False)
|
|
||||||
self.pointwise = nn.Conv2d(inp, oup, 1, bias=False)
|
|
||||||
self.bn = nn.BatchNorm2d(oup)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.depthwise(x)
|
|
||||||
x = self.pointwise(x)
|
|
||||||
x = self.bn(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DetailNode(nn.Module):
|
class DetailNode(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
# <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > '
|
|
||||||
def __init__(self,useBlock=0):
|
|
||||||
super(DetailNode, self).__init__()
|
super(DetailNode, self).__init__()
|
||||||
|
|
||||||
if useBlock == 0:
|
|
||||||
self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
|
||||||
self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
|
||||||
self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
|
||||||
elif useBlock == 1:
|
|
||||||
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
|
||||||
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
|
||||||
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
|
||||||
else:
|
|
||||||
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
@ -323,46 +217,28 @@ class DetailNode(nn.Module):
|
|||||||
return z1, z2
|
return z1, z2
|
||||||
|
|
||||||
|
|
||||||
|
class DetailFeatureExtractionSAR(nn.Module):
|
||||||
|
def __init__(self, num_layers=3):
|
||||||
|
super(DetailFeatureExtractionSAR, self).__init__()
|
||||||
|
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||||
|
self.net = nn.Sequential(*INNmodules)
|
||||||
|
def forward(self, x):
|
||||||
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
||||||
|
for layer in self.net:
|
||||||
|
z1, z2 = layer(z1, z2)
|
||||||
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
class DetailFeatureExtraction(nn.Module):
|
class DetailFeatureExtraction(nn.Module):
|
||||||
def __init__(self, num_layers=3):
|
def __init__(self, num_layers=3):
|
||||||
super(DetailFeatureExtraction, self).__init__()
|
super(DetailFeatureExtraction, self).__init__()
|
||||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||||
self.net = nn.Sequential(*INNmodules)
|
self.net = nn.Sequential(*INNmodules)
|
||||||
# self.enhancement_module = WTConv2d(32, 32)
|
def forward(self, x):
|
||||||
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
||||||
def forward(self, x): # 1 64 128 128
|
|
||||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
|
||||||
# # 增强并添加残差连接
|
|
||||||
# enhanced_z1 = self.enhancement_module(z1)
|
|
||||||
# enhanced_z2 = self.enhancement_module(z2)
|
|
||||||
# # 残差连接
|
|
||||||
# z1 = z1 + enhanced_z1
|
|
||||||
# z2 = z2 + enhanced_z2
|
|
||||||
for layer in self.net:
|
for layer in self.net:
|
||||||
z1, z2 = layer(z1, z2)
|
z1, z2 = layer(z1, z2)
|
||||||
return torch.cat((z1, z2), dim=1)
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
class DetailFeatureFusion(nn.Module):
|
|
||||||
def __init__(self, num_layers=3):
|
|
||||||
super(DetailFeatureFusion, self).__init__()
|
|
||||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
|
||||||
self.net = nn.Sequential(*INNmodules)
|
|
||||||
# self.enhancement_module = WTConv2d(32, 32)
|
|
||||||
|
|
||||||
def forward(self, x): # 1 64 128 128
|
|
||||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
|
||||||
# # 增强并添加残差连接
|
|
||||||
# enhanced_z1 = self.enhancement_module(z1)
|
|
||||||
# enhanced_z2 = self.enhancement_module(z2)
|
|
||||||
# # 残差连接
|
|
||||||
# z1 = z1 + enhanced_z1
|
|
||||||
# z2 = z2 + enhanced_z2
|
|
||||||
for layer in self.net:
|
|
||||||
z1, z2 = layer(z1, z2)
|
|
||||||
return torch.cat((z1, z2), dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -525,80 +401,6 @@ class OverlapPatchEmbed(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class BaseFeatureExtractionSAR(nn.Module):
|
|
||||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
|
||||||
act_layer=nn.GELU,
|
|
||||||
# norm_layer=nn.LayerNorm,
|
|
||||||
drop=0., drop_path=0.,
|
|
||||||
use_layer_scale=True, layer_scale_init_value=1e-5):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.WTConv2d = WTConv2d(dim, dim)
|
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
|
||||||
self.token_mixer = SCSA(dim=dim, head_num=8)
|
|
||||||
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
|
||||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
||||||
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
|
||||||
act_layer=act_layer, drop=drop)
|
|
||||||
|
|
||||||
# The following two techniques are useful to train deep PoolFormers.
|
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
|
||||||
else nn.Identity()
|
|
||||||
self.use_layer_scale = use_layer_scale
|
|
||||||
|
|
||||||
if use_layer_scale:
|
|
||||||
self.layer_scale_1 = nn.Parameter(
|
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
|
||||||
|
|
||||||
self.layer_scale_2 = nn.Parameter(
|
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
|
||||||
|
|
||||||
def forward(self, x): # 1 64 128 128
|
|
||||||
if self.use_layer_scale:
|
|
||||||
# self.layer_scale_1(64,)
|
|
||||||
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
|
||||||
normal = self.norm1(x) # 1 64 128 128
|
|
||||||
token_mix = self.token_mixer(normal) # 1 64 128 128
|
|
||||||
|
|
||||||
x = (x +
|
|
||||||
self.drop_path(
|
|
||||||
tmp1 * token_mix
|
|
||||||
)
|
|
||||||
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
|
||||||
)
|
|
||||||
x = x + self.drop_path(
|
|
||||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
|
||||||
* self.poolmlp(self.norm2(x)))
|
|
||||||
else:
|
|
||||||
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
|
||||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DetailFeatureExtractionSAR(nn.Module):
|
|
||||||
def __init__(self, num_layers=3):
|
|
||||||
super(DetailFeatureExtractionSAR, self).__init__()
|
|
||||||
INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)]
|
|
||||||
self.net = nn.Sequential(*INNmodules)
|
|
||||||
# self.enhancement_module = WTConv2d(32, 32)
|
|
||||||
|
|
||||||
def forward(self, x): # 1 64 128 128
|
|
||||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
|
||||||
# # 增强并添加残差连接
|
|
||||||
# enhanced_z1 = self.enhancement_module(z1)
|
|
||||||
# enhanced_z2 = self.enhancement_module(z2)
|
|
||||||
# # 残差连接
|
|
||||||
# z1 = z1 + enhanced_z1
|
|
||||||
# z2 = z2 + enhanced_z2
|
|
||||||
for layer in self.net:
|
|
||||||
z1, z2 = layer(z1, z2)
|
|
||||||
return torch.cat((z1, z2), dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Restormer_Encoder(nn.Module):
|
class Restormer_Encoder(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
inp_channels=1,
|
inp_channels=1,
|
||||||
@ -612,13 +414,13 @@ class Restormer_Encoder(nn.Module):
|
|||||||
):
|
):
|
||||||
super(Restormer_Encoder, self).__init__()
|
super(Restormer_Encoder, self).__init__()
|
||||||
|
|
||||||
# 区分
|
|
||||||
|
|
||||||
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
||||||
|
|
||||||
self.encoder_level1 = nn.Sequential(
|
self.encoder_level1 = nn.Sequential(
|
||||||
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
|
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
|
||||||
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
||||||
self.baseFeature = BaseFeatureExtraction(dim=dim)
|
self.baseFeature = BaseFeatureExtraction(dim=dim)
|
||||||
|
|
||||||
self.detailFeature = DetailFeatureExtraction()
|
self.detailFeature = DetailFeatureExtraction()
|
||||||
|
|
||||||
self.baseFeature_sar = BaseFeatureExtractionSAR(dim=dim)
|
self.baseFeature_sar = BaseFeatureExtractionSAR(dim=dim)
|
||||||
@ -652,7 +454,8 @@ class Restormer_Decoder(nn.Module):
|
|||||||
|
|
||||||
super(Restormer_Decoder, self).__init__()
|
super(Restormer_Decoder, self).__init__()
|
||||||
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
|
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
|
||||||
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
|
self.encoder_level2 = nn.Sequential(
|
||||||
|
*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
|
||||||
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
||||||
self.output = nn.Sequential(
|
self.output = nn.Sequential(
|
||||||
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
|
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
|
||||||
@ -679,5 +482,3 @@ if __name__ == '__main__':
|
|||||||
window_size = 8
|
window_size = 8
|
||||||
modelE = Restormer_Encoder().cuda()
|
modelE = Restormer_Encoder().cuda()
|
||||||
modelD = Restormer_Decoder().cuda()
|
modelD = Restormer_Decoder().cuda()
|
||||||
print(modelE)
|
|
||||||
print(modelD)
|
|
||||||
|
6
train.py
6
train.py
@ -6,7 +6,7 @@ Import packages
|
|||||||
------------------------------------------------------------------------------
|
------------------------------------------------------------------------------
|
||||||
'''
|
'''
|
||||||
|
|
||||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion,DetailFeatureFusion
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||||
from utils.dataset import H5Dataset
|
from utils.dataset import H5Dataset
|
||||||
import os
|
import os
|
||||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||||
@ -85,8 +85,8 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|||||||
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||||
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||||
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||||
BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device)
|
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
|
||||||
DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device)
|
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||||
|
|
||||||
# optimizer, scheduler and loss function
|
# optimizer, scheduler and loss function
|
||||||
optimizer1 = torch.optim.Adam(
|
optimizer1 = torch.optim.Adam(
|
||||||
|
Loading…
Reference in New Issue
Block a user