Compare commits
10 Commits
3c78c4d873
...
d32424f54a
Author | SHA1 | Date | |
---|---|---|---|
|
d32424f54a | ||
|
936e5d4ef1 | ||
|
2759603d09 | ||
|
655a8ef9f0 | ||
|
2471e747c4 | ||
|
3fe9a38165 | ||
|
0c2350bc2c | ||
|
fe68253b98 | ||
|
9ea7476527 | ||
|
b4a336be9b |
1
MSRS_train/readme.md
Normal file
1
MSRS_train/readme.md
Normal file
@ -0,0 +1 @@
|
||||
Download the MSRS dataset from [this link](https://github.com/Linfeng-Tang/MSRS) and place it here.
|
86
README.md
86
README.md
@ -9,7 +9,7 @@ Codes for ***CDDFuse: Correlation-Driven Dual-Branch Feature Decomposition for M
|
||||
|
||||
|
||||
## Update
|
||||
- [2023/5] Training codes and config files will be public available before June.
|
||||
- [2023/6] Training codes and config files are public available.
|
||||
- [2023/4] Release inference code for infrared-visible image fusion and medical image fusion.
|
||||
|
||||
|
||||
@ -30,21 +30,66 @@ Codes for ***CDDFuse: Correlation-Driven Dual-Branch Feature Decomposition for M
|
||||
|
||||
Multi-modality (MM) image fusion aims to render fused images that maintain the merits of different modalities, e.g., functional highlight and detailed textures. To tackle the challenge in modeling cross-modality features and decomposing desirable modality-specific and modality-shared features, we propose a novel Correlation-Driven feature Decomposition Fusion (CDDFuse) network. Firstly, CDDFuse uses Restormer blocks to extract cross-modality shallow features. We then introduce a dual-branch Transformer-CNN feature extractor with Lite Transformer (LT) blocks leveraging long-range attention to handle low-frequency global features and Invertible Neural Networks (INN) blocks focusing on extracting high-frequency local information. A correlation-driven loss is further proposed to make the low-frequency features correlated while the high-frequency features uncorrelated based on the embedded information. Then, the LT-based global fusion and INN-based local fusion layers output the fused image. Extensive experiments demonstrate that our CDDFuse achieves promising results in multiple fusion tasks, including infrared-visible image fusion and medical image fusion. We also show that CDDFuse can boost the performance in downstream infrared-visible semantic segmentation and object detection in a unified benchmark.
|
||||
|
||||
## Usage
|
||||
## 🌐 Usage
|
||||
|
||||
### Network Architecture
|
||||
### ⚙ Network Architecture
|
||||
|
||||
Our CDDFuse is implemented in ``net.py``.
|
||||
|
||||
### Testing
|
||||
### 🏊 Training
|
||||
**1. Virtual Environment**
|
||||
```
|
||||
# create virtual environment
|
||||
conda create -n cddfuse python=3.8.10
|
||||
conda activate cddfuse
|
||||
# select pytorch version yourself
|
||||
# install cddfuse requirements
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
**2. Data Preparation**
|
||||
|
||||
Download the MSRS dataset from [this link](https://github.com/Linfeng-Tang/MSRS) and place it in the folder ``'./MSRS_train/'``.
|
||||
|
||||
**3. Pre-Processing**
|
||||
|
||||
Run
|
||||
```
|
||||
python dataprocessing.py
|
||||
```
|
||||
and the processed training dataset is in ``'./data/MSRS_train_imgsize_128_stride_200.h5'``.
|
||||
|
||||
**4. CDDFuse Training**
|
||||
|
||||
Run
|
||||
```
|
||||
python train.py
|
||||
```
|
||||
and the trained model is available in ``'./models/'``.
|
||||
|
||||
### 🏄 Testing
|
||||
|
||||
**1. Pretrained models**
|
||||
|
||||
Pretrained models are available in ``'./models/CDDFuse_IVF.pth'`` and ``'./models/CDDFuse_MIF.pth'``, which are responsible for the Infrared-Visible Fusion (IVF) and Medical Image Fusion (MIF) tasks, respectively.
|
||||
|
||||
**2. Test datasets**
|
||||
|
||||
The test datasets used in the paper have been stored in ``'./test_img/RoadScene'``, ``'./test_img/TNO'`` for IVF, ``'./test_img/MRI_CT'``, ``'./test_img/MRI_PET'`` and ``'./test_img/MRI_SPECT'`` for MIF.
|
||||
|
||||
Unfortunately, since the size of **MSRS dataset** for IVF is 500+MB, we can not upload it for exhibition. It can be downloaded via [this link](https://github.com/Linfeng-Tang/MSRS). The other datasets contain all the test images.
|
||||
|
||||
If you want to infer with our CDDFuse and obtain the fusion results in our paper, please run ``'test_IVF.py'`` for IVF and ``'test_MIF.py'`` for MIF.
|
||||
**3. Results in Our Paper**
|
||||
|
||||
If you want to infer with our CDDFuse and obtain the fusion results in our paper, please run
|
||||
```
|
||||
python test_IVF.py
|
||||
```
|
||||
for Infrared-Visible Fusion and
|
||||
```
|
||||
python test_MIF.py
|
||||
```
|
||||
for Medical Image Fusion.
|
||||
|
||||
The testing results will be printed in the terminal.
|
||||
|
||||
@ -91,32 +136,7 @@ CDDFuse_MIF 3.9 58.31 20.87 2.49 1.35 0.97 0.78 1.48
|
||||
```
|
||||
which can match the results in Table 5 in our original paper.
|
||||
|
||||
### Training
|
||||
**1. Virtual Environment**
|
||||
```
|
||||
# create virtual environment
|
||||
conda create -n cddfuse python=3.8.10
|
||||
conda activate cddfuse
|
||||
# select pytorch version yourself
|
||||
# install cddfuse requirements
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
**2. Data Preparation**
|
||||
|
||||
Download the MSRS dataset from [this link](https://github.com/Linfeng-Tang/MSRS) and place it in the folder ``'./MSRS_train``.
|
||||
|
||||
**3. Pre-Processing**
|
||||
|
||||
Run
|
||||
```python dataprocessing.py``` and the processed training dataset is in ``'./data/MSRS_train_imgsize_128_stride_200.h5``.
|
||||
|
||||
**4. CDDFuse Training**
|
||||
|
||||
Run ```python train.py``` and the trained model is available in ``'./models/'``.
|
||||
|
||||
|
||||
## CDDFuse
|
||||
## 🙌 CDDFuse
|
||||
|
||||
### Illustration of our CDDFuse model.
|
||||
|
||||
@ -149,12 +169,12 @@ MM segmentation
|
||||
<img src="image//MMSeg.png" width="60%" align=center />
|
||||
|
||||
|
||||
## Related Work
|
||||
## 📖 Related Work
|
||||
|
||||
- Zixiang Zhao, Haowen Bai, Jiangshe Zhang, Yulun Zhang, Kai Zhang, Shuang Xu, Dongdong Chen, Radu Timofte, Luc Van Gool. *Equivariant Multi-Modality Image Fusion.* **arXiv:2305.11443**, https://arxiv.org/abs/2305.11443
|
||||
|
||||
- Zixiang Zhao, Haowen Bai, Yuanzhi Zhu, Jiangshe Zhang, Shuang Xu, Yulun Zhang, Kai Zhang, Deyu Meng, Radu Timofte, Luc Van Gool.
|
||||
*DDFM: Denoising Diffusion Model for Multi-Modality Image Fusion.* **arXiv:2303.06840**, https://arxiv.org/abs/2303.06840
|
||||
*DDFM: Denoising Diffusion Model for Multi-Modality Image Fusion.* **ICCV 2023**, https://arxiv.org/abs/2303.06840
|
||||
|
||||
- Zixiang Zhao, Shuang Xu, Chunxia Zhang, Junmin Liu, Jiangshe Zhang and Pengfei Li. *DIDFuse: Deep Image Decomposition for Infrared and Visible Image Fusion.* **IJCAI 2020**, https://www.ijcai.org/Proceedings/2020/135.
|
||||
|
||||
|
93
dataprocessing.py
Normal file
93
dataprocessing.py
Normal file
@ -0,0 +1,93 @@
|
||||
import os
|
||||
import h5py
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from skimage.io import imread
|
||||
|
||||
|
||||
def get_img_file(file_name):
|
||||
imagelist = []
|
||||
for parent, dirnames, filenames in os.walk(file_name):
|
||||
for filename in filenames:
|
||||
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
|
||||
imagelist.append(os.path.join(parent, filename))
|
||||
return imagelist
|
||||
|
||||
def rgb2y(img):
|
||||
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
|
||||
return y
|
||||
|
||||
def Im2Patch(img, win, stride=1):
|
||||
k = 0
|
||||
endc = img.shape[0]
|
||||
endw = img.shape[1]
|
||||
endh = img.shape[2]
|
||||
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
|
||||
TotalPatNum = patch.shape[1] * patch.shape[2]
|
||||
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
|
||||
for i in range(win):
|
||||
for j in range(win):
|
||||
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
|
||||
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
|
||||
k = k + 1
|
||||
return Y.reshape([endc, win, win, TotalPatNum])
|
||||
|
||||
def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
|
||||
upper_percentile=90):
|
||||
"""Determine if an image is low contrast."""
|
||||
limits = np.percentile(image, [lower_percentile, upper_percentile])
|
||||
ratio = (limits[1] - limits[0]) / limits[1]
|
||||
return ratio < fraction_threshold
|
||||
|
||||
data_name="MSRS_train"
|
||||
img_size=128 #patch size
|
||||
stride=200 #patch stride
|
||||
|
||||
IR_files = sorted(get_img_file(r"MSRS_train/ir"))
|
||||
VIS_files = sorted(get_img_file(r"MSRS_train/vi"))
|
||||
|
||||
assert len(IR_files) == len(VIS_files)
|
||||
h5f = h5py.File(os.path.join('.\\data',
|
||||
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
|
||||
'w')
|
||||
h5_ir = h5f.create_group('ir_patchs')
|
||||
h5_vis = h5f.create_group('vis_patchs')
|
||||
train_num=0
|
||||
for i in tqdm(range(len(IR_files))):
|
||||
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
|
||||
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
|
||||
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
|
||||
|
||||
# crop
|
||||
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
|
||||
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
|
||||
|
||||
for ii in range(I_IR_Patch_Group.shape[-1]):
|
||||
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
|
||||
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
|
||||
# Determine if the contrast is low
|
||||
if not (bad_IR or bad_VIS):
|
||||
avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR
|
||||
avl_VIS= I_VIS_Patch_Group[0,:,:,ii]
|
||||
avl_IR=avl_IR[None,...]
|
||||
avl_VIS=avl_VIS[None,...]
|
||||
|
||||
h5_ir.create_dataset(str(train_num), data=avl_IR,
|
||||
dtype=avl_IR.dtype, shape=avl_IR.shape)
|
||||
h5_vis.create_dataset(str(train_num), data=avl_VIS,
|
||||
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
|
||||
train_num += 1
|
||||
|
||||
h5f.close()
|
||||
|
||||
with h5py.File(os.path.join('data',
|
||||
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
|
||||
for key in f.keys():
|
||||
print(f[key], key, f[key].name)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Binary file not shown.
Binary file not shown.
132
test.py
Normal file
132
test.py
Normal file
@ -0,0 +1,132 @@
|
||||
import argparse
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
from matplotlib import image as mpimg, pyplot as plt
|
||||
|
||||
from net import Sar_Restormer_Encoder,Vi_Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||
import os
|
||||
import numpy as np
|
||||
from utils.Evaluator import Evaluator
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from utils.img_read_save import img_save, image_read_cv2
|
||||
import warnings
|
||||
import logging
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
logging.basicConfig(level=logging.CRITICAL)
|
||||
|
||||
path = os.path.dirname(sys.argv[0]) + "\\"
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
# ckpt_path=r"models/CDDFuse_IVF.pth"
|
||||
ckpt_path = r"" + path + "models/CDDFuse_04-10-11-56.pth"
|
||||
|
||||
print(torch.cuda.is_available())
|
||||
|
||||
|
||||
def main(opt):
|
||||
# --viPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\vi\ir_2.png --irPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\ir\ir_2.png --outputPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\
|
||||
|
||||
ir_path = opt.irPath
|
||||
vi_path = opt.viPath
|
||||
output_path = opt.outputPath
|
||||
|
||||
print("\n" * 2 + "=" * 80)
|
||||
|
||||
model_name = "CDDFuse "
|
||||
print("The ir_path of " + ir_path + ' :')
|
||||
print("The vi_path of " + vi_path + ' :')
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
|
||||
SAR_Encoder = nn.DataParallel(Sar_Restormer_Encoder()).to(device)
|
||||
VI_Encoder = nn.DataParallel(Vi_Restormer_Encoder()).to(device)
|
||||
|
||||
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||
|
||||
SAR_Encoder.load_state_dict(torch.load(ckpt_path)['SAR_DIDF_Encoder'])
|
||||
VI_Encoder.load_state_dict(torch.load(ckpt_path)['VI_DIDF_Encoder'])
|
||||
|
||||
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
|
||||
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
|
||||
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
|
||||
SAR_Encoder.eval()
|
||||
VI_Encoder.eval()
|
||||
|
||||
Decoder.eval()
|
||||
BaseFuseLayer.eval()
|
||||
DetailFuseLayer.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
data_IR = image_read_cv2(ir_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
|
||||
data_VIS = image_read_cv2(vi_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
|
||||
|
||||
data_IR, data_VIS = torch.FloatTensor(data_IR), torch.FloatTensor(data_VIS)
|
||||
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||||
|
||||
feature_V_B, feature_V_D, feature_V = VI_Encoder(data_VIS)
|
||||
feature_I_B, feature_I_D, feature_I = SAR_Encoder(data_IR)
|
||||
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
|
||||
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
|
||||
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
|
||||
data_Fuse = (data_Fuse - torch.min(data_Fuse)) / (torch.max(data_Fuse) - torch.min(data_Fuse))
|
||||
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
|
||||
|
||||
# 获取文件名(包含后缀)
|
||||
file_name_with_extension = os.path.basename(ir_path)
|
||||
# 分离文件名和文件后缀
|
||||
file_name, file_extension = os.path.splitext(file_name_with_extension)
|
||||
|
||||
img_save(fi, "fusion_" + file_name, output_path)
|
||||
print("输出文件路径:" + output_path + "fusion_" + file_name + ".png")
|
||||
|
||||
metric_result = np.zeros((8))
|
||||
irImagePath = ir_path
|
||||
ir = image_read_cv2(irImagePath, 'GRAY')
|
||||
viImagePath = vi_path
|
||||
vi = image_read_cv2(viImagePath, 'GRAY')
|
||||
|
||||
fusionImagePath = os.path.join(output_path, "fusion_{}.png".format(file_name))
|
||||
fi = image_read_cv2(fusionImagePath, 'GRAY')
|
||||
# 统计
|
||||
metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
|
||||
, Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
|
||||
, Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
|
||||
, Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
|
||||
|
||||
metric_result /= len(os.listdir(output_path))
|
||||
print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
|
||||
print("对比结果:" + model_name + '\t' + str(np.round(metric_result[0], 2)) + '\t'
|
||||
+ str(np.round(metric_result[1], 2)) + '\t'
|
||||
+ str(np.round(metric_result[2], 2)) + '\t'
|
||||
+ str(np.round(metric_result[3], 2)) + '\t'
|
||||
+ str(np.round(metric_result[4], 2)) + '\t'
|
||||
+ str(np.round(metric_result[5], 2)) + '\t'
|
||||
+ str(np.round(metric_result[6], 2)) + '\t'
|
||||
+ str(np.round(metric_result[7], 2))
|
||||
)
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def parse_opt():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='python.exe --irPath "红外绝对路径" --viPath "可见光路径" --outputPath "输出文件路径"')
|
||||
parser.add_argument('--irPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\cus\\sar\\NH49E001013_10.tif", required=False,
|
||||
help="是否为多路径") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
|
||||
parser.add_argument('--viPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\cus\\opr\\NH49E001013_10.tif", required=False,
|
||||
help="完整目录路径!,可以为数组") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
|
||||
parser.add_argument('--outputPath', type=str, default='results_detect', required=False,
|
||||
help="输出路径!") # 使用的也是图片的目标地址
|
||||
|
||||
opt = parser.parse_args()
|
||||
return opt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = parse_opt()
|
||||
main(opt)
|
219
train.py
Normal file
219
train.py
Normal file
@ -0,0 +1,219 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
'''
|
||||
------------------------------------------------------------------------------
|
||||
Import packages
|
||||
------------------------------------------------------------------------------
|
||||
'''
|
||||
|
||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||
from utils.dataset import H5Dataset
|
||||
import os
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||
import sys
|
||||
import time
|
||||
import datetime
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from utils.loss import Fusionloss, cc
|
||||
import kornia
|
||||
|
||||
|
||||
|
||||
'''
|
||||
------------------------------------------------------------------------------
|
||||
Configure our network
|
||||
------------------------------------------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
criteria_fusion = Fusionloss()
|
||||
model_str = 'CDDFuse'
|
||||
|
||||
# . Set the hyper-parameters for training
|
||||
num_epochs = 120 # total epoch
|
||||
epoch_gap = 40 # epoches of Phase I
|
||||
|
||||
lr = 1e-4
|
||||
weight_decay = 0
|
||||
batch_size = 8
|
||||
GPU_number = os.environ['CUDA_VISIBLE_DEVICES']
|
||||
# Coefficients of the loss function
|
||||
coeff_mse_loss_VF = 1. # alpha1
|
||||
coeff_mse_loss_IF = 1.
|
||||
coeff_decomp = 2. # alpha2 and alpha4
|
||||
coeff_tv = 5.
|
||||
|
||||
clip_grad_norm_value = 0.01
|
||||
optim_step = 20
|
||||
optim_gamma = 0.5
|
||||
|
||||
|
||||
# Model
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||
|
||||
# optimizer, scheduler and loss function
|
||||
optimizer1 = torch.optim.Adam(
|
||||
DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
optimizer2 = torch.optim.Adam(
|
||||
DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
optimizer3 = torch.optim.Adam(
|
||||
BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
optimizer4 = torch.optim.Adam(
|
||||
DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
|
||||
scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma)
|
||||
scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma)
|
||||
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
|
||||
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
|
||||
|
||||
MSELoss = nn.MSELoss()
|
||||
L1Loss = nn.L1Loss()
|
||||
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
|
||||
|
||||
|
||||
# data loader
|
||||
trainloader = DataLoader(H5Dataset(r"data/MSRS_train_imgsize_128_stride_200.h5"),
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0)
|
||||
|
||||
loader = {'train': trainloader, }
|
||||
timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M")
|
||||
|
||||
'''
|
||||
------------------------------------------------------------------------------
|
||||
Train
|
||||
------------------------------------------------------------------------------
|
||||
'''
|
||||
|
||||
step = 0
|
||||
torch.backends.cudnn.benchmark = True
|
||||
prev_time = time.time()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
''' train '''
|
||||
for i, (data_VIS, data_IR) in enumerate(loader['train']):
|
||||
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||||
DIDF_Encoder.train()
|
||||
DIDF_Decoder.train()
|
||||
BaseFuseLayer.train()
|
||||
DetailFuseLayer.train()
|
||||
|
||||
DIDF_Encoder.zero_grad()
|
||||
DIDF_Decoder.zero_grad()
|
||||
BaseFuseLayer.zero_grad()
|
||||
DetailFuseLayer.zero_grad()
|
||||
|
||||
optimizer1.zero_grad()
|
||||
optimizer2.zero_grad()
|
||||
optimizer3.zero_grad()
|
||||
optimizer4.zero_grad()
|
||||
|
||||
if epoch < epoch_gap: #Phase I
|
||||
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
|
||||
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR)
|
||||
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
|
||||
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
|
||||
|
||||
cc_loss_B = cc(feature_V_B, feature_I_B)
|
||||
cc_loss_D = cc(feature_V_D, feature_I_D)
|
||||
mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + MSELoss(data_VIS, data_VIS_hat)
|
||||
mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + MSELoss(data_IR, data_IR_hat)
|
||||
|
||||
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
|
||||
kornia.filters.SpatialGradient()(data_VIS_hat))
|
||||
|
||||
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
|
||||
|
||||
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
|
||||
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss
|
||||
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
nn.utils.clip_grad_norm_(
|
||||
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
optimizer1.step()
|
||||
optimizer2.step()
|
||||
else: #Phase II
|
||||
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
|
||||
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR)
|
||||
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
|
||||
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
|
||||
data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
|
||||
|
||||
|
||||
mse_loss_V = 5*Loss_ssim(data_VIS, data_Fuse) + MSELoss(data_VIS, data_Fuse)
|
||||
mse_loss_I = 5*Loss_ssim(data_IR, data_Fuse) + MSELoss(data_IR, data_Fuse)
|
||||
|
||||
cc_loss_B = cc(feature_V_B, feature_I_B)
|
||||
cc_loss_D = cc(feature_V_D, feature_I_D)
|
||||
loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
|
||||
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
|
||||
|
||||
loss = fusionloss + coeff_decomp * loss_decomp
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
nn.utils.clip_grad_norm_(
|
||||
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
nn.utils.clip_grad_norm_(
|
||||
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
nn.utils.clip_grad_norm_(
|
||||
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
optimizer1.step()
|
||||
optimizer2.step()
|
||||
optimizer3.step()
|
||||
optimizer4.step()
|
||||
|
||||
# Determine approximate time left
|
||||
batches_done = epoch * len(loader['train']) + i
|
||||
batches_left = num_epochs * len(loader['train']) - batches_done
|
||||
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
|
||||
prev_time = time.time()
|
||||
sys.stdout.write(
|
||||
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
|
||||
% (
|
||||
epoch,
|
||||
num_epochs,
|
||||
i,
|
||||
len(loader['train']),
|
||||
loss.item(),
|
||||
time_left,
|
||||
)
|
||||
)
|
||||
|
||||
# adjust the learning rate
|
||||
|
||||
scheduler1.step()
|
||||
scheduler2.step()
|
||||
if not epoch < epoch_gap:
|
||||
scheduler3.step()
|
||||
scheduler4.step()
|
||||
|
||||
if optimizer1.param_groups[0]['lr'] <= 1e-6:
|
||||
optimizer1.param_groups[0]['lr'] = 1e-6
|
||||
if optimizer2.param_groups[0]['lr'] <= 1e-6:
|
||||
optimizer2.param_groups[0]['lr'] = 1e-6
|
||||
if optimizer3.param_groups[0]['lr'] <= 1e-6:
|
||||
optimizer3.param_groups[0]['lr'] = 1e-6
|
||||
if optimizer4.param_groups[0]['lr'] <= 1e-6:
|
||||
optimizer4.param_groups[0]['lr'] = 1e-6
|
||||
|
||||
if True:
|
||||
checkpoint = {
|
||||
'DIDF_Encoder': DIDF_Encoder.state_dict(),
|
||||
'DIDF_Decoder': DIDF_Decoder.state_dict(),
|
||||
'BaseFuseLayer': BaseFuseLayer.state_dict(),
|
||||
'DetailFuseLayer': DetailFuseLayer.state_dict(),
|
||||
}
|
||||
torch.save(checkpoint, os.path.join("models/CDDFuse_"+timestamp+'.pth'))
|
||||
|
||||
|
Binary file not shown.
BIN
utils/__pycache__/dataset.cpython-38.pyc
Normal file
BIN
utils/__pycache__/dataset.cpython-38.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
utils/__pycache__/loss.cpython-38.pyc
Normal file
BIN
utils/__pycache__/loss.cpython-38.pyc
Normal file
Binary file not shown.
22
utils/dataset.py
Normal file
22
utils/dataset.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch.utils.data as Data
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
class H5Dataset(Data.Dataset):
|
||||
def __init__(self, h5file_path):
|
||||
self.h5file_path = h5file_path
|
||||
h5f = h5py.File(h5file_path, 'r')
|
||||
self.keys = list(h5f['ir_patchs'].keys())
|
||||
h5f.close()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keys)
|
||||
|
||||
def __getitem__(self, index):
|
||||
h5f = h5py.File(self.h5file_path, 'r')
|
||||
key = self.keys[index]
|
||||
IR = np.array(h5f['ir_patchs'][key])
|
||||
VIS = np.array(h5f['vis_patchs'][key])
|
||||
h5f.close()
|
||||
return torch.Tensor(VIS), torch.Tensor(IR)
|
@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Fusionloss(nn.Module):
|
||||
def __init__(self):
|
||||
@ -38,3 +38,17 @@ class Sobelxy(nn.Module):
|
||||
sobelx=F.conv2d(x, self.weightx, padding=1)
|
||||
sobely=F.conv2d(x, self.weighty, padding=1)
|
||||
return torch.abs(sobelx)+torch.abs(sobely)
|
||||
|
||||
|
||||
def cc(img1, img2):
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
"""Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.]."""
|
||||
N, C, _, _ = img1.shape
|
||||
img1 = img1.reshape(N, C, -1)
|
||||
img2 = img2.reshape(N, C, -1)
|
||||
img1 = img1 - img1.mean(dim=-1, keepdim=True)
|
||||
img2 = img2 - img2.mean(dim=-1, keepdim=True)
|
||||
cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum(img1 **
|
||||
2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1)))
|
||||
cc = torch.clamp(cc, -1., 1.)
|
||||
return cc.mean()
|
Loading…
Reference in New Issue
Block a user