File size: 3,097 Bytes
3bbb319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmpose.models import build_posenet
def test_interhand3d_forward():
# model settings
model_cfg = dict(
type='Interhand3D',
pretrained='torchvision://resnet50',
backbone=dict(type='ResNet', depth=50),
keypoint_head=dict(
type='Interhand3DHead',
keypoint_head_cfg=dict(
in_channels=2048,
out_channels=21 * 64,
depth_size=64,
num_deconv_layers=3,
num_deconv_filters=(256, 256, 256),
num_deconv_kernels=(4, 4, 4),
),
root_head_cfg=dict(
in_channels=2048,
heatmap_size=64,
hidden_dims=(512, ),
),
hand_type_head_cfg=dict(
in_channels=2048,
num_labels=2,
hidden_dims=(512, ),
),
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True),
loss_root_depth=dict(type='L1Loss'),
loss_hand_type=dict(type='BCELoss', use_target_weight=True),
),
train_cfg={},
test_cfg=dict(flip_test=True, shift_heatmap=True))
detector = build_posenet(model_cfg)
detector.init_weights()
input_shape = (2, 3, 256, 256)
mm_inputs = _demo_mm_inputs(input_shape)
imgs = mm_inputs.pop('imgs')
target = mm_inputs.pop('target')
target_weight = mm_inputs.pop('target_weight')
img_metas = mm_inputs.pop('img_metas')
# Test forward train
losses = detector.forward(
imgs, target, target_weight, img_metas, return_loss=True)
assert isinstance(losses, dict)
# Test forward test
with torch.no_grad():
_ = detector.forward(imgs, img_metas=img_metas, return_loss=False)
_ = detector.forward_dummy(imgs)
def _demo_mm_inputs(input_shape=(1, 3, 256, 256), num_outputs=None):
"""Create a superset of inputs needed to run test or train batches.
Args:
input_shape (tuple):
input batch dimensions
"""
(N, C, H, W) = input_shape
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
imgs = torch.FloatTensor(imgs)
target = [
imgs.new_zeros(N, 42, 64, H // 4, W // 4),
imgs.new_zeros(N, 1),
imgs.new_zeros(N, 2),
]
target_weight = [
imgs.new_ones(N, 42, 1),
imgs.new_ones(N, 1),
imgs.new_ones(N),
]
img_metas = [{
'img_shape': (H, W, C),
'center': np.array([W / 2, H / 2]),
'scale': np.array([0.5, 0.5]),
'bbox_score': 1.0,
'bbox_id': 0,
'flip_pairs': [],
'inference_channel': np.arange(42),
'image_file': '<demo>.png',
'heatmap3d_depth_bound': 400.0,
'root_depth_bound': 400.0,
} for _ in range(N)]
mm_inputs = {
'imgs': imgs.requires_grad_(True),
'target': target,
'target_weight': target_weight,
'img_metas': img_metas
}
return mm_inputs
|