File size: 3,891 Bytes
34d1f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import pytest
from mmengine import ConfigDict

from mmdet3d.utils.compat_cfg import (compat_imgs_per_gpu, compat_loader_args,
                                      compat_runner_args)


def test_compat_runner_args():
    cfg = ConfigDict(dict(total_epochs=12))
    with pytest.warns(None) as record:
        cfg = compat_runner_args(cfg)
    assert len(record) == 1
    assert 'runner' in record.list[0].message.args[0]
    assert 'runner' in cfg
    assert cfg.runner.type == 'EpochBasedRunner'
    assert cfg.runner.max_epochs == cfg.total_epochs


def test_compat_loader_args():
    cfg = ConfigDict(dict(data=dict(val=dict(), test=dict(), train=dict())))
    cfg = compat_loader_args(cfg)
    # auto fill loader args
    assert 'val_dataloader' in cfg.data
    assert 'train_dataloader' in cfg.data
    assert 'test_dataloader' in cfg.data
    cfg = ConfigDict(
        dict(
            data=dict(
                samples_per_gpu=1,
                persistent_workers=True,
                workers_per_gpu=1,
                val=dict(samples_per_gpu=3),
                test=dict(samples_per_gpu=2),
                train=dict())))
    cfg = compat_loader_args(cfg)

    assert cfg.data.train_dataloader.workers_per_gpu == 1
    assert cfg.data.train_dataloader.samples_per_gpu == 1
    assert cfg.data.train_dataloader.persistent_workers
    assert cfg.data.val_dataloader.workers_per_gpu == 1
    assert cfg.data.val_dataloader.samples_per_gpu == 3
    assert cfg.data.test_dataloader.workers_per_gpu == 1
    assert cfg.data.test_dataloader.samples_per_gpu == 2

    # test test is a list
    cfg = ConfigDict(
        dict(
            data=dict(
                samples_per_gpu=1,
                persistent_workers=True,
                workers_per_gpu=1,
                val=dict(samples_per_gpu=3),
                test=[dict(samples_per_gpu=2),
                      dict(samples_per_gpu=3)],
                train=dict())))

    cfg = compat_loader_args(cfg)

    # assert can not set args at the same time
    cfg = ConfigDict(
        dict(
            data=dict(
                samples_per_gpu=1,
                persistent_workers=True,
                workers_per_gpu=1,
                val=dict(samples_per_gpu=3),
                test=dict(samples_per_gpu=2),
                train=dict(),
                train_dataloader=dict(samples_per_gpu=2))))
    # samples_per_gpu can not be set in `train_dataloader`
    # and data field at the same time
    with pytest.raises(AssertionError):
        compat_loader_args(cfg)
    cfg = ConfigDict(
        dict(
            data=dict(
                samples_per_gpu=1,
                persistent_workers=True,
                workers_per_gpu=1,
                val=dict(samples_per_gpu=3),
                test=dict(samples_per_gpu=2),
                train=dict(),
                val_dataloader=dict(samples_per_gpu=2))))
    # samples_per_gpu can not be set in `val_dataloader`
    # and data field at the same time
    with pytest.raises(AssertionError):
        compat_loader_args(cfg)
    cfg = ConfigDict(
        dict(
            data=dict(
                samples_per_gpu=1,
                persistent_workers=True,
                workers_per_gpu=1,
                val=dict(samples_per_gpu=3),
                test=dict(samples_per_gpu=2),
                test_dataloader=dict(samples_per_gpu=2))))
    # samples_per_gpu can not be set in `test_dataloader`
    # and data field at the same time
    with pytest.raises(AssertionError):
        compat_loader_args(cfg)


def test_compat_imgs_per_gpu():
    cfg = ConfigDict(
        dict(
            data=dict(
                imgs_per_gpu=1,
                samples_per_gpu=2,
                val=dict(),
                test=dict(),
                train=dict())))
    cfg = compat_imgs_per_gpu(cfg)
    assert cfg.data.samples_per_gpu == cfg.data.imgs_per_gpu