feat(net): 增加 SAR 图像处理支持
- 新增 BaseFeatureExtractionSAR 和 DetailFeatureExtractionSAR 模块 - 修改 DIDF_Encoder 类,支持 SAR 图像输入 - 更新测试和训练脚本,增加 SAR 图像处理相关逻辑
This commit is contained in:
parent
0ef5760d76
commit
e8a0212bbb
72
net.py
72
net.py
@ -112,6 +112,48 @@ class BaseFeatureExtraction(nn.Module):
|
||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
class BaseFeatureExtractionSAR(nn.Module):
|
||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||
act_layer=nn.GELU,
|
||||
# norm_layer=nn.LayerNorm,
|
||||
drop=0., drop_path=0.,
|
||||
use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer, drop=drop)
|
||||
|
||||
# The following two techniques are useful to train deep PoolFormers.
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||
else nn.Identity()
|
||||
self.use_layer_scale = use_layer_scale
|
||||
|
||||
if use_layer_scale:
|
||||
self.layer_scale_1 = nn.Parameter(
|
||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||
|
||||
self.layer_scale_2 = nn.Parameter(
|
||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_layer_scale:
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
||||
* self.token_mixer(self.norm1(x)))
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||
* self.poolmlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
|
||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class InvertedResidualBlock(nn.Module):
|
||||
def __init__(self, inp, oup, expand_ratio):
|
||||
@ -171,6 +213,19 @@ class DetailFeatureExtraction(nn.Module):
|
||||
return torch.cat((z1, z2), dim=1)
|
||||
|
||||
|
||||
class DetailFeatureExtractionSAR(nn.Module):
|
||||
def __init__(self, num_layers=3):
|
||||
super(DetailFeatureExtractionSAR, self).__init__()
|
||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||
self.net = nn.Sequential(*INNmodules)
|
||||
|
||||
def forward(self, x):
|
||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
||||
for layer in self.net:
|
||||
z1, z2 = layer(z1, z2)
|
||||
return torch.cat((z1, z2), dim=1)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
# =============================================================================
|
||||
@ -352,14 +407,23 @@ class Restormer_Encoder(nn.Module):
|
||||
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
|
||||
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
||||
self.baseFeature = BaseFeatureExtraction(dim=dim)
|
||||
|
||||
self.detailFeature = DetailFeatureExtraction()
|
||||
|
||||
def forward(self, inp_img):
|
||||
self.baseFeatureSar= BaseFeatureExtractionSAR(dim=dim)
|
||||
self.detailFeatureSar = DetailFeatureExtractionSAR()
|
||||
|
||||
|
||||
|
||||
def forward(self, inp_img, sar_img=False):
|
||||
inp_enc_level1 = self.patch_embed(inp_img)
|
||||
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
||||
base_feature = self.baseFeature(out_enc_level1)
|
||||
detail_feature = self.detailFeature(out_enc_level1)
|
||||
|
||||
if sar_img:
|
||||
base_feature = self.baseFeature(out_enc_level1)
|
||||
detail_feature = self.detailFeature(out_enc_level1)
|
||||
else:
|
||||
base_feature= self.baseFeature(out_enc_level1)
|
||||
detail_feature = self.detailFeature(out_enc_level1)
|
||||
return base_feature, detail_feature, out_enc_level1
|
||||
|
||||
|
||||
|
@ -17,11 +17,11 @@ current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-15-17-48.pth"
|
||||
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-15-22-09.pth"
|
||||
|
||||
for dataset_name in ["sar"]:
|
||||
print("\n"*2+"="*80)
|
||||
model_name="PFCFuse 最基本版本 "
|
||||
model_name="PFCFuse Enhance "
|
||||
print("The test result of "+dataset_name+' :')
|
||||
test_folder = os.path.join('test_img', dataset_name)
|
||||
test_out_folder=os.path.join('test_result',current_time,dataset_name)
|
||||
|
2
train.py
2
train.py
@ -149,7 +149,7 @@ for epoch in range(num_epochs):
|
||||
|
||||
if epoch < epoch_gap: #Phase I
|
||||
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
|
||||
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR)
|
||||
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR,sar_img=True)
|
||||
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
|
||||
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user