feat(net): 增加 SAR 图像处理支持
- 新增 BaseFeatureExtractionSAR 和 DetailFeatureExtractionSAR 模块 - 修改 DIDF_Encoder 类,支持 SAR 图像输入 - 更新测试和训练脚本,增加 SAR 图像处理相关逻辑
This commit is contained in:
parent
f87a65e68e
commit
1bd418f0e4
68
net.py
68
net.py
@ -112,6 +112,48 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class BaseFeatureExtractionSAR(nn.Module):
|
||||||
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
# norm_layer=nn.LayerNorm,
|
||||||
|
drop=0., drop_path=0.,
|
||||||
|
use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
|
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer, drop=drop)
|
||||||
|
|
||||||
|
# The following two techniques are useful to train deep PoolFormers.
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||||
|
else nn.Identity()
|
||||||
|
self.use_layer_scale = use_layer_scale
|
||||||
|
|
||||||
|
if use_layer_scale:
|
||||||
|
self.layer_scale_1 = nn.Parameter(
|
||||||
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
|
self.layer_scale_2 = nn.Parameter(
|
||||||
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_layer_scale:
|
||||||
|
x = x + self.drop_path(
|
||||||
|
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
* self.token_mixer(self.norm1(x)))
|
||||||
|
x = x + self.drop_path(
|
||||||
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
* self.poolmlp(self.norm2(x)))
|
||||||
|
else:
|
||||||
|
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
|
||||||
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class InvertedResidualBlock(nn.Module):
|
class InvertedResidualBlock(nn.Module):
|
||||||
def __init__(self, inp, oup, expand_ratio):
|
def __init__(self, inp, oup, expand_ratio):
|
||||||
@ -171,6 +213,19 @@ class DetailFeatureExtraction(nn.Module):
|
|||||||
return torch.cat((z1, z2), dim=1)
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class DetailFeatureExtractionSAR(nn.Module):
|
||||||
|
def __init__(self, num_layers=3):
|
||||||
|
super(DetailFeatureExtractionSAR, self).__init__()
|
||||||
|
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||||
|
self.net = nn.Sequential(*INNmodules)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
||||||
|
for layer in self.net:
|
||||||
|
z1, z2 = layer(z1, z2)
|
||||||
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -352,12 +407,21 @@ class Restormer_Encoder(nn.Module):
|
|||||||
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
|
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
|
||||||
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
||||||
self.baseFeature = BaseFeatureExtraction(dim=dim)
|
self.baseFeature = BaseFeatureExtraction(dim=dim)
|
||||||
|
|
||||||
self.detailFeature = DetailFeatureExtraction()
|
self.detailFeature = DetailFeatureExtraction()
|
||||||
|
|
||||||
def forward(self, inp_img):
|
self.baseFeatureSar= BaseFeatureExtractionSAR(dim=dim)
|
||||||
|
self.detailFeatureSar = DetailFeatureExtractionSAR()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, inp_img, sar_img=False):
|
||||||
inp_enc_level1 = self.patch_embed(inp_img)
|
inp_enc_level1 = self.patch_embed(inp_img)
|
||||||
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
||||||
|
|
||||||
|
if sar_img:
|
||||||
|
base_feature = self.baseFeature(out_enc_level1)
|
||||||
|
detail_feature = self.detailFeature(out_enc_level1)
|
||||||
|
else:
|
||||||
base_feature= self.baseFeature(out_enc_level1)
|
base_feature= self.baseFeature(out_enc_level1)
|
||||||
detail_feature = self.detailFeature(out_enc_level1)
|
detail_feature = self.detailFeature(out_enc_level1)
|
||||||
return base_feature, detail_feature, out_enc_level1
|
return base_feature, detail_feature, out_enc_level1
|
||||||
|
2
train.py
2
train.py
@ -149,7 +149,7 @@ for epoch in range(num_epochs):
|
|||||||
|
|
||||||
if epoch < epoch_gap: #Phase I
|
if epoch < epoch_gap: #Phase I
|
||||||
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
|
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
|
||||||
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR)
|
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR,sar_img=True)
|
||||||
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
|
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
|
||||||
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
|
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user