gheinrich commited on
Commit
f134378
1 Parent(s): 23de0da

Upload model

Browse files
Files changed (2) hide show
  1. eradio_model.py +889 -555
  2. pytorch_model.bin +2 -2
eradio_model.py CHANGED
@@ -19,19 +19,25 @@ from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
19
  import numpy as np
20
  import torch.nn.functional as F
21
  from .block import C2f
22
- TRT = False # should help for TRT
 
23
 
24
  import pickle
 
25
  global bias_indx
26
  bias_indx = -1
27
  DEBUG = False
28
 
29
 
30
-
31
  def pixel_unshuffle(data, factor=2):
32
  # performs nn.PixelShuffle(factor) in reverse, torch has some bug for ONNX and TRT, so doing it manually
33
  B, C, H, W = data.shape
34
- return data.view(B, C, factor, H//factor, factor, W//factor).permute(0,1,2,4,3,5).reshape(B, -1, H//factor, W//factor)
 
 
 
 
 
35
 
36
  class SwiGLU(nn.Module):
37
  # should be more advanced, but doesnt improve results so far
@@ -51,7 +57,7 @@ def window_partition(x, window_size):
51
  """
52
  B, C, H, W = x.shape
53
 
54
- if window_size == 0 or (window_size==H and window_size==W):
55
  windows = x.flatten(2).transpose(1, 2)
56
  Hp, Wp = H, W
57
  else:
@@ -62,23 +68,38 @@ def window_partition(x, window_size):
62
  Hp, Wp = H + pad_h, W + pad_w
63
 
64
  x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)
65
- windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
66
 
67
  return windows, (Hp, Wp)
68
 
 
69
  class Conv2d_BN(nn.Module):
70
- '''
71
  Conv2d + BN layer with folding capability to speed up inference
72
- '''
73
- def __init__(self, a, b, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1, bias=False):
 
 
 
 
 
 
 
 
 
 
 
 
74
  super().__init__()
75
- self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, padding, dilation, groups, bias=False)
 
 
76
  if 1:
77
  self.bn = torch.nn.BatchNorm2d(b)
78
  torch.nn.init.constant_(self.bn.weight, bn_weight_init)
79
  torch.nn.init.constant_(self.bn.bias, 0)
80
 
81
- def forward(self,x):
82
  x = self.conv(x)
83
  x = self.bn(x)
84
  return x
@@ -91,14 +112,12 @@ class Conv2d_BN(nn.Module):
91
  c, bn = self.conv, self.bn
92
  w = bn.weight / (bn.running_var + bn.eps) ** 0.5
93
  w = c.weight * w[:, None, None, None]
94
- b = bn.bias - bn.running_mean * bn.weight / \
95
- (bn.running_var + bn.eps)**0.5
96
  self.conv.weight.data.copy_(w)
97
  self.conv.bias = nn.Parameter(b)
98
  self.bn = nn.Identity()
99
 
100
 
101
-
102
  def window_reverse(windows, window_size, H, W, pad_hw):
103
  """
104
  Args:
@@ -113,48 +132,63 @@ def window_reverse(windows, window_size, H, W, pad_hw):
113
  """
114
  # print(f"window_reverse, windows.shape {windows.shape}")
115
  Hp, Wp = pad_hw
116
- if window_size == 0 or (window_size==H and window_size==W):
117
  B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
118
  x = windows.transpose(1, 2).view(B, -1, H, W)
119
  else:
120
  B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
121
- x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
122
- x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], Hp, Wp)
 
 
123
 
124
  if Hp > H or Wp > W:
125
- x = x[:, :, :H, :W, ].contiguous()
126
 
127
  return x
128
 
129
 
130
-
131
  class PosEmbMLPSwinv2D(nn.Module):
132
- def __init__(self, window_size, pretrained_window_size, num_heads, seq_length, no_log=False):
 
 
133
  super().__init__()
134
  self.window_size = window_size
135
  self.num_heads = num_heads
136
  # mlp to generate continuous relative position bias
137
- self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
138
- nn.ReLU(inplace=True),
139
- nn.Linear(512, num_heads, bias=False))
 
 
140
 
141
  # get relative_coords_table
142
- relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
143
- relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
144
- relative_coords_table = torch.stack(
145
- torch.meshgrid([relative_coords_h,
146
- relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
 
 
 
 
 
 
 
147
  if pretrained_window_size[0] > 0:
148
- relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
149
- relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
150
  else:
151
- relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
152
- relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
153
 
154
  if not no_log:
155
  relative_coords_table *= 8 # normalize to -8, 8
156
- relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
157
- torch.abs(relative_coords_table) + 1.0) / np.log2(8)
 
 
 
158
 
159
  self.register_buffer("relative_coords_table", relative_coords_table)
160
 
@@ -163,8 +197,12 @@ class PosEmbMLPSwinv2D(nn.Module):
163
  coords_w = torch.arange(self.window_size[1])
164
  coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
165
  coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
166
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
167
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
 
 
 
 
168
  relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
169
  relative_coords[:, :, 1] += self.window_size[1] - 1
170
  relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
@@ -177,7 +215,7 @@ class PosEmbMLPSwinv2D(nn.Module):
177
 
178
  relative_bias = torch.zeros(1, num_heads, seq_length, seq_length)
179
  self.seq_length = seq_length
180
- self.register_buffer("relative_bias", relative_bias) #for EMA
181
 
182
  def switch_to_deploy(self):
183
  self.deploy = True
@@ -191,7 +229,8 @@ class PosEmbMLPSwinv2D(nn.Module):
191
  # if not (input_tensor.shape[1:] == self.relative_bias.shape[1:]):
192
  # self.grid_exists = False
193
 
194
- if self.training: self.grid_exists = False
 
195
 
196
  if self.deploy and self.grid_exists:
197
  input_tensor += self.relative_bias
@@ -200,12 +239,20 @@ class PosEmbMLPSwinv2D(nn.Module):
200
  if not self.grid_exists:
201
  self.grid_exists = True
202
 
203
- relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
204
- relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
205
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
206
- -1) # Wh*Ww,Wh*Ww,nH
207
-
208
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
 
 
 
 
 
 
 
 
209
  relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
210
 
211
  self.relative_bias = relative_position_bias.unsqueeze(0)
@@ -214,14 +261,25 @@ class PosEmbMLPSwinv2D(nn.Module):
214
  return input_tensor
215
 
216
 
217
-
218
  class GRAAttentionBlock(nn.Module):
219
- def __init__(self, window_size, dim_in, dim_out,
220
- num_heads, drop_path=0., qk_scale=None, qkv_bias=False,
221
- norm_layer=nn.LayerNorm, layer_scale=None,
222
- use_swiglu=True,
223
- subsample_ratio=1, dim_ratio=1, conv_base=False,
224
- do_windowing=True, multi_query=False) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
225
  super().__init__()
226
 
227
  dim = dim_in
@@ -232,49 +290,131 @@ class GRAAttentionBlock(nn.Module):
232
 
233
  if do_windowing:
234
  if SHUFFLE:
235
- self.downsample_op = torch.nn.PixelUnshuffle(subsample_ratio) if subsample_ratio>1 else torch.nn.Identity()
236
- self.downsample_mixer = nn.Conv2d(dim_in * (subsample_ratio * subsample_ratio), dim_in * (dim_ratio), kernel_size=1, stride=1, padding=0, bias=False) if dim*dim_ratio != dim * subsample_ratio * subsample_ratio else torch.nn.Identity()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  else:
238
  if conv_base:
239
- self.downsample_op = nn.Conv2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
 
 
 
 
 
 
 
 
 
240
  self.downsample_mixer = nn.Identity()
241
  else:
242
- self.downsample_op = nn.AvgPool2d(kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
243
- self.downsample_mixer = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1) if subsample_ratio > 1 else nn.Identity()
244
-
 
 
 
 
 
 
 
 
 
245
 
246
  if do_windowing:
247
  if SHUFFLE:
248
- self.upsample_mixer =nn.Conv2d(dim_in * dim_ratio, dim_in * (subsample_ratio * subsample_ratio), kernel_size=1, stride=1, padding=0, bias=False) if dim*dim_ratio != dim * subsample_ratio * subsample_ratio else torch.nn.Identity()
249
- self.upsample_op = torch.nn.PixelShuffle(subsample_ratio) if subsample_ratio>1 else torch.nn.Identity()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  else:
251
  if conv_base:
252
  self.upsample_mixer = nn.Identity()
253
- self.upsample_op = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
 
 
 
 
 
 
 
 
 
254
  else:
255
- self.upsample_mixer = nn.Upsample(scale_factor=subsample_ratio, mode='nearest') if subsample_ratio > 1 else nn.Identity()
256
- self.upsample_op = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1, padding=0, bias=False) if subsample_ratio > 1 else nn.Identity()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  self.window_size = window_size
259
 
260
  self.norm1 = norm_layer(dim_in)
261
  if DEBUG:
262
- print(f"GRAAttentionBlock: input_resolution: , window_size: {window_size}, dim_in: {dim_in}, dim_out: {dim_out}, num_heads: {num_heads}, drop_path: {drop_path}, qk_scale: {qk_scale}, qkv_bias: {qkv_bias}, layer_scale: {layer_scale}")
263
-
 
264
 
265
  self.attn = WindowAttention(
266
  dim_in,
267
- num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
 
 
268
  resolution=window_size,
269
- seq_length=window_size**2, dim_out=dim_in, multi_query=multi_query)
 
 
 
270
  if DEBUG:
271
- print(f"Attention: dim_in: {dim_in}, num_heads: {num_heads}, qkv_bias: {qkv_bias}, qk_scale: {qk_scale}, resolution: {window_size}, seq_length: {window_size**2}, dim_out: {dim_in}")
 
 
272
  print(f"drop_path: {drop_path}, layer_scale: {layer_scale}")
273
 
274
- self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
275
 
276
  use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
277
- self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim_in)) if use_layer_scale else 1
 
 
278
 
279
  ### mlp layer
280
  mlp_ratio = 4
@@ -282,17 +422,27 @@ class GRAAttentionBlock(nn.Module):
282
  mlp_hidden_dim = int(dim_in * mlp_ratio)
283
 
284
  activation = nn.GELU if not use_swiglu else SwiGLU
285
- mlp_hidden_dim = int((4 * dim_in * 1 / 2) / 64) * 64 if use_swiglu else mlp_hidden_dim
286
-
287
- self.mlp = Mlp(in_features=dim_in, hidden_features=mlp_hidden_dim, act_layer=activation, use_swiglu=use_swiglu)
288
-
289
- self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim_in)) if layer_scale else 1
290
- self.drop_path2=DropPath(drop_path) if drop_path > 0. else nn.Identity()
 
 
 
 
 
 
 
 
 
291
  if DEBUG:
292
- print(f"MLP layer: dim_in: {dim_in}, dim_out: {dim_in}, mlp_hidden_dim: {mlp_hidden_dim}")
 
 
293
  print(f"drop_path: {drop_path}, layer_scale: {layer_scale}")
294
 
295
-
296
  def forward(self, x):
297
  skip_connection = x
298
 
@@ -301,15 +451,15 @@ class GRAAttentionBlock(nn.Module):
301
  x = self.downsample_op(x)
302
  x = self.downsample_mixer(x)
303
 
304
- if self.window_size>0:
305
  H, W = x.shape[2], x.shape[3]
306
 
307
  x, pad_hw = window_partition(x, self.window_size)
308
 
309
  # window attention
310
- x = x + self.drop_path1(self.gamma1*self.attn(self.norm1(x)))
311
  # mlp layer
312
- x = x + self.drop_path2(self.gamma2*self.mlp(self.norm2(x)))
313
 
314
  if self.do_windowing:
315
  if self.window_size > 0:
@@ -318,9 +468,19 @@ class GRAAttentionBlock(nn.Module):
318
  x = self.upsample_mixer(x)
319
  x = self.upsample_op(x)
320
 
321
-
322
- if x.shape[2] != skip_connection.shape[2] or x.shape[3] != skip_connection.shape[3]:
323
- x = torch.nn.functional.pad(x, ( 0, -x.shape[3] + skip_connection.shape[3], 0, -x.shape[2] + skip_connection.shape[2]))
 
 
 
 
 
 
 
 
 
 
324
  # need to add skip connection because downsampling and upsampling will break residual connection
325
  # 0.5 is needed to make sure that the skip connection is not too strong
326
  # in case of no downsample / upsample we can show that 0.5 compensates for the residual connection
@@ -329,8 +489,6 @@ class GRAAttentionBlock(nn.Module):
329
  return x
330
 
331
 
332
-
333
-
334
  class MultiResolutionAttention(nn.Module):
335
  """
336
  MultiResolutionAttention (MRA) module
@@ -340,12 +498,23 @@ class MultiResolutionAttention(nn.Module):
340
 
341
  """
342
 
343
- def __init__(self, window_size, sr_ratio,
344
- dim, dim_ratio, num_heads,
345
- do_windowing=True,
346
- layer_scale=1e-5, norm_layer=nn.LayerNorm,
347
- drop_path = 0, qkv_bias=False, qk_scale=1.0,
348
- use_swiglu=True, multi_query=False, conv_base=False) -> None:
 
 
 
 
 
 
 
 
 
 
 
349
  """
350
  Args:
351
  input_resolution: input image resolution
@@ -357,10 +526,8 @@ class MultiResolutionAttention(nn.Module):
357
 
358
  depth = len(sr_ratio)
359
 
360
-
361
  self.attention_blocks = nn.ModuleList()
362
 
363
-
364
  for i in range(depth):
365
  subsample_ratio = sr_ratio[i]
366
  if len(window_size) > i:
@@ -368,15 +535,25 @@ class MultiResolutionAttention(nn.Module):
368
  else:
369
  window_size_local = window_size[0]
370
 
371
- self.attention_blocks.append(GRAAttentionBlock(window_size=window_size_local,
372
- dim_in=dim, dim_out=dim, num_heads=num_heads,
373
- qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer,
374
- layer_scale=layer_scale, drop_path=drop_path,
375
- use_swiglu=use_swiglu, subsample_ratio=subsample_ratio, dim_ratio=dim_ratio,
376
- do_windowing=do_windowing, multi_query=multi_query, conv_base=conv_base),
377
- )
378
-
379
-
 
 
 
 
 
 
 
 
 
 
380
 
381
  def forward(self, x):
382
 
@@ -386,19 +563,20 @@ class MultiResolutionAttention(nn.Module):
386
  return x
387
 
388
 
389
-
390
  class Mlp(nn.Module):
391
  """
392
  Multi-Layer Perceptron (MLP) block
393
  """
394
 
395
- def __init__(self,
396
- in_features,
397
- hidden_features=None,
398
- out_features=None,
399
- act_layer=nn.GELU,
400
- use_swiglu=True,
401
- drop=0.):
 
 
402
  """
403
  Args:
404
  in_features: input features dimension.
@@ -411,7 +589,9 @@ class Mlp(nn.Module):
411
  super().__init__()
412
  out_features = out_features or in_features
413
  hidden_features = hidden_features or in_features
414
- self.fc1 = nn.Linear(in_features, hidden_features * (2 if use_swiglu else 1), bias=False)
 
 
415
  self.act = act_layer()
416
  self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
417
  # self.drop = GaussianDropout(drop)
@@ -427,6 +607,7 @@ class Mlp(nn.Module):
427
  x = x.view(x_size)
428
  return x
429
 
 
430
  class Downsample(nn.Module):
431
  """
432
  Down-sampling block
@@ -434,10 +615,9 @@ class Downsample(nn.Module):
434
  Pixel Unshuffle is used for down-sampling, works great accuracy - wise but takes 10% more TRT time
435
  """
436
 
437
- def __init__(self,
438
- dim,
439
- shuffle = False,
440
- ):
441
  """
442
  Args:
443
  dim: feature size dimension.
@@ -450,14 +630,13 @@ class Downsample(nn.Module):
450
 
451
  if shuffle:
452
  self.norm = lambda x: pixel_unshuffle(x, factor=2)
453
- self.reduction = Conv2d_BN(dim*4, dim_out, 1, 1, 0, bias=False)
454
  else:
455
- #removed layer norm for better, in this formulation we are getting 10% better speed
456
  # LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension
457
  self.norm = nn.Identity()
458
  self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
459
 
460
-
461
  def forward(self, x):
462
  x = self.norm(x)
463
  x = self.reduction(x)
@@ -486,8 +665,8 @@ class PatchEmbed(nn.Module):
486
  Conv2d_BN(in_chans, in_dim, 3, 2, 1, bias=False),
487
  nn.ReLU(),
488
  Conv2d_BN(in_dim, dim, 3, 2, 1, bias=False),
489
- nn.ReLU()
490
- )
491
  else:
492
  self.proj = lambda x: pixel_unshuffle(x, factor=4)
493
 
@@ -496,9 +675,9 @@ class PatchEmbed(nn.Module):
496
  # Conv2d_BN(in_dim, dim, 3, 1, 1),
497
  # nn.SiLU(),
498
  # )
499
- self.conv_down = nn.Sequential(Conv2d_BN(in_chans*16, dim, 3, 1, 1),
500
- nn.ReLU(),
501
- )
502
 
503
  def forward(self, x):
504
  x = self.proj(x)
@@ -506,7 +685,6 @@ class PatchEmbed(nn.Module):
506
  return x
507
 
508
 
509
-
510
  class ConvBlock(nn.Module):
511
  """
512
  Convolutional block, used in first couple of stages
@@ -514,24 +692,30 @@ class ConvBlock(nn.Module):
514
  Experimented with RepVGG, dont see significant improvement in accuracy
515
  Finally, YOLOv8 idea seem to work fine (resnet-18 like block with squeezed feature dimension, and feature concatendation at the end)
516
  """
517
- def __init__(self, dim,
518
- drop_path=0.,
519
- layer_scale=None,
520
- kernel_size=3,
521
- rep_vgg=False):
522
  super().__init__()
523
  self.rep_vgg = rep_vgg
524
  if not rep_vgg:
525
- self.conv1 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
 
 
526
  self.act1 = nn.GELU()
527
  else:
528
- self.conv1 = RepVGGBlock(dim, dim, kernel_size=kernel_size, stride=1, padding=1, groups=1)
529
-
 
530
 
531
  if not rep_vgg:
532
- self.conv2 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
 
 
533
  else:
534
- self.conv2 = RepVGGBlock(dim, dim, kernel_size=kernel_size, stride=1, padding=1, groups=1)
 
 
535
 
536
  self.layer_scale = layer_scale
537
  if layer_scale is not None and type(layer_scale) in [int, float]:
@@ -539,7 +723,7 @@ class ConvBlock(nn.Module):
539
  self.layer_scale = True
540
  else:
541
  self.layer_scale = False
542
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
543
 
544
  def forward(self, x):
545
  input = x
@@ -563,11 +747,21 @@ class WindowAttention(nn.Module):
563
  # look into palm: https://github.com/lucidrains/PaLM-pytorch/blob/main/palm_pytorch/palm_pytorch.py
564
  # single kv attention, mlp in parallel (didnt improve speed)
565
 
566
- def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, resolution=0,
567
- seq_length=0, dim_out=None, multi_query=False):
 
 
 
 
 
 
 
 
 
568
  # taken from EdgeViT and tweaked with attention bias.
569
  super().__init__()
570
- if not dim_out: dim_out = dim
 
571
  self.multi_query = multi_query
572
  self.num_heads = num_heads
573
  head_dim = dim // num_heads
@@ -584,14 +778,16 @@ class WindowAttention(nn.Module):
584
  else:
585
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
586
  else:
587
- self.qkv = nn.Linear(dim, dim + 2*self.head_dim, bias=qkv_bias)
588
 
589
  self.proj = nn.Linear(dim, dim_out, bias=False)
590
  # attention positional bias
591
- self.pos_emb_funct = PosEmbMLPSwinv2D(window_size=[resolution, resolution],
592
- pretrained_window_size=[resolution, resolution],
593
- num_heads=num_heads,
594
- seq_length=seq_length)
 
 
595
 
596
  self.resolution = resolution
597
 
@@ -600,17 +796,37 @@ class WindowAttention(nn.Module):
600
 
601
  if not self.multi_query:
602
  if TRT:
603
- q = self.q(x).reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
604
- k = self.k(x).reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
605
- v = self.v(x).reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
 
 
 
 
 
 
 
 
 
 
 
 
606
  else:
607
- qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
 
 
 
 
608
  q, k, v = qkv[0], qkv[1], qkv[2]
609
  else:
610
  qkv = self.qkv(x)
611
- (q, k, v) = qkv.split([self.dim_internal, self.head_dim, self.head_dim], dim=2)
 
 
612
 
613
- q = q.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
 
 
614
  k = k.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
615
  v = v.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
616
 
@@ -624,35 +840,34 @@ class WindowAttention(nn.Module):
624
  return x
625
 
626
 
627
-
628
  class FasterViTLayer(nn.Module):
629
  """
630
  fastervitlayer
631
  """
632
 
633
- def __init__(self,
634
- dim,
635
- depth,
636
- num_heads,
637
- window_size,
638
- conv=False,
639
- downsample=True,
640
- mlp_ratio=4.,
641
- qkv_bias=False,
642
- qk_scale=None,
643
- norm_layer=nn.LayerNorm,
644
- drop_path=0.,
645
- layer_scale=None,
646
- layer_scale_conv=None,
647
- sr_dim_ratio=1,
648
- sr_ratio=1,
649
- multi_query=False,
650
- use_swiglu=True,
651
- rep_vgg=False,
652
- yolo_arch=False,
653
- downsample_shuffle=False,
654
- conv_base=False,
655
-
656
  ):
657
  """
658
  Args:
@@ -674,23 +889,33 @@ class FasterViTLayer(nn.Module):
674
 
675
  super().__init__()
676
  self.conv = conv
677
- self.yolo_arch=False
678
  if conv:
679
  if not yolo_arch:
680
- self.blocks = nn.ModuleList([
681
- ConvBlock(dim=dim,
682
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
683
- layer_scale=layer_scale_conv, rep_vgg=rep_vgg)
684
- for i in range(depth)])
 
 
 
 
 
 
 
 
685
  else:
686
- self.blocks = C2f(dim,dim,n=depth,shortcut=True,e=0.5)
687
- self.yolo_arch=True
688
  else:
689
- if not isinstance(window_size, list): window_size = [window_size]
 
690
  self.window_size = window_size[0]
691
  self.do_single_windowing = True
692
- if not isinstance(sr_ratio, list): sr_ratio = [sr_ratio]
693
- if any([sr!=1 for sr in sr_ratio]) or len(set(window_size))>1:
 
694
  self.do_single_windowing = False
695
  do_windowing = True
696
  else:
@@ -701,29 +926,31 @@ class FasterViTLayer(nn.Module):
701
  for i in range(depth):
702
 
703
  self.blocks.append(
704
- MultiResolutionAttention(window_size=window_size,
705
- sr_ratio=sr_ratio,
706
- dim=dim,
707
- dim_ratio = sr_dim_ratio,
708
- num_heads=num_heads,
709
- norm_layer=norm_layer,
710
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
711
- layer_scale=layer_scale,
712
- qkv_bias=qkv_bias,
713
- qk_scale=qk_scale,
714
- use_swiglu=use_swiglu,
715
- do_windowing=do_windowing,
716
- multi_query=multi_query,
717
- conv_base=conv_base,
718
- ))
 
 
 
 
719
 
720
  self.transformer = not conv
721
 
722
-
723
- self.downsample = None if not downsample else Downsample(dim=dim, shuffle=downsample_shuffle)
724
-
725
-
726
-
727
 
728
  def forward(self, x):
729
  B, C, H, W = x.shape
@@ -741,11 +968,10 @@ class FasterViTLayer(nn.Module):
741
  if self.transformer and self.do_single_windowing:
742
  x = window_reverse(x, self.window_size, H, W, pad_hw)
743
 
744
-
745
  if self.downsample is None:
746
  return x, x
747
 
748
- return self.downsample(x), x #changing to output pre downsampled features
749
 
750
 
751
  class FasterViT(nn.Module):
@@ -753,37 +979,39 @@ class FasterViT(nn.Module):
753
  FasterViT
754
  """
755
 
756
- def __init__(self,
757
- dim,
758
- in_dim,
759
- depths,
760
- window_size,
761
- mlp_ratio,
762
- num_heads,
763
- drop_path_rate=0.2,
764
- in_chans=3,
765
- num_classes=1000,
766
- qkv_bias=False,
767
- qk_scale=None,
768
- layer_scale=None,
769
- layer_scale_conv=None,
770
- layer_norm_last=False,
771
- sr_ratio = [1, 1, 1, 1],
772
- max_depth = -1,
773
- conv_base=False,
774
- use_swiglu=False,
775
- multi_query=False,
776
- norm_layer=nn.LayerNorm,
777
- rep_vgg=False,
778
- drop_uniform=False,
779
- yolo_arch=False,
780
- shuffle_down=False,
781
- downsample_shuffle=False,
782
- return_full_features=False,
783
- full_features_head_dim=128,
784
- neck_start_stage=1,
785
- use_neck=False,
786
- **kwargs):
 
 
787
  """
788
  Args:
789
  dim: feature size dimension.
@@ -811,7 +1039,9 @@ class FasterViT(nn.Module):
811
 
812
  num_features = int(dim * 2 ** (len(depths) - 1))
813
  self.num_classes = num_classes
814
- self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim, shuffle_down=shuffle_down)
 
 
815
  # set return_full_features true if we want to return full features from all stages
816
  self.return_full_features = return_full_features
817
  self.use_neck = use_neck
@@ -820,32 +1050,35 @@ class FasterViT(nn.Module):
820
  if drop_uniform:
821
  dpr = [drop_path_rate for x in range(sum(depths))]
822
 
823
- if not isinstance(max_depth, list): max_depth = [max_depth] * len(depths)
 
824
 
825
  self.levels = nn.ModuleList()
826
  for i in range(len(depths)):
827
  conv = True if (i == 0 or i == 1) else False
828
 
829
- level = FasterViTLayer(dim=int(dim * 2 ** i),
830
- depth=depths[i],
831
- num_heads=num_heads[i],
832
- window_size=window_size[i],
833
- mlp_ratio=mlp_ratio,
834
- qkv_bias=qkv_bias,
835
- qk_scale=qk_scale,
836
- conv=conv,
837
- drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
838
- downsample=(i < 3),
839
- layer_scale=layer_scale,
840
- layer_scale_conv=layer_scale_conv,
841
- sr_ratio=sr_ratio[i],
842
- use_swiglu=use_swiglu,
843
- multi_query=multi_query,
844
- norm_layer=norm_layer,
845
- rep_vgg=rep_vgg,
846
- yolo_arch=yolo_arch,
847
- downsample_shuffle=downsample_shuffle,
848
- conv_base=conv_base)
 
 
849
 
850
  self.levels.append(level)
851
 
@@ -857,50 +1090,84 @@ class FasterViT(nn.Module):
857
  for i in range(len(depths)):
858
  level_n_features_output = int(dim * 2 ** i)
859
 
860
- if self.neck_start_stage > i: continue
 
861
 
862
- if (upsample_ratio > 1) or full_features_head_dim!=level_n_features_output:
 
 
863
  feature_projection = nn.Sequential()
864
  # feature_projection.add_module("norm",LayerNorm2d(level_n_features_output)) #slow, but better
865
 
866
-
867
- if 0 :
868
  # Train: 0 [1900/10009 ( 19%)] Loss: 6.113 (6.57) Time: 0.548s, 233.40/s (0.549s, 233.04/s) LR: 1.000e-05 Data: 0.015 (0.013)
869
- feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) #fast, but worse
870
- feature_projection.add_module("dconv", nn.ConvTranspose2d(level_n_features_output,
871
- full_features_head_dim, kernel_size=upsample_ratio, stride=upsample_ratio))
 
 
 
 
 
 
 
 
 
872
  else:
873
  # pixel shuffle based upsampling
874
  # Train: 0 [1950/10009 ( 19%)] Loss: 6.190 (6.55) Time: 0.540s, 236.85/s (0.548s, 233.38/s) LR: 1.000e-05 Data: 0.015 (0.013)
875
- feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) #fast, but worse
876
- feature_projection.add_module("conv", nn.Conv2d(level_n_features_output,
877
- full_features_head_dim*upsample_ratio*upsample_ratio, kernel_size=1, stride=1))
878
- feature_projection.add_module("upsample_pixelshuffle", nn.PixelShuffle(upsample_ratio))
 
 
 
 
 
 
 
 
 
 
 
 
 
879
 
880
  else:
881
  feature_projection = nn.Sequential()
882
- feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output))
883
-
 
884
 
885
  self.neck_features_proj.append(feature_projection)
886
 
887
- if i>0 and self.levels[i-1].downsample is not None:
888
  upsample_ratio *= 2
889
 
890
-
891
- num_features = full_features_head_dim if (self.return_full_features or self.use_neck) else num_features
 
 
 
892
 
893
  self.num_features = num_features
894
 
895
- self.norm = LayerNorm2d(num_features) if layer_norm_last else nn.BatchNorm2d(num_features)
 
 
 
 
896
  self.avgpool = nn.AdaptiveAvgPool2d(1)
897
- self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
 
 
898
  self.apply(self._init_weights)
899
  # pass
900
 
901
  def _init_weights(self, m):
902
  if isinstance(m, nn.Linear):
903
- trunc_normal_(m.weight, std=.02)
904
  if isinstance(m, nn.Linear) and m.bias is not None:
905
  nn.init.constant_(m.bias, 0)
906
  elif isinstance(m, nn.LayerNorm):
@@ -915,7 +1182,7 @@ class FasterViT(nn.Module):
915
 
916
  @torch.jit.ignore
917
  def no_weight_decay_keywords(self):
918
- return {'rpb'}
919
 
920
  def forward_features(self, x):
921
  x = self.patch_embed(x)
@@ -924,18 +1191,34 @@ class FasterViT(nn.Module):
924
  x, pre_downsample_x = level(x)
925
 
926
  if self.return_full_features or self.use_neck:
927
- if self.neck_start_stage > il: continue
 
928
  if full_features is None:
929
- full_features = self.neck_features_proj[il - self.neck_start_stage](pre_downsample_x)
 
 
930
  else:
931
- #upsample torch tensor x to match full_features size, and add to full_features
932
- feature_projection = self.neck_features_proj[il - self.neck_start_stage](pre_downsample_x)
933
- if feature_projection.shape[2] != full_features.shape[2] or feature_projection.shape[3] != full_features.shape[3]:
934
- feature_projection = torch.nn.functional.pad(feature_projection, ( 0, -feature_projection.shape[3] + full_features.shape[3], 0, -feature_projection.shape[2] + full_features.shape[2]))
 
 
 
 
 
 
 
 
 
 
 
 
 
935
  full_features += feature_projection
936
 
937
  # x = self.norm(full_features if (self.return_full_features or self.use_neck) else x)
938
- x = self.norm(x) # new version for
939
  x = self.avgpool(x)
940
  x = torch.flatten(x, 1)
941
 
@@ -952,384 +1235,435 @@ class FasterViT(nn.Module):
952
  return x
953
 
954
  def switch_to_deploy(self):
955
- '''
956
  A method to perform model self-compression
957
  merges BN into conv layers
958
  converts MLP relative positional bias into precomputed buffers
959
- '''
960
  for level in [self.patch_embed, self.levels, self.head]:
961
  for module in level.modules():
962
- if hasattr(module, 'switch_to_deploy'):
963
  module.switch_to_deploy()
964
 
 
965
  @register_model
966
- def fastervit2_small(pretrained=False, **kwargs): #,
967
- model = FasterViT(depths=[3, 3, 5, 5],
968
- num_heads=[2, 4, 8, 16],
969
- window_size=[8, 8, [7, 7], 7],
970
- dim=96,
971
- in_dim=64,
972
- mlp_ratio=4,
973
- drop_path_rate=0.2,
974
- sr_ratio=[1, 1, [1, 2], 1],
975
- use_swiglu=False,
976
- downsample_shuffle=False,
977
- yolo_arch=True,
978
- shuffle_down=False,
979
- **kwargs)
 
 
980
  if pretrained:
981
  model.load_state_dict(torch.load(pretrained))
982
  return model
983
 
 
984
  @register_model
985
- def fastervit2_tiny(pretrained=False, **kwargs): #,
986
- model = FasterViT(depths=[1, 3, 4, 5],
987
- num_heads=[2, 4, 8, 16],
988
- window_size=[8, 8, [7, 7], 7],
989
- dim=80,
990
- in_dim=64,
991
- mlp_ratio=4,
992
- drop_path_rate=0.2,
993
- sr_ratio=[1, 1, [2, 1], 1],
994
- use_swiglu=False,
995
- downsample_shuffle=False,
996
- yolo_arch=True,
997
- shuffle_down=False,
998
- **kwargs)
 
 
999
  if pretrained:
1000
  model.load_state_dict(torch.load(pretrained))
1001
  return model
1002
 
 
1003
  @register_model
1004
  def fastervit2_base(pretrained=False, **kwargs):
1005
- model = FasterViT(depths=[3, 3, 5, 5],
1006
- num_heads=[2, 4, 8, 16],
1007
- window_size=[8, 8, [7, 7], 7],
1008
- dim=128,
1009
- in_dim=64,
1010
- mlp_ratio=4,
1011
- drop_path_rate=0.2,
1012
- sr_ratio=[1, 1, [2, 1], 1],
1013
- use_swiglu=False,
1014
- yolo_arch=True,
1015
- shuffle_down=False,
1016
- conv_base=True,
1017
- **kwargs)
 
 
1018
  if pretrained:
1019
  model.load_state_dict(torch.load(pretrained))
1020
  return model
1021
 
 
1022
  @register_model
1023
  def fastervit2_base_fullres1(pretrained=False, **kwargs):
1024
- model = FasterViT(depths=[3, 3, 5, 5],
1025
- num_heads=[2, 4, 8, 16],
1026
- window_size=[8, 8, [7, 7], 7],
1027
- dim=128,
1028
- in_dim=64,
1029
- mlp_ratio=4,
1030
- drop_path_rate=0.2,
1031
- sr_ratio=[1, 1, [2, 1], 1],
1032
- use_swiglu=False,
1033
- yolo_arch=True,
1034
- shuffle_down=False,
1035
- conv_base=True,
1036
- use_neck=True,
1037
- full_features_head_dim=1024,
1038
- neck_start_stage=2,
1039
- **kwargs)
 
 
1040
  if pretrained:
1041
  model.load_state_dict(torch.load(pretrained))
1042
  return model
1043
 
 
1044
  @register_model
1045
  def fastervit2_base_fullres2(pretrained=False, **kwargs):
1046
- model = FasterViT(depths=[3, 3, 5, 5],
1047
- num_heads=[2, 4, 8, 16],
1048
- window_size=[8, 8, [7, 7], 7],
1049
- dim=128,
1050
- in_dim=64,
1051
- mlp_ratio=4,
1052
- drop_path_rate=0.2,
1053
- sr_ratio=[1, 1, [2, 1], 1],
1054
- use_swiglu=False,
1055
- yolo_arch=True,
1056
- shuffle_down=False,
1057
- conv_base=True,
1058
- use_neck=True,
1059
- full_features_head_dim=512,
1060
- neck_start_stage=1,
1061
- **kwargs)
 
 
1062
  if pretrained:
1063
  model.load_state_dict(torch.load(pretrained))
1064
  return model
1065
 
 
1066
  @register_model
1067
  def fastervit2_base_fullres3(pretrained=False, **kwargs):
1068
- model = FasterViT(depths=[3, 3, 5, 5],
1069
- num_heads=[2, 4, 8, 16],
1070
- window_size=[8, 8, [7, 7], 7],
1071
- dim=128,
1072
- in_dim=64,
1073
- mlp_ratio=4,
1074
- drop_path_rate=0.2,
1075
- sr_ratio=[1, 1, [2, 1], 1],
1076
- use_swiglu=False,
1077
- yolo_arch=True,
1078
- shuffle_down=False,
1079
- conv_base=True,
1080
- use_neck=True,
1081
- full_features_head_dim=256,
1082
- neck_start_stage=1,
1083
- **kwargs)
 
 
1084
  if pretrained:
1085
  model.load_state_dict(torch.load(pretrained))
1086
  return model
1087
 
 
1088
  @register_model
1089
  def fastervit2_base_fullres4(pretrained=False, **kwargs):
1090
- model = FasterViT(depths=[3, 3, 5, 5],
1091
- num_heads=[2, 4, 8, 16],
1092
- window_size=[8, 8, [7, 7], 7],
1093
- dim=128,
1094
- in_dim=64,
1095
- mlp_ratio=4,
1096
- drop_path_rate=0.2,
1097
- sr_ratio=[1, 1, [2, 1], 1],
1098
- use_swiglu=False,
1099
- yolo_arch=True,
1100
- shuffle_down=False,
1101
- conv_base=True,
1102
- use_neck=True,
1103
- full_features_head_dim=256,
1104
- neck_start_stage=2,
1105
- **kwargs)
 
 
1106
  if pretrained:
1107
  model.load_state_dict(torch.load(pretrained))
1108
  return model
1109
 
 
1110
  @register_model
1111
  def fastervit2_base_fullres5(pretrained=False, **kwargs):
1112
- model = FasterViT(depths=[3, 3, 5, 5],
1113
- num_heads=[2, 4, 8, 16],
1114
- window_size=[8, 8, [7, 7], 7],
1115
- dim=128,
1116
- in_dim=64,
1117
- mlp_ratio=4,
1118
- drop_path_rate=0.2,
1119
- sr_ratio=[1, 1, [2, 1], 1],
1120
- use_swiglu=False,
1121
- yolo_arch=True,
1122
- shuffle_down=False,
1123
- conv_base=True,
1124
- use_neck=True,
1125
- full_features_head_dim=512,
1126
- neck_start_stage=2,
1127
- **kwargs)
 
 
1128
  if pretrained:
1129
  model.load_state_dict(torch.load(pretrained))
1130
  return model
1131
 
1132
- #pyt: 1934, 4202 TRT
 
1133
  @register_model
1134
  def fastervit2_large(pretrained=False, **kwargs):
1135
- model = FasterViT(depths=[3, 3, 5, 5],
1136
- num_heads=[2, 4, 8, 16],
1137
- window_size=[8, 8, [7, 7], 7],
1138
- dim=128+64,
1139
- in_dim=64,
1140
- mlp_ratio=4,
1141
- drop_path_rate=0.2,
1142
- sr_ratio=[1, 1, [2, 1], 1],
1143
- use_swiglu=False,
1144
- yolo_arch=True,
1145
- shuffle_down=False,
1146
- **kwargs)
 
 
1147
  if pretrained:
1148
  model.load_state_dict(torch.load(pretrained))
1149
  return model
1150
 
 
1151
  @register_model
1152
  def fastervit2_large_fullres(pretrained=False, **kwargs):
1153
- model = FasterViT(depths=[3, 3, 5, 5],
1154
- num_heads=[2, 4, 8, 16],
1155
- window_size=[None, None, [7, 7], 7],
1156
- dim=192,
1157
- in_dim=64,
1158
- mlp_ratio=4,
1159
- drop_path_rate=0.,
1160
- sr_ratio=[1, 1, [2, 1], 1],
1161
- use_swiglu=False,
1162
- yolo_arch=True,
1163
- shuffle_down=False,
1164
- conv_base=True,
1165
- use_neck=True,
1166
- full_features_head_dim=1536,
1167
- neck_start_stage=2,
1168
- **kwargs)
 
 
1169
  if pretrained:
1170
  model.load_state_dict(torch.load(pretrained))
1171
  return model
1172
 
 
1173
  @register_model
1174
  def fastervit2_large_fullres_ws8(pretrained=False, **kwargs):
1175
- model = FasterViT(depths=[3, 3, 5, 5],
1176
- num_heads=[2, 4, 8, 16],
1177
- window_size=[None, None, [8, 8], 8],
1178
- dim=192,
1179
- in_dim=64,
1180
- mlp_ratio=4,
1181
- drop_path_rate=0.,
1182
- sr_ratio=[1, 1, [2, 1], 1],
1183
- use_swiglu=False,
1184
- yolo_arch=True,
1185
- shuffle_down=False,
1186
- conv_base=True,
1187
- use_neck=True,
1188
- full_features_head_dim=1536,
1189
- neck_start_stage=2,
1190
- **kwargs)
 
 
1191
  if pretrained:
1192
  model.load_state_dict(torch.load(pretrained))
1193
  return model
1194
 
 
1195
  @register_model
1196
  def fastervit2_large_fullres_ws16(pretrained=False, **kwargs):
1197
- model = FasterViT(depths=[3, 3, 5, 5],
1198
- num_heads=[2, 4, 8, 16],
1199
- window_size=[None, None, [16, 16], 16],
1200
- dim=192,
1201
- in_dim=64,
1202
- mlp_ratio=4,
1203
- drop_path_rate=0.,
1204
- sr_ratio=[1, 1, [2, 1], 1],
1205
- use_swiglu=False,
1206
- yolo_arch=True,
1207
- shuffle_down=False,
1208
- conv_base=True,
1209
- use_neck=True,
1210
- full_features_head_dim=1536,
1211
- neck_start_stage=2,
1212
- **kwargs)
 
 
1213
  if pretrained:
1214
  model.load_state_dict(torch.load(pretrained))
1215
  return model
1216
 
 
1217
  @register_model
1218
  def fastervit2_large_fullres_ws32(pretrained=False, **kwargs):
1219
- model = FasterViT(depths=[3, 3, 5, 5],
1220
- num_heads=[2, 4, 8, 16],
1221
- window_size=[None, None, [32, 32], 32],
1222
- dim=192,
1223
- in_dim=64,
1224
- mlp_ratio=4,
1225
- drop_path_rate=0.,
1226
- sr_ratio=[1, 1, [2, 1], 1],
1227
- use_swiglu=False,
1228
- yolo_arch=True,
1229
- shuffle_down=False,
1230
- conv_base=True,
1231
- use_neck=True,
1232
- full_features_head_dim=1536,
1233
- neck_start_stage=2,
1234
- **kwargs)
 
 
1235
  if pretrained:
1236
  model.load_state_dict(torch.load(pretrained))
1237
  return model
1238
 
1239
- #pyt: 897
 
1240
  @register_model
1241
  def fastervit2_xlarge(pretrained=False, **kwargs):
1242
- model = FasterViT(depths=[3, 3, 5, 5],
1243
- num_heads=[2, 4, 8, 16],
1244
- window_size=[8, 8, [7, 7], 7],
1245
- dim=128+128+64,
1246
- in_dim=64,
1247
- mlp_ratio=4,
1248
- drop_path_rate=0.2,
1249
- sr_ratio=[1, 1, [2, 1], 1],
1250
- use_swiglu=False,
1251
- yolo_arch=True,
1252
- shuffle_down=False,
1253
- **kwargs)
 
 
1254
  if pretrained:
1255
  model.load_state_dict(torch.load(pretrained))
1256
  return model
1257
 
1258
 
1259
- #pyt:
1260
  @register_model
1261
  def fastervit2_huge(pretrained=False, **kwargs):
1262
- model = FasterViT(depths=[3, 3, 5, 5],
1263
- num_heads=[2, 4, 8, 16],
1264
- window_size=[8, 8, [7, 7], 7],
1265
- dim=128+128+128+64,
1266
- in_dim=64,
1267
- mlp_ratio=4,
1268
- drop_path_rate=0.2,
1269
- sr_ratio=[1, 1, [2, 1], 1],
1270
- use_swiglu=False,
1271
- yolo_arch=True,
1272
- shuffle_down=False,
1273
- **kwargs)
 
 
1274
  if pretrained:
1275
  model.load_state_dict(torch.load(pretrained))
1276
  return model
1277
 
1278
 
1279
  @register_model
1280
- def fastervit2_xtiny(pretrained=False, **kwargs): #,
1281
- model = FasterViT(depths=[1, 3, 4, 5],
1282
- num_heads=[2, 4, 8, 16],
1283
- window_size=[8, 8, [7, 7], 7],
1284
- dim=64,
1285
- in_dim=64,
1286
- mlp_ratio=4,
1287
- drop_path_rate=0.1,
1288
- sr_ratio=[1, 1, [2, 1], 1],
1289
- use_swiglu=False,
1290
- downsample_shuffle=False,
1291
- yolo_arch=True,
1292
- shuffle_down=False,
1293
- **kwargs)
 
 
1294
  if pretrained:
1295
  model.load_state_dict(torch.load(pretrained))
1296
  return model
1297
 
1298
 
1299
  @register_model
1300
- def fastervit2_xxtiny_5(pretrained=False, **kwargs): #,
1301
- model = FasterViT(depths=[1, 3, 4, 5],
1302
- num_heads=[2, 4, 8, 16],
1303
- window_size=[8, 8, [7, 7], 7],
1304
- dim=48,
1305
- in_dim=64,
1306
- mlp_ratio=4,
1307
- drop_path_rate=0.05,
1308
- sr_ratio=[1, 1, [2, 1], 1],
1309
- use_swiglu=False,
1310
- downsample_shuffle=False,
1311
- yolo_arch=True,
1312
- shuffle_down=False,
1313
- **kwargs)
 
 
1314
  if pretrained:
1315
  model.load_state_dict(torch.load(pretrained))
1316
  return model
1317
 
 
1318
  @register_model
1319
- def fastervit2_xxxtiny(pretrained=False, **kwargs): #,
1320
- model = FasterViT(depths=[1, 3, 4, 5],
1321
- num_heads=[2, 4, 8, 16],
1322
- window_size=[8, 8, [7, 7], 7],
1323
- dim=32,
1324
- in_dim=32,
1325
- mlp_ratio=4,
1326
- drop_path_rate=0.0,
1327
- sr_ratio=[1, 1, [2, 1], 1],
1328
- use_swiglu=False,
1329
- downsample_shuffle=False,
1330
- yolo_arch=True,
1331
- shuffle_down=False,
1332
- **kwargs)
 
 
1333
  if pretrained:
1334
  model.load_state_dict(torch.load(pretrained))
1335
  return model
@@ -1337,4 +1671,4 @@ def fastervit2_xxxtiny(pretrained=False, **kwargs): #,
1337
 
1338
  @register_model
1339
  def eradio(pretrained=False, **kwargs):
1340
- return fastervit2_large_fullres(pretrained=pretrained, **kwargs)
 
19
  import numpy as np
20
  import torch.nn.functional as F
21
  from .block import C2f
22
+
23
+ TRT = False # should help for TRT
24
 
25
  import pickle
26
+
27
  global bias_indx
28
  bias_indx = -1
29
  DEBUG = False
30
 
31
 
 
32
  def pixel_unshuffle(data, factor=2):
33
  # performs nn.PixelShuffle(factor) in reverse, torch has some bug for ONNX and TRT, so doing it manually
34
  B, C, H, W = data.shape
35
+ return (
36
+ data.view(B, C, factor, H // factor, factor, W // factor)
37
+ .permute(0, 1, 2, 4, 3, 5)
38
+ .reshape(B, -1, H // factor, W // factor)
39
+ )
40
+
41
 
42
  class SwiGLU(nn.Module):
43
  # should be more advanced, but doesnt improve results so far
 
57
  """
58
  B, C, H, W = x.shape
59
 
60
+ if window_size == 0 or (window_size == H and window_size == W):
61
  windows = x.flatten(2).transpose(1, 2)
62
  Hp, Wp = H, W
63
  else:
 
68
  Hp, Wp = H + pad_h, W + pad_w
69
 
70
  x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)
71
+ windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size * window_size, C)
72
 
73
  return windows, (Hp, Wp)
74
 
75
+
76
  class Conv2d_BN(nn.Module):
77
+ """
78
  Conv2d + BN layer with folding capability to speed up inference
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ a,
84
+ b,
85
+ kernel_size=1,
86
+ stride=1,
87
+ padding=0,
88
+ dilation=1,
89
+ groups=1,
90
+ bn_weight_init=1,
91
+ bias=False,
92
+ ):
93
  super().__init__()
94
+ self.conv = torch.nn.Conv2d(
95
+ a, b, kernel_size, stride, padding, dilation, groups, bias=False
96
+ )
97
  if 1:
98
  self.bn = torch.nn.BatchNorm2d(b)
99
  torch.nn.init.constant_(self.bn.weight, bn_weight_init)
100
  torch.nn.init.constant_(self.bn.bias, 0)
101
 
102
+ def forward(self, x):
103
  x = self.conv(x)
104
  x = self.bn(x)
105
  return x
 
112
  c, bn = self.conv, self.bn
113
  w = bn.weight / (bn.running_var + bn.eps) ** 0.5
114
  w = c.weight * w[:, None, None, None]
115
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
 
116
  self.conv.weight.data.copy_(w)
117
  self.conv.bias = nn.Parameter(b)
118
  self.bn = nn.Identity()
119
 
120
 
 
121
  def window_reverse(windows, window_size, H, W, pad_hw):
122
  """
123
  Args:
 
132
  """
133
  # print(f"window_reverse, windows.shape {windows.shape}")
134
  Hp, Wp = pad_hw
135
+ if window_size == 0 or (window_size == H and window_size == W):
136
  B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
137
  x = windows.transpose(1, 2).view(B, -1, H, W)
138
  else:
139
  B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
140
+ x = windows.view(
141
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
142
+ )
143
+ x = x.permute(0, 5, 1, 3, 2, 4).reshape(B, windows.shape[2], Hp, Wp)
144
 
145
  if Hp > H or Wp > W:
146
+ x = x[:, :, :H, :W,].contiguous()
147
 
148
  return x
149
 
150
 
 
151
  class PosEmbMLPSwinv2D(nn.Module):
152
+ def __init__(
153
+ self, window_size, pretrained_window_size, num_heads, seq_length, no_log=False
154
+ ):
155
  super().__init__()
156
  self.window_size = window_size
157
  self.num_heads = num_heads
158
  # mlp to generate continuous relative position bias
159
+ self.cpb_mlp = nn.Sequential(
160
+ nn.Linear(2, 512, bias=True),
161
+ nn.ReLU(inplace=True),
162
+ nn.Linear(512, num_heads, bias=False),
163
+ )
164
 
165
  # get relative_coords_table
166
+ relative_coords_h = torch.arange(
167
+ -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32
168
+ )
169
+ relative_coords_w = torch.arange(
170
+ -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32
171
+ )
172
+ relative_coords_table = (
173
+ torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
174
+ .permute(1, 2, 0)
175
+ .contiguous()
176
+ .unsqueeze(0)
177
+ ) # 1, 2*Wh-1, 2*Ww-1, 2
178
  if pretrained_window_size[0] > 0:
179
+ relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
180
+ relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
181
  else:
182
+ relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
183
+ relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
184
 
185
  if not no_log:
186
  relative_coords_table *= 8 # normalize to -8, 8
187
+ relative_coords_table = (
188
+ torch.sign(relative_coords_table)
189
+ * torch.log2(torch.abs(relative_coords_table) + 1.0)
190
+ / np.log2(8)
191
+ )
192
 
193
  self.register_buffer("relative_coords_table", relative_coords_table)
194
 
 
197
  coords_w = torch.arange(self.window_size[1])
198
  coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
199
  coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
200
+ relative_coords = (
201
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
202
+ ) # 2, Wh*Ww, Wh*Ww
203
+ relative_coords = relative_coords.permute(
204
+ 1, 2, 0
205
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
206
  relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
207
  relative_coords[:, :, 1] += self.window_size[1] - 1
208
  relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
 
215
 
216
  relative_bias = torch.zeros(1, num_heads, seq_length, seq_length)
217
  self.seq_length = seq_length
218
+ self.register_buffer("relative_bias", relative_bias) # for EMA
219
 
220
  def switch_to_deploy(self):
221
  self.deploy = True
 
229
  # if not (input_tensor.shape[1:] == self.relative_bias.shape[1:]):
230
  # self.grid_exists = False
231
 
232
+ if self.training:
233
+ self.grid_exists = False
234
 
235
  if self.deploy and self.grid_exists:
236
  input_tensor += self.relative_bias
 
239
  if not self.grid_exists:
240
  self.grid_exists = True
241
 
242
+ relative_position_bias_table = self.cpb_mlp(
243
+ self.relative_coords_table
244
+ ).view(-1, self.num_heads)
245
+ relative_position_bias = relative_position_bias_table[
246
+ self.relative_position_index.view(-1)
247
+ ].view(
248
+ self.window_size[0] * self.window_size[1],
249
+ self.window_size[0] * self.window_size[1],
250
+ -1,
251
+ ) # Wh*Ww,Wh*Ww,nH
252
+
253
+ relative_position_bias = relative_position_bias.permute(
254
+ 2, 0, 1
255
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
256
  relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
257
 
258
  self.relative_bias = relative_position_bias.unsqueeze(0)
 
261
  return input_tensor
262
 
263
 
 
264
  class GRAAttentionBlock(nn.Module):
265
+ def __init__(
266
+ self,
267
+ window_size,
268
+ dim_in,
269
+ dim_out,
270
+ num_heads,
271
+ drop_path=0.0,
272
+ qk_scale=None,
273
+ qkv_bias=False,
274
+ norm_layer=nn.LayerNorm,
275
+ layer_scale=None,
276
+ use_swiglu=True,
277
+ subsample_ratio=1,
278
+ dim_ratio=1,
279
+ conv_base=False,
280
+ do_windowing=True,
281
+ multi_query=False,
282
+ ) -> None:
283
  super().__init__()
284
 
285
  dim = dim_in
 
290
 
291
  if do_windowing:
292
  if SHUFFLE:
293
+ self.downsample_op = (
294
+ torch.nn.PixelUnshuffle(subsample_ratio)
295
+ if subsample_ratio > 1
296
+ else torch.nn.Identity()
297
+ )
298
+ self.downsample_mixer = (
299
+ nn.Conv2d(
300
+ dim_in * (subsample_ratio * subsample_ratio),
301
+ dim_in * (dim_ratio),
302
+ kernel_size=1,
303
+ stride=1,
304
+ padding=0,
305
+ bias=False,
306
+ )
307
+ if dim * dim_ratio != dim * subsample_ratio * subsample_ratio
308
+ else torch.nn.Identity()
309
+ )
310
  else:
311
  if conv_base:
312
+ self.downsample_op = (
313
+ nn.Conv2d(
314
+ dim_in,
315
+ dim_out,
316
+ kernel_size=subsample_ratio,
317
+ stride=subsample_ratio,
318
+ )
319
+ if subsample_ratio > 1
320
+ else nn.Identity()
321
+ )
322
  self.downsample_mixer = nn.Identity()
323
  else:
324
+ self.downsample_op = (
325
+ nn.AvgPool2d(
326
+ kernel_size=subsample_ratio, stride=subsample_ratio
327
+ )
328
+ if subsample_ratio > 1
329
+ else nn.Identity()
330
+ )
331
+ self.downsample_mixer = (
332
+ Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1)
333
+ if subsample_ratio > 1
334
+ else nn.Identity()
335
+ )
336
 
337
  if do_windowing:
338
  if SHUFFLE:
339
+ self.upsample_mixer = (
340
+ nn.Conv2d(
341
+ dim_in * dim_ratio,
342
+ dim_in * (subsample_ratio * subsample_ratio),
343
+ kernel_size=1,
344
+ stride=1,
345
+ padding=0,
346
+ bias=False,
347
+ )
348
+ if dim * dim_ratio != dim * subsample_ratio * subsample_ratio
349
+ else torch.nn.Identity()
350
+ )
351
+ self.upsample_op = (
352
+ torch.nn.PixelShuffle(subsample_ratio)
353
+ if subsample_ratio > 1
354
+ else torch.nn.Identity()
355
+ )
356
  else:
357
  if conv_base:
358
  self.upsample_mixer = nn.Identity()
359
+ self.upsample_op = (
360
+ nn.ConvTranspose2d(
361
+ dim_in,
362
+ dim_out,
363
+ kernel_size=subsample_ratio,
364
+ stride=subsample_ratio,
365
+ )
366
+ if subsample_ratio > 1
367
+ else nn.Identity()
368
+ )
369
  else:
370
+ self.upsample_mixer = (
371
+ nn.Upsample(scale_factor=subsample_ratio, mode="nearest")
372
+ if subsample_ratio > 1
373
+ else nn.Identity()
374
+ )
375
+ self.upsample_op = (
376
+ Conv2d_BN(
377
+ dim_in,
378
+ dim_out,
379
+ kernel_size=1,
380
+ stride=1,
381
+ padding=0,
382
+ bias=False,
383
+ )
384
+ if subsample_ratio > 1
385
+ else nn.Identity()
386
+ )
387
 
388
  self.window_size = window_size
389
 
390
  self.norm1 = norm_layer(dim_in)
391
  if DEBUG:
392
+ print(
393
+ f"GRAAttentionBlock: input_resolution: , window_size: {window_size}, dim_in: {dim_in}, dim_out: {dim_out}, num_heads: {num_heads}, drop_path: {drop_path}, qk_scale: {qk_scale}, qkv_bias: {qkv_bias}, layer_scale: {layer_scale}"
394
+ )
395
 
396
  self.attn = WindowAttention(
397
  dim_in,
398
+ num_heads=num_heads,
399
+ qkv_bias=qkv_bias,
400
+ qk_scale=qk_scale,
401
  resolution=window_size,
402
+ seq_length=window_size ** 2,
403
+ dim_out=dim_in,
404
+ multi_query=multi_query,
405
+ )
406
  if DEBUG:
407
+ print(
408
+ f"Attention: dim_in: {dim_in}, num_heads: {num_heads}, qkv_bias: {qkv_bias}, qk_scale: {qk_scale}, resolution: {window_size}, seq_length: {window_size**2}, dim_out: {dim_in}"
409
+ )
410
  print(f"drop_path: {drop_path}, layer_scale: {layer_scale}")
411
 
412
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
413
 
414
  use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
415
+ self.gamma1 = (
416
+ nn.Parameter(layer_scale * torch.ones(dim_in)) if use_layer_scale else 1
417
+ )
418
 
419
  ### mlp layer
420
  mlp_ratio = 4
 
422
  mlp_hidden_dim = int(dim_in * mlp_ratio)
423
 
424
  activation = nn.GELU if not use_swiglu else SwiGLU
425
+ mlp_hidden_dim = (
426
+ int((4 * dim_in * 1 / 2) / 64) * 64 if use_swiglu else mlp_hidden_dim
427
+ )
428
+
429
+ self.mlp = Mlp(
430
+ in_features=dim_in,
431
+ hidden_features=mlp_hidden_dim,
432
+ act_layer=activation,
433
+ use_swiglu=use_swiglu,
434
+ )
435
+
436
+ self.gamma2 = (
437
+ nn.Parameter(layer_scale * torch.ones(dim_in)) if layer_scale else 1
438
+ )
439
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
440
  if DEBUG:
441
+ print(
442
+ f"MLP layer: dim_in: {dim_in}, dim_out: {dim_in}, mlp_hidden_dim: {mlp_hidden_dim}"
443
+ )
444
  print(f"drop_path: {drop_path}, layer_scale: {layer_scale}")
445
 
 
446
  def forward(self, x):
447
  skip_connection = x
448
 
 
451
  x = self.downsample_op(x)
452
  x = self.downsample_mixer(x)
453
 
454
+ if self.window_size > 0:
455
  H, W = x.shape[2], x.shape[3]
456
 
457
  x, pad_hw = window_partition(x, self.window_size)
458
 
459
  # window attention
460
+ x = x + self.drop_path1(self.gamma1 * self.attn(self.norm1(x)))
461
  # mlp layer
462
+ x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
463
 
464
  if self.do_windowing:
465
  if self.window_size > 0:
 
468
  x = self.upsample_mixer(x)
469
  x = self.upsample_op(x)
470
 
471
+ if (
472
+ x.shape[2] != skip_connection.shape[2]
473
+ or x.shape[3] != skip_connection.shape[3]
474
+ ):
475
+ x = torch.nn.functional.pad(
476
+ x,
477
+ (
478
+ 0,
479
+ -x.shape[3] + skip_connection.shape[3],
480
+ 0,
481
+ -x.shape[2] + skip_connection.shape[2],
482
+ ),
483
+ )
484
  # need to add skip connection because downsampling and upsampling will break residual connection
485
  # 0.5 is needed to make sure that the skip connection is not too strong
486
  # in case of no downsample / upsample we can show that 0.5 compensates for the residual connection
 
489
  return x
490
 
491
 
 
 
492
  class MultiResolutionAttention(nn.Module):
493
  """
494
  MultiResolutionAttention (MRA) module
 
498
 
499
  """
500
 
501
+ def __init__(
502
+ self,
503
+ window_size,
504
+ sr_ratio,
505
+ dim,
506
+ dim_ratio,
507
+ num_heads,
508
+ do_windowing=True,
509
+ layer_scale=1e-5,
510
+ norm_layer=nn.LayerNorm,
511
+ drop_path=0,
512
+ qkv_bias=False,
513
+ qk_scale=1.0,
514
+ use_swiglu=True,
515
+ multi_query=False,
516
+ conv_base=False,
517
+ ) -> None:
518
  """
519
  Args:
520
  input_resolution: input image resolution
 
526
 
527
  depth = len(sr_ratio)
528
 
 
529
  self.attention_blocks = nn.ModuleList()
530
 
 
531
  for i in range(depth):
532
  subsample_ratio = sr_ratio[i]
533
  if len(window_size) > i:
 
535
  else:
536
  window_size_local = window_size[0]
537
 
538
+ self.attention_blocks.append(
539
+ GRAAttentionBlock(
540
+ window_size=window_size_local,
541
+ dim_in=dim,
542
+ dim_out=dim,
543
+ num_heads=num_heads,
544
+ qkv_bias=qkv_bias,
545
+ qk_scale=qk_scale,
546
+ norm_layer=norm_layer,
547
+ layer_scale=layer_scale,
548
+ drop_path=drop_path,
549
+ use_swiglu=use_swiglu,
550
+ subsample_ratio=subsample_ratio,
551
+ dim_ratio=dim_ratio,
552
+ do_windowing=do_windowing,
553
+ multi_query=multi_query,
554
+ conv_base=conv_base,
555
+ ),
556
+ )
557
 
558
  def forward(self, x):
559
 
 
563
  return x
564
 
565
 
 
566
  class Mlp(nn.Module):
567
  """
568
  Multi-Layer Perceptron (MLP) block
569
  """
570
 
571
+ def __init__(
572
+ self,
573
+ in_features,
574
+ hidden_features=None,
575
+ out_features=None,
576
+ act_layer=nn.GELU,
577
+ use_swiglu=True,
578
+ drop=0.0,
579
+ ):
580
  """
581
  Args:
582
  in_features: input features dimension.
 
589
  super().__init__()
590
  out_features = out_features or in_features
591
  hidden_features = hidden_features or in_features
592
+ self.fc1 = nn.Linear(
593
+ in_features, hidden_features * (2 if use_swiglu else 1), bias=False
594
+ )
595
  self.act = act_layer()
596
  self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
597
  # self.drop = GaussianDropout(drop)
 
607
  x = x.view(x_size)
608
  return x
609
 
610
+
611
  class Downsample(nn.Module):
612
  """
613
  Down-sampling block
 
615
  Pixel Unshuffle is used for down-sampling, works great accuracy - wise but takes 10% more TRT time
616
  """
617
 
618
+ def __init__(
619
+ self, dim, shuffle=False,
620
+ ):
 
621
  """
622
  Args:
623
  dim: feature size dimension.
 
630
 
631
  if shuffle:
632
  self.norm = lambda x: pixel_unshuffle(x, factor=2)
633
+ self.reduction = Conv2d_BN(dim * 4, dim_out, 1, 1, 0, bias=False)
634
  else:
635
+ # removed layer norm for better, in this formulation we are getting 10% better speed
636
  # LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension
637
  self.norm = nn.Identity()
638
  self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
639
 
 
640
  def forward(self, x):
641
  x = self.norm(x)
642
  x = self.reduction(x)
 
665
  Conv2d_BN(in_chans, in_dim, 3, 2, 1, bias=False),
666
  nn.ReLU(),
667
  Conv2d_BN(in_dim, dim, 3, 2, 1, bias=False),
668
+ nn.ReLU(),
669
+ )
670
  else:
671
  self.proj = lambda x: pixel_unshuffle(x, factor=4)
672
 
 
675
  # Conv2d_BN(in_dim, dim, 3, 1, 1),
676
  # nn.SiLU(),
677
  # )
678
+ self.conv_down = nn.Sequential(
679
+ Conv2d_BN(in_chans * 16, dim, 3, 1, 1), nn.ReLU(),
680
+ )
681
 
682
  def forward(self, x):
683
  x = self.proj(x)
 
685
  return x
686
 
687
 
 
688
  class ConvBlock(nn.Module):
689
  """
690
  Convolutional block, used in first couple of stages
 
692
  Experimented with RepVGG, dont see significant improvement in accuracy
693
  Finally, YOLOv8 idea seem to work fine (resnet-18 like block with squeezed feature dimension, and feature concatendation at the end)
694
  """
695
+
696
+ def __init__(
697
+ self, dim, drop_path=0.0, layer_scale=None, kernel_size=3, rep_vgg=False
698
+ ):
 
699
  super().__init__()
700
  self.rep_vgg = rep_vgg
701
  if not rep_vgg:
702
+ self.conv1 = Conv2d_BN(
703
+ dim, dim, kernel_size=kernel_size, stride=1, padding=1
704
+ )
705
  self.act1 = nn.GELU()
706
  else:
707
+ self.conv1 = RepVGGBlock(
708
+ dim, dim, kernel_size=kernel_size, stride=1, padding=1, groups=1
709
+ )
710
 
711
  if not rep_vgg:
712
+ self.conv2 = Conv2d_BN(
713
+ dim, dim, kernel_size=kernel_size, stride=1, padding=1
714
+ )
715
  else:
716
+ self.conv2 = RepVGGBlock(
717
+ dim, dim, kernel_size=kernel_size, stride=1, padding=1, groups=1
718
+ )
719
 
720
  self.layer_scale = layer_scale
721
  if layer_scale is not None and type(layer_scale) in [int, float]:
 
723
  self.layer_scale = True
724
  else:
725
  self.layer_scale = False
726
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
727
 
728
  def forward(self, x):
729
  input = x
 
747
  # look into palm: https://github.com/lucidrains/PaLM-pytorch/blob/main/palm_pytorch/palm_pytorch.py
748
  # single kv attention, mlp in parallel (didnt improve speed)
749
 
750
+ def __init__(
751
+ self,
752
+ dim,
753
+ num_heads=8,
754
+ qkv_bias=False,
755
+ qk_scale=None,
756
+ resolution=0,
757
+ seq_length=0,
758
+ dim_out=None,
759
+ multi_query=False,
760
+ ):
761
  # taken from EdgeViT and tweaked with attention bias.
762
  super().__init__()
763
+ if not dim_out:
764
+ dim_out = dim
765
  self.multi_query = multi_query
766
  self.num_heads = num_heads
767
  head_dim = dim // num_heads
 
778
  else:
779
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
780
  else:
781
+ self.qkv = nn.Linear(dim, dim + 2 * self.head_dim, bias=qkv_bias)
782
 
783
  self.proj = nn.Linear(dim, dim_out, bias=False)
784
  # attention positional bias
785
+ self.pos_emb_funct = PosEmbMLPSwinv2D(
786
+ window_size=[resolution, resolution],
787
+ pretrained_window_size=[resolution, resolution],
788
+ num_heads=num_heads,
789
+ seq_length=seq_length,
790
+ )
791
 
792
  self.resolution = resolution
793
 
 
796
 
797
  if not self.multi_query:
798
  if TRT:
799
+ q = (
800
+ self.q(x)
801
+ .reshape(B, -1, self.num_heads, C // self.num_heads)
802
+ .permute(0, 2, 1, 3)
803
+ )
804
+ k = (
805
+ self.k(x)
806
+ .reshape(B, -1, self.num_heads, C // self.num_heads)
807
+ .permute(0, 2, 1, 3)
808
+ )
809
+ v = (
810
+ self.v(x)
811
+ .reshape(B, -1, self.num_heads, C // self.num_heads)
812
+ .permute(0, 2, 1, 3)
813
+ )
814
  else:
815
+ qkv = (
816
+ self.qkv(x)
817
+ .reshape(B, -1, 3, self.num_heads, C // self.num_heads)
818
+ .permute(2, 0, 3, 1, 4)
819
+ )
820
  q, k, v = qkv[0], qkv[1], qkv[2]
821
  else:
822
  qkv = self.qkv(x)
823
+ (q, k, v) = qkv.split(
824
+ [self.dim_internal, self.head_dim, self.head_dim], dim=2
825
+ )
826
 
827
+ q = q.reshape(B, -1, self.num_heads, C // self.num_heads).permute(
828
+ 0, 2, 1, 3
829
+ )
830
  k = k.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
831
  v = v.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
832
 
 
840
  return x
841
 
842
 
 
843
  class FasterViTLayer(nn.Module):
844
  """
845
  fastervitlayer
846
  """
847
 
848
+ def __init__(
849
+ self,
850
+ dim,
851
+ depth,
852
+ num_heads,
853
+ window_size,
854
+ conv=False,
855
+ downsample=True,
856
+ mlp_ratio=4.0,
857
+ qkv_bias=False,
858
+ qk_scale=None,
859
+ norm_layer=nn.LayerNorm,
860
+ drop_path=0.0,
861
+ layer_scale=None,
862
+ layer_scale_conv=None,
863
+ sr_dim_ratio=1,
864
+ sr_ratio=1,
865
+ multi_query=False,
866
+ use_swiglu=True,
867
+ rep_vgg=False,
868
+ yolo_arch=False,
869
+ downsample_shuffle=False,
870
+ conv_base=False,
871
  ):
872
  """
873
  Args:
 
889
 
890
  super().__init__()
891
  self.conv = conv
892
+ self.yolo_arch = False
893
  if conv:
894
  if not yolo_arch:
895
+ self.blocks = nn.ModuleList(
896
+ [
897
+ ConvBlock(
898
+ dim=dim,
899
+ drop_path=drop_path[i]
900
+ if isinstance(drop_path, list)
901
+ else drop_path,
902
+ layer_scale=layer_scale_conv,
903
+ rep_vgg=rep_vgg,
904
+ )
905
+ for i in range(depth)
906
+ ]
907
+ )
908
  else:
909
+ self.blocks = C2f(dim, dim, n=depth, shortcut=True, e=0.5)
910
+ self.yolo_arch = True
911
  else:
912
+ if not isinstance(window_size, list):
913
+ window_size = [window_size]
914
  self.window_size = window_size[0]
915
  self.do_single_windowing = True
916
+ if not isinstance(sr_ratio, list):
917
+ sr_ratio = [sr_ratio]
918
+ if any([sr != 1 for sr in sr_ratio]) or len(set(window_size)) > 1:
919
  self.do_single_windowing = False
920
  do_windowing = True
921
  else:
 
926
  for i in range(depth):
927
 
928
  self.blocks.append(
929
+ MultiResolutionAttention(
930
+ window_size=window_size,
931
+ sr_ratio=sr_ratio,
932
+ dim=dim,
933
+ dim_ratio=sr_dim_ratio,
934
+ num_heads=num_heads,
935
+ norm_layer=norm_layer,
936
+ drop_path=drop_path[i]
937
+ if isinstance(drop_path, list)
938
+ else drop_path,
939
+ layer_scale=layer_scale,
940
+ qkv_bias=qkv_bias,
941
+ qk_scale=qk_scale,
942
+ use_swiglu=use_swiglu,
943
+ do_windowing=do_windowing,
944
+ multi_query=multi_query,
945
+ conv_base=conv_base,
946
+ )
947
+ )
948
 
949
  self.transformer = not conv
950
 
951
+ self.downsample = (
952
+ None if not downsample else Downsample(dim=dim, shuffle=downsample_shuffle)
953
+ )
 
 
954
 
955
  def forward(self, x):
956
  B, C, H, W = x.shape
 
968
  if self.transformer and self.do_single_windowing:
969
  x = window_reverse(x, self.window_size, H, W, pad_hw)
970
 
 
971
  if self.downsample is None:
972
  return x, x
973
 
974
+ return self.downsample(x), x # changing to output pre downsampled features
975
 
976
 
977
  class FasterViT(nn.Module):
 
979
  FasterViT
980
  """
981
 
982
+ def __init__(
983
+ self,
984
+ dim,
985
+ in_dim,
986
+ depths,
987
+ window_size,
988
+ mlp_ratio,
989
+ num_heads,
990
+ drop_path_rate=0.2,
991
+ in_chans=3,
992
+ num_classes=1000,
993
+ qkv_bias=False,
994
+ qk_scale=None,
995
+ layer_scale=None,
996
+ layer_scale_conv=None,
997
+ layer_norm_last=False,
998
+ sr_ratio=[1, 1, 1, 1],
999
+ max_depth=-1,
1000
+ conv_base=False,
1001
+ use_swiglu=False,
1002
+ multi_query=False,
1003
+ norm_layer=nn.LayerNorm,
1004
+ rep_vgg=False,
1005
+ drop_uniform=False,
1006
+ yolo_arch=False,
1007
+ shuffle_down=False,
1008
+ downsample_shuffle=False,
1009
+ return_full_features=False,
1010
+ full_features_head_dim=128,
1011
+ neck_start_stage=1,
1012
+ use_neck=False,
1013
+ **kwargs,
1014
+ ):
1015
  """
1016
  Args:
1017
  dim: feature size dimension.
 
1039
 
1040
  num_features = int(dim * 2 ** (len(depths) - 1))
1041
  self.num_classes = num_classes
1042
+ self.patch_embed = PatchEmbed(
1043
+ in_chans=in_chans, in_dim=in_dim, dim=dim, shuffle_down=shuffle_down
1044
+ )
1045
  # set return_full_features true if we want to return full features from all stages
1046
  self.return_full_features = return_full_features
1047
  self.use_neck = use_neck
 
1050
  if drop_uniform:
1051
  dpr = [drop_path_rate for x in range(sum(depths))]
1052
 
1053
+ if not isinstance(max_depth, list):
1054
+ max_depth = [max_depth] * len(depths)
1055
 
1056
  self.levels = nn.ModuleList()
1057
  for i in range(len(depths)):
1058
  conv = True if (i == 0 or i == 1) else False
1059
 
1060
+ level = FasterViTLayer(
1061
+ dim=int(dim * 2 ** i),
1062
+ depth=depths[i],
1063
+ num_heads=num_heads[i],
1064
+ window_size=window_size[i],
1065
+ mlp_ratio=mlp_ratio,
1066
+ qkv_bias=qkv_bias,
1067
+ qk_scale=qk_scale,
1068
+ conv=conv,
1069
+ drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
1070
+ downsample=(i < 3),
1071
+ layer_scale=layer_scale,
1072
+ layer_scale_conv=layer_scale_conv,
1073
+ sr_ratio=sr_ratio[i],
1074
+ use_swiglu=use_swiglu,
1075
+ multi_query=multi_query,
1076
+ norm_layer=norm_layer,
1077
+ rep_vgg=rep_vgg,
1078
+ yolo_arch=yolo_arch,
1079
+ downsample_shuffle=downsample_shuffle,
1080
+ conv_base=conv_base,
1081
+ )
1082
 
1083
  self.levels.append(level)
1084
 
 
1090
  for i in range(len(depths)):
1091
  level_n_features_output = int(dim * 2 ** i)
1092
 
1093
+ if self.neck_start_stage > i:
1094
+ continue
1095
 
1096
+ if (
1097
+ upsample_ratio > 1
1098
+ ) or full_features_head_dim != level_n_features_output:
1099
  feature_projection = nn.Sequential()
1100
  # feature_projection.add_module("norm",LayerNorm2d(level_n_features_output)) #slow, but better
1101
 
1102
+ if 0:
 
1103
  # Train: 0 [1900/10009 ( 19%)] Loss: 6.113 (6.57) Time: 0.548s, 233.40/s (0.549s, 233.04/s) LR: 1.000e-05 Data: 0.015 (0.013)
1104
+ feature_projection.add_module(
1105
+ "norm", nn.BatchNorm2d(level_n_features_output)
1106
+ ) # fast, but worse
1107
+ feature_projection.add_module(
1108
+ "dconv",
1109
+ nn.ConvTranspose2d(
1110
+ level_n_features_output,
1111
+ full_features_head_dim,
1112
+ kernel_size=upsample_ratio,
1113
+ stride=upsample_ratio,
1114
+ ),
1115
+ )
1116
  else:
1117
  # pixel shuffle based upsampling
1118
  # Train: 0 [1950/10009 ( 19%)] Loss: 6.190 (6.55) Time: 0.540s, 236.85/s (0.548s, 233.38/s) LR: 1.000e-05 Data: 0.015 (0.013)
1119
+ feature_projection.add_module(
1120
+ "norm", nn.BatchNorm2d(level_n_features_output)
1121
+ ) # fast, but worse
1122
+ feature_projection.add_module(
1123
+ "conv",
1124
+ nn.Conv2d(
1125
+ level_n_features_output,
1126
+ full_features_head_dim
1127
+ * upsample_ratio
1128
+ * upsample_ratio,
1129
+ kernel_size=1,
1130
+ stride=1,
1131
+ ),
1132
+ )
1133
+ feature_projection.add_module(
1134
+ "upsample_pixelshuffle", nn.PixelShuffle(upsample_ratio)
1135
+ )
1136
 
1137
  else:
1138
  feature_projection = nn.Sequential()
1139
+ feature_projection.add_module(
1140
+ "norm", nn.BatchNorm2d(level_n_features_output)
1141
+ )
1142
 
1143
  self.neck_features_proj.append(feature_projection)
1144
 
1145
+ if i > 0 and self.levels[i - 1].downsample is not None:
1146
  upsample_ratio *= 2
1147
 
1148
+ num_features = (
1149
+ full_features_head_dim
1150
+ if (self.return_full_features or self.use_neck)
1151
+ else num_features
1152
+ )
1153
 
1154
  self.num_features = num_features
1155
 
1156
+ self.norm = (
1157
+ LayerNorm2d(num_features)
1158
+ if layer_norm_last
1159
+ else nn.BatchNorm2d(num_features)
1160
+ )
1161
  self.avgpool = nn.AdaptiveAvgPool2d(1)
1162
+ self.head = (
1163
+ nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
1164
+ )
1165
  self.apply(self._init_weights)
1166
  # pass
1167
 
1168
  def _init_weights(self, m):
1169
  if isinstance(m, nn.Linear):
1170
+ trunc_normal_(m.weight, std=0.02)
1171
  if isinstance(m, nn.Linear) and m.bias is not None:
1172
  nn.init.constant_(m.bias, 0)
1173
  elif isinstance(m, nn.LayerNorm):
 
1182
 
1183
  @torch.jit.ignore
1184
  def no_weight_decay_keywords(self):
1185
+ return {"rpb"}
1186
 
1187
  def forward_features(self, x):
1188
  x = self.patch_embed(x)
 
1191
  x, pre_downsample_x = level(x)
1192
 
1193
  if self.return_full_features or self.use_neck:
1194
+ if self.neck_start_stage > il:
1195
+ continue
1196
  if full_features is None:
1197
+ full_features = self.neck_features_proj[il - self.neck_start_stage](
1198
+ pre_downsample_x
1199
+ )
1200
  else:
1201
+ # upsample torch tensor x to match full_features size, and add to full_features
1202
+ feature_projection = self.neck_features_proj[
1203
+ il - self.neck_start_stage
1204
+ ](pre_downsample_x)
1205
+ if (
1206
+ feature_projection.shape[2] != full_features.shape[2]
1207
+ or feature_projection.shape[3] != full_features.shape[3]
1208
+ ):
1209
+ feature_projection = torch.nn.functional.pad(
1210
+ feature_projection,
1211
+ (
1212
+ 0,
1213
+ -feature_projection.shape[3] + full_features.shape[3],
1214
+ 0,
1215
+ -feature_projection.shape[2] + full_features.shape[2],
1216
+ ),
1217
+ )
1218
  full_features += feature_projection
1219
 
1220
  # x = self.norm(full_features if (self.return_full_features or self.use_neck) else x)
1221
+ x = self.norm(x) # new version for
1222
  x = self.avgpool(x)
1223
  x = torch.flatten(x, 1)
1224
 
 
1235
  return x
1236
 
1237
  def switch_to_deploy(self):
1238
+ """
1239
  A method to perform model self-compression
1240
  merges BN into conv layers
1241
  converts MLP relative positional bias into precomputed buffers
1242
+ """
1243
  for level in [self.patch_embed, self.levels, self.head]:
1244
  for module in level.modules():
1245
+ if hasattr(module, "switch_to_deploy"):
1246
  module.switch_to_deploy()
1247
 
1248
+
1249
  @register_model
1250
+ def fastervit2_small(pretrained=False, **kwargs): # ,
1251
+ model = FasterViT(
1252
+ depths=[3, 3, 5, 5],
1253
+ num_heads=[2, 4, 8, 16],
1254
+ window_size=[8, 8, [7, 7], 7],
1255
+ dim=96,
1256
+ in_dim=64,
1257
+ mlp_ratio=4,
1258
+ drop_path_rate=0.2,
1259
+ sr_ratio=[1, 1, [1, 2], 1],
1260
+ use_swiglu=False,
1261
+ downsample_shuffle=False,
1262
+ yolo_arch=True,
1263
+ shuffle_down=False,
1264
+ **kwargs,
1265
+ )
1266
  if pretrained:
1267
  model.load_state_dict(torch.load(pretrained))
1268
  return model
1269
 
1270
+
1271
  @register_model
1272
+ def fastervit2_tiny(pretrained=False, **kwargs): # ,
1273
+ model = FasterViT(
1274
+ depths=[1, 3, 4, 5],
1275
+ num_heads=[2, 4, 8, 16],
1276
+ window_size=[8, 8, [7, 7], 7],
1277
+ dim=80,
1278
+ in_dim=64,
1279
+ mlp_ratio=4,
1280
+ drop_path_rate=0.2,
1281
+ sr_ratio=[1, 1, [2, 1], 1],
1282
+ use_swiglu=False,
1283
+ downsample_shuffle=False,
1284
+ yolo_arch=True,
1285
+ shuffle_down=False,
1286
+ **kwargs,
1287
+ )
1288
  if pretrained:
1289
  model.load_state_dict(torch.load(pretrained))
1290
  return model
1291
 
1292
+
1293
  @register_model
1294
  def fastervit2_base(pretrained=False, **kwargs):
1295
+ model = FasterViT(
1296
+ depths=[3, 3, 5, 5],
1297
+ num_heads=[2, 4, 8, 16],
1298
+ window_size=[8, 8, [7, 7], 7],
1299
+ dim=128,
1300
+ in_dim=64,
1301
+ mlp_ratio=4,
1302
+ drop_path_rate=0.2,
1303
+ sr_ratio=[1, 1, [2, 1], 1],
1304
+ use_swiglu=False,
1305
+ yolo_arch=True,
1306
+ shuffle_down=False,
1307
+ conv_base=True,
1308
+ **kwargs,
1309
+ )
1310
  if pretrained:
1311
  model.load_state_dict(torch.load(pretrained))
1312
  return model
1313
 
1314
+
1315
  @register_model
1316
  def fastervit2_base_fullres1(pretrained=False, **kwargs):
1317
+ model = FasterViT(
1318
+ depths=[3, 3, 5, 5],
1319
+ num_heads=[2, 4, 8, 16],
1320
+ window_size=[8, 8, [7, 7], 7],
1321
+ dim=128,
1322
+ in_dim=64,
1323
+ mlp_ratio=4,
1324
+ drop_path_rate=0.2,
1325
+ sr_ratio=[1, 1, [2, 1], 1],
1326
+ use_swiglu=False,
1327
+ yolo_arch=True,
1328
+ shuffle_down=False,
1329
+ conv_base=True,
1330
+ use_neck=True,
1331
+ full_features_head_dim=1024,
1332
+ neck_start_stage=2,
1333
+ **kwargs,
1334
+ )
1335
  if pretrained:
1336
  model.load_state_dict(torch.load(pretrained))
1337
  return model
1338
 
1339
+
1340
  @register_model
1341
  def fastervit2_base_fullres2(pretrained=False, **kwargs):
1342
+ model = FasterViT(
1343
+ depths=[3, 3, 5, 5],
1344
+ num_heads=[2, 4, 8, 16],
1345
+ window_size=[8, 8, [7, 7], 7],
1346
+ dim=128,
1347
+ in_dim=64,
1348
+ mlp_ratio=4,
1349
+ drop_path_rate=0.2,
1350
+ sr_ratio=[1, 1, [2, 1], 1],
1351
+ use_swiglu=False,
1352
+ yolo_arch=True,
1353
+ shuffle_down=False,
1354
+ conv_base=True,
1355
+ use_neck=True,
1356
+ full_features_head_dim=512,
1357
+ neck_start_stage=1,
1358
+ **kwargs,
1359
+ )
1360
  if pretrained:
1361
  model.load_state_dict(torch.load(pretrained))
1362
  return model
1363
 
1364
+
1365
  @register_model
1366
  def fastervit2_base_fullres3(pretrained=False, **kwargs):
1367
+ model = FasterViT(
1368
+ depths=[3, 3, 5, 5],
1369
+ num_heads=[2, 4, 8, 16],
1370
+ window_size=[8, 8, [7, 7], 7],
1371
+ dim=128,
1372
+ in_dim=64,
1373
+ mlp_ratio=4,
1374
+ drop_path_rate=0.2,
1375
+ sr_ratio=[1, 1, [2, 1], 1],
1376
+ use_swiglu=False,
1377
+ yolo_arch=True,
1378
+ shuffle_down=False,
1379
+ conv_base=True,
1380
+ use_neck=True,
1381
+ full_features_head_dim=256,
1382
+ neck_start_stage=1,
1383
+ **kwargs,
1384
+ )
1385
  if pretrained:
1386
  model.load_state_dict(torch.load(pretrained))
1387
  return model
1388
 
1389
+
1390
  @register_model
1391
  def fastervit2_base_fullres4(pretrained=False, **kwargs):
1392
+ model = FasterViT(
1393
+ depths=[3, 3, 5, 5],
1394
+ num_heads=[2, 4, 8, 16],
1395
+ window_size=[8, 8, [7, 7], 7],
1396
+ dim=128,
1397
+ in_dim=64,
1398
+ mlp_ratio=4,
1399
+ drop_path_rate=0.2,
1400
+ sr_ratio=[1, 1, [2, 1], 1],
1401
+ use_swiglu=False,
1402
+ yolo_arch=True,
1403
+ shuffle_down=False,
1404
+ conv_base=True,
1405
+ use_neck=True,
1406
+ full_features_head_dim=256,
1407
+ neck_start_stage=2,
1408
+ **kwargs,
1409
+ )
1410
  if pretrained:
1411
  model.load_state_dict(torch.load(pretrained))
1412
  return model
1413
 
1414
+
1415
  @register_model
1416
  def fastervit2_base_fullres5(pretrained=False, **kwargs):
1417
+ model = FasterViT(
1418
+ depths=[3, 3, 5, 5],
1419
+ num_heads=[2, 4, 8, 16],
1420
+ window_size=[8, 8, [7, 7], 7],
1421
+ dim=128,
1422
+ in_dim=64,
1423
+ mlp_ratio=4,
1424
+ drop_path_rate=0.2,
1425
+ sr_ratio=[1, 1, [2, 1], 1],
1426
+ use_swiglu=False,
1427
+ yolo_arch=True,
1428
+ shuffle_down=False,
1429
+ conv_base=True,
1430
+ use_neck=True,
1431
+ full_features_head_dim=512,
1432
+ neck_start_stage=2,
1433
+ **kwargs,
1434
+ )
1435
  if pretrained:
1436
  model.load_state_dict(torch.load(pretrained))
1437
  return model
1438
 
1439
+
1440
+ # pyt: 1934, 4202 TRT
1441
  @register_model
1442
  def fastervit2_large(pretrained=False, **kwargs):
1443
+ model = FasterViT(
1444
+ depths=[3, 3, 5, 5],
1445
+ num_heads=[2, 4, 8, 16],
1446
+ window_size=[8, 8, [7, 7], 7],
1447
+ dim=128 + 64,
1448
+ in_dim=64,
1449
+ mlp_ratio=4,
1450
+ drop_path_rate=0.2,
1451
+ sr_ratio=[1, 1, [2, 1], 1],
1452
+ use_swiglu=False,
1453
+ yolo_arch=True,
1454
+ shuffle_down=False,
1455
+ **kwargs,
1456
+ )
1457
  if pretrained:
1458
  model.load_state_dict(torch.load(pretrained))
1459
  return model
1460
 
1461
+
1462
  @register_model
1463
  def fastervit2_large_fullres(pretrained=False, **kwargs):
1464
+ model = FasterViT(
1465
+ depths=[3, 3, 5, 5],
1466
+ num_heads=[2, 4, 8, 16],
1467
+ window_size=[None, None, [7, 7], 7],
1468
+ dim=192,
1469
+ in_dim=64,
1470
+ mlp_ratio=4,
1471
+ drop_path_rate=0.0,
1472
+ sr_ratio=[1, 1, [2, 1], 1],
1473
+ use_swiglu=False,
1474
+ yolo_arch=True,
1475
+ shuffle_down=False,
1476
+ conv_base=True,
1477
+ use_neck=True,
1478
+ full_features_head_dim=1536,
1479
+ neck_start_stage=2,
1480
+ **kwargs,
1481
+ )
1482
  if pretrained:
1483
  model.load_state_dict(torch.load(pretrained))
1484
  return model
1485
 
1486
+
1487
  @register_model
1488
  def fastervit2_large_fullres_ws8(pretrained=False, **kwargs):
1489
+ model = FasterViT(
1490
+ depths=[3, 3, 5, 5],
1491
+ num_heads=[2, 4, 8, 16],
1492
+ window_size=[None, None, [8, 8], 8],
1493
+ dim=192,
1494
+ in_dim=64,
1495
+ mlp_ratio=4,
1496
+ drop_path_rate=0.0,
1497
+ sr_ratio=[1, 1, [2, 1], 1],
1498
+ use_swiglu=False,
1499
+ yolo_arch=True,
1500
+ shuffle_down=False,
1501
+ conv_base=True,
1502
+ use_neck=True,
1503
+ full_features_head_dim=1536,
1504
+ neck_start_stage=2,
1505
+ **kwargs,
1506
+ )
1507
  if pretrained:
1508
  model.load_state_dict(torch.load(pretrained))
1509
  return model
1510
 
1511
+
1512
  @register_model
1513
  def fastervit2_large_fullres_ws16(pretrained=False, **kwargs):
1514
+ model = FasterViT(
1515
+ depths=[3, 3, 5, 5],
1516
+ num_heads=[2, 4, 8, 16],
1517
+ window_size=[None, None, [16, 16], 16],
1518
+ dim=192,
1519
+ in_dim=64,
1520
+ mlp_ratio=4,
1521
+ drop_path_rate=0.0,
1522
+ sr_ratio=[1, 1, [2, 1], 1],
1523
+ use_swiglu=False,
1524
+ yolo_arch=True,
1525
+ shuffle_down=False,
1526
+ conv_base=True,
1527
+ use_neck=True,
1528
+ full_features_head_dim=1536,
1529
+ neck_start_stage=2,
1530
+ **kwargs,
1531
+ )
1532
  if pretrained:
1533
  model.load_state_dict(torch.load(pretrained))
1534
  return model
1535
 
1536
+
1537
  @register_model
1538
  def fastervit2_large_fullres_ws32(pretrained=False, **kwargs):
1539
+ model = FasterViT(
1540
+ depths=[3, 3, 5, 5],
1541
+ num_heads=[2, 4, 8, 16],
1542
+ window_size=[None, None, [32, 32], 32],
1543
+ dim=192,
1544
+ in_dim=64,
1545
+ mlp_ratio=4,
1546
+ drop_path_rate=0.0,
1547
+ sr_ratio=[1, 1, [2, 1], 1],
1548
+ use_swiglu=False,
1549
+ yolo_arch=True,
1550
+ shuffle_down=False,
1551
+ conv_base=True,
1552
+ use_neck=True,
1553
+ full_features_head_dim=1536,
1554
+ neck_start_stage=2,
1555
+ **kwargs,
1556
+ )
1557
  if pretrained:
1558
  model.load_state_dict(torch.load(pretrained))
1559
  return model
1560
 
1561
+
1562
+ # pyt: 897
1563
  @register_model
1564
  def fastervit2_xlarge(pretrained=False, **kwargs):
1565
+ model = FasterViT(
1566
+ depths=[3, 3, 5, 5],
1567
+ num_heads=[2, 4, 8, 16],
1568
+ window_size=[8, 8, [7, 7], 7],
1569
+ dim=128 + 128 + 64,
1570
+ in_dim=64,
1571
+ mlp_ratio=4,
1572
+ drop_path_rate=0.2,
1573
+ sr_ratio=[1, 1, [2, 1], 1],
1574
+ use_swiglu=False,
1575
+ yolo_arch=True,
1576
+ shuffle_down=False,
1577
+ **kwargs,
1578
+ )
1579
  if pretrained:
1580
  model.load_state_dict(torch.load(pretrained))
1581
  return model
1582
 
1583
 
1584
+ # pyt:
1585
  @register_model
1586
  def fastervit2_huge(pretrained=False, **kwargs):
1587
+ model = FasterViT(
1588
+ depths=[3, 3, 5, 5],
1589
+ num_heads=[2, 4, 8, 16],
1590
+ window_size=[8, 8, [7, 7], 7],
1591
+ dim=128 + 128 + 128 + 64,
1592
+ in_dim=64,
1593
+ mlp_ratio=4,
1594
+ drop_path_rate=0.2,
1595
+ sr_ratio=[1, 1, [2, 1], 1],
1596
+ use_swiglu=False,
1597
+ yolo_arch=True,
1598
+ shuffle_down=False,
1599
+ **kwargs,
1600
+ )
1601
  if pretrained:
1602
  model.load_state_dict(torch.load(pretrained))
1603
  return model
1604
 
1605
 
1606
  @register_model
1607
+ def fastervit2_xtiny(pretrained=False, **kwargs): # ,
1608
+ model = FasterViT(
1609
+ depths=[1, 3, 4, 5],
1610
+ num_heads=[2, 4, 8, 16],
1611
+ window_size=[8, 8, [7, 7], 7],
1612
+ dim=64,
1613
+ in_dim=64,
1614
+ mlp_ratio=4,
1615
+ drop_path_rate=0.1,
1616
+ sr_ratio=[1, 1, [2, 1], 1],
1617
+ use_swiglu=False,
1618
+ downsample_shuffle=False,
1619
+ yolo_arch=True,
1620
+ shuffle_down=False,
1621
+ **kwargs,
1622
+ )
1623
  if pretrained:
1624
  model.load_state_dict(torch.load(pretrained))
1625
  return model
1626
 
1627
 
1628
  @register_model
1629
+ def fastervit2_xxtiny_5(pretrained=False, **kwargs): # ,
1630
+ model = FasterViT(
1631
+ depths=[1, 3, 4, 5],
1632
+ num_heads=[2, 4, 8, 16],
1633
+ window_size=[8, 8, [7, 7], 7],
1634
+ dim=48,
1635
+ in_dim=64,
1636
+ mlp_ratio=4,
1637
+ drop_path_rate=0.05,
1638
+ sr_ratio=[1, 1, [2, 1], 1],
1639
+ use_swiglu=False,
1640
+ downsample_shuffle=False,
1641
+ yolo_arch=True,
1642
+ shuffle_down=False,
1643
+ **kwargs,
1644
+ )
1645
  if pretrained:
1646
  model.load_state_dict(torch.load(pretrained))
1647
  return model
1648
 
1649
+
1650
  @register_model
1651
+ def fastervit2_xxxtiny(pretrained=False, **kwargs): # ,
1652
+ model = FasterViT(
1653
+ depths=[1, 3, 4, 5],
1654
+ num_heads=[2, 4, 8, 16],
1655
+ window_size=[8, 8, [7, 7], 7],
1656
+ dim=32,
1657
+ in_dim=32,
1658
+ mlp_ratio=4,
1659
+ drop_path_rate=0.0,
1660
+ sr_ratio=[1, 1, [2, 1], 1],
1661
+ use_swiglu=False,
1662
+ downsample_shuffle=False,
1663
+ yolo_arch=True,
1664
+ shuffle_down=False,
1665
+ **kwargs,
1666
+ )
1667
  if pretrained:
1668
  model.load_state_dict(torch.load(pretrained))
1669
  return model
 
1671
 
1672
  @register_model
1673
  def eradio(pretrained=False, **kwargs):
1674
+ return fastervit2_large_fullres_ws16(pretrained=pretrained, **kwargs)
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:115b8f54d0d4999c180718ce138f8078127af9815b6cb507b253e5db10a5723c
3
- size 1057766065
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3840092575224b5ff90adf3b7970a5a5e379f8988241ee8145969e27c32c17e7
3
+ size 1105844337