52Hz commited on
Commit
31fd11c
·
1 Parent(s): 15c9ff5

Update model/SUNet_detail.py

Browse files
Files changed (1) hide show
  1. model/SUNet_detail.py +1 -24
model/SUNet_detail.py CHANGED
@@ -3,7 +3,7 @@ import torch.nn as nn
3
  import torch.utils.checkpoint as checkpoint
4
  from einops import rearrange
5
  from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6
- from thop import profile
7
 
8
  class Mlp(nn.Module):
9
  def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
@@ -763,26 +763,3 @@ class SUNet(nn.Module):
763
  flops += self.num_features * self.out_chans
764
  return flops
765
 
766
-
767
- if __name__ == '__main__':
768
- from utils.model_utils import network_parameters
769
-
770
- height = 256
771
- width = 256
772
- x = torch.randn((1, 3, height, width)) # .cuda()
773
- model = SUNet(img_size=256, patch_size=4, in_chans=3, out_chans=3,
774
- embed_dim=96, depths=[8, 8, 8, 8],
775
- num_heads=[8, 8, 8, 8],
776
- window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=2,
777
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
778
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
779
- use_checkpoint=False, final_upsample="Dual up-sample") # .cuda()
780
- # print(model)
781
- print('input image size: (%d, %d)' % (height, width))
782
- print('FLOPs: %.4f G' % (model.flops() / 1e9))
783
- print('model parameters: ', network_parameters(model))
784
- # x = model(x)
785
- print('output image size: ', x.shape)
786
- flops, params = profile(model, (x,))
787
- print(flops)
788
- print(params)
 
3
  import torch.utils.checkpoint as checkpoint
4
  from einops import rearrange
5
  from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6
+
7
 
8
  class Mlp(nn.Module):
9
  def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
 
763
  flops += self.num_features * self.out_chans
764
  return flops
765