52Hz commited on
Commit
15c9ff5
·
1 Parent(s): 6e75a05

Upload model/SUNet_detail.py

Browse files
Files changed (1) hide show
  1. model/SUNet_detail.py +788 -0
model/SUNet_detail.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.utils.checkpoint as checkpoint
4
+ from einops import rearrange
5
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6
+ from thop import profile
7
+
8
+ class Mlp(nn.Module):
9
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
10
+ super().__init__()
11
+ out_features = out_features or in_features
12
+ hidden_features = hidden_features or in_features
13
+ self.fc1 = nn.Linear(in_features, hidden_features)
14
+ self.act = act_layer()
15
+ self.fc2 = nn.Linear(hidden_features, out_features)
16
+ self.drop = nn.Dropout(drop)
17
+
18
+ def forward(self, x):
19
+ x = self.fc1(x)
20
+ x = self.act(x)
21
+ x = self.drop(x)
22
+ x = self.fc2(x)
23
+ x = self.drop(x)
24
+ return x
25
+
26
+
27
+ def window_partition(x, window_size):
28
+ """
29
+ Args:
30
+ x: (B, H, W, C)
31
+ window_size (int): window size
32
+
33
+ Returns:
34
+ windows: (num_windows*B, window_size, window_size, C)
35
+ """
36
+ B, H, W, C = x.shape
37
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
38
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
39
+ return windows
40
+
41
+
42
+ def window_reverse(windows, window_size, H, W):
43
+ """
44
+ Args:
45
+ windows: (num_windows*B, window_size, window_size, C)
46
+ window_size (int): Window size
47
+ H (int): Height of image
48
+ W (int): Width of image
49
+
50
+ Returns:
51
+ x: (B, H, W, C)
52
+ """
53
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
54
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
55
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
56
+ return x
57
+
58
+
59
+ class WindowAttention(nn.Module):
60
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
61
+ It supports both of shifted and non-shifted window.
62
+
63
+ Args:
64
+ dim (int): Number of input channels.
65
+ window_size (tuple[int]): The height and width of the window.
66
+ num_heads (int): Number of attention heads.
67
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
68
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
69
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
70
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
71
+ """
72
+
73
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
74
+
75
+ super().__init__()
76
+ self.dim = dim
77
+ self.window_size = window_size # Wh, Ww
78
+ self.num_heads = num_heads
79
+ head_dim = dim // num_heads
80
+ self.scale = qk_scale or head_dim ** -0.5
81
+
82
+ # define a parameter table of relative position bias
83
+ self.relative_position_bias_table = nn.Parameter(
84
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
85
+
86
+ # get pair-wise relative position index for each token inside the window
87
+ coords_h = torch.arange(self.window_size[0])
88
+ coords_w = torch.arange(self.window_size[1])
89
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
90
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
91
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
92
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
93
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
94
+ relative_coords[:, :, 1] += self.window_size[1] - 1
95
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
96
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
97
+ self.register_buffer("relative_position_index", relative_position_index)
98
+
99
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
100
+ self.attn_drop = nn.Dropout(attn_drop)
101
+ self.proj = nn.Linear(dim, dim)
102
+ self.proj_drop = nn.Dropout(proj_drop)
103
+
104
+ trunc_normal_(self.relative_position_bias_table, std=.02)
105
+ self.softmax = nn.Softmax(dim=-1)
106
+
107
+ def forward(self, x, mask=None):
108
+ """
109
+ Args:
110
+ x: input features with shape of (num_windows*B, N, C)
111
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
112
+ """
113
+ B_, N, C = x.shape
114
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
115
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
116
+
117
+ q = q * self.scale
118
+ attn = (q @ k.transpose(-2, -1))
119
+
120
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
121
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
122
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
123
+ attn = attn + relative_position_bias.unsqueeze(0)
124
+
125
+ if mask is not None:
126
+ nW = mask.shape[0]
127
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
128
+ attn = attn.view(-1, self.num_heads, N, N)
129
+ attn = self.softmax(attn)
130
+ else:
131
+ attn = self.softmax(attn)
132
+
133
+ attn = self.attn_drop(attn)
134
+
135
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
136
+ x = self.proj(x)
137
+ x = self.proj_drop(x)
138
+ return x
139
+
140
+ def extra_repr(self) -> str:
141
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
142
+
143
+ def flops(self, N):
144
+ # calculate flops for 1 window with token length of N
145
+ flops = 0
146
+ # qkv = self.qkv(x)
147
+ flops += N * self.dim * 3 * self.dim
148
+ # attn = (q @ k.transpose(-2, -1))
149
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
150
+ # x = (attn @ v)
151
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
152
+ # x = self.proj(x)
153
+ flops += N * self.dim * self.dim
154
+ return flops
155
+
156
+
157
+ class SwinTransformerBlock(nn.Module):
158
+ r""" Swin Transformer Block.
159
+
160
+ Args:
161
+ dim (int): Number of input channels.
162
+ input_resolution (tuple[int]): Input resulotion.
163
+ num_heads (int): Number of attention heads.
164
+ window_size (int): Window size.
165
+ shift_size (int): Shift size for SW-MSA.
166
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
167
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
168
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
169
+ drop (float, optional): Dropout rate. Default: 0.0
170
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
171
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
172
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
173
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
174
+ """
175
+
176
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
177
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
178
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
179
+ super().__init__()
180
+ self.dim = dim
181
+ self.input_resolution = input_resolution
182
+ self.num_heads = num_heads
183
+ self.window_size = window_size
184
+ self.shift_size = shift_size
185
+ self.mlp_ratio = mlp_ratio
186
+ if min(self.input_resolution) <= self.window_size:
187
+ # if window size is larger than input resolution, we don't partition windows
188
+ self.shift_size = 0
189
+ self.window_size = min(self.input_resolution)
190
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
191
+
192
+ self.norm1 = norm_layer(dim)
193
+ self.attn = WindowAttention(
194
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
195
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
196
+
197
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
198
+ self.norm2 = norm_layer(dim)
199
+ mlp_hidden_dim = int(dim * mlp_ratio)
200
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
201
+
202
+ if self.shift_size > 0:
203
+ # calculate attention mask for SW-MSA
204
+ H, W = self.input_resolution
205
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
206
+ h_slices = (slice(0, -self.window_size),
207
+ slice(-self.window_size, -self.shift_size),
208
+ slice(-self.shift_size, None))
209
+ w_slices = (slice(0, -self.window_size),
210
+ slice(-self.window_size, -self.shift_size),
211
+ slice(-self.shift_size, None))
212
+ cnt = 0
213
+ for h in h_slices:
214
+ for w in w_slices:
215
+ img_mask[:, h, w, :] = cnt
216
+ cnt += 1
217
+
218
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
219
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
220
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
221
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
222
+ else:
223
+ attn_mask = None
224
+
225
+ self.register_buffer("attn_mask", attn_mask)
226
+
227
+ def forward(self, x):
228
+ H, W = self.input_resolution
229
+ B, L, C = x.shape
230
+ # assert L == H * W, "input feature has wrong size"
231
+
232
+ shortcut = x
233
+ x = self.norm1(x)
234
+ x = x.view(B, H, W, C)
235
+
236
+ # cyclic shift
237
+ if self.shift_size > 0:
238
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
239
+ else:
240
+ shifted_x = x
241
+
242
+ # partition windows
243
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
244
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
245
+
246
+ # W-MSA/SW-MSA
247
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
248
+
249
+ # merge windows
250
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
251
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
252
+
253
+ # reverse cyclic shift
254
+ if self.shift_size > 0:
255
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
256
+ else:
257
+ x = shifted_x
258
+ x = x.view(B, H * W, C)
259
+
260
+ # FFN
261
+ x = shortcut + self.drop_path(x)
262
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
263
+
264
+ return x
265
+
266
+ def extra_repr(self) -> str:
267
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
268
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
269
+
270
+ def flops(self):
271
+ flops = 0
272
+ H, W = self.input_resolution
273
+ # norm1
274
+ flops += self.dim * H * W
275
+ # W-MSA/SW-MSA
276
+ nW = H * W / self.window_size / self.window_size
277
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
278
+ # mlp
279
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
280
+ # norm2
281
+ flops += self.dim * H * W
282
+ return flops
283
+
284
+
285
+ class PatchMerging(nn.Module):
286
+ r""" Patch Merging Layer.
287
+
288
+ Args:
289
+ input_resolution (tuple[int]): Resolution of input feature.
290
+ dim (int): Number of input channels.
291
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
292
+ """
293
+
294
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
295
+ super().__init__()
296
+ self.input_resolution = input_resolution
297
+ self.dim = dim
298
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
299
+ self.norm = norm_layer(4 * dim)
300
+
301
+ def forward(self, x):
302
+ """
303
+ x: B, H*W, C
304
+ """
305
+ H, W = self.input_resolution
306
+ B, L, C = x.shape
307
+ assert L == H * W, "input feature has wrong size"
308
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
309
+
310
+ x = x.view(B, H, W, C)
311
+
312
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
313
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
314
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
315
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
316
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
317
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
318
+
319
+ x = self.norm(x)
320
+ x = self.reduction(x)
321
+
322
+ return x
323
+
324
+ def extra_repr(self) -> str:
325
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
326
+
327
+ def flops(self):
328
+ H, W = self.input_resolution
329
+ flops = H * W * self.dim
330
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
331
+ return flops
332
+
333
+
334
+ # Dual up-sample
335
+ class UpSample(nn.Module):
336
+ def __init__(self, input_resolution, in_channels, scale_factor):
337
+ super(UpSample, self).__init__()
338
+ self.input_resolution = input_resolution
339
+ self.factor = scale_factor
340
+
341
+
342
+ if self.factor == 2:
343
+ self.conv = nn.Conv2d(in_channels, in_channels//2, 1, 1, 0, bias=False)
344
+ self.up_p = nn.Sequential(nn.Conv2d(in_channels, 2*in_channels, 1, 1, 0, bias=False),
345
+ nn.PReLU(),
346
+ nn.PixelShuffle(scale_factor),
347
+ nn.Conv2d(in_channels//2, in_channels//2, 1, stride=1, padding=0, bias=False))
348
+
349
+ self.up_b = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1, 1, 0),
350
+ nn.PReLU(),
351
+ nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),
352
+ nn.Conv2d(in_channels, in_channels // 2, 1, stride=1, padding=0, bias=False))
353
+ elif self.factor == 4:
354
+ self.conv = nn.Conv2d(2*in_channels, in_channels, 1, 1, 0, bias=False)
355
+ self.up_p = nn.Sequential(nn.Conv2d(in_channels, 16 * in_channels, 1, 1, 0, bias=False),
356
+ nn.PReLU(),
357
+ nn.PixelShuffle(scale_factor),
358
+ nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0, bias=False))
359
+
360
+ self.up_b = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1, 1, 0),
361
+ nn.PReLU(),
362
+ nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),
363
+ nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0, bias=False))
364
+
365
+ def forward(self, x):
366
+ """
367
+ x: B, L = H*W, C
368
+ """
369
+ if type(self.input_resolution) == int:
370
+ H = self.input_resolution
371
+ W = self.input_resolution
372
+
373
+ elif type(self.input_resolution) == tuple:
374
+ H, W = self.input_resolution
375
+
376
+ B, L, C = x.shape
377
+ x = x.view(B, H, W, C) # B, H, W, C
378
+ x = x.permute(0, 3, 1, 2) # B, C, H, W
379
+ x_p = self.up_p(x) # pixel shuffle
380
+ x_b = self.up_b(x) # bilinear
381
+ out = self.conv(torch.cat([x_p, x_b], dim=1))
382
+ out = out.permute(0, 2, 3, 1) # B, H, W, C
383
+ if self.factor == 2:
384
+ out = out.view(B, -1, C // 2)
385
+
386
+ return out
387
+
388
+
389
+ class BasicLayer(nn.Module):
390
+ """ A basic Swin Transformer layer for one stage.
391
+
392
+ Args:
393
+ dim (int): Number of input channels.
394
+ input_resolution (tuple[int]): Input resolution.
395
+ depth (int): Number of blocks.
396
+ num_heads (int): Number of attention heads.
397
+ window_size (int): Local window size.
398
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
399
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
400
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
401
+ drop (float, optional): Dropout rate. Default: 0.0
402
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
403
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
404
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
405
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
406
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
407
+ """
408
+
409
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
410
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
411
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
412
+
413
+ super().__init__()
414
+ self.dim = dim
415
+ self.input_resolution = input_resolution
416
+ self.depth = depth
417
+ self.use_checkpoint = use_checkpoint
418
+
419
+ # build blocks
420
+ self.blocks = nn.ModuleList([
421
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
422
+ num_heads=num_heads, window_size=window_size,
423
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
424
+ mlp_ratio=mlp_ratio,
425
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
426
+ drop=drop, attn_drop=attn_drop,
427
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
428
+ norm_layer=norm_layer)
429
+ for i in range(depth)])
430
+
431
+ # patch merging layer
432
+ if downsample is not None:
433
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
434
+ else:
435
+ self.downsample = None
436
+
437
+ def forward(self, x):
438
+ for blk in self.blocks:
439
+ if self.use_checkpoint:
440
+ x = checkpoint.checkpoint(blk, x)
441
+ else:
442
+ x = blk(x)
443
+ if self.downsample is not None:
444
+ x = self.downsample(x)
445
+ return x
446
+
447
+ def extra_repr(self) -> str:
448
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
449
+
450
+ def flops(self):
451
+ flops = 0
452
+ for blk in self.blocks:
453
+ flops += blk.flops()
454
+ if self.downsample is not None:
455
+ flops += self.downsample.flops()
456
+ return flops
457
+
458
+
459
+ class BasicLayer_up(nn.Module):
460
+ """ A basic Swin Transformer layer for one stage.
461
+
462
+ Args:
463
+ dim (int): Number of input channels.
464
+ input_resolution (tuple[int]): Input resolution.
465
+ depth (int): Number of blocks.
466
+ num_heads (int): Number of attention heads.
467
+ window_size (int): Local window size.
468
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
469
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
470
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
471
+ drop (float, optional): Dropout rate. Default: 0.0
472
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
473
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
474
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
475
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
476
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
477
+ """
478
+
479
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
480
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
481
+ drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
482
+
483
+ super().__init__()
484
+ self.dim = dim
485
+ self.input_resolution = input_resolution
486
+ self.depth = depth
487
+ self.use_checkpoint = use_checkpoint
488
+
489
+ # build blocks
490
+ self.blocks = nn.ModuleList([
491
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
492
+ num_heads=num_heads, window_size=window_size,
493
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
494
+ mlp_ratio=mlp_ratio,
495
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
496
+ drop=drop, attn_drop=attn_drop,
497
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
498
+ norm_layer=norm_layer)
499
+ for i in range(depth)])
500
+
501
+ # patch merging layer
502
+ if upsample is not None:
503
+ self.upsample = UpSample(input_resolution, in_channels=dim, scale_factor=2)
504
+ else:
505
+ self.upsample = None
506
+
507
+ def forward(self, x):
508
+ for blk in self.blocks:
509
+ if self.use_checkpoint:
510
+ x = checkpoint.checkpoint(blk, x)
511
+ else:
512
+ x = blk(x)
513
+ if self.upsample is not None:
514
+ x = self.upsample(x)
515
+ return x
516
+
517
+
518
+ class PatchEmbed(nn.Module):
519
+ r""" Image to Patch Embedding
520
+
521
+ Args:
522
+ img_size (int): Image size. Default: 224.
523
+ patch_size (int): Patch token size. Default: 4.
524
+ in_chans (int): Number of input image channels. Default: 3.
525
+ embed_dim (int): Number of linear projection output channels. Default: 96.
526
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
527
+ """
528
+
529
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
530
+ super().__init__()
531
+ img_size = to_2tuple(img_size)
532
+ patch_size = to_2tuple(patch_size)
533
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
534
+ self.img_size = img_size
535
+ self.patch_size = patch_size
536
+ self.patches_resolution = patches_resolution
537
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
538
+
539
+ self.in_chans = in_chans
540
+ self.embed_dim = embed_dim
541
+
542
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
543
+ if norm_layer is not None:
544
+ self.norm = norm_layer(embed_dim)
545
+ else:
546
+ self.norm = None
547
+
548
+ def forward(self, x):
549
+ B, C, H, W = x.shape
550
+ # FIXME look at relaxing size constraints
551
+ # assert H == self.img_size[0] and W == self.img_size[1], \
552
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
553
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
554
+ if self.norm is not None:
555
+ x = self.norm(x)
556
+ return x
557
+
558
+ def flops(self):
559
+ Ho, Wo = self.patches_resolution
560
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
561
+ if self.norm is not None:
562
+ flops += Ho * Wo * self.embed_dim
563
+ return flops
564
+
565
+
566
+ class SUNet(nn.Module):
567
+ r""" Swin Transformer
568
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
569
+ https://arxiv.org/pdf/2103.14030
570
+
571
+ Args:
572
+ img_size (int | tuple(int)): Input image size. Default 224
573
+ patch_size (int | tuple(int)): Patch size. Default: 4
574
+ in_chans (int): Number of input image channels. Default: 3
575
+
576
+ embed_dim (int): Patch embedding dimension. Default: 96
577
+ depths (tuple(int)): Depth of each Swin Transformer layer.
578
+ num_heads (tuple(int)): Number of attention heads in different layers.
579
+ window_size (int): Window size. Default: 7
580
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
581
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
582
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
583
+ drop_rate (float): Dropout rate. Default: 0
584
+ attn_drop_rate (float): Attention dropout rate. Default: 0
585
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
586
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
587
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
588
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
589
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
590
+ """
591
+
592
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, out_chans=3,
593
+ embed_dim=96, depths=[2, 2, 2, 2], num_heads=[3, 6, 12, 24],
594
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
595
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
596
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
597
+ use_checkpoint=False, final_upsample="Dual up-sample", **kwargs):
598
+ super(SUNet, self).__init__()
599
+
600
+ self.out_chans = out_chans
601
+ self.num_layers = len(depths)
602
+ self.embed_dim = embed_dim
603
+ self.ape = ape
604
+ self.patch_norm = patch_norm
605
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
606
+ self.num_features_up = int(embed_dim * 2)
607
+ self.mlp_ratio = mlp_ratio
608
+ self.final_upsample = final_upsample
609
+ self.prelu = nn.PReLU()
610
+ self.conv_first = nn.Conv2d(in_chans, embed_dim, 3, 1, 1)
611
+
612
+ # split image into non-overlapping patches
613
+ self.patch_embed = PatchEmbed(
614
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
615
+ norm_layer=norm_layer if self.patch_norm else None)
616
+ num_patches = self.patch_embed.num_patches
617
+ patches_resolution = self.patch_embed.patches_resolution
618
+ self.patches_resolution = patches_resolution
619
+
620
+ # absolute position embedding
621
+ if self.ape:
622
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
623
+ trunc_normal_(self.absolute_pos_embed, std=.02)
624
+
625
+ self.pos_drop = nn.Dropout(p=drop_rate)
626
+
627
+ # stochastic depth
628
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
629
+
630
+ # build encoder and bottleneck layers
631
+ self.layers = nn.ModuleList()
632
+ for i_layer in range(self.num_layers):
633
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
634
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
635
+ patches_resolution[1] // (2 ** i_layer)),
636
+ depth=depths[i_layer],
637
+ num_heads=num_heads[i_layer],
638
+ window_size=window_size,
639
+ mlp_ratio=self.mlp_ratio,
640
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
641
+ drop=drop_rate, attn_drop=attn_drop_rate,
642
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
643
+ norm_layer=norm_layer,
644
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
645
+ use_checkpoint=use_checkpoint)
646
+ self.layers.append(layer)
647
+
648
+ # build decoder layers
649
+ self.layers_up = nn.ModuleList()
650
+ self.concat_back_dim = nn.ModuleList()
651
+ for i_layer in range(self.num_layers):
652
+ concat_linear = nn.Linear(2 * int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
653
+ int(embed_dim * 2 ** (
654
+ self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity()
655
+ if i_layer == 0:
656
+ layer_up = UpSample(input_resolution=patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
657
+ in_channels=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), scale_factor=2)
658
+ else:
659
+ layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
660
+ input_resolution=(
661
+ patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
662
+ patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
663
+ depth=depths[(self.num_layers - 1 - i_layer)],
664
+ num_heads=num_heads[(self.num_layers - 1 - i_layer)],
665
+ window_size=window_size,
666
+ mlp_ratio=self.mlp_ratio,
667
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
668
+ drop=drop_rate, attn_drop=attn_drop_rate,
669
+ drop_path=dpr[sum(depths[:(self.num_layers - 1 - i_layer)]):sum(
670
+ depths[:(self.num_layers - 1 - i_layer) + 1])],
671
+ norm_layer=norm_layer,
672
+ upsample=UpSample if (i_layer < self.num_layers - 1) else None,
673
+ use_checkpoint=use_checkpoint)
674
+ self.layers_up.append(layer_up)
675
+ self.concat_back_dim.append(concat_linear)
676
+
677
+ self.norm = norm_layer(self.num_features)
678
+ self.norm_up = norm_layer(self.embed_dim)
679
+
680
+ if self.final_upsample == "Dual up-sample":
681
+ self.up = UpSample(input_resolution=(img_size // patch_size, img_size // patch_size),
682
+ in_channels=embed_dim, scale_factor=4)
683
+ self.output = nn.Conv2d(in_channels=embed_dim, out_channels=self.out_chans, kernel_size=3, stride=1,
684
+ padding=1, bias=False) # kernel = 1
685
+
686
+ self.apply(self._init_weights)
687
+
688
+ def _init_weights(self, m):
689
+ if isinstance(m, nn.Linear):
690
+ trunc_normal_(m.weight, std=.02)
691
+ if isinstance(m, nn.Linear) and m.bias is not None:
692
+ nn.init.constant_(m.bias, 0)
693
+ elif isinstance(m, nn.LayerNorm):
694
+ nn.init.constant_(m.bias, 0)
695
+ nn.init.constant_(m.weight, 1.0)
696
+
697
+ @torch.jit.ignore
698
+ def no_weight_decay(self):
699
+ return {'absolute_pos_embed'}
700
+
701
+ @torch.jit.ignore
702
+ def no_weight_decay_keywords(self):
703
+ return {'relative_position_bias_table'}
704
+
705
+ # Encoder and Bottleneck
706
+ def forward_features(self, x):
707
+ residual = x
708
+ x = self.patch_embed(x)
709
+ if self.ape:
710
+ x = x + self.absolute_pos_embed
711
+ x = self.pos_drop(x)
712
+ x_downsample = []
713
+
714
+ for layer in self.layers:
715
+ x_downsample.append(x)
716
+ x = layer(x)
717
+
718
+ x = self.norm(x) # B L C
719
+
720
+ return x, residual, x_downsample
721
+
722
+ # Dencoder and Skip connection
723
+ def forward_up_features(self, x, x_downsample):
724
+ for inx, layer_up in enumerate(self.layers_up):
725
+ if inx == 0:
726
+ x = layer_up(x)
727
+ else:
728
+ x = torch.cat([x, x_downsample[3 - inx]], -1) # concat last dimension
729
+ x = self.concat_back_dim[inx](x)
730
+ x = layer_up(x)
731
+
732
+ x = self.norm_up(x) # B L C
733
+
734
+ return x
735
+
736
+ def up_x4(self, x):
737
+ H, W = self.patches_resolution
738
+ B, L, C = x.shape
739
+ assert L == H * W, "input features has wrong size"
740
+
741
+ if self.final_upsample == "Dual up-sample":
742
+ x = self.up(x)
743
+ # x = x.view(B, 4 * H, 4 * W, -1)
744
+ x = x.permute(0, 3, 1, 2) # B,C,H,W
745
+
746
+ return x
747
+
748
+ def forward(self, x):
749
+ x = self.conv_first(x)
750
+ x, residual, x_downsample = self.forward_features(x)
751
+ x = self.forward_up_features(x, x_downsample)
752
+ x = self.up_x4(x)
753
+ out = self.output(x)
754
+ # x = x + residual
755
+ return out
756
+
757
+ def flops(self):
758
+ flops = 0
759
+ flops += self.patch_embed.flops()
760
+ for i, layer in enumerate(self.layers):
761
+ flops += layer.flops()
762
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
763
+ flops += self.num_features * self.out_chans
764
+ return flops
765
+
766
+
767
+ if __name__ == '__main__':
768
+ from utils.model_utils import network_parameters
769
+
770
+ height = 256
771
+ width = 256
772
+ x = torch.randn((1, 3, height, width)) # .cuda()
773
+ model = SUNet(img_size=256, patch_size=4, in_chans=3, out_chans=3,
774
+ embed_dim=96, depths=[8, 8, 8, 8],
775
+ num_heads=[8, 8, 8, 8],
776
+ window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=2,
777
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
778
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
779
+ use_checkpoint=False, final_upsample="Dual up-sample") # .cuda()
780
+ # print(model)
781
+ print('input image size: (%d, %d)' % (height, width))
782
+ print('FLOPs: %.4f G' % (model.flops() / 1e9))
783
+ print('model parameters: ', network_parameters(model))
784
+ # x = model(x)
785
+ print('output image size: ', x.shape)
786
+ flops, params = profile(model, (x,))
787
+ print(flops)
788
+ print(params)