File size: 4,930 Bytes
75507f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class make_dense(nn.Module):
    def __init__(self,nChannels,GrowthRate,kernel_size=3):
        super(make_dense,self).__init__()
        self.conv = nn.Conv3d(nChannels,GrowthRate,kernel_size=kernel_size,padding=(kernel_size-1)//2,bias=True)
        # self.norm = nn.BatchNorm3d(nChannels)
    def forward(self,x):
        # out = self.norm(x)
        out = F.relu(self.conv(x))
        out = torch.cat([x,out],dim=1)
        return out

# class RDB(nn.Module):
#     def __init__(self,inChannels,outChannels,nDenseLayer,GrowthRate,KernelSize = 3):
#         super(RDB,self).__init__()
#         nChannels_ = inChannels
#         modules = []
#         for i in range (nDenseLayer):
#             modules.append(make_dense(nChannels_,GrowthRate,kernel_size=KernelSize))
#             nChannels_ += GrowthRate
#         self.dense_layers = nn.Sequential(*modules)
#         self.conv_1x1 = nn.Conv3d(nChannels_,outChannels,kernel_size=1,padding=0,bias = False)
#     def forward(self,x):
#         out = self.dense_layers(x)
#         out = self.conv_1x1(out)
#         # out = out + x
#         return out

class RDB(nn.Module):
    def __init__(self,inChannels,outChannels,nDenseLayer,GrowthRate,KernelSize = 3,

                 block_dropout = True, block_dropout_rate = 0.2):
        super(RDB,self).__init__()
        nChannels_ = inChannels
        modules = []
        for i in range (nDenseLayer):
            modules.append(make_dense(nChannels_,GrowthRate,kernel_size=KernelSize))
            nChannels_ += GrowthRate
        if block_dropout:
            modules.append(nn.Dropout3d(block_dropout_rate))
        self.dense_layers = nn.Sequential(*modules)
        self.conv_1x1 = nn.Conv3d(nChannels_,outChannels,kernel_size=1,padding=0,bias = False)
    def forward(self,x):
        out = self.dense_layers(x)
        out = self.conv_1x1(out)
        # out = out + x
        return out


class RRDB(nn.Module):
    def __init__(self,nChannels,nDenseLayers,nInitFeat,GrowthRate,featureFusion=True,kernel_config = [3,3,3,3]):
        super(RRDB,self).__init__()
        nChannels_ = nChannels
        nDenseLayers_ = nDenseLayers
        nInitFeat_ = nInitFeat
        GrowthRate_ = GrowthRate
        self.featureFusion = featureFusion

        #First Convolution
        self.C1 = nn.Conv3d(nChannels_,nInitFeat_,kernel_size=kernel_config[0],padding=(kernel_config[0]-1)//2,bias=True)
        # Initialize RDB
        if self.featureFusion:
            self.RDB1 = RDB(nInitFeat_,nInitFeat_,nDenseLayers_,GrowthRate_,kernel_config[1])
            # print(f"RDB1 =========================================== \n {self.RDB1}")
            self.RDB2 = RDB(nInitFeat_*2,nInitFeat_, nDenseLayers_, GrowthRate_,kernel_config[2])
            # print(f"RDB2 =========================================== \n {self.RDB2}")
            self.RDB3 = RDB(nInitFeat_*3,nInitFeat_, nDenseLayers_, GrowthRate_,kernel_config[3])
            # print(f"RDB3 =========================================== \n {self.RDB3}")
            self.FF_1X1 = nn.Conv3d(nInitFeat_*4, 1, kernel_size=1, padding=0, bias=True)
            # print(f"FF1x1 =========================================== \n {self.FF_1X1}")
        else:
            self.RDB1 = RDB(nInitFeat_, nDenseLayers_, GrowthRate_, kernel_config[1])
            self.RDB2 = RDB(nInitFeat_, nDenseLayers_, GrowthRate_, kernel_config[2])
            self.RDB3 = RDB(nInitFeat_, nDenseLayers_, GrowthRate_, kernel_config[3])
            self.FF_1X1 = nn.Conv3d(nInitFeat_, 1, kernel_size=1, padding=0, bias=True)


        # Feature Fusion


        # self.FF_3X3 = nn.Conv3d(nInitFeat_,nInitFeat_,kernel_size=3,padding=1,bias=True)

        # self.final_layer = nn.Conv3d(nInitFeat_,nChannels_,kernel_size=1,padding=0,bias=False)

    def forward(self,x):
        First = F.relu(self.C1(x))
        R_1 = self.RDB1(First)

        if self.featureFusion:
            FF0 = torch.cat([First,R_1],dim = 1)
            R_2 = self.RDB2(FF0)
            FF1 = torch.cat([First,R_1,R_2],dim=1)
            R_3 = self.RDB3(FF1)
            FF2 = torch.cat([First,R_1, R_2, R_3], dim=1)
            FF1X1 = self.FF_1X1(FF2)
        else:
            R_2 = self.RDB2(R_1)
            R_3 = self.RDB3(R_2)
            FF1X1 = self.FF_1X1(R_3)

        # FF2 = torch.cat([R_1,R_2,R_3],dim=1)
        # FF1X1 = self.FF_1X1(FF2)
        # FF3X3 = self.FF_3X3(FF1X1)
        # output = self.final_layer(FF3X3)

        return FF1X1

if __name__ == '__main__':
    model = RRDB(nChannels=1,nDenseLayers=6,nInitFeat=6,GrowthRate=12,featureFusion=True,kernel_config = [3,3,3,3]).cuda()
    dimensions = 1, 1, 64, 64, 64
    x = torch.rand(dimensions)
    x = x.cuda()
    out = model(x)
    print(out.shape)