From 1bd418f0e495732e9250370fb82720aa7b45ddcd Mon Sep 17 00:00:00 2001 From: zjut Date: Sat, 16 Nov 2024 12:59:07 +0800 Subject: [PATCH] =?UTF-8?q?feat(net):=20=E5=A2=9E=E5=8A=A0=20SAR=20?= =?UTF-8?q?=E5=9B=BE=E5=83=8F=E5=A4=84=E7=90=86=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 BaseFeatureExtractionSAR 和 DetailFeatureExtractionSAR 模块 - 修改 DIDF_Encoder 类,支持 SAR 图像输入 - 更新测试和训练脚本,增加 SAR 图像处理相关逻辑 --- net.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- train.py | 2 +- 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/net.py b/net.py index e404ed0..fa384a3 100644 --- a/net.py +++ b/net.py @@ -112,6 +112,48 @@ class BaseFeatureExtraction(nn.Module): x = x + self.drop_path(self.poolmlp(self.norm2(x))) return x +class BaseFeatureExtractionSAR(nn.Module): + def __init__(self, dim, pool_size=3, mlp_ratio=4., + act_layer=nn.GELU, + # norm_layer=nn.LayerNorm, + drop=0., drop_path=0., + use_layer_scale=True, layer_scale_init_value=1e-5): + + super().__init__() + + self.norm1 = LayerNorm(dim, 'WithBias') + self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 + self.norm2 = LayerNorm(dim, 'WithBias') + mlp_hidden_dim = int(dim * mlp_ratio) + self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop) + + # The following two techniques are useful to train deep PoolFormers. + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.use_layer_scale = use_layer_scale + + if use_layer_scale: + self.layer_scale_1 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + self.layer_scale_2 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + def forward(self, x): + if self.use_layer_scale: + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) + * self.token_mixer(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) + * self.poolmlp(self.norm2(x))) + else: + x = x + self.drop_path(self.token_mixer(self.norm1(x))) + x = x + self.drop_path(self.poolmlp(self.norm2(x))) + return x + + class InvertedResidualBlock(nn.Module): def __init__(self, inp, oup, expand_ratio): @@ -171,6 +213,19 @@ class DetailFeatureExtraction(nn.Module): return torch.cat((z1, z2), dim=1) +class DetailFeatureExtractionSAR(nn.Module): + def __init__(self, num_layers=3): + super(DetailFeatureExtractionSAR, self).__init__() + INNmodules = [DetailNode() for _ in range(num_layers)] + self.net = nn.Sequential(*INNmodules) + + def forward(self, x): + z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] + for layer in self.net: + z1, z2 = layer(z1, z2) + return torch.cat((z1, z2), dim=1) + + # ============================================================================= # ============================================================================= @@ -352,14 +407,23 @@ class Restormer_Encoder(nn.Module): *[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) self.baseFeature = BaseFeatureExtraction(dim=dim) - self.detailFeature = DetailFeatureExtraction() - def forward(self, inp_img): + self.baseFeatureSar= BaseFeatureExtractionSAR(dim=dim) + self.detailFeatureSar = DetailFeatureExtractionSAR() + + + + def forward(self, inp_img, sar_img=False): inp_enc_level1 = self.patch_embed(inp_img) out_enc_level1 = self.encoder_level1(inp_enc_level1) - base_feature = self.baseFeature(out_enc_level1) - detail_feature = self.detailFeature(out_enc_level1) + + if sar_img: + base_feature = self.baseFeature(out_enc_level1) + detail_feature = self.detailFeature(out_enc_level1) + else: + base_feature= self.baseFeature(out_enc_level1) + detail_feature = self.detailFeature(out_enc_level1) return base_feature, detail_feature, out_enc_level1 diff --git a/train.py b/train.py index dfb5ec2..9c4e990 100644 --- a/train.py +++ b/train.py @@ -149,7 +149,7 @@ for epoch in range(num_epochs): if epoch < epoch_gap: #Phase I feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS) - feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR) + feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR,sar_img=True) data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D) data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)