Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
•
1a02524
1
Parent(s):
572f947
cond grad penalty: use only cond embedding to compute grad
Browse files- score_sde/models/ncsnpp_generator_adagn.py +3 -0
- train_ddgan.py +75 -26
score_sde/models/ncsnpp_generator_adagn.py
CHANGED
@@ -325,9 +325,12 @@ class NCSNpp(nn.Module):
|
|
325 |
|
326 |
hs = [modules[m_idx](x)]
|
327 |
m_idx += 1
|
|
|
|
|
328 |
for i_level in range(self.num_resolutions):
|
329 |
# Residual blocks for this resolution
|
330 |
for i_block in range(self.num_res_blocks):
|
|
|
331 |
h = modules[m_idx](hs[-1], temb, zemb)
|
332 |
m_idx += 1
|
333 |
if h.shape[-1] in self.attn_resolutions:
|
|
|
325 |
|
326 |
hs = [modules[m_idx](x)]
|
327 |
m_idx += 1
|
328 |
+
#print(self.attn_resolutions)
|
329 |
+
#self.attn_resolutions = (32,)
|
330 |
for i_level in range(self.num_resolutions):
|
331 |
# Residual blocks for this resolution
|
332 |
for i_block in range(self.num_res_blocks):
|
333 |
+
#print(hs[-1].shape, temb.shape, zemb.shape, type(modules[m_idx]))
|
334 |
h = modules[m_idx](hs[-1], temb, zemb)
|
335 |
m_idx += 1
|
336 |
if h.shape[-1] in self.attn_resolutions:
|
train_ddgan.py
CHANGED
@@ -28,7 +28,10 @@ from torch.multiprocessing import Process
|
|
28 |
import torch.distributed as dist
|
29 |
import shutil
|
30 |
import logging
|
31 |
-
import
|
|
|
|
|
|
|
32 |
def log_and_continue(exn):
|
33 |
logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
|
34 |
return True
|
@@ -192,7 +195,11 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None
|
|
192 |
return x
|
193 |
|
194 |
|
195 |
-
|
|
|
|
|
|
|
|
|
196 |
|
197 |
def train(rank, gpu, args):
|
198 |
from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
|
@@ -278,6 +285,7 @@ def train(rank, gpu, args):
|
|
278 |
),
|
279 |
])
|
280 |
pipeline.extend([
|
|
|
281 |
wds.decode("pilrgb", handler=log_and_continue),
|
282 |
wds.rename(image="jpg;png"),
|
283 |
wds.map_dict(image=train_transform),
|
@@ -307,7 +315,7 @@ def train(rank, gpu, args):
|
|
307 |
pin_memory=True,
|
308 |
sampler=train_sampler,
|
309 |
)
|
310 |
-
text_encoder =
|
311 |
args.cond_size = text_encoder.output_size
|
312 |
netG = NCSNpp(args).to(device)
|
313 |
nb_params = 0
|
@@ -387,7 +395,7 @@ def train(rank, gpu, args):
|
|
387 |
.format(checkpoint['epoch']))
|
388 |
else:
|
389 |
global_step, epoch, init_epoch = 0, 0, 0
|
390 |
-
|
391 |
for epoch in range(init_epoch, args.num_epoch+1):
|
392 |
if args.dataset == "wds":
|
393 |
os.environ["WDS_EPOCH"] = str(epoch)
|
@@ -419,45 +427,71 @@ def train(rank, gpu, args):
|
|
419 |
x_t, x_tp1 = q_sample_pairs(coeff, real_data, t)
|
420 |
x_t.requires_grad = True
|
421 |
|
422 |
-
cond_for_discr = (cond_pooled, cond, cond_mask) if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
|
424 |
# train with real
|
425 |
D_real = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
|
426 |
|
427 |
errD_real = F.softplus(-D_real)
|
428 |
errD_real = errD_real.mean()
|
|
|
429 |
|
430 |
errD_real.backward(retain_graph=True)
|
431 |
|
432 |
|
433 |
if args.lazy_reg is None:
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
else:
|
445 |
-
if global_step % args.lazy_reg == 0:
|
446 |
grad_real = torch.autograd.grad(
|
447 |
-
|
448 |
-
|
449 |
grad_penalty = (
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
grad_penalty = args.r1_gamma / 2 * grad_penalty
|
455 |
grad_penalty.backward()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
|
457 |
# train with fake
|
458 |
latent_z = torch.randn(batch_size, nz, device=device)
|
459 |
|
460 |
-
|
461 |
x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
|
462 |
x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
|
463 |
|
@@ -466,6 +500,18 @@ def train(rank, gpu, args):
|
|
466 |
|
467 |
errD_fake = F.softplus(output)
|
468 |
errD_fake = errD_fake.mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
errD_fake.backward()
|
470 |
|
471 |
|
@@ -592,6 +638,7 @@ if __name__ == '__main__':
|
|
592 |
|
593 |
parser.add_argument('--resume', action='store_true',default=False)
|
594 |
parser.add_argument('--masked_mean', action='store_true',default=False)
|
|
|
595 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
596 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
597 |
|
@@ -616,7 +663,7 @@ if __name__ == '__main__':
|
|
616 |
help='channel multiplier')
|
617 |
parser.add_argument('--num_res_blocks', type=int, default=2,
|
618 |
help='number of resnet blocks per scale')
|
619 |
-
parser.add_argument('--attn_resolutions', default=(16,),
|
620 |
help='resolution of applying attention')
|
621 |
parser.add_argument('--dropout', type=float, default=0.,
|
622 |
help='drop-out rate')
|
@@ -665,12 +712,14 @@ if __name__ == '__main__':
|
|
665 |
parser.add_argument('--beta2', type=float, default=0.9,
|
666 |
help='beta2 for adam')
|
667 |
parser.add_argument('--no_lr_decay',action='store_true', default=False)
|
668 |
-
|
|
|
669 |
parser.add_argument('--use_ema', action='store_true', default=False,
|
670 |
help='use EMA or not')
|
671 |
parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
|
672 |
|
673 |
parser.add_argument('--r1_gamma', type=float, default=0.05, help='coef for r1 reg')
|
|
|
674 |
parser.add_argument('--lazy_reg', type=int, default=None,
|
675 |
help='lazy regulariation.')
|
676 |
|
|
|
28 |
import torch.distributed as dist
|
29 |
import shutil
|
30 |
import logging
|
31 |
+
from encoder import build_encoder
|
32 |
+
from utils import ResampledShards2
|
33 |
+
|
34 |
+
|
35 |
def log_and_continue(exn):
|
36 |
logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
|
37 |
return True
|
|
|
195 |
return x
|
196 |
|
197 |
|
198 |
+
|
199 |
+
def filter_no_caption(sample):
|
200 |
+
return 'txt' in sample
|
201 |
+
|
202 |
+
|
203 |
|
204 |
def train(rank, gpu, args):
|
205 |
from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
|
|
|
285 |
),
|
286 |
])
|
287 |
pipeline.extend([
|
288 |
+
wds.select(filter_no_caption),
|
289 |
wds.decode("pilrgb", handler=log_and_continue),
|
290 |
wds.rename(image="jpg;png"),
|
291 |
wds.map_dict(image=train_transform),
|
|
|
315 |
pin_memory=True,
|
316 |
sampler=train_sampler,
|
317 |
)
|
318 |
+
text_encoder = build_encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
|
319 |
args.cond_size = text_encoder.output_size
|
320 |
netG = NCSNpp(args).to(device)
|
321 |
nb_params = 0
|
|
|
395 |
.format(checkpoint['epoch']))
|
396 |
else:
|
397 |
global_step, epoch, init_epoch = 0, 0, 0
|
398 |
+
use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn")
|
399 |
for epoch in range(init_epoch, args.num_epoch+1):
|
400 |
if args.dataset == "wds":
|
401 |
os.environ["WDS_EPOCH"] = str(epoch)
|
|
|
427 |
x_t, x_tp1 = q_sample_pairs(coeff, real_data, t)
|
428 |
x_t.requires_grad = True
|
429 |
|
430 |
+
cond_for_discr = (cond_pooled, cond, cond_mask) if use_cond_attn_discr else cond_pooled
|
431 |
+
if args.grad_penalty_cond:
|
432 |
+
if use_cond_attn_discr:
|
433 |
+
#cond_pooled.requires_grad = True
|
434 |
+
cond.requires_grad = True
|
435 |
+
#cond_mask.requires_grad = True
|
436 |
+
else:
|
437 |
+
cond_for_discr.requires_grad = True
|
438 |
|
439 |
# train with real
|
440 |
D_real = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
|
441 |
|
442 |
errD_real = F.softplus(-D_real)
|
443 |
errD_real = errD_real.mean()
|
444 |
+
|
445 |
|
446 |
errD_real.backward(retain_graph=True)
|
447 |
|
448 |
|
449 |
if args.lazy_reg is None:
|
450 |
+
if args.grad_penalty_cond:
|
451 |
+
inputs = (x_t,) + (cond,) if use_cond_attn_discr else (cond_for_discr,)
|
452 |
+
grad_real = torch.autograd.grad(
|
453 |
+
outputs=D_real.sum(), inputs=inputs, create_graph=True
|
454 |
+
)[0]
|
455 |
+
grad_real = torch.cat([g.view(g.size(0), -1) for g in grad_real])
|
456 |
+
grad_penalty = (grad_real.norm(2, dim=1) ** 2).mean()
|
457 |
+
grad_penalty = args.r1_gamma / 2 * grad_penalty
|
458 |
+
grad_penalty.backward()
|
459 |
+
else:
|
|
|
|
|
460 |
grad_real = torch.autograd.grad(
|
461 |
+
outputs=D_real.sum(), inputs=x_t, create_graph=True
|
462 |
+
)[0]
|
463 |
grad_penalty = (
|
464 |
+
grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
|
465 |
+
).mean()
|
466 |
+
|
467 |
+
|
468 |
grad_penalty = args.r1_gamma / 2 * grad_penalty
|
469 |
grad_penalty.backward()
|
470 |
+
else:
|
471 |
+
if global_step % args.lazy_reg == 0:
|
472 |
+
if args.grad_penalty_cond:
|
473 |
+
inputs = (x_t,) + (cond,) if use_cond_attn_discr else (cond_for_discr,)
|
474 |
+
grad_real = torch.autograd.grad(
|
475 |
+
outputs=D_real.sum(), inputs=inputs, create_graph=True
|
476 |
+
)[0]
|
477 |
+
grad_real = torch.cat([g.view(g.size(0), -1) for g in grad_real])
|
478 |
+
grad_penalty = (grad_real.norm(2, dim=1) ** 2).mean()
|
479 |
+
grad_penalty = args.r1_gamma / 2 * grad_penalty
|
480 |
+
grad_penalty.backward()
|
481 |
+
else:
|
482 |
+
grad_real = torch.autograd.grad(
|
483 |
+
outputs=D_real.sum(), inputs=x_t, create_graph=True
|
484 |
+
)[0]
|
485 |
+
grad_penalty = (
|
486 |
+
grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
|
487 |
+
).mean()
|
488 |
+
|
489 |
+
grad_penalty = args.r1_gamma / 2 * grad_penalty
|
490 |
+
grad_penalty.backward()
|
491 |
|
492 |
# train with fake
|
493 |
latent_z = torch.randn(batch_size, nz, device=device)
|
494 |
|
|
|
495 |
x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
|
496 |
x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
|
497 |
|
|
|
500 |
|
501 |
errD_fake = F.softplus(output)
|
502 |
errD_fake = errD_fake.mean()
|
503 |
+
|
504 |
+
if args.mismatch_loss:
|
505 |
+
# following https://github.com/tobran/DF-GAN/blob/bc38a4f795c294b09b4ef5579cd4ff78807e5b96/code/lib/modules.py,
|
506 |
+
# we add a discr loss for (real image, non matching text)
|
507 |
+
#inds = torch.flip(torch.arange(len(x_t)), dims=(0,))
|
508 |
+
inds = torch.cat([torch.arange(1,len(x_t)),torch.arange(1)])
|
509 |
+
cond_for_discr_mis = (cond_pooled[inds], cond[inds], cond_mask[inds]) if use_cond_attn_discr else cond_pooled[inds]
|
510 |
+
D_real_mis = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr_mis).view(-1)
|
511 |
+
errD_real_mis = F.softplus(D_real_mis)
|
512 |
+
errD_real_mis = errD_real_mis.mean()
|
513 |
+
errD_fake = errD_fake * 0.5 + errD_real_mis * 0.5
|
514 |
+
|
515 |
errD_fake.backward()
|
516 |
|
517 |
|
|
|
638 |
|
639 |
parser.add_argument('--resume', action='store_true',default=False)
|
640 |
parser.add_argument('--masked_mean', action='store_true',default=False)
|
641 |
+
parser.add_argument('--mismatch_loss', action='store_true',default=False)
|
642 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
643 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
644 |
|
|
|
663 |
help='channel multiplier')
|
664 |
parser.add_argument('--num_res_blocks', type=int, default=2,
|
665 |
help='number of resnet blocks per scale')
|
666 |
+
parser.add_argument('--attn_resolutions', default=(16,), nargs='+', type=int,
|
667 |
help='resolution of applying attention')
|
668 |
parser.add_argument('--dropout', type=float, default=0.,
|
669 |
help='drop-out rate')
|
|
|
712 |
parser.add_argument('--beta2', type=float, default=0.9,
|
713 |
help='beta2 for adam')
|
714 |
parser.add_argument('--no_lr_decay',action='store_true', default=False)
|
715 |
+
parser.add_argument('--grad_penalty_cond', action='store_true',default=False)
|
716 |
+
|
717 |
parser.add_argument('--use_ema', action='store_true', default=False,
|
718 |
help='use EMA or not')
|
719 |
parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
|
720 |
|
721 |
parser.add_argument('--r1_gamma', type=float, default=0.05, help='coef for r1 reg')
|
722 |
+
|
723 |
parser.add_argument('--lazy_reg', type=int, default=None,
|
724 |
help='lazy regulariation.')
|
725 |
|