File size: 12,537 Bytes
2557c6e
 
2745124
2557c6e
 
 
afceeed
de858d1
2557c6e
 
afceeed
 
de858d1
 
 
c55eec4
 
 
 
 
 
 
2745124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afceeed
 
 
 
829c2a5
3f51080
c55eec4
3f51080
afceeed
3f51080
 
4648c2c
 
829c2a5
 
3f51080
 
829c2a5
 
3f51080
 
 
 
 
829c2a5
 
606d9c1
 
829c2a5
 
3f51080
829c2a5
606d9c1
3f51080
 
606d9c1
3f51080
 
 
 
 
 
 
 
829c2a5
 
3f51080
829c2a5
3f51080
829c2a5
 
3f51080
829c2a5
 
606d9c1
afceeed
c55eec4
829c2a5
 
 
afceeed
829c2a5
 
 
 
 
c55eec4
 
829c2a5
606d9c1
 
829c2a5
 
 
 
 
606d9c1
 
 
afceeed
 
 
606d9c1
afceeed
2557c6e
 
 
afceeed
de858d1
 
afceeed
 
606d9c1
de858d1
 
 
 
606d9c1
 
de858d1
 
 
2745124
4648c2c
 
 
 
 
de858d1
 
 
4648c2c
2745124
de858d1
 
 
 
2745124
 
 
 
 
 
 
 
 
53c0486
2745124
 
53c0486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de858d1
 
2557c6e
afceeed
2557c6e
 
de858d1
2557c6e
de858d1
 
 
 
 
 
 
53c0486
2557c6e
de858d1
 
 
2557c6e
afceeed
2557c6e
 
 
afceeed
606d9c1
c55eec4
afceeed
 
 
c55eec4
afceeed
 
c55eec4
afceeed
 
 
 
de858d1
 
 
2557c6e
de858d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c55eec4
de858d1
 
 
 
 
 
 
c55eec4
 
 
de858d1
 
 
 
 
c55eec4
de858d1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
from typing import Dict, Any
import os
import shutil
from pathlib import Path
import time
from datetime import datetime
import argparse
from loguru import logger
from hyvideo.utils.file_utils import save_videos_grid
from hyvideo.inference import HunyuanVideoSampler
from hyvideo.constants import NEGATIVE_PROMPT

# Configure logger
logger.add("handler_debug.log", rotation="500 MB")

DEFAULT_RESOLUTION = "720p"
DEFAULT_WIDTH = 1280
DEFAULT_HEIGHT = 720
DEFAULT_NB_FRAMES = (4 * 30) + 1 # or 129 (note: hunyan requires an extra +1 frame)
DEFAULT_NB_STEPS = 22 # or 50
DEFAULT_FPS = 24

def setup_vae_path(vae_path: Path) -> Path:
    """Create a temporary directory with correctly named VAE config file"""
    tmp_vae_dir = Path("/tmp/vae")
    if tmp_vae_dir.exists():
        shutil.rmtree(tmp_vae_dir)
    tmp_vae_dir.mkdir(parents=True)
    
    # Copy files to temp directory
    logger.info(f"Setting up VAE in temporary directory: {tmp_vae_dir}")
    
    # Copy and rename config file
    original_config = vae_path / "hunyuan-video-t2v-720p_vae_config.json"
    new_config = tmp_vae_dir / "config.json"
    shutil.copy2(original_config, new_config)
    logger.info(f"Copied VAE config from {original_config} to {new_config}")
    
    # Copy model file
    original_model = vae_path / "pytorch_model.pt"
    new_model = tmp_vae_dir / "pytorch_model.pt"
    shutil.copy2(original_model, new_model)
    logger.info(f"Copied VAE model from {original_model} to {new_model}")
    
    return tmp_vae_dir

def get_default_args():
    """Create default arguments instead of parsing from command line"""
    parser = argparse.ArgumentParser()
    
    # Model configuration
    parser.add_argument("--model", type=str, default="HYVideo-T/2-cfgdistill")
    parser.add_argument("--model-resolution", type=str, default=DEFAULT_RESOLUTION, choices=["540p", "720p"])
    parser.add_argument("--latent-channels", type=int, default=16)
    parser.add_argument("--precision", type=str, default="bf16", choices=["bf16", "fp32", "fp16"])
    parser.add_argument("--rope-theta", type=int, default=256)
    parser.add_argument("--load-key", type=str, default="module")
    parser.add_argument("--use-fp8", action="store_true", default=False)

    # VAE settings
    parser.add_argument("--vae", type=str, default="884-16c-hy")
    parser.add_argument("--vae-precision", type=str, default="fp16")
    parser.add_argument("--vae-tiling", action="store_true", default=True)
    
    # Text encoder settings
    parser.add_argument("--text-encoder", type=str, default="llm")
    parser.add_argument("--text-encoder-precision", type=str, default="fp16")
    parser.add_argument("--text-states-dim", type=int, default=4096)
    parser.add_argument("--text-len", type=int, default=256)
    parser.add_argument("--tokenizer", type=str, default="llm")
    
    # Prompt template settings
    parser.add_argument("--prompt-template", type=str, default="dit-llm-encode")
    parser.add_argument("--prompt-template-video", type=str, default="dit-llm-encode-video")
    
    # Additional text encoder settings
    parser.add_argument("--hidden-state-skip-layer", type=int, default=2)
    parser.add_argument("--apply-final-norm", action="store_true")
    parser.add_argument("--text-encoder-2", type=str, default="clipL")
    parser.add_argument("--text-encoder-precision-2", type=str, default="fp16")
    parser.add_argument("--text-states-dim-2", type=int, default=768)
    parser.add_argument("--tokenizer-2", type=str, default="clipL")
    parser.add_argument("--text-len-2", type=int, default=77)
    
    # Model architecture settings
    parser.add_argument("--hidden-size", type=int, default=1024)
    parser.add_argument("--heads-num", type=int, default=16)
    parser.add_argument("--layers-num", type=int, default=24)
    parser.add_argument("--mlp-ratio", type=float, default=4.0)
    parser.add_argument("--use-guidance-net", action="store_true", default=True)
    
    # Inference settings
    parser.add_argument("--denoise-type", type=str, default="flow")
    parser.add_argument("--flow-shift", type=float, default=7.0)
    parser.add_argument("--flow-reverse", action="store_true", default=True)
    parser.add_argument("--flow-solver", type=str, default="euler")
    parser.add_argument("--use-linear-quadratic-schedule", action="store_true")
    parser.add_argument("--linear-schedule-end", type=int, default=25)
    
    # Hardware settings
    parser.add_argument("--use-cpu-offload", action="store_true", default=False)
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--infer-steps", type=int, default=DEFAULT_NB_STEPS)
    parser.add_argument("--disable-autocast", action="store_true")
    
    # Output settings
    parser.add_argument("--save-path", type=str, default="outputs")
    parser.add_argument("--save-path-suffix", type=str, default="")
    parser.add_argument("--name-suffix", type=str, default="")
    
    # Generation settings
    parser.add_argument("--num-videos", type=int, default=1)
    parser.add_argument("--video-size", nargs="+", type=int, default=[DEFAULT_HEIGHT, DEFAULT_WIDTH])
    parser.add_argument("--video-length", type=int, default=DEFAULT_NB_FRAMES)
    parser.add_argument("--prompt", type=str, default=None)
    parser.add_argument("--seed-type", type=str, default="auto", choices=["file", "random", "fixed", "auto"])
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--neg-prompt", type=str, default="")
    parser.add_argument("--cfg-scale", type=float, default=1.0)
    parser.add_argument("--embedded-cfg-scale", type=float, default=6.0)
    parser.add_argument("--reproduce", action="store_true")
    
    # Parallel settings
    parser.add_argument("--ulysses-degree", type=int, default=1)
    parser.add_argument("--ring-degree", type=int, default=1)
    
    # Parse with empty args list to avoid reading sys.argv
    args = parser.parse_args([])
    
    return args

class EndpointHandler:
    def __init__(self, path: str = ""):
        """Initialize the handler with model path and default config."""
        logger.info(f"Initializing EndpointHandler with path: {path}")
        
        # Use default args instead of parsing from command line
        self.args = get_default_args()
        
        # Convert path to absolute path if not already
        path = str(Path(path).absolute())
        logger.info(f"Absolute path: {path}")
        
        # Set up model paths
        self.args.model_base = path
        
        # Set paths for model components
        dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
        original_vae_path = Path(path) / "hunyuan-video-t2v-720p/vae"

        # to save on memory, we activate fp8 weights and we override the previous dit_weight_path setting
        self.args.use_fp8 = True
        dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt"

        # Log all critical paths
        logger.info(f"Model base path: {self.args.model_base}")
        logger.info(f"DiT weight path: {dit_weight_path}")
        logger.info(f"Use fp8: {self.args.use_fp8}")
        logger.info(f"Original VAE path: {original_vae_path}")
        
        # Verify paths exist
        logger.info("Checking if paths exist:")
        logger.info(f"DiT weight exists: {dit_weight_path.exists()}")
        logger.info(f"VAE path exists: {original_vae_path.exists()}")
        
        if original_vae_path.exists():
            logger.info(f"VAE path contents: {list(original_vae_path.glob('*'))}")
            
            # Set up VAE in temporary directory with correct file names
            tmp_vae_path = setup_vae_path(original_vae_path)
            
            # Override the VAE path in constants to use our temporary directory
            from hyvideo.constants import VAE_PATH, TEXT_ENCODER_PATH, TOKENIZER_PATH
            VAE_PATH["884-16c-hy"] = str(tmp_vae_path)
            logger.info(f"Updated VAE_PATH to: {VAE_PATH['884-16c-hy']}")
            
            # Update text encoder paths to use absolute paths
            text_encoder_path = str(Path(path) / "text_encoder")
            text_encoder_2_path = str(Path(path) / "text_encoder_2")
            
            # Update both text encoder and tokenizer paths
            TEXT_ENCODER_PATH.update({
                "llm": text_encoder_path,
                "clipL": text_encoder_2_path
            })
            
            TOKENIZER_PATH.update({
                "llm": text_encoder_path,
                "clipL": text_encoder_2_path
            })
            
            logger.info(f"Updated text encoder paths:")
            logger.info(f"TEXT_ENCODER_PATH['llm']: {TEXT_ENCODER_PATH['llm']}")
            logger.info(f"TEXT_ENCODER_PATH['clipL']: {TEXT_ENCODER_PATH['clipL']}")
            logger.info(f"TOKENIZER_PATH['llm']: {TOKENIZER_PATH['llm']}")
            logger.info(f"TOKENIZER_PATH['clipL']: {TOKENIZER_PATH['clipL']}")
        
        self.args.dit_weight = str(dit_weight_path)
        
        # Initialize model
        models_root_path = Path(path)
        if not models_root_path.exists():
            raise ValueError(f"models_root_path does not exist: {models_root_path}")
        
        try:
            logger.info("Attempting to initialize HunyuanVideoSampler...")
            self.model = HunyuanVideoSampler.from_pretrained(models_root_path, args=self.args)
            logger.info("Successfully initialized HunyuanVideoSampler")
        except Exception as e:
            logger.error(f"Error initializing model: {str(e)}")
            raise
        
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Process a single request"""
        # Log incoming request
        logger.info(f"Processing request with data: {data}")
        
        # Get inputs from request data
        prompt = data.pop("inputs", None)
        if prompt is None:
            raise ValueError("No prompt provided in the 'inputs' field")
            
        # Parse resolution 
        resolution = data.pop("resolution", f"{DEFAULT_WIDTH}x{DEFAULT_HEIGHT}")
        width, height = map(int, resolution.split("x"))
        
        # Get other parameters with defaults
        video_length = int(data.pop("video_length", DEFAULT_NB_FRAMES))
        seed = data.pop("seed", -1)
        seed = None if seed == -1 else int(seed)
        num_inference_steps = int(data.pop("num_inference_steps", DEFAULT_NB_STEPS))
        guidance_scale = float(data.pop("guidance_scale", 1.0))
        flow_shift = float(data.pop("flow_shift", 7.0))
        embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0))
        
        logger.info(f"Processing with parameters: width={width}, height={height}, "
                   f"video_length={video_length}, seed={seed}, "
                   f"num_inference_steps={num_inference_steps}")
        
        try:
            # Run inference
            outputs = self.model.predict(
                prompt=prompt,
                height=height,
                width=width,
                video_length=video_length, 
                seed=seed,
                negative_prompt="",
                infer_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                num_videos_per_prompt=1,
                flow_shift=flow_shift,
                batch_size=1,
                embedded_guidance_scale=embedded_guidance_scale
            )
            
            # Get the video tensor
            samples = outputs['samples']
            sample = samples[0].unsqueeze(0)
            
            # Save to temporary file
            temp_path = "/tmp/temp_video.mp4"
            save_videos_grid(sample, temp_path, fps=DEFAULT_FPS)
            
            # Read video file and convert to base64
            with open(temp_path, "rb") as f:
                video_bytes = f.read()
            import base64
            video_base64 = base64.b64encode(video_bytes).decode()
            
            # Add MP4 data URI prefix
            video_data_uri = f"data:video/mp4;base64,{video_base64}"
            
            # Cleanup
            os.remove(temp_path)
            
            logger.info("Successfully generated and encoded video")
            
            return video_data_uri
            
        except Exception as e:
            logger.error(f"Error during video generation: {str(e)}")
            raise