模型结构

DetailFeatureExtraction增加了一个增强残差
BaseFeatureExtraction增加了
x = self.WTConv2d(x)
This commit is contained in:
whaifree 2024-10-07 15:24:33 +08:00
parent af0a9f358c
commit afd55abe9e
5 changed files with 289 additions and 5 deletions

42
componets/whaiutil.py Normal file
View File

@ -0,0 +1,42 @@
import os
from PIL import Image
def transfer(input_path, quality=20, resize_factor=0.1):
# 打开TIFF图像
# img = Image.open(input_path)
#
# # 保存为JPEG并设置压缩质量
# img.save(output_path, 'JPEG', quality=quality)
# input_path = os.path.join(input_folder, filename)
# 获取input_path的文件名
# 使用os.path.splitext获取文件名和后缀的元组
# 使用os.path.basename获取文件名包含后缀
filename_with_extension = os.path.basename(input_path)
filename, file_extension = os.path.splitext(filename_with_extension)
# 使用os.path.dirname获取文件所在的目录路径
output_folder = os.path.dirname(input_path)
output_path = os.path.join(output_folder, filename + '.jpg')
img = Image.open(input_path)
# 将图像缩小到原来的一半
new_width = int(img.width * resize_factor)
new_height = int(img.height * resize_factor)
resized_img = img.resize((new_width, new_height))
# 保存为JPEG并设置压缩质量
# 转换为RGB模式丢弃透明通道
rgb_img = resized_img.convert('RGB')
# 保存为JPEG并设置压缩质量
# 压缩
rgb_img.save(output_path, 'JPEG', quality=quality)
print(f'{output_path} 转换完成')
return output_path

33
logs/20241007_whai.log Normal file
View File

@ -0,0 +1,33 @@
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.01 40.67 15.39 1.53 1.76 0.64 0.53 0.95
================================================================================

View File

@ -13,7 +13,7 @@ logging.basicConfig(level=logging.CRITICAL)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/PFCFusion10-05-20-46.pth"
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion10-06-22-17.pth"
for dataset_name in ["TNO"]:
print("\n"*2+"="*80)

201
test_sar.py Normal file
View File

@ -0,0 +1,201 @@
import argparse
import sys
import uuid
import cv2
from PIL import Image
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
import os
import numpy as np
from utils.Evaluator import Evaluator
import torch
import torch.nn as nn
from utils.img_read_save import img_save, image_read_cv2
import warnings
import logging
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.CRITICAL)
path = os.path.dirname(sys.argv[0]) + "\\"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# ckpt_path=r"models/CDDFuse_IVF.pth"
ckpt_path = r"" + path + "models/CDDFuse_IVF.pth"
print(torch.cuda.is_available())
def main(opt):
# --viPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\vi\ir_2.png --irPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\ir\ir_2.png --outputPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\
sar_path = transfer(opt.sarPath, 100, 0.15)
vi_path = transfer(opt.viPath, 100, 0.15)
output_path = opt.outputPath
print("\n" * 2 + "=" * 80)
print("The sar_path of " + sar_path + ' :')
print("The vi_path of " + vi_path + ' :')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
Encoder.load_state_dict(torch.load(ckpt_path)['DIDF_Encoder'])
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
Encoder.eval()
Decoder.eval()
BaseFuseLayer.eval()
DetailFuseLayer.eval()
with torch.no_grad():
data_IR = image_read_cv2(sar_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
data_VIS = image_read_cv2(vi_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
data_VIS_BGR = cv2.imread(vi_path)
_, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb))
data_IR, data_VIS = torch.FloatTensor(data_IR), torch.FloatTensor(data_VIS)
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
feature_V_B, feature_V_D, feature_V = Encoder(data_VIS)
feature_I_B, feature_I_D, feature_I = Encoder(data_IR)
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
data_Fuse = (data_Fuse - torch.min(data_Fuse)) / (torch.max(data_Fuse) - torch.min(data_Fuse))
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
fi = fi.astype(np.uint8)
ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb))
rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB)
# 获取文件名(包含后缀)
file_name_with_extension = os.path.basename(sar_path)
# 分离文件名和文件后缀
file_name, file_extension = os.path.splitext(file_name_with_extension)
img_save(rgb_fi, "fusionSAR_" + file_name, output_path)
print("输出文件路径:" + output_path + "fusionSAR_" + file_name + ".jpg")
# metric_result = np.zeros((8))
# sarImagePath = sar_path
# ir = image_read_cv2(sarImagePath, 'GRAY')
# viImagePath = vi_path
# vi = image_read_cv2(viImagePath, 'GRAY')
#
# fusionImagePath = os.path.join(output_path, "fusionSAR_{}.jpg".format(file_name))
# fi = image_read_cv2(fusionImagePath, 'GRAY')
# # 统计
# metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
# , Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
# , Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
# , Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
#
# metric_result /= len(os.listdir(output_path))
# print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
# print("对比结果:" + model_name + '\t' + str(np.round(metric_result[0], 2)) + '\t'
# + str(np.round(metric_result[1], 2)) + '\t'
# + str(np.round(metric_result[2], 2)) + '\t'
# + str(np.round(metric_result[3], 2)) + '\t'
# + str(np.round(metric_result[4], 2)) + '\t'
# + str(np.round(metric_result[5], 2)) + '\t'
# + str(np.round(metric_result[6], 2)) + '\t'
# + str(np.round(metric_result[7], 2))
# )
# print("=" * 80)
def transfer(input_path, quality=20, resize_factor=0.1):
# 打开TIFF图像
# img = Image.open(input_path)
#
# # 保存为JPEG并设置压缩质量
# img.save(output_path, 'JPEG', quality=quality)
# input_path = os.path.join(input_folder, filename)
# 获取input_path的文件名
# 使用os.path.splitext获取文件名和后缀的元组
# 使用os.path.basename获取文件名包含后缀
filename_with_extension = os.path.basename(input_path)
filename, file_extension = os.path.splitext(filename_with_extension)
# 使用os.path.dirname获取文件所在的目录路径
output_folder = os.path.dirname(input_path)
output_path = os.path.join(output_folder, filename + '.jpg')
img = Image.open(input_path)
# 将图像缩小到原来的一半
new_width = int(img.width * resize_factor)
new_height = int(img.height * resize_factor)
resized_img = img.resize((new_width, new_height))
# 保存为JPEG并设置压缩质量
# 转换为RGB模式丢弃透明通道
rgb_img = resized_img.convert('RGB')
# 保存为JPEG并设置压缩质量
# 压缩
rgb_img.save(output_path, 'JPEG', quality=quality)
print(f'{output_path} 转换完成')
return output_path
def parse_opt():
parser = argparse.ArgumentParser(
description='python.exe --sarPath "sar绝对路径" --viPath "可见光路径" --outputPath "输出文件路径"')
parser.add_argument('--sarPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\Test\\ir\\NH49E011024.tif", required=True,
help="是否为多路径") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
parser.add_argument('--viPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\Test\\vi\\NH49E011024.tif", required=True,
help="完整目录路径!,可以为数组") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
parser.add_argument('--outputPath', type=str, default='results_detect', required=True,
help="输出路径!") # 使用的也是图片的目标地址
opt = parser.parse_args()
return opt
if __name__ == '__main__':
print(torch.cuda.is_available())
opt = parse_opt()
main(opt)
def add_prefix_to_files(directory_path, prefix):
# 使用os.listdir获取目录中的所有文件
files = os.listdir(directory_path)
for old_filename in files:
# 构建新的文件名
new_filename = f"{prefix}_{old_filename}"
# 构建旧文件路径和新文件路径
old_path = os.path.join(directory_path, old_filename)
new_path = os.path.join(directory_path, new_filename)
# 使用os.rename进行文件重命名
os.rename(old_path, new_path)
print(f'{old_filename} 重命名为 {new_filename}')
# 替换为实际的目录路径和前缀
# directory_path = '/path/to/your/directory'
# new_prefix = 'new'
#
# # 执行批量重命名
# add_prefix_to_files(directory_path, new_prefix)

View File

@ -1,3 +1,4 @@
import os
import subprocess
import datetime
@ -8,8 +9,15 @@ command = "/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFu
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = f"/home/star/whaiDir/PFCFuse/logs/log_{current_time}.log"
# 运行命令并将输出重定向到文件
with open(output_file, 'w') as file:
subprocess.run(command.split(), stdout=file, stderr=subprocess.STDOUT)
try:
# 运行命令并将输出重定向到文件
with open(output_file, 'w') as file:
result = subprocess.run(command.split(), stdout=file, stderr=subprocess.STDOUT, check=True)
print(f"Command output has been written to {output_file}")
# 如果命令成功执行,则打印确认信息
print(f"Command executed successfully. Output has been written to {output_file}")
except subprocess.CalledProcessError as e:
# 如果命令执行失败,则删除文件并打印错误信息
if os.path.exists(output_file):
os.remove(output_file)
print(f"Command failed with return code {e.returncode}. No log file was created.")