Mehdi Cherti commited on
Commit
06c5f0c
1 Parent(s): 8d2bdec

update available models

Browse files
Files changed (1) hide show
  1. run.py +92 -14
run.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
- from clize import run
3
  from glob import glob
4
  from subprocess import call
5
-
6
  def base():
7
  return {
8
  "slurm":{
@@ -34,7 +33,7 @@ def base():
34
  "save_ckpt_every": 1,
35
  "masked_mean": "",
36
  "resume": "",
37
- }
38
  }
39
  def ddgan_cc12m_v2():
40
  cfg = base()
@@ -69,7 +68,6 @@ def ddgan_cc12m_v9():
69
  cfg['model']['batch_size'] = 1
70
  return cfg
71
 
72
-
73
  def ddgan_cc12m_v11():
74
  cfg = base()
75
  cfg['model']['text_encoder'] = "google/t5-v1_1-large"
@@ -77,22 +75,78 @@ def ddgan_cc12m_v11():
77
  cfg['model']['cross_attention'] = ""
78
  return cfg
79
 
80
- models = [
81
- ddgan_cc12m_v2,
82
- ddgan_cc12m_v6,
83
- ddgan_cc12m_v7,
84
- ddgan_cc12m_v8,
85
- ddgan_cc12m_v9,
86
- ddgan_cc12m_v11,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  ]
 
89
  def get_model(model_name):
90
  for model in models:
91
  if model.__name__ == model_name:
92
  return model()
93
 
94
 
95
- def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir=""):
96
 
97
  cfg = get_model(model_name)
98
  model = cfg['model']
@@ -104,6 +158,7 @@ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guida
104
  args = {}
105
  args['exp'] = model_name
106
  args['image_size'] = model['image_size']
 
107
  args['num_channels'] = model['num_channels']
108
  args['dataset'] = model['dataset']
109
  args['num_channels_dae'] = model['num_channels_dae']
@@ -116,12 +171,35 @@ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guida
116
  args['text_encoder'] = model.get("text_encoder")
117
  args['cross_attention'] = model.get("cross_attention")
118
  args['guidance_scale'] = guidance_scale
 
 
 
119
 
120
  if fid:
121
  args['compute_fid'] = ''
122
  args['real_img_dir'] = real_img_dir
123
- cmd = "python test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
 
124
  print(cmd)
125
  call(cmd, shell=True)
126
 
127
- run([test])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  from glob import glob
3
  from subprocess import call
4
+ import json
5
  def base():
6
  return {
7
  "slurm":{
 
33
  "save_ckpt_every": 1,
34
  "masked_mean": "",
35
  "resume": "",
36
+ },
37
  }
38
  def ddgan_cc12m_v2():
39
  cfg = base()
 
68
  cfg['model']['batch_size'] = 1
69
  return cfg
70
 
 
71
  def ddgan_cc12m_v11():
72
  cfg = base()
73
  cfg['model']['text_encoder'] = "google/t5-v1_1-large"
 
75
  cfg['model']['cross_attention'] = ""
76
  return cfg
77
 
78
+ def ddgan_cc12m_v12():
79
+ cfg = ddgan_cc12m_v11()
80
+ cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
81
+ cfg['model']['preprocessing'] = 'random_resized_crop_v1'
82
+ return cfg
83
+
84
+ def ddgan_cc12m_v13():
85
+ cfg = ddgan_cc12m_v12()
86
+ cfg['model']['discr_type'] = "large_cond_attn"
87
+ return cfg
88
+
89
+ def ddgan_cc12m_v14():
90
+ cfg = ddgan_cc12m_v12()
91
+ cfg['model']['num_channels_dae'] = 192
92
+ return cfg
93
+
94
+
95
+ def ddgan_cifar10_cond17():
96
+ cfg = base()
97
+ cfg['model']['image_size'] = 32
98
+ cfg['model']['classifier_free_guidance_proba'] = 0.2
99
+ cfg['model']['ch_mult'] = "1 2 2 2"
100
+ cfg['model']['cross_attention'] = ""
101
+ cfg['model']['dataset'] = "cifar10"
102
+ cfg['model']['n_mlp'] = 4
103
+ return cfg
104
 
105
+ def ddgan_cifar10_cond18():
106
+ cfg = ddgan_cifar10_cond17()
107
+ cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
108
+ return cfg
109
+
110
+ def ddgan_laion_aesthetic_v1():
111
+ cfg = ddgan_cc12m_v11()
112
+ cfg['model']['dataset_root'] = '"/p/scratch/ccstdl/cherti1/LAION-aesthetic/output/{00000..05038}.tar"'
113
+ return cfg
114
+
115
+ def ddgan_laion_aesthetic_v2():
116
+ cfg = ddgan_laion_aesthetic_v1()
117
+ cfg['model']['discr_type'] = "large_cond_attn"
118
+ return cfg
119
+
120
+ def ddgan_laion_aesthetic_v3():
121
+ cfg = ddgan_laion_aesthetic_v1()
122
+ cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
123
+ return cfg
124
+
125
+
126
+ models = [
127
+ ddgan_cifar10_cond17, # cifar10, cross attn for discr
128
+ ddgan_cifar10_cond18, # cifar10, xl encoder
129
+ ddgan_cc12m_v2, # baseline (no large text encoder, no classifier guidance)
130
+ ddgan_cc12m_v6, # like v2 but using large T5 text encoder
131
+ ddgan_cc12m_v7, # like v2 but with classifier guidance
132
+ ddgan_cc12m_v8, # like v6 but classifier guidance
133
+ ddgan_cc12m_v9, # ~1B model but 64x64 resolution
134
+ ddgan_cc12m_v11, # large text encoder + cross attention + classifier free guidance
135
+ ddgan_cc12m_v12, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1
136
+ ddgan_cc12m_v13, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + cond attn
137
+ ddgan_cc12m_v14, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + 300M model
138
+ ddgan_laion_aesthetic_v1, # like ddgan_cc12m_v11 but fine-tuned on laion aesthetic
139
+ ddgan_laion_aesthetic_v2, # like ddgan_laion_aesthetic_v1 but trained from scratch with the new cross attn discr
140
+ ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL
141
  ]
142
+
143
  def get_model(model_name):
144
  for model in models:
145
  if model.__name__ == model_name:
146
  return model()
147
 
148
 
149
+ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0):
150
 
151
  cfg = get_model(model_name)
152
  model = cfg['model']
 
158
  args = {}
159
  args['exp'] = model_name
160
  args['image_size'] = model['image_size']
161
+ args['seed'] = seed
162
  args['num_channels'] = model['num_channels']
163
  args['dataset'] = model['dataset']
164
  args['num_channels_dae'] = model['num_channels_dae']
 
171
  args['text_encoder'] = model.get("text_encoder")
172
  args['cross_attention'] = model.get("cross_attention")
173
  args['guidance_scale'] = guidance_scale
174
+ args['masked_mean'] = model.get("masked_mean")
175
+ args['dynamic_thresholding_quantile'] = q
176
+ args['n_mlp'] = model.get("n_mlp")
177
 
178
  if fid:
179
  args['compute_fid'] = ''
180
  args['real_img_dir'] = real_img_dir
181
+ args['nb_images_for_fid'] = nb_images_for_fid
182
+ cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
183
  print(cmd)
184
  call(cmd, shell=True)
185
 
186
+ def eval_results(model_name):
187
+ import pandas as pd
188
+ rows = []
189
+ cfg = get_model(model_name)
190
+ model = cfg['model']
191
+ paths = glob('./saved_info/dd_gan/{}/{}/fid*.json'.format(model["dataset"], model_name))
192
+ for path in paths:
193
+ with open(path, "r") as fd:
194
+ data = json.load(fd)
195
+ row = {}
196
+ row['fid'] = data['fid']
197
+ row['epoch'] = data['epoch_id']
198
+ rows.append(row)
199
+ out = './saved_info/dd_gan/{}/{}/fid.csv'.format(model["dataset"], model_name)
200
+ df = pd.DataFrame(rows)
201
+ df.to_csv(out, index=False)
202
+
203
+ if __name__ == "__main__":
204
+ from clize import run
205
+ run([test, eval_results])