Compare commits
11 Commits
775cbdf20f
...
f28bb255f4
Author | SHA1 | Date | |
---|---|---|---|
|
f28bb255f4 | ||
|
315c723399 | ||
|
ef66a0321d | ||
|
125a6bdf6f | ||
|
be28d553fc | ||
|
ac4225c966 | ||
|
8d99c2c4f8 | ||
|
0cf1726eeb | ||
|
1bd418f0e4 | ||
|
f87a65e68e | ||
|
6738c9057d |
@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="jdk" jdkName="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/pfcfuse/bin/python)" jdkType="Python SDK" />
|
<orderEntry type="inheritedJdk" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
<component name="PyDocumentationSettings">
|
<component name="PyDocumentationSettings">
|
||||||
|
@ -12,5 +12,5 @@
|
|||||||
</MavenGeneralSettings>
|
</MavenGeneralSettings>
|
||||||
</option>
|
</option>
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/pfcfuse/bin/python)" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (pfcfuse)" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
@ -17,14 +17,16 @@ current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|||||||
|
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-17-10-34.pth"
|
pth_path = "whaiFusion11-17-20-18"
|
||||||
|
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/"+pth_path+".pth"
|
||||||
|
print("path_pth:{}".format(ckpt_path))
|
||||||
|
|
||||||
for dataset_name in ["sar"]:
|
for dataset_name in ["sar"]:
|
||||||
print("\n"*2+"="*80)
|
print("\n"*2+"="*80)
|
||||||
model_name="PFCFuse Enhance 增加widthblock"
|
model_name=pth_path
|
||||||
print("The test result of "+dataset_name+' :')
|
print("The test result of "+dataset_name+' :')
|
||||||
test_folder = os.path.join('test_img', dataset_name)
|
test_folder = os.path.join('test_img', dataset_name)
|
||||||
test_out_folder=os.path.join('test_result',current_time,dataset_name)
|
test_out_folder=os.path.join('test_result',pth_path,dataset_name)
|
||||||
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||||
|
4
train.py
4
train.py
@ -259,5 +259,7 @@ if True:
|
|||||||
'BaseFuseLayer': BaseFuseLayer.state_dict(),
|
'BaseFuseLayer': BaseFuseLayer.state_dict(),
|
||||||
'DetailFuseLayer': DetailFuseLayer.state_dict(),
|
'DetailFuseLayer': DetailFuseLayer.state_dict(),
|
||||||
}
|
}
|
||||||
torch.save(checkpoint, os.path.join("models/whaiFusion"+timestamp+'.pth'))
|
savepth = os.path.join("models/whaiFusion" + timestamp + '.pth');
|
||||||
|
torch.save(checkpoint, savepth)
|
||||||
|
print("save model:{}".format(savepth))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user