HunyuanVideo-HFIE / handler.py
jbilcke-hf's picture
jbilcke-hf HF staff
Upload handler.py
2557c6e verified
raw
history blame
4.42 kB
from typing import Dict, Any
import os
from pathlib import Path
import time
from datetime import datetime
import torch
import base64
from io import BytesIO
from hyvideo.utils.file_utils import save_videos_grid
from hyvideo.config import parse_args
from hyvideo.inference import HunyuanVideoSampler
class EndpointHandler:
def __init__(self, path: str = ""):
"""Initialize the handler with the model path.
Args:
path: Path to the model weights directory
"""
self.args = parse_args()
models_root_path = Path(path)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
# Initialize model
self.model = HunyuanVideoSampler.from_pretrained(models_root_path, args=self.args)
# Default parameters
self.default_params = {
"num_inference_steps": 50,
"guidance_scale": 1.0,
"flow_shift": 7.0,
"embedded_guidance_scale": 6.0,
"video_length": 129, # 5s
"resolution": "1280x720"
}
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Process the input data and generate video.
Args:
data: Dictionary containing the input parameters
Required:
- inputs (str): The prompt text
Optional:
- resolution (str): Video resolution like "1280x720"
- video_length (int): Number of frames
- seed (int): Random seed (-1 for random)
- num_inference_steps (int): Number of inference steps
- guidance_scale (float): Guidance scale value
- flow_shift (float): Flow shift value
- embedded_guidance_scale (float): Embedded guidance scale value
Returns:
Dictionary containing the base64 encoded video
"""
# Get prompt
prompt = data.pop("inputs", None)
if prompt is None:
raise ValueError("No prompt provided in the 'inputs' field")
# Get optional parameters with defaults
resolution = data.pop("resolution", self.default_params["resolution"])
video_length = int(data.pop("video_length", self.default_params["video_length"]))
seed = int(data.pop("seed", -1))
num_inference_steps = int(data.pop("num_inference_steps", self.default_params["num_inference_steps"]))
guidance_scale = float(data.pop("guidance_scale", self.default_params["guidance_scale"]))
flow_shift = float(data.pop("flow_shift", self.default_params["flow_shift"]))
embedded_guidance_scale = float(data.pop("embedded_guidance_scale", self.default_params["embedded_guidance_scale"]))
# Process resolution
width, height = resolution.split("x")
width, height = int(width), int(height)
# Set seed
seed = None if seed == -1 else seed
# Generate video
outputs = self.model.predict(
prompt=prompt,
height=height,
width=width,
video_length=video_length,
seed=seed,
negative_prompt="", # not applicable in inference
infer_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_videos_per_prompt=1,
flow_shift=flow_shift,
batch_size=1,
embedded_guidance_scale=embedded_guidance_scale
)
# Process output video
samples = outputs['samples']
sample = samples[0].unsqueeze(0)
# Save video to temporary file
temp_dir = "/tmp/video_output"
os.makedirs(temp_dir, exist_ok=True)
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
video_path = f"{temp_dir}/{time_flag}_seed{outputs['seeds'][0]}.mp4"
save_videos_grid(sample, video_path, fps=24)
# Read video file and convert to base64
with open(video_path, "rb") as f:
video_bytes = f.read()
video_base64 = base64.b64encode(video_bytes).decode()
# Clean up
os.remove(video_path)
return {
"video_base64": video_base64,
"seed": outputs['seeds'][0],
"prompt": outputs['prompts'][0]
}