File size: 9,366 Bytes
2252f3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from .BasePIFuNet import BasePIFuNet
import functools
from .net_util import *
from lib.dataset.PointFeat import PointFeat
from lib.dataset.mesh_util import feat_select
# class ResBlkPIFuNet(BasePIFuNet):
# def __init__(self, opt,
# projection_mode='orthogonal'):
# if opt.color_loss_type == 'l1':
# error_term = nn.L1Loss()
# elif opt.color_loss_type == 'mse':
# error_term = nn.MSELoss()
# super(ResBlkPIFuNet, self).__init__(
# projection_mode=projection_mode,
# error_term=error_term)
# self.name = 'respifu'
# self.opt = opt
# self.smpl_feats = self.opt.smpl_feats
# norm_type = get_norm_layer(norm_type=opt.norm_color)
# self.image_filter = ResnetFilter(opt, norm_layer=norm_type)
# self.smpl_feat_dict=None
# self.surface_classifier = SurfaceClassifier(
# filter_channels=self.opt.mlp_dim_color,
# num_views=self.opt.num_views,
# no_residual=self.opt.no_residual,
# last_op=nn.Tanh())
# self.normalizer = DepthNormalizer(opt)
# init_net(self)
# def filter(self, images):
# '''
# Filter the input images
# store all intermediate features.
# :param images: [B, C, H, W] input images
# '''
# self.im_feat = self.image_filter(images)
# def attach(self, im_feat):
# #self.im_feat = torch.cat([im_feat, self.im_feat], 1)
# self.geo_feat=im_feat
# def query(self, points, calibs, transforms=None, labels=None):
# '''
# Given 3D points, query the network predictions for each point.
# Image features should be pre-computed before this call.
# store all intermediate features.
# query() function may behave differently during training/testing.
# :param points: [B, 3, N] world space coordinates of points
# :param calibs: [B, 3, 4] calibration matrices for each image
# :param transforms: Optional [B, 2, 3] image space coordinate transforms
# :param labels: Optional [B, Res, N] gt labeling
# :return: [B, Res, N] predictions for each point
# '''
# if labels is not None:
# self.labels = labels
# xyz = self.projection(points, calibs, transforms)
# xy = xyz[:, :2, :]
# z = xyz[:, 2:3, :]
# z_feat = self.normalizer(z)
# if self.smpl_feat_dict==None:
# # This is a list of [B, Feat_i, N] features
# point_local_feat_list = [self.index(self.im_feat, xy), z_feat]
# # [B, Feat_all, N]
# point_local_feat = torch.cat(point_local_feat_list, 1)
# self.preds = self.surface_classifier(point_local_feat)
# else:
# point_feat_extractor = PointFeat(self.smpl_feat_dict["smpl_verts"],
# self.smpl_feat_dict["smpl_faces"])
# point_feat_out = point_feat_extractor.query(
# xyz.permute(0, 2, 1).contiguous(), self.smpl_feat_dict)
# feat_lst = [
# point_feat_out[key] for key in self.smpl_feats
# if key in point_feat_out.keys()
# ]
# smpl_feat = torch.cat(feat_lst, dim=2).permute(0, 2, 1)
# point_normal_feat = feat_select(self.index(self.geo_feat, xy), # select front or back normal feature
# smpl_feat[:, [-1], :])
# point_color_feat = torch.cat([self.index(self.im_feat, xy), z_feat],1)
# point_feat_list = [point_normal_feat, point_color_feat, smpl_feat[:, :-1, :]]
# point_feat = torch.cat(point_feat_list, 1)
# self.preds = self.surface_classifier(point_feat)
# def forward(self, images, im_feat, points, calibs, transforms=None, labels=None):
# self.filter(images)
# self.attach(im_feat)
# self.query(points, calibs, transforms, labels)
# error = self.get_error(self.preds,self.labels)
# return self.preds, error
class ResnetBlock(nn.Module):
"""Define a Resnet block"""
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False):
"""Initialize the Resnet block
A resnet block is a conv block with skip connections
We construct a conv block with build_conv_block function,
and implement skip connections in <forward> function.
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
"""
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, last)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False):
"""Construct a convolutional block.
Parameters:
dim (int) -- the number of channels in the conv layer.
padding_type (str) -- the name of padding layer: reflect | replicate | zero
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers.
use_bias (bool) -- if the conv layer uses bias or not
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
"""
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
if last:
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
else:
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
"""Forward function (with skip connections)"""
out = x + self.conv_block(x) # add skip connections
return out
class ResnetFilter(nn.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self, opt, input_nc=3, output_nc=256, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
n_blocks=6, padding_type='reflect'):
"""Construct a Resnet-based generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers
n_blocks (int) -- the number of ResNet blocks
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
"""
assert (n_blocks >= 0)
super(ResnetFilter, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2 ** n_downsampling
for i in range(n_blocks): # add ResNet blocks
if i == n_blocks - 1:
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer,
use_dropout=use_dropout, use_bias=use_bias, last=True)]
else:
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer,
use_dropout=use_dropout, use_bias=use_bias)]
if opt.use_tanh:
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
"""Standard forward"""
return self.model(input)
|