ahatamiz commited on
Commit
e3785c4
1 Parent(s): ede60d3

Create mamba_vision.py

Browse files
Files changed (1) hide show
  1. mamba_vision.py +865 -0
mamba_vision.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from timm.models.registry import register_model
15
+ import math
16
+ from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
17
+ from timm.models._builder import resolve_pretrained_cfg
18
+ try:
19
+ from timm.models._builder import _update_default_kwargs as update_args
20
+ except:
21
+ from timm.models._builder import _update_default_model_kwargs as update_args
22
+ from timm.models.vision_transformer import Mlp, PatchEmbed
23
+ from timm.models.layers import DropPath, trunc_normal_
24
+ from timm.models.registry import register_model
25
+ import torch.nn.functional as F
26
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
27
+ from einops import rearrange, repeat
28
+ from pathlib import Path
29
+ from huggingface_hub import PyTorchModelHubMixin
30
+
31
+
32
+ def _cfg(url='', **kwargs):
33
+ return {'url': url,
34
+ 'num_classes': 1000,
35
+ 'input_size': (3, 224, 224),
36
+ 'pool_size': None,
37
+ 'crop_pct': 0.875,
38
+ 'interpolation': 'bicubic',
39
+ 'fixed_input_size': True,
40
+ 'mean': (0.485, 0.456, 0.406),
41
+ 'std': (0.229, 0.224, 0.225),
42
+ **kwargs
43
+ }
44
+
45
+
46
+ default_cfgs = {
47
+ 'mamba_vision_T': _cfg(url='https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar',
48
+ crop_pct=1.0,
49
+ input_size=(3, 224, 224),
50
+ crop_mode='center'),
51
+ 'mamba_vision_T2': _cfg(url='https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar',
52
+ crop_pct=0.98,
53
+ input_size=(3, 224, 224),
54
+ crop_mode='center'),
55
+ 'mamba_vision_S': _cfg(url='https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar',
56
+ crop_pct=0.93,
57
+ input_size=(3, 224, 224),
58
+ crop_mode='center'),
59
+ 'mamba_vision_B': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar',
60
+ crop_pct=1.0,
61
+ input_size=(3, 224, 224),
62
+ crop_mode='center'),
63
+ 'mamba_vision_L': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar',
64
+ crop_pct=1.0,
65
+ input_size=(3, 224, 224),
66
+ crop_mode='center'),
67
+ 'mamba_vision_L2': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar',
68
+ crop_pct=1.0,
69
+ input_size=(3, 224, 224),
70
+ crop_mode='center')
71
+ }
72
+
73
+
74
+ def window_partition(x, window_size):
75
+ """
76
+ Args:
77
+ x: (B, C, H, W)
78
+ window_size: window size
79
+ h_w: Height of window
80
+ w_w: Width of window
81
+ Returns:
82
+ local window features (num_windows*B, window_size*window_size, C)
83
+ """
84
+ B, C, H, W = x.shape
85
+ x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
86
+ windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
87
+ return windows
88
+
89
+
90
+ def window_reverse(windows, window_size, H, W):
91
+ """
92
+ Args:
93
+ windows: local window features (num_windows*B, window_size, window_size, C)
94
+ window_size: Window size
95
+ H: Height of image
96
+ W: Width of image
97
+ Returns:
98
+ x: (B, C, H, W)
99
+ """
100
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
101
+ x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
102
+ x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W)
103
+ return x
104
+
105
+
106
+ def _load_state_dict(module, state_dict, strict=False, logger=None):
107
+ """Load state_dict to a module.
108
+
109
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
110
+ Default value for ``strict`` is set to ``False`` and the message for
111
+ param mismatch will be shown even if strict is False.
112
+
113
+ Args:
114
+ module (Module): Module that receives the state_dict.
115
+ state_dict (OrderedDict): Weights.
116
+ strict (bool): whether to strictly enforce that the keys
117
+ in :attr:`state_dict` match the keys returned by this module's
118
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
119
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
120
+ message. If not specified, print function will be used.
121
+ """
122
+ unexpected_keys = []
123
+ all_missing_keys = []
124
+ err_msg = []
125
+
126
+ metadata = getattr(state_dict, '_metadata', None)
127
+ state_dict = state_dict.copy()
128
+ if metadata is not None:
129
+ state_dict._metadata = metadata
130
+
131
+ def load(module, prefix=''):
132
+ local_metadata = {} if metadata is None else metadata.get(
133
+ prefix[:-1], {})
134
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
135
+ all_missing_keys, unexpected_keys,
136
+ err_msg)
137
+ for name, child in module._modules.items():
138
+ if child is not None:
139
+ load(child, prefix + name + '.')
140
+
141
+ load(module)
142
+ load = None
143
+ missing_keys = [
144
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
145
+ ]
146
+
147
+ if unexpected_keys:
148
+ err_msg.append('unexpected key in source '
149
+ f'state_dict: {", ".join(unexpected_keys)}\n')
150
+ if missing_keys:
151
+ err_msg.append(
152
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
153
+
154
+
155
+ if len(err_msg) > 0:
156
+ err_msg.insert(
157
+ 0, 'The model and loaded state dict do not match exactly\n')
158
+ err_msg = '\n'.join(err_msg)
159
+ if strict:
160
+ raise RuntimeError(err_msg)
161
+ elif logger is not None:
162
+ logger.warning(err_msg)
163
+ else:
164
+ print(err_msg)
165
+
166
+
167
+ def _load_checkpoint(model,
168
+ filename,
169
+ map_location='cpu',
170
+ strict=False,
171
+ logger=None):
172
+ """Load checkpoint from a file or URI.
173
+
174
+ Args:
175
+ model (Module): Module to load checkpoint.
176
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
177
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
178
+ details.
179
+ map_location (str): Same as :func:`torch.load`.
180
+ strict (bool): Whether to allow different params for the model and
181
+ checkpoint.
182
+ logger (:mod:`logging.Logger` or None): The logger for error message.
183
+
184
+ Returns:
185
+ dict or OrderedDict: The loaded checkpoint.
186
+ """
187
+ checkpoint = torch.load(filename, map_location=map_location)
188
+ if not isinstance(checkpoint, dict):
189
+ raise RuntimeError(
190
+ f'No state_dict found in checkpoint file {filename}')
191
+ if 'state_dict' in checkpoint:
192
+ state_dict = checkpoint['state_dict']
193
+ elif 'model' in checkpoint:
194
+ state_dict = checkpoint['model']
195
+ else:
196
+ state_dict = checkpoint
197
+ if list(state_dict.keys())[0].startswith('module.'):
198
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
199
+
200
+ if sorted(list(state_dict.keys()))[0].startswith('encoder'):
201
+ state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
202
+
203
+ _load_state_dict(model, state_dict, strict, logger)
204
+ return checkpoint
205
+
206
+
207
+ class Downsample(nn.Module):
208
+ """
209
+ Down-sampling block"
210
+ """
211
+
212
+ def __init__(self,
213
+ dim,
214
+ keep_dim=False,
215
+ ):
216
+ """
217
+ Args:
218
+ dim: feature size dimension.
219
+ norm_layer: normalization layer.
220
+ keep_dim: bool argument for maintaining the resolution.
221
+ """
222
+
223
+ super().__init__()
224
+ if keep_dim:
225
+ dim_out = dim
226
+ else:
227
+ dim_out = 2 * dim
228
+ self.reduction = nn.Sequential(
229
+ nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False),
230
+ )
231
+
232
+ def forward(self, x):
233
+ x = self.reduction(x)
234
+ return x
235
+
236
+
237
+ class PatchEmbed(nn.Module):
238
+ """
239
+ Patch embedding block"
240
+ """
241
+
242
+ def __init__(self, in_chans=3, in_dim=64, dim=96):
243
+ """
244
+ Args:
245
+ in_chans: number of input channels.
246
+ dim: feature size dimension.
247
+ """
248
+ # in_dim = 1
249
+ super().__init__()
250
+ self.proj = nn.Identity()
251
+ self.conv_down = nn.Sequential(
252
+ nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False),
253
+ nn.BatchNorm2d(in_dim, eps=1e-4),
254
+ nn.ReLU(),
255
+ nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False),
256
+ nn.BatchNorm2d(dim, eps=1e-4),
257
+ nn.ReLU()
258
+ )
259
+
260
+ def forward(self, x):
261
+ x = self.proj(x)
262
+ x = self.conv_down(x)
263
+ return x
264
+
265
+
266
+ class ConvBlock(nn.Module):
267
+
268
+ def __init__(self, dim,
269
+ drop_path=0.,
270
+ layer_scale=None,
271
+ kernel_size=3):
272
+ super().__init__()
273
+
274
+ self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
275
+ self.norm1 = nn.BatchNorm2d(dim, eps=1e-5)
276
+ self.act1 = nn.GELU(approximate= 'tanh')
277
+ self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
278
+ self.norm2 = nn.BatchNorm2d(dim, eps=1e-5)
279
+ self.layer_scale = layer_scale
280
+ if layer_scale is not None and type(layer_scale) in [int, float]:
281
+ self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
282
+ self.layer_scale = True
283
+ else:
284
+ self.layer_scale = False
285
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
286
+
287
+ def forward(self, x):
288
+ input = x
289
+ x = self.conv1(x)
290
+ x = self.norm1(x)
291
+ x = self.act1(x)
292
+ x = self.conv2(x)
293
+ x = self.norm2(x)
294
+ if self.layer_scale:
295
+ x = x * self.gamma.view(1, -1, 1, 1)
296
+ x = input + self.drop_path(x)
297
+ return x
298
+
299
+
300
+ class MambaVisionMixer(nn.Module):
301
+ def __init__(
302
+ self,
303
+ d_model,
304
+ d_state=16,
305
+ d_conv=4,
306
+ expand=2,
307
+ dt_rank="auto",
308
+ dt_min=0.001,
309
+ dt_max=0.1,
310
+ dt_init="random",
311
+ dt_scale=1.0,
312
+ dt_init_floor=1e-4,
313
+ conv_bias=True,
314
+ bias=False,
315
+ use_fast_path=True,
316
+ layer_idx=None,
317
+ device=None,
318
+ dtype=None,
319
+ ):
320
+ factory_kwargs = {"device": device, "dtype": dtype}
321
+ super().__init__()
322
+ self.d_model = d_model
323
+ self.d_state = d_state
324
+ self.d_conv = d_conv
325
+ self.expand = expand
326
+ self.d_inner = int(self.expand * self.d_model)
327
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
328
+ self.use_fast_path = use_fast_path
329
+ self.layer_idx = layer_idx
330
+ self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
331
+ self.x_proj = nn.Linear(
332
+ self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
333
+ )
334
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner//2, bias=True, **factory_kwargs)
335
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
336
+ if dt_init == "constant":
337
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
338
+ elif dt_init == "random":
339
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
340
+ else:
341
+ raise NotImplementedError
342
+ dt = torch.exp(
343
+ torch.rand(self.d_inner//2, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
344
+ + math.log(dt_min)
345
+ ).clamp(min=dt_init_floor)
346
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
347
+ with torch.no_grad():
348
+ self.dt_proj.bias.copy_(inv_dt)
349
+ self.dt_proj.bias._no_reinit = True
350
+ A = repeat(
351
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
352
+ "n -> d n",
353
+ d=self.d_inner//2,
354
+ ).contiguous()
355
+ A_log = torch.log(A)
356
+ self.A_log = nn.Parameter(A_log)
357
+ self.A_log._no_weight_decay = True
358
+ self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device))
359
+ self.D._no_weight_decay = True
360
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
361
+ self.conv1d_x = nn.Conv1d(
362
+ in_channels=self.d_inner//2,
363
+ out_channels=self.d_inner//2,
364
+ bias=conv_bias//2,
365
+ kernel_size=d_conv,
366
+ groups=self.d_inner//2,
367
+ **factory_kwargs,
368
+ )
369
+ self.conv1d_z = nn.Conv1d(
370
+ in_channels=self.d_inner//2,
371
+ out_channels=self.d_inner//2,
372
+ bias=conv_bias//2,
373
+ kernel_size=d_conv,
374
+ groups=self.d_inner//2,
375
+ **factory_kwargs,
376
+ )
377
+
378
+ def forward(self, hidden_states):
379
+ """
380
+ hidden_states: (B, L, D)
381
+ Returns: same shape as hidden_states
382
+ """
383
+ _, seqlen, _ = hidden_states.shape
384
+ xz = self.in_proj(hidden_states)
385
+ xz = rearrange(xz, "b l d -> b d l")
386
+ x, z = xz.chunk(2, dim=1)
387
+ A = -torch.exp(self.A_log.float())
388
+ x = F.silu(F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2))
389
+ z = F.silu(F.conv1d(input=z, weight=self.conv1d_z.weight, bias=self.conv1d_z.bias, padding='same', groups=self.d_inner//2))
390
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
391
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
392
+ dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen)
393
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
394
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
395
+ y = selective_scan_fn(x,
396
+ dt,
397
+ A,
398
+ B,
399
+ C,
400
+ self.D.float(),
401
+ z=None,
402
+ delta_bias=self.dt_proj.bias.float(),
403
+ delta_softplus=True,
404
+ return_last_state=None)
405
+
406
+ y = torch.cat([y, z], dim=1)
407
+ y = rearrange(y, "b d l -> b l d")
408
+ out = self.out_proj(y)
409
+ return out
410
+
411
+
412
+ class Attention(nn.Module):
413
+
414
+ def __init__(
415
+ self,
416
+ dim,
417
+ num_heads=8,
418
+ qkv_bias=False,
419
+ qk_norm=False,
420
+ attn_drop=0.,
421
+ proj_drop=0.,
422
+ norm_layer=nn.LayerNorm,
423
+ ):
424
+ super().__init__()
425
+ assert dim % num_heads == 0
426
+ self.num_heads = num_heads
427
+ self.head_dim = dim // num_heads
428
+ self.scale = self.head_dim ** -0.5
429
+ self.fused_attn = True
430
+
431
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
432
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
433
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
434
+ self.attn_drop = nn.Dropout(attn_drop)
435
+ self.proj = nn.Linear(dim, dim)
436
+ self.proj_drop = nn.Dropout(proj_drop)
437
+
438
+ def forward(self, x):
439
+ B, N, C = x.shape
440
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
441
+ q, k, v = qkv.unbind(0)
442
+ q, k = self.q_norm(q), self.k_norm(k)
443
+
444
+ if self.fused_attn:
445
+ x = F.scaled_dot_product_attention(
446
+ q, k, v,
447
+ dropout_p=self.attn_drop.p,
448
+ )
449
+ else:
450
+ q = q * self.scale
451
+ attn = q @ k.transpose(-2, -1)
452
+ attn = attn.softmax(dim=-1)
453
+ attn = self.attn_drop(attn)
454
+ x = attn @ v
455
+
456
+ x = x.transpose(1, 2).reshape(B, N, C)
457
+ x = self.proj(x)
458
+ x = self.proj_drop(x)
459
+ return x
460
+
461
+
462
+ class Block(nn.Module):
463
+ def __init__(self,
464
+ dim,
465
+ num_heads,
466
+ counter,
467
+ transformer_blocks,
468
+ mlp_ratio=4.,
469
+ qkv_bias=False,
470
+ qk_scale=False,
471
+ drop=0.,
472
+ attn_drop=0.,
473
+ drop_path=0.,
474
+ act_layer=nn.GELU,
475
+ norm_layer=nn.LayerNorm,
476
+ Mlp_block=Mlp,
477
+ layer_scale=None,
478
+ ):
479
+ super().__init__()
480
+ self.norm1 = norm_layer(dim)
481
+ if counter in transformer_blocks:
482
+ self.mixer = Attention(
483
+ dim,
484
+ num_heads=num_heads,
485
+ qkv_bias=qkv_bias,
486
+ qk_norm=qk_scale,
487
+ attn_drop=attn_drop,
488
+ proj_drop=drop,
489
+ norm_layer=norm_layer,
490
+ )
491
+ else:
492
+ self.mixer = MambaVisionMixer(d_model=dim,
493
+ d_state=8,
494
+ d_conv=3,
495
+ expand=1
496
+ )
497
+
498
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
499
+ self.norm2 = norm_layer(dim)
500
+ mlp_hidden_dim = int(dim * mlp_ratio)
501
+ self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
502
+ use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
503
+ self.gamma_1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
504
+ self.gamma_2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
505
+
506
+ def forward(self, x):
507
+ x = x + self.drop_path(self.gamma_1 * self.mixer(self.norm1(x)))
508
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
509
+ return x
510
+
511
+
512
+ class MambaVisionLayer(nn.Module):
513
+ """
514
+ MambaVision layer"
515
+ """
516
+
517
+ def __init__(self,
518
+ dim,
519
+ depth,
520
+ num_heads,
521
+ window_size,
522
+ conv=False,
523
+ downsample=True,
524
+ mlp_ratio=4.,
525
+ qkv_bias=True,
526
+ qk_scale=None,
527
+ drop=0.,
528
+ attn_drop=0.,
529
+ drop_path=0.,
530
+ layer_scale=None,
531
+ layer_scale_conv=None,
532
+ transformer_blocks = [],
533
+ ):
534
+ """
535
+ Args:
536
+ dim: feature size dimension.
537
+ depth: number of layers in each stage.
538
+ window_size: window size in each stage.
539
+ conv: bool argument for conv stage flag.
540
+ downsample: bool argument for down-sampling.
541
+ mlp_ratio: MLP ratio.
542
+ num_heads: number of heads in each stage.
543
+ qkv_bias: bool argument for query, key, value learnable bias.
544
+ qk_scale: bool argument to scaling query, key.
545
+ drop: dropout rate.
546
+ attn_drop: attention dropout rate.
547
+ drop_path: drop path rate.
548
+ norm_layer: normalization layer.
549
+ layer_scale: layer scaling coefficient.
550
+ layer_scale_conv: conv layer scaling coefficient.
551
+ transformer_blocks: list of transformer blocks.
552
+ """
553
+
554
+ super().__init__()
555
+ self.conv = conv
556
+ self.transformer_block = False
557
+ if conv:
558
+ self.blocks = nn.ModuleList([ConvBlock(dim=dim,
559
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
560
+ layer_scale=layer_scale_conv)
561
+ for i in range(depth)])
562
+ self.transformer_block = False
563
+ else:
564
+ self.transformer_block = True
565
+ self.blocks = nn.ModuleList([Block(dim=dim,
566
+ counter=i,
567
+ transformer_blocks=transformer_blocks,
568
+ num_heads=num_heads,
569
+ mlp_ratio=mlp_ratio,
570
+ qkv_bias=qkv_bias,
571
+ qk_scale=qk_scale,
572
+ drop=drop,
573
+ attn_drop=attn_drop,
574
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
575
+ layer_scale=layer_scale)
576
+ for i in range(depth)])
577
+ self.transformer_block = True
578
+
579
+ self.downsample = None if not downsample else Downsample(dim=dim)
580
+ self.do_gt = False
581
+ self.window_size = window_size
582
+
583
+ def forward(self, x):
584
+ _, _, H, W = x.shape
585
+
586
+ if self.transformer_block:
587
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
588
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
589
+ if pad_r > 0 or pad_b > 0:
590
+ x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b))
591
+ _, _, Hp, Wp = x.shape
592
+ else:
593
+ Hp, Wp = H, W
594
+ x = window_partition(x, self.window_size)
595
+
596
+ for _, blk in enumerate(self.blocks):
597
+ x = blk(x)
598
+ if self.transformer_block:
599
+ x = window_reverse(x, self.window_size, Hp, Wp)
600
+ if pad_r > 0 or pad_b > 0:
601
+ x = x[:, :, :H, :W].contiguous()
602
+ if self.downsample is None:
603
+ return x
604
+ return self.downsample(x)
605
+
606
+
607
+ class MambaVision(nn.Module, PyTorchModelHubMixin):
608
+ """
609
+ MambaVision,
610
+ """
611
+
612
+ def __init__(self,
613
+ dim,
614
+ in_dim,
615
+ depths,
616
+ window_size,
617
+ mlp_ratio,
618
+ num_heads,
619
+ drop_path_rate=0.2,
620
+ in_chans=3,
621
+ num_classes=1000,
622
+ qkv_bias=True,
623
+ qk_scale=None,
624
+ drop_rate=0.,
625
+ attn_drop_rate=0.,
626
+ layer_scale=None,
627
+ layer_scale_conv=None,
628
+ **kwargs):
629
+ """
630
+ Args:
631
+ dim: feature size dimension.
632
+ depths: number of layers in each stage.
633
+ window_size: window size in each stage.
634
+ mlp_ratio: MLP ratio.
635
+ num_heads: number of heads in each stage.
636
+ drop_path_rate: drop path rate.
637
+ in_chans: number of input channels.
638
+ num_classes: number of classes.
639
+ qkv_bias: bool argument for query, key, value learnable bias.
640
+ qk_scale: bool argument to scaling query, key.
641
+ drop_rate: dropout rate.
642
+ attn_drop_rate: attention dropout rate.
643
+ norm_layer: normalization layer.
644
+ layer_scale: layer scaling coefficient.
645
+ layer_scale_conv: conv layer scaling coefficient.
646
+ """
647
+ super().__init__()
648
+ num_features = int(dim * 2 ** (len(depths) - 1))
649
+ self.num_classes = num_classes
650
+ self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim)
651
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
652
+ self.levels = nn.ModuleList()
653
+ for i in range(len(depths)):
654
+ conv = True if (i == 0 or i == 1) else False
655
+ level = MambaVisionLayer(dim=int(dim * 2 ** i),
656
+ depth=depths[i],
657
+ num_heads=num_heads[i],
658
+ window_size=window_size[i],
659
+ mlp_ratio=mlp_ratio,
660
+ qkv_bias=qkv_bias,
661
+ qk_scale=qk_scale,
662
+ conv=conv,
663
+ drop=drop_rate,
664
+ attn_drop=attn_drop_rate,
665
+ drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
666
+ downsample=(i < 3),
667
+ layer_scale=layer_scale,
668
+ layer_scale_conv=layer_scale_conv,
669
+ transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])),
670
+ )
671
+ self.levels.append(level)
672
+ self.norm = nn.BatchNorm2d(num_features)
673
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
674
+ self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
675
+ self.apply(self._init_weights)
676
+
677
+ def _init_weights(self, m):
678
+ if isinstance(m, nn.Linear):
679
+ trunc_normal_(m.weight, std=.02)
680
+ if isinstance(m, nn.Linear) and m.bias is not None:
681
+ nn.init.constant_(m.bias, 0)
682
+ elif isinstance(m, nn.LayerNorm):
683
+ nn.init.constant_(m.bias, 0)
684
+ nn.init.constant_(m.weight, 1.0)
685
+ elif isinstance(m, LayerNorm2d):
686
+ nn.init.constant_(m.bias, 0)
687
+ nn.init.constant_(m.weight, 1.0)
688
+ elif isinstance(m, nn.BatchNorm2d):
689
+ nn.init.ones_(m.weight)
690
+ nn.init.zeros_(m.bias)
691
+
692
+ @torch.jit.ignore
693
+ def no_weight_decay_keywords(self):
694
+ return {'rpb'}
695
+
696
+ def forward_features(self, x):
697
+ x = self.patch_embed(x)
698
+ for level in self.levels:
699
+ x = level(x)
700
+ x = self.norm(x)
701
+ x = self.avgpool(x)
702
+ x = torch.flatten(x, 1)
703
+ return x
704
+
705
+ def forward(self, x):
706
+ x = self.forward_features(x)
707
+ x = self.head(x)
708
+ return x
709
+
710
+ def _load_state_dict(self,
711
+ pretrained,
712
+ strict: bool = False):
713
+ _load_checkpoint(self,
714
+ pretrained,
715
+ strict=strict)
716
+
717
+
718
+ @register_model
719
+ def mamba_vision_T(pretrained=False, **kwargs):
720
+ model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T.pth.tar")
721
+ pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T').to_dict()
722
+ update_args(pretrained_cfg, kwargs, kwargs_filter=None)
723
+ model = MambaVision(depths=[1, 3, 8, 4],
724
+ num_heads=[2, 4, 8, 16],
725
+ window_size=[8, 8, 14, 7],
726
+ dim=80,
727
+ in_dim=32,
728
+ mlp_ratio=4,
729
+ resolution=224,
730
+ drop_path_rate=0.2,
731
+ **kwargs)
732
+ model.pretrained_cfg = pretrained_cfg
733
+ model.default_cfg = model.pretrained_cfg
734
+ if pretrained:
735
+ if not Path(model_path).is_file():
736
+ url = model.default_cfg['url']
737
+ torch.hub.download_url_to_file(url=url, dst=model_path)
738
+ model._load_state_dict(model_path)
739
+ return model
740
+
741
+
742
+ @register_model
743
+ def mamba_vision_T2(pretrained=False, **kwargs):
744
+ model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T2.pth.tar")
745
+ pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T2').to_dict()
746
+ update_args(pretrained_cfg, kwargs, kwargs_filter=None)
747
+ model = MambaVision(depths=[1, 3, 11, 4],
748
+ num_heads=[2, 4, 8, 16],
749
+ window_size=[8, 8, 14, 7],
750
+ dim=80,
751
+ in_dim=32,
752
+ mlp_ratio=4,
753
+ resolution=224,
754
+ drop_path_rate=0.2,
755
+ **kwargs)
756
+ model.pretrained_cfg = pretrained_cfg
757
+ model.default_cfg = model.pretrained_cfg
758
+ if pretrained:
759
+ if not Path(model_path).is_file():
760
+ url = model.default_cfg['url']
761
+ torch.hub.download_url_to_file(url=url, dst=model_path)
762
+ model._load_state_dict(model_path)
763
+ return model
764
+
765
+
766
+ @register_model
767
+ def mamba_vision_S(pretrained=False, **kwargs):
768
+ model_path = kwargs.pop("model_path", "/tmp/mamba_vision_S.pth.tar")
769
+ pretrained_cfg = resolve_pretrained_cfg('mamba_vision_S').to_dict()
770
+ update_args(pretrained_cfg, kwargs, kwargs_filter=None)
771
+ model = MambaVision(depths=[3, 3, 7, 5],
772
+ num_heads=[2, 4, 8, 16],
773
+ window_size=[8, 8, 14, 7],
774
+ dim=96,
775
+ in_dim=64,
776
+ mlp_ratio=4,
777
+ resolution=224,
778
+ drop_path_rate=0.2,
779
+ **kwargs)
780
+ model.pretrained_cfg = pretrained_cfg
781
+ model.default_cfg = model.pretrained_cfg
782
+ if pretrained:
783
+ if not Path(model_path).is_file():
784
+ url = model.default_cfg['url']
785
+ torch.hub.download_url_to_file(url=url, dst=model_path)
786
+ model._load_state_dict(model_path)
787
+ return model
788
+
789
+
790
+ @register_model
791
+ def mamba_vision_B(pretrained=False, **kwargs):
792
+ model_path = kwargs.pop("model_path", "/tmp/mamba_vision_B.pth.tar")
793
+ pretrained_cfg = resolve_pretrained_cfg('mamba_vision_B').to_dict()
794
+ update_args(pretrained_cfg, kwargs, kwargs_filter=None)
795
+ model = MambaVision(depths=[3, 3, 10, 5],
796
+ num_heads=[2, 4, 8, 16],
797
+ window_size=[8, 8, 14, 7],
798
+ dim=128,
799
+ in_dim=64,
800
+ mlp_ratio=4,
801
+ resolution=224,
802
+ drop_path_rate=0.3,
803
+ layer_scale=1e-5,
804
+ layer_scale_conv=None,
805
+ **kwargs)
806
+ model.pretrained_cfg = pretrained_cfg
807
+ model.default_cfg = model.pretrained_cfg
808
+ if pretrained:
809
+ if not Path(model_path).is_file():
810
+ url = model.default_cfg['url']
811
+ torch.hub.download_url_to_file(url=url, dst=model_path)
812
+ model._load_state_dict(model_path)
813
+ return model
814
+
815
+
816
+ @register_model
817
+ def mamba_vision_L(pretrained=False, **kwargs):
818
+ model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L.pth.tar")
819
+ pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L').to_dict()
820
+ update_args(pretrained_cfg, kwargs, kwargs_filter=None)
821
+ model = MambaVision(depths=[3, 3, 10, 5],
822
+ num_heads=[4, 8, 16, 32],
823
+ window_size=[8, 8, 14, 7],
824
+ dim=196,
825
+ in_dim=64,
826
+ mlp_ratio=4,
827
+ resolution=224,
828
+ drop_path_rate=0.3,
829
+ layer_scale=1e-5,
830
+ layer_scale_conv=None,
831
+ **kwargs)
832
+ model.pretrained_cfg = pretrained_cfg
833
+ model.default_cfg = model.pretrained_cfg
834
+ if pretrained:
835
+ if not Path(model_path).is_file():
836
+ url = model.default_cfg['url']
837
+ torch.hub.download_url_to_file(url=url, dst=model_path)
838
+ model._load_state_dict(model_path)
839
+ return model
840
+
841
+
842
+ @register_model
843
+ def mamba_vision_L2(pretrained=False, **kwargs):
844
+ model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L2.pth.tar")
845
+ pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L2').to_dict()
846
+ update_args(pretrained_cfg, kwargs, kwargs_filter=None)
847
+ model = MambaVision(depths=[3, 3, 12, 5],
848
+ num_heads=[4, 8, 16, 32],
849
+ window_size=[8, 8, 14, 7],
850
+ dim=196,
851
+ in_dim=64,
852
+ mlp_ratio=4,
853
+ resolution=224,
854
+ drop_path_rate=0.3,
855
+ layer_scale=1e-5,
856
+ layer_scale_conv=None,
857
+ **kwargs)
858
+ model.pretrained_cfg = pretrained_cfg
859
+ model.default_cfg = model.pretrained_cfg
860
+ if pretrained:
861
+ if not Path(model_path).is_file():
862
+ url = model.default_cfg['url']
863
+ torch.hub.download_url_to_file(url=url, dst=model_path)
864
+ model._load_state_dict(model_path)
865
+ return model