模型结构
DetailFeatureExtraction增加了一个增强残差 BaseFeatureExtraction增加了 x = self.WTConv2d(x)
This commit is contained in:
parent
af0a9f358c
commit
afd55abe9e
42
componets/whaiutil.py
Normal file
42
componets/whaiutil.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def transfer(input_path, quality=20, resize_factor=0.1):
|
||||||
|
# 打开TIFF图像
|
||||||
|
# img = Image.open(input_path)
|
||||||
|
#
|
||||||
|
# # 保存为JPEG,并设置压缩质量
|
||||||
|
# img.save(output_path, 'JPEG', quality=quality)
|
||||||
|
|
||||||
|
# input_path = os.path.join(input_folder, filename)
|
||||||
|
# 获取input_path的文件名
|
||||||
|
|
||||||
|
# 使用os.path.splitext获取文件名和后缀的元组
|
||||||
|
# 使用os.path.basename获取文件名(包含后缀)
|
||||||
|
filename_with_extension = os.path.basename(input_path)
|
||||||
|
filename, file_extension = os.path.splitext(filename_with_extension)
|
||||||
|
|
||||||
|
# 使用os.path.dirname获取文件所在的目录路径
|
||||||
|
output_folder = os.path.dirname(input_path)
|
||||||
|
|
||||||
|
output_path = os.path.join(output_folder, filename + '.jpg')
|
||||||
|
|
||||||
|
img = Image.open(input_path)
|
||||||
|
|
||||||
|
# 将图像缩小到原来的一半
|
||||||
|
new_width = int(img.width * resize_factor)
|
||||||
|
new_height = int(img.height * resize_factor)
|
||||||
|
resized_img = img.resize((new_width, new_height))
|
||||||
|
|
||||||
|
# 保存为JPEG,并设置压缩质量
|
||||||
|
# 转换为RGB模式,丢弃透明通道
|
||||||
|
rgb_img = resized_img.convert('RGB')
|
||||||
|
|
||||||
|
# 保存为JPEG,并设置压缩质量
|
||||||
|
# 压缩
|
||||||
|
rgb_img.save(output_path, 'JPEG', quality=quality)
|
||||||
|
|
||||||
|
print(f'{output_path} 转换完成')
|
||||||
|
|
||||||
|
return output_path
|
33
logs/20241007_whai.log
Normal file
33
logs/20241007_whai.log
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
|
||||||
|
|
||||||
|
|
||||||
|
================================================================================
|
||||||
|
The test result of TNO :
|
||||||
|
19.png
|
||||||
|
05.png
|
||||||
|
21.png
|
||||||
|
18.png
|
||||||
|
15.png
|
||||||
|
22.png
|
||||||
|
14.png
|
||||||
|
13.png
|
||||||
|
08.png
|
||||||
|
01.png
|
||||||
|
02.png
|
||||||
|
03.png
|
||||||
|
25.png
|
||||||
|
17.png
|
||||||
|
11.png
|
||||||
|
16.png
|
||||||
|
06.png
|
||||||
|
07.png
|
||||||
|
09.png
|
||||||
|
10.png
|
||||||
|
12.png
|
||||||
|
23.png
|
||||||
|
24.png
|
||||||
|
20.png
|
||||||
|
04.png
|
||||||
|
EN SD SF MI SCD VIF Qabf SSIM
|
||||||
|
PFCFuse 7.01 40.67 15.39 1.53 1.76 0.64 0.53 0.95
|
||||||
|
================================================================================
|
@ -13,7 +13,7 @@ logging.basicConfig(level=logging.CRITICAL)
|
|||||||
|
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/PFCFusion10-05-20-46.pth"
|
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion10-06-22-17.pth"
|
||||||
|
|
||||||
for dataset_name in ["TNO"]:
|
for dataset_name in ["TNO"]:
|
||||||
print("\n"*2+"="*80)
|
print("\n"*2+"="*80)
|
||||||
|
201
test_sar.py
Normal file
201
test_sar.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from utils.Evaluator import Evaluator
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from utils.img_read_save import img_save, image_read_cv2
|
||||||
|
import warnings
|
||||||
|
import logging
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
logging.basicConfig(level=logging.CRITICAL)
|
||||||
|
|
||||||
|
path = os.path.dirname(sys.argv[0]) + "\\"
|
||||||
|
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
# ckpt_path=r"models/CDDFuse_IVF.pth"
|
||||||
|
ckpt_path = r"" + path + "models/CDDFuse_IVF.pth"
|
||||||
|
|
||||||
|
print(torch.cuda.is_available())
|
||||||
|
|
||||||
|
|
||||||
|
def main(opt):
|
||||||
|
# --viPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\vi\ir_2.png --irPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\ir\ir_2.png --outputPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\
|
||||||
|
|
||||||
|
sar_path = transfer(opt.sarPath, 100, 0.15)
|
||||||
|
|
||||||
|
vi_path = transfer(opt.viPath, 100, 0.15)
|
||||||
|
|
||||||
|
output_path = opt.outputPath
|
||||||
|
|
||||||
|
print("\n" * 2 + "=" * 80)
|
||||||
|
|
||||||
|
print("The sar_path of " + sar_path + ' :')
|
||||||
|
print("The vi_path of " + vi_path + ' :')
|
||||||
|
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||||
|
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||||
|
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||||
|
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||||
|
|
||||||
|
Encoder.load_state_dict(torch.load(ckpt_path)['DIDF_Encoder'])
|
||||||
|
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
|
||||||
|
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
|
||||||
|
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
|
||||||
|
Encoder.eval()
|
||||||
|
Decoder.eval()
|
||||||
|
BaseFuseLayer.eval()
|
||||||
|
DetailFuseLayer.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
data_IR = image_read_cv2(sar_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
|
||||||
|
data_VIS = image_read_cv2(vi_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0
|
||||||
|
data_VIS_BGR = cv2.imread(vi_path)
|
||||||
|
_, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb))
|
||||||
|
|
||||||
|
data_IR, data_VIS = torch.FloatTensor(data_IR), torch.FloatTensor(data_VIS)
|
||||||
|
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||||||
|
|
||||||
|
feature_V_B, feature_V_D, feature_V = Encoder(data_VIS)
|
||||||
|
feature_I_B, feature_I_D, feature_I = Encoder(data_IR)
|
||||||
|
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
|
||||||
|
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
|
||||||
|
|
||||||
|
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
|
||||||
|
data_Fuse = (data_Fuse - torch.min(data_Fuse)) / (torch.max(data_Fuse) - torch.min(data_Fuse))
|
||||||
|
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
|
||||||
|
fi = fi.astype(np.uint8)
|
||||||
|
ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb))
|
||||||
|
rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB)
|
||||||
|
|
||||||
|
# 获取文件名(包含后缀)
|
||||||
|
file_name_with_extension = os.path.basename(sar_path)
|
||||||
|
# 分离文件名和文件后缀
|
||||||
|
file_name, file_extension = os.path.splitext(file_name_with_extension)
|
||||||
|
|
||||||
|
|
||||||
|
img_save(rgb_fi, "fusionSAR_" + file_name, output_path)
|
||||||
|
print("输出文件路径:" + output_path + "fusionSAR_" + file_name + ".jpg")
|
||||||
|
|
||||||
|
# metric_result = np.zeros((8))
|
||||||
|
# sarImagePath = sar_path
|
||||||
|
# ir = image_read_cv2(sarImagePath, 'GRAY')
|
||||||
|
# viImagePath = vi_path
|
||||||
|
# vi = image_read_cv2(viImagePath, 'GRAY')
|
||||||
|
#
|
||||||
|
# fusionImagePath = os.path.join(output_path, "fusionSAR_{}.jpg".format(file_name))
|
||||||
|
# fi = image_read_cv2(fusionImagePath, 'GRAY')
|
||||||
|
# # 统计
|
||||||
|
# metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
|
||||||
|
# , Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
|
||||||
|
# , Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
|
||||||
|
# , Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
|
||||||
|
#
|
||||||
|
# metric_result /= len(os.listdir(output_path))
|
||||||
|
# print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
|
||||||
|
# print("对比结果:" + model_name + '\t' + str(np.round(metric_result[0], 2)) + '\t'
|
||||||
|
# + str(np.round(metric_result[1], 2)) + '\t'
|
||||||
|
# + str(np.round(metric_result[2], 2)) + '\t'
|
||||||
|
# + str(np.round(metric_result[3], 2)) + '\t'
|
||||||
|
# + str(np.round(metric_result[4], 2)) + '\t'
|
||||||
|
# + str(np.round(metric_result[5], 2)) + '\t'
|
||||||
|
# + str(np.round(metric_result[6], 2)) + '\t'
|
||||||
|
# + str(np.round(metric_result[7], 2))
|
||||||
|
# )
|
||||||
|
# print("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
def transfer(input_path, quality=20, resize_factor=0.1):
|
||||||
|
# 打开TIFF图像
|
||||||
|
# img = Image.open(input_path)
|
||||||
|
#
|
||||||
|
# # 保存为JPEG,并设置压缩质量
|
||||||
|
# img.save(output_path, 'JPEG', quality=quality)
|
||||||
|
|
||||||
|
# input_path = os.path.join(input_folder, filename)
|
||||||
|
# 获取input_path的文件名
|
||||||
|
|
||||||
|
# 使用os.path.splitext获取文件名和后缀的元组
|
||||||
|
# 使用os.path.basename获取文件名(包含后缀)
|
||||||
|
filename_with_extension = os.path.basename(input_path)
|
||||||
|
filename, file_extension = os.path.splitext(filename_with_extension)
|
||||||
|
|
||||||
|
# 使用os.path.dirname获取文件所在的目录路径
|
||||||
|
output_folder = os.path.dirname(input_path)
|
||||||
|
|
||||||
|
output_path = os.path.join(output_folder, filename + '.jpg')
|
||||||
|
|
||||||
|
img = Image.open(input_path)
|
||||||
|
|
||||||
|
# 将图像缩小到原来的一半
|
||||||
|
new_width = int(img.width * resize_factor)
|
||||||
|
new_height = int(img.height * resize_factor)
|
||||||
|
resized_img = img.resize((new_width, new_height))
|
||||||
|
|
||||||
|
# 保存为JPEG,并设置压缩质量
|
||||||
|
# 转换为RGB模式,丢弃透明通道
|
||||||
|
rgb_img = resized_img.convert('RGB')
|
||||||
|
|
||||||
|
# 保存为JPEG,并设置压缩质量
|
||||||
|
# 压缩
|
||||||
|
rgb_img.save(output_path, 'JPEG', quality=quality)
|
||||||
|
|
||||||
|
print(f'{output_path} 转换完成')
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def parse_opt():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='python.exe --sarPath "sar绝对路径" --viPath "可见光路径" --outputPath "输出文件路径"')
|
||||||
|
parser.add_argument('--sarPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\Test\\ir\\NH49E011024.tif", required=True,
|
||||||
|
help="是否为多路径") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
|
||||||
|
parser.add_argument('--viPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\Test\\vi\\NH49E011024.tif", required=True,
|
||||||
|
help="完整目录路径!,可以为数组") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg
|
||||||
|
parser.add_argument('--outputPath', type=str, default='results_detect', required=True,
|
||||||
|
help="输出路径!") # 使用的也是图片的目标地址
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
return opt
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
print(torch.cuda.is_available())
|
||||||
|
opt = parse_opt()
|
||||||
|
main(opt)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def add_prefix_to_files(directory_path, prefix):
|
||||||
|
# 使用os.listdir获取目录中的所有文件
|
||||||
|
files = os.listdir(directory_path)
|
||||||
|
|
||||||
|
for old_filename in files:
|
||||||
|
# 构建新的文件名
|
||||||
|
new_filename = f"{prefix}_{old_filename}"
|
||||||
|
|
||||||
|
# 构建旧文件路径和新文件路径
|
||||||
|
old_path = os.path.join(directory_path, old_filename)
|
||||||
|
new_path = os.path.join(directory_path, new_filename)
|
||||||
|
|
||||||
|
# 使用os.rename进行文件重命名
|
||||||
|
os.rename(old_path, new_path)
|
||||||
|
|
||||||
|
print(f'{old_filename} 重命名为 {new_filename}')
|
||||||
|
|
||||||
|
# 替换为实际的目录路径和前缀
|
||||||
|
# directory_path = '/path/to/your/directory'
|
||||||
|
# new_prefix = 'new'
|
||||||
|
#
|
||||||
|
# # 执行批量重命名
|
||||||
|
# add_prefix_to_files(directory_path, new_prefix)
|
16
trainExe.py
16
trainExe.py
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
@ -8,8 +9,15 @@ command = "/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFu
|
|||||||
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
output_file = f"/home/star/whaiDir/PFCFuse/logs/log_{current_time}.log"
|
output_file = f"/home/star/whaiDir/PFCFuse/logs/log_{current_time}.log"
|
||||||
|
|
||||||
# 运行命令并将输出重定向到文件
|
try:
|
||||||
with open(output_file, 'w') as file:
|
# 运行命令并将输出重定向到文件
|
||||||
subprocess.run(command.split(), stdout=file, stderr=subprocess.STDOUT)
|
with open(output_file, 'w') as file:
|
||||||
|
result = subprocess.run(command.split(), stdout=file, stderr=subprocess.STDOUT, check=True)
|
||||||
|
|
||||||
print(f"Command output has been written to {output_file}")
|
# 如果命令成功执行,则打印确认信息
|
||||||
|
print(f"Command executed successfully. Output has been written to {output_file}")
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
# 如果命令执行失败,则删除文件并打印错误信息
|
||||||
|
if os.path.exists(output_file):
|
||||||
|
os.remove(output_file)
|
||||||
|
print(f"Command failed with return code {e.returncode}. No log file was created.")
|
||||||
|
Loading…
Reference in New Issue
Block a user