|
|
|
import pytest |
|
import torch |
|
from mmcv.utils.parrots_wrapper import _BatchNorm |
|
|
|
from mmpose.models.backbones import VGG |
|
|
|
|
|
def check_norm_state(modules, train_state): |
|
"""Check if norm layer is in correct train state.""" |
|
for mod in modules: |
|
if isinstance(mod, _BatchNorm): |
|
if mod.training != train_state: |
|
return False |
|
return True |
|
|
|
|
|
def test_vgg(): |
|
"""Test VGG backbone.""" |
|
with pytest.raises(KeyError): |
|
|
|
VGG(18) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
VGG(11, num_stages=0) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
VGG(11, num_stages=6) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
VGG(11, dilations=(1, 1), num_stages=3) |
|
|
|
with pytest.raises(TypeError): |
|
|
|
model = VGG(11) |
|
model.init_weights(pretrained=0) |
|
|
|
|
|
model = VGG(11, norm_eval=True) |
|
model.init_weights() |
|
model.train() |
|
assert check_norm_state(model.modules(), False) |
|
|
|
|
|
model = VGG(11, out_indices=(0, 1, 2, 3, 4)) |
|
model.init_weights() |
|
model.train() |
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
feat = model(imgs) |
|
assert len(feat) == 5 |
|
assert feat[0].shape == (1, 64, 112, 112) |
|
assert feat[1].shape == (1, 128, 56, 56) |
|
assert feat[2].shape == (1, 256, 28, 28) |
|
assert feat[3].shape == (1, 512, 14, 14) |
|
assert feat[4].shape == (1, 512, 7, 7) |
|
|
|
|
|
model = VGG(11, num_classes=10, out_indices=(0, 1, 2, 3, 4, 5)) |
|
model.init_weights() |
|
model.train() |
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
feat = model(imgs) |
|
assert len(feat) == 6 |
|
assert feat[0].shape == (1, 64, 112, 112) |
|
assert feat[1].shape == (1, 128, 56, 56) |
|
assert feat[2].shape == (1, 256, 28, 28) |
|
assert feat[3].shape == (1, 512, 14, 14) |
|
assert feat[4].shape == (1, 512, 7, 7) |
|
assert feat[5].shape == (1, 10) |
|
|
|
|
|
model = VGG(11, norm_cfg=dict(type='BN'), out_indices=(0, 1, 2, 3, 4)) |
|
model.init_weights() |
|
model.train() |
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
feat = model(imgs) |
|
assert len(feat) == 5 |
|
assert feat[0].shape == (1, 64, 112, 112) |
|
assert feat[1].shape == (1, 128, 56, 56) |
|
assert feat[2].shape == (1, 256, 28, 28) |
|
assert feat[3].shape == (1, 512, 14, 14) |
|
assert feat[4].shape == (1, 512, 7, 7) |
|
|
|
|
|
model = VGG( |
|
11, |
|
num_classes=10, |
|
norm_cfg=dict(type='BN'), |
|
out_indices=(0, 1, 2, 3, 4, 5)) |
|
model.init_weights() |
|
model.train() |
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
feat = model(imgs) |
|
assert len(feat) == 6 |
|
assert feat[0].shape == (1, 64, 112, 112) |
|
assert feat[1].shape == (1, 128, 56, 56) |
|
assert feat[2].shape == (1, 256, 28, 28) |
|
assert feat[3].shape == (1, 512, 14, 14) |
|
assert feat[4].shape == (1, 512, 7, 7) |
|
assert feat[5].shape == (1, 10) |
|
|
|
|
|
model = VGG(13, out_indices=(0, 1, 2)) |
|
model.init_weights() |
|
model.train() |
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
feat = model(imgs) |
|
assert len(feat) == 3 |
|
assert feat[0].shape == (1, 64, 112, 112) |
|
assert feat[1].shape == (1, 128, 56, 56) |
|
assert feat[2].shape == (1, 256, 28, 28) |
|
|
|
|
|
model = VGG(16) |
|
model.init_weights() |
|
model.train() |
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
feat = model(imgs) |
|
assert feat.shape == (1, 512, 7, 7) |
|
|
|
|
|
model = VGG(19, num_classes=10) |
|
model.init_weights() |
|
model.train() |
|
|
|
imgs = torch.randn(1, 3, 224, 224) |
|
feat = model(imgs) |
|
assert feat.shape == (1, 10) |
|
|