diff --git a/net.py b/net.py index 8a55253..e5d2614 100644 --- a/net.py +++ b/net.py @@ -7,6 +7,7 @@ from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from einops import rearrange from componets.SCSA import SCSA +from componets.WTConvCV2 import WTConv2d def drop_path(x, drop_prob: float = 0., training: bool = False): @@ -248,6 +249,10 @@ class DetailNode(nn.Module): self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + elif useBlock == 2: + self.theta_phi = WTConv2d(in_channels=32, out_channels=32) + self.theta_rho = WTConv2d(in_channels=32, out_channels=32) + self.theta_eta = WTConv2d(in_channels=32, out_channels=32) else: self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) @@ -269,7 +274,7 @@ class DetailNode(nn.Module): class DetailFeatureFusion(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureFusion, self).__init__() - INNmodules = [DetailNode() for _ in range(num_layers)] + INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x): @@ -281,7 +286,7 @@ class DetailFeatureFusion(nn.Module): class DetailFeatureExtraction(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtraction, self).__init__() - INNmodules = [DetailNode() for _ in range(num_layers)] + INNmodules = [DetailNode(useBlock=0) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x):