chaojiemao commited on
Commit
b8d6d4a
·
verified ·
1 Parent(s): 20a320d

Update model/flux.py

Browse files
Files changed (1) hide show
  1. model/flux.py +280 -0
model/flux.py CHANGED
@@ -10,6 +10,7 @@ from scepter.modules.utils.config import dict_to_yaml
10
  from scepter.modules.utils.distribute import we
11
  from scepter.modules.utils.file_system import FS
12
  from torch import Tensor, nn
 
13
  from torch.utils.checkpoint import checkpoint_sequential
14
 
15
  from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
@@ -404,4 +405,283 @@ class Flux(BaseModel):
404
  return dict_to_yaml('MODEL',
405
  __class__.__name__,
406
  Flux.para_dict,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  set_name=True)
 
10
  from scepter.modules.utils.distribute import we
11
  from scepter.modules.utils.file_system import FS
12
  from torch import Tensor, nn
13
+ from torch.nn.utils.rnn import pad_sequence
14
  from torch.utils.checkpoint import checkpoint_sequential
15
 
16
  from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
 
405
  return dict_to_yaml('MODEL',
406
  __class__.__name__,
407
  Flux.para_dict,
408
+ set_name=True)
409
+
410
+ @BACKBONES.register_class()
411
+ class FluxMR(Flux):
412
+ def prepare_input(self, x, cond):
413
+ if isinstance(cond['context'], list):
414
+ context, y = torch.cat(cond["context"], dim=0).to(x), torch.cat(cond["y"], dim=0).to(x)
415
+ else:
416
+ context, y = cond['context'].to(x), cond['y'].to(x)
417
+ batch_frames, batch_frames_ids = [], []
418
+ for ix, shape in zip(x, cond["x_shapes"]):
419
+ # unpack image from sequence
420
+ ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
421
+ c, h, w = ix.shape
422
+ ix = rearrange(ix, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
423
+ ix_id = torch.zeros(h // 2, w // 2, 3)
424
+ ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
425
+ ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
426
+ ix_id = rearrange(ix_id, "h w c -> (h w) c")
427
+ batch_frames.append([ix])
428
+ batch_frames_ids.append([ix_id])
429
+
430
+ x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
431
+ for frames, frame_ids in zip(batch_frames, batch_frames_ids):
432
+ proj_frames = []
433
+ for idx, one_frame in enumerate(frames):
434
+ one_frame = self.img_in(one_frame)
435
+ proj_frames.append(one_frame)
436
+ ix = torch.cat(proj_frames, dim=0)
437
+ if_id = torch.cat(frame_ids, dim=0)
438
+ x_list.append(ix)
439
+ x_id_list.append(if_id)
440
+ mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
441
+ x_seq_length.append(ix.shape[0])
442
+ x = pad_sequence(tuple(x_list), batch_first=True)
443
+ x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
444
+ mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
445
+
446
+ txt = self.txt_in(context)
447
+ txt_ids = torch.zeros(context.shape[0], context.shape[1], 3).to(x)
448
+ mask_txt = torch.ones(context.shape[0], context.shape[1]).to(x.device, non_blocking=True).bool()
449
+
450
+ return x, x_ids, txt, txt_ids, y, mask_x, mask_txt, x_seq_length
451
+
452
+ def unpack(self, x: Tensor, cond: dict = None, x_seq_length: list = None) -> Tensor:
453
+ x_list = []
454
+ image_shapes = cond["x_shapes"]
455
+ for u, shape, seq_length in zip(x, image_shapes, x_seq_length):
456
+ height, width = shape
457
+ h, w = math.ceil(height / 2), math.ceil(width / 2)
458
+ u = rearrange(
459
+ u[seq_length-h*w:seq_length, ...],
460
+ "(h w) (c ph pw) -> (h ph w pw) c",
461
+ h=h,
462
+ w=w,
463
+ ph=2,
464
+ pw=2,
465
+ )
466
+ x_list.append(u)
467
+ x = pad_sequence(tuple(x_list), batch_first=True).permute(0, 2, 1)
468
+ return x
469
+
470
+ def forward(
471
+ self,
472
+ x: Tensor,
473
+ t: Tensor,
474
+ cond: dict = {},
475
+ guidance: Tensor | None = None,
476
+ gc_seg: int = 0,
477
+ **kwargs
478
+ ) -> Tensor:
479
+ x, x_ids, txt, txt_ids, y, mask_x, mask_txt, seq_length_list = self.prepare_input(x, cond)
480
+ # running on sequences img
481
+ vec = self.time_in(timestep_embedding(t, 256))
482
+ if self.guidance_embed:
483
+ if guidance is None:
484
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
485
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
486
+ vec = vec + self.vector_in(y)
487
+ ids = torch.cat((txt_ids, x_ids), dim=1)
488
+ pe = self.pe_embedder(ids)
489
+
490
+ mask_aside = torch.cat((mask_txt, mask_x), dim=1)
491
+ mask = mask_aside[:, None, :] * mask_aside[:, :, None]
492
+
493
+ kwargs = dict(
494
+ vec=vec,
495
+ pe=pe,
496
+ mask=mask,
497
+ txt_length = txt.shape[1],
498
+ )
499
+ x = torch.cat((txt, x), 1)
500
+ if self.use_grad_checkpoint and gc_seg >= 0:
501
+ x = checkpoint_sequential(
502
+ functions=[partial(block, **kwargs) for block in self.double_blocks],
503
+ segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
504
+ input=x,
505
+ use_reentrant=False
506
+ )
507
+ else:
508
+ for block in self.double_blocks:
509
+ x = block(x, **kwargs)
510
+
511
+ kwargs = dict(
512
+ vec=vec,
513
+ pe=pe,
514
+ mask=mask,
515
+ )
516
+
517
+ if self.use_grad_checkpoint and gc_seg >= 0:
518
+ x = checkpoint_sequential(
519
+ functions=[partial(block, **kwargs) for block in self.single_blocks],
520
+ segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
521
+ input=x,
522
+ use_reentrant=False
523
+ )
524
+ else:
525
+ for block in self.single_blocks:
526
+ x = block(x, **kwargs)
527
+ x = x[:, txt.shape[1]:, ...]
528
+ x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
529
+ x = self.unpack(x, cond, seq_length_list)
530
+ return x
531
+
532
+ @staticmethod
533
+ def get_config_template():
534
+ return dict_to_yaml('MODEL',
535
+ __class__.__name__,
536
+ FluxEdit.para_dict,
537
+ set_name=True)
538
+ @BACKBONES.register_class()
539
+ class FluxEdit(FluxMR):
540
+ def prepare_input(self, x, cond, *args, **kwargs):
541
+ context, y = cond["context"], cond["y"]
542
+ batch_frames, batch_frames_ids, batch_shift = [], [], []
543
+
544
+ for ix, shape, is_align in zip(x, cond["x_shapes"], cond['align']):
545
+ # unpack image from sequence
546
+ ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
547
+ c, h, w = ix.shape
548
+ ix = rearrange(ix, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
549
+ ix_id = torch.zeros(h // 2, w // 2, 3)
550
+ ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
551
+ ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
552
+ batch_shift.append(h // 2) #if is_align < 1 else batch_shift.append(0)
553
+ ix_id = rearrange(ix_id, "h w c -> (h w) c")
554
+ batch_frames.append([ix])
555
+ batch_frames_ids.append([ix_id])
556
+ if 'edit_x' in cond:
557
+ for i, edit in enumerate(cond['edit_x']):
558
+ if edit is None:
559
+ continue
560
+ for ie in edit:
561
+ ie = ie.squeeze(0)
562
+ c, h, w = ie.shape
563
+ ie = rearrange(ie, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
564
+ ie_id = torch.zeros(h // 2, w // 2, 3)
565
+ ie_id[..., 1] = ie_id[..., 1] + torch.arange(batch_shift[i], h // 2 + batch_shift[i])[:, None]
566
+ ie_id[..., 2] = ie_id[..., 2] + torch.arange(w // 2)[None, :]
567
+ ie_id = rearrange(ie_id, "h w c -> (h w) c")
568
+ batch_frames[i].append(ie)
569
+ batch_frames_ids[i].append(ie_id)
570
+
571
+ x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
572
+ for frames, frame_ids in zip(batch_frames, batch_frames_ids):
573
+ proj_frames = []
574
+ for idx, one_frame in enumerate(frames):
575
+ one_frame = self.img_in(one_frame)
576
+ proj_frames.append(one_frame)
577
+ ix = torch.cat(proj_frames, dim=0)
578
+ if_id = torch.cat(frame_ids, dim=0)
579
+ x_list.append(ix)
580
+ x_id_list.append(if_id)
581
+ mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
582
+ x_seq_length.append(ix.shape[0])
583
+ x = pad_sequence(tuple(x_list), batch_first=True)
584
+ x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
585
+ mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
586
+
587
+ txt_list, mask_txt_list, y_list = [], [], []
588
+ for sample_id, (ctx, yy) in enumerate(zip(context, y)):
589
+ ctx_batch = []
590
+ for frame_id, one_ctx in enumerate(ctx):
591
+ one_ctx = self.txt_in(one_ctx.to(x))
592
+ ctx_batch.append(one_ctx)
593
+ txt_list.append(torch.cat(ctx_batch, dim=0))
594
+ mask_txt_list.append(torch.ones(txt_list[-1].shape[0]).to(ctx.device, non_blocking=True).bool())
595
+ y_list.append(yy.mean(dim = 0, keepdim=True))
596
+ txt = pad_sequence(tuple(txt_list), batch_first=True)
597
+ txt_ids = torch.zeros(txt.shape[0], txt.shape[1], 3).to(x)
598
+ mask_txt = pad_sequence(tuple(mask_txt_list), batch_first=True)
599
+ y = torch.cat(y_list, dim=0)
600
+ return x, x_ids, txt, txt_ids, y, mask_x, mask_txt, x_seq_length
601
+
602
+ def unpack(self, x: Tensor, cond: dict = None, x_seq_length: list = None) -> Tensor:
603
+ x_list = []
604
+ image_shapes = cond["x_shapes"]
605
+ for u, shape, seq_length in zip(x, image_shapes, x_seq_length):
606
+ height, width = shape
607
+ h, w = math.ceil(height / 2), math.ceil(width / 2)
608
+ u = rearrange(
609
+ u[:h*w, ...],
610
+ "(h w) (c ph pw) -> (h ph w pw) c",
611
+ h=h,
612
+ w=w,
613
+ ph=2,
614
+ pw=2,
615
+ )
616
+ x_list.append(u)
617
+ x = pad_sequence(tuple(x_list), batch_first=True).permute(0, 2, 1)
618
+ return x
619
+
620
+ def forward(
621
+ self,
622
+ x: Tensor,
623
+ t: Tensor,
624
+ cond: dict = {},
625
+ guidance: Tensor | None = None,
626
+ gc_seg: int = 0,
627
+ text_position_embeddings = None
628
+ ) -> Tensor:
629
+ x, x_ids, txt, txt_ids, y, mask_x, mask_txt, seq_length_list = self.prepare_input(x, cond, text_position_embeddings)
630
+ # running on sequences img
631
+ vec = self.time_in(timestep_embedding(t, 256))
632
+ if self.guidance_embed:
633
+ if guidance is None:
634
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
635
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
636
+ vec = vec + self.vector_in(y)
637
+ ids = torch.cat((txt_ids, x_ids), dim=1)
638
+ pe = self.pe_embedder(ids)
639
+
640
+ mask_aside = torch.cat((mask_txt, mask_x), dim=1)
641
+ mask = mask_aside[:, None, :] * mask_aside[:, :, None]
642
+
643
+ kwargs = dict(
644
+ vec=vec,
645
+ pe=pe,
646
+ mask=mask,
647
+ txt_length = txt.shape[1],
648
+ )
649
+ x = torch.cat((txt, x), 1)
650
+
651
+ if self.use_grad_checkpoint and gc_seg >= 0:
652
+ x = checkpoint_sequential(
653
+ functions=[partial(block, **kwargs) for block in self.double_blocks],
654
+ segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
655
+ input=x,
656
+ use_reentrant=False
657
+ )
658
+ else:
659
+ for block in self.double_blocks:
660
+ x = block(x, **kwargs)
661
+
662
+ kwargs = dict(
663
+ vec=vec,
664
+ pe=pe,
665
+ mask=mask,
666
+ )
667
+
668
+ if self.use_grad_checkpoint and gc_seg >= 0:
669
+ x = checkpoint_sequential(
670
+ functions=[partial(block, **kwargs) for block in self.single_blocks],
671
+ segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
672
+ input=x,
673
+ use_reentrant=False
674
+ )
675
+ else:
676
+ for block in self.single_blocks:
677
+ x = block(x, **kwargs)
678
+ x = x[:, txt.shape[1]:, ...]
679
+ x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
680
+ x = self.unpack(x, cond, seq_length_list)
681
+ return x
682
+ @staticmethod
683
+ def get_config_template():
684
+ return dict_to_yaml('MODEL',
685
+ __class__.__name__,
686
+ FluxEdit.para_dict,
687
  set_name=True)