Mehdi Cherti commited on
Commit
572f947
1 Parent(s): 06c5f0c

support clip score and higher resolution at test time

Browse files
Files changed (1) hide show
  1. test_ddgan.py +45 -13
test_ddgan.py CHANGED
@@ -12,7 +12,7 @@ import os
12
  import json
13
  import torchvision
14
  from score_sde.models.ncsnpp_generator_adagn import NCSNpp
15
- import t5
16
 
17
  #%% Diffusion coefficients
18
  def var_func_vp(t, beta_min, beta_max):
@@ -130,13 +130,13 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None
130
  def sample_from_model_classifier_free_guidance(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0):
131
  x = x_init
132
  null = text_encoder([""] * len(x_init), return_only_pooled=False)
133
- latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
134
  with torch.no_grad():
135
  for i in reversed(range(n_time)):
136
  t = torch.full((x.size(0),), i, dtype=torch.int64).to(x.device)
137
  t_time = t
138
 
139
- #latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
140
 
141
  x_0_uncond = generator(x, t_time, latent_z, cond=null)
142
 
@@ -184,10 +184,8 @@ def sample_from_model_classifier_free_guidance(coefficients, generator, n_time,
184
  def sample_and_test(args):
185
  torch.manual_seed(args.seed)
186
  device = 'cuda:0'
187
- text_encoder = t5.T5Encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
188
  args.cond_size = text_encoder.output_size
189
- # cond = text_encoder([str(yi%10) for yi in range(args.batch_size)])
190
-
191
  if args.dataset == 'cifar10':
192
  real_img_dir = 'pytorch_fid/cifar10_train_stat.npy'
193
  elif args.dataset == 'celeba_256':
@@ -201,7 +199,7 @@ def sample_and_test(args):
201
 
202
 
203
  netG = NCSNpp(args).to(device)
204
-
205
 
206
  if args.epoch_id == -1:
207
  epochs = range(1000)
@@ -214,7 +212,7 @@ def sample_and_test(args):
214
  if not os.path.exists(path):
215
  continue
216
  ckpt = torch.load(path, map_location=device)
217
- dest = './saved_info/dd_gan/{}/{}/fid_{}.json'.format(args.dataset, args.exp, args.epoch_id)
218
 
219
  if args.compute_fid and os.path.exists(dest):
220
  continue
@@ -258,6 +256,15 @@ def sample_and_test(args):
258
  block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
259
  inceptionv3 = InceptionV3([block_idx]).to(device)
260
 
 
 
 
 
 
 
 
 
 
261
  if not args.real_img_dir.endswith("npz"):
262
  real_mu, real_sigma = compute_statistics_of_path(
263
  args.real_img_dir, inceptionv3, args.batch_size, dims, device,
@@ -270,6 +277,9 @@ def sample_and_test(args):
270
  real_sigma = stats['sigma']
271
 
272
  fake_features = []
 
 
 
273
  for b in range(0, len(texts), args.batch_size):
274
  text = texts[b:b+args.batch_size]
275
  with torch.no_grad():
@@ -277,6 +287,7 @@ def sample_and_test(args):
277
  bs = len(text)
278
  t0 = time.time()
279
  x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
 
280
  if args.guidance_scale:
281
  fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
282
  else:
@@ -295,6 +306,17 @@ def sample_and_test(args):
295
  pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
296
  pred = pred.squeeze(3).squeeze(2).cpu().numpy()
297
  fake_features.append(pred)
 
 
 
 
 
 
 
 
 
 
 
298
  if i % 10 == 0:
299
  print('generating batch ', i, time.time() - t0)
300
  """
@@ -311,14 +333,17 @@ def sample_and_test(args):
311
  fake_mu = np.mean(fake_features, axis=0)
312
  fake_sigma = np.cov(fake_features, rowvar=False)
313
  fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
314
- dest = './saved_info/dd_gan/{}/{}/fid_{}.json'.format(args.dataset, args.exp, args.epoch_id)
315
  results = {
316
  "fid": fid,
317
  }
 
 
 
318
  results.update(vars(args))
319
  with open(dest, "w") as fd:
320
  json.dump(results, fd)
321
- print('FID = {}'.format(fid))
322
  else:
323
  if args.cond_text.endswith(".txt"):
324
  texts = open(args.cond_text).readlines()
@@ -326,11 +351,13 @@ def sample_and_test(args):
326
  else:
327
  texts = [args.cond_text] * args.batch_size
328
  cond = text_encoder(texts, return_only_pooled=False)
329
- x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size, args.image_size).to(device)
 
330
  if args.guidance_scale:
331
  fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
332
  else:
333
  fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
 
334
  fake_sample = to_range_0_1(fake_sample)
335
  torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
336
 
@@ -344,11 +371,16 @@ if __name__ == '__main__':
344
  help='seed used for initialization')
345
  parser.add_argument('--compute_fid', action='store_true', default=False,
346
  help='whether or not compute FID')
 
 
 
 
347
  parser.add_argument('--epoch_id', type=int,default=1000)
348
  parser.add_argument('--guidance_scale', type=float,default=0)
349
  parser.add_argument('--dynamic_thresholding_quantile', type=float,default=0)
350
  parser.add_argument('--cond_text', type=str,default="0")
351
-
 
352
  parser.add_argument('--cross_attention', action='store_true',default=False)
353
 
354
 
@@ -419,7 +451,7 @@ if __name__ == '__main__':
419
  parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
420
  parser.add_argument('--masked_mean', action='store_true',default=False)
421
  parser.add_argument('--nb_images_for_fid', type=int, default=0)
422
-
423
 
424
 
425
 
 
12
  import json
13
  import torchvision
14
  from score_sde.models.ncsnpp_generator_adagn import NCSNpp
15
+ from encoder import build_encoder
16
 
17
  #%% Diffusion coefficients
18
  def var_func_vp(t, beta_min, beta_max):
 
130
  def sample_from_model_classifier_free_guidance(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0):
131
  x = x_init
132
  null = text_encoder([""] * len(x_init), return_only_pooled=False)
133
+ #latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
134
  with torch.no_grad():
135
  for i in reversed(range(n_time)):
136
  t = torch.full((x.size(0),), i, dtype=torch.int64).to(x.device)
137
  t_time = t
138
 
139
+ latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
140
 
141
  x_0_uncond = generator(x, t_time, latent_z, cond=null)
142
 
 
184
  def sample_and_test(args):
185
  torch.manual_seed(args.seed)
186
  device = 'cuda:0'
187
+ text_encoder =build_encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
188
  args.cond_size = text_encoder.output_size
 
 
189
  if args.dataset == 'cifar10':
190
  real_img_dir = 'pytorch_fid/cifar10_train_stat.npy'
191
  elif args.dataset == 'celeba_256':
 
199
 
200
 
201
  netG = NCSNpp(args).to(device)
202
+ netG.attn_resolutions = [r * args.scale_factor_w for r in netG.attn_resolutions]
203
 
204
  if args.epoch_id == -1:
205
  epochs = range(1000)
 
212
  if not os.path.exists(path):
213
  continue
214
  ckpt = torch.load(path, map_location=device)
215
+ dest = './saved_info/dd_gan/{}/{}/eval_{}.json'.format(args.dataset, args.exp, args.epoch_id)
216
 
217
  if args.compute_fid and os.path.exists(dest):
218
  continue
 
256
  block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
257
  inceptionv3 = InceptionV3([block_idx]).to(device)
258
 
259
+ if args.compute_clip_score:
260
+ import clip
261
+ CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
262
+ CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
263
+ clip_model, preprocess = clip.load(args.clip_model, device)
264
+ clip_mean = torch.Tensor(CLIP_MEAN).view(1,-1,1,1).to(device)
265
+ clip_std = torch.Tensor(CLIP_STD).view(1,-1,1,1).to(device)
266
+
267
+
268
  if not args.real_img_dir.endswith("npz"):
269
  real_mu, real_sigma = compute_statistics_of_path(
270
  args.real_img_dir, inceptionv3, args.batch_size, dims, device,
 
277
  real_sigma = stats['sigma']
278
 
279
  fake_features = []
280
+ if args.compute_clip_score:
281
+ clip_scores = []
282
+
283
  for b in range(0, len(texts), args.batch_size):
284
  text = texts[b:b+args.batch_size]
285
  with torch.no_grad():
 
287
  bs = len(text)
288
  t0 = time.time()
289
  x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
290
+ #print(x_t_1.shape)
291
  if args.guidance_scale:
292
  fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
293
  else:
 
306
  pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
307
  pred = pred.squeeze(3).squeeze(2).cpu().numpy()
308
  fake_features.append(pred)
309
+
310
+ if args.compute_clip_score:
311
+ with torch.no_grad():
312
+ clip_ims = torch.nn.functional.interpolate(fake_sample, (224, 224), mode="bicubic")
313
+ clip_txt = clip.tokenize(text).to(device)
314
+ imf = clip_model.encode_image(clip_ims)
315
+ txtf = clip_model.encode_text(clip_txt)
316
+ imf = torch.nn.functional.normalize(imf, dim=1)
317
+ txtf = torch.nn.functional.normalize(txtf, dim=1)
318
+ clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
319
+ break
320
  if i % 10 == 0:
321
  print('generating batch ', i, time.time() - t0)
322
  """
 
333
  fake_mu = np.mean(fake_features, axis=0)
334
  fake_sigma = np.cov(fake_features, rowvar=False)
335
  fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
336
+ dest = './saved_info/dd_gan/{}/{}/eval_{}.json'.format(args.dataset, args.exp, args.epoch_id)
337
  results = {
338
  "fid": fid,
339
  }
340
+ if args.compute_clip_score:
341
+ clip_score = torch.cat(clip_scores).mean().item()
342
+ results['clip_score'] = clip_score
343
  results.update(vars(args))
344
  with open(dest, "w") as fd:
345
  json.dump(results, fd)
346
+ print(results)
347
  else:
348
  if args.cond_text.endswith(".txt"):
349
  texts = open(args.cond_text).readlines()
 
351
  else:
352
  texts = [args.cond_text] * args.batch_size
353
  cond = text_encoder(texts, return_only_pooled=False)
354
+ x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size*args.scale_factor_h, args.image_size*args.scale_factor_w).to(device)
355
+ t0 = time.time()
356
  if args.guidance_scale:
357
  fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
358
  else:
359
  fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
360
+ print(time.time() - t0)
361
  fake_sample = to_range_0_1(fake_sample)
362
  torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
363
 
 
371
  help='seed used for initialization')
372
  parser.add_argument('--compute_fid', action='store_true', default=False,
373
  help='whether or not compute FID')
374
+ parser.add_argument('--compute_clip_score', action='store_true', default=False,
375
+ help='whether or not compute CLIP score')
376
+ parser.add_argument('--clip_model', type=str,default="ViT-L/14")
377
+
378
  parser.add_argument('--epoch_id', type=int,default=1000)
379
  parser.add_argument('--guidance_scale', type=float,default=0)
380
  parser.add_argument('--dynamic_thresholding_quantile', type=float,default=0)
381
  parser.add_argument('--cond_text', type=str,default="0")
382
+ parser.add_argument('--scale_factor_h', type=int,default=1)
383
+ parser.add_argument('--scale_factor_w', type=int,default=1)
384
  parser.add_argument('--cross_attention', action='store_true',default=False)
385
 
386
 
 
451
  parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
452
  parser.add_argument('--masked_mean', action='store_true',default=False)
453
  parser.add_argument('--nb_images_for_fid', type=int, default=0)
454
+
455
 
456
 
457