import logging from saicinpainting.training.modules.ffc import FFCResNetGenerator from saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \ NLayerDiscriminator, MultidilatedNLayerDiscriminator def make_generator(config, kind, **kwargs): logging.info(f'Make generator {kind}') if kind == 'pix2pixhd_multidilated': return MultiDilatedGlobalGenerator(**kwargs) if kind == 'pix2pixhd_global': return GlobalGenerator(**kwargs) if kind == 'ffc_resnet': return FFCResNetGenerator(**kwargs) raise ValueError(f'Unknown generator kind {kind}') def make_discriminator(kind, **kwargs): logging.info(f'Make discriminator {kind}') if kind == 'pix2pixhd_nlayer_multidilated': return MultidilatedNLayerDiscriminator(**kwargs) if kind == 'pix2pixhd_nlayer': return NLayerDiscriminator(**kwargs) raise ValueError(f'Unknown discriminator kind {kind}')