|
import os |
|
import time |
|
import random |
|
import functools |
|
from typing import List, Optional, Tuple, Union |
|
|
|
from pathlib import Path |
|
from loguru import logger |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE |
|
from hyvideo.vae import load_vae |
|
from hyvideo.modules import load_model |
|
from hyvideo.text_encoder import TextEncoder |
|
from hyvideo.utils.data_utils import align_to |
|
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed |
|
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler |
|
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline |
|
|
|
try: |
|
import xfuser |
|
from xfuser.core.distributed import ( |
|
get_sequence_parallel_world_size, |
|
get_sequence_parallel_rank, |
|
get_sp_group, |
|
initialize_model_parallel, |
|
init_distributed_environment |
|
) |
|
except: |
|
xfuser = None |
|
get_sequence_parallel_world_size = None |
|
get_sequence_parallel_rank = None |
|
get_sp_group = None |
|
initialize_model_parallel = None |
|
init_distributed_environment = None |
|
|
|
|
|
def parallelize_transformer(pipe): |
|
transformer = pipe.transformer |
|
original_forward = transformer.forward |
|
|
|
@functools.wraps(transformer.__class__.forward) |
|
def new_forward( |
|
self, |
|
x: torch.Tensor, |
|
t: torch.Tensor, |
|
text_states: torch.Tensor = None, |
|
text_mask: torch.Tensor = None, |
|
text_states_2: Optional[torch.Tensor] = None, |
|
freqs_cos: Optional[torch.Tensor] = None, |
|
freqs_sin: Optional[torch.Tensor] = None, |
|
guidance: torch.Tensor = None, |
|
return_dict: bool = True, |
|
): |
|
if x.shape[-2] // 2 % get_sequence_parallel_world_size() == 0: |
|
|
|
split_dim = -2 |
|
elif x.shape[-1] // 2 % get_sequence_parallel_world_size() == 0: |
|
|
|
split_dim = -1 |
|
else: |
|
raise ValueError(f"Cannot split video sequence into ulysses_degree x ring_degree ({get_sequence_parallel_world_size()}) parts evenly") |
|
|
|
|
|
temporal_size, h, w = x.shape[2], x.shape[3] // 2, x.shape[4] // 2 |
|
|
|
x = torch.chunk(x, get_sequence_parallel_world_size(),dim=split_dim)[get_sequence_parallel_rank()] |
|
|
|
dim_thw = freqs_cos.shape[-1] |
|
freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw) |
|
freqs_cos = torch.chunk(freqs_cos, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()] |
|
freqs_cos = freqs_cos.reshape(-1, dim_thw) |
|
dim_thw = freqs_sin.shape[-1] |
|
freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw) |
|
freqs_sin = torch.chunk(freqs_sin, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()] |
|
freqs_sin = freqs_sin.reshape(-1, dim_thw) |
|
|
|
from xfuser.core.long_ctx_attention import xFuserLongContextAttention |
|
|
|
for block in transformer.double_blocks + transformer.single_blocks: |
|
block.hybrid_seq_parallel_attn = xFuserLongContextAttention() |
|
|
|
output = original_forward( |
|
x, |
|
t, |
|
text_states, |
|
text_mask, |
|
text_states_2, |
|
freqs_cos, |
|
freqs_sin, |
|
guidance, |
|
return_dict, |
|
) |
|
|
|
return_dict = not isinstance(output, tuple) |
|
sample = output["x"] |
|
sample = get_sp_group().all_gather(sample, dim=split_dim) |
|
output["x"] = sample |
|
return output |
|
|
|
new_forward = new_forward.__get__(transformer) |
|
transformer.forward = new_forward |
|
|
|
|
|
class Inference(object): |
|
def __init__( |
|
self, |
|
args, |
|
vae, |
|
vae_kwargs, |
|
text_encoder, |
|
model, |
|
text_encoder_2=None, |
|
pipeline=None, |
|
use_cpu_offload=False, |
|
device=None, |
|
logger=None, |
|
parallel_args=None, |
|
): |
|
self.vae = vae |
|
self.vae_kwargs = vae_kwargs |
|
|
|
self.text_encoder = text_encoder |
|
self.text_encoder_2 = text_encoder_2 |
|
|
|
self.model = model |
|
self.pipeline = pipeline |
|
self.use_cpu_offload = use_cpu_offload |
|
|
|
self.args = args |
|
self.device = ( |
|
device |
|
if device is not None |
|
else "cuda" |
|
if torch.cuda.is_available() |
|
else "cpu" |
|
) |
|
self.logger = logger |
|
self.parallel_args = parallel_args |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs): |
|
""" |
|
Initialize the Inference pipeline. |
|
|
|
Args: |
|
pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints. |
|
args (argparse.Namespace): The arguments for the pipeline. |
|
device (int): The device for inference. Default is 0. |
|
""" |
|
|
|
logger.info(f"Got text-to-video model root path: {pretrained_model_path}") |
|
|
|
|
|
if args.ulysses_degree > 1 or args.ring_degree > 1: |
|
assert xfuser is not None, \ |
|
"Ulysses Attention and Ring Attention requires xfuser package." |
|
|
|
assert args.use_cpu_offload is False, \ |
|
"Cannot enable use_cpu_offload in the distributed environment." |
|
|
|
dist.init_process_group("nccl") |
|
|
|
assert dist.get_world_size() == args.ring_degree * args.ulysses_degree, \ |
|
"number of GPUs should be equal to ring_degree * ulysses_degree." |
|
|
|
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) |
|
|
|
initialize_model_parallel( |
|
sequence_parallel_degree=dist.get_world_size(), |
|
ring_degree=args.ring_degree, |
|
ulysses_degree=args.ulysses_degree, |
|
) |
|
device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}") |
|
else: |
|
if device is None: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
parallel_args = {"ulysses_degree": args.ulysses_degree, "ring_degree": args.ring_degree} |
|
|
|
|
|
|
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
logger.info("Building model...") |
|
factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]} |
|
in_channels = args.latent_channels |
|
out_channels = args.latent_channels |
|
|
|
model = load_model( |
|
args, |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
factor_kwargs=factor_kwargs, |
|
) |
|
model = model.to(device) |
|
model = Inference.load_state_dict(args, model, pretrained_model_path) |
|
model.eval() |
|
|
|
|
|
|
|
vae, _, s_ratio, t_ratio = load_vae( |
|
args.vae, |
|
args.vae_precision, |
|
logger=logger, |
|
device=device if not args.use_cpu_offload else "cpu", |
|
) |
|
vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} |
|
|
|
|
|
if args.prompt_template_video is not None: |
|
crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get( |
|
"crop_start", 0 |
|
) |
|
elif args.prompt_template is not None: |
|
crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0) |
|
else: |
|
crop_start = 0 |
|
max_length = args.text_len + crop_start |
|
|
|
|
|
prompt_template = ( |
|
PROMPT_TEMPLATE[args.prompt_template] |
|
if args.prompt_template is not None |
|
else None |
|
) |
|
|
|
|
|
prompt_template_video = ( |
|
PROMPT_TEMPLATE[args.prompt_template_video] |
|
if args.prompt_template_video is not None |
|
else None |
|
) |
|
|
|
text_encoder = TextEncoder( |
|
text_encoder_type=args.text_encoder, |
|
max_length=max_length, |
|
text_encoder_precision=args.text_encoder_precision, |
|
tokenizer_type=args.tokenizer, |
|
prompt_template=prompt_template, |
|
prompt_template_video=prompt_template_video, |
|
hidden_state_skip_layer=args.hidden_state_skip_layer, |
|
apply_final_norm=args.apply_final_norm, |
|
reproduce=args.reproduce, |
|
logger=logger, |
|
device=device if not args.use_cpu_offload else "cpu", |
|
) |
|
text_encoder_2 = None |
|
if args.text_encoder_2 is not None: |
|
text_encoder_2 = TextEncoder( |
|
text_encoder_type=args.text_encoder_2, |
|
max_length=args.text_len_2, |
|
text_encoder_precision=args.text_encoder_precision_2, |
|
tokenizer_type=args.tokenizer_2, |
|
reproduce=args.reproduce, |
|
logger=logger, |
|
device=device if not args.use_cpu_offload else "cpu", |
|
) |
|
|
|
return cls( |
|
args=args, |
|
vae=vae, |
|
vae_kwargs=vae_kwargs, |
|
text_encoder=text_encoder, |
|
text_encoder_2=text_encoder_2, |
|
model=model, |
|
use_cpu_offload=args.use_cpu_offload, |
|
device=device, |
|
logger=logger, |
|
parallel_args=parallel_args |
|
) |
|
|
|
@staticmethod |
|
def load_state_dict(args, model, pretrained_model_path): |
|
load_key = args.load_key |
|
dit_weight = Path(args.dit_weight) |
|
|
|
if dit_weight is None: |
|
model_dir = pretrained_model_path / f"t2v_{args.model_resolution}" |
|
files = list(model_dir.glob("*.pt")) |
|
if len(files) == 0: |
|
raise ValueError(f"No model weights found in {model_dir}") |
|
if str(files[0]).startswith("pytorch_model_"): |
|
model_path = dit_weight / f"pytorch_model_{load_key}.pt" |
|
bare_model = True |
|
elif any(str(f).endswith("_model_states.pt") for f in files): |
|
files = [f for f in files if str(f).endswith("_model_states.pt")] |
|
model_path = files[0] |
|
if len(files) > 1: |
|
logger.warning( |
|
f"Multiple model weights found in {dit_weight}, using {model_path}" |
|
) |
|
bare_model = False |
|
else: |
|
raise ValueError( |
|
f"Invalid model path: {dit_weight} with unrecognized weight format: " |
|
f"{list(map(str, files))}. When given a directory as --dit-weight, only " |
|
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and " |
|
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a " |
|
f"specific weight file, please provide the full path to the file." |
|
) |
|
else: |
|
if dit_weight.is_dir(): |
|
files = list(dit_weight.glob("*.pt")) |
|
if len(files) == 0: |
|
raise ValueError(f"No model weights found in {dit_weight}") |
|
if str(files[0]).startswith("pytorch_model_"): |
|
model_path = dit_weight / f"pytorch_model_{load_key}.pt" |
|
bare_model = True |
|
elif any(str(f).endswith("_model_states.pt") for f in files): |
|
files = [f for f in files if str(f).endswith("_model_states.pt")] |
|
model_path = files[0] |
|
if len(files) > 1: |
|
logger.warning( |
|
f"Multiple model weights found in {dit_weight}, using {model_path}" |
|
) |
|
bare_model = False |
|
else: |
|
raise ValueError( |
|
f"Invalid model path: {dit_weight} with unrecognized weight format: " |
|
f"{list(map(str, files))}. When given a directory as --dit-weight, only " |
|
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and " |
|
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a " |
|
f"specific weight file, please provide the full path to the file." |
|
) |
|
elif dit_weight.is_file(): |
|
model_path = dit_weight |
|
bare_model = "unknown" |
|
else: |
|
raise ValueError(f"Invalid model path: {dit_weight}") |
|
|
|
if not model_path.exists(): |
|
raise ValueError(f"model_path not exists: {model_path}") |
|
logger.info(f"Loading torch model {model_path}...") |
|
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) |
|
|
|
if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict): |
|
bare_model = False |
|
if bare_model is False: |
|
if load_key in state_dict: |
|
state_dict = state_dict[load_key] |
|
else: |
|
raise KeyError( |
|
f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint " |
|
f"are: {list(state_dict.keys())}." |
|
) |
|
model.load_state_dict(state_dict, strict=True) |
|
return model |
|
|
|
@staticmethod |
|
def parse_size(size): |
|
if isinstance(size, int): |
|
size = [size] |
|
if not isinstance(size, (list, tuple)): |
|
raise ValueError(f"Size must be an integer or (height, width), got {size}.") |
|
if len(size) == 1: |
|
size = [size[0], size[0]] |
|
if len(size) != 2: |
|
raise ValueError(f"Size must be an integer or (height, width), got {size}.") |
|
return size |
|
|
|
|
|
class HunyuanVideoSampler(Inference): |
|
def __init__( |
|
self, |
|
args, |
|
vae, |
|
vae_kwargs, |
|
text_encoder, |
|
model, |
|
text_encoder_2=None, |
|
pipeline=None, |
|
use_cpu_offload=False, |
|
device=0, |
|
logger=None, |
|
parallel_args=None |
|
): |
|
super().__init__( |
|
args, |
|
vae, |
|
vae_kwargs, |
|
text_encoder, |
|
model, |
|
text_encoder_2=text_encoder_2, |
|
pipeline=pipeline, |
|
use_cpu_offload=use_cpu_offload, |
|
device=device, |
|
logger=logger, |
|
parallel_args=parallel_args |
|
) |
|
|
|
self.pipeline = self.load_diffusion_pipeline( |
|
args=args, |
|
vae=self.vae, |
|
text_encoder=self.text_encoder, |
|
text_encoder_2=self.text_encoder_2, |
|
model=self.model, |
|
device=self.device, |
|
) |
|
|
|
self.default_negative_prompt = NEGATIVE_PROMPT |
|
|
|
def load_diffusion_pipeline( |
|
self, |
|
args, |
|
vae, |
|
text_encoder, |
|
text_encoder_2, |
|
model, |
|
scheduler=None, |
|
device=None, |
|
progress_bar_config=None, |
|
data_type="video", |
|
): |
|
"""Load the denoising scheduler for inference.""" |
|
if scheduler is None: |
|
if args.denoise_type == "flow": |
|
scheduler = FlowMatchDiscreteScheduler( |
|
shift=args.flow_shift, |
|
reverse=args.flow_reverse, |
|
solver=args.flow_solver, |
|
) |
|
else: |
|
raise ValueError(f"Invalid denoise type {args.denoise_type}") |
|
|
|
pipeline = HunyuanVideoPipeline( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
text_encoder_2=text_encoder_2, |
|
transformer=model, |
|
scheduler=scheduler, |
|
progress_bar_config=progress_bar_config, |
|
args=args, |
|
) |
|
if self.use_cpu_offload: |
|
pipeline.enable_sequential_cpu_offload() |
|
else: |
|
pipeline = pipeline.to(device) |
|
|
|
return pipeline |
|
|
|
def get_rotary_pos_embed(self, video_length, height, width): |
|
target_ndim = 3 |
|
ndim = 5 - 2 |
|
|
|
if "884" in self.args.vae: |
|
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] |
|
elif "888" in self.args.vae: |
|
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8] |
|
else: |
|
latents_size = [video_length, height // 8, width // 8] |
|
|
|
if isinstance(self.model.patch_size, int): |
|
assert all(s % self.model.patch_size == 0 for s in latents_size), ( |
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " |
|
f"but got {latents_size}." |
|
) |
|
rope_sizes = [s // self.model.patch_size for s in latents_size] |
|
elif isinstance(self.model.patch_size, list): |
|
assert all( |
|
s % self.model.patch_size[idx] == 0 |
|
for idx, s in enumerate(latents_size) |
|
), ( |
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " |
|
f"but got {latents_size}." |
|
) |
|
rope_sizes = [ |
|
s // self.model.patch_size[idx] for idx, s in enumerate(latents_size) |
|
] |
|
|
|
if len(rope_sizes) != target_ndim: |
|
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes |
|
head_dim = self.model.hidden_size // self.model.heads_num |
|
rope_dim_list = self.model.rope_dim_list |
|
if rope_dim_list is None: |
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] |
|
assert ( |
|
sum(rope_dim_list) == head_dim |
|
), "sum(rope_dim_list) should equal to head_dim of attention layer" |
|
freqs_cos, freqs_sin = get_nd_rotary_pos_embed( |
|
rope_dim_list, |
|
rope_sizes, |
|
theta=self.args.rope_theta, |
|
use_real=True, |
|
theta_rescale_factor=1, |
|
) |
|
return freqs_cos, freqs_sin |
|
|
|
@torch.no_grad() |
|
def predict( |
|
self, |
|
prompt, |
|
height=192, |
|
width=336, |
|
video_length=129, |
|
seed=None, |
|
negative_prompt=None, |
|
infer_steps=50, |
|
guidance_scale=6, |
|
flow_shift=5.0, |
|
embedded_guidance_scale=None, |
|
batch_size=1, |
|
num_videos_per_prompt=1, |
|
**kwargs, |
|
): |
|
""" |
|
Predict the image/video from the given text. |
|
|
|
Args: |
|
prompt (str or List[str]): The input text. |
|
kwargs: |
|
height (int): The height of the output video. Default is 192. |
|
width (int): The width of the output video. Default is 336. |
|
video_length (int): The frame number of the output video. Default is 129. |
|
seed (int or List[str]): The random seed for the generation. Default is a random integer. |
|
negative_prompt (str or List[str]): The negative text prompt. Default is an empty string. |
|
guidance_scale (float): The guidance scale for the generation. Default is 6.0. |
|
num_images_per_prompt (int): The number of images per prompt. Default is 1. |
|
infer_steps (int): The number of inference steps. Default is 100. |
|
""" |
|
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1: |
|
assert seed is not None, \ |
|
"You have to set a seed in the distributed environment, please rerun with --seed <your-seed>." |
|
|
|
parallelize_transformer(self.pipeline) |
|
|
|
out_dict = dict() |
|
|
|
|
|
|
|
|
|
if isinstance(seed, torch.Tensor): |
|
seed = seed.tolist() |
|
if seed is None: |
|
seeds = [ |
|
random.randint(0, 1_000_000) |
|
for _ in range(batch_size * num_videos_per_prompt) |
|
] |
|
elif isinstance(seed, int): |
|
seeds = [ |
|
seed + i |
|
for _ in range(batch_size) |
|
for i in range(num_videos_per_prompt) |
|
] |
|
elif isinstance(seed, (list, tuple)): |
|
if len(seed) == batch_size: |
|
seeds = [ |
|
int(seed[i]) + j |
|
for i in range(batch_size) |
|
for j in range(num_videos_per_prompt) |
|
] |
|
elif len(seed) == batch_size * num_videos_per_prompt: |
|
seeds = [int(s) for s in seed] |
|
else: |
|
raise ValueError( |
|
f"Length of seed must be equal to number of prompt(batch_size) or " |
|
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}." |
|
) |
|
else: |
|
raise ValueError( |
|
f"Seed must be an integer, a list of integers, or None, got {seed}." |
|
) |
|
generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds] |
|
out_dict["seeds"] = seeds |
|
|
|
|
|
|
|
|
|
if width <= 0 or height <= 0 or video_length <= 0: |
|
raise ValueError( |
|
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}" |
|
) |
|
if (video_length - 1) % 4 != 0: |
|
raise ValueError( |
|
f"`video_length-1` must be a multiple of 4, got {video_length}" |
|
) |
|
|
|
logger.info( |
|
f"Input (height, width, video_length) = ({height}, {width}, {video_length})" |
|
) |
|
|
|
target_height = align_to(height, 16) |
|
target_width = align_to(width, 16) |
|
target_video_length = video_length |
|
|
|
out_dict["size"] = (target_height, target_width, target_video_length) |
|
|
|
|
|
|
|
|
|
if not isinstance(prompt, str): |
|
raise TypeError(f"`prompt` must be a string, but got {type(prompt)}") |
|
prompt = [prompt.strip()] |
|
|
|
|
|
if negative_prompt is None or negative_prompt == "": |
|
negative_prompt = self.default_negative_prompt |
|
if not isinstance(negative_prompt, str): |
|
raise TypeError( |
|
f"`negative_prompt` must be a string, but got {type(negative_prompt)}" |
|
) |
|
negative_prompt = [negative_prompt.strip()] |
|
|
|
|
|
|
|
|
|
scheduler = FlowMatchDiscreteScheduler( |
|
shift=flow_shift, |
|
reverse=self.args.flow_reverse, |
|
solver=self.args.flow_solver |
|
) |
|
self.pipeline.scheduler = scheduler |
|
|
|
|
|
|
|
|
|
freqs_cos, freqs_sin = self.get_rotary_pos_embed( |
|
target_video_length, target_height, target_width |
|
) |
|
n_tokens = freqs_cos.shape[0] |
|
|
|
|
|
|
|
|
|
debug_str = f""" |
|
height: {target_height} |
|
width: {target_width} |
|
video_length: {target_video_length} |
|
prompt: {prompt} |
|
neg_prompt: {negative_prompt} |
|
seed: {seed} |
|
infer_steps: {infer_steps} |
|
num_videos_per_prompt: {num_videos_per_prompt} |
|
guidance_scale: {guidance_scale} |
|
n_tokens: {n_tokens} |
|
flow_shift: {flow_shift} |
|
embedded_guidance_scale: {embedded_guidance_scale}""" |
|
logger.debug(debug_str) |
|
|
|
|
|
|
|
|
|
start_time = time.time() |
|
samples = self.pipeline( |
|
prompt=prompt, |
|
height=target_height, |
|
width=target_width, |
|
video_length=target_video_length, |
|
num_inference_steps=infer_steps, |
|
guidance_scale=guidance_scale, |
|
negative_prompt=negative_prompt, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
generator=generator, |
|
output_type="pil", |
|
freqs_cis=(freqs_cos, freqs_sin), |
|
n_tokens=n_tokens, |
|
embedded_guidance_scale=embedded_guidance_scale, |
|
data_type="video" if target_video_length > 1 else "image", |
|
is_progress_bar=True, |
|
vae_ver=self.args.vae, |
|
enable_tiling=self.args.vae_tiling, |
|
)[0] |
|
out_dict["samples"] = samples |
|
out_dict["prompts"] = prompt |
|
|
|
gen_time = time.time() - start_time |
|
logger.info(f"Success, time: {gen_time}") |
|
|
|
return out_dict |
|
|