diff --git a/net.py b/net.py index 04160e6..8a55253 100644 --- a/net.py +++ b/net.py @@ -222,14 +222,36 @@ class InvertedResidualBlock(nn.Module): def forward(self, x): return self.bottleneckBlock(x) +class DepthwiseSeparableConvBlock(nn.Module): + def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1): + super(DepthwiseSeparableConvBlock, self).__init__() + self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False) + self.pointwise = nn.Conv2d(inp, oup, 1, bias=False) + self.bn = nn.BatchNorm2d(oup) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.depthwise(x) + x = self.pointwise(x) + x = self.bn(x) + x = self.relu(x) + return x class DetailNode(nn.Module): - def __init__(self): + def __init__(self,useBlock=0): super(DetailNode, self).__init__() - - self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) - self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + if useBlock == 0: + self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32) + self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32) + self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32) + elif useBlock == 1: + self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + else: + self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) + self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.shffleconv = nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0, bias=True)