|
import argparse |
|
from .constants import * |
|
import re |
|
from .modules.models import HUNYUAN_VIDEO_CONFIG |
|
|
|
|
|
def parse_args(namespace=None): |
|
parser = argparse.ArgumentParser(description="HunyuanVideo inference script") |
|
|
|
parser = add_network_args(parser) |
|
parser = add_extra_models_args(parser) |
|
parser = add_denoise_schedule_args(parser) |
|
parser = add_inference_args(parser) |
|
parser = add_parallel_args(parser) |
|
|
|
args = parser.parse_args(namespace=namespace) |
|
args = sanity_check_args(args) |
|
|
|
return args |
|
|
|
|
|
def add_network_args(parser: argparse.ArgumentParser): |
|
group = parser.add_argument_group(title="HunyuanVideo network args") |
|
|
|
|
|
group.add_argument( |
|
"--model", |
|
type=str, |
|
choices=list(HUNYUAN_VIDEO_CONFIG.keys()), |
|
default="HYVideo-T/2-cfgdistill", |
|
) |
|
group.add_argument( |
|
"--latent-channels", |
|
type=str, |
|
default=16, |
|
help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, " |
|
"it still needs to match the latent channels of the VAE model.", |
|
) |
|
group.add_argument( |
|
"--precision", |
|
type=str, |
|
default="bf16", |
|
choices=PRECISIONS, |
|
help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.", |
|
) |
|
|
|
|
|
group.add_argument( |
|
"--rope-theta", type=int, default=256, help="Theta used in RoPE." |
|
) |
|
return parser |
|
|
|
|
|
def add_extra_models_args(parser: argparse.ArgumentParser): |
|
group = parser.add_argument_group( |
|
title="Extra models args, including vae, text encoders and tokenizers)" |
|
) |
|
|
|
|
|
group.add_argument( |
|
"--vae", |
|
type=str, |
|
default="884-16c-hy", |
|
choices=list(VAE_PATH), |
|
help="Name of the VAE model.", |
|
) |
|
group.add_argument( |
|
"--vae-precision", |
|
type=str, |
|
default="fp16", |
|
choices=PRECISIONS, |
|
help="Precision mode for the VAE model.", |
|
) |
|
group.add_argument( |
|
"--vae-tiling", |
|
action="store_true", |
|
help="Enable tiling for the VAE model to save GPU memory.", |
|
) |
|
group.set_defaults(vae_tiling=True) |
|
|
|
group.add_argument( |
|
"--text-encoder", |
|
type=str, |
|
default="llm", |
|
choices=list(TEXT_ENCODER_PATH), |
|
help="Name of the text encoder model.", |
|
) |
|
group.add_argument( |
|
"--text-encoder-precision", |
|
type=str, |
|
default="fp16", |
|
choices=PRECISIONS, |
|
help="Precision mode for the text encoder model.", |
|
) |
|
group.add_argument( |
|
"--text-states-dim", |
|
type=int, |
|
default=4096, |
|
help="Dimension of the text encoder hidden states.", |
|
) |
|
group.add_argument( |
|
"--text-len", type=int, default=256, help="Maximum length of the text input." |
|
) |
|
group.add_argument( |
|
"--tokenizer", |
|
type=str, |
|
default="llm", |
|
choices=list(TOKENIZER_PATH), |
|
help="Name of the tokenizer model.", |
|
) |
|
group.add_argument( |
|
"--prompt-template", |
|
type=str, |
|
default="dit-llm-encode", |
|
choices=PROMPT_TEMPLATE, |
|
help="Image prompt template for the decoder-only text encoder model.", |
|
) |
|
group.add_argument( |
|
"--prompt-template-video", |
|
type=str, |
|
default="dit-llm-encode-video", |
|
choices=PROMPT_TEMPLATE, |
|
help="Video prompt template for the decoder-only text encoder model.", |
|
) |
|
group.add_argument( |
|
"--hidden-state-skip-layer", |
|
type=int, |
|
default=2, |
|
help="Skip layer for hidden states.", |
|
) |
|
group.add_argument( |
|
"--apply-final-norm", |
|
action="store_true", |
|
help="Apply final normalization to the used text encoder hidden states.", |
|
) |
|
|
|
|
|
group.add_argument( |
|
"--text-encoder-2", |
|
type=str, |
|
default="clipL", |
|
choices=list(TEXT_ENCODER_PATH), |
|
help="Name of the second text encoder model.", |
|
) |
|
group.add_argument( |
|
"--text-encoder-precision-2", |
|
type=str, |
|
default="fp16", |
|
choices=PRECISIONS, |
|
help="Precision mode for the second text encoder model.", |
|
) |
|
group.add_argument( |
|
"--text-states-dim-2", |
|
type=int, |
|
default=768, |
|
help="Dimension of the second text encoder hidden states.", |
|
) |
|
group.add_argument( |
|
"--tokenizer-2", |
|
type=str, |
|
default="clipL", |
|
choices=list(TOKENIZER_PATH), |
|
help="Name of the second tokenizer model.", |
|
) |
|
group.add_argument( |
|
"--text-len-2", |
|
type=int, |
|
default=77, |
|
help="Maximum length of the second text input.", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def add_denoise_schedule_args(parser: argparse.ArgumentParser): |
|
group = parser.add_argument_group(title="Denoise schedule args") |
|
|
|
group.add_argument( |
|
"--denoise-type", |
|
type=str, |
|
default="flow", |
|
help="Denoise type for noised inputs.", |
|
) |
|
|
|
|
|
group.add_argument( |
|
"--flow-shift", |
|
type=float, |
|
default=7.0, |
|
help="Shift factor for flow matching schedulers.", |
|
) |
|
group.add_argument( |
|
"--flow-reverse", |
|
action="store_true", |
|
help="If reverse, learning/sampling from t=1 -> t=0.", |
|
) |
|
group.add_argument( |
|
"--flow-solver", |
|
type=str, |
|
default="euler", |
|
help="Solver for flow matching.", |
|
) |
|
group.add_argument( |
|
"--use-linear-quadratic-schedule", |
|
action="store_true", |
|
help="Use linear quadratic schedule for flow matching." |
|
"Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)", |
|
) |
|
group.add_argument( |
|
"--linear-schedule-end", |
|
type=int, |
|
default=25, |
|
help="End step for linear quadratic schedule for flow matching.", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def add_inference_args(parser: argparse.ArgumentParser): |
|
group = parser.add_argument_group(title="Inference args") |
|
|
|
|
|
group.add_argument( |
|
"--model-base", |
|
type=str, |
|
default=".", |
|
help="Root path of all the models, including t2v models and extra models.", |
|
) |
|
group.add_argument( |
|
"--dit-weight", |
|
type=str, |
|
default="./hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", |
|
help="Path to the HunyuanVideo model. If None, search the model in the args.model_root." |
|
"1. If it is a file, load the model directly." |
|
"2. If it is a directory, search the model in the directory. Support two types of models: " |
|
"1) named `pytorch_model_*.pt`" |
|
"2) named `*_model_states.pt`, where * can be `mp_rank_00`.", |
|
) |
|
group.add_argument( |
|
"--model-resolution", |
|
type=str, |
|
default="540p", |
|
choices=["540p", "720p"], |
|
help="Root path of all the models, including t2v models and extra models.", |
|
) |
|
group.add_argument( |
|
"--load-key", |
|
type=str, |
|
default="module", |
|
help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.", |
|
) |
|
group.add_argument( |
|
"--use-cpu-offload", |
|
action="store_true", |
|
help="Use CPU offload for the model load.", |
|
) |
|
|
|
|
|
group.add_argument( |
|
"--batch-size", |
|
type=int, |
|
default=1, |
|
help="Batch size for inference and evaluation.", |
|
) |
|
group.add_argument( |
|
"--infer-steps", |
|
type=int, |
|
default=50, |
|
help="Number of denoising steps for inference.", |
|
) |
|
group.add_argument( |
|
"--disable-autocast", |
|
action="store_true", |
|
help="Disable autocast for denoising loop and vae decoding in pipeline sampling.", |
|
) |
|
group.add_argument( |
|
"--save-path", |
|
type=str, |
|
default="./results", |
|
help="Path to save the generated samples.", |
|
) |
|
group.add_argument( |
|
"--save-path-suffix", |
|
type=str, |
|
default="", |
|
help="Suffix for the directory of saved samples.", |
|
) |
|
group.add_argument( |
|
"--name-suffix", |
|
type=str, |
|
default="", |
|
help="Suffix for the names of saved samples.", |
|
) |
|
group.add_argument( |
|
"--num-videos", |
|
type=int, |
|
default=1, |
|
help="Number of videos to generate for each prompt.", |
|
) |
|
|
|
group.add_argument( |
|
"--video-size", |
|
type=int, |
|
nargs="+", |
|
default=(720, 1280), |
|
help="Video size for training. If a single value is provided, it will be used for both height " |
|
"and width. If two values are provided, they will be used for height and width " |
|
"respectively.", |
|
) |
|
group.add_argument( |
|
"--video-length", |
|
type=int, |
|
default=129, |
|
help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1", |
|
) |
|
|
|
group.add_argument( |
|
"--prompt", |
|
type=str, |
|
default=None, |
|
help="Prompt for sampling during evaluation.", |
|
) |
|
group.add_argument( |
|
"--seed-type", |
|
type=str, |
|
default="auto", |
|
choices=["file", "random", "fixed", "auto"], |
|
help="Seed type for evaluation. If file, use the seed from the CSV file. If random, generate a " |
|
"random seed. If fixed, use the fixed seed given by `--seed`. If auto, `csv` will use the " |
|
"seed column if available, otherwise use the fixed `seed` value. `prompt` will use the " |
|
"fixed `seed` value.", |
|
) |
|
group.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") |
|
|
|
|
|
group.add_argument( |
|
"--neg-prompt", type=str, default=None, help="Negative prompt for sampling." |
|
) |
|
group.add_argument( |
|
"--cfg-scale", type=float, default=1.0, help="Classifier free guidance scale." |
|
) |
|
group.add_argument( |
|
"--embedded-cfg-scale", |
|
type=float, |
|
default=6.0, |
|
help="Embeded classifier free guidance scale.", |
|
) |
|
|
|
group.add_argument( |
|
"--reproduce", |
|
action="store_true", |
|
help="Enable reproducibility by setting random seeds and deterministic algorithms.", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def add_parallel_args(parser: argparse.ArgumentParser): |
|
group = parser.add_argument_group(title="Parallel args") |
|
|
|
|
|
group.add_argument( |
|
"--ulysses-degree", |
|
type=int, |
|
default=1, |
|
help="Ulysses degree.", |
|
) |
|
group.add_argument( |
|
"--ring-degree", |
|
type=int, |
|
default=1, |
|
help="Ulysses degree.", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def sanity_check_args(args): |
|
|
|
vae_pattern = r"\d{2,3}-\d{1,2}c-\w+" |
|
if not re.match(vae_pattern, args.vae): |
|
raise ValueError( |
|
f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'." |
|
) |
|
vae_channels = int(args.vae.split("-")[1][:-1]) |
|
if args.latent_channels is None: |
|
args.latent_channels = vae_channels |
|
if vae_channels != args.latent_channels: |
|
raise ValueError( |
|
f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})." |
|
) |
|
return args |
|
|