one / generator.py
handsomeboyMMk's picture
Update generator.py
02133df
import torch
from torch import nn
from .modules import Conv2dBlock, Concat
class SkipEnrgcoderDecoder(nn.Module):
def __init__(self, input_depth, num_channels_down = [128] * 5, num_channels_up = [128] * 5, num_channels_skip = [128] * 5):
super(SkipEncoderDecoder, self).__init__()
self.model = nn.Sequential()
model_tmp = self.model
for i in range(len(num_channels_down)):
deeper = nn.Sequential()
skip = nn.Sequential()
if num_channels_skip[i] != 0:
model_tmp.add_module(str(len(model_tmp) + 1), Concat(1, skip, deeper))
else:
model_tmp.add_module(str(len(model_tmp) + 1), deeper)
model_tmp.add_module(str(len(model_tmp) + 1), nn.BatchNorm2d(num_channels_skip[i] + (num_channels_up[i + 1] if i < (len(num_channels_down) - 1) else num_channels_down[i])))
if num_channels_skip[i] != 0:
skip.add_module(str(len(skip) + 1), Conv2dBlock(input_depth, num_channels_skip[i], 1, bias = False))
deeper.add_module(str(len(deeper) + 1), Conv2dBlock(input_depth, num_channels_down[i], 3, 2, bias = False))
deeper.add_module(str(len(deeper) + 1), Conv2dBlock(num_channels_down[i], num_channels_down[i], 3, bias = False))
deeper_main = nn.Sequential()
if i == len(num_channels_down) - 1:
k = num_channels_down[i]
else:
deeper.add_module(str(len(deeper) + 1), deeper_main)
k = num_channels_up[i + 1]
deeper.add_module(str(len(deeper) + 1), nn.Upsample(scale_factor = 2, mode = 'nearest'))
model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_skip[i] + k, num_channels_up[i], 3, 1, bias = False))
model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_up[i], num_channels_up[i], 1, bias = False))
input_depth = num_channels_down[i]
model_tmp = deeper_main
self.model.add_module(str(len(self.model) + 1), nn.Conv2d(num_channels_up[0], 3, 1, bias = True))
self.model.add_module(str(len(self.model) + 1), nn.Sigmoid())
def forward(self, x):
return self.model(x)
def input_noise(INPUT_DEPTH, spatial_size, scale = 1./10):
shape = [1, INPUT_DEPTH, spatial_size[0], spatial_size[1]]
return torch.rand(*shape) * scale