diff --git a/componets/DEConv.py b/componets/DEConv.py new file mode 100644 index 0000000..2adf41c --- /dev/null +++ b/componets/DEConv.py @@ -0,0 +1,171 @@ +# https://github.com/cecret3350/DEA-Net/blob/main/code/model/modules/deconv.py + +import math +import torch +from torch import nn +from einops.layers.torch import Rearrange + +""" +在深度学习和图像处理领域,"vanilla" 和 "difference" 卷积是两种不同的卷积操作,它们各自有不同的特性和应用场景。DEConv(细节增强卷积)的设计思想是结合这两种卷积的特性来增强模型的性能,尤其是在图像去雾等任务中。 + +Vanilla Convolution(标准卷积) +"Vanilla" 卷积是最基本的卷积类型,通常仅称为卷积。它是卷积神经网络(CNN)中最常用的组件,用于提取输入数据(如图像)的特征。 +标准卷积通过在输入数据上滑动小的、可学习的过滤器(或称为核),并计算过滤器与数据的局部区域之间的点乘,来工作。通过这种方式,它能够捕获输入数据的局部模式和特征。 + +Difference Convolution(差分卷积) +差分卷积是一种特殊类型的卷积,它专注于捕捉输入数据中的局部差异信息,例如边缘或纹理的变化。 +它通过修改标准卷积核的权重或者通过特殊的操作来实现,使得网络更加关注于图像的高频信息,即图像中的细节和纹理变化。在图像处理任务中,如图像去雾、图像增强、边缘检测等,捕获这种高频信息非常重要,因为它们往往包含了关于物体边界和结构的关键信息。 + +重参数化技术 +重参数化技术是一种参数转换方法,它允许模型在不增加额外参数和计算代价的情况下,实现更复杂的功能或改善性能。在DEConv的上下文中,重参数化技术使得将vanilla卷积和difference卷积结合起来的操作,可以等价地转换成一个标准的卷积操作。 +这意味着DEConv可以在不增加额外参数和计算成本的情况下,通过巧妙地设计卷积核权重,同时利用标准卷积和差分卷积的优势,从而增强模型处理图像的能力。 +具体来说,通过重参数化,可以将差分卷积的效果整合到一个卷积核中,使得这个卷积核既能捕获图像的基本特征(通过标准卷积部分),也能强调图像的细节和差异信息(通过差分卷积部分)。 +这种方法特别适用于那些需要同时考虑全局内容和局部细节信息的任务,如图像去雾,其中既需要理解图像的整体结构,也需要恢复由于雾导致的细节丢失。 +重参数化技术的关键优势在于,它允许模型在维持参数数量和计算复杂度不变的前提下,实现更为复杂或更为精细的功能。 +""" + + +class Conv2d_cd(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, + padding=1, dilation=1, groups=1, bias=False, theta=1.0): + super(Conv2d_cd, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + self.theta = theta + + def get_weight(self): + conv_weight = self.conv.weight + conv_shape = conv_weight.shape + conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight) + # conv_weight_cd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0) + conv_weight_cd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0) + conv_weight_cd[:, :, :] = conv_weight[:, :, :] + conv_weight_cd[:, :, 4] = conv_weight[:, :, 4] - conv_weight[:, :, :].sum(2) + conv_weight_cd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])( + conv_weight_cd) + return conv_weight_cd, self.conv.bias + + +class Conv2d_ad(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, + padding=1, dilation=1, groups=1, bias=False, theta=1.0): + super(Conv2d_ad, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + self.theta = theta + + def get_weight(self): + conv_weight = self.conv.weight + conv_shape = conv_weight.shape + conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight) + conv_weight_ad = conv_weight - self.theta * conv_weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]] + conv_weight_ad = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])( + conv_weight_ad) + return conv_weight_ad, self.conv.bias + + +class Conv2d_rd(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, + padding=2, dilation=1, groups=1, bias=False, theta=1.0): + + super(Conv2d_rd, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + self.theta = theta + + def forward(self, x): + + if math.fabs(self.theta - 0.0) < 1e-8: + out_normal = self.conv(x) + return out_normal + else: + conv_weight = self.conv.weight + conv_shape = conv_weight.shape + if conv_weight.is_cuda: + conv_weight_rd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 5 * 5).fill_(0) + else: + conv_weight_rd = torch.zeros(conv_shape[0], conv_shape[1], 5 * 5) + conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight) + conv_weight_rd[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = conv_weight[:, :, 1:] + conv_weight_rd[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -conv_weight[:, :, 1:] * self.theta + conv_weight_rd[:, :, 12] = conv_weight[:, :, 0] * (1 - self.theta) + conv_weight_rd = conv_weight_rd.view(conv_shape[0], conv_shape[1], 5, 5) + out_diff = nn.functional.conv2d(input=x, weight=conv_weight_rd, bias=self.conv.bias, + stride=self.conv.stride, padding=self.conv.padding, groups=self.conv.groups) + + return out_diff + + +class Conv2d_hd(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, + padding=1, dilation=1, groups=1, bias=False, theta=1.0): + super(Conv2d_hd, self).__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + + def get_weight(self): + conv_weight = self.conv.weight + conv_shape = conv_weight.shape + # conv_weight_hd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0) + conv_weight_hd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0) + conv_weight_hd[:, :, [0, 3, 6]] = conv_weight[:, :, :] + conv_weight_hd[:, :, [2, 5, 8]] = -conv_weight[:, :, :] + conv_weight_hd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])( + conv_weight_hd) + return conv_weight_hd, self.conv.bias + + +class Conv2d_vd(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, + padding=1, dilation=1, groups=1, bias=False): + super(Conv2d_vd, self).__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + + def get_weight(self): + conv_weight = self.conv.weight + conv_shape = conv_weight.shape + # conv_weight_vd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0) + conv_weight_vd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0) + conv_weight_vd[:, :, [0, 1, 2]] = conv_weight[:, :, :] + conv_weight_vd[:, :, [6, 7, 8]] = -conv_weight[:, :, :] + conv_weight_vd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])( + conv_weight_vd) + return conv_weight_vd, self.conv.bias + + +class DEConv(nn.Module): + def __init__(self, dim): + super(DEConv, self).__init__() + self.conv1_1 = Conv2d_cd(dim, dim, 3, bias=True) + self.conv1_2 = Conv2d_hd(dim, dim, 3, bias=True) + self.conv1_3 = Conv2d_vd(dim, dim, 3, bias=True) + self.conv1_4 = Conv2d_ad(dim, dim, 3, bias=True) + self.conv1_5 = nn.Conv2d(dim, dim, 3, padding=1, bias=True) + + def forward(self, x): + w1, b1 = self.conv1_1.get_weight() + w2, b2 = self.conv1_2.get_weight() + w3, b3 = self.conv1_3.get_weight() + w4, b4 = self.conv1_4.get_weight() + w5, b5 = self.conv1_5.weight, self.conv1_5.bias + + w = w1 + w2 + w3 + w4 + w5 + b = b1 + b2 + b3 + b4 + b5 + res = nn.functional.conv2d(input=x, weight=w, bias=b, stride=1, padding=1, groups=1) + + return res + + +if __name__ == '__main__': + # 初始化DEConv模块,dim为输入和输出的通道数 + block = DEConv(dim=16) + # 创建一个随机输入张量,假设输入尺寸为(batch_size, channels, height, width) + input_tensor = torch.rand(4, 16, 64, 64) + # 将输入传递给DEConv模块 + output_tensor = block(input_tensor) + # 打印输入和输出张量的尺寸 + print("输入尺寸:", input_tensor.size()) + print("输出尺寸:", output_tensor.size()) + + diff --git a/net.py b/net.py index 1f45825..4841dad 100644 --- a/net.py +++ b/net.py @@ -155,7 +155,7 @@ class BaseFeatureExtraction(nn.Module): super().__init__() - self.WTConv2d = WTConv2d(dim, dim) + self.norm1 = LayerNorm(dim, 'WithBias') self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 self.norm2 = LayerNorm(dim, 'WithBias') @@ -207,7 +207,7 @@ class BaseFeatureFusion(nn.Module): super().__init__() - self.WTConv2d = WTConv2d(dim, dim) + self.norm1 = LayerNorm(dim, 'WithBias') # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 @@ -235,7 +235,7 @@ class BaseFeatureFusion(nn.Module): normal = self.norm1(x) # 1 64 128 128 token_mix = self.token_mixer(normal) # 1 64 128 128 - x = self.WTConv2d(x) + x = (x + self.drop_path( @@ -561,8 +561,6 @@ class BaseFeatureExtractionSAR(nn.Module): normal = self.norm1(x) # 1 64 128 128 token_mix = self.token_mixer(normal) # 1 64 128 128 - x = self.WTConv2d(x) - x = (x + self.drop_path( tmp1 * token_mix