52Hz commited on
Commit
6e75a05
1 Parent(s): d649af9

Create SUNet.py

Browse files
Files changed (1) hide show
  1. model/SUNet.py +30 -0
model/SUNet.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from model.SUNet_detail import SUNet
3
+
4
+
5
+ class SUNet_model(nn.Module):
6
+ def __init__(self, config):
7
+ super(SUNet_model, self).__init__()
8
+ self.config = config
9
+ self.swin_unet = SUNet(img_size=config['SWINUNET']['IMG_SIZE'],
10
+ patch_size=config['SWINUNET']['PATCH_SIZE'],
11
+ in_chans=3,
12
+ out_chans=3,
13
+ embed_dim=config['SWINUNET']['EMB_DIM'],
14
+ depths=config['SWINUNET']['DEPTH_EN'],
15
+ num_heads=config['SWINUNET']['HEAD_NUM'],
16
+ window_size=config['SWINUNET']['WIN_SIZE'],
17
+ mlp_ratio=config['SWINUNET']['MLP_RATIO'],
18
+ qkv_bias=config['SWINUNET']['QKV_BIAS'],
19
+ qk_scale=config['SWINUNET']['QK_SCALE'],
20
+ drop_rate=config['SWINUNET']['DROP_RATE'],
21
+ drop_path_rate=config['SWINUNET']['DROP_PATH_RATE'],
22
+ ape=config['SWINUNET']['APE'],
23
+ patch_norm=config['SWINUNET']['PATCH_NORM'],
24
+ use_checkpoint=config['SWINUNET']['USE_CHECKPOINTS'])
25
+
26
+ def forward(self, x):
27
+ if x.size()[1] == 1:
28
+ x = x.repeat(1, 3, 1, 1)
29
+ logits = self.swin_unet(x)
30
+ return logits