From 7068b627c42b72a30adc658e8804ed63b6720bf9 Mon Sep 17 00:00:00 2001 From: whaifree Date: Sun, 6 Oct 2024 16:42:42 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86INN=E9=83=A8?= =?UTF-8?q?=E5=88=86=E7=9A=84=E6=AE=8B=E5=B7=AE=E8=BF=9E=E6=8E=A5=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=EF=BC=8C=E4=BF=AE=E6=94=B9=E4=BA=86=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E5=92=8C=E6=B5=8B=E8=AF=95=E4=BB=A3=E7=A0=81=E4=BB=A5=E6=8F=90?= =?UTF-8?q?=E9=AB=98=E4=BB=A3=E7=A0=81=E7=9A=84=E5=8F=AF=E8=AF=BB=E6=80=A7?= =?UTF-8?q?=E5=92=8C=E5=8F=AF=E7=BB=B4=E6=8A=A4=E6=80=A7=E3=80=82-=20?= =?UTF-8?q?=E5=9C=A8train.py=E4=B8=AD=E6=B7=BB=E5=8A=A0=E4=BA=86=E6=89=93?= =?UTF-8?q?=E5=8D=B0=E6=89=80=E6=9C=89=E5=8F=82=E6=95=B0=E7=9A=84=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=EF=BC=8C=E4=BB=A5=E6=96=B9=E4=BE=BF=E6=A3=80=E6=9F=A5?= =?UTF-8?q?=E5=92=8C=E8=AE=B0=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- net.py | 6 +++++- test_IVF.py | 2 +- train.py | 24 +++++++++++++++++++++++- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/net.py b/net.py index 18b2e17..d9fb694 100644 --- a/net.py +++ b/net.py @@ -248,7 +248,11 @@ class DetailFeatureExtraction(nn.Module): super(DetailFeatureExtraction, self).__init__() INNmodules = [DetailNode() for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) - + self.enhancement_module = nn.Sequential( + nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True), + ) def forward(self, x): # 1 64 128 128 z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 diff --git a/test_IVF.py b/test_IVF.py index 831df93..cc0012d 100644 --- a/test_IVF.py +++ b/test_IVF.py @@ -13,7 +13,7 @@ logging.basicConfig(level=logging.CRITICAL) os.environ["CUDA_VISIBLE_DEVICES"] = "0" -ckpt_path= r"/home/star/whaiDir/PFCFuse/PFCFuse_IVF.pth" +ckpt_path= r"/home/star/whaiDir/PFCFuse/models/PFCFusion10-05-20-46.pth" for dataset_name in ["TNO"]: print("\n"*2+"="*80) diff --git a/train.py b/train.py index 0126d8f..205195e 100644 --- a/train.py +++ b/train.py @@ -34,7 +34,7 @@ criteria_fusion = Fusionloss() model_str = 'PFCFuse' # . Set the hyper-parameters for training -num_epochs = 120 # total epoch +num_epochs = 60 # total epoch epoch_gap = 40 # epoches of Phase I lr = 1e-4 @@ -57,6 +57,28 @@ clip_grad_norm_value = 0.01 optim_step = 20 optim_gamma = 0.5 +# 打印所有参数 +print(f"Model: {model_str}") +print(f"Number of epochs: {num_epochs}") +print(f"Epoch gap: {epoch_gap}") +print(f"Learning rate: {lr}") +print(f"Weight decay: {weight_decay}") +print(f"Batch size: {batch_size}") +print(f"GPU number: {GPU_number}") + +print(f"Coefficient of MSE loss VF: {coeff_mse_loss_VF}") +print(f"Coefficient of MSE loss IF: {coeff_mse_loss_IF}") +print(f"Coefficient of RMI loss VF: {coeff_rmi_loss_VF}") +print(f"Coefficient of RMI loss IF: {coeff_rmi_loss_IF}") +print(f"Coefficient of Cosine loss VF: {coeff_cos_loss_VF}") +print(f"Coefficient of Cosine loss IF: {coeff_cos_loss_IF}") +print(f"Coefficient of Decomposition loss: {coeff_decomp}") +print(f"Coefficient of Total Variation loss: {coeff_tv}") + +print(f"Clip gradient norm value: {clip_grad_norm_value}") +print(f"Optimization step: {optim_step}") +print(f"Optimization gamma: {optim_gamma}") + # Model device = 'cuda' if torch.cuda.is_available() else 'cpu'