show / mmpose-0.29.0 /tests /test_models /test_mesh_head.py
camenduru's picture
thanks to show ❤
3bbb319
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmpose.models import HMRMeshHead
from mmpose.models.misc.discriminator import SMPLDiscriminator
def test_mesh_hmr_head():
"""Test hmr mesh head."""
head = HMRMeshHead(in_channels=512)
head.init_weights()
input_shape = (1, 512, 8, 8)
inputs = _demo_inputs(input_shape)
out = head(inputs)
smpl_rotmat, smpl_shape, camera = out
assert smpl_rotmat.shape == torch.Size([1, 24, 3, 3])
assert smpl_shape.shape == torch.Size([1, 10])
assert camera.shape == torch.Size([1, 3])
"""Test hmr mesh head with assigned mean parameters and n_iter """
head = HMRMeshHead(
in_channels=512,
smpl_mean_params='tests/data/smpl/smpl_mean_params.npz',
n_iter=3)
head.init_weights()
input_shape = (1, 512, 8, 8)
inputs = _demo_inputs(input_shape)
out = head(inputs)
smpl_rotmat, smpl_shape, camera = out
assert smpl_rotmat.shape == torch.Size([1, 24, 3, 3])
assert smpl_shape.shape == torch.Size([1, 10])
assert camera.shape == torch.Size([1, 3])
# test discriminator with SMPL pose parameters
# in rotation matrix representation
disc = SMPLDiscriminator(
beta_channel=(10, 10, 5, 1),
per_joint_channel=(9, 32, 32, 16, 1),
full_pose_channel=(23 * 16, 256, 1))
pred_theta = (camera, smpl_rotmat, smpl_shape)
pred_score = disc(pred_theta)
assert pred_score.shape[1] == 25
# test discriminator with SMPL pose parameters
# in axis-angle representation
pred_theta = (camera, camera.new_zeros([1, 72]), smpl_shape)
pred_score = disc(pred_theta)
assert pred_score.shape[1] == 25
with pytest.raises(TypeError):
_ = SMPLDiscriminator(
beta_channel=[10, 10, 5, 1],
per_joint_channel=(9, 32, 32, 16, 1),
full_pose_channel=(23 * 16, 256, 1))
with pytest.raises(ValueError):
_ = SMPLDiscriminator(
beta_channel=(10, ),
per_joint_channel=(9, 32, 32, 16, 1),
full_pose_channel=(23 * 16, 256, 1))
def _demo_inputs(input_shape=(1, 3, 64, 64)):
"""Create a superset of inputs needed to run mesh head.
Args:
input_shape (tuple): input batch dimensions.
Default: (1, 3, 64, 64).
Returns:
Random input tensor with the size of input_shape.
"""
inps = np.random.random(input_shape)
inps = torch.FloatTensor(inps)
return inps