venkatesh-thiru
commited on
Commit
•
75507f4
1
Parent(s):
af06282
Upload model
Browse files- RRDB.py +118 -0
- SRMRIModels.py +41 -0
- SRMRIModelsConfigs.py +52 -0
- config.json +23 -0
- model.safetensors +3 -0
- unet3DMSS.py +189 -0
RRDB.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class make_dense(nn.Module):
|
8 |
+
def __init__(self,nChannels,GrowthRate,kernel_size=3):
|
9 |
+
super(make_dense,self).__init__()
|
10 |
+
self.conv = nn.Conv3d(nChannels,GrowthRate,kernel_size=kernel_size,padding=(kernel_size-1)//2,bias=True)
|
11 |
+
# self.norm = nn.BatchNorm3d(nChannels)
|
12 |
+
def forward(self,x):
|
13 |
+
# out = self.norm(x)
|
14 |
+
out = F.relu(self.conv(x))
|
15 |
+
out = torch.cat([x,out],dim=1)
|
16 |
+
return out
|
17 |
+
|
18 |
+
# class RDB(nn.Module):
|
19 |
+
# def __init__(self,inChannels,outChannels,nDenseLayer,GrowthRate,KernelSize = 3):
|
20 |
+
# super(RDB,self).__init__()
|
21 |
+
# nChannels_ = inChannels
|
22 |
+
# modules = []
|
23 |
+
# for i in range (nDenseLayer):
|
24 |
+
# modules.append(make_dense(nChannels_,GrowthRate,kernel_size=KernelSize))
|
25 |
+
# nChannels_ += GrowthRate
|
26 |
+
# self.dense_layers = nn.Sequential(*modules)
|
27 |
+
# self.conv_1x1 = nn.Conv3d(nChannels_,outChannels,kernel_size=1,padding=0,bias = False)
|
28 |
+
# def forward(self,x):
|
29 |
+
# out = self.dense_layers(x)
|
30 |
+
# out = self.conv_1x1(out)
|
31 |
+
# # out = out + x
|
32 |
+
# return out
|
33 |
+
|
34 |
+
class RDB(nn.Module):
|
35 |
+
def __init__(self,inChannels,outChannels,nDenseLayer,GrowthRate,KernelSize = 3,
|
36 |
+
block_dropout = True, block_dropout_rate = 0.2):
|
37 |
+
super(RDB,self).__init__()
|
38 |
+
nChannels_ = inChannels
|
39 |
+
modules = []
|
40 |
+
for i in range (nDenseLayer):
|
41 |
+
modules.append(make_dense(nChannels_,GrowthRate,kernel_size=KernelSize))
|
42 |
+
nChannels_ += GrowthRate
|
43 |
+
if block_dropout:
|
44 |
+
modules.append(nn.Dropout3d(block_dropout_rate))
|
45 |
+
self.dense_layers = nn.Sequential(*modules)
|
46 |
+
self.conv_1x1 = nn.Conv3d(nChannels_,outChannels,kernel_size=1,padding=0,bias = False)
|
47 |
+
def forward(self,x):
|
48 |
+
out = self.dense_layers(x)
|
49 |
+
out = self.conv_1x1(out)
|
50 |
+
# out = out + x
|
51 |
+
return out
|
52 |
+
|
53 |
+
|
54 |
+
class RRDB(nn.Module):
|
55 |
+
def __init__(self,nChannels,nDenseLayers,nInitFeat,GrowthRate,featureFusion=True,kernel_config = [3,3,3,3]):
|
56 |
+
super(RRDB,self).__init__()
|
57 |
+
nChannels_ = nChannels
|
58 |
+
nDenseLayers_ = nDenseLayers
|
59 |
+
nInitFeat_ = nInitFeat
|
60 |
+
GrowthRate_ = GrowthRate
|
61 |
+
self.featureFusion = featureFusion
|
62 |
+
|
63 |
+
#First Convolution
|
64 |
+
self.C1 = nn.Conv3d(nChannels_,nInitFeat_,kernel_size=kernel_config[0],padding=(kernel_config[0]-1)//2,bias=True)
|
65 |
+
# Initialize RDB
|
66 |
+
if self.featureFusion:
|
67 |
+
self.RDB1 = RDB(nInitFeat_,nInitFeat_,nDenseLayers_,GrowthRate_,kernel_config[1])
|
68 |
+
# print(f"RDB1 =========================================== \n {self.RDB1}")
|
69 |
+
self.RDB2 = RDB(nInitFeat_*2,nInitFeat_, nDenseLayers_, GrowthRate_,kernel_config[2])
|
70 |
+
# print(f"RDB2 =========================================== \n {self.RDB2}")
|
71 |
+
self.RDB3 = RDB(nInitFeat_*3,nInitFeat_, nDenseLayers_, GrowthRate_,kernel_config[3])
|
72 |
+
# print(f"RDB3 =========================================== \n {self.RDB3}")
|
73 |
+
self.FF_1X1 = nn.Conv3d(nInitFeat_*4, 1, kernel_size=1, padding=0, bias=True)
|
74 |
+
# print(f"FF1x1 =========================================== \n {self.FF_1X1}")
|
75 |
+
else:
|
76 |
+
self.RDB1 = RDB(nInitFeat_, nDenseLayers_, GrowthRate_, kernel_config[1])
|
77 |
+
self.RDB2 = RDB(nInitFeat_, nDenseLayers_, GrowthRate_, kernel_config[2])
|
78 |
+
self.RDB3 = RDB(nInitFeat_, nDenseLayers_, GrowthRate_, kernel_config[3])
|
79 |
+
self.FF_1X1 = nn.Conv3d(nInitFeat_, 1, kernel_size=1, padding=0, bias=True)
|
80 |
+
|
81 |
+
|
82 |
+
# Feature Fusion
|
83 |
+
|
84 |
+
|
85 |
+
# self.FF_3X3 = nn.Conv3d(nInitFeat_,nInitFeat_,kernel_size=3,padding=1,bias=True)
|
86 |
+
|
87 |
+
# self.final_layer = nn.Conv3d(nInitFeat_,nChannels_,kernel_size=1,padding=0,bias=False)
|
88 |
+
|
89 |
+
def forward(self,x):
|
90 |
+
First = F.relu(self.C1(x))
|
91 |
+
R_1 = self.RDB1(First)
|
92 |
+
|
93 |
+
if self.featureFusion:
|
94 |
+
FF0 = torch.cat([First,R_1],dim = 1)
|
95 |
+
R_2 = self.RDB2(FF0)
|
96 |
+
FF1 = torch.cat([First,R_1,R_2],dim=1)
|
97 |
+
R_3 = self.RDB3(FF1)
|
98 |
+
FF2 = torch.cat([First,R_1, R_2, R_3], dim=1)
|
99 |
+
FF1X1 = self.FF_1X1(FF2)
|
100 |
+
else:
|
101 |
+
R_2 = self.RDB2(R_1)
|
102 |
+
R_3 = self.RDB3(R_2)
|
103 |
+
FF1X1 = self.FF_1X1(R_3)
|
104 |
+
|
105 |
+
# FF2 = torch.cat([R_1,R_2,R_3],dim=1)
|
106 |
+
# FF1X1 = self.FF_1X1(FF2)
|
107 |
+
# FF3X3 = self.FF_3X3(FF1X1)
|
108 |
+
# output = self.final_layer(FF3X3)
|
109 |
+
|
110 |
+
return FF1X1
|
111 |
+
|
112 |
+
if __name__ == '__main__':
|
113 |
+
model = RRDB(nChannels=1,nDenseLayers=6,nInitFeat=6,GrowthRate=12,featureFusion=True,kernel_config = [3,3,3,3]).cuda()
|
114 |
+
dimensions = 1, 1, 64, 64, 64
|
115 |
+
x = torch.rand(dimensions)
|
116 |
+
x = x.cuda()
|
117 |
+
out = model(x)
|
118 |
+
print(out.shape)
|
SRMRIModels.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel
|
2 |
+
from .RRDB import RRDB
|
3 |
+
from .unet3DMSS import UNetMSS
|
4 |
+
from .SRMRIModelsConfigs import RRDBConfiguration, UNetMSSConfiguration
|
5 |
+
|
6 |
+
class SRMRIModelUNetMSS(PreTrainedModel):
|
7 |
+
config_class = UNetMSSConfiguration
|
8 |
+
def __init__(self, config):
|
9 |
+
super().__init__(config)
|
10 |
+
self.model = UNetMSS(
|
11 |
+
in_channels=config.in_channels,
|
12 |
+
n_classes=config.n_classes,
|
13 |
+
depth=config.depth,
|
14 |
+
wf=config.wf,
|
15 |
+
padding=config.padding,
|
16 |
+
batch_norm=config.batch_norm,
|
17 |
+
up_mode=config.up_mode,
|
18 |
+
dropout=config.dropout,
|
19 |
+
mss_level=config.mss_level,
|
20 |
+
mss_fromlatent=config.mss_fromlatent,
|
21 |
+
mss_up=config.mss_up,
|
22 |
+
mss_interpb4=config.mss_interpb4)
|
23 |
+
def forward(self, x):
|
24 |
+
return self.model.forward(x)
|
25 |
+
|
26 |
+
|
27 |
+
class SRMRIModelRRDB(PreTrainedModel):
|
28 |
+
config_class = RRDBConfiguration
|
29 |
+
def __init__(self, config):
|
30 |
+
super().__init__(config)
|
31 |
+
self.model = RRDB(
|
32 |
+
nChannels=config.nChannels,
|
33 |
+
nDenseLayers=config.nDenseLayers,
|
34 |
+
nInitFeat=config.nInitFeat,
|
35 |
+
GrowthRate=config.GrowthRate,
|
36 |
+
featureFusion=config.featureFusion,
|
37 |
+
kernel_config=config.kernel_config,
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return self.model.forward(x)
|
SRMRIModelsConfigs.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
class RRDBConfiguration(PretrainedConfig):
|
5 |
+
model_type = "SRMRIModelRRDB"
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
nChannels=1,
|
9 |
+
nDenseLayers=6,
|
10 |
+
nInitFeat=6,
|
11 |
+
GrowthRate=12,
|
12 |
+
featureFusion=True,
|
13 |
+
kernel_config=[3, 3, 3, 3],
|
14 |
+
**kwargs):
|
15 |
+
self.nChannels = nChannels
|
16 |
+
self.nDenseLayers = nDenseLayers
|
17 |
+
self.nInitFeat = nInitFeat
|
18 |
+
self.GrowthRate = GrowthRate
|
19 |
+
self.featureFusion = featureFusion
|
20 |
+
self.kernel_config = kernel_config
|
21 |
+
super().__init__(**kwargs)
|
22 |
+
|
23 |
+
class UNetMSSConfiguration(PretrainedConfig):
|
24 |
+
model_type = "SRMRIModelUNetMSS"
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
in_channels=1,
|
28 |
+
n_classes=1,
|
29 |
+
depth=3,
|
30 |
+
wf=6,
|
31 |
+
padding=True,
|
32 |
+
batch_norm=False,
|
33 |
+
up_mode='upconv',
|
34 |
+
dropout=False,
|
35 |
+
mss_level=2,
|
36 |
+
mss_fromlatent=True,
|
37 |
+
mss_up="trilinear",
|
38 |
+
mss_interpb4=True,
|
39 |
+
**kwargs):
|
40 |
+
self.in_channels = in_channels
|
41 |
+
self.n_classes = n_classes
|
42 |
+
self.depth = depth
|
43 |
+
self.wf = wf
|
44 |
+
self.padding = padding
|
45 |
+
self.batch_norm = batch_norm
|
46 |
+
self.up_mode = up_mode
|
47 |
+
self.dropout = dropout
|
48 |
+
self.mss_level = mss_level
|
49 |
+
self.mss_fromlatent = mss_fromlatent
|
50 |
+
self.mss_up = mss_up
|
51 |
+
self.mss_interpb4 = mss_interpb4
|
52 |
+
super().__init__(**kwargs)
|
config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"GrowthRate": 12,
|
3 |
+
"architectures": [
|
4 |
+
"SRMRIModelRRDB"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "SRMRIModelsConfigs.RRDBConfiguration",
|
8 |
+
"AutoModel": "SRMRIModels.SRMRIModelRRDB"
|
9 |
+
},
|
10 |
+
"featureFusion": true,
|
11 |
+
"kernel_config": [
|
12 |
+
3,
|
13 |
+
3,
|
14 |
+
3,
|
15 |
+
3
|
16 |
+
],
|
17 |
+
"model_type": "SRMRIModelRRDB",
|
18 |
+
"nChannels": 1,
|
19 |
+
"nDenseLayers": 6,
|
20 |
+
"nInitFeat": 6,
|
21 |
+
"torch_dtype": "float32",
|
22 |
+
"transformers_version": "4.44.0"
|
23 |
+
}
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cd89edd8de99fe3fede6955832ef58fa64678e27d83174def7d5b101070c0836
|
3 |
+
size 991796
|
unet3DMSS.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://discuss.pytorch.org/t/unet-implementation/426
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
# import torchcomplex.nn.functional as cF
|
7 |
+
|
8 |
+
__author__ = "Soumick Chatterjee, Chompunuch Sarasaen"
|
9 |
+
__copyright__ = "Copyright 2020, Faculty of Computer Science, Otto von Guericke University Magdeburg, Germany"
|
10 |
+
__credits__ = ["Soumick Chatterjee", "Chompunuch Sarasaen"]
|
11 |
+
__license__ = "GPL"
|
12 |
+
__version__ = "1.0.0"
|
13 |
+
__maintainer__ = "Soumick Chatterjee"
|
14 |
+
__email__ = "soumick.chatterjee@ovgu.de"
|
15 |
+
__status__ = "Production"
|
16 |
+
|
17 |
+
|
18 |
+
class UNetMSS(nn.Module):
|
19 |
+
"""
|
20 |
+
Implementation of
|
21 |
+
U-Net: Convolutional Networks for Biomedical Image Segmentation
|
22 |
+
(Ronneberger et al., 2015)
|
23 |
+
https://arxiv.org/abs/1505.04597
|
24 |
+
|
25 |
+
Using the default arguments will yield the exact version used
|
26 |
+
in the original paper
|
27 |
+
|
28 |
+
Args:
|
29 |
+
in_channels (int): number of input channels
|
30 |
+
n_classes (int): number of output channels
|
31 |
+
depth (int): depth of the network
|
32 |
+
wf (int): number of filters in the first layer is 2**wf
|
33 |
+
padding (bool): if True, apply padding such that the input shape
|
34 |
+
is the same as the output.
|
35 |
+
This may introduce artifacts
|
36 |
+
batch_norm (bool): Use BatchNorm after layers with an
|
37 |
+
activation function
|
38 |
+
up_mode (str): one of 'upconv' or 'upsample'.
|
39 |
+
'upconv' will use transposed convolutions for
|
40 |
+
learned upsampling.
|
41 |
+
'upsample' will use bilinear upsampling.
|
42 |
+
"""
|
43 |
+
def __init__(self, in_channels=1, n_classes=1, depth=3, wf=6, padding=True,
|
44 |
+
batch_norm=False, up_mode='upconv', dropout=False, mss_level=2, mss_fromlatent=True,
|
45 |
+
mss_up="trilinear", mss_interpb4=False):
|
46 |
+
super(UNetMSS, self).__init__()
|
47 |
+
assert up_mode in ('upconv', 'upsample')
|
48 |
+
self.padding = padding
|
49 |
+
self.depth = depth
|
50 |
+
self.dropout = nn.Dropout3d() if dropout else nn.Sequential()
|
51 |
+
prev_channels = in_channels
|
52 |
+
self.down_path = nn.ModuleList()
|
53 |
+
up_out_features = []
|
54 |
+
for i in range(depth):
|
55 |
+
self.down_path.append(UNetConvBlock(prev_channels, 2**(wf+i),
|
56 |
+
padding, batch_norm))
|
57 |
+
prev_channels = 2**(wf+i)
|
58 |
+
|
59 |
+
if mss_fromlatent:
|
60 |
+
mss_features = [prev_channels]
|
61 |
+
else:
|
62 |
+
mss_features = []
|
63 |
+
|
64 |
+
self.up_path = nn.ModuleList()
|
65 |
+
for i in reversed(range(depth - 1)):
|
66 |
+
self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode,
|
67 |
+
padding, batch_norm))
|
68 |
+
prev_channels = 2**(wf+i)
|
69 |
+
up_out_features.append(prev_channels)
|
70 |
+
|
71 |
+
self.last = nn.Conv3d(prev_channels, n_classes, kernel_size=1)
|
72 |
+
|
73 |
+
mss_features += up_out_features[len(up_out_features)-1-mss_level if not mss_fromlatent
|
74 |
+
else len(up_out_features)-1-mss_level+1:-1]
|
75 |
+
|
76 |
+
self.mss_level = mss_level
|
77 |
+
self.mss_up = mss_up
|
78 |
+
self.mss_fromlatent = mss_fromlatent
|
79 |
+
self.mss_interpb4 = mss_interpb4
|
80 |
+
self.mss_convs = nn.ModuleList()
|
81 |
+
for i in range(self.mss_level):
|
82 |
+
self.mss_convs.append(nn.Conv3d(mss_features[i], n_classes, kernel_size=1))
|
83 |
+
if self.mss_level == 1:
|
84 |
+
self.mss_coeff = [0.5]
|
85 |
+
else:
|
86 |
+
lmbda = []
|
87 |
+
for i in range(self.mss_level-1, -1, -1):
|
88 |
+
lmbda.append(2**i)
|
89 |
+
self.mss_coeff = []
|
90 |
+
fact = 1.0 / sum(lmbda)
|
91 |
+
for i in range(self.mss_level-1):
|
92 |
+
self.mss_coeff.append(fact*lmbda[i])
|
93 |
+
self.mss_coeff.append(1.0 - sum(self.mss_coeff))
|
94 |
+
self.mss_coeff.reverse()
|
95 |
+
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
blocks = []
|
99 |
+
for i, down in enumerate(self.down_path):
|
100 |
+
x = down(x)
|
101 |
+
if i != len(self.down_path)-1:
|
102 |
+
blocks.append(x)
|
103 |
+
x = F.avg_pool3d(x, 2)
|
104 |
+
x = self.dropout(x)
|
105 |
+
|
106 |
+
if self.mss_fromlatent:
|
107 |
+
mss = [x]
|
108 |
+
else:
|
109 |
+
mss = []
|
110 |
+
|
111 |
+
for i, up in enumerate(self.up_path):
|
112 |
+
x = up(x, blocks[-i-1])
|
113 |
+
if self.training and ((len(self.up_path)-1-i <= self.mss_level) and not(i+1 == len(self.up_path))):
|
114 |
+
mss.append(x)
|
115 |
+
|
116 |
+
if self.training:
|
117 |
+
for i in range(len(mss)):
|
118 |
+
if not self.mss_interpb4:
|
119 |
+
mss[i] = F.interpolate(self.mss_convs[i](mss[i]), size=x.shape[2:], mode=self.mss_up)
|
120 |
+
else:
|
121 |
+
mss[i] = self.mss_convs[i](F.interpolate(mss[i], size=x.shape[2:], mode=self.mss_up))
|
122 |
+
|
123 |
+
return self.last(x), mss
|
124 |
+
else:
|
125 |
+
return self.last(x)
|
126 |
+
|
127 |
+
class UNetConvBlock(nn.Module):
|
128 |
+
def __init__(self, in_size, out_size, padding, batch_norm):
|
129 |
+
super(UNetConvBlock, self).__init__()
|
130 |
+
block = []
|
131 |
+
|
132 |
+
block.append(nn.Conv3d(in_size, out_size, kernel_size=3,
|
133 |
+
padding=int(padding)))
|
134 |
+
block.append(nn.ReLU())
|
135 |
+
if batch_norm:
|
136 |
+
block.append(nn.BatchNorm3d(out_size))
|
137 |
+
|
138 |
+
block.append(nn.Conv3d(out_size, out_size, kernel_size=3,
|
139 |
+
padding=int(padding)))
|
140 |
+
block.append(nn.ReLU())
|
141 |
+
if batch_norm:
|
142 |
+
block.append(nn.BatchNorm3d(out_size))
|
143 |
+
|
144 |
+
self.block = nn.Sequential(*block)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
out = self.block(x)
|
148 |
+
return out
|
149 |
+
|
150 |
+
|
151 |
+
class UNetUpBlock(nn.Module):
|
152 |
+
def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
|
153 |
+
super(UNetUpBlock, self).__init__()
|
154 |
+
if up_mode == 'upconv':
|
155 |
+
self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=2,
|
156 |
+
stride=2)
|
157 |
+
elif up_mode == 'upsample':
|
158 |
+
self.up = nn.Sequential(nn.Upsample(mode='trilinear', scale_factor=2),
|
159 |
+
nn.Conv3d(in_size, out_size, kernel_size=1))
|
160 |
+
|
161 |
+
self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)
|
162 |
+
|
163 |
+
def center_crop(self, layer, target_size):
|
164 |
+
_, _, layer_depth, layer_height, layer_width = layer.size()
|
165 |
+
diff_z = (layer_depth - target_size[0]) // 2
|
166 |
+
diff_y = (layer_height - target_size[1]) // 2
|
167 |
+
diff_x = (layer_width - target_size[2]) // 2
|
168 |
+
return layer[:, :, diff_z:(diff_z + target_size[0]), diff_y:(diff_y + target_size[1]), diff_x:(diff_x + target_size[2])]
|
169 |
+
# _, _, layer_height, layer_width = layer.size() #for 2D data
|
170 |
+
# diff_y = (layer_height - target_size[0]) // 2
|
171 |
+
# diff_x = (layer_width - target_size[1]) // 2
|
172 |
+
# return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]
|
173 |
+
|
174 |
+
def forward(self, x, bridge):
|
175 |
+
up = self.up(x)
|
176 |
+
# bridge = self.center_crop(bridge, up.shape[2:]) #sending shape ignoring 2 digit, so target size start with 0,1,2
|
177 |
+
up = F.interpolate(up, size=bridge.shape[2:], mode='trilinear')
|
178 |
+
out = torch.cat([up, bridge], 1)
|
179 |
+
out = self.conv_block(out)
|
180 |
+
|
181 |
+
return out
|
182 |
+
|
183 |
+
|
184 |
+
if __name__ == "__main__":
|
185 |
+
model = UNetMSS(in_channels=1, n_classes=1, depth=4, wf=6, padding=True,
|
186 |
+
batch_norm=False, up_mode='upconv', dropout=True, mss_level=3,
|
187 |
+
mss_fromlatent=True, mss_up="trilinear", mss_interpb4=True).cuda()
|
188 |
+
|
189 |
+
print(model)
|