feat(net): 新增基础特征融合和细节特征融合模块
- 添加了 BaseFeatureFusion 和 DetailFeatureFusion 两个新类 - 更新了 train.py 中的导入和实例化语句
This commit is contained in:
parent
1bd418f0e4
commit
0cf1726eeb
53
net.py
53
net.py
@ -71,6 +71,48 @@ class PoolMlp(nn.Module):
|
|||||||
x = self.drop(x)
|
x = self.drop(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFeatureFusion(nn.Module):
|
||||||
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
# norm_layer=nn.LayerNorm,
|
||||||
|
drop=0., drop_path=0.,
|
||||||
|
use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
|
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer, drop=drop)
|
||||||
|
|
||||||
|
# The following two techniques are useful to train deep PoolFormers.
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||||
|
else nn.Identity()
|
||||||
|
self.use_layer_scale = use_layer_scale
|
||||||
|
|
||||||
|
if use_layer_scale:
|
||||||
|
self.layer_scale_1 = nn.Parameter(
|
||||||
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
|
self.layer_scale_2 = nn.Parameter(
|
||||||
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_layer_scale:
|
||||||
|
x = x + self.drop_path(
|
||||||
|
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
* self.token_mixer(self.norm1(x)))
|
||||||
|
x = x + self.drop_path(
|
||||||
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
* self.poolmlp(self.norm2(x)))
|
||||||
|
else:
|
||||||
|
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
|
||||||
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
class BaseFeatureExtraction(nn.Module):
|
class BaseFeatureExtraction(nn.Module):
|
||||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
@ -199,6 +241,17 @@ class DetailNode(nn.Module):
|
|||||||
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
|
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
|
||||||
return z1, z2
|
return z1, z2
|
||||||
|
|
||||||
|
class DetailFeatureFusion(nn.Module):
|
||||||
|
def __init__(self, num_layers=3):
|
||||||
|
super(DetailFeatureFusion, self).__init__()
|
||||||
|
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||||
|
self.net = nn.Sequential(*INNmodules)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
||||||
|
for layer in self.net:
|
||||||
|
z1, z2 = layer(z1, z2)
|
||||||
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
class DetailFeatureExtraction(nn.Module):
|
class DetailFeatureExtraction(nn.Module):
|
||||||
def __init__(self, num_layers=3):
|
def __init__(self, num_layers=3):
|
||||||
|
7
train.py
7
train.py
@ -6,7 +6,8 @@ Import packages
|
|||||||
------------------------------------------------------------------------------
|
------------------------------------------------------------------------------
|
||||||
'''
|
'''
|
||||||
|
|
||||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion, \
|
||||||
|
DetailFeatureFusion
|
||||||
from utils.dataset import H5Dataset
|
from utils.dataset import H5Dataset
|
||||||
import os
|
import os
|
||||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||||
@ -85,8 +86,8 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|||||||
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||||
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||||
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||||
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
|
BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device)
|
||||||
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device)
|
||||||
|
|
||||||
# optimizer, scheduler and loss function
|
# optimizer, scheduler and loss function
|
||||||
optimizer1 = torch.optim.Adam(
|
optimizer1 = torch.optim.Adam(
|
||||||
|
Loading…
Reference in New Issue
Block a user