feat(net): 新增基础特征融合和细节特征融合模块

- 添加了 BaseFeatureFusion 和 DetailFeatureFusion 两个新类
- 更新了 train.py 中的导入和实例化语句
This commit is contained in:
zjut 2024-11-16 21:39:34 +08:00
parent 1bd418f0e4
commit 0cf1726eeb
2 changed files with 57 additions and 3 deletions

53
net.py
View File

@ -71,6 +71,48 @@ class PoolMlp(nn.Module):
x = self.drop(x) x = self.drop(x)
return x return x
class BaseFeatureFusion(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x):
if self.use_layer_scale:
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
* self.token_mixer(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class BaseFeatureExtraction(nn.Module): class BaseFeatureExtraction(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4., def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, act_layer=nn.GELU,
@ -199,6 +241,17 @@ class DetailNode(nn.Module):
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2) z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
return z1, z2 return z1, z2
class DetailFeatureFusion(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureFusion, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
def forward(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
class DetailFeatureExtraction(nn.Module): class DetailFeatureExtraction(nn.Module):
def __init__(self, num_layers=3): def __init__(self, num_layers=3):

View File

@ -6,7 +6,8 @@ Import packages
------------------------------------------------------------------------------ ------------------------------------------------------------------------------
''' '''
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion, \
DetailFeatureFusion
from utils.dataset import H5Dataset from utils.dataset import H5Dataset
import os import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
@ -85,8 +86,8 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device) DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device) DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) # BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device) BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device) DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device)
# optimizer, scheduler and loss function # optimizer, scheduler and loss function
optimizer1 = torch.optim.Adam( optimizer1 = torch.optim.Adam(