Compare commits

...

11 Commits

Author SHA1 Message Date
zjut
fa8106838e feat(model): 重构模型并添加新功能
- 重新组织了模型结构,增加了新的特征融合模块
- 添加了深度可分离卷积块和新的细节特征提取模块
- 更新了数据处理流程,使用了新的数据集路径
- 调整了训练参数,增加了训练轮次和学习率- 优化了损失函数,使用了Huber损失替代MSE损失
2024-11-12 10:37:56 +08:00
zjut
e1a339e04b feat(net): 替换 SMFA 为 SCSA 并调整相关配置
- 将 SMFA 模块替换为 SCSA 模块
- 更新项目配置,使用本地 Python 3.8 环境
-调整 SCSA 模块参数,如维度、头数等
- 优化注意力机制,提高模型性能
2024-11-08 12:04:52 +08:00
zjut
41b2ea1ff9 feat(net): 为 net.py 添加新的组件引用并优化前向传播逻辑
- 在 net.py 中引入 SMFA 组件
- 优化 BasicLayer 类的前向传播逻辑
- 添加 SMFA、DynamicFilter 和 UFFC 组件的实现

- 使用SMFA替代Pooling
 self.WTConv2d = WTConv2d(dim, dim)
        self.norm1 = LayerNorm(dim, 'WithBias')
        self.token_mixer  = SMFA(dim=dim)

        # self.token_mixer = Pooling(kernel_size=pool_size)  # vits是msa,MLPs是mlp,这个用pool来替代
        self.norm2 = LayerNorm(dim, 'WithBias')
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
                               act_layer=act_layer, drop=drop)
2024-11-05 14:09:59 +08:00
85dc7a92ed feat(net): 添加时空注意力机制- 在 net.py 中引入 SpatiotemporalAttentionFullNotWeightShared 模块
- 在 Restormer_Decoder 类中添加时空注意力机制处理基础特征和细节特征
2024-10-26 19:18:01 +08:00
zjut
f4b3a933bf train:调整训练日志输出频率
- 将训练日志的输出频率从每100 个批次调整为每个批次
- 此修改可以提供更频繁的训练进度反馈,便于实时监控训练状态
2024-10-26 19:04:12 +08:00
b6486dbaf4 添加 .idea/ 和 status.md到 .gitignore 文件,避免个人配置和状态文件被跟踪。在测试脚本中移除了不必要的打印语句。新增了测试日志和成功运行的日志文件。 2024-10-26 18:37:15 +08:00
7d6d629786 添加 .idea/ 和 status.md到 .gitignore 文件,避免个人配置和状态文件被跟踪。在测试脚本中移除了不必要的打印语句。新增了测试日志和成功运行的日志文件。 2024-10-09 12:04:46 +08:00
15eb10b512 添加 .idea/ 和 status.md到 .gitignore 文件,避免个人配置和状态文件被跟踪。在测试脚本中移除了不必要的打印语句。新增了测试日志和成功运行的日志文件。 2024-10-09 11:57:57 +08:00
5e561ab6f7 修改代码结构,提高可读性和可维护性;调整训练输出频率。
改进 self.enhancement_module 为
        self.enhancement_module = WTConv2d(32, 32)
2024-10-09 11:35:06 +08:00
96ce7d5fda 修改代码结构,提高可读性和可维护性;调整训练输出频率。
改进 self.enhancement_module 为
        self.enhancement_module = WTConv2d(32, 32)
2024-10-08 16:50:11 +08:00
afd55abe9e 模型结构
DetailFeatureExtraction增加了一个增强残差
BaseFeatureExtraction增加了
x = self.WTConv2d(x)
2024-10-07 15:24:33 +08:00
594 changed files with 2067 additions and 71 deletions

5
.gitignore Normal file
View File

@ -0,0 +1,5 @@
.idea/
status.md
data/
test_img/
test_result/

8
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

12
.idea/PFCFuse.iml Normal file
View File

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

78
.idea/deployment.xml Normal file
View File

@ -0,0 +1,78 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="star@192.168.50.108:22 password (9)" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="star@192.168.50.108:22 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (3)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (4)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (5)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (6)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (7)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (8)">
<serverdata>
<mappings>
<mapping deploy="/home/star/whaiDir/PFCFuse" local="$PROJECT_DIR$" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (9)">
<serverdata>
<mappings>
<mapping deploy="/home/star/whaiDir/PFCFuse" local="$PROJECT_DIR$" />
</mappings>
</serverdata>
</paths>
<paths name="v100">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
<option name="myAutoUpload" value="ALWAYS" />
</component>
</project>

15
.idea/git_toolbox_prj.xml Normal file
View File

@ -0,0 +1,15 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="GitToolBoxProjectSettings">
<option name="commitMessageIssueKeyValidationOverride">
<BoolValueOverride>
<option name="enabled" value="true" />
</BoolValueOverride>
</option>
<option name="commitMessageValidationEnabledOverride">
<BoolValueOverride>
<option name="enabled" value="true" />
</BoolValueOverride>
</option>
</component>
</project>

View File

@ -0,0 +1,264 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<Languages>
<language minSize="226" name="Python" />
</Languages>
</inspection_tool>
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="245">
<item index="0" class="java.lang.String" itemvalue="numba" />
<item index="1" class="java.lang.String" itemvalue="tensorflow-estimator" />
<item index="2" class="java.lang.String" itemvalue="greenlet" />
<item index="3" class="java.lang.String" itemvalue="Babel" />
<item index="4" class="java.lang.String" itemvalue="scikit-learn" />
<item index="5" class="java.lang.String" itemvalue="testpath" />
<item index="6" class="java.lang.String" itemvalue="py" />
<item index="7" class="java.lang.String" itemvalue="gitdb" />
<item index="8" class="java.lang.String" itemvalue="torchvision" />
<item index="9" class="java.lang.String" itemvalue="patsy" />
<item index="10" class="java.lang.String" itemvalue="mccabe" />
<item index="11" class="java.lang.String" itemvalue="bleach" />
<item index="12" class="java.lang.String" itemvalue="lxml" />
<item index="13" class="java.lang.String" itemvalue="torchaudio" />
<item index="14" class="java.lang.String" itemvalue="jsonschema" />
<item index="15" class="java.lang.String" itemvalue="xlrd" />
<item index="16" class="java.lang.String" itemvalue="Werkzeug" />
<item index="17" class="java.lang.String" itemvalue="anaconda-project" />
<item index="18" class="java.lang.String" itemvalue="tensorboard-data-server" />
<item index="19" class="java.lang.String" itemvalue="typing-extensions" />
<item index="20" class="java.lang.String" itemvalue="click" />
<item index="21" class="java.lang.String" itemvalue="regex" />
<item index="22" class="java.lang.String" itemvalue="fastcache" />
<item index="23" class="java.lang.String" itemvalue="tensorboard" />
<item index="24" class="java.lang.String" itemvalue="imageio" />
<item index="25" class="java.lang.String" itemvalue="pytest-remotedata" />
<item index="26" class="java.lang.String" itemvalue="matplotlib" />
<item index="27" class="java.lang.String" itemvalue="idna" />
<item index="28" class="java.lang.String" itemvalue="Bottleneck" />
<item index="29" class="java.lang.String" itemvalue="rsa" />
<item index="30" class="java.lang.String" itemvalue="networkx" />
<item index="31" class="java.lang.String" itemvalue="pycurl" />
<item index="32" class="java.lang.String" itemvalue="smmap" />
<item index="33" class="java.lang.String" itemvalue="pluggy" />
<item index="34" class="java.lang.String" itemvalue="cffi" />
<item index="35" class="java.lang.String" itemvalue="pep8" />
<item index="36" class="java.lang.String" itemvalue="numpy" />
<item index="37" class="java.lang.String" itemvalue="jdcal" />
<item index="38" class="java.lang.String" itemvalue="alabaster" />
<item index="39" class="java.lang.String" itemvalue="jupyter" />
<item index="40" class="java.lang.String" itemvalue="pyOpenSSL" />
<item index="41" class="java.lang.String" itemvalue="PyWavelets" />
<item index="42" class="java.lang.String" itemvalue="prompt-toolkit" />
<item index="43" class="java.lang.String" itemvalue="QtAwesome" />
<item index="44" class="java.lang.String" itemvalue="msgpack-python" />
<item index="45" class="java.lang.String" itemvalue="Flask-Cors" />
<item index="46" class="java.lang.String" itemvalue="glob2" />
<item index="47" class="java.lang.String" itemvalue="Send2Trash" />
<item index="48" class="java.lang.String" itemvalue="imagesize" />
<item index="49" class="java.lang.String" itemvalue="et-xmlfile" />
<item index="50" class="java.lang.String" itemvalue="pathlib2" />
<item index="51" class="java.lang.String" itemvalue="docker-pycreds" />
<item index="52" class="java.lang.String" itemvalue="importlib-resources" />
<item index="53" class="java.lang.String" itemvalue="pathtools" />
<item index="54" class="java.lang.String" itemvalue="spyder" />
<item index="55" class="java.lang.String" itemvalue="pylint" />
<item index="56" class="java.lang.String" itemvalue="statsmodels" />
<item index="57" class="java.lang.String" itemvalue="tensorboardX" />
<item index="58" class="java.lang.String" itemvalue="isort" />
<item index="59" class="java.lang.String" itemvalue="ruamel_yaml" />
<item index="60" class="java.lang.String" itemvalue="pytz" />
<item index="61" class="java.lang.String" itemvalue="unicodecsv" />
<item index="62" class="java.lang.String" itemvalue="pytest-astropy" />
<item index="63" class="java.lang.String" itemvalue="traitlets" />
<item index="64" class="java.lang.String" itemvalue="absl-py" />
<item index="65" class="java.lang.String" itemvalue="protobuf" />
<item index="66" class="java.lang.String" itemvalue="nltk" />
<item index="67" class="java.lang.String" itemvalue="partd" />
<item index="68" class="java.lang.String" itemvalue="promise" />
<item index="69" class="java.lang.String" itemvalue="gast" />
<item index="70" class="java.lang.String" itemvalue="filelock" />
<item index="71" class="java.lang.String" itemvalue="numpydoc" />
<item index="72" class="java.lang.String" itemvalue="pyzmq" />
<item index="73" class="java.lang.String" itemvalue="oauthlib" />
<item index="74" class="java.lang.String" itemvalue="astropy" />
<item index="75" class="java.lang.String" itemvalue="keras" />
<item index="76" class="java.lang.String" itemvalue="entrypoints" />
<item index="77" class="java.lang.String" itemvalue="bkcharts" />
<item index="78" class="java.lang.String" itemvalue="pyparsing" />
<item index="79" class="java.lang.String" itemvalue="munch" />
<item index="80" class="java.lang.String" itemvalue="sphinxcontrib-websupport" />
<item index="81" class="java.lang.String" itemvalue="beautifulsoup4" />
<item index="82" class="java.lang.String" itemvalue="path.py" />
<item index="83" class="java.lang.String" itemvalue="clyent" />
<item index="84" class="java.lang.String" itemvalue="navigator-updater" />
<item index="85" class="java.lang.String" itemvalue="tifffile" />
<item index="86" class="java.lang.String" itemvalue="cryptography" />
<item index="87" class="java.lang.String" itemvalue="pygdal" />
<item index="88" class="java.lang.String" itemvalue="fastrlock" />
<item index="89" class="java.lang.String" itemvalue="widgetsnbextension" />
<item index="90" class="java.lang.String" itemvalue="multipledispatch" />
<item index="91" class="java.lang.String" itemvalue="numexpr" />
<item index="92" class="java.lang.String" itemvalue="jupyter-core" />
<item index="93" class="java.lang.String" itemvalue="ipython_genutils" />
<item index="94" class="java.lang.String" itemvalue="yapf" />
<item index="95" class="java.lang.String" itemvalue="rope" />
<item index="96" class="java.lang.String" itemvalue="wcwidth" />
<item index="97" class="java.lang.String" itemvalue="cupy-cuda110" />
<item index="98" class="java.lang.String" itemvalue="llvmlite" />
<item index="99" class="java.lang.String" itemvalue="Jinja2" />
<item index="100" class="java.lang.String" itemvalue="pycrypto" />
<item index="101" class="java.lang.String" itemvalue="Keras-Preprocessing" />
<item index="102" class="java.lang.String" itemvalue="ptflops" />
<item index="103" class="java.lang.String" itemvalue="cupy-cuda111" />
<item index="104" class="java.lang.String" itemvalue="cupy-cuda114" />
<item index="105" class="java.lang.String" itemvalue="some-package" />
<item index="106" class="java.lang.String" itemvalue="wandb" />
<item index="107" class="java.lang.String" itemvalue="netaddr" />
<item index="108" class="java.lang.String" itemvalue="sortedcollections" />
<item index="109" class="java.lang.String" itemvalue="six" />
<item index="110" class="java.lang.String" itemvalue="timm" />
<item index="111" class="java.lang.String" itemvalue="pyflakes" />
<item index="112" class="java.lang.String" itemvalue="asn1crypto" />
<item index="113" class="java.lang.String" itemvalue="parso" />
<item index="114" class="java.lang.String" itemvalue="pytest-doctestplus" />
<item index="115" class="java.lang.String" itemvalue="ipython" />
<item index="116" class="java.lang.String" itemvalue="xlwt" />
<item index="117" class="java.lang.String" itemvalue="packaging" />
<item index="118" class="java.lang.String" itemvalue="chardet" />
<item index="119" class="java.lang.String" itemvalue="jupyterlab-launcher" />
<item index="120" class="java.lang.String" itemvalue="click-plugins" />
<item index="121" class="java.lang.String" itemvalue="PyYAML" />
<item index="122" class="java.lang.String" itemvalue="pickleshare" />
<item index="123" class="java.lang.String" itemvalue="pycparser" />
<item index="124" class="java.lang.String" itemvalue="pyasn1-modules" />
<item index="125" class="java.lang.String" itemvalue="tables" />
<item index="126" class="java.lang.String" itemvalue="Pygments" />
<item index="127" class="java.lang.String" itemvalue="sentry-sdk" />
<item index="128" class="java.lang.String" itemvalue="docutils" />
<item index="129" class="java.lang.String" itemvalue="gevent" />
<item index="130" class="java.lang.String" itemvalue="shortuuid" />
<item index="131" class="java.lang.String" itemvalue="qtconsole" />
<item index="132" class="java.lang.String" itemvalue="terminado" />
<item index="133" class="java.lang.String" itemvalue="GitPython" />
<item index="134" class="java.lang.String" itemvalue="distributed" />
<item index="135" class="java.lang.String" itemvalue="jupyter-client" />
<item index="136" class="java.lang.String" itemvalue="pexpect" />
<item index="137" class="java.lang.String" itemvalue="ipykernel" />
<item index="138" class="java.lang.String" itemvalue="nbconvert" />
<item index="139" class="java.lang.String" itemvalue="attrs" />
<item index="140" class="java.lang.String" itemvalue="psutil" />
<item index="141" class="java.lang.String" itemvalue="simplejson" />
<item index="142" class="java.lang.String" itemvalue="jedi" />
<item index="143" class="java.lang.String" itemvalue="flatbuffers" />
<item index="144" class="java.lang.String" itemvalue="cytoolz" />
<item index="145" class="java.lang.String" itemvalue="odo" />
<item index="146" class="java.lang.String" itemvalue="decorator" />
<item index="147" class="java.lang.String" itemvalue="pandocfilters" />
<item index="148" class="java.lang.String" itemvalue="backports.shutil-get-terminal-size" />
<item index="149" class="java.lang.String" itemvalue="pycodestyle" />
<item index="150" class="java.lang.String" itemvalue="pycosat" />
<item index="151" class="java.lang.String" itemvalue="pyasn1" />
<item index="152" class="java.lang.String" itemvalue="requests" />
<item index="153" class="java.lang.String" itemvalue="bitarray" />
<item index="154" class="java.lang.String" itemvalue="kornia" />
<item index="155" class="java.lang.String" itemvalue="mkl-fft" />
<item index="156" class="java.lang.String" itemvalue="tensorflow" />
<item index="157" class="java.lang.String" itemvalue="XlsxWriter" />
<item index="158" class="java.lang.String" itemvalue="seaborn" />
<item index="159" class="java.lang.String" itemvalue="tensorboard-plugin-wit" />
<item index="160" class="java.lang.String" itemvalue="blaze" />
<item index="161" class="java.lang.String" itemvalue="zipp" />
<item index="162" class="java.lang.String" itemvalue="pkginfo" />
<item index="163" class="java.lang.String" itemvalue="cached-property" />
<item index="164" class="java.lang.String" itemvalue="torchstat" />
<item index="165" class="java.lang.String" itemvalue="datashape" />
<item index="166" class="java.lang.String" itemvalue="itsdangerous" />
<item index="167" class="java.lang.String" itemvalue="ipywidgets" />
<item index="168" class="java.lang.String" itemvalue="scipy" />
<item index="169" class="java.lang.String" itemvalue="thop" />
<item index="170" class="java.lang.String" itemvalue="tornado" />
<item index="171" class="java.lang.String" itemvalue="google-auth-oauthlib" />
<item index="172" class="java.lang.String" itemvalue="opencv-python" />
<item index="173" class="java.lang.String" itemvalue="torch" />
<item index="174" class="java.lang.String" itemvalue="singledispatch" />
<item index="175" class="java.lang.String" itemvalue="sortedcontainers" />
<item index="176" class="java.lang.String" itemvalue="mistune" />
<item index="177" class="java.lang.String" itemvalue="pandas" />
<item index="178" class="java.lang.String" itemvalue="termcolor" />
<item index="179" class="java.lang.String" itemvalue="clang" />
<item index="180" class="java.lang.String" itemvalue="toolz" />
<item index="181" class="java.lang.String" itemvalue="Sphinx" />
<item index="182" class="java.lang.String" itemvalue="mpmath" />
<item index="183" class="java.lang.String" itemvalue="jupyter-console" />
<item index="184" class="java.lang.String" itemvalue="bokeh" />
<item index="185" class="java.lang.String" itemvalue="cachetools" />
<item index="186" class="java.lang.String" itemvalue="gmpy2" />
<item index="187" class="java.lang.String" itemvalue="setproctitle" />
<item index="188" class="java.lang.String" itemvalue="webencodings" />
<item index="189" class="java.lang.String" itemvalue="html5lib" />
<item index="190" class="java.lang.String" itemvalue="colorlog" />
<item index="191" class="java.lang.String" itemvalue="python-dateutil" />
<item index="192" class="java.lang.String" itemvalue="QtPy" />
<item index="193" class="java.lang.String" itemvalue="astroid" />
<item index="194" class="java.lang.String" itemvalue="cycler" />
<item index="195" class="java.lang.String" itemvalue="mkl-random" />
<item index="196" class="java.lang.String" itemvalue="pytest-arraydiff" />
<item index="197" class="java.lang.String" itemvalue="locket" />
<item index="198" class="java.lang.String" itemvalue="heapdict" />
<item index="199" class="java.lang.String" itemvalue="snowballstemmer" />
<item index="200" class="java.lang.String" itemvalue="contextlib2" />
<item index="201" class="java.lang.String" itemvalue="certifi" />
<item index="202" class="java.lang.String" itemvalue="Markdown" />
<item index="203" class="java.lang.String" itemvalue="sympy" />
<item index="204" class="java.lang.String" itemvalue="notebook" />
<item index="205" class="java.lang.String" itemvalue="pyodbc" />
<item index="206" class="java.lang.String" itemvalue="boto" />
<item index="207" class="java.lang.String" itemvalue="cligj" />
<item index="208" class="java.lang.String" itemvalue="h5py" />
<item index="209" class="java.lang.String" itemvalue="wrapt" />
<item index="210" class="java.lang.String" itemvalue="kiwisolver" />
<item index="211" class="java.lang.String" itemvalue="pytest-openfiles" />
<item index="212" class="java.lang.String" itemvalue="anaconda-client" />
<item index="213" class="java.lang.String" itemvalue="backcall" />
<item index="214" class="java.lang.String" itemvalue="PySocks" />
<item index="215" class="java.lang.String" itemvalue="charset-normalizer" />
<item index="216" class="java.lang.String" itemvalue="typing" />
<item index="217" class="java.lang.String" itemvalue="dask" />
<item index="218" class="java.lang.String" itemvalue="enum34" />
<item index="219" class="java.lang.String" itemvalue="torchsummary" />
<item index="220" class="java.lang.String" itemvalue="scikit-image" />
<item index="221" class="java.lang.String" itemvalue="ptyprocess" />
<item index="222" class="java.lang.String" itemvalue="more-itertools" />
<item index="223" class="java.lang.String" itemvalue="SQLAlchemy" />
<item index="224" class="java.lang.String" itemvalue="tblib" />
<item index="225" class="java.lang.String" itemvalue="cloudpickle" />
<item index="226" class="java.lang.String" itemvalue="importlib-metadata" />
<item index="227" class="java.lang.String" itemvalue="simplegeneric" />
<item index="228" class="java.lang.String" itemvalue="zict" />
<item index="229" class="java.lang.String" itemvalue="urllib3" />
<item index="230" class="java.lang.String" itemvalue="jupyterlab" />
<item index="231" class="java.lang.String" itemvalue="Cython" />
<item index="232" class="java.lang.String" itemvalue="Flask" />
<item index="233" class="java.lang.String" itemvalue="nose" />
<item index="234" class="java.lang.String" itemvalue="pytorch-msssim" />
<item index="235" class="java.lang.String" itemvalue="pytest" />
<item index="236" class="java.lang.String" itemvalue="nbformat" />
<item index="237" class="java.lang.String" itemvalue="matmul" />
<item index="238" class="java.lang.String" itemvalue="tqdm" />
<item index="239" class="java.lang.String" itemvalue="lazy-object-proxy" />
<item index="240" class="java.lang.String" itemvalue="colorama" />
<item index="241" class="java.lang.String" itemvalue="grpcio" />
<item index="242" class="java.lang.String" itemvalue="ply" />
<item index="243" class="java.lang.String" itemvalue="google-auth" />
<item index="244" class="java.lang.String" itemvalue="openpyxl" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

16
.idea/misc.xml Normal file
View File

@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.9 (flaskTest)" />
</component>
<component name="MavenImportPreferences">
<option name="generalSettings">
<MavenGeneralSettings>
<option name="localRepository" value="E:\maven\repository" />
<option name="showDialogWithAdvancedSettings" value="true" />
<option name="userSettingsFile" value="E:\maven\settings.xml" />
</MavenGeneralSettings>
</option>
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (pfcfuse)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/PFCFuse.iml" filepath="$PROJECT_DIR$/.idea/PFCFuse.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

View File

@ -0,0 +1,116 @@
import torch
import torch.nn as nn
from timm.layers.helpers import to_2tuple
"""
配备多头自注意力 MHSA 的模型在计算机视觉方面取得了显着的性能它们的计算复杂度与输入特征图中的二次像素数成正比导致处理速度缓慢尤其是在处理高分辨率图像时
为了规避这个问题提出了一种新型的代币混合器作为MHSA的替代方案基于FFT的代币混合器涉及类似于MHSA的全局操作但计算复杂度较低
在这里我们提出了一种名为动态过滤器的新型令牌混合器以缩小上述差距
DynamicFilter 模块通过频域滤波和动态调整滤波器权重能够对图像进行复杂的增强和处理
"""
class StarReLU(nn.Module):
"""
StarReLU: s * relu(x) ** 2 + b
"""
def __init__(self, scale_value=1.0, bias_value=0.0,
scale_learnable=True, bias_learnable=True,
mode=None, inplace=False):
super().__init__()
self.inplace = inplace
self.relu = nn.ReLU(inplace=inplace)
self.scale = nn.Parameter(scale_value * torch.ones(1),
requires_grad=scale_learnable)
self.bias = nn.Parameter(bias_value * torch.ones(1),
requires_grad=bias_learnable)
def forward(self, x):
return self.scale * self.relu(x) ** 2 + self.bias
class Mlp(nn.Module):
""" MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
Mostly copied from timm.
"""
def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0.,
bias=False, **kwargs):
super().__init__()
in_features = dim
out_features = out_features or in_features
hidden_features = int(mlp_ratio * in_features)
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class DynamicFilter(nn.Module):
def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25,
act1_layer=StarReLU, act2_layer=nn.Identity,
bias=False, num_filters=4, size=14, weight_resize=False,
**kwargs):
super().__init__()
size = to_2tuple(size)
self.size = size[0]
self.filter_size = size[1] // 2 + 1
self.num_filters = num_filters
self.dim = dim
self.med_channels = int(expansion_ratio * dim)
self.weight_resize = weight_resize
self.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias)
self.act1 = act1_layer()
self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels)
self.complex_weights = nn.Parameter(
torch.randn(self.size, self.filter_size, num_filters, 2,
dtype=torch.float32) * 0.02)
self.act2 = act2_layer()
self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias)
def forward(self, x):
B, H, W, _ = x.shape
routeing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters,
-1).softmax(dim=1)
x = self.pwconv1(x)
x = self.act1(x)
x = x.to(torch.float32)
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
if self.weight_resize:
complex_weights = resize_complex_weight(self.complex_weights, x.shape[1],
x.shape[2])
complex_weights = torch.view_as_complex(complex_weights.contiguous())
else:
complex_weights = torch.view_as_complex(self.complex_weights)
routeing = routeing.to(torch.complex64)
weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights)
if self.weight_resize:
weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels)
else:
weight = weight.view(-1, self.size, self.filter_size, self.med_channels)
x = x * weight
x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
x = self.act2(x)
x = self.pwconv2(x)
return x
if __name__ == '__main__':
block = DynamicFilter(32, size=64) # size==H,W
input = torch.rand(3, 64, 64, 32)
output = block(input)
print(input.size())
print(output.size())

Binary file not shown.

156
componets/SCSA.py Normal file
View File

@ -0,0 +1,156 @@
import typing as t
import torch
import torch.nn as nn
from einops.einops import rearrange
from mmengine.model import BaseModule
__all__ = ['SCSA']
"""SCSA探索空间注意力和通道注意力之间的协同作用
通道和空间注意力分别在为各种下游视觉任务提取特征依赖性和空间结构关系方面带来了显着的改进
虽然它们的结合更有利于发挥各自的优势但通道和空间注意力之间的协同作用尚未得到充分探索缺乏充分利用多语义信息的协同潜力来进行特征引导和缓解语义差异
我们的研究试图在多个语义层面揭示空间和通道注意力之间的协同关系提出了一种新颖的空间和通道协同注意力模块SCSA我们的SCSA由两部分组成可共享的多语义空间注意力SMSA和渐进式通道自注意力PCSA
SMSA 集成多语义信息并利用渐进式压缩策略将判别性空间先验注入 PCSA 的通道自注意力中有效地指导通道重新校准此外PCSA 中基于自注意力机制的稳健特征交互进一步缓解了 SMSA 中不同子特征之间多语义信息的差异
我们在七个基准数据集上进行了广泛的实验包括 ImageNet-1K 上的分类MSCOCO 2017 上的对象检测ADE20K 上的分割以及其他四个复杂场景检测数据集我们的结果表明我们提出的 SCSA 不仅超越了当前最先进的注意力机制
而且在各种任务场景中表现出增强的泛化能力
"""
class SCSA(BaseModule):
def __init__(
self,
dim: int,
head_num: int,
window_size: int = 7,
group_kernel_sizes: t.List[int] = [3, 5, 7, 9],
qkv_bias: bool = False,
fuse_bn: bool = False,
norm_cfg: t.Dict = dict(type='BN'),
act_cfg: t.Dict = dict(type='ReLU'),
down_sample_mode: str = 'avg_pool',
attn_drop_ratio: float = 0.,
gate_layer: str = 'sigmoid',
):
super(SCSA, self).__init__()
self.dim = dim
self.head_num = head_num
self.head_dim = dim // head_num
self.scaler = self.head_dim ** -0.5
self.group_kernel_sizes = group_kernel_sizes
self.window_size = window_size
self.qkv_bias = qkv_bias
self.fuse_bn = fuse_bn
self.down_sample_mode = down_sample_mode
assert self.dim // 4, 'The dimension of input feature should be divisible by 4.'
self.group_chans = group_chans = self.dim // 4
self.local_dwc = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[0],
padding=group_kernel_sizes[0] // 2, groups=group_chans)
self.global_dwc_s = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[1],
padding=group_kernel_sizes[1] // 2, groups=group_chans)
self.global_dwc_m = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[2],
padding=group_kernel_sizes[2] // 2, groups=group_chans)
self.global_dwc_l = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[3],
padding=group_kernel_sizes[3] // 2, groups=group_chans)
self.sa_gate = nn.Softmax(dim=2) if gate_layer == 'softmax' else nn.Sigmoid()
self.norm_h = nn.GroupNorm(4, dim)
self.norm_w = nn.GroupNorm(4, dim)
self.conv_d = nn.Identity()
self.norm = nn.GroupNorm(1, dim)
self.q = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.k = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.v = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.ca_gate = nn.Softmax(dim=1) if gate_layer == 'softmax' else nn.Sigmoid()
if window_size == -1:
self.down_func = nn.AdaptiveAvgPool2d((1, 1))
else:
if down_sample_mode == 'recombination':
self.down_func = self.space_to_chans
# dimensionality reduction
self.conv_d = nn.Conv2d(in_channels=dim * window_size ** 2, out_channels=dim, kernel_size=1, bias=False)
elif down_sample_mode == 'avg_pool':
self.down_func = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=window_size)
elif down_sample_mode == 'max_pool':
self.down_func = nn.MaxPool2d(kernel_size=(window_size, window_size), stride=window_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
The dim of x is (B, C, H, W)
"""
# Spatial attention priority calculation
b, c, h_, w_ = x.size()
# (B, C, H)
x_h = x.mean(dim=3)
l_x_h, g_x_h_s, g_x_h_m, g_x_h_l = torch.split(x_h, self.group_chans, dim=1)
# (B, C, W)
x_w = x.mean(dim=2)
l_x_w, g_x_w_s, g_x_w_m, g_x_w_l = torch.split(x_w, self.group_chans, dim=1)
x_h_attn = self.sa_gate(self.norm_h(torch.cat((
self.local_dwc(l_x_h),
self.global_dwc_s(g_x_h_s),
self.global_dwc_m(g_x_h_m),
self.global_dwc_l(g_x_h_l),
), dim=1)))
x_h_attn = x_h_attn.view(b, c, h_, 1)
x_w_attn = self.sa_gate(self.norm_w(torch.cat((
self.local_dwc(l_x_w),
self.global_dwc_s(g_x_w_s),
self.global_dwc_m(g_x_w_m),
self.global_dwc_l(g_x_w_l)
), dim=1)))
x_w_attn = x_w_attn.view(b, c, 1, w_)
x = x * x_h_attn * x_w_attn
# Channel attention based on self attention
# reduce calculations
y = self.down_func(x)
y = self.conv_d(y)
_, _, h_, w_ = y.size()
# normalization first, then reshape -> (B, H, W, C) -> (B, C, H * W) and generate q, k and v
y = self.norm(y)
q = self.q(y)
k = self.k(y)
v = self.v(y)
# (B, C, H, W) -> (B, head_num, head_dim, N)
q = rearrange(q, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
k = rearrange(k, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
v = rearrange(v, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
# (B, head_num, head_dim, head_dim)
attn = q @ k.transpose(-2, -1) * self.scaler
attn = self.attn_drop(attn.softmax(dim=-1))
# (B, head_num, head_dim, N)
attn = attn @ v
# (B, C, H_, W_)
attn = rearrange(attn, 'b head_num head_dim (h w) -> b (head_num head_dim) h w', h=int(h_), w=int(w_))
# (B, C, 1, 1)
attn = attn.mean((2, 3), keepdim=True)
attn = self.ca_gate(attn)
return attn * x
if __name__ == '__main__':
block = SCSA(
dim=256,
head_num=8,
)
input_tensor = torch.rand(1, 256, 32, 32)
# 调用模块进行前向传播
output_tensor = block(input_tensor)
# 打印输入和输出张量的大小
print("Input size:", input_tensor.size())
print("Output size:", output_tensor.size())

65
componets/SMFA.py Normal file
View File

@ -0,0 +1,65 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
"""ECCV2024(https://github.com/Zheng-MJ/SMFANet)
基于Transformer的恢复方法取得了显著的效果因为Transformer的自注意力机制SA可以探索非局部信息从而实现更好的高分辨率图像重建然而关键的点积自注意力需要大量的计算资源这限制了其在低功耗设备上的应用
此外自注意力机制的低通滤波特性限制了其捕捉局部细节的能力从而导致重建结果过于平滑为了解决这些问题我们提出了一种自调制特征聚合SMFA模块协同利用局部和非局部特征交互以实现更精确的重建
具体而言SMFA模块采用了高效的自注意力近似EASA分支来建模非局部信息并使用局部细节估计LDE分支来捕捉局部细节此外我们还引入了基于部分卷积的前馈网络PCFN以进一步优化从SMFA提取的代表性特征
大量实验表明所提出的SMFANet系列在公共基准数据集上实现了更好的重建性能与计算效率的平衡
特别是与SwinIR-light的×4放大相比SMFANet+在五个公共测试集上的平均性能提高了0.14dB运行速度提升了约10倍且模型复杂度如FLOPs仅为其约43%
"""
class DMlp(nn.Module):
def __init__(self, dim, growth_rate=2.0):
super().__init__()
hidden_dim = int(dim * growth_rate)
self.conv_0 = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 3, 1, 1, groups=dim),
nn.Conv2d(hidden_dim, hidden_dim, 1, 1, 0)
)
self.act = nn.GELU()
self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
def forward(self, x):
x = self.conv_0(x)
x = self.act(x)
x = self.conv_1(x)
return x
class SMFA(nn.Module):
def __init__(self, dim=36):
super(SMFA, self).__init__()
self.linear_0 = nn.Conv2d(dim, dim * 2, 1, 1, 0)
self.linear_1 = nn.Conv2d(dim, dim, 1, 1, 0)
self.linear_2 = nn.Conv2d(dim, dim, 1, 1, 0)
self.lde = DMlp(dim, 2)
self.dw_conv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
self.gelu = nn.GELU()
self.down_scale = 8
self.alpha = nn.Parameter(torch.ones((1, dim, 1, 1)))
self.belt = nn.Parameter(torch.zeros((1, dim, 1, 1)))
def forward(self, f):
_, _, h, w = f.shape
y, x = self.linear_0(f).chunk(2, dim=1)
x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)))
x_v = torch.var(x, dim=(-2, -1), keepdim=True)
x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h, w),
mode='nearest')
y_d = self.lde(y)
return self.linear_2(x_l + y_d)
if __name__ == '__main__':
block = SMFA(dim=36)
input = torch.randn(3, 36, 64, 64)
output = block(input)
print(input.size())
print(output.size())

110
componets/TIAM.py Normal file
View File

@ -0,0 +1,110 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
"""Elsevier2024
变化检测 (CD) 是地球观测中一种重要的监测方法尤其适用于土地利用分析城市管理和灾害损失评估然而在星座互联和空天协作时代感兴趣区域 (ROI) 的变化由于几何透视旋转和时间风格差异而导致许多错误检测
为了应对这些挑战我们引入了 CDNeXt该框架阐明了一种稳健而有效的方法用于将基于预训练主干的 Siamese 网络与用于遥感图像的创新时空交互注意模块 (TIAM) 相结合
CDNeXt 可分为四个主要组件编码器交互器解码器和检测器值得注意的是 TIAM 提供支持的交互器从编码器提取的二进制时间特征中查询和重建空间透视依赖关系和时间风格相关性以扩大 ROI 变化的差异
最后检测器集成解码器生成的分层特征随后生成二进制变化掩码
"""
class SpatiotemporalAttentionFullNotWeightShared(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
super(SpatiotemporalAttentionFullNotWeightShared, self).__init__()
assert dimension in [2, ]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
self.g1 = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
)
self.g2 = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.W1 = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels)
)
self.W2 = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels)
)
self.theta = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.phi = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
def forward(self, x1, x2):
"""
:param x: (b, c, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
"""
batch_size = x1.size(0)
g_x11 = self.g1(x1).reshape(batch_size, self.inter_channels, -1)
g_x12 = g_x11.permute(0, 2, 1)
g_x21 = self.g2(x2).reshape(batch_size, self.inter_channels, -1)
g_x22 = g_x21.permute(0, 2, 1)
theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
theta_x2 = theta_x1.permute(0, 2, 1)
phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)
phi_x2 = phi_x1.permute(0, 2, 1)
energy_time_1 = torch.matmul(theta_x1, phi_x2)
energy_time_2 = energy_time_1.permute(0, 2, 1)
energy_space_1 = torch.matmul(theta_x2, phi_x1)
energy_space_2 = energy_space_1.permute(0, 2, 1)
energy_time_1s = F.softmax(energy_time_1, dim=-1)
energy_time_2s = F.softmax(energy_time_2, dim=-1)
energy_space_2s = F.softmax(energy_space_1, dim=-2)
energy_space_1s = F.softmax(energy_space_2, dim=-2)
# C1*S(C2) energy_time_1s * C1*H1W1 g_x12 * energy_space_1s S(H2W2)*H1W1 -> C1*H1W1
y1 = torch.matmul(torch.matmul(energy_time_2s, g_x11), energy_space_2s).contiguous() # C2*H2W2
# C2*S(C1) energy_time_2s * C2*H2W2 g_x21 * energy_space_2s S(H1W1)*H2W2 -> C2*H2W2
y2 = torch.matmul(torch.matmul(energy_time_1s, g_x21), energy_space_1s).contiguous() # C1*H1W1
y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
return x1 + self.W1(y1), x2 + self.W2(y2)
if __name__ == '__main__':
in_channels = 64
batch_size = 8
height = 32
width = 32
block = SpatiotemporalAttentionFullNotWeightShared(in_channels=in_channels)
input1 = torch.rand(batch_size, in_channels, height, width)
input2 = torch.rand(batch_size, in_channels, height, width)
output1, output2 = block(input1, input2)
print(f"Input1 size: {input1.size()}")
print(f"Input2 size: {input2.size()}")
print(f"Output1 size: {output1.size()}")
print(f"Output2 size: {output2.size()}")

Binary file not shown.

123
componets/UFFC.py Normal file
View File

@ -0,0 +1,123 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
"""ICCV2023
最近提出的图像修复方法 LaMa 以快速傅里叶卷积 (FFC) 为基础构建了其网络该网络最初是为图像分类等高级视觉任务而提出的
FFC 使全卷积网络在其早期层中拥有全局感受野得益于 FFC 模块的独特特性LaMa 能够生成稳健的重复纹理
这是以前的修复方法无法实现的但是原始 FFC 模块是否适合图像修复等低级视觉任务
在本文中我们分析了在图像修复中使用 FFC 的基本缺陷 1) 频谱偏移2) 意外的空间激活和 3) 频率感受野有限
这些缺陷使得基于 FFC 的修复框架难以生成复杂纹理并执行完美重建
基于以上分析我们提出了一种新颖的无偏快速傅里叶卷积 (UFFC) 模块该模块通过
1) 范围变换和逆变换2) 绝对位置嵌入3) 动态跳过连接和 4) 自适应剪辑对原始 FFC 模块进行了修改以克服这些缺陷
实现更好的修复效果在多个基准数据集上进行的大量实验证明了我们方法的有效性在纹理捕捉能力和表现力方面均优于最先进的方法
"""
class FourierUnit_modified(nn.Module):
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
spectral_pos_encoding=False, use_se=False, ffc3d=False, fft_norm='ortho'):
# bn_layer not used
super(FourierUnit_modified, self).__init__()
self.groups = groups
self.input_shape = 32 # change!!!!!it!!!!!!manually!!!!!!
self.in_channels = in_channels
self.locMap = nn.Parameter(torch.rand(self.input_shape, self.input_shape // 2 + 1))
self.lambda_base = nn.Parameter(torch.tensor(0.), requires_grad=True)
self.conv_layer_down55 = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
out_channels=out_channels * 2,
kernel_size=1, stride=1, padding=0, dilation=1, groups=self.groups,
bias=False, padding_mode='reflect')
self.conv_layer_down55_shift = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
out_channels=out_channels * 2,
kernel_size=3, stride=1, padding=2, dilation=2,
groups=self.groups, bias=False, padding_mode='reflect')
self.norm = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.spatial_scale_factor = spatial_scale_factor
self.spatial_scale_mode = spatial_scale_mode
self.spectral_pos_encoding = spectral_pos_encoding
self.ffc3d = ffc3d
self.fft_norm = fft_norm
self.img_freq = None
self.distill = None
def forward(self, x):
batch = x.shape[0]
if self.spatial_scale_factor is not None:
orig_size = x.shape[-2:]
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode,
align_corners=False)
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
locMap = self.locMap.expand_as(ffted[:, :1, :, :]) # B 1 H' W'
ffted_copy = ffted.clone()
cat_img_mask_freq = torch.cat((ffted[:, :self.in_channels, :, :],
ffted[:, self.in_channels:, :, :],
locMap), dim=1)
ffted = self.conv_layer_down55(cat_img_mask_freq)
ffted = torch.fft.fftshift(ffted, dim=-2)
ffted = self.relu(ffted)
locMap_shift = torch.fft.fftshift(locMap, dim=-2) ## ONLY IF NOT SHIFT BACK
# REPEAT CONV
cat_img_mask_freq1 = torch.cat((ffted[:, :self.in_channels, :, :],
ffted[:, self.in_channels:, :, :],
locMap_shift), dim=1)
ffted = self.conv_layer_down55_shift(cat_img_mask_freq1)
ffted = torch.fft.fftshift(ffted, dim=-2)
lambda_base = torch.sigmoid(self.lambda_base)
ffted = ffted_copy * lambda_base + ffted * (1 - lambda_base)
# irfft
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
if self.spatial_scale_factor is not None:
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
epsilon = 0.5
output = output - torch.mean(output) + torch.mean(x)
output = torch.clip(output, float(x.min() - epsilon), float(x.max() + epsilon))
self.distill = output # for self perc
return output
if __name__ == '__main__':
in_channels = 16
out_channels = 16
block = FourierUnit_modified(in_channels=in_channels, out_channels=out_channels)
input_tensor = torch.rand(8, in_channels, 32, 32)
output = block(input_tensor)
print("Input size:", input_tensor.size())
print("Output size:", output.size())

42
componets/whaiutil.py Normal file
View File

@ -0,0 +1,42 @@
import os
from PIL import Image
def transfer(input_path, quality=20, resize_factor=0.1):
# 打开TIFF图像
# img = Image.open(input_path)
#
# # 保存为JPEG并设置压缩质量
# img.save(output_path, 'JPEG', quality=quality)
# input_path = os.path.join(input_folder, filename)
# 获取input_path的文件名
# 使用os.path.splitext获取文件名和后缀的元组
# 使用os.path.basename获取文件名包含后缀
filename_with_extension = os.path.basename(input_path)
filename, file_extension = os.path.splitext(filename_with_extension)
# 使用os.path.dirname获取文件所在的目录路径
output_folder = os.path.dirname(input_path)
output_path = os.path.join(output_folder, filename + '.jpg')
img = Image.open(input_path)
# 将图像缩小到原来的一半
new_width = int(img.width * resize_factor)
new_height = int(img.height * resize_factor)
resized_img = img.resize((new_width, new_height))
# 保存为JPEG并设置压缩质量
# 转换为RGB模式丢弃透明通道
rgb_img = resized_img.convert('RGB')
# 保存为JPEG并设置压缩质量
# 压缩
rgb_img.save(output_path, 'JPEG', quality=quality)
print(f'{output_path} 转换完成')
return output_path

View File

@ -39,16 +39,17 @@ def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
ratio = (limits[1] - limits[0]) / limits[1]
return ratio < fraction_threshold
data_name="MSRS_train"
img_size=128 #patch size
data_name="YYX_sar_opr_data"
img_size=256 #patch size
stride=200 #patch stride
IR_files = sorted(get_img_file(r"MSRS_train/MSRS-main/train/ir"))
VIS_files = sorted(get_img_file(r"MSRS_train/MSRS-main/train/vi"))
IR_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/YYX-OPT-SAR-main/SAR_1"))
VIS_files = sorted(get_img_file(r"/media/star/8TB/whaiDownload/YYX-OPT-SAR-main/OPR_1"))
assert len(IR_files) == len(VIS_files)
h5f = h5py.File(os.path.join('.\\data',
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
h5path= os.path.join('/home/star/whaiDir/PFCFuse/data/',
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5')
h5f = h5py.File(h5path,
'w')
h5_ir = h5f.create_group('ir_patchs')
h5_vis = h5f.create_group('vis_patchs')
@ -80,8 +81,7 @@ for i in tqdm(range(len(IR_files))):
h5f.close()
with h5py.File(os.path.join('data',
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
with h5py.File(h5path,"r") as f:
for key in f.keys():
print(f[key], key, f[key].name)

35
logs/20241005.log Normal file
View File

@ -0,0 +1,35 @@
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
# base pcffuse
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 2.39 33.82 11.32 0.81 0.8 0.12 0.07 0.11
================================================================================
Process finished with exit code 0

33
logs/20241007_whai.log Normal file
View File

@ -0,0 +1,33 @@
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.01 40.67 15.39 1.53 1.76 0.64 0.53 0.95
================================================================================

View File

@ -0,0 +1,89 @@
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.14 46.48 13.18 2.22 1.76 0.79 0.56 1.02
================================================================================
================================================================================
The test result of RoadScene :
FLIR_07206.jpg
FLIR_08202.jpg
FLIR_05893.jpg
FLIR_06974.jpg
FLIR_04424.jpg
FLIR_08284.jpg
FLIR_07786.jpg
FLIR_08021.jpg
FLIR_07968.jpg
FLIR_01130.jpg
FLIR_06993.jpg
FLIR_07190.jpg
FLIR_06570.jpg
FLIR_07809.jpg
FLIR_06430.jpg
FLIR_08592.jpg
FLIR_00211.jpg
FLIR_08721.jpg
FLIR_05955.jpg
FLIR_04688.jpg
FLIR_07732.jpg
FLIR_06392.jpg
FLIR_00977.jpg
FLIR_05105.jpg
FLIR_04269.jpg
FLIR_07970.jpg
FLIR_05005.jpg
FLIR_07209.jpg
FLIR_07555.jpg
FLIR_06325.jpg
FLIR_04943.jpg
FLIR_video_02829.jpg
FLIR_08248.jpg
FLIR_04484.jpg
FLIR_08058.jpg
FLIR_06795.jpg
FLIR_06995.jpg
FLIR_05879.jpg
FLIR_04593.jpg
FLIR_08094.jpg
FLIR_08526.jpg
FLIR_08858.jpg
FLIR_09465.jpg
FLIR_05064.jpg
FLIR_05857.jpg
FLIR_05914.jpg
FLIR_04722.jpg
FLIR_06506.jpg
FLIR_06282.jpg
FLIR_04512.jpg
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.41 52.99 15.81 2.37 1.78 0.71 0.55 0.96
================================================================================

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,24 @@
2.4.1+cu121
True
Model: PFCFuse
Number of epochs: 60
Epoch gap: 40
Learning rate: 0.0001
Weight decay: 0
Batch size: 1
GPU number: 0
Coefficient of MSE loss VF: 1.0
Coefficient of MSE loss IF: 1.0
Coefficient of RMI loss VF: 1.0
Coefficient of RMI loss IF: 1.0
Coefficient of Cosine loss VF: 1.0
Coefficient of Cosine loss IF: 1.0
Coefficient of Decomposition loss: 2.0
Coefficient of Total Variation loss: 5.0
Clip gradient norm value: 0.01
Optimization step: 20
Optimization gamma: 0.5
[Epoch 0/60] [Batch 0/6487] [loss: 10.193036] ETA: 9 days, 21 [Epoch 0/60] [Batch 1/6487] [loss: 4.166963] ETA: 11:49:56.5 [Epoch 0/60] [Batch 2/6487] [loss: 10.681509] ETA: 10:23:19.1 [Epoch 0/60] [Batch 3/6487] [loss: 6.257133] ETA: 10:31:48.3 [Epoch 0/60] [Batch 4/6487] [loss: 13.018341] ETA: 10:32:54.2 [Epoch 0/60] [Batch 5/6487] [loss: 11.268185] ETA: 10:27:32.2 [Epoch 0/60] [Batch 6/6487] [loss: 6.920656] ETA: 10:34:01.5 [Epoch 0/60] [Batch 7/6487] [loss: 4.666215] ETA: 10:32:45.3 [Epoch 0/60] [Batch 8/6487] [loss: 10.787085] ETA: 10:26:01.9 [Epoch 0/60] [Batch 9/6487] [loss: 5.754866] ETA: 10:34:34.2 [Epoch 0/60] [Batch 10/6487] [loss: 28.760792] ETA: 10:36:32.6 [Epoch 0/60] [Batch 11/6487] [loss: 8.672796] ETA: 10:25:11.9 [Epoch 0/60] [Batch 12/6487] [loss: 14.300608] ETA: 10:28:19.9 [Epoch 0/60] [Batch 13/6487] [loss: 11.821722] ETA: 10:34:18.6 [Epoch 0/60] [Batch 14/6487] [loss: 7.627745] ETA: 10:31:44.8 [Epoch 0/60] [Batch 15/6487] [loss: 5.722600] ETA: 10:34:17.4 [Epoch 0/60] [Batch 16/6487] [loss: 10.423873] ETA: 11:33:27.1 [Epoch 0/60] [Batch 17/6487] [loss: 4.454098] ETA: 9:37:13.67 [Epoch 0/60] [Batch 18/6487] [loss: 3.820719] ETA: 9:33:57.42 [Epoch 0/60] [Batch 19/6487] [loss: 6.564124] ETA: 9:41:22.09 [Epoch 0/60] [Batch 20/6487] [loss: 5.406681] ETA: 9:47:30.11 [Epoch 0/60] [Batch 21/6487] [loss: 25.275440] ETA: 9:39:29.91 [Epoch 0/60] [Batch 22/6487] [loss: 4.228334] ETA: 9:42:15.45 [Epoch 0/60] [Batch 23/6487] [loss: 22.508118] ETA: 9:38:18.10 [Epoch 0/60] [Batch 24/6487] [loss: 5.062001] ETA: 9:46:09.29 [Epoch 0/60] [Batch 25/6487] [loss: 3.157355] ETA: 9:41:30.09 [Epoch 0/60] [Batch 26/6487] [loss: 6.438435] ETA: 10:02:51.9 [Epoch 0/60] [Batch 27/6487] [loss: 7.430470] ETA: 9:18:12.94 [Epoch 0/60] [Batch 28/6487] [loss: 3.783903] ETA: 10:41:13.9 [Epoch 0/60] [Batch 29/6487] [loss: 2.954306] ETA: 9:44:25.10 [Epoch 0/60] [Batch 30/6487] [loss: 5.863827] ETA: 9:35:13.84 [Epoch 0/60] [Batch 31/6487] [loss: 6.467144] ETA: 9:46:19.80 [Epoch 0/60] [Batch 32/6487] [loss: 4.801052] ETA: 9:32:17.18 [Epoch 0/60] [Batch 33/6487] [loss: 5.658401] ETA: 9:31:10.28 [Epoch 0/60] [Batch 34/6487] [loss: 2.085633] ETA: 9:36:47.39 [Epoch 0/60] [Batch 35/6487] [loss: 15.402915] ETA: 9:40:51.43 [Epoch 0/60] [Batch 36/6487] [loss: 3.181264] ETA: 9:33:06.65 [Epoch 0/60] [Batch 37/6487] [loss: 3.883055] ETA: 9:42:29.60 [Epoch 0/60] [Batch 38/6487] [loss: 3.342676] ETA: 10:07:02.2 [Epoch 0/60] [Batch 39/6487] [loss: 2.589705] ETA: 9:36:43.32 [Epoch 0/60] [Batch 40/6487] [loss: 3.742121] ETA: 9:42:57.54 [Epoch 0/60] [Batch 41/6487] [loss: 2.732829] ETA: 9:36:54.65 [Epoch 0/60] [Batch 42/6487] [loss: 6.655626] ETA: 9:42:20.71 [Epoch 0/60] [Batch 43/6487] [loss: 1.822412] ETA: 9:38:02.02 [Epoch 0/60] [Batch 44/6487] [loss: 2.875143] ETA: 9:41:02.96 [Epoch 0/60] [Batch 45/6487] [loss: 2.319836] ETA: 9:38:16.23 [Epoch 0/60] [Batch 46/6487] [loss: 2.354790] ETA: 9:39:08.93 [Epoch 0/60] [Batch 47/6487] [loss: 1.986412] ETA: 9:52:11.40 [Epoch 0/60] [Batch 48/6487] [loss: 2.154071] ETA: 10:08:20.0 [Epoch 0/60] [Batch 49/6487] [loss: 1.425418] ETA: 9:54:04.42 [Epoch 0/60] [Batch 50/6487] [loss: 1.988360] ETA: 9:30:25.08 [Epoch 0/60] [Batch 51/6487] [loss: 4.090429] ETA: 9:43:53.52 [Epoch 0/60] [Batch 52/6487] [loss: 1.924778] ETA: 9:46:19.38 [Epoch 0/60] [Batch 53/6487] [loss: 2.191964] ETA: 9:46:59.93 [Epoch 0/60] [Batch 54/6487] [loss: 2.032799] ETA: 9:46:14.01 [Epoch 0/60] [Batch 55/6487] [loss: 1.923933] ETA: 9:44:21.65 [Epoch 0/60] [Batch 56/6487] [loss: 1.656838] ETA: 9:56:15.90 [Epoch 0/60] [Batch 57/6487] [loss: 1.656845] ETA: 10:21:26.1 [Epoch 0/60] [Batch 58/6487] [loss: 1.157820] ETA: 10:44:47.3 [Epoch 0/60] [Batch 59/6487] [loss: 1.652715] ETA: 10:46:39.9 [Epoch 0/60] [Batch 60/6487] [loss: 1.633865] ETA: 10:23:34.7 [Epoch 0/60] [Batch 61/6487] [loss: 1.290259] ETA: 9:24:06.12Traceback (most recent call last):
File "/home/star/whaiDir/PFCFuse/train.py", line 232, in <module>
loss.item(),
KeyboardInterrupt

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,38 @@
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.01 40.4 15.51 1.55 1.75 0.66 0.54 0.96
================================================================================
Process finished with exit code 0

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

247
net.py
View File

@ -6,6 +6,9 @@ import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange
from componets.SCSA import SCSA
from componets.TIAM import SpatiotemporalAttentionFullNotWeightShared
from componets.WTConvCV2 import WTConv2d
@ -145,6 +148,83 @@ class PoolMlp(nn.Module):
# x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
# x = x + self.drop_path(self.poolmlp(self.norm2(x)))
# return x
class DetailFeatureFusion(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureFusion, self).__init__()
INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
self.enhancement_module = WTConv2d(32, 32)
def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
# 增强并添加残差连接
enhanced_z1 = self.enhancement_module(z1)
enhanced_z2 = self.enhancement_module(z2)
for layer in self.net:
z1, z2 = layer(z1, z2)
# 残差连接
z1 = z1 + enhanced_z1
z2 = z2 + enhanced_z2
return torch.cat((z1, z2), dim=1)
class BaseFeatureFusion(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.norm1 = LayerNorm(dim, 'WithBias')
# self.token_mixer = SMFA(dim=dim)
self.token_mixer = SCSA(dim=dim,head_num=8)
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x): # 1 64 128 128
if self.use_layer_scale:
# self.layer_scale_1(64,)
# wtConvX = self.WTConv2d(x)
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
normal = self.norm1(x) # 1 64 128 128
token_mix = self.token_mixer(normal) # 1 64 128 128
x = (x +
self.drop_path(tmp1 * token_mix)
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
)
# pol = self.poolmlp(self.norm2(x))
#
# x = x + self.drop_path(
# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
# * pol)
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class BaseFeatureExtraction(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
@ -155,7 +235,7 @@ class BaseFeatureExtraction(nn.Module):
super().__init__()
self.WTConv2d = WTConv2d(dim, dim)
self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
@ -178,21 +258,77 @@ class BaseFeatureExtraction(nn.Module):
def forward(self, x): # 1 64 128 128
if self.use_layer_scale:
# self.layer_scale_1(64,)
# wtConvX = self.WTConv2d(x)
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
normal = self.norm1(x) # 1 64 128 128
token_mix = self.token_mixer(normal) # 1 64 128 128
x = (x +
self.drop_path(tmp1 * token_mix)
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
)
# pol = self.poolmlp(self.norm2(x))
#
# x = x + self.drop_path(
# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
# * pol)
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class BaseFeatureExtractionSAR(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.norm1 = LayerNorm(dim, 'WithBias')
# self.token_mixer = SMFA(dim=dim)
self.token_mixer = SCSA(dim=dim,head_num=8)
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
# self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
# self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
# act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x): # 1 64 128 128
if self.use_layer_scale:
# self.layer_scale_1(64,)
# wtConvX = self.WTConv2d(x)
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
normal = self.norm1(x) # 1 64 128 128
token_mix = self.token_mixer(normal) # 1 64 128 128
x = self.WTConv2d(x)
x = (x +
self.drop_path(
tmp1 * token_mix
)
self.drop_path(tmp1 * token_mix)
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
)
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x)))
# pol = self.poolmlp(self.norm2(x))
#
# x = x + self.drop_path(
# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
# * pol)
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
@ -219,15 +355,41 @@ class InvertedResidualBlock(nn.Module):
def forward(self, x):
return self.bottleneckBlock(x)
class DepthwiseSeparableConvBlock(nn.Module):
def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1):
super(DepthwiseSeparableConvBlock, self).__init__()
self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False)
self.pointwise = nn.Conv2d(inp, oup, 1, bias=False)
self.bn = nn.BatchNorm2d(oup)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
x = self.bn(x)
x = self.relu(x)
return x
class DetailNode(nn.Module):
# <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > '
def __init__(self):
def __init__(self,useBlock=0):
super(DetailNode, self).__init__()
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
if useBlock==0:
self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32)
self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32)
self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32)
elif useBlock==1:
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
else:
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.shffleconv = nn.Conv2d(64, 64, kernel_size=1,
stride=1, padding=0, bias=True)
@ -242,26 +404,27 @@ class DetailNode(nn.Module):
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
return z1, z2
class DetailFeatureExtractionSAR(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureExtractionSAR, self).__init__()
# useBlock = 1表示使用 invresblock
INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
def forward(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
class DetailFeatureExtraction(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureExtraction, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
INNmodules = [DetailNode(useBlock=0) for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
self.enhancement_module = nn.Sequential(
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
)
def forward(self, x): # 1 64 128 128
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
# 增强并添加残差连接
enhanced_z1 = self.enhancement_module(z1)
enhanced_z2 = self.enhancement_module(z2)
# 残差连接
z1 = z1 + enhanced_z1
z2 = z2 + enhanced_z2
def forward(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
@ -428,6 +591,12 @@ class OverlapPatchEmbed(nn.Module):
return x
class Restormer_Encoder(nn.Module):
def __init__(self,
inp_channels=1,
@ -441,21 +610,31 @@ class Restormer_Encoder(nn.Module):
):
super(Restormer_Encoder, self).__init__()
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
# 区分
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_level1 = nn.Sequential(
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
self.baseFeature = BaseFeatureExtraction(dim=dim)
self.detailFeature = DetailFeatureExtraction()
def forward(self, inp_img):
self.baseFeature_sar = BaseFeatureExtractionSAR(dim=dim)
self.detailFeature_sar = DetailFeatureExtractionSAR()
def forward(self, inp_img,is_sar = False):
inp_enc_level1 = self.patch_embed(inp_img)
out_enc_level1 = self.encoder_level1(inp_enc_level1)
base_feature = self.baseFeature(out_enc_level1)
detail_feature = self.detailFeature(out_enc_level1)
return base_feature, detail_feature, out_enc_level1
if is_sar:
base_feature = self.baseFeature_sar(out_enc_level1) # 1 64 128 128
detail_feature = self.detailFeature_sar(out_enc_level1) # 1 64 128 128
return base_feature, detail_feature, out_enc_level1 # 1 64 128 128
else:
base_feature = self.baseFeature(out_enc_level1) # 1 64 128 128
detail_feature = self.detailFeature(out_enc_level1) # 1 64 128 128
return base_feature, detail_feature, out_enc_level1 # 1 64 128 128
class Restormer_Decoder(nn.Module):
@ -481,8 +660,12 @@ class Restormer_Decoder(nn.Module):
nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3,
stride=1, padding=1, bias=bias), )
self.sigmoid = nn.Sigmoid()
self.spatiotemporalAttentionFullNotWeightShared = SpatiotemporalAttentionFullNotWeightShared(in_channels=dim)
def forward(self, inp_img, base_feature, detail_feature):
base_feature, detail_feature =self.spatiotemporalAttentionFullNotWeightShared(base_feature, detail_feature)
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
out_enc_level0 = self.reduce_channel(out_enc_level0)
out_enc_level1 = self.encoder_level2(out_enc_level0)

136
status.md Normal file
View File

@ -0,0 +1,136 @@
PFCFuse
```angular2html
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.14 46.48 13.18 2.22 1.76 0.79 0.56 1.02
================================================================================
================================================================================
The test result of RoadScene :
FLIR_07206.jpg
FLIR_08202.jpg
FLIR_05893.jpg
FLIR_06974.jpg
FLIR_04424.jpg
FLIR_08284.jpg
FLIR_07786.jpg
FLIR_08021.jpg
FLIR_07968.jpg
FLIR_01130.jpg
FLIR_06993.jpg
FLIR_07190.jpg
FLIR_06570.jpg
FLIR_07809.jpg
FLIR_06430.jpg
FLIR_08592.jpg
FLIR_00211.jpg
FLIR_08721.jpg
FLIR_05955.jpg
FLIR_04688.jpg
FLIR_07732.jpg
FLIR_06392.jpg
FLIR_00977.jpg
FLIR_05105.jpg
FLIR_04269.jpg
FLIR_07970.jpg
FLIR_05005.jpg
FLIR_07209.jpg
FLIR_07555.jpg
FLIR_06325.jpg
FLIR_04943.jpg
FLIR_video_02829.jpg
FLIR_08248.jpg
FLIR_04484.jpg
FLIR_08058.jpg
FLIR_06795.jpg
FLIR_06995.jpg
FLIR_05879.jpg
FLIR_04593.jpg
FLIR_08094.jpg
FLIR_08526.jpg
FLIR_08858.jpg
FLIR_09465.jpg
FLIR_05064.jpg
FLIR_05857.jpg
FLIR_05914.jpg
FLIR_04722.jpg
FLIR_06506.jpg
FLIR_06282.jpg
FLIR_04512.jpg
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.41 52.99 15.81 2.37 1.78 0.71 0.55 0.96
================================================================================
```
20241008
```
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.01 40.4 15.51 1.55 1.75 0.66 0.54 0.96
================================================================================
Process finished with exit code 0
```

View File

@ -1,3 +1,5 @@
import datetime
import cv2
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
import os
@ -11,16 +13,18 @@ import logging
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.CRITICAL)
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/PFCFusion10-05-20-46.pth"
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-11-20-36.pth"
for dataset_name in ["TNO"]:
for dataset_name in ["sar"]:
print("\n"*2+"="*80)
model_name="PFCFuse "
print("The test result of "+dataset_name+' :')
test_folder=os.path.join('/home/star/whaiDir/CDDFuse/test_img/',dataset_name)
test_out_folder=os.path.join('test_result',dataset_name)
test_folder = os.path.join('test_img', dataset_name)
test_out_folder=os.path.join('test_result',current_time,dataset_name)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
@ -39,7 +43,6 @@ for dataset_name in ["TNO"]:
with torch.no_grad():
for img_name in os.listdir(os.path.join(test_folder,"ir")):
print(img_name)
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0

BIN
test_img/MRI_CT/CT/11.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_CT/CT/12.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/13.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/14.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_CT/CT/15.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_CT/CT/16.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/17.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

BIN
test_img/MRI_CT/CT/18.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

BIN
test_img/MRI_CT/CT/19.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

BIN
test_img/MRI_CT/CT/20.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

BIN
test_img/MRI_CT/CT/21.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

BIN
test_img/MRI_CT/CT/22.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/23.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

BIN
test_img/MRI_CT/CT/24.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

BIN
test_img/MRI_CT/CT/25.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

BIN
test_img/MRI_CT/CT/26.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/27.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

BIN
test_img/MRI_CT/CT/28.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

BIN
test_img/MRI_CT/CT/29.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

BIN
test_img/MRI_CT/CT/30.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

BIN
test_img/MRI_CT/CT/31.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

BIN
test_img/MRI_CT/MRI/11.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

BIN
test_img/MRI_CT/MRI/12.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

BIN
test_img/MRI_CT/MRI/13.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

BIN
test_img/MRI_CT/MRI/14.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

BIN
test_img/MRI_CT/MRI/15.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

BIN
test_img/MRI_CT/MRI/16.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

BIN
test_img/MRI_CT/MRI/17.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

BIN
test_img/MRI_CT/MRI/18.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 57 KiB

BIN
test_img/MRI_CT/MRI/19.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

BIN
test_img/MRI_CT/MRI/20.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

BIN
test_img/MRI_CT/MRI/21.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

BIN
test_img/MRI_CT/MRI/22.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

BIN
test_img/MRI_CT/MRI/23.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

BIN
test_img/MRI_CT/MRI/24.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

BIN
test_img/MRI_CT/MRI/25.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

BIN
test_img/MRI_CT/MRI/26.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

BIN
test_img/MRI_CT/MRI/27.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

BIN
test_img/MRI_CT/MRI/28.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

BIN
test_img/MRI_CT/MRI/29.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

BIN
test_img/MRI_CT/MRI/30.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

BIN
test_img/MRI_CT/MRI/31.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

BIN
test_img/MRI_PET/MRI/11.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/12.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/13.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

BIN
test_img/MRI_PET/MRI/14.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

BIN
test_img/MRI_PET/MRI/15.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/16.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/17.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/18.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/19.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

BIN
test_img/MRI_PET/MRI/20.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/21.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/22.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_PET/MRI/23.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

BIN
test_img/MRI_PET/MRI/24.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

BIN
test_img/MRI_PET/MRI/25.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

BIN
test_img/MRI_PET/MRI/26.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

BIN
test_img/MRI_PET/MRI/27.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

BIN
test_img/MRI_PET/MRI/28.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

BIN
test_img/MRI_PET/MRI/29.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
test_img/MRI_PET/MRI/30.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
test_img/MRI_PET/MRI/31.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Some files were not shown because too many files have changed in this diff Show More