import os from glob import glob from subprocess import call import json models = {} def register(func): models[func.__name__] = func return func def get_model_config(model_name): return models[model_name]()["model"] def base(): return { "slurm":{ "t": 360, "N": 2, "n": 8, }, "model":{ "dataset": "wds", "seed": 0, "cross_attention": False, "num_channels": 3, "centered": True, "use_geometric": False, "beta_min": 0.1, "beta_max": 20.0, "num_channels_dae": 128, "n_mlp": 3, "ch_mult": [1, 1, 2, 2, 4, 4], "num_res_blocks": 2, "attn_resolutions": (16,), "dropout": 0.0, "resamp_with_conv": True, "conditional": True, "fir": True, "fir_kernel": [1, 3, 3, 1], "skip_rescale": True, "resblock_type": "biggan", "progressive": "none", "progressive_input": "residual", "progressive_combine": "sum", "embedding_type": "positional", "fourier_scale": 16.0, "not_use_tanh": False, "image_size": 256, "nz": 100, "num_timesteps": 4, "z_emb_dim": 256, "t_emb_dim": 256, "text_encoder": "google/t5-v1_1-base", "masked_mean": True, "cross_attention_block": "basic", } } @register def ddgan_cc12m_v2(): cfg = base() cfg['slurm']['N'] = 2 cfg['slurm']['n'] = 8 return cfg @register def ddgan_cc12m_v6(): cfg = base() cfg['model']['text_encoder'] = "google/t5-v1_1-large" return cfg @register def ddgan_cc12m_v7(): cfg = base() cfg['model']['classifier_free_guidance_proba'] = 0.2 cfg['slurm']['N'] = 2 cfg['slurm']['n'] = 8 return cfg @register def ddgan_cc12m_v8(): cfg = base() cfg['model']['text_encoder'] = "google/t5-v1_1-large" cfg['model']['classifier_free_guidance_proba'] = 0.2 return cfg @register def ddgan_cc12m_v9(): cfg = base() cfg['model']['text_encoder'] = "google/t5-v1_1-large" cfg['model']['classifier_free_guidance_proba'] = 0.2 cfg['model']['num_channels_dae'] = 320 cfg['model']['image_size'] = 64 cfg['model']['batch_size'] = 1 return cfg @register def ddgan_cc12m_v11(): cfg = base() cfg['model']['text_encoder'] = "google/t5-v1_1-large" cfg['model']['classifier_free_guidance_proba'] = 0.2 cfg['model']['cross_attention'] = True return cfg @register def ddgan_cc12m_v12(): cfg = ddgan_cc12m_v11() cfg['model']['text_encoder'] = "google/t5-v1_1-xl" cfg['model']['preprocessing'] = 'random_resized_crop_v1' return cfg @register def ddgan_cc12m_v13(): cfg = ddgan_cc12m_v12() cfg['model']['discr_type'] = "large_cond_attn" return cfg @register def ddgan_cc12m_v14(): cfg = ddgan_cc12m_v12() cfg['model']['num_channels_dae'] = 192 return cfg @register def ddgan_cc12m_v15(): cfg = ddgan_cc12m_v11() cfg['model']['mismatch_loss'] = True cfg['model']['grad_penalty_cond'] = True return cfg @register def ddgan_cifar10_cond17(): cfg = base() cfg['model']['image_size'] = 32 cfg['model']['classifier_free_guidance_proba'] = 0.2 cfg['model']['ch_mult'] = "1 2 2 2" cfg['model']['cross_attention'] = True cfg['model']['dataset'] = "cifar10" cfg['model']['n_mlp'] = 4 return cfg @register def ddgan_cifar10_cond18(): cfg = ddgan_cifar10_cond17() cfg['model']['text_encoder'] = "google/t5-v1_1-xl" return cfg @register def ddgan_cifar10_cond19(): cfg = ddgan_cifar10_cond17() cfg['model']['discr_type'] = 'small_cond_attn' cfg['model']['mismatch_loss'] = True cfg['model']['grad_penalty_cond'] =True return cfg @register def ddgan_laion_aesthetic_v1(): cfg = ddgan_cc12m_v11() cfg['model']['dataset_root'] = '"/p/scratch/ccstdl/cherti1/LAION-aesthetic/output/{00000..05038}.tar"' return cfg @register def ddgan_laion_aesthetic_v2(): cfg = ddgan_laion_aesthetic_v1() cfg['model']['discr_type'] = "large_cond_attn" return cfg @register def ddgan_laion_aesthetic_v3(): cfg = ddgan_laion_aesthetic_v1() cfg['model']['text_encoder'] = "google/t5-v1_1-xl" cfg['model']['mismatch_loss'] = True cfg['model']['grad_penalty_cond'] = True return cfg @register def ddgan_laion_aesthetic_v4(): cfg = ddgan_laion_aesthetic_v1() cfg['model']['text_encoder'] = "openclip/ViT-L-14-336/openai" return cfg @register def ddgan_laion_aesthetic_v5(): cfg = ddgan_laion_aesthetic_v1() cfg['model']['mismatch_loss'] = True cfg['model']['grad_penalty_cond'] = True return cfg @register def ddgan_laion2b_v1(): cfg = ddgan_laion_aesthetic_v3() cfg['model']['mismatch_loss'] = True cfg['model']['grad_penalty_cond'] = True cfg['model']['num_channels_dae'] = 224 cfg['model']['batch_size'] = 2 cfg['model']['discr_type'] = "large_cond_attn" cfg['model']['preprocessing'] = 'random_resized_crop_v1' return cfg @register def ddgan_laion_aesthetic_v6(): cfg = ddgan_laion_aesthetic_v3() cfg['model']['no_lr_decay'] = '' return cfg @register def ddgan_laion_aesthetic_v7(): cfg = ddgan_laion_aesthetic_v6() cfg['model']['r1_gamma'] = 5 return cfg @register def ddgan_laion_aesthetic_v8(): cfg = ddgan_laion_aesthetic_v6() cfg['model']['num_timesteps'] = 8 return cfg @register def ddgan_laion_aesthetic_v9(): cfg = ddgan_laion_aesthetic_v3() cfg['model']['num_channels_dae'] = 384 return cfg @register def ddgan_sd_v1(): cfg = ddgan_laion_aesthetic_v3() return cfg @register def ddgan_sd_v2(): cfg = ddgan_laion_aesthetic_v3() return cfg @register def ddgan_sd_v3(): cfg = ddgan_laion_aesthetic_v3() return cfg @register def ddgan_sd_v4(): cfg = ddgan_laion_aesthetic_v3() return cfg @register def ddgan_sd_v5(): cfg = ddgan_laion_aesthetic_v3() cfg['model']['num_timesteps'] = 8 return cfg @register def ddgan_sd_v6(): cfg = ddgan_laion_aesthetic_v3() cfg['model']['num_channels_dae'] = 192 return cfg @register def ddgan_sd_v7(): cfg = ddgan_laion_aesthetic_v3() return cfg @register def ddgan_sd_v8(): cfg = ddgan_laion_aesthetic_v3() cfg['model']['image_size'] = 512 return cfg @register def ddgan_laion_aesthetic_v12(): cfg = ddgan_laion_aesthetic_v3() return cfg @register def ddgan_laion_aesthetic_v13(): cfg = ddgan_laion_aesthetic_v3() cfg['model']['text_encoder'] = "openclip/ViT-H-14/laion2b_s32b_b79k" return cfg @register def ddgan_laion_aesthetic_v14(): cfg = ddgan_laion_aesthetic_v3() cfg['model']['text_encoder'] = "openclip/ViT-H-14/laion2b_s32b_b79k" return cfg @register def ddgan_sd_v9(): cfg = ddgan_laion_aesthetic_v3() cfg['model']['text_encoder'] = "openclip/ViT-H-14/laion2b_s32b_b79k" cfg['model']['classifier_free_guidance_proba'] = 0.0 return cfg @register def ddgan_sd_v10(): cfg = ddgan_sd_v9() cfg['model']['num_timesteps'] = 2 return cfg @register def ddgan_laion2b_v2(): cfg = ddgan_sd_v9() return cfg @register def ddgan_ddb_v1(): cfg = ddgan_sd_v10() return cfg @register def ddgan_sd_v11(): cfg = ddgan_sd_v10() cfg['model']['image_size'] = 512 return cfg @register def ddgan_ddb_v2(): cfg = ddgan_ddb_v1() cfg['model']['num_timesteps'] = 1 return cfg @register def ddgan_ddb_v3(): cfg = ddgan_ddb_v1() cfg['model']['num_channels_dae'] = 192 cfg['model']['num_timesteps'] = 2 return cfg @register def ddgan_ddb_v4(): cfg = ddgan_ddb_v1() cfg['model']['num_channels_dae'] = 256 cfg['model']['num_timesteps'] = 2 return cfg @register def ddgan_ddb_v5(): cfg = ddgan_ddb_v2() return cfg @register def ddgan_ddb_v6(): cfg = ddgan_ddb_v3() return cfg @register def ddgan_ddb_v7(): cfg = ddgan_ddb_v1() return cfg @register def ddgan_ddb_v9(): cfg = ddgan_ddb_v3() cfg['model']['attn_resolutions'] = [4, 8, 16, 32] return cfg @register def ddgan_laion_aesthetic_v15(): cfg = ddgan_ddb_v3() return cfg @register def ddgan_ddb_v10(): cfg = ddgan_ddb_v9() return cfg @register def ddgan_ddb_v11(): cfg = ddgan_ddb_v3() cfg['model']['text_encoder'] = "openclip/ViT-g-14/laion2B-s12B-b42K" return cfg @register def ddgan_ddb_v12(): cfg = ddgan_ddb_v3() cfg['model']['text_encoder'] = "openclip/ViT-bigG-14/laion2b_s39b_b160k" return cfg @register def ddgan_ddb_v13(): cfg = ddgan_ddb_v3() cfg['model']['num_channels_dae'] = 320 # 1B model return cfg @register def ddgan_ddb_v14(): cfg = ddgan_ddb_v3() cfg['model']['cross_attention_block'] = "cross_and_global_attention" return cfg