|
import torch |
|
from torch import nn |
|
import numpy as np |
|
|
|
class DepthwiseSeperableConv2d(nn.Module): |
|
def __init__(self, input_channels, output_channels, **kwargs): |
|
super(DepthwiseSeperableConv2d, self).__init__() |
|
|
|
self.depthwise = nn.Conv2d(input_channels, input_channels, groups = input_channels, **kwargs) |
|
self.pointwise = nn.Conv2d(input_channels, output_channels, kernel_size = 1) |
|
|
|
def forward(self, x): |
|
x = self.depthwise(x) |
|
x = self.pointwise(x) |
|
|
|
return x |
|
|
|
class Conv2dBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size, stride = 1, bias = False): |
|
super(Conv2dBlock, self).__init__() |
|
|
|
self.model = nn.Sequential( |
|
nn.ReflectionPad2d(int((kernel_size - 1) / 2)), |
|
DepthwiseSeperableConv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = 0, bias = bias), |
|
nn.BatchNorm2d(out_channels), |
|
nn.LeakyReLU(0.2) |
|
) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
class Concat(nn.Module): |
|
def __init__(self, dim, *args): |
|
super(Concat, self).__init__() |
|
self.dim = dim |
|
|
|
for idx, module in enumerate(args): |
|
self.add_module(str(idx), module) |
|
|
|
def forward(self, input): |
|
inputs = [] |
|
for module in self._modules.values(): |
|
inputs.append(module(input)) |
|
|
|
inputs_shapes2 = [x.shape[2] for x in inputs] |
|
inputs_shapes3 = [x.shape[3] for x in inputs] |
|
|
|
if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)): |
|
inputs_ = inputs |
|
else: |
|
target_shape2 = min(inputs_shapes2) |
|
target_shape3 = min(inputs_shapes3) |
|
|
|
inputs_ = [] |
|
for inp in inputs: |
|
diff2 = (inp.size(2) - target_shape2) // 2 |
|
diff3 = (inp.size(3) - target_shape3) // 2 |
|
inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3]) |
|
|
|
return torch.cat(inputs_, dim=self.dim) |
|
|
|
def __len__(self): |
|
return len(self._modules) |