From e8a0212bbbdf8147b6669623281d57b4f0ab153c Mon Sep 17 00:00:00 2001 From: zjut Date: Sat, 16 Nov 2024 12:59:07 +0800 Subject: [PATCH 01/10] =?UTF-8?q?feat(net):=20=E5=A2=9E=E5=8A=A0=20SAR=20?= =?UTF-8?q?=E5=9B=BE=E5=83=8F=E5=A4=84=E7=90=86=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 BaseFeatureExtractionSAR 和 DetailFeatureExtractionSAR 模块 - 修改 DIDF_Encoder 类,支持 SAR 图像输入 - 更新测试和训练脚本,增加 SAR 图像处理相关逻辑 --- net.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++--- test_IVF.py | 4 +-- train.py | 2 +- 3 files changed, 71 insertions(+), 7 deletions(-) diff --git a/net.py b/net.py index e404ed0..fa384a3 100644 --- a/net.py +++ b/net.py @@ -112,6 +112,48 @@ class BaseFeatureExtraction(nn.Module): x = x + self.drop_path(self.poolmlp(self.norm2(x))) return x +class BaseFeatureExtractionSAR(nn.Module): + def __init__(self, dim, pool_size=3, mlp_ratio=4., + act_layer=nn.GELU, + # norm_layer=nn.LayerNorm, + drop=0., drop_path=0., + use_layer_scale=True, layer_scale_init_value=1e-5): + + super().__init__() + + self.norm1 = LayerNorm(dim, 'WithBias') + self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 + self.norm2 = LayerNorm(dim, 'WithBias') + mlp_hidden_dim = int(dim * mlp_ratio) + self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop) + + # The following two techniques are useful to train deep PoolFormers. + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.use_layer_scale = use_layer_scale + + if use_layer_scale: + self.layer_scale_1 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + self.layer_scale_2 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + def forward(self, x): + if self.use_layer_scale: + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) + * self.token_mixer(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) + * self.poolmlp(self.norm2(x))) + else: + x = x + self.drop_path(self.token_mixer(self.norm1(x))) + x = x + self.drop_path(self.poolmlp(self.norm2(x))) + return x + + class InvertedResidualBlock(nn.Module): def __init__(self, inp, oup, expand_ratio): @@ -171,6 +213,19 @@ class DetailFeatureExtraction(nn.Module): return torch.cat((z1, z2), dim=1) +class DetailFeatureExtractionSAR(nn.Module): + def __init__(self, num_layers=3): + super(DetailFeatureExtractionSAR, self).__init__() + INNmodules = [DetailNode() for _ in range(num_layers)] + self.net = nn.Sequential(*INNmodules) + + def forward(self, x): + z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] + for layer in self.net: + z1, z2 = layer(z1, z2) + return torch.cat((z1, z2), dim=1) + + # ============================================================================= # ============================================================================= @@ -352,14 +407,23 @@ class Restormer_Encoder(nn.Module): *[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) self.baseFeature = BaseFeatureExtraction(dim=dim) - self.detailFeature = DetailFeatureExtraction() - def forward(self, inp_img): + self.baseFeatureSar= BaseFeatureExtractionSAR(dim=dim) + self.detailFeatureSar = DetailFeatureExtractionSAR() + + + + def forward(self, inp_img, sar_img=False): inp_enc_level1 = self.patch_embed(inp_img) out_enc_level1 = self.encoder_level1(inp_enc_level1) - base_feature = self.baseFeature(out_enc_level1) - detail_feature = self.detailFeature(out_enc_level1) + + if sar_img: + base_feature = self.baseFeature(out_enc_level1) + detail_feature = self.detailFeature(out_enc_level1) + else: + base_feature= self.baseFeature(out_enc_level1) + detail_feature = self.detailFeature(out_enc_level1) return base_feature, detail_feature, out_enc_level1 diff --git a/test_IVF.py b/test_IVF.py index 04e10ce..148ed1f 100644 --- a/test_IVF.py +++ b/test_IVF.py @@ -17,11 +17,11 @@ current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") os.environ["CUDA_VISIBLE_DEVICES"] = "0" -ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-15-17-48.pth" +ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-15-22-09.pth" for dataset_name in ["sar"]: print("\n"*2+"="*80) - model_name="PFCFuse 最基本版本 " + model_name="PFCFuse Enhance " print("The test result of "+dataset_name+' :') test_folder = os.path.join('test_img', dataset_name) test_out_folder=os.path.join('test_result',current_time,dataset_name) diff --git a/train.py b/train.py index 5bf94a8..9f2a1de 100644 --- a/train.py +++ b/train.py @@ -149,7 +149,7 @@ for epoch in range(num_epochs): if epoch < epoch_gap: #Phase I feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS) - feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR) + feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR,sar_img=True) data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D) data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D) From 9224f9b640ddee1610eca4a172e516cb2aaad68f Mon Sep 17 00:00:00 2001 From: zjut Date: Sat, 16 Nov 2024 21:37:02 +0800 Subject: [PATCH 02/10] =?UTF-8?q?test:=20=E6=9B=B4=E6=96=B0=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E6=A8=A1=E5=9E=8B=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修改了 test_IVF.py 文件中的模型路径 - 将旧路径 whaiFusion11-15-22-09.pth 更改为新路径 whaiFusion11-16-11-20.pth --- test_IVF.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_IVF.py b/test_IVF.py index 148ed1f..467f7fb 100644 --- a/test_IVF.py +++ b/test_IVF.py @@ -17,7 +17,7 @@ current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") os.environ["CUDA_VISIBLE_DEVICES"] = "0" -ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-15-22-09.pth" +ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-16-11-20.pth" for dataset_name in ["sar"]: print("\n"*2+"="*80) From aae81d97fdc1221427f144711a7cd75c228903be Mon Sep 17 00:00:00 2001 From: zjut Date: Sat, 16 Nov 2024 21:39:34 +0800 Subject: [PATCH 03/10] =?UTF-8?q?feat(net):=20=E6=96=B0=E5=A2=9E=E5=9F=BA?= =?UTF-8?q?=E7=A1=80=E7=89=B9=E5=BE=81=E8=9E=8D=E5=90=88=E5=92=8C=E7=BB=86?= =?UTF-8?q?=E8=8A=82=E7=89=B9=E5=BE=81=E8=9E=8D=E5=90=88=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加了 BaseFeatureFusion 和 DetailFeatureFusion 两个新类 - 更新了 train.py 中的导入和实例化语句 --- net.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ train.py | 7 ++++--- 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/net.py b/net.py index fa384a3..0b9fea7 100644 --- a/net.py +++ b/net.py @@ -71,6 +71,48 @@ class PoolMlp(nn.Module): x = self.drop(x) return x + +class BaseFeatureFusion(nn.Module): + def __init__(self, dim, pool_size=3, mlp_ratio=4., + act_layer=nn.GELU, + # norm_layer=nn.LayerNorm, + drop=0., drop_path=0., + use_layer_scale=True, layer_scale_init_value=1e-5): + + super().__init__() + + self.norm1 = LayerNorm(dim, 'WithBias') + self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 + self.norm2 = LayerNorm(dim, 'WithBias') + mlp_hidden_dim = int(dim * mlp_ratio) + self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop) + + # The following two techniques are useful to train deep PoolFormers. + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.use_layer_scale = use_layer_scale + + if use_layer_scale: + self.layer_scale_1 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + self.layer_scale_2 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + def forward(self, x): + if self.use_layer_scale: + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) + * self.token_mixer(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) + * self.poolmlp(self.norm2(x))) + else: + x = x + self.drop_path(self.token_mixer(self.norm1(x))) + x = x + self.drop_path(self.poolmlp(self.norm2(x))) + return x + class BaseFeatureExtraction(nn.Module): def __init__(self, dim, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, @@ -199,6 +241,17 @@ class DetailNode(nn.Module): z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2) return z1, z2 +class DetailFeatureFusion(nn.Module): + def __init__(self, num_layers=3): + super(DetailFeatureFusion, self).__init__() + INNmodules = [DetailNode() for _ in range(num_layers)] + self.net = nn.Sequential(*INNmodules) + + def forward(self, x): + z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] + for layer in self.net: + z1, z2 = layer(z1, z2) + return torch.cat((z1, z2), dim=1) class DetailFeatureExtraction(nn.Module): def __init__(self, num_layers=3): diff --git a/train.py b/train.py index 9f2a1de..bf43c31 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,8 @@ Import packages ------------------------------------------------------------------------------ ''' -from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction +from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion, \ + DetailFeatureFusion from utils.dataset import H5Dataset import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' @@ -85,8 +86,8 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu' DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device) DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device) # BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) -BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device) -DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device) +BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device) +DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device) # optimizer, scheduler and loss function optimizer1 = torch.optim.Adam( From 23293e91382f643cb3e864742c60e13ec2370054 Mon Sep 17 00:00:00 2001 From: zjut Date: Sat, 16 Nov 2024 23:07:51 +0800 Subject: [PATCH 04/10] =?UTF-8?q?feat(net):=20=E6=9B=BF=E6=8D=A2=20token?= =?UTF-8?q?=5Fmixer=20=E4=B8=BA=20SCSA=20=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 引入新的 SCSA(空间和通道协同注意力)模块 - 用 SCSA 替换原有的 Pooling层作为 token_mixer - 删除了未使用的 SEBlock.py 文件- 移除了与当前项目无关的 TIAM(CV).py 文件 --- componets/SCSA.py | 156 ++++++++++++++++++++++++++++++++++++++++++ componets/TIAM(CV).py | 110 ----------------------------- net.py | 5 +- 3 files changed, 160 insertions(+), 111 deletions(-) create mode 100644 componets/SCSA.py delete mode 100644 componets/TIAM(CV).py diff --git a/componets/SCSA.py b/componets/SCSA.py new file mode 100644 index 0000000..d26d7ba --- /dev/null +++ b/componets/SCSA.py @@ -0,0 +1,156 @@ +import typing as t + +import torch +import torch.nn as nn +from einops.einops import rearrange +from mmengine.model import BaseModule +__all__ = ['SCSA'] + +"""SCSA:探索空间注意力和通道注意力之间的协同作用 +通道和空间注意力分别在为各种下游视觉任务提取特征依赖性和空间结构关系方面带来了显着的改进。 +虽然它们的结合更有利于发挥各自的优势,但通道和空间注意力之间的协同作用尚未得到充分探索,缺乏充分利用多语义信息的协同潜力来进行特征引导和缓解语义差异。 +我们的研究试图在多个语义层面揭示空间和通道注意力之间的协同关系,提出了一种新颖的空间和通道协同注意力模块(SCSA)。我们的SCSA由两部分组成:可共享的多语义空间注意力(SMSA)和渐进式通道自注意力(PCSA)。 +SMSA 集成多语义信息并利用渐进式压缩策略将判别性空间先验注入 PCSA 的通道自注意力中,有效地指导通道重新校准。此外,PCSA 中基于自注意力机制的稳健特征交互进一步缓解了 SMSA 中不同子特征之间多语义信息的差异。 +我们在七个基准数据集上进行了广泛的实验,包括 ImageNet-1K 上的分类、MSCOCO 2017 上的对象检测、ADE20K 上的分割以及其他四个复杂场景检测数据集。我们的结果表明,我们提出的 SCSA 不仅超越了当前最先进的注意力机制, +而且在各种任务场景中表现出增强的泛化能力。 +""" + +class SCSA(BaseModule): + + def __init__( + self, + dim: int, + head_num: int, + window_size: int = 7, + group_kernel_sizes: t.List[int] = [3, 5, 7, 9], + qkv_bias: bool = False, + fuse_bn: bool = False, + norm_cfg: t.Dict = dict(type='BN'), + act_cfg: t.Dict = dict(type='ReLU'), + down_sample_mode: str = 'avg_pool', + attn_drop_ratio: float = 0., + gate_layer: str = 'sigmoid', + ): + super(SCSA, self).__init__() + self.dim = dim + self.head_num = head_num + self.head_dim = dim // head_num + self.scaler = self.head_dim ** -0.5 + self.group_kernel_sizes = group_kernel_sizes + self.window_size = window_size + self.qkv_bias = qkv_bias + self.fuse_bn = fuse_bn + self.down_sample_mode = down_sample_mode + + assert self.dim // 4, 'The dimension of input feature should be divisible by 4.' + self.group_chans = group_chans = self.dim // 4 + + self.local_dwc = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[0], + padding=group_kernel_sizes[0] // 2, groups=group_chans) + self.global_dwc_s = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[1], + padding=group_kernel_sizes[1] // 2, groups=group_chans) + self.global_dwc_m = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[2], + padding=group_kernel_sizes[2] // 2, groups=group_chans) + self.global_dwc_l = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[3], + padding=group_kernel_sizes[3] // 2, groups=group_chans) + self.sa_gate = nn.Softmax(dim=2) if gate_layer == 'softmax' else nn.Sigmoid() + self.norm_h = nn.GroupNorm(4, dim) + self.norm_w = nn.GroupNorm(4, dim) + + self.conv_d = nn.Identity() + self.norm = nn.GroupNorm(1, dim) + self.q = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim) + self.k = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim) + self.v = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim) + self.attn_drop = nn.Dropout(attn_drop_ratio) + self.ca_gate = nn.Softmax(dim=1) if gate_layer == 'softmax' else nn.Sigmoid() + + if window_size == -1: + self.down_func = nn.AdaptiveAvgPool2d((1, 1)) + else: + if down_sample_mode == 'recombination': + self.down_func = self.space_to_chans + # dimensionality reduction + self.conv_d = nn.Conv2d(in_channels=dim * window_size ** 2, out_channels=dim, kernel_size=1, bias=False) + elif down_sample_mode == 'avg_pool': + self.down_func = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=window_size) + elif down_sample_mode == 'max_pool': + self.down_func = nn.MaxPool2d(kernel_size=(window_size, window_size), stride=window_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + The dim of x is (B, C, H, W) + """ + # Spatial attention priority calculation + b, c, h_, w_ = x.size() + # (B, C, H) + x_h = x.mean(dim=3) + l_x_h, g_x_h_s, g_x_h_m, g_x_h_l = torch.split(x_h, self.group_chans, dim=1) + # (B, C, W) + x_w = x.mean(dim=2) + l_x_w, g_x_w_s, g_x_w_m, g_x_w_l = torch.split(x_w, self.group_chans, dim=1) + + x_h_attn = self.sa_gate(self.norm_h(torch.cat(( + self.local_dwc(l_x_h), + self.global_dwc_s(g_x_h_s), + self.global_dwc_m(g_x_h_m), + self.global_dwc_l(g_x_h_l), + ), dim=1))) + x_h_attn = x_h_attn.view(b, c, h_, 1) + + x_w_attn = self.sa_gate(self.norm_w(torch.cat(( + self.local_dwc(l_x_w), + self.global_dwc_s(g_x_w_s), + self.global_dwc_m(g_x_w_m), + self.global_dwc_l(g_x_w_l) + ), dim=1))) + x_w_attn = x_w_attn.view(b, c, 1, w_) + + x = x * x_h_attn * x_w_attn + + # Channel attention based on self attention + # reduce calculations + y = self.down_func(x) + y = self.conv_d(y) + _, _, h_, w_ = y.size() + + # normalization first, then reshape -> (B, H, W, C) -> (B, C, H * W) and generate q, k and v + y = self.norm(y) + q = self.q(y) + k = self.k(y) + v = self.v(y) + # (B, C, H, W) -> (B, head_num, head_dim, N) + q = rearrange(q, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num), + head_dim=int(self.head_dim)) + k = rearrange(k, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num), + head_dim=int(self.head_dim)) + v = rearrange(v, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num), + head_dim=int(self.head_dim)) + + # (B, head_num, head_dim, head_dim) + attn = q @ k.transpose(-2, -1) * self.scaler + attn = self.attn_drop(attn.softmax(dim=-1)) + # (B, head_num, head_dim, N) + attn = attn @ v + # (B, C, H_, W_) + attn = rearrange(attn, 'b head_num head_dim (h w) -> b (head_num head_dim) h w', h=int(h_), w=int(w_)) + # (B, C, 1, 1) + attn = attn.mean((2, 3), keepdim=True) + attn = self.ca_gate(attn) + return attn * x + +if __name__ == '__main__': + + block = SCSA( + dim=256, + head_num=8, + ) + + input_tensor = torch.rand(1, 256, 32, 32) + + # 调用模块进行前向传播 + output_tensor = block(input_tensor) + + # 打印输入和输出张量的大小 + print("Input size:", input_tensor.size()) + print("Output size:", output_tensor.size()) diff --git a/componets/TIAM(CV).py b/componets/TIAM(CV).py deleted file mode 100644 index b2595af..0000000 --- a/componets/TIAM(CV).py +++ /dev/null @@ -1,110 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -"""Elsevier2024 -变化检测 (CD) 是地球观测中一种重要的监测方法,尤其适用于土地利用分析、城市管理和灾害损失评估。然而,在星座互联和空天协作时代,感兴趣区域 (ROI) 的变化由于几何透视旋转和时间风格差异而导致许多错误检测。 -为了应对这些挑战,我们引入了 CDNeXt,该框架阐明了一种稳健而有效的方法,用于将基于预训练主干的 Siamese 网络与用于遥感图像的创新时空交互注意模块 (TIAM) 相结合。 -CDNeXt 可分为四个主要组件:编码器、交互器、解码器和检测器。值得注意的是,由 TIAM 提供支持的交互器从编码器提取的二进制时间特征中查询和重建空间透视依赖关系和时间风格相关性,以扩大 ROI 变化的差异。 -最后,检测器集成解码器生成的分层特征,随后生成二进制变化掩码。 -""" - -class SpatiotemporalAttentionFullNotWeightShared(nn.Module): - def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False): - super(SpatiotemporalAttentionFullNotWeightShared, self).__init__() - assert dimension in [2, ] - self.dimension = dimension - self.sub_sample = sub_sample - self.in_channels = in_channels - self.inter_channels = inter_channels - - if self.inter_channels is None: - self.inter_channels = in_channels // 2 - if self.inter_channels == 0: - self.inter_channels = 1 - - self.g1 = nn.Sequential( - nn.BatchNorm2d(self.in_channels), - nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, - kernel_size=1, stride=1, padding=0) - ) - self.g2 = nn.Sequential( - nn.BatchNorm2d(self.in_channels), - nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, - kernel_size=1, stride=1, padding=0), - ) - - self.W1 = nn.Sequential( - nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, - kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(self.in_channels) - ) - self.W2 = nn.Sequential( - nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, - kernel_size=1, stride=1, padding=0), - nn.BatchNorm2d(self.in_channels) - ) - self.theta = nn.Sequential( - nn.BatchNorm2d(self.in_channels), - nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, - kernel_size=1, stride=1, padding=0), - ) - self.phi = nn.Sequential( - nn.BatchNorm2d(self.in_channels), - nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, - kernel_size=1, stride=1, padding=0), - ) - - def forward(self, x1, x2): - """ - :param x: (b, c, h, w) - :param return_nl_map: if True return z, nl_map, else only return z. - :return: - """ - batch_size = x1.size(0) - g_x11 = self.g1(x1).reshape(batch_size, self.inter_channels, -1) - g_x12 = g_x11.permute(0, 2, 1) - g_x21 = self.g2(x2).reshape(batch_size, self.inter_channels, -1) - g_x22 = g_x21.permute(0, 2, 1) - - theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1) - theta_x2 = theta_x1.permute(0, 2, 1) - - phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1) - phi_x2 = phi_x1.permute(0, 2, 1) - - energy_time_1 = torch.matmul(theta_x1, phi_x2) - energy_time_2 = energy_time_1.permute(0, 2, 1) - energy_space_1 = torch.matmul(theta_x2, phi_x1) - energy_space_2 = energy_space_1.permute(0, 2, 1) - - energy_time_1s = F.softmax(energy_time_1, dim=-1) - energy_time_2s = F.softmax(energy_time_2, dim=-1) - energy_space_2s = F.softmax(energy_space_1, dim=-2) - energy_space_1s = F.softmax(energy_space_2, dim=-2) - # C1*S(C2) energy_time_1s * C1*H1W1 g_x12 * energy_space_1s S(H2W2)*H1W1 -> C1*H1W1 - y1 = torch.matmul(torch.matmul(energy_time_2s, g_x11), energy_space_2s).contiguous() # C2*H2W2 - # C2*S(C1) energy_time_2s * C2*H2W2 g_x21 * energy_space_2s S(H1W1)*H2W2 -> C2*H2W2 - y2 = torch.matmul(torch.matmul(energy_time_1s, g_x21), energy_space_1s).contiguous() # C1*H1W1 - y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:]) - y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:]) - return x1 + self.W1(y1), x2 + self.W2(y2) - - -if __name__ == '__main__': - in_channels = 64 - batch_size = 8 - height = 32 - width = 32 - - block = SpatiotemporalAttentionFullNotWeightShared(in_channels=in_channels) - - input1 = torch.rand(batch_size, in_channels, height, width) - input2 = torch.rand(batch_size, in_channels, height, width) - - output1, output2 = block(input1, input2) - - print(f"Input1 size: {input1.size()}") - print(f"Input2 size: {input2.size()}") - print(f"Output1 size: {output1.size()}") - print(f"Output2 size: {output2.size()}") diff --git a/net.py b/net.py index 0b9fea7..04160e6 100644 --- a/net.py +++ b/net.py @@ -6,6 +6,8 @@ import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from einops import rearrange +from componets.SCSA import SCSA + def drop_path(x, drop_prob: float = 0., training: bool = False): if drop_prob == 0. or not training: @@ -164,7 +166,8 @@ class BaseFeatureExtractionSAR(nn.Module): super().__init__() self.norm1 = LayerNorm(dim, 'WithBias') - self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 + self.token_mixer = SCSA(dim=dim,head_num=8) + # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 self.norm2 = LayerNorm(dim, 'WithBias') mlp_hidden_dim = int(dim * mlp_ratio) self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, From a435e035c000f70a42cd6c68eac890f6897dc4e9 Mon Sep 17 00:00:00 2001 From: zjut Date: Sun, 17 Nov 2024 09:53:12 +0800 Subject: [PATCH 05/10] =?UTF-8?q?test:=20=E6=9B=B4=E6=96=B0=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E6=A8=A1=E5=9E=8B=E5=B9=B6=E4=BF=AE=E6=94=B9=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=90=8D=E7=A7=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 更新测试模型路径为 whaiFusion11-16-23-08.pth - 修改模型名称为 "PFCFuse Enhance SCSA" --- test_IVF.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_IVF.py b/test_IVF.py index 467f7fb..bd3c269 100644 --- a/test_IVF.py +++ b/test_IVF.py @@ -17,11 +17,11 @@ current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") os.environ["CUDA_VISIBLE_DEVICES"] = "0" -ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-16-11-20.pth" +ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-16-23-08.pth" for dataset_name in ["sar"]: print("\n"*2+"="*80) - model_name="PFCFuse Enhance " + model_name="PFCFuse Enhance SCSA " print("The test result of "+dataset_name+' :') test_folder = os.path.join('test_img', dataset_name) test_out_folder=os.path.join('test_result',current_time,dataset_name) From 0336fc23ba41954a707473666f0aa79d417fc215 Mon Sep 17 00:00:00 2001 From: zjut Date: Sun, 17 Nov 2024 10:33:24 +0800 Subject: [PATCH 06/10] =?UTF-8?q?feat(net):=20=E4=B8=BA=20DetailNode?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=E6=B7=BB=E5=8A=A0=E5=8F=AF=E9=80=89=E5=8D=B7?= =?UTF-8?q?=E7=A7=AF=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 DetailNode 类中引入 useBlock 参数,用于选择不同的卷积块 - 新增 DepthwiseSeparableConvBlock 类,实现深度可分离卷积 - 根据 useBlock 的值,选择使用 DepthwiseSeparableConvBlock 或 InvertedResidualBlock - 优化了网络结构,提供了更多的灵活性和选择性 --- net.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/net.py b/net.py index 04160e6..8a55253 100644 --- a/net.py +++ b/net.py @@ -222,14 +222,36 @@ class InvertedResidualBlock(nn.Module): def forward(self, x): return self.bottleneckBlock(x) +class DepthwiseSeparableConvBlock(nn.Module): + def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1): + super(DepthwiseSeparableConvBlock, self).__init__() + self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False) + self.pointwise = nn.Conv2d(inp, oup, 1, bias=False) + self.bn = nn.BatchNorm2d(oup) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.depthwise(x) + x = self.pointwise(x) + x = self.bn(x) + x = self.relu(x) + return x class DetailNode(nn.Module): - def __init__(self): + def __init__(self,useBlock=0): super(DetailNode, self).__init__() - - self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + if useBlock == 0: + self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32) + self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32) + self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32) + elif useBlock == 1: + self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + else: + self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.shffleconv = nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0, bias=True) From 3b7b64c915f55f6d4790231fcc279c41f825ffee Mon Sep 17 00:00:00 2001 From: zjut Date: Sun, 17 Nov 2024 15:49:42 +0800 Subject: [PATCH 07/10] =?UTF-8?q?test(test=5FIVF.py):=20=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=B7=AF=E5=BE=84=E5=92=8C=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E5=90=8D=E7=A7=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 更新模型路径为新的权重文件 - 修改模型名称以反映新的增强方法 --- test_IVF.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_IVF.py b/test_IVF.py index bd3c269..ef78c63 100644 --- a/test_IVF.py +++ b/test_IVF.py @@ -17,11 +17,11 @@ current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") os.environ["CUDA_VISIBLE_DEVICES"] = "0" -ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-16-23-08.pth" +ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-17-10-34.pth" for dataset_name in ["sar"]: print("\n"*2+"="*80) - model_name="PFCFuse Enhance SCSA " + model_name="PFCFuse Enhance 增加widthblock" print("The test result of "+dataset_name+' :') test_folder = os.path.join('test_img', dataset_name) test_out_folder=os.path.join('test_result',current_time,dataset_name) From 260e3aa760bd6ce94011e5b637a0bff7ab15fcd9 Mon Sep 17 00:00:00 2001 From: zjut Date: Sun, 17 Nov 2024 15:57:41 +0800 Subject: [PATCH 08/10] =?UTF-8?q?feat(net):=20=E6=B7=BB=E5=8A=A0=20WTConv2?= =?UTF-8?q?d=20=E5=B1=82=E5=B9=B6=E4=BF=AE=E6=94=B9=20DetailNode=20?= =?UTF-8?q?=E4=BD=BF=E7=94=A8-=20=E5=9C=A8=20net.py=20=E4=B8=AD=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E4=BA=86=20WTConv2d=20=E5=B1=82=E7=9A=84=E5=AF=BC?= =?UTF-8?q?=E5=85=A5-=20=E4=BF=AE=E6=94=B9=E4=BA=86=20DetailNode=20?= =?UTF-8?q?=E7=B1=BB=E7=9A=84=E6=9E=84=E9=80=A0=E5=87=BD=E6=95=B0=EF=BC=8C?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=20useBlock=20=E5=8F=82=E6=95=B0=20-?= =?UTF-8?q?=20=E6=A0=B9=E6=8D=AE=20useBlock=20=E5=8F=82=E6=95=B0=E7=9A=84?= =?UTF-8?q?=E5=80=BC=EF=BC=8C=E9=80=89=E6=8B=A9=E4=BD=BF=E7=94=A8=20WTConv?= =?UTF-8?q?2d=E5=B1=82=E6=88=96=20InvertedResidualBlock-=20=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E4=BA=86=20DetailFeatureFusion=20=E5=92=8C=20DetailFe?= =?UTF-8?q?atureExtraction=20=E7=B1=BB=EF=BC=8C=E6=8C=87=E5=AE=9A=E4=BA=86?= =?UTF-8?q?=20DetailNode=20=E7=9A=84=20useBlock=20=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- net.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/net.py b/net.py index 8a55253..e5d2614 100644 --- a/net.py +++ b/net.py @@ -7,6 +7,7 @@ from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from einops import rearrange from componets.SCSA import SCSA +from componets.WTConvCV2 import WTConv2d def drop_path(x, drop_prob: float = 0., training: bool = False): @@ -248,6 +249,10 @@ class DetailNode(nn.Module): self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + elif useBlock == 2: + self.theta_phi = WTConv2d(in_channels=32, out_channels=32) + self.theta_rho = WTConv2d(in_channels=32, out_channels=32) + self.theta_eta = WTConv2d(in_channels=32, out_channels=32) else: self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) @@ -269,7 +274,7 @@ class DetailNode(nn.Module): class DetailFeatureFusion(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureFusion, self).__init__() - INNmodules = [DetailNode() for _ in range(num_layers)] + INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x): @@ -281,7 +286,7 @@ class DetailFeatureFusion(nn.Module): class DetailFeatureExtraction(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtraction, self).__init__() - INNmodules = [DetailNode() for _ in range(num_layers)] + INNmodules = [DetailNode(useBlock=0) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x): From 5555d0d39c7c64977e90f7eb643dd8184f6c1011 Mon Sep 17 00:00:00 2001 From: zjut Date: Sun, 17 Nov 2024 16:12:44 +0800 Subject: [PATCH 09/10] =?UTF-8?q?feat(net):=20=E4=BF=AE=E6=94=B9=20DetailF?= =?UTF-8?q?eatureExtraction=20=E5=92=8C=20DetailFeatureExtractionSAR=20?= =?UTF-8?q?=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 DetailFeatureExtraction 类中的 DetailNode 使用参数 useBlock=2 - 将 DetailFeatureExtractionSAR 类中的 DetailNode 使用参数 useBlock=1 vi wtconv sar inn --- net.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/net.py b/net.py index e5d2614..debb154 100644 --- a/net.py +++ b/net.py @@ -286,7 +286,7 @@ class DetailFeatureFusion(nn.Module): class DetailFeatureExtraction(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtraction, self).__init__() - INNmodules = [DetailNode(useBlock=0) for _ in range(num_layers)] + INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x): @@ -299,7 +299,7 @@ class DetailFeatureExtraction(nn.Module): class DetailFeatureExtractionSAR(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtractionSAR, self).__init__() - INNmodules = [DetailNode() for _ in range(num_layers)] + INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x): From 775cbdf20f980ea653fe2125bc4735d3eeb63412 Mon Sep 17 00:00:00 2001 From: zjut Date: Mon, 18 Nov 2024 09:27:16 +0800 Subject: [PATCH 10/10] =?UTF-8?q?refactor(net):=20=E4=BF=AE=E6=94=B9=20Det?= =?UTF-8?q?ailFeatureExtraction=20=E5=92=8C=20DetailFeatureExtractionSAR?= =?UTF-8?q?=20=E7=B1=BB=E4=B8=AD=E7=9A=84=20DetailNode=20=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 DetailFeatureExtraction 类中的 DetailNode 使用方式从 useBlock=2 改为 useBlock=1 - 将 DetailFeatureExtractionSAR 类中的 DetailNode 使用方式从 useBlock=1 改为 useBlock=2 --- net.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/net.py b/net.py index debb154..2e2bc40 100644 --- a/net.py +++ b/net.py @@ -286,7 +286,7 @@ class DetailFeatureFusion(nn.Module): class DetailFeatureExtraction(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtraction, self).__init__() - INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)] + INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x): @@ -299,7 +299,7 @@ class DetailFeatureExtraction(nn.Module): class DetailFeatureExtractionSAR(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtractionSAR, self).__init__() - INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)] + INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x):