From c1eed72f2489614681940bdcf5b6cb256ae82bcc Mon Sep 17 00:00:00 2001 From: zjut Date: Thu, 14 Nov 2024 16:02:05 +0800 Subject: [PATCH] =?UTF-8?q?feat(net):=20=E9=87=8D=E6=9E=84=E7=89=B9?= =?UTF-8?q?=E5=BE=81=E8=9E=8D=E5=90=88=E6=A8=A1=E5=9D=97=E5=B9=B6=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E6=96=B0=E7=BB=84=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 BaseFeatureFusion 和 DetailFeatureFusioin 类,用于特征融合 - 更新 ProjectRootManager 配置,使用本地 Python 3.8 环境 - 修改训练数据集路径 - 优化训练日志输出格式 --- net.py | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ train.py | 15 ++++++------ 2 files changed, 81 insertions(+), 7 deletions(-) diff --git a/net.py b/net.py index 75df2bf..f23aac3 100644 --- a/net.py +++ b/net.py @@ -198,6 +198,58 @@ class BaseFeatureExtraction(nn.Module): x = x + self.drop_path(self.poolmlp(self.norm2(x))) return x +class BaseFeatureFusion(nn.Module): + def __init__(self, dim, pool_size=3, mlp_ratio=4., + act_layer=nn.GELU, + # norm_layer=nn.LayerNorm, + drop=0., drop_path=0., + use_layer_scale=True, layer_scale_init_value=1e-5): + + super().__init__() + + self.WTConv2d = WTConv2d(dim, dim) + self.norm1 = LayerNorm(dim, 'WithBias') + self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 + self.norm2 = LayerNorm(dim, 'WithBias') + mlp_hidden_dim = int(dim * mlp_ratio) + self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop) + + # The following two techniques are useful to train deep PoolFormers. + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.use_layer_scale = use_layer_scale + + if use_layer_scale: + self.layer_scale_1 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + self.layer_scale_2 = nn.Parameter( + torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) + + def forward(self, x): # 1 64 128 128 + if self.use_layer_scale: + # self.layer_scale_1(64,) + tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1 + normal = self.norm1(x) # 1 64 128 128 + token_mix = self.token_mixer(normal) # 1 64 128 128 + + x = self.WTConv2d(x) + + x = (x + + self.drop_path( + tmp1 * token_mix + ) + # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。 + ) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) + * self.poolmlp(self.norm2(x))) + else: + x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse + x = x + self.drop_path(self.poolmlp(self.norm2(x))) + return x + class InvertedResidualBlock(nn.Module): def __init__(self, inp, oup, expand_ratio): super(InvertedResidualBlock, self).__init__() @@ -262,6 +314,27 @@ class DetailFeatureExtraction(nn.Module): z1, z2 = layer(z1, z2) return torch.cat((z1, z2), dim=1) +class DetailFeatureFusioin(nn.Module): + def __init__(self, num_layers=3): + super(DetailFeatureFusioin, self).__init__() + INNmodules = [DetailNode() for _ in range(num_layers)] + self.net = nn.Sequential(*INNmodules) + self.enhancement_module = WTConv2d(32, 32) + + def forward(self, x): # 1 64 128 128 + z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 + # 增强并添加残差连接 + enhanced_z1 = self.enhancement_module(z1) + enhanced_z2 = self.enhancement_module(z2) + # 残差连接 + z1 = z1 + enhanced_z1 + z2 = z2 + enhanced_z2 + for layer in self.net: + z1, z2 = layer(z1, z2) + return torch.cat((z1, z2), dim=1) + + + # ============================================================================= # ============================================================================= diff --git a/train.py b/train.py index 501045e..a5d8c86 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,8 @@ Import packages ------------------------------------------------------------------------------ ''' -from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction +from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion, \ + DetailFeatureFusioin from utils.dataset import H5Dataset import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' @@ -85,8 +86,8 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu' DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device) DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device) # BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) -BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device) -DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device) +BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device) +DetailFuseLayer = nn.DataParallel(DetailFeatureFusioin(num_layers=1)).to(device) # optimizer, scheduler and loss function optimizer1 = torch.optim.Adam( @@ -109,7 +110,7 @@ Loss_ssim = kornia.losses.SSIM(11, reduction='mean') HuberLoss = nn.HuberLoss() # data loader -trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/CDDFuse/data/MSRS_train_imgsize_128_stride_200.h5"), +trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/PFCFuse/data/YYX_sar_opr_data_imgsize_256_stride_200.h5"), batch_size=batch_size, shuffle=True, num_workers=0) @@ -222,8 +223,8 @@ for epoch in range(num_epochs): time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) epoch_time = time.time() - prev_time prev_time = time.time() - if i % 100 == 0: - sys.stdout.write( + + sys.stdout.write( "\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s" % ( epoch, @@ -233,7 +234,7 @@ for epoch in range(num_epochs): loss.item(), time_left, ) - ) + ) # adjust the learning rate