From ac4225c966761cbe7af66c22c1ba04786bf2e1a4 Mon Sep 17 00:00:00 2001 From: zjut Date: Sun, 17 Nov 2024 10:33:24 +0800 Subject: [PATCH] =?UTF-8?q?feat(net):=20=E4=B8=BA=20DetailNode=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=E6=B7=BB=E5=8A=A0=E5=8F=AF=E9=80=89=E5=8D=B7=E7=A7=AF?= =?UTF-8?q?=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 DetailNode 类中引入 useBlock 参数,用于选择不同的卷积块 - 新增 DepthwiseSeparableConvBlock 类,实现深度可分离卷积 - 根据 useBlock 的值,选择使用 DepthwiseSeparableConvBlock 或 InvertedResidualBlock - 优化了网络结构,提供了更多的灵活性和选择性 --- net.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/net.py b/net.py index 04160e6..8a55253 100644 --- a/net.py +++ b/net.py @@ -222,14 +222,36 @@ class InvertedResidualBlock(nn.Module): def forward(self, x): return self.bottleneckBlock(x) +class DepthwiseSeparableConvBlock(nn.Module): + def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1): + super(DepthwiseSeparableConvBlock, self).__init__() + self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False) + self.pointwise = nn.Conv2d(inp, oup, 1, bias=False) + self.bn = nn.BatchNorm2d(oup) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.depthwise(x) + x = self.pointwise(x) + x = self.bn(x) + x = self.relu(x) + return x class DetailNode(nn.Module): - def __init__(self): + def __init__(self,useBlock=0): super(DetailNode, self).__init__() - - self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + if useBlock == 0: + self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32) + self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32) + self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32) + elif useBlock == 1: + self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + else: + self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.shffleconv = nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0, bias=True)