manbeast3b
Initial commit
019d58e
import functools
import unittest
from typing import Any, Dict, Optional, Union
import torch
from diffusers import DiffusionPipeline, HunyuanVideoTransformer3DModel
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import logging, scale_lora_layers, unscale_lora_layers, USE_PEFT_BACKEND
from para_attn.first_block_cache import utils
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def apply_cache_on_transformer(
transformer: HunyuanVideoTransformer3DModel,
*,
residual_diff_threshold=0.06,
):
cached_transformer_blocks = torch.nn.ModuleList(
[
utils.CachedTransformerBlocks(
transformer.transformer_blocks + transformer.single_transformer_blocks,
transformer=transformer,
residual_diff_threshold=residual_diff_threshold,
)
]
)
dummy_single_transformer_blocks = torch.nn.ModuleList()
original_forward = transformer.forward
@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
pooled_projections: torch.Tensor,
guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
**kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
with unittest.mock.patch.object(
self,
"transformer_blocks",
cached_transformer_blocks,
), unittest.mock.patch.object(
self,
"single_transformer_blocks",
dummy_single_transformer_blocks,
):
if getattr(self, "_is_parallelized", False):
return original_forward(
hidden_states,
timestep,
encoder_hidden_states,
encoder_attention_mask,
pooled_projections,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=return_dict,
**kwargs,
)
else:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p
# 1. RoPE
image_rotary_emb = self.rope(hidden_states)
# 2. Conditional embeddings
temb = self.time_text_embed(timestep, guidance, pooled_projections)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
encoder_hidden_states = encoder_hidden_states[:, encoder_attention_mask[0].bool()]
# 4. Transformer blocks
hidden_states, encoder_hidden_states = self.call_transformer_blocks(
hidden_states, encoder_hidden_states, temb, None, image_rotary_emb
)
# 5. Output projection
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
)
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
hidden_states = hidden_states.to(timestep.dtype)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)
transformer.forward = new_forward.__get__(transformer)
def call_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, **kwargs):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
*args,
**kwargs,
**ckpt_kwargs,
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
*args,
**kwargs,
**ckpt_kwargs,
)
else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
return hidden_states, encoder_hidden_states
transformer.call_transformer_blocks = call_transformer_blocks.__get__(transformer)
return transformer
def apply_cache_on_pipe(
pipe: DiffusionPipeline,
*,
shallow_patch: bool = False,
**kwargs,
):
original_call = pipe.__class__.__call__
if not getattr(original_call, "_is_cached", False):
@functools.wraps(original_call)
def new_call(self, *args, **kwargs):
with utils.cache_context(utils.create_cache_context()):
return original_call(self, *args, **kwargs)
pipe.__class__.__call__ = new_call
new_call._is_cached = True
if not shallow_patch:
apply_cache_on_transformer(pipe.transformer, **kwargs)
pipe._is_cached = True
return pipe