2024-06-03 19:36:29 +08:00
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import math
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import torch.utils.checkpoint as checkpoint
|
|
|
|
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
|
2024-11-16 23:07:51 +08:00
|
|
|
|
from componets.SCSA import SCSA
|
2024-11-17 15:57:41 +08:00
|
|
|
|
from componets.WTConvCV2 import WTConv2d
|
2024-11-16 23:07:51 +08:00
|
|
|
|
|
2024-06-03 19:36:29 +08:00
|
|
|
|
|
|
|
|
|
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
|
|
|
|
if drop_prob == 0. or not training:
|
|
|
|
|
return x
|
|
|
|
|
keep_prob = 1 - drop_prob
|
|
|
|
|
# work with diff dim tensors, not just 2D ConvNets
|
|
|
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
|
|
|
random_tensor = keep_prob + \
|
|
|
|
|
torch.rand(shape, dtype=x.dtype, device=x.device)
|
|
|
|
|
random_tensor.floor_() # binarize
|
|
|
|
|
output = x.div(keep_prob) * random_tensor
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DropPath(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, drop_prob=None):
|
|
|
|
|
super(DropPath, self).__init__()
|
|
|
|
|
self.drop_prob = drop_prob
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return drop_path(x, self.drop_prob, self.training)
|
2024-10-06 16:42:18 +08:00
|
|
|
|
|
2024-06-03 19:36:29 +08:00
|
|
|
|
class Pooling(nn.Module):
|
|
|
|
|
def __init__(self, kernel_size=3):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.pool = nn.AvgPool2d(
|
|
|
|
|
kernel_size, stride=1, padding=kernel_size // 2)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.pool(x) - x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PoolMlp(nn.Module):
|
|
|
|
|
"""
|
2024-11-15 17:48:42 +08:00
|
|
|
|
Implementation of MLP with 1*1 convolutions.
|
|
|
|
|
Input: tensor with shape [B, C, H, W]
|
2024-06-03 19:36:29 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
in_features,
|
|
|
|
|
hidden_features=None,
|
|
|
|
|
out_features=None,
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
bias=False,
|
|
|
|
|
drop=0.):
|
|
|
|
|
super().__init__()
|
|
|
|
|
out_features = out_features or in_features
|
|
|
|
|
hidden_features = hidden_features or in_features
|
|
|
|
|
self.fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=bias)
|
|
|
|
|
self.act = act_layer()
|
|
|
|
|
self.fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=bias)
|
|
|
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W)
|
|
|
|
|
x = self.act(x)
|
|
|
|
|
x = self.drop(x)
|
|
|
|
|
x = self.fc2(x) # (B, C, H, W) --> (B, C, H, W)
|
|
|
|
|
x = self.drop(x)
|
|
|
|
|
return x
|
|
|
|
|
|
2024-11-16 21:39:34 +08:00
|
|
|
|
|
|
|
|
|
class BaseFeatureFusion(nn.Module):
|
|
|
|
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
# norm_layer=nn.LayerNorm,
|
|
|
|
|
drop=0., drop_path=0.,
|
|
|
|
|
use_layer_scale=True, layer_scale_init_value=1e-5):
|
|
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
|
|
|
|
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
|
|
|
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
|
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
|
|
|
|
act_layer=act_layer, drop=drop)
|
|
|
|
|
|
|
|
|
|
# The following two techniques are useful to train deep PoolFormers.
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
|
|
|
|
else nn.Identity()
|
|
|
|
|
self.use_layer_scale = use_layer_scale
|
|
|
|
|
|
|
|
|
|
if use_layer_scale:
|
|
|
|
|
self.layer_scale_1 = nn.Parameter(
|
|
|
|
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
|
|
|
|
|
|
|
|
|
self.layer_scale_2 = nn.Parameter(
|
|
|
|
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
if self.use_layer_scale:
|
|
|
|
|
x = x + self.drop_path(
|
|
|
|
|
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
|
|
|
|
* self.token_mixer(self.norm1(x)))
|
|
|
|
|
x = x + self.drop_path(
|
|
|
|
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
|
|
|
|
* self.poolmlp(self.norm2(x)))
|
|
|
|
|
else:
|
|
|
|
|
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
|
|
|
|
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
|
|
|
|
return x
|
|
|
|
|
|
2024-06-03 19:36:29 +08:00
|
|
|
|
class BaseFeatureExtraction(nn.Module):
|
|
|
|
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
# norm_layer=nn.LayerNorm,
|
|
|
|
|
drop=0., drop_path=0.,
|
|
|
|
|
use_layer_scale=True, layer_scale_init_value=1e-5):
|
|
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
|
|
|
|
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
|
|
|
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
|
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
|
|
|
|
act_layer=act_layer, drop=drop)
|
|
|
|
|
|
|
|
|
|
# The following two techniques are useful to train deep PoolFormers.
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
|
|
|
|
else nn.Identity()
|
|
|
|
|
self.use_layer_scale = use_layer_scale
|
|
|
|
|
|
|
|
|
|
if use_layer_scale:
|
|
|
|
|
self.layer_scale_1 = nn.Parameter(
|
|
|
|
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
|
|
|
|
|
|
|
|
|
self.layer_scale_2 = nn.Parameter(
|
|
|
|
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
|
|
|
|
|
2024-11-15 17:48:42 +08:00
|
|
|
|
def forward(self, x):
|
2024-06-03 19:36:29 +08:00
|
|
|
|
if self.use_layer_scale:
|
|
|
|
|
x = x + self.drop_path(
|
2024-11-15 17:48:42 +08:00
|
|
|
|
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
|
|
|
|
* self.token_mixer(self.norm1(x)))
|
2024-11-14 16:02:05 +08:00
|
|
|
|
x = x + self.drop_path(
|
|
|
|
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
|
|
|
|
* self.poolmlp(self.norm2(x)))
|
|
|
|
|
else:
|
2024-11-15 17:48:42 +08:00
|
|
|
|
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
|
2024-11-14 16:02:05 +08:00
|
|
|
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
|
|
|
|
return x
|
|
|
|
|
|
2024-11-16 12:59:07 +08:00
|
|
|
|
class BaseFeatureExtractionSAR(nn.Module):
|
|
|
|
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
# norm_layer=nn.LayerNorm,
|
|
|
|
|
drop=0., drop_path=0.,
|
|
|
|
|
use_layer_scale=True, layer_scale_init_value=1e-5):
|
|
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
2024-11-16 23:07:51 +08:00
|
|
|
|
self.token_mixer = SCSA(dim=dim,head_num=8)
|
|
|
|
|
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
2024-11-16 12:59:07 +08:00
|
|
|
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
|
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
|
|
|
|
act_layer=act_layer, drop=drop)
|
|
|
|
|
|
|
|
|
|
# The following two techniques are useful to train deep PoolFormers.
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
|
|
|
|
else nn.Identity()
|
|
|
|
|
self.use_layer_scale = use_layer_scale
|
|
|
|
|
|
|
|
|
|
if use_layer_scale:
|
|
|
|
|
self.layer_scale_1 = nn.Parameter(
|
|
|
|
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
|
|
|
|
|
|
|
|
|
self.layer_scale_2 = nn.Parameter(
|
|
|
|
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
if self.use_layer_scale:
|
|
|
|
|
x = x + self.drop_path(
|
|
|
|
|
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
|
|
|
|
* self.token_mixer(self.norm1(x)))
|
|
|
|
|
x = x + self.drop_path(
|
|
|
|
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
|
|
|
|
* self.poolmlp(self.norm2(x)))
|
|
|
|
|
else:
|
|
|
|
|
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
|
|
|
|
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
2024-11-15 17:48:42 +08:00
|
|
|
|
|
2024-06-03 19:36:29 +08:00
|
|
|
|
class InvertedResidualBlock(nn.Module):
|
|
|
|
|
def __init__(self, inp, oup, expand_ratio):
|
|
|
|
|
super(InvertedResidualBlock, self).__init__()
|
|
|
|
|
hidden_dim = int(inp * expand_ratio)
|
|
|
|
|
self.bottleneckBlock = nn.Sequential(
|
|
|
|
|
# pw
|
|
|
|
|
nn.Conv2d(inp, hidden_dim, 1, bias=False),
|
|
|
|
|
# nn.BatchNorm2d(hidden_dim),
|
|
|
|
|
nn.ReLU6(inplace=True),
|
|
|
|
|
# dw
|
|
|
|
|
nn.ReflectionPad2d(1),
|
|
|
|
|
nn.Conv2d(hidden_dim, hidden_dim, 3, groups=hidden_dim, bias=False),
|
|
|
|
|
# nn.BatchNorm2d(hidden_dim),
|
|
|
|
|
nn.ReLU6(inplace=True),
|
|
|
|
|
# pw-linear
|
|
|
|
|
nn.Conv2d(hidden_dim, oup, 1, bias=False),
|
|
|
|
|
# nn.BatchNorm2d(oup),
|
|
|
|
|
)
|
2024-11-14 16:59:11 +08:00
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
2024-11-15 17:48:42 +08:00
|
|
|
|
return self.bottleneckBlock(x)
|
2024-11-14 16:59:11 +08:00
|
|
|
|
|
2024-11-17 10:33:24 +08:00
|
|
|
|
class DepthwiseSeparableConvBlock(nn.Module):
|
|
|
|
|
def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1):
|
|
|
|
|
super(DepthwiseSeparableConvBlock, self).__init__()
|
|
|
|
|
self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False)
|
|
|
|
|
self.pointwise = nn.Conv2d(inp, oup, 1, bias=False)
|
|
|
|
|
self.bn = nn.BatchNorm2d(oup)
|
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.depthwise(x)
|
|
|
|
|
x = self.pointwise(x)
|
|
|
|
|
x = self.bn(x)
|
|
|
|
|
x = self.relu(x)
|
|
|
|
|
return x
|
2024-11-14 16:59:11 +08:00
|
|
|
|
|
2024-06-03 19:36:29 +08:00
|
|
|
|
class DetailNode(nn.Module):
|
2024-11-17 10:33:24 +08:00
|
|
|
|
def __init__(self,useBlock=0):
|
2024-06-03 19:36:29 +08:00
|
|
|
|
super(DetailNode, self).__init__()
|
2024-11-17 10:33:24 +08:00
|
|
|
|
if useBlock == 0:
|
|
|
|
|
self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
|
|
|
|
self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
|
|
|
|
self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
|
|
|
|
elif useBlock == 1:
|
|
|
|
|
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
|
|
|
|
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
|
|
|
|
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
2024-11-17 15:57:41 +08:00
|
|
|
|
elif useBlock == 2:
|
|
|
|
|
self.theta_phi = WTConv2d(in_channels=32, out_channels=32)
|
|
|
|
|
self.theta_rho = WTConv2d(in_channels=32, out_channels=32)
|
|
|
|
|
self.theta_eta = WTConv2d(in_channels=32, out_channels=32)
|
2024-11-17 10:33:24 +08:00
|
|
|
|
else:
|
|
|
|
|
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
|
|
|
|
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
|
|
|
|
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
2024-06-03 19:36:29 +08:00
|
|
|
|
self.shffleconv = nn.Conv2d(64, 64, kernel_size=1,
|
|
|
|
|
stride=1, padding=0, bias=True)
|
|
|
|
|
|
|
|
|
|
def separateFeature(self, x):
|
|
|
|
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
|
|
|
|
return z1, z2
|
|
|
|
|
|
|
|
|
|
def forward(self, z1, z2):
|
|
|
|
|
z1, z2 = self.separateFeature(
|
|
|
|
|
self.shffleconv(torch.cat((z1, z2), dim=1)))
|
|
|
|
|
z2 = z2 + self.theta_phi(z1)
|
|
|
|
|
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
|
|
|
|
|
return z1, z2
|
|
|
|
|
|
2024-11-16 21:39:34 +08:00
|
|
|
|
class DetailFeatureFusion(nn.Module):
|
|
|
|
|
def __init__(self, num_layers=3):
|
|
|
|
|
super(DetailFeatureFusion, self).__init__()
|
2024-11-17 15:57:41 +08:00
|
|
|
|
INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)]
|
2024-11-16 21:39:34 +08:00
|
|
|
|
self.net = nn.Sequential(*INNmodules)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
|
|
|
|
for layer in self.net:
|
|
|
|
|
z1, z2 = layer(z1, z2)
|
|
|
|
|
return torch.cat((z1, z2), dim=1)
|
2024-06-03 19:36:29 +08:00
|
|
|
|
|
|
|
|
|
class DetailFeatureExtraction(nn.Module):
|
|
|
|
|
def __init__(self, num_layers=3):
|
|
|
|
|
super(DetailFeatureExtraction, self).__init__()
|
2024-11-18 09:27:16 +08:00
|
|
|
|
INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)]
|
2024-11-14 16:02:05 +08:00
|
|
|
|
self.net = nn.Sequential(*INNmodules)
|
2024-11-15 17:48:42 +08:00
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
2024-11-14 16:02:05 +08:00
|
|
|
|
for layer in self.net:
|
|
|
|
|
z1, z2 = layer(z1, z2)
|
|
|
|
|
return torch.cat((z1, z2), dim=1)
|
|
|
|
|
|
|
|
|
|
|
2024-11-16 12:59:07 +08:00
|
|
|
|
class DetailFeatureExtractionSAR(nn.Module):
|
|
|
|
|
def __init__(self, num_layers=3):
|
|
|
|
|
super(DetailFeatureExtractionSAR, self).__init__()
|
2024-11-18 09:27:16 +08:00
|
|
|
|
INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)]
|
2024-11-16 12:59:07 +08:00
|
|
|
|
self.net = nn.Sequential(*INNmodules)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
|
|
|
|
for layer in self.net:
|
|
|
|
|
z1, z2 = layer(z1, z2)
|
|
|
|
|
return torch.cat((z1, z2), dim=1)
|
|
|
|
|
|
|
|
|
|
|
2024-06-03 19:36:29 +08:00
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
import numbers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##########################################################################
|
|
|
|
|
## Layer Norm
|
|
|
|
|
def to_3d(x):
|
|
|
|
|
return rearrange(x, 'b c h w -> b (h w) c')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_4d(x, h, w):
|
|
|
|
|
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BiasFree_LayerNorm(nn.Module):
|
|
|
|
|
def __init__(self, normalized_shape):
|
|
|
|
|
super(BiasFree_LayerNorm, self).__init__()
|
|
|
|
|
if isinstance(normalized_shape, numbers.Integral):
|
|
|
|
|
normalized_shape = (normalized_shape,)
|
|
|
|
|
normalized_shape = torch.Size(normalized_shape)
|
|
|
|
|
|
|
|
|
|
assert len(normalized_shape) == 1
|
|
|
|
|
|
|
|
|
|
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
|
|
|
|
self.normalized_shape = normalized_shape
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
sigma = x.var(-1, keepdim=True, unbiased=False)
|
|
|
|
|
return x / torch.sqrt(sigma + 1e-5) * self.weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WithBias_LayerNorm(nn.Module):
|
|
|
|
|
def __init__(self, normalized_shape):
|
|
|
|
|
super(WithBias_LayerNorm, self).__init__()
|
|
|
|
|
if isinstance(normalized_shape, numbers.Integral):
|
|
|
|
|
normalized_shape = (normalized_shape,)
|
|
|
|
|
normalized_shape = torch.Size(normalized_shape)
|
|
|
|
|
|
|
|
|
|
assert len(normalized_shape) == 1
|
|
|
|
|
|
|
|
|
|
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
|
|
|
|
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
|
|
|
|
self.normalized_shape = normalized_shape
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
mu = x.mean(-1, keepdim=True)
|
|
|
|
|
sigma = x.var(-1, keepdim=True, unbiased=False)
|
|
|
|
|
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module):
|
|
|
|
|
def __init__(self, dim, LayerNorm_type):
|
|
|
|
|
super(LayerNorm, self).__init__()
|
|
|
|
|
if LayerNorm_type == 'BiasFree':
|
|
|
|
|
self.body = BiasFree_LayerNorm(dim)
|
|
|
|
|
else:
|
|
|
|
|
self.body = WithBias_LayerNorm(dim)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
h, w = x.shape[-2:]
|
|
|
|
|
return to_4d(self.body(to_3d(x)), h, w)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##########################################################################
|
|
|
|
|
## Gated-Dconv Feed-Forward Network (GDFN)
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
|
|
def __init__(self, dim, ffn_expansion_factor, bias):
|
|
|
|
|
super(FeedForward, self).__init__()
|
|
|
|
|
|
|
|
|
|
hidden_features = int(dim * ffn_expansion_factor)
|
|
|
|
|
|
|
|
|
|
self.project_in = nn.Conv2d(
|
|
|
|
|
dim, hidden_features * 2, kernel_size=1, bias=bias)
|
|
|
|
|
|
|
|
|
|
self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3,
|
|
|
|
|
stride=1, padding=1, groups=hidden_features * 2, bias=bias)
|
|
|
|
|
|
|
|
|
|
self.project_out = nn.Conv2d(
|
|
|
|
|
hidden_features, dim, kernel_size=1, bias=bias)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.project_in(x)
|
|
|
|
|
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
|
|
|
|
x = F.gelu(x1) * x2
|
|
|
|
|
x = self.project_out(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##########################################################################
|
|
|
|
|
## Multi-DConv Head Transposed Self-Attention (MDTA)
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
|
def __init__(self, dim, num_heads, bias):
|
|
|
|
|
super(Attention, self).__init__()
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
|
|
|
|
|
|
|
|
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
|
|
|
|
self.qkv_dwconv = nn.Conv2d(
|
|
|
|
|
dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
|
|
|
|
|
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
b, c, h, w = x.shape
|
|
|
|
|
|
|
|
|
|
qkv = self.qkv_dwconv(self.qkv(x))
|
|
|
|
|
q, k, v = qkv.chunk(3, dim=1)
|
|
|
|
|
|
|
|
|
|
q = rearrange(q, 'b (head c) h w -> b head c (h w)',
|
|
|
|
|
head=self.num_heads)
|
|
|
|
|
k = rearrange(k, 'b (head c) h w -> b head c (h w)',
|
|
|
|
|
head=self.num_heads)
|
|
|
|
|
v = rearrange(v, 'b (head c) h w -> b head c (h w)',
|
|
|
|
|
head=self.num_heads)
|
|
|
|
|
|
|
|
|
|
q = torch.nn.functional.normalize(q, dim=-1)
|
|
|
|
|
k = torch.nn.functional.normalize(k, dim=-1)
|
|
|
|
|
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
out = (attn @ v)
|
|
|
|
|
|
|
|
|
|
out = rearrange(out, 'b head c (h w) -> b (head c) h w',
|
|
|
|
|
head=self.num_heads, h=h, w=w)
|
|
|
|
|
|
|
|
|
|
out = self.project_out(out)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##########################################################################
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
|
|
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
|
|
|
|
|
super(TransformerBlock, self).__init__()
|
|
|
|
|
|
|
|
|
|
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
|
|
|
|
self.attn = Attention(dim, num_heads, bias)
|
|
|
|
|
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
|
|
|
|
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = x + self.attn(self.norm1(x))
|
|
|
|
|
x = x + self.ffn(self.norm2(x))
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##########################################################################
|
|
|
|
|
## Overlapped image patch embedding with 3x3 Conv
|
|
|
|
|
class OverlapPatchEmbed(nn.Module):
|
|
|
|
|
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
|
|
|
|
super(OverlapPatchEmbed, self).__init__()
|
|
|
|
|
|
|
|
|
|
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3,
|
|
|
|
|
stride=1, padding=1, bias=bias)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Restormer_Encoder(nn.Module):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
inp_channels=1,
|
|
|
|
|
out_channels=1,
|
|
|
|
|
dim=64,
|
|
|
|
|
num_blocks=[4, 4],
|
|
|
|
|
heads=[8, 8, 8],
|
|
|
|
|
ffn_expansion_factor=2,
|
|
|
|
|
bias=False,
|
|
|
|
|
LayerNorm_type='WithBias',
|
|
|
|
|
):
|
|
|
|
|
super(Restormer_Encoder, self).__init__()
|
|
|
|
|
|
2024-10-09 11:35:06 +08:00
|
|
|
|
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
2024-11-15 17:48:42 +08:00
|
|
|
|
|
2024-06-03 19:36:29 +08:00
|
|
|
|
self.encoder_level1 = nn.Sequential(
|
|
|
|
|
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
|
|
|
|
|
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
|
|
|
|
self.baseFeature = BaseFeatureExtraction(dim=dim)
|
2024-11-15 17:48:42 +08:00
|
|
|
|
self.detailFeature = DetailFeatureExtraction()
|
2024-10-09 11:35:06 +08:00
|
|
|
|
|
2024-11-16 12:59:07 +08:00
|
|
|
|
self.baseFeatureSar= BaseFeatureExtractionSAR(dim=dim)
|
|
|
|
|
self.detailFeatureSar = DetailFeatureExtractionSAR()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, inp_img, sar_img=False):
|
2024-06-03 19:36:29 +08:00
|
|
|
|
inp_enc_level1 = self.patch_embed(inp_img)
|
|
|
|
|
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
2024-11-16 12:59:07 +08:00
|
|
|
|
|
|
|
|
|
if sar_img:
|
|
|
|
|
base_feature = self.baseFeature(out_enc_level1)
|
|
|
|
|
detail_feature = self.detailFeature(out_enc_level1)
|
|
|
|
|
else:
|
|
|
|
|
base_feature= self.baseFeature(out_enc_level1)
|
|
|
|
|
detail_feature = self.detailFeature(out_enc_level1)
|
2024-11-15 17:48:42 +08:00
|
|
|
|
return base_feature, detail_feature, out_enc_level1
|
2024-06-03 19:36:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Restormer_Decoder(nn.Module):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
inp_channels=1,
|
|
|
|
|
out_channels=1,
|
|
|
|
|
dim=64,
|
|
|
|
|
num_blocks=[4, 4],
|
|
|
|
|
heads=[8, 8, 8],
|
|
|
|
|
ffn_expansion_factor=2,
|
|
|
|
|
bias=False,
|
|
|
|
|
LayerNorm_type='WithBias',
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
super(Restormer_Decoder, self).__init__()
|
|
|
|
|
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
|
2024-11-15 17:48:42 +08:00
|
|
|
|
self.encoder_level2 = nn.Sequential(
|
|
|
|
|
*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
|
2024-06-03 19:36:29 +08:00
|
|
|
|
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
|
|
|
|
self.output = nn.Sequential(
|
|
|
|
|
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
|
|
|
|
|
stride=1, padding=1, bias=bias),
|
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
|
nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3,
|
|
|
|
|
stride=1, padding=1, bias=bias), )
|
|
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
|
|
|
|
|
|
def forward(self, inp_img, base_feature, detail_feature):
|
|
|
|
|
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
|
|
|
|
|
out_enc_level0 = self.reduce_channel(out_enc_level0)
|
|
|
|
|
out_enc_level1 = self.encoder_level2(out_enc_level0)
|
|
|
|
|
if inp_img is not None:
|
|
|
|
|
out_enc_level1 = self.output(out_enc_level1) + inp_img
|
|
|
|
|
else:
|
|
|
|
|
out_enc_level1 = self.output(out_enc_level1)
|
|
|
|
|
return self.sigmoid(out_enc_level1), out_enc_level0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
height = 128
|
|
|
|
|
width = 128
|
|
|
|
|
window_size = 8
|
|
|
|
|
modelE = Restormer_Encoder().cuda()
|
|
|
|
|
modelD = Restormer_Decoder().cuda()
|