vobecant
commited on
Commit
·
dd78229
1
Parent(s):
83a95c0
Initial commit
Browse files- .idea/DaS.iml +8 -0
- .idea/inspectionProfiles/Project_Default.xml +26 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- .idea/workspace.xml +133 -0
- README.md +0 -11
- app.py +151 -0
- examples/img1.jpg +0 -0
- requirements.txt +6 -0
- segmenter_model/backbone_picie.py +348 -0
- segmenter_model/blocks.py +129 -0
- segmenter_model/decoder.py +214 -0
- segmenter_model/factory.py +165 -0
- segmenter_model/fpn_picie.py +66 -0
- segmenter_model/picie_model.py +82 -0
- segmenter_model/resnet_dilated.py +55 -0
- segmenter_model/segmenter.py +86 -0
- segmenter_model/torch.py +38 -0
- segmenter_model/utils.py +582 -0
- segmenter_model/vit_dino.py +348 -0
.idea/DaS.iml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="jdk" jdkName="Python 3.8 (pytorch) (2)" jdkType="Python SDK" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
5 |
+
<option name="ignoredPackages">
|
6 |
+
<value>
|
7 |
+
<list size="13">
|
8 |
+
<item index="0" class="java.lang.String" itemvalue="yacs" />
|
9 |
+
<item index="1" class="java.lang.String" itemvalue="termcolor" />
|
10 |
+
<item index="2" class="java.lang.String" itemvalue="pydot" />
|
11 |
+
<item index="3" class="java.lang.String" itemvalue="fvcore" />
|
12 |
+
<item index="4" class="java.lang.String" itemvalue="tabulate" />
|
13 |
+
<item index="5" class="java.lang.String" itemvalue="mock" />
|
14 |
+
<item index="6" class="java.lang.String" itemvalue="pycocotools" />
|
15 |
+
<item index="7" class="java.lang.String" itemvalue="prettytable" />
|
16 |
+
<item index="8" class="java.lang.String" itemvalue="interrogate" />
|
17 |
+
<item index="9" class="java.lang.String" itemvalue="cityscapesscripts" />
|
18 |
+
<item index="10" class="java.lang.String" itemvalue="isort" />
|
19 |
+
<item index="11" class="java.lang.String" itemvalue="xdoctest" />
|
20 |
+
<item index="12" class="java.lang.String" itemvalue="codecov" />
|
21 |
+
</list>
|
22 |
+
</value>
|
23 |
+
</option>
|
24 |
+
</inspection_tool>
|
25 |
+
</profile>
|
26 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (pytorch) (2)" project-jdk-type="Python SDK" />
|
4 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/DaS.iml" filepath="$PROJECT_DIR$/.idea/DaS.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
.idea/workspace.xml
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ChangeListManager">
|
4 |
+
<list default="true" id="5dd22f22-8223-4d55-99f9-57d1e00622d7" name="Default Changelist" comment="Initial commit.">
|
5 |
+
<change afterPath="$PROJECT_DIR$/app.py" afterDir="false" />
|
6 |
+
<change afterPath="$PROJECT_DIR$/examples/img1.jpg" afterDir="false" />
|
7 |
+
<change afterPath="$PROJECT_DIR$/requirements.txt" afterDir="false" />
|
8 |
+
<change afterPath="$PROJECT_DIR$/segmenter_model/backbone_picie.py" afterDir="false" />
|
9 |
+
<change afterPath="$PROJECT_DIR$/segmenter_model/blocks.py" afterDir="false" />
|
10 |
+
<change afterPath="$PROJECT_DIR$/segmenter_model/decoder.py" afterDir="false" />
|
11 |
+
<change afterPath="$PROJECT_DIR$/segmenter_model/factory.py" afterDir="false" />
|
12 |
+
<change afterPath="$PROJECT_DIR$/segmenter_model/fpn_picie.py" afterDir="false" />
|
13 |
+
<change afterPath="$PROJECT_DIR$/segmenter_model/picie_model.py" afterDir="false" />
|
14 |
+
<change afterPath="$PROJECT_DIR$/segmenter_model/resnet_dilated.py" afterDir="false" />
|
15 |
+
<change afterPath="$PROJECT_DIR$/segmenter_model/segmenter.py" afterDir="false" />
|
16 |
+
<change afterPath="$PROJECT_DIR$/segmenter_model/torch.py" afterDir="false" />
|
17 |
+
<change afterPath="$PROJECT_DIR$/segmenter_model/utils.py" afterDir="false" />
|
18 |
+
<change afterPath="$PROJECT_DIR$/segmenter_model/vit_dino.py" afterDir="false" />
|
19 |
+
<change beforePath="$PROJECT_DIR$/README.md" beforeDir="false" afterPath="$PROJECT_DIR$/README.md" afterDir="false" />
|
20 |
+
</list>
|
21 |
+
<option name="SHOW_DIALOG" value="false" />
|
22 |
+
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
23 |
+
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
24 |
+
<option name="LAST_RESOLUTION" value="IGNORE" />
|
25 |
+
</component>
|
26 |
+
<component name="FileTemplateManagerImpl">
|
27 |
+
<option name="RECENT_TEMPLATES">
|
28 |
+
<list>
|
29 |
+
<option value="Python Script" />
|
30 |
+
</list>
|
31 |
+
</option>
|
32 |
+
</component>
|
33 |
+
<component name="Git.Settings">
|
34 |
+
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
35 |
+
<option name="UPDATE_TYPE" value="REBASE" />
|
36 |
+
</component>
|
37 |
+
<component name="ProjectId" id="26QLDSf8iYKDlLRah6kIg09oqIa" />
|
38 |
+
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
|
39 |
+
<component name="ProjectViewState">
|
40 |
+
<option name="hideEmptyMiddlePackages" value="true" />
|
41 |
+
<option name="showLibraryContents" value="true" />
|
42 |
+
</component>
|
43 |
+
<component name="PropertiesComponent">
|
44 |
+
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
|
45 |
+
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
|
46 |
+
<property name="WebServerToolWindowFactoryState" value="true" />
|
47 |
+
<property name="last_opened_file_path" value="$PROJECT_DIR$" />
|
48 |
+
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
|
49 |
+
</component>
|
50 |
+
<component name="RecentsManager">
|
51 |
+
<key name="CopyFile.RECENT_KEYS">
|
52 |
+
<recent name="$PROJECT_DIR$" />
|
53 |
+
<recent name="$PROJECT_DIR$/examples" />
|
54 |
+
</key>
|
55 |
+
<key name="MoveFile.RECENT_KEYS">
|
56 |
+
<recent name="$PROJECT_DIR$/examples" />
|
57 |
+
</key>
|
58 |
+
</component>
|
59 |
+
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
60 |
+
<component name="TaskManager">
|
61 |
+
<task active="true" id="Default" summary="Default task">
|
62 |
+
<changelist id="5dd22f22-8223-4d55-99f9-57d1e00622d7" name="Default Changelist" comment="" />
|
63 |
+
<created>1647350746642</created>
|
64 |
+
<option name="number" value="Default" />
|
65 |
+
<option name="presentableId" value="Default" />
|
66 |
+
<updated>1647350746642</updated>
|
67 |
+
<workItem from="1647350750956" duration="4327000" />
|
68 |
+
</task>
|
69 |
+
<task id="LOCAL-00001" summary="Initial commit.">
|
70 |
+
<created>1647352693910</created>
|
71 |
+
<option name="number" value="00001" />
|
72 |
+
<option name="presentableId" value="LOCAL-00001" />
|
73 |
+
<option name="project" value="LOCAL" />
|
74 |
+
<updated>1647352693910</updated>
|
75 |
+
</task>
|
76 |
+
<task id="LOCAL-00002" summary="Initial commit.">
|
77 |
+
<created>1647353059401</created>
|
78 |
+
<option name="number" value="00002" />
|
79 |
+
<option name="presentableId" value="LOCAL-00002" />
|
80 |
+
<option name="project" value="LOCAL" />
|
81 |
+
<updated>1647353059401</updated>
|
82 |
+
</task>
|
83 |
+
<task id="LOCAL-00003" summary="Added gitignore.">
|
84 |
+
<created>1647353514970</created>
|
85 |
+
<option name="number" value="00003" />
|
86 |
+
<option name="presentableId" value="LOCAL-00003" />
|
87 |
+
<option name="project" value="LOCAL" />
|
88 |
+
<updated>1647353514970</updated>
|
89 |
+
</task>
|
90 |
+
<task id="LOCAL-00004" summary="Added gitignore.">
|
91 |
+
<created>1647353622389</created>
|
92 |
+
<option name="number" value="00004" />
|
93 |
+
<option name="presentableId" value="LOCAL-00004" />
|
94 |
+
<option name="project" value="LOCAL" />
|
95 |
+
<updated>1647353622389</updated>
|
96 |
+
</task>
|
97 |
+
<task id="LOCAL-00005" summary="Added gitignore.">
|
98 |
+
<created>1647353674966</created>
|
99 |
+
<option name="number" value="00005" />
|
100 |
+
<option name="presentableId" value="LOCAL-00005" />
|
101 |
+
<option name="project" value="LOCAL" />
|
102 |
+
<updated>1647353674966</updated>
|
103 |
+
</task>
|
104 |
+
<task id="LOCAL-00006" summary="Initial commit.">
|
105 |
+
<created>1647354226094</created>
|
106 |
+
<option name="number" value="00006" />
|
107 |
+
<option name="presentableId" value="LOCAL-00006" />
|
108 |
+
<option name="project" value="LOCAL" />
|
109 |
+
<updated>1647354226094</updated>
|
110 |
+
</task>
|
111 |
+
<option name="localTasksCounter" value="7" />
|
112 |
+
<servers />
|
113 |
+
</component>
|
114 |
+
<component name="TypeScriptGeneratedFilesManager">
|
115 |
+
<option name="version" value="3" />
|
116 |
+
</component>
|
117 |
+
<component name="Vcs.Log.Tabs.Properties">
|
118 |
+
<option name="TAB_STATES">
|
119 |
+
<map>
|
120 |
+
<entry key="MAIN">
|
121 |
+
<value>
|
122 |
+
<State />
|
123 |
+
</value>
|
124 |
+
</entry>
|
125 |
+
</map>
|
126 |
+
</option>
|
127 |
+
</component>
|
128 |
+
<component name="VcsManagerConfiguration">
|
129 |
+
<MESSAGE value="Added gitignore." />
|
130 |
+
<MESSAGE value="Initial commit." />
|
131 |
+
<option name="LAST_COMMIT_MESSAGE" value="Initial commit." />
|
132 |
+
</component>
|
133 |
+
</project>
|
README.md
CHANGED
@@ -1,12 +1 @@
|
|
1 |
-
---
|
2 |
-
title: DaS
|
3 |
-
emoji: 💻
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: pink
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 2.8.10
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
-
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
|
app.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import requests
|
4 |
+
import torch
|
5 |
+
import yaml
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
from segmenter_model import utils
|
10 |
+
from segmenter_model.factory import create_segmenter
|
11 |
+
from segmenter_model.fpn_picie import PanopticFPN
|
12 |
+
from segmenter_model.utils import colorize_one, map2cs
|
13 |
+
|
14 |
+
WEIGHTS = './weights/segmenter.pth'
|
15 |
+
|
16 |
+
|
17 |
+
def download_file_from_google_drive(id, destination):
|
18 |
+
def get_confirm_token(response):
|
19 |
+
for key, value in response.cookies.items():
|
20 |
+
if key.startswith('download_warning'):
|
21 |
+
return value
|
22 |
+
|
23 |
+
return None
|
24 |
+
|
25 |
+
def save_response_content(response, destination):
|
26 |
+
CHUNK_SIZE = 32768
|
27 |
+
|
28 |
+
with open(destination, "wb") as f:
|
29 |
+
for chunk in response.iter_content(CHUNK_SIZE):
|
30 |
+
if chunk: # filter out keep-alive new chunks
|
31 |
+
f.write(chunk)
|
32 |
+
|
33 |
+
URL = "https://docs.google.com/uc?export=download"
|
34 |
+
|
35 |
+
session = requests.Session()
|
36 |
+
|
37 |
+
response = session.get(URL, params={'id': id}, stream=True)
|
38 |
+
token = get_confirm_token(response)
|
39 |
+
|
40 |
+
if token:
|
41 |
+
params = {'id': id, 'confirm': token}
|
42 |
+
response = session.get(URL, params=params, stream=True)
|
43 |
+
|
44 |
+
save_response_content(response, destination)
|
45 |
+
|
46 |
+
|
47 |
+
def segment_segmenter(image, model, window_size, window_stride, encoder_features=False, decoder_features=False,
|
48 |
+
no_upsample=False, batch_size=2):
|
49 |
+
seg_pred = utils.inference(
|
50 |
+
model,
|
51 |
+
image,
|
52 |
+
image.shape[-2:],
|
53 |
+
window_size,
|
54 |
+
window_stride,
|
55 |
+
batch_size=batch_size,
|
56 |
+
no_upsample=no_upsample,
|
57 |
+
encoder_features=encoder_features,
|
58 |
+
decoder_features=decoder_features
|
59 |
+
)
|
60 |
+
if not (encoder_features or decoder_features):
|
61 |
+
seg_pred = seg_pred.argmax(1).unsqueeze(1)
|
62 |
+
return seg_pred
|
63 |
+
|
64 |
+
|
65 |
+
def remap(seg_pred, ignore=255):
|
66 |
+
mapping = {0: 0, 12: 1, 15: 2, 23: 3, 10: 4, 14: 5, 18: 6, 2: 7, 17: 8, 13: 9, 8: 10, 3: 11, 27: 12, 4: 13, 25: 14,
|
67 |
+
24: 15, 6: 16, 22: 17, 28: 18}
|
68 |
+
h, w = seg_pred.shape[-2:]
|
69 |
+
seg_pred_remap = np.ones((h, w), dtype=np.uint8) * ignore
|
70 |
+
for pseudo, gt in mapping.items():
|
71 |
+
whr = seg_pred == pseudo
|
72 |
+
seg_pred_remap[whr] = gt
|
73 |
+
return seg_pred_remap
|
74 |
+
|
75 |
+
|
76 |
+
def create_model(resnet=False):
|
77 |
+
weights_path = WEIGHTS
|
78 |
+
variant_path = '{}_variant.yml'.format(weights_path)
|
79 |
+
|
80 |
+
print('Use weights {}'.format(weights_path))
|
81 |
+
print('Load variant from {}'.format(variant_path))
|
82 |
+
variant = yaml.load(
|
83 |
+
open(variant_path, "r"), Loader=yaml.FullLoader
|
84 |
+
)
|
85 |
+
|
86 |
+
# TODO: parse hyperparameters
|
87 |
+
window_size = variant['inference_kwargs']["window_size"]
|
88 |
+
window_stride = variant['inference_kwargs']["window_stride"]
|
89 |
+
dataset_kwargs = variant['dataset_kwargs']
|
90 |
+
net_kwargs = variant["net_kwargs"]
|
91 |
+
net_kwargs['n_cls'] = dataset_kwargs['nlabels']
|
92 |
+
|
93 |
+
dataset_kwargs = variant['dataset_kwargs']
|
94 |
+
|
95 |
+
net_kwargs = variant["net_kwargs"]
|
96 |
+
net_kwargs['n_cls'] = dataset_kwargs['nlabels']
|
97 |
+
if not resnet:
|
98 |
+
net_kwargs['decoder']['dropout'] = 0.
|
99 |
+
|
100 |
+
# TODO: create model
|
101 |
+
if resnet:
|
102 |
+
model = PanopticFPN(arch=net_kwargs['backbone'], pretrain=net_kwargs['pretrain'], n_cls=net_kwargs['n_cls'])
|
103 |
+
else:
|
104 |
+
model = create_segmenter(net_kwargs)
|
105 |
+
|
106 |
+
# TODO: load weights
|
107 |
+
print('Load weights from {}'.format(weights_path))
|
108 |
+
weights = torch.load(weights_path)['model']
|
109 |
+
model.load_state_dict(weights, strict=True)
|
110 |
+
|
111 |
+
model.eval()
|
112 |
+
|
113 |
+
return model, window_size, window_stride
|
114 |
+
|
115 |
+
|
116 |
+
def get_transformations():
|
117 |
+
return transforms.Compose([
|
118 |
+
transforms.ToTensor(),
|
119 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
120 |
+
|
121 |
+
|
122 |
+
model, window_size, window_stride = create_model()
|
123 |
+
|
124 |
+
|
125 |
+
def predict(input_img):
|
126 |
+
input_img = Image.open(input_img)
|
127 |
+
transform = transforms.Compose([transforms.Resize(256, Image.BICUBIC), transforms.ToTensor()])
|
128 |
+
input_img = transform(input_img)
|
129 |
+
input_img = torch.unsqueeze(input_img, 0)
|
130 |
+
|
131 |
+
with torch.no_grad():
|
132 |
+
segmentation = segment_segmenter(input_img, model, window_size, window_stride).squeeze().detach()
|
133 |
+
segmentation_remap = remap(segmentation)
|
134 |
+
|
135 |
+
drawing_pseudo = colorize_one(segmentation_remap)
|
136 |
+
drawing_cs = map2cs(segmentation_remap)
|
137 |
+
|
138 |
+
drawing_pseudo = transforms.ToPILImage()(drawing_pseudo)
|
139 |
+
drawing_cs = transforms.ToPILImage()(drawing_cs)
|
140 |
+
return drawing_pseudo, drawing_cs
|
141 |
+
|
142 |
+
|
143 |
+
title = "Drive&Segment"
|
144 |
+
description = 'Gradio Demo accompanying paper "Drive&Segment: Unsupervised Semantic Segmentation of Urban Scenes via Cross-modal Distillation"'
|
145 |
+
# article = "<p style='text-align: center'><a href='TODO' target='_blank'>Project Page</a> | <a href='codelink' target='_blank'>Github</a></p>"
|
146 |
+
examples = [['examples/img1.jpg']]
|
147 |
+
|
148 |
+
iface = gr.Interface(predict, gr.inputs.Image(type='filepath'), "image", title=title, description=description,
|
149 |
+
examples=examples)
|
150 |
+
|
151 |
+
iface.launch()
|
examples/img1.jpg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
PIL
|
4 |
+
timm
|
5 |
+
yaml
|
6 |
+
einops
|
segmenter_model/backbone_picie.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
try:
|
4 |
+
from torchvision.models.utils import load_state_dict_from_url
|
5 |
+
except:
|
6 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
7 |
+
|
8 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
9 |
+
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
10 |
+
'wide_resnet50_2', 'wide_resnet101_2']
|
11 |
+
|
12 |
+
model_urls = {
|
13 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
14 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
15 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
16 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
17 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
18 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
19 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
20 |
+
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
21 |
+
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
26 |
+
"""3x3 convolution with padding"""
|
27 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
28 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
29 |
+
|
30 |
+
|
31 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
32 |
+
"""1x1 convolution"""
|
33 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
34 |
+
|
35 |
+
|
36 |
+
class BasicBlock(nn.Module):
|
37 |
+
expansion = 1
|
38 |
+
|
39 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
40 |
+
base_width=64, dilation=1, norm_layer=None):
|
41 |
+
super(BasicBlock, self).__init__()
|
42 |
+
if norm_layer is None:
|
43 |
+
norm_layer = nn.BatchNorm2d
|
44 |
+
if groups != 1 or base_width != 64:
|
45 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
46 |
+
if dilation > 1:
|
47 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
48 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
49 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
50 |
+
self.bn1 = norm_layer(planes)
|
51 |
+
self.relu = nn.ReLU(inplace=True)
|
52 |
+
self.conv2 = conv3x3(planes, planes)
|
53 |
+
self.bn2 = norm_layer(planes)
|
54 |
+
self.downsample = downsample
|
55 |
+
self.stride = stride
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
identity = x
|
59 |
+
|
60 |
+
out = self.conv1(x)
|
61 |
+
out = self.bn1(out)
|
62 |
+
out = self.relu(out)
|
63 |
+
|
64 |
+
out = self.conv2(out)
|
65 |
+
out = self.bn2(out)
|
66 |
+
|
67 |
+
if self.downsample is not None:
|
68 |
+
identity = self.downsample(x)
|
69 |
+
|
70 |
+
out += identity
|
71 |
+
out = self.relu(out)
|
72 |
+
|
73 |
+
return out
|
74 |
+
|
75 |
+
|
76 |
+
class Bottleneck(nn.Module):
|
77 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
78 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
79 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
80 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
81 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
82 |
+
|
83 |
+
expansion = 4
|
84 |
+
|
85 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
86 |
+
base_width=64, dilation=1, norm_layer=None):
|
87 |
+
super(Bottleneck, self).__init__()
|
88 |
+
if norm_layer is None:
|
89 |
+
norm_layer = nn.BatchNorm2d
|
90 |
+
width = int(planes * (base_width / 64.)) * groups
|
91 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
92 |
+
self.conv1 = conv1x1(inplanes, width)
|
93 |
+
self.bn1 = norm_layer(width)
|
94 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
95 |
+
self.bn2 = norm_layer(width)
|
96 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
97 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
98 |
+
self.relu = nn.ReLU(inplace=True)
|
99 |
+
self.downsample = downsample
|
100 |
+
self.stride = stride
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
identity = x
|
104 |
+
|
105 |
+
out = self.conv1(x)
|
106 |
+
out = self.bn1(out)
|
107 |
+
out = self.relu(out)
|
108 |
+
|
109 |
+
out = self.conv2(out)
|
110 |
+
out = self.bn2(out)
|
111 |
+
out = self.relu(out)
|
112 |
+
|
113 |
+
out = self.conv3(out)
|
114 |
+
out = self.bn3(out)
|
115 |
+
|
116 |
+
if self.downsample is not None:
|
117 |
+
identity = self.downsample(x)
|
118 |
+
|
119 |
+
out += identity
|
120 |
+
out = self.relu(out)
|
121 |
+
|
122 |
+
return out
|
123 |
+
|
124 |
+
|
125 |
+
class ResNet(nn.Module):
|
126 |
+
|
127 |
+
def __init__(self, block, layers, zero_init_residual=False,
|
128 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
129 |
+
norm_layer=None):
|
130 |
+
super(ResNet, self).__init__()
|
131 |
+
if norm_layer is None:
|
132 |
+
norm_layer = nn.BatchNorm2d
|
133 |
+
self._norm_layer = norm_layer
|
134 |
+
|
135 |
+
self.inplanes = 64
|
136 |
+
self.dilation = 1
|
137 |
+
if replace_stride_with_dilation is None:
|
138 |
+
# each element in the tuple indicates if we should replace
|
139 |
+
# the 2x2 stride with a dilated convolution instead
|
140 |
+
replace_stride_with_dilation = [False, False, False]
|
141 |
+
if len(replace_stride_with_dilation) != 3:
|
142 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
143 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
144 |
+
self.groups = groups
|
145 |
+
self.base_width = width_per_group
|
146 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
147 |
+
bias=False)
|
148 |
+
self.bn1 = norm_layer(self.inplanes)
|
149 |
+
self.relu = nn.ReLU(inplace=True)
|
150 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
151 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
152 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
153 |
+
dilate=replace_stride_with_dilation[0])
|
154 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
155 |
+
dilate=replace_stride_with_dilation[1])
|
156 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
157 |
+
dilate=replace_stride_with_dilation[2])
|
158 |
+
# self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
159 |
+
# self.fc = nn.Linear(512 * block.expansion, num_classes)
|
160 |
+
|
161 |
+
for m in self.modules():
|
162 |
+
if isinstance(m, nn.Conv2d):
|
163 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
164 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
165 |
+
nn.init.constant_(m.weight, 1)
|
166 |
+
nn.init.constant_(m.bias, 0)
|
167 |
+
|
168 |
+
# Zero-initialize the last BN in each residual branch,
|
169 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
170 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
171 |
+
if zero_init_residual:
|
172 |
+
for m in self.modules():
|
173 |
+
if isinstance(m, Bottleneck):
|
174 |
+
nn.init.constant_(m.bn3.weight, 0)
|
175 |
+
elif isinstance(m, BasicBlock):
|
176 |
+
nn.init.constant_(m.bn2.weight, 0)
|
177 |
+
|
178 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
179 |
+
norm_layer = self._norm_layer
|
180 |
+
downsample = None
|
181 |
+
previous_dilation = self.dilation
|
182 |
+
if dilate:
|
183 |
+
self.dilation *= stride
|
184 |
+
stride = 1
|
185 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
186 |
+
downsample = nn.Sequential(
|
187 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
188 |
+
norm_layer(planes * block.expansion),
|
189 |
+
)
|
190 |
+
|
191 |
+
layers = []
|
192 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
193 |
+
self.base_width, previous_dilation, norm_layer))
|
194 |
+
self.inplanes = planes * block.expansion
|
195 |
+
for _ in range(1, blocks):
|
196 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
197 |
+
base_width=self.base_width, dilation=self.dilation,
|
198 |
+
norm_layer=norm_layer))
|
199 |
+
|
200 |
+
return nn.Sequential(*layers)
|
201 |
+
|
202 |
+
def _forward_impl(self, x):
|
203 |
+
outputs = {}
|
204 |
+
# See note [TorchScript super()]
|
205 |
+
x = self.conv1(x)
|
206 |
+
x = self.bn1(x)
|
207 |
+
x = self.relu(x)
|
208 |
+
x = self.maxpool(x)
|
209 |
+
# outputs['stem'] = x
|
210 |
+
|
211 |
+
x = self.layer1(x) # 1/4
|
212 |
+
outputs['res2'] = x
|
213 |
+
|
214 |
+
x = self.layer2(x) # 1/8
|
215 |
+
outputs['res3'] = x
|
216 |
+
|
217 |
+
x = self.layer3(x) # 1/16
|
218 |
+
outputs['res4'] = x
|
219 |
+
|
220 |
+
x = self.layer4(x) # 1/32
|
221 |
+
outputs['res5'] = x
|
222 |
+
|
223 |
+
return outputs
|
224 |
+
|
225 |
+
def forward(self, x):
|
226 |
+
return self._forward_impl(x)
|
227 |
+
|
228 |
+
|
229 |
+
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
|
230 |
+
model = ResNet(block, layers, **kwargs)
|
231 |
+
if pretrained:
|
232 |
+
state_dict = load_state_dict_from_url(model_urls[arch],
|
233 |
+
progress=progress)
|
234 |
+
model.load_state_dict(state_dict, strict=False)
|
235 |
+
return model
|
236 |
+
|
237 |
+
|
238 |
+
def resnet18(pretrained=False, progress=True, **kwargs):
|
239 |
+
r"""ResNet-18 model from
|
240 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
241 |
+
Args:
|
242 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
243 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
244 |
+
"""
|
245 |
+
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
246 |
+
**kwargs)
|
247 |
+
|
248 |
+
|
249 |
+
def resnet34(pretrained=False, progress=True, **kwargs):
|
250 |
+
r"""ResNet-34 model from
|
251 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
252 |
+
Args:
|
253 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
254 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
255 |
+
"""
|
256 |
+
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
257 |
+
**kwargs)
|
258 |
+
|
259 |
+
|
260 |
+
def resnet50(pretrained=False, progress=True, **kwargs):
|
261 |
+
r"""ResNet-50 model from
|
262 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
263 |
+
Args:
|
264 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
265 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
266 |
+
"""
|
267 |
+
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
268 |
+
**kwargs)
|
269 |
+
|
270 |
+
|
271 |
+
def resnet101(pretrained=False, progress=True, **kwargs):
|
272 |
+
r"""ResNet-101 model from
|
273 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
274 |
+
Args:
|
275 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
276 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
277 |
+
"""
|
278 |
+
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
279 |
+
**kwargs)
|
280 |
+
|
281 |
+
|
282 |
+
def resnet152(pretrained=False, progress=True, **kwargs):
|
283 |
+
r"""ResNet-152 model from
|
284 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
|
285 |
+
Args:
|
286 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
287 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
288 |
+
"""
|
289 |
+
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
290 |
+
**kwargs)
|
291 |
+
|
292 |
+
|
293 |
+
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
|
294 |
+
r"""ResNeXt-50 32x4d model from
|
295 |
+
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
296 |
+
Args:
|
297 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
298 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
299 |
+
"""
|
300 |
+
kwargs['groups'] = 32
|
301 |
+
kwargs['width_per_group'] = 4
|
302 |
+
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
303 |
+
pretrained, progress, **kwargs)
|
304 |
+
|
305 |
+
|
306 |
+
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
|
307 |
+
r"""ResNeXt-101 32x8d model from
|
308 |
+
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
|
309 |
+
Args:
|
310 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
311 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
312 |
+
"""
|
313 |
+
kwargs['groups'] = 32
|
314 |
+
kwargs['width_per_group'] = 8
|
315 |
+
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
316 |
+
pretrained, progress, **kwargs)
|
317 |
+
|
318 |
+
|
319 |
+
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
|
320 |
+
r"""Wide ResNet-50-2 model from
|
321 |
+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
322 |
+
The model is the same as ResNet except for the bottleneck number of channels
|
323 |
+
which is twice larger in every block. The number of channels in outer 1x1
|
324 |
+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
325 |
+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
326 |
+
Args:
|
327 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
328 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
329 |
+
"""
|
330 |
+
kwargs['width_per_group'] = 64 * 2
|
331 |
+
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
332 |
+
pretrained, progress, **kwargs)
|
333 |
+
|
334 |
+
|
335 |
+
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
|
336 |
+
r"""Wide ResNet-101-2 model from
|
337 |
+
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
|
338 |
+
The model is the same as ResNet except for the bottleneck number of channels
|
339 |
+
which is twice larger in every block. The number of channels in outer 1x1
|
340 |
+
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
341 |
+
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
342 |
+
Args:
|
343 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
344 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
345 |
+
"""
|
346 |
+
kwargs['width_per_group'] = 64 * 2
|
347 |
+
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
348 |
+
pretrained, progress, **kwargs)
|
segmenter_model/blocks.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adapted from 2020 Ross Wightman
|
3 |
+
https://github.com/rwightman/pytorch-image-models
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from einops import rearrange
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from timm.models.layers import DropPath
|
14 |
+
|
15 |
+
|
16 |
+
class FeedForward(nn.Module):
|
17 |
+
def __init__(self, dim, hidden_dim, dropout, out_dim=None):
|
18 |
+
super().__init__()
|
19 |
+
self.fc1 = nn.Linear(dim, hidden_dim)
|
20 |
+
self.act = nn.GELU()
|
21 |
+
if out_dim is None:
|
22 |
+
out_dim = dim
|
23 |
+
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
24 |
+
self.drop = nn.Dropout(dropout)
|
25 |
+
|
26 |
+
@property
|
27 |
+
def unwrapped(self):
|
28 |
+
return self
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
x = self.fc1(x)
|
32 |
+
x = self.act(x)
|
33 |
+
x = self.drop(x)
|
34 |
+
x = self.fc2(x)
|
35 |
+
x = self.drop(x)
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
class Attention(nn.Module):
|
40 |
+
def __init__(self, dim, heads, dropout):
|
41 |
+
super().__init__()
|
42 |
+
self.heads = heads
|
43 |
+
head_dim = dim // heads
|
44 |
+
self.scale = head_dim ** -0.5
|
45 |
+
self.attn = None
|
46 |
+
|
47 |
+
self.qkv = nn.Linear(dim, dim * 3)
|
48 |
+
self.attn_drop = nn.Dropout(dropout)
|
49 |
+
self.proj = nn.Linear(dim, dim)
|
50 |
+
self.proj_drop = nn.Dropout(dropout)
|
51 |
+
|
52 |
+
@property
|
53 |
+
def unwrapped(self):
|
54 |
+
return self
|
55 |
+
|
56 |
+
def forward(self, x, mask=None):
|
57 |
+
B, N, C = x.shape
|
58 |
+
qkv = (
|
59 |
+
self.qkv(x)
|
60 |
+
.reshape(B, N, 3, self.heads, C // self.heads)
|
61 |
+
.permute(2, 0, 3, 1, 4)
|
62 |
+
)
|
63 |
+
q, k, v = (
|
64 |
+
qkv[0],
|
65 |
+
qkv[1],
|
66 |
+
qkv[2],
|
67 |
+
)
|
68 |
+
|
69 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
70 |
+
attn = attn.softmax(dim=-1)
|
71 |
+
attn = self.attn_drop(attn)
|
72 |
+
|
73 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
74 |
+
x = self.proj(x)
|
75 |
+
x = self.proj_drop(x)
|
76 |
+
|
77 |
+
return x, attn
|
78 |
+
|
79 |
+
|
80 |
+
class AttentionQK(nn.Module):
|
81 |
+
def __init__(self, dim, heads=1, dropout=0.):
|
82 |
+
super().__init__()
|
83 |
+
self.heads = heads
|
84 |
+
head_dim = dim // heads
|
85 |
+
self.scale = head_dim ** -0.5
|
86 |
+
self.attn = None
|
87 |
+
|
88 |
+
self.qk = nn.Linear(dim, dim * 2)
|
89 |
+
self.attn_drop = nn.Dropout(dropout)
|
90 |
+
|
91 |
+
@property
|
92 |
+
def unwrapped(self):
|
93 |
+
return self
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
B, N, C = x.shape
|
97 |
+
qkv = (
|
98 |
+
self.qk(x)
|
99 |
+
.reshape(B, N, 2, self.heads, C // self.heads)
|
100 |
+
.permute(2, 0, 3, 1, 4)
|
101 |
+
)
|
102 |
+
q, k = (
|
103 |
+
qkv[0],
|
104 |
+
qkv[1]
|
105 |
+
)
|
106 |
+
|
107 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
108 |
+
# attn = attn.sigmoid()
|
109 |
+
attn = attn.softmax(dim=-1)
|
110 |
+
|
111 |
+
return attn
|
112 |
+
|
113 |
+
|
114 |
+
class Block(nn.Module):
|
115 |
+
def __init__(self, dim, heads, mlp_dim, dropout, drop_path):
|
116 |
+
super().__init__()
|
117 |
+
self.norm1 = nn.LayerNorm(dim)
|
118 |
+
self.norm2 = nn.LayerNorm(dim)
|
119 |
+
self.attn = Attention(dim, heads, dropout)
|
120 |
+
self.mlp = FeedForward(dim, mlp_dim, dropout)
|
121 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
122 |
+
|
123 |
+
def forward(self, x, mask=None, return_attention=False):
|
124 |
+
y, attn = self.attn(self.norm1(x), mask)
|
125 |
+
if return_attention:
|
126 |
+
return attn
|
127 |
+
x = x + self.drop_path(y)
|
128 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
129 |
+
return x
|
segmenter_model/decoder.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
from timm.models.layers import trunc_normal_
|
8 |
+
|
9 |
+
from segmenter_model.blocks import Block, FeedForward
|
10 |
+
from segmenter_model.utils import init_weights
|
11 |
+
|
12 |
+
|
13 |
+
class DecoderLinear(nn.Module):
|
14 |
+
def __init__(self, n_cls, patch_size, d_encoder):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.d_encoder = d_encoder
|
18 |
+
self.patch_size = patch_size
|
19 |
+
self.n_cls = n_cls
|
20 |
+
|
21 |
+
self.head = nn.Linear(self.d_encoder, n_cls)
|
22 |
+
self.apply(init_weights)
|
23 |
+
|
24 |
+
@torch.jit.ignore
|
25 |
+
def no_weight_decay(self):
|
26 |
+
return set()
|
27 |
+
|
28 |
+
def forward(self, x, im_size):
|
29 |
+
H, W = im_size
|
30 |
+
GS = H // self.patch_size
|
31 |
+
x = self.head(x)
|
32 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=GS)
|
33 |
+
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
class MaskTransformer(nn.Module):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
n_cls,
|
41 |
+
patch_size,
|
42 |
+
d_encoder,
|
43 |
+
n_layers,
|
44 |
+
n_heads,
|
45 |
+
d_model,
|
46 |
+
d_ff,
|
47 |
+
drop_path_rate,
|
48 |
+
dropout,
|
49 |
+
):
|
50 |
+
super().__init__()
|
51 |
+
self.d_encoder = d_encoder
|
52 |
+
self.patch_size = patch_size
|
53 |
+
self.n_layers = n_layers
|
54 |
+
self.n_cls = n_cls
|
55 |
+
self.d_model = d_model
|
56 |
+
self.d_ff = d_ff
|
57 |
+
self.scale = d_model ** -0.5
|
58 |
+
|
59 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
|
60 |
+
self.blocks = nn.ModuleList(
|
61 |
+
[Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)]
|
62 |
+
)
|
63 |
+
|
64 |
+
self.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model))
|
65 |
+
self.proj_dec = nn.Linear(d_encoder, d_model)
|
66 |
+
|
67 |
+
self.proj_patch = nn.Parameter(self.scale * torch.randn(d_model, d_model))
|
68 |
+
self.proj_classes = nn.Parameter(self.scale * torch.randn(d_model, d_model))
|
69 |
+
|
70 |
+
self.decoder_norm = nn.LayerNorm(d_model)
|
71 |
+
self.mask_norm = nn.LayerNorm(n_cls)
|
72 |
+
|
73 |
+
self.apply(init_weights)
|
74 |
+
trunc_normal_(self.cls_emb, std=0.02)
|
75 |
+
|
76 |
+
@torch.jit.ignore
|
77 |
+
def no_weight_decay(self):
|
78 |
+
return {"cls_emb"}
|
79 |
+
|
80 |
+
def forward(self, x, im_size, features_only=False, no_rearrange=False):
|
81 |
+
H, W = im_size
|
82 |
+
GS = H // self.patch_size
|
83 |
+
|
84 |
+
# project from the encoder dimensionality to the decoder dimensionality (usually the same)
|
85 |
+
x = self.proj_dec(x)
|
86 |
+
# reshape the class embedding token
|
87 |
+
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
|
88 |
+
# concatenate the class embedding token to the input
|
89 |
+
x = torch.cat((x, cls_emb), 1)
|
90 |
+
# forward the concatenated tokens through decoder blocks
|
91 |
+
for blk in self.blocks:
|
92 |
+
x = blk(x)
|
93 |
+
# perform normalization
|
94 |
+
x = self.decoder_norm(x)
|
95 |
+
|
96 |
+
# split to patch features and class-segmentation features
|
97 |
+
patches, cls_seg_feat = x[:, : -self.n_cls], x[:, -self.n_cls:]
|
98 |
+
|
99 |
+
# project the patch features
|
100 |
+
patches = patches @ self.proj_patch
|
101 |
+
|
102 |
+
if features_only:
|
103 |
+
if not no_rearrange:
|
104 |
+
features = rearrange(patches, "b (h w) n -> b n h w", h=int(GS))
|
105 |
+
else:
|
106 |
+
features = patches
|
107 |
+
return features
|
108 |
+
|
109 |
+
# project the class-segmentation features
|
110 |
+
cls_seg_feat = cls_seg_feat @ self.proj_classes
|
111 |
+
|
112 |
+
# scalar product between L2-normalized patch embeddings and class embeddings -> masks
|
113 |
+
patches = patches / patches.norm(dim=-1, keepdim=True)
|
114 |
+
cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)
|
115 |
+
masks = patches @ cls_seg_feat.transpose(1, 2)
|
116 |
+
|
117 |
+
masks = self.mask_norm(masks)
|
118 |
+
if not no_rearrange:
|
119 |
+
masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS))
|
120 |
+
|
121 |
+
return masks
|
122 |
+
|
123 |
+
def get_attention_map(self, x, layer_id):
|
124 |
+
if layer_id >= self.n_layers or layer_id < 0:
|
125 |
+
raise ValueError(
|
126 |
+
f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}."
|
127 |
+
)
|
128 |
+
x = self.proj_dec(x)
|
129 |
+
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
|
130 |
+
x = torch.cat((x, cls_emb), 1)
|
131 |
+
for i, blk in enumerate(self.blocks):
|
132 |
+
if i < layer_id:
|
133 |
+
x = blk(x)
|
134 |
+
else:
|
135 |
+
return blk(x, return_attention=True)
|
136 |
+
|
137 |
+
|
138 |
+
class DeepLabHead(nn.Sequential):
|
139 |
+
def __init__(self, in_channels, num_classes, patch_size=None):
|
140 |
+
super(DeepLabHead, self).__init__(
|
141 |
+
ASPP(in_channels, [12, 24, 36]),
|
142 |
+
nn.Conv2d(256, 256, 3, padding=1, bias=False),
|
143 |
+
nn.BatchNorm2d(256),
|
144 |
+
nn.ReLU(),
|
145 |
+
nn.Conv2d(256, num_classes, 1)
|
146 |
+
)
|
147 |
+
self.patch_size = patch_size
|
148 |
+
|
149 |
+
def forward(self, x, im_size=None):
|
150 |
+
if len(x.shape) == 3:
|
151 |
+
# features from ViT
|
152 |
+
assert im_size is not None and self.patch_size is not None
|
153 |
+
H, W = im_size
|
154 |
+
GS = H // self.patch_size
|
155 |
+
x = rearrange(x, "b (h w) n -> b n h w", h=int(GS)).contiguous()
|
156 |
+
for module in self:
|
157 |
+
x = module(x)
|
158 |
+
return x
|
159 |
+
|
160 |
+
|
161 |
+
class ASPPConv(nn.Sequential):
|
162 |
+
def __init__(self, in_channels, out_channels, dilation):
|
163 |
+
modules = [
|
164 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
|
165 |
+
nn.BatchNorm2d(out_channels),
|
166 |
+
nn.ReLU()
|
167 |
+
]
|
168 |
+
super(ASPPConv, self).__init__(*modules)
|
169 |
+
|
170 |
+
|
171 |
+
class ASPPPooling(nn.Sequential):
|
172 |
+
def __init__(self, in_channels, out_channels):
|
173 |
+
super(ASPPPooling, self).__init__(
|
174 |
+
nn.AdaptiveAvgPool2d(1),
|
175 |
+
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
176 |
+
nn.BatchNorm2d(out_channels),
|
177 |
+
nn.ReLU())
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
size = x.shape[-2:]
|
181 |
+
for mod in self:
|
182 |
+
x = mod(x)
|
183 |
+
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
|
184 |
+
|
185 |
+
|
186 |
+
class ASPP(nn.Module):
|
187 |
+
def __init__(self, in_channels, atrous_rates, out_channels=256):
|
188 |
+
super(ASPP, self).__init__()
|
189 |
+
modules = []
|
190 |
+
modules.append(nn.Sequential(
|
191 |
+
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
192 |
+
nn.BatchNorm2d(out_channels),
|
193 |
+
nn.ReLU()))
|
194 |
+
|
195 |
+
rates = tuple(atrous_rates)
|
196 |
+
for rate in rates:
|
197 |
+
modules.append(ASPPConv(in_channels, out_channels, rate))
|
198 |
+
|
199 |
+
modules.append(ASPPPooling(in_channels, out_channels))
|
200 |
+
|
201 |
+
self.convs = nn.ModuleList(modules)
|
202 |
+
|
203 |
+
self.project = nn.Sequential(
|
204 |
+
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
|
205 |
+
nn.BatchNorm2d(out_channels),
|
206 |
+
nn.ReLU(),
|
207 |
+
nn.Dropout(0.5))
|
208 |
+
|
209 |
+
def forward(self, x):
|
210 |
+
res = []
|
211 |
+
for conv in self.convs:
|
212 |
+
res.append(conv(x))
|
213 |
+
res = torch.cat(res, dim=1)
|
214 |
+
return self.project(res)
|
segmenter_model/factory.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import yaml
|
3 |
+
import torch
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from timm.models.helpers import load_pretrained, load_custom_pretrained
|
9 |
+
from timm.models.vision_transformer import default_cfgs, checkpoint_filter_fn
|
10 |
+
from timm.models.registry import register_model
|
11 |
+
from timm.models.vision_transformer import _create_vision_transformer
|
12 |
+
from segmenter_model.decoder import MaskTransformer
|
13 |
+
from segmenter_model.segmenter import Segmenter
|
14 |
+
import segmenter_model.torch as ptu
|
15 |
+
|
16 |
+
from segmenter_model.vit_dino import vit_small, VisionTransformer
|
17 |
+
|
18 |
+
|
19 |
+
@register_model
|
20 |
+
def vit_base_patch8_384(pretrained=False, **kwargs):
|
21 |
+
"""ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
22 |
+
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
23 |
+
"""
|
24 |
+
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
25 |
+
model = _create_vision_transformer(
|
26 |
+
"vit_base_patch8_384",
|
27 |
+
pretrained=pretrained,
|
28 |
+
default_cfg=dict(
|
29 |
+
url="",
|
30 |
+
input_size=(3, 384, 384),
|
31 |
+
mean=(0.5, 0.5, 0.5),
|
32 |
+
std=(0.5, 0.5, 0.5),
|
33 |
+
num_classes=1000,
|
34 |
+
),
|
35 |
+
**model_kwargs,
|
36 |
+
)
|
37 |
+
return model
|
38 |
+
|
39 |
+
|
40 |
+
def create_vit(model_cfg):
|
41 |
+
model_cfg = model_cfg.copy()
|
42 |
+
backbone = model_cfg.pop("backbone")
|
43 |
+
if 'pretrained_weights' in model_cfg:
|
44 |
+
pretrained_weights = model_cfg.pop('pretrained_weights')
|
45 |
+
|
46 |
+
if 'dino' in backbone:
|
47 |
+
if backbone.lower() == 'dino_vits16':
|
48 |
+
model_cfg['drop_rate'] = model_cfg['dropout']
|
49 |
+
model = vit_small(**model_cfg)
|
50 |
+
# hard-coded for now, too lazy
|
51 |
+
ciirc_path = '/home/vobecant/PhD/weights/dino/dino_deitsmall16_pretrain.pth'
|
52 |
+
karolina_path = '/scratch/project/dd-21-20/pretrained_weights/dino/dino_deitsmall16_pretrain.pth'
|
53 |
+
if os.path.exists(ciirc_path):
|
54 |
+
pretrained_weights = ciirc_path
|
55 |
+
elif os.path.exists(karolina_path):
|
56 |
+
pretrained_weights = karolina_path
|
57 |
+
else:
|
58 |
+
raise Exception('DINO weights not found!')
|
59 |
+
model.load_state_dict(torch.load(pretrained_weights), strict=True)
|
60 |
+
else:
|
61 |
+
model = torch.hub.load('facebookresearch/dino:main', backbone)
|
62 |
+
setattr(model, 'd_model', model.num_features)
|
63 |
+
setattr(model, 'patch_size', model.patch_embed.patch_size)
|
64 |
+
setattr(model, 'distilled', False)
|
65 |
+
model.forward = lambda x, return_features: model.get_intermediate_layers(x, n=1)[0]
|
66 |
+
else:
|
67 |
+
normalization = model_cfg.pop("normalization")
|
68 |
+
model_cfg["n_cls"] = 1000
|
69 |
+
mlp_expansion_ratio = 4
|
70 |
+
model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"]
|
71 |
+
|
72 |
+
if backbone in default_cfgs:
|
73 |
+
default_cfg = default_cfgs[backbone]
|
74 |
+
else:
|
75 |
+
default_cfg = dict(
|
76 |
+
pretrained=False,
|
77 |
+
num_classes=1000,
|
78 |
+
drop_rate=0.0,
|
79 |
+
drop_path_rate=0.0,
|
80 |
+
drop_block_rate=None,
|
81 |
+
)
|
82 |
+
|
83 |
+
default_cfg["input_size"] = (
|
84 |
+
3,
|
85 |
+
model_cfg["image_size"][0],
|
86 |
+
model_cfg["image_size"][1],
|
87 |
+
)
|
88 |
+
model = VisionTransformer(**model_cfg)
|
89 |
+
if backbone == "vit_base_patch8_384":
|
90 |
+
path = os.path.expandvars("/home/vobecant/PhD/weights/vit_base_patch8_384.pth")
|
91 |
+
state_dict = torch.load(path, map_location="cpu")
|
92 |
+
filtered_dict = checkpoint_filter_fn(state_dict, model)
|
93 |
+
model.load_state_dict(filtered_dict, strict=True)
|
94 |
+
elif "deit" in backbone:
|
95 |
+
load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn)
|
96 |
+
else:
|
97 |
+
load_custom_pretrained(model, default_cfg)
|
98 |
+
|
99 |
+
return model
|
100 |
+
|
101 |
+
|
102 |
+
def create_decoder(encoder, decoder_cfg):
|
103 |
+
decoder_cfg = decoder_cfg.copy()
|
104 |
+
name = decoder_cfg.pop("name")
|
105 |
+
decoder_cfg["d_encoder"] = encoder.d_model
|
106 |
+
decoder_cfg["patch_size"] = encoder.patch_size
|
107 |
+
|
108 |
+
if "linear" in name:
|
109 |
+
decoder = DecoderLinear(**decoder_cfg)
|
110 |
+
elif name == "mask_transformer":
|
111 |
+
dim = encoder.d_model
|
112 |
+
n_heads = dim // 64
|
113 |
+
decoder_cfg["n_heads"] = n_heads
|
114 |
+
decoder_cfg["d_model"] = dim
|
115 |
+
decoder_cfg["d_ff"] = 4 * dim
|
116 |
+
decoder = MaskTransformer(**decoder_cfg)
|
117 |
+
elif 'deeplab' in name:
|
118 |
+
decoder = DeepLabHead(in_channels=encoder.d_model, num_classes=decoder_cfg["n_cls"],
|
119 |
+
patch_size=decoder_cfg["patch_size"])
|
120 |
+
else:
|
121 |
+
raise ValueError(f"Unknown decoder: {name}")
|
122 |
+
return decoder
|
123 |
+
|
124 |
+
|
125 |
+
def create_segmenter(model_cfg):
|
126 |
+
model_cfg = model_cfg.copy()
|
127 |
+
decoder_cfg = model_cfg.pop("decoder")
|
128 |
+
decoder_cfg["n_cls"] = model_cfg["n_cls"]
|
129 |
+
|
130 |
+
if 'weights_path' in model_cfg.keys():
|
131 |
+
weights_path = model_cfg.pop('weights_path')
|
132 |
+
else:
|
133 |
+
weights_path = None
|
134 |
+
|
135 |
+
encoder = create_vit(model_cfg)
|
136 |
+
decoder = create_decoder(encoder, decoder_cfg)
|
137 |
+
model = Segmenter(encoder, decoder, n_cls=model_cfg["n_cls"])
|
138 |
+
|
139 |
+
if weights_path is not None:
|
140 |
+
raise Exception('Wants to load weights to the complete segmenter insice create_segmenter method!')
|
141 |
+
state_dict = torch.load(weights_path, map_location="cpu")
|
142 |
+
if 'model' in state_dict:
|
143 |
+
state_dict = state_dict['model']
|
144 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
145 |
+
print(msg)
|
146 |
+
|
147 |
+
return model
|
148 |
+
|
149 |
+
|
150 |
+
def load_model(model_path, decoder_only=False, variant_path=None):
|
151 |
+
variant_path = Path(model_path).parent / "variant.yml" if variant_path is None else variant_path
|
152 |
+
with open(variant_path, "r") as f:
|
153 |
+
variant = yaml.load(f, Loader=yaml.FullLoader)
|
154 |
+
net_kwargs = variant["net_kwargs"]
|
155 |
+
|
156 |
+
model = create_segmenter(net_kwargs)
|
157 |
+
data = torch.load(model_path, map_location=ptu.device)
|
158 |
+
checkpoint = data["model"]
|
159 |
+
|
160 |
+
if decoder_only:
|
161 |
+
model.decoder.load_state_dict(checkpoint, strict=True)
|
162 |
+
else:
|
163 |
+
model.load_state_dict(checkpoint, strict=True)
|
164 |
+
|
165 |
+
return model, variant
|
segmenter_model/fpn_picie.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# taken from https://raw.githubusercontent.com/janghyuncho/PiCIE/1d7b034f57e98670b0d6a244b2eea11fa0dde73e/modules/fpn.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from . import backbone_picie as backbone
|
7 |
+
|
8 |
+
|
9 |
+
class PanopticFPN(nn.Module):
|
10 |
+
def __init__(self, arch, pretrain, n_cls):
|
11 |
+
super(PanopticFPN, self).__init__()
|
12 |
+
self.n_cls = n_cls
|
13 |
+
self.backbone = backbone.__dict__[arch](pretrained=pretrain)
|
14 |
+
self.decoder = FPNDecoder(arch, n_cls)
|
15 |
+
|
16 |
+
def forward(self, x, encoder_features=False, decoder_features=False):
|
17 |
+
feats = self.backbone(x)
|
18 |
+
if decoder_features:
|
19 |
+
dec, outs = self.decoder(feats, get_features=decoder_features)
|
20 |
+
else:
|
21 |
+
outs = self.decoder(feats)
|
22 |
+
|
23 |
+
if encoder_features:
|
24 |
+
if decoder_features:
|
25 |
+
return feats['res5'], dec, outs
|
26 |
+
else:
|
27 |
+
return feats['res5'], outs
|
28 |
+
else:
|
29 |
+
return outs
|
30 |
+
|
31 |
+
|
32 |
+
class FPNDecoder(nn.Module):
|
33 |
+
def __init__(self, arch, n_cls):
|
34 |
+
super(FPNDecoder, self).__init__()
|
35 |
+
self.n_cls = n_cls
|
36 |
+
if arch == 'resnet18':
|
37 |
+
mfactor = 1
|
38 |
+
out_dim = 128
|
39 |
+
else:
|
40 |
+
mfactor = 4
|
41 |
+
out_dim = 256
|
42 |
+
|
43 |
+
self.layer4 = nn.Conv2d(512 * mfactor // 8, out_dim, kernel_size=1, stride=1, padding=0)
|
44 |
+
self.layer3 = nn.Conv2d(512 * mfactor // 4, out_dim, kernel_size=1, stride=1, padding=0)
|
45 |
+
self.layer2 = nn.Conv2d(512 * mfactor // 2, out_dim, kernel_size=1, stride=1, padding=0)
|
46 |
+
self.layer1 = nn.Conv2d(512 * mfactor, out_dim, kernel_size=1, stride=1, padding=0)
|
47 |
+
|
48 |
+
self.pred = nn.Conv2d(out_dim, self.n_cls, 1, 1)
|
49 |
+
|
50 |
+
def forward(self, x, get_features=False):
|
51 |
+
o1 = self.layer1(x['res5'])
|
52 |
+
o2 = self.upsample_add(o1, self.layer2(x['res4']))
|
53 |
+
o3 = self.upsample_add(o2, self.layer3(x['res3']))
|
54 |
+
o4 = self.upsample_add(o3, self.layer4(x['res2']))
|
55 |
+
|
56 |
+
pred = self.pred(o4)
|
57 |
+
|
58 |
+
if get_features:
|
59 |
+
return o4, pred
|
60 |
+
else:
|
61 |
+
return pred
|
62 |
+
|
63 |
+
def upsample_add(self, x, y):
|
64 |
+
_, _, H, W = y.size()
|
65 |
+
|
66 |
+
return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False) + y
|
segmenter_model/picie_model.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from . import backbone_picie as backbone
|
5 |
+
|
6 |
+
|
7 |
+
class PanopticFPN(nn.Module):
|
8 |
+
def __init__(self, args):
|
9 |
+
super(PanopticFPN, self).__init__()
|
10 |
+
self.backbone = backbone.__dict__[args.arch](pretrained=args.pretrain)
|
11 |
+
if args.arch == 'vit_small':
|
12 |
+
self.decoder = FPNDecoderViT(args)
|
13 |
+
else:
|
14 |
+
self.decoder = FPNDecoder(args)
|
15 |
+
|
16 |
+
def forward(self, x, encoder_features=False, decoder_features=False):
|
17 |
+
feats = self.backbone(x)
|
18 |
+
dec_outs = self.decoder(feats)
|
19 |
+
|
20 |
+
if encoder_features:
|
21 |
+
return feats['res5'], dec_outs
|
22 |
+
else:
|
23 |
+
return dec_outs
|
24 |
+
|
25 |
+
|
26 |
+
class FPNDecoder(nn.Module):
|
27 |
+
def __init__(self, args):
|
28 |
+
super(FPNDecoder, self).__init__()
|
29 |
+
if args.arch == 'resnet18':
|
30 |
+
mfactor = 1
|
31 |
+
out_dim = 128
|
32 |
+
else:
|
33 |
+
mfactor = 4
|
34 |
+
out_dim = 256
|
35 |
+
|
36 |
+
self.layer4 = nn.Conv2d(512 * mfactor // 8, out_dim, kernel_size=1, stride=1, padding=0)
|
37 |
+
self.layer3 = nn.Conv2d(512 * mfactor // 4, out_dim, kernel_size=1, stride=1, padding=0)
|
38 |
+
self.layer2 = nn.Conv2d(512 * mfactor // 2, out_dim, kernel_size=1, stride=1, padding=0)
|
39 |
+
self.layer1 = nn.Conv2d(512 * mfactor, out_dim, kernel_size=1, stride=1, padding=0)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
o1 = self.layer1(x['res5'])
|
43 |
+
o2 = self.upsample_add(o1, self.layer2(x['res4']))
|
44 |
+
o3 = self.upsample_add(o2, self.layer3(x['res3']))
|
45 |
+
o4 = self.upsample_add(o3, self.layer4(x['res2']))
|
46 |
+
|
47 |
+
return o4
|
48 |
+
|
49 |
+
def upsample_add(self, x, y):
|
50 |
+
_, _, H, W = y.size()
|
51 |
+
|
52 |
+
return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False) + y
|
53 |
+
|
54 |
+
|
55 |
+
class FPNDecoderViT(nn.Module):
|
56 |
+
def __init__(self, args):
|
57 |
+
super(FPNDecoderViT, self).__init__()
|
58 |
+
if args.arch == 'resnet18' or args.arch == 'vit_small':
|
59 |
+
mfactor = 1
|
60 |
+
out_dim = 128
|
61 |
+
else:
|
62 |
+
mfactor = 4
|
63 |
+
out_dim = 256
|
64 |
+
|
65 |
+
self.upsample_rate = 4
|
66 |
+
|
67 |
+
self.layer4 = nn.Conv2d(384, out_dim, kernel_size=1, stride=1, padding=0)
|
68 |
+
self.layer3 = nn.Conv2d(384, out_dim, kernel_size=1, stride=1, padding=0)
|
69 |
+
self.layer2 = nn.Conv2d(384, out_dim, kernel_size=1, stride=1, padding=0)
|
70 |
+
self.layer1 = nn.Conv2d(384, out_dim, kernel_size=1, stride=1, padding=0)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
o1 = self.layer1(x[3])
|
74 |
+
o1 = F.interpolate(o1, scale_factor=4, mode='bilinear', align_corners=False)
|
75 |
+
o2 = self.upsample_add(o1, self.layer2(x[2]))
|
76 |
+
o3 = self.upsample_add(o2, self.layer3(x[1]))
|
77 |
+
o4 = self.upsample_add(o3, self.layer4(x[0]))
|
78 |
+
|
79 |
+
return o4
|
80 |
+
|
81 |
+
def upsample_add(self, x, y):
|
82 |
+
return F.interpolate(y, scale_factor=self.upsample_rate, mode='bilinear', align_corners=False) + x
|
segmenter_model/resnet_dilated.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Authors: Wouter Van Gansbeke & Simon Vandenhende
|
3 |
+
# Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
|
4 |
+
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
class ResnetDilated(nn.Module):
|
8 |
+
def __init__(self, orig_resnet, dilate_scale=8):
|
9 |
+
super(ResnetDilated, self).__init__()
|
10 |
+
from functools import partial
|
11 |
+
|
12 |
+
if dilate_scale == 8:
|
13 |
+
orig_resnet.layer3.apply(
|
14 |
+
partial(self._nostride_dilate, dilate=2))
|
15 |
+
orig_resnet.layer4.apply(
|
16 |
+
partial(self._nostride_dilate, dilate=4))
|
17 |
+
elif dilate_scale == 16:
|
18 |
+
orig_resnet.layer4.apply(
|
19 |
+
partial(self._nostride_dilate, dilate=2))
|
20 |
+
|
21 |
+
self.conv1 = orig_resnet.conv1
|
22 |
+
self.bn1 = orig_resnet.bn1
|
23 |
+
self.relu = orig_resnet.relu
|
24 |
+
|
25 |
+
self.maxpool = orig_resnet.maxpool
|
26 |
+
self.layer1 = orig_resnet.layer1
|
27 |
+
self.layer2 = orig_resnet.layer2
|
28 |
+
self.layer3 = orig_resnet.layer3
|
29 |
+
self.layer4 = orig_resnet.layer4
|
30 |
+
|
31 |
+
def _nostride_dilate(self, m, dilate):
|
32 |
+
classname = m.__class__.__name__
|
33 |
+
if classname.find('Conv') != -1:
|
34 |
+
# the convolution with stride
|
35 |
+
if m.stride == (2, 2):
|
36 |
+
m.stride = (1, 1)
|
37 |
+
if m.kernel_size == (3, 3):
|
38 |
+
m.dilation = (dilate//2, dilate//2)
|
39 |
+
m.padding = (dilate//2, dilate//2)
|
40 |
+
# other convoluions
|
41 |
+
else:
|
42 |
+
if m.kernel_size == (3, 3):
|
43 |
+
m.dilation = (dilate, dilate)
|
44 |
+
m.padding = (dilate, dilate)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
48 |
+
x = self.maxpool(x)
|
49 |
+
|
50 |
+
x = self.layer1(x)
|
51 |
+
x = self.layer2(x)
|
52 |
+
x = self.layer3(x)
|
53 |
+
x = self.layer4(x)
|
54 |
+
|
55 |
+
return x
|
segmenter_model/segmenter.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from einops import rearrange
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
# from timm.models.layers import trunc_normal_
|
8 |
+
|
9 |
+
from segmenter_model.utils import padding, unpadding
|
10 |
+
|
11 |
+
|
12 |
+
class Segmenter(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
encoder,
|
16 |
+
decoder,
|
17 |
+
n_cls,
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.n_cls = n_cls
|
21 |
+
self.patch_size = encoder.patch_size
|
22 |
+
self.encoder = encoder
|
23 |
+
self.decoder = decoder
|
24 |
+
|
25 |
+
@torch.jit.ignore
|
26 |
+
def no_weight_decay(self):
|
27 |
+
def append_prefix_no_weight_decay(prefix, module):
|
28 |
+
return set(map(lambda x: prefix + x, module.no_weight_decay()))
|
29 |
+
|
30 |
+
nwd_params = append_prefix_no_weight_decay("encoder.", self.encoder).union(
|
31 |
+
append_prefix_no_weight_decay("decoder.", self.decoder)
|
32 |
+
)
|
33 |
+
return nwd_params
|
34 |
+
|
35 |
+
def forward(self, im, decoder_features=False, no_upsample=False, encoder_features=False, no_rearrange=False,
|
36 |
+
cls_only=False, encoder_only=False):
|
37 |
+
H_ori, W_ori = im.size(2), im.size(3)
|
38 |
+
if not no_upsample:
|
39 |
+
im = padding(im, self.patch_size)
|
40 |
+
H, W = im.size(2), im.size(3)
|
41 |
+
|
42 |
+
x = self.encoder(im, return_features=True) # self.patch_size times smaller than im
|
43 |
+
|
44 |
+
# remove CLS/DIST tokens for decoding
|
45 |
+
num_extra_tokens = 1 + self.encoder.distilled
|
46 |
+
|
47 |
+
if cls_only:
|
48 |
+
return x[:, 0]
|
49 |
+
x = x[:, num_extra_tokens:]
|
50 |
+
|
51 |
+
if encoder_features:
|
52 |
+
enc_fts = x.clone()
|
53 |
+
if not no_rearrange:
|
54 |
+
GS = H // self.patch_size
|
55 |
+
enc_fts = rearrange(enc_fts, "b (h w) c -> b c h w", h=GS)
|
56 |
+
if encoder_only:
|
57 |
+
return enc_fts
|
58 |
+
|
59 |
+
if decoder_features:
|
60 |
+
output = self.decoder(x, (H, W), features_only=True, no_rearrange=no_rearrange)
|
61 |
+
if no_rearrange:
|
62 |
+
if encoder_features:
|
63 |
+
output = (enc_fts, output)
|
64 |
+
return output
|
65 |
+
else:
|
66 |
+
output = self.decoder(x, (H, W)) # shape (BS, NCLS, H/self.patch_size, W/self.patch_size)
|
67 |
+
|
68 |
+
if not no_upsample:
|
69 |
+
output = F.interpolate(output, size=(H, W), mode="bilinear") # upsample self.patch_size times
|
70 |
+
output = unpadding(output, (H_ori, W_ori))
|
71 |
+
|
72 |
+
if encoder_features:
|
73 |
+
output = (enc_fts, output)
|
74 |
+
return output
|
75 |
+
|
76 |
+
def get_attention_map_enc(self, im, layer_id):
|
77 |
+
return self.encoder.get_attention_map(im, layer_id)
|
78 |
+
|
79 |
+
def get_attention_map_dec(self, im, layer_id):
|
80 |
+
x = self.encoder(im, return_features=True)
|
81 |
+
|
82 |
+
# remove CLS/DIST tokens for decoding
|
83 |
+
num_extra_tokens = 1 + self.encoder.distilled
|
84 |
+
x = x[:, num_extra_tokens:]
|
85 |
+
|
86 |
+
return self.decoder.get_attention_map(x, layer_id)
|
segmenter_model/torch.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
"""
|
5 |
+
GPU wrappers
|
6 |
+
"""
|
7 |
+
|
8 |
+
use_gpu = False
|
9 |
+
gpu_id = 0
|
10 |
+
device = None
|
11 |
+
|
12 |
+
distributed = False
|
13 |
+
dist_rank = 0
|
14 |
+
world_size = 1
|
15 |
+
|
16 |
+
|
17 |
+
def set_gpu_mode(mode, pbs=False):
|
18 |
+
global use_gpu
|
19 |
+
global device
|
20 |
+
global gpu_id
|
21 |
+
global distributed
|
22 |
+
global dist_rank
|
23 |
+
global world_size
|
24 |
+
if pbs:
|
25 |
+
gpu_id = int(os.environ.get("MPI_LOCALRANKID", 0))
|
26 |
+
dist_rank = int(os.environ.get("PMI_RANK", 0))
|
27 |
+
world_size = int(os.environ.get("PMI_SIZE", 1))
|
28 |
+
else:
|
29 |
+
gpu_id = int(os.environ.get("SLURM_LOCALID", 0))
|
30 |
+
dist_rank = int(os.environ.get("SLURM_PROCID", 0))
|
31 |
+
world_size = int(os.environ.get("SLURM_NTASKS", 1))
|
32 |
+
|
33 |
+
distributed = world_size > 1
|
34 |
+
use_gpu = mode
|
35 |
+
print('gpu_id: {}, dist_rank: {}, world_size: {}, distributed: {}'.format(gpu_id, dist_rank, world_size,
|
36 |
+
distributed))
|
37 |
+
device = torch.device(f"cuda:{gpu_id}" if use_gpu else "cpu")
|
38 |
+
torch.backends.cudnn.benchmark = True
|
segmenter_model/utils.py
ADDED
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
# import segm.utils.torch as ptu
|
3 |
+
# from segm.engine import seg2rgb
|
4 |
+
from collections import namedtuple
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from PIL import Image
|
11 |
+
from timm.models.layers import trunc_normal_
|
12 |
+
|
13 |
+
import torch
|
14 |
+
|
15 |
+
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
|
16 |
+
'has_instances', 'ignore_in_eval', 'color'])
|
17 |
+
|
18 |
+
classes = [
|
19 |
+
CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
|
20 |
+
CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
|
21 |
+
CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
|
22 |
+
CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
|
23 |
+
CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
|
24 |
+
CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
|
25 |
+
CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
|
26 |
+
CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
|
27 |
+
CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
|
28 |
+
CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
|
29 |
+
CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
|
30 |
+
CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
|
31 |
+
CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
|
32 |
+
CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
|
33 |
+
CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
|
34 |
+
CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
|
35 |
+
CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
|
36 |
+
CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
|
37 |
+
CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
|
38 |
+
CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
|
39 |
+
CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
|
40 |
+
CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
|
41 |
+
CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
|
42 |
+
CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
|
43 |
+
CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
|
44 |
+
CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
|
45 |
+
CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
|
46 |
+
CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
|
47 |
+
CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
|
48 |
+
CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
|
49 |
+
CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
|
50 |
+
CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
|
51 |
+
CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
|
52 |
+
CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
|
53 |
+
CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
|
54 |
+
]
|
55 |
+
|
56 |
+
cityscapes_id_to_trainID = {cls.id: cls.train_id for cls in classes}
|
57 |
+
cityscapes_trainID_to_testID = {cls.train_id: cls.id for cls in classes}
|
58 |
+
cityscapes_trainID_to_color = {cls.train_id: cls.color for cls in classes}
|
59 |
+
cityscapes_trainID_to_name = {cls.train_id: cls.name for cls in classes}
|
60 |
+
cityscapes_trainID_to_color[255] = (0, 0, 0)
|
61 |
+
cityscapes_trainID_to_name = {cls.train_id: cls.name for cls in classes}
|
62 |
+
cityscapes_trainID_to_name[255] = 'ignore'
|
63 |
+
cityscapes_trainID_to_name[19] = 'ignore'
|
64 |
+
|
65 |
+
|
66 |
+
def map2cs(seg):
|
67 |
+
while len(seg.shape) > 2:
|
68 |
+
seg = seg[0]
|
69 |
+
colors = cityscapes_trainID_to_color
|
70 |
+
# assert False, 'set ignore_idx color to black, make sure that it is not in colors'
|
71 |
+
rgb = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
72 |
+
for l in np.unique(seg):
|
73 |
+
rgb[seg == l, :] = colors[l]
|
74 |
+
return rgb
|
75 |
+
|
76 |
+
|
77 |
+
def get_colors(num_colors):
|
78 |
+
from PIL import ImageColor
|
79 |
+
import matplotlib
|
80 |
+
hex_colors = [
|
81 |
+
# "#000000", # keep the black reserved
|
82 |
+
"#FFFF00", "#1CE6FF", "#FF34FF", "#FF4A46", "#008941", "#006FA6", "#A30059",
|
83 |
+
"#FFDBE5", "#7A4900", "#0000A6", "#63FFAC", "#B79762", "#004D43", "#8FB0FF", "#997D87",
|
84 |
+
"#5A0007", "#809693", "#FEFFE6", "#1B4400", "#4FC601", "#3B5DFF", "#4A3B53", "#FF2F80",
|
85 |
+
"#61615A", "#BA0900", "#6B7900", "#00C2A0", "#FFAA92", "#FF90C9", "#B903AA", "#D16100",
|
86 |
+
"#DDEFFF", "#000035", "#7B4F4B", "#A1C299", "#300018", "#0AA6D8", "#013349", "#00846F",
|
87 |
+
"#372101", "#FFB500", "#C2FFED", "#A079BF", "#CC0744", "#C0B9B2", "#C2FF99", "#001E09",
|
88 |
+
"#00489C", "#6F0062", "#0CBD66", "#EEC3FF", "#456D75", "#B77B68", "#7A87A1", "#788D66",
|
89 |
+
"#885578", "#FAD09F", "#FF8A9A", "#D157A0", "#BEC459", "#456648", "#0086ED", "#886F4C",
|
90 |
+
"#34362D", "#B4A8BD", "#00A6AA", "#452C2C", "#636375", "#A3C8C9", "#FF913F", "#938A81",
|
91 |
+
"#575329", "#00FECF", "#B05B6F", "#8CD0FF", "#3B9700", "#04F757", "#C8A1A1", "#1E6E00",
|
92 |
+
"#7900D7", "#A77500", "#6367A9", "#A05837", "#6B002C", "#772600", "#D790FF", "#9B9700",
|
93 |
+
"#549E79", "#FFF69F", "#201625", "#72418F", "#BC23FF", "#99ADC0", "#3A2465", "#922329",
|
94 |
+
"#5B4534", "#FDE8DC", "#404E55", "#0089A3", "#CB7E98", "#A4E804", "#324E72", "#6A3A4C",
|
95 |
+
"#83AB58", "#001C1E", "#D1F7CE", "#004B28", "#C8D0F6", "#A3A489", "#806C66", "#222800",
|
96 |
+
"#BF5650", "#E83000", "#66796D", "#DA007C", "#FF1A59", "#8ADBB4", "#1E0200", "#5B4E51",
|
97 |
+
"#C895C5", "#320033", "#FF6832", "#66E1D3", "#CFCDAC", "#D0AC94", "#7ED379", "#012C58",
|
98 |
+
]
|
99 |
+
hex_colors_mlib = list(matplotlib.colors.cnames.values())
|
100 |
+
for hcm in hex_colors_mlib:
|
101 |
+
if hcm not in hex_colors:
|
102 |
+
hex_colors.append(hcm)
|
103 |
+
colors = [ImageColor.getrgb(hex) for hex in hex_colors]
|
104 |
+
return colors[:num_colors]
|
105 |
+
|
106 |
+
|
107 |
+
def colorize_one(seg, ignore=None, colors=None, ncolors=32):
|
108 |
+
unq = np.unique(seg)
|
109 |
+
if ncolors is not None:
|
110 |
+
ncolors = max(ncolors, max(unq))
|
111 |
+
else:
|
112 |
+
ncolors = max(unq)
|
113 |
+
colors = get_colors(ncolors) if colors is None else colors
|
114 |
+
h, w = seg.shape
|
115 |
+
c = 3
|
116 |
+
rgb = np.zeros((h, w, c), dtype=np.uint8)
|
117 |
+
for l in unq:
|
118 |
+
if ignore is not None and l == ignore:
|
119 |
+
continue
|
120 |
+
try:
|
121 |
+
rgb[seg == l, :] = colors[l]
|
122 |
+
except:
|
123 |
+
raise Exception(l)
|
124 |
+
return rgb
|
125 |
+
|
126 |
+
|
127 |
+
def init_weights(m):
|
128 |
+
if isinstance(m, nn.Linear):
|
129 |
+
trunc_normal_(m.weight, std=0.02)
|
130 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
131 |
+
nn.init.constant_(m.bias, 0)
|
132 |
+
elif isinstance(m, nn.LayerNorm):
|
133 |
+
nn.init.constant_(m.bias, 0)
|
134 |
+
nn.init.constant_(m.weight, 1.0)
|
135 |
+
|
136 |
+
|
137 |
+
def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens):
|
138 |
+
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
139 |
+
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
140 |
+
posemb_tok, posemb_grid = (
|
141 |
+
posemb[:, :num_extra_tokens],
|
142 |
+
posemb[0, num_extra_tokens:],
|
143 |
+
)
|
144 |
+
if grid_old_shape is None:
|
145 |
+
gs_old_h = int(math.sqrt(len(posemb_grid)))
|
146 |
+
gs_old_w = gs_old_h
|
147 |
+
else:
|
148 |
+
gs_old_h, gs_old_w = grid_old_shape
|
149 |
+
|
150 |
+
gs_h, gs_w = grid_new_shape
|
151 |
+
posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2)
|
152 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
153 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
154 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
155 |
+
return posemb
|
156 |
+
|
157 |
+
|
158 |
+
def checkpoint_filter_fn(state_dict, model):
|
159 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
160 |
+
out_dict = {}
|
161 |
+
if "model" in state_dict:
|
162 |
+
# For deit models
|
163 |
+
state_dict = state_dict["model"]
|
164 |
+
num_extra_tokens = 1 + ("dist_token" in state_dict.keys())
|
165 |
+
patch_size = model.patch_size
|
166 |
+
image_size = model.patch_embed.image_size
|
167 |
+
for k, v in state_dict.items():
|
168 |
+
if k == "pos_embed" and v.shape != model.pos_embed.shape:
|
169 |
+
# To resize pos embedding when using model at different size from pretrained weights
|
170 |
+
v = resize_pos_embed(
|
171 |
+
v,
|
172 |
+
None,
|
173 |
+
(image_size[0] // patch_size, image_size[1] // patch_size),
|
174 |
+
num_extra_tokens,
|
175 |
+
)
|
176 |
+
out_dict[k] = v
|
177 |
+
return out_dict
|
178 |
+
|
179 |
+
|
180 |
+
def padding(im, patch_size, fill_value=0):
|
181 |
+
# make the image sizes divisible by patch_size
|
182 |
+
H, W = im.size(2), im.size(3)
|
183 |
+
pad_h, pad_w = 0, 0
|
184 |
+
if H % patch_size > 0:
|
185 |
+
pad_h = patch_size - (H % patch_size)
|
186 |
+
if W % patch_size > 0:
|
187 |
+
pad_w = patch_size - (W % patch_size)
|
188 |
+
im_padded = im
|
189 |
+
if pad_h > 0 or pad_w > 0:
|
190 |
+
im_padded = F.pad(im, (0, pad_w, 0, pad_h), value=fill_value)
|
191 |
+
return im_padded
|
192 |
+
|
193 |
+
|
194 |
+
def unpadding(y, target_size):
|
195 |
+
H, W = target_size
|
196 |
+
H_pad, W_pad = y.size(2), y.size(3)
|
197 |
+
# crop predictions on extra pixels coming from padding
|
198 |
+
extra_h = H_pad - H
|
199 |
+
extra_w = W_pad - W
|
200 |
+
if extra_h > 0:
|
201 |
+
y = y[:, :, :-extra_h]
|
202 |
+
if extra_w > 0:
|
203 |
+
y = y[:, :, :, :-extra_w]
|
204 |
+
return y
|
205 |
+
|
206 |
+
|
207 |
+
def resize(im, smaller_size):
|
208 |
+
h, w = im.shape[2:]
|
209 |
+
if h < w:
|
210 |
+
ratio = w / h
|
211 |
+
h_res, w_res = smaller_size, ratio * smaller_size
|
212 |
+
else:
|
213 |
+
ratio = h / w
|
214 |
+
h_res, w_res = ratio * smaller_size, smaller_size
|
215 |
+
if min(h, w) < smaller_size:
|
216 |
+
im_res = F.interpolate(im, (int(h_res), int(w_res)), mode="bilinear")
|
217 |
+
else:
|
218 |
+
im_res = im
|
219 |
+
return im_res
|
220 |
+
|
221 |
+
|
222 |
+
def sliding_window(im, flip, window_size, window_stride, channels_first=True):
|
223 |
+
if channels_first:
|
224 |
+
B, C, H, W = im.shape
|
225 |
+
else:
|
226 |
+
B, H, W, C = im.shape
|
227 |
+
ws = window_size
|
228 |
+
|
229 |
+
windows = {"crop": [], "anchors": []}
|
230 |
+
h_anchors = torch.arange(0, H, window_stride)
|
231 |
+
w_anchors = torch.arange(0, W, window_stride)
|
232 |
+
h_anchors = [h.item() for h in h_anchors if h < H - ws] + [H - ws]
|
233 |
+
w_anchors = [w.item() for w in w_anchors if w < W - ws] + [W - ws]
|
234 |
+
for ha in h_anchors:
|
235 |
+
for wa in w_anchors:
|
236 |
+
if channels_first:
|
237 |
+
window = im[:, :, ha: ha + ws, wa: wa + ws]
|
238 |
+
else:
|
239 |
+
window = im[:, ha: ha + ws, wa: wa + ws]
|
240 |
+
windows["crop"].append(window)
|
241 |
+
windows["anchors"].append((ha, wa))
|
242 |
+
windows["flip"] = flip
|
243 |
+
windows["shape"] = (H, W)
|
244 |
+
return windows
|
245 |
+
|
246 |
+
|
247 |
+
def merge_windows(windows, window_size, ori_shape, no_softmax=False, no_upsample=False, patch_size=None):
|
248 |
+
ws = window_size
|
249 |
+
im_windows = windows["seg_maps"]
|
250 |
+
anchors = windows["anchors"]
|
251 |
+
C = im_windows[0].shape[0]
|
252 |
+
H, W = windows["shape"]
|
253 |
+
flip = windows["flip"]
|
254 |
+
|
255 |
+
if no_upsample:
|
256 |
+
H, W = H // patch_size, W // patch_size
|
257 |
+
|
258 |
+
logit = torch.zeros((C, H, W), device=im_windows.device)
|
259 |
+
count = torch.zeros((1, H, W), device=im_windows.device)
|
260 |
+
for window, (ha, wa) in zip(im_windows, anchors):
|
261 |
+
if no_upsample:
|
262 |
+
ha = ha // patch_size
|
263 |
+
wa = wa // patch_size
|
264 |
+
logit[:, ha: ha + ws, wa: wa + ws] += window
|
265 |
+
count[:, ha: ha + ws, wa: wa + ws] += 1
|
266 |
+
logit /= count
|
267 |
+
# print('Interpolate {} -> {}'.format(logit.shape, ori_shape))
|
268 |
+
if not no_upsample:
|
269 |
+
logit = F.interpolate(
|
270 |
+
logit.unsqueeze(0),
|
271 |
+
ori_shape,
|
272 |
+
mode="bilinear",
|
273 |
+
)[0]
|
274 |
+
if flip:
|
275 |
+
logit = torch.flip(logit, (2,))
|
276 |
+
if not no_softmax:
|
277 |
+
# print('Softmax in merge_windows')
|
278 |
+
result = F.softmax(logit, 0)
|
279 |
+
else:
|
280 |
+
# print('No softmax in merge_windows')
|
281 |
+
result = logit
|
282 |
+
return result
|
283 |
+
|
284 |
+
|
285 |
+
def debug_windows(windows, debug_file):
|
286 |
+
pass
|
287 |
+
|
288 |
+
|
289 |
+
def inference_picie(
|
290 |
+
model,
|
291 |
+
classifier,
|
292 |
+
metric_test,
|
293 |
+
ims,
|
294 |
+
ori_shape,
|
295 |
+
window_size,
|
296 |
+
window_stride,
|
297 |
+
batch_size,
|
298 |
+
decoder_features=False,
|
299 |
+
no_upsample=False,
|
300 |
+
debug_file=None,
|
301 |
+
im_rgb=None,
|
302 |
+
channel_first=False
|
303 |
+
):
|
304 |
+
try:
|
305 |
+
C = model.n_cls
|
306 |
+
except:
|
307 |
+
C = classifier.module.bias.shape[0]
|
308 |
+
|
309 |
+
# seg_maps = []
|
310 |
+
|
311 |
+
# for im, im_metas in zip(ims, ims_metas):
|
312 |
+
for im in ims:
|
313 |
+
im = im.to('cuda')
|
314 |
+
if len(im.shape) == 3:
|
315 |
+
im = im.unsqueeze(0)
|
316 |
+
flip = False # im_metas["flip"]
|
317 |
+
windows = sliding_window(im, flip, window_size, window_stride)
|
318 |
+
crops = torch.stack(windows.pop("crop"))[:, 0]
|
319 |
+
num_crops = len(crops)
|
320 |
+
|
321 |
+
WB = batch_size if batch_size > 0 else num_crops
|
322 |
+
if no_upsample:
|
323 |
+
window_size = window_size // model.patch_size
|
324 |
+
seg_maps = torch.zeros((num_crops, C, window_size, window_size), device=im.device)
|
325 |
+
with torch.no_grad():
|
326 |
+
for i in range(0, num_crops, WB):
|
327 |
+
# try:
|
328 |
+
feats = model.forward(crops[i: i + WB])
|
329 |
+
if metric_test == 'cosine':
|
330 |
+
feats = F.normalize(feats, dim=1, p=2)
|
331 |
+
probs = classifier(feats)
|
332 |
+
probs = F.interpolate(probs, crops[i: i + WB].shape[-2:], mode='bilinear', align_corners=False)
|
333 |
+
seg_maps[i: i + WB] = probs
|
334 |
+
windows["seg_maps"] = seg_maps
|
335 |
+
|
336 |
+
if debug_file is not None:
|
337 |
+
if isinstance(im_rgb, torch.Tensor):
|
338 |
+
im_rgb = im_rgb.detach().cpu().numpy()
|
339 |
+
if len(im_rgb.shape) == 4:
|
340 |
+
im_rgb = im_rgb[0]
|
341 |
+
h, w = im.shape[-2:]
|
342 |
+
im_rgb = cv2.resize(im_rgb, (w, h), interpolation=cv2.INTER_LINEAR)
|
343 |
+
|
344 |
+
crops_rgb = np.stack(
|
345 |
+
sliding_window(im_rgb[None, :], flip, window_size, window_stride, channels_first=channel_first).pop(
|
346 |
+
"crop"))[:, 0]
|
347 |
+
|
348 |
+
im_seg_map = merge_windows(windows, window_size, ori_shape, no_softmax=decoder_features,
|
349 |
+
no_upsample=no_upsample, patch_size=None)
|
350 |
+
|
351 |
+
seg_map = im_seg_map
|
352 |
+
if no_upsample and not decoder_features:
|
353 |
+
pass
|
354 |
+
else:
|
355 |
+
seg_map = F.interpolate(
|
356 |
+
seg_map.unsqueeze(0),
|
357 |
+
ori_shape,
|
358 |
+
mode="bilinear",
|
359 |
+
)
|
360 |
+
|
361 |
+
return seg_map
|
362 |
+
|
363 |
+
|
364 |
+
def inference(
|
365 |
+
model,
|
366 |
+
ims,
|
367 |
+
ori_shape,
|
368 |
+
window_size,
|
369 |
+
window_stride,
|
370 |
+
batch_size,
|
371 |
+
decoder_features=False,
|
372 |
+
encoder_features=False,
|
373 |
+
save2cpu=False,
|
374 |
+
no_upsample=False,
|
375 |
+
debug_file=None,
|
376 |
+
im_rgb=None,
|
377 |
+
channel_first=False
|
378 |
+
):
|
379 |
+
C = model.n_cls
|
380 |
+
patch_size = model.patch_size
|
381 |
+
|
382 |
+
# seg_maps = []
|
383 |
+
|
384 |
+
# for im, im_metas in zip(ims, ims_metas):
|
385 |
+
for im in ims:
|
386 |
+
im = im.to('cuda')
|
387 |
+
if len(im.shape) == 3:
|
388 |
+
im = im.unsqueeze(0)
|
389 |
+
# im = resize(im, window_size)
|
390 |
+
flip = False # im_metas["flip"]
|
391 |
+
# print(im)
|
392 |
+
windows = sliding_window(im, flip, window_size, window_stride)
|
393 |
+
# print(windows)
|
394 |
+
crops = torch.stack(windows.pop("crop"))[:, 0]
|
395 |
+
num_crops = len(crops)
|
396 |
+
|
397 |
+
WB = batch_size if batch_size > 0 else num_crops
|
398 |
+
if no_upsample:
|
399 |
+
window_size = window_size // model.patch_size
|
400 |
+
# print('Change variable window_size to {}'.format(window_size))
|
401 |
+
seg_maps = torch.zeros((num_crops, C, window_size, window_size), device=im.device)
|
402 |
+
# print('Allocated segm_maps: {}, device: {}'.format(seg_maps.shape, seg_maps.device))
|
403 |
+
with torch.no_grad():
|
404 |
+
for i in range(0, num_crops, WB):
|
405 |
+
# try:
|
406 |
+
seg_maps[i: i + WB] = model.forward(crops[i: i + WB], decoder_features=decoder_features,
|
407 |
+
encoder_features=encoder_features,
|
408 |
+
no_upsample=no_upsample)
|
409 |
+
# except:
|
410 |
+
# print('Input of shape: {}'.format(crops[i:i + WB].shape))
|
411 |
+
# assert False, "End after error."
|
412 |
+
# torch.cuda.empty_cache()
|
413 |
+
windows["seg_maps"] = seg_maps
|
414 |
+
|
415 |
+
if debug_file is not None:
|
416 |
+
if isinstance(im_rgb, torch.Tensor):
|
417 |
+
im_rgb = im_rgb.detach().cpu().numpy()
|
418 |
+
if len(im_rgb.shape) == 4:
|
419 |
+
im_rgb = im_rgb[0]
|
420 |
+
h, w = im.shape[-2:]
|
421 |
+
im_rgb = cv2.resize(im_rgb, (w, h), interpolation=cv2.INTER_LINEAR)
|
422 |
+
|
423 |
+
crops_rgb = np.stack(
|
424 |
+
sliding_window(im_rgb[None, :], flip, window_size, window_stride, channels_first=channel_first).pop(
|
425 |
+
"crop"))[:, 0]
|
426 |
+
|
427 |
+
windows_row = np.concatenate([w for w in crops_rgb], axis=1)
|
428 |
+
# print(windows_row)
|
429 |
+
try:
|
430 |
+
Image.fromarray(windows_row).save(debug_file)
|
431 |
+
except:
|
432 |
+
pass
|
433 |
+
|
434 |
+
suffix = debug_file[-4:]
|
435 |
+
debug_file = debug_file.replace(suffix, '_preds{}'.format(suffix))
|
436 |
+
windows_preds = seg_maps.argmax(dim=1).cpu().numpy()
|
437 |
+
windows_preds_row = np.concatenate([seg2rgb(wp, C, 255) for wp in windows_preds], axis=1)
|
438 |
+
windows_row_plus_preds = np.concatenate((windows_row, windows_preds_row), axis=0)
|
439 |
+
try:
|
440 |
+
Image.fromarray(windows_preds_row).save(debug_file)
|
441 |
+
except:
|
442 |
+
pass
|
443 |
+
|
444 |
+
debug_file = debug_file.replace(suffix, '_wImg{}'.format(suffix))
|
445 |
+
try:
|
446 |
+
Image.fromarray(windows_row_plus_preds).save(debug_file)
|
447 |
+
except:
|
448 |
+
pass
|
449 |
+
|
450 |
+
im_seg_map = merge_windows(windows, window_size, ori_shape, no_softmax=decoder_features,
|
451 |
+
no_upsample=no_upsample, patch_size=model.patch_size)
|
452 |
+
|
453 |
+
seg_map = im_seg_map
|
454 |
+
if no_upsample and not decoder_features:
|
455 |
+
pass
|
456 |
+
else:
|
457 |
+
seg_map = F.interpolate(
|
458 |
+
seg_map.unsqueeze(0),
|
459 |
+
ori_shape,
|
460 |
+
mode="bilinear",
|
461 |
+
)
|
462 |
+
# seg_maps.append(seg_map)
|
463 |
+
|
464 |
+
# print('Done one inference.')
|
465 |
+
# seg_maps = torch.cat(seg_maps, dim=0)
|
466 |
+
return seg_map
|
467 |
+
|
468 |
+
|
469 |
+
def inference_features(
|
470 |
+
model,
|
471 |
+
ims,
|
472 |
+
ori_shape,
|
473 |
+
window_size,
|
474 |
+
window_stride,
|
475 |
+
batch_size,
|
476 |
+
decoder_features=False,
|
477 |
+
encoder_features=False,
|
478 |
+
save2cpu=False,
|
479 |
+
no_upsample=True,
|
480 |
+
encoder_only=False
|
481 |
+
):
|
482 |
+
C = model.n_cls if decoder_features else model.encoder.d_model
|
483 |
+
patch_size = model.patch_size
|
484 |
+
|
485 |
+
# seg_maps = []
|
486 |
+
|
487 |
+
# for im, im_metas in zip(ims, ims_metas):
|
488 |
+
for im in ims:
|
489 |
+
im = im.to('cuda')
|
490 |
+
if len(im.shape) == 3:
|
491 |
+
im = im.unsqueeze(0)
|
492 |
+
# im = resize(im, window_size)
|
493 |
+
flip = False # im_metas["flip"]
|
494 |
+
# print(im)
|
495 |
+
windows = sliding_window(im, flip, window_size, window_stride)
|
496 |
+
# print(windows)
|
497 |
+
crops = torch.stack(windows.pop("crop"))[:, 0]
|
498 |
+
num_crops = len(crops)
|
499 |
+
|
500 |
+
WB = batch_size if batch_size > 0 else num_crops
|
501 |
+
if no_upsample:
|
502 |
+
window_size = window_size // model.patch_size
|
503 |
+
# print('Change variable window_size to {}'.format(window_size))
|
504 |
+
enc_maps = torch.zeros((num_crops, C, window_size, window_size), device=im.device)
|
505 |
+
if decoder_features:
|
506 |
+
dec_maps = torch.zeros((num_crops, C, window_size, window_size), device=im.device)
|
507 |
+
# print('Allocated segm_maps: {}, device: {}'.format(seg_maps.shape, seg_maps.device))
|
508 |
+
with torch.no_grad():
|
509 |
+
for i in range(0, num_crops, WB):
|
510 |
+
enc_fts = model.forward(crops[i: i + WB], decoder_features=decoder_features,
|
511 |
+
encoder_features=True,
|
512 |
+
no_upsample=no_upsample, encoder_only=encoder_only)
|
513 |
+
if decoder_features:
|
514 |
+
enc_fts, dec_fts = enc_fts
|
515 |
+
dec_maps[i: i + WB] = dec_fts
|
516 |
+
elif isinstance(enc_fts, tuple):
|
517 |
+
enc_fts = enc_fts[0]
|
518 |
+
enc_maps[i: i + WB] = enc_fts
|
519 |
+
|
520 |
+
windows["seg_maps"] = enc_maps
|
521 |
+
im_enc_map = merge_windows(windows, window_size, ori_shape, no_softmax=decoder_features,
|
522 |
+
no_upsample=no_upsample, patch_size=model.patch_size)
|
523 |
+
|
524 |
+
if decoder_features:
|
525 |
+
windows["seg_maps"] = dec_maps
|
526 |
+
im_dec_map = merge_windows(windows, window_size, ori_shape, no_softmax=decoder_features,
|
527 |
+
no_upsample=no_upsample, patch_size=model.patch_size)
|
528 |
+
|
529 |
+
if no_upsample:
|
530 |
+
pass
|
531 |
+
else:
|
532 |
+
im_enc_map = F.interpolate(
|
533 |
+
im_enc_map.unsqueeze(0),
|
534 |
+
ori_shape,
|
535 |
+
mode="bilinear",
|
536 |
+
)
|
537 |
+
if decoder_features:
|
538 |
+
im_dec_map = F.interpolate(
|
539 |
+
im_dec_map.unsqueeze(0),
|
540 |
+
ori_shape,
|
541 |
+
mode="bilinear",
|
542 |
+
)
|
543 |
+
|
544 |
+
im_enc_map = im_enc_map.cpu().numpy()
|
545 |
+
if decoder_features:
|
546 |
+
im_dec_map = im_dec_map.cpu().numpy()
|
547 |
+
return im_enc_map, im_dec_map
|
548 |
+
|
549 |
+
return im_enc_map
|
550 |
+
|
551 |
+
|
552 |
+
def inference_conv(
|
553 |
+
model,
|
554 |
+
ims,
|
555 |
+
ims_metas,
|
556 |
+
ori_shape
|
557 |
+
):
|
558 |
+
assert len(ims) == 1
|
559 |
+
for im, im_metas in zip(ims, ims_metas):
|
560 |
+
im = im.to(ptu.device)
|
561 |
+
if len(im.shape) < 4:
|
562 |
+
im = im.unsqueeze(0)
|
563 |
+
logits = model(im)
|
564 |
+
if ori_shape[:2] != logits.shape[-2:]:
|
565 |
+
# resize
|
566 |
+
logits = F.interpolate(
|
567 |
+
logits,
|
568 |
+
ori_shape[-2:],
|
569 |
+
mode="bilinear",
|
570 |
+
)
|
571 |
+
# 3) applies softmax
|
572 |
+
result = F.softmax(logits.squeeze(), 0)
|
573 |
+
# print(result.shape)
|
574 |
+
return result
|
575 |
+
|
576 |
+
|
577 |
+
def num_params(model):
|
578 |
+
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
579 |
+
n_params = sum([torch.prod(torch.tensor(p.size())) for p in model_parameters])
|
580 |
+
if not type(n_params) == int:
|
581 |
+
n_params = n_params.item()
|
582 |
+
return n_params
|
segmenter_model/vit_dino.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copied from DINO
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Mostly copy-paste from timm library.
|
17 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
18 |
+
"""
|
19 |
+
import math
|
20 |
+
import warnings
|
21 |
+
from functools import partial
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import torch.nn as nn
|
25 |
+
|
26 |
+
|
27 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
28 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
29 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
30 |
+
def norm_cdf(x):
|
31 |
+
# Computes standard normal cumulative distribution function
|
32 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
33 |
+
|
34 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
35 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
36 |
+
"The distribution of values may be incorrect.",
|
37 |
+
stacklevel=2)
|
38 |
+
|
39 |
+
with torch.no_grad():
|
40 |
+
# Values are generated by using a truncated uniform distribution and
|
41 |
+
# then using the inverse CDF for the normal distribution.
|
42 |
+
# Get upper and lower cdf values
|
43 |
+
l = norm_cdf((a - mean) / std)
|
44 |
+
u = norm_cdf((b - mean) / std)
|
45 |
+
|
46 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
47 |
+
# [2l-1, 2u-1].
|
48 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
49 |
+
|
50 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
51 |
+
# standard normal
|
52 |
+
tensor.erfinv_()
|
53 |
+
|
54 |
+
# Transform to proper mean, std
|
55 |
+
tensor.mul_(std * math.sqrt(2.))
|
56 |
+
tensor.add_(mean)
|
57 |
+
|
58 |
+
# Clamp to ensure it's in the proper range
|
59 |
+
tensor.clamp_(min=a, max=b)
|
60 |
+
return tensor
|
61 |
+
|
62 |
+
|
63 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
64 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
65 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
66 |
+
|
67 |
+
|
68 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
69 |
+
if drop_prob == 0. or not training:
|
70 |
+
return x
|
71 |
+
keep_prob = 1 - drop_prob
|
72 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
73 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
74 |
+
random_tensor.floor_() # binarize
|
75 |
+
output = x.div(keep_prob) * random_tensor
|
76 |
+
return output
|
77 |
+
|
78 |
+
|
79 |
+
class DropPath(nn.Module):
|
80 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self, drop_prob=None):
|
84 |
+
super(DropPath, self).__init__()
|
85 |
+
self.drop_prob = drop_prob
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
return drop_path(x, self.drop_prob, self.training)
|
89 |
+
|
90 |
+
|
91 |
+
class Mlp(nn.Module):
|
92 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
93 |
+
super().__init__()
|
94 |
+
out_features = out_features or in_features
|
95 |
+
hidden_features = hidden_features or in_features
|
96 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
97 |
+
self.act = act_layer()
|
98 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
99 |
+
self.drop = nn.Dropout(drop)
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
x = self.fc1(x)
|
103 |
+
x = self.act(x)
|
104 |
+
x = self.drop(x)
|
105 |
+
x = self.fc2(x)
|
106 |
+
x = self.drop(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
class Attention(nn.Module):
|
111 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
112 |
+
super().__init__()
|
113 |
+
self.num_heads = num_heads
|
114 |
+
head_dim = dim // num_heads
|
115 |
+
self.scale = qk_scale or head_dim ** -0.5
|
116 |
+
|
117 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
118 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
119 |
+
self.proj = nn.Linear(dim, dim)
|
120 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
B, N, C = x.shape
|
124 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
125 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
126 |
+
|
127 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
128 |
+
attn = attn.softmax(dim=-1)
|
129 |
+
attn = self.attn_drop(attn)
|
130 |
+
|
131 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
132 |
+
x = self.proj(x)
|
133 |
+
x = self.proj_drop(x)
|
134 |
+
return x, attn
|
135 |
+
|
136 |
+
|
137 |
+
class Block(nn.Module):
|
138 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
139 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
140 |
+
super().__init__()
|
141 |
+
self.norm1 = norm_layer(dim)
|
142 |
+
self.attn = Attention(
|
143 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
144 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
145 |
+
self.norm2 = norm_layer(dim)
|
146 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
147 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
148 |
+
|
149 |
+
def forward(self, x, return_attention=False):
|
150 |
+
y, attn = self.attn(self.norm1(x))
|
151 |
+
if return_attention:
|
152 |
+
return attn
|
153 |
+
x = x + self.drop_path(y)
|
154 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
155 |
+
return x
|
156 |
+
|
157 |
+
|
158 |
+
class PatchEmbed(nn.Module):
|
159 |
+
""" Image to Patch Embedding
|
160 |
+
"""
|
161 |
+
|
162 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
163 |
+
super().__init__()
|
164 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
165 |
+
self.img_size = img_size
|
166 |
+
self.patch_size = patch_size
|
167 |
+
self.num_patches = num_patches
|
168 |
+
|
169 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
B, C, H, W = x.shape
|
173 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
174 |
+
return x
|
175 |
+
|
176 |
+
|
177 |
+
class VisionTransformer(nn.Module):
|
178 |
+
""" Vision Transformer """
|
179 |
+
|
180 |
+
def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
|
181 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
182 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
|
183 |
+
super().__init__()
|
184 |
+
self.num_features = self.embed_dim = embed_dim
|
185 |
+
|
186 |
+
self.patch_embed = PatchEmbed(
|
187 |
+
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
188 |
+
num_patches = self.patch_embed.num_patches
|
189 |
+
|
190 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
191 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
192 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
193 |
+
|
194 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
195 |
+
self.blocks = nn.ModuleList([
|
196 |
+
Block(
|
197 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
198 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
199 |
+
for i in range(depth)])
|
200 |
+
self.norm = norm_layer(embed_dim)
|
201 |
+
|
202 |
+
# Classifier head
|
203 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
204 |
+
|
205 |
+
trunc_normal_(self.pos_embed, std=.02)
|
206 |
+
trunc_normal_(self.cls_token, std=.02)
|
207 |
+
self.apply(self._init_weights)
|
208 |
+
|
209 |
+
def _init_weights(self, m):
|
210 |
+
if isinstance(m, nn.Linear):
|
211 |
+
trunc_normal_(m.weight, std=.02)
|
212 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
213 |
+
nn.init.constant_(m.bias, 0)
|
214 |
+
elif isinstance(m, nn.LayerNorm):
|
215 |
+
nn.init.constant_(m.bias, 0)
|
216 |
+
nn.init.constant_(m.weight, 1.0)
|
217 |
+
|
218 |
+
def interpolate_pos_encoding(self, x, w, h):
|
219 |
+
npatch = x.shape[1] - 1
|
220 |
+
N = self.pos_embed.shape[1] - 1
|
221 |
+
if npatch == N and w == h:
|
222 |
+
return self.pos_embed
|
223 |
+
class_pos_embed = self.pos_embed[:, 0]
|
224 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
225 |
+
dim = x.shape[-1]
|
226 |
+
w0 = w // self.patch_embed.patch_size
|
227 |
+
h0 = h // self.patch_embed.patch_size
|
228 |
+
# we add a small number to avoid floating point error in the interpolation
|
229 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
230 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
231 |
+
patch_pos_embed = nn.functional.interpolate(
|
232 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
233 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
234 |
+
mode='bicubic',
|
235 |
+
)
|
236 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
237 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
238 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
239 |
+
|
240 |
+
def prepare_tokens(self, x):
|
241 |
+
B, nc, w, h = x.shape
|
242 |
+
x = self.patch_embed(x) # patch linear embedding
|
243 |
+
|
244 |
+
# add the [CLS] token to the embed patch tokens
|
245 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
246 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
247 |
+
|
248 |
+
# add positional encoding to each token
|
249 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
250 |
+
|
251 |
+
return self.pos_drop(x)
|
252 |
+
|
253 |
+
def forward(self, x):
|
254 |
+
x = self.prepare_tokens(x)
|
255 |
+
for blk in self.blocks:
|
256 |
+
x = blk(x)
|
257 |
+
x = self.norm(x)
|
258 |
+
return x[:, 0]
|
259 |
+
|
260 |
+
def get_last_selfattention(self, x):
|
261 |
+
x = self.prepare_tokens(x)
|
262 |
+
for i, blk in enumerate(self.blocks):
|
263 |
+
if i < len(self.blocks) - 1:
|
264 |
+
x = blk(x)
|
265 |
+
else:
|
266 |
+
# return attention of the last block
|
267 |
+
return blk(x, return_attention=True)
|
268 |
+
|
269 |
+
def get_n_last_selfattentions(self, x, layers_from_end=(1)):
|
270 |
+
x = self.prepare_tokens(x)
|
271 |
+
attentions = []
|
272 |
+
for i, blk in enumerate(self.blocks):
|
273 |
+
num_from_end = len(self.blocks) - i
|
274 |
+
if num_from_end in layers_from_end:
|
275 |
+
# get attention of the block
|
276 |
+
attn = blk(x, return_attention=True)
|
277 |
+
attentions.append(attn)
|
278 |
+
x = blk(x)
|
279 |
+
return attentions
|
280 |
+
|
281 |
+
def get_intermediate_layers(self, x, n=1):
|
282 |
+
x = self.prepare_tokens(x)
|
283 |
+
# we return the output tokens from the `n` last blocks
|
284 |
+
output = []
|
285 |
+
for i, blk in enumerate(self.blocks):
|
286 |
+
x = blk(x)
|
287 |
+
if len(self.blocks) - i <= n:
|
288 |
+
output.append(self.norm(x))
|
289 |
+
return output
|
290 |
+
|
291 |
+
|
292 |
+
def vit_tiny(patch_size=16, **kwargs):
|
293 |
+
model = VisionTransformer(
|
294 |
+
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
|
295 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
296 |
+
return model
|
297 |
+
|
298 |
+
|
299 |
+
def vit_small(patch_size=16, **kwargs):
|
300 |
+
model = VisionTransformer(
|
301 |
+
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
|
302 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
303 |
+
return model
|
304 |
+
|
305 |
+
|
306 |
+
def vit_base(patch_size=16, **kwargs):
|
307 |
+
model = VisionTransformer(
|
308 |
+
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
309 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
310 |
+
return model
|
311 |
+
|
312 |
+
|
313 |
+
class DINOHead(nn.Module):
|
314 |
+
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048,
|
315 |
+
bottleneck_dim=256):
|
316 |
+
super().__init__()
|
317 |
+
nlayers = max(nlayers, 1)
|
318 |
+
if nlayers == 1:
|
319 |
+
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
320 |
+
else:
|
321 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
322 |
+
if use_bn:
|
323 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
324 |
+
layers.append(nn.GELU())
|
325 |
+
for _ in range(nlayers - 2):
|
326 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
327 |
+
if use_bn:
|
328 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
329 |
+
layers.append(nn.GELU())
|
330 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
331 |
+
self.mlp = nn.Sequential(*layers)
|
332 |
+
self.apply(self._init_weights)
|
333 |
+
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
334 |
+
self.last_layer.weight_g.data.fill_(1)
|
335 |
+
if norm_last_layer:
|
336 |
+
self.last_layer.weight_g.requires_grad = False
|
337 |
+
|
338 |
+
def _init_weights(self, m):
|
339 |
+
if isinstance(m, nn.Linear):
|
340 |
+
trunc_normal_(m.weight, std=.02)
|
341 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
342 |
+
nn.init.constant_(m.bias, 0)
|
343 |
+
|
344 |
+
def forward(self, x):
|
345 |
+
x = self.mlp(x)
|
346 |
+
x = nn.functional.normalize(x, dim=-1, p=2)
|
347 |
+
x = self.last_layer(x)
|
348 |
+
return x
|