feat(net): 重构特征融合模块并添加新组件
- 新增 BaseFeatureFusion 和 DetailFeatureFusioin 类,用于特征融合 - 更新 ProjectRootManager 配置,使用本地 Python 3.8 环境 - 修改训练数据集路径 - 优化训练日志输出格式
This commit is contained in:
parent
b6486dbaf4
commit
c1eed72f24
73
net.py
73
net.py
@ -198,6 +198,58 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class BaseFeatureFusion(nn.Module):
|
||||||
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
# norm_layer=nn.LayerNorm,
|
||||||
|
drop=0., drop_path=0.,
|
||||||
|
use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.WTConv2d = WTConv2d(dim, dim)
|
||||||
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
|
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer, drop=drop)
|
||||||
|
|
||||||
|
# The following two techniques are useful to train deep PoolFormers.
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||||
|
else nn.Identity()
|
||||||
|
self.use_layer_scale = use_layer_scale
|
||||||
|
|
||||||
|
if use_layer_scale:
|
||||||
|
self.layer_scale_1 = nn.Parameter(
|
||||||
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
|
self.layer_scale_2 = nn.Parameter(
|
||||||
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
|
def forward(self, x): # 1 64 128 128
|
||||||
|
if self.use_layer_scale:
|
||||||
|
# self.layer_scale_1(64,)
|
||||||
|
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
||||||
|
normal = self.norm1(x) # 1 64 128 128
|
||||||
|
token_mix = self.token_mixer(normal) # 1 64 128 128
|
||||||
|
|
||||||
|
x = self.WTConv2d(x)
|
||||||
|
|
||||||
|
x = (x +
|
||||||
|
self.drop_path(
|
||||||
|
tmp1 * token_mix
|
||||||
|
)
|
||||||
|
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
||||||
|
)
|
||||||
|
x = x + self.drop_path(
|
||||||
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
* self.poolmlp(self.norm2(x)))
|
||||||
|
else:
|
||||||
|
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
||||||
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
class InvertedResidualBlock(nn.Module):
|
class InvertedResidualBlock(nn.Module):
|
||||||
def __init__(self, inp, oup, expand_ratio):
|
def __init__(self, inp, oup, expand_ratio):
|
||||||
super(InvertedResidualBlock, self).__init__()
|
super(InvertedResidualBlock, self).__init__()
|
||||||
@ -262,6 +314,27 @@ class DetailFeatureExtraction(nn.Module):
|
|||||||
z1, z2 = layer(z1, z2)
|
z1, z2 = layer(z1, z2)
|
||||||
return torch.cat((z1, z2), dim=1)
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
|
class DetailFeatureFusioin(nn.Module):
|
||||||
|
def __init__(self, num_layers=3):
|
||||||
|
super(DetailFeatureFusioin, self).__init__()
|
||||||
|
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||||
|
self.net = nn.Sequential(*INNmodules)
|
||||||
|
self.enhancement_module = WTConv2d(32, 32)
|
||||||
|
|
||||||
|
def forward(self, x): # 1 64 128 128
|
||||||
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
||||||
|
# 增强并添加残差连接
|
||||||
|
enhanced_z1 = self.enhancement_module(z1)
|
||||||
|
enhanced_z2 = self.enhancement_module(z2)
|
||||||
|
# 残差连接
|
||||||
|
z1 = z1 + enhanced_z1
|
||||||
|
z2 = z2 + enhanced_z2
|
||||||
|
for layer in self.net:
|
||||||
|
z1, z2 = layer(z1, z2)
|
||||||
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
11
train.py
11
train.py
@ -6,7 +6,8 @@ Import packages
|
|||||||
------------------------------------------------------------------------------
|
------------------------------------------------------------------------------
|
||||||
'''
|
'''
|
||||||
|
|
||||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion, \
|
||||||
|
DetailFeatureFusioin
|
||||||
from utils.dataset import H5Dataset
|
from utils.dataset import H5Dataset
|
||||||
import os
|
import os
|
||||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||||
@ -85,8 +86,8 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|||||||
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||||
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||||
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||||
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
|
BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device)
|
||||||
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
DetailFuseLayer = nn.DataParallel(DetailFeatureFusioin(num_layers=1)).to(device)
|
||||||
|
|
||||||
# optimizer, scheduler and loss function
|
# optimizer, scheduler and loss function
|
||||||
optimizer1 = torch.optim.Adam(
|
optimizer1 = torch.optim.Adam(
|
||||||
@ -109,7 +110,7 @@ Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
|
|||||||
HuberLoss = nn.HuberLoss()
|
HuberLoss = nn.HuberLoss()
|
||||||
|
|
||||||
# data loader
|
# data loader
|
||||||
trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/CDDFuse/data/MSRS_train_imgsize_128_stride_200.h5"),
|
trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/PFCFuse/data/YYX_sar_opr_data_imgsize_256_stride_200.h5"),
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0)
|
num_workers=0)
|
||||||
@ -222,7 +223,7 @@ for epoch in range(num_epochs):
|
|||||||
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
|
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
|
||||||
epoch_time = time.time() - prev_time
|
epoch_time = time.time() - prev_time
|
||||||
prev_time = time.time()
|
prev_time = time.time()
|
||||||
if i % 100 == 0:
|
|
||||||
sys.stdout.write(
|
sys.stdout.write(
|
||||||
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
|
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
|
||||||
% (
|
% (
|
||||||
|
Loading…
Reference in New Issue
Block a user