giantmonkeyTC
mm2
c2ca15f
raw
history blame contribute delete
751 Bytes
import torch
from mmdet3d.registry import MODELS
def test_dla_net():
# test DLANet used in SMOKE
# test list config
cfg = dict(
type='DLANet',
depth=34,
in_channels=3,
norm_cfg=dict(type='GN', num_groups=32))
img = torch.randn((4, 3, 32, 32))
self = MODELS.build(cfg)
self.init_weights()
results = self(img)
assert len(results) == 6
assert results[0].shape == torch.Size([4, 16, 32, 32])
assert results[1].shape == torch.Size([4, 32, 16, 16])
assert results[2].shape == torch.Size([4, 64, 8, 8])
assert results[3].shape == torch.Size([4, 128, 4, 4])
assert results[4].shape == torch.Size([4, 256, 2, 2])
assert results[5].shape == torch.Size([4, 512, 1, 1])