Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import pytest | |
import torch | |
from mmpose.models.backbones import RegNet | |
regnet_test_data = [ | |
('regnetx_400mf', | |
dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, | |
bot_mul=1.0), [32, 64, 160, 384]), | |
('regnetx_800mf', | |
dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, | |
bot_mul=1.0), [64, 128, 288, 672]), | |
('regnetx_1.6gf', | |
dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, | |
bot_mul=1.0), [72, 168, 408, 912]), | |
('regnetx_3.2gf', | |
dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, | |
bot_mul=1.0), [96, 192, 432, 1008]), | |
('regnetx_4.0gf', | |
dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, | |
bot_mul=1.0), [80, 240, 560, 1360]), | |
('regnetx_6.4gf', | |
dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, | |
bot_mul=1.0), [168, 392, 784, 1624]), | |
('regnetx_8.0gf', | |
dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, | |
bot_mul=1.0), [80, 240, 720, 1920]), | |
('regnetx_12gf', | |
dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, | |
bot_mul=1.0), [224, 448, 896, 2240]), | |
] | |
def test_regnet_backbone(arch_name, arch, out_channels): | |
with pytest.raises(AssertionError): | |
# ResNeXt depth should be in [50, 101, 152] | |
RegNet(arch_name + '233') | |
# output the last feature map | |
model = RegNet(arch_name) | |
model.init_weights() | |
model.train() | |
imgs = torch.randn(1, 3, 224, 224) | |
feat = model(imgs) | |
assert isinstance(feat, torch.Tensor) | |
assert feat.shape == (1, out_channels[-1], 7, 7) | |
# output feature map of all stages | |
model = RegNet(arch_name, out_indices=(0, 1, 2, 3)) | |
model.init_weights() | |
model.train() | |
imgs = torch.randn(1, 3, 224, 224) | |
feat = model(imgs) | |
assert len(feat) == 4 | |
assert feat[0].shape == (1, out_channels[0], 56, 56) | |
assert feat[1].shape == (1, out_channels[1], 28, 28) | |
assert feat[2].shape == (1, out_channels[2], 14, 14) | |
assert feat[3].shape == (1, out_channels[3], 7, 7) | |
def test_custom_arch(arch_name, arch, out_channels): | |
# output the last feature map | |
model = RegNet(arch) | |
model.init_weights() | |
imgs = torch.randn(1, 3, 224, 224) | |
feat = model(imgs) | |
assert isinstance(feat, torch.Tensor) | |
assert feat.shape == (1, out_channels[-1], 7, 7) | |
# output feature map of all stages | |
model = RegNet(arch, out_indices=(0, 1, 2, 3)) | |
model.init_weights() | |
imgs = torch.randn(1, 3, 224, 224) | |
feat = model(imgs) | |
assert len(feat) == 4 | |
assert feat[0].shape == (1, out_channels[0], 56, 56) | |
assert feat[1].shape == (1, out_channels[1], 28, 28) | |
assert feat[2].shape == (1, out_channels[2], 14, 14) | |
assert feat[3].shape == (1, out_channels[3], 7, 7) | |
def test_exception(): | |
# arch must be a str or dict | |
with pytest.raises(TypeError): | |
_ = RegNet(50) | |