import torch import torch.nn as nn import math import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from einops import rearrange from componets.SCSA import SCSA from componets.WTConvCV2 import WTConv2d def drop_path(x, drop_prob: float = 0., training: bool = False): if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob # work with diff dim tensors, not just 2D ConvNets shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + \ torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class Pooling(nn.Module): def __init__(self, kernel_size=3): super().__init__() self.pool = nn.AvgPool2d( kernel_size, stride=1, padding=kernel_size // 2) def forward(self, x): return self.pool(x) - x class PoolMlp(nn.Module): """ Implementation of MLP with 1*1 convolutions. Input: tensor with shape [B, C, H, W] """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=False, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=bias) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=bias) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W) x = self.act(x) x = self.drop(x) x = self.fc2(x) # (B, C, H, W) --> (B, C, H, W) x = self.drop(x) return x class BaseFeatureFusion(nn.Module): def __init__(self, dim, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, # norm_layer=nn.LayerNorm, drop=0., drop_path=0., use_layer_scale=True, layer_scale_init_value=1e-5): super().__init__() self.norm1 = LayerNorm(dim, 'WithBias') self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 self.norm2 = LayerNorm(dim, 'WithBias') mlp_hidden_dim = int(dim * mlp_ratio) self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # The following two techniques are useful to train deep PoolFormers. self.drop_path = DropPath(drop_path) if drop_path > 0. \ else nn.Identity() self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) self.layer_scale_2 = nn.Parameter( torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) def forward(self, x): if self.use_layer_scale: x = x + self.drop_path( self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x))) x = x + self.drop_path( self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.poolmlp(self.norm2(x))) else: x = x + self.drop_path(self.token_mixer(self.norm1(x))) x = x + self.drop_path(self.poolmlp(self.norm2(x))) return x class BaseFeatureExtraction(nn.Module): def __init__(self, dim, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, # norm_layer=nn.LayerNorm, drop=0., drop_path=0., use_layer_scale=True, layer_scale_init_value=1e-5): super().__init__() self.norm1 = LayerNorm(dim, 'WithBias') self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 self.norm2 = LayerNorm(dim, 'WithBias') mlp_hidden_dim = int(dim * mlp_ratio) self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # The following two techniques are useful to train deep PoolFormers. self.drop_path = DropPath(drop_path) if drop_path > 0. \ else nn.Identity() self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) self.layer_scale_2 = nn.Parameter( torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) def forward(self, x): if self.use_layer_scale: x = x + self.drop_path( self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x))) x = x + self.drop_path( self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.poolmlp(self.norm2(x))) else: x = x + self.drop_path(self.token_mixer(self.norm1(x))) x = x + self.drop_path(self.poolmlp(self.norm2(x))) return x class BaseFeatureExtractionSAR(nn.Module): def __init__(self, dim, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, # norm_layer=nn.LayerNorm, drop=0., drop_path=0., use_layer_scale=True, layer_scale_init_value=1e-5): super().__init__() self.norm1 = LayerNorm(dim, 'WithBias') self.token_mixer = SCSA(dim=dim,head_num=8) # self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代 self.norm2 = LayerNorm(dim, 'WithBias') mlp_hidden_dim = int(dim * mlp_ratio) self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # The following two techniques are useful to train deep PoolFormers. self.drop_path = DropPath(drop_path) if drop_path > 0. \ else nn.Identity() self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) self.layer_scale_2 = nn.Parameter( torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) def forward(self, x): if self.use_layer_scale: x = x + self.drop_path( self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x))) x = x + self.drop_path( self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.poolmlp(self.norm2(x))) else: x = x + self.drop_path(self.token_mixer(self.norm1(x))) x = x + self.drop_path(self.poolmlp(self.norm2(x))) return x class InvertedResidualBlock(nn.Module): def __init__(self, inp, oup, expand_ratio): super(InvertedResidualBlock, self).__init__() hidden_dim = int(inp * expand_ratio) self.bottleneckBlock = nn.Sequential( # pw nn.Conv2d(inp, hidden_dim, 1, bias=False), # nn.BatchNorm2d(hidden_dim), nn.ReLU6(inplace=True), # dw nn.ReflectionPad2d(1), nn.Conv2d(hidden_dim, hidden_dim, 3, groups=hidden_dim, bias=False), # nn.BatchNorm2d(hidden_dim), nn.ReLU6(inplace=True), # pw-linear nn.Conv2d(hidden_dim, oup, 1, bias=False), # nn.BatchNorm2d(oup), ) def forward(self, x): return self.bottleneckBlock(x) class DepthwiseSeparableConvBlock(nn.Module): def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1): super(DepthwiseSeparableConvBlock, self).__init__() self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False) self.pointwise = nn.Conv2d(inp, oup, 1, bias=False) self.bn = nn.BatchNorm2d(oup) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.depthwise(x) x = self.pointwise(x) x = self.bn(x) x = self.relu(x) return x class DetailNode(nn.Module): def __init__(self,useBlock=1): super(DetailNode, self).__init__() if useBlock == 0: self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32) self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32) self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32) elif useBlock == 1: self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) elif useBlock == 2: self.theta_phi = WTConv2d(in_channels=32, out_channels=32) self.theta_rho = WTConv2d(in_channels=32, out_channels=32) self.theta_eta = WTConv2d(in_channels=32, out_channels=32) else: self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.shffleconv = nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0, bias=True) def separateFeature(self, x): z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] return z1, z2 def forward(self, z1, z2): z1, z2 = self.separateFeature( self.shffleconv(torch.cat((z1, z2), dim=1))) z2 = z2 + self.theta_phi(z1) z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2) return z1, z2 class DetailFeatureFusion(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureFusion, self).__init__() INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x): z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] for layer in self.net: z1, z2 = layer(z1, z2) return torch.cat((z1, z2), dim=1) class DetailFeatureExtraction(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtraction, self).__init__() INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x): z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] for layer in self.net: z1, z2 = layer(z1, z2) return torch.cat((z1, z2), dim=1) class DetailFeatureExtractionSAR(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtractionSAR, self).__init__() INNmodules = [DetailNode() for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x): z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] for layer in self.net: z1, z2 = layer(z1, z2) return torch.cat((z1, z2), dim=1) # ============================================================================= # ============================================================================= import numbers ########################################################################## ## Layer Norm def to_3d(x): return rearrange(x, 'b c h w -> b (h w) c') def to_4d(x, h, w): return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) class BiasFree_LayerNorm(nn.Module): def __init__(self, normalized_shape): super(BiasFree_LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) normalized_shape = torch.Size(normalized_shape) assert len(normalized_shape) == 1 self.weight = nn.Parameter(torch.ones(normalized_shape)) self.normalized_shape = normalized_shape def forward(self, x): sigma = x.var(-1, keepdim=True, unbiased=False) return x / torch.sqrt(sigma + 1e-5) * self.weight class WithBias_LayerNorm(nn.Module): def __init__(self, normalized_shape): super(WithBias_LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) normalized_shape = torch.Size(normalized_shape) assert len(normalized_shape) == 1 self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.normalized_shape = normalized_shape def forward(self, x): mu = x.mean(-1, keepdim=True) sigma = x.var(-1, keepdim=True, unbiased=False) return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias class LayerNorm(nn.Module): def __init__(self, dim, LayerNorm_type): super(LayerNorm, self).__init__() if LayerNorm_type == 'BiasFree': self.body = BiasFree_LayerNorm(dim) else: self.body = WithBias_LayerNorm(dim) def forward(self, x): h, w = x.shape[-2:] return to_4d(self.body(to_3d(x)), h, w) ########################################################################## ## Gated-Dconv Feed-Forward Network (GDFN) class FeedForward(nn.Module): def __init__(self, dim, ffn_expansion_factor, bias): super(FeedForward, self).__init__() hidden_features = int(dim * ffn_expansion_factor) self.project_in = nn.Conv2d( dim, hidden_features * 2, kernel_size=1, bias=bias) self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, groups=hidden_features * 2, bias=bias) self.project_out = nn.Conv2d( hidden_features, dim, kernel_size=1, bias=bias) def forward(self, x): x = self.project_in(x) x1, x2 = self.dwconv(x).chunk(2, dim=1) x = F.gelu(x1) * x2 x = self.project_out(x) return x ########################################################################## ## Multi-DConv Head Transposed Self-Attention (MDTA) class Attention(nn.Module): def __init__(self, dim, num_heads, bias): super(Attention, self).__init__() self.num_heads = num_heads self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) self.qkv_dwconv = nn.Conv2d( dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) def forward(self, x): b, c, h, w = x.shape qkv = self.qkv_dwconv(self.qkv(x)) q, k, v = qkv.chunk(3, dim=1) q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) attn = (q @ k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) out = (attn @ v) out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) out = self.project_out(out) return out ########################################################################## class TransformerBlock(nn.Module): def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): super(TransformerBlock, self).__init__() self.norm1 = LayerNorm(dim, LayerNorm_type) self.attn = Attention(dim, num_heads, bias) self.norm2 = LayerNorm(dim, LayerNorm_type) self.ffn = FeedForward(dim, ffn_expansion_factor, bias) def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.ffn(self.norm2(x)) return x ########################################################################## ## Overlapped image patch embedding with 3x3 Conv class OverlapPatchEmbed(nn.Module): def __init__(self, in_c=3, embed_dim=48, bias=False): super(OverlapPatchEmbed, self).__init__() self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) def forward(self, x): x = self.proj(x) return x class Restormer_Encoder(nn.Module): def __init__(self, inp_channels=1, out_channels=1, dim=64, num_blocks=[4, 4], heads=[8, 8, 8], ffn_expansion_factor=2, bias=False, LayerNorm_type='WithBias', ): super(Restormer_Encoder, self).__init__() self.patch_embed = OverlapPatchEmbed(inp_channels, dim) self.encoder_level1 = nn.Sequential( *[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) self.baseFeature = BaseFeatureExtraction(dim=dim) self.detailFeature = DetailFeatureExtraction() self.baseFeatureSar= BaseFeatureExtractionSAR(dim=dim) self.detailFeatureSar = DetailFeatureExtractionSAR() def forward(self, inp_img, sar_img=False): inp_enc_level1 = self.patch_embed(inp_img) out_enc_level1 = self.encoder_level1(inp_enc_level1) if sar_img: base_feature = self.baseFeature(out_enc_level1) detail_feature = self.detailFeature(out_enc_level1) else: base_feature= self.baseFeature(out_enc_level1) detail_feature = self.detailFeature(out_enc_level1) return base_feature, detail_feature, out_enc_level1 class Restormer_Decoder(nn.Module): def __init__(self, inp_channels=1, out_channels=1, dim=64, num_blocks=[4, 4], heads=[8, 8, 8], ffn_expansion_factor=2, bias=False, LayerNorm_type='WithBias', ): super(Restormer_Decoder, self).__init__() self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias) self.encoder_level2 = nn.Sequential( *[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) self.output = nn.Sequential( nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3, stride=1, padding=1, bias=bias), nn.LeakyReLU(), nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3, stride=1, padding=1, bias=bias), ) self.sigmoid = nn.Sigmoid() def forward(self, inp_img, base_feature, detail_feature): out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1) out_enc_level0 = self.reduce_channel(out_enc_level0) out_enc_level1 = self.encoder_level2(out_enc_level0) if inp_img is not None: out_enc_level1 = self.output(out_enc_level1) + inp_img else: out_enc_level1 = self.output(out_enc_level1) return self.sigmoid(out_enc_level1), out_enc_level0 if __name__ == '__main__': height = 128 width = 128 window_size = 8 modelE = Restormer_Encoder().cuda() modelD = Restormer_Decoder().cuda()