Compare commits
No commits in common. "d32424f54aac32a1d696889839ac90572e305bca" and "3c78c4d8739a0c61445080ab0dd2e877d45f0f29" have entirely different histories.
d32424f54a
...
3c78c4d873
@ -1 +0,0 @@
|
|||||||
Download the MSRS dataset from [this link](https://github.com/Linfeng-Tang/MSRS) and place it here.
|
|
86
README.md
86
README.md
@ -9,7 +9,7 @@ Codes for ***CDDFuse: Correlation-Driven Dual-Branch Feature Decomposition for M
|
|||||||
|
|
||||||
|
|
||||||
## Update
|
## Update
|
||||||
- [2023/6] Training codes and config files are public available.
|
- [2023/5] Training codes and config files will be public available before June.
|
||||||
- [2023/4] Release inference code for infrared-visible image fusion and medical image fusion.
|
- [2023/4] Release inference code for infrared-visible image fusion and medical image fusion.
|
||||||
|
|
||||||
|
|
||||||
@ -30,66 +30,21 @@ Codes for ***CDDFuse: Correlation-Driven Dual-Branch Feature Decomposition for M
|
|||||||
|
|
||||||
Multi-modality (MM) image fusion aims to render fused images that maintain the merits of different modalities, e.g., functional highlight and detailed textures. To tackle the challenge in modeling cross-modality features and decomposing desirable modality-specific and modality-shared features, we propose a novel Correlation-Driven feature Decomposition Fusion (CDDFuse) network. Firstly, CDDFuse uses Restormer blocks to extract cross-modality shallow features. We then introduce a dual-branch Transformer-CNN feature extractor with Lite Transformer (LT) blocks leveraging long-range attention to handle low-frequency global features and Invertible Neural Networks (INN) blocks focusing on extracting high-frequency local information. A correlation-driven loss is further proposed to make the low-frequency features correlated while the high-frequency features uncorrelated based on the embedded information. Then, the LT-based global fusion and INN-based local fusion layers output the fused image. Extensive experiments demonstrate that our CDDFuse achieves promising results in multiple fusion tasks, including infrared-visible image fusion and medical image fusion. We also show that CDDFuse can boost the performance in downstream infrared-visible semantic segmentation and object detection in a unified benchmark.
|
Multi-modality (MM) image fusion aims to render fused images that maintain the merits of different modalities, e.g., functional highlight and detailed textures. To tackle the challenge in modeling cross-modality features and decomposing desirable modality-specific and modality-shared features, we propose a novel Correlation-Driven feature Decomposition Fusion (CDDFuse) network. Firstly, CDDFuse uses Restormer blocks to extract cross-modality shallow features. We then introduce a dual-branch Transformer-CNN feature extractor with Lite Transformer (LT) blocks leveraging long-range attention to handle low-frequency global features and Invertible Neural Networks (INN) blocks focusing on extracting high-frequency local information. A correlation-driven loss is further proposed to make the low-frequency features correlated while the high-frequency features uncorrelated based on the embedded information. Then, the LT-based global fusion and INN-based local fusion layers output the fused image. Extensive experiments demonstrate that our CDDFuse achieves promising results in multiple fusion tasks, including infrared-visible image fusion and medical image fusion. We also show that CDDFuse can boost the performance in downstream infrared-visible semantic segmentation and object detection in a unified benchmark.
|
||||||
|
|
||||||
## 🌐 Usage
|
## Usage
|
||||||
|
|
||||||
### ⚙ Network Architecture
|
### Network Architecture
|
||||||
|
|
||||||
Our CDDFuse is implemented in ``net.py``.
|
Our CDDFuse is implemented in ``net.py``.
|
||||||
|
|
||||||
### 🏊 Training
|
### Testing
|
||||||
**1. Virtual Environment**
|
|
||||||
```
|
|
||||||
# create virtual environment
|
|
||||||
conda create -n cddfuse python=3.8.10
|
|
||||||
conda activate cddfuse
|
|
||||||
# select pytorch version yourself
|
|
||||||
# install cddfuse requirements
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
**2. Data Preparation**
|
|
||||||
|
|
||||||
Download the MSRS dataset from [this link](https://github.com/Linfeng-Tang/MSRS) and place it in the folder ``'./MSRS_train/'``.
|
|
||||||
|
|
||||||
**3. Pre-Processing**
|
|
||||||
|
|
||||||
Run
|
|
||||||
```
|
|
||||||
python dataprocessing.py
|
|
||||||
```
|
|
||||||
and the processed training dataset is in ``'./data/MSRS_train_imgsize_128_stride_200.h5'``.
|
|
||||||
|
|
||||||
**4. CDDFuse Training**
|
|
||||||
|
|
||||||
Run
|
|
||||||
```
|
|
||||||
python train.py
|
|
||||||
```
|
|
||||||
and the trained model is available in ``'./models/'``.
|
|
||||||
|
|
||||||
### 🏄 Testing
|
|
||||||
|
|
||||||
**1. Pretrained models**
|
|
||||||
|
|
||||||
Pretrained models are available in ``'./models/CDDFuse_IVF.pth'`` and ``'./models/CDDFuse_MIF.pth'``, which are responsible for the Infrared-Visible Fusion (IVF) and Medical Image Fusion (MIF) tasks, respectively.
|
Pretrained models are available in ``'./models/CDDFuse_IVF.pth'`` and ``'./models/CDDFuse_MIF.pth'``, which are responsible for the Infrared-Visible Fusion (IVF) and Medical Image Fusion (MIF) tasks, respectively.
|
||||||
|
|
||||||
**2. Test datasets**
|
|
||||||
|
|
||||||
The test datasets used in the paper have been stored in ``'./test_img/RoadScene'``, ``'./test_img/TNO'`` for IVF, ``'./test_img/MRI_CT'``, ``'./test_img/MRI_PET'`` and ``'./test_img/MRI_SPECT'`` for MIF.
|
The test datasets used in the paper have been stored in ``'./test_img/RoadScene'``, ``'./test_img/TNO'`` for IVF, ``'./test_img/MRI_CT'``, ``'./test_img/MRI_PET'`` and ``'./test_img/MRI_SPECT'`` for MIF.
|
||||||
|
|
||||||
Unfortunately, since the size of **MSRS dataset** for IVF is 500+MB, we can not upload it for exhibition. It can be downloaded via [this link](https://github.com/Linfeng-Tang/MSRS). The other datasets contain all the test images.
|
Unfortunately, since the size of **MSRS dataset** for IVF is 500+MB, we can not upload it for exhibition. It can be downloaded via [this link](https://github.com/Linfeng-Tang/MSRS). The other datasets contain all the test images.
|
||||||
|
|
||||||
**3. Results in Our Paper**
|
If you want to infer with our CDDFuse and obtain the fusion results in our paper, please run ``'test_IVF.py'`` for IVF and ``'test_MIF.py'`` for MIF.
|
||||||
|
|
||||||
If you want to infer with our CDDFuse and obtain the fusion results in our paper, please run
|
|
||||||
```
|
|
||||||
python test_IVF.py
|
|
||||||
```
|
|
||||||
for Infrared-Visible Fusion and
|
|
||||||
```
|
|
||||||
python test_MIF.py
|
|
||||||
```
|
|
||||||
for Medical Image Fusion.
|
|
||||||
|
|
||||||
The testing results will be printed in the terminal.
|
The testing results will be printed in the terminal.
|
||||||
|
|
||||||
@ -136,7 +91,32 @@ CDDFuse_MIF 3.9 58.31 20.87 2.49 1.35 0.97 0.78 1.48
|
|||||||
```
|
```
|
||||||
which can match the results in Table 5 in our original paper.
|
which can match the results in Table 5 in our original paper.
|
||||||
|
|
||||||
## 🙌 CDDFuse
|
### Training
|
||||||
|
**1. Virtual Environment**
|
||||||
|
```
|
||||||
|
# create virtual environment
|
||||||
|
conda create -n cddfuse python=3.8.10
|
||||||
|
conda activate cddfuse
|
||||||
|
# select pytorch version yourself
|
||||||
|
# install cddfuse requirements
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Data Preparation**
|
||||||
|
|
||||||
|
Download the MSRS dataset from [this link](https://github.com/Linfeng-Tang/MSRS) and place it in the folder ``'./MSRS_train``.
|
||||||
|
|
||||||
|
**3. Pre-Processing**
|
||||||
|
|
||||||
|
Run
|
||||||
|
```python dataprocessing.py``` and the processed training dataset is in ``'./data/MSRS_train_imgsize_128_stride_200.h5``.
|
||||||
|
|
||||||
|
**4. CDDFuse Training**
|
||||||
|
|
||||||
|
Run ```python train.py``` and the trained model is available in ``'./models/'``.
|
||||||
|
|
||||||
|
|
||||||
|
## CDDFuse
|
||||||
|
|
||||||
### Illustration of our CDDFuse model.
|
### Illustration of our CDDFuse model.
|
||||||
|
|
||||||
@ -169,12 +149,12 @@ MM segmentation
|
|||||||
<img src="image//MMSeg.png" width="60%" align=center />
|
<img src="image//MMSeg.png" width="60%" align=center />
|
||||||
|
|
||||||
|
|
||||||
## 📖 Related Work
|
## Related Work
|
||||||
|
|
||||||
- Zixiang Zhao, Haowen Bai, Jiangshe Zhang, Yulun Zhang, Kai Zhang, Shuang Xu, Dongdong Chen, Radu Timofte, Luc Van Gool. *Equivariant Multi-Modality Image Fusion.* **arXiv:2305.11443**, https://arxiv.org/abs/2305.11443
|
- Zixiang Zhao, Haowen Bai, Jiangshe Zhang, Yulun Zhang, Kai Zhang, Shuang Xu, Dongdong Chen, Radu Timofte, Luc Van Gool. *Equivariant Multi-Modality Image Fusion.* **arXiv:2305.11443**, https://arxiv.org/abs/2305.11443
|
||||||
|
|
||||||
- Zixiang Zhao, Haowen Bai, Yuanzhi Zhu, Jiangshe Zhang, Shuang Xu, Yulun Zhang, Kai Zhang, Deyu Meng, Radu Timofte, Luc Van Gool.
|
- Zixiang Zhao, Haowen Bai, Yuanzhi Zhu, Jiangshe Zhang, Shuang Xu, Yulun Zhang, Kai Zhang, Deyu Meng, Radu Timofte, Luc Van Gool.
|
||||||
*DDFM: Denoising Diffusion Model for Multi-Modality Image Fusion.* **ICCV 2023**, https://arxiv.org/abs/2303.06840
|
*DDFM: Denoising Diffusion Model for Multi-Modality Image Fusion.* **arXiv:2303.06840**, https://arxiv.org/abs/2303.06840
|
||||||
|
|
||||||
- Zixiang Zhao, Shuang Xu, Chunxia Zhang, Junmin Liu, Jiangshe Zhang and Pengfei Li. *DIDFuse: Deep Image Decomposition for Infrared and Visible Image Fusion.* **IJCAI 2020**, https://www.ijcai.org/Proceedings/2020/135.
|
- Zixiang Zhao, Shuang Xu, Chunxia Zhang, Junmin Liu, Jiangshe Zhang and Pengfei Li. *DIDFuse: Deep Image Decomposition for Infrared and Visible Image Fusion.* **IJCAI 2020**, https://www.ijcai.org/Proceedings/2020/135.
|
||||||
|
|
||||||
|
@ -1,93 +0,0 @@
|
|||||||
import os
|
|
||||||
import h5py
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
from skimage.io import imread
|
|
||||||
|
|
||||||
|
|
||||||
def get_img_file(file_name):
|
|
||||||
imagelist = []
|
|
||||||
for parent, dirnames, filenames in os.walk(file_name):
|
|
||||||
for filename in filenames:
|
|
||||||
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
|
|
||||||
imagelist.append(os.path.join(parent, filename))
|
|
||||||
return imagelist
|
|
||||||
|
|
||||||
def rgb2y(img):
|
|
||||||
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
|
|
||||||
return y
|
|
||||||
|
|
||||||
def Im2Patch(img, win, stride=1):
|
|
||||||
k = 0
|
|
||||||
endc = img.shape[0]
|
|
||||||
endw = img.shape[1]
|
|
||||||
endh = img.shape[2]
|
|
||||||
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
|
|
||||||
TotalPatNum = patch.shape[1] * patch.shape[2]
|
|
||||||
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
|
|
||||||
for i in range(win):
|
|
||||||
for j in range(win):
|
|
||||||
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
|
|
||||||
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
|
|
||||||
k = k + 1
|
|
||||||
return Y.reshape([endc, win, win, TotalPatNum])
|
|
||||||
|
|
||||||
def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
|
|
||||||
upper_percentile=90):
|
|
||||||
"""Determine if an image is low contrast."""
|
|
||||||
limits = np.percentile(image, [lower_percentile, upper_percentile])
|
|
||||||
ratio = (limits[1] - limits[0]) / limits[1]
|
|
||||||
return ratio < fraction_threshold
|
|
||||||
|
|
||||||
data_name="MSRS_train"
|
|
||||||
img_size=128 #patch size
|
|
||||||
stride=200 #patch stride
|
|
||||||
|
|
||||||
IR_files = sorted(get_img_file(r"MSRS_train/ir"))
|
|
||||||
VIS_files = sorted(get_img_file(r"MSRS_train/vi"))
|
|
||||||
|
|
||||||
assert len(IR_files) == len(VIS_files)
|
|
||||||
h5f = h5py.File(os.path.join('.\\data',
|
|
||||||
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
|
|
||||||
'w')
|
|
||||||
h5_ir = h5f.create_group('ir_patchs')
|
|
||||||
h5_vis = h5f.create_group('vis_patchs')
|
|
||||||
train_num=0
|
|
||||||
for i in tqdm(range(len(IR_files))):
|
|
||||||
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
|
|
||||||
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
|
|
||||||
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
|
|
||||||
|
|
||||||
# crop
|
|
||||||
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
|
|
||||||
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
|
|
||||||
|
|
||||||
for ii in range(I_IR_Patch_Group.shape[-1]):
|
|
||||||
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
|
|
||||||
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
|
|
||||||
# Determine if the contrast is low
|
|
||||||
if not (bad_IR or bad_VIS):
|
|
||||||
avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR
|
|
||||||
avl_VIS= I_VIS_Patch_Group[0,:,:,ii]
|
|
||||||
avl_IR=avl_IR[None,...]
|
|
||||||
avl_VIS=avl_VIS[None,...]
|
|
||||||
|
|
||||||
h5_ir.create_dataset(str(train_num), data=avl_IR,
|
|
||||||
dtype=avl_IR.dtype, shape=avl_IR.shape)
|
|
||||||
h5_vis.create_dataset(str(train_num), data=avl_VIS,
|
|
||||||
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
|
|
||||||
train_num += 1
|
|
||||||
|
|
||||||
h5f.close()
|
|
||||||
|
|
||||||
with h5py.File(os.path.join('data',
|
|
||||||
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
|
|
||||||
for key in f.keys():
|
|
||||||
print(f[key], key, f[key].name)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Binary file not shown.
Binary file not shown.
132
test.py
132
test.py
@ -1,132 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import sys
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from matplotlib import image as mpimg, pyplot as plt
|
|
||||||
|
|
||||||
from net import Sar_Restormer_Encoder,Vi_Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from utils.Evaluator import Evaluator
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from utils.img_read_save import img_save, image_read_cv2
|
|
||||||
import warnings
|
|
||||||
import logging
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
|
||||||
logging.basicConfig(level=logging.CRITICAL)
|
|
||||||
|
|
||||||
path = os.path.dirname(sys.argv[0]) + "\\"
|
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
||||||
# ckpt_path=r"models/CDDFuse_IVF.pth"
|
|
||||||
ckpt_path = r"" + path + "models/CDDFuse_04-10-11-56.pth"
|
|
||||||
|
|
||||||
print(torch.cuda.is_available())
|
|
||||||
|
|
||||||
|
|
||||||
def main(opt):
|
|
||||||
# --viPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\vi\ir_2.png --irPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\ir\ir_2.png --outputPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\
|
|
||||||
|
|
||||||
ir_path = opt.irPath
|
|
||||||
vi_path = opt.viPath
|
|
||||||
output_path = opt.outputPath
|
|
||||||
|
|
||||||
print("\n" * 2 + "=" * 80)
|
|
||||||
|
|
||||||
model_name = "CDDFuse "
|
|
||||||
print("The ir_path of " + ir_path + ' :')
|
|
||||||
print("The vi_path of " + vi_path + ' :')
|
|
||||||
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
||||||
|
|
||||||
|
|
||||||
SAR_Encoder = nn.DataParallel(Sar_Restormer_Encoder()).to(device)
|
|
||||||
VI_Encoder = nn.DataParallel(Vi_Restormer_Encoder()).to(device)
|
|
||||||
|
|
||||||
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
|
||||||
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
|
||||||
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
|
||||||
|
|
||||||
SAR_Encoder.load_state_dict(torch.load(ckpt_path)['SAR_DIDF_Encoder'])
|
|
||||||
VI_Encoder.load_state_dict(torch.load(ckpt_path)['VI_DIDF_Encoder'])
|
|
||||||
|
|
||||||
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
|
|
||||||
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
|
|
||||||
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
|
|
||||||
SAR_Encoder.eval()
|
|
||||||
VI_Encoder.eval()
|
|
||||||
|
|
||||||
Decoder.eval()
|
|
||||||
BaseFuseLayer.eval()
|
|
||||||
DetailFuseLayer.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
data_IR = image_read_cv2(ir_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
|
|
||||||
data_VIS = image_read_cv2(vi_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
|
|
||||||
|
|
||||||
data_IR, data_VIS = torch.FloatTensor(data_IR), torch.FloatTensor(data_VIS)
|
|
||||||
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
|
||||||
|
|
||||||
feature_V_B, feature_V_D, feature_V = VI_Encoder(data_VIS)
|
|
||||||
feature_I_B, feature_I_D, feature_I = SAR_Encoder(data_IR)
|
|
||||||
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
|
|
||||||
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
|
|
||||||
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
|
|
||||||
data_Fuse = (data_Fuse - torch.min(data_Fuse)) / (torch.max(data_Fuse) - torch.min(data_Fuse))
|
|
||||||
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
|
|
||||||
|
|
||||||
# 获取文件名(包含后缀)
|
|
||||||
file_name_with_extension = os.path.basename(ir_path)
|
|
||||||
# 分离文件名和文件后缀
|
|
||||||
file_name, file_extension = os.path.splitext(file_name_with_extension)
|
|
||||||
|
|
||||||
img_save(fi, "fusion_" + file_name, output_path)
|
|
||||||
print("输出文件路径:" + output_path + "fusion_" + file_name + ".png")
|
|
||||||
|
|
||||||
metric_result = np.zeros((8))
|
|
||||||
irImagePath = ir_path
|
|
||||||
ir = image_read_cv2(irImagePath, 'GRAY')
|
|
||||||
viImagePath = vi_path
|
|
||||||
vi = image_read_cv2(viImagePath, 'GRAY')
|
|
||||||
|
|
||||||
fusionImagePath = os.path.join(output_path, "fusion_{}.png".format(file_name))
|
|
||||||
fi = image_read_cv2(fusionImagePath, 'GRAY')
|
|
||||||
# 统计
|
|
||||||
metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
|
|
||||||
, Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
|
|
||||||
, Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
|
|
||||||
, Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
|
|
||||||
|
|
||||||
metric_result /= len(os.listdir(output_path))
|
|
||||||
print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
|
|
||||||
print("对比结果:" + model_name + '\t' + str(np.round(metric_result[0], 2)) + '\t'
|
|
||||||
+ str(np.round(metric_result[1], 2)) + '\t'
|
|
||||||
+ str(np.round(metric_result[2], 2)) + '\t'
|
|
||||||
+ str(np.round(metric_result[3], 2)) + '\t'
|
|
||||||
+ str(np.round(metric_result[4], 2)) + '\t'
|
|
||||||
+ str(np.round(metric_result[5], 2)) + '\t'
|
|
||||||
+ str(np.round(metric_result[6], 2)) + '\t'
|
|
||||||
+ str(np.round(metric_result[7], 2))
|
|
||||||
)
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_opt():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description='python.exe --irPath "红外绝对路径" --viPath "可见光路径" --outputPath "输出文件路径"')
|
|
||||||
parser.add_argument('--irPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\cus\\sar\\NH49E001013_10.tif", required=False,
|
|
||||||
help="是否为多路径") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
|
|
||||||
parser.add_argument('--viPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\cus\\opr\\NH49E001013_10.tif", required=False,
|
|
||||||
help="完整目录路径!,可以为数组") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
|
|
||||||
parser.add_argument('--outputPath', type=str, default='results_detect', required=False,
|
|
||||||
help="输出路径!") # 使用的也是图片的目标地址
|
|
||||||
|
|
||||||
opt = parser.parse_args()
|
|
||||||
return opt
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
opt = parse_opt()
|
|
||||||
main(opt)
|
|
219
train.py
219
train.py
@ -1,219 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
'''
|
|
||||||
------------------------------------------------------------------------------
|
|
||||||
Import packages
|
|
||||||
------------------------------------------------------------------------------
|
|
||||||
'''
|
|
||||||
|
|
||||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
|
||||||
from utils.dataset import H5Dataset
|
|
||||||
import os
|
|
||||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import datetime
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from utils.loss import Fusionloss, cc
|
|
||||||
import kornia
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
------------------------------------------------------------------------------
|
|
||||||
Configure our network
|
|
||||||
------------------------------------------------------------------------------
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
|
||||||
criteria_fusion = Fusionloss()
|
|
||||||
model_str = 'CDDFuse'
|
|
||||||
|
|
||||||
# . Set the hyper-parameters for training
|
|
||||||
num_epochs = 120 # total epoch
|
|
||||||
epoch_gap = 40 # epoches of Phase I
|
|
||||||
|
|
||||||
lr = 1e-4
|
|
||||||
weight_decay = 0
|
|
||||||
batch_size = 8
|
|
||||||
GPU_number = os.environ['CUDA_VISIBLE_DEVICES']
|
|
||||||
# Coefficients of the loss function
|
|
||||||
coeff_mse_loss_VF = 1. # alpha1
|
|
||||||
coeff_mse_loss_IF = 1.
|
|
||||||
coeff_decomp = 2. # alpha2 and alpha4
|
|
||||||
coeff_tv = 5.
|
|
||||||
|
|
||||||
clip_grad_norm_value = 0.01
|
|
||||||
optim_step = 20
|
|
||||||
optim_gamma = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
# Model
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
||||||
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
|
||||||
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
|
||||||
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
|
||||||
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
|
||||||
|
|
||||||
# optimizer, scheduler and loss function
|
|
||||||
optimizer1 = torch.optim.Adam(
|
|
||||||
DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay)
|
|
||||||
optimizer2 = torch.optim.Adam(
|
|
||||||
DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay)
|
|
||||||
optimizer3 = torch.optim.Adam(
|
|
||||||
BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
|
||||||
optimizer4 = torch.optim.Adam(
|
|
||||||
DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
|
||||||
|
|
||||||
scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma)
|
|
||||||
scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma)
|
|
||||||
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
|
|
||||||
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
|
|
||||||
|
|
||||||
MSELoss = nn.MSELoss()
|
|
||||||
L1Loss = nn.L1Loss()
|
|
||||||
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
|
|
||||||
|
|
||||||
|
|
||||||
# data loader
|
|
||||||
trainloader = DataLoader(H5Dataset(r"data/MSRS_train_imgsize_128_stride_200.h5"),
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=0)
|
|
||||||
|
|
||||||
loader = {'train': trainloader, }
|
|
||||||
timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M")
|
|
||||||
|
|
||||||
'''
|
|
||||||
------------------------------------------------------------------------------
|
|
||||||
Train
|
|
||||||
------------------------------------------------------------------------------
|
|
||||||
'''
|
|
||||||
|
|
||||||
step = 0
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
prev_time = time.time()
|
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
''' train '''
|
|
||||||
for i, (data_VIS, data_IR) in enumerate(loader['train']):
|
|
||||||
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
|
||||||
DIDF_Encoder.train()
|
|
||||||
DIDF_Decoder.train()
|
|
||||||
BaseFuseLayer.train()
|
|
||||||
DetailFuseLayer.train()
|
|
||||||
|
|
||||||
DIDF_Encoder.zero_grad()
|
|
||||||
DIDF_Decoder.zero_grad()
|
|
||||||
BaseFuseLayer.zero_grad()
|
|
||||||
DetailFuseLayer.zero_grad()
|
|
||||||
|
|
||||||
optimizer1.zero_grad()
|
|
||||||
optimizer2.zero_grad()
|
|
||||||
optimizer3.zero_grad()
|
|
||||||
optimizer4.zero_grad()
|
|
||||||
|
|
||||||
if epoch < epoch_gap: #Phase I
|
|
||||||
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
|
|
||||||
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR)
|
|
||||||
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
|
|
||||||
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
|
|
||||||
|
|
||||||
cc_loss_B = cc(feature_V_B, feature_I_B)
|
|
||||||
cc_loss_D = cc(feature_V_D, feature_I_D)
|
|
||||||
mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + MSELoss(data_VIS, data_VIS_hat)
|
|
||||||
mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + MSELoss(data_IR, data_IR_hat)
|
|
||||||
|
|
||||||
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
|
|
||||||
kornia.filters.SpatialGradient()(data_VIS_hat))
|
|
||||||
|
|
||||||
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
|
|
||||||
|
|
||||||
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
|
|
||||||
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
nn.utils.clip_grad_norm_(
|
|
||||||
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
|
||||||
nn.utils.clip_grad_norm_(
|
|
||||||
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
|
||||||
optimizer1.step()
|
|
||||||
optimizer2.step()
|
|
||||||
else: #Phase II
|
|
||||||
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
|
|
||||||
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR)
|
|
||||||
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
|
|
||||||
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
|
|
||||||
data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
|
|
||||||
|
|
||||||
|
|
||||||
mse_loss_V = 5*Loss_ssim(data_VIS, data_Fuse) + MSELoss(data_VIS, data_Fuse)
|
|
||||||
mse_loss_I = 5*Loss_ssim(data_IR, data_Fuse) + MSELoss(data_IR, data_Fuse)
|
|
||||||
|
|
||||||
cc_loss_B = cc(feature_V_B, feature_I_B)
|
|
||||||
cc_loss_D = cc(feature_V_D, feature_I_D)
|
|
||||||
loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
|
|
||||||
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
|
|
||||||
|
|
||||||
loss = fusionloss + coeff_decomp * loss_decomp
|
|
||||||
loss.backward()
|
|
||||||
nn.utils.clip_grad_norm_(
|
|
||||||
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
|
||||||
nn.utils.clip_grad_norm_(
|
|
||||||
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
|
||||||
nn.utils.clip_grad_norm_(
|
|
||||||
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
|
||||||
nn.utils.clip_grad_norm_(
|
|
||||||
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
|
||||||
optimizer1.step()
|
|
||||||
optimizer2.step()
|
|
||||||
optimizer3.step()
|
|
||||||
optimizer4.step()
|
|
||||||
|
|
||||||
# Determine approximate time left
|
|
||||||
batches_done = epoch * len(loader['train']) + i
|
|
||||||
batches_left = num_epochs * len(loader['train']) - batches_done
|
|
||||||
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
|
|
||||||
prev_time = time.time()
|
|
||||||
sys.stdout.write(
|
|
||||||
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
|
|
||||||
% (
|
|
||||||
epoch,
|
|
||||||
num_epochs,
|
|
||||||
i,
|
|
||||||
len(loader['train']),
|
|
||||||
loss.item(),
|
|
||||||
time_left,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# adjust the learning rate
|
|
||||||
|
|
||||||
scheduler1.step()
|
|
||||||
scheduler2.step()
|
|
||||||
if not epoch < epoch_gap:
|
|
||||||
scheduler3.step()
|
|
||||||
scheduler4.step()
|
|
||||||
|
|
||||||
if optimizer1.param_groups[0]['lr'] <= 1e-6:
|
|
||||||
optimizer1.param_groups[0]['lr'] = 1e-6
|
|
||||||
if optimizer2.param_groups[0]['lr'] <= 1e-6:
|
|
||||||
optimizer2.param_groups[0]['lr'] = 1e-6
|
|
||||||
if optimizer3.param_groups[0]['lr'] <= 1e-6:
|
|
||||||
optimizer3.param_groups[0]['lr'] = 1e-6
|
|
||||||
if optimizer4.param_groups[0]['lr'] <= 1e-6:
|
|
||||||
optimizer4.param_groups[0]['lr'] = 1e-6
|
|
||||||
|
|
||||||
if True:
|
|
||||||
checkpoint = {
|
|
||||||
'DIDF_Encoder': DIDF_Encoder.state_dict(),
|
|
||||||
'DIDF_Decoder': DIDF_Decoder.state_dict(),
|
|
||||||
'BaseFuseLayer': BaseFuseLayer.state_dict(),
|
|
||||||
'DetailFuseLayer': DetailFuseLayer.state_dict(),
|
|
||||||
}
|
|
||||||
torch.save(checkpoint, os.path.join("models/CDDFuse_"+timestamp+'.pth'))
|
|
||||||
|
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,22 +0,0 @@
|
|||||||
import torch.utils.data as Data
|
|
||||||
import h5py
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
class H5Dataset(Data.Dataset):
|
|
||||||
def __init__(self, h5file_path):
|
|
||||||
self.h5file_path = h5file_path
|
|
||||||
h5f = h5py.File(h5file_path, 'r')
|
|
||||||
self.keys = list(h5f['ir_patchs'].keys())
|
|
||||||
h5f.close()
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.keys)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
h5f = h5py.File(self.h5file_path, 'r')
|
|
||||||
key = self.keys[index]
|
|
||||||
IR = np.array(h5f['ir_patchs'][key])
|
|
||||||
VIS = np.array(h5f['vis_patchs'][key])
|
|
||||||
h5f.close()
|
|
||||||
return torch.Tensor(VIS), torch.Tensor(IR)
|
|
@ -2,7 +2,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
class Fusionloss(nn.Module):
|
class Fusionloss(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -38,17 +38,3 @@ class Sobelxy(nn.Module):
|
|||||||
sobelx=F.conv2d(x, self.weightx, padding=1)
|
sobelx=F.conv2d(x, self.weightx, padding=1)
|
||||||
sobely=F.conv2d(x, self.weighty, padding=1)
|
sobely=F.conv2d(x, self.weighty, padding=1)
|
||||||
return torch.abs(sobelx)+torch.abs(sobely)
|
return torch.abs(sobelx)+torch.abs(sobely)
|
||||||
|
|
||||||
|
|
||||||
def cc(img1, img2):
|
|
||||||
eps = torch.finfo(torch.float32).eps
|
|
||||||
"""Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.]."""
|
|
||||||
N, C, _, _ = img1.shape
|
|
||||||
img1 = img1.reshape(N, C, -1)
|
|
||||||
img2 = img2.reshape(N, C, -1)
|
|
||||||
img1 = img1 - img1.mean(dim=-1, keepdim=True)
|
|
||||||
img2 = img2 - img2.mean(dim=-1, keepdim=True)
|
|
||||||
cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum(img1 **
|
|
||||||
2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1)))
|
|
||||||
cc = torch.clamp(cc, -1., 1.)
|
|
||||||
return cc.mean()
|
|
Loading…
Reference in New Issue
Block a user