增加了INN部分的残差连接模块,修改了训练和测试代码以提高代码的可读性和可维护性。- 在train.py中添加了打印所有参数的代码,以方便检查和记录

This commit is contained in:
whaifree 2024-10-06 16:42:42 +08:00
parent faacea007c
commit 7068b627c4
3 changed files with 29 additions and 3 deletions

6
net.py
View File

@ -248,7 +248,11 @@ class DetailFeatureExtraction(nn.Module):
super(DetailFeatureExtraction, self).__init__() super(DetailFeatureExtraction, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)] INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules) self.net = nn.Sequential(*INNmodules)
self.enhancement_module = nn.Sequential(
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
)
def forward(self, x): # 1 64 128 128 def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128

View File

@ -13,7 +13,7 @@ logging.basicConfig(level=logging.CRITICAL)
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ckpt_path= r"/home/star/whaiDir/PFCFuse/PFCFuse_IVF.pth" ckpt_path= r"/home/star/whaiDir/PFCFuse/models/PFCFusion10-05-20-46.pth"
for dataset_name in ["TNO"]: for dataset_name in ["TNO"]:
print("\n"*2+"="*80) print("\n"*2+"="*80)

View File

@ -34,7 +34,7 @@ criteria_fusion = Fusionloss()
model_str = 'PFCFuse' model_str = 'PFCFuse'
# . Set the hyper-parameters for training # . Set the hyper-parameters for training
num_epochs = 120 # total epoch num_epochs = 60 # total epoch
epoch_gap = 40 # epoches of Phase I epoch_gap = 40 # epoches of Phase I
lr = 1e-4 lr = 1e-4
@ -57,6 +57,28 @@ clip_grad_norm_value = 0.01
optim_step = 20 optim_step = 20
optim_gamma = 0.5 optim_gamma = 0.5
# 打印所有参数
print(f"Model: {model_str}")
print(f"Number of epochs: {num_epochs}")
print(f"Epoch gap: {epoch_gap}")
print(f"Learning rate: {lr}")
print(f"Weight decay: {weight_decay}")
print(f"Batch size: {batch_size}")
print(f"GPU number: {GPU_number}")
print(f"Coefficient of MSE loss VF: {coeff_mse_loss_VF}")
print(f"Coefficient of MSE loss IF: {coeff_mse_loss_IF}")
print(f"Coefficient of RMI loss VF: {coeff_rmi_loss_VF}")
print(f"Coefficient of RMI loss IF: {coeff_rmi_loss_IF}")
print(f"Coefficient of Cosine loss VF: {coeff_cos_loss_VF}")
print(f"Coefficient of Cosine loss IF: {coeff_cos_loss_IF}")
print(f"Coefficient of Decomposition loss: {coeff_decomp}")
print(f"Coefficient of Total Variation loss: {coeff_tv}")
print(f"Clip gradient norm value: {clip_grad_norm_value}")
print(f"Optimization step: {optim_step}")
print(f"Optimization gamma: {optim_gamma}")
# Model # Model
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'