feat(net): 重构特征融合模块并添加新组件

- 新增 BaseFeatureFusion 和 DetailFeatureFusioin 类,用于特征融合
- 更新 ProjectRootManager 配置,使用本地 Python 3.8 环境
- 修改训练数据集路径
- 优化训练日志输出格式
This commit is contained in:
zjut 2024-11-14 16:02:05 +08:00
parent b6486dbaf4
commit c1eed72f24
2 changed files with 81 additions and 7 deletions

73
net.py
View File

@ -198,6 +198,58 @@ class BaseFeatureExtraction(nn.Module):
x = x + self.drop_path(self.poolmlp(self.norm2(x))) x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x return x
class BaseFeatureFusion(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.WTConv2d = WTConv2d(dim, dim)
self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x): # 1 64 128 128
if self.use_layer_scale:
# self.layer_scale_1(64,)
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
normal = self.norm1(x) # 1 64 128 128
token_mix = self.token_mixer(normal) # 1 64 128 128
x = self.WTConv2d(x)
x = (x +
self.drop_path(
tmp1 * token_mix
)
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
)
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class InvertedResidualBlock(nn.Module): class InvertedResidualBlock(nn.Module):
def __init__(self, inp, oup, expand_ratio): def __init__(self, inp, oup, expand_ratio):
super(InvertedResidualBlock, self).__init__() super(InvertedResidualBlock, self).__init__()
@ -262,6 +314,27 @@ class DetailFeatureExtraction(nn.Module):
z1, z2 = layer(z1, z2) z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1) return torch.cat((z1, z2), dim=1)
class DetailFeatureFusioin(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureFusioin, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
self.enhancement_module = WTConv2d(32, 32)
def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
# 增强并添加残差连接
enhanced_z1 = self.enhancement_module(z1)
enhanced_z2 = self.enhancement_module(z2)
# 残差连接
z1 = z1 + enhanced_z1
z2 = z2 + enhanced_z2
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
# ============================================================================= # =============================================================================
# ============================================================================= # =============================================================================

View File

@ -6,7 +6,8 @@ Import packages
------------------------------------------------------------------------------ ------------------------------------------------------------------------------
''' '''
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion, \
DetailFeatureFusioin
from utils.dataset import H5Dataset from utils.dataset import H5Dataset
import os import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
@ -85,8 +86,8 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device) DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device) DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) # BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device) BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device) DetailFuseLayer = nn.DataParallel(DetailFeatureFusioin(num_layers=1)).to(device)
# optimizer, scheduler and loss function # optimizer, scheduler and loss function
optimizer1 = torch.optim.Adam( optimizer1 = torch.optim.Adam(
@ -109,7 +110,7 @@ Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
HuberLoss = nn.HuberLoss() HuberLoss = nn.HuberLoss()
# data loader # data loader
trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/CDDFuse/data/MSRS_train_imgsize_128_stride_200.h5"), trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/PFCFuse/data/YYX_sar_opr_data_imgsize_256_stride_200.h5"),
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
num_workers=0) num_workers=0)
@ -222,8 +223,8 @@ for epoch in range(num_epochs):
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
epoch_time = time.time() - prev_time epoch_time = time.time() - prev_time
prev_time = time.time() prev_time = time.time()
if i % 100 == 0:
sys.stdout.write( sys.stdout.write(
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s" "\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
% ( % (
epoch, epoch,
@ -233,7 +234,7 @@ for epoch in range(num_epochs):
loss.item(), loss.item(),
time_left, time_left,
) )
) )
# adjust the learning rate # adjust the learning rate