feat(net): 为 net.py 添加新的组件引用并优化前向传播逻辑
- 在 net.py 中引入 SMFA 组件 - 优化 BasicLayer 类的前向传播逻辑 - 添加 SMFA、DynamicFilter 和 UFFC 组件的实现 - 使用SMFA替代Pooling self.WTConv2d = WTConv2d(dim, dim) self.norm1 = LayerNorm(dim, 'WithBias') self.token_mixer = SMFA(dim=dim) # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 self.norm2 = LayerNorm(dim, 'WithBias') mlp_hidden_dim = int(dim * mlp_ratio) self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
This commit is contained in:
parent
85dc7a92ed
commit
41b2ea1ff9
116
componets/DynamicFilter(频域模块动态滤波器用于CV2维图像).py
Normal file
116
componets/DynamicFilter(频域模块动态滤波器用于CV2维图像).py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from timm.layers.helpers import to_2tuple
|
||||||
|
|
||||||
|
"""
|
||||||
|
配备多头自注意力 (MHSA) 的模型在计算机视觉方面取得了显着的性能。它们的计算复杂度与输入特征图中的二次像素数成正比,导致处理速度缓慢,尤其是在处理高分辨率图像时。
|
||||||
|
为了规避这个问题,提出了一种新型的代币混合器作为MHSA的替代方案:基于FFT的代币混合器涉及类似于MHSA的全局操作,但计算复杂度较低。
|
||||||
|
在这里,我们提出了一种名为动态过滤器的新型令牌混合器以缩小上述差距。
|
||||||
|
DynamicFilter 模块通过频域滤波和动态调整滤波器权重,能够对图像进行复杂的增强和处理。
|
||||||
|
"""
|
||||||
|
|
||||||
|
class StarReLU(nn.Module):
|
||||||
|
"""
|
||||||
|
StarReLU: s * relu(x) ** 2 + b
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scale_value=1.0, bias_value=0.0,
|
||||||
|
scale_learnable=True, bias_learnable=True,
|
||||||
|
mode=None, inplace=False):
|
||||||
|
super().__init__()
|
||||||
|
self.inplace = inplace
|
||||||
|
self.relu = nn.ReLU(inplace=inplace)
|
||||||
|
self.scale = nn.Parameter(scale_value * torch.ones(1),
|
||||||
|
requires_grad=scale_learnable)
|
||||||
|
self.bias = nn.Parameter(bias_value * torch.ones(1),
|
||||||
|
requires_grad=bias_learnable)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.scale * self.relu(x) ** 2 + self.bias
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
""" MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
|
||||||
|
Mostly copied from timm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0.,
|
||||||
|
bias=False, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
in_features = dim
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = int(mlp_ratio * in_features)
|
||||||
|
drop_probs = to_2tuple(drop)
|
||||||
|
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.drop1 = nn.Dropout(drop_probs[0])
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||||
|
self.drop2 = nn.Dropout(drop_probs[1])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop1(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicFilter(nn.Module):
|
||||||
|
def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25,
|
||||||
|
act1_layer=StarReLU, act2_layer=nn.Identity,
|
||||||
|
bias=False, num_filters=4, size=14, weight_resize=False,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__()
|
||||||
|
size = to_2tuple(size)
|
||||||
|
self.size = size[0]
|
||||||
|
self.filter_size = size[1] // 2 + 1
|
||||||
|
self.num_filters = num_filters
|
||||||
|
self.dim = dim
|
||||||
|
self.med_channels = int(expansion_ratio * dim)
|
||||||
|
self.weight_resize = weight_resize
|
||||||
|
self.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias)
|
||||||
|
self.act1 = act1_layer()
|
||||||
|
self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels)
|
||||||
|
self.complex_weights = nn.Parameter(
|
||||||
|
torch.randn(self.size, self.filter_size, num_filters, 2,
|
||||||
|
dtype=torch.float32) * 0.02)
|
||||||
|
self.act2 = act2_layer()
|
||||||
|
self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, H, W, _ = x.shape
|
||||||
|
|
||||||
|
routeing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters,
|
||||||
|
-1).softmax(dim=1)
|
||||||
|
x = self.pwconv1(x)
|
||||||
|
x = self.act1(x)
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
|
||||||
|
|
||||||
|
if self.weight_resize:
|
||||||
|
complex_weights = resize_complex_weight(self.complex_weights, x.shape[1],
|
||||||
|
x.shape[2])
|
||||||
|
complex_weights = torch.view_as_complex(complex_weights.contiguous())
|
||||||
|
else:
|
||||||
|
complex_weights = torch.view_as_complex(self.complex_weights)
|
||||||
|
routeing = routeing.to(torch.complex64)
|
||||||
|
weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights)
|
||||||
|
if self.weight_resize:
|
||||||
|
weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels)
|
||||||
|
else:
|
||||||
|
weight = weight.view(-1, self.size, self.filter_size, self.med_channels)
|
||||||
|
x = x * weight
|
||||||
|
x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
|
||||||
|
|
||||||
|
x = self.act2(x)
|
||||||
|
x = self.pwconv2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
block = DynamicFilter(32, size=64) # size==H,W
|
||||||
|
input = torch.rand(3, 64, 64, 32)
|
||||||
|
output = block(input)
|
||||||
|
print(input.size())
|
||||||
|
print(output.size())
|
65
componets/SMFA.py
Normal file
65
componets/SMFA.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
"""ECCV2024(https://github.com/Zheng-MJ/SMFANet)
|
||||||
|
基于Transformer的恢复方法取得了显著的效果,因为Transformer的自注意力机制(SA)可以探索非局部信息,从而实现更好的高分辨率图像重建。然而,关键的点积自注意力需要大量的计算资源,这限制了其在低功耗设备上的应用。
|
||||||
|
此外,自注意力机制的低通滤波特性限制了其捕捉局部细节的能力,从而导致重建结果过于平滑。为了解决这些问题,我们提出了一种自调制特征聚合(SMFA)模块,协同利用局部和非局部特征交互,以实现更精确的重建。
|
||||||
|
具体而言,SMFA模块采用了高效的自注意力近似(EASA)分支来建模非局部信息,并使用局部细节估计(LDE)分支来捕捉局部细节。此外,我们还引入了基于部分卷积的前馈网络(PCFN),以进一步优化从SMFA提取的代表性特征。
|
||||||
|
大量实验表明,所提出的SMFANet系列在公共基准数据集上实现了更好的重建性能与计算效率的平衡。
|
||||||
|
特别是,与SwinIR-light的×4放大相比,SMFANet+在五个公共测试集上的平均性能提高了0.14dB,运行速度提升了约10倍,且模型复杂度(如FLOPs)仅为其约43%。
|
||||||
|
"""
|
||||||
|
|
||||||
|
class DMlp(nn.Module):
|
||||||
|
def __init__(self, dim, growth_rate=2.0):
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(dim * growth_rate)
|
||||||
|
self.conv_0 = nn.Sequential(
|
||||||
|
nn.Conv2d(dim, hidden_dim, 3, 1, 1, groups=dim),
|
||||||
|
nn.Conv2d(hidden_dim, hidden_dim, 1, 1, 0)
|
||||||
|
)
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_0(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.conv_1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SMFA(nn.Module):
|
||||||
|
def __init__(self, dim=36):
|
||||||
|
super(SMFA, self).__init__()
|
||||||
|
self.linear_0 = nn.Conv2d(dim, dim * 2, 1, 1, 0)
|
||||||
|
self.linear_1 = nn.Conv2d(dim, dim, 1, 1, 0)
|
||||||
|
self.linear_2 = nn.Conv2d(dim, dim, 1, 1, 0)
|
||||||
|
|
||||||
|
self.lde = DMlp(dim, 2)
|
||||||
|
|
||||||
|
self.dw_conv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
|
||||||
|
|
||||||
|
self.gelu = nn.GELU()
|
||||||
|
self.down_scale = 8
|
||||||
|
|
||||||
|
self.alpha = nn.Parameter(torch.ones((1, dim, 1, 1)))
|
||||||
|
self.belt = nn.Parameter(torch.zeros((1, dim, 1, 1)))
|
||||||
|
|
||||||
|
def forward(self, f):
|
||||||
|
_, _, h, w = f.shape
|
||||||
|
y, x = self.linear_0(f).chunk(2, dim=1)
|
||||||
|
x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)))
|
||||||
|
x_v = torch.var(x, dim=(-2, -1), keepdim=True)
|
||||||
|
x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h, w),
|
||||||
|
mode='nearest')
|
||||||
|
y_d = self.lde(y)
|
||||||
|
return self.linear_2(x_l + y_d)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
block = SMFA(dim=36)
|
||||||
|
input = torch.randn(3, 36, 64, 64)
|
||||||
|
output = block(input)
|
||||||
|
print(input.size())
|
||||||
|
print(output.size())
|
BIN
componets/UFFC(CV2维任务).pdf
Normal file
BIN
componets/UFFC(CV2维任务).pdf
Normal file
Binary file not shown.
123
componets/UFFC(CV2维任务).py
Normal file
123
componets/UFFC(CV2维任务).py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
"""ICCV2023
|
||||||
|
最近提出的图像修复方法 LaMa 以快速傅里叶卷积 (FFC) 为基础构建了其网络,该网络最初是为图像分类等高级视觉任务而提出的。
|
||||||
|
FFC 使全卷积网络在其早期层中拥有全局感受野。得益于 FFC 模块的独特特性,LaMa 能够生成稳健的重复纹理,
|
||||||
|
这是以前的修复方法无法实现的。但是,原始 FFC 模块是否适合图像修复等低级视觉任务?
|
||||||
|
在本文中,我们分析了在图像修复中使用 FFC 的基本缺陷,即 1) 频谱偏移、2) 意外的空间激活和 3) 频率感受野有限。
|
||||||
|
这些缺陷使得基于 FFC 的修复框架难以生成复杂纹理并执行完美重建。
|
||||||
|
基于以上分析,我们提出了一种新颖的无偏快速傅里叶卷积 (UFFC) 模块,该模块通过
|
||||||
|
1) 范围变换和逆变换、2) 绝对位置嵌入、3) 动态跳过连接和 4) 自适应剪辑对原始 FFC 模块进行了修改,以克服这些缺陷,
|
||||||
|
实现更好的修复效果。在多个基准数据集上进行的大量实验证明了我们方法的有效性,在纹理捕捉能力和表现力方面均优于最先进的方法。
|
||||||
|
"""
|
||||||
|
|
||||||
|
class FourierUnit_modified(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
|
||||||
|
spectral_pos_encoding=False, use_se=False, ffc3d=False, fft_norm='ortho'):
|
||||||
|
# bn_layer not used
|
||||||
|
super(FourierUnit_modified, self).__init__()
|
||||||
|
self.groups = groups
|
||||||
|
|
||||||
|
self.input_shape = 32 # change!!!!!it!!!!!!manually!!!!!!
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.locMap = nn.Parameter(torch.rand(self.input_shape, self.input_shape // 2 + 1))
|
||||||
|
|
||||||
|
self.lambda_base = nn.Parameter(torch.tensor(0.), requires_grad=True)
|
||||||
|
|
||||||
|
self.conv_layer_down55 = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
|
||||||
|
out_channels=out_channels * 2,
|
||||||
|
kernel_size=1, stride=1, padding=0, dilation=1, groups=self.groups,
|
||||||
|
bias=False, padding_mode='reflect')
|
||||||
|
self.conv_layer_down55_shift = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
|
||||||
|
out_channels=out_channels * 2,
|
||||||
|
kernel_size=3, stride=1, padding=2, dilation=2,
|
||||||
|
groups=self.groups, bias=False, padding_mode='reflect')
|
||||||
|
|
||||||
|
self.norm = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
self.spatial_scale_factor = spatial_scale_factor
|
||||||
|
self.spatial_scale_mode = spatial_scale_mode
|
||||||
|
self.spectral_pos_encoding = spectral_pos_encoding
|
||||||
|
self.ffc3d = ffc3d
|
||||||
|
self.fft_norm = fft_norm
|
||||||
|
|
||||||
|
self.img_freq = None
|
||||||
|
self.distill = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch = x.shape[0]
|
||||||
|
|
||||||
|
if self.spatial_scale_factor is not None:
|
||||||
|
orig_size = x.shape[-2:]
|
||||||
|
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode,
|
||||||
|
align_corners=False)
|
||||||
|
|
||||||
|
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
|
||||||
|
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
|
||||||
|
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
||||||
|
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
||||||
|
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
|
||||||
|
|
||||||
|
locMap = self.locMap.expand_as(ffted[:, :1, :, :]) # B 1 H' W'
|
||||||
|
ffted_copy = ffted.clone()
|
||||||
|
|
||||||
|
cat_img_mask_freq = torch.cat((ffted[:, :self.in_channels, :, :],
|
||||||
|
ffted[:, self.in_channels:, :, :],
|
||||||
|
locMap), dim=1)
|
||||||
|
|
||||||
|
ffted = self.conv_layer_down55(cat_img_mask_freq)
|
||||||
|
ffted = torch.fft.fftshift(ffted, dim=-2)
|
||||||
|
|
||||||
|
ffted = self.relu(ffted)
|
||||||
|
|
||||||
|
locMap_shift = torch.fft.fftshift(locMap, dim=-2) ## ONLY IF NOT SHIFT BACK
|
||||||
|
|
||||||
|
# REPEAT CONV
|
||||||
|
cat_img_mask_freq1 = torch.cat((ffted[:, :self.in_channels, :, :],
|
||||||
|
ffted[:, self.in_channels:, :, :],
|
||||||
|
locMap_shift), dim=1)
|
||||||
|
|
||||||
|
ffted = self.conv_layer_down55_shift(cat_img_mask_freq1)
|
||||||
|
ffted = torch.fft.fftshift(ffted, dim=-2)
|
||||||
|
|
||||||
|
lambda_base = torch.sigmoid(self.lambda_base)
|
||||||
|
|
||||||
|
ffted = ffted_copy * lambda_base + ffted * (1 - lambda_base)
|
||||||
|
|
||||||
|
# irfft
|
||||||
|
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
|
||||||
|
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
|
||||||
|
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
||||||
|
|
||||||
|
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
|
||||||
|
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
|
||||||
|
|
||||||
|
if self.spatial_scale_factor is not None:
|
||||||
|
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
|
||||||
|
|
||||||
|
epsilon = 0.5
|
||||||
|
output = output - torch.mean(output) + torch.mean(x)
|
||||||
|
output = torch.clip(output, float(x.min() - epsilon), float(x.max() + epsilon))
|
||||||
|
|
||||||
|
self.distill = output # for self perc
|
||||||
|
return output
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
in_channels = 16
|
||||||
|
out_channels = 16
|
||||||
|
|
||||||
|
block = FourierUnit_modified(in_channels=in_channels, out_channels=out_channels)
|
||||||
|
|
||||||
|
input_tensor = torch.rand(8, in_channels, 32, 32)
|
||||||
|
|
||||||
|
output = block(input_tensor)
|
||||||
|
|
||||||
|
print("Input size:", input_tensor.size())
|
||||||
|
print("Output size:", output.size())
|
16
net.py
16
net.py
@ -6,6 +6,7 @@ import torch.utils.checkpoint as checkpoint
|
|||||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
from componets.SMFA import SMFA
|
||||||
from componets.TIAM import SpatiotemporalAttentionFullNotWeightShared
|
from componets.TIAM import SpatiotemporalAttentionFullNotWeightShared
|
||||||
from componets.WTConvCV2 import WTConv2d
|
from componets.WTConvCV2 import WTConv2d
|
||||||
|
|
||||||
@ -158,7 +159,9 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
|
|
||||||
self.WTConv2d = WTConv2d(dim, dim)
|
self.WTConv2d = WTConv2d(dim, dim)
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
self.token_mixer = SMFA(dim=dim)
|
||||||
|
|
||||||
|
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||||
@ -179,11 +182,12 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
def forward(self, x): # 1 64 128 128
|
def forward(self, x): # 1 64 128 128
|
||||||
if self.use_layer_scale:
|
if self.use_layer_scale:
|
||||||
# self.layer_scale_1(64,)
|
# self.layer_scale_1(64,)
|
||||||
|
wtConvX = self.WTConv2d(x)
|
||||||
|
|
||||||
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
||||||
normal = self.norm1(x) # 1 64 128 128
|
normal = self.norm1(x) # 1 64 128 128
|
||||||
token_mix = self.token_mixer(normal) # 1 64 128 128
|
token_mix = self.token_mixer(normal) # 1 64 128 128
|
||||||
|
|
||||||
x = self.WTConv2d(x)
|
|
||||||
|
|
||||||
x = (x +
|
x = (x +
|
||||||
self.drop_path(
|
self.drop_path(
|
||||||
@ -191,7 +195,7 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
)
|
)
|
||||||
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
||||||
)
|
)
|
||||||
x = x + self.drop_path(
|
x = wtConvX + self.drop_path(
|
||||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
* self.poolmlp(self.norm2(x)))
|
* self.poolmlp(self.norm2(x)))
|
||||||
else:
|
else:
|
||||||
@ -256,11 +260,13 @@ class DetailFeatureExtraction(nn.Module):
|
|||||||
# 增强并添加残差连接
|
# 增强并添加残差连接
|
||||||
enhanced_z1 = self.enhancement_module(z1)
|
enhanced_z1 = self.enhancement_module(z1)
|
||||||
enhanced_z2 = self.enhancement_module(z2)
|
enhanced_z2 = self.enhancement_module(z2)
|
||||||
|
|
||||||
|
for layer in self.net:
|
||||||
|
z1, z2 = layer(z1, z2)
|
||||||
|
|
||||||
# 残差连接
|
# 残差连接
|
||||||
z1 = z1 + enhanced_z1
|
z1 = z1 + enhanced_z1
|
||||||
z2 = z2 + enhanced_z2
|
z2 = z2 + enhanced_z2
|
||||||
for layer in self.net:
|
|
||||||
z1, z2 = layer(z1, z2)
|
|
||||||
return torch.cat((z1, z2), dim=1)
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
Loading…
Reference in New Issue
Block a user