File size: 1,580 Bytes
8ca3a29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
# python3.7
"""Collects all models."""
from .pggan_generator import PGGANGenerator
from .pggan_discriminator import PGGANDiscriminator
from .stylegan_generator import StyleGANGenerator
from .stylegan_discriminator import StyleGANDiscriminator
from .stylegan2_generator import StyleGAN2Generator
from .stylegan2_discriminator import StyleGAN2Discriminator
from .stylegan3_generator import StyleGAN3Generator
from .ghfeat_encoder import GHFeatEncoder
from .perceptual_model import PerceptualModel
from .inception_model import InceptionModel
__all__ = ['build_model']
_MODELS = {
'PGGANGenerator': PGGANGenerator,
'PGGANDiscriminator': PGGANDiscriminator,
'StyleGANGenerator': StyleGANGenerator,
'StyleGANDiscriminator': StyleGANDiscriminator,
'StyleGAN2Generator': StyleGAN2Generator,
'StyleGAN2Discriminator': StyleGAN2Discriminator,
'StyleGAN3Generator': StyleGAN3Generator,
'GHFeatEncoder': GHFeatEncoder,
'PerceptualModel': PerceptualModel.build_model,
'InceptionModel': InceptionModel.build_model
}
def build_model(model_type, **kwargs):
"""Builds a model based on its class type.
Args:
model_type: Class type to which the model belongs, which is case
sensitive.
**kwargs: Additional arguments to build the model.
Raises:
ValueError: If the `model_type` is not supported.
"""
if model_type not in _MODELS:
raise ValueError(f'Invalid model type: `{model_type}`!\n'
f'Types allowed: {list(_MODELS)}.')
return _MODELS[model_type](**kwargs)
|