gheinrich commited on
Commit
f6d64da
1 Parent(s): cc0d2ab

Upload model

Browse files
cls_token.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
enable_cpe_support.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
eradio_model.py CHANGED
@@ -1,6 +1,6 @@
1
  #!/usr/bin/env python3
2
 
3
- # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
4
  #
5
  # NVIDIA CORPORATION and its licensors retain all intellectual property
6
  # and proprietary rights in and to this software, related documentation
@@ -8,8 +8,12 @@
8
  # distribution of this software and related documentation without an express
9
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
 
11
- # Created by Pavlo Molchanov, LPR - DL Efficiency Research team
12
- # based on Fastervit1 from LPR
 
 
 
 
13
 
14
  import torch
15
  import torch.nn as nn
@@ -18,15 +22,105 @@ from timm.models.registry import register_model
18
  from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
19
  import numpy as np
20
  import torch.nn.functional as F
21
- from .block import C2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- TRT = False # should help for TRT
24
 
25
- import pickle
 
 
26
 
27
- global bias_indx
28
- bias_indx = -1
29
- DEBUG = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  def pixel_unshuffle(data, factor=2):
@@ -63,6 +157,7 @@ def window_partition(x, window_size):
63
  else:
64
  pad_h = (window_size - H % window_size) % window_size
65
  pad_w = (window_size - W % window_size) % window_size
 
66
  if pad_h > 0 or pad_w > 0:
67
  x = F.pad(x, (0, pad_w, 0, pad_h, 0, 0, 0, 0))
68
  Hp, Wp = H + pad_h, W + pad_w
@@ -106,8 +201,6 @@ class Conv2d_BN(nn.Module):
106
 
107
  @torch.no_grad()
108
  def switch_to_deploy(self):
109
-
110
- # return 1
111
  if not isinstance(self.bn, nn.Identity):
112
  c, bn = self.conv, self.bn
113
  w = bn.weight / (bn.running_var + bn.eps) ** 0.5
@@ -149,25 +242,47 @@ def window_reverse(windows, window_size, H, W, pad_hw):
149
 
150
 
151
  class PosEmbMLPSwinv2D(nn.Module):
 
 
 
 
152
  def __init__(
153
- self, window_size, pretrained_window_size, num_heads, seq_length, no_log=False
154
  ):
155
  super().__init__()
156
  self.window_size = window_size
157
  self.num_heads = num_heads
158
  # mlp to generate continuous relative position bias
159
  self.cpb_mlp = nn.Sequential(
160
- nn.Linear(2, 512, bias=True),
161
  nn.ReLU(inplace=True),
162
- nn.Linear(512, num_heads, bias=False),
163
  )
164
 
165
- # get relative_coords_table
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  relative_coords_h = torch.arange(
167
- -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32
168
  )
169
  relative_coords_w = torch.arange(
170
- -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32
171
  )
172
  relative_coords_table = (
173
  torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
@@ -190,8 +305,6 @@ class PosEmbMLPSwinv2D(nn.Module):
190
  / np.log2(8)
191
  )
192
 
193
- self.register_buffer("relative_coords_table", relative_coords_table)
194
-
195
  # get pair-wise relative position index for each token inside the window
196
  coords_h = torch.arange(self.window_size[0])
197
  coords_w = torch.arange(self.window_size[1])
@@ -207,15 +320,13 @@ class PosEmbMLPSwinv2D(nn.Module):
207
  relative_coords[:, :, 1] += self.window_size[1] - 1
208
  relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
209
  relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
210
- self.register_buffer("relative_position_index", relative_position_index)
211
 
212
- self.grid_exists = False
213
 
214
- self.deploy = False
 
 
215
 
216
- relative_bias = torch.zeros(1, num_heads, seq_length, seq_length)
217
- self.seq_length = seq_length
218
- self.register_buffer("relative_bias", relative_bias) # for EMA
219
 
220
  def switch_to_deploy(self):
221
  self.deploy = True
@@ -224,19 +335,25 @@ class PosEmbMLPSwinv2D(nn.Module):
224
  def forward(self, input_tensor):
225
  # for efficiency, we want this forward to be folded into a single operation (sum)
226
  # if resolution stays the same, then we dont need to recompute MLP layers
227
- #
228
- # to dynamically adjust patch size over the step
229
- # if not (input_tensor.shape[1:] == self.relative_bias.shape[1:]):
230
- # self.grid_exists = False
231
 
232
- if self.training:
233
  self.grid_exists = False
234
 
 
 
 
 
 
 
 
 
 
 
235
  if self.deploy and self.grid_exists:
236
  input_tensor += self.relative_bias
237
  return input_tensor
238
 
239
- if not self.grid_exists:
240
  self.grid_exists = True
241
 
242
  relative_position_bias_table = self.cpb_mlp(
@@ -279,142 +396,40 @@ class GRAAttentionBlock(nn.Module):
279
  conv_base=False,
280
  do_windowing=True,
281
  multi_query=False,
 
282
  ) -> None:
283
  super().__init__()
284
 
285
- dim = dim_in
286
- # conv_base = True
287
- SHUFFLE = True
288
- SHUFFLE = False
289
  self.do_windowing = do_windowing
290
 
291
  if do_windowing:
292
- if SHUFFLE:
293
- self.downsample_op = (
294
- torch.nn.PixelUnshuffle(subsample_ratio)
295
- if subsample_ratio > 1
296
- else torch.nn.Identity()
297
- )
298
- self.downsample_mixer = (
299
- nn.Conv2d(
300
- dim_in * (subsample_ratio * subsample_ratio),
301
- dim_in * (dim_ratio),
302
- kernel_size=1,
303
- stride=1,
304
- padding=0,
305
- bias=False,
306
- )
307
- if dim * dim_ratio != dim * subsample_ratio * subsample_ratio
308
- else torch.nn.Identity()
309
- )
310
- else:
311
- if conv_base:
312
- self.downsample_op = (
313
- nn.Conv2d(
314
- dim_in,
315
- dim_out,
316
- kernel_size=subsample_ratio,
317
- stride=subsample_ratio,
318
- )
319
- if subsample_ratio > 1
320
- else nn.Identity()
321
- )
322
  self.downsample_mixer = nn.Identity()
323
- else:
324
- self.downsample_op = (
325
- nn.AvgPool2d(
326
- kernel_size=subsample_ratio, stride=subsample_ratio
327
- )
328
- if subsample_ratio > 1
329
- else nn.Identity()
330
- )
331
- self.downsample_mixer = (
332
- Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1)
333
- if subsample_ratio > 1
334
- else nn.Identity()
335
- )
336
-
337
- if do_windowing:
338
- if SHUFFLE:
339
- self.upsample_mixer = (
340
- nn.Conv2d(
341
- dim_in * dim_ratio,
342
- dim_in * (subsample_ratio * subsample_ratio),
343
- kernel_size=1,
344
- stride=1,
345
- padding=0,
346
- bias=False,
347
- )
348
- if dim * dim_ratio != dim * subsample_ratio * subsample_ratio
349
- else torch.nn.Identity()
350
- )
351
- self.upsample_op = (
352
- torch.nn.PixelShuffle(subsample_ratio)
353
- if subsample_ratio > 1
354
- else torch.nn.Identity()
355
- )
356
- else:
357
- if conv_base:
358
  self.upsample_mixer = nn.Identity()
359
- self.upsample_op = (
360
- nn.ConvTranspose2d(
361
- dim_in,
362
- dim_out,
363
- kernel_size=subsample_ratio,
364
- stride=subsample_ratio,
365
- )
366
- if subsample_ratio > 1
367
- else nn.Identity()
368
- )
369
- else:
370
- self.upsample_mixer = (
371
- nn.Upsample(scale_factor=subsample_ratio, mode="nearest")
372
- if subsample_ratio > 1
373
- else nn.Identity()
374
- )
375
- self.upsample_op = (
376
- Conv2d_BN(
377
- dim_in,
378
- dim_out,
379
- kernel_size=1,
380
- stride=1,
381
- padding=0,
382
- bias=False,
383
- )
384
- if subsample_ratio > 1
385
- else nn.Identity()
386
- )
387
 
388
  self.window_size = window_size
389
 
390
  self.norm1 = norm_layer(dim_in)
391
- if DEBUG:
392
- print(
393
- f"GRAAttentionBlock: input_resolution: , window_size: {window_size}, dim_in: {dim_in}, dim_out: {dim_out}, num_heads: {num_heads}, drop_path: {drop_path}, qk_scale: {qk_scale}, qkv_bias: {qkv_bias}, layer_scale: {layer_scale}"
394
- )
395
 
396
  self.attn = WindowAttention(
397
  dim_in,
398
- num_heads=num_heads,
399
- qkv_bias=qkv_bias,
400
- qk_scale=qk_scale,
401
  resolution=window_size,
402
- seq_length=window_size ** 2,
403
- dim_out=dim_in,
404
- multi_query=multi_query,
405
- )
406
- if DEBUG:
407
- print(
408
- f"Attention: dim_in: {dim_in}, num_heads: {num_heads}, qkv_bias: {qkv_bias}, qk_scale: {qk_scale}, resolution: {window_size}, seq_length: {window_size**2}, dim_out: {dim_in}"
409
- )
410
- print(f"drop_path: {drop_path}, layer_scale: {layer_scale}")
411
 
412
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
413
 
414
  use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
415
- self.gamma1 = (
416
- nn.Parameter(layer_scale * torch.ones(dim_in)) if use_layer_scale else 1
417
- )
418
 
419
  ### mlp layer
420
  mlp_ratio = 4
@@ -422,26 +437,13 @@ class GRAAttentionBlock(nn.Module):
422
  mlp_hidden_dim = int(dim_in * mlp_ratio)
423
 
424
  activation = nn.GELU if not use_swiglu else SwiGLU
425
- mlp_hidden_dim = (
426
- int((4 * dim_in * 1 / 2) / 64) * 64 if use_swiglu else mlp_hidden_dim
427
- )
428
 
429
- self.mlp = Mlp(
430
- in_features=dim_in,
431
- hidden_features=mlp_hidden_dim,
432
- act_layer=activation,
433
- use_swiglu=use_swiglu,
434
- )
435
 
436
- self.gamma2 = (
437
- nn.Parameter(layer_scale * torch.ones(dim_in)) if layer_scale else 1
438
- )
439
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
440
- if DEBUG:
441
- print(
442
- f"MLP layer: dim_in: {dim_in}, dim_out: {dim_in}, mlp_hidden_dim: {mlp_hidden_dim}"
443
- )
444
- print(f"drop_path: {drop_path}, layer_scale: {layer_scale}")
445
 
446
  def forward(self, x):
447
  skip_connection = x
@@ -514,6 +516,7 @@ class MultiResolutionAttention(nn.Module):
514
  use_swiglu=True,
515
  multi_query=False,
516
  conv_base=False,
 
517
  ) -> None:
518
  """
519
  Args:
@@ -552,6 +555,7 @@ class MultiResolutionAttention(nn.Module):
552
  do_windowing=do_windowing,
553
  multi_query=multi_query,
554
  conv_base=conv_base,
 
555
  ),
556
  )
557
 
@@ -594,16 +598,13 @@ class Mlp(nn.Module):
594
  )
595
  self.act = act_layer()
596
  self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
597
- # self.drop = GaussianDropout(drop)
598
 
599
  def forward(self, x):
600
  x_size = x.size()
601
  x = x.view(-1, x_size[-1])
602
  x = self.fc1(x)
603
  x = self.act(x)
604
- # x = self.drop(x)
605
  x = self.fc2(x)
606
- # x = self.drop(x)
607
  x = x.view(x_size)
608
  return x
609
 
@@ -621,7 +622,7 @@ class Downsample(nn.Module):
621
  """
622
  Args:
623
  dim: feature size dimension.
624
- shuffle: idea with
625
  keep_dim: bool argument for maintaining the resolution.
626
  """
627
 
@@ -632,8 +633,6 @@ class Downsample(nn.Module):
632
  self.norm = lambda x: pixel_unshuffle(x, factor=2)
633
  self.reduction = Conv2d_BN(dim * 4, dim_out, 1, 1, 0, bias=False)
634
  else:
635
- # removed layer norm for better, in this formulation we are getting 10% better speed
636
- # LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension
637
  self.norm = nn.Identity()
638
  self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
639
 
@@ -646,6 +645,8 @@ class Downsample(nn.Module):
646
  class PatchEmbed(nn.Module):
647
  """
648
  Patch embedding block
 
 
649
  """
650
 
651
  def __init__(self, in_chans=3, in_dim=64, dim=96, shuffle_down=False):
@@ -669,12 +670,6 @@ class PatchEmbed(nn.Module):
669
  )
670
  else:
671
  self.proj = lambda x: pixel_unshuffle(x, factor=4)
672
-
673
- # self.conv_down = nn.Sequential(Conv2d_BN(in_chans*16, in_dim, 3, 1, 1),
674
- # nn.SiLU(),
675
- # Conv2d_BN(in_dim, dim, 3, 1, 1),
676
- # nn.SiLU(),
677
- # )
678
  self.conv_down = nn.Sequential(
679
  Conv2d_BN(in_chans * 16, dim, 3, 1, 1), nn.ReLU(),
680
  )
@@ -689,33 +684,19 @@ class ConvBlock(nn.Module):
689
  """
690
  Convolutional block, used in first couple of stages
691
  Experimented with plan resnet-18 like modules, they are the best in terms of throughput
692
- Experimented with RepVGG, dont see significant improvement in accuracy
693
  Finally, YOLOv8 idea seem to work fine (resnet-18 like block with squeezed feature dimension, and feature concatendation at the end)
694
  """
695
-
696
- def __init__(
697
- self, dim, drop_path=0.0, layer_scale=None, kernel_size=3, rep_vgg=False
698
- ):
 
699
  super().__init__()
700
- self.rep_vgg = rep_vgg
701
- if not rep_vgg:
702
- self.conv1 = Conv2d_BN(
703
- dim, dim, kernel_size=kernel_size, stride=1, padding=1
704
- )
705
- self.act1 = nn.GELU()
706
- else:
707
- self.conv1 = RepVGGBlock(
708
- dim, dim, kernel_size=kernel_size, stride=1, padding=1, groups=1
709
- )
710
 
711
- if not rep_vgg:
712
- self.conv2 = Conv2d_BN(
713
- dim, dim, kernel_size=kernel_size, stride=1, padding=1
714
- )
715
- else:
716
- self.conv2 = RepVGGBlock(
717
- dim, dim, kernel_size=kernel_size, stride=1, padding=1, groups=1
718
- )
719
 
720
  self.layer_scale = layer_scale
721
  if layer_scale is not None and type(layer_scale) in [int, float]:
@@ -723,17 +704,15 @@ class ConvBlock(nn.Module):
723
  self.layer_scale = True
724
  else:
725
  self.layer_scale = False
726
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
727
 
728
  def forward(self, x):
729
  input = x
730
- if not self.rep_vgg:
731
- x = self.conv1(x)
732
- x = self.act1(x)
733
- x = self.conv2(x)
734
- else:
735
- x = self.conv1(x)
736
- x = self.conv2(x)
737
  if self.layer_scale:
738
  x = x * self.gamma.view(1, -1, 1, 1)
739
  x = input + self.drop_path(x)
@@ -743,9 +722,6 @@ class ConvBlock(nn.Module):
743
  class WindowAttention(nn.Module):
744
  # Windowed Attention from SwinV2
745
  # use a MLP trick to deal with various input image resolutions, then fold it to improve speed
746
- # tested multi-querry attention, but it is not as good as full attention:
747
- # look into palm: https://github.com/lucidrains/PaLM-pytorch/blob/main/palm_pytorch/palm_pytorch.py
748
- # single kv attention, mlp in parallel (didnt improve speed)
749
 
750
  def __init__(
751
  self,
@@ -757,6 +733,7 @@ class WindowAttention(nn.Module):
757
  seq_length=0,
758
  dim_out=None,
759
  multi_query=False,
 
760
  ):
761
  # taken from EdgeViT and tweaked with attention bias.
762
  super().__init__()
@@ -771,12 +748,7 @@ class WindowAttention(nn.Module):
771
 
772
  self.scale = qk_scale or head_dim ** -0.5
773
  if not multi_query:
774
- if TRT:
775
- self.q = nn.Linear(dim, dim, bias=qkv_bias)
776
- self.k = nn.Linear(dim, dim, bias=qkv_bias)
777
- self.v = nn.Linear(dim, dim, bias=qkv_bias)
778
- else:
779
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
780
  else:
781
  self.qkv = nn.Linear(dim, dim + 2 * self.head_dim, bias=qkv_bias)
782
 
@@ -787,6 +759,7 @@ class WindowAttention(nn.Module):
787
  pretrained_window_size=[resolution, resolution],
788
  num_heads=num_heads,
789
  seq_length=seq_length,
 
790
  )
791
 
792
  self.resolution = resolution
@@ -795,29 +768,12 @@ class WindowAttention(nn.Module):
795
  B, N, C = x.shape
796
 
797
  if not self.multi_query:
798
- if TRT:
799
- q = (
800
- self.q(x)
801
- .reshape(B, -1, self.num_heads, C // self.num_heads)
802
- .permute(0, 2, 1, 3)
803
- )
804
- k = (
805
- self.k(x)
806
- .reshape(B, -1, self.num_heads, C // self.num_heads)
807
- .permute(0, 2, 1, 3)
808
- )
809
- v = (
810
- self.v(x)
811
- .reshape(B, -1, self.num_heads, C // self.num_heads)
812
- .permute(0, 2, 1, 3)
813
- )
814
- else:
815
- qkv = (
816
  self.qkv(x)
817
  .reshape(B, -1, 3, self.num_heads, C // self.num_heads)
818
  .permute(2, 0, 3, 1, 4)
819
  )
820
- q, k, v = qkv[0], qkv[1], qkv[2]
821
  else:
822
  qkv = self.qkv(x)
823
  (q, k, v) = qkv.split(
@@ -864,10 +820,10 @@ class FasterViTLayer(nn.Module):
864
  sr_ratio=1,
865
  multi_query=False,
866
  use_swiglu=True,
867
- rep_vgg=False,
868
  yolo_arch=False,
869
  downsample_shuffle=False,
870
  conv_base=False,
 
871
  ):
872
  """
873
  Args:
@@ -899,12 +855,11 @@ class FasterViTLayer(nn.Module):
899
  drop_path=drop_path[i]
900
  if isinstance(drop_path, list)
901
  else drop_path,
902
- layer_scale=layer_scale_conv,
903
- rep_vgg=rep_vgg,
904
- )
905
  for i in range(depth)
906
  ]
907
  )
 
908
  else:
909
  self.blocks = C2f(dim, dim, n=depth, shortcut=True, e=0.5)
910
  self.yolo_arch = True
@@ -915,6 +870,7 @@ class FasterViTLayer(nn.Module):
915
  self.do_single_windowing = True
916
  if not isinstance(sr_ratio, list):
917
  sr_ratio = [sr_ratio]
 
918
  if any([sr != 1 for sr in sr_ratio]) or len(set(window_size)) > 1:
919
  self.do_single_windowing = False
920
  do_windowing = True
@@ -943,37 +899,119 @@ class FasterViTLayer(nn.Module):
943
  do_windowing=do_windowing,
944
  multi_query=multi_query,
945
  conv_base=conv_base,
 
946
  )
947
  )
948
 
 
 
949
  self.transformer = not conv
950
 
951
  self.downsample = (
952
  None if not downsample else Downsample(dim=dim, shuffle=downsample_shuffle)
953
  )
954
 
 
955
  def forward(self, x):
956
  B, C, H, W = x.shape
957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
958
  if self.transformer and self.do_single_windowing:
959
  H, W = x.shape[2], x.shape[3]
960
  x, pad_hw = window_partition(x, self.window_size)
961
 
962
- if not self.yolo_arch:
963
- for bn, blk in enumerate(self.blocks):
964
- x = blk(x)
965
- else:
966
- x = self.blocks(x)
 
967
 
968
  if self.transformer and self.do_single_windowing:
969
  x = window_reverse(x, self.window_size, H, W, pad_hw)
970
 
 
 
 
 
971
  if self.downsample is None:
972
  return x, x
973
 
974
  return self.downsample(x), x # changing to output pre downsampled features
975
 
976
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
977
  class FasterViT(nn.Module):
978
  """
979
  FasterViT
@@ -1001,7 +1039,6 @@ class FasterViT(nn.Module):
1001
  use_swiglu=False,
1002
  multi_query=False,
1003
  norm_layer=nn.LayerNorm,
1004
- rep_vgg=False,
1005
  drop_uniform=False,
1006
  yolo_arch=False,
1007
  shuffle_down=False,
@@ -1010,6 +1047,7 @@ class FasterViT(nn.Module):
1010
  full_features_head_dim=128,
1011
  neck_start_stage=1,
1012
  use_neck=False,
 
1013
  **kwargs,
1014
  ):
1015
  """
@@ -1074,48 +1112,32 @@ class FasterViT(nn.Module):
1074
  use_swiglu=use_swiglu,
1075
  multi_query=multi_query,
1076
  norm_layer=norm_layer,
1077
- rep_vgg=rep_vgg,
1078
  yolo_arch=yolo_arch,
1079
  downsample_shuffle=downsample_shuffle,
1080
  conv_base=conv_base,
 
 
1081
  )
1082
 
1083
  self.levels.append(level)
1084
 
1085
- if self.return_full_features or self.use_neck:
1086
- # create feature projection layers for segmentation output
1087
- self.neck_features_proj = nn.ModuleList()
1088
- self.neck_start_stage = neck_start_stage
1089
- upsample_ratio = 1
1090
- for i in range(len(depths)):
1091
- level_n_features_output = int(dim * 2 ** i)
1092
-
1093
- if self.neck_start_stage > i:
1094
- continue
1095
-
1096
- if (
1097
- upsample_ratio > 1
1098
- ) or full_features_head_dim != level_n_features_output:
1099
- feature_projection = nn.Sequential()
1100
- # feature_projection.add_module("norm",LayerNorm2d(level_n_features_output)) #slow, but better
1101
-
1102
- if 0:
1103
- # Train: 0 [1900/10009 ( 19%)] Loss: 6.113 (6.57) Time: 0.548s, 233.40/s (0.549s, 233.04/s) LR: 1.000e-05 Data: 0.015 (0.013)
1104
- feature_projection.add_module(
1105
- "norm", nn.BatchNorm2d(level_n_features_output)
1106
- ) # fast, but worse
1107
- feature_projection.add_module(
1108
- "dconv",
1109
- nn.ConvTranspose2d(
1110
- level_n_features_output,
1111
- full_features_head_dim,
1112
- kernel_size=upsample_ratio,
1113
- stride=upsample_ratio,
1114
- ),
1115
- )
1116
- else:
1117
  # pixel shuffle based upsampling
1118
- # Train: 0 [1950/10009 ( 19%)] Loss: 6.190 (6.55) Time: 0.540s, 236.85/s (0.548s, 233.38/s) LR: 1.000e-05 Data: 0.015 (0.013)
1119
  feature_projection.add_module(
1120
  "norm", nn.BatchNorm2d(level_n_features_output)
1121
  ) # fast, but worse
@@ -1133,17 +1155,19 @@ class FasterViT(nn.Module):
1133
  feature_projection.add_module(
1134
  "upsample_pixelshuffle", nn.PixelShuffle(upsample_ratio)
1135
  )
 
 
 
 
 
1136
 
1137
- else:
1138
- feature_projection = nn.Sequential()
1139
- feature_projection.add_module(
1140
- "norm", nn.BatchNorm2d(level_n_features_output)
1141
- )
1142
-
1143
- self.neck_features_proj.append(feature_projection)
1144
 
1145
- if i > 0 and self.levels[i - 1].downsample is not None:
1146
- upsample_ratio *= 2
 
 
 
1147
 
1148
  num_features = (
1149
  full_features_head_dim
@@ -1180,6 +1204,60 @@ class FasterViT(nn.Module):
1180
  nn.init.ones_(m.weight)
1181
  nn.init.zeros_(m.bias)
1182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1183
  @torch.jit.ignore
1184
  def no_weight_decay_keywords(self):
1185
  return {"rpb"}
@@ -1191,34 +1269,37 @@ class FasterViT(nn.Module):
1191
  x, pre_downsample_x = level(x)
1192
 
1193
  if self.return_full_features or self.use_neck:
1194
- if self.neck_start_stage > il:
1195
- continue
1196
- if full_features is None:
1197
- full_features = self.neck_features_proj[il - self.neck_start_stage](
1198
- pre_downsample_x
1199
- )
1200
- else:
1201
- # upsample torch tensor x to match full_features size, and add to full_features
1202
- feature_projection = self.neck_features_proj[
1203
- il - self.neck_start_stage
1204
- ](pre_downsample_x)
1205
- if (
1206
- feature_projection.shape[2] != full_features.shape[2]
1207
- or feature_projection.shape[3] != full_features.shape[3]
1208
- ):
1209
- feature_projection = torch.nn.functional.pad(
1210
- feature_projection,
1211
- (
1212
- 0,
1213
- -feature_projection.shape[3] + full_features.shape[3],
1214
- 0,
1215
- -feature_projection.shape[2] + full_features.shape[2],
1216
- ),
1217
  )
1218
- full_features += feature_projection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1219
 
1220
- # x = self.norm(full_features if (self.return_full_features or self.use_neck) else x)
1221
  x = self.norm(x) # new version for
 
1222
  x = self.avgpool(x)
1223
  x = torch.flatten(x, 1)
1224
 
@@ -1228,7 +1309,9 @@ class FasterViT(nn.Module):
1228
  return x, full_features
1229
 
1230
  def forward(self, x):
 
1231
  x, full_features = self.forward_features(x)
 
1232
  x = self.head(x)
1233
  if full_features is not None:
1234
  return x, full_features
@@ -1245,245 +1328,6 @@ class FasterViT(nn.Module):
1245
  if hasattr(module, "switch_to_deploy"):
1246
  module.switch_to_deploy()
1247
 
1248
-
1249
- @register_model
1250
- def fastervit2_small(pretrained=False, **kwargs): # ,
1251
- model = FasterViT(
1252
- depths=[3, 3, 5, 5],
1253
- num_heads=[2, 4, 8, 16],
1254
- window_size=[8, 8, [7, 7], 7],
1255
- dim=96,
1256
- in_dim=64,
1257
- mlp_ratio=4,
1258
- drop_path_rate=0.2,
1259
- sr_ratio=[1, 1, [1, 2], 1],
1260
- use_swiglu=False,
1261
- downsample_shuffle=False,
1262
- yolo_arch=True,
1263
- shuffle_down=False,
1264
- **kwargs,
1265
- )
1266
- if pretrained:
1267
- model.load_state_dict(torch.load(pretrained))
1268
- return model
1269
-
1270
-
1271
- @register_model
1272
- def fastervit2_tiny(pretrained=False, **kwargs): # ,
1273
- model = FasterViT(
1274
- depths=[1, 3, 4, 5],
1275
- num_heads=[2, 4, 8, 16],
1276
- window_size=[8, 8, [7, 7], 7],
1277
- dim=80,
1278
- in_dim=64,
1279
- mlp_ratio=4,
1280
- drop_path_rate=0.2,
1281
- sr_ratio=[1, 1, [2, 1], 1],
1282
- use_swiglu=False,
1283
- downsample_shuffle=False,
1284
- yolo_arch=True,
1285
- shuffle_down=False,
1286
- **kwargs,
1287
- )
1288
- if pretrained:
1289
- model.load_state_dict(torch.load(pretrained))
1290
- return model
1291
-
1292
-
1293
- @register_model
1294
- def fastervit2_base(pretrained=False, **kwargs):
1295
- model = FasterViT(
1296
- depths=[3, 3, 5, 5],
1297
- num_heads=[2, 4, 8, 16],
1298
- window_size=[8, 8, [7, 7], 7],
1299
- dim=128,
1300
- in_dim=64,
1301
- mlp_ratio=4,
1302
- drop_path_rate=0.2,
1303
- sr_ratio=[1, 1, [2, 1], 1],
1304
- use_swiglu=False,
1305
- yolo_arch=True,
1306
- shuffle_down=False,
1307
- conv_base=True,
1308
- **kwargs,
1309
- )
1310
- if pretrained:
1311
- model.load_state_dict(torch.load(pretrained))
1312
- return model
1313
-
1314
-
1315
- @register_model
1316
- def fastervit2_base_fullres1(pretrained=False, **kwargs):
1317
- model = FasterViT(
1318
- depths=[3, 3, 5, 5],
1319
- num_heads=[2, 4, 8, 16],
1320
- window_size=[8, 8, [7, 7], 7],
1321
- dim=128,
1322
- in_dim=64,
1323
- mlp_ratio=4,
1324
- drop_path_rate=0.2,
1325
- sr_ratio=[1, 1, [2, 1], 1],
1326
- use_swiglu=False,
1327
- yolo_arch=True,
1328
- shuffle_down=False,
1329
- conv_base=True,
1330
- use_neck=True,
1331
- full_features_head_dim=1024,
1332
- neck_start_stage=2,
1333
- **kwargs,
1334
- )
1335
- if pretrained:
1336
- model.load_state_dict(torch.load(pretrained))
1337
- return model
1338
-
1339
-
1340
- @register_model
1341
- def fastervit2_base_fullres2(pretrained=False, **kwargs):
1342
- model = FasterViT(
1343
- depths=[3, 3, 5, 5],
1344
- num_heads=[2, 4, 8, 16],
1345
- window_size=[8, 8, [7, 7], 7],
1346
- dim=128,
1347
- in_dim=64,
1348
- mlp_ratio=4,
1349
- drop_path_rate=0.2,
1350
- sr_ratio=[1, 1, [2, 1], 1],
1351
- use_swiglu=False,
1352
- yolo_arch=True,
1353
- shuffle_down=False,
1354
- conv_base=True,
1355
- use_neck=True,
1356
- full_features_head_dim=512,
1357
- neck_start_stage=1,
1358
- **kwargs,
1359
- )
1360
- if pretrained:
1361
- model.load_state_dict(torch.load(pretrained))
1362
- return model
1363
-
1364
-
1365
- @register_model
1366
- def fastervit2_base_fullres3(pretrained=False, **kwargs):
1367
- model = FasterViT(
1368
- depths=[3, 3, 5, 5],
1369
- num_heads=[2, 4, 8, 16],
1370
- window_size=[8, 8, [7, 7], 7],
1371
- dim=128,
1372
- in_dim=64,
1373
- mlp_ratio=4,
1374
- drop_path_rate=0.2,
1375
- sr_ratio=[1, 1, [2, 1], 1],
1376
- use_swiglu=False,
1377
- yolo_arch=True,
1378
- shuffle_down=False,
1379
- conv_base=True,
1380
- use_neck=True,
1381
- full_features_head_dim=256,
1382
- neck_start_stage=1,
1383
- **kwargs,
1384
- )
1385
- if pretrained:
1386
- model.load_state_dict(torch.load(pretrained))
1387
- return model
1388
-
1389
-
1390
- @register_model
1391
- def fastervit2_base_fullres4(pretrained=False, **kwargs):
1392
- model = FasterViT(
1393
- depths=[3, 3, 5, 5],
1394
- num_heads=[2, 4, 8, 16],
1395
- window_size=[8, 8, [7, 7], 7],
1396
- dim=128,
1397
- in_dim=64,
1398
- mlp_ratio=4,
1399
- drop_path_rate=0.2,
1400
- sr_ratio=[1, 1, [2, 1], 1],
1401
- use_swiglu=False,
1402
- yolo_arch=True,
1403
- shuffle_down=False,
1404
- conv_base=True,
1405
- use_neck=True,
1406
- full_features_head_dim=256,
1407
- neck_start_stage=2,
1408
- **kwargs,
1409
- )
1410
- if pretrained:
1411
- model.load_state_dict(torch.load(pretrained))
1412
- return model
1413
-
1414
-
1415
- @register_model
1416
- def fastervit2_base_fullres5(pretrained=False, **kwargs):
1417
- model = FasterViT(
1418
- depths=[3, 3, 5, 5],
1419
- num_heads=[2, 4, 8, 16],
1420
- window_size=[8, 8, [7, 7], 7],
1421
- dim=128,
1422
- in_dim=64,
1423
- mlp_ratio=4,
1424
- drop_path_rate=0.2,
1425
- sr_ratio=[1, 1, [2, 1], 1],
1426
- use_swiglu=False,
1427
- yolo_arch=True,
1428
- shuffle_down=False,
1429
- conv_base=True,
1430
- use_neck=True,
1431
- full_features_head_dim=512,
1432
- neck_start_stage=2,
1433
- **kwargs,
1434
- )
1435
- if pretrained:
1436
- model.load_state_dict(torch.load(pretrained))
1437
- return model
1438
-
1439
-
1440
- # pyt: 1934, 4202 TRT
1441
- @register_model
1442
- def fastervit2_large(pretrained=False, **kwargs):
1443
- model = FasterViT(
1444
- depths=[3, 3, 5, 5],
1445
- num_heads=[2, 4, 8, 16],
1446
- window_size=[8, 8, [7, 7], 7],
1447
- dim=128 + 64,
1448
- in_dim=64,
1449
- mlp_ratio=4,
1450
- drop_path_rate=0.2,
1451
- sr_ratio=[1, 1, [2, 1], 1],
1452
- use_swiglu=False,
1453
- yolo_arch=True,
1454
- shuffle_down=False,
1455
- **kwargs,
1456
- )
1457
- if pretrained:
1458
- model.load_state_dict(torch.load(pretrained))
1459
- return model
1460
-
1461
-
1462
- @register_model
1463
- def fastervit2_large_fullres(pretrained=False, **kwargs):
1464
- model = FasterViT(
1465
- depths=[3, 3, 5, 5],
1466
- num_heads=[2, 4, 8, 16],
1467
- window_size=[None, None, [7, 7], 7],
1468
- dim=192,
1469
- in_dim=64,
1470
- mlp_ratio=4,
1471
- drop_path_rate=0.0,
1472
- sr_ratio=[1, 1, [2, 1], 1],
1473
- use_swiglu=False,
1474
- yolo_arch=True,
1475
- shuffle_down=False,
1476
- conv_base=True,
1477
- use_neck=True,
1478
- full_features_head_dim=1536,
1479
- neck_start_stage=2,
1480
- **kwargs,
1481
- )
1482
- if pretrained:
1483
- model.load_state_dict(torch.load(pretrained))
1484
- return model
1485
-
1486
-
1487
  @register_model
1488
  def fastervit2_large_fullres_ws8(pretrained=False, **kwargs):
1489
  model = FasterViT(
@@ -1559,116 +1403,24 @@ def fastervit2_large_fullres_ws32(pretrained=False, **kwargs):
1559
  return model
1560
 
1561
 
1562
- # pyt: 897
1563
  @register_model
1564
- def fastervit2_xlarge(pretrained=False, **kwargs):
1565
- model = FasterViT(
1566
- depths=[3, 3, 5, 5],
1567
- num_heads=[2, 4, 8, 16],
1568
- window_size=[8, 8, [7, 7], 7],
1569
- dim=128 + 128 + 64,
1570
- in_dim=64,
1571
- mlp_ratio=4,
1572
- drop_path_rate=0.2,
1573
- sr_ratio=[1, 1, [2, 1], 1],
1574
- use_swiglu=False,
1575
- yolo_arch=True,
1576
- shuffle_down=False,
1577
- **kwargs,
1578
- )
1579
- if pretrained:
1580
- model.load_state_dict(torch.load(pretrained))
1581
- return model
1582
-
1583
-
1584
- # pyt:
1585
- @register_model
1586
- def fastervit2_huge(pretrained=False, **kwargs):
1587
- model = FasterViT(
1588
- depths=[3, 3, 5, 5],
1589
- num_heads=[2, 4, 8, 16],
1590
- window_size=[8, 8, [7, 7], 7],
1591
- dim=128 + 128 + 128 + 64,
1592
- in_dim=64,
1593
- mlp_ratio=4,
1594
- drop_path_rate=0.2,
1595
- sr_ratio=[1, 1, [2, 1], 1],
1596
- use_swiglu=False,
1597
- yolo_arch=True,
1598
- shuffle_down=False,
1599
- **kwargs,
1600
- )
1601
- if pretrained:
1602
- model.load_state_dict(torch.load(pretrained))
1603
- return model
1604
-
1605
-
1606
- @register_model
1607
- def fastervit2_xtiny(pretrained=False, **kwargs): # ,
1608
- model = FasterViT(
1609
- depths=[1, 3, 4, 5],
1610
- num_heads=[2, 4, 8, 16],
1611
- window_size=[8, 8, [7, 7], 7],
1612
- dim=64,
1613
- in_dim=64,
1614
- mlp_ratio=4,
1615
- drop_path_rate=0.1,
1616
- sr_ratio=[1, 1, [2, 1], 1],
1617
- use_swiglu=False,
1618
- downsample_shuffle=False,
1619
- yolo_arch=True,
1620
- shuffle_down=False,
1621
- **kwargs,
1622
- )
1623
- if pretrained:
1624
- model.load_state_dict(torch.load(pretrained))
1625
- return model
1626
 
 
 
 
 
1627
 
1628
- @register_model
1629
- def fastervit2_xxtiny_5(pretrained=False, **kwargs): # ,
1630
- model = FasterViT(
1631
- depths=[1, 3, 4, 5],
1632
- num_heads=[2, 4, 8, 16],
1633
- window_size=[8, 8, [7, 7], 7],
1634
- dim=48,
1635
- in_dim=64,
1636
- mlp_ratio=4,
1637
- drop_path_rate=0.05,
1638
- sr_ratio=[1, 1, [2, 1], 1],
1639
- use_swiglu=False,
1640
- downsample_shuffle=False,
1641
- yolo_arch=True,
1642
- shuffle_down=False,
1643
- **kwargs,
1644
- )
1645
- if pretrained:
1646
- model.load_state_dict(torch.load(pretrained))
1647
- return model
1648
 
 
 
1649
 
1650
- @register_model
1651
- def fastervit2_xxxtiny(pretrained=False, **kwargs): # ,
1652
- model = FasterViT(
1653
- depths=[1, 3, 4, 5],
1654
- num_heads=[2, 4, 8, 16],
1655
- window_size=[8, 8, [7, 7], 7],
1656
- dim=32,
1657
- in_dim=32,
1658
- mlp_ratio=4,
1659
- drop_path_rate=0.0,
1660
- sr_ratio=[1, 1, [2, 1], 1],
1661
- use_swiglu=False,
1662
- downsample_shuffle=False,
1663
- yolo_arch=True,
1664
- shuffle_down=False,
1665
- **kwargs,
1666
- )
1667
- if pretrained:
1668
- model.load_state_dict(torch.load(pretrained))
1669
- return model
1670
 
1671
 
1672
- @register_model
1673
- def eradio(pretrained=False, **kwargs):
1674
- return fastervit2_large_fullres_ws16(pretrained=pretrained, **kwargs)
 
1
  #!/usr/bin/env python3
2
 
3
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
  #
5
  # NVIDIA CORPORATION and its licensors retain all intellectual property
6
  # and proprietary rights in and to this software, related documentation
 
8
  # distribution of this software and related documentation without an express
9
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
 
11
+ # E-RADIO (FasterViTv2) model from
12
+ # Mike Ranzinger, Greg Heinrich, Jan Kautz, and Pavlo Molchanov. "AM-RADIO: Agglomerative Model--Reduce All Domains Into One." arXiv preprint arXiv:2312.06709 (2023).
13
+
14
+ # based on FasterViT, Swin Transformer, YOLOv8
15
+ # FasterViT:
16
+ # Ali Hatamizadeh, Greg Heinrich, Hongxu Yin, Andrew Tao, Jose M. Alvarez, Jan Kautz, and Pavlo Molchanov. "FasterViT: Fast Vision Transformers with Hierarchical Attention." arXiv preprint arXiv:2306.06189 (2023).
17
 
18
  import torch
19
  import torch.nn as nn
 
22
  from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
23
  import numpy as np
24
  import torch.nn.functional as F
25
+ import warnings
26
+
27
+
28
+ SIMPLER_UP_TOWER = False
29
+
30
+ #######################
31
+ ## Codebase from YOLOv8
32
+ ## BEGINNING
33
+ #######################
34
+
35
+ class C2f(nn.Module):
36
+ """Faster Implementation of CSP Bottleneck with 2 convolutions."""
37
+ """From YOLOv8 codebase"""
38
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=None): # ch_in, ch_out, number, shortcut, groups, expansion
39
+ super().__init__()
40
+ if drop_path is None:
41
+ drop_path = [0.0] * n
42
+
43
+ self.c = int(c2 * e) # hidden channels
44
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
45
+ self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
46
+ self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0, drop_path=drop_path[i]) for i in range(n))
47
+
48
+ def forward(self, x):
49
+ """Forward pass through C2f layer."""
50
+ y = list(self.cv1(x).chunk(2, 1))
51
+ y.extend(m(y[-1]) for m in self.m)
52
+ return self.cv2(torch.cat(y, 1))
53
+
54
+ def forward_split(self, x):
55
+ """Forward pass using split() instead of chunk()."""
56
+ y = list(self.cv1(x).split((self.c, self.c), 1))
57
+ y.extend(m(y[-1]) for m in self.m)
58
+ return self.cv2(torch.cat(y, 1))
59
+
60
+ class Bottleneck(nn.Module):
61
+ """Standard bottleneck."""
62
+
63
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_path=0.0): # ch_in, ch_out, shortcut, groups, kernels, expand
64
+ super().__init__()
65
+ c_ = int(c2 * e) # hidden channels
66
+ self.cv1 = Conv(c1, c_, k[0], 1)
67
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
68
+ self.add = shortcut and c1 == c2
69
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
70
+
71
+ def forward(self, x):
72
+ """'forward()' applies the YOLOv5 FPN to input data."""
73
+ return x + self.drop_path1(self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x))
74
 
 
75
 
76
+ class Conv(nn.Module):
77
+ """Modified to support layer fusion"""
78
+ default_act = nn.SiLU() # default activation
79
 
80
+ def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, dilation=1, bn_weight_init=1, bias=False, act=True):
81
+ super().__init__()
82
+
83
+ self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, autopad(kernel_size, padding, dilation), dilation, g, bias=False)
84
+ if 1:
85
+ self.bn = torch.nn.BatchNorm2d(b)
86
+ torch.nn.init.constant_(self.bn.weight, bn_weight_init)
87
+ torch.nn.init.constant_(self.bn.bias, 0)
88
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
89
+
90
+
91
+ def forward(self,x):
92
+ x = self.conv(x)
93
+ x = self.bn(x)
94
+ x = self.act(x)
95
+ return x
96
+
97
+ @torch.no_grad()
98
+ def switch_to_deploy(self):
99
+ # return 1
100
+ c, bn = self.conv, self.bn
101
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
102
+ w = c.weight * w[:, None, None, None]
103
+ b = bn.bias - bn.running_mean * bn.weight / \
104
+ (bn.running_var + bn.eps)**0.5
105
+
106
+ self.conv.weight.data.copy_(w)
107
+ self.conv.bias = nn.Parameter(b)
108
+
109
+ self.bn = nn.Identity()
110
+
111
+ def autopad(k, p=None, d=1): # kernel, padding, dilation
112
+ """Pad to 'same' shape outputs."""
113
+ if d > 1:
114
+ k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
115
+ if p is None:
116
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
117
+ return p
118
+
119
+
120
+ #######################
121
+ ## Codebase from YOLOv8
122
+ ## END
123
+ #######################
124
 
125
 
126
  def pixel_unshuffle(data, factor=2):
 
157
  else:
158
  pad_h = (window_size - H % window_size) % window_size
159
  pad_w = (window_size - W % window_size) % window_size
160
+ #interpolate features
161
  if pad_h > 0 or pad_w > 0:
162
  x = F.pad(x, (0, pad_w, 0, pad_h, 0, 0, 0, 0))
163
  Hp, Wp = H + pad_h, W + pad_w
 
201
 
202
  @torch.no_grad()
203
  def switch_to_deploy(self):
 
 
204
  if not isinstance(self.bn, nn.Identity):
205
  c, bn = self.conv, self.bn
206
  w = bn.weight / (bn.running_var + bn.eps) ** 0.5
 
242
 
243
 
244
  class PosEmbMLPSwinv2D(nn.Module):
245
+ """
246
+ 2D positional embedding from Swin Transformer v2
247
+ Added functionality to store the positional embedding in the model and not recompute it every time
248
+ """
249
  def __init__(
250
+ self, window_size, pretrained_window_size, num_heads, seq_length, no_log=False, cpb_mlp_hidden=512,
251
  ):
252
  super().__init__()
253
  self.window_size = window_size
254
  self.num_heads = num_heads
255
  # mlp to generate continuous relative position bias
256
  self.cpb_mlp = nn.Sequential(
257
+ nn.Linear(2, cpb_mlp_hidden, bias=True),
258
  nn.ReLU(inplace=True),
259
+ nn.Linear(cpb_mlp_hidden, num_heads, bias=False),
260
  )
261
 
262
+ self.grid_exists = False
263
+ self.seq_length = seq_length
264
+ self.deploy = False
265
+ self.num_heads = num_heads
266
+ self.no_log = no_log
267
+ self.pretrained_window_size = pretrained_window_size
268
+ self.relative_bias_window_size = window_size
269
+
270
+ relative_coords_table, relative_position_index, relative_bias = self.relative_bias_initialization(window_size, num_heads,
271
+ pretrained_window_size, seq_length,
272
+ no_log)
273
+
274
+ self.register_buffer("relative_coords_table", relative_coords_table)
275
+ self.register_buffer("relative_position_index", relative_position_index)
276
+ self.register_buffer("relative_bias", relative_bias) # for EMA
277
+
278
+ def relative_bias_initialization(self, window_size, num_heads, pretrained_window_size, seq_length, no_log):
279
+ # as in separate function to support window size chage after model weights loading
280
+
281
  relative_coords_h = torch.arange(
282
+ -(window_size[0] - 1), window_size[0], dtype=torch.float32
283
  )
284
  relative_coords_w = torch.arange(
285
+ -(window_size[1] - 1), window_size[1], dtype=torch.float32
286
  )
287
  relative_coords_table = (
288
  torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
 
305
  / np.log2(8)
306
  )
307
 
 
 
308
  # get pair-wise relative position index for each token inside the window
309
  coords_h = torch.arange(self.window_size[0])
310
  coords_w = torch.arange(self.window_size[1])
 
320
  relative_coords[:, :, 1] += self.window_size[1] - 1
321
  relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
322
  relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
 
323
 
324
+ relative_bias = torch.zeros(1, num_heads, seq_length, seq_length)
325
 
326
+ self.relative_bias_window_size = window_size
327
+
328
+ return relative_coords_table, relative_position_index, relative_bias
329
 
 
 
 
330
 
331
  def switch_to_deploy(self):
332
  self.deploy = True
 
335
  def forward(self, input_tensor):
336
  # for efficiency, we want this forward to be folded into a single operation (sum)
337
  # if resolution stays the same, then we dont need to recompute MLP layers
 
 
 
 
338
 
339
+ if not self.deploy or self.training:
340
  self.grid_exists = False
341
 
342
+ #compare if all elements in self.window_size list match those in self.relative_bias_window_size
343
+ if not all([self.window_size[i] == self.relative_bias_window_size[i] for i in range(len(self.window_size))]):
344
+ relative_coords_table, relative_position_index, relative_bias = self.relative_bias_initialization(self.window_size, self.num_heads,
345
+ self.pretrained_window_size, self.seq_length,
346
+ self.no_log)
347
+
348
+ self.relative_coords_table = relative_coords_table.to(self.relative_coords_table.device)
349
+ self.relative_position_index = relative_position_index.to(self.relative_position_index.device)
350
+ self.relative_bias = relative_bias.to(self.relative_bias.device)
351
+
352
  if self.deploy and self.grid_exists:
353
  input_tensor += self.relative_bias
354
  return input_tensor
355
 
356
+ if 1:
357
  self.grid_exists = True
358
 
359
  relative_position_bias_table = self.cpb_mlp(
 
396
  conv_base=False,
397
  do_windowing=True,
398
  multi_query=False,
399
+ cpb_mlp_hidden=512,
400
  ) -> None:
401
  super().__init__()
402
 
403
+
 
 
 
404
  self.do_windowing = do_windowing
405
 
406
  if do_windowing:
407
+ if conv_base:
408
+ self.downsample_op = nn.Conv2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  self.downsample_mixer = nn.Identity()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  self.upsample_mixer = nn.Identity()
411
+ self.upsample_op = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
412
+ else:
413
+ self.downsample_op = nn.AvgPool2d(kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
414
+ self.downsample_mixer = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1) if subsample_ratio > 1 else nn.Identity()
415
+ self.upsample_mixer = nn.Upsample(scale_factor=subsample_ratio, mode='nearest') if subsample_ratio > 1 else nn.Identity()
416
+ self.upsample_op = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1, padding=0, bias=False) if subsample_ratio > 1 else nn.Identity()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
  self.window_size = window_size
419
 
420
  self.norm1 = norm_layer(dim_in)
 
 
 
 
421
 
422
  self.attn = WindowAttention(
423
  dim_in,
424
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
 
 
425
  resolution=window_size,
426
+ seq_length=window_size**2, dim_out=dim_in, multi_query=multi_query,
427
+ cpb_mlp_hidden=cpb_mlp_hidden)
 
 
 
 
 
 
 
428
 
429
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
430
 
431
  use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
432
+ self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim_in)) if use_layer_scale else 1
 
 
433
 
434
  ### mlp layer
435
  mlp_ratio = 4
 
437
  mlp_hidden_dim = int(dim_in * mlp_ratio)
438
 
439
  activation = nn.GELU if not use_swiglu else SwiGLU
440
+ mlp_hidden_dim = int((4 * dim_in * 1 / 2) / 64) * 64 if use_swiglu else mlp_hidden_dim
 
 
441
 
442
+ self.mlp = Mlp(in_features=dim_in, hidden_features=mlp_hidden_dim, act_layer=activation, use_swiglu=use_swiglu)
443
+
444
+ self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim_in)) if layer_scale else 1
445
+ self.drop_path2=DropPath(drop_path) if drop_path > 0. else nn.Identity()
 
 
446
 
 
 
 
 
 
 
 
 
 
447
 
448
  def forward(self, x):
449
  skip_connection = x
 
516
  use_swiglu=True,
517
  multi_query=False,
518
  conv_base=False,
519
+ cpb_mlp_hidden=512
520
  ) -> None:
521
  """
522
  Args:
 
555
  do_windowing=do_windowing,
556
  multi_query=multi_query,
557
  conv_base=conv_base,
558
+ cpb_mlp_hidden=cpb_mlp_hidden
559
  ),
560
  )
561
 
 
598
  )
599
  self.act = act_layer()
600
  self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
 
601
 
602
  def forward(self, x):
603
  x_size = x.size()
604
  x = x.view(-1, x_size[-1])
605
  x = self.fc1(x)
606
  x = self.act(x)
 
607
  x = self.fc2(x)
 
608
  x = x.view(x_size)
609
  return x
610
 
 
622
  """
623
  Args:
624
  dim: feature size dimension.
625
+ shuffle: idea with pixel unshuffling instead for resizing
626
  keep_dim: bool argument for maintaining the resolution.
627
  """
628
 
 
633
  self.norm = lambda x: pixel_unshuffle(x, factor=2)
634
  self.reduction = Conv2d_BN(dim * 4, dim_out, 1, 1, 0, bias=False)
635
  else:
 
 
636
  self.norm = nn.Identity()
637
  self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
638
 
 
645
  class PatchEmbed(nn.Module):
646
  """
647
  Patch embedding block
648
+ Used to convert image into an initial set of feature maps with lower resolution
649
+
650
  """
651
 
652
  def __init__(self, in_chans=3, in_dim=64, dim=96, shuffle_down=False):
 
670
  )
671
  else:
672
  self.proj = lambda x: pixel_unshuffle(x, factor=4)
 
 
 
 
 
 
673
  self.conv_down = nn.Sequential(
674
  Conv2d_BN(in_chans * 16, dim, 3, 1, 1), nn.ReLU(),
675
  )
 
684
  """
685
  Convolutional block, used in first couple of stages
686
  Experimented with plan resnet-18 like modules, they are the best in terms of throughput
 
687
  Finally, YOLOv8 idea seem to work fine (resnet-18 like block with squeezed feature dimension, and feature concatendation at the end)
688
  """
689
+ def __init__(self, dim,
690
+ drop_path=0.,
691
+ layer_scale=None,
692
+ kernel_size=3,
693
+ ):
694
  super().__init__()
 
 
 
 
 
 
 
 
 
 
695
 
696
+ self.conv1 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
697
+ self.act1 = nn.GELU()
698
+
699
+ self.conv2 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
 
 
 
 
700
 
701
  self.layer_scale = layer_scale
702
  if layer_scale is not None and type(layer_scale) in [int, float]:
 
704
  self.layer_scale = True
705
  else:
706
  self.layer_scale = False
707
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
708
 
709
  def forward(self, x):
710
  input = x
711
+
712
+ x = self.conv1(x)
713
+ x = self.act1(x)
714
+ x = self.conv2(x)
715
+
 
 
716
  if self.layer_scale:
717
  x = x * self.gamma.view(1, -1, 1, 1)
718
  x = input + self.drop_path(x)
 
722
  class WindowAttention(nn.Module):
723
  # Windowed Attention from SwinV2
724
  # use a MLP trick to deal with various input image resolutions, then fold it to improve speed
 
 
 
725
 
726
  def __init__(
727
  self,
 
733
  seq_length=0,
734
  dim_out=None,
735
  multi_query=False,
736
+ cpb_mlp_hidden=512,
737
  ):
738
  # taken from EdgeViT and tweaked with attention bias.
739
  super().__init__()
 
748
 
749
  self.scale = qk_scale or head_dim ** -0.5
750
  if not multi_query:
751
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
 
 
 
 
 
752
  else:
753
  self.qkv = nn.Linear(dim, dim + 2 * self.head_dim, bias=qkv_bias)
754
 
 
759
  pretrained_window_size=[resolution, resolution],
760
  num_heads=num_heads,
761
  seq_length=seq_length,
762
+ cpb_mlp_hidden=cpb_mlp_hidden,
763
  )
764
 
765
  self.resolution = resolution
 
768
  B, N, C = x.shape
769
 
770
  if not self.multi_query:
771
+ qkv = (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
772
  self.qkv(x)
773
  .reshape(B, -1, 3, self.num_heads, C // self.num_heads)
774
  .permute(2, 0, 3, 1, 4)
775
  )
776
+ q, k, v = qkv[0], qkv[1], qkv[2]
777
  else:
778
  qkv = self.qkv(x)
779
  (q, k, v) = qkv.split(
 
820
  sr_ratio=1,
821
  multi_query=False,
822
  use_swiglu=True,
 
823
  yolo_arch=False,
824
  downsample_shuffle=False,
825
  conv_base=False,
826
+ cpb_mlp_hidden=512,
827
  ):
828
  """
829
  Args:
 
855
  drop_path=drop_path[i]
856
  if isinstance(drop_path, list)
857
  else drop_path,
858
+ layer_scale=layer_scale_conv )
 
 
859
  for i in range(depth)
860
  ]
861
  )
862
+ self.blocks = nn.Sequential(*self.blocks)
863
  else:
864
  self.blocks = C2f(dim, dim, n=depth, shortcut=True, e=0.5)
865
  self.yolo_arch = True
 
870
  self.do_single_windowing = True
871
  if not isinstance(sr_ratio, list):
872
  sr_ratio = [sr_ratio]
873
+ self.sr_ratio = sr_ratio
874
  if any([sr != 1 for sr in sr_ratio]) or len(set(window_size)) > 1:
875
  self.do_single_windowing = False
876
  do_windowing = True
 
899
  do_windowing=do_windowing,
900
  multi_query=multi_query,
901
  conv_base=conv_base,
902
+ cpb_mlp_hidden=cpb_mlp_hidden,
903
  )
904
  )
905
 
906
+ self.blocks = nn.Sequential(*self.blocks)
907
+
908
  self.transformer = not conv
909
 
910
  self.downsample = (
911
  None if not downsample else Downsample(dim=dim, shuffle=downsample_shuffle)
912
  )
913
 
914
+
915
  def forward(self, x):
916
  B, C, H, W = x.shape
917
 
918
+ # do padding for transforemr
919
+ interpolate = True
920
+ if self.transformer and interpolate:
921
+ # Windowed Attention will split feature map into windows with the size of window_size x window_size
922
+ # if the resolution is not divisible by window_size, we need to interpolate the feature map
923
+ # can be done via padding, but doing so after training hurts the model performance.
924
+ # interpolation affects the performance as well, but not as much as padding
925
+ if isinstance(self.window_size, list) or isinstance(self.window_size, tuple):
926
+ current_max_window_size = max(self.window_size)
927
+ else:
928
+ current_max_window_size = self.window_size
929
+
930
+ max_window_size = max([res_upsample*current_max_window_size for res_upsample in self.sr_ratio])
931
+ if H % max_window_size != 0 or W % max_window_size != 0:
932
+ new_h = int(np.ceil(H/max_window_size)*max_window_size)
933
+ new_w = int(np.ceil(W/max_window_size)*max_window_size)
934
+ x = F.interpolate(x, size=(new_h, new_w), mode='nearest')
935
+ warnings.warn(f"Choosen window size is not optimal for given resolution. Interpolation of features maps will be done and it can affect the performance. Max window size is {max_window_size}, feature map size is {H}x{W}, interpolated feature map size is {new_h}x{new_w}.")
936
+
937
+
938
  if self.transformer and self.do_single_windowing:
939
  H, W = x.shape[2], x.shape[3]
940
  x, pad_hw = window_partition(x, self.window_size)
941
 
942
+ x = self.blocks(x)
943
+ # if not self.yolo_arch:
944
+ # for bn, blk in enumerate(self.blocks):
945
+ # x = blk(x)
946
+ # else:
947
+ # x = self.blocks(x)
948
 
949
  if self.transformer and self.do_single_windowing:
950
  x = window_reverse(x, self.window_size, H, W, pad_hw)
951
 
952
+ if self.transformer and interpolate:
953
+ #lets keep original resolution, might be not ideal, but for the upsampling tower we need to keep the expected resolution.
954
+ x = F.interpolate(x, size=(H, W), mode='nearest')
955
+
956
  if self.downsample is None:
957
  return x, x
958
 
959
  return self.downsample(x), x # changing to output pre downsampled features
960
 
961
 
962
+ class HiResNeck(nn.Module):
963
+ """
964
+ The block is used to output dense features from all stages
965
+ Otherwise, by default, only the last stage features are returned with FasterViTv2
966
+ """
967
+ def __init__(self, dim, depths, neck_start_stage, full_features_head_dim):
968
+
969
+ '''
970
+ Hi Resolution neck to support output of high res features that are useful for dense tasks.
971
+ depths - total number of layers in the base model
972
+ neck_start_stage - when to start the neck, 0 - start from the first stage, 1 - start from the second stage etc.
973
+ earlier layers result in higher resolution features at the cost of compute
974
+ full_features_head_dim - number of channels in the dense features head
975
+ '''
976
+ # create feature projection layers for segmentation output
977
+ self.neck_features_proj = nn.ModuleList()
978
+ self.neck_start_stage = neck_start_stage
979
+ upsample_ratio = 1
980
+ for i in range(len(depths)):
981
+ level_n_features_output = int(dim * 2 ** i)
982
+
983
+ if self.neck_start_stage > i: continue
984
+
985
+ if (upsample_ratio > 1) or full_features_head_dim!=level_n_features_output:
986
+ feature_projection = nn.Sequential()
987
+ feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) #fast, but worse
988
+
989
+ feature_projection.add_module("dconv", nn.ConvTranspose2d(level_n_features_output,
990
+ full_features_head_dim, kernel_size=upsample_ratio, stride=upsample_ratio))
991
+ else:
992
+ feature_projection = nn.Sequential()
993
+
994
+ self.neck_features_proj.append(feature_projection)
995
+
996
+ if i>0 and self.levels[i-1].downsample is not None:
997
+ upsample_ratio *= 2
998
+
999
+ def forward(self, x, il_level=-1, full_features=None):
1000
+ if self.neck_start_stage > il_level:
1001
+ return full_features
1002
+
1003
+ if full_features is None:
1004
+ full_features = self.neck_features_proj[il_level - self.neck_start_stage](x)
1005
+ else:
1006
+ #upsample torch tensor x to match full_features size, and add to full_features
1007
+ feature_projection = self.neck_features_proj[il_level - self.neck_start_stage](x)
1008
+ if feature_projection.shape[2] != full_features.shape[2] or feature_projection.shape[3] != full_features.shape[3]:
1009
+ feature_projection = torch.nn.functional.pad(feature_projection, ( 0, -feature_projection.shape[3] + full_features.shape[3], 0, -feature_projection.shape[2] + full_features.shape[2]))
1010
+ full_features += feature_projection
1011
+ return full_features
1012
+
1013
+
1014
+
1015
  class FasterViT(nn.Module):
1016
  """
1017
  FasterViT
 
1039
  use_swiglu=False,
1040
  multi_query=False,
1041
  norm_layer=nn.LayerNorm,
 
1042
  drop_uniform=False,
1043
  yolo_arch=False,
1044
  shuffle_down=False,
 
1047
  full_features_head_dim=128,
1048
  neck_start_stage=1,
1049
  use_neck=False,
1050
+ cpb_mlp_hidden=512,
1051
  **kwargs,
1052
  ):
1053
  """
 
1112
  use_swiglu=use_swiglu,
1113
  multi_query=multi_query,
1114
  norm_layer=norm_layer,
 
1115
  yolo_arch=yolo_arch,
1116
  downsample_shuffle=downsample_shuffle,
1117
  conv_base=conv_base,
1118
+ cpb_mlp_hidden=cpb_mlp_hidden,
1119
+
1120
  )
1121
 
1122
  self.levels.append(level)
1123
 
1124
+ if not SIMPLER_UP_TOWER:
1125
+ if self.return_full_features or self.use_neck:
1126
+ # create feature projection layers for segmentation output
1127
+ self.neck_features_proj = nn.ModuleList()
1128
+ self.neck_start_stage = neck_start_stage
1129
+ upsample_ratio = 1
1130
+ for i in range(len(depths)):
1131
+ level_n_features_output = int(dim * 2 ** i)
1132
+
1133
+ if self.neck_start_stage > i:
1134
+ continue
1135
+
1136
+ if (
1137
+ upsample_ratio > 1
1138
+ ) or full_features_head_dim != level_n_features_output:
1139
+ feature_projection = nn.Sequential()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1140
  # pixel shuffle based upsampling
 
1141
  feature_projection.add_module(
1142
  "norm", nn.BatchNorm2d(level_n_features_output)
1143
  ) # fast, but worse
 
1155
  feature_projection.add_module(
1156
  "upsample_pixelshuffle", nn.PixelShuffle(upsample_ratio)
1157
  )
1158
+ else:
1159
+ feature_projection = nn.Sequential()
1160
+ feature_projection.add_module(
1161
+ "norm", nn.BatchNorm2d(level_n_features_output)
1162
+ )
1163
 
1164
+ self.neck_features_proj.append(feature_projection)
 
 
 
 
 
 
1165
 
1166
+ if i > 0 and self.levels[i - 1].downsample is not None:
1167
+ upsample_ratio *= 2
1168
+ else:
1169
+ if self.return_full_features or self.use_neck:
1170
+ self.high_res_neck = HiResNeck(dim, num_heads, depths, neck_start_stage, full_features_head_dim)
1171
 
1172
  num_features = (
1173
  full_features_head_dim
 
1204
  nn.init.ones_(m.weight)
1205
  nn.init.zeros_(m.bias)
1206
 
1207
+ def change_window_size(self, new_window_size):
1208
+ """
1209
+ FasterViT uses windowed attention, it might be sensative to the choiuce of this parameter
1210
+ especially in case of eneven partitioning of the feature maps.
1211
+ FasterViT allows changing the window size post training.
1212
+ Therefore it should be changed with different input image resolution.
1213
+ Recommended values:
1214
+ input res | window_size
1215
+ 224 | 7
1216
+ 256 | 8
1217
+ 386 | 12
1218
+ 512 | 16
1219
+ Ideally, window_size should be a factor of the input resolution. In the third stage we divide the resolution by 16, so window_size should be img_res/16/2 for the third stage and img_res/32 for the last stage.
1220
+ Applying in the brute force way, can be done smarter
1221
+ """
1222
+ window_size = new_window_size
1223
+
1224
+ for module in self.modules():
1225
+ if hasattr(module, "window_size"):
1226
+ # check if tuple or a number
1227
+ if isinstance(module.window_size, tuple):
1228
+ if module.window_size[0] != window_size:
1229
+ module.window_size = (window_size, window_size)
1230
+ elif isinstance(module.window_size, list):
1231
+ if module.window_size[0] != window_size:
1232
+ module.window_size = [window_size, window_size]
1233
+ else:
1234
+ module.window_size = window_size
1235
+
1236
+ def set_optimal_window_size(self, image_dim):
1237
+ """
1238
+ Using hand picked window size for various resolutions.
1239
+ """
1240
+ if isinstance(image_dim, list) or isinstance(image_dim, tuple):
1241
+ image_dim = min(image_dim)
1242
+
1243
+ if image_dim == 224:
1244
+ new_window_size = 7
1245
+ elif image_dim == 256:
1246
+ new_window_size = 8
1247
+ elif image_dim == 384:
1248
+ new_window_size = 12
1249
+ elif image_dim == 512:
1250
+ new_window_size = 16
1251
+ else:
1252
+ if image_dim < 512:
1253
+ new_window_size = np.ceil(image_dim / 32)
1254
+ else:
1255
+ new_window_size = 16
1256
+
1257
+ print(f"Changing window size to {new_window_size}")
1258
+ self.change_window_size(new_window_size = new_window_size)
1259
+
1260
+
1261
  @torch.jit.ignore
1262
  def no_weight_decay_keywords(self):
1263
  return {"rpb"}
 
1269
  x, pre_downsample_x = level(x)
1270
 
1271
  if self.return_full_features or self.use_neck:
1272
+ if not SIMPLER_UP_TOWER:
1273
+ if self.neck_start_stage > il:
1274
+ continue
1275
+ if full_features is None:
1276
+ full_features = self.neck_features_proj[il - self.neck_start_stage](
1277
+ pre_downsample_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1278
  )
1279
+ else:
1280
+ # upsample torch tensor x to match full_features size, and add to full_features
1281
+ feature_projection = self.neck_features_proj[
1282
+ il - self.neck_start_stage
1283
+ ](pre_downsample_x)
1284
+ if (
1285
+ feature_projection.shape[2] != full_features.shape[2]
1286
+ or feature_projection.shape[3] != full_features.shape[3]
1287
+ ):
1288
+ feature_projection = torch.nn.functional.pad(
1289
+ feature_projection,
1290
+ (
1291
+ 0,
1292
+ -feature_projection.shape[3] + full_features.shape[3],
1293
+ 0,
1294
+ -feature_projection.shape[2] + full_features.shape[2],
1295
+ ),
1296
+ )
1297
+ full_features += feature_projection
1298
+ else:
1299
+ full_features = self.high_res_neck(pre_downsample_x, il, full_features)
1300
 
 
1301
  x = self.norm(x) # new version for
1302
+
1303
  x = self.avgpool(x)
1304
  x = torch.flatten(x, 1)
1305
 
 
1309
  return x, full_features
1310
 
1311
  def forward(self, x):
1312
+
1313
  x, full_features = self.forward_features(x)
1314
+
1315
  x = self.head(x)
1316
  if full_features is not None:
1317
  return x, full_features
 
1328
  if hasattr(module, "switch_to_deploy"):
1329
  module.switch_to_deploy()
1330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1331
  @register_model
1332
  def fastervit2_large_fullres_ws8(pretrained=False, **kwargs):
1333
  model = FasterViT(
 
1403
  return model
1404
 
1405
 
 
1406
  @register_model
1407
+ def eradio(pretrained=False, **kwargs):
1408
+ return fastervit2_large_fullres_ws16(pretrained=pretrained, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1409
 
1410
+ '''
1411
+ Suggested way to use:
1412
+ from transformers import AutoModel
1413
+ model = AutoModel.from_pretrained("nvidia/E-RADIO", trust_remote_code=True)
1414
 
1415
+ model.model.set_optimal_window_size(image_dim = data["image"][0].shape[:2])
1416
+ imgs = [torch.tensor(img).permute(2,0,1)/255.0 for img in data["image"]] #res is 224
1417
+ input_images = torch.stack(imgs).cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1418
 
1419
+ model.eval()
1420
+ model.cuda()
1421
 
1422
+ cls_token, features = model(input_images)
1423
+ cls_token = features.mean([2, 3])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1424
 
1425
 
1426
+ '''
 
 
hf_model.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -14,7 +14,6 @@
14
  from collections import namedtuple
15
  from typing import Optional
16
 
17
- from einops import rearrange
18
  from timm.models import VisionTransformer
19
  import torch
20
  from transformers import PretrainedConfig, PreTrainedModel
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
14
  from collections import namedtuple
15
  from typing import Optional
16
 
 
17
  from timm.models import VisionTransformer
18
  import torch
19
  from transformers import PretrainedConfig, PreTrainedModel
input_conditioner.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3840092575224b5ff90adf3b7970a5a5e379f8988241ee8145969e27c32c17e7
3
  size 1105844337
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc6244a274d1479e33f4779949f98cefeb3108b77fdc9b79c33b92295c5141d4
3
  size 1105844337
radio_model.py CHANGED
@@ -1,10 +1,11 @@
1
- # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
5
  # and any modifications thereto. Any use, reproduction, disclosure or
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
 
8
 
9
  import torch
10
  from torch import nn
@@ -13,6 +14,8 @@ from timm.models import create_model, VisionTransformer
13
 
14
  from .enable_cpe_support import enable_cpe
15
  from .input_conditioner import InputConditioner
 
 
16
 
17
 
18
  class RADIOModel(nn.Module):
@@ -22,6 +25,7 @@ class RADIOModel(nn.Module):
22
  input_conditioner: InputConditioner,
23
  return_summary: bool,
24
  return_spatial_features: bool,
 
25
  ):
26
  super().__init__()
27
 
@@ -29,6 +33,24 @@ class RADIOModel(nn.Module):
29
  self.input_conditioner = input_conditioner
30
  self.return_summary = return_summary
31
  self.return_spatial_features = return_spatial_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def forward(self, x: torch.Tensor):
34
  x = self.input_conditioner(x)
@@ -40,7 +62,13 @@ class RADIOModel(nn.Module):
40
  elif isinstance(self.model, VisionTransformer):
41
  patch_gen = getattr(self.model, "patch_generator", None)
42
  if patch_gen is not None:
43
- summary = y[:, : patch_gen.num_cls_tokens].flatten(1)
 
 
 
 
 
 
44
  all_feat = y[:, patch_gen.num_skip :]
45
  elif self.model.global_pool == "avg":
46
  summary = y[:, self.model.num_prefix_tokens :].mean(dim=1)
@@ -51,7 +79,7 @@ class RADIOModel(nn.Module):
51
  else:
52
  raise ValueError("Unsupported model type")
53
 
54
- if self.return_summary and self.return_spatial_features:
55
  return summary, all_feat
56
  elif self.return_summary:
57
  return summary
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
5
  # and any modifications thereto. Any use, reproduction, disclosure or
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from typing import Optional
9
 
10
  import torch
11
  from torch import nn
 
14
 
15
  from .enable_cpe_support import enable_cpe
16
  from .input_conditioner import InputConditioner
17
+ # Register extra models
18
+ from . import extra_timm_models
19
 
20
 
21
  class RADIOModel(nn.Module):
 
25
  input_conditioner: InputConditioner,
26
  return_summary: bool,
27
  return_spatial_features: bool,
28
+ summary_idxs: Optional[torch.Tensor] = None,
29
  ):
30
  super().__init__()
31
 
 
33
  self.input_conditioner = input_conditioner
34
  self.return_summary = return_summary
35
  self.return_spatial_features = return_spatial_features
36
+ self.summary_select_idx = -1
37
+ if summary_idxs is not None:
38
+ self.register_buffer('summary_idxs', summary_idxs)
39
+ else:
40
+ self.summary_idxs = None
41
+
42
+ @property
43
+ def return_both(self):
44
+ return self.return_summary and self.return_spatial_features
45
+
46
+ @property
47
+ def num_summary_tokens(self):
48
+ patch_gen = getattr(self.model, "patch_generator", None)
49
+ if patch_gen is not None:
50
+ return patch_gen.num_skip
51
+ elif self.model.global_pool == 'avg':
52
+ return 0
53
+ return 1
54
 
55
  def forward(self, x: torch.Tensor):
56
  x = self.input_conditioner(x)
 
62
  elif isinstance(self.model, VisionTransformer):
63
  patch_gen = getattr(self.model, "patch_generator", None)
64
  if patch_gen is not None:
65
+ summary = y[:, : patch_gen.num_cls_tokens]
66
+ if self.summary_select_idx >= 0:
67
+ summary = summary[:, self.summary_select_idx]
68
+ elif self.summary_idxs is not None:
69
+ summary = summary[:, self.summary_idxs].flatten(1)
70
+ else:
71
+ summary = summary.flatten(1)
72
  all_feat = y[:, patch_gen.num_skip :]
73
  elif self.model.global_pool == "avg":
74
  summary = y[:, self.model.num_prefix_tokens :].mean(dim=1)
 
79
  else:
80
  raise ValueError("Unsupported model type")
81
 
82
+ if self.return_both:
83
  return summary, all_feat
84
  elif self.return_summary:
85
  return summary
vit_patch_generator.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
@@ -224,12 +224,12 @@ class ViTPatchGenerator(nn.Module):
224
  grid_xy.mul_(2).sub_(1)
225
 
226
  pos_embed = F.grid_sample(
227
- pos_embed.expand(batch_size, -1, -1, -1),
228
  grid=grid_xy,
229
  mode='bilinear',
230
  padding_mode='zeros',
231
  align_corners=True,
232
- )
233
  else:
234
  # i_rows, i_cols = input_dims
235
  # p_rows, p_cols = pos_embed.shape[2:]
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
 
224
  grid_xy.mul_(2).sub_(1)
225
 
226
  pos_embed = F.grid_sample(
227
+ pos_embed.float().expand(batch_size, -1, -1, -1),
228
  grid=grid_xy,
229
  mode='bilinear',
230
  padding_mode='zeros',
231
  align_corners=True,
232
+ ).to(pos_embed.dtype)
233
  else:
234
  # i_rows, i_cols = input_dims
235
  # p_rows, p_cols = pos_embed.shape[2:]