jwyang commited on
Commit
7f59780
1 Parent(s): 590ef1e

add application

Browse files
Files changed (2) hide show
  1. app.py +123 -0
  2. focalnet.py +634 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import gradio as gr
3
+ import numpy as np
4
+ import cv2
5
+ import torch
6
+ import torch.nn as nn
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
10
+ from timm.data import create_transform
11
+ from timm.data.transforms import _pil_interp
12
+ from focalnet import FocalNet, build_transforms, build_transforms4display
13
+
14
+ # Download human-readable labels for ImageNet.
15
+ response = requests.get("https://git.io/JJkYN")
16
+ labels = response.text.split("\n")
17
+
18
+ '''
19
+ build model
20
+ '''
21
+ model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], use_layerscale=True, use_postln=True)
22
+ url = 'https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_iso_16.pth'
23
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
24
+ model.load_state_dict(checkpoint["model"])
25
+ model = model.cuda(); model.eval()
26
+
27
+ '''
28
+ build data transform
29
+ '''
30
+ eval_transforms = build_transforms(224, center_crop=False)
31
+ display_transforms = build_transforms4display(224, center_crop=False)
32
+
33
+ '''
34
+ build upsampler
35
+ '''
36
+ # upsampler = nn.Upsample(scale_factor=16, mode='bilinear')
37
+
38
+ '''
39
+ borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
40
+ '''
41
+ def show_cam_on_image(img: np.ndarray,
42
+ mask: np.ndarray,
43
+ use_rgb: bool = False,
44
+ colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
45
+ """ This function overlays the cam mask on the image as an heatmap.
46
+ By default the heatmap is in BGR format.
47
+ :param img: The base image in RGB or BGR format.
48
+ :param mask: The cam mask.
49
+ :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
50
+ :param colormap: The OpenCV colormap to be used.
51
+ :returns: The default image with the cam overlay.
52
+ """
53
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
54
+ if use_rgb:
55
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
56
+ heatmap = np.float32(heatmap) / 255
57
+
58
+ if np.max(img) > 1:
59
+ raise Exception(
60
+ "The input image should np.float32 in the range [0, 1]")
61
+
62
+ cam = 0.7*heatmap + 0.3*img
63
+ # cam = cam / np.max(cam)
64
+ return np.uint8(255 * cam)
65
+
66
+ def classify_image(inp):
67
+
68
+ img_t = eval_transforms(inp)
69
+ img_d = display_transforms(inp).permute(1, 2, 0).cpu().numpy()
70
+ print(img_d.min(), img_d.max())
71
+
72
+ prediction = model(img_t.unsqueeze(0).cuda()).softmax(-1).flatten()
73
+
74
+ modulator = model.layers[0].blocks[2].modulation.modulator.norm(2, 1, keepdim=True)
75
+ modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
76
+ modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
77
+ modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
78
+ cam0 = show_cam_on_image(img_d, modulator, use_rgb=True)
79
+
80
+ modulator = model.layers[0].blocks[5].modulation.modulator.norm(2, 1, keepdim=True)
81
+ modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
82
+ modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
83
+ modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
84
+ cam1 = show_cam_on_image(img_d, modulator, use_rgb=True)
85
+
86
+ modulator = model.layers[0].blocks[8].modulation.modulator.norm(2, 1, keepdim=True)
87
+ modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
88
+ modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
89
+ modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
90
+ cam2 = show_cam_on_image(img_d, modulator, use_rgb=True)
91
+
92
+ modulator = model.layers[0].blocks[11].modulation.modulator.norm(2, 1, keepdim=True)
93
+ modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
94
+ modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
95
+ modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
96
+ cam3 = show_cam_on_image(img_d, modulator, use_rgb=True)
97
+
98
+ return Image.fromarray(cam0), Image.fromarray(cam1), Image.fromarray(cam2), Image.fromarray(cam3), {labels[i]: float(prediction[i]) for i in range(1000)}
99
+
100
+
101
+ image = gr.inputs.Image()
102
+ label = gr.outputs.Label(num_top_classes=3)
103
+
104
+ gr.Interface(
105
+ fn=classify_image,
106
+ inputs=image,
107
+ outputs=[
108
+ gr.outputs.Image(
109
+ type="pil",
110
+ label="Modulator at layer 3"),
111
+ gr.outputs.Image(
112
+ type="pil",
113
+ label="Modulator at layer 6"),
114
+ gr.outputs.Image(
115
+ type="pil",
116
+ label="Modulator at layer 9"),
117
+ gr.outputs.Image(
118
+ type="pil",
119
+ label="Modulator at layer 12"),
120
+ label,
121
+ ],
122
+ # examples=[["images/aiko.jpg"], ["images/pencils.jpg"], ["images/donut.png"]],
123
+ ).launch()
focalnet.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # FocalNets -- Focal Modulation Networks
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Jianwei Yang (jianwyan@microsoft.com)
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint as checkpoint
12
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
13
+ from timm.models.registry import register_model
14
+
15
+ from torchvision import transforms
16
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
17
+ from timm.data import create_transform
18
+ from timm.data.transforms import _pil_interp
19
+
20
+ class Mlp(nn.Module):
21
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
22
+ super().__init__()
23
+ out_features = out_features or in_features
24
+ hidden_features = hidden_features or in_features
25
+ self.fc1 = nn.Linear(in_features, hidden_features)
26
+ self.act = act_layer()
27
+ self.fc2 = nn.Linear(hidden_features, out_features)
28
+ self.drop = nn.Dropout(drop)
29
+
30
+ def forward(self, x):
31
+ x = self.fc1(x)
32
+ x = self.act(x)
33
+ x = self.drop(x)
34
+ x = self.fc2(x)
35
+ x = self.drop(x)
36
+ return x
37
+
38
+ class FocalModulation(nn.Module):
39
+ def __init__(self, dim, focal_window, focal_level, focal_factor=2, bias=True, proj_drop=0., use_postln=False):
40
+ super().__init__()
41
+
42
+ self.dim = dim
43
+ self.focal_window = focal_window
44
+ self.focal_level = focal_level
45
+ self.focal_factor = focal_factor
46
+ self.use_postln = use_postln
47
+
48
+ self.f = nn.Linear(dim, 2*dim + (self.focal_level+1), bias=bias)
49
+ self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias)
50
+
51
+ self.act = nn.GELU()
52
+ self.proj = nn.Linear(dim, dim)
53
+ self.proj_drop = nn.Dropout(proj_drop)
54
+ self.focal_layers = nn.ModuleList()
55
+
56
+ self.kernel_sizes = []
57
+ for k in range(self.focal_level):
58
+ kernel_size = self.focal_factor*k + self.focal_window
59
+ self.focal_layers.append(
60
+ nn.Sequential(
61
+ nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1,
62
+ groups=dim, padding=kernel_size//2, bias=False),
63
+ nn.GELU(),
64
+ )
65
+ )
66
+ self.kernel_sizes.append(kernel_size)
67
+ if self.use_postln:
68
+ self.ln = nn.LayerNorm(dim)
69
+
70
+ def forward(self, x):
71
+ """
72
+ Args:
73
+ x: input features with shape of (B, H, W, C)
74
+ """
75
+ C = x.shape[-1]
76
+
77
+ # pre linear projection
78
+ x = self.f(x).permute(0, 3, 1, 2).contiguous()
79
+ q, ctx, self.gates = torch.split(x, (C, C, self.focal_level+1), 1)
80
+
81
+ # context aggreation
82
+ ctx_all = 0
83
+ for l in range(self.focal_level):
84
+ ctx = self.focal_layers[l](ctx)
85
+ ctx_all = ctx_all + ctx*self.gates[:, l:l+1]
86
+ ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
87
+ ctx_all = ctx_all + ctx_global*self.gates[:,self.focal_level:]
88
+
89
+ # focal modulation
90
+ self.modulator = self.h(ctx_all)
91
+ x_out = q*self.modulator
92
+ x_out = x_out.permute(0, 2, 3, 1).contiguous()
93
+ if self.use_postln:
94
+ x_out = self.ln(x_out)
95
+
96
+ # post linear porjection
97
+ x_out = self.proj(x_out)
98
+ x_out = self.proj_drop(x_out)
99
+ return x_out
100
+
101
+ def extra_repr(self) -> str:
102
+ return f'dim={self.dim}'
103
+
104
+ def flops(self, N):
105
+ # calculate flops for 1 window with token length of N
106
+ flops = 0
107
+
108
+ flops += N * self.dim * (self.dim * 2 + (self.focal_level+1))
109
+
110
+ # focal convolution
111
+ for k in range(self.focal_level):
112
+ flops += N * (self.kernel_sizes[k]**2+1) * self.dim
113
+
114
+ # global gating
115
+ flops += N * 1 * self.dim
116
+
117
+ # self.linear
118
+ flops += N * self.dim * (self.dim + 1)
119
+
120
+ # x = self.proj(x)
121
+ flops += N * self.dim * self.dim
122
+ return flops
123
+
124
+ class FocalNetBlock(nn.Module):
125
+ r""" Focal Modulation Network Block.
126
+
127
+ Args:
128
+ dim (int): Number of input channels.
129
+ input_resolution (tuple[int]): Input resulotion.
130
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
131
+ drop (float, optional): Dropout rate. Default: 0.0
132
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
133
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
134
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
135
+ focal_level (int): Number of focal levels.
136
+ focal_window (int): Focal window size at first focal level
137
+ use_layerscale (bool): Whether use layerscale
138
+ layerscale_value (float): Initial layerscale value
139
+ use_postln (bool): Whether use layernorm after modulation
140
+ """
141
+
142
+ def __init__(self, dim, input_resolution, mlp_ratio=4., drop=0., drop_path=0.,
143
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm,
144
+ focal_level=1, focal_window=3,
145
+ use_layerscale=False, layerscale_value=1e-4,
146
+ use_postln=False):
147
+ super().__init__()
148
+ self.dim = dim
149
+ self.input_resolution = input_resolution
150
+ self.mlp_ratio = mlp_ratio
151
+
152
+ self.focal_window = focal_window
153
+ self.focal_level = focal_level
154
+
155
+ self.norm1 = norm_layer(dim)
156
+ self.modulation = FocalModulation(dim, proj_drop=drop, focal_window=focal_window, focal_level=self.focal_level, use_postln=use_postln)
157
+
158
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
159
+ self.norm2 = norm_layer(dim)
160
+ mlp_hidden_dim = int(dim * mlp_ratio)
161
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
162
+
163
+ self.gamma_1 = 1.0
164
+ self.gamma_2 = 1.0
165
+ if use_layerscale:
166
+ self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
167
+ self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
168
+
169
+ self.H = None
170
+ self.W = None
171
+
172
+ def forward(self, x):
173
+ H, W = self.H, self.W
174
+ B, L, C = x.shape
175
+ shortcut = x
176
+
177
+ # Focal Modulation
178
+ x = self.norm1(x)
179
+ x = x.view(B, H, W, C)
180
+ x = self.modulation(x).view(B, H * W, C)
181
+
182
+ # FFN
183
+ x = shortcut + self.drop_path(self.gamma_1 * x)
184
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
185
+
186
+ return x
187
+
188
+ def extra_repr(self) -> str:
189
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, " \
190
+ f"mlp_ratio={self.mlp_ratio}"
191
+
192
+ def flops(self):
193
+ flops = 0
194
+ H, W = self.input_resolution
195
+ # norm1
196
+ flops += self.dim * H * W
197
+
198
+ # W-MSA/SW-MSA
199
+ flops += self.modulation.flops(H*W)
200
+
201
+ # mlp
202
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
203
+ # norm2
204
+ flops += self.dim * H * W
205
+ return flops
206
+
207
+ class BasicLayer(nn.Module):
208
+ """ A basic Focal Transformer layer for one stage.
209
+
210
+ Args:
211
+ dim (int): Number of input channels.
212
+ input_resolution (tuple[int]): Input resolution.
213
+ depth (int): Number of blocks.
214
+ window_size (int): Local window size.
215
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
216
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
217
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
218
+ drop (float, optional): Dropout rate. Default: 0.0
219
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
220
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
221
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
222
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
223
+ focal_level (int): Number of focal levels
224
+ focal_window (int): Focal window size at first focal level
225
+ use_layerscale (bool): Whether use layerscale
226
+ layerscale_value (float): Initial layerscale value
227
+ use_postln (bool): Whether use layernorm after modulation
228
+ """
229
+
230
+ def __init__(self, dim, out_dim, input_resolution, depth,
231
+ mlp_ratio=4., drop=0., drop_path=0., norm_layer=nn.LayerNorm,
232
+ downsample=None, use_checkpoint=False,
233
+ focal_level=1, focal_window=1,
234
+ use_conv_embed=False,
235
+ use_layerscale=False, layerscale_value=1e-4, use_postln=False):
236
+
237
+ super().__init__()
238
+ self.dim = dim
239
+ self.input_resolution = input_resolution
240
+ self.depth = depth
241
+ self.use_checkpoint = use_checkpoint
242
+
243
+ # build blocks
244
+ self.blocks = nn.ModuleList([
245
+ FocalNetBlock(
246
+ dim=dim,
247
+ input_resolution=input_resolution,
248
+ mlp_ratio=mlp_ratio,
249
+ drop=drop,
250
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
251
+ norm_layer=norm_layer,
252
+ focal_level=focal_level,
253
+ focal_window=focal_window,
254
+ use_layerscale=use_layerscale,
255
+ layerscale_value=layerscale_value,
256
+ use_postln=use_postln,
257
+ )
258
+ for i in range(depth)])
259
+
260
+ if downsample is not None:
261
+ self.downsample = downsample(
262
+ img_size=input_resolution,
263
+ patch_size=2,
264
+ in_chans=dim,
265
+ embed_dim=out_dim,
266
+ use_conv_embed=use_conv_embed,
267
+ norm_layer=norm_layer,
268
+ is_stem=False
269
+ )
270
+ else:
271
+ self.downsample = None
272
+
273
+ def forward(self, x, H, W):
274
+ for blk in self.blocks:
275
+ blk.H, blk.W = H, W
276
+ if self.use_checkpoint:
277
+ x = checkpoint.checkpoint(blk, x)
278
+ else:
279
+ x = blk(x)
280
+
281
+ if self.downsample is not None:
282
+ x = x.transpose(1, 2).reshape(x.shape[0], -1, H, W)
283
+ x, Ho, Wo = self.downsample(x)
284
+ else:
285
+ Ho, Wo = H, W
286
+ return x, Ho, Wo
287
+
288
+ def extra_repr(self) -> str:
289
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
290
+
291
+ def flops(self):
292
+ flops = 0
293
+ for blk in self.blocks:
294
+ flops += blk.flops()
295
+ if self.downsample is not None:
296
+ flops += self.downsample.flops()
297
+ return flops
298
+
299
+ class PatchEmbed(nn.Module):
300
+ r""" Image to Patch Embedding
301
+
302
+ Args:
303
+ img_size (int): Image size. Default: 224.
304
+ patch_size (int): Patch token size. Default: 4.
305
+ in_chans (int): Number of input image channels. Default: 3.
306
+ embed_dim (int): Number of linear projection output channels. Default: 96.
307
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
308
+ """
309
+
310
+ def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96, use_conv_embed=False, norm_layer=None, is_stem=False):
311
+ super().__init__()
312
+ patch_size = to_2tuple(patch_size)
313
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
314
+ self.img_size = img_size
315
+ self.patch_size = patch_size
316
+ self.patches_resolution = patches_resolution
317
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
318
+
319
+ self.in_chans = in_chans
320
+ self.embed_dim = embed_dim
321
+
322
+ if use_conv_embed:
323
+ # if we choose to use conv embedding, then we treat the stem and non-stem differently
324
+ if is_stem:
325
+ kernel_size = 7; padding = 2; stride = 4
326
+ else:
327
+ kernel_size = 3; padding = 1; stride = 2
328
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
329
+ else:
330
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
331
+
332
+ if norm_layer is not None:
333
+ self.norm = norm_layer(embed_dim)
334
+ else:
335
+ self.norm = None
336
+
337
+ def forward(self, x):
338
+ B, C, H, W = x.shape
339
+
340
+ x = self.proj(x)
341
+ H, W = x.shape[2:]
342
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
343
+ if self.norm is not None:
344
+ x = self.norm(x)
345
+ return x, H, W
346
+
347
+ def flops(self):
348
+ Ho, Wo = self.patches_resolution
349
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
350
+ if self.norm is not None:
351
+ flops += Ho * Wo * self.embed_dim
352
+ return flops
353
+
354
+ class FocalNet(nn.Module):
355
+ r""" Focal Modulation Networks (FocalNets)
356
+
357
+ Args:
358
+ img_size (int | tuple(int)): Input image size. Default 224
359
+ patch_size (int | tuple(int)): Patch size. Default: 4
360
+ in_chans (int): Number of input image channels. Default: 3
361
+ num_classes (int): Number of classes for classification head. Default: 1000
362
+ embed_dim (int): Patch embedding dimension. Default: 96
363
+ depths (tuple(int)): Depth of each Focal Transformer layer.
364
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
365
+ drop_rate (float): Dropout rate. Default: 0
366
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
367
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
368
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
369
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
370
+ focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level. Default: [1, 1, 1, 1]
371
+ focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1]
372
+ use_conv_embed (bool): Whether use convolutional embedding. We noted that using convolutional embedding usually improve the performance, but we do not use it by default. Default: False
373
+ use_layerscale (bool): Whether use layerscale proposed in CaiT. Default: False
374
+ layerscale_value (float): Value for layer scale. Default: 1e-4
375
+ use_postln (bool): Whether use layernorm after modulation (it helps stablize training of large models)
376
+ """
377
+ def __init__(self,
378
+ img_size=224,
379
+ patch_size=4,
380
+ in_chans=3,
381
+ num_classes=1000,
382
+ embed_dim=96,
383
+ depths=[2, 2, 6, 2],
384
+ mlp_ratio=4.,
385
+ drop_rate=0.,
386
+ drop_path_rate=0.1,
387
+ norm_layer=nn.LayerNorm,
388
+ patch_norm=True,
389
+ use_checkpoint=False,
390
+ focal_levels=[2, 2, 2, 2],
391
+ focal_windows=[3, 3, 3, 3],
392
+ use_conv_embed=False,
393
+ use_layerscale=False,
394
+ layerscale_value=1e-4,
395
+ use_postln=False,
396
+ **kwargs):
397
+ super().__init__()
398
+
399
+ self.num_layers = len(depths)
400
+ embed_dim = [embed_dim * (2 ** i) for i in range(self.num_layers)]
401
+
402
+ self.num_classes = num_classes
403
+ self.embed_dim = embed_dim
404
+ self.patch_norm = patch_norm
405
+ self.num_features = embed_dim[-1]
406
+ self.mlp_ratio = mlp_ratio
407
+
408
+ # split image into patches using either non-overlapped embedding or overlapped embedding
409
+ self.patch_embed = PatchEmbed(
410
+ img_size=to_2tuple(img_size),
411
+ patch_size=patch_size,
412
+ in_chans=in_chans,
413
+ embed_dim=embed_dim[0],
414
+ use_conv_embed=use_conv_embed,
415
+ norm_layer=norm_layer if self.patch_norm else None,
416
+ is_stem=True)
417
+
418
+ num_patches = self.patch_embed.num_patches
419
+ patches_resolution = self.patch_embed.patches_resolution
420
+ self.patches_resolution = patches_resolution
421
+ self.pos_drop = nn.Dropout(p=drop_rate)
422
+
423
+ # stochastic depth
424
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
425
+
426
+ # build layers
427
+ self.layers = nn.ModuleList()
428
+ for i_layer in range(self.num_layers):
429
+ layer = BasicLayer(dim=embed_dim[i_layer],
430
+ out_dim=embed_dim[i_layer+1] if (i_layer < self.num_layers - 1) else None,
431
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
432
+ patches_resolution[1] // (2 ** i_layer)),
433
+ depth=depths[i_layer],
434
+ mlp_ratio=self.mlp_ratio,
435
+ drop=drop_rate,
436
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
437
+ norm_layer=norm_layer,
438
+ downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
439
+ focal_level=focal_levels[i_layer],
440
+ focal_window=focal_windows[i_layer],
441
+ use_conv_embed=use_conv_embed,
442
+ use_checkpoint=use_checkpoint,
443
+ use_layerscale=use_layerscale,
444
+ layerscale_value=layerscale_value,
445
+ use_postln=use_postln,
446
+ )
447
+ self.layers.append(layer)
448
+
449
+ self.norm = norm_layer(self.num_features)
450
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
451
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
452
+
453
+ self.apply(self._init_weights)
454
+
455
+ def _init_weights(self, m):
456
+ if isinstance(m, nn.Linear):
457
+ trunc_normal_(m.weight, std=.02)
458
+ if isinstance(m, nn.Linear) and m.bias is not None:
459
+ nn.init.constant_(m.bias, 0)
460
+ elif isinstance(m, nn.LayerNorm):
461
+ nn.init.constant_(m.bias, 0)
462
+ nn.init.constant_(m.weight, 1.0)
463
+
464
+ @torch.jit.ignore
465
+ def no_weight_decay(self):
466
+ return {''}
467
+
468
+ @torch.jit.ignore
469
+ def no_weight_decay_keywords(self):
470
+ return {''}
471
+
472
+ def forward_features(self, x):
473
+ x, H, W = self.patch_embed(x)
474
+ x = self.pos_drop(x)
475
+
476
+ for layer in self.layers:
477
+ x, H, W = layer(x, H, W)
478
+ x = self.norm(x) # B L C
479
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
480
+ x = torch.flatten(x, 1)
481
+ return x
482
+
483
+ def forward(self, x):
484
+ x = self.forward_features(x)
485
+ x = self.head(x)
486
+ return x
487
+
488
+ def flops(self):
489
+ flops = 0
490
+ flops += self.patch_embed.flops()
491
+ for i, layer in enumerate(self.layers):
492
+ flops += layer.flops()
493
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
494
+ flops += self.num_features * self.num_classes
495
+ return flops
496
+
497
+ def build_transforms(img_size, center_crop=False):
498
+ t = [transforms.ToPILImage()]
499
+ if center_crop:
500
+ size = int((256 / 224) * img_size)
501
+ t.append(
502
+ transforms.Resize(size, interpolation=_pil_interp('bicubic'))
503
+ )
504
+ t.append(
505
+ transforms.CenterCrop(img_size)
506
+ )
507
+ else:
508
+ t.append(
509
+ transforms.Resize(img_size, interpolation=_pil_interp('bicubic'))
510
+ )
511
+ t.append(transforms.ToTensor())
512
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
513
+ return transforms.Compose(t)
514
+
515
+ def build_transforms4display(img_size, center_crop=False):
516
+ t = [transforms.ToPILImage()]
517
+ if center_crop:
518
+ size = int((256 / 224) * img_size)
519
+ t.append(
520
+ transforms.Resize(size, interpolation=_pil_interp('bicubic'))
521
+ )
522
+ t.append(
523
+ transforms.CenterCrop(img_size)
524
+ )
525
+ else:
526
+ t.append(
527
+ transforms.Resize(img_size, interpolation=_pil_interp('bicubic'))
528
+ )
529
+ t.append(transforms.ToTensor())
530
+ return transforms.Compose(t)
531
+
532
+ model_urls = {
533
+ "focalnet_tiny_srf": "",
534
+ "focalnet_small_srf": "",
535
+ "focalnet_base_srf": "",
536
+ "focalnet_tiny_lrf": "",
537
+ "focalnet_small_lrf": "",
538
+ "focalnet_base_lrf": "",
539
+ }
540
+
541
+ @register_model
542
+ def focalnet_tiny_srf(pretrained=False, **kwargs):
543
+ model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96, **kwargs)
544
+ if pretrained:
545
+ url = model_urls['focalnet_tiny_srf']
546
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
547
+ model.load_state_dict(checkpoint["model"])
548
+ return model
549
+
550
+ @register_model
551
+ def focalnet_small_srf(pretrained=False, **kwargs):
552
+ model = FocalNet(depths=[2, 2, 18, 2], embed_dim=96, **kwargs)
553
+ if pretrained:
554
+ url = model_urls['focalnet_small_srf']
555
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
556
+ model.load_state_dict(checkpoint["model"])
557
+ return model
558
+
559
+ @register_model
560
+ def focalnet_base_srf(pretrained=False, **kwargs):
561
+ model = FocalNet(depths=[2, 2, 18, 2], embed_dim=128, **kwargs)
562
+ if pretrained:
563
+ url = model_urls['focalnet_base_srf']
564
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
565
+ model.load_state_dict(checkpoint["model"])
566
+ return model
567
+
568
+ @register_model
569
+ def focalnet_tiny_lrf(pretrained=False, **kwargs):
570
+ model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs)
571
+ if pretrained:
572
+ url = model_urls['focalnet_tiny_lrf']
573
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
574
+ model.load_state_dict(checkpoint["model"])
575
+ return model
576
+
577
+ @register_model
578
+ def focalnet_small_lrf(pretrained=False, **kwargs):
579
+ model = FocalNet(depths=[2, 2, 18, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs)
580
+ if pretrained:
581
+ url = model_urls['focalnet_small_lrf']
582
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
583
+ model.load_state_dict(checkpoint["model"])
584
+ return model
585
+
586
+ @register_model
587
+ def focalnet_base_lrf(pretrained=False, **kwargs):
588
+ model = FocalNet(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], **kwargs)
589
+ if pretrained:
590
+ url = model_urls['focalnet_base_lrf']
591
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
592
+ model.load_state_dict(checkpoint["model"])
593
+ return model
594
+
595
+ @register_model
596
+ def focalnet_tiny_iso_16(pretrained=False, **kwargs):
597
+ model = FocalNet(depths=[12], patch_size=16, embed_dim=192, focal_levels=[3], focal_windows=[3], **kwargs)
598
+ if pretrained:
599
+ url = model_urls['focalnet_tiny_iso_16']
600
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
601
+ model.load_state_dict(checkpoint["model"])
602
+ return model
603
+
604
+ @register_model
605
+ def focalnet_small_iso_16(pretrained=False, **kwargs):
606
+ model = FocalNet(depths=[12], patch_size=16, embed_dim=384, focal_levels=[3], focal_windows=[3], **kwargs)
607
+ if pretrained:
608
+ url = model_urls['focalnet_small_iso_16']
609
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
610
+ model.load_state_dict(checkpoint["model"])
611
+ return model
612
+
613
+ @register_model
614
+ def focalnet_base_iso_16(pretrained=False, **kwargs):
615
+ model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], focal_windows=[3], use_layerscale=True, use_postln=True, **kwargs)
616
+ if pretrained:
617
+ url = model_urls['focalnet_base_iso_16']
618
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
619
+ model.load_state_dict(checkpoint["model"])
620
+ return model
621
+
622
+ if __name__ == '__main__':
623
+ img_size = 224
624
+ x = torch.rand(16, 3, img_size, img_size).cuda()
625
+ # model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96)
626
+ # model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], focal_windows=[3], focal_factors=[2])
627
+ model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3]).cuda()
628
+ print(model); model(x)
629
+
630
+ flops = model.flops()
631
+ print(f"number of GFLOPs: {flops / 1e9}")
632
+
633
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
634
+ print(f"number of params: {n_parameters}")