refactor(net): 注释掉 DetailFeatureExtraction、DetailFeatureFusion 和 DetailFeatureExtractionSAR 类中的 enhancement_module

- 在三个类中注释掉了 enhancement_module 的定义
- 该改动可能是为了暂时禁用增强模块的功能或进行调试
This commit is contained in:
whai 2024-11-15 09:14:15 +08:00
parent 555515c2dc
commit c023c0801d

6
net.py
View File

@ -327,7 +327,7 @@ class DetailFeatureExtraction(nn.Module):
super(DetailFeatureExtraction, self).__init__() super(DetailFeatureExtraction, self).__init__()
INNmodules = [DetailNode(use) for _ in range(num_layers)] INNmodules = [DetailNode(use) for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules) self.net = nn.Sequential(*INNmodules)
self.enhancement_module = WTConv2d(32, 32) # self.enhancement_module = WTConv2d(32, 32)
def forward(self, x): # 1 64 128 128 def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
@ -346,7 +346,7 @@ class DetailFeatureFusion(nn.Module):
super(DetailFeatureFusion, self).__init__() super(DetailFeatureFusion, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)] INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules) self.net = nn.Sequential(*INNmodules)
self.enhancement_module = WTConv2d(32, 32) # self.enhancement_module = WTConv2d(32, 32)
def forward(self, x): # 1 64 128 128 def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
@ -582,7 +582,7 @@ class DetailFeatureExtractionSAR(nn.Module):
super(DetailFeatureExtractionSAR, self).__init__() super(DetailFeatureExtractionSAR, self).__init__()
INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)] INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules) self.net = nn.Sequential(*INNmodules)
self.enhancement_module = WTConv2d(32, 32) # self.enhancement_module = WTConv2d(32, 32)
def forward(self, x): # 1 64 128 128 def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128