HunyuanVideo-HFIE / handler.py
jbilcke-hf's picture
jbilcke-hf HF staff
Update handler.py
bab295a verified
from dataclasses import dataclass
from typing import Dict, Any, Optional
import base64
import asyncio
import logging
import random
import traceback
import torch
import os
import gc
# note: there is no HunyuanImageToVideoPipeline yet in Diffusers
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, FasterCacheConfig
from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
from varnish import Varnish
from varnish.utils import is_truthy, process_input_image
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Check environment variable for pipeline support
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
@dataclass
class GenerationConfig:
"""Configuration for video generation"""
# Content settings
prompt: str
negative_prompt: str = ""
# Model settings
num_frames: int = 49 # Should be 4k + 1 format
height: int = 320
width: int = 576
num_inference_steps: int = 50
guidance_scale: float = 7.0
# Reproducibility
seed: int = -1
# Varnish post-processing settings
fps: int = 30
double_num_frames: bool = False
super_resolution: bool = False
grain_amount: float = 0.0
quality: int = 18 # CRF scale (0-51, lower is better)
# Audio settings
enable_audio: bool = False
audio_prompt: str = ""
audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
# TeaCache settings
enable_teacache: bool = False
teacache_threshold: float = 0.15 # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup)
# Enhance-A-Video settings
enable_enhance_a_video: bool = False
enhance_a_video_weight: float = 5.0
# LoRA settings
lora_model_name: str = "" # HuggingFace repo ID or path to LoRA model
lora_model_weight_file: str = "" # Specific weight file to load from the LoRA model
lora_model_trigger: str = "" # Optional trigger word to prepend to the prompt
def validate_and_adjust(self) -> 'GenerationConfig':
"""Validate and adjust parameters"""
# Ensure num_frames follows 4k + 1 format
k = (self.num_frames - 1) // 4
self.num_frames = (k * 4) + 1
# Set random seed if not specified
if self.seed == -1:
self.seed = random.randint(0, 2**32 - 1)
return self
class EndpointHandler:
"""Handles video generation requests using HunyuanVideo and Varnish"""
def __init__(self, path: str = ""):
"""Initialize handler with models
Args:
path: Path to model weights
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize transformer with Enhance-A-Video injection first
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
path,
subfolder="transformer",
torch_dtype=torch.bfloat16
)
if support_image_prompt:
raise Exception("Please use a version of Diffusers that supports HunyuanImageToVideoPipeline")
# # Initialize image-to-video pipeline
# self.image_to_video = HunyuanImageToVideoPipeline.from_pretrained(
# path,
# transformer=transformer,
# torch_dtype=torch.float16,
# ).to(self.device)
#
# # Initialize components in appropriate precision
# self.image_to_video.text_encoder = self.image_to_video.text_encoder.half()
# self.image_to_video.text_encoder_2 = self.image_to_video.text_encoder_2.half()
# self.image_to_video.transformer = self.image_to_video.transformer.to(torch.bfloat16)
# self.image_to_video.vae = self.image_to_video.vae.half()
# apply_enhance_a_video(self.image_to_video.transformer, EnhanceAVideoConfig(
# weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
# num_frames_callback=lambda: (config.num_frames - 1),
# _attention_type=1
# ))
else:
# Initialize text-to-video pipeline
self.text_to_video = HunyuanVideoPipeline.from_pretrained(
path,
transformer=transformer,
torch_dtype=torch.float16,
).to(self.device)
# Initialize components in appropriate precision
self.text_to_video.text_encoder = self.text_to_video.text_encoder.half()
self.text_to_video.text_encoder_2 = self.text_to_video.text_encoder_2.half()
self.text_to_video.transformer = self.text_to_video.transformer.to(torch.bfloat16)
self.text_to_video.vae = self.text_to_video.vae.half()
# apply_enhance_a_video(self.text_to_video.transformer, EnhanceAVideoConfig(
# # weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
# weight=config.enhance_a_video_weight,
# num_frames_callback=lambda: (config.num_frames - 1),
# _attention_type=1
# ))
# enable FasterCache
# those values are coming from here:
# https://github.com/huggingface/diffusers/pull/10163/files#diff-777f4ee62cb325371233a450e0f6cc0ba357a3fade2ec2dea912260b4f8d08ceR67-R74
faster_cache_config = FasterCacheConfig(
current_timestep_callback=lambda: self.text_to_video.current_timestep,
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 901),
unconditional_batch_skip_range=2,
attention_weight_callback=lambda _: 0.5,
is_guidance_distilled=True,
# do we need to uncomment those?
#unconditional_batch_timestep_skip_range=(-1, 901),
#tensor_format="BFCHW",
)
#self.text_to_video.transformer.enable_cache(faster_cache_config)
# Initialize LoRA tracking
self._current_lora_model = None
# Initialize Varnish for post-processing
self.varnish = Varnish(
device=self.device,
model_base_dir="/repository/varnish"
)
async def process_frames(
self,
frames: torch.Tensor,
config: GenerationConfig
) -> tuple[str, dict]:
"""Post-process generated frames using Varnish
Args:
frames: Generated video frames tensor
config: Generation configuration
Returns:
Tuple of (video data URI, metadata dictionary)
"""
try:
# Process video with Varnish
result = await self.varnish(
input_data=frames,
fps=config.fps,
double_num_frames=config.double_num_frames,
super_resolution=config.super_resolution,
grain_amount=config.grain_amount,
enable_audio=config.enable_audio,
audio_prompt=config.audio_prompt,
audio_negative_prompt=config.audio_negative_prompt
)
# Convert to data URI
video_uri = await result.write(type="data-uri", quality=config.quality)
# Collect metadata
metadata = {
"width": result.metadata.width,
"height": result.metadata.height,
"num_frames": result.metadata.frame_count,
"fps": result.metadata.fps,
"duration": result.metadata.duration,
"seed": config.seed,
"enable_teacache": config.enable_teacache,
"teacache_threshold": config.teacache_threshold if config.enable_teacache else 0,
"enable_enhance_a_video": config.enable_enhance_a_video,
"enhance_a_video_weight": config.enhance_a_video_weight if config.enable_enhance_a_video else 0,
}
return video_uri, metadata
except Exception as e:
logger.error(f"Error in process_frames: {str(e)}")
raise RuntimeError(f"Failed to process frames: {str(e)}")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Process video generation requests
Args:
data: Request data containing:
- inputs (str): Prompt for video generation
- parameters (dict): Generation parameters
Returns:
Dictionary containing:
- video: Base64 encoded MP4 data URI
- content-type: MIME type
- metadata: Generation metadata
"""
# Extract inputs
inputs = data.pop("inputs", data)
if isinstance(inputs, dict):
prompt = inputs.get("prompt", "")
else:
prompt = inputs
params = data.get("parameters", {})
# Create and validate config
config = GenerationConfig(
prompt=prompt,
negative_prompt=params.get("negative_prompt", ""),
num_frames=params.get("num_frames", 49),
height=params.get("height", 320),
width=params.get("width", 576),
num_inference_steps=params.get("num_inference_steps", 50),
guidance_scale=params.get("guidance_scale", 7.0),
seed=params.get("seed", -1),
fps=params.get("fps", 30),
double_num_frames=params.get("double_num_frames", False),
super_resolution=params.get("super_resolution", False),
grain_amount=params.get("grain_amount", 0.0),
quality=params.get("quality", 18),
enable_audio=params.get("enable_audio", False),
audio_prompt=params.get("audio_prompt", ""),
audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"),
enable_teacache=params.get("enable_teacache", False),
# values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup).
teacache_threshold=params.get("teacache_threshold", 0.15),
enable_enhance_a_video=params.get("enable_enhance_a_video", False),
enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0),
lora_model_name=params.get("lora_model_name", ""),
lora_model_weight_file=params.get("lora_model_weight_file", ""),
lora_model_trigger=params.get("lora_model_trigger", ""),
).validate_and_adjust()
try:
# Set random seeds
if config.seed != -1:
torch.manual_seed(config.seed)
random.seed(config.seed)
generator = torch.Generator(device=self.device).manual_seed(config.seed)
else:
generator = None
# Configure TeaCache
#if config.enable_teacache:
# enable_teacache(
# self.pipeline.transformer,
# num_inference_steps=config.num_inference_steps,
# rel_l1_thresh=config.teacache_threshold
# )
#else:
# disable_teacache(self.pipeline.transformer)
with torch.amp.autocast_mode.autocast('cuda', torch.bfloat16), torch.no_grad(), torch.inference_mode():
# Prepare generation parameters
generation_kwargs = {
"prompt": config.prompt,
# Failed to generate video: HunyuanVideoPipeline.__call__() got an unexpected keyword argument 'negative_prompt'
#"negative_prompt": config.negative_prompt,
"num_frames": config.num_frames,
"height": config.height,
"width": config.width,
"num_inference_steps": config.num_inference_steps,
"guidance_scale": config.guidance_scale,
"generator": generator,
"output_type": "pt",
}
# Handle LoRA loading/unloading
if hasattr(self, '_current_lora_model'):
if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file):
# Unload previous LoRA if it exists and is different
if support_image_prompt and hasattr(self.image_to_video, 'unload_lora_weights'):
self.image_to_video.unload_lora_weights()
else:
if hasattr(self.text_to_video, 'unload_lora_weights'):
self.text_to_video.unload_lora_weights()
if config.lora_model_name:
# Load new LoRA
if support_image_prompt and hasattr(self.image_to_video, 'load_lora_weights'):
self.image_to_video.load_lora_weights(
config.lora_model_name,
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
token=hf_token,
)
else:
if hasattr(self.text_to_video, 'load_lora_weights'):
self.text_to_video.load_lora_weights(
config.lora_model_name,
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
token=hf_token,
)
self._current_lora_model = (config.lora_model_name, config.lora_model_weight_file)
# Modify prompt if trigger word is provided
if config.lora_model_trigger:
generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"
# Check if image-to-video generation is requested
if support_image_prompt and input_image:
processed_image = process_input_image(
input_image,
config.width,
config.height,
config.input_image_quality,
)
generation_kwargs["image"] = processed_image
frames = self.image_to_video(**generation_kwargs).frames
else:
frames = self.text_to_video(**generation_kwargs).frames
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config))
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
gc.collect()
return {
"video": video_uri,
"content-type": "video/mp4",
"metadata": metadata
}
except Exception as e:
message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
logger.error(message)
raise RuntimeError(message)