Spaces:
Running
on
Zero
Running
on
Zero
Update model/flux.py
Browse files- model/flux.py +280 -0
model/flux.py
CHANGED
@@ -10,6 +10,7 @@ from scepter.modules.utils.config import dict_to_yaml
|
|
10 |
from scepter.modules.utils.distribute import we
|
11 |
from scepter.modules.utils.file_system import FS
|
12 |
from torch import Tensor, nn
|
|
|
13 |
from torch.utils.checkpoint import checkpoint_sequential
|
14 |
|
15 |
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
@@ -404,4 +405,283 @@ class Flux(BaseModel):
|
|
404 |
return dict_to_yaml('MODEL',
|
405 |
__class__.__name__,
|
406 |
Flux.para_dict,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
set_name=True)
|
|
|
10 |
from scepter.modules.utils.distribute import we
|
11 |
from scepter.modules.utils.file_system import FS
|
12 |
from torch import Tensor, nn
|
13 |
+
from torch.nn.utils.rnn import pad_sequence
|
14 |
from torch.utils.checkpoint import checkpoint_sequential
|
15 |
|
16 |
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
|
|
405 |
return dict_to_yaml('MODEL',
|
406 |
__class__.__name__,
|
407 |
Flux.para_dict,
|
408 |
+
set_name=True)
|
409 |
+
|
410 |
+
@BACKBONES.register_class()
|
411 |
+
class FluxMR(Flux):
|
412 |
+
def prepare_input(self, x, cond):
|
413 |
+
if isinstance(cond['context'], list):
|
414 |
+
context, y = torch.cat(cond["context"], dim=0).to(x), torch.cat(cond["y"], dim=0).to(x)
|
415 |
+
else:
|
416 |
+
context, y = cond['context'].to(x), cond['y'].to(x)
|
417 |
+
batch_frames, batch_frames_ids = [], []
|
418 |
+
for ix, shape in zip(x, cond["x_shapes"]):
|
419 |
+
# unpack image from sequence
|
420 |
+
ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
|
421 |
+
c, h, w = ix.shape
|
422 |
+
ix = rearrange(ix, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
|
423 |
+
ix_id = torch.zeros(h // 2, w // 2, 3)
|
424 |
+
ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
|
425 |
+
ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
|
426 |
+
ix_id = rearrange(ix_id, "h w c -> (h w) c")
|
427 |
+
batch_frames.append([ix])
|
428 |
+
batch_frames_ids.append([ix_id])
|
429 |
+
|
430 |
+
x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
|
431 |
+
for frames, frame_ids in zip(batch_frames, batch_frames_ids):
|
432 |
+
proj_frames = []
|
433 |
+
for idx, one_frame in enumerate(frames):
|
434 |
+
one_frame = self.img_in(one_frame)
|
435 |
+
proj_frames.append(one_frame)
|
436 |
+
ix = torch.cat(proj_frames, dim=0)
|
437 |
+
if_id = torch.cat(frame_ids, dim=0)
|
438 |
+
x_list.append(ix)
|
439 |
+
x_id_list.append(if_id)
|
440 |
+
mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
|
441 |
+
x_seq_length.append(ix.shape[0])
|
442 |
+
x = pad_sequence(tuple(x_list), batch_first=True)
|
443 |
+
x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
|
444 |
+
mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
|
445 |
+
|
446 |
+
txt = self.txt_in(context)
|
447 |
+
txt_ids = torch.zeros(context.shape[0], context.shape[1], 3).to(x)
|
448 |
+
mask_txt = torch.ones(context.shape[0], context.shape[1]).to(x.device, non_blocking=True).bool()
|
449 |
+
|
450 |
+
return x, x_ids, txt, txt_ids, y, mask_x, mask_txt, x_seq_length
|
451 |
+
|
452 |
+
def unpack(self, x: Tensor, cond: dict = None, x_seq_length: list = None) -> Tensor:
|
453 |
+
x_list = []
|
454 |
+
image_shapes = cond["x_shapes"]
|
455 |
+
for u, shape, seq_length in zip(x, image_shapes, x_seq_length):
|
456 |
+
height, width = shape
|
457 |
+
h, w = math.ceil(height / 2), math.ceil(width / 2)
|
458 |
+
u = rearrange(
|
459 |
+
u[seq_length-h*w:seq_length, ...],
|
460 |
+
"(h w) (c ph pw) -> (h ph w pw) c",
|
461 |
+
h=h,
|
462 |
+
w=w,
|
463 |
+
ph=2,
|
464 |
+
pw=2,
|
465 |
+
)
|
466 |
+
x_list.append(u)
|
467 |
+
x = pad_sequence(tuple(x_list), batch_first=True).permute(0, 2, 1)
|
468 |
+
return x
|
469 |
+
|
470 |
+
def forward(
|
471 |
+
self,
|
472 |
+
x: Tensor,
|
473 |
+
t: Tensor,
|
474 |
+
cond: dict = {},
|
475 |
+
guidance: Tensor | None = None,
|
476 |
+
gc_seg: int = 0,
|
477 |
+
**kwargs
|
478 |
+
) -> Tensor:
|
479 |
+
x, x_ids, txt, txt_ids, y, mask_x, mask_txt, seq_length_list = self.prepare_input(x, cond)
|
480 |
+
# running on sequences img
|
481 |
+
vec = self.time_in(timestep_embedding(t, 256))
|
482 |
+
if self.guidance_embed:
|
483 |
+
if guidance is None:
|
484 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
485 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
486 |
+
vec = vec + self.vector_in(y)
|
487 |
+
ids = torch.cat((txt_ids, x_ids), dim=1)
|
488 |
+
pe = self.pe_embedder(ids)
|
489 |
+
|
490 |
+
mask_aside = torch.cat((mask_txt, mask_x), dim=1)
|
491 |
+
mask = mask_aside[:, None, :] * mask_aside[:, :, None]
|
492 |
+
|
493 |
+
kwargs = dict(
|
494 |
+
vec=vec,
|
495 |
+
pe=pe,
|
496 |
+
mask=mask,
|
497 |
+
txt_length = txt.shape[1],
|
498 |
+
)
|
499 |
+
x = torch.cat((txt, x), 1)
|
500 |
+
if self.use_grad_checkpoint and gc_seg >= 0:
|
501 |
+
x = checkpoint_sequential(
|
502 |
+
functions=[partial(block, **kwargs) for block in self.double_blocks],
|
503 |
+
segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
|
504 |
+
input=x,
|
505 |
+
use_reentrant=False
|
506 |
+
)
|
507 |
+
else:
|
508 |
+
for block in self.double_blocks:
|
509 |
+
x = block(x, **kwargs)
|
510 |
+
|
511 |
+
kwargs = dict(
|
512 |
+
vec=vec,
|
513 |
+
pe=pe,
|
514 |
+
mask=mask,
|
515 |
+
)
|
516 |
+
|
517 |
+
if self.use_grad_checkpoint and gc_seg >= 0:
|
518 |
+
x = checkpoint_sequential(
|
519 |
+
functions=[partial(block, **kwargs) for block in self.single_blocks],
|
520 |
+
segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
|
521 |
+
input=x,
|
522 |
+
use_reentrant=False
|
523 |
+
)
|
524 |
+
else:
|
525 |
+
for block in self.single_blocks:
|
526 |
+
x = block(x, **kwargs)
|
527 |
+
x = x[:, txt.shape[1]:, ...]
|
528 |
+
x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
|
529 |
+
x = self.unpack(x, cond, seq_length_list)
|
530 |
+
return x
|
531 |
+
|
532 |
+
@staticmethod
|
533 |
+
def get_config_template():
|
534 |
+
return dict_to_yaml('MODEL',
|
535 |
+
__class__.__name__,
|
536 |
+
FluxEdit.para_dict,
|
537 |
+
set_name=True)
|
538 |
+
@BACKBONES.register_class()
|
539 |
+
class FluxEdit(FluxMR):
|
540 |
+
def prepare_input(self, x, cond, *args, **kwargs):
|
541 |
+
context, y = cond["context"], cond["y"]
|
542 |
+
batch_frames, batch_frames_ids, batch_shift = [], [], []
|
543 |
+
|
544 |
+
for ix, shape, is_align in zip(x, cond["x_shapes"], cond['align']):
|
545 |
+
# unpack image from sequence
|
546 |
+
ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
|
547 |
+
c, h, w = ix.shape
|
548 |
+
ix = rearrange(ix, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
|
549 |
+
ix_id = torch.zeros(h // 2, w // 2, 3)
|
550 |
+
ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
|
551 |
+
ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
|
552 |
+
batch_shift.append(h // 2) #if is_align < 1 else batch_shift.append(0)
|
553 |
+
ix_id = rearrange(ix_id, "h w c -> (h w) c")
|
554 |
+
batch_frames.append([ix])
|
555 |
+
batch_frames_ids.append([ix_id])
|
556 |
+
if 'edit_x' in cond:
|
557 |
+
for i, edit in enumerate(cond['edit_x']):
|
558 |
+
if edit is None:
|
559 |
+
continue
|
560 |
+
for ie in edit:
|
561 |
+
ie = ie.squeeze(0)
|
562 |
+
c, h, w = ie.shape
|
563 |
+
ie = rearrange(ie, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
|
564 |
+
ie_id = torch.zeros(h // 2, w // 2, 3)
|
565 |
+
ie_id[..., 1] = ie_id[..., 1] + torch.arange(batch_shift[i], h // 2 + batch_shift[i])[:, None]
|
566 |
+
ie_id[..., 2] = ie_id[..., 2] + torch.arange(w // 2)[None, :]
|
567 |
+
ie_id = rearrange(ie_id, "h w c -> (h w) c")
|
568 |
+
batch_frames[i].append(ie)
|
569 |
+
batch_frames_ids[i].append(ie_id)
|
570 |
+
|
571 |
+
x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
|
572 |
+
for frames, frame_ids in zip(batch_frames, batch_frames_ids):
|
573 |
+
proj_frames = []
|
574 |
+
for idx, one_frame in enumerate(frames):
|
575 |
+
one_frame = self.img_in(one_frame)
|
576 |
+
proj_frames.append(one_frame)
|
577 |
+
ix = torch.cat(proj_frames, dim=0)
|
578 |
+
if_id = torch.cat(frame_ids, dim=0)
|
579 |
+
x_list.append(ix)
|
580 |
+
x_id_list.append(if_id)
|
581 |
+
mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
|
582 |
+
x_seq_length.append(ix.shape[0])
|
583 |
+
x = pad_sequence(tuple(x_list), batch_first=True)
|
584 |
+
x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
|
585 |
+
mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
|
586 |
+
|
587 |
+
txt_list, mask_txt_list, y_list = [], [], []
|
588 |
+
for sample_id, (ctx, yy) in enumerate(zip(context, y)):
|
589 |
+
ctx_batch = []
|
590 |
+
for frame_id, one_ctx in enumerate(ctx):
|
591 |
+
one_ctx = self.txt_in(one_ctx.to(x))
|
592 |
+
ctx_batch.append(one_ctx)
|
593 |
+
txt_list.append(torch.cat(ctx_batch, dim=0))
|
594 |
+
mask_txt_list.append(torch.ones(txt_list[-1].shape[0]).to(ctx.device, non_blocking=True).bool())
|
595 |
+
y_list.append(yy.mean(dim = 0, keepdim=True))
|
596 |
+
txt = pad_sequence(tuple(txt_list), batch_first=True)
|
597 |
+
txt_ids = torch.zeros(txt.shape[0], txt.shape[1], 3).to(x)
|
598 |
+
mask_txt = pad_sequence(tuple(mask_txt_list), batch_first=True)
|
599 |
+
y = torch.cat(y_list, dim=0)
|
600 |
+
return x, x_ids, txt, txt_ids, y, mask_x, mask_txt, x_seq_length
|
601 |
+
|
602 |
+
def unpack(self, x: Tensor, cond: dict = None, x_seq_length: list = None) -> Tensor:
|
603 |
+
x_list = []
|
604 |
+
image_shapes = cond["x_shapes"]
|
605 |
+
for u, shape, seq_length in zip(x, image_shapes, x_seq_length):
|
606 |
+
height, width = shape
|
607 |
+
h, w = math.ceil(height / 2), math.ceil(width / 2)
|
608 |
+
u = rearrange(
|
609 |
+
u[:h*w, ...],
|
610 |
+
"(h w) (c ph pw) -> (h ph w pw) c",
|
611 |
+
h=h,
|
612 |
+
w=w,
|
613 |
+
ph=2,
|
614 |
+
pw=2,
|
615 |
+
)
|
616 |
+
x_list.append(u)
|
617 |
+
x = pad_sequence(tuple(x_list), batch_first=True).permute(0, 2, 1)
|
618 |
+
return x
|
619 |
+
|
620 |
+
def forward(
|
621 |
+
self,
|
622 |
+
x: Tensor,
|
623 |
+
t: Tensor,
|
624 |
+
cond: dict = {},
|
625 |
+
guidance: Tensor | None = None,
|
626 |
+
gc_seg: int = 0,
|
627 |
+
text_position_embeddings = None
|
628 |
+
) -> Tensor:
|
629 |
+
x, x_ids, txt, txt_ids, y, mask_x, mask_txt, seq_length_list = self.prepare_input(x, cond, text_position_embeddings)
|
630 |
+
# running on sequences img
|
631 |
+
vec = self.time_in(timestep_embedding(t, 256))
|
632 |
+
if self.guidance_embed:
|
633 |
+
if guidance is None:
|
634 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
635 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
636 |
+
vec = vec + self.vector_in(y)
|
637 |
+
ids = torch.cat((txt_ids, x_ids), dim=1)
|
638 |
+
pe = self.pe_embedder(ids)
|
639 |
+
|
640 |
+
mask_aside = torch.cat((mask_txt, mask_x), dim=1)
|
641 |
+
mask = mask_aside[:, None, :] * mask_aside[:, :, None]
|
642 |
+
|
643 |
+
kwargs = dict(
|
644 |
+
vec=vec,
|
645 |
+
pe=pe,
|
646 |
+
mask=mask,
|
647 |
+
txt_length = txt.shape[1],
|
648 |
+
)
|
649 |
+
x = torch.cat((txt, x), 1)
|
650 |
+
|
651 |
+
if self.use_grad_checkpoint and gc_seg >= 0:
|
652 |
+
x = checkpoint_sequential(
|
653 |
+
functions=[partial(block, **kwargs) for block in self.double_blocks],
|
654 |
+
segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
|
655 |
+
input=x,
|
656 |
+
use_reentrant=False
|
657 |
+
)
|
658 |
+
else:
|
659 |
+
for block in self.double_blocks:
|
660 |
+
x = block(x, **kwargs)
|
661 |
+
|
662 |
+
kwargs = dict(
|
663 |
+
vec=vec,
|
664 |
+
pe=pe,
|
665 |
+
mask=mask,
|
666 |
+
)
|
667 |
+
|
668 |
+
if self.use_grad_checkpoint and gc_seg >= 0:
|
669 |
+
x = checkpoint_sequential(
|
670 |
+
functions=[partial(block, **kwargs) for block in self.single_blocks],
|
671 |
+
segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
|
672 |
+
input=x,
|
673 |
+
use_reentrant=False
|
674 |
+
)
|
675 |
+
else:
|
676 |
+
for block in self.single_blocks:
|
677 |
+
x = block(x, **kwargs)
|
678 |
+
x = x[:, txt.shape[1]:, ...]
|
679 |
+
x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
|
680 |
+
x = self.unpack(x, cond, seq_length_list)
|
681 |
+
return x
|
682 |
+
@staticmethod
|
683 |
+
def get_config_template():
|
684 |
+
return dict_to_yaml('MODEL',
|
685 |
+
__class__.__name__,
|
686 |
+
FluxEdit.para_dict,
|
687 |
set_name=True)
|