From 260e3aa760bd6ce94011e5b637a0bff7ab15fcd9 Mon Sep 17 00:00:00 2001 From: zjut Date: Sun, 17 Nov 2024 15:57:41 +0800 Subject: [PATCH] =?UTF-8?q?feat(net):=20=E6=B7=BB=E5=8A=A0=20WTConv2d=20?= =?UTF-8?q?=E5=B1=82=E5=B9=B6=E4=BF=AE=E6=94=B9=20DetailNode=20=E4=BD=BF?= =?UTF-8?q?=E7=94=A8-=20=E5=9C=A8=20net.py=20=E4=B8=AD=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E4=BA=86=20WTConv2d=20=E5=B1=82=E7=9A=84=E5=AF=BC=E5=85=A5-=20?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=20DetailNode=20=E7=B1=BB=E7=9A=84?= =?UTF-8?q?=E6=9E=84=E9=80=A0=E5=87=BD=E6=95=B0=EF=BC=8C=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E4=BA=86=20useBlock=20=E5=8F=82=E6=95=B0=20-=20=E6=A0=B9?= =?UTF-8?q?=E6=8D=AE=20useBlock=20=E5=8F=82=E6=95=B0=E7=9A=84=E5=80=BC?= =?UTF-8?q?=EF=BC=8C=E9=80=89=E6=8B=A9=E4=BD=BF=E7=94=A8=20WTConv2d?= =?UTF-8?q?=E5=B1=82=E6=88=96=20InvertedResidualBlock-=20=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E4=BA=86=20DetailFeatureFusion=20=E5=92=8C=20DetailFe?= =?UTF-8?q?atureExtraction=20=E7=B1=BB=EF=BC=8C=E6=8C=87=E5=AE=9A=E4=BA=86?= =?UTF-8?q?=20DetailNode=20=E7=9A=84=20useBlock=20=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- net.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/net.py b/net.py index 8a55253..e5d2614 100644 --- a/net.py +++ b/net.py @@ -7,6 +7,7 @@ from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from einops import rearrange from componets.SCSA import SCSA +from componets.WTConvCV2 import WTConv2d def drop_path(x, drop_prob: float = 0., training: bool = False): @@ -248,6 +249,10 @@ class DetailNode(nn.Module): self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + elif useBlock == 2: + self.theta_phi = WTConv2d(in_channels=32, out_channels=32) + self.theta_rho = WTConv2d(in_channels=32, out_channels=32) + self.theta_eta = WTConv2d(in_channels=32, out_channels=32) else: self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) @@ -269,7 +274,7 @@ class DetailNode(nn.Module): class DetailFeatureFusion(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureFusion, self).__init__() - INNmodules = [DetailNode() for _ in range(num_layers)] + INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x): @@ -281,7 +286,7 @@ class DetailFeatureFusion(nn.Module): class DetailFeatureExtraction(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtraction, self).__init__() - INNmodules = [DetailNode() for _ in range(num_layers)] + INNmodules = [DetailNode(useBlock=0) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x):