whyu commited on
Commit
b36de75
1 Parent(s): bee2676

Init commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.DS_Store
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import requests
4
+ from PIL import Image
5
+ from timm.data import create_transform
6
+
7
+
8
+ # Prepare the model.
9
+ import models
10
+ model = models.mambaout_femto(pretrained=True) # can change different model name
11
+ model.eval()
12
+
13
+ # Prepare the transform.
14
+ transform = create_transform(input_size=224, crop_pct=model.default_cfg['crop_pct'])
15
+
16
+ # Download human-readable labels for ImageNet.
17
+ response = requests.get("https://git.io/JJkYN")
18
+ labels = response.text.split("\n")
19
+
20
+ def predict(inp):
21
+ inp = transform(inp).unsqueeze(0)
22
+
23
+ with torch.no_grad():
24
+ prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
25
+ confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
26
+ return confidences
27
+
28
+
29
+ title="MambaOut: Do We Really Need Mamba for Vision?"
30
+ description="Gradio demo for MambaOut model (Femto) proposed by [MambaOut: Do We Really Need Mamba for Vision?](https://arxiv.org/abs/2405.07992). To use it simply upload your image or click on one of the examples to load them. Read more at [arXiv](https://arxiv.org/abs/2405.07992) and [GitHub](https://github.com/yuweihao/MambaOut)."
31
+
32
+
33
+ gr.Interface(title=title,
34
+ description=description,
35
+ fn=predict,
36
+ inputs=gr.Image(type="pil"),
37
+ outputs=gr.Label(num_top_classes=3),
38
+ examples=["images/basketball.jpg", "images/Kobe_coffee.jpg"]).launch()
39
+
40
+ # Basketball image credit: https://www.sportsonline.com.au/products/kobe-bryant-hand-signed-basketball-signed-in-silver
41
+ # Kobe coffee image credit: https://aroundsaddleworth.co.uk/wp-content/uploads/2020/01/DSC_0177-scaled.jpg
42
+
43
+
44
+
45
+
images/Kobe_coffee.jpg ADDED
images/basketball.jpg ADDED
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .mambaout import *
models/mambaout.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MambaOut models for image classification.
3
+ Some implementations are modified from:
4
+ timm (https://github.com/rwightman/pytorch-image-models),
5
+ MetaFormer (https://github.com/sail-sg/metaformer),
6
+ InceptionNeXt (https://github.com/sail-sg/inceptionnext)
7
+ """
8
+ from functools import partial
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from timm.models.layers import trunc_normal_, DropPath
13
+ from timm.models.registry import register_model
14
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
15
+
16
+
17
+ def _cfg(url='', **kwargs):
18
+ return {
19
+ 'url': url,
20
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
21
+ 'crop_pct': 1.0, 'interpolation': 'bicubic',
22
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
23
+ **kwargs
24
+ }
25
+
26
+
27
+ default_cfgs = {
28
+ 'mambaout_femto': _cfg(
29
+ url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_femto.pth'),
30
+ 'mambaout_tiny': _cfg(
31
+ url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_tiny.pth'),
32
+ 'mambaout_small': _cfg(
33
+ url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_small.pth'),
34
+ 'mambaout_base': _cfg(
35
+ url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_base.pth'),
36
+ }
37
+
38
+
39
+ class StemLayer(nn.Module):
40
+ r""" Code modified from InternImage:
41
+ https://github.com/OpenGVLab/InternImage
42
+ """
43
+
44
+ def __init__(self,
45
+ in_channels=3,
46
+ out_channels=96,
47
+ act_layer=nn.GELU,
48
+ norm_layer=partial(nn.LayerNorm, eps=1e-6)):
49
+ super().__init__()
50
+ self.conv1 = nn.Conv2d(in_channels,
51
+ out_channels // 2,
52
+ kernel_size=3,
53
+ stride=2,
54
+ padding=1)
55
+ self.norm1 = norm_layer(out_channels // 2)
56
+ self.act = act_layer()
57
+ self.conv2 = nn.Conv2d(out_channels // 2,
58
+ out_channels,
59
+ kernel_size=3,
60
+ stride=2,
61
+ padding=1)
62
+ self.norm2 = norm_layer(out_channels)
63
+
64
+ def forward(self, x):
65
+ x = self.conv1(x)
66
+ x = x.permute(0, 2, 3, 1)
67
+ x = self.norm1(x)
68
+ x = x.permute(0, 3, 1, 2)
69
+ x = self.act(x)
70
+ x = self.conv2(x)
71
+ x = x.permute(0, 2, 3, 1)
72
+ x = self.norm2(x)
73
+ return x
74
+
75
+
76
+ class DownsampleLayer(nn.Module):
77
+ r""" Code modified from InternImage:
78
+ https://github.com/OpenGVLab/InternImage
79
+ """
80
+ def __init__(self, in_channels=96, out_channels=198, norm_layer=partial(nn.LayerNorm, eps=1e-6)):
81
+ super().__init__()
82
+ self.conv = nn.Conv2d(in_channels,
83
+ out_channels,
84
+ kernel_size=3,
85
+ stride=2,
86
+ padding=1)
87
+ self.norm = norm_layer(out_channels)
88
+
89
+ def forward(self, x):
90
+ x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
91
+ x = self.norm(x)
92
+ return x
93
+
94
+
95
+ class MlpHead(nn.Module):
96
+ """ MLP classification head
97
+ """
98
+ def __init__(self, dim, num_classes=1000, act_layer=nn.GELU, mlp_ratio=4,
99
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), head_dropout=0., bias=True):
100
+ super().__init__()
101
+ hidden_features = int(mlp_ratio * dim)
102
+ self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
103
+ self.act = act_layer()
104
+ self.norm = norm_layer(hidden_features)
105
+ self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
106
+ self.head_dropout = nn.Dropout(head_dropout)
107
+
108
+ def forward(self, x):
109
+ x = self.fc1(x)
110
+ x = self.act(x)
111
+ x = self.norm(x)
112
+ x = self.head_dropout(x)
113
+ x = self.fc2(x)
114
+ return x
115
+
116
+
117
+ class GatedCNNBlock(nn.Module):
118
+ r""" Our implementation of Gated CNN Block: https://arxiv.org/pdf/1612.08083
119
+ Args:
120
+ conv_ratio: control the number of channels to conduct depthwise convolution.
121
+ Conduct convolution on partial channels can improve paraitcal efficiency.
122
+ The idea of partical channels is from ShuffleNet V2 (https://arxiv.org/abs/1807.11164) and
123
+ also used by InceptionNeXt (https://arxiv.org/abs/2303.16900) and FasterNet (https://arxiv.org/abs/2303.03667)
124
+ """
125
+ def __init__(self, dim, expension_ratio=8/3, kernel_size=7, conv_ratio=1.0,
126
+ norm_layer=partial(nn.LayerNorm,eps=1e-6),
127
+ act_layer=nn.GELU,
128
+ drop_path=0.,
129
+ **kwargs):
130
+ super().__init__()
131
+ self.norm = norm_layer(dim)
132
+ hidden = int(expension_ratio * dim)
133
+ self.fc1 = nn.Linear(dim, hidden * 2)
134
+ self.act = act_layer()
135
+ conv_channels = int(conv_ratio * dim)
136
+ self.split_indices = (hidden, hidden - conv_channels, conv_channels)
137
+ self.conv = nn.Conv2d(conv_channels, conv_channels, kernel_size=kernel_size, padding=kernel_size//2, groups=conv_channels)
138
+ self.fc2 = nn.Linear(hidden, dim)
139
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
140
+
141
+ def forward(self, x):
142
+ shortcut = x # [B, H, W, C]
143
+ x = self.norm(x)
144
+ g, i, c = torch.split(self.fc1(x), self.split_indices, dim=-1)
145
+ c = c.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
146
+ c = self.conv(c)
147
+ c = c.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
148
+ x = self.fc2(self.act(g) * torch.cat((i, c), dim=-1))
149
+ x = self.drop_path(x)
150
+ return x + shortcut
151
+
152
+ r"""
153
+ downsampling (stem) for the first stage is two layer of conv with k3, s2 and p1
154
+ downsamplings for the last 3 stages is a layer of conv with k3, s2 and p1
155
+ DOWNSAMPLE_LAYERS_FOUR_STAGES format: [Downsampling, Downsampling, Downsampling, Downsampling]
156
+ use `partial` to specify some arguments
157
+ """
158
+ DOWNSAMPLE_LAYERS_FOUR_STAGES = [StemLayer] + [DownsampleLayer]*3
159
+
160
+
161
+ class MambaOut(nn.Module):
162
+ r""" MetaFormer
163
+ A PyTorch impl of : `MetaFormer Baselines for Vision` -
164
+ https://arxiv.org/abs/2210.13452
165
+
166
+ Args:
167
+ in_chans (int): Number of input image channels. Default: 3.
168
+ num_classes (int): Number of classes for classification head. Default: 1000.
169
+ depths (list or tuple): Number of blocks at each stage. Default: [3, 3, 9, 3].
170
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 576].
171
+ downsample_layers: (list or tuple): Downsampling layers before each stage.
172
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
173
+ output_norm: norm before classifier head. Default: partial(nn.LayerNorm, eps=1e-6).
174
+ head_fn: classification head. Default: nn.Linear.
175
+ head_dropout (float): dropout for MLP classifier. Default: 0.
176
+ """
177
+ def __init__(self, in_chans=3, num_classes=1000,
178
+ depths=[3, 3, 9, 3],
179
+ dims=[96, 192, 384, 576],
180
+ downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES,
181
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
182
+ act_layer=nn.GELU,
183
+ conv_ratio=1.0,
184
+ kernel_size=7,
185
+ drop_path_rate=0.,
186
+ output_norm=partial(nn.LayerNorm, eps=1e-6),
187
+ head_fn=MlpHead,
188
+ head_dropout=0.0,
189
+ **kwargs,
190
+ ):
191
+ super().__init__()
192
+ self.num_classes = num_classes
193
+
194
+ if not isinstance(depths, (list, tuple)):
195
+ depths = [depths] # it means the model has only one stage
196
+ if not isinstance(dims, (list, tuple)):
197
+ dims = [dims]
198
+
199
+ num_stage = len(depths)
200
+ self.num_stage = num_stage
201
+
202
+ if not isinstance(downsample_layers, (list, tuple)):
203
+ downsample_layers = [downsample_layers] * num_stage
204
+ down_dims = [in_chans] + dims
205
+ self.downsample_layers = nn.ModuleList(
206
+ [downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)]
207
+ )
208
+
209
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
210
+
211
+ self.stages = nn.ModuleList()
212
+ cur = 0
213
+ for i in range(num_stage):
214
+ stage = nn.Sequential(
215
+ *[GatedCNNBlock(dim=dims[i],
216
+ norm_layer=norm_layer,
217
+ act_layer=act_layer,
218
+ kernel_size=kernel_size,
219
+ conv_ratio=conv_ratio,
220
+ drop_path=dp_rates[cur + j],
221
+ ) for j in range(depths[i])]
222
+ )
223
+ self.stages.append(stage)
224
+ cur += depths[i]
225
+
226
+ self.norm = output_norm(dims[-1])
227
+
228
+ if head_dropout > 0.0:
229
+ self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout)
230
+ else:
231
+ self.head = head_fn(dims[-1], num_classes)
232
+
233
+ self.apply(self._init_weights)
234
+
235
+ def _init_weights(self, m):
236
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
237
+ trunc_normal_(m.weight, std=.02)
238
+ if m.bias is not None:
239
+ nn.init.constant_(m.bias, 0)
240
+
241
+ @torch.jit.ignore
242
+ def no_weight_decay(self):
243
+ return {'norm'}
244
+
245
+ def forward_features(self, x):
246
+ for i in range(self.num_stage):
247
+ x = self.downsample_layers[i](x)
248
+ x = self.stages[i](x)
249
+ return self.norm(x.mean([1, 2])) # (B, H, W, C) -> (B, C)
250
+
251
+ def forward(self, x):
252
+ x = self.forward_features(x)
253
+ x = self.head(x)
254
+ return x
255
+
256
+
257
+
258
+ ###############################################################################
259
+ # a series of MambaOut models
260
+ @register_model
261
+ def mambaout_femto(pretrained=False, **kwargs):
262
+ model = MambaOut(
263
+ depths=[3, 3, 9, 3],
264
+ dims=[48, 96, 192, 288],
265
+ **kwargs)
266
+ model.default_cfg = default_cfgs['mambaout_femto']
267
+ if pretrained:
268
+ state_dict = torch.hub.load_state_dict_from_url(
269
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
270
+ model.load_state_dict(state_dict)
271
+ return model
272
+
273
+
274
+ @register_model
275
+ def mambaout_tiny(pretrained=False, **kwargs):
276
+ model = MambaOut(
277
+ depths=[3, 3, 9, 3],
278
+ dims=[96, 192, 384, 576],
279
+ **kwargs)
280
+ model.default_cfg = default_cfgs['mambaout_tiny']
281
+ if pretrained:
282
+ state_dict = torch.hub.load_state_dict_from_url(
283
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
284
+ model.load_state_dict(state_dict)
285
+ return model
286
+
287
+
288
+ @register_model
289
+ def mambaout_small(pretrained=False, **kwargs):
290
+ model = MambaOut(
291
+ depths=[3, 4, 27, 3],
292
+ dims=[96, 192, 384, 576],
293
+ **kwargs)
294
+ model.default_cfg = default_cfgs['mambaout_small']
295
+ if pretrained:
296
+ state_dict = torch.hub.load_state_dict_from_url(
297
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
298
+ model.load_state_dict(state_dict)
299
+ return model
300
+
301
+
302
+ @register_model
303
+ def mambaout_base(pretrained=False, **kwargs):
304
+ model = MambaOut(
305
+ depths=[3, 4, 27, 3],
306
+ dims=[128, 256, 512, 768],
307
+ **kwargs)
308
+ model.default_cfg = default_cfgs['mambaout_base']
309
+ if pretrained:
310
+ state_dict = torch.hub.load_state_dict_from_url(
311
+ url= model.default_cfg['url'], map_location="cpu", check_hash=True)
312
+ model.load_state_dict(state_dict)
313
+ return model
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ timm==0.6.11