feat(net): 移除 WTConv2d,添加 DEConv 模块

- 删除了 WTConv2d 相关代码
- 新增了 DEConv 模块,包括多种卷积类型
- 更新了 net.py 中的相关调用,移除了 WTConv2d
This commit is contained in:
whai 2024-11-14 17:15:00 +08:00
parent 30bbfdf86e
commit 555515c2dc
2 changed files with 174 additions and 5 deletions

171
componets/DEConv.py Normal file
View File

@ -0,0 +1,171 @@
# https://github.com/cecret3350/DEA-Net/blob/main/code/model/modules/deconv.py
import math
import torch
from torch import nn
from einops.layers.torch import Rearrange
"""
在深度学习和图像处理领域"vanilla" "difference" 卷积是两种不同的卷积操作它们各自有不同的特性和应用场景DEConv细节增强卷积的设计思想是结合这两种卷积的特性来增强模型的性能尤其是在图像去雾等任务中
Vanilla Convolution标准卷积
"Vanilla" 卷积是最基本的卷积类型通常仅称为卷积它是卷积神经网络CNN中最常用的组件用于提取输入数据如图像的特征
标准卷积通过在输入数据上滑动小的可学习的过滤器或称为核并计算过滤器与数据的局部区域之间的点乘来工作通过这种方式它能够捕获输入数据的局部模式和特征
Difference Convolution差分卷积
差分卷积是一种特殊类型的卷积它专注于捕捉输入数据中的局部差异信息例如边缘或纹理的变化
它通过修改标准卷积核的权重或者通过特殊的操作来实现使得网络更加关注于图像的高频信息即图像中的细节和纹理变化在图像处理任务中如图像去雾图像增强边缘检测等捕获这种高频信息非常重要因为它们往往包含了关于物体边界和结构的关键信息
重参数化技术
重参数化技术是一种参数转换方法它允许模型在不增加额外参数和计算代价的情况下实现更复杂的功能或改善性能在DEConv的上下文中重参数化技术使得将vanilla卷积和difference卷积结合起来的操作可以等价地转换成一个标准的卷积操作
这意味着DEConv可以在不增加额外参数和计算成本的情况下通过巧妙地设计卷积核权重同时利用标准卷积和差分卷积的优势从而增强模型处理图像的能力
具体来说通过重参数化可以将差分卷积的效果整合到一个卷积核中使得这个卷积核既能捕获图像的基本特征通过标准卷积部分也能强调图像的细节和差异信息通过差分卷积部分
这种方法特别适用于那些需要同时考虑全局内容和局部细节信息的任务如图像去雾其中既需要理解图像的整体结构也需要恢复由于雾导致的细节丢失
重参数化技术的关键优势在于它允许模型在维持参数数量和计算复杂度不变的前提下实现更为复杂或更为精细的功能
"""
class Conv2d_cd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_cd, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
# conv_weight_cd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_cd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_cd[:, :, :] = conv_weight[:, :, :]
conv_weight_cd[:, :, 4] = conv_weight[:, :, 4] - conv_weight[:, :, :].sum(2)
conv_weight_cd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
conv_weight_cd)
return conv_weight_cd, self.conv.bias
class Conv2d_ad(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_ad, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
conv_weight_ad = conv_weight - self.theta * conv_weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]
conv_weight_ad = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
conv_weight_ad)
return conv_weight_ad, self.conv.bias
class Conv2d_rd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=2, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_rd, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def forward(self, x):
if math.fabs(self.theta - 0.0) < 1e-8:
out_normal = self.conv(x)
return out_normal
else:
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
if conv_weight.is_cuda:
conv_weight_rd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 5 * 5).fill_(0)
else:
conv_weight_rd = torch.zeros(conv_shape[0], conv_shape[1], 5 * 5)
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
conv_weight_rd[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = conv_weight[:, :, 1:]
conv_weight_rd[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -conv_weight[:, :, 1:] * self.theta
conv_weight_rd[:, :, 12] = conv_weight[:, :, 0] * (1 - self.theta)
conv_weight_rd = conv_weight_rd.view(conv_shape[0], conv_shape[1], 5, 5)
out_diff = nn.functional.conv2d(input=x, weight=conv_weight_rd, bias=self.conv.bias,
stride=self.conv.stride, padding=self.conv.padding, groups=self.conv.groups)
return out_diff
class Conv2d_hd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_hd, self).__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
# conv_weight_hd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_hd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_hd[:, :, [0, 3, 6]] = conv_weight[:, :, :]
conv_weight_hd[:, :, [2, 5, 8]] = -conv_weight[:, :, :]
conv_weight_hd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
conv_weight_hd)
return conv_weight_hd, self.conv.bias
class Conv2d_vd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False):
super(Conv2d_vd, self).__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
# conv_weight_vd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_vd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_vd[:, :, [0, 1, 2]] = conv_weight[:, :, :]
conv_weight_vd[:, :, [6, 7, 8]] = -conv_weight[:, :, :]
conv_weight_vd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
conv_weight_vd)
return conv_weight_vd, self.conv.bias
class DEConv(nn.Module):
def __init__(self, dim):
super(DEConv, self).__init__()
self.conv1_1 = Conv2d_cd(dim, dim, 3, bias=True)
self.conv1_2 = Conv2d_hd(dim, dim, 3, bias=True)
self.conv1_3 = Conv2d_vd(dim, dim, 3, bias=True)
self.conv1_4 = Conv2d_ad(dim, dim, 3, bias=True)
self.conv1_5 = nn.Conv2d(dim, dim, 3, padding=1, bias=True)
def forward(self, x):
w1, b1 = self.conv1_1.get_weight()
w2, b2 = self.conv1_2.get_weight()
w3, b3 = self.conv1_3.get_weight()
w4, b4 = self.conv1_4.get_weight()
w5, b5 = self.conv1_5.weight, self.conv1_5.bias
w = w1 + w2 + w3 + w4 + w5
b = b1 + b2 + b3 + b4 + b5
res = nn.functional.conv2d(input=x, weight=w, bias=b, stride=1, padding=1, groups=1)
return res
if __name__ == '__main__':
# 初始化DEConv模块dim为输入和输出的通道数
block = DEConv(dim=16)
# 创建一个随机输入张量,假设输入尺寸为(batch_size, channels, height, width)
input_tensor = torch.rand(4, 16, 64, 64)
# 将输入传递给DEConv模块
output_tensor = block(input_tensor)
# 打印输入和输出张量的尺寸
print("输入尺寸:", input_tensor.size())
print("输出尺寸:", output_tensor.size())

8
net.py
View File

@ -155,7 +155,7 @@ class BaseFeatureExtraction(nn.Module):
super().__init__() super().__init__()
self.WTConv2d = WTConv2d(dim, dim)
self.norm1 = LayerNorm(dim, 'WithBias') self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代 self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias') self.norm2 = LayerNorm(dim, 'WithBias')
@ -207,7 +207,7 @@ class BaseFeatureFusion(nn.Module):
super().__init__() super().__init__()
self.WTConv2d = WTConv2d(dim, dim)
self.norm1 = LayerNorm(dim, 'WithBias') self.norm1 = LayerNorm(dim, 'WithBias')
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代 # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代 self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
@ -235,7 +235,7 @@ class BaseFeatureFusion(nn.Module):
normal = self.norm1(x) # 1 64 128 128 normal = self.norm1(x) # 1 64 128 128
token_mix = self.token_mixer(normal) # 1 64 128 128 token_mix = self.token_mixer(normal) # 1 64 128 128
x = self.WTConv2d(x)
x = (x + x = (x +
self.drop_path( self.drop_path(
@ -561,8 +561,6 @@ class BaseFeatureExtractionSAR(nn.Module):
normal = self.norm1(x) # 1 64 128 128 normal = self.norm1(x) # 1 64 128 128
token_mix = self.token_mixer(normal) # 1 64 128 128 token_mix = self.token_mixer(normal) # 1 64 128 128
x = self.WTConv2d(x)
x = (x + x = (x +
self.drop_path( self.drop_path(
tmp1 * token_mix tmp1 * token_mix