Compare commits

...

10 Commits

Author SHA1 Message Date
whaifree
d32424f54a d 2024-04-10 12:22:12 +08:00
Zhaozixiang1228
936e5d4ef1 Update README.md 2023-07-18 09:46:15 +02:00
Zhaozixiang1228
2759603d09 Update README.md 2023-07-18 09:45:18 +02:00
Zhaozixiang1228
655a8ef9f0 Update README.md 2023-07-18 09:43:42 +02:00
Zhaozixiang1228
2471e747c4 Create readme.md 2023-06-18 01:57:33 +02:00
Zhaozixiang1228
3fe9a38165 main 2023-06-17 01:46:05 +02:00
Zhaozixiang1228
0c2350bc2c Update README.md 2023-06-17 00:57:51 +02:00
Zhaozixiang1228
fe68253b98 Update README.md 2023-06-17 00:54:36 +02:00
Zhaozixiang1228
9ea7476527 Update README.md 2023-06-17 00:51:09 +02:00
Zhaozixiang1228
b4a336be9b Update README.md 2023-06-17 00:48:57 +02:00
13 changed files with 535 additions and 34 deletions

1
MSRS_train/readme.md Normal file
View File

@ -0,0 +1 @@
Download the MSRS dataset from [this link](https://github.com/Linfeng-Tang/MSRS) and place it here.

View File

@ -9,7 +9,7 @@ Codes for ***CDDFuse: Correlation-Driven Dual-Branch Feature Decomposition for M
## Update ## Update
- [2023/5] Training codes and config files will be public available before June. - [2023/6] Training codes and config files are public available.
- [2023/4] Release inference code for infrared-visible image fusion and medical image fusion. - [2023/4] Release inference code for infrared-visible image fusion and medical image fusion.
@ -30,21 +30,66 @@ Codes for ***CDDFuse: Correlation-Driven Dual-Branch Feature Decomposition for M
Multi-modality (MM) image fusion aims to render fused images that maintain the merits of different modalities, e.g., functional highlight and detailed textures. To tackle the challenge in modeling cross-modality features and decomposing desirable modality-specific and modality-shared features, we propose a novel Correlation-Driven feature Decomposition Fusion (CDDFuse) network. Firstly, CDDFuse uses Restormer blocks to extract cross-modality shallow features. We then introduce a dual-branch Transformer-CNN feature extractor with Lite Transformer (LT) blocks leveraging long-range attention to handle low-frequency global features and Invertible Neural Networks (INN) blocks focusing on extracting high-frequency local information. A correlation-driven loss is further proposed to make the low-frequency features correlated while the high-frequency features uncorrelated based on the embedded information. Then, the LT-based global fusion and INN-based local fusion layers output the fused image. Extensive experiments demonstrate that our CDDFuse achieves promising results in multiple fusion tasks, including infrared-visible image fusion and medical image fusion. We also show that CDDFuse can boost the performance in downstream infrared-visible semantic segmentation and object detection in a unified benchmark. Multi-modality (MM) image fusion aims to render fused images that maintain the merits of different modalities, e.g., functional highlight and detailed textures. To tackle the challenge in modeling cross-modality features and decomposing desirable modality-specific and modality-shared features, we propose a novel Correlation-Driven feature Decomposition Fusion (CDDFuse) network. Firstly, CDDFuse uses Restormer blocks to extract cross-modality shallow features. We then introduce a dual-branch Transformer-CNN feature extractor with Lite Transformer (LT) blocks leveraging long-range attention to handle low-frequency global features and Invertible Neural Networks (INN) blocks focusing on extracting high-frequency local information. A correlation-driven loss is further proposed to make the low-frequency features correlated while the high-frequency features uncorrelated based on the embedded information. Then, the LT-based global fusion and INN-based local fusion layers output the fused image. Extensive experiments demonstrate that our CDDFuse achieves promising results in multiple fusion tasks, including infrared-visible image fusion and medical image fusion. We also show that CDDFuse can boost the performance in downstream infrared-visible semantic segmentation and object detection in a unified benchmark.
## Usage ## 🌐 Usage
### Network Architecture ### Network Architecture
Our CDDFuse is implemented in ``net.py``. Our CDDFuse is implemented in ``net.py``.
### Testing ### 🏊 Training
**1. Virtual Environment**
```
# create virtual environment
conda create -n cddfuse python=3.8.10
conda activate cddfuse
# select pytorch version yourself
# install cddfuse requirements
pip install -r requirements.txt
```
**2. Data Preparation**
Download the MSRS dataset from [this link](https://github.com/Linfeng-Tang/MSRS) and place it in the folder ``'./MSRS_train/'``.
**3. Pre-Processing**
Run
```
python dataprocessing.py
```
and the processed training dataset is in ``'./data/MSRS_train_imgsize_128_stride_200.h5'``.
**4. CDDFuse Training**
Run
```
python train.py
```
and the trained model is available in ``'./models/'``.
### 🏄 Testing
**1. Pretrained models**
Pretrained models are available in ``'./models/CDDFuse_IVF.pth'`` and ``'./models/CDDFuse_MIF.pth'``, which are responsible for the Infrared-Visible Fusion (IVF) and Medical Image Fusion (MIF) tasks, respectively. Pretrained models are available in ``'./models/CDDFuse_IVF.pth'`` and ``'./models/CDDFuse_MIF.pth'``, which are responsible for the Infrared-Visible Fusion (IVF) and Medical Image Fusion (MIF) tasks, respectively.
**2. Test datasets**
The test datasets used in the paper have been stored in ``'./test_img/RoadScene'``, ``'./test_img/TNO'`` for IVF, ``'./test_img/MRI_CT'``, ``'./test_img/MRI_PET'`` and ``'./test_img/MRI_SPECT'`` for MIF. The test datasets used in the paper have been stored in ``'./test_img/RoadScene'``, ``'./test_img/TNO'`` for IVF, ``'./test_img/MRI_CT'``, ``'./test_img/MRI_PET'`` and ``'./test_img/MRI_SPECT'`` for MIF.
Unfortunately, since the size of **MSRS dataset** for IVF is 500+MB, we can not upload it for exhibition. It can be downloaded via [this link](https://github.com/Linfeng-Tang/MSRS). The other datasets contain all the test images. Unfortunately, since the size of **MSRS dataset** for IVF is 500+MB, we can not upload it for exhibition. It can be downloaded via [this link](https://github.com/Linfeng-Tang/MSRS). The other datasets contain all the test images.
If you want to infer with our CDDFuse and obtain the fusion results in our paper, please run ``'test_IVF.py'`` for IVF and ``'test_MIF.py'`` for MIF. **3. Results in Our Paper**
If you want to infer with our CDDFuse and obtain the fusion results in our paper, please run
```
python test_IVF.py
```
for Infrared-Visible Fusion and
```
python test_MIF.py
```
for Medical Image Fusion.
The testing results will be printed in the terminal. The testing results will be printed in the terminal.
@ -91,32 +136,7 @@ CDDFuse_MIF 3.9 58.31 20.87 2.49 1.35 0.97 0.78 1.48
``` ```
which can match the results in Table 5 in our original paper. which can match the results in Table 5 in our original paper.
### Training ## 🙌 CDDFuse
**1. Virtual Environment**
```
# create virtual environment
conda create -n cddfuse python=3.8.10
conda activate cddfuse
# select pytorch version yourself
# install cddfuse requirements
pip install -r requirements.txt
```
**2. Data Preparation**
Download the MSRS dataset from [this link](https://github.com/Linfeng-Tang/MSRS) and place it in the folder ``'./MSRS_train``.
**3. Pre-Processing**
Run
```python dataprocessing.py``` and the processed training dataset is in ``'./data/MSRS_train_imgsize_128_stride_200.h5``.
**4. CDDFuse Training**
Run ```python train.py``` and the trained model is available in ``'./models/'``.
## CDDFuse
### Illustration of our CDDFuse model. ### Illustration of our CDDFuse model.
@ -149,12 +169,12 @@ MM segmentation
<img src="image//MMSeg.png" width="60%" align=center /> <img src="image//MMSeg.png" width="60%" align=center />
## Related Work ## 📖 Related Work
- Zixiang Zhao, Haowen Bai, Jiangshe Zhang, Yulun Zhang, Kai Zhang, Shuang Xu, Dongdong Chen, Radu Timofte, Luc Van Gool. *Equivariant Multi-Modality Image Fusion.* **arXiv:2305.11443**, https://arxiv.org/abs/2305.11443 - Zixiang Zhao, Haowen Bai, Jiangshe Zhang, Yulun Zhang, Kai Zhang, Shuang Xu, Dongdong Chen, Radu Timofte, Luc Van Gool. *Equivariant Multi-Modality Image Fusion.* **arXiv:2305.11443**, https://arxiv.org/abs/2305.11443
- Zixiang Zhao, Haowen Bai, Yuanzhi Zhu, Jiangshe Zhang, Shuang Xu, Yulun Zhang, Kai Zhang, Deyu Meng, Radu Timofte, Luc Van Gool. - Zixiang Zhao, Haowen Bai, Yuanzhi Zhu, Jiangshe Zhang, Shuang Xu, Yulun Zhang, Kai Zhang, Deyu Meng, Radu Timofte, Luc Van Gool.
*DDFM: Denoising Diffusion Model for Multi-Modality Image Fusion.* **arXiv:2303.06840**, https://arxiv.org/abs/2303.06840 *DDFM: Denoising Diffusion Model for Multi-Modality Image Fusion.* **ICCV 2023**, https://arxiv.org/abs/2303.06840
- Zixiang Zhao, Shuang Xu, Chunxia Zhang, Junmin Liu, Jiangshe Zhang and Pengfei Li. *DIDFuse: Deep Image Decomposition for Infrared and Visible Image Fusion.* **IJCAI 2020**, https://www.ijcai.org/Proceedings/2020/135. - Zixiang Zhao, Shuang Xu, Chunxia Zhang, Junmin Liu, Jiangshe Zhang and Pengfei Li. *DIDFuse: Deep Image Decomposition for Infrared and Visible Image Fusion.* **IJCAI 2020**, https://www.ijcai.org/Proceedings/2020/135.

93
dataprocessing.py Normal file
View File

@ -0,0 +1,93 @@
import os
import h5py
import numpy as np
from tqdm import tqdm
from skimage.io import imread
def get_img_file(file_name):
imagelist = []
for parent, dirnames, filenames in os.walk(file_name):
for filename in filenames:
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
imagelist.append(os.path.join(parent, filename))
return imagelist
def rgb2y(img):
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
return y
def Im2Patch(img, win, stride=1):
k = 0
endc = img.shape[0]
endw = img.shape[1]
endh = img.shape[2]
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
TotalPatNum = patch.shape[1] * patch.shape[2]
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
for i in range(win):
for j in range(win):
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
k = k + 1
return Y.reshape([endc, win, win, TotalPatNum])
def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
upper_percentile=90):
"""Determine if an image is low contrast."""
limits = np.percentile(image, [lower_percentile, upper_percentile])
ratio = (limits[1] - limits[0]) / limits[1]
return ratio < fraction_threshold
data_name="MSRS_train"
img_size=128 #patch size
stride=200 #patch stride
IR_files = sorted(get_img_file(r"MSRS_train/ir"))
VIS_files = sorted(get_img_file(r"MSRS_train/vi"))
assert len(IR_files) == len(VIS_files)
h5f = h5py.File(os.path.join('.\\data',
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
'w')
h5_ir = h5f.create_group('ir_patchs')
h5_vis = h5f.create_group('vis_patchs')
train_num=0
for i in tqdm(range(len(IR_files))):
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
# crop
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
for ii in range(I_IR_Patch_Group.shape[-1]):
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
# Determine if the contrast is low
if not (bad_IR or bad_VIS):
avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR
avl_VIS= I_VIS_Patch_Group[0,:,:,ii]
avl_IR=avl_IR[None,...]
avl_VIS=avl_VIS[None,...]
h5_ir.create_dataset(str(train_num), data=avl_IR,
dtype=avl_IR.dtype, shape=avl_IR.shape)
h5_vis.create_dataset(str(train_num), data=avl_VIS,
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
train_num += 1
h5f.close()
with h5py.File(os.path.join('data',
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
for key in f.keys():
print(f[key], key, f[key].name)

Binary file not shown.

Binary file not shown.

132
test.py Normal file
View File

@ -0,0 +1,132 @@
import argparse
import sys
import uuid
from matplotlib import image as mpimg, pyplot as plt
from net import Sar_Restormer_Encoder,Vi_Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
import os
import numpy as np
from utils.Evaluator import Evaluator
import torch
import torch.nn as nn
from utils.img_read_save import img_save, image_read_cv2
import warnings
import logging
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.CRITICAL)
path = os.path.dirname(sys.argv[0]) + "\\"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# ckpt_path=r"models/CDDFuse_IVF.pth"
ckpt_path = r"" + path + "models/CDDFuse_04-10-11-56.pth"
print(torch.cuda.is_available())
def main(opt):
# --viPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\vi\ir_2.png --irPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\ir\ir_2.png --outputPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\
ir_path = opt.irPath
vi_path = opt.viPath
output_path = opt.outputPath
print("\n" * 2 + "=" * 80)
model_name = "CDDFuse "
print("The ir_path of " + ir_path + ' :')
print("The vi_path of " + vi_path + ' :')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
SAR_Encoder = nn.DataParallel(Sar_Restormer_Encoder()).to(device)
VI_Encoder = nn.DataParallel(Vi_Restormer_Encoder()).to(device)
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
SAR_Encoder.load_state_dict(torch.load(ckpt_path)['SAR_DIDF_Encoder'])
VI_Encoder.load_state_dict(torch.load(ckpt_path)['VI_DIDF_Encoder'])
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
SAR_Encoder.eval()
VI_Encoder.eval()
Decoder.eval()
BaseFuseLayer.eval()
DetailFuseLayer.eval()
with torch.no_grad():
data_IR = image_read_cv2(ir_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
data_VIS = image_read_cv2(vi_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
data_IR, data_VIS = torch.FloatTensor(data_IR), torch.FloatTensor(data_VIS)
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
feature_V_B, feature_V_D, feature_V = VI_Encoder(data_VIS)
feature_I_B, feature_I_D, feature_I = SAR_Encoder(data_IR)
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
data_Fuse = (data_Fuse - torch.min(data_Fuse)) / (torch.max(data_Fuse) - torch.min(data_Fuse))
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
# 获取文件名(包含后缀)
file_name_with_extension = os.path.basename(ir_path)
# 分离文件名和文件后缀
file_name, file_extension = os.path.splitext(file_name_with_extension)
img_save(fi, "fusion_" + file_name, output_path)
print("输出文件路径:" + output_path + "fusion_" + file_name + ".png")
metric_result = np.zeros((8))
irImagePath = ir_path
ir = image_read_cv2(irImagePath, 'GRAY')
viImagePath = vi_path
vi = image_read_cv2(viImagePath, 'GRAY')
fusionImagePath = os.path.join(output_path, "fusion_{}.png".format(file_name))
fi = image_read_cv2(fusionImagePath, 'GRAY')
# 统计
metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
, Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
, Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
, Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
metric_result /= len(os.listdir(output_path))
print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
print("对比结果:" + model_name + '\t' + str(np.round(metric_result[0], 2)) + '\t'
+ str(np.round(metric_result[1], 2)) + '\t'
+ str(np.round(metric_result[2], 2)) + '\t'
+ str(np.round(metric_result[3], 2)) + '\t'
+ str(np.round(metric_result[4], 2)) + '\t'
+ str(np.round(metric_result[5], 2)) + '\t'
+ str(np.round(metric_result[6], 2)) + '\t'
+ str(np.round(metric_result[7], 2))
)
print("=" * 80)
def parse_opt():
parser = argparse.ArgumentParser(
description='python.exe --irPath "红外绝对路径" --viPath "可见光路径" --outputPath "输出文件路径"')
parser.add_argument('--irPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\cus\\sar\\NH49E001013_10.tif", required=False,
help="是否为多路径") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
parser.add_argument('--viPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\cus\\opr\\NH49E001013_10.tif", required=False,
help="完整目录路径!,可以为数组") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
parser.add_argument('--outputPath', type=str, default='results_detect', required=False,
help="输出路径!") # 使用的也是图片的目标地址
opt = parser.parse_args()
return opt
if __name__ == '__main__':
opt = parse_opt()
main(opt)

219
train.py Normal file
View File

@ -0,0 +1,219 @@
# -*- coding: utf-8 -*-
'''
------------------------------------------------------------------------------
Import packages
------------------------------------------------------------------------------
'''
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
from utils.dataset import H5Dataset
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import sys
import time
import datetime
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from utils.loss import Fusionloss, cc
import kornia
'''
------------------------------------------------------------------------------
Configure our network
------------------------------------------------------------------------------
'''
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
criteria_fusion = Fusionloss()
model_str = 'CDDFuse'
# . Set the hyper-parameters for training
num_epochs = 120 # total epoch
epoch_gap = 40 # epoches of Phase I
lr = 1e-4
weight_decay = 0
batch_size = 8
GPU_number = os.environ['CUDA_VISIBLE_DEVICES']
# Coefficients of the loss function
coeff_mse_loss_VF = 1. # alpha1
coeff_mse_loss_IF = 1.
coeff_decomp = 2. # alpha2 and alpha4
coeff_tv = 5.
clip_grad_norm_value = 0.01
optim_step = 20
optim_gamma = 0.5
# Model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
# optimizer, scheduler and loss function
optimizer1 = torch.optim.Adam(
DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay)
optimizer2 = torch.optim.Adam(
DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay)
optimizer3 = torch.optim.Adam(
BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
optimizer4 = torch.optim.Adam(
DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma)
scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma)
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
MSELoss = nn.MSELoss()
L1Loss = nn.L1Loss()
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
# data loader
trainloader = DataLoader(H5Dataset(r"data/MSRS_train_imgsize_128_stride_200.h5"),
batch_size=batch_size,
shuffle=True,
num_workers=0)
loader = {'train': trainloader, }
timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M")
'''
------------------------------------------------------------------------------
Train
------------------------------------------------------------------------------
'''
step = 0
torch.backends.cudnn.benchmark = True
prev_time = time.time()
for epoch in range(num_epochs):
''' train '''
for i, (data_VIS, data_IR) in enumerate(loader['train']):
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
DIDF_Encoder.train()
DIDF_Decoder.train()
BaseFuseLayer.train()
DetailFuseLayer.train()
DIDF_Encoder.zero_grad()
DIDF_Decoder.zero_grad()
BaseFuseLayer.zero_grad()
DetailFuseLayer.zero_grad()
optimizer1.zero_grad()
optimizer2.zero_grad()
optimizer3.zero_grad()
optimizer4.zero_grad()
if epoch < epoch_gap: #Phase I
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR)
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
cc_loss_B = cc(feature_V_B, feature_I_B)
cc_loss_D = cc(feature_V_D, feature_I_D)
mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + MSELoss(data_VIS, data_VIS_hat)
mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + MSELoss(data_IR, data_IR_hat)
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
kornia.filters.SpatialGradient()(data_VIS_hat))
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss
loss.backward()
nn.utils.clip_grad_norm_(
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
optimizer1.step()
optimizer2.step()
else: #Phase II
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR)
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
mse_loss_V = 5*Loss_ssim(data_VIS, data_Fuse) + MSELoss(data_VIS, data_Fuse)
mse_loss_I = 5*Loss_ssim(data_IR, data_Fuse) + MSELoss(data_IR, data_Fuse)
cc_loss_B = cc(feature_V_B, feature_I_B)
cc_loss_D = cc(feature_V_D, feature_I_D)
loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
loss = fusionloss + coeff_decomp * loss_decomp
loss.backward()
nn.utils.clip_grad_norm_(
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
optimizer1.step()
optimizer2.step()
optimizer3.step()
optimizer4.step()
# Determine approximate time left
batches_done = epoch * len(loader['train']) + i
batches_left = num_epochs * len(loader['train']) - batches_done
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
prev_time = time.time()
sys.stdout.write(
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
% (
epoch,
num_epochs,
i,
len(loader['train']),
loss.item(),
time_left,
)
)
# adjust the learning rate
scheduler1.step()
scheduler2.step()
if not epoch < epoch_gap:
scheduler3.step()
scheduler4.step()
if optimizer1.param_groups[0]['lr'] <= 1e-6:
optimizer1.param_groups[0]['lr'] = 1e-6
if optimizer2.param_groups[0]['lr'] <= 1e-6:
optimizer2.param_groups[0]['lr'] = 1e-6
if optimizer3.param_groups[0]['lr'] <= 1e-6:
optimizer3.param_groups[0]['lr'] = 1e-6
if optimizer4.param_groups[0]['lr'] <= 1e-6:
optimizer4.param_groups[0]['lr'] = 1e-6
if True:
checkpoint = {
'DIDF_Encoder': DIDF_Encoder.state_dict(),
'DIDF_Decoder': DIDF_Decoder.state_dict(),
'BaseFuseLayer': BaseFuseLayer.state_dict(),
'DetailFuseLayer': DetailFuseLayer.state_dict(),
}
torch.save(checkpoint, os.path.join("models/CDDFuse_"+timestamp+'.pth'))

Binary file not shown.

Binary file not shown.

22
utils/dataset.py Normal file
View File

@ -0,0 +1,22 @@
import torch.utils.data as Data
import h5py
import numpy as np
import torch
class H5Dataset(Data.Dataset):
def __init__(self, h5file_path):
self.h5file_path = h5file_path
h5f = h5py.File(h5file_path, 'r')
self.keys = list(h5f['ir_patchs'].keys())
h5f.close()
def __len__(self):
return len(self.keys)
def __getitem__(self, index):
h5f = h5py.File(self.h5file_path, 'r')
key = self.keys[index]
IR = np.array(h5f['ir_patchs'][key])
VIS = np.array(h5f['vis_patchs'][key])
h5f.close()
return torch.Tensor(VIS), torch.Tensor(IR)

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np
class Fusionloss(nn.Module): class Fusionloss(nn.Module):
def __init__(self): def __init__(self):
@ -38,3 +38,17 @@ class Sobelxy(nn.Module):
sobelx=F.conv2d(x, self.weightx, padding=1) sobelx=F.conv2d(x, self.weightx, padding=1)
sobely=F.conv2d(x, self.weighty, padding=1) sobely=F.conv2d(x, self.weighty, padding=1)
return torch.abs(sobelx)+torch.abs(sobely) return torch.abs(sobelx)+torch.abs(sobely)
def cc(img1, img2):
eps = torch.finfo(torch.float32).eps
"""Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.]."""
N, C, _, _ = img1.shape
img1 = img1.reshape(N, C, -1)
img2 = img2.reshape(N, C, -1)
img1 = img1 - img1.mean(dim=-1, keepdim=True)
img2 = img2 - img2.mean(dim=-1, keepdim=True)
cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum(img1 **
2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1)))
cc = torch.clamp(cc, -1., 1.)
return cc.mean()