File size: 11,813 Bytes
2cd560a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 |
# ------------------------------------------------------------------------------
# Adapted from https://github.com/akanazawa/hmr
# Original licence: Copyright (c) 2018 akanazawa, under the MIT License.
# ------------------------------------------------------------------------------
from abc import abstractmethod
import torch
import torch.nn as nn
from mmcv.cnn import normal_init, xavier_init
from mmpose.models.utils.geometry import batch_rodrigues
class BaseDiscriminator(nn.Module):
"""Base linear module for SMPL parameter discriminator.
Args:
fc_layers (Tuple): Tuple of neuron count,
such as (9, 32, 32, 1)
use_dropout (Tuple): Tuple of bool define use dropout or not
for each layer, such as (True, True, False)
drop_prob (Tuple): Tuple of float defined the drop prob,
such as (0.5, 0.5, 0)
use_activation(Tuple): Tuple of bool define use active function
or not, such as (True, True, False)
"""
def __init__(self, fc_layers, use_dropout, drop_prob, use_activation):
super().__init__()
self.fc_layers = fc_layers
self.use_dropout = use_dropout
self.drop_prob = drop_prob
self.use_activation = use_activation
self._check()
self.create_layers()
def _check(self):
"""Check input to avoid ValueError."""
if not isinstance(self.fc_layers, tuple):
raise TypeError(f'fc_layers require tuple, '
f'get {type(self.fc_layers)}')
if not isinstance(self.use_dropout, tuple):
raise TypeError(f'use_dropout require tuple, '
f'get {type(self.use_dropout)}')
if not isinstance(self.drop_prob, tuple):
raise TypeError(f'drop_prob require tuple, '
f'get {type(self.drop_prob)}')
if not isinstance(self.use_activation, tuple):
raise TypeError(f'use_activation require tuple, '
f'get {type(self.use_activation)}')
l_fc_layer = len(self.fc_layers)
l_use_drop = len(self.use_dropout)
l_drop_prob = len(self.drop_prob)
l_use_activation = len(self.use_activation)
pass_check = (
l_fc_layer >= 2 and l_use_drop < l_fc_layer
and l_drop_prob < l_fc_layer and l_use_activation < l_fc_layer
and l_drop_prob == l_use_drop)
if not pass_check:
msg = 'Wrong BaseDiscriminator parameters!'
raise ValueError(msg)
def create_layers(self):
"""Create layers."""
l_fc_layer = len(self.fc_layers)
l_use_drop = len(self.use_dropout)
l_use_activation = len(self.use_activation)
self.fc_blocks = nn.Sequential()
for i in range(l_fc_layer - 1):
self.fc_blocks.add_module(
name=f'regressor_fc_{i}',
module=nn.Linear(
in_features=self.fc_layers[i],
out_features=self.fc_layers[i + 1]))
if i < l_use_activation and self.use_activation[i]:
self.fc_blocks.add_module(
name=f'regressor_af_{i}', module=nn.ReLU())
if i < l_use_drop and self.use_dropout[i]:
self.fc_blocks.add_module(
name=f'regressor_fc_dropout_{i}',
module=nn.Dropout(p=self.drop_prob[i]))
@abstractmethod
def forward(self, inputs):
"""Forward function."""
msg = 'the base class [BaseDiscriminator] is not callable!'
raise NotImplementedError(msg)
def init_weights(self):
"""Initialize model weights."""
for m in self.fc_blocks.named_modules():
if isinstance(m, nn.Linear):
xavier_init(m, gain=0.01)
class ShapeDiscriminator(BaseDiscriminator):
"""Discriminator for SMPL shape parameters, the inputs is (batch_size x 10)
Args:
fc_layers (Tuple): Tuple of neuron count, such as (10, 5, 1)
use_dropout (Tuple): Tuple of bool define use dropout or
not for each layer, such as (True, True, False)
drop_prob (Tuple): Tuple of float defined the drop prob,
such as (0.5, 0)
use_activation(Tuple): Tuple of bool define use active
function or not, such as (True, False)
"""
def __init__(self, fc_layers, use_dropout, drop_prob, use_activation):
if fc_layers[-1] != 1:
msg = f'the neuron count of the last layer ' \
f'must be 1, but got {fc_layers[-1]}'
raise ValueError(msg)
super().__init__(fc_layers, use_dropout, drop_prob, use_activation)
def forward(self, inputs):
"""Forward function."""
return self.fc_blocks(inputs)
class PoseDiscriminator(nn.Module):
"""Discriminator for SMPL pose parameters of each joint. It is composed of
discriminators for each joints. The inputs is (batch_size x joint_count x
9)
Args:
channels (Tuple): Tuple of channel number,
such as (9, 32, 32, 1)
joint_count (int): Joint number, such as 23
"""
def __init__(self, channels, joint_count):
super().__init__()
if channels[-1] != 1:
msg = f'the neuron count of the last layer ' \
f'must be 1, but got {channels[-1]}'
raise ValueError(msg)
self.joint_count = joint_count
self.conv_blocks = nn.Sequential()
len_channels = len(channels)
for idx in range(len_channels - 2):
self.conv_blocks.add_module(
name=f'conv_{idx}',
module=nn.Conv2d(
in_channels=channels[idx],
out_channels=channels[idx + 1],
kernel_size=1,
stride=1))
self.fc_layer = nn.ModuleList()
for idx in range(joint_count):
self.fc_layer.append(
nn.Linear(
in_features=channels[len_channels - 2], out_features=1))
def forward(self, inputs):
"""Forward function.
The input is (batch_size x joint_count x 9).
"""
# shape: batch_size x 9 x 1 x joint_count
inputs = inputs.transpose(1, 2).unsqueeze(2).contiguous()
# shape: batch_size x c x 1 x joint_count
internal_outputs = self.conv_blocks(inputs)
outputs = []
for idx in range(self.joint_count):
outputs.append(self.fc_layer[idx](internal_outputs[:, :, 0, idx]))
return torch.cat(outputs, 1), internal_outputs
def init_weights(self):
"""Initialize model weights."""
for m in self.conv_blocks:
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001, bias=0)
for m in self.fc_layer.named_modules():
if isinstance(m, nn.Linear):
xavier_init(m, gain=0.01)
class FullPoseDiscriminator(BaseDiscriminator):
"""Discriminator for SMPL pose parameters of all joints.
Args:
fc_layers (Tuple): Tuple of neuron count,
such as (736, 1024, 1024, 1)
use_dropout (Tuple): Tuple of bool define use dropout or not
for each layer, such as (True, True, False)
drop_prob (Tuple): Tuple of float defined the drop prob,
such as (0.5, 0.5, 0)
use_activation(Tuple): Tuple of bool define use active
function or not, such as (True, True, False)
"""
def __init__(self, fc_layers, use_dropout, drop_prob, use_activation):
if fc_layers[-1] != 1:
msg = f'the neuron count of the last layer must be 1,' \
f' but got {fc_layers[-1]}'
raise ValueError(msg)
super().__init__(fc_layers, use_dropout, drop_prob, use_activation)
def forward(self, inputs):
"""Forward function."""
return self.fc_blocks(inputs)
class SMPLDiscriminator(nn.Module):
"""Discriminator for SMPL pose and shape parameters. It is composed of a
discriminator for SMPL shape parameters, a discriminator for SMPL pose
parameters of all joints and a discriminator for SMPL pose parameters of
each joint.
Args:
beta_channel (tuple of int): Tuple of neuron count of the
discriminator of shape parameters. Defaults to (10, 5, 1)
per_joint_channel (tuple of int): Tuple of neuron count of the
discriminator of each joint. Defaults to (9, 32, 32, 1)
full_pose_channel (tuple of int): Tuple of neuron count of the
discriminator of full pose. Defaults to (23*32, 1024, 1024, 1)
"""
def __init__(self,
beta_channel=(10, 5, 1),
per_joint_channel=(9, 32, 32, 1),
full_pose_channel=(23 * 32, 1024, 1024, 1)):
super().__init__()
self.joint_count = 23
# The count of SMPL shape parameter is 10.
assert beta_channel[0] == 10
# Use 3 x 3 rotation matrix as the pose parameters
# of each joint, so the input channel is 9.
assert per_joint_channel[0] == 9
assert self.joint_count * per_joint_channel[-2] \
== full_pose_channel[0]
self.beta_channel = beta_channel
self.per_joint_channel = per_joint_channel
self.full_pose_channel = full_pose_channel
self._create_sub_modules()
def _create_sub_modules(self):
"""Create sub discriminators."""
# create theta discriminator for each joint
self.pose_discriminator = PoseDiscriminator(self.per_joint_channel,
self.joint_count)
# create full pose discriminator for total joints
fc_layers = self.full_pose_channel
use_dropout = tuple([False] * (len(fc_layers) - 1))
drop_prob = tuple([0.5] * (len(fc_layers) - 1))
use_activation = tuple([True] * (len(fc_layers) - 2) + [False])
self.full_pose_discriminator = FullPoseDiscriminator(
fc_layers, use_dropout, drop_prob, use_activation)
# create shape discriminator for betas
fc_layers = self.beta_channel
use_dropout = tuple([False] * (len(fc_layers) - 1))
drop_prob = tuple([0.5] * (len(fc_layers) - 1))
use_activation = tuple([True] * (len(fc_layers) - 2) + [False])
self.shape_discriminator = ShapeDiscriminator(fc_layers, use_dropout,
drop_prob,
use_activation)
def forward(self, thetas):
"""Forward function."""
_, poses, shapes = thetas
batch_size = poses.shape[0]
shape_disc_value = self.shape_discriminator(shapes)
# The first rotation matrix is global rotation
# and is NOT used in discriminator.
if poses.dim() == 2:
rotate_matrixs = \
batch_rodrigues(poses.contiguous().view(-1, 3)
).view(batch_size, 24, 9)[:, 1:, :]
else:
rotate_matrixs = poses.contiguous().view(batch_size, 24,
9)[:, 1:, :].contiguous()
pose_disc_value, pose_inter_disc_value \
= self.pose_discriminator(rotate_matrixs)
full_pose_disc_value = self.full_pose_discriminator(
pose_inter_disc_value.contiguous().view(batch_size, -1))
return torch.cat(
(pose_disc_value, full_pose_disc_value, shape_disc_value), 1)
def init_weights(self):
"""Initialize model weights."""
self.full_pose_discriminator.init_weights()
self.pose_discriminator.init_weights()
self.shape_discriminator.init_weights()
|