|
from dataclasses import dataclass |
|
from typing import Dict, Any, Optional |
|
import base64 |
|
import asyncio |
|
import logging |
|
import random |
|
import traceback |
|
import torch |
|
import os |
|
import gc |
|
|
|
|
|
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, FasterCacheConfig |
|
from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig |
|
from varnish import Varnish |
|
from varnish.utils import is_truthy, process_input_image |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT")) |
|
|
|
@dataclass |
|
class GenerationConfig: |
|
"""Configuration for video generation""" |
|
|
|
prompt: str |
|
negative_prompt: str = "" |
|
|
|
|
|
num_frames: int = 49 |
|
height: int = 320 |
|
width: int = 576 |
|
num_inference_steps: int = 50 |
|
guidance_scale: float = 7.0 |
|
|
|
|
|
seed: int = -1 |
|
|
|
|
|
fps: int = 30 |
|
double_num_frames: bool = False |
|
super_resolution: bool = False |
|
grain_amount: float = 0.0 |
|
quality: int = 18 |
|
|
|
|
|
enable_audio: bool = False |
|
audio_prompt: str = "" |
|
audio_negative_prompt: str = "voices, voice, talking, speaking, speech" |
|
|
|
|
|
enable_teacache: bool = False |
|
teacache_threshold: float = 0.15 |
|
|
|
|
|
|
|
enable_enhance_a_video: bool = False |
|
enhance_a_video_weight: float = 5.0 |
|
|
|
|
|
lora_model_name: str = "" |
|
lora_model_weight_file: str = "" |
|
lora_model_trigger: str = "" |
|
|
|
def validate_and_adjust(self) -> 'GenerationConfig': |
|
"""Validate and adjust parameters""" |
|
|
|
k = (self.num_frames - 1) // 4 |
|
self.num_frames = (k * 4) + 1 |
|
|
|
|
|
if self.seed == -1: |
|
self.seed = random.randint(0, 2**32 - 1) |
|
|
|
return self |
|
|
|
class EndpointHandler: |
|
"""Handles video generation requests using HunyuanVideo and Varnish""" |
|
|
|
def __init__(self, path: str = ""): |
|
"""Initialize handler with models |
|
|
|
Args: |
|
path: Path to model weights |
|
""" |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
transformer = HunyuanVideoTransformer3DModel.from_pretrained( |
|
path, |
|
subfolder="transformer", |
|
torch_dtype=torch.bfloat16 |
|
) |
|
|
|
if support_image_prompt: |
|
raise Exception("Please use a version of Diffusers that supports HunyuanImageToVideoPipeline") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
self.text_to_video = HunyuanVideoPipeline.from_pretrained( |
|
path, |
|
transformer=transformer, |
|
torch_dtype=torch.float16, |
|
).to(self.device) |
|
|
|
|
|
self.text_to_video.text_encoder = self.text_to_video.text_encoder.half() |
|
self.text_to_video.text_encoder_2 = self.text_to_video.text_encoder_2.half() |
|
self.text_to_video.transformer = self.text_to_video.transformer.to(torch.bfloat16) |
|
self.text_to_video.vae = self.text_to_video.vae.half() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
faster_cache_config = FasterCacheConfig( |
|
current_timestep_callback=lambda: self.text_to_video.current_timestep, |
|
spatial_attention_block_skip_range=2, |
|
spatial_attention_timestep_skip_range=(-1, 901), |
|
unconditional_batch_skip_range=2, |
|
attention_weight_callback=lambda _: 0.5, |
|
is_guidance_distilled=True, |
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self._current_lora_model = None |
|
|
|
|
|
self.varnish = Varnish( |
|
device=self.device, |
|
model_base_dir="/repository/varnish" |
|
) |
|
|
|
async def process_frames( |
|
self, |
|
frames: torch.Tensor, |
|
config: GenerationConfig |
|
) -> tuple[str, dict]: |
|
"""Post-process generated frames using Varnish |
|
|
|
Args: |
|
frames: Generated video frames tensor |
|
config: Generation configuration |
|
|
|
Returns: |
|
Tuple of (video data URI, metadata dictionary) |
|
""" |
|
try: |
|
|
|
result = await self.varnish( |
|
input_data=frames, |
|
fps=config.fps, |
|
double_num_frames=config.double_num_frames, |
|
super_resolution=config.super_resolution, |
|
grain_amount=config.grain_amount, |
|
enable_audio=config.enable_audio, |
|
audio_prompt=config.audio_prompt, |
|
audio_negative_prompt=config.audio_negative_prompt |
|
) |
|
|
|
|
|
video_uri = await result.write(type="data-uri", quality=config.quality) |
|
|
|
|
|
metadata = { |
|
"width": result.metadata.width, |
|
"height": result.metadata.height, |
|
"num_frames": result.metadata.frame_count, |
|
"fps": result.metadata.fps, |
|
"duration": result.metadata.duration, |
|
"seed": config.seed, |
|
"enable_teacache": config.enable_teacache, |
|
"teacache_threshold": config.teacache_threshold if config.enable_teacache else 0, |
|
"enable_enhance_a_video": config.enable_enhance_a_video, |
|
"enhance_a_video_weight": config.enhance_a_video_weight if config.enable_enhance_a_video else 0, |
|
} |
|
|
|
return video_uri, metadata |
|
|
|
except Exception as e: |
|
logger.error(f"Error in process_frames: {str(e)}") |
|
raise RuntimeError(f"Failed to process frames: {str(e)}") |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Process video generation requests |
|
|
|
Args: |
|
data: Request data containing: |
|
- inputs (str): Prompt for video generation |
|
- parameters (dict): Generation parameters |
|
|
|
Returns: |
|
Dictionary containing: |
|
- video: Base64 encoded MP4 data URI |
|
- content-type: MIME type |
|
- metadata: Generation metadata |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
if isinstance(inputs, dict): |
|
prompt = inputs.get("prompt", "") |
|
else: |
|
prompt = inputs |
|
|
|
params = data.get("parameters", {}) |
|
|
|
|
|
config = GenerationConfig( |
|
prompt=prompt, |
|
negative_prompt=params.get("negative_prompt", ""), |
|
num_frames=params.get("num_frames", 49), |
|
height=params.get("height", 320), |
|
width=params.get("width", 576), |
|
num_inference_steps=params.get("num_inference_steps", 50), |
|
guidance_scale=params.get("guidance_scale", 7.0), |
|
seed=params.get("seed", -1), |
|
fps=params.get("fps", 30), |
|
double_num_frames=params.get("double_num_frames", False), |
|
super_resolution=params.get("super_resolution", False), |
|
grain_amount=params.get("grain_amount", 0.0), |
|
quality=params.get("quality", 18), |
|
enable_audio=params.get("enable_audio", False), |
|
audio_prompt=params.get("audio_prompt", ""), |
|
audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"), |
|
enable_teacache=params.get("enable_teacache", False), |
|
|
|
|
|
teacache_threshold=params.get("teacache_threshold", 0.15), |
|
|
|
enable_enhance_a_video=params.get("enable_enhance_a_video", False), |
|
enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0), |
|
|
|
lora_model_name=params.get("lora_model_name", ""), |
|
lora_model_weight_file=params.get("lora_model_weight_file", ""), |
|
lora_model_trigger=params.get("lora_model_trigger", ""), |
|
).validate_and_adjust() |
|
|
|
try: |
|
|
|
if config.seed != -1: |
|
torch.manual_seed(config.seed) |
|
random.seed(config.seed) |
|
generator = torch.Generator(device=self.device).manual_seed(config.seed) |
|
else: |
|
generator = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.amp.autocast_mode.autocast('cuda', torch.bfloat16), torch.no_grad(), torch.inference_mode(): |
|
|
|
|
|
generation_kwargs = { |
|
"prompt": config.prompt, |
|
|
|
|
|
|
|
|
|
|
|
"num_frames": config.num_frames, |
|
"height": config.height, |
|
"width": config.width, |
|
"num_inference_steps": config.num_inference_steps, |
|
"guidance_scale": config.guidance_scale, |
|
"generator": generator, |
|
"output_type": "pt", |
|
} |
|
|
|
|
|
if hasattr(self, '_current_lora_model'): |
|
if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file): |
|
|
|
if support_image_prompt and hasattr(self.image_to_video, 'unload_lora_weights'): |
|
self.image_to_video.unload_lora_weights() |
|
else: |
|
if hasattr(self.text_to_video, 'unload_lora_weights'): |
|
self.text_to_video.unload_lora_weights() |
|
|
|
if config.lora_model_name: |
|
|
|
if support_image_prompt and hasattr(self.image_to_video, 'load_lora_weights'): |
|
self.image_to_video.load_lora_weights( |
|
config.lora_model_name, |
|
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None, |
|
token=hf_token, |
|
) |
|
else: |
|
if hasattr(self.text_to_video, 'load_lora_weights'): |
|
self.text_to_video.load_lora_weights( |
|
config.lora_model_name, |
|
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None, |
|
token=hf_token, |
|
) |
|
self._current_lora_model = (config.lora_model_name, config.lora_model_weight_file) |
|
|
|
|
|
if config.lora_model_trigger: |
|
generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}" |
|
|
|
|
|
if support_image_prompt and input_image: |
|
processed_image = process_input_image( |
|
input_image, |
|
config.width, |
|
config.height, |
|
config.input_image_quality, |
|
) |
|
generation_kwargs["image"] = processed_image |
|
frames = self.image_to_video(**generation_kwargs).frames |
|
else: |
|
frames = self.text_to_video(**generation_kwargs).frames |
|
|
|
|
|
try: |
|
loop = asyncio.get_event_loop() |
|
except RuntimeError: |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config)) |
|
|
|
torch.cuda.empty_cache() |
|
torch.cuda.reset_peak_memory_stats() |
|
gc.collect() |
|
|
|
return { |
|
"video": video_uri, |
|
"content-type": "video/mp4", |
|
"metadata": metadata |
|
} |
|
|
|
except Exception as e: |
|
message = f"Error generating video ({str(e)})\n{traceback.format_exc()}" |
|
logger.error(message) |
|
raise RuntimeError(message) |