feat(net): 替换 token_mixer 为 SCSA 模块

- 引入新的 SCSA(空间和通道协同注意力)模块
- 用 SCSA 替换原有的 Pooling层作为 token_mixer
- 删除了未使用的 SEBlock.py 文件- 移除了与当前项目无关的 TIAM(CV).py 文件
This commit is contained in:
zjut 2024-11-16 23:07:51 +08:00
parent 0cf1726eeb
commit 8d99c2c4f8
3 changed files with 160 additions and 111 deletions

156
componets/SCSA.py Normal file
View File

@ -0,0 +1,156 @@
import typing as t
import torch
import torch.nn as nn
from einops.einops import rearrange
from mmengine.model import BaseModule
__all__ = ['SCSA']
"""SCSA探索空间注意力和通道注意力之间的协同作用
通道和空间注意力分别在为各种下游视觉任务提取特征依赖性和空间结构关系方面带来了显着的改进
虽然它们的结合更有利于发挥各自的优势但通道和空间注意力之间的协同作用尚未得到充分探索缺乏充分利用多语义信息的协同潜力来进行特征引导和缓解语义差异
我们的研究试图在多个语义层面揭示空间和通道注意力之间的协同关系提出了一种新颖的空间和通道协同注意力模块SCSA我们的SCSA由两部分组成可共享的多语义空间注意力SMSA和渐进式通道自注意力PCSA
SMSA 集成多语义信息并利用渐进式压缩策略将判别性空间先验注入 PCSA 的通道自注意力中有效地指导通道重新校准此外PCSA 中基于自注意力机制的稳健特征交互进一步缓解了 SMSA 中不同子特征之间多语义信息的差异
我们在七个基准数据集上进行了广泛的实验包括 ImageNet-1K 上的分类MSCOCO 2017 上的对象检测ADE20K 上的分割以及其他四个复杂场景检测数据集我们的结果表明我们提出的 SCSA 不仅超越了当前最先进的注意力机制
而且在各种任务场景中表现出增强的泛化能力
"""
class SCSA(BaseModule):
def __init__(
self,
dim: int,
head_num: int,
window_size: int = 7,
group_kernel_sizes: t.List[int] = [3, 5, 7, 9],
qkv_bias: bool = False,
fuse_bn: bool = False,
norm_cfg: t.Dict = dict(type='BN'),
act_cfg: t.Dict = dict(type='ReLU'),
down_sample_mode: str = 'avg_pool',
attn_drop_ratio: float = 0.,
gate_layer: str = 'sigmoid',
):
super(SCSA, self).__init__()
self.dim = dim
self.head_num = head_num
self.head_dim = dim // head_num
self.scaler = self.head_dim ** -0.5
self.group_kernel_sizes = group_kernel_sizes
self.window_size = window_size
self.qkv_bias = qkv_bias
self.fuse_bn = fuse_bn
self.down_sample_mode = down_sample_mode
assert self.dim // 4, 'The dimension of input feature should be divisible by 4.'
self.group_chans = group_chans = self.dim // 4
self.local_dwc = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[0],
padding=group_kernel_sizes[0] // 2, groups=group_chans)
self.global_dwc_s = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[1],
padding=group_kernel_sizes[1] // 2, groups=group_chans)
self.global_dwc_m = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[2],
padding=group_kernel_sizes[2] // 2, groups=group_chans)
self.global_dwc_l = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[3],
padding=group_kernel_sizes[3] // 2, groups=group_chans)
self.sa_gate = nn.Softmax(dim=2) if gate_layer == 'softmax' else nn.Sigmoid()
self.norm_h = nn.GroupNorm(4, dim)
self.norm_w = nn.GroupNorm(4, dim)
self.conv_d = nn.Identity()
self.norm = nn.GroupNorm(1, dim)
self.q = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.k = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.v = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.ca_gate = nn.Softmax(dim=1) if gate_layer == 'softmax' else nn.Sigmoid()
if window_size == -1:
self.down_func = nn.AdaptiveAvgPool2d((1, 1))
else:
if down_sample_mode == 'recombination':
self.down_func = self.space_to_chans
# dimensionality reduction
self.conv_d = nn.Conv2d(in_channels=dim * window_size ** 2, out_channels=dim, kernel_size=1, bias=False)
elif down_sample_mode == 'avg_pool':
self.down_func = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=window_size)
elif down_sample_mode == 'max_pool':
self.down_func = nn.MaxPool2d(kernel_size=(window_size, window_size), stride=window_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
The dim of x is (B, C, H, W)
"""
# Spatial attention priority calculation
b, c, h_, w_ = x.size()
# (B, C, H)
x_h = x.mean(dim=3)
l_x_h, g_x_h_s, g_x_h_m, g_x_h_l = torch.split(x_h, self.group_chans, dim=1)
# (B, C, W)
x_w = x.mean(dim=2)
l_x_w, g_x_w_s, g_x_w_m, g_x_w_l = torch.split(x_w, self.group_chans, dim=1)
x_h_attn = self.sa_gate(self.norm_h(torch.cat((
self.local_dwc(l_x_h),
self.global_dwc_s(g_x_h_s),
self.global_dwc_m(g_x_h_m),
self.global_dwc_l(g_x_h_l),
), dim=1)))
x_h_attn = x_h_attn.view(b, c, h_, 1)
x_w_attn = self.sa_gate(self.norm_w(torch.cat((
self.local_dwc(l_x_w),
self.global_dwc_s(g_x_w_s),
self.global_dwc_m(g_x_w_m),
self.global_dwc_l(g_x_w_l)
), dim=1)))
x_w_attn = x_w_attn.view(b, c, 1, w_)
x = x * x_h_attn * x_w_attn
# Channel attention based on self attention
# reduce calculations
y = self.down_func(x)
y = self.conv_d(y)
_, _, h_, w_ = y.size()
# normalization first, then reshape -> (B, H, W, C) -> (B, C, H * W) and generate q, k and v
y = self.norm(y)
q = self.q(y)
k = self.k(y)
v = self.v(y)
# (B, C, H, W) -> (B, head_num, head_dim, N)
q = rearrange(q, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
k = rearrange(k, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
v = rearrange(v, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
# (B, head_num, head_dim, head_dim)
attn = q @ k.transpose(-2, -1) * self.scaler
attn = self.attn_drop(attn.softmax(dim=-1))
# (B, head_num, head_dim, N)
attn = attn @ v
# (B, C, H_, W_)
attn = rearrange(attn, 'b head_num head_dim (h w) -> b (head_num head_dim) h w', h=int(h_), w=int(w_))
# (B, C, 1, 1)
attn = attn.mean((2, 3), keepdim=True)
attn = self.ca_gate(attn)
return attn * x
if __name__ == '__main__':
block = SCSA(
dim=256,
head_num=8,
)
input_tensor = torch.rand(1, 256, 32, 32)
# 调用模块进行前向传播
output_tensor = block(input_tensor)
# 打印输入和输出张量的大小
print("Input size:", input_tensor.size())
print("Output size:", output_tensor.size())

View File

@ -1,110 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
"""Elsevier2024
变化检测 (CD) 是地球观测中一种重要的监测方法尤其适用于土地利用分析城市管理和灾害损失评估然而在星座互联和空天协作时代感兴趣区域 (ROI) 的变化由于几何透视旋转和时间风格差异而导致许多错误检测
为了应对这些挑战我们引入了 CDNeXt该框架阐明了一种稳健而有效的方法用于将基于预训练主干的 Siamese 网络与用于遥感图像的创新时空交互注意模块 (TIAM) 相结合
CDNeXt 可分为四个主要组件编码器交互器解码器和检测器值得注意的是 TIAM 提供支持的交互器从编码器提取的二进制时间特征中查询和重建空间透视依赖关系和时间风格相关性以扩大 ROI 变化的差异
最后检测器集成解码器生成的分层特征随后生成二进制变化掩码
"""
class SpatiotemporalAttentionFullNotWeightShared(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
super(SpatiotemporalAttentionFullNotWeightShared, self).__init__()
assert dimension in [2, ]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
self.g1 = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
)
self.g2 = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.W1 = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels)
)
self.W2 = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels)
)
self.theta = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.phi = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
def forward(self, x1, x2):
"""
:param x: (b, c, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
"""
batch_size = x1.size(0)
g_x11 = self.g1(x1).reshape(batch_size, self.inter_channels, -1)
g_x12 = g_x11.permute(0, 2, 1)
g_x21 = self.g2(x2).reshape(batch_size, self.inter_channels, -1)
g_x22 = g_x21.permute(0, 2, 1)
theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
theta_x2 = theta_x1.permute(0, 2, 1)
phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)
phi_x2 = phi_x1.permute(0, 2, 1)
energy_time_1 = torch.matmul(theta_x1, phi_x2)
energy_time_2 = energy_time_1.permute(0, 2, 1)
energy_space_1 = torch.matmul(theta_x2, phi_x1)
energy_space_2 = energy_space_1.permute(0, 2, 1)
energy_time_1s = F.softmax(energy_time_1, dim=-1)
energy_time_2s = F.softmax(energy_time_2, dim=-1)
energy_space_2s = F.softmax(energy_space_1, dim=-2)
energy_space_1s = F.softmax(energy_space_2, dim=-2)
# C1*S(C2) energy_time_1s * C1*H1W1 g_x12 * energy_space_1s S(H2W2)*H1W1 -> C1*H1W1
y1 = torch.matmul(torch.matmul(energy_time_2s, g_x11), energy_space_2s).contiguous() # C2*H2W2
# C2*S(C1) energy_time_2s * C2*H2W2 g_x21 * energy_space_2s S(H1W1)*H2W2 -> C2*H2W2
y2 = torch.matmul(torch.matmul(energy_time_1s, g_x21), energy_space_1s).contiguous() # C1*H1W1
y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
return x1 + self.W1(y1), x2 + self.W2(y2)
if __name__ == '__main__':
in_channels = 64
batch_size = 8
height = 32
width = 32
block = SpatiotemporalAttentionFullNotWeightShared(in_channels=in_channels)
input1 = torch.rand(batch_size, in_channels, height, width)
input2 = torch.rand(batch_size, in_channels, height, width)
output1, output2 = block(input1, input2)
print(f"Input1 size: {input1.size()}")
print(f"Input2 size: {input2.size()}")
print(f"Output1 size: {output1.size()}")
print(f"Output2 size: {output2.size()}")

5
net.py
View File

@ -6,6 +6,8 @@ import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange
from componets.SCSA import SCSA
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
@ -164,7 +166,8 @@ class BaseFeatureExtractionSAR(nn.Module):
super().__init__()
self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.token_mixer = SCSA(dim=dim,head_num=8)
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,