File size: 2,404 Bytes
2f85de4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# python3.7
"""Collects all models."""

from .pggan_generator import PGGANGenerator
from .pggan_discriminator import PGGANDiscriminator
from .stylegan_generator import StyleGANGenerator
from .stylegan_discriminator import StyleGANDiscriminator
from .stylegan2_generator import StyleGAN2Generator
from .stylegan2_discriminator import StyleGAN2Discriminator
from .stylegan3_generator import StyleGAN3Generator
from .ghfeat_encoder import GHFeatEncoder
from .perceptual_model import PerceptualModel
from .inception_model import InceptionModel
from .eg3d_generator import EG3DGenerator
from .eg3d_discriminator import DualDiscriminator
from .pigan_generator import PiGANGenerator
from .pigan_discriminator import PiGANDiscriminator
from .volumegan_generator import VolumeGANGenerator
from .volumegan_discriminator import VolumeGANDiscriminator
from .eg3d_generator_fv import EG3DGeneratorFV
from .bev3d_generator import BEV3DGenerator
from .sgbev3d_generator import SGBEV3DGenerator

__all__ = ['build_model']

_MODELS = {
    'PGGANGenerator': PGGANGenerator,
    'PGGANDiscriminator': PGGANDiscriminator,
    'StyleGANGenerator': StyleGANGenerator,
    'StyleGANDiscriminator': StyleGANDiscriminator,
    'StyleGAN2Generator': StyleGAN2Generator,
    'StyleGAN2Discriminator': StyleGAN2Discriminator,
    'StyleGAN3Generator': StyleGAN3Generator,
    'GHFeatEncoder': GHFeatEncoder,
    'PerceptualModel': PerceptualModel.build_model,
    'InceptionModel': InceptionModel.build_model,
    'EG3DGenerator': EG3DGenerator,
    'EG3DDiscriminator': DualDiscriminator,
    'PiGANGenerator': PiGANGenerator,
    'PiGANDiscriminator': PiGANDiscriminator,
    'VolumeGANGenerator': VolumeGANGenerator,
    'VolumeGANDiscriminator': VolumeGANDiscriminator,
    'EG3DGeneratorFV': EG3DGeneratorFV,
    'BEV3DGenerator': BEV3DGenerator,
    'SGBEV3DGenerator': SGBEV3DGenerator,
}


def build_model(model_type, **kwargs):
    """Builds a model based on its class type.

    Args:
        model_type: Class type to which the model belongs, which is case
            sensitive.
        **kwargs: Additional arguments to build the model.

    Raises:
        ValueError: If the `model_type` is not supported.
    """
    if model_type not in _MODELS:
        raise ValueError(f'Invalid model type: `{model_type}`!\n'
                         f'Types allowed: {list(_MODELS)}.')
    return _MODELS[model_type](**kwargs)