diff --git a/net.py b/net.py index debb154..2e2bc40 100644 --- a/net.py +++ b/net.py @@ -286,7 +286,7 @@ class DetailFeatureFusion(nn.Module): class DetailFeatureExtraction(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtraction, self).__init__() - INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)] + INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x): @@ -299,7 +299,7 @@ class DetailFeatureExtraction(nn.Module): class DetailFeatureExtractionSAR(nn.Module): def __init__(self, num_layers=3): super(DetailFeatureExtractionSAR, self).__init__() - INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)] + INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) def forward(self, x):