venkatesh-thiru's picture
Upload model
75507f4 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class make_dense(nn.Module):
def __init__(self,nChannels,GrowthRate,kernel_size=3):
super(make_dense,self).__init__()
self.conv = nn.Conv3d(nChannels,GrowthRate,kernel_size=kernel_size,padding=(kernel_size-1)//2,bias=True)
# self.norm = nn.BatchNorm3d(nChannels)
def forward(self,x):
# out = self.norm(x)
out = F.relu(self.conv(x))
out = torch.cat([x,out],dim=1)
return out
# class RDB(nn.Module):
# def __init__(self,inChannels,outChannels,nDenseLayer,GrowthRate,KernelSize = 3):
# super(RDB,self).__init__()
# nChannels_ = inChannels
# modules = []
# for i in range (nDenseLayer):
# modules.append(make_dense(nChannels_,GrowthRate,kernel_size=KernelSize))
# nChannels_ += GrowthRate
# self.dense_layers = nn.Sequential(*modules)
# self.conv_1x1 = nn.Conv3d(nChannels_,outChannels,kernel_size=1,padding=0,bias = False)
# def forward(self,x):
# out = self.dense_layers(x)
# out = self.conv_1x1(out)
# # out = out + x
# return out
class RDB(nn.Module):
def __init__(self,inChannels,outChannels,nDenseLayer,GrowthRate,KernelSize = 3,
block_dropout = True, block_dropout_rate = 0.2):
super(RDB,self).__init__()
nChannels_ = inChannels
modules = []
for i in range (nDenseLayer):
modules.append(make_dense(nChannels_,GrowthRate,kernel_size=KernelSize))
nChannels_ += GrowthRate
if block_dropout:
modules.append(nn.Dropout3d(block_dropout_rate))
self.dense_layers = nn.Sequential(*modules)
self.conv_1x1 = nn.Conv3d(nChannels_,outChannels,kernel_size=1,padding=0,bias = False)
def forward(self,x):
out = self.dense_layers(x)
out = self.conv_1x1(out)
# out = out + x
return out
class RRDB(nn.Module):
def __init__(self,nChannels,nDenseLayers,nInitFeat,GrowthRate,featureFusion=True,kernel_config = [3,3,3,3]):
super(RRDB,self).__init__()
nChannels_ = nChannels
nDenseLayers_ = nDenseLayers
nInitFeat_ = nInitFeat
GrowthRate_ = GrowthRate
self.featureFusion = featureFusion
#First Convolution
self.C1 = nn.Conv3d(nChannels_,nInitFeat_,kernel_size=kernel_config[0],padding=(kernel_config[0]-1)//2,bias=True)
# Initialize RDB
if self.featureFusion:
self.RDB1 = RDB(nInitFeat_,nInitFeat_,nDenseLayers_,GrowthRate_,kernel_config[1])
# print(f"RDB1 =========================================== \n {self.RDB1}")
self.RDB2 = RDB(nInitFeat_*2,nInitFeat_, nDenseLayers_, GrowthRate_,kernel_config[2])
# print(f"RDB2 =========================================== \n {self.RDB2}")
self.RDB3 = RDB(nInitFeat_*3,nInitFeat_, nDenseLayers_, GrowthRate_,kernel_config[3])
# print(f"RDB3 =========================================== \n {self.RDB3}")
self.FF_1X1 = nn.Conv3d(nInitFeat_*4, 1, kernel_size=1, padding=0, bias=True)
# print(f"FF1x1 =========================================== \n {self.FF_1X1}")
else:
self.RDB1 = RDB(nInitFeat_, nDenseLayers_, GrowthRate_, kernel_config[1])
self.RDB2 = RDB(nInitFeat_, nDenseLayers_, GrowthRate_, kernel_config[2])
self.RDB3 = RDB(nInitFeat_, nDenseLayers_, GrowthRate_, kernel_config[3])
self.FF_1X1 = nn.Conv3d(nInitFeat_, 1, kernel_size=1, padding=0, bias=True)
# Feature Fusion
# self.FF_3X3 = nn.Conv3d(nInitFeat_,nInitFeat_,kernel_size=3,padding=1,bias=True)
# self.final_layer = nn.Conv3d(nInitFeat_,nChannels_,kernel_size=1,padding=0,bias=False)
def forward(self,x):
First = F.relu(self.C1(x))
R_1 = self.RDB1(First)
if self.featureFusion:
FF0 = torch.cat([First,R_1],dim = 1)
R_2 = self.RDB2(FF0)
FF1 = torch.cat([First,R_1,R_2],dim=1)
R_3 = self.RDB3(FF1)
FF2 = torch.cat([First,R_1, R_2, R_3], dim=1)
FF1X1 = self.FF_1X1(FF2)
else:
R_2 = self.RDB2(R_1)
R_3 = self.RDB3(R_2)
FF1X1 = self.FF_1X1(R_3)
# FF2 = torch.cat([R_1,R_2,R_3],dim=1)
# FF1X1 = self.FF_1X1(FF2)
# FF3X3 = self.FF_3X3(FF1X1)
# output = self.final_layer(FF3X3)
return FF1X1
if __name__ == '__main__':
model = RRDB(nChannels=1,nDenseLayers=6,nInitFeat=6,GrowthRate=12,featureFusion=True,kernel_config = [3,3,3,3]).cuda()
dimensions = 1, 1, 64, 64, 64
x = torch.rand(dimensions)
x = x.cuda()
out = model(x)
print(out.shape)