|
|
|
"""Collects all models.""" |
|
|
|
from .pggan_generator import PGGANGenerator |
|
from .pggan_discriminator import PGGANDiscriminator |
|
from .stylegan_generator import StyleGANGenerator |
|
from .stylegan_discriminator import StyleGANDiscriminator |
|
from .stylegan2_generator import StyleGAN2Generator |
|
from .stylegan2_discriminator import StyleGAN2Discriminator |
|
from .stylegan3_generator import StyleGAN3Generator |
|
from .ghfeat_encoder import GHFeatEncoder |
|
from .perceptual_model import PerceptualModel |
|
from .inception_model import InceptionModel |
|
from .eg3d_generator import EG3DGenerator |
|
from .eg3d_discriminator import DualDiscriminator |
|
from .pigan_generator import PiGANGenerator |
|
from .pigan_discriminator import PiGANDiscriminator |
|
from .volumegan_generator import VolumeGANGenerator |
|
from .volumegan_discriminator import VolumeGANDiscriminator |
|
from .eg3d_generator_fv import EG3DGeneratorFV |
|
from .bev3d_generator import BEV3DGenerator |
|
from .sgbev3d_generator import SGBEV3DGenerator |
|
|
|
__all__ = ['build_model'] |
|
|
|
_MODELS = { |
|
'PGGANGenerator': PGGANGenerator, |
|
'PGGANDiscriminator': PGGANDiscriminator, |
|
'StyleGANGenerator': StyleGANGenerator, |
|
'StyleGANDiscriminator': StyleGANDiscriminator, |
|
'StyleGAN2Generator': StyleGAN2Generator, |
|
'StyleGAN2Discriminator': StyleGAN2Discriminator, |
|
'StyleGAN3Generator': StyleGAN3Generator, |
|
'GHFeatEncoder': GHFeatEncoder, |
|
'PerceptualModel': PerceptualModel.build_model, |
|
'InceptionModel': InceptionModel.build_model, |
|
'EG3DGenerator': EG3DGenerator, |
|
'EG3DDiscriminator': DualDiscriminator, |
|
'PiGANGenerator': PiGANGenerator, |
|
'PiGANDiscriminator': PiGANDiscriminator, |
|
'VolumeGANGenerator': VolumeGANGenerator, |
|
'VolumeGANDiscriminator': VolumeGANDiscriminator, |
|
'EG3DGeneratorFV': EG3DGeneratorFV, |
|
'BEV3DGenerator': BEV3DGenerator, |
|
'SGBEV3DGenerator': SGBEV3DGenerator, |
|
} |
|
|
|
|
|
def build_model(model_type, **kwargs): |
|
"""Builds a model based on its class type. |
|
|
|
Args: |
|
model_type: Class type to which the model belongs, which is case |
|
sensitive. |
|
**kwargs: Additional arguments to build the model. |
|
|
|
Raises: |
|
ValueError: If the `model_type` is not supported. |
|
""" |
|
if model_type not in _MODELS: |
|
raise ValueError(f'Invalid model type: `{model_type}`!\n' |
|
f'Types allowed: {list(_MODELS)}.') |
|
return _MODELS[model_type](**kwargs) |
|
|