caixiaoshun commited on
Commit
184ac8d
·
1 Parent(s): 55edecd

第一次提交

Browse files
Files changed (31) hide show
  1. app.py +141 -0
  2. examples_png/barren_11.png +0 -0
  3. examples_png/snow_10.png +0 -0
  4. examples_png/vegetation_21.png +0 -0
  5. examples_png/water_22.png +0 -0
  6. logs/train/runs/hrcwhu_cdnetv1/2024-08-03_18-18-09/checkpoints/epoch_063.ckpt +3 -0
  7. logs/train/runs/hrcwhu_cdnetv1/2024-08-03_18-18-09/checkpoints/last.ckpt +3 -0
  8. logs/train/runs/hrcwhu_cdnetv2/2024-08-03_18-26-19/checkpoints/epoch_150.ckpt +3 -0
  9. logs/train/runs/hrcwhu_cdnetv2/2024-08-03_18-26-19/checkpoints/last.ckpt +3 -0
  10. logs/train/runs/hrcwhu_dual_branch/2024-08-03_18-40-06/checkpoints/epoch_009.ckpt +3 -0
  11. logs/train/runs/hrcwhu_dual_branch/2024-08-03_18-40-06/checkpoints/last.ckpt +3 -0
  12. logs/train/runs/hrcwhu_hrcloud/2024-08-03_18-55-52/checkpoints/epoch_024.ckpt +3 -0
  13. logs/train/runs/hrcwhu_hrcloud/2024-08-03_18-55-52/checkpoints/last.ckpt +3 -0
  14. logs/train/runs/hrcwhu_mcdnet/2024-08-04_14-30-56/checkpoints/epoch_032.ckpt +3 -0
  15. logs/train/runs/hrcwhu_mcdnet/2024-08-04_14-30-56/checkpoints/last.ckpt +3 -0
  16. logs/train/runs/hrcwhu_scnn/2024-08-03_19-29-26/checkpoints/epoch_196.ckpt +3 -0
  17. logs/train/runs/hrcwhu_scnn/2024-08-03_19-29-26/checkpoints/last.ckpt +3 -0
  18. requirements.txt +289 -0
  19. src/__init__.py +0 -0
  20. src/models/__init__.py +0 -0
  21. src/models/components/__init__.py +0 -0
  22. src/models/components/cdnetv1.py +396 -0
  23. src/models/components/cdnetv2.py +699 -0
  24. src/models/components/cnn.py +26 -0
  25. src/models/components/dual_branch.py +680 -0
  26. src/models/components/hrcloud.py +751 -0
  27. src/models/components/lnn.py +23 -0
  28. src/models/components/mcdnet.py +448 -0
  29. src/models/components/scnn.py +36 -0
  30. src/models/components/unet.py +63 -0
  31. src/models/components/vae.py +152 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/8/4 下午2:38
3
+ # @Author : xiaoshun
4
+ # @Email : 3038523973@qq.com
5
+ # @File : app.py
6
+ # @Software: PyCharm
7
+
8
+ from glob import glob
9
+ import gradio as gr
10
+ import torch
11
+ import numpy as np
12
+ import cv2
13
+ from PIL import Image
14
+ import albumentations as albu
15
+ from albumentations.pytorch.transforms import ToTensorV2
16
+ from src.data.components.hrcwhu import HRCWHU
17
+ from src.data.hrcwhu_datamodule import HRCWHUDataModule
18
+ from src.models.components.cdnetv1 import CDnetV1
19
+ from src.models.components.cdnetv2 import CDnetV2
20
+ from src.models.components.dual_branch import Dual_Branch
21
+ from src.models.components.hrcloud import HRcloudNet
22
+ from src.models.components.mcdnet import MCDNet
23
+ from src.models.components.scnn import SCNNNet
24
+
25
+
26
+ class Application:
27
+ def __init__(self):
28
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ self.models = {
30
+ "cdnetv1": CDnetV1(num_classes=2).to(self.device),
31
+ "cdnetv2": CDnetV2(num_classes=2).to(self.device),
32
+ "hrcloud": HRcloudNet(num_classes=2).to(self.device),
33
+ "mcdnet": MCDNet(in_channels=3, num_classes=2).to(self.device),
34
+ "scnn": SCNNNet(num_classes=2).to(self.device),
35
+ "dual_branch": Dual_Branch(img_size=256, in_channels=3, num_classes=2).to(
36
+ self.device
37
+ ),
38
+ }
39
+ self.__load_weight()
40
+ self.transform = albu.Compose(
41
+ [
42
+ albu.Resize(256, 256, always_apply=True),
43
+ ToTensorV2(),
44
+ ]
45
+ )
46
+
47
+ def __load_weight(self):
48
+ """
49
+ 将模型权重加载进来
50
+ """
51
+ for model_name, model in self.models.items():
52
+ weight_path = glob(
53
+ f"logs/train/runs/*{model_name}*/*/checkpoints/*epoch*.ckpt"
54
+ )[0]
55
+ weight = torch.load(weight_path, map_location=self.device)
56
+ state_dict = {}
57
+ for key, value in weight["state_dict"].items():
58
+ new_key = key[4:]
59
+ state_dict[new_key] = value
60
+ model.load_state_dict(state_dict)
61
+ model.eval()
62
+ print(f"{model_name} weight loaded!")
63
+
64
+ @torch.no_grad
65
+ def inference(self, image: torch.Tensor, model_name: str):
66
+ x = image.float()
67
+ x = x.unsqueeze(0)
68
+ x = x.to(self.device)
69
+ logits = self.models[model_name](x)
70
+ if isinstance(logits, tuple):
71
+ logits = logits[0]
72
+ fake_mask = torch.argmax(logits, 1).detach().cpu().squeeze(0).numpy()
73
+ return fake_mask
74
+
75
+ def give_colors_to_mask(self, mask: np.ndarray):
76
+ """
77
+ 赋予mask颜色
78
+ """
79
+ assert len(mask.shape) == 2, "Value Error,mask的形状为(height,width)"
80
+ colors_mask = np.zeros((mask.shape[0], mask.shape[1], 3)).astype(np.float32)
81
+ colors = ((255, 255, 255), (128, 192, 128))
82
+ for color in range(2):
83
+ segc = mask == color
84
+ colors_mask[:, :, 0] += segc * (colors[color][0])
85
+ colors_mask[:, :, 1] += segc * (colors[color][1])
86
+ colors_mask[:, :, 2] += segc * (colors[color][2])
87
+ return colors_mask
88
+
89
+ def to_pil(self, image: np.ndarray, width=None, height=None):
90
+ colors_np = self.give_colors_to_mask(image)
91
+ pil_np = Image.fromarray(np.uint8(colors_np))
92
+ if width and height:
93
+ pil_np = pil_np.resize((width, height))
94
+ return pil_np
95
+
96
+ def flip(self, image_pil: Image.Image, model_name: str):
97
+ if image_pil is None:
98
+ return Image.fromarray(np.uint8(np.random.random((32,32,3)) * 255)), "请上传一张图片"
99
+ if model_name is None:
100
+ return Image.fromarray(np.uint8(np.random.random((32,32,3)) * 255)), "请选择模型名称"
101
+ image = np.array(image_pil)
102
+ raw_height, raw_width = image.shape[0], image.shape[1]
103
+ transform = self.transform(image=image)
104
+ image = transform["image"]
105
+ image = image / 255.0
106
+ fake_image = self.inference(image, model_name)
107
+ fake_image = self.to_pil(fake_image, raw_width, raw_height)
108
+ return fake_image,"success"
109
+
110
+ def tiff_to_png(image: Image.Image):
111
+ if image.format == "TIFF":
112
+ image = image.convert("RGB")
113
+ return np.array(image)
114
+
115
+ def run(self):
116
+ app = gr.Interface(
117
+ self.flip,
118
+ [
119
+ gr.Image(sources=["clipboard", "upload"], type="pil"),
120
+ gr.Radio(
121
+ ["cdnetv1", "cdnetv2", "hrcloud", "mcdnet", "scnn", "dual_branch"],
122
+ label="model_name",
123
+ info="选择使用的模型",
124
+ ),
125
+ ],
126
+ [gr.Image(), gr.Textbox(label="提示信息")],
127
+ examples=[
128
+ ["examples_png/barren_11.png", "dual_branch"],
129
+ ["examples_png/snow_10.png", "scnn"],
130
+ ["examples_png/vegetation_21.png", "cdnetv2"],
131
+ ["examples_png/water_22.png", "hrcloud"],
132
+ ],
133
+ title="云检测模型在线演示",
134
+ submit_btn=gr.Button("Submit", variant="primary")
135
+ )
136
+ app.launch(share=True)
137
+
138
+
139
+ if __name__ == "__main__":
140
+ app = Application()
141
+ app.run()
examples_png/barren_11.png ADDED
examples_png/snow_10.png ADDED
examples_png/vegetation_21.png ADDED
examples_png/water_22.png ADDED
logs/train/runs/hrcwhu_cdnetv1/2024-08-03_18-18-09/checkpoints/epoch_063.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:289e19ba0c05d81e4c350cef50e189e27078cbf291c818c4c700147f6dd04eb5
3
+ size 186417158
logs/train/runs/hrcwhu_cdnetv1/2024-08-03_18-18-09/checkpoints/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcede25335ecffb78a0eb1ebd399d7c3d6dba361f5a7f6a888d260a830c3b1fc
3
+ size 186417158
logs/train/runs/hrcwhu_cdnetv2/2024-08-03_18-26-19/checkpoints/epoch_150.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15d74b3d0f92edfba47077984f125a15b6e576b707913899a497ccdc7b5bc3b6
3
+ size 271036010
logs/train/runs/hrcwhu_cdnetv2/2024-08-03_18-26-19/checkpoints/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e36ddc6ced1b7092459cc2b2a9563b7e75199ca7436ba447b677537716d9b2e
3
+ size 271036010
logs/train/runs/hrcwhu_dual_branch/2024-08-03_18-40-06/checkpoints/epoch_009.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36e0e174e7e61a91b895cb81cdd9e584ee1a8c5ce9536ef6d20508c8cae8b05a
3
+ size 1143163971
logs/train/runs/hrcwhu_dual_branch/2024-08-03_18-40-06/checkpoints/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05c1245fd9693a1eafdadd41aa7d60dcffef9b5dcbec5aceae06ebc66a62bf86
3
+ size 1143164035
logs/train/runs/hrcwhu_hrcloud/2024-08-03_18-55-52/checkpoints/epoch_024.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3554e396cddb6e56b89062d84c1a6538d980cdc3adb9e27e53bc0d241a8cffbf
3
+ size 893819887
logs/train/runs/hrcwhu_hrcloud/2024-08-03_18-55-52/checkpoints/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef0c18c6eef14e194ec38fd287f07a8532f753130289a729c9101ad28f76d8e9
3
+ size 893819887
logs/train/runs/hrcwhu_mcdnet/2024-08-04_14-30-56/checkpoints/epoch_032.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1422b109a63c4a3f612539c6b6b04fddfe151f7f35f756dc406319b523179d68
3
+ size 52581962
logs/train/runs/hrcwhu_mcdnet/2024-08-04_14-30-56/checkpoints/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8261111d7ef82879c3c8e7dd241103340fbc85a75c1924b9b350d5d0c7dd6751
3
+ size 52581962
logs/train/runs/hrcwhu_scnn/2024-08-03_19-29-26/checkpoints/epoch_196.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9769c01aeeb86f0b03b45ca2b3730f1798be0a72d83f3b5b8e6fefa81018ef63
3
+ size 26732
logs/train/runs/hrcwhu_scnn/2024-08-03_19-29-26/checkpoints/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d54c8f0700e6e1fe8b8f83acc7178dac7f17445a7fed79e084e50a036a50f3cf
3
+ size 26732
requirements.txt ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ addict==2.4.0
3
+ aiofiles==23.2.1
4
+ aiohttp==3.9.5
5
+ aiosignal==1.3.1
6
+ albucore==0.0.12
7
+ albumentations==1.4.12
8
+ alembic==1.13.2
9
+ aliyun-python-sdk-core==2.15.1
10
+ aliyun-python-sdk-kms==2.16.3
11
+ annotated-types==0.7.0
12
+ antlr4-python3-runtime==4.9.3
13
+ anyio==4.4.0
14
+ argon2-cffi==23.1.0
15
+ argon2-cffi-bindings==21.2.0
16
+ arrow==1.3.0
17
+ asttokens==2.4.1
18
+ astunparse==1.6.3
19
+ async-lru==2.0.4
20
+ async-timeout==4.0.3
21
+ attrs==23.2.0
22
+ autopage==0.5.2
23
+ Babel==2.15.0
24
+ beautifulsoup4==4.12.3
25
+ bleach==6.1.0
26
+ blessed==1.20.0
27
+ Brotli @ file:///croot/brotli-split_1714483155106/work
28
+ certifi @ file:///croot/certifi_1717618050233/work/certifi
29
+ cffi==1.16.0
30
+ cfgv==3.4.0
31
+ chardet==5.2.0
32
+ charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
33
+ click==8.1.7
34
+ cliff==4.7.0
35
+ cmaes==0.10.0
36
+ cmd2==2.4.3
37
+ colorama==0.4.6
38
+ colorlog==6.8.2
39
+ comm==0.2.2
40
+ contourpy==1.2.1
41
+ crcmod==1.7
42
+ cryptography==42.0.8
43
+ cycler==0.12.1
44
+ debugpy==1.8.1
45
+ decorator==5.1.1
46
+ defusedxml==0.7.1
47
+ dill==0.3.8
48
+ distlib==0.3.8
49
+ docker-pycreds==0.4.0
50
+ efficientnet_pytorch==0.7.1
51
+ einops @ file:///home/conda/feedstock_root/build_artifacts/einops_1714285159399/work
52
+ eval_type_backport==0.2.0
53
+ exceptiongroup==1.2.1
54
+ executing==2.0.1
55
+ fastapi==0.112.0
56
+ fastjsonschema==2.20.0
57
+ ffmpy==0.4.0
58
+ filelock==3.14.0
59
+ flatbuffers==24.3.25
60
+ fonttools==4.53.0
61
+ fqdn==1.5.1
62
+ frozenlist==1.4.1
63
+ fsspec==2024.5.0
64
+ ftfy==6.2.0
65
+ future==1.0.0
66
+ gast==0.6.0
67
+ gitdb==4.0.11
68
+ GitPython==3.1.43
69
+ gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645455533097/work
70
+ google-pasta==0.2.0
71
+ gpustat==1.1.1
72
+ gradio==4.40.0
73
+ gradio_client==1.2.0
74
+ greenlet==3.0.3
75
+ grpcio==1.64.1
76
+ h11==0.14.0
77
+ h5py==3.11.0
78
+ hf_transfer==0.1.8
79
+ httpcore==1.0.5
80
+ httpx==0.27.0
81
+ huggingface-hub==0.23.4
82
+ hydra-colorlog==1.2.0
83
+ hydra-core==1.3.2
84
+ hydra-optuna-sweeper==1.2.0
85
+ identify==2.6.0
86
+ idna @ file:///croot/idna_1714398848350/work
87
+ image-dehazer==0.0.9
88
+ imageio==2.34.2
89
+ imgaug==0.4.0
90
+ importlib_metadata==8.0.0
91
+ importlib_resources==6.4.0
92
+ iniconfig==2.0.0
93
+ ipykernel==6.29.4
94
+ ipython==8.25.0
95
+ isoduration==20.11.0
96
+ jedi==0.19.1
97
+ Jinja2 @ file:///croot/jinja2_1716993405101/work
98
+ jmespath==0.10.0
99
+ joblib==1.4.2
100
+ json5==0.9.25
101
+ jsonpointer==3.0.0
102
+ jsonschema==4.22.0
103
+ jsonschema-specifications==2023.12.1
104
+ jupyter-events==0.10.0
105
+ jupyter-lsp==2.2.5
106
+ jupyter_client==8.6.2
107
+ jupyter_core==5.7.2
108
+ jupyter_server==2.14.1
109
+ jupyter_server_terminals==0.5.3
110
+ jupyterlab==4.2.2
111
+ jupyterlab_pygments==0.3.0
112
+ jupyterlab_server==2.27.2
113
+ keras==3.4.1
114
+ kiwisolver==1.4.5
115
+ kornia==0.7.3
116
+ kornia_rs==0.1.5
117
+ lazy_loader==0.4
118
+ libclang==18.1.1
119
+ lightning==2.3.3
120
+ lightning-utilities==0.11.3.post0
121
+ Mako==1.3.5
122
+ mamba-ssm==2.2.2
123
+ Markdown==3.6
124
+ markdown-it-py==3.0.0
125
+ MarkupSafe @ file:///croot/markupsafe_1704205993651/work
126
+ matplotlib==3.9.0
127
+ matplotlib-inline==0.1.7
128
+ mdurl==0.1.2
129
+ mistune==3.0.2
130
+ mkl-fft @ file:///croot/mkl_fft_1695058164594/work
131
+ mkl-random @ file:///croot/mkl_random_1695059800811/work
132
+ mkl-service==2.4.0
133
+ ml-dtypes==0.4.0
134
+ mmcv==2.1.0
135
+ mmengine==0.10.4
136
+ mmsegmentation==1.2.2
137
+ model-index==0.1.11
138
+ mpmath @ file:///croot/mpmath_1690848262763/work
139
+ multidict==6.0.5
140
+ multiprocess==0.70.16
141
+ munch==4.0.0
142
+ namex==0.0.8
143
+ natsort==8.4.0
144
+ nbclient==0.10.0
145
+ nbconvert==7.16.4
146
+ nbformat==5.10.4
147
+ nest-asyncio==1.6.0
148
+ networkx @ file:///croot/networkx_1717597493534/work
149
+ ninja==1.11.1.1
150
+ nodeenv==1.9.1
151
+ notebook_shim==0.2.4
152
+ numpy @ file:///croot/numpy_and_numpy_base_1708638617955/work/dist/numpy-1.26.4-cp310-cp310-linux_x86_64.whl#sha256=d8cd837ed43e87f77e6efaa08e8de927ca030a1c9c5d04624432d6fb9a74a5ee
153
+ nvidia-ml-py==12.555.43
154
+ omegaconf==2.3.0
155
+ opencv-python==4.10.0.84
156
+ opencv-python-headless==4.10.0.84
157
+ opendatalab==0.0.10
158
+ openmim==0.3.9
159
+ openxlab==0.1.1
160
+ opt-einsum==3.3.0
161
+ optree==0.12.1
162
+ optuna==2.10.1
163
+ ordered-set==4.1.0
164
+ orjson==3.10.6
165
+ oss2==2.17.0
166
+ overrides==7.7.0
167
+ packaging==24.1
168
+ pandas==2.2.2
169
+ pandocfilters==1.5.1
170
+ parso==0.8.4
171
+ patsy==0.5.6
172
+ pbr==6.0.0
173
+ pexpect==4.9.0
174
+ pillow @ file:///croot/pillow_1714398848491/work
175
+ platformdirs==4.2.2
176
+ plotly==5.22.0
177
+ pluggy==1.5.0
178
+ pre-commit==3.7.1
179
+ pretrainedmodels==0.7.4
180
+ prettytable==3.10.0
181
+ prometheus_client==0.20.0
182
+ prompt_toolkit==3.0.47
183
+ protobuf==4.25.3
184
+ psutil==6.0.0
185
+ ptyprocess==0.7.0
186
+ pure-eval==0.2.2
187
+ pyarrow==16.1.0
188
+ pyarrow-hotfix==0.6
189
+ pycparser==2.22
190
+ pycryptodome==3.20.0
191
+ pydantic==2.8.0
192
+ pydantic_core==2.20.0
193
+ pydub==0.25.1
194
+ Pygments==2.18.0
195
+ pyparsing==3.1.2
196
+ pyperclip==1.9.0
197
+ pyproject==1.3.1
198
+ PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work
199
+ pytest==8.2.2
200
+ python-dateutil==2.9.0.post0
201
+ python-dotenv==1.0.1
202
+ python-json-logger==2.0.7
203
+ python-multipart==0.0.9
204
+ pytorch-lightning==2.3.3
205
+ pytz==2023.4
206
+ pyvips==2.2.3
207
+ PyYAML @ file:///croot/pyyaml_1698096049011/work
208
+ pyzmq==26.0.3
209
+ referencing==0.35.1
210
+ regex==2024.5.15
211
+ requests==2.32.3
212
+ rfc3339-validator==0.1.4
213
+ rfc3986-validator==0.1.1
214
+ rich==13.4.2
215
+ rootutils==1.0.7
216
+ rpds-py==0.18.1
217
+ ruff==0.5.6
218
+ safetensors==0.4.3
219
+ scikit-image==0.24.0
220
+ scikit-learn==1.5.0
221
+ scipy==1.13.1
222
+ seaborn==0.13.2
223
+ segmentation-models-pytorch==0.3.3
224
+ semantic-version==2.10.0
225
+ Send2Trash==1.8.3
226
+ sentry-sdk==2.9.0
227
+ setproctitle==1.3.3
228
+ shapely==2.0.4
229
+ shellingham==1.5.4
230
+ six==1.16.0
231
+ smmap==5.0.1
232
+ sniffio==1.3.1
233
+ soupsieve==2.5
234
+ SQLAlchemy==2.0.31
235
+ stack-data==0.6.3
236
+ starlette==0.37.2
237
+ statsmodels==0.14.2
238
+ stevedore==5.2.0
239
+ sympy @ file:///croot/sympy_1701397643339/work
240
+ tabulate==0.9.0
241
+ tenacity==8.4.1
242
+ tensorboard==2.17.0
243
+ tensorboard-data-server==0.7.2
244
+ tensorboardX==2.6.2.2
245
+ tensorflow==2.17.0
246
+ tensorflow-io-gcs-filesystem==0.37.1
247
+ termcolor==2.4.0
248
+ terminado==0.18.1
249
+ thop==0.1.1.post2209072238
250
+ threadpoolctl==3.5.0
251
+ tifffile==2024.6.18
252
+ timm==0.9.2
253
+ tinycss2==1.3.0
254
+ tokenizers==0.19.1
255
+ tomli==2.0.1
256
+ tomlkit==0.12.0
257
+ torch==2.3.1
258
+ torchaudio==2.3.1
259
+ torchinfo==1.8.0
260
+ torchmetrics==1.4.0.post0
261
+ torchsummary==1.5.1
262
+ torchvision==0.18.1
263
+ tornado==6.4.1
264
+ tqdm==4.65.2
265
+ traitlets==5.14.3
266
+ transformers==4.42.4
267
+ triton==2.3.1
268
+ typer==0.12.3
269
+ types-python-dateutil==2.9.0.20240316
270
+ typing_extensions @ file:///croot/typing_extensions_1715268824938/work
271
+ tzdata==2024.1
272
+ uri-template==1.3.0
273
+ urllib3==2.2.2
274
+ uvicorn==0.30.5
275
+ virtualenv==20.26.3
276
+ wandb==0.17.4
277
+ wcwidth==0.2.13
278
+ webcolors==24.6.0
279
+ webencodings==0.5.1
280
+ websocket-client==1.8.0
281
+ websockets==12.0
282
+ Werkzeug==3.0.3
283
+ wrapt==1.16.0
284
+ xlrd==2.0.1
285
+ xxhash==3.4.1
286
+ yacs==0.1.8
287
+ yapf==0.40.2
288
+ yarl==1.9.4
289
+ zipp==3.19.2
src/__init__.py ADDED
File without changes
src/models/__init__.py ADDED
File without changes
src/models/components/__init__.py ADDED
File without changes
src/models/components/cdnetv1.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/7/24 上午11:36
3
+ # @Author : xiaoshun
4
+ # @Email : 3038523973@qq.com
5
+ # @File : cdnetv1.py
6
+ # @Software: PyCharm
7
+
8
+ """Cloud detection Network"""
9
+
10
+ """Cloud detection Network"""
11
+
12
+ """
13
+ This is the implementation of CDnetV1 without multi-scale inputs. This implementation uses ResNet by default.
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.optim as optim
19
+ import torch.nn.functional as F
20
+ import torch.backends.cudnn as cudnn
21
+ from torch.utils import data, model_zoo
22
+ from torch.autograd import Variable
23
+ import math
24
+ import numpy as np
25
+
26
+ affine_par = True
27
+ from torch.autograd import Function
28
+
29
+
30
+ def conv3x3(in_planes, out_planes, stride=1):
31
+ "3x3 convolution with padding"
32
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
33
+ padding=1, bias=False)
34
+
35
+
36
+ class BasicBlock(nn.Module):
37
+ expansion = 1
38
+
39
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
40
+ super(BasicBlock, self).__init__()
41
+ self.conv1 = conv3x3(inplanes, planes, stride)
42
+ self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
43
+ self.relu = nn.ReLU(inplace=True)
44
+ self.conv2 = conv3x3(planes, planes)
45
+ self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
46
+ self.downsample = downsample
47
+ self.stride = stride
48
+
49
+ def forward(self, x):
50
+ residual = x
51
+
52
+ out = self.conv1(x)
53
+ out = self.bn1(out)
54
+ out = self.relu(out)
55
+
56
+ out = self.conv2(out)
57
+ out = self.bn2(out)
58
+
59
+ if self.downsample is not None:
60
+ residual = self.downsample(x)
61
+
62
+ out += residual
63
+ out = self.relu(out)
64
+
65
+ return out
66
+
67
+
68
+ class Bottleneck(nn.Module):
69
+ expansion = 4
70
+
71
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
72
+ super(Bottleneck, self).__init__()
73
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change
74
+ self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
75
+ for i in self.bn1.parameters():
76
+ i.requires_grad = False
77
+
78
+ padding = dilation
79
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change
80
+ padding=padding, bias=False, dilation=dilation)
81
+ self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
82
+ for i in self.bn2.parameters():
83
+ i.requires_grad = False
84
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
85
+ self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
86
+ for i in self.bn3.parameters():
87
+ i.requires_grad = False
88
+ self.relu = nn.ReLU(inplace=True)
89
+ self.downsample = downsample
90
+ self.stride = stride
91
+
92
+ def forward(self, x):
93
+ residual = x
94
+
95
+ out = self.conv1(x)
96
+ out = self.bn1(out)
97
+ out = self.relu(out)
98
+
99
+ out = self.conv2(out)
100
+ out = self.bn2(out)
101
+ out = self.relu(out)
102
+
103
+ out = self.conv3(out)
104
+ out = self.bn3(out)
105
+
106
+ if self.downsample is not None:
107
+ residual = self.downsample(x)
108
+
109
+ out += residual
110
+ out = self.relu(out)
111
+
112
+ return out
113
+
114
+
115
+ class Classifier_Module(nn.Module):
116
+
117
+ def __init__(self, dilation_series, padding_series, num_classes):
118
+ super(Classifier_Module, self).__init__()
119
+ self.conv2d_list = nn.ModuleList()
120
+ for dilation, padding in zip(dilation_series, padding_series):
121
+ self.conv2d_list.append(
122
+ nn.Conv2d(2048, num_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=True))
123
+
124
+ for m in self.conv2d_list:
125
+ m.weight.data.normal_(0, 0.01)
126
+
127
+ def forward(self, x):
128
+ out = self.conv2d_list[0](x)
129
+ for i in range(len(self.conv2d_list) - 1):
130
+ out += self.conv2d_list[i + 1](x)
131
+ return out
132
+
133
+
134
+ class _ConvBNReLU(nn.Module):
135
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
136
+ dilation=1, groups=1, norm_layer=nn.BatchNorm2d):
137
+ super(_ConvBNReLU, self).__init__()
138
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
139
+ self.bn = norm_layer(out_channels)
140
+ self.relu = nn.ReLU(True)
141
+
142
+ def forward(self, x):
143
+ x = self.conv(x)
144
+ x = self.bn(x)
145
+ x = self.relu(x)
146
+ return x
147
+
148
+
149
+ class _ASPPConv(nn.Module):
150
+ def __init__(self, in_channels, out_channels, atrous_rate, norm_layer):
151
+ super(_ASPPConv, self).__init__()
152
+ self.block = nn.Sequential(
153
+ nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate, dilation=atrous_rate, bias=False),
154
+ norm_layer(out_channels),
155
+ nn.ReLU(True)
156
+ )
157
+
158
+ def forward(self, x):
159
+ return self.block(x)
160
+
161
+
162
+ class _AsppPooling(nn.Module):
163
+ def __init__(self, in_channels, out_channels, norm_layer):
164
+ super(_AsppPooling, self).__init__()
165
+ self.gap = nn.Sequential(
166
+ nn.AdaptiveAvgPool2d(1),
167
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
168
+ norm_layer(out_channels),
169
+ nn.ReLU(True)
170
+ )
171
+
172
+ def forward(self, x):
173
+ size = x.size()[2:]
174
+ pool = self.gap(x)
175
+ out = F.interpolate(pool, size, mode='bilinear', align_corners=True)
176
+ return out
177
+
178
+
179
+ class _ASPP(nn.Module):
180
+ def __init__(self, in_channels, atrous_rates, norm_layer):
181
+ super(_ASPP, self).__init__()
182
+ out_channels = 512 # changed from 256
183
+ self.b0 = nn.Sequential(
184
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
185
+ norm_layer(out_channels),
186
+ nn.ReLU(True)
187
+ )
188
+
189
+ rate1, rate2, rate3 = tuple(atrous_rates)
190
+ self.b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
191
+ self.b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
192
+ self.b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
193
+ self.b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
194
+
195
+ # self.project = nn.Sequential(
196
+ # nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
197
+ # norm_layer(out_channels),
198
+ # nn.ReLU(True),
199
+ # nn.Dropout(0.5))
200
+ self.dropout2d = nn.Dropout2d(0.3)
201
+
202
+ def forward(self, x):
203
+ feat1 = self.dropout2d(self.b0(x))
204
+ feat2 = self.dropout2d(self.b1(x))
205
+ feat3 = self.dropout2d(self.b2(x))
206
+ feat4 = self.dropout2d(self.b3(x))
207
+ feat5 = self.dropout2d(self.b4(x))
208
+ x = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
209
+ # x = self.project(x)
210
+ return x
211
+
212
+
213
+ class _FPM(nn.Module):
214
+ def __init__(self, in_channels, num_classes, norm_layer=nn.BatchNorm2d):
215
+ super(_FPM, self).__init__()
216
+ self.aspp = _ASPP(in_channels, [6, 12, 18], norm_layer=norm_layer)
217
+ # self.dropout2d = nn.Dropout2d(0.5)
218
+
219
+ def forward(self, x):
220
+ x = torch.cat((x, self.aspp(x)), dim=1)
221
+ # x = self.dropout2d(x) # added
222
+ return x
223
+
224
+
225
+ class BR(nn.Module):
226
+ def __init__(self, num_classes, stride=1, downsample=None):
227
+ super(BR, self).__init__()
228
+ self.conv1 = conv3x3(num_classes, num_classes * 16, stride)
229
+ self.relu = nn.ReLU(inplace=True)
230
+ self.conv2 = conv3x3(num_classes * 16, num_classes)
231
+ self.stride = stride
232
+
233
+ def forward(self, x):
234
+ residual = x
235
+
236
+ out = self.conv1(x)
237
+ out = self.relu(out)
238
+
239
+ out = self.conv2(out)
240
+ out += residual
241
+
242
+ return out
243
+
244
+
245
+ class CDnetV1(nn.Module):
246
+ def __init__(self, block=Bottleneck, layers=[3, 4, 6, 3], num_classes=21, aux=True):
247
+ self.inplanes = 64
248
+ self.aux = aux
249
+ super().__init__()
250
+ # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
251
+ # self.bn1 = nn.BatchNorm2d(64, affine = affine_par)
252
+
253
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
254
+ self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
255
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
256
+ self.bn2 = nn.BatchNorm2d(64, affine=affine_par)
257
+ self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
258
+ self.bn3 = nn.BatchNorm2d(64, affine=affine_par)
259
+
260
+ for i in self.bn1.parameters():
261
+ i.requires_grad = False
262
+ self.relu = nn.ReLU(inplace=True)
263
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change
264
+ self.layer1 = self._make_layer(block, 64, layers[0])
265
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
266
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
267
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
268
+ # self.layer5 = self._make_pred_layer(Classifier_Module, [6,12,18,24],[6,12,18,24],num_classes)
269
+
270
+ self.res5_con1x1 = nn.Sequential(
271
+ nn.Conv2d(1024 + 2048, 512, kernel_size=1, stride=1, padding=0),
272
+ nn.BatchNorm2d(512),
273
+ nn.ReLU(True)
274
+ )
275
+
276
+ self.fpm1 = _FPM(512, num_classes)
277
+ self.fpm2 = _FPM(512, num_classes)
278
+ self.fpm3 = _FPM(256, num_classes)
279
+
280
+ self.br1 = BR(num_classes)
281
+ self.br2 = BR(num_classes)
282
+ self.br3 = BR(num_classes)
283
+ self.br4 = BR(num_classes)
284
+ self.br5 = BR(num_classes)
285
+ self.br6 = BR(num_classes)
286
+ self.br7 = BR(num_classes)
287
+
288
+ self.predict1 = self._predict_layer(512 * 6, num_classes)
289
+ self.predict2 = self._predict_layer(512 * 6, num_classes)
290
+ self.predict3 = self._predict_layer(512 * 5 + 256, num_classes)
291
+
292
+ for m in self.modules():
293
+ if isinstance(m, nn.Conv2d):
294
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
295
+ m.weight.data.normal_(0, 0.01)
296
+ elif isinstance(m, nn.BatchNorm2d):
297
+ m.weight.data.fill_(1)
298
+ m.bias.data.zero_()
299
+ # for i in m.parameters():
300
+ # i.requires_grad = False
301
+
302
+ def _predict_layer(self, in_channels, num_classes):
303
+ return nn.Sequential(nn.Conv2d(in_channels, 256, kernel_size=1, stride=1, padding=0),
304
+ nn.BatchNorm2d(256),
305
+ nn.ReLU(True),
306
+ nn.Dropout2d(0.1),
307
+ nn.Conv2d(256, num_classes, kernel_size=3, stride=1, padding=1, bias=True))
308
+
309
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
310
+ downsample = None
311
+ if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4:
312
+ downsample = nn.Sequential(
313
+ nn.Conv2d(self.inplanes, planes * block.expansion,
314
+ kernel_size=1, stride=stride, bias=False),
315
+ nn.BatchNorm2d(planes * block.expansion, affine=affine_par))
316
+ for i in downsample._modules['1'].parameters():
317
+ i.requires_grad = False
318
+ layers = []
319
+ layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))
320
+ self.inplanes = planes * block.expansion
321
+ for i in range(1, blocks):
322
+ layers.append(block(self.inplanes, planes, dilation=dilation))
323
+
324
+ return nn.Sequential(*layers)
325
+
326
+ # def _make_pred_layer(self,block, dilation_series, padding_series,num_classes):
327
+ # return block(dilation_series,padding_series,num_classes)
328
+
329
+ def base_forward(self, x):
330
+ x = self.relu(self.bn1(self.conv1(x)))
331
+ size_conv1 = x.size()[2:]
332
+ x = self.relu(self.bn2(self.conv2(x)))
333
+ x = self.relu(self.bn3(self.conv3(x)))
334
+ x = self.maxpool(x)
335
+ x = self.layer1(x)
336
+ res2 = x
337
+ x = self.layer2(x)
338
+ res3 = x
339
+ x = self.layer3(x)
340
+ res4 = x
341
+ x = self.layer4(x)
342
+ x = self.res5_con1x1(torch.cat([x, res4], dim=1))
343
+
344
+ return x, res3, res2, size_conv1
345
+
346
+ def forward(self, x):
347
+ size = x.size()[2:]
348
+ score1, score2, score3, size_conv1 = self.base_forward(x)
349
+ # outputs = list()
350
+ score1 = self.fpm1(score1)
351
+ score1 = self.predict1(score1) # 1/8
352
+ predict1 = score1
353
+ score1 = self.br1(score1)
354
+
355
+ score2 = self.fpm2(score2)
356
+ score2 = self.predict2(score2) # 1/8
357
+ predict2 = score2
358
+
359
+ # first fusion
360
+ score2 = self.br2(score2) + score1
361
+ score2 = self.br3(score2)
362
+
363
+ score3 = self.fpm3(score3)
364
+ score3 = self.predict3(score3) # 1/4
365
+ predict3 = score3
366
+ score3 = self.br4(score3)
367
+
368
+ # second fusion
369
+ size_score3 = score3.size()[2:]
370
+ score3 = score3 + F.interpolate(score2, size_score3, mode='bilinear', align_corners=True)
371
+ score3 = self.br5(score3)
372
+
373
+ # upsampling + BR
374
+ score3 = F.interpolate(score3, size_conv1, mode='bilinear', align_corners=True)
375
+ score3 = self.br6(score3)
376
+ score3 = F.interpolate(score3, size, mode='bilinear', align_corners=True)
377
+ score3 = self.br7(score3)
378
+
379
+ # if self.aux:
380
+ # auxout = self.dsn(mid)
381
+ # auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
382
+ # #outputs.append(auxout)
383
+ return score3
384
+ # return score3, predict1, predict2, predict3
385
+
386
+
387
+ if __name__ == '__main__':
388
+ model = CDnetV1(num_classes=21)
389
+ fake_image = torch.randn(2, 3, 224, 224)
390
+ outputs = model(fake_image)
391
+ for out in outputs:
392
+ print(out.shape)
393
+ # torch.Size([2, 21, 224, 224])
394
+ # torch.Size([2, 21, 29, 29])
395
+ # torch.Size([2, 21, 29, 29])
396
+ # torch.Size([2, 21, 57, 57])
src/models/components/cdnetv2.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/7/24 下午3:41
3
+ # @Author : xiaoshun
4
+ # @Email : 3038523973@qq.com
5
+ # @File : cdnetv2.py
6
+ # @Software: PyCharm
7
+
8
+ """Cloud detection Network"""
9
+
10
+ """
11
+ This is the implementation of CDnetV2 without multi-scale inputs. This implementation uses ResNet by default.
12
+ """
13
+ # nn.GroupNorm
14
+
15
+ import torch
16
+ from torch import nn
17
+ # import torch.nn as nn
18
+ import torch.optim as optim
19
+ import torch.nn.functional as F
20
+ import torch.backends.cudnn as cudnn
21
+ from torch.utils import data, model_zoo
22
+ from torch.autograd import Variable
23
+ import math
24
+ import numpy as np
25
+
26
+ affine_par = True
27
+ from torch.autograd import Function
28
+
29
+
30
+ def conv3x3(in_planes, out_planes, stride=1):
31
+ "3x3 convolution with padding"
32
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
33
+ padding=1, bias=False)
34
+
35
+
36
+ class BasicBlock(nn.Module):
37
+ expansion = 1
38
+
39
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
40
+ super(BasicBlock, self).__init__()
41
+ self.conv1 = conv3x3(inplanes, planes, stride)
42
+ self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
43
+ self.relu = nn.ReLU(inplace=True)
44
+ self.conv2 = conv3x3(planes, planes)
45
+ self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
46
+ self.downsample = downsample
47
+ self.stride = stride
48
+
49
+ def forward(self, x):
50
+ residual = x
51
+
52
+ out = self.conv1(x)
53
+ out = self.bn1(out)
54
+ out = self.relu(out)
55
+
56
+ out = self.conv2(out)
57
+ out = self.bn2(out)
58
+
59
+ if self.downsample is not None:
60
+ residual = self.downsample(x)
61
+
62
+ out += residual
63
+ out = self.relu(out)
64
+
65
+ return out
66
+
67
+
68
+ class Bottleneck(nn.Module):
69
+ expansion = 4
70
+
71
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
72
+ super(Bottleneck, self).__init__()
73
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change
74
+ self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
75
+ for i in self.bn1.parameters():
76
+ i.requires_grad = False
77
+
78
+ padding = dilation
79
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change
80
+ padding=padding, bias=False, dilation=dilation)
81
+ self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
82
+ for i in self.bn2.parameters():
83
+ i.requires_grad = False
84
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
85
+ self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
86
+ for i in self.bn3.parameters():
87
+ i.requires_grad = False
88
+ self.relu = nn.ReLU(inplace=True)
89
+ self.downsample = downsample
90
+ self.stride = stride
91
+
92
+ def forward(self, x):
93
+ residual = x
94
+
95
+ out = self.conv1(x)
96
+ out = self.bn1(out)
97
+ out = self.relu(out)
98
+
99
+ out = self.conv2(out)
100
+ out = self.bn2(out)
101
+ out = self.relu(out)
102
+
103
+ out = self.conv3(out)
104
+ out = self.bn3(out)
105
+
106
+ if self.downsample is not None:
107
+ residual = self.downsample(x)
108
+
109
+ out += residual
110
+ out = self.relu(out)
111
+
112
+ return out
113
+
114
+ # self.layerx_1 = Bottleneck_nosample(64, 64, stride=1, dilation=1)
115
+ # self.layerx_2 = Bottleneck(256, 64, stride=1, dilation=1, downsample=None)
116
+ # self.layerx_3 = Bottleneck_downsample(256, 64, stride=2, dilation=1)
117
+
118
+
119
+ class Res_block_1(nn.Module):
120
+ expansion = 4
121
+
122
+ def __init__(self, inplanes=64, planes=64, stride=1, dilation=1):
123
+ super(Res_block_1, self).__init__()
124
+
125
+ self.conv1 = nn.Sequential(
126
+ nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False),
127
+ nn.GroupNorm(8, planes),
128
+ nn.ReLU(inplace=True))
129
+
130
+ self.conv2 = nn.Sequential(
131
+ nn.Conv2d(planes, planes, kernel_size=3, stride=1,
132
+ padding=1, bias=False, dilation=1),
133
+ nn.GroupNorm(8, planes),
134
+ nn.ReLU(inplace=True))
135
+
136
+ self.conv3 = nn.Sequential(
137
+ nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False),
138
+ nn.GroupNorm(8, planes * 4))
139
+
140
+ self.relu = nn.ReLU(inplace=True)
141
+
142
+ self.down_sample = nn.Sequential(
143
+ nn.Conv2d(inplanes, planes * 4,
144
+ kernel_size=1, stride=1, bias=False),
145
+ nn.GroupNorm(8, planes * 4))
146
+
147
+ def forward(self, x):
148
+ # residual = x
149
+
150
+ out = self.conv1(x)
151
+ out = self.conv2(out)
152
+ out = self.conv3(out)
153
+ residual = self.down_sample(x)
154
+ out += residual
155
+ out = self.relu(out)
156
+
157
+ return out
158
+
159
+
160
+ class Res_block_2(nn.Module):
161
+ expansion = 4
162
+
163
+ def __init__(self, inplanes=256, planes=64, stride=1, dilation=1):
164
+ super(Res_block_2, self).__init__()
165
+
166
+ self.conv1 = nn.Sequential(
167
+ nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False),
168
+ nn.GroupNorm(8, planes),
169
+ nn.ReLU(inplace=True))
170
+
171
+ self.conv2 = nn.Sequential(
172
+ nn.Conv2d(planes, planes, kernel_size=3, stride=1,
173
+ padding=1, bias=False, dilation=1),
174
+ nn.GroupNorm(8, planes),
175
+ nn.ReLU(inplace=True))
176
+
177
+ self.conv3 = nn.Sequential(
178
+ nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False),
179
+ nn.GroupNorm(8, planes * 4))
180
+
181
+ self.relu = nn.ReLU(inplace=True)
182
+
183
+ def forward(self, x):
184
+ residual = x
185
+
186
+ out = self.conv1(x)
187
+ out = self.conv2(out)
188
+ out = self.conv3(out)
189
+
190
+ out += residual
191
+ out = self.relu(out)
192
+
193
+ return out
194
+
195
+
196
+ class Res_block_3(nn.Module):
197
+ expansion = 4
198
+
199
+ def __init__(self, inplanes=256, planes=64, stride=1, dilation=1):
200
+ super(Res_block_3, self).__init__()
201
+
202
+ self.conv1 = nn.Sequential(
203
+ nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False),
204
+ nn.GroupNorm(8, planes),
205
+ nn.ReLU(inplace=True))
206
+
207
+ self.conv2 = nn.Sequential(
208
+ nn.Conv2d(planes, planes, kernel_size=3, stride=1,
209
+ padding=1, bias=False, dilation=1),
210
+ nn.GroupNorm(8, planes),
211
+ nn.ReLU(inplace=True))
212
+
213
+ self.conv3 = nn.Sequential(
214
+ nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False),
215
+ nn.GroupNorm(8, planes * 4))
216
+
217
+ self.relu = nn.ReLU(inplace=True)
218
+
219
+ self.downsample = nn.Sequential(
220
+ nn.Conv2d(inplanes, planes * 4,
221
+ kernel_size=1, stride=stride, bias=False),
222
+ nn.GroupNorm(8, planes * 4))
223
+
224
+ def forward(self, x):
225
+ # residual = x
226
+
227
+ out = self.conv1(x)
228
+ out = self.conv2(out)
229
+ out = self.conv3(out)
230
+ # residual = self.downsample(x)
231
+ out += self.downsample(x)
232
+ out = self.relu(out)
233
+
234
+ return out
235
+
236
+
237
+ class Classifier_Module(nn.Module):
238
+
239
+ def __init__(self, dilation_series, padding_series, num_classes):
240
+ super(Classifier_Module, self).__init__()
241
+ self.conv2d_list = nn.ModuleList()
242
+ for dilation, padding in zip(dilation_series, padding_series):
243
+ self.conv2d_list.append(
244
+ nn.Conv2d(2048, num_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=True))
245
+
246
+ for m in self.conv2d_list:
247
+ m.weight.data.normal_(0, 0.01)
248
+
249
+ def forward(self, x):
250
+ out = self.conv2d_list[0](x)
251
+ for i in range(len(self.conv2d_list) - 1):
252
+ out += self.conv2d_list[i + 1](x)
253
+ return out
254
+
255
+
256
+ class _ConvBNReLU(nn.Module):
257
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
258
+ dilation=1, groups=1, relu6=False, norm_layer=nn.BatchNorm2d):
259
+ super(_ConvBNReLU, self).__init__()
260
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
261
+ self.bn = norm_layer(out_channels)
262
+ self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True)
263
+
264
+ def forward(self, x):
265
+ x = self.conv(x)
266
+ x = self.bn(x)
267
+ x = self.relu(x)
268
+ return x
269
+
270
+
271
+ class _ASPPConv(nn.Module):
272
+ def __init__(self, in_channels, out_channels, atrous_rate, norm_layer):
273
+ super(_ASPPConv, self).__init__()
274
+ self.block = nn.Sequential(
275
+ nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate, dilation=atrous_rate, bias=False),
276
+ norm_layer(out_channels),
277
+ nn.ReLU(True)
278
+ )
279
+
280
+ def forward(self, x):
281
+ return self.block(x)
282
+
283
+
284
+ class _AsppPooling(nn.Module):
285
+ def __init__(self, in_channels, out_channels, norm_layer):
286
+ super(_AsppPooling, self).__init__()
287
+ self.gap = nn.Sequential(
288
+ nn.AdaptiveAvgPool2d(1),
289
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
290
+ norm_layer(out_channels),
291
+ nn.ReLU(True)
292
+ )
293
+
294
+ def forward(self, x):
295
+ size = x.size()[2:]
296
+ pool = self.gap(x)
297
+ out = F.interpolate(pool, size, mode='bilinear', align_corners=True)
298
+ return out
299
+
300
+
301
+ class _ASPP(nn.Module):
302
+ def __init__(self, in_channels, atrous_rates, norm_layer):
303
+ super(_ASPP, self).__init__()
304
+ out_channels = 256
305
+ self.b0 = nn.Sequential(
306
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
307
+ norm_layer(out_channels),
308
+ nn.ReLU(True)
309
+ )
310
+
311
+ rate1, rate2, rate3 = tuple(atrous_rates)
312
+ self.b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
313
+ self.b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
314
+ self.b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
315
+ self.b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
316
+
317
+ self.project = nn.Sequential(
318
+ nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
319
+ norm_layer(out_channels),
320
+ nn.ReLU(True),
321
+ nn.Dropout(0.5)
322
+ )
323
+
324
+ def forward(self, x):
325
+ feat1 = self.b0(x)
326
+ feat2 = self.b1(x)
327
+ feat3 = self.b2(x)
328
+ feat4 = self.b3(x)
329
+ feat5 = self.b4(x)
330
+ x = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
331
+ x = self.project(x)
332
+ return x
333
+
334
+
335
+ class _DeepLabHead(nn.Module):
336
+ def __init__(self, num_classes, c1_channels=256, norm_layer=nn.BatchNorm2d):
337
+ super(_DeepLabHead, self).__init__()
338
+ self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer)
339
+ self.c1_block = _ConvBNReLU(c1_channels, 48, 3, padding=1, norm_layer=norm_layer)
340
+ self.block = nn.Sequential(
341
+ _ConvBNReLU(304, 256, 3, padding=1, norm_layer=norm_layer),
342
+ nn.Dropout(0.5),
343
+ _ConvBNReLU(256, 256, 3, padding=1, norm_layer=norm_layer),
344
+ nn.Dropout(0.1),
345
+ nn.Conv2d(256, num_classes, 1))
346
+
347
+ def forward(self, x, c1):
348
+ size = c1.size()[2:]
349
+ c1 = self.c1_block(c1)
350
+ x = self.aspp(x)
351
+ x = F.interpolate(x, size, mode='bilinear', align_corners=True)
352
+ return self.block(torch.cat([x, c1], dim=1))
353
+
354
+
355
+ class _CARM(nn.Module):
356
+ def __init__(self, in_planes, ratio=8):
357
+ super(_CARM, self).__init__()
358
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
359
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
360
+
361
+ self.fc1_1 = nn.Linear(in_planes, in_planes // ratio)
362
+ self.fc1_2 = nn.Linear(in_planes // ratio, in_planes)
363
+
364
+ self.fc2_1 = nn.Linear(in_planes, in_planes // ratio)
365
+ self.fc2_2 = nn.Linear(in_planes // ratio, in_planes)
366
+ self.relu = nn.ReLU(True)
367
+
368
+ self.sigmoid = nn.Sigmoid()
369
+
370
+ def forward(self, x):
371
+ avg_out = self.avg_pool(x)
372
+ avg_out = avg_out.view(avg_out.size(0), -1)
373
+ avg_out = self.fc1_2(self.relu(self.fc1_1(avg_out)))
374
+
375
+ max_out = self.max_pool(x)
376
+ max_out = max_out.view(max_out.size(0), -1)
377
+ max_out = self.fc2_2(self.relu(self.fc2_1(max_out)))
378
+
379
+ max_out_size = max_out.size()[1]
380
+ avg_out = torch.reshape(avg_out, (-1, max_out_size, 1, 1))
381
+ max_out = torch.reshape(max_out, (-1, max_out_size, 1, 1))
382
+
383
+ out = self.sigmoid(avg_out + max_out)
384
+
385
+ x = out * x
386
+ return x
387
+
388
+
389
+ class FSFB_CH(nn.Module):
390
+ def __init__(self, in_planes, num, ratio=8):
391
+ super(FSFB_CH, self).__init__()
392
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
393
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
394
+
395
+ self.fc1_1 = nn.Linear(in_planes, in_planes // ratio)
396
+ self.fc1_2 = nn.Linear(in_planes // ratio, num * in_planes)
397
+
398
+ self.fc2_1 = nn.Linear(in_planes, in_planes // ratio)
399
+ self.fc2_2 = nn.Linear(in_planes // ratio, num * in_planes)
400
+ self.relu = nn.ReLU(True)
401
+
402
+ self.fc3 = nn.Linear(num * in_planes, 2 * num * in_planes)
403
+ self.fc4 = nn.Linear(2 * num * in_planes, 2 * num * in_planes)
404
+ self.fc5 = nn.Linear(2 * num * in_planes, num * in_planes)
405
+
406
+ self.softmax = nn.Softmax(dim=3)
407
+
408
+ def forward(self, x, num):
409
+ avg_out = self.avg_pool(x)
410
+ avg_out = avg_out.view(avg_out.size(0), -1)
411
+ avg_out = self.fc1_2(self.relu(self.fc1_1(avg_out)))
412
+
413
+ max_out = self.max_pool(x)
414
+ max_out = max_out.view(max_out.size(0), -1)
415
+ max_out = self.fc2_2(self.relu(self.fc2_1(max_out)))
416
+
417
+ out = avg_out + max_out
418
+ out = self.relu(self.fc3(out))
419
+ out = self.relu(self.fc4(out))
420
+ out = self.relu(self.fc5(out)) # (N, num*in_planes)
421
+
422
+ out_size = out.size()[1]
423
+ out = torch.reshape(out, (-1, out_size // num, 1, num)) # (N, in_planes, 1, num )
424
+ out = self.softmax(out)
425
+
426
+ channel_scale = torch.chunk(out, num, dim=3) # (N, in_planes, 1, 1 )
427
+
428
+ return channel_scale
429
+
430
+
431
+ class FSFB_SP(nn.Module):
432
+ def __init__(self, num, norm_layer=nn.BatchNorm2d):
433
+ super(FSFB_SP, self).__init__()
434
+ self.conv = nn.Sequential(
435
+ nn.Conv2d(2, 2 * num, kernel_size=3, padding=1, bias=False),
436
+ norm_layer(2 * num),
437
+ nn.ReLU(True),
438
+ nn.Conv2d(2 * num, 4 * num, kernel_size=3, padding=1, bias=False),
439
+ norm_layer(4 * num),
440
+ nn.ReLU(True),
441
+ nn.Conv2d(4 * num, 4 * num, kernel_size=3, padding=1, bias=False),
442
+ norm_layer(4 * num),
443
+ nn.ReLU(True),
444
+ nn.Conv2d(4 * num, 2 * num, kernel_size=3, padding=1, bias=False),
445
+ norm_layer(2 * num),
446
+ nn.ReLU(True),
447
+ nn.Conv2d(2 * num, num, kernel_size=3, padding=1, bias=False)
448
+ )
449
+ self.softmax = nn.Softmax(dim=1)
450
+
451
+ def forward(self, x, num):
452
+ avg_out = torch.mean(x, dim=1, keepdim=True)
453
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
454
+ x = torch.cat([avg_out, max_out], dim=1)
455
+ x = self.conv(x)
456
+ x = self.softmax(x)
457
+ spatial_scale = torch.chunk(x, num, dim=1)
458
+ return spatial_scale
459
+
460
+
461
+ ##################################################################################################################
462
+
463
+
464
+ class _HFFM(nn.Module):
465
+ def __init__(self, in_channels, atrous_rates, norm_layer=nn.BatchNorm2d):
466
+ super(_HFFM, self).__init__()
467
+ out_channels = 256
468
+ self.b0 = nn.Sequential(
469
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
470
+ norm_layer(out_channels),
471
+ nn.ReLU(True)
472
+ )
473
+
474
+ rate1, rate2, rate3 = tuple(atrous_rates)
475
+ self.b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
476
+ self.b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
477
+ self.b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
478
+ self.b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
479
+ self.carm = _CARM(in_channels)
480
+ self.sa = FSFB_SP(4, norm_layer)
481
+ self.ca = FSFB_CH(out_channels, 4, 8)
482
+
483
+ def forward(self, x, num):
484
+ x = self.carm(x)
485
+ # feat1 = self.b0(x)
486
+ feat1 = self.b1(x)
487
+ feat2 = self.b2(x)
488
+ feat3 = self.b3(x)
489
+ feat4 = self.b4(x)
490
+ feat = feat1 + feat2 + feat3 + feat4
491
+ spatial_atten = self.sa(feat, num)
492
+ channel_atten = self.ca(feat, num)
493
+
494
+ feat_ca = channel_atten[0] * feat1 + channel_atten[1] * feat2 + channel_atten[2] * feat3 + channel_atten[
495
+ 3] * feat4
496
+ feat_sa = spatial_atten[0] * feat1 + spatial_atten[1] * feat2 + spatial_atten[2] * feat3 + spatial_atten[
497
+ 3] * feat4
498
+ feat_sa = feat_sa + feat_ca
499
+
500
+ return feat_sa
501
+
502
+
503
+ class _AFFM(nn.Module):
504
+ def __init__(self, in_channels=256, norm_layer=nn.BatchNorm2d):
505
+ super(_AFFM, self).__init__()
506
+
507
+ self.sa = FSFB_SP(2, norm_layer)
508
+ self.ca = FSFB_CH(in_channels, 2, 8)
509
+ self.carm = _CARM(in_channels)
510
+
511
+ def forward(self, feat1, feat2, hffm, num):
512
+ feat = feat1 + feat2
513
+ spatial_atten = self.sa(feat, num)
514
+ channel_atten = self.ca(feat, num)
515
+
516
+ feat_ca = channel_atten[0] * feat1 + channel_atten[1] * feat2
517
+ feat_sa = spatial_atten[0] * feat1 + spatial_atten[1] * feat2
518
+ output = self.carm(feat_sa + feat_ca + hffm)
519
+ # output = self.carm (feat_sa + hffm)
520
+
521
+ return output, channel_atten, spatial_atten
522
+
523
+
524
+ class block_Conv3x3(nn.Module):
525
+ def __init__(self, in_channels):
526
+ super(block_Conv3x3, self).__init__()
527
+ self.block = nn.Sequential(
528
+ nn.Conv2d(in_channels, 256, kernel_size=3, stride=1, padding=1, bias=False),
529
+ nn.BatchNorm2d(256),
530
+ nn.ReLU(True)
531
+ )
532
+
533
+ def forward(self, x):
534
+ return self.block(x)
535
+
536
+
537
+ class CDnetV2(nn.Module):
538
+ def __init__(self, block=Bottleneck, layers=[3, 4, 6, 3], num_classes=21, aux=True):
539
+ self.inplanes = 256 # change
540
+ self.aux = aux
541
+ super().__init__()
542
+ # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
543
+ # self.bn1 = nn.BatchNorm2d(64, affine = affine_par)
544
+
545
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
546
+ self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
547
+
548
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
549
+ self.bn2 = nn.BatchNorm2d(64, affine=affine_par)
550
+
551
+ self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
552
+ self.bn3 = nn.BatchNorm2d(64, affine=affine_par)
553
+
554
+ self.relu = nn.ReLU(inplace=True)
555
+
556
+ self.dropout = nn.Dropout(0.3)
557
+ for i in self.bn1.parameters():
558
+ i.requires_grad = False
559
+
560
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change
561
+
562
+ # self.layer1 = self._make_layer(block, 64, layers[0])
563
+
564
+ self.layerx_1 = Res_block_1(64, 64, stride=1, dilation=1)
565
+ self.layerx_2 = Res_block_2(256, 64, stride=1, dilation=1)
566
+ self.layerx_3 = Res_block_3(256, 64, stride=2, dilation=1)
567
+
568
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
569
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
570
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
571
+ # self.layer5 = self._make_pred_layer(Classifier_Module, [6,12,18,24],[6,12,18,24],num_classes)
572
+
573
+ self.hffm = _HFFM(2048, [6, 12, 18])
574
+ self.affm_1 = _AFFM()
575
+ self.affm_2 = _AFFM()
576
+ self.affm_3 = _AFFM()
577
+ self.affm_4 = _AFFM()
578
+ self.carm = _CARM(256)
579
+
580
+ self.con_layer1_1 = block_Conv3x3(256)
581
+ self.con_res2 = block_Conv3x3(256)
582
+ self.con_res3 = block_Conv3x3(512)
583
+ self.con_res4 = block_Conv3x3(1024)
584
+ self.con_res5 = block_Conv3x3(2048)
585
+
586
+ self.dsn1 = nn.Sequential(
587
+ nn.Conv2d(256, num_classes, kernel_size=1, stride=1, padding=0)
588
+ )
589
+
590
+ self.dsn2 = nn.Sequential(
591
+ nn.Conv2d(256, num_classes, kernel_size=1, stride=1, padding=0)
592
+ )
593
+
594
+ for m in self.modules():
595
+ if isinstance(m, nn.Conv2d):
596
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
597
+ m.weight.data.normal_(0, 0.01)
598
+ elif isinstance(m, nn.BatchNorm2d):
599
+ m.weight.data.fill_(1)
600
+ m.bias.data.zero_()
601
+ # for i in m.parameters():
602
+ # i.requires_grad = False
603
+
604
+ # self.inplanes = 256 # change
605
+
606
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
607
+ downsample = None
608
+ if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4:
609
+ downsample = nn.Sequential(
610
+ nn.Conv2d(self.inplanes, planes * block.expansion,
611
+ kernel_size=1, stride=stride, bias=False),
612
+ nn.BatchNorm2d(planes * block.expansion, affine=affine_par))
613
+ for i in downsample._modules['1'].parameters():
614
+ i.requires_grad = False
615
+ layers = []
616
+ layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))
617
+ self.inplanes = planes * block.expansion
618
+ for i in range(1, blocks):
619
+ layers.append(block(self.inplanes, planes, dilation=dilation))
620
+
621
+ return nn.Sequential(*layers)
622
+
623
+ # def _make_pred_layer(self,block, dilation_series, padding_series,num_classes):
624
+ # return block(dilation_series,padding_series,num_classes)
625
+
626
+ def base_forward(self, x):
627
+ x = self.relu(self.bn1(self.conv1(x))) # 1/2
628
+ x = self.relu(self.bn2(self.conv2(x)))
629
+ x = self.relu(self.bn3(self.conv3(x)))
630
+ x = self.maxpool(x) # 1/4
631
+
632
+ # x = self.layer1(x) # 1/8
633
+
634
+ # layer1
635
+ x = self.layerx_1(x) # 1/4
636
+ layer1_0 = x
637
+
638
+ x = self.layerx_2(x) # 1/4
639
+ layer1_0 = self.con_layer1_1(x + layer1_0) # 256
640
+ size_layer1_0 = layer1_0.size()[2:]
641
+
642
+ x = self.layerx_3(x) # 1/8
643
+ res2 = self.con_res2(x) # 256
644
+ size_res2 = res2.size()[2:]
645
+
646
+ # layer2-4
647
+ x = self.layer2(x) # 1/16
648
+ res3 = self.con_res3(x) # 256
649
+ x = self.layer3(x) # 1/16
650
+
651
+ res4 = self.con_res4(x) # 256
652
+ x = self.layer4(x) # 1/16
653
+ res5 = self.con_res5(x) # 256
654
+
655
+ # x = self.res5_con1x1(torch.cat([x, res4], dim=1))
656
+ return layer1_0, res2, res3, res4, res5, x, size_layer1_0, size_res2
657
+
658
+ # return res2, res3, res4, res5, x, layer_1024, size_res2
659
+
660
+ def forward(self, x):
661
+ # size = x.size()[2:]
662
+ layer1_0, res2, res3, res4, res5, layer4, size_layer1_0, size_res2 = self.base_forward(x)
663
+
664
+ hffm = self.hffm(layer4, 4) # 256 HFFM
665
+ res5 = res5 + hffm
666
+ aux_feature = res5 # loss_aux
667
+ # res5 = self.carm(res5)
668
+ res5, _, _ = self.affm_1(res4, res5, hffm, 2) # 1/16
669
+ # aux_feature = res5
670
+ res5, _, _ = self.affm_2(res3, res5, hffm, 2) # 1/16
671
+
672
+ res5 = F.interpolate(res5, size_res2, mode='bilinear', align_corners=True)
673
+ res5, _, _ = self.affm_3(res2, res5, F.interpolate(hffm, size_res2, mode='bilinear', align_corners=True), 2)
674
+
675
+ res5 = F.interpolate(res5, size_layer1_0, mode='bilinear', align_corners=True)
676
+ res5, _, _ = self.affm_4(layer1_0, res5,
677
+ F.interpolate(hffm, size_layer1_0, mode='bilinear', align_corners=True), 2)
678
+
679
+ output = self.dsn1(res5)
680
+
681
+ if self.aux:
682
+ auxout = self.dsn2(aux_feature)
683
+ # auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
684
+ # outputs.append(auxout)
685
+ size = x.size()[2:]
686
+ pred, pred_aux = output, auxout
687
+ pred = F.interpolate(pred, size, mode='bilinear', align_corners=True)
688
+ pred_aux = F.interpolate(pred_aux, size, mode='bilinear', align_corners=True)
689
+ return pred, pred_aux
690
+
691
+
692
+ if __name__ == '__main__':
693
+ model = CDnetV2(num_classes=3)
694
+ fake_image = torch.rand(2, 3, 256, 256)
695
+ output = model(fake_image)
696
+ for out in output:
697
+ print(out.shape)
698
+ # torch.Size([2, 3, 256, 256])
699
+ # torch.Size([2, 3, 256, 256])
src/models/components/cnn.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class CNN(nn.Module):
6
+ def __init__(self, dim=32):
7
+ super(CNN, self).__init__()
8
+ self.conv1 = nn.Conv2d(1, dim, 5)
9
+ self.conv2 = nn.Conv2d(dim, dim * 2, 5)
10
+ self.fc1 = nn.Linear(dim * 2 * 4 * 4, 10)
11
+
12
+ def forward(self, x):
13
+ x = torch.relu(self.conv1(x))
14
+ x = torch.max_pool2d(x, 2)
15
+ x = torch.relu(self.conv2(x))
16
+ x = torch.max_pool2d(x, 2)
17
+ x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
18
+ x = self.fc1(x)
19
+ return x
20
+
21
+
22
+ if __name__ == "__main__":
23
+ input = torch.randn(2, 1, 28, 28)
24
+ model = CNN()
25
+ output = model(input)
26
+ assert output.shape == (2, 10)
src/models/components/dual_branch.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/7/26 上午11:19
3
+ # @Author : xiaoshun
4
+ # @Email : 3038523973@qq.com
5
+ # @File : dual_branch.py
6
+ # @Software: PyCharm
7
+
8
+ from einops import rearrange
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ # from models.Transformer.ViT import truncated_normal_
15
+
16
+ # Decoder细化卷积模块
17
+ class SBR(nn.Module):
18
+ def __init__(self, in_ch):
19
+ super(SBR, self).__init__()
20
+ self.conv1x3 = nn.Sequential(
21
+ nn.Conv2d(in_ch, in_ch, kernel_size=(1, 3), stride=1, padding=(0, 1)),
22
+ nn.BatchNorm2d(in_ch),
23
+ nn.ReLU(True)
24
+ )
25
+ self.conv3x1 = nn.Sequential(
26
+ nn.Conv2d(in_ch, in_ch, kernel_size=(3, 1), stride=1, padding=(1, 0)),
27
+ nn.BatchNorm2d(in_ch),
28
+ nn.ReLU(True)
29
+ )
30
+
31
+ def forward(self, x):
32
+ out = self.conv3x1(self.conv1x3(x)) # 先进行1x3的卷积,得到结果并将结果再进行3x1的卷积
33
+ return out + x
34
+
35
+
36
+ # 下采样卷积模块 stage 1,2,3
37
+ class c_stage123(nn.Module):
38
+ def __init__(self, in_chans, out_chans):
39
+ super().__init__()
40
+ self.stage123 = nn.Sequential(
41
+ nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=3, stride=2, padding=1),
42
+ nn.BatchNorm2d(out_chans),
43
+ nn.ReLU(),
44
+ nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1),
45
+ nn.BatchNorm2d(out_chans),
46
+ nn.ReLU(),
47
+ )
48
+ self.conv1x1_123 = nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=1)
49
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
50
+
51
+ def forward(self, x):
52
+ stage123 = self.stage123(x) # 3*3卷积,两倍下采样 3*224*224-->64*112*112
53
+ max = self.maxpool(x) # 最大值池化,两倍下采样 3*224*224-->3*112*112
54
+ max = self.conv1x1_123(max) # 1*1卷积 3*112*112-->64*112*112
55
+ stage123 = stage123 + max # 残差结构,广播机制
56
+ return stage123
57
+
58
+
59
+ # 下采样卷积模块 stage4,5
60
+ class c_stage45(nn.Module):
61
+ def __init__(self, in_chans, out_chans):
62
+ super().__init__()
63
+ self.stage45 = nn.Sequential(
64
+ nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=3, stride=2, padding=1),
65
+ nn.BatchNorm2d(out_chans),
66
+ nn.ReLU(),
67
+ nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1),
68
+ nn.BatchNorm2d(out_chans),
69
+ nn.ReLU(),
70
+ nn.Conv2d(in_channels=out_chans, out_channels=out_chans, kernel_size=3, stride=1, padding=1),
71
+ nn.BatchNorm2d(out_chans),
72
+ nn.ReLU(),
73
+ )
74
+ self.conv1x1_45 = nn.Conv2d(in_channels=in_chans, out_channels=out_chans, kernel_size=1)
75
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
76
+
77
+ def forward(self, x):
78
+ stage45 = self.stage45(x) # 3*3卷积模块 2倍下采样
79
+ max = self.maxpool(x) # 最大值池化,两倍下采样
80
+ max = self.conv1x1_45(max) # 1*1卷积模块 调整通道数
81
+ stage45 = stage45 + max # 残差结构
82
+ return stage45
83
+
84
+
85
+ class Identity(nn.Module): # 恒等映射
86
+ def __init__(self):
87
+ super().__init__()
88
+
89
+ def forward(self, x):
90
+ return x
91
+
92
+
93
+ # 轻量卷积模块
94
+ class DepthwiseConv2d(nn.Module): # 用于自注意力机制
95
+ def __init__(self, in_chans, out_chans, kernel_size=1, stride=1, padding=0, dilation=1):
96
+ super().__init__()
97
+ # depthwise conv
98
+ self.depthwise = nn.Conv2d(
99
+ in_channels=in_chans,
100
+ out_channels=in_chans,
101
+ kernel_size=kernel_size,
102
+ stride=stride,
103
+ padding=padding,
104
+ dilation=dilation, # 深层卷积的膨胀率
105
+ groups=in_chans # 指定分组卷积的组数
106
+ )
107
+ # batch norm
108
+ self.bn = nn.BatchNorm2d(num_features=in_chans)
109
+
110
+ # pointwise conv 逐点卷积
111
+ self.pointwise = nn.Conv2d(
112
+ in_channels=in_chans,
113
+ out_channels=out_chans,
114
+ kernel_size=1
115
+ )
116
+
117
+ def forward(self, x):
118
+ x = self.depthwise(x)
119
+ x = self.bn(x)
120
+ x = self.pointwise(x)
121
+ return x
122
+
123
+
124
+ # residual skip connection 残差跳跃连接
125
+ class Residual(nn.Module):
126
+ def __init__(self, fn):
127
+ super().__init__()
128
+ self.fn = fn
129
+
130
+ def forward(self, input, **kwargs):
131
+ x = self.fn(input, **kwargs)
132
+ return (x + input)
133
+
134
+
135
+ # layer norm plus 层归一化
136
+ class PreNorm(nn.Module): # 代表神经网络层
137
+ def __init__(self, dim, fn):
138
+ super().__init__()
139
+ self.norm = nn.LayerNorm(dim)
140
+ self.fn = fn
141
+
142
+ def forward(self, input, **kwargs):
143
+ return self.fn(self.norm(input), **kwargs)
144
+
145
+
146
+ # FeedForward层使得representation的表达能力更强
147
+ class FeedForward(nn.Module):
148
+ def __init__(self, dim, hidden_dim, dropout=0.):
149
+ super().__init__()
150
+ self.net = nn.Sequential(
151
+ nn.Linear(in_features=dim, out_features=hidden_dim),
152
+ nn.GELU(),
153
+ nn.Dropout(dropout),
154
+ nn.Linear(in_features=hidden_dim, out_features=dim),
155
+ nn.Dropout(dropout)
156
+ )
157
+
158
+ def forward(self, input):
159
+ return self.net(input)
160
+
161
+
162
+ class ConvAttnetion(nn.Module):
163
+ '''
164
+ using the Depth_Separable_Wise Conv2d to produce the q, k, v instead of using Linear Project in ViT
165
+ '''
166
+
167
+ def __init__(self, dim, img_size, heads=8, dim_head=64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1,
168
+ dropout=0., last_stage=False):
169
+ super().__init__()
170
+ self.last_stage = last_stage
171
+ self.img_size = img_size
172
+ inner_dim = dim_head * heads # 512
173
+ project_out = not (heads == 1 and dim_head == dim)
174
+
175
+ self.heads = heads
176
+ self.scale = dim_head ** (-0.5)
177
+
178
+ pad = (kernel_size - q_stride) // 2
179
+
180
+ self.to_q = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=q_stride,
181
+ padding=pad) # 自注意力机制
182
+ self.to_k = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=k_stride,
183
+ padding=pad)
184
+ self.to_v = DepthwiseConv2d(in_chans=dim, out_chans=inner_dim, kernel_size=kernel_size, stride=v_stride,
185
+ padding=pad)
186
+
187
+ self.to_out = nn.Sequential(
188
+ nn.Linear(
189
+ in_features=inner_dim,
190
+ out_features=dim
191
+ ),
192
+ nn.Dropout(dropout)
193
+ ) if project_out else Identity()
194
+
195
+ def forward(self, x):
196
+ b, n, c, h = *x.shape, self.heads # * 星号的作用大概是去掉 tuple 属性吧
197
+
198
+ # print(x.shape)
199
+ # print('+++++++++++++++++++++++++++++++++')
200
+
201
+ # if语句内容没有使用
202
+ if self.last_stage:
203
+ cls_token = x[:, 0]
204
+ # print(cls_token.shape)
205
+ # print('+++++++++++++++++++++++++++++++++')
206
+ x = x[:, 1:] # 去掉每个数组的第一个元素
207
+
208
+ cls_token = rearrange(torch.unsqueeze(cls_token, dim=1), 'b n (h d) -> b h n d', h=h)
209
+
210
+ # rearrange:用于对张量的维度进行重新变换排序,可用于替换pytorch中的reshape,view,transpose和permute等操作
211
+ x = rearrange(x, 'b (l w) n -> b n l w', l=self.img_size, w=self.img_size) # [1, 3136, 64]-->1*64*56*56
212
+ # batch_size,N(通道数),h,w
213
+
214
+ q = self.to_q(x) # 1*64*56*56-->1*64*56*56
215
+ # print(q.shape)
216
+ # print('++++++++++++++')
217
+ q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h) # 1*64*56*56-->1*1*3136*64
218
+ # print(q.shape)
219
+ # print('=====================')
220
+ # batch_size,head,h*w,dim_head
221
+
222
+ k = self.to_k(x) # 操作和q一样
223
+ k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h)
224
+ # batch_size,head,h*w,dim_head
225
+
226
+ v = self.to_v(x) ##操作和q一样
227
+ # print(v.shape)
228
+ # print('[[[[[[[[[[[[[[[[[[[[[[[[[[[[')
229
+ v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h)
230
+ # print(v.shape)
231
+ # print(']]]]]]]]]]]]]]]]]]]]]]]]]]]')
232
+ # batch_size,head,h*w,dim_head
233
+
234
+ if self.last_stage:
235
+ # print(q.shape)
236
+ # print('================')
237
+ q = torch.cat([cls_token, q], dim=2)
238
+ # print(q.shape)
239
+ # print('++++++++++++++++++')
240
+ v = torch.cat([cls_token, v], dim=2)
241
+ k = torch.cat([cls_token, k], dim=2)
242
+
243
+ # calculate attention by matmul + scale
244
+ # permute:(batch_size,head,dim_head,h*w
245
+ # print(k.shape)
246
+ # print('++++++++++++++++++++')
247
+ k = k.permute(0, 1, 3, 2) # 1*1*3136*64-->1*1*64*3136
248
+ # print(k.shape)
249
+ # print('====================')
250
+ attention = (q.matmul(k)) # 1*1*3136*3136
251
+ # print(attention.shape)
252
+ # print('--------------------')
253
+ attention = attention * self.scale # 可以得到一个logit的向量,避免出现梯度下降和梯度爆炸
254
+ # print(attention.shape)
255
+ # print('####################')
256
+ # pass a softmax
257
+ attention = F.softmax(attention, dim=-1)
258
+ # print(attention.shape)
259
+ # print('********************')
260
+
261
+ # matmul v
262
+ # attention.matmul(v):(batch_size,head,h*w,dim_head)
263
+ # permute:(batch_size,h*w,head,dim_head)
264
+ out = (attention.matmul(v)).permute(0, 2, 1, 3).reshape(b, n,
265
+ c) # 1*3136*64 这些操作的目的是将注意力权重和值向量相乘后得到的结果进行重塑,得到一个形状为 (batch size, 序列长度, 值向量或矩阵的维度) 的张量
266
+
267
+ # linear project
268
+ out = self.to_out(out)
269
+ return out
270
+
271
+
272
+ # Reshape Layers
273
+ class Rearrange(nn.Module):
274
+ def __init__(self, string, h, w):
275
+ super().__init__()
276
+ self.string = string
277
+ self.h = h
278
+ self.w = w
279
+
280
+ def forward(self, input):
281
+
282
+ if self.string == 'b c h w -> b (h w) c':
283
+ N, C, H, W = input.shape
284
+ # print(input.shape)
285
+ x = torch.reshape(input, shape=(N, -1, self.h * self.w)).permute(0, 2, 1)
286
+ # print(x.shape)
287
+ # print('+++++++++++++++++++')
288
+ if self.string == 'b (h w) c -> b c h w':
289
+ N, _, C = input.shape
290
+ # print(input.shape)
291
+ x = torch.reshape(input, shape=(N, self.h, self.w, -1)).permute(0, 3, 1, 2)
292
+ # print(x.shape)
293
+ # print('=====================')
294
+ return x
295
+
296
+
297
+ # Transformer layers
298
+ class Transformer(nn.Module):
299
+ def __init__(self, dim, img_size, depth, heads, dim_head, mlp_dim, dropout=0., last_stage=False):
300
+ super().__init__()
301
+ self.layers = nn.ModuleList([ # 管理子模块,参数注册
302
+ nn.ModuleList([
303
+ PreNorm(dim=dim, fn=ConvAttnetion(dim, img_size, heads=heads, dim_head=dim_head, dropout=dropout,
304
+ last_stage=last_stage)), # 归一化,重参数化
305
+ PreNorm(dim=dim, fn=FeedForward(dim=dim, hidden_dim=mlp_dim, dropout=dropout))
306
+ ]) for _ in range(depth)
307
+ ])
308
+
309
+ def forward(self, x):
310
+ for attn, ff in self.layers:
311
+ x = x + attn(x)
312
+ x = x + ff(x)
313
+ return x
314
+
315
+
316
+ class Dual_Branch(nn.Module): # 最主要的大函数
317
+ def __init__(self, img_size, in_channels, num_classes, dim=64, kernels=[7, 3, 3, 3], strides=[4, 2, 2, 2],
318
+ heads=[1, 3, 6, 6],
319
+ depth=[1, 2, 10, 10], pool='cls', dropout=0., emb_dropout=0., scale_dim=4, ):
320
+ super().__init__()
321
+
322
+ assert pool in ['cls', 'mean'], f'pool type must be either cls or mean pooling'
323
+ self.pool = pool
324
+ self.dim = dim
325
+
326
+ # stage1
327
+ # k:7 s:4 in: 1, 64, 56, 56 out: 1, 3136, 64
328
+ self.stage1_conv_embed = nn.Sequential(
329
+ nn.Conv2d( # 1*3*224*224-->[1, 64, 56, 56]
330
+ in_channels=in_channels,
331
+ out_channels=dim,
332
+ kernel_size=kernels[0],
333
+ stride=strides[0],
334
+ padding=2
335
+ ),
336
+ Rearrange('b c h w -> b (h w) c', h=img_size // 4, w=img_size // 4), # [1, 64, 56, 56]-->[1, 3136, 64]
337
+ nn.LayerNorm(dim) # 对每个batch归一化
338
+ )
339
+
340
+ self.stage1_transformer = nn.Sequential(
341
+ Transformer( #
342
+ dim=dim,
343
+ img_size=img_size // 4,
344
+ depth=depth[0], # Transformer层中的编码器和解码器层数。
345
+ heads=heads[0],
346
+ dim_head=self.dim, # 它是每个注意力头的维度大小,通常是嵌入维度除以头数。
347
+ mlp_dim=dim * scale_dim, # mlp_dim:它是Transformer中前馈神经网络的隐藏层维度大小,通常是嵌入维度乘以一个缩放因子。
348
+ dropout=dropout,
349
+ # last_stage=last_stage #它是一个标志位,用于表示该Transformer层是否是最后一层。
350
+ ),
351
+ Rearrange('b (h w) c -> b c h w', h=img_size // 4, w=img_size // 4)
352
+ )
353
+
354
+ # stage2
355
+ # k:3 s:2 in: 1, 192, 28, 28 out: 1, 784, 192
356
+ in_channels = dim
357
+ scale = heads[1] // heads[0]
358
+ dim = scale * dim
359
+
360
+ self.stage2_conv_embed = nn.Sequential(
361
+ nn.Conv2d(
362
+ in_channels=in_channels,
363
+ out_channels=dim,
364
+ kernel_size=kernels[1],
365
+ stride=strides[1],
366
+ padding=1
367
+ ),
368
+ Rearrange('b c h w -> b (h w) c', h=img_size // 8, w=img_size // 8),
369
+ nn.LayerNorm(dim)
370
+ )
371
+
372
+ self.stage2_transformer = nn.Sequential(
373
+ Transformer(
374
+ dim=dim,
375
+ img_size=img_size // 8,
376
+ depth=depth[1],
377
+ heads=heads[1],
378
+ dim_head=self.dim,
379
+ mlp_dim=dim * scale_dim,
380
+ dropout=dropout
381
+ ),
382
+ Rearrange('b (h w) c -> b c h w', h=img_size // 8, w=img_size // 8)
383
+ )
384
+
385
+ # stage3
386
+ in_channels = dim
387
+ scale = heads[2] // heads[1]
388
+ dim = scale * dim
389
+
390
+ self.stage3_conv_embed = nn.Sequential(
391
+ nn.Conv2d(
392
+ in_channels=in_channels,
393
+ out_channels=dim,
394
+ kernel_size=kernels[2],
395
+ stride=strides[2],
396
+ padding=1
397
+ ),
398
+ Rearrange('b c h w -> b (h w) c', h=img_size // 16, w=img_size // 16),
399
+ nn.LayerNorm(dim)
400
+ )
401
+
402
+ self.stage3_transformer = nn.Sequential(
403
+ Transformer(
404
+ dim=dim,
405
+ img_size=img_size // 16,
406
+ depth=depth[2],
407
+ heads=heads[2],
408
+ dim_head=self.dim,
409
+ mlp_dim=dim * scale_dim,
410
+ dropout=dropout
411
+ ),
412
+ Rearrange('b (h w) c -> b c h w', h=img_size // 16, w=img_size // 16)
413
+ )
414
+
415
+ # stage4
416
+ in_channels = dim
417
+ scale = heads[3] // heads[2]
418
+ dim = scale * dim
419
+
420
+ self.stage4_conv_embed = nn.Sequential(
421
+ nn.Conv2d(
422
+ in_channels=in_channels,
423
+ out_channels=dim,
424
+ kernel_size=kernels[3],
425
+ stride=strides[3],
426
+ padding=1
427
+ ),
428
+ Rearrange('b c h w -> b (h w) c', h=img_size // 32, w=img_size // 32),
429
+ nn.LayerNorm(dim)
430
+ )
431
+
432
+ self.stage4_transformer = nn.Sequential(
433
+ Transformer(
434
+ dim=dim, img_size=img_size // 32,
435
+ depth=depth[3],
436
+ heads=heads[3],
437
+ dim_head=self.dim,
438
+ mlp_dim=dim * scale_dim,
439
+ dropout=dropout,
440
+ ),
441
+ Rearrange('b (h w) c -> b c h w', h=img_size // 32, w=img_size // 32)
442
+ )
443
+
444
+ ### CNN Branch ###
445
+ self.c_stage1 = c_stage123(in_chans=3, out_chans=64)
446
+ self.c_stage2 = c_stage123(in_chans=64, out_chans=128)
447
+ self.c_stage3 = c_stage123(in_chans=128, out_chans=384)
448
+ self.c_stage4 = c_stage45(in_chans=384, out_chans=512)
449
+ self.c_stage5 = c_stage45(in_chans=512, out_chans=1024)
450
+ self.c_max = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
451
+ self.up_conv1 = nn.Conv2d(in_channels=192, out_channels=128, kernel_size=1)
452
+ self.up_conv2 = nn.Conv2d(in_channels=384, out_channels=512, kernel_size=1)
453
+
454
+ ### CTmerge ###
455
+ self.CTmerge1 = nn.Sequential(
456
+ nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
457
+ nn.BatchNorm2d(64),
458
+ nn.ReLU(),
459
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
460
+ nn.BatchNorm2d(64),
461
+ nn.ReLU(),
462
+ )
463
+ self.CTmerge2 = nn.Sequential(
464
+ nn.Conv2d(in_channels=320, out_channels=128, kernel_size=3, stride=1, padding=1),
465
+ nn.BatchNorm2d(128),
466
+ nn.ReLU(),
467
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
468
+ nn.BatchNorm2d(128),
469
+ nn.ReLU(),
470
+ )
471
+ self.CTmerge3 = nn.Sequential(
472
+ nn.Conv2d(in_channels=768, out_channels=512, kernel_size=3, stride=1, padding=1),
473
+ nn.BatchNorm2d(512),
474
+ nn.ReLU(),
475
+ nn.Conv2d(in_channels=512, out_channels=384, kernel_size=3, stride=1, padding=1),
476
+ nn.BatchNorm2d(384),
477
+ nn.ReLU(),
478
+ nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1),
479
+ nn.BatchNorm2d(384),
480
+ nn.ReLU(),
481
+ )
482
+
483
+ self.CTmerge4 = nn.Sequential(
484
+ nn.Conv2d(in_channels=896, out_channels=640, kernel_size=3, stride=1, padding=1),
485
+ nn.BatchNorm2d(640),
486
+ nn.ReLU(),
487
+ nn.Conv2d(in_channels=640, out_channels=512, kernel_size=3, stride=1, padding=1),
488
+ nn.BatchNorm2d(512),
489
+ nn.ReLU(),
490
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
491
+ nn.BatchNorm2d(512),
492
+ nn.ReLU(),
493
+ )
494
+
495
+ # decoder
496
+ self.decoder4 = nn.Sequential(
497
+ DepthwiseConv2d(
498
+ in_chans=1408,
499
+ out_chans=1024,
500
+ kernel_size=3,
501
+ stride=1,
502
+ padding=1
503
+ ),
504
+ DepthwiseConv2d(
505
+ in_chans=1024,
506
+ out_chans=512,
507
+ kernel_size=3,
508
+ stride=1,
509
+ padding=1
510
+ ),
511
+ nn.GELU()
512
+ )
513
+ self.decoder3 = nn.Sequential(
514
+ DepthwiseConv2d(
515
+ in_chans=896,
516
+ out_chans=512,
517
+ kernel_size=3,
518
+ stride=1,
519
+ padding=1
520
+ ),
521
+ DepthwiseConv2d(
522
+ in_chans=512,
523
+ out_chans=384,
524
+ kernel_size=3,
525
+ stride=1,
526
+ padding=1
527
+ ),
528
+ nn.GELU()
529
+ )
530
+
531
+ self.decoder2 = nn.Sequential(
532
+ DepthwiseConv2d(
533
+ in_chans=576,
534
+ out_chans=256,
535
+ kernel_size=3,
536
+ stride=1,
537
+ padding=1
538
+ ),
539
+ DepthwiseConv2d(
540
+ in_chans=256,
541
+ out_chans=192,
542
+ kernel_size=3,
543
+ stride=1,
544
+ padding=1
545
+ ),
546
+ nn.GELU()
547
+ )
548
+
549
+ self.decoder1 = nn.Sequential(
550
+ DepthwiseConv2d(
551
+ in_chans=256,
552
+ out_chans=64,
553
+ kernel_size=3,
554
+ stride=1,
555
+ padding=1
556
+ ),
557
+ DepthwiseConv2d(
558
+ in_chans=64,
559
+ out_chans=16,
560
+ kernel_size=3,
561
+ stride=1,
562
+ padding=1
563
+ ),
564
+ nn.GELU()
565
+ )
566
+ self.sbr4 = SBR(512)
567
+ self.sbr3 = SBR(384)
568
+ self.sbr2 = SBR(192)
569
+ self.sbr1 = SBR(16)
570
+
571
+ self.head = nn.Conv2d(in_channels=16, out_channels=num_classes, kernel_size=1)
572
+
573
+ def forward(self, input):
574
+ ### encoder ###
575
+ # stage1 = ts1 cat cs1
576
+ # t_s1 = self.t_stage1(input)
577
+ # print(input.shape)
578
+ # print('++++++++++++++++++++++')
579
+
580
+ t_s1 = self.stage1_conv_embed(input) # 1*3*224*224-->1*3136*64
581
+
582
+ # print(t_s1.shape)
583
+ # print('======================')
584
+
585
+ t_s1 = self.stage1_transformer(t_s1) # 1*3136*64-->1*64*56*56
586
+
587
+ # print(t_s1.shape)
588
+ # print('----------------------')
589
+
590
+ c_s1 = self.c_stage1(input) # 1*3*224*224-->1*64*112*112
591
+
592
+ # print(c_s1.shape)
593
+ # print('!!!!!!!!!!!!!!!!!!!!!!!')
594
+
595
+ stage1 = self.CTmerge1(torch.cat([t_s1, self.c_max(c_s1)], dim=1)) # 1*64*56*56 # 拼接两条分支
596
+
597
+ # print(stage1.shape)
598
+ # print('[[[[[[[[[[[[[[[[[[[[[[[')
599
+
600
+ # stage2 = ts2 up cs2
601
+ # t_s2 = self.t_stage2(stage1)
602
+ t_s2 = self.stage2_conv_embed(stage1) # 1*64*56*56-->1*784*192 # stage2_conv_embed是转化为序列操作
603
+
604
+ # print(t_s2.shape)
605
+ # print('[[[[[[[[[[[[[[[[[[[[[[[')
606
+ t_s2 = self.stage2_transformer(t_s2) # 1*784*192-->1*192*28*28
607
+ # print(t_s2.shape)
608
+ # print('+++++++++++++++++++++++++')
609
+
610
+ c_s2 = self.c_stage2(c_s1) # 1*64*112*112-->1*128*56*56
611
+ stage2 = self.CTmerge2(
612
+ torch.cat([c_s2, F.interpolate(t_s2, size=c_s2.size()[2:], mode='bilinear', align_corners=True)],
613
+ dim=1)) # mode='bilinear'表示使用双线性插值 1*128*56*56
614
+
615
+ # stage3 = ts3 cat cs3
616
+ # t_s3 = self.t_stage3(t_s2)
617
+ t_s3 = self.stage3_conv_embed(t_s2) # 1*192*28*28-->1*196*384
618
+ # print(t_s3.shape)
619
+ # print('///////////////////////')
620
+ t_s3 = self.stage3_transformer(t_s3) # 1*196*384-->1*384*14*14
621
+ # print(t_s3.shape)
622
+ # print('....................')
623
+ c_s3 = self.c_stage3(stage2) # 1*128*56*56-->1*384*28*28
624
+ stage3 = self.CTmerge3(torch.cat([t_s3, self.c_max(c_s3)], dim=1)) # 1*384*14*14
625
+
626
+ # stage4 = ts4 up cs4
627
+ # t_s4 = self.t_stage4(stage3)
628
+ t_s4 = self.stage4_conv_embed(stage3) # 1*384*14*14-->1*49*384
629
+ # print(t_s4.shape)
630
+ # print(';;;;;;;;;;;;;;;;;;;;;;;')
631
+ t_s4 = self.stage4_transformer(t_s4) # 1*49*384-->1*384*7*7
632
+ # print(t_s4.shape)
633
+ # print('::::::::::::::::::::')
634
+
635
+ c_s4 = self.c_stage4(c_s3) # 1*384*28*28-->1*512*14*14
636
+ stage4 = self.CTmerge4(
637
+ torch.cat([c_s4, F.interpolate(t_s4, size=c_s4.size()[2:], mode='bilinear', align_corners=True)],
638
+ dim=1)) # 1*512*14*14
639
+
640
+ # cs5
641
+ c_s5 = self.c_stage5(stage4) # 1*512*14*14-->1*1024*7*7
642
+
643
+ ### decoder ###
644
+ decoder4 = torch.cat([c_s5, t_s4], dim=1) # 1*1408*7*7
645
+ decoder4 = self.decoder4(decoder4) # 1*1408*7*7-->1*512*7*7
646
+ decoder4 = F.interpolate(decoder4, size=c_s3.size()[2:], mode='bilinear',
647
+ align_corners=True) # 1*512*7*7-->1*512*28*28
648
+ decoder4 = self.sbr4(decoder4) # 1*512*28*28
649
+ # print(decoder4.shape)
650
+
651
+ decoder3 = torch.cat([decoder4, c_s3], dim=1) # 1*896*28*28
652
+ decoder3 = self.decoder3(decoder3) # 1*384*28*28
653
+ decoder3 = F.interpolate(decoder3, size=t_s2.size()[2:], mode='bilinear', align_corners=True) # 1*384*28*28
654
+ decoder3 = self.sbr3(decoder3) # 1*384*28*28
655
+ # print(decoder3.shape)
656
+
657
+ decoder2 = torch.cat([decoder3, t_s2], dim=1) # 1*576*28*28
658
+ decoder2 = self.decoder2(decoder2) # 1*192*28*28
659
+ decoder2 = F.interpolate(decoder2, size=c_s1.size()[2:], mode='bilinear', align_corners=True) # 1*192*112*112
660
+ decoder2 = self.sbr2(decoder2) # 1*192*112*112
661
+ # print(decoder2.shape)
662
+
663
+ decoder1 = torch.cat([decoder2, c_s1], dim=1) # 1*256*112*112
664
+ decoder1 = self.decoder1(decoder1) # 1*16*112*112
665
+ # print(decoder1.shape)
666
+ final = F.interpolate(decoder1, size=input.size()[2:], mode='bilinear', align_corners=True) # 1*16*224*224
667
+ # print(final.shape)
668
+ # final = self.sbr1(decoder1)
669
+ # print(final.shape)
670
+ final = self.head(final) # 1*3*224*224
671
+
672
+ return final
673
+
674
+
675
+ if __name__ == '__main__':
676
+ x = torch.rand(1, 3, 224, 224).cuda()
677
+ model = Dual_Branch(img_size=224, in_channels=3, num_classes=7).cuda()
678
+ y = model(x)
679
+ print(y.shape)
680
+ # torch.Size([1, 7, 224, 224])
src/models/components/hrcloud.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 论文地址:https://arxiv.org/abs/2407.07365
2
+ #
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import logging
8
+ import os
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch._utils
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ BatchNorm2d = nn.BatchNorm2d
17
+ # BN_MOMENTUM = 0.01
18
+ relu_inplace = True
19
+ BN_MOMENTUM = 0.1
20
+ ALIGN_CORNERS = True
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def conv3x3(in_planes, out_planes, stride=1):
26
+ """3x3 convolution with padding"""
27
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28
+ padding=1, bias=False)
29
+
30
+
31
+ from yacs.config import CfgNode as CN
32
+ import math
33
+ from einops import rearrange
34
+
35
+ # configs for HRNet48
36
+ HRNET_48 = CN()
37
+ HRNET_48.FINAL_CONV_KERNEL = 1
38
+
39
+ HRNET_48.STAGE1 = CN()
40
+ HRNET_48.STAGE1.NUM_MODULES = 1
41
+ HRNET_48.STAGE1.NUM_BRANCHES = 1
42
+ HRNET_48.STAGE1.NUM_BLOCKS = [4]
43
+ HRNET_48.STAGE1.NUM_CHANNELS = [64]
44
+ HRNET_48.STAGE1.BLOCK = 'BOTTLENECK'
45
+ HRNET_48.STAGE1.FUSE_METHOD = 'SUM'
46
+
47
+ HRNET_48.STAGE2 = CN()
48
+ HRNET_48.STAGE2.NUM_MODULES = 1
49
+ HRNET_48.STAGE2.NUM_BRANCHES = 2
50
+ HRNET_48.STAGE2.NUM_BLOCKS = [4, 4]
51
+ HRNET_48.STAGE2.NUM_CHANNELS = [48, 96]
52
+ HRNET_48.STAGE2.BLOCK = 'BASIC'
53
+ HRNET_48.STAGE2.FUSE_METHOD = 'SUM'
54
+
55
+ HRNET_48.STAGE3 = CN()
56
+ HRNET_48.STAGE3.NUM_MODULES = 4
57
+ HRNET_48.STAGE3.NUM_BRANCHES = 3
58
+ HRNET_48.STAGE3.NUM_BLOCKS = [4, 4, 4]
59
+ HRNET_48.STAGE3.NUM_CHANNELS = [48, 96, 192]
60
+ HRNET_48.STAGE3.BLOCK = 'BASIC'
61
+ HRNET_48.STAGE3.FUSE_METHOD = 'SUM'
62
+
63
+ HRNET_48.STAGE4 = CN()
64
+ HRNET_48.STAGE4.NUM_MODULES = 3
65
+ HRNET_48.STAGE4.NUM_BRANCHES = 4
66
+ HRNET_48.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
67
+ HRNET_48.STAGE4.NUM_CHANNELS = [48, 96, 192, 384]
68
+ HRNET_48.STAGE4.BLOCK = 'BASIC'
69
+ HRNET_48.STAGE4.FUSE_METHOD = 'SUM'
70
+
71
+ HRNET_32 = CN()
72
+ HRNET_32.FINAL_CONV_KERNEL = 1
73
+
74
+ HRNET_32.STAGE1 = CN()
75
+ HRNET_32.STAGE1.NUM_MODULES = 1
76
+ HRNET_32.STAGE1.NUM_BRANCHES = 1
77
+ HRNET_32.STAGE1.NUM_BLOCKS = [4]
78
+ HRNET_32.STAGE1.NUM_CHANNELS = [64]
79
+ HRNET_32.STAGE1.BLOCK = 'BOTTLENECK'
80
+ HRNET_32.STAGE1.FUSE_METHOD = 'SUM'
81
+
82
+ HRNET_32.STAGE2 = CN()
83
+ HRNET_32.STAGE2.NUM_MODULES = 1
84
+ HRNET_32.STAGE2.NUM_BRANCHES = 2
85
+ HRNET_32.STAGE2.NUM_BLOCKS = [4, 4]
86
+ HRNET_32.STAGE2.NUM_CHANNELS = [32, 64]
87
+ HRNET_32.STAGE2.BLOCK = 'BASIC'
88
+ HRNET_32.STAGE2.FUSE_METHOD = 'SUM'
89
+
90
+ HRNET_32.STAGE3 = CN()
91
+ HRNET_32.STAGE3.NUM_MODULES = 4
92
+ HRNET_32.STAGE3.NUM_BRANCHES = 3
93
+ HRNET_32.STAGE3.NUM_BLOCKS = [4, 4, 4]
94
+ HRNET_32.STAGE3.NUM_CHANNELS = [32, 64, 128]
95
+ HRNET_32.STAGE3.BLOCK = 'BASIC'
96
+ HRNET_32.STAGE3.FUSE_METHOD = 'SUM'
97
+
98
+ HRNET_32.STAGE4 = CN()
99
+ HRNET_32.STAGE4.NUM_MODULES = 3
100
+ HRNET_32.STAGE4.NUM_BRANCHES = 4
101
+ HRNET_32.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
102
+ HRNET_32.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
103
+ HRNET_32.STAGE4.BLOCK = 'BASIC'
104
+ HRNET_32.STAGE4.FUSE_METHOD = 'SUM'
105
+
106
+ HRNET_18 = CN()
107
+ HRNET_18.FINAL_CONV_KERNEL = 1
108
+
109
+ HRNET_18.STAGE1 = CN()
110
+ HRNET_18.STAGE1.NUM_MODULES = 1
111
+ HRNET_18.STAGE1.NUM_BRANCHES = 1
112
+ HRNET_18.STAGE1.NUM_BLOCKS = [4]
113
+ HRNET_18.STAGE1.NUM_CHANNELS = [64]
114
+ HRNET_18.STAGE1.BLOCK = 'BOTTLENECK'
115
+ HRNET_18.STAGE1.FUSE_METHOD = 'SUM'
116
+
117
+ HRNET_18.STAGE2 = CN()
118
+ HRNET_18.STAGE2.NUM_MODULES = 1
119
+ HRNET_18.STAGE2.NUM_BRANCHES = 2
120
+ HRNET_18.STAGE2.NUM_BLOCKS = [4, 4]
121
+ HRNET_18.STAGE2.NUM_CHANNELS = [18, 36]
122
+ HRNET_18.STAGE2.BLOCK = 'BASIC'
123
+ HRNET_18.STAGE2.FUSE_METHOD = 'SUM'
124
+
125
+ HRNET_18.STAGE3 = CN()
126
+ HRNET_18.STAGE3.NUM_MODULES = 4
127
+ HRNET_18.STAGE3.NUM_BRANCHES = 3
128
+ HRNET_18.STAGE3.NUM_BLOCKS = [4, 4, 4]
129
+ HRNET_18.STAGE3.NUM_CHANNELS = [18, 36, 72]
130
+ HRNET_18.STAGE3.BLOCK = 'BASIC'
131
+ HRNET_18.STAGE3.FUSE_METHOD = 'SUM'
132
+
133
+ HRNET_18.STAGE4 = CN()
134
+ HRNET_18.STAGE4.NUM_MODULES = 3
135
+ HRNET_18.STAGE4.NUM_BRANCHES = 4
136
+ HRNET_18.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
137
+ HRNET_18.STAGE4.NUM_CHANNELS = [18, 36, 72, 144]
138
+ HRNET_18.STAGE4.BLOCK = 'BASIC'
139
+ HRNET_18.STAGE4.FUSE_METHOD = 'SUM'
140
+
141
+
142
+ class PPM(nn.Module):
143
+ def __init__(self, in_dim, reduction_dim, bins):
144
+ super(PPM, self).__init__()
145
+ self.features = []
146
+ for bin in bins:
147
+ self.features.append(nn.Sequential(
148
+ nn.AdaptiveAvgPool2d(bin),
149
+ nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
150
+ nn.BatchNorm2d(reduction_dim),
151
+ nn.ReLU(inplace=True)
152
+ ))
153
+ self.features = nn.ModuleList(self.features)
154
+
155
+ def forward(self, x):
156
+ x_size = x.size()
157
+ out = [x]
158
+ for f in self.features:
159
+ out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
160
+ return torch.cat(out, 1)
161
+
162
+
163
+ class BasicBlock(nn.Module):
164
+ expansion = 1
165
+
166
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
167
+ super(BasicBlock, self).__init__()
168
+ self.conv1 = conv3x3(inplanes, planes, stride)
169
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
170
+ self.relu = nn.ReLU(inplace=relu_inplace)
171
+ self.conv2 = conv3x3(planes, planes)
172
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
173
+ self.downsample = downsample
174
+ self.stride = stride
175
+
176
+ def forward(self, x):
177
+ residual = x
178
+
179
+ out = self.conv1(x)
180
+ out = self.bn1(out)
181
+ out = self.relu(out)
182
+
183
+ out = self.conv2(out)
184
+ out = self.bn2(out)
185
+
186
+ if self.downsample is not None:
187
+ residual = self.downsample(x)
188
+ out = out + residual
189
+ out = self.relu(out)
190
+
191
+ return out
192
+
193
+
194
+ class Bottleneck(nn.Module):
195
+ expansion = 4
196
+
197
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
198
+ super(Bottleneck, self).__init__()
199
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
200
+ self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
201
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
202
+ padding=1, bias=False)
203
+ self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
204
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
205
+ bias=False)
206
+ self.bn3 = BatchNorm2d(planes * self.expansion,
207
+ momentum=BN_MOMENTUM)
208
+ self.relu = nn.ReLU(inplace=relu_inplace)
209
+ self.downsample = downsample
210
+ self.stride = stride
211
+
212
+ def forward(self, x):
213
+ residual = x
214
+
215
+ out = self.conv1(x)
216
+ out = self.bn1(out)
217
+ out = self.relu(out)
218
+
219
+ out = self.conv2(out)
220
+ out = self.bn2(out)
221
+ out = self.relu(out)
222
+
223
+ out = self.conv3(out)
224
+ out = self.bn3(out)
225
+
226
+ if self.downsample is not None:
227
+ residual = self.downsample(x)
228
+ # att = self.downsample(att)
229
+ out = out + residual
230
+ out = self.relu(out)
231
+
232
+ return out
233
+
234
+
235
+ class HighResolutionModule(nn.Module):
236
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
237
+ num_channels, fuse_method, multi_scale_output=True):
238
+ super(HighResolutionModule, self).__init__()
239
+ self._check_branches(
240
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
241
+
242
+ self.num_inchannels = num_inchannels
243
+ self.fuse_method = fuse_method
244
+ self.num_branches = num_branches
245
+
246
+ self.multi_scale_output = multi_scale_output
247
+
248
+ self.branches = self._make_branches(
249
+ num_branches, blocks, num_blocks, num_channels)
250
+ self.fuse_layers = self._make_fuse_layers()
251
+ self.relu = nn.ReLU(inplace=relu_inplace)
252
+
253
+ def _check_branches(self, num_branches, blocks, num_blocks,
254
+ num_inchannels, num_channels):
255
+ if num_branches != len(num_blocks):
256
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
257
+ num_branches, len(num_blocks))
258
+ logger.error(error_msg)
259
+ raise ValueError(error_msg)
260
+
261
+ if num_branches != len(num_channels):
262
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
263
+ num_branches, len(num_channels))
264
+ logger.error(error_msg)
265
+ raise ValueError(error_msg)
266
+
267
+ if num_branches != len(num_inchannels):
268
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
269
+ num_branches, len(num_inchannels))
270
+ logger.error(error_msg)
271
+ raise ValueError(error_msg)
272
+
273
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
274
+ stride=1):
275
+ downsample = None
276
+ if stride != 1 or \
277
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
278
+ downsample = nn.Sequential(
279
+ nn.Conv2d(self.num_inchannels[branch_index],
280
+ num_channels[branch_index] * block.expansion,
281
+ kernel_size=1, stride=stride, bias=False),
282
+ BatchNorm2d(num_channels[branch_index] * block.expansion,
283
+ momentum=BN_MOMENTUM),
284
+ )
285
+
286
+ layers = []
287
+ layers.append(block(self.num_inchannels[branch_index],
288
+ num_channels[branch_index], stride, downsample))
289
+ self.num_inchannels[branch_index] = \
290
+ num_channels[branch_index] * block.expansion
291
+ for i in range(1, num_blocks[branch_index]):
292
+ layers.append(block(self.num_inchannels[branch_index],
293
+ num_channels[branch_index]))
294
+
295
+ return nn.Sequential(*layers)
296
+
297
+ # 创建平行层
298
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
299
+ branches = []
300
+
301
+ for i in range(num_branches):
302
+ branches.append(
303
+ self._make_one_branch(i, block, num_blocks, num_channels))
304
+
305
+ return nn.ModuleList(branches)
306
+
307
+ def _make_fuse_layers(self):
308
+ if self.num_branches == 1:
309
+ return None
310
+ num_branches = self.num_branches # 3
311
+ num_inchannels = self.num_inchannels # [48, 96, 192]
312
+ fuse_layers = []
313
+ for i in range(num_branches if self.multi_scale_output else 1):
314
+ fuse_layer = []
315
+ for j in range(num_branches):
316
+ if j > i:
317
+ fuse_layer.append(nn.Sequential(
318
+ nn.Conv2d(num_inchannels[j],
319
+ num_inchannels[i],
320
+ 1,
321
+ 1,
322
+ 0,
323
+ bias=False),
324
+ BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
325
+ elif j == i:
326
+ fuse_layer.append(None)
327
+ else:
328
+ conv3x3s = []
329
+ for k in range(i - j):
330
+ if k == i - j - 1:
331
+ num_outchannels_conv3x3 = num_inchannels[i]
332
+ conv3x3s.append(nn.Sequential(
333
+ nn.Conv2d(num_inchannels[j],
334
+ num_outchannels_conv3x3,
335
+ 3, 2, 1, bias=False),
336
+ BatchNorm2d(num_outchannels_conv3x3,
337
+ momentum=BN_MOMENTUM)))
338
+ else:
339
+ num_outchannels_conv3x3 = num_inchannels[j]
340
+ conv3x3s.append(nn.Sequential(
341
+ nn.Conv2d(num_inchannels[j],
342
+ num_outchannels_conv3x3,
343
+ 3, 2, 1, bias=False),
344
+ BatchNorm2d(num_outchannels_conv3x3,
345
+ momentum=BN_MOMENTUM),
346
+ nn.ReLU(inplace=relu_inplace)))
347
+ fuse_layer.append(nn.Sequential(*conv3x3s))
348
+ fuse_layers.append(nn.ModuleList(fuse_layer))
349
+
350
+ return nn.ModuleList(fuse_layers)
351
+
352
+ def get_num_inchannels(self):
353
+ return self.num_inchannels
354
+
355
+ def forward(self, x):
356
+ if self.num_branches == 1:
357
+ return [self.branches[0](x[0])]
358
+
359
+ for i in range(self.num_branches):
360
+ x[i] = self.branches[i](x[i])
361
+
362
+ x_fuse = []
363
+ for i in range(len(self.fuse_layers)):
364
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
365
+ for j in range(1, self.num_branches):
366
+ if i == j:
367
+ y = y + x[j]
368
+ elif j > i:
369
+ width_output = x[i].shape[-1]
370
+ height_output = x[i].shape[-2]
371
+ y = y + F.interpolate(
372
+ self.fuse_layers[i][j](x[j]),
373
+ size=[height_output, width_output],
374
+ mode='bilinear', align_corners=ALIGN_CORNERS)
375
+ else:
376
+ y = y + self.fuse_layers[i][j](x[j])
377
+ x_fuse.append(self.relu(y))
378
+
379
+ return x_fuse
380
+
381
+
382
+ blocks_dict = {
383
+ 'BASIC': BasicBlock,
384
+ 'BOTTLENECK': Bottleneck
385
+ }
386
+
387
+
388
+ class HRcloudNet(nn.Module):
389
+
390
+ def __init__(self, num_classes=2, base_c=48, **kwargs):
391
+ global ALIGN_CORNERS
392
+ extra = HRNET_48
393
+ super(HRcloudNet, self).__init__()
394
+ ALIGN_CORNERS = True
395
+ # ALIGN_CORNERS = config.MODEL.ALIGN_CORNERS
396
+ self.num_classes = num_classes
397
+ # stem net
398
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
399
+ bias=False)
400
+ self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM)
401
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
402
+ bias=False)
403
+ self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM)
404
+ self.relu = nn.ReLU(inplace=relu_inplace)
405
+
406
+ self.stage1_cfg = extra['STAGE1']
407
+ num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
408
+ block = blocks_dict[self.stage1_cfg['BLOCK']]
409
+ num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
410
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
411
+ stage1_out_channel = block.expansion * num_channels
412
+
413
+ self.stage2_cfg = extra['STAGE2']
414
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
415
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
416
+ num_channels = [
417
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
418
+ self.transition1 = self._make_transition_layer(
419
+ [stage1_out_channel], num_channels)
420
+ self.stage2, pre_stage_channels = self._make_stage(
421
+ self.stage2_cfg, num_channels)
422
+
423
+ self.stage3_cfg = extra['STAGE3']
424
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
425
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
426
+ num_channels = [
427
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
428
+ self.transition2 = self._make_transition_layer(
429
+ pre_stage_channels, num_channels) # 只在pre[-1]与cur[-1]之间下采样?
430
+ self.stage3, pre_stage_channels = self._make_stage(
431
+ self.stage3_cfg, num_channels)
432
+
433
+ self.stage4_cfg = extra['STAGE4']
434
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
435
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
436
+ num_channels = [
437
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
438
+ self.transition3 = self._make_transition_layer(
439
+ pre_stage_channels, num_channels)
440
+ self.stage4, pre_stage_channels = self._make_stage(
441
+ self.stage4_cfg, num_channels, multi_scale_output=True)
442
+ self.out_conv = OutConv(base_c, num_classes)
443
+ last_inp_channels = int(np.sum(pre_stage_channels))
444
+
445
+ self.corr = Corr(nclass=2)
446
+ self.proj = nn.Sequential(
447
+ # 512 32
448
+ nn.Conv2d(720, 48, kernel_size=3, stride=1, padding=1, bias=True),
449
+ nn.BatchNorm2d(48),
450
+ nn.ReLU(inplace=True),
451
+ nn.Dropout2d(0.1),
452
+ )
453
+ # self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
454
+ self.up2 = Up(base_c * 8, base_c * 4, True)
455
+ self.up3 = Up(base_c * 4, base_c * 2, True)
456
+ self.up4 = Up(base_c * 2, base_c, True)
457
+ fea_dim = 720
458
+ bins = (1, 2, 3, 6)
459
+ self.ppm = PPM(fea_dim, int(fea_dim / len(bins)), bins)
460
+ fea_dim *= 2
461
+ self.cls = nn.Sequential(
462
+ nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
463
+ nn.BatchNorm2d(512),
464
+ nn.ReLU(inplace=True),
465
+ nn.Dropout2d(p=0.1),
466
+ nn.Conv2d(512, 2, kernel_size=1)
467
+ )
468
+
469
+ '''
470
+ 转换层的作用有两种情况:
471
+
472
+ 当前分支数小于之前分支数时,仅对前几个分支进行通道数调整。
473
+ 当前分支数大于之前分支数时,新建一些转换层,对多余的分支进行下采样,改变通道数以适应后续的连接。
474
+ 最终,这些转换层会被组合成一个 nn.ModuleList 对象,并在网络的构建过程中使用。
475
+ 这有助于确保每个分支的通道数在不同阶段之间能够正确匹配,以便进行特征的融合和连接
476
+ '''
477
+
478
+ def _make_transition_layer(
479
+ self, num_channels_pre_layer, num_channels_cur_layer):
480
+ # 现在的分支数
481
+ num_branches_cur = len(num_channels_cur_layer) # 3
482
+ # 处理前的分支数
483
+ num_branches_pre = len(num_channels_pre_layer) # 2
484
+
485
+ transition_layers = []
486
+ for i in range(num_branches_cur):
487
+ # 如果当前分支数小于之前分支数,仅针对第一到第二阶段
488
+ if i < num_branches_pre:
489
+ # 如果对应层的通道数不一致,则进行转化(
490
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
491
+ transition_layers.append(nn.Sequential(
492
+
493
+ nn.Conv2d(num_channels_pre_layer[i],
494
+ num_channels_cur_layer[i],
495
+ 3,
496
+ 1,
497
+ 1,
498
+ bias=False),
499
+ BatchNorm2d(
500
+ num_channels_cur_layer[i], momentum=BN_MOMENTUM),
501
+ nn.ReLU(inplace=relu_inplace)))
502
+ else:
503
+ transition_layers.append(None)
504
+ else: # 在新建层下采样改变通道数
505
+ conv3x3s = []
506
+ for j in range(i + 1 - num_branches_pre): # 3
507
+ inchannels = num_channels_pre_layer[-1]
508
+ outchannels = num_channels_cur_layer[i] \
509
+ if j == i - num_branches_pre else inchannels
510
+ conv3x3s.append(nn.Sequential(
511
+ nn.Conv2d(
512
+ inchannels, outchannels, 3, 2, 1, bias=False),
513
+ BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
514
+ nn.ReLU(inplace=relu_inplace)))
515
+ transition_layers.append(nn.Sequential(*conv3x3s))
516
+
517
+ return nn.ModuleList(transition_layers)
518
+
519
+ '''
520
+ _make_layer 函数的主要作用是创建一个由多个相同类型的残差块(Residual Block)组成的层。
521
+ '''
522
+
523
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
524
+ downsample = None
525
+ if stride != 1 or inplanes != planes * block.expansion:
526
+ downsample = nn.Sequential(
527
+ nn.Conv2d(inplanes, planes * block.expansion,
528
+ kernel_size=1, stride=stride, bias=False),
529
+ BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
530
+ )
531
+
532
+ layers = []
533
+ layers.append(block(inplanes, planes, stride, downsample))
534
+ inplanes = planes * block.expansion
535
+ for i in range(1, blocks):
536
+ layers.append(block(inplanes, planes))
537
+
538
+ return nn.Sequential(*layers)
539
+
540
+ # 多尺度融合
541
+ def _make_stage(self, layer_config, num_inchannels,
542
+ multi_scale_output=True):
543
+ num_modules = layer_config['NUM_MODULES']
544
+ num_branches = layer_config['NUM_BRANCHES']
545
+ num_blocks = layer_config['NUM_BLOCKS']
546
+ num_channels = layer_config['NUM_CHANNELS']
547
+ block = blocks_dict[layer_config['BLOCK']]
548
+ fuse_method = layer_config['FUSE_METHOD']
549
+
550
+ modules = []
551
+ for i in range(num_modules): # 重复4次
552
+ # multi_scale_output is only used last module
553
+ if not multi_scale_output and i == num_modules - 1:
554
+ reset_multi_scale_output = False
555
+ else:
556
+ reset_multi_scale_output = True
557
+ modules.append(
558
+ HighResolutionModule(num_branches,
559
+ block,
560
+ num_blocks,
561
+ num_inchannels,
562
+ num_channels,
563
+ fuse_method,
564
+ reset_multi_scale_output)
565
+ )
566
+ num_inchannels = modules[-1].get_num_inchannels()
567
+
568
+ return nn.Sequential(*modules), num_inchannels
569
+
570
+ def forward(self, input, need_fp=True, use_corr=True):
571
+ # from ipdb import set_trace
572
+ # set_trace()
573
+ x = self.conv1(input)
574
+ x = self.bn1(x)
575
+ x = self.relu(x)
576
+ # x_176 = x
577
+ x = self.conv2(x)
578
+ x = self.bn2(x)
579
+ x = self.relu(x)
580
+ x = self.layer1(x)
581
+
582
+ x_list = []
583
+ for i in range(self.stage2_cfg['NUM_BRANCHES']): # 2
584
+ if self.transition1[i] is not None:
585
+ x_list.append(self.transition1[i](x))
586
+ else:
587
+ x_list.append(x)
588
+ y_list = self.stage2(x_list)
589
+ # Y1
590
+ x_list = []
591
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
592
+ if self.transition2[i] is not None:
593
+ if i < self.stage2_cfg['NUM_BRANCHES']:
594
+ x_list.append(self.transition2[i](y_list[i]))
595
+ else:
596
+ x_list.append(self.transition2[i](y_list[-1]))
597
+ else:
598
+ x_list.append(y_list[i])
599
+ y_list = self.stage3(x_list)
600
+
601
+ x_list = []
602
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
603
+ if self.transition3[i] is not None:
604
+ if i < self.stage3_cfg['NUM_BRANCHES']:
605
+ x_list.append(self.transition3[i](y_list[i]))
606
+ else:
607
+ x_list.append(self.transition3[i](y_list[-1]))
608
+ else:
609
+ x_list.append(y_list[i])
610
+ x = self.stage4(x_list)
611
+ dict_return = {}
612
+ # Upsampling
613
+ x0_h, x0_w = x[0].size(2), x[0].size(3)
614
+
615
+ x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS)
616
+ # x = self.stage3_(x)
617
+ x[2] = self.up2(x[3], x[2])
618
+ x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS)
619
+ # x = self.stage2_(x)
620
+ x[1] = self.up3(x[2], x[1])
621
+ x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS)
622
+ x[0] = self.up4(x[1], x[0])
623
+ xk = torch.cat([x[0], x1, x2, x3], 1)
624
+ # PPM
625
+ feat = self.ppm(xk)
626
+ x = self.cls(feat)
627
+ # fp分支
628
+ if need_fp:
629
+ logits = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
630
+ # logits = self.out_conv(torch.cat((x, nn.Dropout2d(0.5)(x))))
631
+ out = logits
632
+ out_fp = logits
633
+ if use_corr:
634
+ proj_feats = self.proj(xk)
635
+ corr_out = self.corr(proj_feats, out)
636
+ corr_out = F.interpolate(corr_out, size=(352, 352), mode="bilinear", align_corners=True)
637
+ dict_return['corr_out'] = corr_out
638
+ dict_return['out'] = out
639
+ dict_return['out_fp'] = out_fp
640
+
641
+ return dict_return['out']
642
+
643
+ out = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
644
+ if use_corr: # True
645
+ proj_feats = self.proj(xk)
646
+ # 计算
647
+ corr_out = self.corr(proj_feats, out)
648
+ corr_out = F.interpolate(corr_out, size=(352, 352), mode="bilinear", align_corners=True)
649
+ dict_return['corr_out'] = corr_out
650
+ dict_return['out'] = out
651
+ return dict_return['out']
652
+ # return x
653
+
654
+ def init_weights(self, pretrained='', ):
655
+ logger.info('=> init weights from normal distribution')
656
+ for m in self.modules():
657
+ if isinstance(m, nn.Conv2d):
658
+ nn.init.normal_(m.weight, std=0.001)
659
+ elif isinstance(m, nn.BatchNorm2d):
660
+ nn.init.constant_(m.weight, 1)
661
+ nn.init.constant_(m.bias, 0)
662
+ if os.path.isfile(pretrained):
663
+ pretrained_dict = torch.load(pretrained)
664
+ logger.info('=> loading pretrained model {}'.format(pretrained))
665
+ model_dict = self.state_dict()
666
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
667
+ if k in model_dict.keys()}
668
+ for k, _ in pretrained_dict.items():
669
+ logger.info(
670
+ '=> loading {} pretrained model {}'.format(k, pretrained))
671
+ model_dict.update(pretrained_dict)
672
+ self.load_state_dict(model_dict)
673
+
674
+
675
+ class OutConv(nn.Sequential):
676
+ def __init__(self, in_channels, num_classes):
677
+ super(OutConv, self).__init__(
678
+ nn.Conv2d(720, num_classes, kernel_size=1)
679
+ )
680
+
681
+
682
+ class DoubleConv(nn.Sequential):
683
+ def __init__(self, in_channels, out_channels, mid_channels=None):
684
+ if mid_channels is None:
685
+ mid_channels = out_channels
686
+ super(DoubleConv, self).__init__(
687
+ nn.Conv2d(in_channels + out_channels, mid_channels, kernel_size=3, padding=1, bias=False),
688
+ nn.BatchNorm2d(mid_channels),
689
+ nn.ReLU(inplace=True),
690
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
691
+ nn.BatchNorm2d(out_channels),
692
+ nn.ReLU(inplace=True)
693
+ )
694
+
695
+
696
+ class Up(nn.Module):
697
+ def __init__(self, in_channels, out_channels, bilinear=True):
698
+ super(Up, self).__init__()
699
+ if bilinear:
700
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
701
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
702
+ else:
703
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
704
+ self.conv = DoubleConv(in_channels, out_channels)
705
+
706
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
707
+ x1 = self.up(x1)
708
+ # [N, C, H, W]
709
+ diff_y = x2.size()[2] - x1.size()[2]
710
+ diff_x = x2.size()[3] - x1.size()[3]
711
+
712
+ # padding_left, padding_right, padding_top, padding_bottom
713
+ x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
714
+ diff_y // 2, diff_y - diff_y // 2])
715
+
716
+ x = torch.cat([x2, x1], dim=1)
717
+ x = self.conv(x)
718
+ return x
719
+
720
+
721
+ class Corr(nn.Module):
722
+ def __init__(self, nclass=2):
723
+ super(Corr, self).__init__()
724
+ self.nclass = nclass
725
+ self.conv1 = nn.Conv2d(48, self.nclass, kernel_size=1, stride=1, padding=0, bias=True)
726
+ self.conv2 = nn.Conv2d(48, self.nclass, kernel_size=1, stride=1, padding=0, bias=True)
727
+
728
+ def forward(self, feature_in, out):
729
+ # in torch.Size([4, 32, 22, 22])
730
+ # out = [4 2 352 352]
731
+ h_in, w_in = math.ceil(feature_in.shape[2] / (1)), math.ceil(feature_in.shape[3] / (1))
732
+ out = F.interpolate(out.detach(), (h_in, w_in), mode='bilinear', align_corners=True)
733
+ feature = F.interpolate(feature_in, (h_in, w_in), mode='bilinear', align_corners=True)
734
+ f1 = rearrange(self.conv1(feature), 'n c h w -> n c (h w)')
735
+ f2 = rearrange(self.conv2(feature), 'n c h w -> n c (h w)')
736
+ out_temp = rearrange(out, 'n c h w -> n c (h w)')
737
+ corr_map = torch.matmul(f1.transpose(1, 2), f2) / torch.sqrt(torch.tensor(f1.shape[1]).float())
738
+ corr_map = F.softmax(corr_map, dim=-1)
739
+ # out_temp 2 2 484
740
+ # corr_map 4 484 484
741
+ out = rearrange(torch.matmul(out_temp, corr_map), 'n c (h w) -> n c h w', h=h_in, w=w_in)
742
+ # out torch.Size([4, 2, 22, 22])
743
+ return out
744
+
745
+
746
+ if __name__ == '__main__':
747
+ input = torch.randn(4, 3, 352, 352)
748
+ cloud = HRcloudNet(num_classes=2)
749
+ output = cloud(input)
750
+ print(output.shape)
751
+ # torch.Size([4, 2, 352, 352]) torch.Size([4, 2, 352, 352]) torch.Size([4, 2, 352, 352])
src/models/components/lnn.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LNN(nn.Module):
6
+ # 创建一个全连接网络用于手写数字识别,并通过一个参数dim控制中间层的维度
7
+ def __init__(self, dim=32):
8
+ super(LNN, self).__init__()
9
+ self.fc1 = nn.Linear(28 * 28, dim)
10
+ self.fc2 = nn.Linear(dim, 10)
11
+
12
+ def forward(self, x):
13
+ x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
14
+ x = torch.relu(self.fc1(x))
15
+ x = self.fc2(x)
16
+ return x
17
+
18
+
19
+ if __name__ == "__main__":
20
+ input = torch.randn(2, 1, 28, 28)
21
+ model = LNN()
22
+ output = model(input)
23
+ assert output.shape == (2, 10)
src/models/components/mcdnet.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/7/21 下午3:51
3
+ # @Author : xiaoshun
4
+ # @Email : 3038523973@qq.com
5
+ # @File : mcdnet.py
6
+ # @Software: PyCharm
7
+ import cv2
8
+ import image_dehazer
9
+ import numpy as np
10
+ # 论文地址:https://www.sciencedirect.com/science/article/pii/S1569843224001742?via%3Dihub
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ class _DPFF(nn.Module):
17
+ def __init__(self, in_channels) -> None:
18
+ super(_DPFF, self).__init__()
19
+ self.cbr1 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False)
20
+ self.cbr2 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False)
21
+ # self.sigmoid = nn.Sigmoid()
22
+ self.cbr3 = nn.Conv2d(in_channels, in_channels, 1, 1, bias=False)
23
+ self.cbr4 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False)
24
+
25
+ def forward(self, feature1, feature2):
26
+ d1 = torch.abs(feature1 - feature2)
27
+ d2 = self.cbr1(torch.cat([feature1, feature2], dim=1))
28
+ d = torch.cat([d1, d2], dim=1)
29
+ d = self.cbr2(d)
30
+ # d = self.sigmoid(d)
31
+
32
+ v1, v2 = self.cbr3(feature1), self.cbr3(feature2)
33
+ v1, v2 = v1 * d, v2 * d
34
+ features = torch.cat([v1, v2], dim=1)
35
+ features = self.cbr4(features)
36
+
37
+ return features
38
+
39
+
40
+ class DPFF(nn.Module):
41
+ def __init__(self, layer_channels) -> None:
42
+ super(DPFF, self).__init__()
43
+ self.cfes = nn.ModuleList()
44
+ for layer_channel in layer_channels:
45
+ self.cfes.append(_DPFF(layer_channel))
46
+
47
+ def forward(self, features1, features2):
48
+ outputs = []
49
+ for feature1, feature2, cfe in zip(features1, features2, self.cfes):
50
+ outputs.append(cfe(feature1, feature2))
51
+ return outputs
52
+
53
+
54
+ class DirectDPFF(nn.Module):
55
+ def __init__(self, layer_channels) -> None:
56
+ super(DirectDPFF, self).__init__()
57
+ self.fusions = nn.ModuleList(
58
+ [nn.Conv2d(layer_channel * 2, layer_channel, 1, 1) for layer_channel in layer_channels]
59
+ )
60
+
61
+ def forward(self, features1, features2):
62
+ outputs = []
63
+ for feature1, feature2, fusion in zip(features1, features2, self.fusions):
64
+ feature = torch.cat([feature1, feature2], dim=1)
65
+ outputs.append(fusion(feature))
66
+ return outputs
67
+
68
+
69
+ class ConvBlock(nn.Module):
70
+ def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True,
71
+ bn=False, activation=True, maxpool=True):
72
+ super(ConvBlock, self).__init__()
73
+ self.module = []
74
+ if maxpool:
75
+ down = nn.Sequential(
76
+ *[
77
+ nn.MaxPool2d(2),
78
+ nn.Conv2d(input_size, output_size, 1, 1, 0, bias=bias)
79
+ ]
80
+ )
81
+ else:
82
+ down = nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias)
83
+ self.module.append(down)
84
+ if bn:
85
+ self.module.append(nn.BatchNorm2d(output_size))
86
+ if activation:
87
+ self.module.append(nn.PReLU())
88
+ self.module = nn.Sequential(*self.module)
89
+
90
+ def forward(self, x):
91
+ out = self.module(x)
92
+
93
+ return out
94
+
95
+
96
+ class DeconvBlock(nn.Module):
97
+ def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True,
98
+ bn=False, activation=True, bilinear=True):
99
+ super(DeconvBlock, self).__init__()
100
+ self.module = []
101
+ if bilinear:
102
+ deconv = nn.Sequential(
103
+ *[
104
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
105
+ nn.Conv2d(input_size, output_size, 1, 1, 0, bias=bias)
106
+ ]
107
+ )
108
+ else:
109
+ deconv = nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias)
110
+ self.module.append(deconv)
111
+ if bn:
112
+ self.module.append(nn.BatchNorm2d(output_size))
113
+ if activation:
114
+ self.module.append(nn.PReLU())
115
+ self.module = nn.Sequential(*self.module)
116
+
117
+ def forward(self, x):
118
+ out = self.module(x)
119
+
120
+ return out
121
+
122
+
123
+ class FusionBlock(torch.nn.Module):
124
+ def __init__(self, num_filter, num_ft, kernel_size=4, stride=2, padding=1, bias=True, maxpool=False,
125
+ bilinear=False):
126
+ super(FusionBlock, self).__init__()
127
+ self.num_ft = num_ft
128
+ self.up_convs = nn.ModuleList()
129
+ self.down_convs = nn.ModuleList()
130
+ for i in range(self.num_ft):
131
+ self.up_convs.append(
132
+ DeconvBlock(num_filter // (2 ** i), num_filter // (2 ** (i + 1)), kernel_size, stride, padding,
133
+ bias=bias, bilinear=bilinear)
134
+ )
135
+ self.down_convs.append(
136
+ ConvBlock(num_filter // (2 ** (i + 1)), num_filter // (2 ** i), kernel_size, stride, padding, bias=bias,
137
+ maxpool=maxpool)
138
+ )
139
+
140
+ def forward(self, ft_l, ft_h_list):
141
+ ft_fusion = ft_l
142
+ for i in range(len(ft_h_list)):
143
+ ft = ft_fusion
144
+ for j in range(self.num_ft - i):
145
+ ft = self.up_convs[j](ft)
146
+ ft = ft - ft_h_list[i]
147
+ for j in range(self.num_ft - i):
148
+ ft = self.down_convs[self.num_ft - i - j - 1](ft)
149
+ ft_fusion = ft_fusion + ft
150
+
151
+ return ft_fusion
152
+
153
+
154
+ class ConvLayer(nn.Module):
155
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True):
156
+ super(ConvLayer, self).__init__()
157
+ reflection_padding = kernel_size // 2
158
+ self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
159
+ self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
160
+
161
+ def forward(self, x):
162
+ out = self.reflection_pad(x)
163
+ out = self.conv2d(out)
164
+ return out
165
+
166
+
167
+ class UpsampleConvLayer(torch.nn.Module):
168
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
169
+ super(UpsampleConvLayer, self).__init__()
170
+ self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride)
171
+
172
+ def forward(self, x):
173
+ out = self.conv2d(x)
174
+ return out
175
+
176
+
177
+ class AddRelu(nn.Module):
178
+ """It is for adding two feed forwards to the output of the two following conv layers in expanding path
179
+ """
180
+
181
+ def __init__(self) -> None:
182
+ super(AddRelu, self).__init__()
183
+ self.relu = nn.PReLU()
184
+
185
+ def forward(self, input_tensor1, input_tensor2, input_tensor3):
186
+ x = input_tensor1 + input_tensor2 + input_tensor3
187
+ return self.relu(x)
188
+
189
+
190
+ class BasicBlock(nn.Module):
191
+ def __init__(self, in_channels, out_channels, mid_channels=None):
192
+ super(BasicBlock, self).__init__()
193
+ if not mid_channels:
194
+ mid_channels = out_channels
195
+ self.conv1 = ConvLayer(in_channels, mid_channels, kernel_size=3, stride=1)
196
+ self.bn1 = nn.BatchNorm2d(mid_channels, momentum=0.1)
197
+ self.relu = nn.PReLU()
198
+
199
+ self.conv2 = ConvLayer(mid_channels, out_channels, kernel_size=3, stride=1)
200
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.1)
201
+
202
+ self.conv3 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1)
203
+
204
+ def forward(self, x):
205
+ out = self.conv1(x)
206
+ out = self.bn1(out)
207
+ out = self.relu(out)
208
+
209
+ out = self.conv2(out)
210
+ out = self.bn2(out)
211
+
212
+ residual = self.conv3(x)
213
+
214
+ out = out + residual
215
+ out = self.relu(out)
216
+
217
+ return out
218
+
219
+
220
+ class Bottleneck(nn.Module):
221
+ def __init__(self, in_channels, out_channels):
222
+ super(Bottleneck, self).__init__()
223
+ self.conv1 = ConvLayer(in_channels, out_channels, kernel_size=3, stride=1)
224
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=0.1)
225
+
226
+ self.conv2 = ConvLayer(out_channels, out_channels, kernel_size=3, stride=1)
227
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.1)
228
+
229
+ self.conv3 = ConvLayer(out_channels, out_channels, kernel_size=3, stride=1)
230
+ self.bn3 = nn.BatchNorm2d(out_channels, momentum=0.1)
231
+
232
+ self.conv4 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1)
233
+
234
+ self.relu = nn.PReLU()
235
+
236
+ def forward(self, x):
237
+ out = self.conv1(x)
238
+ out = self.bn1(out)
239
+ out = self.relu(out)
240
+
241
+ out = self.conv2(out)
242
+ out = self.bn2(out)
243
+ out = self.relu(out)
244
+
245
+ out = self.conv3(out)
246
+ out = self.bn3(out)
247
+
248
+ residual = self.conv4(x)
249
+
250
+ out = out + residual
251
+ out = self.relu(out)
252
+
253
+ return out
254
+
255
+
256
+ class PPM(nn.Module):
257
+ def __init__(self, in_channels, out_channels):
258
+ super(PPM, self).__init__()
259
+
260
+ self.pool_sizes = [1, 2, 3, 6] # subregion size in each level
261
+ self.num_levels = len(self.pool_sizes) # number of pyramid levels
262
+
263
+ self.conv_layers = nn.ModuleList()
264
+ for i in range(self.num_levels):
265
+ self.conv_layers.append(nn.Sequential(
266
+ nn.AdaptiveAvgPool2d(output_size=self.pool_sizes[i]),
267
+ nn.Conv2d(in_channels, in_channels // self.num_levels, kernel_size=1),
268
+ nn.BatchNorm2d(in_channels // self.num_levels),
269
+ nn.ReLU(inplace=True)
270
+ ))
271
+ self.out_conv = nn.Conv2d(in_channels * 2, out_channels, kernel_size=1, stride=1)
272
+
273
+ def forward(self, x):
274
+ input_size = x.size()[2:] # get input size
275
+ output = [x]
276
+
277
+ # pyramid pooling
278
+ for i in range(self.num_levels):
279
+ out = self.conv_layers[i](x)
280
+ out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=True)
281
+ output.append(out)
282
+
283
+ # concatenate features from different levels
284
+ output = torch.cat(output, dim=1)
285
+ output = self.out_conv(output)
286
+
287
+ return output
288
+
289
+
290
+ class MCDNet(nn.Module):
291
+ def __init__(self, in_channels=4, num_classes=4, maxpool=False, bilinear=False) -> None:
292
+ super(MCDNet, self).__init__()
293
+ level = 1
294
+ # encoder
295
+ self.conv_input = ConvLayer(in_channels, 32 * level, kernel_size=3, stride=2)
296
+
297
+ self.dense0 = BasicBlock(32 * level, 32 * level)
298
+ self.conv2x = ConvLayer(32 * level, 64 * level, kernel_size=3, stride=2)
299
+
300
+ self.dense1 = BasicBlock(64 * level, 64 * level)
301
+ self.conv4x = ConvLayer(64 * level, 128 * level, kernel_size=3, stride=2)
302
+
303
+ self.dense2 = BasicBlock(128 * level, 128 * level)
304
+ self.conv8x = ConvLayer(128 * level, 256 * level, kernel_size=3, stride=2)
305
+
306
+ self.dense3 = BasicBlock(256 * level, 256 * level)
307
+ self.conv16x = ConvLayer(256 * level, 512 * level, kernel_size=3, stride=2)
308
+
309
+ self.dense4 = PPM(512 * level, 512 * level)
310
+
311
+ # dpff
312
+ self.dpffm = DPFF([32, 64, 128, 256, 512])
313
+
314
+ # decoder
315
+ self.convd16x = UpsampleConvLayer(512 * level, 256 * level, kernel_size=3, stride=2)
316
+ self.fusion4 = FusionBlock(256 * level, 3, maxpool=maxpool, bilinear=bilinear)
317
+ self.dense_4 = Bottleneck(512 * level, 256 * level)
318
+ self.add_block4 = AddRelu()
319
+
320
+ self.convd8x = UpsampleConvLayer(256 * level, 128 * level, kernel_size=3, stride=2)
321
+ self.fusion3 = FusionBlock(128 * level, 2, maxpool=maxpool, bilinear=bilinear)
322
+ self.dense_3 = Bottleneck(256 * level, 128 * level)
323
+ self.add_block3 = AddRelu()
324
+
325
+ self.convd4x = UpsampleConvLayer(128 * level, 64 * level, kernel_size=3, stride=2)
326
+ self.fusion2 = FusionBlock(64 * level, 1, maxpool=maxpool, bilinear=bilinear)
327
+ self.dense_2 = Bottleneck(128 * level, 64 * level)
328
+ self.add_block2 = AddRelu()
329
+
330
+ self.convd2x = UpsampleConvLayer(64 * level, 32 * level, kernel_size=3, stride=2)
331
+ self.dense_1 = Bottleneck(64 * level, 32 * level)
332
+ self.add_block1 = AddRelu()
333
+
334
+ self.head = UpsampleConvLayer(32 * level, num_classes, kernel_size=3, stride=2)
335
+ self.apply(self._weights_init)
336
+
337
+ def _weights_init(self, m):
338
+ if isinstance(m, nn.Linear):
339
+ nn.init.xavier_normal_(m.weight)
340
+ nn.init.constant_(m.bias, 0)
341
+ elif isinstance(m, nn.Conv2d):
342
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
343
+ elif isinstance(m, nn.BatchNorm2d):
344
+ nn.init.constant_(m.weight, 1)
345
+ nn.init.constant_(m.bias, 0)
346
+
347
+ def get_lr_data(self, x: torch.Tensor) -> torch.Tensor:
348
+ images = x.cpu().permute(0, 2, 3, 1).numpy()
349
+ batch_size = images.shape[0]
350
+ lr = []
351
+ for i in range(batch_size):
352
+ lr_image = cv2.cvtColor(images[i], cv2.COLOR_RGB2BGR)
353
+ lr_image = image_dehazer.remove_haze(lr_image, showHazeTransmissionMap=False)[0]
354
+ lr_image = cv2.cvtColor(lr_image, cv2.COLOR_BGR2RGB)
355
+ max_pix = np.max(lr_image)
356
+ min_pix = np.min(lr_image)
357
+ lr_image = (lr_image - min_pix) / (max_pix - min_pix)
358
+ lr_image = np.clip(lr_image, 0, 1)
359
+ lr_tensor = torch.from_numpy(lr_image).permute(2, 0, 1).float()
360
+ lr.append(lr_tensor)
361
+ return torch.stack(lr, dim=0).to(x.device)
362
+
363
+ def forward(self, x1):
364
+ x2 = self.get_lr_data(x1)
365
+ # encoder1
366
+ res1x_1 = self.conv_input(x1)
367
+ res1x_1 = self.dense0(res1x_1)
368
+
369
+ res2x_1 = self.conv2x(res1x_1)
370
+ res2x_1 = self.dense1(res2x_1)
371
+
372
+ res4x_1 = self.conv4x(res2x_1)
373
+ res4x_1 = self.dense2(res4x_1)
374
+
375
+ res8x_1 = self.conv8x(res4x_1)
376
+ res8x_1 = self.dense3(res8x_1)
377
+
378
+ res16x_1 = self.conv16x(res8x_1)
379
+ res16x_1 = self.dense4(res16x_1)
380
+
381
+ # encoder2
382
+ res1x_2 = self.conv_input(x2)
383
+ res1x_2 = self.dense0(res1x_2)
384
+
385
+ res2x_2 = self.conv2x(res1x_2)
386
+ res2x_2 = self.dense1(res2x_2)
387
+
388
+ res4x_2 = self.conv4x(res2x_2)
389
+ res4x_2 = self.dense2(res4x_2)
390
+
391
+ res8x_2 = self.conv8x(res4x_2)
392
+ res8x_2 = self.dense3(res8x_2)
393
+
394
+ res16x_2 = self.conv16x(res8x_2)
395
+ res16x_2 = self.dense4(res16x_2)
396
+
397
+ # dual-perspective feature fusion
398
+ res1x, res2x, res4x, res8x, res16x = self.dpffm(
399
+ [res1x_1, res2x_1, res4x_1, res8x_1, res16x_1],
400
+ [res1x_2, res2x_2, res4x_2, res8x_2, res16x_2]
401
+ )
402
+
403
+ # decoder
404
+ res8x1 = self.convd16x(res16x)
405
+ res8x1 = F.interpolate(res8x1, res8x.size()[2:], mode='bilinear')
406
+ res8x2 = self.fusion4(res8x, [res1x, res2x, res4x])
407
+ res8x2 = torch.cat([res8x1, res8x2], dim=1)
408
+ res8x2 = self.dense_4(res8x2)
409
+ res8x2 = self.add_block4(res8x1, res8x, res8x2)
410
+
411
+ res4x1 = self.convd8x(res8x2)
412
+ res4x1 = F.interpolate(res4x1, res4x.size()[2:], mode='bilinear')
413
+ res4x2 = self.fusion3(res4x, [res1x, res2x])
414
+ res4x2 = torch.cat([res4x1, res4x2], dim=1)
415
+ res4x2 = self.dense_3(res4x2)
416
+ res4x2 = self.add_block3(res4x1, res4x, res4x2)
417
+
418
+ res2x1 = self.convd4x(res4x2)
419
+ res2x1 = F.interpolate(res2x1, res2x.size()[2:], mode='bilinear')
420
+ res2x2 = self.fusion2(res2x, [res1x])
421
+ res2x2 = torch.cat([res2x1, res2x2], dim=1)
422
+ res2x2 = self.dense_2(res2x2)
423
+ res2x2 = self.add_block2(res2x1, res2x, res2x2)
424
+
425
+ res1x1 = self.convd2x(res2x2)
426
+ res1x1 = F.interpolate(res1x1, res1x.size()[2:], mode='bilinear')
427
+ res1x2 = torch.cat([res1x1, res1x], dim=1)
428
+ res1x2 = self.dense_1(res1x2)
429
+ res1x2 = self.add_block1(res1x1, res1x, res1x2)
430
+
431
+ out = self.head(res1x2)
432
+ out = F.interpolate(out, x1.size()[2:], mode='bilinear')
433
+
434
+ return out
435
+
436
+
437
+ def lr_lambda(epoch):
438
+ return (1 - epoch / 50) ** 0.9
439
+
440
+
441
+ if __name__ == "__main__":
442
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
443
+ # device = 'cpu'
444
+ model = MCDNet(in_channels=3, num_classes=7).to(device)
445
+ fake_img = torch.randn(size=(2, 3, 256, 256)).to(device)
446
+ out = model(fake_img).detach().cpu()
447
+ print(out.shape)
448
+ # torch.Size([2, 7, 256, 256])
src/models/components/scnn.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/7/21 下午5:11
3
+ # @Author : xiaoshun
4
+ # @Email : 3038523973@qq.com
5
+ # @File : scnn.py
6
+ # @Software: PyCharm
7
+
8
+ # 论文地址:https://www.sciencedirect.com/science/article/abs/pii/S0924271624000352?via%3Dihub#fn1
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class SCNNNet(nn.Module):
16
+ def __init__(self, in_channels=3, num_classes=2, dropout_p=0.5):
17
+ super().__init__()
18
+ self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1)
19
+ self.conv2 = nn.Conv2d(64, num_classes, kernel_size=1)
20
+ self.conv3 = nn.Conv2d(num_classes, num_classes, kernel_size=3, padding=1)
21
+ self.dropout = nn.Dropout2d(p=dropout_p)
22
+
23
+ def forward(self, x):
24
+ x = F.relu(self.conv1(x))
25
+ x = self.dropout(x)
26
+ x = self.conv2(x)
27
+ x = self.conv3(x)
28
+ return x
29
+
30
+
31
+ if __name__ == '__main__':
32
+ model = SCNNNet(num_classes=7)
33
+ fake_img = torch.randn((2, 3, 224, 224))
34
+ out = model(fake_img)
35
+ print(out.shape)
36
+ # torch.Size([2, 7, 224, 224])
src/models/components/unet.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class UNet(nn.Module):
6
+ def __init__(self, in_channels, out_channels):
7
+ super(UNet, self).__init__()
8
+
9
+ def conv_block(in_channels, out_channels):
10
+ return nn.Sequential(
11
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
12
+ nn.ReLU(inplace=True),
13
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
14
+ nn.ReLU(inplace=True)
15
+ )
16
+
17
+ self.encoder1 = conv_block(in_channels, 64)
18
+ self.encoder2 = conv_block(64, 128)
19
+ self.encoder3 = conv_block(128, 256)
20
+ self.encoder4 = conv_block(256, 512)
21
+ self.bottleneck = conv_block(512, 1024)
22
+
23
+ self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
24
+ self.decoder4 = conv_block(1024, 512)
25
+ self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
26
+ self.decoder3 = conv_block(512, 256)
27
+ self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
28
+ self.decoder2 = conv_block(256, 128)
29
+ self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
30
+ self.decoder1 = conv_block(128, 64)
31
+
32
+ self.final = nn.Conv2d(64, out_channels, kernel_size=1)
33
+
34
+ def forward(self, x):
35
+ enc1 = self.encoder1(x)
36
+ enc2 = self.encoder2(F.max_pool2d(enc1, kernel_size=2, stride=2))
37
+ enc3 = self.encoder3(F.max_pool2d(enc2, kernel_size=2, stride=2))
38
+ enc4 = self.encoder4(F.max_pool2d(enc3, kernel_size=2, stride=2))
39
+ bottleneck = self.bottleneck(F.max_pool2d(enc4, kernel_size=2, stride=2))
40
+
41
+ dec4 = self.upconv4(bottleneck)
42
+ dec4 = torch.cat((dec4, enc4), dim=1)
43
+ dec4 = self.decoder4(dec4)
44
+ dec3 = self.upconv3(dec4)
45
+ dec3 = torch.cat((dec3, enc3), dim=1)
46
+ dec3 = self.decoder3(dec3)
47
+ dec2 = self.upconv2(dec3)
48
+ dec2 = torch.cat((dec2, enc2), dim=1)
49
+ dec2 = self.decoder2(dec2)
50
+ dec1 = self.upconv1(dec2)
51
+ dec1 = torch.cat((dec1, enc1), dim=1)
52
+ dec1 = self.decoder1(dec1)
53
+
54
+ return self.final(dec1)
55
+
56
+ if __name__ == "__main__":
57
+ model = UNet(in_channels=3,out_channels=7)
58
+ fake_img = torch.rand(size=(2,3,224,224))
59
+ print(fake_img.shape)
60
+ # torch.Size([2, 3, 224, 224])
61
+ out = model(fake_img)
62
+ print(out.shape)
63
+ # torch.Size([2, 7, 224, 224])
src/models/components/vae.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+ from typing import List, Dict
6
+ from src.plugin.taming_transformers.taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
+
8
+ from src.plugin.ldm.modules.diffusionmodules.model import Encoder, Decoder
9
+ from src.plugin.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
+
11
+ import matplotlib.pyplot as plt
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+
18
+ class AutoencoderKL(nn.Module):
19
+ def __init__(
20
+ self,
21
+ double_z: bool = True,
22
+ z_channels: int = 3,
23
+ resolution: int = 512,
24
+ in_channels: int = 3,
25
+ out_ch: int = 3,
26
+ ch: int = 128,
27
+ ch_mult: List = [1, 2, 4, 4],
28
+ num_res_blocks: int = 2,
29
+ attn_resolutions: List = [],
30
+ dropout: float = 0.0,
31
+ embed_dim: int = 3,
32
+ ckpt_path: str = None,
33
+ ignore_keys: List = [],
34
+ ):
35
+ super(AutoencoderKL, self).__init__()
36
+ ddconfig = {
37
+ "double_z": double_z,
38
+ "z_channels": z_channels,
39
+ "resolution": resolution,
40
+ "in_channels": in_channels,
41
+ "out_ch": out_ch,
42
+ "ch": ch,
43
+ "ch_mult": ch_mult,
44
+ "num_res_blocks": num_res_blocks,
45
+ "attn_resolutions": attn_resolutions,
46
+ "dropout": dropout
47
+ }
48
+ self.encoder = Encoder(**ddconfig)
49
+ self.decoder = Decoder(**ddconfig)
50
+ assert ddconfig["double_z"]
51
+ self.quant_conv = nn.Conv2d(
52
+ 2 * ddconfig["z_channels"], 2 * embed_dim, 1)
53
+ self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
54
+ self.embed_dim = embed_dim
55
+ if ckpt_path is not None:
56
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
57
+
58
+ def init_from_ckpt(self, path, ignore_keys=list()):
59
+ sd = torch.load(path, map_location="cpu")["state_dict"]
60
+ keys = list(sd.keys())
61
+ for k in keys:
62
+ for ik in ignore_keys:
63
+ if k.startswith(ik):
64
+ print(f"Deleting key {k} from state_dict.")
65
+ del sd[k]
66
+ self.load_state_dict(sd, strict=False)
67
+ print(f"Restored from {path}")
68
+
69
+ def encode(self, x):
70
+ h = self.encoder(x) # B, C, h, w
71
+ moments = self.quant_conv(h) # B, 6, h, w
72
+ posterior = DiagonalGaussianDistribution(moments)
73
+ return posterior # 分布
74
+
75
+ def decode(self, z):
76
+ z = self.post_quant_conv(z)
77
+ dec = self.decoder(z)
78
+ return dec
79
+
80
+ def forward(self, input, sample_posterior=True):
81
+ posterior = self.encode(input) # 高斯分布
82
+ if sample_posterior:
83
+ z = posterior.sample() # 采样
84
+ else:
85
+ z = posterior.mode()
86
+ dec = self.decode(z)
87
+ last_layer_weight = self.decoder.conv_out.weight
88
+ return dec, posterior, last_layer_weight
89
+
90
+
91
+ if __name__ == '__main__':
92
+ # Test the input and output shapes of the model
93
+ model = AutoencoderKL()
94
+ x = torch.randn(1, 3, 512, 512)
95
+ dec, posterior, last_layer_weight = model(x)
96
+
97
+ assert dec.shape == (1, 3, 512, 512)
98
+ assert posterior.sample().shape == posterior.mode().shape == (1, 3, 64, 64)
99
+ assert last_layer_weight.shape == (3, 128, 3, 3)
100
+
101
+ # Plot the latent space and the reconstruction from the pretrained model
102
+ model = AutoencoderKL(ckpt_path="/mnt/chongqinggeminiceph1fs/geminicephfs/wx-mm-spr-xxxx/zouxuechao/Collaborative-Diffusion/outputs/512_vae/2024-06-27T06-02-04_512_vae/checkpoints/epoch=000036.ckpt")
103
+ model.eval()
104
+ image_path = "data/celeba/image/image_512_downsampled_from_hq_1024/0.jpg"
105
+
106
+ from PIL import Image
107
+ import numpy as np
108
+ from src.data.components.celeba import DalleTransformerPreprocessor
109
+ from src.data.components.celeba import CelebA
110
+ image = Image.open(image_path).convert('RGB')
111
+ image = np.array(image).astype(np.uint8)
112
+ import copy
113
+ original = copy.deepcopy(image)
114
+ transform = DalleTransformerPreprocessor(size=512, phase='test')
115
+ image = transform(image=image)['image']
116
+ image = image.astype(np.float32)/127.5 - 1.0
117
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
118
+
119
+ dec, posterior, last_layer_weight = model(image)
120
+
121
+ # original image
122
+ plt.subplot(1, 3, 1)
123
+ plt.imshow(original)
124
+ plt.title("Original")
125
+ plt.axis("off")
126
+
127
+ # sampled image from the latent space
128
+ plt.subplot(1, 3, 2)
129
+ x = model.decode(posterior.sample())
130
+ x = (x+1)/2
131
+ x = x.squeeze(0).permute(1, 2, 0).cpu()
132
+ x = x.detach().numpy()
133
+ x = x.clip(0, 1)
134
+ x = (x*255).astype(np.uint8)
135
+ plt.imshow(x)
136
+ plt.title("Sampled")
137
+ plt.axis("off")
138
+
139
+ # reconstructed image
140
+ plt.subplot(1, 3, 3)
141
+ x = dec
142
+ x = (x+1)/2
143
+ x = x.squeeze(0).permute(1, 2, 0).cpu()
144
+ x = x.detach().numpy()
145
+ x = x.clip(0, 1)
146
+ x = (x*255).astype(np.uint8)
147
+ plt.imshow(x)
148
+ plt.title("Reconstructed")
149
+ plt.axis("off")
150
+
151
+ plt.tight_layout()
152
+ plt.savefig("vae_reconstruction.png")