gheinrich commited on
Commit
427e2df
1 Parent(s): 32045a2

Upload model

Browse files
Files changed (9) hide show
  1. README.md +3 -0
  2. adaptor_base.py +35 -0
  3. common.py +51 -0
  4. config.json +146 -39
  5. eradio_model.py +876 -499
  6. hf_model.py +55 -68
  7. input_conditioner.py +4 -4
  8. model.safetensors +3 -0
  9. radio_model.py +95 -28
README.md CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  # AM-RADIO: Reduce All Domains Into One
2
 
3
  Mike Ranzinger, Greg Heinrich, Jan Kautz, Pavlo Molchanov
 
1
+ ---
2
+ {}
3
+ ---
4
  # AM-RADIO: Reduce All Domains Into One
5
 
6
  Mike Ranzinger, Greg Heinrich, Jan Kautz, Pavlo Molchanov
adaptor_base.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+ from typing import NamedTuple
10
+
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ class AdaptorInput(NamedTuple):
17
+ images: torch.Tensor
18
+ summary: torch.Tensor
19
+ features: torch.Tensor
20
+
21
+
22
+ class RadioOutput(NamedTuple):
23
+ summary: torch.Tensor
24
+ features: torch.Tensor
25
+
26
+ def to(self, *args, **kwargs):
27
+ return RadioOutput(
28
+ self.summary.to(*args, **kwargs) if self.summary is not None else None,
29
+ self.features.to(*args, **kwargs) if self.features is not None else None,
30
+ )
31
+
32
+
33
+ class AdaptorBase(nn.Module):
34
+ def forward(self, input: AdaptorInput) -> RadioOutput:
35
+ raise NotImplementedError("Subclasses must implement this!")
common.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from dataclasses import dataclass
10
+
11
+ from .radio_model import Resolution
12
+
13
+
14
+ @dataclass
15
+ class RadioResource:
16
+ url: str
17
+ patch_size: int
18
+ max_resolution: int
19
+ preferred_resolution: Resolution
20
+
21
+
22
+ RESOURCE_MAP = {
23
+ # RADIO
24
+ "radio_v2.1": RadioResource(
25
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.1_bf16.pth.tar?download=true",
26
+ patch_size=16,
27
+ max_resolution=2048,
28
+ preferred_resolution=Resolution(432, 432),
29
+ ),
30
+ "radio_v2": RadioResource(
31
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.pth.tar?download=true",
32
+ patch_size=16,
33
+ max_resolution=2048,
34
+ preferred_resolution=Resolution(432, 432),
35
+ ),
36
+ "radio_v1": RadioResource(
37
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v1.pth.tar?download=true",
38
+ patch_size=14,
39
+ max_resolution=1050,
40
+ preferred_resolution=Resolution(378, 378),
41
+ ),
42
+ # E-RADIO
43
+ "e-radio_v2": RadioResource(
44
+ "https://huggingface.co/nvidia/RADIO/resolve/main/eradio_v2.pth.tar?download=true",
45
+ patch_size=16,
46
+ max_resolution=2048,
47
+ preferred_resolution=Resolution(512, 512),
48
+ ),
49
+ }
50
+
51
+ DEFAULT_VERSION = "radio_v2.1"
config.json CHANGED
@@ -1,6 +1,7 @@
1
  {
 
2
  "architectures": [
3
- "ERADIOModel"
4
  ],
5
  "args": {
6
  "aa": null,
@@ -9,16 +10,18 @@
9
  "amp_impl": "native",
10
  "aug_repeats": 0,
11
  "aug_splits": 0,
 
12
  "batch_size": 32,
13
  "bn_eps": null,
14
  "bn_momentum": null,
15
- "cache": "/lustre/fs3/portfolios/llmservice/users/gheinrich/cache/",
16
  "cache_dir": null,
17
  "channels_last": false,
18
- "checkpoint_hist": 3,
 
19
  "class_map": "",
20
  "clip_grad": null,
21
  "clip_mode": "norm",
 
22
  "coco_annotations_file": "/datasets/coco2017-adlsa/annotations/captions_val2017.json",
23
  "coco_image_dir": "/datasets/coco2017-adlsa/val2017",
24
  "color_jitter": 0.4,
@@ -29,9 +32,20 @@
29
  "crop_pct": null,
30
  "cutmix": 0.0,
31
  "cutmix_minmax": null,
32
- "data_dir": "/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/classification/imagenet-21k/webdataset",
 
 
 
 
 
 
 
 
 
33
  "dataset": "nvgpt4",
34
  "dataset_download": false,
 
 
35
  "debug_full_knn": false,
36
  "decay_epochs": 90,
37
  "decay_milestones": [
@@ -41,14 +55,16 @@
41
  ],
42
  "decay_rate": 0.1,
43
  "device": "cuda:0",
44
- "dist_bn": "reduce",
45
  "distributed": true,
46
  "drop": 0.0,
47
  "drop_block": null,
48
  "drop_connect": null,
49
  "drop_path": null,
 
50
  "epoch_repeats": 0.0,
51
- "epochs": 300,
 
52
  "eval_metric": "knn_top1",
53
  "eval_teacher": false,
54
  "eval_teacher_only": false,
@@ -57,15 +73,19 @@
57
  "fast_norm": false,
58
  "feature_summarizer": "cls_token",
59
  "feature_upscale_factor": null,
 
 
 
60
  "fuser": "",
61
  "gp": null,
62
  "grad_accum_steps": 1,
63
  "grad_checkpointing": false,
64
  "head_init_bias": null,
65
  "head_init_scale": null,
 
66
  "hflip": 0.5,
67
  "img_size": null,
68
- "in_chans": 3,
69
  "initial_checkpoint": "",
70
  "input_size": null,
71
  "interpolation": "",
@@ -75,6 +95,7 @@
75
  "log_mlflow": false,
76
  "log_wandb": true,
77
  "loss": "cosine",
 
78
  "lr": 0.001,
79
  "lr_base": 0.1,
80
  "lr_base_scale": "",
@@ -87,34 +108,45 @@
87
  "lr_noise_pct": 0.67,
88
  "lr_noise_std": 1.0,
89
  "mean": null,
 
90
  "min_lr": 0,
91
  "mixup": 0.0,
92
  "mixup_mode": "batch",
93
  "mixup_off_epoch": 0,
94
  "mixup_prob": 1.0,
95
  "mixup_switch_prob": 0.5,
96
- "mlp_hidden_size": 1024,
97
- "model": "fastervit2_large_fullres",
98
- "model_ema": false,
99
- "model_ema_decay": 0.9998,
100
- "model_ema_force_cpu": false,
 
 
 
 
 
 
 
 
 
101
  "model_kwargs": {
102
  "return_full_features": true
103
  },
 
104
  "momentum": 0.9,
105
  "no_aug": false,
106
  "no_ddp_bb": false,
107
  "no_prefetcher": false,
108
  "no_resume_opt": false,
109
- "num_classes": 0,
110
  "opt": "fusedlamb",
111
  "opt_betas": null,
112
  "opt_eps": null,
113
  "opt_kwargs": {},
114
- "output": "/lustre/fs3/portfolios/llmservice/users/gheinrich/results/evfm/19-11-23-fastervit2-l-fullres",
115
  "patience_epochs": 10,
116
  "pin_mem": false,
117
- "prefetcher": false,
118
  "pretrained": false,
119
  "rank": 0,
120
  "ratio": [
@@ -123,11 +155,11 @@
123
  ],
124
  "recount": 1,
125
  "recovery_interval": 0,
 
126
  "remode": "pixel",
127
  "reprob": 0.0,
128
  "resplit": false,
129
- "resume": "/lustre/fs3/portfolios/llmservice/users/gheinrich/results/evfm/19-11-23-fastervit2-l-fullres/checkpoints/last.pth.tar",
130
- "return_full_features": true,
131
  "save_images": false,
132
  "scale": [
133
  0.5,
@@ -137,28 +169,96 @@
137
  "sched_on_updates": true,
138
  "seed": 42,
139
  "smoothing": 0.1,
 
140
  "split_bn": false,
141
  "start_epoch": null,
142
  "std": null,
143
  "steps_per_epoch": 2000,
144
- "sync_bn": false,
145
- "synchronize_step": false,
146
  "teachers": [
147
  {
 
 
148
  "batch_size": 32,
149
- "config": "open_clip_vit-h-14_res224.yaml",
 
 
 
 
 
 
 
 
 
150
  "fd_loss_weight": 1.0,
 
151
  "feature_distillation": true,
152
- "sample_rate": 8,
153
- "summary_loss_weight": 1.0
 
 
 
 
 
 
 
154
  },
155
  {
 
 
156
  "batch_size": 32,
157
- "config": "dinov2_vit-g-14_res224.yaml",
158
- "fd_loss_weight": 4.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  "feature_distillation": true,
160
- "sample_rate": 8,
161
- "summary_loss_weight": 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  }
163
  ],
164
  "torchcompile": null,
@@ -169,30 +269,37 @@
169
  "use_coco": false,
170
  "use_multi_epochs_loader": false,
171
  "val_data_dir": "/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/classification/imagenet-1k/webdataset",
172
- "val_img_size": 224,
 
 
173
  "val_split": "val",
174
- "validation_batch_size": 32,
175
  "vflip": 0.0,
176
  "wandb_entity": "",
177
- "wandb_group": "backbones",
178
  "wandb_job_type": "",
179
  "wandb_name": "",
180
  "wandb_project": "",
181
- "warmup_epochs": 2.5,
182
  "warmup_lr": 1e-05,
183
  "warmup_prefix": false,
184
- "weight_decay": 2e-05,
185
  "worker_seeding": "all",
186
- "workers": 4,
187
- "world_size": 32
188
  },
189
  "auto_map": {
190
- "AutoConfig": "hf_model.ERADIOConfig",
191
- "AutoModel": "hf_model.ERADIOModel"
192
  },
193
- "return_spatial_features": true,
194
- "return_summary": true,
 
 
 
 
195
  "torch_dtype": "float32",
196
- "transformers_version": "4.29.0",
197
- "version": "v1"
 
198
  }
 
1
  {
2
+ "adaptor_names": null,
3
  "architectures": [
4
+ "RADIOModel"
5
  ],
6
  "args": {
7
  "aa": null,
 
10
  "amp_impl": "native",
11
  "aug_repeats": 0,
12
  "aug_splits": 0,
13
+ "auto_loss_balance_mode": "manual",
14
  "batch_size": 32,
15
  "bn_eps": null,
16
  "bn_momentum": null,
 
17
  "cache_dir": null,
18
  "channels_last": false,
19
+ "checkpoint_hist": 10,
20
+ "chk_keep_forever": 10,
21
  "class_map": "",
22
  "clip_grad": null,
23
  "clip_mode": "norm",
24
+ "cls_token_per_teacher": false,
25
  "coco_annotations_file": "/datasets/coco2017-adlsa/annotations/captions_val2017.json",
26
  "coco_image_dir": "/datasets/coco2017-adlsa/val2017",
27
  "color_jitter": 0.4,
 
32
  "crop_pct": null,
33
  "cutmix": 0.0,
34
  "cutmix_minmax": null,
35
+ "data_dir": [
36
+ [
37
+ "/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/captioning/datacomp/dc1b/stage2",
38
+ 0.95
39
+ ],
40
+ [
41
+ "/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/segmentation/sam/stage1",
42
+ 0.05
43
+ ]
44
+ ],
45
  "dataset": "nvgpt4",
46
  "dataset_download": false,
47
+ "ddp_comm_fp16": false,
48
+ "ddp_comm_power_sgd": false,
49
  "debug_full_knn": false,
50
  "decay_epochs": 90,
51
  "decay_milestones": [
 
55
  ],
56
  "decay_rate": 0.1,
57
  "device": "cuda:0",
58
+ "dist_bn": "",
59
  "distributed": true,
60
  "drop": 0.0,
61
  "drop_block": null,
62
  "drop_connect": null,
63
  "drop_path": null,
64
+ "dtype": "float32",
65
  "epoch_repeats": 0.0,
66
+ "epochs": 50,
67
+ "eval": false,
68
  "eval_metric": "knn_top1",
69
  "eval_teacher": false,
70
  "eval_teacher_only": false,
 
73
  "fast_norm": false,
74
  "feature_summarizer": "cls_token",
75
  "feature_upscale_factor": null,
76
+ "force_new_wandb_id": false,
77
+ "force_spectral_reparam": false,
78
+ "freeze_bn": false,
79
  "fuser": "",
80
  "gp": null,
81
  "grad_accum_steps": 1,
82
  "grad_checkpointing": false,
83
  "head_init_bias": null,
84
  "head_init_scale": null,
85
+ "head_warmup": 10,
86
  "hflip": 0.5,
87
  "img_size": null,
88
+ "in_chans": null,
89
  "initial_checkpoint": "",
90
  "input_size": null,
91
  "interpolation": "",
 
95
  "log_mlflow": false,
96
  "log_wandb": true,
97
  "loss": "cosine",
98
+ "loss_auto_balance": false,
99
  "lr": 0.001,
100
  "lr_base": 0.1,
101
  "lr_base_scale": "",
 
108
  "lr_noise_pct": 0.67,
109
  "lr_noise_std": 1.0,
110
  "mean": null,
111
+ "mesa": false,
112
  "min_lr": 0,
113
  "mixup": 0.0,
114
  "mixup_mode": "batch",
115
  "mixup_off_epoch": 0,
116
  "mixup_prob": 1.0,
117
  "mixup_switch_prob": 0.5,
118
+ "mlp_hidden_size": 1520,
119
+ "mlp_num_inner": 1,
120
+ "mlp_version": "v2",
121
+ "model": "eradio",
122
+ "model_ema": {
123
+ "decay": 0.9998,
124
+ "force_cpu": false,
125
+ "power": false,
126
+ "power_stds": [
127
+ 0.05,
128
+ 0.1
129
+ ],
130
+ "start_epoch": 2
131
+ },
132
  "model_kwargs": {
133
  "return_full_features": true
134
  },
135
+ "model_norm": false,
136
  "momentum": 0.9,
137
  "no_aug": false,
138
  "no_ddp_bb": false,
139
  "no_prefetcher": false,
140
  "no_resume_opt": false,
141
+ "num_classes": null,
142
  "opt": "fusedlamb",
143
  "opt_betas": null,
144
  "opt_eps": null,
145
  "opt_kwargs": {},
146
+ "output": "/lustre/fs6/portfolios/llmservice/users/mranzinger/output/evfm/eradio/n8_3-25-24_eradio_stage3-alt_s2ep77",
147
  "patience_epochs": 10,
148
  "pin_mem": false,
149
+ "prefetcher": true,
150
  "pretrained": false,
151
  "rank": 0,
152
  "ratio": [
 
155
  ],
156
  "recount": 1,
157
  "recovery_interval": 0,
158
+ "register_multiple": 0,
159
  "remode": "pixel",
160
  "reprob": 0.0,
161
  "resplit": false,
162
+ "resume": "/lustre/fs6/portfolios/llmservice/users/mranzinger/output/evfm/eradio/n8_3-25-24_eradio_stage3-alt_s2ep77/checkpoints/last.pth.tar",
 
163
  "save_images": false,
164
  "scale": [
165
  0.5,
 
169
  "sched_on_updates": true,
170
  "seed": 42,
171
  "smoothing": 0.1,
172
+ "spectral_reparam": false,
173
  "split_bn": false,
174
  "start_epoch": null,
175
  "std": null,
176
  "steps_per_epoch": 2000,
177
+ "sync_bn": true,
178
+ "synchronize_step": true,
179
  "teachers": [
180
  {
181
+ "amp": true,
182
+ "amp_dtype": "bfloat16",
183
  "batch_size": 32,
184
+ "data_dir": [
185
+ [
186
+ "/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/captioning/datacomp/dc1b/stage2",
187
+ 0.95
188
+ ],
189
+ [
190
+ "/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/segmentation/sam/stage1",
191
+ 0.05
192
+ ]
193
+ ],
194
  "fd_loss_weight": 1.0,
195
+ "fd_normalize": false,
196
  "feature_distillation": true,
197
+ "input_size": 378,
198
+ "model": "ViT-H-14-378-quickgelu",
199
+ "name": "clip",
200
+ "pretrained": "dfn5b",
201
+ "sample_rate": 32,
202
+ "student_resolution": 512,
203
+ "summary_loss_weight": 1.0,
204
+ "torchcompile": true,
205
+ "type": "open_clip"
206
  },
207
  {
208
+ "amp": true,
209
+ "amp_dtype": "bfloat16",
210
  "batch_size": 32,
211
+ "fd_loss_weight": 1.5,
212
+ "fd_normalize": false,
213
+ "feature_distillation": true,
214
+ "input_size": 224,
215
+ "model": "dinov2_vitg14_reg",
216
+ "name": "dino_v2",
217
+ "sample_rate": 32,
218
+ "summary_loss_weight": 1.0,
219
+ "torchcompile": true,
220
+ "type": "dino_v2"
221
+ },
222
+ {
223
+ "amp": true,
224
+ "amp_dtype": "bfloat16",
225
+ "batch_size": 2,
226
+ "fd_loss_fn": "MSE",
227
+ "fd_loss_weight": 0.13,
228
+ "fd_normalize": false,
229
  "feature_distillation": true,
230
+ "input_size": 448,
231
+ "model": "dinov2_vitl14_reg",
232
+ "name": "dino_v2_large",
233
+ "sample_rate": 2,
234
+ "student_resolution": 1024,
235
+ "summary_loss_weight": 1e-05,
236
+ "torchcompile": true,
237
+ "type": "dino_v2",
238
+ "use_summary": true
239
+ },
240
+ {
241
+ "amp": false,
242
+ "batch_size": 2,
243
+ "data_dir": [
244
+ [
245
+ "/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/segmentation/sam/stage1",
246
+ 0.4
247
+ ]
248
+ ],
249
+ "fd_loss_fn": "MSE",
250
+ "fd_loss_weight": 0.13,
251
+ "fd_normalize": false,
252
+ "fd_ohem": false,
253
+ "feature_distillation": true,
254
+ "input_size": 1024,
255
+ "model": "vit-h",
256
+ "name": "sam",
257
+ "sample_rate": 2,
258
+ "student_resolution": 1024,
259
+ "summary_loss_weight": 1e-05,
260
+ "type": "sam",
261
+ "use_summary": false
262
  }
263
  ],
264
  "torchcompile": null,
 
269
  "use_coco": false,
270
  "use_multi_epochs_loader": false,
271
  "val_data_dir": "/lustre/fsw/portfolios/llmservice/projects/llmservice_nlp_fm/datasets/classification/imagenet-1k/webdataset",
272
+ "val_ema_only": false,
273
+ "val_img_size": 512,
274
+ "val_jobs_script": "run_validation_jobs_eradio.sh",
275
  "val_split": "val",
276
+ "validation_batch_size": 128,
277
  "vflip": 0.0,
278
  "wandb_entity": "",
279
+ "wandb_group": "eradio",
280
  "wandb_job_type": "",
281
  "wandb_name": "",
282
  "wandb_project": "",
283
+ "warmup_epochs": 0.001,
284
  "warmup_lr": 1e-05,
285
  "warmup_prefix": false,
286
+ "weight_decay": 0.0002,
287
  "worker_seeding": "all",
288
+ "workers": 10,
289
+ "world_size": 64
290
  },
291
  "auto_map": {
292
+ "AutoConfig": "hf_model.RADIOConfig",
293
+ "AutoModel": "hf_model.RADIOModel"
294
  },
295
+ "max_resolution": 2048,
296
+ "patch_size": 16,
297
+ "preferred_resolution": [
298
+ 512,
299
+ 512
300
+ ],
301
  "torch_dtype": "float32",
302
+ "transformers_version": "4.37.2",
303
+ "version": "e-radio_v2",
304
+ "vitdet_window_size": null
305
  }
eradio_model.py CHANGED
@@ -12,9 +12,11 @@
12
  # Mike Ranzinger, Greg Heinrich, Jan Kautz, and Pavlo Molchanov. "AM-RADIO: Agglomerative Model--Reduce All Domains Into One." arXiv preprint arXiv:2312.06709 (2023).
13
 
14
  # based on FasterViT, Swin Transformer, YOLOv8
 
15
  # FasterViT:
16
  # Ali Hatamizadeh, Greg Heinrich, Hongxu Yin, Andrew Tao, Jose M. Alvarez, Jan Kautz, and Pavlo Molchanov. "FasterViT: Fast Vision Transformers with Hierarchical Attention." arXiv preprint arXiv:2306.06189 (2023).
17
 
 
18
  import torch
19
  import torch.nn as nn
20
  from timm.models.registry import register_model
@@ -22,10 +24,9 @@ from timm.models.registry import register_model
22
  from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
23
  import numpy as np
24
  import torch.nn.functional as F
 
25
  import warnings
26
 
27
- SIMPLER_UP_TOWER = False
28
-
29
  #######################
30
  ## Codebase from YOLOv8
31
  ## BEGINNING
@@ -96,16 +97,17 @@ class Conv(nn.Module):
96
  @torch.no_grad()
97
  def switch_to_deploy(self):
98
  # return 1
99
- c, bn = self.conv, self.bn
100
- w = bn.weight / (bn.running_var + bn.eps) ** 0.5
101
- w = c.weight * w[:, None, None, None]
102
- b = bn.bias - bn.running_mean * bn.weight / \
103
- (bn.running_var + bn.eps)**0.5
 
104
 
105
- self.conv.weight.data.copy_(w)
106
- self.conv.bias = nn.Parameter(b)
107
 
108
- self.bn = nn.Identity()
109
 
110
  def autopad(k, p=None, d=1): # kernel, padding, dilation
111
  """Pad to 'same' shape outputs."""
@@ -121,16 +123,10 @@ def autopad(k, p=None, d=1): # kernel, padding, dilation
121
  ## END
122
  #######################
123
 
124
-
125
  def pixel_unshuffle(data, factor=2):
126
  # performs nn.PixelShuffle(factor) in reverse, torch has some bug for ONNX and TRT, so doing it manually
127
  B, C, H, W = data.shape
128
- return (
129
- data.view(B, C, factor, H // factor, factor, W // factor)
130
- .permute(0, 1, 2, 4, 3, 5)
131
- .reshape(B, -1, H // factor, W // factor)
132
- )
133
-
134
 
135
  class SwiGLU(nn.Module):
136
  # should be more advanced, but doesnt improve results so far
@@ -141,6 +137,7 @@ class SwiGLU(nn.Module):
141
 
142
  def window_partition(x, window_size):
143
  """
 
144
  Args:
145
  x: (B, C, H, W)
146
  window_size: window size
@@ -150,50 +147,35 @@ def window_partition(x, window_size):
150
  """
151
  B, C, H, W = x.shape
152
 
153
- if window_size == 0 or (window_size == H and window_size == W):
154
  windows = x.flatten(2).transpose(1, 2)
155
  Hp, Wp = H, W
156
  else:
157
  pad_h = (window_size - H % window_size) % window_size
158
  pad_w = (window_size - W % window_size) % window_size
159
- #interpolate features
160
  if pad_h > 0 or pad_w > 0:
161
- x = F.pad(x, (0, pad_w, 0, pad_h, 0, 0, 0, 0))
162
  Hp, Wp = H + pad_h, W + pad_w
163
 
164
  x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)
165
- windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size * window_size, C)
166
 
167
  return windows, (Hp, Wp)
168
 
169
-
170
  class Conv2d_BN(nn.Module):
171
- """
172
  Conv2d + BN layer with folding capability to speed up inference
173
- """
174
-
175
- def __init__(
176
- self,
177
- a,
178
- b,
179
- kernel_size=1,
180
- stride=1,
181
- padding=0,
182
- dilation=1,
183
- groups=1,
184
- bn_weight_init=1,
185
- bias=False,
186
- ):
187
  super().__init__()
188
- self.conv = torch.nn.Conv2d(
189
- a, b, kernel_size, stride, padding, dilation, groups, bias=False
190
- )
191
  if 1:
192
  self.bn = torch.nn.BatchNorm2d(b)
193
  torch.nn.init.constant_(self.bn.weight, bn_weight_init)
194
  torch.nn.init.constant_(self.bn.bias, 0)
195
 
196
- def forward(self, x):
197
  x = self.conv(x)
198
  x = self.bn(x)
199
  return x
@@ -204,14 +186,17 @@ class Conv2d_BN(nn.Module):
204
  c, bn = self.conv, self.bn
205
  w = bn.weight / (bn.running_var + bn.eps) ** 0.5
206
  w = c.weight * w[:, None, None, None]
207
- b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
 
208
  self.conv.weight.data.copy_(w)
209
  self.conv.bias = nn.Parameter(b)
210
  self.bn = nn.Identity()
211
 
212
 
 
213
  def window_reverse(windows, window_size, H, W, pad_hw):
214
  """
 
215
  Args:
216
  windows: local window features (num_windows*B, window_size, window_size, C)
217
  window_size: Window size
@@ -224,22 +209,21 @@ def window_reverse(windows, window_size, H, W, pad_hw):
224
  """
225
  # print(f"window_reverse, windows.shape {windows.shape}")
226
  Hp, Wp = pad_hw
227
- if window_size == 0 or (window_size == H and window_size == W):
228
  B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
229
  x = windows.transpose(1, 2).view(B, -1, H, W)
230
  else:
231
  B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
232
- x = windows.view(
233
- B, Hp // window_size, Wp // window_size, window_size, window_size, -1
234
- )
235
- x = x.permute(0, 5, 1, 3, 2, 4).reshape(B, windows.shape[2], Hp, Wp)
236
 
237
  if Hp > H or Wp > W:
238
- x = x[:, :, :H, :W,].contiguous()
239
 
240
  return x
241
 
242
 
 
243
  class PosEmbMLPSwinv2D(nn.Module):
244
  """
245
  2D positional embedding from Swin Transformer v2
@@ -276,7 +260,6 @@ class PosEmbMLPSwinv2D(nn.Module):
276
 
277
  def relative_bias_initialization(self, window_size, num_heads, pretrained_window_size, seq_length, no_log):
278
  # as in separate function to support window size chage after model weights loading
279
-
280
  relative_coords_h = torch.arange(
281
  -(window_size[0] - 1), window_size[0], dtype=torch.float32
282
  )
@@ -349,7 +332,7 @@ class PosEmbMLPSwinv2D(nn.Module):
349
  self.relative_bias = relative_bias.to(self.relative_bias.device)
350
 
351
  if self.deploy and self.grid_exists:
352
- input_tensor += self.relative_bias
353
  return input_tensor
354
 
355
  if 1:
@@ -373,38 +356,39 @@ class PosEmbMLPSwinv2D(nn.Module):
373
 
374
  self.relative_bias = relative_position_bias.unsqueeze(0)
375
 
376
- input_tensor += self.relative_bias
377
  return input_tensor
378
 
379
 
380
  class GRAAttentionBlock(nn.Module):
381
- def __init__(
382
- self,
383
- window_size,
384
- dim_in,
385
- dim_out,
386
- num_heads,
387
- drop_path=0.0,
388
- qk_scale=None,
389
- qkv_bias=False,
390
- norm_layer=nn.LayerNorm,
391
- layer_scale=None,
392
- use_swiglu=True,
393
- subsample_ratio=1,
394
- dim_ratio=1,
395
- conv_base=False,
396
- do_windowing=True,
397
- multi_query=False,
398
- cpb_mlp_hidden=512,
399
- ) -> None:
400
  super().__init__()
401
 
 
402
 
403
  self.do_windowing = do_windowing
 
 
 
404
 
405
  if do_windowing:
406
  if conv_base:
407
  self.downsample_op = nn.Conv2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
 
 
408
  self.downsample_mixer = nn.Identity()
409
  self.upsample_mixer = nn.Identity()
410
  self.upsample_op = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
@@ -414,6 +398,20 @@ class GRAAttentionBlock(nn.Module):
414
  self.upsample_mixer = nn.Upsample(scale_factor=subsample_ratio, mode='nearest') if subsample_ratio > 1 else nn.Identity()
415
  self.upsample_op = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1, padding=0, bias=False) if subsample_ratio > 1 else nn.Identity()
416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  self.window_size = window_size
418
 
419
  self.norm1 = norm_layer(dim_in)
@@ -423,7 +421,7 @@ class GRAAttentionBlock(nn.Module):
423
  num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
424
  resolution=window_size,
425
  seq_length=window_size**2, dim_out=dim_in, multi_query=multi_query,
426
- cpb_mlp_hidden=cpb_mlp_hidden)
427
 
428
  self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
429
 
@@ -446,83 +444,103 @@ class GRAAttentionBlock(nn.Module):
446
 
447
  def forward(self, x):
448
  skip_connection = x
 
 
 
 
 
 
449
 
450
  if self.do_windowing:
451
  # performing windowing if required
452
  x = self.downsample_op(x)
453
  x = self.downsample_mixer(x)
454
 
455
- if self.window_size > 0:
456
  H, W = x.shape[2], x.shape[3]
457
 
 
 
 
 
458
  x, pad_hw = window_partition(x, self.window_size)
459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  # window attention
461
- x = x + self.drop_path1(self.gamma1 * self.attn(self.norm1(x)))
462
  # mlp layer
463
- x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
464
 
465
  if self.do_windowing:
466
  if self.window_size > 0:
467
  x = window_reverse(x, self.window_size, H, W, pad_hw)
468
 
 
 
 
 
 
469
  x = self.upsample_mixer(x)
470
  x = self.upsample_op(x)
471
 
472
- if (
473
- x.shape[2] != skip_connection.shape[2]
474
- or x.shape[3] != skip_connection.shape[3]
475
- ):
476
- x = torch.nn.functional.pad(
477
- x,
478
- (
479
- 0,
480
- -x.shape[3] + skip_connection.shape[3],
481
- 0,
482
- -x.shape[2] + skip_connection.shape[2],
483
- ),
484
- )
485
  # need to add skip connection because downsampling and upsampling will break residual connection
486
  # 0.5 is needed to make sure that the skip connection is not too strong
487
  # in case of no downsample / upsample we can show that 0.5 compensates for the residual connection
488
  x = 0.5 * x + 0.5 * skip_connection
489
-
490
  return x
491
 
492
 
 
 
493
  class MultiResolutionAttention(nn.Module):
494
  """
495
  MultiResolutionAttention (MRA) module
496
  The idea is to use multiple attention blocks with different resolution
497
  Feature maps are downsampled / upsampled for each attention block on different blocks
498
- Every attention block supports
499
-
500
  """
501
 
502
- def __init__(
503
- self,
504
- window_size,
505
- sr_ratio,
506
- dim,
507
- dim_ratio,
508
- num_heads,
509
- do_windowing=True,
510
- layer_scale=1e-5,
511
- norm_layer=nn.LayerNorm,
512
- drop_path=0,
513
- qkv_bias=False,
514
- qk_scale=1.0,
515
- use_swiglu=True,
516
- multi_query=False,
517
- conv_base=False,
518
- cpb_mlp_hidden=512
519
- ) -> None:
520
  """
521
  Args:
522
  input_resolution: input image resolution
523
  window_size: window size
524
  compression_ratio: compression ratio
525
  max_depth: maximum depth of the GRA module
 
526
  """
527
  super().__init__()
528
 
@@ -530,6 +548,7 @@ class MultiResolutionAttention(nn.Module):
530
 
531
  self.attention_blocks = nn.ModuleList()
532
 
 
533
  for i in range(depth):
534
  subsample_ratio = sr_ratio[i]
535
  if len(window_size) > i:
@@ -537,26 +556,14 @@ class MultiResolutionAttention(nn.Module):
537
  else:
538
  window_size_local = window_size[0]
539
 
540
- self.attention_blocks.append(
541
- GRAAttentionBlock(
542
- window_size=window_size_local,
543
- dim_in=dim,
544
- dim_out=dim,
545
- num_heads=num_heads,
546
- qkv_bias=qkv_bias,
547
- qk_scale=qk_scale,
548
- norm_layer=norm_layer,
549
- layer_scale=layer_scale,
550
- drop_path=drop_path,
551
- use_swiglu=use_swiglu,
552
- subsample_ratio=subsample_ratio,
553
- dim_ratio=dim_ratio,
554
- do_windowing=do_windowing,
555
- multi_query=multi_query,
556
- conv_base=conv_base,
557
- cpb_mlp_hidden=cpb_mlp_hidden
558
- ),
559
- )
560
 
561
  def forward(self, x):
562
 
@@ -566,20 +573,19 @@ class MultiResolutionAttention(nn.Module):
566
  return x
567
 
568
 
 
569
  class Mlp(nn.Module):
570
  """
571
  Multi-Layer Perceptron (MLP) block
572
  """
573
 
574
- def __init__(
575
- self,
576
- in_features,
577
- hidden_features=None,
578
- out_features=None,
579
- act_layer=nn.GELU,
580
- use_swiglu=True,
581
- drop=0.0,
582
- ):
583
  """
584
  Args:
585
  in_features: input features dimension.
@@ -592,9 +598,7 @@ class Mlp(nn.Module):
592
  super().__init__()
593
  out_features = out_features or in_features
594
  hidden_features = hidden_features or in_features
595
- self.fc1 = nn.Linear(
596
- in_features, hidden_features * (2 if use_swiglu else 1), bias=False
597
- )
598
  self.act = act_layer()
599
  self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
600
 
@@ -607,21 +611,20 @@ class Mlp(nn.Module):
607
  x = x.view(x_size)
608
  return x
609
 
610
-
611
  class Downsample(nn.Module):
612
  """
613
  Down-sampling block
614
-
615
  Pixel Unshuffle is used for down-sampling, works great accuracy - wise but takes 10% more TRT time
616
  """
617
 
618
- def __init__(
619
- self, dim, shuffle=False,
620
- ):
 
621
  """
622
  Args:
623
  dim: feature size dimension.
624
- shuffle: idea with pixel unshuffling instead for resizing
625
  keep_dim: bool argument for maintaining the resolution.
626
  """
627
 
@@ -630,11 +633,16 @@ class Downsample(nn.Module):
630
 
631
  if shuffle:
632
  self.norm = lambda x: pixel_unshuffle(x, factor=2)
633
- self.reduction = Conv2d_BN(dim * 4, dim_out, 1, 1, 0, bias=False)
 
634
  else:
 
 
 
635
  self.norm = nn.Identity()
636
  self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
637
 
 
638
  def forward(self, x):
639
  x = self.norm(x)
640
  x = self.reduction(x)
@@ -645,7 +653,6 @@ class PatchEmbed(nn.Module):
645
  """
646
  Patch embedding block
647
  Used to convert image into an initial set of feature maps with lower resolution
648
-
649
  """
650
 
651
  def __init__(self, in_chans=3, in_dim=64, dim=96, shuffle_down=False):
@@ -665,13 +672,13 @@ class PatchEmbed(nn.Module):
665
  Conv2d_BN(in_chans, in_dim, 3, 2, 1, bias=False),
666
  nn.ReLU(),
667
  Conv2d_BN(in_dim, dim, 3, 2, 1, bias=False),
668
- nn.ReLU(),
669
- )
670
  else:
671
  self.proj = lambda x: pixel_unshuffle(x, factor=4)
672
- self.conv_down = nn.Sequential(
673
- Conv2d_BN(in_chans * 16, dim, 3, 1, 1), nn.ReLU(),
674
- )
675
 
676
  def forward(self, x):
677
  x = self.proj(x)
@@ -679,6 +686,7 @@ class PatchEmbed(nn.Module):
679
  return x
680
 
681
 
 
682
  class ConvBlock(nn.Module):
683
  """
684
  Convolutional block, used in first couple of stages
@@ -722,22 +730,12 @@ class WindowAttention(nn.Module):
722
  # Windowed Attention from SwinV2
723
  # use a MLP trick to deal with various input image resolutions, then fold it to improve speed
724
 
725
- def __init__(
726
- self,
727
- dim,
728
- num_heads=8,
729
- qkv_bias=False,
730
- qk_scale=None,
731
- resolution=0,
732
- seq_length=0,
733
- dim_out=None,
734
- multi_query=False,
735
- cpb_mlp_hidden=512,
736
- ):
737
  # taken from EdgeViT and tweaked with attention bias.
738
  super().__init__()
739
- if not dim_out:
740
- dim_out = dim
741
  self.multi_query = multi_query
742
  self.num_heads = num_heads
743
  head_dim = dim // num_heads
@@ -749,39 +747,29 @@ class WindowAttention(nn.Module):
749
  if not multi_query:
750
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
751
  else:
752
- self.qkv = nn.Linear(dim, dim + 2 * self.head_dim, bias=qkv_bias)
753
 
754
  self.proj = nn.Linear(dim, dim_out, bias=False)
755
  # attention positional bias
756
- self.pos_emb_funct = PosEmbMLPSwinv2D(
757
- window_size=[resolution, resolution],
758
- pretrained_window_size=[resolution, resolution],
759
- num_heads=num_heads,
760
- seq_length=seq_length,
761
- cpb_mlp_hidden=cpb_mlp_hidden,
762
- )
763
 
764
  self.resolution = resolution
765
 
766
- def forward(self, x):
767
  B, N, C = x.shape
768
 
769
  if not self.multi_query:
770
- qkv = (
771
- self.qkv(x)
772
- .reshape(B, -1, 3, self.num_heads, C // self.num_heads)
773
- .permute(2, 0, 3, 1, 4)
774
- )
775
  q, k, v = qkv[0], qkv[1], qkv[2]
776
  else:
777
  qkv = self.qkv(x)
778
- (q, k, v) = qkv.split(
779
- [self.dim_internal, self.head_dim, self.head_dim], dim=2
780
- )
781
 
782
- q = q.reshape(B, -1, self.num_heads, C // self.num_heads).permute(
783
- 0, 2, 1, 3
784
- )
785
  k = k.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
786
  v = v.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
787
 
@@ -789,40 +777,50 @@ class WindowAttention(nn.Module):
789
 
790
  attn = self.pos_emb_funct(attn)
791
 
 
 
 
 
 
 
792
  attn = attn.softmax(dim=-1)
793
  x = (attn @ v).transpose(1, 2).reshape(B, -1, C)
794
  x = self.proj(x)
795
  return x
796
 
797
 
 
798
  class FasterViTLayer(nn.Module):
799
  """
800
  fastervitlayer
801
  """
802
 
803
- def __init__(
804
- self,
805
- dim,
806
- depth,
807
- num_heads,
808
- window_size,
809
- conv=False,
810
- downsample=True,
811
- mlp_ratio=4.0,
812
- qkv_bias=False,
813
- qk_scale=None,
814
- norm_layer=nn.LayerNorm,
815
- drop_path=0.0,
816
- layer_scale=None,
817
- layer_scale_conv=None,
818
- sr_dim_ratio=1,
819
- sr_ratio=1,
820
- multi_query=False,
821
- use_swiglu=True,
822
- yolo_arch=False,
823
- downsample_shuffle=False,
824
- conv_base=False,
825
- cpb_mlp_hidden=512,
 
 
 
826
  ):
827
  """
828
  Args:
@@ -840,75 +838,68 @@ class FasterViTLayer(nn.Module):
840
  drop_path: drop path rate.
841
  norm_layer: normalization layer.
842
  layer_scale: layer scaling coefficient.
 
 
843
  """
844
 
845
  super().__init__()
846
  self.conv = conv
847
- self.yolo_arch = False
 
848
  if conv:
849
  if not yolo_arch:
850
- self.blocks = nn.ModuleList(
851
- [
852
- ConvBlock(
853
- dim=dim,
854
- drop_path=drop_path[i]
855
- if isinstance(drop_path, list)
856
- else drop_path,
857
- layer_scale=layer_scale_conv )
858
- for i in range(depth)
859
- ]
860
- )
861
  self.blocks = nn.Sequential(*self.blocks)
862
  else:
863
- self.blocks = C2f(dim, dim, n=depth, shortcut=True, e=0.5)
864
- self.yolo_arch = True
865
  else:
866
- if not isinstance(window_size, list):
867
- window_size = [window_size]
868
  self.window_size = window_size[0]
869
  self.do_single_windowing = True
870
- if not isinstance(sr_ratio, list):
871
- sr_ratio = [sr_ratio]
872
  self.sr_ratio = sr_ratio
873
- if any([sr != 1 for sr in sr_ratio]) or len(set(window_size)) > 1:
874
  self.do_single_windowing = False
875
  do_windowing = True
876
  else:
877
  self.do_single_windowing = True
878
  do_windowing = False
879
 
 
 
 
 
 
880
  self.blocks = nn.ModuleList()
881
  for i in range(depth):
882
-
883
  self.blocks.append(
884
- MultiResolutionAttention(
885
- window_size=window_size,
886
- sr_ratio=sr_ratio,
887
- dim=dim,
888
- dim_ratio=sr_dim_ratio,
889
- num_heads=num_heads,
890
- norm_layer=norm_layer,
891
- drop_path=drop_path[i]
892
- if isinstance(drop_path, list)
893
- else drop_path,
894
- layer_scale=layer_scale,
895
- qkv_bias=qkv_bias,
896
- qk_scale=qk_scale,
897
- use_swiglu=use_swiglu,
898
- do_windowing=do_windowing,
899
- multi_query=multi_query,
900
- conv_base=conv_base,
901
- cpb_mlp_hidden=cpb_mlp_hidden,
902
- )
903
- )
904
-
905
  self.blocks = nn.Sequential(*self.blocks)
906
 
907
  self.transformer = not conv
908
-
909
- self.downsample = (
910
- None if not downsample else Downsample(dim=dim, shuffle=downsample_shuffle)
911
- )
912
 
913
 
914
  def forward(self, x):
@@ -931,19 +922,16 @@ class FasterViTLayer(nn.Module):
931
  new_h = int(np.ceil(H/max_window_size)*max_window_size)
932
  new_w = int(np.ceil(W/max_window_size)*max_window_size)
933
  x = F.interpolate(x, size=(new_h, new_w), mode='nearest')
934
- warnings.warn(f"Choosen window size is not optimal for given resolution. Interpolation of features maps will be done and it can affect the performance. Max window size is {max_window_size}, feature map size is {H}x{W}, interpolated feature map size is {new_h}x{new_w}.")
 
935
 
936
 
937
  if self.transformer and self.do_single_windowing:
938
  H, W = x.shape[2], x.shape[3]
939
  x, pad_hw = window_partition(x, self.window_size)
940
 
 
941
  x = self.blocks(x)
942
- # if not self.yolo_arch:
943
- # for bn, blk in enumerate(self.blocks):
944
- # x = blk(x)
945
- # else:
946
- # x = self.blocks(x)
947
 
948
  if self.transformer and self.do_single_windowing:
949
  x = window_reverse(x, self.window_size, H, W, pad_hw)
@@ -958,12 +946,23 @@ class FasterViTLayer(nn.Module):
958
  return self.downsample(x), x # changing to output pre downsampled features
959
 
960
 
 
 
 
 
 
 
 
 
 
 
 
961
  class HiResNeck(nn.Module):
962
  """
963
  The block is used to output dense features from all stages
964
  Otherwise, by default, only the last stage features are returned with FasterViTv2
965
  """
966
- def __init__(self, dim, depths, neck_start_stage, full_features_head_dim):
967
 
968
  '''
969
  Hi Resolution neck to support output of high res features that are useful for dense tasks.
@@ -972,6 +971,7 @@ class HiResNeck(nn.Module):
972
  earlier layers result in higher resolution features at the cost of compute
973
  full_features_head_dim - number of channels in the dense features head
974
  '''
 
975
  # create feature projection layers for segmentation output
976
  self.neck_features_proj = nn.ModuleList()
977
  self.neck_start_stage = neck_start_stage
@@ -983,16 +983,24 @@ class HiResNeck(nn.Module):
983
 
984
  if (upsample_ratio > 1) or full_features_head_dim!=level_n_features_output:
985
  feature_projection = nn.Sequential()
986
- feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) #fast, but worse
987
-
988
- feature_projection.add_module("dconv", nn.ConvTranspose2d(level_n_features_output,
989
- full_features_head_dim, kernel_size=upsample_ratio, stride=upsample_ratio))
 
 
 
 
 
 
 
 
990
  else:
991
  feature_projection = nn.Sequential()
992
 
993
  self.neck_features_proj.append(feature_projection)
994
 
995
- if i>0 and self.levels[i-1].downsample is not None:
996
  upsample_ratio *= 2
997
 
998
  def forward(self, x, il_level=-1, full_features=None):
@@ -1006,49 +1014,48 @@ class HiResNeck(nn.Module):
1006
  feature_projection = self.neck_features_proj[il_level - self.neck_start_stage](x)
1007
  if feature_projection.shape[2] != full_features.shape[2] or feature_projection.shape[3] != full_features.shape[3]:
1008
  feature_projection = torch.nn.functional.pad(feature_projection, ( 0, -feature_projection.shape[3] + full_features.shape[3], 0, -feature_projection.shape[2] + full_features.shape[2]))
1009
- full_features += feature_projection
1010
  return full_features
1011
 
1012
-
1013
-
1014
  class FasterViT(nn.Module):
1015
  """
1016
  FasterViT
1017
  """
1018
 
1019
- def __init__(
1020
- self,
1021
- dim,
1022
- in_dim,
1023
- depths,
1024
- window_size,
1025
- mlp_ratio,
1026
- num_heads,
1027
- drop_path_rate=0.2,
1028
- in_chans=3,
1029
- num_classes=1000,
1030
- qkv_bias=False,
1031
- qk_scale=None,
1032
- layer_scale=None,
1033
- layer_scale_conv=None,
1034
- layer_norm_last=False,
1035
- sr_ratio=[1, 1, 1, 1],
1036
- max_depth=-1,
1037
- conv_base=False,
1038
- use_swiglu=False,
1039
- multi_query=False,
1040
- norm_layer=nn.LayerNorm,
1041
- drop_uniform=False,
1042
- yolo_arch=False,
1043
- shuffle_down=False,
1044
- downsample_shuffle=False,
1045
- return_full_features=False,
1046
- full_features_head_dim=128,
1047
- neck_start_stage=1,
1048
- use_neck=False,
1049
- cpb_mlp_hidden=512,
1050
- **kwargs,
1051
- ):
 
1052
  """
1053
  Args:
1054
  dim: feature size dimension.
@@ -1071,14 +1078,18 @@ class FasterViT(nn.Module):
1071
  for 224 resolution, the output of the stage before downsample:
1072
  stage 0: 56x56, stage 1: 28x28, stage 2: 14x14, stage 3: 7x7
1073
  use_neck: even for summarization embedding use neck
 
 
 
 
 
 
1074
  """
1075
  super().__init__()
1076
 
1077
  num_features = int(dim * 2 ** (len(depths) - 1))
1078
  self.num_classes = num_classes
1079
- self.patch_embed = PatchEmbed(
1080
- in_chans=in_chans, in_dim=in_dim, dim=dim, shuffle_down=shuffle_down
1081
- )
1082
  # set return_full_features true if we want to return full features from all stages
1083
  self.return_full_features = return_full_features
1084
  self.use_neck = use_neck
@@ -1087,110 +1098,53 @@ class FasterViT(nn.Module):
1087
  if drop_uniform:
1088
  dpr = [drop_path_rate for x in range(sum(depths))]
1089
 
1090
- if not isinstance(max_depth, list):
1091
- max_depth = [max_depth] * len(depths)
1092
 
1093
  self.levels = nn.ModuleList()
1094
  for i in range(len(depths)):
1095
  conv = True if (i == 0 or i == 1) else False
1096
 
1097
- level = FasterViTLayer(
1098
- dim=int(dim * 2 ** i),
1099
- depth=depths[i],
1100
- num_heads=num_heads[i],
1101
- window_size=window_size[i],
1102
- mlp_ratio=mlp_ratio,
1103
- qkv_bias=qkv_bias,
1104
- qk_scale=qk_scale,
1105
- conv=conv,
1106
- drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
1107
- downsample=(i < 3),
1108
- layer_scale=layer_scale,
1109
- layer_scale_conv=layer_scale_conv,
1110
- sr_ratio=sr_ratio[i],
1111
- use_swiglu=use_swiglu,
1112
- multi_query=multi_query,
1113
- norm_layer=norm_layer,
1114
- yolo_arch=yolo_arch,
1115
- downsample_shuffle=downsample_shuffle,
1116
- conv_base=conv_base,
1117
- cpb_mlp_hidden=cpb_mlp_hidden,
1118
-
1119
- )
1120
 
1121
  self.levels.append(level)
1122
 
1123
- if not SIMPLER_UP_TOWER:
1124
- if self.return_full_features or self.use_neck:
1125
- # create feature projection layers for segmentation output
1126
- self.neck_features_proj = nn.ModuleList()
1127
- self.neck_start_stage = neck_start_stage
1128
- upsample_ratio = 1
1129
- for i in range(len(depths)):
1130
- level_n_features_output = int(dim * 2 ** i)
1131
-
1132
- if self.neck_start_stage > i:
1133
- continue
1134
-
1135
- if (
1136
- upsample_ratio > 1
1137
- ) or full_features_head_dim != level_n_features_output:
1138
- feature_projection = nn.Sequential()
1139
- # pixel shuffle based upsampling
1140
- feature_projection.add_module(
1141
- "norm", nn.BatchNorm2d(level_n_features_output)
1142
- ) # fast, but worse
1143
- feature_projection.add_module(
1144
- "conv",
1145
- nn.Conv2d(
1146
- level_n_features_output,
1147
- full_features_head_dim
1148
- * upsample_ratio
1149
- * upsample_ratio,
1150
- kernel_size=1,
1151
- stride=1,
1152
- ),
1153
- )
1154
- feature_projection.add_module(
1155
- "upsample_pixelshuffle", nn.PixelShuffle(upsample_ratio)
1156
- )
1157
- else:
1158
- feature_projection = nn.Sequential()
1159
- feature_projection.add_module(
1160
- "norm", nn.BatchNorm2d(level_n_features_output)
1161
- )
1162
-
1163
- self.neck_features_proj.append(feature_projection)
1164
-
1165
- if i > 0 and self.levels[i - 1].downsample is not None:
1166
- upsample_ratio *= 2
1167
- else:
1168
- if self.return_full_features or self.use_neck:
1169
- self.high_res_neck = HiResNeck(dim, num_heads, depths, neck_start_stage, full_features_head_dim)
1170
 
1171
- num_features = (
1172
- full_features_head_dim
1173
- if (self.return_full_features or self.use_neck)
1174
- else num_features
1175
- )
1176
-
1177
- self.num_features = num_features
1178
 
1179
- self.norm = (
1180
- LayerNorm2d(num_features)
1181
- if layer_norm_last
1182
- else nn.BatchNorm2d(num_features)
1183
- )
1184
  self.avgpool = nn.AdaptiveAvgPool2d(1)
1185
- self.head = (
1186
- nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
1187
- )
1188
  self.apply(self._init_weights)
1189
- # pass
1190
 
1191
  def _init_weights(self, m):
1192
  if isinstance(m, nn.Linear):
1193
- trunc_normal_(m.weight, std=0.02)
1194
  if isinstance(m, nn.Linear) and m.bias is not None:
1195
  nn.init.constant_(m.bias, 0)
1196
  elif isinstance(m, nn.LayerNorm):
@@ -1203,23 +1157,72 @@ class FasterViT(nn.Module):
1203
  nn.init.ones_(m.weight)
1204
  nn.init.zeros_(m.bias)
1205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1206
  def change_window_size(self, new_window_size):
1207
  """
1208
- FasterViT uses windowed attention, it might be sensative to the choiuce of this parameter
1209
- especially in case of eneven partitioning of the feature maps.
1210
- FasterViT allows changing the window size post training.
1211
- Therefore it should be changed with different input image resolution.
1212
- Recommended values:
1213
- input res | window_size
1214
- 224 | 7
1215
- 256 | 8
1216
- 386 | 12
1217
- 512 | 16
1218
- Ideally, window_size should be a factor of the input resolution. In the third stage we divide the resolution by 16, so window_size should be img_res/16/2 for the third stage and img_res/32 for the last stage.
1219
- Applying in the brute force way, can be done smarter
 
 
 
1220
  """
1221
  window_size = new_window_size
1222
-
1223
  for module in self.modules():
1224
  if hasattr(module, "window_size"):
1225
  # check if tuple or a number
@@ -1232,100 +1235,292 @@ class FasterViT(nn.Module):
1232
  else:
1233
  module.window_size = window_size
1234
 
1235
- def set_optimal_window_size(self, image_dim):
 
1236
  """
1237
  Using hand picked window size for various resolutions.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1238
  """
 
 
 
 
 
 
 
 
 
 
 
 
1239
  if isinstance(image_dim, list) or isinstance(image_dim, tuple):
1240
  image_dim = min(image_dim)
1241
 
1242
- if image_dim == 224:
1243
- new_window_size = 7
1244
- elif image_dim == 256:
1245
- new_window_size = 8
1246
- elif image_dim == 384:
1247
- new_window_size = 12
1248
- elif image_dim == 512:
1249
- new_window_size = 16
1250
- else:
1251
- if image_dim < 512:
1252
- new_window_size = np.ceil(image_dim / 32)
1253
- else:
1254
- new_window_size = 16
1255
 
1256
- print(f"Changing window size to {new_window_size}")
1257
  self.change_window_size(new_window_size = new_window_size)
1258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1259
 
1260
- @torch.jit.ignore
1261
- def no_weight_decay_keywords(self):
1262
- return {"rpb"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1263
 
1264
- def forward_features(self, x):
1265
- x = self.patch_embed(x)
1266
- full_features = None
1267
- for il, level in enumerate(self.levels):
1268
- x, pre_downsample_x = level(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1269
 
1270
- if self.return_full_features or self.use_neck:
1271
- if not SIMPLER_UP_TOWER:
1272
- if self.neck_start_stage > il:
1273
- continue
1274
- if full_features is None:
1275
- full_features = self.neck_features_proj[il - self.neck_start_stage](
1276
- pre_downsample_x
1277
- )
1278
- else:
1279
- # upsample torch tensor x to match full_features size, and add to full_features
1280
- feature_projection = self.neck_features_proj[
1281
- il - self.neck_start_stage
1282
- ](pre_downsample_x)
1283
- if (
1284
- feature_projection.shape[2] != full_features.shape[2]
1285
- or feature_projection.shape[3] != full_features.shape[3]
1286
- ):
1287
- feature_projection = torch.nn.functional.pad(
1288
- feature_projection,
1289
- (
1290
- 0,
1291
- -feature_projection.shape[3] + full_features.shape[3],
1292
- 0,
1293
- -feature_projection.shape[2] + full_features.shape[2],
1294
- ),
1295
- )
1296
- full_features += feature_projection
1297
- else:
1298
- full_features = self.high_res_neck(pre_downsample_x, il, full_features)
1299
 
1300
- x = self.norm(x) # new version for
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1301
 
1302
- x = self.avgpool(x)
1303
- x = torch.flatten(x, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1304
 
1305
- if not self.return_full_features:
1306
- return x, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1307
 
1308
- return x, full_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1309
 
1310
- def forward(self, x):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1311
 
1312
- x, full_features = self.forward_features(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1313
 
1314
- x = self.head(x)
1315
- if full_features is not None:
1316
- return x, full_features
1317
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1318
 
1319
- def switch_to_deploy(self):
1320
- """
1321
- A method to perform model self-compression
1322
- merges BN into conv layers
1323
- converts MLP relative positional bias into precomputed buffers
1324
- """
1325
- for level in [self.patch_embed, self.levels, self.head]:
1326
- for module in level.modules():
1327
- if hasattr(module, "switch_to_deploy"):
1328
- module.switch_to_deploy()
1329
 
1330
  @register_model
1331
  def fastervit2_large_fullres_ws8(pretrained=False, **kwargs):
@@ -1348,7 +1543,7 @@ def fastervit2_large_fullres_ws8(pretrained=False, **kwargs):
1348
  **kwargs,
1349
  )
1350
  if pretrained:
1351
- model.load_state_dict(torch.load(pretrained))
1352
  return model
1353
 
1354
 
@@ -1373,7 +1568,7 @@ def fastervit2_large_fullres_ws16(pretrained=False, **kwargs):
1373
  **kwargs,
1374
  )
1375
  if pretrained:
1376
- model.load_state_dict(torch.load(pretrained))
1377
  return model
1378
 
1379
 
@@ -1398,28 +1593,210 @@ def fastervit2_large_fullres_ws32(pretrained=False, **kwargs):
1398
  **kwargs,
1399
  )
1400
  if pretrained:
1401
- model.load_state_dict(torch.load(pretrained))
1402
  return model
1403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1404
 
1405
  @register_model
1406
- def eradio(pretrained=False, **kwargs):
1407
- return fastervit2_large_fullres_ws16(pretrained=pretrained, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1408
 
1409
- '''
1410
- Suggested way to use:
1411
- from transformers import AutoModel
1412
- model = AutoModel.from_pretrained("nvidia/E-RADIO", trust_remote_code=True)
1413
 
1414
- model.model.set_optimal_window_size(image_dim = data["image"][0].shape[:2])
1415
- imgs = [torch.tensor(img).permute(2,0,1)/255.0 for img in data["image"]] #res is 224
1416
- input_images = torch.stack(imgs).cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1417
 
1418
- model.eval()
1419
- model.cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1420
 
1421
- cls_token, features = model(input_images)
1422
- cls_token = features.mean([2, 3])
1423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1424
 
1425
- '''
 
 
 
12
  # Mike Ranzinger, Greg Heinrich, Jan Kautz, and Pavlo Molchanov. "AM-RADIO: Agglomerative Model--Reduce All Domains Into One." arXiv preprint arXiv:2312.06709 (2023).
13
 
14
  # based on FasterViT, Swin Transformer, YOLOv8
15
+
16
  # FasterViT:
17
  # Ali Hatamizadeh, Greg Heinrich, Hongxu Yin, Andrew Tao, Jose M. Alvarez, Jan Kautz, and Pavlo Molchanov. "FasterViT: Fast Vision Transformers with Hierarchical Attention." arXiv preprint arXiv:2306.06189 (2023).
18
 
19
+ import timm
20
  import torch
21
  import torch.nn as nn
22
  from timm.models.registry import register_model
 
24
  from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
25
  import numpy as np
26
  import torch.nn.functional as F
27
+ import math
28
  import warnings
29
 
 
 
30
  #######################
31
  ## Codebase from YOLOv8
32
  ## BEGINNING
 
97
  @torch.no_grad()
98
  def switch_to_deploy(self):
99
  # return 1
100
+ if not isinstance(self.bn, nn.Identity):
101
+ c, bn = self.conv, self.bn
102
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
103
+ w = c.weight * w[:, None, None, None]
104
+ b = bn.bias - bn.running_mean * bn.weight / \
105
+ (bn.running_var + bn.eps)**0.5
106
 
107
+ self.conv.weight.data.copy_(w)
108
+ self.conv.bias = nn.Parameter(b)
109
 
110
+ self.bn = nn.Identity()
111
 
112
  def autopad(k, p=None, d=1): # kernel, padding, dilation
113
  """Pad to 'same' shape outputs."""
 
123
  ## END
124
  #######################
125
 
 
126
  def pixel_unshuffle(data, factor=2):
127
  # performs nn.PixelShuffle(factor) in reverse, torch has some bug for ONNX and TRT, so doing it manually
128
  B, C, H, W = data.shape
129
+ return data.view(B, C, factor, H//factor, factor, W//factor).permute(0,1,2,4,3,5).reshape(B, -1, H//factor, W//factor)
 
 
 
 
 
130
 
131
  class SwiGLU(nn.Module):
132
  # should be more advanced, but doesnt improve results so far
 
137
 
138
  def window_partition(x, window_size):
139
  """
140
+ Function for partitioning image into windows and later do windowed attention
141
  Args:
142
  x: (B, C, H, W)
143
  window_size: window size
 
147
  """
148
  B, C, H, W = x.shape
149
 
150
+ if window_size == 0 or (window_size==H and window_size==W):
151
  windows = x.flatten(2).transpose(1, 2)
152
  Hp, Wp = H, W
153
  else:
154
  pad_h = (window_size - H % window_size) % window_size
155
  pad_w = (window_size - W % window_size) % window_size
 
156
  if pad_h > 0 or pad_w > 0:
157
+ x = F.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
158
  Hp, Wp = H + pad_h, W + pad_w
159
 
160
  x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)
161
+ windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
162
 
163
  return windows, (Hp, Wp)
164
 
 
165
  class Conv2d_BN(nn.Module):
166
+ '''
167
  Conv2d + BN layer with folding capability to speed up inference
168
+ Can be merged with Conv() function with additional arguments
169
+ '''
170
+ def __init__(self, a, b, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1, bias=False):
 
 
 
 
 
 
 
 
 
 
 
171
  super().__init__()
172
+ self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, padding, dilation, groups, bias=False)
 
 
173
  if 1:
174
  self.bn = torch.nn.BatchNorm2d(b)
175
  torch.nn.init.constant_(self.bn.weight, bn_weight_init)
176
  torch.nn.init.constant_(self.bn.bias, 0)
177
 
178
+ def forward(self,x):
179
  x = self.conv(x)
180
  x = self.bn(x)
181
  return x
 
186
  c, bn = self.conv, self.bn
187
  w = bn.weight / (bn.running_var + bn.eps) ** 0.5
188
  w = c.weight * w[:, None, None, None]
189
+ b = bn.bias - bn.running_mean * bn.weight / \
190
+ (bn.running_var + bn.eps)**0.5
191
  self.conv.weight.data.copy_(w)
192
  self.conv.bias = nn.Parameter(b)
193
  self.bn = nn.Identity()
194
 
195
 
196
+
197
  def window_reverse(windows, window_size, H, W, pad_hw):
198
  """
199
+ Windows to the full feature map
200
  Args:
201
  windows: local window features (num_windows*B, window_size, window_size, C)
202
  window_size: Window size
 
209
  """
210
  # print(f"window_reverse, windows.shape {windows.shape}")
211
  Hp, Wp = pad_hw
212
+ if window_size == 0 or (window_size==H and window_size==W):
213
  B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
214
  x = windows.transpose(1, 2).view(B, -1, H, W)
215
  else:
216
  B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
217
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
218
+ x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], Hp, Wp)
 
 
219
 
220
  if Hp > H or Wp > W:
221
+ x = x[:, :, :H, :W, ].contiguous()
222
 
223
  return x
224
 
225
 
226
+
227
  class PosEmbMLPSwinv2D(nn.Module):
228
  """
229
  2D positional embedding from Swin Transformer v2
 
260
 
261
  def relative_bias_initialization(self, window_size, num_heads, pretrained_window_size, seq_length, no_log):
262
  # as in separate function to support window size chage after model weights loading
 
263
  relative_coords_h = torch.arange(
264
  -(window_size[0] - 1), window_size[0], dtype=torch.float32
265
  )
 
332
  self.relative_bias = relative_bias.to(self.relative_bias.device)
333
 
334
  if self.deploy and self.grid_exists:
335
+ input_tensor = input_tensor + self.relative_bias
336
  return input_tensor
337
 
338
  if 1:
 
356
 
357
  self.relative_bias = relative_position_bias.unsqueeze(0)
358
 
359
+ input_tensor = input_tensor + self.relative_bias
360
  return input_tensor
361
 
362
 
363
  class GRAAttentionBlock(nn.Module):
364
+ def __init__(self, window_size, dim_in, dim_out,
365
+ num_heads, drop_path=0., qk_scale=None, qkv_bias=False,
366
+ norm_layer=nn.LayerNorm, layer_scale=None,
367
+ use_swiglu=True,
368
+ subsample_ratio=1, dim_ratio=1, conv_base=False,
369
+ do_windowing=True, multi_query=False, use_shift=0,
370
+ cpb_mlp_hidden=512, conv_groups_ratio=0):
371
+ '''
372
+ Global Resolution Attention Block , see README for details
373
+ Attention with subsampling to get a bigger receptive field for attention
374
+ conv_base - use conv2d instead of avgpool2d for downsample / upsample
375
+
376
+
377
+ '''
 
 
 
 
 
378
  super().__init__()
379
 
380
+ self.shift_size=window_size//2 if use_shift else 0
381
 
382
  self.do_windowing = do_windowing
383
+ self.subsample_ratio = subsample_ratio
384
+
385
+
386
 
387
  if do_windowing:
388
  if conv_base:
389
  self.downsample_op = nn.Conv2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
390
+
391
+
392
  self.downsample_mixer = nn.Identity()
393
  self.upsample_mixer = nn.Identity()
394
  self.upsample_op = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
 
398
  self.upsample_mixer = nn.Upsample(scale_factor=subsample_ratio, mode='nearest') if subsample_ratio > 1 else nn.Identity()
399
  self.upsample_op = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1, padding=0, bias=False) if subsample_ratio > 1 else nn.Identity()
400
 
401
+
402
+ # in case there is no downsampling conv we want to have it separately
403
+ # will help with information propagation between windows
404
+ if subsample_ratio == 1:
405
+ # conv_groups_ratio=0
406
+ self.pre_conv = Conv2d_BN(dim_in, dim_in, kernel_size=3, stride=1, padding=1, groups=max(1,int(conv_groups_ratio*dim_in)), bias=False)
407
+ # self.pre_conv = nn.Conv2d(dim_in, dim_in, kernel_size=3, stride=1, padding=1, groups=max(1,int(conv_groups_ratio*dim_in)), bias=False)
408
+ # self.pre_conv_act = nn.ReLU6()
409
+ #for simplicity:
410
+ self.pre_conv_act = nn.Identity()
411
+ if conv_groups_ratio == -1:
412
+ self.pre_conv = nn.Identity()
413
+ self.pre_conv_act = nn.Identity()
414
+
415
  self.window_size = window_size
416
 
417
  self.norm1 = norm_layer(dim_in)
 
421
  num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
422
  resolution=window_size,
423
  seq_length=window_size**2, dim_out=dim_in, multi_query=multi_query,
424
+ shift_size=self.shift_size, cpb_mlp_hidden=cpb_mlp_hidden)
425
 
426
  self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
427
 
 
444
 
445
  def forward(self, x):
446
  skip_connection = x
447
+ attn_mask = None
448
+
449
+ # in case there is no downsampling conv we want to have it separately
450
+ # will help with information propagation
451
+ if self.subsample_ratio == 1:
452
+ x = self.pre_conv_act(self.pre_conv(x)) + skip_connection
453
 
454
  if self.do_windowing:
455
  # performing windowing if required
456
  x = self.downsample_op(x)
457
  x = self.downsample_mixer(x)
458
 
459
+ if self.window_size>0:
460
  H, W = x.shape[2], x.shape[3]
461
 
462
+ if self.shift_size > 0 and H>self.window_size and W>self.window_size:
463
+ # @swin like cyclic shift, doesnt show better performance
464
+ x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3))
465
+
466
  x, pad_hw = window_partition(x, self.window_size)
467
 
468
+ if self.shift_size > 0 and H>self.window_size and W>self.window_size:
469
+ # set atten matrix to have -100 and the top right square
470
+ # attn[:, :, :-self.shift_size, -self.shift_size:] = -100.0
471
+ # calculate attention mask for SW-MSA
472
+ # not used in final version, can be useful for some cases especially for high res
473
+ H, W = pad_hw
474
+ img_mask = torch.zeros((1, H, W, 1), device=x.device) # 1 H W 1
475
+ h_slices = (slice(0, -self.window_size),
476
+ slice(-self.window_size, -self.shift_size),
477
+ slice(-self.shift_size, None))
478
+ w_slices = (slice(0, -self.window_size),
479
+ slice(-self.window_size, -self.shift_size),
480
+ slice(-self.shift_size, None))
481
+ cnt = 0
482
+ for h in h_slices:
483
+ for w in w_slices:
484
+ img_mask[:, h, w, :] = cnt
485
+ cnt += 1
486
+ img_mask = img_mask.transpose(1,2).transpose(1,3)
487
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
488
+
489
+ mask_windows = mask_windows[0].view(-1, self.window_size * self.window_size)
490
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
491
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
492
+
493
  # window attention
494
+ x = x + self.drop_path1(self.gamma1*self.attn(self.norm1(x), attn_mask=attn_mask)) # or pass H,W
495
  # mlp layer
496
+ x = x + self.drop_path2(self.gamma2*self.mlp(self.norm2(x)))
497
 
498
  if self.do_windowing:
499
  if self.window_size > 0:
500
  x = window_reverse(x, self.window_size, H, W, pad_hw)
501
 
502
+ # reverse cyclic shift
503
+ if self.shift_size > 0 and H>self.window_size and W>self.window_size:
504
+ # @swin like cyclic shift, not tested
505
+ x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(2, 3))
506
+
507
  x = self.upsample_mixer(x)
508
  x = self.upsample_op(x)
509
 
510
+
511
+ if x.shape[2] != skip_connection.shape[2] or x.shape[3] != skip_connection.shape[3]:
512
+ x = torch.nn.functional.pad(x, ( 0, -x.shape[3] + skip_connection.shape[3], 0, -x.shape[2] + skip_connection.shape[2]), mode="reflect")
 
 
 
 
 
 
 
 
 
 
513
  # need to add skip connection because downsampling and upsampling will break residual connection
514
  # 0.5 is needed to make sure that the skip connection is not too strong
515
  # in case of no downsample / upsample we can show that 0.5 compensates for the residual connection
516
  x = 0.5 * x + 0.5 * skip_connection
 
517
  return x
518
 
519
 
520
+
521
+
522
  class MultiResolutionAttention(nn.Module):
523
  """
524
  MultiResolutionAttention (MRA) module
525
  The idea is to use multiple attention blocks with different resolution
526
  Feature maps are downsampled / upsampled for each attention block on different blocks
527
+ Every attention block supports windowing
 
528
  """
529
 
530
+ def __init__(self, window_size, sr_ratio,
531
+ dim, dim_ratio, num_heads,
532
+ do_windowing=True,
533
+ layer_scale=1e-5, norm_layer=nn.LayerNorm,
534
+ drop_path = 0, qkv_bias=False, qk_scale=1.0,
535
+ use_swiglu=True, multi_query=False, conv_base=False,
536
+ use_shift=0, cpb_mlp_hidden=512, conv_groups_ratio=0) -> None:
 
 
 
 
 
 
 
 
 
 
 
537
  """
538
  Args:
539
  input_resolution: input image resolution
540
  window_size: window size
541
  compression_ratio: compression ratio
542
  max_depth: maximum depth of the GRA module
543
+ use_shift: do window shifting
544
  """
545
  super().__init__()
546
 
 
548
 
549
  self.attention_blocks = nn.ModuleList()
550
 
551
+
552
  for i in range(depth):
553
  subsample_ratio = sr_ratio[i]
554
  if len(window_size) > i:
 
556
  else:
557
  window_size_local = window_size[0]
558
 
559
+ self.attention_blocks.append(GRAAttentionBlock(window_size=window_size_local,
560
+ dim_in=dim, dim_out=dim, num_heads=num_heads,
561
+ qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer,
562
+ layer_scale=layer_scale, drop_path=drop_path,
563
+ use_swiglu=use_swiglu, subsample_ratio=subsample_ratio, dim_ratio=dim_ratio,
564
+ do_windowing=do_windowing, multi_query=multi_query, conv_base=conv_base,
565
+ use_shift=use_shift, cpb_mlp_hidden=cpb_mlp_hidden, conv_groups_ratio=conv_groups_ratio),
566
+ )
 
 
 
 
 
 
 
 
 
 
 
 
567
 
568
  def forward(self, x):
569
 
 
573
  return x
574
 
575
 
576
+
577
  class Mlp(nn.Module):
578
  """
579
  Multi-Layer Perceptron (MLP) block
580
  """
581
 
582
+ def __init__(self,
583
+ in_features,
584
+ hidden_features=None,
585
+ out_features=None,
586
+ act_layer=nn.GELU,
587
+ use_swiglu=True,
588
+ drop=0.):
 
 
589
  """
590
  Args:
591
  in_features: input features dimension.
 
598
  super().__init__()
599
  out_features = out_features or in_features
600
  hidden_features = hidden_features or in_features
601
+ self.fc1 = nn.Linear(in_features, hidden_features * (2 if use_swiglu else 1), bias=False)
 
 
602
  self.act = act_layer()
603
  self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
604
 
 
611
  x = x.view(x_size)
612
  return x
613
 
 
614
  class Downsample(nn.Module):
615
  """
616
  Down-sampling block
 
617
  Pixel Unshuffle is used for down-sampling, works great accuracy - wise but takes 10% more TRT time
618
  """
619
 
620
+ def __init__(self,
621
+ dim,
622
+ shuffle = False,
623
+ ):
624
  """
625
  Args:
626
  dim: feature size dimension.
627
+ shuffle: idea with
628
  keep_dim: bool argument for maintaining the resolution.
629
  """
630
 
 
633
 
634
  if shuffle:
635
  self.norm = lambda x: pixel_unshuffle(x, factor=2)
636
+ self.reduction = Conv2d_BN(dim*4, dim_out, 1, 1, 0, bias=False)
637
+ # pixel unshuffleging works well but doesnt provide any speedup
638
  else:
639
+ # removed layer norm for better, in this formulation we are getting 10% better speed
640
+ # LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension
641
+ # therefore we remove it compared to the original implementation in FasterViTv1
642
  self.norm = nn.Identity()
643
  self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
644
 
645
+
646
  def forward(self, x):
647
  x = self.norm(x)
648
  x = self.reduction(x)
 
653
  """
654
  Patch embedding block
655
  Used to convert image into an initial set of feature maps with lower resolution
 
656
  """
657
 
658
  def __init__(self, in_chans=3, in_dim=64, dim=96, shuffle_down=False):
 
672
  Conv2d_BN(in_chans, in_dim, 3, 2, 1, bias=False),
673
  nn.ReLU(),
674
  Conv2d_BN(in_dim, dim, 3, 2, 1, bias=False),
675
+ nn.ReLU()
676
+ )
677
  else:
678
  self.proj = lambda x: pixel_unshuffle(x, factor=4)
679
+ self.conv_down = nn.Sequential(Conv2d_BN(in_chans*16, dim, 3, 1, 1),
680
+ nn.ReLU(),
681
+ )
682
 
683
  def forward(self, x):
684
  x = self.proj(x)
 
686
  return x
687
 
688
 
689
+
690
  class ConvBlock(nn.Module):
691
  """
692
  Convolutional block, used in first couple of stages
 
730
  # Windowed Attention from SwinV2
731
  # use a MLP trick to deal with various input image resolutions, then fold it to improve speed
732
 
733
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, resolution=0,
734
+ seq_length=0, dim_out=None, multi_query=False, shift_size=0, cpb_mlp_hidden=512):
 
 
 
 
 
 
 
 
 
 
735
  # taken from EdgeViT and tweaked with attention bias.
736
  super().__init__()
737
+ if not dim_out: dim_out = dim
738
+ self.shift_size = shift_size
739
  self.multi_query = multi_query
740
  self.num_heads = num_heads
741
  head_dim = dim // num_heads
 
747
  if not multi_query:
748
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
749
  else:
750
+ self.qkv = nn.Linear(dim, dim + 2*self.head_dim, bias=qkv_bias)
751
 
752
  self.proj = nn.Linear(dim, dim_out, bias=False)
753
  # attention positional bias
754
+ self.pos_emb_funct = PosEmbMLPSwinv2D(window_size=[resolution, resolution],
755
+ pretrained_window_size=[resolution, resolution],
756
+ num_heads=num_heads,
757
+ seq_length=seq_length,
758
+ cpb_mlp_hidden=cpb_mlp_hidden)
 
 
759
 
760
  self.resolution = resolution
761
 
762
+ def forward(self, x, attn_mask = None):
763
  B, N, C = x.shape
764
 
765
  if not self.multi_query:
766
+ qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
 
 
 
 
767
  q, k, v = qkv[0], qkv[1], qkv[2]
768
  else:
769
  qkv = self.qkv(x)
770
+ (q, k, v) = qkv.split([self.dim_internal, self.head_dim, self.head_dim], dim=2)
 
 
771
 
772
+ q = q.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
 
 
773
  k = k.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
774
  v = v.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
775
 
 
777
 
778
  attn = self.pos_emb_funct(attn)
779
 
780
+ #add window shift
781
+ if attn_mask is not None:
782
+ nW = attn_mask.shape[0]
783
+ attn = attn.view(B // nW, nW, self.num_heads, N, N) + attn_mask.unsqueeze(1).unsqueeze(0)
784
+ attn = attn.view(-1, self.num_heads, N, N)
785
+
786
  attn = attn.softmax(dim=-1)
787
  x = (attn @ v).transpose(1, 2).reshape(B, -1, C)
788
  x = self.proj(x)
789
  return x
790
 
791
 
792
+
793
  class FasterViTLayer(nn.Module):
794
  """
795
  fastervitlayer
796
  """
797
 
798
+ def __init__(self,
799
+ dim,
800
+ depth,
801
+ num_heads,
802
+ window_size,
803
+ conv=False,
804
+ downsample=True,
805
+ mlp_ratio=4.,
806
+ qkv_bias=False,
807
+ qk_scale=None,
808
+ norm_layer=nn.LayerNorm,
809
+ drop_path=0.,
810
+ layer_scale=None,
811
+ layer_scale_conv=None,
812
+ sr_dim_ratio=1,
813
+ sr_ratio=1,
814
+ multi_query=False,
815
+ use_swiglu=True,
816
+ yolo_arch=False,
817
+ downsample_shuffle=False,
818
+ conv_base=False,
819
+ use_shift=False,
820
+ cpb_mlp_hidden=512,
821
+ conv_groups_ratio=0,
822
+ verbose: bool = True,
823
+
824
  ):
825
  """
826
  Args:
 
838
  drop_path: drop path rate.
839
  norm_layer: normalization layer.
840
  layer_scale: layer scaling coefficient.
841
+ use_shift: SWIN like window shifting for half the window size for every alternating layer (considering multi-resolution)
842
+ conv_groups_ratio: group ratio for conv when no subsampling in multi-res attention
843
  """
844
 
845
  super().__init__()
846
  self.conv = conv
847
+ self.yolo_arch=False
848
+ self.verbose = verbose
849
  if conv:
850
  if not yolo_arch:
851
+ self.blocks = nn.ModuleList([
852
+ ConvBlock(dim=dim,
853
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
854
+ layer_scale=layer_scale_conv)
855
+ for i in range(depth)])
 
 
 
 
 
 
856
  self.blocks = nn.Sequential(*self.blocks)
857
  else:
858
+ self.blocks = C2f(dim,dim,n=depth,shortcut=True,e=0.5)
859
+ self.yolo_arch=True
860
  else:
861
+ if not isinstance(window_size, list): window_size = [window_size]
 
862
  self.window_size = window_size[0]
863
  self.do_single_windowing = True
864
+ if not isinstance(sr_ratio, list): sr_ratio = [sr_ratio]
 
865
  self.sr_ratio = sr_ratio
866
+ if any([sr!=1 for sr in sr_ratio]) or len(set(window_size))>1:
867
  self.do_single_windowing = False
868
  do_windowing = True
869
  else:
870
  self.do_single_windowing = True
871
  do_windowing = False
872
 
873
+ #for v2_2
874
+ if conv_groups_ratio != -1:
875
+ self.do_single_windowing = False
876
+ do_windowing = True
877
+
878
  self.blocks = nn.ModuleList()
879
  for i in range(depth):
 
880
  self.blocks.append(
881
+ MultiResolutionAttention(window_size=window_size,
882
+ sr_ratio=sr_ratio,
883
+ dim=dim,
884
+ dim_ratio = sr_dim_ratio,
885
+ num_heads=num_heads,
886
+ norm_layer=norm_layer,
887
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
888
+ layer_scale=layer_scale,
889
+ qkv_bias=qkv_bias,
890
+ qk_scale=qk_scale,
891
+ use_swiglu=use_swiglu,
892
+ do_windowing=do_windowing,
893
+ multi_query=multi_query,
894
+ conv_base=conv_base,
895
+ cpb_mlp_hidden=cpb_mlp_hidden,
896
+ use_shift =0 if ((not use_shift) or ((i) % 2 == 0)) else True ,
897
+ conv_groups_ratio=conv_groups_ratio,
898
+ ))
 
 
 
899
  self.blocks = nn.Sequential(*self.blocks)
900
 
901
  self.transformer = not conv
902
+ self.downsample = None if not downsample else Downsample(dim=dim, shuffle=downsample_shuffle)
 
 
 
903
 
904
 
905
  def forward(self, x):
 
922
  new_h = int(np.ceil(H/max_window_size)*max_window_size)
923
  new_w = int(np.ceil(W/max_window_size)*max_window_size)
924
  x = F.interpolate(x, size=(new_h, new_w), mode='nearest')
925
+ if self.verbose:
926
+ warnings.warn(f"Choosen window size is not optimal for given resolution. Interpolation of features maps will be done and it can affect the performance. Max window size is {max_window_size}, feature map size is {H}x{W}, interpolated feature map size is {new_h}x{new_w}.")
927
 
928
 
929
  if self.transformer and self.do_single_windowing:
930
  H, W = x.shape[2], x.shape[3]
931
  x, pad_hw = window_partition(x, self.window_size)
932
 
933
+ #run main blocks
934
  x = self.blocks(x)
 
 
 
 
 
935
 
936
  if self.transformer and self.do_single_windowing:
937
  x = window_reverse(x, self.window_size, H, W, pad_hw)
 
946
  return self.downsample(x), x # changing to output pre downsampled features
947
 
948
 
949
+ class InterpolateLayer(nn.Module):
950
+ def __init__(self, size=None, scale_factor=None, mode='nearest'):
951
+ super(InterpolateLayer, self).__init__()
952
+ self.size = size
953
+ self.scale_factor = scale_factor
954
+ self.mode = mode
955
+
956
+ def forward(self, x):
957
+ return F.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode)
958
+
959
+
960
  class HiResNeck(nn.Module):
961
  """
962
  The block is used to output dense features from all stages
963
  Otherwise, by default, only the last stage features are returned with FasterViTv2
964
  """
965
+ def __init__(self, dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled):
966
 
967
  '''
968
  Hi Resolution neck to support output of high res features that are useful for dense tasks.
 
971
  earlier layers result in higher resolution features at the cost of compute
972
  full_features_head_dim - number of channels in the dense features head
973
  '''
974
+ super().__init__()
975
  # create feature projection layers for segmentation output
976
  self.neck_features_proj = nn.ModuleList()
977
  self.neck_start_stage = neck_start_stage
 
983
 
984
  if (upsample_ratio > 1) or full_features_head_dim!=level_n_features_output:
985
  feature_projection = nn.Sequential()
986
+ if False:
987
+ feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) #fast, but worse
988
+ feature_projection.add_module("dconv", nn.ConvTranspose2d(level_n_features_output,
989
+ full_features_head_dim, kernel_size=upsample_ratio, stride=upsample_ratio))
990
+ else:
991
+ # B, in_channels, H, W -> B, in_channels, H*upsample_ratio, W*upsample_ratio
992
+ # print("upsample ratio", upsample_ratio, level_n_features_output, level_n_features_output)
993
+ feature_projection.add_module("upsample", InterpolateLayer(scale_factor=upsample_ratio, mode='nearest'))
994
+ feature_projection.add_module("conv1", nn.Conv2d(level_n_features_output, level_n_features_output, kernel_size=3, stride=1, padding=1, groups=level_n_features_output))
995
+ feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output))
996
+ # B, in_channels, H*upsample_ratio, W*upsample_ratio -> B, full_features_head_dim, H*upsample_ratio, W*upsample_ratio
997
+ feature_projection.add_module("conv2", nn.Conv2d(level_n_features_output, full_features_head_dim, kernel_size=1, stride=1, padding=0))
998
  else:
999
  feature_projection = nn.Sequential()
1000
 
1001
  self.neck_features_proj.append(feature_projection)
1002
 
1003
+ if i>0 and downsample_enabled[i]:
1004
  upsample_ratio *= 2
1005
 
1006
  def forward(self, x, il_level=-1, full_features=None):
 
1014
  feature_projection = self.neck_features_proj[il_level - self.neck_start_stage](x)
1015
  if feature_projection.shape[2] != full_features.shape[2] or feature_projection.shape[3] != full_features.shape[3]:
1016
  feature_projection = torch.nn.functional.pad(feature_projection, ( 0, -feature_projection.shape[3] + full_features.shape[3], 0, -feature_projection.shape[2] + full_features.shape[2]))
1017
+ full_features = full_features + feature_projection
1018
  return full_features
1019
 
 
 
1020
  class FasterViT(nn.Module):
1021
  """
1022
  FasterViT
1023
  """
1024
 
1025
+ def __init__(self,
1026
+ dim,
1027
+ in_dim,
1028
+ depths,
1029
+ window_size,
1030
+ mlp_ratio,
1031
+ num_heads,
1032
+ drop_path_rate=0.2,
1033
+ in_chans=3,
1034
+ num_classes=1000,
1035
+ qkv_bias=False,
1036
+ qk_scale=None,
1037
+ layer_scale=None,
1038
+ layer_scale_conv=None,
1039
+ layer_norm_last=False,
1040
+ sr_ratio = [1, 1, 1, 1],
1041
+ max_depth = -1,
1042
+ conv_base=False,
1043
+ use_swiglu=False,
1044
+ multi_query=False,
1045
+ norm_layer=nn.LayerNorm,
1046
+ drop_uniform=False,
1047
+ yolo_arch=False,
1048
+ shuffle_down=False,
1049
+ downsample_shuffle=False,
1050
+ return_full_features=False,
1051
+ full_features_head_dim=128,
1052
+ neck_start_stage=1,
1053
+ use_neck=False,
1054
+ use_shift=False,
1055
+ cpb_mlp_hidden=512,
1056
+ conv_groups_ratio=0,
1057
+ verbose: bool = False,
1058
+ **kwargs):
1059
  """
1060
  Args:
1061
  dim: feature size dimension.
 
1078
  for 224 resolution, the output of the stage before downsample:
1079
  stage 0: 56x56, stage 1: 28x28, stage 2: 14x14, stage 3: 7x7
1080
  use_neck: even for summarization embedding use neck
1081
+ use_shift: SWIN like window shifting but without masking attention
1082
+ conv_groups_ratio: will be used for conv blocks where there is no multires attention,
1083
+ if 0 then normal conv,
1084
+ if 1 then channels are independent,
1085
+ if -1 then no conv at all
1086
+
1087
  """
1088
  super().__init__()
1089
 
1090
  num_features = int(dim * 2 ** (len(depths) - 1))
1091
  self.num_classes = num_classes
1092
+ self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim, shuffle_down=shuffle_down)
 
 
1093
  # set return_full_features true if we want to return full features from all stages
1094
  self.return_full_features = return_full_features
1095
  self.use_neck = use_neck
 
1098
  if drop_uniform:
1099
  dpr = [drop_path_rate for x in range(sum(depths))]
1100
 
1101
+ if not isinstance(max_depth, list): max_depth = [max_depth] * len(depths)
 
1102
 
1103
  self.levels = nn.ModuleList()
1104
  for i in range(len(depths)):
1105
  conv = True if (i == 0 or i == 1) else False
1106
 
1107
+ level = FasterViTLayer(dim=int(dim * 2 ** i),
1108
+ depth=depths[i],
1109
+ num_heads=num_heads[i],
1110
+ window_size=window_size[i],
1111
+ mlp_ratio=mlp_ratio,
1112
+ qkv_bias=qkv_bias,
1113
+ qk_scale=qk_scale,
1114
+ conv=conv,
1115
+ drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
1116
+ downsample=(i < len(depths) - 1),
1117
+ layer_scale=layer_scale,
1118
+ layer_scale_conv=layer_scale_conv,
1119
+ sr_ratio=sr_ratio[i],
1120
+ use_swiglu=use_swiglu,
1121
+ multi_query=multi_query,
1122
+ norm_layer=norm_layer,
1123
+ yolo_arch=yolo_arch,
1124
+ downsample_shuffle=downsample_shuffle,
1125
+ conv_base=conv_base,
1126
+ cpb_mlp_hidden=cpb_mlp_hidden,
1127
+ use_shift=use_shift,
1128
+ conv_groups_ratio=conv_groups_ratio,
1129
+ verbose=verbose)
1130
 
1131
  self.levels.append(level)
1132
 
1133
+ if self.return_full_features or self.use_neck:
1134
+ #num_heads
1135
+ downsample_enabled = [self.levels[i-1].downsample is not None for i in range(len(self.levels))]
1136
+ self.high_res_neck = HiResNeck(dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1137
 
1138
+ self.switched_to_deploy = False
 
 
 
 
 
 
1139
 
1140
+ self.norm = LayerNorm2d(num_features) if layer_norm_last else nn.BatchNorm2d(num_features)
 
 
 
 
1141
  self.avgpool = nn.AdaptiveAvgPool2d(1)
1142
+ self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
 
 
1143
  self.apply(self._init_weights)
 
1144
 
1145
  def _init_weights(self, m):
1146
  if isinstance(m, nn.Linear):
1147
+ trunc_normal_(m.weight, std=.02)
1148
  if isinstance(m, nn.Linear) and m.bias is not None:
1149
  nn.init.constant_(m.bias, 0)
1150
  elif isinstance(m, nn.LayerNorm):
 
1157
  nn.init.ones_(m.weight)
1158
  nn.init.zeros_(m.bias)
1159
 
1160
+ @torch.jit.ignore
1161
+ def no_weight_decay_keywords(self):
1162
+ return {'rpb'}
1163
+
1164
+ def forward_features(self, x):
1165
+ x = self.patch_embed(x)
1166
+ full_features = None
1167
+ for il, level in enumerate(self.levels):
1168
+ x, pre_downsample_x = level(x)
1169
+
1170
+ if self.return_full_features or self.use_neck:
1171
+ full_features = self.high_res_neck(pre_downsample_x, il, full_features)
1172
+
1173
+ # x = self.norm(full_features if (self.return_full_features or self.use_neck) else x)
1174
+ x = self.norm(x) # new version for
1175
+
1176
+ if not self.return_full_features:
1177
+ return x, None
1178
+
1179
+ return x, full_features
1180
+
1181
+ def forward(self, x):
1182
+ x, full_features = self.forward_features(x)
1183
+
1184
+ x = self.avgpool(x)
1185
+ x = torch.flatten(x, 1)
1186
+
1187
+ x = self.head(x)
1188
+ if full_features is not None:
1189
+ return x, full_features
1190
+ return x
1191
+
1192
+ def switch_to_deploy(self):
1193
+ '''
1194
+ A method to perform model self-compression
1195
+ merges BN into conv layers
1196
+ converts MLP relative positional bias into precomputed buffers
1197
+ '''
1198
+ if not self.switched_to_deploy:
1199
+ for level in [self.patch_embed, self.levels, self.head]:
1200
+ for module in level.modules():
1201
+ if hasattr(module, 'switch_to_deploy'):
1202
+ module.switch_to_deploy()
1203
+ self.switched_to_deploy = True
1204
+
1205
+
1206
  def change_window_size(self, new_window_size):
1207
  """
1208
+ FasterViT employs windowed attention, which may be sensitive to the choice of this parameter,
1209
+ especially in cases of uneven partitioning of the feature maps.
1210
+ FasterViT allows for the adjustment of the window size after training,
1211
+ making it adaptable to different input image resolutions.
1212
+ The recommended values for window size based on input resolution are as follows:
1213
+
1214
+ Input Resolution | Window Size
1215
+ 224 | 7
1216
+ 256 | 8
1217
+ 386 | 12
1218
+ 512 | 16
1219
+ Ideally, the window size should be a factor of the input resolution. In the third stage, we divide the resolution by 16, so the window size should be
1220
+ img_res/16/2
1221
+ for the third stage and img_res/32 for the last stage. While this can be applied in a brute-force manner, a better way is to do model.change_window_size.
1222
+ Manual way to change resolution -> model.change_window_size(resolution)
1223
  """
1224
  window_size = new_window_size
1225
+ print(f"Setting window size to {window_size}")
1226
  for module in self.modules():
1227
  if hasattr(module, "window_size"):
1228
  # check if tuple or a number
 
1235
  else:
1236
  module.window_size = window_size
1237
 
1238
+
1239
+ def set_optimal_window_size(self, image_dim, max_window_size = 16):
1240
  """
1241
  Using hand picked window size for various resolutions.
1242
+
1243
+ FasterViT employs windowed attention, which may be sensitive to the choice of this parameter,
1244
+ especially in cases of uneven partitioning of the feature maps.
1245
+ FasterViT allows for the adjustment of the window size after training,
1246
+ making it adaptable to different input image resolutions.
1247
+ The recommended values for window size based on input resolution are as follows:
1248
+
1249
+ Input Resolution | Window Size
1250
+ 224 | 7
1251
+ 256 | 8
1252
+ 386 | 12
1253
+ 512 | 16
1254
+ Ideally, the window size should be a factor of the input resolution. In the third stage, we divide the resolution by 16, so the window size should be
1255
+ img_res/16/2
1256
+ for the third stage and img_res/32 for the last stage. While this can be applied in a brute-force manner, a better way is to do model.change_window_size.
1257
+ Manual way to change resolution -> model.change_window_size(resolution)
1258
+
1259
  """
1260
+ # import math
1261
+
1262
+ def divisorGenerator(n):
1263
+ large_divisors = []
1264
+ for i in range(1, int(math.sqrt(n) + 1)):
1265
+ if n % i == 0:
1266
+ yield i
1267
+ if i*i != n:
1268
+ large_divisors.append(n / i)
1269
+ for divisor in reversed(large_divisors):
1270
+ yield divisor
1271
+
1272
  if isinstance(image_dim, list) or isinstance(image_dim, tuple):
1273
  image_dim = min(image_dim)
1274
 
1275
+ # we do windowed attention in the 3rd stage for the first time, therefore //16,
1276
+ # we do subsampled attention with downsample by 2 so need to get //32 actually
1277
+ # ideally we should rewrite this to be dependent on the structure of the model like what if subsampled is removed etc
1278
+ all_divisors = np.array(list(divisorGenerator(image_dim//32)))
1279
+ new_window_size = int(min(all_divisors[all_divisors <= max_window_size][-1], max_window_size))
1280
+
1281
+ # for image_dim in [128, 224, 256, 384, 512, 768, 1024]:
1282
+ # all_divisors = np.array(list(divisorGenerator(image_dim//32)))
1283
+ # new_window_size = int(min(all_divisors[all_divisors <= max_window_size][-1], max_window_size))
1284
+ # print(f"Setting window size to {new_window_size} for image resolution {image_dim}")
 
 
 
1285
 
 
1286
  self.change_window_size(new_window_size = new_window_size)
1287
 
1288
+ # 83.44200001953125
1289
+ @register_model
1290
+ def fastervit2_small(pretrained=False, **kwargs): #,
1291
+ model = FasterViT(depths=[3, 3, 5, 5],
1292
+ num_heads=[2, 4, 8, 16],
1293
+ window_size=[8, 8, [7, 7], 7],
1294
+ dim=96,
1295
+ in_dim=64,
1296
+ mlp_ratio=4,
1297
+ drop_path_rate=0.2,
1298
+ sr_ratio=[1, 1, [1, 2], 1],
1299
+ use_swiglu=False,
1300
+ downsample_shuffle=False,
1301
+ yolo_arch=True,
1302
+ shuffle_down=False,
1303
+ **kwargs)
1304
+ if pretrained:
1305
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1306
+ return model
1307
 
1308
+ # 82.61
1309
+ @register_model
1310
+ def fastervit2_tiny(pretrained=False, **kwargs): #,
1311
+ model = FasterViT(depths=[1, 3, 4, 5],
1312
+ num_heads=[2, 4, 8, 16],
1313
+ window_size=[8, 8, [7, 7], 7],
1314
+ dim=80,
1315
+ in_dim=64,
1316
+ mlp_ratio=4,
1317
+ drop_path_rate=0.2,
1318
+ sr_ratio=[1, 1, [2, 1], 1],
1319
+ use_swiglu=False,
1320
+ downsample_shuffle=False,
1321
+ yolo_arch=True,
1322
+ shuffle_down=False,
1323
+ **kwargs)
1324
+ if pretrained:
1325
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1326
+ return model
1327
 
1328
+ #'top1', 84.31800001220704
1329
+ @register_model
1330
+ def fastervit2_base(pretrained=False, **kwargs):
1331
+ model = FasterViT(depths=[3, 3, 5, 5],
1332
+ num_heads=[2, 4, 8, 16],
1333
+ window_size=[8, 8, [7, 7], 7],
1334
+ dim=128,
1335
+ in_dim=64,
1336
+ mlp_ratio=4,
1337
+ drop_path_rate=0.2,
1338
+ sr_ratio=[1, 1, [2, 1], 1],
1339
+ use_swiglu=False,
1340
+ yolo_arch=True,
1341
+ shuffle_down=False,
1342
+ conv_base=True,
1343
+ **kwargs)
1344
+ if pretrained:
1345
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1346
+ return model
1347
 
1348
+ #84.39999999267579
1349
+ @register_model
1350
+ def fastervit2_base_v1(pretrained=False, **kwargs):
1351
+ model = FasterViT(depths=[4, 4, 5, 5],
1352
+ num_heads=[2, 4, 8, 16],
1353
+ window_size=[8, 8, [7, 7], 7],
1354
+ dim=128,
1355
+ in_dim=64,
1356
+ mlp_ratio=4,
1357
+ drop_path_rate=0.2,
1358
+ sr_ratio=[1, 1, [2, 1], 1],
1359
+ use_swiglu=False,
1360
+ yolo_arch=True,
1361
+ shuffle_down=False,
1362
+ conv_base=True,
1363
+ downsample_shuffle=False,
1364
+ **kwargs)
1365
+ if pretrained:
1366
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1367
+ return model
 
 
 
 
 
 
 
 
 
1368
 
1369
+ @register_model
1370
+ def fastervit2_base_fullres1(pretrained=False, **kwargs):
1371
+ model = FasterViT(depths=[3, 3, 5, 5],
1372
+ num_heads=[2, 4, 8, 16],
1373
+ window_size=[8, 8, [7, 7], 7],
1374
+ dim=128,
1375
+ in_dim=64,
1376
+ mlp_ratio=4,
1377
+ drop_path_rate=0.2,
1378
+ sr_ratio=[1, 1, [2, 1], 1],
1379
+ use_swiglu=False,
1380
+ yolo_arch=True,
1381
+ shuffle_down=False,
1382
+ conv_base=True,
1383
+ use_neck=True,
1384
+ full_features_head_dim=1024,
1385
+ neck_start_stage=2,
1386
+ **kwargs)
1387
+ if pretrained:
1388
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1389
+ return model
1390
 
1391
+ @register_model
1392
+ def fastervit2_base_fullres2(pretrained=False, **kwargs):
1393
+ model = FasterViT(depths=[3, 3, 5, 5],
1394
+ num_heads=[2, 4, 8, 16],
1395
+ window_size=[8, 8, [7, 7], 7],
1396
+ dim=128,
1397
+ in_dim=64,
1398
+ mlp_ratio=4,
1399
+ drop_path_rate=0.2,
1400
+ sr_ratio=[1, 1, [2, 1], 1],
1401
+ use_swiglu=False,
1402
+ yolo_arch=True,
1403
+ shuffle_down=False,
1404
+ conv_base=True,
1405
+ use_neck=True,
1406
+ full_features_head_dim=512,
1407
+ neck_start_stage=1,
1408
+ **kwargs)
1409
+ if pretrained:
1410
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1411
+ return model
1412
 
1413
+ @register_model
1414
+ def fastervit2_base_fullres3(pretrained=False, **kwargs):
1415
+ model = FasterViT(depths=[3, 3, 5, 5],
1416
+ num_heads=[2, 4, 8, 16],
1417
+ window_size=[8, 8, [7, 7], 7],
1418
+ dim=128,
1419
+ in_dim=64,
1420
+ mlp_ratio=4,
1421
+ drop_path_rate=0.2,
1422
+ sr_ratio=[1, 1, [2, 1], 1],
1423
+ use_swiglu=False,
1424
+ yolo_arch=True,
1425
+ shuffle_down=False,
1426
+ conv_base=True,
1427
+ use_neck=True,
1428
+ full_features_head_dim=256,
1429
+ neck_start_stage=1,
1430
+ **kwargs)
1431
+ if pretrained:
1432
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1433
+ return model
1434
 
1435
+ @register_model
1436
+ def fastervit2_base_fullres4(pretrained=False, **kwargs):
1437
+ model = FasterViT(depths=[3, 3, 5, 5],
1438
+ num_heads=[2, 4, 8, 16],
1439
+ window_size=[8, 8, [7, 7], 7],
1440
+ dim=128,
1441
+ in_dim=64,
1442
+ mlp_ratio=4,
1443
+ drop_path_rate=0.2,
1444
+ sr_ratio=[1, 1, [2, 1], 1],
1445
+ use_swiglu=False,
1446
+ yolo_arch=True,
1447
+ shuffle_down=False,
1448
+ conv_base=True,
1449
+ use_neck=True,
1450
+ full_features_head_dim=256,
1451
+ neck_start_stage=2,
1452
+ **kwargs)
1453
+ if pretrained:
1454
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1455
+ return model
1456
 
1457
+ @register_model
1458
+ def fastervit2_base_fullres5(pretrained=False, **kwargs):
1459
+ model = FasterViT(depths=[3, 3, 5, 5],
1460
+ num_heads=[2, 4, 8, 16],
1461
+ window_size=[8, 8, [7, 7], 7],
1462
+ dim=128,
1463
+ in_dim=64,
1464
+ mlp_ratio=4,
1465
+ drop_path_rate=0.2,
1466
+ sr_ratio=[1, 1, [2, 1], 1],
1467
+ use_swiglu=False,
1468
+ yolo_arch=True,
1469
+ shuffle_down=False,
1470
+ conv_base=True,
1471
+ use_neck=True,
1472
+ full_features_head_dim=512,
1473
+ neck_start_stage=2,
1474
+ **kwargs)
1475
+ if pretrained:
1476
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1477
+ return model
1478
 
1479
+ #84.87
1480
+ @register_model
1481
+ def fastervit2_large(pretrained=False, **kwargs):
1482
+ model = FasterViT(depths=[3, 3, 5, 5],
1483
+ num_heads=[2, 4, 8, 16],
1484
+ window_size=[8, 8, [7, 7], 7],
1485
+ dim=128+64,
1486
+ in_dim=64,
1487
+ mlp_ratio=4,
1488
+ drop_path_rate=0.3,
1489
+ sr_ratio=[1, 1, [2, 1], 1],
1490
+ use_swiglu=False,
1491
+ yolo_arch=False,
1492
+ shuffle_down=False,
1493
+ cpb_mlp_hidden=64,
1494
+ conv_base=True,
1495
+ **kwargs)
1496
+ if pretrained:
1497
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1498
+ return model
1499
 
1500
+ @register_model
1501
+ def fastervit2_large_fullres(pretrained=False, **kwargs):
1502
+ model = FasterViT(
1503
+ depths=[3, 3, 5, 5],
1504
+ num_heads=[2, 4, 8, 16],
1505
+ window_size=[None, None, [7, 7], 7],
1506
+ dim=192,
1507
+ in_dim=64,
1508
+ mlp_ratio=4,
1509
+ drop_path_rate=0.0,
1510
+ sr_ratio=[1, 1, [2, 1], 1],
1511
+ use_swiglu=False,
1512
+ yolo_arch=True,
1513
+ shuffle_down=False,
1514
+ conv_base=True,
1515
+ use_neck=True,
1516
+ full_features_head_dim=1536,
1517
+ neck_start_stage=2,
1518
+ **kwargs,
1519
+ )
1520
+ if pretrained:
1521
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1522
+ return model
1523
 
 
 
 
 
 
 
 
 
 
 
1524
 
1525
  @register_model
1526
  def fastervit2_large_fullres_ws8(pretrained=False, **kwargs):
 
1543
  **kwargs,
1544
  )
1545
  if pretrained:
1546
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1547
  return model
1548
 
1549
 
 
1568
  **kwargs,
1569
  )
1570
  if pretrained:
1571
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1572
  return model
1573
 
1574
 
 
1593
  **kwargs,
1594
  )
1595
  if pretrained:
1596
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1597
  return model
1598
 
1599
+ #85.23% top1
1600
+ @register_model
1601
+ def fastervit2_xlarge(pretrained=False, **kwargs):
1602
+ model = FasterViT(depths=[3, 3, 5, 5],
1603
+ num_heads=[2, 4, 8, 16],
1604
+ window_size=[8, 8, [7, 7], 7],
1605
+ dim=128+128+64,
1606
+ in_dim=64,
1607
+ mlp_ratio=4,
1608
+ drop_path_rate=0.4,
1609
+ sr_ratio=[1, 1, [2, 1], 1],
1610
+ use_swiglu=False,
1611
+ yolo_arch=False,
1612
+ shuffle_down=False,
1613
+ cpb_mlp_hidden=64,
1614
+ **kwargs)
1615
+ if pretrained:
1616
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1617
+ return model
1618
 
1619
  @register_model
1620
+ def fastervit2_huge(pretrained=False, **kwargs):
1621
+ model = FasterViT(depths=[3, 3, 5, 5],
1622
+ num_heads=[2, 4, 8, 16],
1623
+ window_size=[8, 8, [7, 7], 7],
1624
+ dim=128+128+128+64,
1625
+ in_dim=64,
1626
+ mlp_ratio=4,
1627
+ drop_path_rate=0.2,
1628
+ sr_ratio=[1, 1, [2, 1], 1],
1629
+ use_swiglu=False,
1630
+ yolo_arch=True,
1631
+ shuffle_down=False,
1632
+ **kwargs)
1633
+ if pretrained:
1634
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1635
+ return model
1636
+
1637
+
1638
+ # 81.61
1639
+ @register_model
1640
+ def fastervit2_xtiny(pretrained=False, **kwargs): #,
1641
+ model = FasterViT(depths=[1, 3, 4, 5],
1642
+ num_heads=[2, 4, 8, 16],
1643
+ window_size=[8, 8, [7, 7], 7],
1644
+ dim=64,
1645
+ in_dim=64,
1646
+ mlp_ratio=4,
1647
+ drop_path_rate=0.1,
1648
+ sr_ratio=[1, 1, [2, 1], 1],
1649
+ use_swiglu=False,
1650
+ downsample_shuffle=False,
1651
+ yolo_arch=True,
1652
+ shuffle_down=False,
1653
+ cpb_mlp_hidden=64,
1654
+ **kwargs)
1655
+ if pretrained:
1656
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1657
+ return model
1658
 
 
 
 
 
1659
 
1660
+ # 80.19
1661
+ @register_model
1662
+ def fastervit2_xxtiny(pretrained=False, **kwargs): #,
1663
+ model = FasterViT(depths=[1, 3, 4, 5],
1664
+ num_heads=[2, 4, 8, 16],
1665
+ window_size=[8, 8, [7, 7], 7],
1666
+ dim=48,
1667
+ in_dim=64,
1668
+ mlp_ratio=4,
1669
+ drop_path_rate=0.05,
1670
+ sr_ratio=[1, 1, [2, 1], 1],
1671
+ use_swiglu=False,
1672
+ downsample_shuffle=False,
1673
+ yolo_arch=True,
1674
+ shuffle_down=False,
1675
+ cpb_mlp_hidden=64,
1676
+ **kwargs)
1677
+ if pretrained:
1678
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1679
+ return model
1680
+
1681
+ @register_model
1682
+ # 77.0
1683
+ def fastervit2_xxxtiny(pretrained=False, **kwargs): #,
1684
+ model = FasterViT(depths=[1, 3, 4, 5],
1685
+ num_heads=[2, 4, 8, 16],
1686
+ window_size=[8, 8, [7, 7], 7],
1687
+ dim=32,
1688
+ in_dim=32,
1689
+ mlp_ratio=4,
1690
+ drop_path_rate=0.0,
1691
+ sr_ratio=[1, 1, [2, 1], 1],
1692
+ use_swiglu=False,
1693
+ downsample_shuffle=False,
1694
+ yolo_arch=True,
1695
+ shuffle_down=False,
1696
+ cpb_mlp_hidden=64,
1697
+ **kwargs)
1698
+ if pretrained:
1699
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1700
+ return model
1701
+
1702
+
1703
+ @register_model
1704
+ def fastervit2_xxxtiny_fullres(pretrained=False, **kwargs):
1705
+ model = FasterViT(depths=[1, 3, 4, 5],
1706
+ num_heads=[2, 4, 8, 16],
1707
+ window_size=[8, 8, [7, 7], 7],
1708
+ dim=32,
1709
+ in_dim=32,
1710
+ mlp_ratio=4,
1711
+ drop_path_rate=0.0,
1712
+ sr_ratio=[1, 1, [2, 1], 1],
1713
+ use_swiglu=False,
1714
+ downsample_shuffle=False,
1715
+ yolo_arch=True,
1716
+ shuffle_down=False,
1717
+ cpb_mlp_hidden=64,
1718
+ use_neck=True,
1719
+ full_features_head_dim=128,
1720
+ neck_start_stage=1,
1721
+ conv_groups_ratio = 1,
1722
+ **kwargs)
1723
+ if pretrained:
1724
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1725
+ return model
1726
+
1727
+ @register_model
1728
+ def eradio_xxxtiny(pretrained=False, **kwargs): # ,
1729
+ model = FasterViT(
1730
+ depths=[1, 3, 4, 5],
1731
+ num_heads=[2, 4, 8, 16],
1732
+ window_size=[None, None, [16, 16], 16],
1733
+ dim=32,
1734
+ in_dim=32,
1735
+ mlp_ratio=4,
1736
+ drop_path_rate=0.0,
1737
+ sr_ratio=[1, 1, [2, 1], 1],
1738
+ use_swiglu=False,
1739
+ yolo_arch=True,
1740
+ shuffle_down=False,
1741
+ conv_base=True,
1742
+ use_neck=True,
1743
+ full_features_head_dim=256,
1744
+ neck_start_stage=2,
1745
+ **kwargs,
1746
+ )
1747
+ if pretrained:
1748
+ model.load_state_dict(torch.load(pretrained))
1749
+ return model
1750
 
1751
+ @register_model
1752
+ def eradio_xxxtiny_8x_ws12(pretrained=False, **kwargs):
1753
+ model = FasterViT(depths=[1, 3, 4, 5],
1754
+ num_heads=[2, 4, 8, 16],
1755
+ window_size=[None, None, [12, 12], 12],
1756
+ dim=32,
1757
+ in_dim=32,
1758
+ mlp_ratio=4,
1759
+ drop_path_rate=0.0,
1760
+ sr_ratio=[1, 1, [2, 1], 1],
1761
+ use_swiglu=False,
1762
+ downsample_shuffle=False,
1763
+ yolo_arch=True,
1764
+ shuffle_down=False,
1765
+ cpb_mlp_hidden=64,
1766
+ use_neck=True,
1767
+ full_features_head_dim=256,
1768
+ neck_start_stage=2,
1769
+ conv_groups_ratio = 1,
1770
+ **kwargs)
1771
+ if pretrained:
1772
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1773
+ return model
1774
 
 
 
1775
 
1776
+ @register_model
1777
+ def eradio_xxxtiny_8x_ws16(pretrained=False, **kwargs):
1778
+ model = FasterViT(depths=[1, 3, 4, 5],
1779
+ num_heads=[2, 4, 8, 16],
1780
+ window_size=[None, None, [16, 16], 16],
1781
+ dim=32,
1782
+ in_dim=32,
1783
+ mlp_ratio=4,
1784
+ drop_path_rate=0.0,
1785
+ sr_ratio=[1, 1, [2, 1], 1],
1786
+ use_swiglu=False,
1787
+ downsample_shuffle=False,
1788
+ yolo_arch=True,
1789
+ shuffle_down=False,
1790
+ cpb_mlp_hidden=64,
1791
+ use_neck=True,
1792
+ full_features_head_dim=256,
1793
+ neck_start_stage=1,
1794
+ conv_groups_ratio = 1,
1795
+ **kwargs)
1796
+ if pretrained:
1797
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1798
+ return model
1799
 
1800
+ @register_model
1801
+ def eradio(pretrained=False, **kwargs):
1802
+ return fastervit2_large_fullres_ws16(pretrained=pretrained, **kwargs)
hf_model.py CHANGED
@@ -12,35 +12,56 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  from collections import namedtuple
15
- from typing import Optional
16
 
17
  from timm.models import VisionTransformer
18
  import torch
19
  from transformers import PretrainedConfig, PreTrainedModel
20
 
21
 
 
 
 
22
  from .eradio_model import eradio
23
  from .radio_model import create_model_from_args
24
- from .radio_model import RADIOModel as RADIOModelBase
25
  from .input_conditioner import get_default_conditioner, InputConditioner
 
 
26
  # Register extra models
27
  from .extra_timm_models import *
28
 
 
29
  class RADIOConfig(PretrainedConfig):
30
  """Pretrained Hugging Face configuration for RADIO models."""
31
 
32
  def __init__(
33
  self,
34
  args: Optional[dict] = None,
35
- version: Optional[str] = "v1",
36
- return_summary: Optional[bool] = True,
37
- return_spatial_features: Optional[bool] = True,
 
 
 
38
  **kwargs,
39
  ):
40
  self.args = args
 
 
 
 
 
 
41
  self.version = version
42
- self.return_summary = return_summary
43
- self.return_spatial_features = return_spatial_features
 
 
 
 
 
 
44
  super().__init__(**kwargs)
45
 
46
 
@@ -59,14 +80,39 @@ class RADIOModel(PreTrainedModel):
59
  RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
60
  args = RADIOArgs(**config.args)
61
  self.config = config
 
62
  model = create_model_from_args(args)
63
  input_conditioner: InputConditioner = get_default_conditioner()
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  self.radio_model = RADIOModelBase(
66
  model,
67
  input_conditioner,
68
- config.return_summary,
69
- config.return_spatial_features,
 
 
 
 
70
  )
71
 
72
  @property
@@ -79,62 +125,3 @@ class RADIOModel(PreTrainedModel):
79
 
80
  def forward(self, x: torch.Tensor):
81
  return self.radio_model.forward(x)
82
-
83
-
84
- class ERADIOConfig(PretrainedConfig):
85
- """Pretrained Hugging Face configuration for ERADIO models."""
86
-
87
- def __init__(
88
- self,
89
- args: Optional[dict] = None,
90
- version: Optional[str] = "v1",
91
- return_summary: Optional[bool] = True,
92
- return_spatial_features: Optional[bool] = True,
93
- **kwargs,
94
- ):
95
- self.args = args
96
- self.version = version
97
- self.return_summary = return_summary
98
- self.return_spatial_features = return_spatial_features
99
- super().__init__(**kwargs)
100
-
101
-
102
- class ERADIOModel(PreTrainedModel):
103
- """Pretrained Hugging Face model for ERADIO.
104
-
105
- This class inherits from PreTrainedModel, which provides
106
- HuggingFace's functionality for loading and saving models.
107
- """
108
-
109
- config_class = ERADIOConfig
110
-
111
- def __init__(self, config):
112
- super().__init__(config)
113
-
114
- config.args["in_chans"] = 3
115
- config.args["num_classes"] = 0
116
- config.args["return_full_features"] = config.return_spatial_features
117
-
118
- self.config = config
119
- model = eradio(**config.args)
120
- self.input_conditioner: InputConditioner = get_default_conditioner()
121
- self.return_summary = config.return_summary
122
- self.return_spatial_features = config.return_spatial_features
123
- self.model = model
124
-
125
- def forward(self, x: torch.Tensor):
126
- x = self.input_conditioner(x)
127
- y = self.model.forward_features(x)
128
- summary, features = self.model.forward_features(x)
129
-
130
- if isinstance(y, tuple):
131
- summary, features = y
132
- else:
133
- summary = y
134
- features = None
135
-
136
- if self.return_summary and self.return_spatial_features:
137
- return summary, features
138
- elif self.return_summary:
139
- return summary
140
- return features
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  from collections import namedtuple
15
+ from typing import Optional, List, Union
16
 
17
  from timm.models import VisionTransformer
18
  import torch
19
  from transformers import PretrainedConfig, PreTrainedModel
20
 
21
 
22
+ from .common import RESOURCE_MAP, DEFAULT_VERSION
23
+
24
+ # Force import of eradio_model in order to register it.
25
  from .eradio_model import eradio
26
  from .radio_model import create_model_from_args
27
+ from .radio_model import RADIOModel as RADIOModelBase, Resolution
28
  from .input_conditioner import get_default_conditioner, InputConditioner
29
+
30
+
31
  # Register extra models
32
  from .extra_timm_models import *
33
 
34
+
35
  class RADIOConfig(PretrainedConfig):
36
  """Pretrained Hugging Face configuration for RADIO models."""
37
 
38
  def __init__(
39
  self,
40
  args: Optional[dict] = None,
41
+ version: Optional[str] = DEFAULT_VERSION,
42
+ patch_size: Optional[int] = None,
43
+ max_resolution: Optional[int] = None,
44
+ preferred_resolution: Optional[Resolution] = None,
45
+ adaptor_names: Union[str, List[str]] = None,
46
+ vitdet_window_size: Optional[int] = None,
47
  **kwargs,
48
  ):
49
  self.args = args
50
+ for field in ["dtype", "amp_dtype"]:
51
+ if self.args is not None and field in self.args:
52
+ # Convert to a string in order to make it serializable.
53
+ # For example for torch.float32 we will store "float32",
54
+ # for "bfloat16" we will store "bfloat16".
55
+ self.args[field] = str(args[field]).split(".")[-1]
56
  self.version = version
57
+ resource = RESOURCE_MAP[version]
58
+ self.patch_size = patch_size or resource.patch_size
59
+ self.max_resolution = max_resolution or resource.max_resolution
60
+ self.preferred_resolution = (
61
+ preferred_resolution or resource.preferred_resolution
62
+ )
63
+ self.adaptor_names = adaptor_names
64
+ self.vitdet_window_size = vitdet_window_size
65
  super().__init__(**kwargs)
66
 
67
 
 
80
  RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
81
  args = RADIOArgs(**config.args)
82
  self.config = config
83
+
84
  model = create_model_from_args(args)
85
  input_conditioner: InputConditioner = get_default_conditioner()
86
 
87
+ dtype = getattr(args, "dtype", torch.float32)
88
+ if isinstance(dtype, str):
89
+ # Convert the dtype's string representation back to a dtype.
90
+ dtype = getattr(torch, dtype)
91
+ model.to(dtype=dtype)
92
+ input_conditioner.dtype = dtype
93
+
94
+ summary_idxs = torch.tensor(
95
+ [i for i, t in enumerate(args.teachers) if t.get("use_summary", True)],
96
+ dtype=torch.int64,
97
+ )
98
+
99
+ adaptor_names = config.adaptor_names
100
+ if adaptor_names is not None:
101
+ raise NotImplementedError(
102
+ f"Adaptors are not yet supported in Hugging Face models. Adaptor names: {adaptor_names}"
103
+ )
104
+
105
+ adaptors = dict()
106
+
107
  self.radio_model = RADIOModelBase(
108
  model,
109
  input_conditioner,
110
+ summary_idxs=summary_idxs,
111
+ patch_size=config.patch_size,
112
+ max_resolution=config.max_resolution,
113
+ window_size=config.vitdet_window_size,
114
+ preferred_resolution=config.preferred_resolution,
115
+ adaptors=adaptors,
116
  )
117
 
118
  @property
 
125
 
126
  def forward(self, x: torch.Tensor):
127
  return self.radio_model.forward(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
input_conditioner.py CHANGED
@@ -19,20 +19,20 @@ class InputConditioner(nn.Module):
19
  input_scale: float,
20
  norm_mean: norm_t,
21
  norm_std: norm_t,
22
- dtype: torch.dtype = torch.float32,
23
  ):
24
  super().__init__()
25
 
26
  self.dtype = dtype
27
 
28
- # self.input_scale = input_scale
29
  self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
30
  self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
31
 
32
  def forward(self, x: torch.Tensor):
33
- # x = x * self.input_scale
34
  y = (x - self.norm_mean) / self.norm_std
35
- return y.to(self.dtype)
 
 
36
 
37
 
38
  def get_default_conditioner():
 
19
  input_scale: float,
20
  norm_mean: norm_t,
21
  norm_std: norm_t,
22
+ dtype: torch.dtype = None,
23
  ):
24
  super().__init__()
25
 
26
  self.dtype = dtype
27
 
 
28
  self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
29
  self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
30
 
31
  def forward(self, x: torch.Tensor):
 
32
  y = (x - self.norm_mean) / self.norm_std
33
+ if self.dtype is not None:
34
+ y = y.to(self.dtype)
35
+ return y
36
 
37
 
38
  def get_default_conditioner():
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9079d79a8948849416e84a25d9318e020e719dbe6f8c16a13d674f8e1f5e6b88
3
+ size 1614710336
radio_model.py CHANGED
@@ -5,7 +5,7 @@
5
  # and any modifications thereto. Any use, reproduction, disclosure or
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
- from typing import Optional
9
 
10
  import torch
11
  from torch import nn
@@ -16,6 +16,13 @@ from .enable_cpe_support import enable_cpe
16
  from .input_conditioner import InputConditioner
17
  # Register extra models
18
  from . import extra_timm_models
 
 
 
 
 
 
 
19
 
20
 
21
  class RADIOModel(nn.Module):
@@ -23,28 +30,32 @@ class RADIOModel(nn.Module):
23
  self,
24
  model: nn.Module,
25
  input_conditioner: InputConditioner,
26
- return_summary: bool,
27
- return_spatial_features: bool,
 
28
  summary_idxs: Optional[torch.Tensor] = None,
 
 
29
  ):
30
  super().__init__()
31
 
32
  self.model = model
33
  self.input_conditioner = input_conditioner
34
- self.return_summary = return_summary
35
- self.return_spatial_features = return_spatial_features
36
- self.summary_select_idx = -1
37
  if summary_idxs is not None:
38
  self.register_buffer('summary_idxs', summary_idxs)
39
  else:
40
  self.summary_idxs = None
41
 
42
- @property
43
- def return_both(self):
44
- return self.return_summary and self.return_spatial_features
 
 
 
 
45
 
46
  @property
47
- def num_summary_tokens(self):
48
  patch_gen = getattr(self.model, "patch_generator", None)
49
  if patch_gen is not None:
50
  return patch_gen.num_skip
@@ -52,38 +63,94 @@ class RADIOModel(nn.Module):
52
  return 0
53
  return 1
54
 
55
- def forward(self, x: torch.Tensor):
56
- x = self.input_conditioner(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
 
 
 
 
 
 
 
 
 
58
  y = self.model.forward_features(x)
59
 
60
- if isinstance(y, (list, tuple)):
61
- summary, all_feat = y
62
- elif isinstance(self.model, VisionTransformer):
63
  patch_gen = getattr(self.model, "patch_generator", None)
64
  if patch_gen is not None:
65
- summary = y[:, : patch_gen.num_cls_tokens]
66
- if self.summary_select_idx >= 0:
67
- summary = summary[:, self.summary_select_idx]
68
- elif self.summary_idxs is not None:
69
- summary = summary[:, self.summary_idxs].flatten(1)
70
  else:
71
- summary = summary.flatten(1)
72
  all_feat = y[:, patch_gen.num_skip :]
73
  elif self.model.global_pool == "avg":
74
- summary = y[:, self.model.num_prefix_tokens :].mean(dim=1)
 
75
  all_feat = y
76
  else:
77
- summary = y[:, 0]
 
78
  all_feat = y[:, 1:]
 
 
 
 
 
 
 
 
79
  else:
80
  raise ValueError("Unsupported model type")
81
 
82
- if self.return_both:
83
- return summary, all_feat
84
- elif self.return_summary:
85
- return summary
86
- return all_feat
 
 
 
 
 
 
 
 
 
87
 
88
 
89
  def create_model_from_args(args) -> nn.Module:
 
5
  # and any modifications thereto. Any use, reproduction, disclosure or
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from typing import Optional, Callable, Union, Tuple, Any, Dict, NamedTuple
9
 
10
  import torch
11
  from torch import nn
 
16
  from .input_conditioner import InputConditioner
17
  # Register extra models
18
  from . import extra_timm_models
19
+ from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
20
+ from . import eradio_model
21
+
22
+
23
+ class Resolution(NamedTuple):
24
+ height: int
25
+ width: int
26
 
27
 
28
  class RADIOModel(nn.Module):
 
30
  self,
31
  model: nn.Module,
32
  input_conditioner: InputConditioner,
33
+ patch_size: int,
34
+ max_resolution: int,
35
+ preferred_resolution: Resolution,
36
  summary_idxs: Optional[torch.Tensor] = None,
37
+ window_size: int = None,
38
+ adaptors: Dict[str, AdaptorBase] = None,
39
  ):
40
  super().__init__()
41
 
42
  self.model = model
43
  self.input_conditioner = input_conditioner
 
 
 
44
  if summary_idxs is not None:
45
  self.register_buffer('summary_idxs', summary_idxs)
46
  else:
47
  self.summary_idxs = None
48
 
49
+ self._preferred_resolution = preferred_resolution
50
+ self._patch_size = patch_size
51
+ self._max_resolution = max_resolution
52
+ self._window_size = window_size
53
+
54
+ adaptors = adaptors or dict()
55
+ self.adaptors = nn.ModuleDict(adaptors)
56
 
57
  @property
58
+ def num_summary_tokens(self) -> int:
59
  patch_gen = getattr(self.model, "patch_generator", None)
60
  if patch_gen is not None:
61
  return patch_gen.num_skip
 
63
  return 0
64
  return 1
65
 
66
+ @property
67
+ def patch_size(self) -> int:
68
+ return self._patch_size
69
+
70
+ @property
71
+ def max_resolution(self) -> int:
72
+ return self._max_resolution
73
+
74
+ @property
75
+ def preferred_resolution(self) -> Resolution:
76
+ return self._preferred_resolution
77
+
78
+ @property
79
+ def window_size(self) -> int:
80
+ return self._window_size
81
+
82
+ @property
83
+ def min_resolution_step(self) -> int:
84
+ res = self.patch_size
85
+ if self.window_size is not None:
86
+ res *= self.window_size
87
+ return res
88
+
89
+ def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
90
+ ret = self.input_conditioner
91
+ self.input_conditioner = nn.Identity()
92
+ return ret
93
+
94
+ def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:
95
+ height = int(round(height / self.min_resolution_step) * self.min_resolution_step)
96
+ width = int(round(width / self.min_resolution_step) * self.min_resolution_step)
97
+
98
+ height = max(height, self.min_resolution_step)
99
+ width = max(width, self.min_resolution_step)
100
 
101
+ return Resolution(height=height, width=width)
102
+
103
+ def switch_to_deploy(self):
104
+ fn = getattr(self.model, 'switch_to_deploy', None)
105
+ if fn is not None:
106
+ fn()
107
+
108
+ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
109
+ x = self.input_conditioner(x)
110
  y = self.model.forward_features(x)
111
 
112
+ if isinstance(self.model, VisionTransformer):
 
 
113
  patch_gen = getattr(self.model, "patch_generator", None)
114
  if patch_gen is not None:
115
+ all_summary = y[:, : patch_gen.num_cls_tokens]
116
+ if self.summary_idxs is not None:
117
+ bb_summary = all_summary[:, self.summary_idxs]
 
 
118
  else:
119
+ bb_summary = all_summary
120
  all_feat = y[:, patch_gen.num_skip :]
121
  elif self.model.global_pool == "avg":
122
+ all_summary = y[:, self.model.num_prefix_tokens :].mean(dim=1)
123
+ bb_summary = all_summary
124
  all_feat = y
125
  else:
126
+ all_summary = y[:, 0]
127
+ bb_summary = all_summary
128
  all_feat = y[:, 1:]
129
+ elif isinstance(self.model, eradio_model.FasterViT):
130
+ _, f = y
131
+ all_feat = f.flatten(2).transpose(1, 2)
132
+ all_summary = all_feat.mean(dim=1)
133
+ bb_summary = all_summary
134
+ elif isinstance(y, (list, tuple)):
135
+ all_summary, all_feat = y
136
+ bb_summary = all_summary
137
  else:
138
  raise ValueError("Unsupported model type")
139
 
140
+ all_feat = all_feat.float()
141
+ ret = RadioOutput(bb_summary.flatten(1), all_feat).to(torch.float32)
142
+ if self.adaptors:
143
+ ret = dict(backbone=ret)
144
+ for name, adaptor in self.adaptors.items():
145
+ if all_summary.ndim == 3:
146
+ summary = all_summary[:, adaptor.head_idx]
147
+ else:
148
+ summary = all_summary
149
+ ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat)
150
+ v = adaptor(ada_input).to(torch.float32)
151
+ ret[name] = v
152
+
153
+ return ret
154
 
155
 
156
  def create_model_from_args(args) -> nn.Module: