Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
•
ae26d48
1
Parent(s):
1a02524
update run.py
Browse files
run.py
CHANGED
@@ -91,6 +91,11 @@ def ddgan_cc12m_v14():
|
|
91 |
cfg['model']['num_channels_dae'] = 192
|
92 |
return cfg
|
93 |
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
def ddgan_cifar10_cond17():
|
96 |
cfg = base()
|
@@ -107,6 +112,13 @@ def ddgan_cifar10_cond18():
|
|
107 |
cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
|
108 |
return cfg
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
def ddgan_laion_aesthetic_v1():
|
111 |
cfg = ddgan_cc12m_v11()
|
112 |
cfg['model']['dataset_root'] = '"/p/scratch/ccstdl/cherti1/LAION-aesthetic/output/{00000..05038}.tar"'
|
@@ -122,10 +134,23 @@ def ddgan_laion_aesthetic_v3():
|
|
122 |
cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
|
123 |
return cfg
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
models = [
|
127 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
128 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
|
|
|
|
129 |
ddgan_cc12m_v2, # baseline (no large text encoder, no classifier guidance)
|
130 |
ddgan_cc12m_v6, # like v2 but using large T5 text encoder
|
131 |
ddgan_cc12m_v7, # like v2 but with classifier guidance
|
@@ -135,9 +160,12 @@ models = [
|
|
135 |
ddgan_cc12m_v12, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1
|
136 |
ddgan_cc12m_v13, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + cond attn
|
137 |
ddgan_cc12m_v14, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + 300M model
|
|
|
138 |
ddgan_laion_aesthetic_v1, # like ddgan_cc12m_v11 but fine-tuned on laion aesthetic
|
139 |
ddgan_laion_aesthetic_v2, # like ddgan_laion_aesthetic_v1 but trained from scratch with the new cross attn discr
|
140 |
-
ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL
|
|
|
|
|
141 |
]
|
142 |
|
143 |
def get_model(model_name):
|
@@ -146,7 +174,7 @@ def get_model(model_name):
|
|
146 |
return model()
|
147 |
|
148 |
|
149 |
-
def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0):
|
150 |
|
151 |
cfg = get_model(model_name)
|
152 |
model = cfg['model']
|
@@ -173,12 +201,16 @@ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guida
|
|
173 |
args['guidance_scale'] = guidance_scale
|
174 |
args['masked_mean'] = model.get("masked_mean")
|
175 |
args['dynamic_thresholding_quantile'] = q
|
|
|
|
|
176 |
args['n_mlp'] = model.get("n_mlp")
|
177 |
-
|
178 |
if fid:
|
179 |
args['compute_fid'] = ''
|
180 |
args['real_img_dir'] = real_img_dir
|
181 |
args['nb_images_for_fid'] = nb_images_for_fid
|
|
|
|
|
|
|
182 |
cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
|
183 |
print(cmd)
|
184 |
call(cmd, shell=True)
|
|
|
91 |
cfg['model']['num_channels_dae'] = 192
|
92 |
return cfg
|
93 |
|
94 |
+
def ddgan_cc12m_v15():
|
95 |
+
cfg = ddgan_cc12m_v11()
|
96 |
+
cfg['model']['mismatch_loss'] = ''
|
97 |
+
cfg['model']['grad_penalty_cond'] = ''
|
98 |
+
return cfg
|
99 |
|
100 |
def ddgan_cifar10_cond17():
|
101 |
cfg = base()
|
|
|
112 |
cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
|
113 |
return cfg
|
114 |
|
115 |
+
def ddgan_cifar10_cond19():
|
116 |
+
cfg = ddgan_cifar10_cond17()
|
117 |
+
cfg['model']['discr_type'] = 'small_cond_attn'
|
118 |
+
cfg['model']['mismatch_loss'] = ''
|
119 |
+
cfg['model']['grad_penalty_cond'] = ''
|
120 |
+
return cfg
|
121 |
+
|
122 |
def ddgan_laion_aesthetic_v1():
|
123 |
cfg = ddgan_cc12m_v11()
|
124 |
cfg['model']['dataset_root'] = '"/p/scratch/ccstdl/cherti1/LAION-aesthetic/output/{00000..05038}.tar"'
|
|
|
134 |
cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
|
135 |
return cfg
|
136 |
|
137 |
+
def ddgan_laion_aesthetic_v4():
|
138 |
+
cfg = ddgan_laion_aesthetic_v1()
|
139 |
+
cfg['model']['text_encoder'] = "openclip/ViT-L-14-336/openai"
|
140 |
+
return cfg
|
141 |
+
|
142 |
+
|
143 |
+
def ddgan_laion_aesthetic_v5():
|
144 |
+
cfg = ddgan_laion_aesthetic_v1()
|
145 |
+
cfg['model']['mismatch_loss'] = ''
|
146 |
+
cfg['model']['grad_penalty_cond'] = ''
|
147 |
+
return cfg
|
148 |
|
149 |
models = [
|
150 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
151 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
152 |
+
ddgan_cifar10_cond19, # cifar10, xl encoder
|
153 |
+
|
154 |
ddgan_cc12m_v2, # baseline (no large text encoder, no classifier guidance)
|
155 |
ddgan_cc12m_v6, # like v2 but using large T5 text encoder
|
156 |
ddgan_cc12m_v7, # like v2 but with classifier guidance
|
|
|
160 |
ddgan_cc12m_v12, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1
|
161 |
ddgan_cc12m_v13, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + cond attn
|
162 |
ddgan_cc12m_v14, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + 300M model
|
163 |
+
ddgan_cc12m_v15, # fine-tune v11 with --mismatch_loss and --grad_penalty_cond
|
164 |
ddgan_laion_aesthetic_v1, # like ddgan_cc12m_v11 but fine-tuned on laion aesthetic
|
165 |
ddgan_laion_aesthetic_v2, # like ddgan_laion_aesthetic_v1 but trained from scratch with the new cross attn discr
|
166 |
+
ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL (continue from 23aug with mismatch and grad penalty and random_resized_crop_v1)
|
167 |
+
ddgan_laion_aesthetic_v4, # like ddgan_laion_aesthetic_v1 but trained from scratch with OpenAI's ClipEncoder
|
168 |
+
ddgan_laion_aesthetic_v5, # fine-tune ddgan_laion_aesthetic_v1 with mismatch and cond grad penalty losses
|
169 |
]
|
170 |
|
171 |
def get_model(model_name):
|
|
|
174 |
return model()
|
175 |
|
176 |
|
177 |
+
def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False):
|
178 |
|
179 |
cfg = get_model(model_name)
|
180 |
model = cfg['model']
|
|
|
201 |
args['guidance_scale'] = guidance_scale
|
202 |
args['masked_mean'] = model.get("masked_mean")
|
203 |
args['dynamic_thresholding_quantile'] = q
|
204 |
+
args['scale_factor_h'] = scale_factor_h
|
205 |
+
args['scale_factor_w'] = scale_factor_w
|
206 |
args['n_mlp'] = model.get("n_mlp")
|
|
|
207 |
if fid:
|
208 |
args['compute_fid'] = ''
|
209 |
args['real_img_dir'] = real_img_dir
|
210 |
args['nb_images_for_fid'] = nb_images_for_fid
|
211 |
+
if compute_clip_score:
|
212 |
+
args['compute_clip_score'] = ""
|
213 |
+
|
214 |
cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
|
215 |
print(cmd)
|
216 |
call(cmd, shell=True)
|