diff --git a/componets/whaiutil.py b/componets/whaiutil.py new file mode 100644 index 0000000..4d877ac --- /dev/null +++ b/componets/whaiutil.py @@ -0,0 +1,42 @@ +import os +from PIL import Image + + +def transfer(input_path, quality=20, resize_factor=0.1): + # 打开TIFF图像 + # img = Image.open(input_path) + # + # # 保存为JPEG,并设置压缩质量 + # img.save(output_path, 'JPEG', quality=quality) + + # input_path = os.path.join(input_folder, filename) + # 获取input_path的文件名 + + # 使用os.path.splitext获取文件名和后缀的元组 + # 使用os.path.basename获取文件名(包含后缀) + filename_with_extension = os.path.basename(input_path) + filename, file_extension = os.path.splitext(filename_with_extension) + + # 使用os.path.dirname获取文件所在的目录路径 + output_folder = os.path.dirname(input_path) + + output_path = os.path.join(output_folder, filename + '.jpg') + + img = Image.open(input_path) + + # 将图像缩小到原来的一半 + new_width = int(img.width * resize_factor) + new_height = int(img.height * resize_factor) + resized_img = img.resize((new_width, new_height)) + + # 保存为JPEG,并设置压缩质量 + # 转换为RGB模式,丢弃透明通道 + rgb_img = resized_img.convert('RGB') + + # 保存为JPEG,并设置压缩质量 + # 压缩 + rgb_img.save(output_path, 'JPEG', quality=quality) + + print(f'{output_path} 转换完成') + + return output_path diff --git a/logs/20241007_whai.log b/logs/20241007_whai.log new file mode 100644 index 0000000..93ba044 --- /dev/null +++ b/logs/20241007_whai.log @@ -0,0 +1,33 @@ +/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py + + +================================================================================ +The test result of TNO : +19.png +05.png +21.png +18.png +15.png +22.png +14.png +13.png +08.png +01.png +02.png +03.png +25.png +17.png +11.png +16.png +06.png +07.png +09.png +10.png +12.png +23.png +24.png +20.png +04.png + EN SD SF MI SCD VIF Qabf SSIM +PFCFuse 7.01 40.67 15.39 1.53 1.76 0.64 0.53 0.95 +================================================================================ diff --git a/test_IVF.py b/test_IVF.py index cc0012d..93c5f1a 100644 --- a/test_IVF.py +++ b/test_IVF.py @@ -13,7 +13,7 @@ logging.basicConfig(level=logging.CRITICAL) os.environ["CUDA_VISIBLE_DEVICES"] = "0" -ckpt_path= r"/home/star/whaiDir/PFCFuse/models/PFCFusion10-05-20-46.pth" +ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion10-06-22-17.pth" for dataset_name in ["TNO"]: print("\n"*2+"="*80) diff --git a/test_sar.py b/test_sar.py new file mode 100644 index 0000000..142c4c6 --- /dev/null +++ b/test_sar.py @@ -0,0 +1,201 @@ +import argparse +import sys +import uuid + +import cv2 +from PIL import Image + +from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction +import os +import numpy as np +from utils.Evaluator import Evaluator +import torch +import torch.nn as nn +from utils.img_read_save import img_save, image_read_cv2 +import warnings +import logging + +warnings.filterwarnings("ignore") +logging.basicConfig(level=logging.CRITICAL) + +path = os.path.dirname(sys.argv[0]) + "\\" + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +# ckpt_path=r"models/CDDFuse_IVF.pth" +ckpt_path = r"" + path + "models/CDDFuse_IVF.pth" + +print(torch.cuda.is_available()) + + +def main(opt): + # --viPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\vi\ir_2.png --irPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\ir\ir_2.png --outputPath D:\PythonProject\MMIF-CDDFuse\test_img\Test\ + + sar_path = transfer(opt.sarPath, 100, 0.15) + + vi_path = transfer(opt.viPath, 100, 0.15) + + output_path = opt.outputPath + + print("\n" * 2 + "=" * 80) + + print("The sar_path of " + sar_path + ' :') + print("The vi_path of " + vi_path + ' :') + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + Encoder = nn.DataParallel(Restormer_Encoder()).to(device) + Decoder = nn.DataParallel(Restormer_Decoder()).to(device) + BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device) + DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device) + + Encoder.load_state_dict(torch.load(ckpt_path)['DIDF_Encoder']) + Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder']) + BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer']) + DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer']) + Encoder.eval() + Decoder.eval() + BaseFuseLayer.eval() + DetailFuseLayer.eval() + + with torch.no_grad(): + data_IR = image_read_cv2(sar_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0 + data_VIS = image_read_cv2(vi_path, mode='GRAY')[np.newaxis, np.newaxis, ...] / 255.0 + data_VIS_BGR = cv2.imread(vi_path) + _, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb)) + + data_IR, data_VIS = torch.FloatTensor(data_IR), torch.FloatTensor(data_VIS) + data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda() + + feature_V_B, feature_V_D, feature_V = Encoder(data_VIS) + feature_I_B, feature_I_D, feature_I = Encoder(data_IR) + feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B) + feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D) + + data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D) + data_Fuse = (data_Fuse - torch.min(data_Fuse)) / (torch.max(data_Fuse) - torch.min(data_Fuse)) + fi = np.squeeze((data_Fuse * 255).cpu().numpy()) + fi = fi.astype(np.uint8) + ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb)) + rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB) + + # 获取文件名(包含后缀) + file_name_with_extension = os.path.basename(sar_path) + # 分离文件名和文件后缀 + file_name, file_extension = os.path.splitext(file_name_with_extension) + + + img_save(rgb_fi, "fusionSAR_" + file_name, output_path) + print("输出文件路径:" + output_path + "fusionSAR_" + file_name + ".jpg") + + # metric_result = np.zeros((8)) + # sarImagePath = sar_path + # ir = image_read_cv2(sarImagePath, 'GRAY') + # viImagePath = vi_path + # vi = image_read_cv2(viImagePath, 'GRAY') + # + # fusionImagePath = os.path.join(output_path, "fusionSAR_{}.jpg".format(file_name)) + # fi = image_read_cv2(fusionImagePath, 'GRAY') + # # 统计 + # metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi) + # , Evaluator.SF(fi), Evaluator.MI(fi, ir, vi) + # , Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi) + # , Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)]) + # + # metric_result /= len(os.listdir(output_path)) + # print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM") + # print("对比结果:" + model_name + '\t' + str(np.round(metric_result[0], 2)) + '\t' + # + str(np.round(metric_result[1], 2)) + '\t' + # + str(np.round(metric_result[2], 2)) + '\t' + # + str(np.round(metric_result[3], 2)) + '\t' + # + str(np.round(metric_result[4], 2)) + '\t' + # + str(np.round(metric_result[5], 2)) + '\t' + # + str(np.round(metric_result[6], 2)) + '\t' + # + str(np.round(metric_result[7], 2)) + # ) + # print("=" * 80) + + +def transfer(input_path, quality=20, resize_factor=0.1): + # 打开TIFF图像 + # img = Image.open(input_path) + # + # # 保存为JPEG,并设置压缩质量 + # img.save(output_path, 'JPEG', quality=quality) + + # input_path = os.path.join(input_folder, filename) + # 获取input_path的文件名 + + # 使用os.path.splitext获取文件名和后缀的元组 + # 使用os.path.basename获取文件名(包含后缀) + filename_with_extension = os.path.basename(input_path) + filename, file_extension = os.path.splitext(filename_with_extension) + + # 使用os.path.dirname获取文件所在的目录路径 + output_folder = os.path.dirname(input_path) + + output_path = os.path.join(output_folder, filename + '.jpg') + + img = Image.open(input_path) + + # 将图像缩小到原来的一半 + new_width = int(img.width * resize_factor) + new_height = int(img.height * resize_factor) + resized_img = img.resize((new_width, new_height)) + + # 保存为JPEG,并设置压缩质量 + # 转换为RGB模式,丢弃透明通道 + rgb_img = resized_img.convert('RGB') + + # 保存为JPEG,并设置压缩质量 + # 压缩 + rgb_img.save(output_path, 'JPEG', quality=quality) + + print(f'{output_path} 转换完成') + + return output_path + + +def parse_opt(): + parser = argparse.ArgumentParser( + description='python.exe --sarPath "sar绝对路径" --viPath "可见光路径" --outputPath "输出文件路径"') + parser.add_argument('--sarPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\Test\\ir\\NH49E011024.tif", required=True, + help="是否为多路径") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg + parser.add_argument('--viPath', type=str, default="D:\\PythonProject\\MMIF-CDDFuse\\test_img\\Test\\vi\\NH49E011024.tif", required=True, + help="完整目录路径!,可以为数组") # 这里全部都是使用图片的名字,默认是 项目路径 + 2.jpg + parser.add_argument('--outputPath', type=str, default='results_detect', required=True, + help="输出路径!") # 使用的也是图片的目标地址 + + opt = parser.parse_args() + return opt + + +if __name__ == '__main__': + + print(torch.cuda.is_available()) + opt = parse_opt() + main(opt) + + + +def add_prefix_to_files(directory_path, prefix): + # 使用os.listdir获取目录中的所有文件 + files = os.listdir(directory_path) + + for old_filename in files: + # 构建新的文件名 + new_filename = f"{prefix}_{old_filename}" + + # 构建旧文件路径和新文件路径 + old_path = os.path.join(directory_path, old_filename) + new_path = os.path.join(directory_path, new_filename) + + # 使用os.rename进行文件重命名 + os.rename(old_path, new_path) + + print(f'{old_filename} 重命名为 {new_filename}') + +# 替换为实际的目录路径和前缀 +# directory_path = '/path/to/your/directory' +# new_prefix = 'new' +# +# # 执行批量重命名 +# add_prefix_to_files(directory_path, new_prefix) diff --git a/trainExe.py b/trainExe.py index 8b107f2..7d73e8b 100644 --- a/trainExe.py +++ b/trainExe.py @@ -1,3 +1,4 @@ +import os import subprocess import datetime @@ -8,8 +9,15 @@ command = "/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFu current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") output_file = f"/home/star/whaiDir/PFCFuse/logs/log_{current_time}.log" -# 运行命令并将输出重定向到文件 -with open(output_file, 'w') as file: - subprocess.run(command.split(), stdout=file, stderr=subprocess.STDOUT) +try: + # 运行命令并将输出重定向到文件 + with open(output_file, 'w') as file: + result = subprocess.run(command.split(), stdout=file, stderr=subprocess.STDOUT, check=True) -print(f"Command output has been written to {output_file}") + # 如果命令成功执行,则打印确认信息 + print(f"Command executed successfully. Output has been written to {output_file}") +except subprocess.CalledProcessError as e: + # 如果命令执行失败,则删除文件并打印错误信息 + if os.path.exists(output_file): + os.remove(output_file) + print(f"Command failed with return code {e.returncode}. No log file was created.")