pfcfuse/train.py

266 lines
9.6 KiB
Python
Raw Permalink Normal View History

2024-06-03 19:36:29 +08:00
# -*- coding: utf-8 -*-
'''
------------------------------------------------------------------------------
Import packages
------------------------------------------------------------------------------
'''
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction, BaseFeatureFusion, \
DetailFeatureFusion
2024-06-03 19:36:29 +08:00
from utils.dataset import H5Dataset
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import sys
import time
import datetime
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from utils.loss import Fusionloss, cc, relative_diff_loss
import kornia
print(torch.__version__)
print(torch.cuda.is_available())
'''
------------------------------------------------------------------------------
Configure our network
------------------------------------------------------------------------------
'''
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
criteria_fusion = Fusionloss()
model_str = 'PFCFuse'
# . Set the hyper-parameters for training
num_epochs = 60 # total epoch
2024-06-03 19:36:29 +08:00
epoch_gap = 40 # epoches of Phase I
lr = 1e-4
weight_decay = 0
batch_size = 1
GPU_number = os.environ['CUDA_VISIBLE_DEVICES']
# Coefficients of the loss function
coeff_mse_loss_VF = 1. # alpha1
coeff_mse_loss_IF = 1.
coeff_rmi_loss_VF = 1.
coeff_rmi_loss_IF = 1.
coeff_cos_loss_VF = 1.
coeff_cos_loss_IF = 1.
coeff_decomp = 2. # alpha2 and alpha4
coeff_tv = 5.
clip_grad_norm_value = 0.01
optim_step = 20
optim_gamma = 0.5
# 打印所有参数
print(f"Model: {model_str}")
print(f"Number of epochs: {num_epochs}")
print(f"Epoch gap: {epoch_gap}")
print(f"Learning rate: {lr}")
print(f"Weight decay: {weight_decay}")
print(f"Batch size: {batch_size}")
print(f"GPU number: {GPU_number}")
print(f"Coefficient of MSE loss VF: {coeff_mse_loss_VF}")
print(f"Coefficient of MSE loss IF: {coeff_mse_loss_IF}")
print(f"Coefficient of RMI loss VF: {coeff_rmi_loss_VF}")
print(f"Coefficient of RMI loss IF: {coeff_rmi_loss_IF}")
print(f"Coefficient of Cosine loss VF: {coeff_cos_loss_VF}")
print(f"Coefficient of Cosine loss IF: {coeff_cos_loss_IF}")
print(f"Coefficient of Decomposition loss: {coeff_decomp}")
print(f"Coefficient of Total Variation loss: {coeff_tv}")
print(f"Clip gradient norm value: {clip_grad_norm_value}")
print(f"Optimization step: {optim_step}")
print(f"Optimization gamma: {optim_gamma}")
2024-06-03 19:36:29 +08:00
# Model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureFusion(dim=64)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureFusion(num_layers=1)).to(device)
2024-06-03 19:36:29 +08:00
# optimizer, scheduler and loss function
optimizer1 = torch.optim.Adam(
DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay)
optimizer2 = torch.optim.Adam(
DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay)
optimizer3 = torch.optim.Adam(
BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
optimizer4 = torch.optim.Adam(
DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma)
scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma)
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
MSELoss = nn.MSELoss()
L1Loss = nn.L1Loss()
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
HuberLoss = nn.HuberLoss()
# data loader
trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/PFCFuse/data/YYX_sar_opr_data_imgsize_256_stride_200.h5"),
2024-06-03 19:36:29 +08:00
batch_size=batch_size,
shuffle=True,
num_workers=0)
loader = {'train': trainloader, }
timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M")
'''
------------------------------------------------------------------------------
Train
------------------------------------------------------------------------------
'''
step = 0
torch.backends.cudnn.benchmark = True
prev_time = time.time()
start = time.time()
for epoch in range(num_epochs):
''' train '''
for i, (data_VIS, data_IR) in enumerate(loader['train']):
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
DIDF_Encoder.train()
DIDF_Decoder.train()
BaseFuseLayer.train()
DetailFuseLayer.train()
DIDF_Encoder.zero_grad()
DIDF_Decoder.zero_grad()
BaseFuseLayer.zero_grad()
DetailFuseLayer.zero_grad()
optimizer1.zero_grad()
optimizer2.zero_grad()
optimizer3.zero_grad()
optimizer4.zero_grad()
if epoch < epoch_gap: #Phase I
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR,sar_img=True)
2024-06-03 19:36:29 +08:00
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
cc_loss_B = cc(feature_V_B, feature_I_B)
cc_loss_D = cc(feature_V_D, feature_I_D)
mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + HuberLoss(data_VIS, data_VIS_hat)
mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + HuberLoss(data_IR, data_IR_hat)
# print("mse_loss_V", mse_loss_V)
# print("mse_loss_I", mse_loss_I)
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
kornia.filters.SpatialGradient()(data_VIS_hat))
# print("Gradient_loss", Gradient_loss)
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
# print("loss_decomp", loss_decomp)
loss_rmi_v = relative_diff_loss(data_VIS, data_VIS_hat)
loss_rmi_i = relative_diff_loss(data_IR, data_IR_hat)
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss + \
coeff_rmi_loss_IF * loss_rmi_i + coeff_rmi_loss_VF * loss_rmi_v
loss.backward()
nn.utils.clip_grad_norm_(
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
optimizer1.step()
optimizer2.step()
else: #Phase II
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR)
2024-06-03 19:36:29 +08:00
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
mse_loss_V = 5 * Loss_ssim(data_VIS, data_Fuse) + HuberLoss(data_VIS, data_Fuse)
mse_loss_I = 5 * Loss_ssim(data_IR, data_Fuse) + HuberLoss(data_IR, data_Fuse)
cc_loss_B = cc(feature_V_B, feature_I_B)
cc_loss_D = cc(feature_V_D, feature_I_D)
loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
loss = fusionloss + coeff_decomp * loss_decomp
loss.backward()
nn.utils.clip_grad_norm_(
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
optimizer1.step()
optimizer2.step()
optimizer3.step()
optimizer4.step()
# Determine approximate time left
batches_done = epoch * len(loader['train']) + i
batches_left = num_epochs * len(loader['train']) - batches_done
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
epoch_time = time.time() - prev_time
prev_time = time.time()
sys.stdout.write(
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
% (
epoch,
num_epochs,
i,
len(loader['train']),
loss.item(),
time_left,
)
)
2024-06-03 19:36:29 +08:00
# adjust the learning rate
scheduler1.step()
scheduler2.step()
if not epoch < epoch_gap:
scheduler3.step()
scheduler4.step()
if optimizer1.param_groups[0]['lr'] <= 1e-6:
optimizer1.param_groups[0]['lr'] = 1e-6
if optimizer2.param_groups[0]['lr'] <= 1e-6:
optimizer2.param_groups[0]['lr'] = 1e-6
if optimizer3.param_groups[0]['lr'] <= 1e-6:
optimizer3.param_groups[0]['lr'] = 1e-6
if optimizer4.param_groups[0]['lr'] <= 1e-6:
optimizer4.param_groups[0]['lr'] = 1e-6
if True:
checkpoint = {
'DIDF_Encoder': DIDF_Encoder.state_dict(),
'DIDF_Decoder': DIDF_Decoder.state_dict(),
'BaseFuseLayer': BaseFuseLayer.state_dict(),
'DetailFuseLayer': DetailFuseLayer.state_dict(),
}
savepth = os.path.join("models/whaiFusion" + timestamp + '.pth');
torch.save(checkpoint, savepth)
print("save model:{}".format(savepth))
2024-06-03 19:36:29 +08:00