Mehdi Cherti commited on
Commit
ae26d48
1 Parent(s): 1a02524

update run.py

Browse files
Files changed (1) hide show
  1. run.py +35 -3
run.py CHANGED
@@ -91,6 +91,11 @@ def ddgan_cc12m_v14():
91
  cfg['model']['num_channels_dae'] = 192
92
  return cfg
93
 
 
 
 
 
 
94
 
95
  def ddgan_cifar10_cond17():
96
  cfg = base()
@@ -107,6 +112,13 @@ def ddgan_cifar10_cond18():
107
  cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
108
  return cfg
109
 
 
 
 
 
 
 
 
110
  def ddgan_laion_aesthetic_v1():
111
  cfg = ddgan_cc12m_v11()
112
  cfg['model']['dataset_root'] = '"/p/scratch/ccstdl/cherti1/LAION-aesthetic/output/{00000..05038}.tar"'
@@ -122,10 +134,23 @@ def ddgan_laion_aesthetic_v3():
122
  cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
123
  return cfg
124
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  models = [
127
  ddgan_cifar10_cond17, # cifar10, cross attn for discr
128
  ddgan_cifar10_cond18, # cifar10, xl encoder
 
 
129
  ddgan_cc12m_v2, # baseline (no large text encoder, no classifier guidance)
130
  ddgan_cc12m_v6, # like v2 but using large T5 text encoder
131
  ddgan_cc12m_v7, # like v2 but with classifier guidance
@@ -135,9 +160,12 @@ models = [
135
  ddgan_cc12m_v12, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1
136
  ddgan_cc12m_v13, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + cond attn
137
  ddgan_cc12m_v14, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + 300M model
 
138
  ddgan_laion_aesthetic_v1, # like ddgan_cc12m_v11 but fine-tuned on laion aesthetic
139
  ddgan_laion_aesthetic_v2, # like ddgan_laion_aesthetic_v1 but trained from scratch with the new cross attn discr
140
- ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL
 
 
141
  ]
142
 
143
  def get_model(model_name):
@@ -146,7 +174,7 @@ def get_model(model_name):
146
  return model()
147
 
148
 
149
- def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0):
150
 
151
  cfg = get_model(model_name)
152
  model = cfg['model']
@@ -173,12 +201,16 @@ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guida
173
  args['guidance_scale'] = guidance_scale
174
  args['masked_mean'] = model.get("masked_mean")
175
  args['dynamic_thresholding_quantile'] = q
 
 
176
  args['n_mlp'] = model.get("n_mlp")
177
-
178
  if fid:
179
  args['compute_fid'] = ''
180
  args['real_img_dir'] = real_img_dir
181
  args['nb_images_for_fid'] = nb_images_for_fid
 
 
 
182
  cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
183
  print(cmd)
184
  call(cmd, shell=True)
 
91
  cfg['model']['num_channels_dae'] = 192
92
  return cfg
93
 
94
+ def ddgan_cc12m_v15():
95
+ cfg = ddgan_cc12m_v11()
96
+ cfg['model']['mismatch_loss'] = ''
97
+ cfg['model']['grad_penalty_cond'] = ''
98
+ return cfg
99
 
100
  def ddgan_cifar10_cond17():
101
  cfg = base()
 
112
  cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
113
  return cfg
114
 
115
+ def ddgan_cifar10_cond19():
116
+ cfg = ddgan_cifar10_cond17()
117
+ cfg['model']['discr_type'] = 'small_cond_attn'
118
+ cfg['model']['mismatch_loss'] = ''
119
+ cfg['model']['grad_penalty_cond'] = ''
120
+ return cfg
121
+
122
  def ddgan_laion_aesthetic_v1():
123
  cfg = ddgan_cc12m_v11()
124
  cfg['model']['dataset_root'] = '"/p/scratch/ccstdl/cherti1/LAION-aesthetic/output/{00000..05038}.tar"'
 
134
  cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
135
  return cfg
136
 
137
+ def ddgan_laion_aesthetic_v4():
138
+ cfg = ddgan_laion_aesthetic_v1()
139
+ cfg['model']['text_encoder'] = "openclip/ViT-L-14-336/openai"
140
+ return cfg
141
+
142
+
143
+ def ddgan_laion_aesthetic_v5():
144
+ cfg = ddgan_laion_aesthetic_v1()
145
+ cfg['model']['mismatch_loss'] = ''
146
+ cfg['model']['grad_penalty_cond'] = ''
147
+ return cfg
148
 
149
  models = [
150
  ddgan_cifar10_cond17, # cifar10, cross attn for discr
151
  ddgan_cifar10_cond18, # cifar10, xl encoder
152
+ ddgan_cifar10_cond19, # cifar10, xl encoder
153
+
154
  ddgan_cc12m_v2, # baseline (no large text encoder, no classifier guidance)
155
  ddgan_cc12m_v6, # like v2 but using large T5 text encoder
156
  ddgan_cc12m_v7, # like v2 but with classifier guidance
 
160
  ddgan_cc12m_v12, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1
161
  ddgan_cc12m_v13, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + cond attn
162
  ddgan_cc12m_v14, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + 300M model
163
+ ddgan_cc12m_v15, # fine-tune v11 with --mismatch_loss and --grad_penalty_cond
164
  ddgan_laion_aesthetic_v1, # like ddgan_cc12m_v11 but fine-tuned on laion aesthetic
165
  ddgan_laion_aesthetic_v2, # like ddgan_laion_aesthetic_v1 but trained from scratch with the new cross attn discr
166
+ ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL (continue from 23aug with mismatch and grad penalty and random_resized_crop_v1)
167
+ ddgan_laion_aesthetic_v4, # like ddgan_laion_aesthetic_v1 but trained from scratch with OpenAI's ClipEncoder
168
+ ddgan_laion_aesthetic_v5, # fine-tune ddgan_laion_aesthetic_v1 with mismatch and cond grad penalty losses
169
  ]
170
 
171
  def get_model(model_name):
 
174
  return model()
175
 
176
 
177
+ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False):
178
 
179
  cfg = get_model(model_name)
180
  model = cfg['model']
 
201
  args['guidance_scale'] = guidance_scale
202
  args['masked_mean'] = model.get("masked_mean")
203
  args['dynamic_thresholding_quantile'] = q
204
+ args['scale_factor_h'] = scale_factor_h
205
+ args['scale_factor_w'] = scale_factor_w
206
  args['n_mlp'] = model.get("n_mlp")
 
207
  if fid:
208
  args['compute_fid'] = ''
209
  args['real_img_dir'] = real_img_dir
210
  args['nb_images_for_fid'] = nb_images_for_fid
211
+ if compute_clip_score:
212
+ args['compute_clip_score'] = ""
213
+
214
  cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
215
  print(cmd)
216
  call(cmd, shell=True)