|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from .BasePIFuNet import BasePIFuNet |
|
import functools |
|
|
|
from .net_util import * |
|
from lib.dataset.PointFeat import PointFeat |
|
from lib.dataset.mesh_util import feat_select |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
"""Define a Resnet block""" |
|
|
|
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False): |
|
"""Initialize the Resnet block |
|
A resnet block is a conv block with skip connections |
|
We construct a conv block with build_conv_block function, |
|
and implement skip connections in <forward> function. |
|
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf |
|
""" |
|
super(ResnetBlock, self).__init__() |
|
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, last) |
|
|
|
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False): |
|
"""Construct a convolutional block. |
|
Parameters: |
|
dim (int) -- the number of channels in the conv layer. |
|
padding_type (str) -- the name of padding layer: reflect | replicate | zero |
|
norm_layer -- normalization layer |
|
use_dropout (bool) -- if use dropout layers. |
|
use_bias (bool) -- if the conv layer uses bias or not |
|
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) |
|
""" |
|
conv_block = [] |
|
p = 0 |
|
if padding_type == 'reflect': |
|
conv_block += [nn.ReflectionPad2d(1)] |
|
elif padding_type == 'replicate': |
|
conv_block += [nn.ReplicationPad2d(1)] |
|
elif padding_type == 'zero': |
|
p = 1 |
|
else: |
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type) |
|
|
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] |
|
if use_dropout: |
|
conv_block += [nn.Dropout(0.5)] |
|
|
|
p = 0 |
|
if padding_type == 'reflect': |
|
conv_block += [nn.ReflectionPad2d(1)] |
|
elif padding_type == 'replicate': |
|
conv_block += [nn.ReplicationPad2d(1)] |
|
elif padding_type == 'zero': |
|
p = 1 |
|
else: |
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type) |
|
if last: |
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)] |
|
else: |
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] |
|
|
|
return nn.Sequential(*conv_block) |
|
|
|
def forward(self, x): |
|
"""Forward function (with skip connections)""" |
|
out = x + self.conv_block(x) |
|
return out |
|
|
|
|
|
class ResnetFilter(nn.Module): |
|
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. |
|
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) |
|
""" |
|
|
|
def __init__(self, opt, input_nc=3, output_nc=256, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, |
|
n_blocks=6, padding_type='reflect'): |
|
"""Construct a Resnet-based generator |
|
Parameters: |
|
input_nc (int) -- the number of channels in input images |
|
output_nc (int) -- the number of channels in output images |
|
ngf (int) -- the number of filters in the last conv layer |
|
norm_layer -- normalization layer |
|
use_dropout (bool) -- if use dropout layers |
|
n_blocks (int) -- the number of ResNet blocks |
|
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero |
|
""" |
|
assert (n_blocks >= 0) |
|
super(ResnetFilter, self).__init__() |
|
if type(norm_layer) == functools.partial: |
|
use_bias = norm_layer.func == nn.InstanceNorm2d |
|
else: |
|
use_bias = norm_layer == nn.InstanceNorm2d |
|
|
|
model = [nn.ReflectionPad2d(3), |
|
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), |
|
norm_layer(ngf), |
|
nn.ReLU(True)] |
|
|
|
n_downsampling = 2 |
|
for i in range(n_downsampling): |
|
mult = 2 ** i |
|
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), |
|
norm_layer(ngf * mult * 2), |
|
nn.ReLU(True)] |
|
|
|
mult = 2 ** n_downsampling |
|
for i in range(n_blocks): |
|
if i == n_blocks - 1: |
|
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, |
|
use_dropout=use_dropout, use_bias=use_bias, last=True)] |
|
else: |
|
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, |
|
use_dropout=use_dropout, use_bias=use_bias)] |
|
|
|
if opt.use_tanh: |
|
model += [nn.Tanh()] |
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, input): |
|
"""Standard forward""" |
|
return self.model(input) |
|
|