Mehdi Cherti commited on
Commit
be61cf2
1 Parent(s): ae26d48
Files changed (10) hide show
  1. EMA.py +0 -1
  2. clip_encoder.py +64 -0
  3. encoder.py +9 -0
  4. run.py +103 -3
  5. scripts/init.sh +15 -0
  6. scripts/run_hdfml.sh +25 -0
  7. scripts/run_jurecadc_ddp.sh +4 -1
  8. test_ddgan.py +280 -64
  9. train_ddgan.py +158 -60
  10. utils.py +2 -1
EMA.py CHANGED
@@ -39,7 +39,6 @@ class EMA(Optimizer):
39
  # State initialization
40
  if 'ema' not in state:
41
  state['ema'] = p.data.clone()
42
-
43
  if p.shape not in params:
44
  params[p.shape] = {'idx': 0, 'data': []}
45
  ema[p.shape] = []
 
39
  # State initialization
40
  if 'ema' not in state:
41
  state['ema'] = p.data.clone()
 
42
  if p.shape not in params:
43
  params[p.shape] = {'idx': 0, 'data': []}
44
  ema[p.shape] = []
clip_encoder.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import open_clip
4
+ from einops import rearrange
5
+
6
+
7
+ def exists(val):
8
+ return val is not None
9
+
10
+ class CLIPEncoder(nn.Module):
11
+
12
+ def __init__(self, model, pretrained):
13
+ super().__init__()
14
+ self.model = model
15
+ self.pretrained = pretrained
16
+ self.model, _, _ = open_clip.create_model_and_transforms(model, pretrained=pretrained)
17
+ self.output_size = self.model.transformer.width
18
+
19
+ def forward(self, texts, return_only_pooled=True):
20
+ device = next(self.parameters()).device
21
+ toks = open_clip.tokenize(texts).to(device)
22
+ x = self.model.token_embedding(toks) # [batch_size, n_ctx, d_model]
23
+ x = x + self.model.positional_embedding
24
+ x = x.permute(1, 0, 2) # NLD -> LND
25
+ x = self.model.transformer(x, attn_mask=self.model.attn_mask)
26
+ x = x.permute(1, 0, 2) # LND -> NLD
27
+ x = self.model.ln_final(x)
28
+ mask = (toks!=0)
29
+ pooled = x[torch.arange(x.shape[0]), toks.argmax(dim=-1)] @ self.model.text_projection
30
+ if return_only_pooled:
31
+ return pooled
32
+ else:
33
+ return pooled, x, mask
34
+
35
+
36
+
37
+
38
+ class CLIPImageEncoder(nn.Module):
39
+
40
+ def __init__(self, model_type="ViT-B/32"):
41
+ super().__init__()
42
+ import clip
43
+ self.model, preprocess = clip.load(model_type, device="cpu", jit=False)
44
+ CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
45
+ CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
46
+ mean = torch.tensor(CLIP_MEAN).view(1, 3, 1, 1)
47
+ std = torch.tensor(CLIP_STD).view(1, 3, 1, 1)
48
+ self.register_buffer("mean", mean)
49
+ self.register_buffer("std", std)
50
+ self.output_size = 512
51
+
52
+ def forward_image(self, x):
53
+ x = torch.nn.functional.interpolate(x, mode='bicubic', size=(224, 224))
54
+ x = (x-self.mean)/self.std
55
+ return self.model.encode_image(x)
56
+
57
+ def forward_text(self, texts):
58
+ import clip
59
+ toks = clip.tokenize(texts, truncate=True).to(self.mean.device)
60
+ return self.model.encode_text(toks)
61
+
62
+
63
+
64
+
encoder.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import t5
2
+ import clip_encoder
3
+
4
+ def build_encoder(name, **kwargs):
5
+ if name.startswith("google"):
6
+ return t5.T5Encoder(name=name, **kwargs)
7
+ elif name.startswith("openclip"):
8
+ _, model, pretrained = name.split("/")
9
+ return clip_encoder.CLIPEncoder(model, pretrained)
run.py CHANGED
@@ -132,6 +132,8 @@ def ddgan_laion_aesthetic_v2():
132
  def ddgan_laion_aesthetic_v3():
133
  cfg = ddgan_laion_aesthetic_v1()
134
  cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
 
 
135
  return cfg
136
 
137
  def ddgan_laion_aesthetic_v4():
@@ -146,6 +148,85 @@ def ddgan_laion_aesthetic_v5():
146
  cfg['model']['grad_penalty_cond'] = ''
147
  return cfg
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  models = [
150
  ddgan_cifar10_cond17, # cifar10, cross attn for discr
151
  ddgan_cifar10_cond18, # cifar10, xl encoder
@@ -166,6 +247,23 @@ models = [
166
  ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL (continue from 23aug with mismatch and grad penalty and random_resized_crop_v1)
167
  ddgan_laion_aesthetic_v4, # like ddgan_laion_aesthetic_v1 but trained from scratch with OpenAI's ClipEncoder
168
  ddgan_laion_aesthetic_v5, # fine-tune ddgan_laion_aesthetic_v1 with mismatch and cond grad penalty losses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  ]
170
 
171
  def get_model(model_name):
@@ -174,7 +272,7 @@ def get_model(model_name):
174
  return model()
175
 
176
 
177
- def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False):
178
 
179
  cfg = get_model(model_name)
180
  model = cfg['model']
@@ -204,13 +302,15 @@ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guida
204
  args['scale_factor_h'] = scale_factor_h
205
  args['scale_factor_w'] = scale_factor_w
206
  args['n_mlp'] = model.get("n_mlp")
 
207
  if fid:
208
  args['compute_fid'] = ''
209
  args['real_img_dir'] = real_img_dir
210
  args['nb_images_for_fid'] = nb_images_for_fid
211
  if compute_clip_score:
212
  args['compute_clip_score'] = ""
213
-
 
214
  cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
215
  print(cmd)
216
  call(cmd, shell=True)
@@ -234,4 +334,4 @@ def eval_results(model_name):
234
 
235
  if __name__ == "__main__":
236
  from clize import run
237
- run([test, eval_results])
 
132
  def ddgan_laion_aesthetic_v3():
133
  cfg = ddgan_laion_aesthetic_v1()
134
  cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
135
+ cfg['model']['mismatch_loss'] = ''
136
+ cfg['model']['grad_penalty_cond'] = ''
137
  return cfg
138
 
139
  def ddgan_laion_aesthetic_v4():
 
148
  cfg['model']['grad_penalty_cond'] = ''
149
  return cfg
150
 
151
+
152
+
153
+ def ddgan_laion2b_v1():
154
+ cfg = ddgan_laion_aesthetic_v3()
155
+ cfg['model']['mismatch_loss'] = ''
156
+ cfg['model']['grad_penalty_cond'] = ''
157
+ cfg['model']['num_channels_dae'] = 224
158
+ cfg['model']['batch_size'] = 2
159
+ cfg['model']['discr_type'] = "large_cond_attn"
160
+ cfg['model']['preprocessing'] = 'random_resized_crop_v1'
161
+ return cfg
162
+
163
+ def ddgan_laion_aesthetic_v6():
164
+ cfg = ddgan_laion_aesthetic_v3()
165
+ cfg['model']['no_lr_decay'] = ''
166
+ return cfg
167
+
168
+
169
+
170
+ def ddgan_laion_aesthetic_v7():
171
+ cfg = ddgan_laion_aesthetic_v6()
172
+ cfg['model']['r1_gamma'] = 5
173
+ return cfg
174
+
175
+
176
+ def ddgan_laion_aesthetic_v8():
177
+ cfg = ddgan_laion_aesthetic_v6()
178
+ cfg['model']['num_timesteps'] = 8
179
+ return cfg
180
+
181
+ def ddgan_laion_aesthetic_v9():
182
+ cfg = ddgan_laion_aesthetic_v3()
183
+ cfg['model']['num_channels_dae'] = 384
184
+ return cfg
185
+
186
+ def ddgan_sd_v1():
187
+ cfg = ddgan_laion_aesthetic_v3()
188
+ return cfg
189
+ def ddgan_sd_v2():
190
+ cfg = ddgan_laion_aesthetic_v3()
191
+ return cfg
192
+ def ddgan_sd_v3():
193
+ cfg = ddgan_laion_aesthetic_v3()
194
+ return cfg
195
+ def ddgan_sd_v4():
196
+ cfg = ddgan_laion_aesthetic_v3()
197
+ return cfg
198
+ def ddgan_sd_v5():
199
+ cfg = ddgan_laion_aesthetic_v3()
200
+ cfg['model']['num_timesteps'] = 8
201
+ return cfg
202
+ def ddgan_sd_v6():
203
+ cfg = ddgan_laion_aesthetic_v3()
204
+ cfg['model']['num_channels_dae'] = 192
205
+ return cfg
206
+ def ddgan_sd_v7():
207
+ cfg = ddgan_laion_aesthetic_v3()
208
+ return cfg
209
+ def ddgan_sd_v8():
210
+ cfg = ddgan_laion_aesthetic_v3()
211
+ cfg['model']['image_size'] = 512
212
+ return cfg
213
+ def ddgan_laion_aesthetic_v12():
214
+ cfg = ddgan_laion_aesthetic_v3()
215
+ return cfg
216
+ def ddgan_laion_aesthetic_v13():
217
+ cfg = ddgan_laion_aesthetic_v3()
218
+ cfg['model']['text_encoder'] = "openclip/ViT-H-14/laion2b_s32b_b79k"
219
+ return cfg
220
+
221
+ def ddgan_laion_aesthetic_v14():
222
+ cfg = ddgan_laion_aesthetic_v3()
223
+ cfg['model']['text_encoder'] = "openclip/ViT-H-14/laion2b_s32b_b79k"
224
+ return cfg
225
+ def ddgan_sd_v9():
226
+ cfg = ddgan_laion_aesthetic_v3()
227
+ cfg['model']['text_encoder'] = "openclip/ViT-H-14/laion2b_s32b_b79k"
228
+ return cfg
229
+
230
  models = [
231
  ddgan_cifar10_cond17, # cifar10, cross attn for discr
232
  ddgan_cifar10_cond18, # cifar10, xl encoder
 
247
  ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL (continue from 23aug with mismatch and grad penalty and random_resized_crop_v1)
248
  ddgan_laion_aesthetic_v4, # like ddgan_laion_aesthetic_v1 but trained from scratch with OpenAI's ClipEncoder
249
  ddgan_laion_aesthetic_v5, # fine-tune ddgan_laion_aesthetic_v1 with mismatch and cond grad penalty losses
250
+ ddgan_laion_aesthetic_v6, # like v3 but without lr decay
251
+ ddgan_laion_aesthetic_v7, # like v6 but with r1 gamma of 5 instead of 1, trying to constrain the discr more.
252
+ ddgan_laion_aesthetic_v8, # like v6 but with 8 timesteps
253
+ ddgan_laion_aesthetic_v9,
254
+ ddgan_laion_aesthetic_v12,
255
+ ddgan_laion_aesthetic_v13,
256
+ ddgan_laion_aesthetic_v14,
257
+ ddgan_laion2b_v1,
258
+ ddgan_sd_v1,
259
+ ddgan_sd_v2,
260
+ ddgan_sd_v3,
261
+ ddgan_sd_v4,
262
+ ddgan_sd_v5,
263
+ ddgan_sd_v6,
264
+ ddgan_sd_v7,
265
+ ddgan_sd_v8,
266
+ ddgan_sd_v9,
267
  ]
268
 
269
  def get_model(model_name):
 
272
  return model()
273
 
274
 
275
+ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False, eval_name="", scale_method="convolutional"):
276
 
277
  cfg = get_model(model_name)
278
  model = cfg['model']
 
302
  args['scale_factor_h'] = scale_factor_h
303
  args['scale_factor_w'] = scale_factor_w
304
  args['n_mlp'] = model.get("n_mlp")
305
+ args['scale_method'] = scale_method
306
  if fid:
307
  args['compute_fid'] = ''
308
  args['real_img_dir'] = real_img_dir
309
  args['nb_images_for_fid'] = nb_images_for_fid
310
  if compute_clip_score:
311
  args['compute_clip_score'] = ""
312
+ if eval_name:
313
+ args["eval_name"] = eval_name
314
  cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
315
  print(cmd)
316
  call(cmd, shell=True)
 
334
 
335
  if __name__ == "__main__":
336
  from clize import run
337
+ run([test, eval_results])
scripts/init.sh CHANGED
@@ -32,6 +32,21 @@ if [[ "$machine" == juwelsbooster ]]; then
32
  ml torchvision/0.12.0
33
  source /p/project/covidnetx/environments/juwels_booster_2022/bin/activate
34
  fi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  if [[ "$machine" == jusuf ]]; then
36
  echo not supported
37
  fi
 
32
  ml torchvision/0.12.0
33
  source /p/project/covidnetx/environments/juwels_booster_2022/bin/activate
34
  fi
35
+ if [[ "$machine" == hdfml ]]; then
36
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
37
+ ml purge
38
+ ml use $OTHERSTAGES
39
+ ml Stages/2022
40
+ ml GCC/11.2.0
41
+ ml OpenMPI/4.1.2
42
+ ml CUDA/11.5
43
+ ml cuDNN/8.3.1.22-CUDA-11.5
44
+ ml NCCL/2.12.7-1-CUDA-11.5
45
+ ml PyTorch/1.11-CUDA-11.5
46
+ ml Horovod/0.24
47
+ ml torchvision/0.12.0
48
+ source envs/hdfml/bin/activate
49
+ fi
50
  if [[ "$machine" == jusuf ]]; then
51
  echo not supported
52
  fi
scripts/run_hdfml.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -x
2
+ #SBATCH --account=cstdl
3
+ #SBATCH --nodes=8
4
+ #SBATCH --ntasks-per-node=4
5
+ #SBATCH --cpus-per-task=8
6
+ #SBATCH --time=06:00:00
7
+ #SBATCH --gres=gpu
8
+ #SBATCH --partition=batch
9
+ ml purge
10
+ ml use $OTHERSTAGES
11
+ ml Stages/2022
12
+ ml GCC/11.2.0
13
+ ml OpenMPI/4.1.2
14
+ ml CUDA/11.5
15
+ ml cuDNN/8.3.1.22-CUDA-11.5
16
+ ml NCCL/2.12.7-1-CUDA-11.5
17
+ ml PyTorch/1.11-CUDA-11.5
18
+ ml Horovod/0.24
19
+ ml torchvision/0.12.0
20
+ source envs/hdfml/bin/activate
21
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
22
+ echo "Job id: $SLURM_JOB_ID"
23
+ export TOKENIZERS_PARALLELISM=false
24
+ export NCCL_ASYNC_ERROR_HANDLING=1
25
+ srun python -u $*
scripts/run_jurecadc_ddp.sh CHANGED
@@ -13,5 +13,8 @@ source scripts/init.sh
13
  export CUDA_VISIBLE_DEVICES=0,1,2,3
14
  echo "Job id: $SLURM_JOB_ID"
15
  export TOKENIZERS_PARALLELISM=false
16
- export NCCL_ASYNC_ERROR_HANDLING=1
 
 
 
17
  srun python -u $*
 
13
  export CUDA_VISIBLE_DEVICES=0,1,2,3
14
  echo "Job id: $SLURM_JOB_ID"
15
  export TOKENIZERS_PARALLELISM=false
16
+ #export NCCL_ASYNC_ERROR_HANDLING=1
17
+ export NCCL_IB_TIMEOUT=50
18
+ export UCX_RC_TIMEOUT=4s
19
+ export NCCL_IB_RETRY_CNT=10
20
  srun python -u $*
test_ddgan.py CHANGED
@@ -86,7 +86,18 @@ class Posterior_Coefficients():
86
  self.posterior_mean_coef2 = ((1 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1 - self.alphas_cumprod))
87
 
88
  self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20))
89
-
 
 
 
 
 
 
 
 
 
 
 
90
  def sample_posterior(coefficients, x_0,x_t, t):
91
 
92
  def q_posterior(x_0, x_t, t):
@@ -150,10 +161,10 @@ def sample_from_model_classifier_free_guidance(coefficients, generator, n_time,
150
  # eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
151
  eps = eps_uncond * (1 - guidance_scale) + eps_cond * guidance_scale
152
  x_0 = (1/torch.sqrt(coefficients.alphas_cumprod[i])) * (x - torch.sqrt(1 - coefficients.alphas_cumprod[i]) * eps)
153
-
154
 
155
  # Dynamic thresholding
156
- q = args.dynamic_thresholding_quantile
157
  #print("Before", x_0.min(), x_0.max())
158
  if q:
159
  shape = x_0.shape
@@ -180,9 +191,174 @@ def sample_from_model_classifier_free_guidance(coefficients, generator, n_time,
180
  return x
181
 
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  #%%
184
  def sample_and_test(args):
185
  torch.manual_seed(args.seed)
 
186
  device = 'cuda:0'
187
  text_encoder =build_encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
188
  args.cond_size = text_encoder.output_size
@@ -197,10 +373,9 @@ def sample_and_test(args):
197
 
198
  to_range_0_1 = lambda x: (x + 1.) / 2.
199
 
200
-
201
  netG = NCSNpp(args).to(device)
202
- netG.attn_resolutions = [r * args.scale_factor_w for r in netG.attn_resolutions]
203
-
204
  if args.epoch_id == -1:
205
  epochs = range(1000)
206
  else:
@@ -209,17 +384,27 @@ def sample_and_test(args):
209
  for epoch in epochs:
210
  args.epoch_id = epoch
211
  path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
 
212
  if not os.path.exists(path):
213
  continue
 
 
 
 
 
214
  ckpt = torch.load(path, map_location=device)
215
- dest = './saved_info/dd_gan/{}/{}/eval_{}.json'.format(args.dataset, args.exp, args.epoch_id)
 
 
216
 
217
- if args.compute_fid and os.path.exists(dest):
218
  continue
219
  print("Eval Epoch", args.epoch_id)
220
  #loading weights from ddp in single gpu
 
221
  for key in list(ckpt.keys()):
222
- ckpt[key[7:]] = ckpt.pop(key)
 
223
  netG.load_state_dict(ckpt)
224
  netG.eval()
225
 
@@ -234,7 +419,7 @@ def sample_and_test(args):
234
  if not os.path.exists(save_dir):
235
  os.makedirs(save_dir)
236
 
237
- if args.compute_fid:
238
  from torch.nn.functional import adaptive_avg_pool2d
239
  from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
240
  from pytorch_fid.inception import InceptionV3
@@ -252,9 +437,11 @@ def sample_and_test(args):
252
  print("Text size:", len(texts))
253
  #print("Iters:", iters_needed)
254
  i = 0
255
- dims = 2048
256
- block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
257
- inceptionv3 = InceptionV3([block_idx]).to(device)
 
 
258
 
259
  if args.compute_clip_score:
260
  import clip
@@ -264,19 +451,20 @@ def sample_and_test(args):
264
  clip_mean = torch.Tensor(CLIP_MEAN).view(1,-1,1,1).to(device)
265
  clip_std = torch.Tensor(CLIP_STD).view(1,-1,1,1).to(device)
266
 
267
-
268
- if not args.real_img_dir.endswith("npz"):
269
- real_mu, real_sigma = compute_statistics_of_path(
270
- args.real_img_dir, inceptionv3, args.batch_size, dims, device,
271
- resize=args.image_size,
272
- )
273
- np.savez("inception_statistics.npz", mu=real_mu, sigma=real_sigma)
274
- else:
275
- stats = np.load(args.real_img_dir)
276
- real_mu = stats['mu']
277
- real_sigma = stats['sigma']
278
-
279
- fake_features = []
 
280
  if args.compute_clip_score:
281
  clip_scores = []
282
 
@@ -287,7 +475,6 @@ def sample_and_test(args):
287
  bs = len(text)
288
  t0 = time.time()
289
  x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
290
- #print(x_t_1.shape)
291
  if args.guidance_scale:
292
  fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
293
  else:
@@ -298,45 +485,39 @@ def sample_and_test(args):
298
  index = i * args.batch_size + j
299
  torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
300
  """
301
- with torch.no_grad():
302
- pred = inceptionv3(fake_sample)[0]
303
- # If model output is not scalar, apply global spatial average pooling.
304
- # This happens if you choose a dimensionality not equal 2048.
305
- if pred.size(2) != 1 or pred.size(3) != 1:
306
- pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
307
- pred = pred.squeeze(3).squeeze(2).cpu().numpy()
308
- fake_features.append(pred)
 
 
309
 
310
  if args.compute_clip_score:
311
  with torch.no_grad():
312
  clip_ims = torch.nn.functional.interpolate(fake_sample, (224, 224), mode="bicubic")
313
- clip_txt = clip.tokenize(text).to(device)
 
314
  imf = clip_model.encode_image(clip_ims)
315
  txtf = clip_model.encode_text(clip_txt)
316
  imf = torch.nn.functional.normalize(imf, dim=1)
317
  txtf = torch.nn.functional.normalize(txtf, dim=1)
318
  clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
319
- break
320
  if i % 10 == 0:
321
- print('generating batch ', i, time.time() - t0)
322
- """
323
- if i % 10 == 0:
324
- ff = np.concatenate(fake_features)
325
- fake_mu = np.mean(ff, axis=0)
326
- fake_sigma = np.cov(ff, rowvar=False)
327
- fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
328
- print("FID", fid)
329
- """
330
  i += 1
331
 
332
- fake_features = np.concatenate(fake_features)
333
- fake_mu = np.mean(fake_features, axis=0)
334
- fake_sigma = np.cov(fake_features, rowvar=False)
335
- fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
336
- dest = './saved_info/dd_gan/{}/{}/eval_{}.json'.format(args.dataset, args.exp, args.epoch_id)
337
- results = {
338
- "fid": fid,
339
- }
340
  if args.compute_clip_score:
341
  clip_score = torch.cat(clip_scores).mean().item()
342
  results['clip_score'] = clip_score
@@ -344,22 +525,54 @@ def sample_and_test(args):
344
  with open(dest, "w") as fd:
345
  json.dump(results, fd)
346
  print(results)
347
- else:
348
  if args.cond_text.endswith(".txt"):
349
  texts = open(args.cond_text).readlines()
350
  texts = [t.strip() for t in texts]
351
  else:
352
  texts = [args.cond_text] * args.batch_size
353
- cond = text_encoder(texts, return_only_pooled=False)
354
- x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size*args.scale_factor_h, args.image_size*args.scale_factor_w).to(device)
355
- t0 = time.time()
356
- if args.guidance_scale:
357
- fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
 
 
 
 
 
358
  else:
359
- fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
360
- print(time.time() - t0)
361
- fake_sample = to_range_0_1(fake_sample)
362
- torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
 
365
 
@@ -374,6 +587,7 @@ if __name__ == '__main__':
374
  parser.add_argument('--compute_clip_score', action='store_true', default=False,
375
  help='whether or not compute CLIP score')
376
  parser.add_argument('--clip_model', type=str,default="ViT-L/14")
 
377
 
378
  parser.add_argument('--epoch_id', type=int,default=1000)
379
  parser.add_argument('--guidance_scale', type=float,default=0)
@@ -381,6 +595,8 @@ if __name__ == '__main__':
381
  parser.add_argument('--cond_text', type=str,default="0")
382
  parser.add_argument('--scale_factor_h', type=int,default=1)
383
  parser.add_argument('--scale_factor_w', type=int,default=1)
 
 
384
  parser.add_argument('--cross_attention', action='store_true',default=False)
385
 
386
 
 
86
  self.posterior_mean_coef2 = ((1 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1 - self.alphas_cumprod))
87
 
88
  self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20))
89
+
90
+ def predict_q_posterior(coefficients, x_0, x_t, t):
91
+ mean = (
92
+ extract(coefficients.posterior_mean_coef1, t, x_t.shape) * x_0
93
+ + extract(coefficients.posterior_mean_coef2, t, x_t.shape) * x_t
94
+ )
95
+ var = extract(coefficients.posterior_variance, t, x_t.shape)
96
+ log_var_clipped = extract(coefficients.posterior_log_variance_clipped, t, x_t.shape)
97
+ return mean, var, log_var_clipped
98
+
99
+
100
+
101
  def sample_posterior(coefficients, x_0,x_t, t):
102
 
103
  def q_posterior(x_0, x_t, t):
 
161
  # eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
162
  eps = eps_uncond * (1 - guidance_scale) + eps_cond * guidance_scale
163
  x_0 = (1/torch.sqrt(coefficients.alphas_cumprod[i])) * (x - torch.sqrt(1 - coefficients.alphas_cumprod[i]) * eps)
164
+ #x_0 = x_0_uncond * (1 - guidance_scale) + x_0_cond * guidance_scale
165
 
166
  # Dynamic thresholding
167
+ q = opt.dynamic_thresholding_quantile
168
  #print("Before", x_0.min(), x_0.max())
169
  if q:
170
  shape = x_0.shape
 
191
  return x
192
 
193
 
194
+ def sample_from_model_classifier_free_guidance_convolutional(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0, split_input_params=None):
195
+ x = x_init
196
+ null = text_encoder([""] * len(x_init), return_only_pooled=False)
197
+ #latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
198
+ ks = split_input_params["ks"] # eg. (128, 128)
199
+ stride = split_input_params["stride"] # eg. (64, 64)
200
+ uf = split_input_params["vqf"]
201
+ with torch.no_grad():
202
+ for i in reversed(range(n_time)):
203
+ t = torch.full((x.size(0),), i, dtype=torch.int64).to(x.device)
204
+ t_time = t
205
+ latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
206
+
207
+ fold, unfold, normalization, weighting = get_fold_unfold(x, ks, stride, split_input_params, uf=uf)
208
+ x = unfold(x)
209
+ x = x.view((x.shape[0], -1, ks[0], ks[1], x.shape[-1]))
210
+ x_new_list = []
211
+ for j in range(x.shape[-1]):
212
+ x_0_uncond = generator(x[:,:,:,:,j], t_time, latent_z, cond=null)
213
+ x_0_cond = generator(x[:,:,:,:,j], t_time, latent_z, cond=cond)
214
+
215
+ eps_uncond = (x[:,:,:,:,j] - torch.sqrt(coefficients.alphas_cumprod[i]) * x_0_uncond) / torch.sqrt(1 - coefficients.alphas_cumprod[i])
216
+ eps_cond = (x[:,:,:,:,j] - torch.sqrt(coefficients.alphas_cumprod[i]) * x_0_cond) / torch.sqrt(1 - coefficients.alphas_cumprod[i])
217
+
218
+ eps = eps_uncond * (1 - guidance_scale) + eps_cond * guidance_scale
219
+ x_0 = (1/torch.sqrt(coefficients.alphas_cumprod[i])) * (x[:,:,:,:,j] - torch.sqrt(1 - coefficients.alphas_cumprod[i]) * eps)
220
+ q = args.dynamic_thresholding_quantile
221
+ if q:
222
+ shape = x_0.shape
223
+ x_0_v = x_0.view(shape[0], -1)
224
+ d = torch.quantile(torch.abs(x_0_v), q, dim=1, keepdim=True)
225
+ d.clamp_(min=1)
226
+ x_0_v = x_0_v.clamp(-d, d) / d
227
+ x_0 = x_0_v.view(shape)
228
+ x_new = sample_posterior(coefficients, x_0, x[:,:,:,:,j], t)
229
+ x_new_list.append(x_new)
230
+
231
+ o = torch.stack(x_new_list, axis=-1)
232
+ #o = o * weighting
233
+ o = o.view((o.shape[0], -1, o.shape[-1]))
234
+ decoded = fold(o)
235
+ decoded = decoded / normalization
236
+ x = decoded.detach()
237
+
238
+ return x
239
+
240
+ def sample_from_model_clip_guidance(coefficients, generator, clip_model, n_time, x_init, T, opt, texts, cond=None, guidance_scale=0):
241
+ x = x_init
242
+ text_features = torch.nn.functional.normalize(clip_model.forward_text(texts), dim=1)
243
+ n_time = 16
244
+ for i in reversed(range(n_time)):
245
+ t = torch.full((x.size(0),), i%4, dtype=torch.int64).to(x.device)
246
+ t_time = t
247
+ latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
248
+ x.requires_grad = True
249
+ x_0 = generator(x, t_time, latent_z, cond=cond)
250
+ x_new = sample_posterior(coefficients, x_0, x, t)
251
+ x_new_n = (x_new + 1) / 2
252
+ image_features = torch.nn.functional.normalize(clip_model.forward_image(x_new_n), dim=1)
253
+ loss = (image_features*text_features).sum(dim=1).mean()
254
+ x_grad, = torch.autograd.grad(loss, x)
255
+ lr = 3000
256
+ x = x.detach()
257
+ print(x.min(),x.max(), lr*x_grad.min(), lr*x_grad.max())
258
+ x += x_grad * lr
259
+
260
+ with torch.no_grad():
261
+ x_0 = generator(x, t_time, latent_z, cond=cond)
262
+ x_new = sample_posterior(coefficients, x_0, x, t)
263
+
264
+ x = x_new.detach()
265
+ print(i)
266
+ return x
267
+
268
+ def meshgrid(h, w):
269
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
270
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
271
+
272
+ arr = torch.cat([y, x], dim=-1)
273
+ return arr
274
+ def delta_border(h, w):
275
+ """
276
+ :param h: height
277
+ :param w: width
278
+ :return: normalized distance to image border,
279
+ wtith min distance = 0 at border and max dist = 0.5 at image center
280
+ """
281
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
282
+ arr = meshgrid(h, w) / lower_right_corner
283
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
284
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
285
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
286
+ return edge_dist
287
+
288
+ def get_weighting(h, w, Ly, Lx, device, split_input_params):
289
+ weighting = delta_border(h, w)
290
+ weighting = torch.clip(weighting, split_input_params["clip_min_weight"],
291
+ split_input_params["clip_max_weight"], )
292
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
293
+
294
+ if split_input_params["tie_braker"]:
295
+ L_weighting = delta_border(Ly, Lx)
296
+ L_weighting = torch.clip(L_weighting,
297
+ split_input_params["clip_min_tie_weight"],
298
+ split_input_params["clip_max_tie_weight"])
299
+
300
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
301
+ weighting = weighting * L_weighting
302
+ return weighting
303
+
304
+ def get_fold_unfold(x, kernel_size, stride, split_input_params, uf=1, df=1): # todo load once not every time, shorten code
305
+ """
306
+ :param x: img of size (bs, c, h, w)
307
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
308
+ """
309
+ bs, nc, h, w = x.shape
310
+
311
+ # number of crops in image
312
+ Ly = (h - kernel_size[0]) // stride[0] + 1
313
+ Lx = (w - kernel_size[1]) // stride[1] + 1
314
+
315
+ if uf == 1 and df == 1:
316
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
317
+ unfold = torch.nn.Unfold(**fold_params)
318
+
319
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
320
+
321
+ weighting = get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device, split_input_params).to(x.dtype)
322
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
323
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
324
+
325
+ elif uf > 1 and df == 1:
326
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
327
+ unfold = torch.nn.Unfold(**fold_params)
328
+
329
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
330
+ dilation=1, padding=0,
331
+ stride=(stride[0] * uf, stride[1] * uf))
332
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
333
+
334
+ weighting = get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device, split_input_params).to(x.dtype)
335
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
336
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
337
+
338
+ elif df > 1 and uf == 1:
339
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
340
+ unfold = torch.nn.Unfold(**fold_params)
341
+
342
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
343
+ dilation=1, padding=0,
344
+ stride=(stride[0] // df, stride[1] // df))
345
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
346
+
347
+ weighting = get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device, split_input_params).to(x.dtype)
348
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
349
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
350
+
351
+ else:
352
+ raise NotImplementedError
353
+
354
+ return fold, unfold, normalization, weighting
355
+
356
+
357
+
358
  #%%
359
  def sample_and_test(args):
360
  torch.manual_seed(args.seed)
361
+
362
  device = 'cuda:0'
363
  text_encoder =build_encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
364
  args.cond_size = text_encoder.output_size
 
373
 
374
  to_range_0_1 = lambda x: (x + 1.) / 2.
375
 
376
+ print(vars(args))
377
  netG = NCSNpp(args).to(device)
378
+
 
379
  if args.epoch_id == -1:
380
  epochs = range(1000)
381
  else:
 
384
  for epoch in epochs:
385
  args.epoch_id = epoch
386
  path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
387
+ next_path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id+1)
388
  if not os.path.exists(path):
389
  continue
390
+ print(path)
391
+
392
+ #if not os.path.exists(next_path):
393
+ # print(f"STOP at {epoch}")
394
+ # break
395
  ckpt = torch.load(path, map_location=device)
396
+ suffix = '_' + args.eval_name if args.eval_name else ""
397
+ dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id, suffix)
398
+ next_dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id+1, suffix)
399
 
400
+ if (args.compute_fid or args.compute_clip_score) and os.path.exists(dest):
401
  continue
402
  print("Eval Epoch", args.epoch_id)
403
  #loading weights from ddp in single gpu
404
+ #print(ckpt.keys())
405
  for key in list(ckpt.keys()):
406
+ if key.startswith("module"):
407
+ ckpt[key[7:]] = ckpt.pop(key)
408
  netG.load_state_dict(ckpt)
409
  netG.eval()
410
 
 
419
  if not os.path.exists(save_dir):
420
  os.makedirs(save_dir)
421
 
422
+ if args.compute_fid or args.compute_clip_score:
423
  from torch.nn.functional import adaptive_avg_pool2d
424
  from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
425
  from pytorch_fid.inception import InceptionV3
 
437
  print("Text size:", len(texts))
438
  #print("Iters:", iters_needed)
439
  i = 0
440
+
441
+ if args.compute_fid:
442
+ dims = 2048
443
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
444
+ inceptionv3 = InceptionV3([block_idx]).to(device)
445
 
446
  if args.compute_clip_score:
447
  import clip
 
451
  clip_mean = torch.Tensor(CLIP_MEAN).view(1,-1,1,1).to(device)
452
  clip_std = torch.Tensor(CLIP_STD).view(1,-1,1,1).to(device)
453
 
454
+ if args.compute_fid:
455
+ if not args.real_img_dir.endswith("npz"):
456
+ real_mu, real_sigma = compute_statistics_of_path(
457
+ args.real_img_dir, inceptionv3, args.batch_size, dims, device,
458
+ resize=args.image_size,
459
+ )
460
+ np.savez("inception_statistics.npz", mu=real_mu, sigma=real_sigma)
461
+ else:
462
+ stats = np.load(args.real_img_dir)
463
+ real_mu = stats['mu']
464
+ real_sigma = stats['sigma']
465
+
466
+ fake_features = []
467
+
468
  if args.compute_clip_score:
469
  clip_scores = []
470
 
 
475
  bs = len(text)
476
  t0 = time.time()
477
  x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
 
478
  if args.guidance_scale:
479
  fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
480
  else:
 
485
  index = i * args.batch_size + j
486
  torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
487
  """
488
+
489
+ if args.compute_fid:
490
+ with torch.no_grad():
491
+ pred = inceptionv3(fake_sample)[0]
492
+ # If model output is not scalar, apply global spatial average pooling.
493
+ # This happens if you choose a dimensionality not equal 2048.
494
+ if pred.size(2) != 1 or pred.size(3) != 1:
495
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
496
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
497
+ fake_features.append(pred)
498
 
499
  if args.compute_clip_score:
500
  with torch.no_grad():
501
  clip_ims = torch.nn.functional.interpolate(fake_sample, (224, 224), mode="bicubic")
502
+ clip_ims = (clip_ims - clip_mean) / clip_std
503
+ clip_txt = clip.tokenize(text, truncate=True).to(device)
504
  imf = clip_model.encode_image(clip_ims)
505
  txtf = clip_model.encode_text(clip_txt)
506
  imf = torch.nn.functional.normalize(imf, dim=1)
507
  txtf = torch.nn.functional.normalize(txtf, dim=1)
508
  clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
509
+
510
  if i % 10 == 0:
511
+ print('evaluating batch ', i, time.time() - t0)
 
 
 
 
 
 
 
 
512
  i += 1
513
 
514
+ results = {}
515
+ if args.compute_fid:
516
+ fake_features = np.concatenate(fake_features)
517
+ fake_mu = np.mean(fake_features, axis=0)
518
+ fake_sigma = np.cov(fake_features, rowvar=False)
519
+ fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
520
+ results['fid'] = fid
 
521
  if args.compute_clip_score:
522
  clip_score = torch.cat(clip_scores).mean().item()
523
  results['clip_score'] = clip_score
 
525
  with open(dest, "w") as fd:
526
  json.dump(results, fd)
527
  print(results)
528
+ else:
529
  if args.cond_text.endswith(".txt"):
530
  texts = open(args.cond_text).readlines()
531
  texts = [t.strip() for t in texts]
532
  else:
533
  texts = [args.cond_text] * args.batch_size
534
+ clip_guidance = False
535
+ if clip_guidance:
536
+ from clip_encoder import CLIPImageEncoder
537
+ cond = text_encoder(texts, return_only_pooled=False)
538
+ clip_image_model = CLIPImageEncoder().to(device)
539
+ x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size*args.scale_factor_h, args.image_size*args.scale_factor_w).to(device)
540
+ fake_sample = sample_from_model_clip_guidance(pos_coeff, netG, clip_image_model, args.num_timesteps, x_t_1,T, args, texts, cond=cond, guidance_scale=args.guidance_scale)
541
+ fake_sample = to_range_0_1(fake_sample)
542
+ torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
543
+
544
  else:
545
+ cond = text_encoder(texts, return_only_pooled=False)
546
+ x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size*args.scale_factor_h, args.image_size*args.scale_factor_w).to(device)
547
+ t0 = time.time()
548
+ if args.guidance_scale:
549
+ if args.scale_factor_h > 1 or args.scale_factor_w > 1:
550
+ if args.scale_method == "convolutional":
551
+ split_input_params = {
552
+ "ks": (args.image_size, args.image_size),
553
+ "stride": (150, 150),
554
+ "clip_max_tie_weight": 0.5,
555
+ "clip_min_tie_weight": 0.01,
556
+ "clip_max_weight": 0.5,
557
+ "clip_min_weight": 0.01,
558
+
559
+ "tie_braker": True,
560
+ 'vqf': 1,
561
+ }
562
+ fake_sample = sample_from_model_classifier_free_guidance_convolutional(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale, split_input_params=split_input_params)
563
+ elif args.scale_method == "larger_input":
564
+ netG.attn_resolutions = [r * args.scale_factor_w for r in netG.attn_resolutions]
565
+ fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
566
+ else:
567
+ fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
568
+ else:
569
+ fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
570
+
571
+ print(time.time() - t0)
572
+ fake_sample = to_range_0_1(fake_sample)
573
+ torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
574
+
575
+
576
 
577
 
578
 
 
587
  parser.add_argument('--compute_clip_score', action='store_true', default=False,
588
  help='whether or not compute CLIP score')
589
  parser.add_argument('--clip_model', type=str,default="ViT-L/14")
590
+ parser.add_argument('--eval_name', type=str,default="")
591
 
592
  parser.add_argument('--epoch_id', type=int,default=1000)
593
  parser.add_argument('--guidance_scale', type=float,default=0)
 
595
  parser.add_argument('--cond_text', type=str,default="0")
596
  parser.add_argument('--scale_factor_h', type=int,default=1)
597
  parser.add_argument('--scale_factor_w', type=int,default=1)
598
+ parser.add_argument('--scale_method', type=str,default="convolutional")
599
+
600
  parser.add_argument('--cross_attention', action='store_true',default=False)
601
 
602
 
train_ddgan.py CHANGED
@@ -5,7 +5,7 @@
5
  # for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.
6
  # ---------------------------------------------------------------
7
 
8
-
9
  import argparse
10
  import torch
11
  import numpy as np
@@ -30,6 +30,7 @@ import shutil
30
  import logging
31
  from encoder import build_encoder
32
  from utils import ResampledShards2
 
33
 
34
 
35
  def log_and_continue(exn):
@@ -194,23 +195,29 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None
194
 
195
  return x
196
 
197
-
198
 
199
  def filter_no_caption(sample):
200
  return 'txt' in sample
201
 
202
-
 
 
 
 
 
 
203
 
204
  def train(rank, gpu, args):
205
  from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
206
  from score_sde.models.ncsnpp_generator_adagn import NCSNpp
207
  from EMA import EMA
208
 
209
- torch.manual_seed(args.seed + rank)
210
- torch.cuda.manual_seed(args.seed + rank)
211
- torch.cuda.manual_seed_all(args.seed + rank)
212
  device = "cuda"
213
-
214
  batch_size = args.batch_size
215
 
216
  nz = args.nz #latent dimension
@@ -270,11 +277,12 @@ def train(rank, gpu, args):
270
  ])
271
  elif args.preprocessing == "random_resized_crop_v1":
272
  train_transform = transforms.Compose([
273
- transforms.RandomResizedCrop(256, scale=(0.95, 1.0), interpolation=3),
274
  transforms.ToTensor(),
275
  transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
276
  ])
277
- pipeline = [ResampledShards2(args.dataset_root)]
 
278
  pipeline.extend([
279
  wds.split_by_node,
280
  wds.split_by_worker,
@@ -339,6 +347,13 @@ def train(rank, gpu, args):
339
  t_emb_dim = args.t_emb_dim,
340
  cond_size=text_encoder.output_size,
341
  act=nn.LeakyReLU(0.2)).to(device)
 
 
 
 
 
 
 
342
  elif args.discr_type == "large_cond_attn":
343
  netD = CondAttnDiscriminator(
344
  nc = 2*args.num_channels,
@@ -350,6 +365,15 @@ def train(rank, gpu, args):
350
  broadcast_params(netG.parameters())
351
  broadcast_params(netD.parameters())
352
 
 
 
 
 
 
 
 
 
 
353
  optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))
354
  optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
355
 
@@ -358,9 +382,16 @@ def train(rank, gpu, args):
358
 
359
  schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_epoch, eta_min=1e-5)
360
  schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_epoch, eta_min=1e-5)
 
 
 
 
 
 
361
 
362
- netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
363
- netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
 
364
 
365
  exp = args.exp
366
  parent_dir = "./saved_info/dd_gan/{}".format(args.dataset)
@@ -377,6 +408,10 @@ def train(rank, gpu, args):
377
  T = get_time_schedule(args, device)
378
 
379
  checkpoint_file = os.path.join(exp_path, 'content.pth')
 
 
 
 
380
  if args.resume and os.path.exists(checkpoint_file):
381
  checkpoint = torch.load(checkpoint_file, map_location="cpu")
382
  init_epoch = checkpoint['epoch']
@@ -395,7 +430,7 @@ def train(rank, gpu, args):
395
  .format(checkpoint['epoch']))
396
  else:
397
  global_step, epoch, init_epoch = 0, 0, 0
398
- use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn")
399
  for epoch in range(init_epoch, args.num_epoch+1):
400
  if args.dataset == "wds":
401
  os.environ["WDS_EPOCH"] = str(epoch)
@@ -403,6 +438,7 @@ def train(rank, gpu, args):
403
  train_sampler.set_epoch(epoch)
404
 
405
  for iteration, (x, y) in enumerate(data_loader):
 
406
  if args.dataset != "wds":
407
  y = [str(yi) for yi in y.tolist()]
408
 
@@ -437,15 +473,15 @@ def train(rank, gpu, args):
437
  cond_for_discr.requires_grad = True
438
 
439
  # train with real
440
- D_real = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
441
-
442
- errD_real = F.softplus(-D_real)
443
- errD_real = errD_real.mean()
444
 
445
 
446
  errD_real.backward(retain_graph=True)
447
 
448
-
449
  if args.lazy_reg is None:
450
  if args.grad_penalty_cond:
451
  inputs = (x_t,) + (cond,) if use_cond_attn_discr else (cond_for_discr,)
@@ -491,26 +527,36 @@ def train(rank, gpu, args):
491
 
492
  # train with fake
493
  latent_z = torch.randn(batch_size, nz, device=device)
494
-
495
- x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
496
- x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
497
-
498
- output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
 
 
 
 
 
 
 
499
 
500
-
501
- errD_fake = F.softplus(output)
502
- errD_fake = errD_fake.mean()
 
 
503
 
504
  if args.mismatch_loss:
505
  # following https://github.com/tobran/DF-GAN/blob/bc38a4f795c294b09b4ef5579cd4ff78807e5b96/code/lib/modules.py,
506
  # we add a discr loss for (real image, non matching text)
507
  #inds = torch.flip(torch.arange(len(x_t)), dims=(0,))
508
- inds = torch.cat([torch.arange(1,len(x_t)),torch.arange(1)])
509
- cond_for_discr_mis = (cond_pooled[inds], cond[inds], cond_mask[inds]) if use_cond_attn_discr else cond_pooled[inds]
510
- D_real_mis = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr_mis).view(-1)
511
- errD_real_mis = F.softplus(D_real_mis)
512
- errD_real_mis = errD_real_mis.mean()
513
- errD_fake = errD_fake * 0.5 + errD_real_mis * 0.5
 
514
 
515
  errD_fake.backward()
516
 
@@ -534,58 +580,106 @@ def train(rank, gpu, args):
534
 
535
  latent_z = torch.randn(batch_size, nz,device=device)
536
 
537
-
 
 
 
 
 
 
 
 
 
 
 
538
 
539
- x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
540
- x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
541
-
542
- output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
543
-
544
-
545
- errG = F.softplus(-output)
546
- errG = errG.mean()
547
 
548
  errG.backward()
549
  optimizerG.step()
550
 
551
-
 
 
 
 
552
 
553
  global_step += 1
 
 
554
  if iteration % 100 == 0:
555
  if rank == 0:
556
  print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
 
557
  if iteration % 1000 == 0:
558
  x_t_1 = torch.randn_like(real_data)
559
- fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
 
560
  if rank == 0:
561
  torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}_iteration_{}.png'.format(epoch, iteration)), normalize=True)
562
- if args.save_content:
563
- print('Saving content.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
  content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
565
- 'netG_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
566
- 'schedulerG': schedulerG.state_dict(), 'netD_dict': netD.state_dict(),
567
- 'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
568
-
569
- torch.save(content, os.path.join(exp_path, 'content.pth'))
570
- torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
571
- if args.use_ema:
572
- optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
573
-
574
- torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
575
- if args.use_ema:
576
- optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
 
578
  if not args.no_lr_decay:
579
 
580
  schedulerG.step()
581
  schedulerD.step()
582
-
583
  if rank == 0:
584
  if epoch % 10 == 0:
585
  torchvision.utils.save_image(x_pos_sample, os.path.join(exp_path, 'xpos_epoch_{}.png'.format(epoch)), normalize=True)
586
 
587
  x_t_1 = torch.randn_like(real_data)
588
- fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
 
589
  torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}.png'.format(epoch)), normalize=True)
590
 
591
  if args.save_content:
@@ -606,7 +700,8 @@ def train(rank, gpu, args):
606
  torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
607
  if args.use_ema:
608
  optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
609
-
 
610
 
611
 
612
  def init_processes(rank, size, fn, args):
@@ -641,6 +736,8 @@ if __name__ == '__main__':
641
  parser.add_argument('--mismatch_loss', action='store_true',default=False)
642
  parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
643
  parser.add_argument('--cross_attention', action='store_true',default=False)
 
 
644
 
645
  parser.add_argument('--image_size', type=int, default=32,
646
  help='size of image')
@@ -728,6 +825,7 @@ if __name__ == '__main__':
728
  parser.add_argument('--save_ckpt_every', type=int, default=25, help='save ckpt every x epochs')
729
  parser.add_argument('--discr_type', type=str, default="large")
730
  parser.add_argument('--preprocessing', type=str, default="resize")
 
731
 
732
  ###ddp
733
  parser.add_argument('--num_proc_node', type=int, default=1,
@@ -746,4 +844,4 @@ if __name__ == '__main__':
746
  args.world_size = int(os.getenv("SLURM_NTASKS"))
747
  args.rank = int(os.environ['SLURM_PROCID'])
748
  # size = args.num_process_per_node
749
- init_processes(args.rank, args.world_size, train, args)
 
5
  # for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.
6
  # ---------------------------------------------------------------
7
 
8
+ from glob import glob
9
  import argparse
10
  import torch
11
  import numpy as np
 
30
  import logging
31
  from encoder import build_encoder
32
  from utils import ResampledShards2
33
+ from torch.utils.tensorboard import SummaryWriter
34
 
35
 
36
  def log_and_continue(exn):
 
195
 
196
  return x
197
 
198
+ from contextlib import suppress
199
 
200
  def filter_no_caption(sample):
201
  return 'txt' in sample
202
 
203
+ def get_autocast(precision):
204
+ if precision == 'amp':
205
+ return torch.cuda.amp.autocast
206
+ elif precision == 'amp_bfloat16':
207
+ return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
208
+ else:
209
+ return suppress
210
 
211
  def train(rank, gpu, args):
212
  from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
213
  from score_sde.models.ncsnpp_generator_adagn import NCSNpp
214
  from EMA import EMA
215
 
216
+ #torch.manual_seed(args.seed + rank)
217
+ #torch.cuda.manual_seed(args.seed + rank)
218
+ #torch.cuda.manual_seed_all(args.seed + rank)
219
  device = "cuda"
220
+ autocast = get_autocast(args.precision)
221
  batch_size = args.batch_size
222
 
223
  nz = args.nz #latent dimension
 
277
  ])
278
  elif args.preprocessing == "random_resized_crop_v1":
279
  train_transform = transforms.Compose([
280
+ transforms.RandomResizedCrop(args.image_size, scale=(0.95, 1.0), interpolation=3),
281
  transforms.ToTensor(),
282
  transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
283
  ])
284
+ shards = glob(os.path.join(args.dataset_root, "*.tar")) if os.path.isdir(args.dataset_root) else args.dataset_root
285
+ pipeline = [ResampledShards2(shards)]
286
  pipeline.extend([
287
  wds.split_by_node,
288
  wds.split_by_worker,
 
347
  t_emb_dim = args.t_emb_dim,
348
  cond_size=text_encoder.output_size,
349
  act=nn.LeakyReLU(0.2)).to(device)
350
+ elif args.discr_type == "large_attn_pool":
351
+ netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
352
+ t_emb_dim = args.t_emb_dim,
353
+ cond_size=text_encoder.output_size,
354
+ attn_pool=True,
355
+ act=nn.LeakyReLU(0.2)).to(device)
356
+
357
  elif args.discr_type == "large_cond_attn":
358
  netD = CondAttnDiscriminator(
359
  nc = 2*args.num_channels,
 
365
  broadcast_params(netG.parameters())
366
  broadcast_params(netD.parameters())
367
 
368
+ if args.fsdp:
369
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
370
+ from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
371
+ netG = FSDP(
372
+ netG,
373
+ flatten_parameters=True,
374
+ verbose=True,
375
+ )
376
+
377
  optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))
378
  optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
379
 
 
382
 
383
  schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_epoch, eta_min=1e-5)
384
  schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_epoch, eta_min=1e-5)
385
+
386
+ if args.fsdp:
387
+ netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
388
+ else:
389
+ netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
390
+ netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
391
 
392
+ if args.grad_checkpointing:
393
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
394
+ netG = checkpoint_wrapper(netG)
395
 
396
  exp = args.exp
397
  parent_dir = "./saved_info/dd_gan/{}".format(args.dataset)
 
408
  T = get_time_schedule(args, device)
409
 
410
  checkpoint_file = os.path.join(exp_path, 'content.pth')
411
+
412
+ if rank == 0:
413
+ log_writer = SummaryWriter(exp_path)
414
+
415
  if args.resume and os.path.exists(checkpoint_file):
416
  checkpoint = torch.load(checkpoint_file, map_location="cpu")
417
  init_epoch = checkpoint['epoch']
 
430
  .format(checkpoint['epoch']))
431
  else:
432
  global_step, epoch, init_epoch = 0, 0, 0
433
+ use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn", "large_attn_pool")
434
  for epoch in range(init_epoch, args.num_epoch+1):
435
  if args.dataset == "wds":
436
  os.environ["WDS_EPOCH"] = str(epoch)
 
438
  train_sampler.set_epoch(epoch)
439
 
440
  for iteration, (x, y) in enumerate(data_loader):
441
+ #print(x.shape)
442
  if args.dataset != "wds":
443
  y = [str(yi) for yi in y.tolist()]
444
 
 
473
  cond_for_discr.requires_grad = True
474
 
475
  # train with real
476
+ with autocast():
477
+ D_real = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
478
+ errD_real = F.softplus(-D_real)
479
+ errD_real = errD_real.mean()
480
 
481
 
482
  errD_real.backward(retain_graph=True)
483
 
484
+ grad_penalty = None
485
  if args.lazy_reg is None:
486
  if args.grad_penalty_cond:
487
  inputs = (x_t,) + (cond,) if use_cond_attn_discr else (cond_for_discr,)
 
527
 
528
  # train with fake
529
  latent_z = torch.randn(batch_size, nz, device=device)
530
+ with autocast():
531
+ if args.grad_checkpointing:
532
+ ginp = x_tp1.detach()
533
+ ginp.requires_grad = True
534
+ latent_z.requires_grad = True
535
+ cond_pooled.requires_grad = True
536
+ cond.requires_grad = True
537
+ #cond_mask.requires_grad = True
538
+ x_0_predict = netG(ginp, t, latent_z, cond=(cond_pooled, cond, cond_mask))
539
+ else:
540
+ x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
541
+ x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
542
 
543
+ output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
544
+
545
+
546
+ errD_fake = F.softplus(output)
547
+ errD_fake = errD_fake.mean()
548
 
549
  if args.mismatch_loss:
550
  # following https://github.com/tobran/DF-GAN/blob/bc38a4f795c294b09b4ef5579cd4ff78807e5b96/code/lib/modules.py,
551
  # we add a discr loss for (real image, non matching text)
552
  #inds = torch.flip(torch.arange(len(x_t)), dims=(0,))
553
+ with autocast():
554
+ inds = torch.cat([torch.arange(1,len(x_t)),torch.arange(1)])
555
+ cond_for_discr_mis = (cond_pooled[inds], cond[inds], cond_mask[inds]) if use_cond_attn_discr else cond_pooled[inds]
556
+ D_real_mis = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr_mis).view(-1)
557
+ errD_real_mis = F.softplus(D_real_mis)
558
+ errD_real_mis = errD_real_mis.mean()
559
+ errD_fake = errD_fake * 0.5 + errD_real_mis * 0.5
560
 
561
  errD_fake.backward()
562
 
 
580
 
581
  latent_z = torch.randn(batch_size, nz,device=device)
582
 
583
+ with autocast():
584
+ if args.grad_checkpointing:
585
+ ginp = x_tp1.detach()
586
+ ginp.requires_grad = True
587
+ latent_z.requires_grad = True
588
+ cond_pooled.requires_grad = True
589
+ cond.requires_grad = True
590
+ #cond_mask.requires_grad = True
591
+ x_0_predict = netG(ginp, t, latent_z, cond=(cond_pooled, cond, cond_mask))
592
+ else:
593
+ x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
594
+ x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
595
 
596
+ output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
597
+
598
+
599
+ errG = F.softplus(-output)
600
+ errG = errG.mean()
 
 
 
601
 
602
  errG.backward()
603
  optimizerG.step()
604
 
605
+ if (iteration % 10 == 0) and (rank == 0):
606
+ log_writer.add_scalar('g_loss', errG.item(), global_step)
607
+ log_writer.add_scalar('d_loss', errD.item(), global_step)
608
+ if grad_penalty is not None:
609
+ log_writer.add_scalar('grad_penalty', grad_penalty.item(), global_step)
610
 
611
  global_step += 1
612
+
613
+
614
  if iteration % 100 == 0:
615
  if rank == 0:
616
  print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
617
+ print('Global step:', global_step)
618
  if iteration % 1000 == 0:
619
  x_t_1 = torch.randn_like(real_data)
620
+ with autocast():
621
+ fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
622
  if rank == 0:
623
  torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}_iteration_{}.png'.format(epoch, iteration)), normalize=True)
624
+
625
+ if args.save_content:
626
+ dist.barrier()
627
+ print('Saving content.')
628
+ def to_cpu(d):
629
+ for k, v in d.items():
630
+ d[k] = v.cpu()
631
+ return d
632
+
633
+ if args.fsdp:
634
+ netG_state_dict = to_cpu(netG.state_dict())
635
+ netD_state_dict = to_cpu(netD.state_dict())
636
+ #netG_optim_state_dict = (netG.gather_full_optim_state_dict(optimizerG))
637
+ netG_optim_state_dict = optimizerG.state_dict()
638
+ #print(netG_optim_state_dict)
639
+ netD_optim_state_dict = (optimizerD.state_dict())
640
  content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
641
+ 'netG_dict': netG_state_dict, 'optimizerG': netG_optim_state_dict,
642
+ 'schedulerG': schedulerG.state_dict(), 'netD_dict': netD_state_dict,
643
+ 'optimizerD': netD_optim_state_dict, 'schedulerD': schedulerD.state_dict()}
644
+ if rank == 0:
645
+ torch.save(content, os.path.join(exp_path, 'content.pth'))
646
+ torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
647
+ if args.use_ema:
648
+ optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
649
+ if args.use_ema and rank == 0:
650
+ torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
651
+ if args.use_ema:
652
+ optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
653
+ #if args.use_ema:
654
+ # dist.barrier()
655
+ print("Saved content")
656
+ else:
657
+ if rank == 0:
658
+ content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
659
+ 'netG_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
660
+ 'schedulerG': schedulerG.state_dict(), 'netD_dict': netD.state_dict(),
661
+ 'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
662
+ torch.save(content, os.path.join(exp_path, 'content.pth'))
663
+ torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
664
+ if args.use_ema:
665
+ optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
666
+ torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
667
+ if args.use_ema:
668
+ optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
669
+
670
 
671
  if not args.no_lr_decay:
672
 
673
  schedulerG.step()
674
  schedulerD.step()
675
+ """
676
  if rank == 0:
677
  if epoch % 10 == 0:
678
  torchvision.utils.save_image(x_pos_sample, os.path.join(exp_path, 'xpos_epoch_{}.png'.format(epoch)), normalize=True)
679
 
680
  x_t_1 = torch.randn_like(real_data)
681
+ with autocast():
682
+ fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
683
  torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}.png'.format(epoch)), normalize=True)
684
 
685
  if args.save_content:
 
700
  torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
701
  if args.use_ema:
702
  optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
703
+ dist.barrier()
704
+ """
705
 
706
 
707
  def init_processes(rank, size, fn, args):
 
736
  parser.add_argument('--mismatch_loss', action='store_true',default=False)
737
  parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
738
  parser.add_argument('--cross_attention', action='store_true',default=False)
739
+ parser.add_argument('--fsdp', action='store_true',default=False)
740
+ parser.add_argument('--grad_checkpointing', action='store_true',default=False)
741
 
742
  parser.add_argument('--image_size', type=int, default=32,
743
  help='size of image')
 
825
  parser.add_argument('--save_ckpt_every', type=int, default=25, help='save ckpt every x epochs')
826
  parser.add_argument('--discr_type', type=str, default="large")
827
  parser.add_argument('--preprocessing', type=str, default="resize")
828
+ parser.add_argument('--precision', type=str, default="fp32")
829
 
830
  ###ddp
831
  parser.add_argument('--num_proc_node', type=int, default=1,
 
844
  args.world_size = int(os.getenv("SLURM_NTASKS"))
845
  args.rank = int(os.environ['SLURM_PROCID'])
846
  # size = args.num_process_per_node
847
+ init_processes(args.rank, args.world_size, train, args)
utils.py CHANGED
@@ -41,7 +41,8 @@ class ResampledShards2(IterableDataset):
41
  """
42
  super().__init__()
43
  #urls = wds.shardlists.expand_urls(urls)
44
- urls = list(braceexpand.braceexpand(urls))
 
45
  self.urls = urls
46
  assert isinstance(self.urls[0], str)
47
  self.nshards = nshards
 
41
  """
42
  super().__init__()
43
  #urls = wds.shardlists.expand_urls(urls)
44
+ if type(urls) != list:
45
+ urls = list(braceexpand.braceexpand(urls))
46
  self.urls = urls
47
  assert isinstance(self.urls[0], str)
48
  self.nshards = nshards