diff --git a/net.py b/net.py index fa384a3..0b9fea7 100644 --- a/net.py +++ b/net.py @@ -71,6 +71,48 @@ class PoolMlp(nn.Module): x = self.drop(x) return x + +class BaseFeatureFusion(nn.Module): + def __init__(self, dim, pool_size=3, mlp_ratio=4., + act_layer=nn.GELU, + # norm_layer=nn.LayerNorm, + drop=0., drop_path=0., + use_layer_scale=True, layer_scale_init_value=1e-5): + + super().__init__() + + self.norm1 = LayerNorm(dim, 'WithBias') + self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 + self.norm2 = LayerNorm(dim, 'WithBias') + mlp_hidden_dim = int(dim * mlp_ratio) + self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop) + + # The following two techniques are useful to train deep PoolFormers. + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.use_layer_scale = use_layer_scale + + if use_layer_scale: + self.layer_scale_1 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + self.layer_scale_2 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + def forward(self, x): + if self.use_layer_scale: + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) + * self.token_mixer(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) + * self.poolmlp(self.norm2(x))) + else: + x = x + self.drop_path(self.token_mixer(self.norm1(x))) + x = x + self.drop_path(self.poolmlp(self.norm2(x))) + return x + class BaseFeatureExtraction(nn.Module): def __init__(self, dim, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, @@ -199,6 +241,17 @@ class DetailNode(nn.Module): z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2) return z1, z2 +class DetailFeatureFusion(nn.Module): + def __init__(self, num_layers=3): + super(DetailFeatureFusion, self).__init__() + INNmodules = [DetailNode() for _ in range(num_layers)] + self.net = nn.Sequential(*INNmodules) + + def forward(self, x): + z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] + for layer in self.net: + z1, z2 = layer(z1, z2) + return torch.cat((z1, z2), dim=1) class DetailFeatureExtraction(nn.Module): def __init__(self, num_layers=3): diff --git a/train.py b/train.py index 9c4e990..4c409fd 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,8 @@ Import packages ------------------------------------------------------------------------------ ''' -from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction +from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion, \ + DetailFeatureFusion from utils.dataset import H5Dataset import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' @@ -85,8 +86,8 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu' DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device) DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device) # BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) -BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device) -DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device) +BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device) +DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device) # optimizer, scheduler and loss function optimizer1 = torch.optim.Adam(