import torch from torch import nn from .modules import Conv2dBlock, Concat class SkipEnrgcoderDecoder(nn.Module): def __init__(self, input_depth, num_channels_down = [128] * 5, num_channels_up = [128] * 5, num_channels_skip = [128] * 5): super(SkipEncoderDecoder, self).__init__() self.model = nn.Sequential() model_tmp = self.model for i in range(len(num_channels_down)): deeper = nn.Sequential() skip = nn.Sequential() if num_channels_skip[i] != 0: model_tmp.add_module(str(len(model_tmp) + 1), Concat(1, skip, deeper)) else: model_tmp.add_module(str(len(model_tmp) + 1), deeper) model_tmp.add_module(str(len(model_tmp) + 1), nn.BatchNorm2d(num_channels_skip[i] + (num_channels_up[i + 1] if i < (len(num_channels_down) - 1) else num_channels_down[i]))) if num_channels_skip[i] != 0: skip.add_module(str(len(skip) + 1), Conv2dBlock(input_depth, num_channels_skip[i], 1, bias = False)) deeper.add_module(str(len(deeper) + 1), Conv2dBlock(input_depth, num_channels_down[i], 3, 2, bias = False)) deeper.add_module(str(len(deeper) + 1), Conv2dBlock(num_channels_down[i], num_channels_down[i], 3, bias = False)) deeper_main = nn.Sequential() if i == len(num_channels_down) - 1: k = num_channels_down[i] else: deeper.add_module(str(len(deeper) + 1), deeper_main) k = num_channels_up[i + 1] deeper.add_module(str(len(deeper) + 1), nn.Upsample(scale_factor = 2, mode = 'nearest')) model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_skip[i] + k, num_channels_up[i], 3, 1, bias = False)) model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_up[i], num_channels_up[i], 1, bias = False)) input_depth = num_channels_down[i] model_tmp = deeper_main self.model.add_module(str(len(self.model) + 1), nn.Conv2d(num_channels_up[0], 3, 1, bias = True)) self.model.add_module(str(len(self.model) + 1), nn.Sigmoid()) def forward(self, x): return self.model(x) def input_noise(INPUT_DEPTH, spatial_size, scale = 1./10): shape = [1, INPUT_DEPTH, spatial_size[0], spatial_size[1]] return torch.rand(*shape) * scale