diff --git a/componets/DynamicFilter(频域模块动态滤波器用于CV2维图像).py b/componets/DynamicFilter(频域模块动态滤波器用于CV2维图像).py new file mode 100644 index 0000000..da32281 --- /dev/null +++ b/componets/DynamicFilter(频域模块动态滤波器用于CV2维图像).py @@ -0,0 +1,116 @@ +import torch +import torch.nn as nn +from timm.layers.helpers import to_2tuple + +""" +配备多头自注意力 (MHSA) 的模型在计算机视觉方面取得了显着的性能。它们的计算复杂度与输入特征图中的二次像素数成正比,导致处理速度缓慢,尤其是在处理高分辨率图像时。 +为了规避这个问题,提出了一种新型的代币混合器作为MHSA的替代方案:基于FFT的代币混合器涉及类似于MHSA的全局操作,但计算复杂度较低。 +在这里,我们提出了一种名为动态过滤器的新型令牌混合器以缩小上述差距。 +DynamicFilter 模块通过频域滤波和动态调整滤波器权重,能够对图像进行复杂的增强和处理。 +""" + +class StarReLU(nn.Module): + """ + StarReLU: s * relu(x) ** 2 + b + """ + + def __init__(self, scale_value=1.0, bias_value=0.0, + scale_learnable=True, bias_learnable=True, + mode=None, inplace=False): + super().__init__() + self.inplace = inplace + self.relu = nn.ReLU(inplace=inplace) + self.scale = nn.Parameter(scale_value * torch.ones(1), + requires_grad=scale_learnable) + self.bias = nn.Parameter(bias_value * torch.ones(1), + requires_grad=bias_learnable) + + def forward(self, x): + return self.scale * self.relu(x) ** 2 + self.bias + +class Mlp(nn.Module): + """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks. + Mostly copied from timm. + """ + + def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., + bias=False, **kwargs): + super().__init__() + in_features = dim + out_features = out_features or in_features + hidden_features = int(mlp_ratio * in_features) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class DynamicFilter(nn.Module): + def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25, + act1_layer=StarReLU, act2_layer=nn.Identity, + bias=False, num_filters=4, size=14, weight_resize=False, + **kwargs): + super().__init__() + size = to_2tuple(size) + self.size = size[0] + self.filter_size = size[1] // 2 + 1 + self.num_filters = num_filters + self.dim = dim + self.med_channels = int(expansion_ratio * dim) + self.weight_resize = weight_resize + self.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias) + self.act1 = act1_layer() + self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels) + self.complex_weights = nn.Parameter( + torch.randn(self.size, self.filter_size, num_filters, 2, + dtype=torch.float32) * 0.02) + self.act2 = act2_layer() + self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias) + + def forward(self, x): + B, H, W, _ = x.shape + + routeing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters, + -1).softmax(dim=1) + x = self.pwconv1(x) + x = self.act1(x) + x = x.to(torch.float32) + x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho') + + if self.weight_resize: + complex_weights = resize_complex_weight(self.complex_weights, x.shape[1], + x.shape[2]) + complex_weights = torch.view_as_complex(complex_weights.contiguous()) + else: + complex_weights = torch.view_as_complex(self.complex_weights) + routeing = routeing.to(torch.complex64) + weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights) + if self.weight_resize: + weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels) + else: + weight = weight.view(-1, self.size, self.filter_size, self.med_channels) + x = x * weight + x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho') + + x = self.act2(x) + x = self.pwconv2(x) + return x + + +if __name__ == '__main__': + block = DynamicFilter(32, size=64) # size==H,W + input = torch.rand(3, 64, 64, 32) + output = block(input) + print(input.size()) + print(output.size()) \ No newline at end of file diff --git a/componets/SMFA.py b/componets/SMFA.py new file mode 100644 index 0000000..e86d364 --- /dev/null +++ b/componets/SMFA.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +"""ECCV2024(https://github.com/Zheng-MJ/SMFANet) +基于Transformer的恢复方法取得了显著的效果,因为Transformer的自注意力机制(SA)可以探索非局部信息,从而实现更好的高分辨率图像重建。然而,关键的点积自注意力需要大量的计算资源,这限制了其在低功耗设备上的应用。 +此外,自注意力机制的低通滤波特性限制了其捕捉局部细节的能力,从而导致重建结果过于平滑。为了解决这些问题,我们提出了一种自调制特征聚合(SMFA)模块,协同利用局部和非局部特征交互,以实现更精确的重建。 +具体而言,SMFA模块采用了高效的自注意力近似(EASA)分支来建模非局部信息,并使用局部细节估计(LDE)分支来捕捉局部细节。此外,我们还引入了基于部分卷积的前馈网络(PCFN),以进一步优化从SMFA提取的代表性特征。 +大量实验表明,所提出的SMFANet系列在公共基准数据集上实现了更好的重建性能与计算效率的平衡。 +特别是,与SwinIR-light的×4放大相比,SMFANet+在五个公共测试集上的平均性能提高了0.14dB,运行速度提升了约10倍,且模型复杂度(如FLOPs)仅为其约43%。 +""" + +class DMlp(nn.Module): + def __init__(self, dim, growth_rate=2.0): + super().__init__() + hidden_dim = int(dim * growth_rate) + self.conv_0 = nn.Sequential( + nn.Conv2d(dim, hidden_dim, 3, 1, 1, groups=dim), + nn.Conv2d(hidden_dim, hidden_dim, 1, 1, 0) + ) + self.act = nn.GELU() + self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0) + + def forward(self, x): + x = self.conv_0(x) + x = self.act(x) + x = self.conv_1(x) + return x + + +class SMFA(nn.Module): + def __init__(self, dim=36): + super(SMFA, self).__init__() + self.linear_0 = nn.Conv2d(dim, dim * 2, 1, 1, 0) + self.linear_1 = nn.Conv2d(dim, dim, 1, 1, 0) + self.linear_2 = nn.Conv2d(dim, dim, 1, 1, 0) + + self.lde = DMlp(dim, 2) + + self.dw_conv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim) + + self.gelu = nn.GELU() + self.down_scale = 8 + + self.alpha = nn.Parameter(torch.ones((1, dim, 1, 1))) + self.belt = nn.Parameter(torch.zeros((1, dim, 1, 1))) + + def forward(self, f): + _, _, h, w = f.shape + y, x = self.linear_0(f).chunk(2, dim=1) + x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale))) + x_v = torch.var(x, dim=(-2, -1), keepdim=True) + x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h, w), + mode='nearest') + y_d = self.lde(y) + return self.linear_2(x_l + y_d) + + +if __name__ == '__main__': + block = SMFA(dim=36) + input = torch.randn(3, 36, 64, 64) + output = block(input) + print(input.size()) + print(output.size()) \ No newline at end of file diff --git a/componets/UFFC(CV2维任务).pdf b/componets/UFFC(CV2维任务).pdf new file mode 100644 index 0000000..b06eb51 Binary files /dev/null and b/componets/UFFC(CV2维任务).pdf differ diff --git a/componets/UFFC(CV2维任务).py b/componets/UFFC(CV2维任务).py new file mode 100644 index 0000000..6e03611 --- /dev/null +++ b/componets/UFFC(CV2维任务).py @@ -0,0 +1,123 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +"""ICCV2023 +最近提出的图像修复方法 LaMa 以快速傅里叶卷积 (FFC) 为基础构建了其网络,该网络最初是为图像分类等高级视觉任务而提出的。 +FFC 使全卷积网络在其早期层中拥有全局感受野。得益于 FFC 模块的独特特性,LaMa 能够生成稳健的重复纹理, +这是以前的修复方法无法实现的。但是,原始 FFC 模块是否适合图像修复等低级视觉任务? +在本文中,我们分析了在图像修复中使用 FFC 的基本缺陷,即 1) 频谱偏移、2) 意外的空间激活和 3) 频率感受野有限。 +这些缺陷使得基于 FFC 的修复框架难以生成复杂纹理并执行完美重建。 +基于以上分析,我们提出了一种新颖的无偏快速傅里叶卷积 (UFFC) 模块,该模块通过 + 1) 范围变换和逆变换、2) 绝对位置嵌入、3) 动态跳过连接和 4) 自适应剪辑对原始 FFC 模块进行了修改,以克服这些缺陷, +实现更好的修复效果。在多个基准数据集上进行的大量实验证明了我们方法的有效性,在纹理捕捉能力和表现力方面均优于最先进的方法。 +""" + +class FourierUnit_modified(nn.Module): + + def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear', + spectral_pos_encoding=False, use_se=False, ffc3d=False, fft_norm='ortho'): + # bn_layer not used + super(FourierUnit_modified, self).__init__() + self.groups = groups + + self.input_shape = 32 # change!!!!!it!!!!!!manually!!!!!! + self.in_channels = in_channels + + self.locMap = nn.Parameter(torch.rand(self.input_shape, self.input_shape // 2 + 1)) + + self.lambda_base = nn.Parameter(torch.tensor(0.), requires_grad=True) + + self.conv_layer_down55 = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap + out_channels=out_channels * 2, + kernel_size=1, stride=1, padding=0, dilation=1, groups=self.groups, + bias=False, padding_mode='reflect') + self.conv_layer_down55_shift = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap + out_channels=out_channels * 2, + kernel_size=3, stride=1, padding=2, dilation=2, + groups=self.groups, bias=False, padding_mode='reflect') + + self.norm = nn.BatchNorm2d(out_channels) + + self.relu = nn.ReLU(inplace=True) + + self.spatial_scale_factor = spatial_scale_factor + self.spatial_scale_mode = spatial_scale_mode + self.spectral_pos_encoding = spectral_pos_encoding + self.ffc3d = ffc3d + self.fft_norm = fft_norm + + self.img_freq = None + self.distill = None + + def forward(self, x): + batch = x.shape[0] + + if self.spatial_scale_factor is not None: + orig_size = x.shape[-2:] + x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, + align_corners=False) + + fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) + ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) + ffted = torch.stack((ffted.real, ffted.imag), dim=-1) + ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) + ffted = ffted.view((batch, -1,) + ffted.size()[3:]) + + locMap = self.locMap.expand_as(ffted[:, :1, :, :]) # B 1 H' W' + ffted_copy = ffted.clone() + + cat_img_mask_freq = torch.cat((ffted[:, :self.in_channels, :, :], + ffted[:, self.in_channels:, :, :], + locMap), dim=1) + + ffted = self.conv_layer_down55(cat_img_mask_freq) + ffted = torch.fft.fftshift(ffted, dim=-2) + + ffted = self.relu(ffted) + + locMap_shift = torch.fft.fftshift(locMap, dim=-2) ## ONLY IF NOT SHIFT BACK + + # REPEAT CONV + cat_img_mask_freq1 = torch.cat((ffted[:, :self.in_channels, :, :], + ffted[:, self.in_channels:, :, :], + locMap_shift), dim=1) + + ffted = self.conv_layer_down55_shift(cat_img_mask_freq1) + ffted = torch.fft.fftshift(ffted, dim=-2) + + lambda_base = torch.sigmoid(self.lambda_base) + + ffted = ffted_copy * lambda_base + ffted * (1 - lambda_base) + + # irfft + ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( + 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) + + ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] + output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) + + if self.spatial_scale_factor is not None: + output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False) + + epsilon = 0.5 + output = output - torch.mean(output) + torch.mean(x) + output = torch.clip(output, float(x.min() - epsilon), float(x.max() + epsilon)) + + self.distill = output # for self perc + return output + +if __name__ == '__main__': + in_channels = 16 + out_channels = 16 + + block = FourierUnit_modified(in_channels=in_channels, out_channels=out_channels) + + input_tensor = torch.rand(8, in_channels, 32, 32) + + output = block(input_tensor) + + print("Input size:", input_tensor.size()) + print("Output size:", output.size()) \ No newline at end of file diff --git a/net.py b/net.py index 9e623f1..c47aa5b 100644 --- a/net.py +++ b/net.py @@ -6,6 +6,7 @@ import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from einops import rearrange +from componets.SMFA import SMFA from componets.TIAM import SpatiotemporalAttentionFullNotWeightShared from componets.WTConvCV2 import WTConv2d @@ -158,7 +159,9 @@ class BaseFeatureExtraction(nn.Module): self.WTConv2d = WTConv2d(dim, dim) self.norm1 = LayerNorm(dim, 'WithBias') - self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 + self.token_mixer = SMFA(dim=dim) + + # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 self.norm2 = LayerNorm(dim, 'WithBias') mlp_hidden_dim = int(dim * mlp_ratio) self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, @@ -179,11 +182,12 @@ class BaseFeatureExtraction(nn.Module): def forward(self, x): # 1 64 128 128 if self.use_layer_scale: # self.layer_scale_1(64,) + wtConvX = self.WTConv2d(x) + tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1 normal = self.norm1(x) # 1 64 128 128 token_mix = self.token_mixer(normal) # 1 64 128 128 - x = self.WTConv2d(x) x = (x + self.drop_path( @@ -191,7 +195,7 @@ class BaseFeatureExtraction(nn.Module): ) # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。 ) - x = x + self.drop_path( + x = wtConvX + self.drop_path( self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.poolmlp(self.norm2(x))) else: @@ -256,11 +260,13 @@ class DetailFeatureExtraction(nn.Module): # 增强并添加残差连接 enhanced_z1 = self.enhancement_module(z1) enhanced_z2 = self.enhancement_module(z2) + + for layer in self.net: + z1, z2 = layer(z1, z2) + # 残差连接 z1 = z1 + enhanced_z1 z2 = z2 + enhanced_z2 - for layer in self.net: - z1, z2 = layer(z1, z2) return torch.cat((z1, z2), dim=1) # =============================================================================