vobecant commited on
Commit
dd78229
·
1 Parent(s): 83a95c0

Initial commit

Browse files
.idea/DaS.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="Python 3.8 (pytorch) (2)" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="13">
8
+ <item index="0" class="java.lang.String" itemvalue="yacs" />
9
+ <item index="1" class="java.lang.String" itemvalue="termcolor" />
10
+ <item index="2" class="java.lang.String" itemvalue="pydot" />
11
+ <item index="3" class="java.lang.String" itemvalue="fvcore" />
12
+ <item index="4" class="java.lang.String" itemvalue="tabulate" />
13
+ <item index="5" class="java.lang.String" itemvalue="mock" />
14
+ <item index="6" class="java.lang.String" itemvalue="pycocotools" />
15
+ <item index="7" class="java.lang.String" itemvalue="prettytable" />
16
+ <item index="8" class="java.lang.String" itemvalue="interrogate" />
17
+ <item index="9" class="java.lang.String" itemvalue="cityscapesscripts" />
18
+ <item index="10" class="java.lang.String" itemvalue="isort" />
19
+ <item index="11" class="java.lang.String" itemvalue="xdoctest" />
20
+ <item index="12" class="java.lang.String" itemvalue="codecov" />
21
+ </list>
22
+ </value>
23
+ </option>
24
+ </inspection_tool>
25
+ </profile>
26
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (pytorch) (2)" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/DaS.iml" filepath="$PROJECT_DIR$/.idea/DaS.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
.idea/workspace.xml ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ChangeListManager">
4
+ <list default="true" id="5dd22f22-8223-4d55-99f9-57d1e00622d7" name="Default Changelist" comment="Initial commit.">
5
+ <change afterPath="$PROJECT_DIR$/app.py" afterDir="false" />
6
+ <change afterPath="$PROJECT_DIR$/examples/img1.jpg" afterDir="false" />
7
+ <change afterPath="$PROJECT_DIR$/requirements.txt" afterDir="false" />
8
+ <change afterPath="$PROJECT_DIR$/segmenter_model/backbone_picie.py" afterDir="false" />
9
+ <change afterPath="$PROJECT_DIR$/segmenter_model/blocks.py" afterDir="false" />
10
+ <change afterPath="$PROJECT_DIR$/segmenter_model/decoder.py" afterDir="false" />
11
+ <change afterPath="$PROJECT_DIR$/segmenter_model/factory.py" afterDir="false" />
12
+ <change afterPath="$PROJECT_DIR$/segmenter_model/fpn_picie.py" afterDir="false" />
13
+ <change afterPath="$PROJECT_DIR$/segmenter_model/picie_model.py" afterDir="false" />
14
+ <change afterPath="$PROJECT_DIR$/segmenter_model/resnet_dilated.py" afterDir="false" />
15
+ <change afterPath="$PROJECT_DIR$/segmenter_model/segmenter.py" afterDir="false" />
16
+ <change afterPath="$PROJECT_DIR$/segmenter_model/torch.py" afterDir="false" />
17
+ <change afterPath="$PROJECT_DIR$/segmenter_model/utils.py" afterDir="false" />
18
+ <change afterPath="$PROJECT_DIR$/segmenter_model/vit_dino.py" afterDir="false" />
19
+ <change beforePath="$PROJECT_DIR$/README.md" beforeDir="false" afterPath="$PROJECT_DIR$/README.md" afterDir="false" />
20
+ </list>
21
+ <option name="SHOW_DIALOG" value="false" />
22
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
23
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
24
+ <option name="LAST_RESOLUTION" value="IGNORE" />
25
+ </component>
26
+ <component name="FileTemplateManagerImpl">
27
+ <option name="RECENT_TEMPLATES">
28
+ <list>
29
+ <option value="Python Script" />
30
+ </list>
31
+ </option>
32
+ </component>
33
+ <component name="Git.Settings">
34
+ <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
35
+ <option name="UPDATE_TYPE" value="REBASE" />
36
+ </component>
37
+ <component name="ProjectId" id="26QLDSf8iYKDlLRah6kIg09oqIa" />
38
+ <component name="ProjectLevelVcsManager" settingsEditedManually="true" />
39
+ <component name="ProjectViewState">
40
+ <option name="hideEmptyMiddlePackages" value="true" />
41
+ <option name="showLibraryContents" value="true" />
42
+ </component>
43
+ <component name="PropertiesComponent">
44
+ <property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
45
+ <property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
46
+ <property name="WebServerToolWindowFactoryState" value="true" />
47
+ <property name="last_opened_file_path" value="$PROJECT_DIR$" />
48
+ <property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
49
+ </component>
50
+ <component name="RecentsManager">
51
+ <key name="CopyFile.RECENT_KEYS">
52
+ <recent name="$PROJECT_DIR$" />
53
+ <recent name="$PROJECT_DIR$/examples" />
54
+ </key>
55
+ <key name="MoveFile.RECENT_KEYS">
56
+ <recent name="$PROJECT_DIR$/examples" />
57
+ </key>
58
+ </component>
59
+ <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
60
+ <component name="TaskManager">
61
+ <task active="true" id="Default" summary="Default task">
62
+ <changelist id="5dd22f22-8223-4d55-99f9-57d1e00622d7" name="Default Changelist" comment="" />
63
+ <created>1647350746642</created>
64
+ <option name="number" value="Default" />
65
+ <option name="presentableId" value="Default" />
66
+ <updated>1647350746642</updated>
67
+ <workItem from="1647350750956" duration="4327000" />
68
+ </task>
69
+ <task id="LOCAL-00001" summary="Initial commit.">
70
+ <created>1647352693910</created>
71
+ <option name="number" value="00001" />
72
+ <option name="presentableId" value="LOCAL-00001" />
73
+ <option name="project" value="LOCAL" />
74
+ <updated>1647352693910</updated>
75
+ </task>
76
+ <task id="LOCAL-00002" summary="Initial commit.">
77
+ <created>1647353059401</created>
78
+ <option name="number" value="00002" />
79
+ <option name="presentableId" value="LOCAL-00002" />
80
+ <option name="project" value="LOCAL" />
81
+ <updated>1647353059401</updated>
82
+ </task>
83
+ <task id="LOCAL-00003" summary="Added gitignore.">
84
+ <created>1647353514970</created>
85
+ <option name="number" value="00003" />
86
+ <option name="presentableId" value="LOCAL-00003" />
87
+ <option name="project" value="LOCAL" />
88
+ <updated>1647353514970</updated>
89
+ </task>
90
+ <task id="LOCAL-00004" summary="Added gitignore.">
91
+ <created>1647353622389</created>
92
+ <option name="number" value="00004" />
93
+ <option name="presentableId" value="LOCAL-00004" />
94
+ <option name="project" value="LOCAL" />
95
+ <updated>1647353622389</updated>
96
+ </task>
97
+ <task id="LOCAL-00005" summary="Added gitignore.">
98
+ <created>1647353674966</created>
99
+ <option name="number" value="00005" />
100
+ <option name="presentableId" value="LOCAL-00005" />
101
+ <option name="project" value="LOCAL" />
102
+ <updated>1647353674966</updated>
103
+ </task>
104
+ <task id="LOCAL-00006" summary="Initial commit.">
105
+ <created>1647354226094</created>
106
+ <option name="number" value="00006" />
107
+ <option name="presentableId" value="LOCAL-00006" />
108
+ <option name="project" value="LOCAL" />
109
+ <updated>1647354226094</updated>
110
+ </task>
111
+ <option name="localTasksCounter" value="7" />
112
+ <servers />
113
+ </component>
114
+ <component name="TypeScriptGeneratedFilesManager">
115
+ <option name="version" value="3" />
116
+ </component>
117
+ <component name="Vcs.Log.Tabs.Properties">
118
+ <option name="TAB_STATES">
119
+ <map>
120
+ <entry key="MAIN">
121
+ <value>
122
+ <State />
123
+ </value>
124
+ </entry>
125
+ </map>
126
+ </option>
127
+ </component>
128
+ <component name="VcsManagerConfiguration">
129
+ <MESSAGE value="Added gitignore." />
130
+ <MESSAGE value="Initial commit." />
131
+ <option name="LAST_COMMIT_MESSAGE" value="Initial commit." />
132
+ </component>
133
+ </project>
README.md CHANGED
@@ -1,12 +1 @@
1
- ---
2
- title: DaS
3
- emoji: 💻
4
- colorFrom: pink
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 2.8.10
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
 
 
 
 
 
 
 
 
 
1
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import requests
4
+ import torch
5
+ import yaml
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+
9
+ from segmenter_model import utils
10
+ from segmenter_model.factory import create_segmenter
11
+ from segmenter_model.fpn_picie import PanopticFPN
12
+ from segmenter_model.utils import colorize_one, map2cs
13
+
14
+ WEIGHTS = './weights/segmenter.pth'
15
+
16
+
17
+ def download_file_from_google_drive(id, destination):
18
+ def get_confirm_token(response):
19
+ for key, value in response.cookies.items():
20
+ if key.startswith('download_warning'):
21
+ return value
22
+
23
+ return None
24
+
25
+ def save_response_content(response, destination):
26
+ CHUNK_SIZE = 32768
27
+
28
+ with open(destination, "wb") as f:
29
+ for chunk in response.iter_content(CHUNK_SIZE):
30
+ if chunk: # filter out keep-alive new chunks
31
+ f.write(chunk)
32
+
33
+ URL = "https://docs.google.com/uc?export=download"
34
+
35
+ session = requests.Session()
36
+
37
+ response = session.get(URL, params={'id': id}, stream=True)
38
+ token = get_confirm_token(response)
39
+
40
+ if token:
41
+ params = {'id': id, 'confirm': token}
42
+ response = session.get(URL, params=params, stream=True)
43
+
44
+ save_response_content(response, destination)
45
+
46
+
47
+ def segment_segmenter(image, model, window_size, window_stride, encoder_features=False, decoder_features=False,
48
+ no_upsample=False, batch_size=2):
49
+ seg_pred = utils.inference(
50
+ model,
51
+ image,
52
+ image.shape[-2:],
53
+ window_size,
54
+ window_stride,
55
+ batch_size=batch_size,
56
+ no_upsample=no_upsample,
57
+ encoder_features=encoder_features,
58
+ decoder_features=decoder_features
59
+ )
60
+ if not (encoder_features or decoder_features):
61
+ seg_pred = seg_pred.argmax(1).unsqueeze(1)
62
+ return seg_pred
63
+
64
+
65
+ def remap(seg_pred, ignore=255):
66
+ mapping = {0: 0, 12: 1, 15: 2, 23: 3, 10: 4, 14: 5, 18: 6, 2: 7, 17: 8, 13: 9, 8: 10, 3: 11, 27: 12, 4: 13, 25: 14,
67
+ 24: 15, 6: 16, 22: 17, 28: 18}
68
+ h, w = seg_pred.shape[-2:]
69
+ seg_pred_remap = np.ones((h, w), dtype=np.uint8) * ignore
70
+ for pseudo, gt in mapping.items():
71
+ whr = seg_pred == pseudo
72
+ seg_pred_remap[whr] = gt
73
+ return seg_pred_remap
74
+
75
+
76
+ def create_model(resnet=False):
77
+ weights_path = WEIGHTS
78
+ variant_path = '{}_variant.yml'.format(weights_path)
79
+
80
+ print('Use weights {}'.format(weights_path))
81
+ print('Load variant from {}'.format(variant_path))
82
+ variant = yaml.load(
83
+ open(variant_path, "r"), Loader=yaml.FullLoader
84
+ )
85
+
86
+ # TODO: parse hyperparameters
87
+ window_size = variant['inference_kwargs']["window_size"]
88
+ window_stride = variant['inference_kwargs']["window_stride"]
89
+ dataset_kwargs = variant['dataset_kwargs']
90
+ net_kwargs = variant["net_kwargs"]
91
+ net_kwargs['n_cls'] = dataset_kwargs['nlabels']
92
+
93
+ dataset_kwargs = variant['dataset_kwargs']
94
+
95
+ net_kwargs = variant["net_kwargs"]
96
+ net_kwargs['n_cls'] = dataset_kwargs['nlabels']
97
+ if not resnet:
98
+ net_kwargs['decoder']['dropout'] = 0.
99
+
100
+ # TODO: create model
101
+ if resnet:
102
+ model = PanopticFPN(arch=net_kwargs['backbone'], pretrain=net_kwargs['pretrain'], n_cls=net_kwargs['n_cls'])
103
+ else:
104
+ model = create_segmenter(net_kwargs)
105
+
106
+ # TODO: load weights
107
+ print('Load weights from {}'.format(weights_path))
108
+ weights = torch.load(weights_path)['model']
109
+ model.load_state_dict(weights, strict=True)
110
+
111
+ model.eval()
112
+
113
+ return model, window_size, window_stride
114
+
115
+
116
+ def get_transformations():
117
+ return transforms.Compose([
118
+ transforms.ToTensor(),
119
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
120
+
121
+
122
+ model, window_size, window_stride = create_model()
123
+
124
+
125
+ def predict(input_img):
126
+ input_img = Image.open(input_img)
127
+ transform = transforms.Compose([transforms.Resize(256, Image.BICUBIC), transforms.ToTensor()])
128
+ input_img = transform(input_img)
129
+ input_img = torch.unsqueeze(input_img, 0)
130
+
131
+ with torch.no_grad():
132
+ segmentation = segment_segmenter(input_img, model, window_size, window_stride).squeeze().detach()
133
+ segmentation_remap = remap(segmentation)
134
+
135
+ drawing_pseudo = colorize_one(segmentation_remap)
136
+ drawing_cs = map2cs(segmentation_remap)
137
+
138
+ drawing_pseudo = transforms.ToPILImage()(drawing_pseudo)
139
+ drawing_cs = transforms.ToPILImage()(drawing_cs)
140
+ return drawing_pseudo, drawing_cs
141
+
142
+
143
+ title = "Drive&Segment"
144
+ description = 'Gradio Demo accompanying paper "Drive&Segment: Unsupervised Semantic Segmentation of Urban Scenes via Cross-modal Distillation"'
145
+ # article = "<p style='text-align: center'><a href='TODO' target='_blank'>Project Page</a> | <a href='codelink' target='_blank'>Github</a></p>"
146
+ examples = [['examples/img1.jpg']]
147
+
148
+ iface = gr.Interface(predict, gr.inputs.Image(type='filepath'), "image", title=title, description=description,
149
+ examples=examples)
150
+
151
+ iface.launch()
examples/img1.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ PIL
4
+ timm
5
+ yaml
6
+ einops
segmenter_model/backbone_picie.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ try:
4
+ from torchvision.models.utils import load_state_dict_from_url
5
+ except:
6
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
7
+
8
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
9
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
10
+ 'wide_resnet50_2', 'wide_resnet101_2']
11
+
12
+ model_urls = {
13
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
19
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
20
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
21
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
22
+ }
23
+
24
+
25
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
26
+ """3x3 convolution with padding"""
27
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
29
+
30
+
31
+ def conv1x1(in_planes, out_planes, stride=1):
32
+ """1x1 convolution"""
33
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
34
+
35
+
36
+ class BasicBlock(nn.Module):
37
+ expansion = 1
38
+
39
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
40
+ base_width=64, dilation=1, norm_layer=None):
41
+ super(BasicBlock, self).__init__()
42
+ if norm_layer is None:
43
+ norm_layer = nn.BatchNorm2d
44
+ if groups != 1 or base_width != 64:
45
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
46
+ if dilation > 1:
47
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
48
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
49
+ self.conv1 = conv3x3(inplanes, planes, stride)
50
+ self.bn1 = norm_layer(planes)
51
+ self.relu = nn.ReLU(inplace=True)
52
+ self.conv2 = conv3x3(planes, planes)
53
+ self.bn2 = norm_layer(planes)
54
+ self.downsample = downsample
55
+ self.stride = stride
56
+
57
+ def forward(self, x):
58
+ identity = x
59
+
60
+ out = self.conv1(x)
61
+ out = self.bn1(out)
62
+ out = self.relu(out)
63
+
64
+ out = self.conv2(out)
65
+ out = self.bn2(out)
66
+
67
+ if self.downsample is not None:
68
+ identity = self.downsample(x)
69
+
70
+ out += identity
71
+ out = self.relu(out)
72
+
73
+ return out
74
+
75
+
76
+ class Bottleneck(nn.Module):
77
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
78
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
79
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
80
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
81
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
82
+
83
+ expansion = 4
84
+
85
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
86
+ base_width=64, dilation=1, norm_layer=None):
87
+ super(Bottleneck, self).__init__()
88
+ if norm_layer is None:
89
+ norm_layer = nn.BatchNorm2d
90
+ width = int(planes * (base_width / 64.)) * groups
91
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
92
+ self.conv1 = conv1x1(inplanes, width)
93
+ self.bn1 = norm_layer(width)
94
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
95
+ self.bn2 = norm_layer(width)
96
+ self.conv3 = conv1x1(width, planes * self.expansion)
97
+ self.bn3 = norm_layer(planes * self.expansion)
98
+ self.relu = nn.ReLU(inplace=True)
99
+ self.downsample = downsample
100
+ self.stride = stride
101
+
102
+ def forward(self, x):
103
+ identity = x
104
+
105
+ out = self.conv1(x)
106
+ out = self.bn1(out)
107
+ out = self.relu(out)
108
+
109
+ out = self.conv2(out)
110
+ out = self.bn2(out)
111
+ out = self.relu(out)
112
+
113
+ out = self.conv3(out)
114
+ out = self.bn3(out)
115
+
116
+ if self.downsample is not None:
117
+ identity = self.downsample(x)
118
+
119
+ out += identity
120
+ out = self.relu(out)
121
+
122
+ return out
123
+
124
+
125
+ class ResNet(nn.Module):
126
+
127
+ def __init__(self, block, layers, zero_init_residual=False,
128
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
129
+ norm_layer=None):
130
+ super(ResNet, self).__init__()
131
+ if norm_layer is None:
132
+ norm_layer = nn.BatchNorm2d
133
+ self._norm_layer = norm_layer
134
+
135
+ self.inplanes = 64
136
+ self.dilation = 1
137
+ if replace_stride_with_dilation is None:
138
+ # each element in the tuple indicates if we should replace
139
+ # the 2x2 stride with a dilated convolution instead
140
+ replace_stride_with_dilation = [False, False, False]
141
+ if len(replace_stride_with_dilation) != 3:
142
+ raise ValueError("replace_stride_with_dilation should be None "
143
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
144
+ self.groups = groups
145
+ self.base_width = width_per_group
146
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
147
+ bias=False)
148
+ self.bn1 = norm_layer(self.inplanes)
149
+ self.relu = nn.ReLU(inplace=True)
150
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
151
+ self.layer1 = self._make_layer(block, 64, layers[0])
152
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
153
+ dilate=replace_stride_with_dilation[0])
154
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
155
+ dilate=replace_stride_with_dilation[1])
156
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
157
+ dilate=replace_stride_with_dilation[2])
158
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
159
+ # self.fc = nn.Linear(512 * block.expansion, num_classes)
160
+
161
+ for m in self.modules():
162
+ if isinstance(m, nn.Conv2d):
163
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
164
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
165
+ nn.init.constant_(m.weight, 1)
166
+ nn.init.constant_(m.bias, 0)
167
+
168
+ # Zero-initialize the last BN in each residual branch,
169
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
170
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
171
+ if zero_init_residual:
172
+ for m in self.modules():
173
+ if isinstance(m, Bottleneck):
174
+ nn.init.constant_(m.bn3.weight, 0)
175
+ elif isinstance(m, BasicBlock):
176
+ nn.init.constant_(m.bn2.weight, 0)
177
+
178
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
179
+ norm_layer = self._norm_layer
180
+ downsample = None
181
+ previous_dilation = self.dilation
182
+ if dilate:
183
+ self.dilation *= stride
184
+ stride = 1
185
+ if stride != 1 or self.inplanes != planes * block.expansion:
186
+ downsample = nn.Sequential(
187
+ conv1x1(self.inplanes, planes * block.expansion, stride),
188
+ norm_layer(planes * block.expansion),
189
+ )
190
+
191
+ layers = []
192
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
193
+ self.base_width, previous_dilation, norm_layer))
194
+ self.inplanes = planes * block.expansion
195
+ for _ in range(1, blocks):
196
+ layers.append(block(self.inplanes, planes, groups=self.groups,
197
+ base_width=self.base_width, dilation=self.dilation,
198
+ norm_layer=norm_layer))
199
+
200
+ return nn.Sequential(*layers)
201
+
202
+ def _forward_impl(self, x):
203
+ outputs = {}
204
+ # See note [TorchScript super()]
205
+ x = self.conv1(x)
206
+ x = self.bn1(x)
207
+ x = self.relu(x)
208
+ x = self.maxpool(x)
209
+ # outputs['stem'] = x
210
+
211
+ x = self.layer1(x) # 1/4
212
+ outputs['res2'] = x
213
+
214
+ x = self.layer2(x) # 1/8
215
+ outputs['res3'] = x
216
+
217
+ x = self.layer3(x) # 1/16
218
+ outputs['res4'] = x
219
+
220
+ x = self.layer4(x) # 1/32
221
+ outputs['res5'] = x
222
+
223
+ return outputs
224
+
225
+ def forward(self, x):
226
+ return self._forward_impl(x)
227
+
228
+
229
+ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
230
+ model = ResNet(block, layers, **kwargs)
231
+ if pretrained:
232
+ state_dict = load_state_dict_from_url(model_urls[arch],
233
+ progress=progress)
234
+ model.load_state_dict(state_dict, strict=False)
235
+ return model
236
+
237
+
238
+ def resnet18(pretrained=False, progress=True, **kwargs):
239
+ r"""ResNet-18 model from
240
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
241
+ Args:
242
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
243
+ progress (bool): If True, displays a progress bar of the download to stderr
244
+ """
245
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
246
+ **kwargs)
247
+
248
+
249
+ def resnet34(pretrained=False, progress=True, **kwargs):
250
+ r"""ResNet-34 model from
251
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
252
+ Args:
253
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
254
+ progress (bool): If True, displays a progress bar of the download to stderr
255
+ """
256
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
257
+ **kwargs)
258
+
259
+
260
+ def resnet50(pretrained=False, progress=True, **kwargs):
261
+ r"""ResNet-50 model from
262
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
263
+ Args:
264
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
265
+ progress (bool): If True, displays a progress bar of the download to stderr
266
+ """
267
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
268
+ **kwargs)
269
+
270
+
271
+ def resnet101(pretrained=False, progress=True, **kwargs):
272
+ r"""ResNet-101 model from
273
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
274
+ Args:
275
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
276
+ progress (bool): If True, displays a progress bar of the download to stderr
277
+ """
278
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
279
+ **kwargs)
280
+
281
+
282
+ def resnet152(pretrained=False, progress=True, **kwargs):
283
+ r"""ResNet-152 model from
284
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
285
+ Args:
286
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
287
+ progress (bool): If True, displays a progress bar of the download to stderr
288
+ """
289
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
290
+ **kwargs)
291
+
292
+
293
+ def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
294
+ r"""ResNeXt-50 32x4d model from
295
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
296
+ Args:
297
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
298
+ progress (bool): If True, displays a progress bar of the download to stderr
299
+ """
300
+ kwargs['groups'] = 32
301
+ kwargs['width_per_group'] = 4
302
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
303
+ pretrained, progress, **kwargs)
304
+
305
+
306
+ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
307
+ r"""ResNeXt-101 32x8d model from
308
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
309
+ Args:
310
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
311
+ progress (bool): If True, displays a progress bar of the download to stderr
312
+ """
313
+ kwargs['groups'] = 32
314
+ kwargs['width_per_group'] = 8
315
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
316
+ pretrained, progress, **kwargs)
317
+
318
+
319
+ def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
320
+ r"""Wide ResNet-50-2 model from
321
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
322
+ The model is the same as ResNet except for the bottleneck number of channels
323
+ which is twice larger in every block. The number of channels in outer 1x1
324
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
325
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
326
+ Args:
327
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
328
+ progress (bool): If True, displays a progress bar of the download to stderr
329
+ """
330
+ kwargs['width_per_group'] = 64 * 2
331
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
332
+ pretrained, progress, **kwargs)
333
+
334
+
335
+ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
336
+ r"""Wide ResNet-101-2 model from
337
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
338
+ The model is the same as ResNet except for the bottleneck number of channels
339
+ which is twice larger in every block. The number of channels in outer 1x1
340
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
341
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
342
+ Args:
343
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
344
+ progress (bool): If True, displays a progress bar of the download to stderr
345
+ """
346
+ kwargs['width_per_group'] = 64 * 2
347
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
348
+ pretrained, progress, **kwargs)
segmenter_model/blocks.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from 2020 Ross Wightman
3
+ https://github.com/rwightman/pytorch-image-models
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from pathlib import Path
10
+
11
+ import torch.nn.functional as F
12
+
13
+ from timm.models.layers import DropPath
14
+
15
+
16
+ class FeedForward(nn.Module):
17
+ def __init__(self, dim, hidden_dim, dropout, out_dim=None):
18
+ super().__init__()
19
+ self.fc1 = nn.Linear(dim, hidden_dim)
20
+ self.act = nn.GELU()
21
+ if out_dim is None:
22
+ out_dim = dim
23
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
24
+ self.drop = nn.Dropout(dropout)
25
+
26
+ @property
27
+ def unwrapped(self):
28
+ return self
29
+
30
+ def forward(self, x):
31
+ x = self.fc1(x)
32
+ x = self.act(x)
33
+ x = self.drop(x)
34
+ x = self.fc2(x)
35
+ x = self.drop(x)
36
+ return x
37
+
38
+
39
+ class Attention(nn.Module):
40
+ def __init__(self, dim, heads, dropout):
41
+ super().__init__()
42
+ self.heads = heads
43
+ head_dim = dim // heads
44
+ self.scale = head_dim ** -0.5
45
+ self.attn = None
46
+
47
+ self.qkv = nn.Linear(dim, dim * 3)
48
+ self.attn_drop = nn.Dropout(dropout)
49
+ self.proj = nn.Linear(dim, dim)
50
+ self.proj_drop = nn.Dropout(dropout)
51
+
52
+ @property
53
+ def unwrapped(self):
54
+ return self
55
+
56
+ def forward(self, x, mask=None):
57
+ B, N, C = x.shape
58
+ qkv = (
59
+ self.qkv(x)
60
+ .reshape(B, N, 3, self.heads, C // self.heads)
61
+ .permute(2, 0, 3, 1, 4)
62
+ )
63
+ q, k, v = (
64
+ qkv[0],
65
+ qkv[1],
66
+ qkv[2],
67
+ )
68
+
69
+ attn = (q @ k.transpose(-2, -1)) * self.scale
70
+ attn = attn.softmax(dim=-1)
71
+ attn = self.attn_drop(attn)
72
+
73
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
74
+ x = self.proj(x)
75
+ x = self.proj_drop(x)
76
+
77
+ return x, attn
78
+
79
+
80
+ class AttentionQK(nn.Module):
81
+ def __init__(self, dim, heads=1, dropout=0.):
82
+ super().__init__()
83
+ self.heads = heads
84
+ head_dim = dim // heads
85
+ self.scale = head_dim ** -0.5
86
+ self.attn = None
87
+
88
+ self.qk = nn.Linear(dim, dim * 2)
89
+ self.attn_drop = nn.Dropout(dropout)
90
+
91
+ @property
92
+ def unwrapped(self):
93
+ return self
94
+
95
+ def forward(self, x):
96
+ B, N, C = x.shape
97
+ qkv = (
98
+ self.qk(x)
99
+ .reshape(B, N, 2, self.heads, C // self.heads)
100
+ .permute(2, 0, 3, 1, 4)
101
+ )
102
+ q, k = (
103
+ qkv[0],
104
+ qkv[1]
105
+ )
106
+
107
+ attn = (q @ k.transpose(-2, -1)) * self.scale
108
+ # attn = attn.sigmoid()
109
+ attn = attn.softmax(dim=-1)
110
+
111
+ return attn
112
+
113
+
114
+ class Block(nn.Module):
115
+ def __init__(self, dim, heads, mlp_dim, dropout, drop_path):
116
+ super().__init__()
117
+ self.norm1 = nn.LayerNorm(dim)
118
+ self.norm2 = nn.LayerNorm(dim)
119
+ self.attn = Attention(dim, heads, dropout)
120
+ self.mlp = FeedForward(dim, mlp_dim, dropout)
121
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
122
+
123
+ def forward(self, x, mask=None, return_attention=False):
124
+ y, attn = self.attn(self.norm1(x), mask)
125
+ if return_attention:
126
+ return attn
127
+ x = x + self.drop_path(y)
128
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
129
+ return x
segmenter_model/decoder.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+
7
+ from timm.models.layers import trunc_normal_
8
+
9
+ from segmenter_model.blocks import Block, FeedForward
10
+ from segmenter_model.utils import init_weights
11
+
12
+
13
+ class DecoderLinear(nn.Module):
14
+ def __init__(self, n_cls, patch_size, d_encoder):
15
+ super().__init__()
16
+
17
+ self.d_encoder = d_encoder
18
+ self.patch_size = patch_size
19
+ self.n_cls = n_cls
20
+
21
+ self.head = nn.Linear(self.d_encoder, n_cls)
22
+ self.apply(init_weights)
23
+
24
+ @torch.jit.ignore
25
+ def no_weight_decay(self):
26
+ return set()
27
+
28
+ def forward(self, x, im_size):
29
+ H, W = im_size
30
+ GS = H // self.patch_size
31
+ x = self.head(x)
32
+ x = rearrange(x, "b (h w) c -> b c h w", h=GS)
33
+
34
+ return x
35
+
36
+
37
+ class MaskTransformer(nn.Module):
38
+ def __init__(
39
+ self,
40
+ n_cls,
41
+ patch_size,
42
+ d_encoder,
43
+ n_layers,
44
+ n_heads,
45
+ d_model,
46
+ d_ff,
47
+ drop_path_rate,
48
+ dropout,
49
+ ):
50
+ super().__init__()
51
+ self.d_encoder = d_encoder
52
+ self.patch_size = patch_size
53
+ self.n_layers = n_layers
54
+ self.n_cls = n_cls
55
+ self.d_model = d_model
56
+ self.d_ff = d_ff
57
+ self.scale = d_model ** -0.5
58
+
59
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
60
+ self.blocks = nn.ModuleList(
61
+ [Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)]
62
+ )
63
+
64
+ self.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model))
65
+ self.proj_dec = nn.Linear(d_encoder, d_model)
66
+
67
+ self.proj_patch = nn.Parameter(self.scale * torch.randn(d_model, d_model))
68
+ self.proj_classes = nn.Parameter(self.scale * torch.randn(d_model, d_model))
69
+
70
+ self.decoder_norm = nn.LayerNorm(d_model)
71
+ self.mask_norm = nn.LayerNorm(n_cls)
72
+
73
+ self.apply(init_weights)
74
+ trunc_normal_(self.cls_emb, std=0.02)
75
+
76
+ @torch.jit.ignore
77
+ def no_weight_decay(self):
78
+ return {"cls_emb"}
79
+
80
+ def forward(self, x, im_size, features_only=False, no_rearrange=False):
81
+ H, W = im_size
82
+ GS = H // self.patch_size
83
+
84
+ # project from the encoder dimensionality to the decoder dimensionality (usually the same)
85
+ x = self.proj_dec(x)
86
+ # reshape the class embedding token
87
+ cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
88
+ # concatenate the class embedding token to the input
89
+ x = torch.cat((x, cls_emb), 1)
90
+ # forward the concatenated tokens through decoder blocks
91
+ for blk in self.blocks:
92
+ x = blk(x)
93
+ # perform normalization
94
+ x = self.decoder_norm(x)
95
+
96
+ # split to patch features and class-segmentation features
97
+ patches, cls_seg_feat = x[:, : -self.n_cls], x[:, -self.n_cls:]
98
+
99
+ # project the patch features
100
+ patches = patches @ self.proj_patch
101
+
102
+ if features_only:
103
+ if not no_rearrange:
104
+ features = rearrange(patches, "b (h w) n -> b n h w", h=int(GS))
105
+ else:
106
+ features = patches
107
+ return features
108
+
109
+ # project the class-segmentation features
110
+ cls_seg_feat = cls_seg_feat @ self.proj_classes
111
+
112
+ # scalar product between L2-normalized patch embeddings and class embeddings -> masks
113
+ patches = patches / patches.norm(dim=-1, keepdim=True)
114
+ cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)
115
+ masks = patches @ cls_seg_feat.transpose(1, 2)
116
+
117
+ masks = self.mask_norm(masks)
118
+ if not no_rearrange:
119
+ masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS))
120
+
121
+ return masks
122
+
123
+ def get_attention_map(self, x, layer_id):
124
+ if layer_id >= self.n_layers or layer_id < 0:
125
+ raise ValueError(
126
+ f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}."
127
+ )
128
+ x = self.proj_dec(x)
129
+ cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
130
+ x = torch.cat((x, cls_emb), 1)
131
+ for i, blk in enumerate(self.blocks):
132
+ if i < layer_id:
133
+ x = blk(x)
134
+ else:
135
+ return blk(x, return_attention=True)
136
+
137
+
138
+ class DeepLabHead(nn.Sequential):
139
+ def __init__(self, in_channels, num_classes, patch_size=None):
140
+ super(DeepLabHead, self).__init__(
141
+ ASPP(in_channels, [12, 24, 36]),
142
+ nn.Conv2d(256, 256, 3, padding=1, bias=False),
143
+ nn.BatchNorm2d(256),
144
+ nn.ReLU(),
145
+ nn.Conv2d(256, num_classes, 1)
146
+ )
147
+ self.patch_size = patch_size
148
+
149
+ def forward(self, x, im_size=None):
150
+ if len(x.shape) == 3:
151
+ # features from ViT
152
+ assert im_size is not None and self.patch_size is not None
153
+ H, W = im_size
154
+ GS = H // self.patch_size
155
+ x = rearrange(x, "b (h w) n -> b n h w", h=int(GS)).contiguous()
156
+ for module in self:
157
+ x = module(x)
158
+ return x
159
+
160
+
161
+ class ASPPConv(nn.Sequential):
162
+ def __init__(self, in_channels, out_channels, dilation):
163
+ modules = [
164
+ nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
165
+ nn.BatchNorm2d(out_channels),
166
+ nn.ReLU()
167
+ ]
168
+ super(ASPPConv, self).__init__(*modules)
169
+
170
+
171
+ class ASPPPooling(nn.Sequential):
172
+ def __init__(self, in_channels, out_channels):
173
+ super(ASPPPooling, self).__init__(
174
+ nn.AdaptiveAvgPool2d(1),
175
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
176
+ nn.BatchNorm2d(out_channels),
177
+ nn.ReLU())
178
+
179
+ def forward(self, x):
180
+ size = x.shape[-2:]
181
+ for mod in self:
182
+ x = mod(x)
183
+ return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
184
+
185
+
186
+ class ASPP(nn.Module):
187
+ def __init__(self, in_channels, atrous_rates, out_channels=256):
188
+ super(ASPP, self).__init__()
189
+ modules = []
190
+ modules.append(nn.Sequential(
191
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
192
+ nn.BatchNorm2d(out_channels),
193
+ nn.ReLU()))
194
+
195
+ rates = tuple(atrous_rates)
196
+ for rate in rates:
197
+ modules.append(ASPPConv(in_channels, out_channels, rate))
198
+
199
+ modules.append(ASPPPooling(in_channels, out_channels))
200
+
201
+ self.convs = nn.ModuleList(modules)
202
+
203
+ self.project = nn.Sequential(
204
+ nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
205
+ nn.BatchNorm2d(out_channels),
206
+ nn.ReLU(),
207
+ nn.Dropout(0.5))
208
+
209
+ def forward(self, x):
210
+ res = []
211
+ for conv in self.convs:
212
+ res.append(conv(x))
213
+ res = torch.cat(res, dim=1)
214
+ return self.project(res)
segmenter_model/factory.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import yaml
3
+ import torch
4
+ import math
5
+ import os
6
+ import torch.nn as nn
7
+
8
+ from timm.models.helpers import load_pretrained, load_custom_pretrained
9
+ from timm.models.vision_transformer import default_cfgs, checkpoint_filter_fn
10
+ from timm.models.registry import register_model
11
+ from timm.models.vision_transformer import _create_vision_transformer
12
+ from segmenter_model.decoder import MaskTransformer
13
+ from segmenter_model.segmenter import Segmenter
14
+ import segmenter_model.torch as ptu
15
+
16
+ from segmenter_model.vit_dino import vit_small, VisionTransformer
17
+
18
+
19
+ @register_model
20
+ def vit_base_patch8_384(pretrained=False, **kwargs):
21
+ """ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
22
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
23
+ """
24
+ model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
25
+ model = _create_vision_transformer(
26
+ "vit_base_patch8_384",
27
+ pretrained=pretrained,
28
+ default_cfg=dict(
29
+ url="",
30
+ input_size=(3, 384, 384),
31
+ mean=(0.5, 0.5, 0.5),
32
+ std=(0.5, 0.5, 0.5),
33
+ num_classes=1000,
34
+ ),
35
+ **model_kwargs,
36
+ )
37
+ return model
38
+
39
+
40
+ def create_vit(model_cfg):
41
+ model_cfg = model_cfg.copy()
42
+ backbone = model_cfg.pop("backbone")
43
+ if 'pretrained_weights' in model_cfg:
44
+ pretrained_weights = model_cfg.pop('pretrained_weights')
45
+
46
+ if 'dino' in backbone:
47
+ if backbone.lower() == 'dino_vits16':
48
+ model_cfg['drop_rate'] = model_cfg['dropout']
49
+ model = vit_small(**model_cfg)
50
+ # hard-coded for now, too lazy
51
+ ciirc_path = '/home/vobecant/PhD/weights/dino/dino_deitsmall16_pretrain.pth'
52
+ karolina_path = '/scratch/project/dd-21-20/pretrained_weights/dino/dino_deitsmall16_pretrain.pth'
53
+ if os.path.exists(ciirc_path):
54
+ pretrained_weights = ciirc_path
55
+ elif os.path.exists(karolina_path):
56
+ pretrained_weights = karolina_path
57
+ else:
58
+ raise Exception('DINO weights not found!')
59
+ model.load_state_dict(torch.load(pretrained_weights), strict=True)
60
+ else:
61
+ model = torch.hub.load('facebookresearch/dino:main', backbone)
62
+ setattr(model, 'd_model', model.num_features)
63
+ setattr(model, 'patch_size', model.patch_embed.patch_size)
64
+ setattr(model, 'distilled', False)
65
+ model.forward = lambda x, return_features: model.get_intermediate_layers(x, n=1)[0]
66
+ else:
67
+ normalization = model_cfg.pop("normalization")
68
+ model_cfg["n_cls"] = 1000
69
+ mlp_expansion_ratio = 4
70
+ model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"]
71
+
72
+ if backbone in default_cfgs:
73
+ default_cfg = default_cfgs[backbone]
74
+ else:
75
+ default_cfg = dict(
76
+ pretrained=False,
77
+ num_classes=1000,
78
+ drop_rate=0.0,
79
+ drop_path_rate=0.0,
80
+ drop_block_rate=None,
81
+ )
82
+
83
+ default_cfg["input_size"] = (
84
+ 3,
85
+ model_cfg["image_size"][0],
86
+ model_cfg["image_size"][1],
87
+ )
88
+ model = VisionTransformer(**model_cfg)
89
+ if backbone == "vit_base_patch8_384":
90
+ path = os.path.expandvars("/home/vobecant/PhD/weights/vit_base_patch8_384.pth")
91
+ state_dict = torch.load(path, map_location="cpu")
92
+ filtered_dict = checkpoint_filter_fn(state_dict, model)
93
+ model.load_state_dict(filtered_dict, strict=True)
94
+ elif "deit" in backbone:
95
+ load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn)
96
+ else:
97
+ load_custom_pretrained(model, default_cfg)
98
+
99
+ return model
100
+
101
+
102
+ def create_decoder(encoder, decoder_cfg):
103
+ decoder_cfg = decoder_cfg.copy()
104
+ name = decoder_cfg.pop("name")
105
+ decoder_cfg["d_encoder"] = encoder.d_model
106
+ decoder_cfg["patch_size"] = encoder.patch_size
107
+
108
+ if "linear" in name:
109
+ decoder = DecoderLinear(**decoder_cfg)
110
+ elif name == "mask_transformer":
111
+ dim = encoder.d_model
112
+ n_heads = dim // 64
113
+ decoder_cfg["n_heads"] = n_heads
114
+ decoder_cfg["d_model"] = dim
115
+ decoder_cfg["d_ff"] = 4 * dim
116
+ decoder = MaskTransformer(**decoder_cfg)
117
+ elif 'deeplab' in name:
118
+ decoder = DeepLabHead(in_channels=encoder.d_model, num_classes=decoder_cfg["n_cls"],
119
+ patch_size=decoder_cfg["patch_size"])
120
+ else:
121
+ raise ValueError(f"Unknown decoder: {name}")
122
+ return decoder
123
+
124
+
125
+ def create_segmenter(model_cfg):
126
+ model_cfg = model_cfg.copy()
127
+ decoder_cfg = model_cfg.pop("decoder")
128
+ decoder_cfg["n_cls"] = model_cfg["n_cls"]
129
+
130
+ if 'weights_path' in model_cfg.keys():
131
+ weights_path = model_cfg.pop('weights_path')
132
+ else:
133
+ weights_path = None
134
+
135
+ encoder = create_vit(model_cfg)
136
+ decoder = create_decoder(encoder, decoder_cfg)
137
+ model = Segmenter(encoder, decoder, n_cls=model_cfg["n_cls"])
138
+
139
+ if weights_path is not None:
140
+ raise Exception('Wants to load weights to the complete segmenter insice create_segmenter method!')
141
+ state_dict = torch.load(weights_path, map_location="cpu")
142
+ if 'model' in state_dict:
143
+ state_dict = state_dict['model']
144
+ msg = model.load_state_dict(state_dict, strict=False)
145
+ print(msg)
146
+
147
+ return model
148
+
149
+
150
+ def load_model(model_path, decoder_only=False, variant_path=None):
151
+ variant_path = Path(model_path).parent / "variant.yml" if variant_path is None else variant_path
152
+ with open(variant_path, "r") as f:
153
+ variant = yaml.load(f, Loader=yaml.FullLoader)
154
+ net_kwargs = variant["net_kwargs"]
155
+
156
+ model = create_segmenter(net_kwargs)
157
+ data = torch.load(model_path, map_location=ptu.device)
158
+ checkpoint = data["model"]
159
+
160
+ if decoder_only:
161
+ model.decoder.load_state_dict(checkpoint, strict=True)
162
+ else:
163
+ model.load_state_dict(checkpoint, strict=True)
164
+
165
+ return model, variant
segmenter_model/fpn_picie.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # taken from https://raw.githubusercontent.com/janghyuncho/PiCIE/1d7b034f57e98670b0d6a244b2eea11fa0dde73e/modules/fpn.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from . import backbone_picie as backbone
7
+
8
+
9
+ class PanopticFPN(nn.Module):
10
+ def __init__(self, arch, pretrain, n_cls):
11
+ super(PanopticFPN, self).__init__()
12
+ self.n_cls = n_cls
13
+ self.backbone = backbone.__dict__[arch](pretrained=pretrain)
14
+ self.decoder = FPNDecoder(arch, n_cls)
15
+
16
+ def forward(self, x, encoder_features=False, decoder_features=False):
17
+ feats = self.backbone(x)
18
+ if decoder_features:
19
+ dec, outs = self.decoder(feats, get_features=decoder_features)
20
+ else:
21
+ outs = self.decoder(feats)
22
+
23
+ if encoder_features:
24
+ if decoder_features:
25
+ return feats['res5'], dec, outs
26
+ else:
27
+ return feats['res5'], outs
28
+ else:
29
+ return outs
30
+
31
+
32
+ class FPNDecoder(nn.Module):
33
+ def __init__(self, arch, n_cls):
34
+ super(FPNDecoder, self).__init__()
35
+ self.n_cls = n_cls
36
+ if arch == 'resnet18':
37
+ mfactor = 1
38
+ out_dim = 128
39
+ else:
40
+ mfactor = 4
41
+ out_dim = 256
42
+
43
+ self.layer4 = nn.Conv2d(512 * mfactor // 8, out_dim, kernel_size=1, stride=1, padding=0)
44
+ self.layer3 = nn.Conv2d(512 * mfactor // 4, out_dim, kernel_size=1, stride=1, padding=0)
45
+ self.layer2 = nn.Conv2d(512 * mfactor // 2, out_dim, kernel_size=1, stride=1, padding=0)
46
+ self.layer1 = nn.Conv2d(512 * mfactor, out_dim, kernel_size=1, stride=1, padding=0)
47
+
48
+ self.pred = nn.Conv2d(out_dim, self.n_cls, 1, 1)
49
+
50
+ def forward(self, x, get_features=False):
51
+ o1 = self.layer1(x['res5'])
52
+ o2 = self.upsample_add(o1, self.layer2(x['res4']))
53
+ o3 = self.upsample_add(o2, self.layer3(x['res3']))
54
+ o4 = self.upsample_add(o3, self.layer4(x['res2']))
55
+
56
+ pred = self.pred(o4)
57
+
58
+ if get_features:
59
+ return o4, pred
60
+ else:
61
+ return pred
62
+
63
+ def upsample_add(self, x, y):
64
+ _, _, H, W = y.size()
65
+
66
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False) + y
segmenter_model/picie_model.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from . import backbone_picie as backbone
5
+
6
+
7
+ class PanopticFPN(nn.Module):
8
+ def __init__(self, args):
9
+ super(PanopticFPN, self).__init__()
10
+ self.backbone = backbone.__dict__[args.arch](pretrained=args.pretrain)
11
+ if args.arch == 'vit_small':
12
+ self.decoder = FPNDecoderViT(args)
13
+ else:
14
+ self.decoder = FPNDecoder(args)
15
+
16
+ def forward(self, x, encoder_features=False, decoder_features=False):
17
+ feats = self.backbone(x)
18
+ dec_outs = self.decoder(feats)
19
+
20
+ if encoder_features:
21
+ return feats['res5'], dec_outs
22
+ else:
23
+ return dec_outs
24
+
25
+
26
+ class FPNDecoder(nn.Module):
27
+ def __init__(self, args):
28
+ super(FPNDecoder, self).__init__()
29
+ if args.arch == 'resnet18':
30
+ mfactor = 1
31
+ out_dim = 128
32
+ else:
33
+ mfactor = 4
34
+ out_dim = 256
35
+
36
+ self.layer4 = nn.Conv2d(512 * mfactor // 8, out_dim, kernel_size=1, stride=1, padding=0)
37
+ self.layer3 = nn.Conv2d(512 * mfactor // 4, out_dim, kernel_size=1, stride=1, padding=0)
38
+ self.layer2 = nn.Conv2d(512 * mfactor // 2, out_dim, kernel_size=1, stride=1, padding=0)
39
+ self.layer1 = nn.Conv2d(512 * mfactor, out_dim, kernel_size=1, stride=1, padding=0)
40
+
41
+ def forward(self, x):
42
+ o1 = self.layer1(x['res5'])
43
+ o2 = self.upsample_add(o1, self.layer2(x['res4']))
44
+ o3 = self.upsample_add(o2, self.layer3(x['res3']))
45
+ o4 = self.upsample_add(o3, self.layer4(x['res2']))
46
+
47
+ return o4
48
+
49
+ def upsample_add(self, x, y):
50
+ _, _, H, W = y.size()
51
+
52
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False) + y
53
+
54
+
55
+ class FPNDecoderViT(nn.Module):
56
+ def __init__(self, args):
57
+ super(FPNDecoderViT, self).__init__()
58
+ if args.arch == 'resnet18' or args.arch == 'vit_small':
59
+ mfactor = 1
60
+ out_dim = 128
61
+ else:
62
+ mfactor = 4
63
+ out_dim = 256
64
+
65
+ self.upsample_rate = 4
66
+
67
+ self.layer4 = nn.Conv2d(384, out_dim, kernel_size=1, stride=1, padding=0)
68
+ self.layer3 = nn.Conv2d(384, out_dim, kernel_size=1, stride=1, padding=0)
69
+ self.layer2 = nn.Conv2d(384, out_dim, kernel_size=1, stride=1, padding=0)
70
+ self.layer1 = nn.Conv2d(384, out_dim, kernel_size=1, stride=1, padding=0)
71
+
72
+ def forward(self, x):
73
+ o1 = self.layer1(x[3])
74
+ o1 = F.interpolate(o1, scale_factor=4, mode='bilinear', align_corners=False)
75
+ o2 = self.upsample_add(o1, self.layer2(x[2]))
76
+ o3 = self.upsample_add(o2, self.layer3(x[1]))
77
+ o4 = self.upsample_add(o3, self.layer4(x[0]))
78
+
79
+ return o4
80
+
81
+ def upsample_add(self, x, y):
82
+ return F.interpolate(y, scale_factor=self.upsample_rate, mode='bilinear', align_corners=False) + x
segmenter_model/resnet_dilated.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Authors: Wouter Van Gansbeke & Simon Vandenhende
3
+ # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4
+
5
+ import torch.nn as nn
6
+
7
+ class ResnetDilated(nn.Module):
8
+ def __init__(self, orig_resnet, dilate_scale=8):
9
+ super(ResnetDilated, self).__init__()
10
+ from functools import partial
11
+
12
+ if dilate_scale == 8:
13
+ orig_resnet.layer3.apply(
14
+ partial(self._nostride_dilate, dilate=2))
15
+ orig_resnet.layer4.apply(
16
+ partial(self._nostride_dilate, dilate=4))
17
+ elif dilate_scale == 16:
18
+ orig_resnet.layer4.apply(
19
+ partial(self._nostride_dilate, dilate=2))
20
+
21
+ self.conv1 = orig_resnet.conv1
22
+ self.bn1 = orig_resnet.bn1
23
+ self.relu = orig_resnet.relu
24
+
25
+ self.maxpool = orig_resnet.maxpool
26
+ self.layer1 = orig_resnet.layer1
27
+ self.layer2 = orig_resnet.layer2
28
+ self.layer3 = orig_resnet.layer3
29
+ self.layer4 = orig_resnet.layer4
30
+
31
+ def _nostride_dilate(self, m, dilate):
32
+ classname = m.__class__.__name__
33
+ if classname.find('Conv') != -1:
34
+ # the convolution with stride
35
+ if m.stride == (2, 2):
36
+ m.stride = (1, 1)
37
+ if m.kernel_size == (3, 3):
38
+ m.dilation = (dilate//2, dilate//2)
39
+ m.padding = (dilate//2, dilate//2)
40
+ # other convoluions
41
+ else:
42
+ if m.kernel_size == (3, 3):
43
+ m.dilation = (dilate, dilate)
44
+ m.padding = (dilate, dilate)
45
+
46
+ def forward(self, x):
47
+ x = self.relu(self.bn1(self.conv1(x)))
48
+ x = self.maxpool(x)
49
+
50
+ x = self.layer1(x)
51
+ x = self.layer2(x)
52
+ x = self.layer3(x)
53
+ x = self.layer4(x)
54
+
55
+ return x
segmenter_model/segmenter.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ # from timm.models.layers import trunc_normal_
8
+
9
+ from segmenter_model.utils import padding, unpadding
10
+
11
+
12
+ class Segmenter(nn.Module):
13
+ def __init__(
14
+ self,
15
+ encoder,
16
+ decoder,
17
+ n_cls,
18
+ ):
19
+ super().__init__()
20
+ self.n_cls = n_cls
21
+ self.patch_size = encoder.patch_size
22
+ self.encoder = encoder
23
+ self.decoder = decoder
24
+
25
+ @torch.jit.ignore
26
+ def no_weight_decay(self):
27
+ def append_prefix_no_weight_decay(prefix, module):
28
+ return set(map(lambda x: prefix + x, module.no_weight_decay()))
29
+
30
+ nwd_params = append_prefix_no_weight_decay("encoder.", self.encoder).union(
31
+ append_prefix_no_weight_decay("decoder.", self.decoder)
32
+ )
33
+ return nwd_params
34
+
35
+ def forward(self, im, decoder_features=False, no_upsample=False, encoder_features=False, no_rearrange=False,
36
+ cls_only=False, encoder_only=False):
37
+ H_ori, W_ori = im.size(2), im.size(3)
38
+ if not no_upsample:
39
+ im = padding(im, self.patch_size)
40
+ H, W = im.size(2), im.size(3)
41
+
42
+ x = self.encoder(im, return_features=True) # self.patch_size times smaller than im
43
+
44
+ # remove CLS/DIST tokens for decoding
45
+ num_extra_tokens = 1 + self.encoder.distilled
46
+
47
+ if cls_only:
48
+ return x[:, 0]
49
+ x = x[:, num_extra_tokens:]
50
+
51
+ if encoder_features:
52
+ enc_fts = x.clone()
53
+ if not no_rearrange:
54
+ GS = H // self.patch_size
55
+ enc_fts = rearrange(enc_fts, "b (h w) c -> b c h w", h=GS)
56
+ if encoder_only:
57
+ return enc_fts
58
+
59
+ if decoder_features:
60
+ output = self.decoder(x, (H, W), features_only=True, no_rearrange=no_rearrange)
61
+ if no_rearrange:
62
+ if encoder_features:
63
+ output = (enc_fts, output)
64
+ return output
65
+ else:
66
+ output = self.decoder(x, (H, W)) # shape (BS, NCLS, H/self.patch_size, W/self.patch_size)
67
+
68
+ if not no_upsample:
69
+ output = F.interpolate(output, size=(H, W), mode="bilinear") # upsample self.patch_size times
70
+ output = unpadding(output, (H_ori, W_ori))
71
+
72
+ if encoder_features:
73
+ output = (enc_fts, output)
74
+ return output
75
+
76
+ def get_attention_map_enc(self, im, layer_id):
77
+ return self.encoder.get_attention_map(im, layer_id)
78
+
79
+ def get_attention_map_dec(self, im, layer_id):
80
+ x = self.encoder(im, return_features=True)
81
+
82
+ # remove CLS/DIST tokens for decoding
83
+ num_extra_tokens = 1 + self.encoder.distilled
84
+ x = x[:, num_extra_tokens:]
85
+
86
+ return self.decoder.get_attention_map(x, layer_id)
segmenter_model/torch.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ """
5
+ GPU wrappers
6
+ """
7
+
8
+ use_gpu = False
9
+ gpu_id = 0
10
+ device = None
11
+
12
+ distributed = False
13
+ dist_rank = 0
14
+ world_size = 1
15
+
16
+
17
+ def set_gpu_mode(mode, pbs=False):
18
+ global use_gpu
19
+ global device
20
+ global gpu_id
21
+ global distributed
22
+ global dist_rank
23
+ global world_size
24
+ if pbs:
25
+ gpu_id = int(os.environ.get("MPI_LOCALRANKID", 0))
26
+ dist_rank = int(os.environ.get("PMI_RANK", 0))
27
+ world_size = int(os.environ.get("PMI_SIZE", 1))
28
+ else:
29
+ gpu_id = int(os.environ.get("SLURM_LOCALID", 0))
30
+ dist_rank = int(os.environ.get("SLURM_PROCID", 0))
31
+ world_size = int(os.environ.get("SLURM_NTASKS", 1))
32
+
33
+ distributed = world_size > 1
34
+ use_gpu = mode
35
+ print('gpu_id: {}, dist_rank: {}, world_size: {}, distributed: {}'.format(gpu_id, dist_rank, world_size,
36
+ distributed))
37
+ device = torch.device(f"cuda:{gpu_id}" if use_gpu else "cpu")
38
+ torch.backends.cudnn.benchmark = True
segmenter_model/utils.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ # import segm.utils.torch as ptu
3
+ # from segm.engine import seg2rgb
4
+ from collections import namedtuple
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from PIL import Image
11
+ from timm.models.layers import trunc_normal_
12
+
13
+ import torch
14
+
15
+ CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
16
+ 'has_instances', 'ignore_in_eval', 'color'])
17
+
18
+ classes = [
19
+ CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
20
+ CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
21
+ CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
22
+ CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
23
+ CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
24
+ CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
25
+ CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
26
+ CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
27
+ CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
28
+ CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
29
+ CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
30
+ CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
31
+ CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
32
+ CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
33
+ CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
34
+ CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
35
+ CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
36
+ CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
37
+ CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
38
+ CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
39
+ CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
40
+ CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
41
+ CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
42
+ CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
43
+ CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
44
+ CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
45
+ CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
46
+ CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
47
+ CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
48
+ CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
49
+ CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
50
+ CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
51
+ CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
52
+ CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
53
+ CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
54
+ ]
55
+
56
+ cityscapes_id_to_trainID = {cls.id: cls.train_id for cls in classes}
57
+ cityscapes_trainID_to_testID = {cls.train_id: cls.id for cls in classes}
58
+ cityscapes_trainID_to_color = {cls.train_id: cls.color for cls in classes}
59
+ cityscapes_trainID_to_name = {cls.train_id: cls.name for cls in classes}
60
+ cityscapes_trainID_to_color[255] = (0, 0, 0)
61
+ cityscapes_trainID_to_name = {cls.train_id: cls.name for cls in classes}
62
+ cityscapes_trainID_to_name[255] = 'ignore'
63
+ cityscapes_trainID_to_name[19] = 'ignore'
64
+
65
+
66
+ def map2cs(seg):
67
+ while len(seg.shape) > 2:
68
+ seg = seg[0]
69
+ colors = cityscapes_trainID_to_color
70
+ # assert False, 'set ignore_idx color to black, make sure that it is not in colors'
71
+ rgb = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
72
+ for l in np.unique(seg):
73
+ rgb[seg == l, :] = colors[l]
74
+ return rgb
75
+
76
+
77
+ def get_colors(num_colors):
78
+ from PIL import ImageColor
79
+ import matplotlib
80
+ hex_colors = [
81
+ # "#000000", # keep the black reserved
82
+ "#FFFF00", "#1CE6FF", "#FF34FF", "#FF4A46", "#008941", "#006FA6", "#A30059",
83
+ "#FFDBE5", "#7A4900", "#0000A6", "#63FFAC", "#B79762", "#004D43", "#8FB0FF", "#997D87",
84
+ "#5A0007", "#809693", "#FEFFE6", "#1B4400", "#4FC601", "#3B5DFF", "#4A3B53", "#FF2F80",
85
+ "#61615A", "#BA0900", "#6B7900", "#00C2A0", "#FFAA92", "#FF90C9", "#B903AA", "#D16100",
86
+ "#DDEFFF", "#000035", "#7B4F4B", "#A1C299", "#300018", "#0AA6D8", "#013349", "#00846F",
87
+ "#372101", "#FFB500", "#C2FFED", "#A079BF", "#CC0744", "#C0B9B2", "#C2FF99", "#001E09",
88
+ "#00489C", "#6F0062", "#0CBD66", "#EEC3FF", "#456D75", "#B77B68", "#7A87A1", "#788D66",
89
+ "#885578", "#FAD09F", "#FF8A9A", "#D157A0", "#BEC459", "#456648", "#0086ED", "#886F4C",
90
+ "#34362D", "#B4A8BD", "#00A6AA", "#452C2C", "#636375", "#A3C8C9", "#FF913F", "#938A81",
91
+ "#575329", "#00FECF", "#B05B6F", "#8CD0FF", "#3B9700", "#04F757", "#C8A1A1", "#1E6E00",
92
+ "#7900D7", "#A77500", "#6367A9", "#A05837", "#6B002C", "#772600", "#D790FF", "#9B9700",
93
+ "#549E79", "#FFF69F", "#201625", "#72418F", "#BC23FF", "#99ADC0", "#3A2465", "#922329",
94
+ "#5B4534", "#FDE8DC", "#404E55", "#0089A3", "#CB7E98", "#A4E804", "#324E72", "#6A3A4C",
95
+ "#83AB58", "#001C1E", "#D1F7CE", "#004B28", "#C8D0F6", "#A3A489", "#806C66", "#222800",
96
+ "#BF5650", "#E83000", "#66796D", "#DA007C", "#FF1A59", "#8ADBB4", "#1E0200", "#5B4E51",
97
+ "#C895C5", "#320033", "#FF6832", "#66E1D3", "#CFCDAC", "#D0AC94", "#7ED379", "#012C58",
98
+ ]
99
+ hex_colors_mlib = list(matplotlib.colors.cnames.values())
100
+ for hcm in hex_colors_mlib:
101
+ if hcm not in hex_colors:
102
+ hex_colors.append(hcm)
103
+ colors = [ImageColor.getrgb(hex) for hex in hex_colors]
104
+ return colors[:num_colors]
105
+
106
+
107
+ def colorize_one(seg, ignore=None, colors=None, ncolors=32):
108
+ unq = np.unique(seg)
109
+ if ncolors is not None:
110
+ ncolors = max(ncolors, max(unq))
111
+ else:
112
+ ncolors = max(unq)
113
+ colors = get_colors(ncolors) if colors is None else colors
114
+ h, w = seg.shape
115
+ c = 3
116
+ rgb = np.zeros((h, w, c), dtype=np.uint8)
117
+ for l in unq:
118
+ if ignore is not None and l == ignore:
119
+ continue
120
+ try:
121
+ rgb[seg == l, :] = colors[l]
122
+ except:
123
+ raise Exception(l)
124
+ return rgb
125
+
126
+
127
+ def init_weights(m):
128
+ if isinstance(m, nn.Linear):
129
+ trunc_normal_(m.weight, std=0.02)
130
+ if isinstance(m, nn.Linear) and m.bias is not None:
131
+ nn.init.constant_(m.bias, 0)
132
+ elif isinstance(m, nn.LayerNorm):
133
+ nn.init.constant_(m.bias, 0)
134
+ nn.init.constant_(m.weight, 1.0)
135
+
136
+
137
+ def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens):
138
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
139
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
140
+ posemb_tok, posemb_grid = (
141
+ posemb[:, :num_extra_tokens],
142
+ posemb[0, num_extra_tokens:],
143
+ )
144
+ if grid_old_shape is None:
145
+ gs_old_h = int(math.sqrt(len(posemb_grid)))
146
+ gs_old_w = gs_old_h
147
+ else:
148
+ gs_old_h, gs_old_w = grid_old_shape
149
+
150
+ gs_h, gs_w = grid_new_shape
151
+ posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2)
152
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
153
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
154
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
155
+ return posemb
156
+
157
+
158
+ def checkpoint_filter_fn(state_dict, model):
159
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
160
+ out_dict = {}
161
+ if "model" in state_dict:
162
+ # For deit models
163
+ state_dict = state_dict["model"]
164
+ num_extra_tokens = 1 + ("dist_token" in state_dict.keys())
165
+ patch_size = model.patch_size
166
+ image_size = model.patch_embed.image_size
167
+ for k, v in state_dict.items():
168
+ if k == "pos_embed" and v.shape != model.pos_embed.shape:
169
+ # To resize pos embedding when using model at different size from pretrained weights
170
+ v = resize_pos_embed(
171
+ v,
172
+ None,
173
+ (image_size[0] // patch_size, image_size[1] // patch_size),
174
+ num_extra_tokens,
175
+ )
176
+ out_dict[k] = v
177
+ return out_dict
178
+
179
+
180
+ def padding(im, patch_size, fill_value=0):
181
+ # make the image sizes divisible by patch_size
182
+ H, W = im.size(2), im.size(3)
183
+ pad_h, pad_w = 0, 0
184
+ if H % patch_size > 0:
185
+ pad_h = patch_size - (H % patch_size)
186
+ if W % patch_size > 0:
187
+ pad_w = patch_size - (W % patch_size)
188
+ im_padded = im
189
+ if pad_h > 0 or pad_w > 0:
190
+ im_padded = F.pad(im, (0, pad_w, 0, pad_h), value=fill_value)
191
+ return im_padded
192
+
193
+
194
+ def unpadding(y, target_size):
195
+ H, W = target_size
196
+ H_pad, W_pad = y.size(2), y.size(3)
197
+ # crop predictions on extra pixels coming from padding
198
+ extra_h = H_pad - H
199
+ extra_w = W_pad - W
200
+ if extra_h > 0:
201
+ y = y[:, :, :-extra_h]
202
+ if extra_w > 0:
203
+ y = y[:, :, :, :-extra_w]
204
+ return y
205
+
206
+
207
+ def resize(im, smaller_size):
208
+ h, w = im.shape[2:]
209
+ if h < w:
210
+ ratio = w / h
211
+ h_res, w_res = smaller_size, ratio * smaller_size
212
+ else:
213
+ ratio = h / w
214
+ h_res, w_res = ratio * smaller_size, smaller_size
215
+ if min(h, w) < smaller_size:
216
+ im_res = F.interpolate(im, (int(h_res), int(w_res)), mode="bilinear")
217
+ else:
218
+ im_res = im
219
+ return im_res
220
+
221
+
222
+ def sliding_window(im, flip, window_size, window_stride, channels_first=True):
223
+ if channels_first:
224
+ B, C, H, W = im.shape
225
+ else:
226
+ B, H, W, C = im.shape
227
+ ws = window_size
228
+
229
+ windows = {"crop": [], "anchors": []}
230
+ h_anchors = torch.arange(0, H, window_stride)
231
+ w_anchors = torch.arange(0, W, window_stride)
232
+ h_anchors = [h.item() for h in h_anchors if h < H - ws] + [H - ws]
233
+ w_anchors = [w.item() for w in w_anchors if w < W - ws] + [W - ws]
234
+ for ha in h_anchors:
235
+ for wa in w_anchors:
236
+ if channels_first:
237
+ window = im[:, :, ha: ha + ws, wa: wa + ws]
238
+ else:
239
+ window = im[:, ha: ha + ws, wa: wa + ws]
240
+ windows["crop"].append(window)
241
+ windows["anchors"].append((ha, wa))
242
+ windows["flip"] = flip
243
+ windows["shape"] = (H, W)
244
+ return windows
245
+
246
+
247
+ def merge_windows(windows, window_size, ori_shape, no_softmax=False, no_upsample=False, patch_size=None):
248
+ ws = window_size
249
+ im_windows = windows["seg_maps"]
250
+ anchors = windows["anchors"]
251
+ C = im_windows[0].shape[0]
252
+ H, W = windows["shape"]
253
+ flip = windows["flip"]
254
+
255
+ if no_upsample:
256
+ H, W = H // patch_size, W // patch_size
257
+
258
+ logit = torch.zeros((C, H, W), device=im_windows.device)
259
+ count = torch.zeros((1, H, W), device=im_windows.device)
260
+ for window, (ha, wa) in zip(im_windows, anchors):
261
+ if no_upsample:
262
+ ha = ha // patch_size
263
+ wa = wa // patch_size
264
+ logit[:, ha: ha + ws, wa: wa + ws] += window
265
+ count[:, ha: ha + ws, wa: wa + ws] += 1
266
+ logit /= count
267
+ # print('Interpolate {} -> {}'.format(logit.shape, ori_shape))
268
+ if not no_upsample:
269
+ logit = F.interpolate(
270
+ logit.unsqueeze(0),
271
+ ori_shape,
272
+ mode="bilinear",
273
+ )[0]
274
+ if flip:
275
+ logit = torch.flip(logit, (2,))
276
+ if not no_softmax:
277
+ # print('Softmax in merge_windows')
278
+ result = F.softmax(logit, 0)
279
+ else:
280
+ # print('No softmax in merge_windows')
281
+ result = logit
282
+ return result
283
+
284
+
285
+ def debug_windows(windows, debug_file):
286
+ pass
287
+
288
+
289
+ def inference_picie(
290
+ model,
291
+ classifier,
292
+ metric_test,
293
+ ims,
294
+ ori_shape,
295
+ window_size,
296
+ window_stride,
297
+ batch_size,
298
+ decoder_features=False,
299
+ no_upsample=False,
300
+ debug_file=None,
301
+ im_rgb=None,
302
+ channel_first=False
303
+ ):
304
+ try:
305
+ C = model.n_cls
306
+ except:
307
+ C = classifier.module.bias.shape[0]
308
+
309
+ # seg_maps = []
310
+
311
+ # for im, im_metas in zip(ims, ims_metas):
312
+ for im in ims:
313
+ im = im.to('cuda')
314
+ if len(im.shape) == 3:
315
+ im = im.unsqueeze(0)
316
+ flip = False # im_metas["flip"]
317
+ windows = sliding_window(im, flip, window_size, window_stride)
318
+ crops = torch.stack(windows.pop("crop"))[:, 0]
319
+ num_crops = len(crops)
320
+
321
+ WB = batch_size if batch_size > 0 else num_crops
322
+ if no_upsample:
323
+ window_size = window_size // model.patch_size
324
+ seg_maps = torch.zeros((num_crops, C, window_size, window_size), device=im.device)
325
+ with torch.no_grad():
326
+ for i in range(0, num_crops, WB):
327
+ # try:
328
+ feats = model.forward(crops[i: i + WB])
329
+ if metric_test == 'cosine':
330
+ feats = F.normalize(feats, dim=1, p=2)
331
+ probs = classifier(feats)
332
+ probs = F.interpolate(probs, crops[i: i + WB].shape[-2:], mode='bilinear', align_corners=False)
333
+ seg_maps[i: i + WB] = probs
334
+ windows["seg_maps"] = seg_maps
335
+
336
+ if debug_file is not None:
337
+ if isinstance(im_rgb, torch.Tensor):
338
+ im_rgb = im_rgb.detach().cpu().numpy()
339
+ if len(im_rgb.shape) == 4:
340
+ im_rgb = im_rgb[0]
341
+ h, w = im.shape[-2:]
342
+ im_rgb = cv2.resize(im_rgb, (w, h), interpolation=cv2.INTER_LINEAR)
343
+
344
+ crops_rgb = np.stack(
345
+ sliding_window(im_rgb[None, :], flip, window_size, window_stride, channels_first=channel_first).pop(
346
+ "crop"))[:, 0]
347
+
348
+ im_seg_map = merge_windows(windows, window_size, ori_shape, no_softmax=decoder_features,
349
+ no_upsample=no_upsample, patch_size=None)
350
+
351
+ seg_map = im_seg_map
352
+ if no_upsample and not decoder_features:
353
+ pass
354
+ else:
355
+ seg_map = F.interpolate(
356
+ seg_map.unsqueeze(0),
357
+ ori_shape,
358
+ mode="bilinear",
359
+ )
360
+
361
+ return seg_map
362
+
363
+
364
+ def inference(
365
+ model,
366
+ ims,
367
+ ori_shape,
368
+ window_size,
369
+ window_stride,
370
+ batch_size,
371
+ decoder_features=False,
372
+ encoder_features=False,
373
+ save2cpu=False,
374
+ no_upsample=False,
375
+ debug_file=None,
376
+ im_rgb=None,
377
+ channel_first=False
378
+ ):
379
+ C = model.n_cls
380
+ patch_size = model.patch_size
381
+
382
+ # seg_maps = []
383
+
384
+ # for im, im_metas in zip(ims, ims_metas):
385
+ for im in ims:
386
+ im = im.to('cuda')
387
+ if len(im.shape) == 3:
388
+ im = im.unsqueeze(0)
389
+ # im = resize(im, window_size)
390
+ flip = False # im_metas["flip"]
391
+ # print(im)
392
+ windows = sliding_window(im, flip, window_size, window_stride)
393
+ # print(windows)
394
+ crops = torch.stack(windows.pop("crop"))[:, 0]
395
+ num_crops = len(crops)
396
+
397
+ WB = batch_size if batch_size > 0 else num_crops
398
+ if no_upsample:
399
+ window_size = window_size // model.patch_size
400
+ # print('Change variable window_size to {}'.format(window_size))
401
+ seg_maps = torch.zeros((num_crops, C, window_size, window_size), device=im.device)
402
+ # print('Allocated segm_maps: {}, device: {}'.format(seg_maps.shape, seg_maps.device))
403
+ with torch.no_grad():
404
+ for i in range(0, num_crops, WB):
405
+ # try:
406
+ seg_maps[i: i + WB] = model.forward(crops[i: i + WB], decoder_features=decoder_features,
407
+ encoder_features=encoder_features,
408
+ no_upsample=no_upsample)
409
+ # except:
410
+ # print('Input of shape: {}'.format(crops[i:i + WB].shape))
411
+ # assert False, "End after error."
412
+ # torch.cuda.empty_cache()
413
+ windows["seg_maps"] = seg_maps
414
+
415
+ if debug_file is not None:
416
+ if isinstance(im_rgb, torch.Tensor):
417
+ im_rgb = im_rgb.detach().cpu().numpy()
418
+ if len(im_rgb.shape) == 4:
419
+ im_rgb = im_rgb[0]
420
+ h, w = im.shape[-2:]
421
+ im_rgb = cv2.resize(im_rgb, (w, h), interpolation=cv2.INTER_LINEAR)
422
+
423
+ crops_rgb = np.stack(
424
+ sliding_window(im_rgb[None, :], flip, window_size, window_stride, channels_first=channel_first).pop(
425
+ "crop"))[:, 0]
426
+
427
+ windows_row = np.concatenate([w for w in crops_rgb], axis=1)
428
+ # print(windows_row)
429
+ try:
430
+ Image.fromarray(windows_row).save(debug_file)
431
+ except:
432
+ pass
433
+
434
+ suffix = debug_file[-4:]
435
+ debug_file = debug_file.replace(suffix, '_preds{}'.format(suffix))
436
+ windows_preds = seg_maps.argmax(dim=1).cpu().numpy()
437
+ windows_preds_row = np.concatenate([seg2rgb(wp, C, 255) for wp in windows_preds], axis=1)
438
+ windows_row_plus_preds = np.concatenate((windows_row, windows_preds_row), axis=0)
439
+ try:
440
+ Image.fromarray(windows_preds_row).save(debug_file)
441
+ except:
442
+ pass
443
+
444
+ debug_file = debug_file.replace(suffix, '_wImg{}'.format(suffix))
445
+ try:
446
+ Image.fromarray(windows_row_plus_preds).save(debug_file)
447
+ except:
448
+ pass
449
+
450
+ im_seg_map = merge_windows(windows, window_size, ori_shape, no_softmax=decoder_features,
451
+ no_upsample=no_upsample, patch_size=model.patch_size)
452
+
453
+ seg_map = im_seg_map
454
+ if no_upsample and not decoder_features:
455
+ pass
456
+ else:
457
+ seg_map = F.interpolate(
458
+ seg_map.unsqueeze(0),
459
+ ori_shape,
460
+ mode="bilinear",
461
+ )
462
+ # seg_maps.append(seg_map)
463
+
464
+ # print('Done one inference.')
465
+ # seg_maps = torch.cat(seg_maps, dim=0)
466
+ return seg_map
467
+
468
+
469
+ def inference_features(
470
+ model,
471
+ ims,
472
+ ori_shape,
473
+ window_size,
474
+ window_stride,
475
+ batch_size,
476
+ decoder_features=False,
477
+ encoder_features=False,
478
+ save2cpu=False,
479
+ no_upsample=True,
480
+ encoder_only=False
481
+ ):
482
+ C = model.n_cls if decoder_features else model.encoder.d_model
483
+ patch_size = model.patch_size
484
+
485
+ # seg_maps = []
486
+
487
+ # for im, im_metas in zip(ims, ims_metas):
488
+ for im in ims:
489
+ im = im.to('cuda')
490
+ if len(im.shape) == 3:
491
+ im = im.unsqueeze(0)
492
+ # im = resize(im, window_size)
493
+ flip = False # im_metas["flip"]
494
+ # print(im)
495
+ windows = sliding_window(im, flip, window_size, window_stride)
496
+ # print(windows)
497
+ crops = torch.stack(windows.pop("crop"))[:, 0]
498
+ num_crops = len(crops)
499
+
500
+ WB = batch_size if batch_size > 0 else num_crops
501
+ if no_upsample:
502
+ window_size = window_size // model.patch_size
503
+ # print('Change variable window_size to {}'.format(window_size))
504
+ enc_maps = torch.zeros((num_crops, C, window_size, window_size), device=im.device)
505
+ if decoder_features:
506
+ dec_maps = torch.zeros((num_crops, C, window_size, window_size), device=im.device)
507
+ # print('Allocated segm_maps: {}, device: {}'.format(seg_maps.shape, seg_maps.device))
508
+ with torch.no_grad():
509
+ for i in range(0, num_crops, WB):
510
+ enc_fts = model.forward(crops[i: i + WB], decoder_features=decoder_features,
511
+ encoder_features=True,
512
+ no_upsample=no_upsample, encoder_only=encoder_only)
513
+ if decoder_features:
514
+ enc_fts, dec_fts = enc_fts
515
+ dec_maps[i: i + WB] = dec_fts
516
+ elif isinstance(enc_fts, tuple):
517
+ enc_fts = enc_fts[0]
518
+ enc_maps[i: i + WB] = enc_fts
519
+
520
+ windows["seg_maps"] = enc_maps
521
+ im_enc_map = merge_windows(windows, window_size, ori_shape, no_softmax=decoder_features,
522
+ no_upsample=no_upsample, patch_size=model.patch_size)
523
+
524
+ if decoder_features:
525
+ windows["seg_maps"] = dec_maps
526
+ im_dec_map = merge_windows(windows, window_size, ori_shape, no_softmax=decoder_features,
527
+ no_upsample=no_upsample, patch_size=model.patch_size)
528
+
529
+ if no_upsample:
530
+ pass
531
+ else:
532
+ im_enc_map = F.interpolate(
533
+ im_enc_map.unsqueeze(0),
534
+ ori_shape,
535
+ mode="bilinear",
536
+ )
537
+ if decoder_features:
538
+ im_dec_map = F.interpolate(
539
+ im_dec_map.unsqueeze(0),
540
+ ori_shape,
541
+ mode="bilinear",
542
+ )
543
+
544
+ im_enc_map = im_enc_map.cpu().numpy()
545
+ if decoder_features:
546
+ im_dec_map = im_dec_map.cpu().numpy()
547
+ return im_enc_map, im_dec_map
548
+
549
+ return im_enc_map
550
+
551
+
552
+ def inference_conv(
553
+ model,
554
+ ims,
555
+ ims_metas,
556
+ ori_shape
557
+ ):
558
+ assert len(ims) == 1
559
+ for im, im_metas in zip(ims, ims_metas):
560
+ im = im.to(ptu.device)
561
+ if len(im.shape) < 4:
562
+ im = im.unsqueeze(0)
563
+ logits = model(im)
564
+ if ori_shape[:2] != logits.shape[-2:]:
565
+ # resize
566
+ logits = F.interpolate(
567
+ logits,
568
+ ori_shape[-2:],
569
+ mode="bilinear",
570
+ )
571
+ # 3) applies softmax
572
+ result = F.softmax(logits.squeeze(), 0)
573
+ # print(result.shape)
574
+ return result
575
+
576
+
577
+ def num_params(model):
578
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
579
+ n_params = sum([torch.prod(torch.tensor(p.size())) for p in model_parameters])
580
+ if not type(n_params) == int:
581
+ n_params = n_params.item()
582
+ return n_params
segmenter_model/vit_dino.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from DINO
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Mostly copy-paste from timm library.
17
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
18
+ """
19
+ import math
20
+ import warnings
21
+ from functools import partial
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+
27
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
28
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
29
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
30
+ def norm_cdf(x):
31
+ # Computes standard normal cumulative distribution function
32
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
33
+
34
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
35
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
36
+ "The distribution of values may be incorrect.",
37
+ stacklevel=2)
38
+
39
+ with torch.no_grad():
40
+ # Values are generated by using a truncated uniform distribution and
41
+ # then using the inverse CDF for the normal distribution.
42
+ # Get upper and lower cdf values
43
+ l = norm_cdf((a - mean) / std)
44
+ u = norm_cdf((b - mean) / std)
45
+
46
+ # Uniformly fill tensor with values from [l, u], then translate to
47
+ # [2l-1, 2u-1].
48
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
49
+
50
+ # Use inverse cdf transform for normal distribution to get truncated
51
+ # standard normal
52
+ tensor.erfinv_()
53
+
54
+ # Transform to proper mean, std
55
+ tensor.mul_(std * math.sqrt(2.))
56
+ tensor.add_(mean)
57
+
58
+ # Clamp to ensure it's in the proper range
59
+ tensor.clamp_(min=a, max=b)
60
+ return tensor
61
+
62
+
63
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
64
+ # type: (Tensor, float, float, float, float) -> Tensor
65
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
66
+
67
+
68
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
69
+ if drop_prob == 0. or not training:
70
+ return x
71
+ keep_prob = 1 - drop_prob
72
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
73
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
74
+ random_tensor.floor_() # binarize
75
+ output = x.div(keep_prob) * random_tensor
76
+ return output
77
+
78
+
79
+ class DropPath(nn.Module):
80
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
81
+ """
82
+
83
+ def __init__(self, drop_prob=None):
84
+ super(DropPath, self).__init__()
85
+ self.drop_prob = drop_prob
86
+
87
+ def forward(self, x):
88
+ return drop_path(x, self.drop_prob, self.training)
89
+
90
+
91
+ class Mlp(nn.Module):
92
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
93
+ super().__init__()
94
+ out_features = out_features or in_features
95
+ hidden_features = hidden_features or in_features
96
+ self.fc1 = nn.Linear(in_features, hidden_features)
97
+ self.act = act_layer()
98
+ self.fc2 = nn.Linear(hidden_features, out_features)
99
+ self.drop = nn.Dropout(drop)
100
+
101
+ def forward(self, x):
102
+ x = self.fc1(x)
103
+ x = self.act(x)
104
+ x = self.drop(x)
105
+ x = self.fc2(x)
106
+ x = self.drop(x)
107
+ return x
108
+
109
+
110
+ class Attention(nn.Module):
111
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
112
+ super().__init__()
113
+ self.num_heads = num_heads
114
+ head_dim = dim // num_heads
115
+ self.scale = qk_scale or head_dim ** -0.5
116
+
117
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
118
+ self.attn_drop = nn.Dropout(attn_drop)
119
+ self.proj = nn.Linear(dim, dim)
120
+ self.proj_drop = nn.Dropout(proj_drop)
121
+
122
+ def forward(self, x):
123
+ B, N, C = x.shape
124
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
125
+ q, k, v = qkv[0], qkv[1], qkv[2]
126
+
127
+ attn = (q @ k.transpose(-2, -1)) * self.scale
128
+ attn = attn.softmax(dim=-1)
129
+ attn = self.attn_drop(attn)
130
+
131
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
132
+ x = self.proj(x)
133
+ x = self.proj_drop(x)
134
+ return x, attn
135
+
136
+
137
+ class Block(nn.Module):
138
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
139
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
140
+ super().__init__()
141
+ self.norm1 = norm_layer(dim)
142
+ self.attn = Attention(
143
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
144
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
145
+ self.norm2 = norm_layer(dim)
146
+ mlp_hidden_dim = int(dim * mlp_ratio)
147
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
148
+
149
+ def forward(self, x, return_attention=False):
150
+ y, attn = self.attn(self.norm1(x))
151
+ if return_attention:
152
+ return attn
153
+ x = x + self.drop_path(y)
154
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
155
+ return x
156
+
157
+
158
+ class PatchEmbed(nn.Module):
159
+ """ Image to Patch Embedding
160
+ """
161
+
162
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
163
+ super().__init__()
164
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
165
+ self.img_size = img_size
166
+ self.patch_size = patch_size
167
+ self.num_patches = num_patches
168
+
169
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
170
+
171
+ def forward(self, x):
172
+ B, C, H, W = x.shape
173
+ x = self.proj(x).flatten(2).transpose(1, 2)
174
+ return x
175
+
176
+
177
+ class VisionTransformer(nn.Module):
178
+ """ Vision Transformer """
179
+
180
+ def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
181
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
182
+ drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
183
+ super().__init__()
184
+ self.num_features = self.embed_dim = embed_dim
185
+
186
+ self.patch_embed = PatchEmbed(
187
+ img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
188
+ num_patches = self.patch_embed.num_patches
189
+
190
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
191
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
192
+ self.pos_drop = nn.Dropout(p=drop_rate)
193
+
194
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
195
+ self.blocks = nn.ModuleList([
196
+ Block(
197
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
198
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
199
+ for i in range(depth)])
200
+ self.norm = norm_layer(embed_dim)
201
+
202
+ # Classifier head
203
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
204
+
205
+ trunc_normal_(self.pos_embed, std=.02)
206
+ trunc_normal_(self.cls_token, std=.02)
207
+ self.apply(self._init_weights)
208
+
209
+ def _init_weights(self, m):
210
+ if isinstance(m, nn.Linear):
211
+ trunc_normal_(m.weight, std=.02)
212
+ if isinstance(m, nn.Linear) and m.bias is not None:
213
+ nn.init.constant_(m.bias, 0)
214
+ elif isinstance(m, nn.LayerNorm):
215
+ nn.init.constant_(m.bias, 0)
216
+ nn.init.constant_(m.weight, 1.0)
217
+
218
+ def interpolate_pos_encoding(self, x, w, h):
219
+ npatch = x.shape[1] - 1
220
+ N = self.pos_embed.shape[1] - 1
221
+ if npatch == N and w == h:
222
+ return self.pos_embed
223
+ class_pos_embed = self.pos_embed[:, 0]
224
+ patch_pos_embed = self.pos_embed[:, 1:]
225
+ dim = x.shape[-1]
226
+ w0 = w // self.patch_embed.patch_size
227
+ h0 = h // self.patch_embed.patch_size
228
+ # we add a small number to avoid floating point error in the interpolation
229
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
230
+ w0, h0 = w0 + 0.1, h0 + 0.1
231
+ patch_pos_embed = nn.functional.interpolate(
232
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
233
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
234
+ mode='bicubic',
235
+ )
236
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
237
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
238
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
239
+
240
+ def prepare_tokens(self, x):
241
+ B, nc, w, h = x.shape
242
+ x = self.patch_embed(x) # patch linear embedding
243
+
244
+ # add the [CLS] token to the embed patch tokens
245
+ cls_tokens = self.cls_token.expand(B, -1, -1)
246
+ x = torch.cat((cls_tokens, x), dim=1)
247
+
248
+ # add positional encoding to each token
249
+ x = x + self.interpolate_pos_encoding(x, w, h)
250
+
251
+ return self.pos_drop(x)
252
+
253
+ def forward(self, x):
254
+ x = self.prepare_tokens(x)
255
+ for blk in self.blocks:
256
+ x = blk(x)
257
+ x = self.norm(x)
258
+ return x[:, 0]
259
+
260
+ def get_last_selfattention(self, x):
261
+ x = self.prepare_tokens(x)
262
+ for i, blk in enumerate(self.blocks):
263
+ if i < len(self.blocks) - 1:
264
+ x = blk(x)
265
+ else:
266
+ # return attention of the last block
267
+ return blk(x, return_attention=True)
268
+
269
+ def get_n_last_selfattentions(self, x, layers_from_end=(1)):
270
+ x = self.prepare_tokens(x)
271
+ attentions = []
272
+ for i, blk in enumerate(self.blocks):
273
+ num_from_end = len(self.blocks) - i
274
+ if num_from_end in layers_from_end:
275
+ # get attention of the block
276
+ attn = blk(x, return_attention=True)
277
+ attentions.append(attn)
278
+ x = blk(x)
279
+ return attentions
280
+
281
+ def get_intermediate_layers(self, x, n=1):
282
+ x = self.prepare_tokens(x)
283
+ # we return the output tokens from the `n` last blocks
284
+ output = []
285
+ for i, blk in enumerate(self.blocks):
286
+ x = blk(x)
287
+ if len(self.blocks) - i <= n:
288
+ output.append(self.norm(x))
289
+ return output
290
+
291
+
292
+ def vit_tiny(patch_size=16, **kwargs):
293
+ model = VisionTransformer(
294
+ patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
295
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
296
+ return model
297
+
298
+
299
+ def vit_small(patch_size=16, **kwargs):
300
+ model = VisionTransformer(
301
+ patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
302
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
303
+ return model
304
+
305
+
306
+ def vit_base(patch_size=16, **kwargs):
307
+ model = VisionTransformer(
308
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
309
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
310
+ return model
311
+
312
+
313
+ class DINOHead(nn.Module):
314
+ def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048,
315
+ bottleneck_dim=256):
316
+ super().__init__()
317
+ nlayers = max(nlayers, 1)
318
+ if nlayers == 1:
319
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
320
+ else:
321
+ layers = [nn.Linear(in_dim, hidden_dim)]
322
+ if use_bn:
323
+ layers.append(nn.BatchNorm1d(hidden_dim))
324
+ layers.append(nn.GELU())
325
+ for _ in range(nlayers - 2):
326
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
327
+ if use_bn:
328
+ layers.append(nn.BatchNorm1d(hidden_dim))
329
+ layers.append(nn.GELU())
330
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
331
+ self.mlp = nn.Sequential(*layers)
332
+ self.apply(self._init_weights)
333
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
334
+ self.last_layer.weight_g.data.fill_(1)
335
+ if norm_last_layer:
336
+ self.last_layer.weight_g.requires_grad = False
337
+
338
+ def _init_weights(self, m):
339
+ if isinstance(m, nn.Linear):
340
+ trunc_normal_(m.weight, std=.02)
341
+ if isinstance(m, nn.Linear) and m.bias is not None:
342
+ nn.init.constant_(m.bias, 0)
343
+
344
+ def forward(self, x):
345
+ x = self.mlp(x)
346
+ x = nn.functional.normalize(x, dim=-1, p=2)
347
+ x = self.last_layer(x)
348
+ return x