feat(net): 重构特征融合模块并添加新组件
- 新增 BaseFeatureFusion 和 DetailFeatureFusioin 类,用于特征融合 - 更新 ProjectRootManager 配置,使用本地 Python 3.8 环境 - 修改训练数据集路径 - 优化训练日志输出格式
This commit is contained in:
parent
b6486dbaf4
commit
c1eed72f24
73
net.py
73
net.py
@ -198,6 +198,58 @@ class BaseFeatureExtraction(nn.Module):
|
||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
class BaseFeatureFusion(nn.Module):
|
||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||
act_layer=nn.GELU,
|
||||
# norm_layer=nn.LayerNorm,
|
||||
drop=0., drop_path=0.,
|
||||
use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.WTConv2d = WTConv2d(dim, dim)
|
||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer, drop=drop)
|
||||
|
||||
# The following two techniques are useful to train deep PoolFormers.
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||
else nn.Identity()
|
||||
self.use_layer_scale = use_layer_scale
|
||||
|
||||
if use_layer_scale:
|
||||
self.layer_scale_1 = nn.Parameter(
|
||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||
|
||||
self.layer_scale_2 = nn.Parameter(
|
||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||
|
||||
def forward(self, x): # 1 64 128 128
|
||||
if self.use_layer_scale:
|
||||
# self.layer_scale_1(64,)
|
||||
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
||||
normal = self.norm1(x) # 1 64 128 128
|
||||
token_mix = self.token_mixer(normal) # 1 64 128 128
|
||||
|
||||
x = self.WTConv2d(x)
|
||||
|
||||
x = (x +
|
||||
self.drop_path(
|
||||
tmp1 * token_mix
|
||||
)
|
||||
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
||||
)
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||
* self.poolmlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
class InvertedResidualBlock(nn.Module):
|
||||
def __init__(self, inp, oup, expand_ratio):
|
||||
super(InvertedResidualBlock, self).__init__()
|
||||
@ -262,6 +314,27 @@ class DetailFeatureExtraction(nn.Module):
|
||||
z1, z2 = layer(z1, z2)
|
||||
return torch.cat((z1, z2), dim=1)
|
||||
|
||||
class DetailFeatureFusioin(nn.Module):
|
||||
def __init__(self, num_layers=3):
|
||||
super(DetailFeatureFusioin, self).__init__()
|
||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||
self.net = nn.Sequential(*INNmodules)
|
||||
self.enhancement_module = WTConv2d(32, 32)
|
||||
|
||||
def forward(self, x): # 1 64 128 128
|
||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
||||
# 增强并添加残差连接
|
||||
enhanced_z1 = self.enhancement_module(z1)
|
||||
enhanced_z2 = self.enhancement_module(z2)
|
||||
# 残差连接
|
||||
z1 = z1 + enhanced_z1
|
||||
z2 = z2 + enhanced_z2
|
||||
for layer in self.net:
|
||||
z1, z2 = layer(z1, z2)
|
||||
return torch.cat((z1, z2), dim=1)
|
||||
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
# =============================================================================
|
||||
|
11
train.py
11
train.py
@ -6,7 +6,8 @@ Import packages
|
||||
------------------------------------------------------------------------------
|
||||
'''
|
||||
|
||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion, \
|
||||
DetailFeatureFusioin
|
||||
from utils.dataset import H5Dataset
|
||||
import os
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||
@ -85,8 +86,8 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
|
||||
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||
BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device)
|
||||
DetailFuseLayer = nn.DataParallel(DetailFeatureFusioin(num_layers=1)).to(device)
|
||||
|
||||
# optimizer, scheduler and loss function
|
||||
optimizer1 = torch.optim.Adam(
|
||||
@ -109,7 +110,7 @@ Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
|
||||
HuberLoss = nn.HuberLoss()
|
||||
|
||||
# data loader
|
||||
trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/CDDFuse/data/MSRS_train_imgsize_128_stride_200.h5"),
|
||||
trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/PFCFuse/data/YYX_sar_opr_data_imgsize_256_stride_200.h5"),
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0)
|
||||
@ -222,7 +223,7 @@ for epoch in range(num_epochs):
|
||||
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
|
||||
epoch_time = time.time() - prev_time
|
||||
prev_time = time.time()
|
||||
if i % 100 == 0:
|
||||
|
||||
sys.stdout.write(
|
||||
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
|
||||
% (
|
||||
|
Loading…
Reference in New Issue
Block a user