|
import torch |
|
from torch import nn |
|
from .modules import Conv2dBlock, Concat |
|
|
|
class SkipEnrgcoderDecoder(nn.Module): |
|
def __init__(self, input_depth, num_channels_down = [128] * 5, num_channels_up = [128] * 5, num_channels_skip = [128] * 5): |
|
super(SkipEncoderDecoder, self).__init__() |
|
|
|
self.model = nn.Sequential() |
|
model_tmp = self.model |
|
|
|
for i in range(len(num_channels_down)): |
|
|
|
deeper = nn.Sequential() |
|
skip = nn.Sequential() |
|
|
|
if num_channels_skip[i] != 0: |
|
model_tmp.add_module(str(len(model_tmp) + 1), Concat(1, skip, deeper)) |
|
else: |
|
model_tmp.add_module(str(len(model_tmp) + 1), deeper) |
|
|
|
model_tmp.add_module(str(len(model_tmp) + 1), nn.BatchNorm2d(num_channels_skip[i] + (num_channels_up[i + 1] if i < (len(num_channels_down) - 1) else num_channels_down[i]))) |
|
|
|
if num_channels_skip[i] != 0: |
|
skip.add_module(str(len(skip) + 1), Conv2dBlock(input_depth, num_channels_skip[i], 1, bias = False)) |
|
|
|
deeper.add_module(str(len(deeper) + 1), Conv2dBlock(input_depth, num_channels_down[i], 3, 2, bias = False)) |
|
deeper.add_module(str(len(deeper) + 1), Conv2dBlock(num_channels_down[i], num_channels_down[i], 3, bias = False)) |
|
|
|
deeper_main = nn.Sequential() |
|
|
|
if i == len(num_channels_down) - 1: |
|
k = num_channels_down[i] |
|
else: |
|
deeper.add_module(str(len(deeper) + 1), deeper_main) |
|
k = num_channels_up[i + 1] |
|
|
|
deeper.add_module(str(len(deeper) + 1), nn.Upsample(scale_factor = 2, mode = 'nearest')) |
|
|
|
model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_skip[i] + k, num_channels_up[i], 3, 1, bias = False)) |
|
model_tmp.add_module(str(len(model_tmp) + 1), Conv2dBlock(num_channels_up[i], num_channels_up[i], 1, bias = False)) |
|
|
|
input_depth = num_channels_down[i] |
|
model_tmp = deeper_main |
|
|
|
self.model.add_module(str(len(self.model) + 1), nn.Conv2d(num_channels_up[0], 3, 1, bias = True)) |
|
self.model.add_module(str(len(self.model) + 1), nn.Sigmoid()) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
def input_noise(INPUT_DEPTH, spatial_size, scale = 1./10): |
|
shape = [1, INPUT_DEPTH, spatial_size[0], spatial_size[1]] |
|
return torch.rand(*shape) * scale |