jbilcke-hf HF staff commited on
Commit
de858d1
1 Parent(s): 3f51080

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +93 -56
handler.py CHANGED
@@ -4,10 +4,14 @@ from pathlib import Path
4
  import time
5
  from datetime import datetime
6
  import argparse
 
7
  from hyvideo.utils.file_utils import save_videos_grid
8
  from hyvideo.inference import HunyuanVideoSampler
9
  from hyvideo.constants import NEGATIVE_PROMPT
10
 
 
 
 
11
  def get_default_args():
12
  """Create default arguments instead of parsing from command line"""
13
  parser = argparse.ArgumentParser()
@@ -95,38 +99,60 @@ def get_default_args():
95
  class EndpointHandler:
96
  def __init__(self, path: str = ""):
97
  """Initialize the handler with model path and default config."""
 
 
 
98
  # Use default args instead of parsing from command line
99
  self.args = get_default_args()
100
 
 
 
 
 
101
  # Set up model paths
102
  self.args.model_base = path
103
- self.args.dit_weight = str(Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # Initialize model
106
  models_root_path = Path(path)
107
  if not models_root_path.exists():
108
- raise ValueError(f"`models_root` not exists: {models_root_path}")
109
 
110
- self.model = HunyuanVideoSampler.from_pretrained(models_root_path, args=self.args)
111
-
 
 
 
 
 
 
 
 
 
 
112
 
113
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
114
- """Process a single request
 
 
115
 
116
- Args:
117
- data: Dictionary containing:
118
- - inputs (str): The prompt text
119
- - resolution (str, optional): Video resolution like "1280x720"
120
- - video_length (int, optional): Number of frames
121
- - num_inference_steps (int, optional): Number of inference steps
122
- - seed (int, optional): Random seed (-1 for random)
123
- - guidance_scale (float, optional): Guidance scale value
124
- - flow_shift (float, optional): Flow shift value
125
- - embedded_guidance_scale (float, optional): Embedded guidance scale
126
-
127
- Returns:
128
- Dictionary containing the generated video as base64 string
129
- """
130
  # Get inputs from request data
131
  prompt = data.pop("inputs", None)
132
  if prompt is None:
@@ -145,41 +171,52 @@ class EndpointHandler:
145
  flow_shift = float(data.pop("flow_shift", 7.0))
146
  embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0))
147
 
148
- # Run inference
149
- outputs = self.model.predict(
150
- prompt=prompt,
151
- height=height,
152
- width=width,
153
- video_length=video_length,
154
- seed=seed,
155
- negative_prompt="",
156
- infer_steps=num_inference_steps,
157
- guidance_scale=guidance_scale,
158
- num_videos_per_prompt=1,
159
- flow_shift=flow_shift,
160
- batch_size=1,
161
- embedded_guidance_scale=embedded_guidance_scale
162
- )
163
-
164
- # Get the video tensor
165
- samples = outputs['samples']
166
- sample = samples[0].unsqueeze(0)
167
 
168
- # Save to temporary file
169
- temp_path = "/tmp/temp_video.mp4"
170
- save_videos_grid(sample, temp_path, fps=24)
171
-
172
- # Read video file and convert to base64
173
- with open(temp_path, "rb") as f:
174
- video_bytes = f.read()
175
- import base64
176
- video_base64 = base64.b64encode(video_bytes).decode()
177
-
178
- # Cleanup
179
- os.remove(temp_path)
180
-
181
- return {
182
- "video_base64": video_base64,
183
- "seed": outputs['seeds'][0],
184
- "prompt": outputs['prompts'][0]
185
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import time
5
  from datetime import datetime
6
  import argparse
7
+ from loguru import logger
8
  from hyvideo.utils.file_utils import save_videos_grid
9
  from hyvideo.inference import HunyuanVideoSampler
10
  from hyvideo.constants import NEGATIVE_PROMPT
11
 
12
+ # Configure logger
13
+ logger.add("handler_debug.log", rotation="500 MB")
14
+
15
  def get_default_args():
16
  """Create default arguments instead of parsing from command line"""
17
  parser = argparse.ArgumentParser()
 
99
  class EndpointHandler:
100
  def __init__(self, path: str = ""):
101
  """Initialize the handler with model path and default config."""
102
+ # Log the initial path
103
+ logger.info(f"Initializing EndpointHandler with path: {path}")
104
+
105
  # Use default args instead of parsing from command line
106
  self.args = get_default_args()
107
 
108
+ # Convert path to absolute path if not already
109
+ path = str(Path(path).absolute())
110
+ logger.info(f"Absolute path: {path}")
111
+
112
  # Set up model paths
113
  self.args.model_base = path
114
+
115
+ # Set paths for model components
116
+ dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
117
+ vae_path = Path(path) / "hunyuan-video-t2v-720p/vae"
118
+
119
+ # Log all critical paths
120
+ logger.info(f"Model base path: {self.args.model_base}")
121
+ logger.info(f"DiT weight path: {dit_weight_path}")
122
+ logger.info(f"VAE path: {vae_path}")
123
+
124
+ # Verify paths exist
125
+ logger.info("Checking if paths exist:")
126
+ logger.info(f"DiT weight exists: {dit_weight_path.exists()}")
127
+ logger.info(f"VAE path exists: {vae_path.exists()}")
128
+ if vae_path.exists():
129
+ logger.info(f"VAE path contents: {list(vae_path.glob('*'))}")
130
+
131
+ self.args.dit_weight = str(dit_weight_path)
132
 
133
  # Initialize model
134
  models_root_path = Path(path)
135
  if not models_root_path.exists():
136
+ raise ValueError(f"models_root_path does not exist: {models_root_path}")
137
 
138
+ # Log directory contents for debugging
139
+ logger.info("Directory contents:")
140
+ for item in models_root_path.glob("**/*"):
141
+ logger.info(f" {item}")
142
+
143
+ try:
144
+ logger.info("Attempting to initialize HunyuanVideoSampler...")
145
+ self.model = HunyuanVideoSampler.from_pretrained(models_root_path, args=self.args)
146
+ logger.info("Successfully initialized HunyuanVideoSampler")
147
+ except Exception as e:
148
+ logger.error(f"Error initializing model: {str(e)}")
149
+ raise
150
 
151
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
152
+ """Process a single request"""
153
+ # Log incoming request
154
+ logger.info(f"Processing request with data: {data}")
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  # Get inputs from request data
157
  prompt = data.pop("inputs", None)
158
  if prompt is None:
 
171
  flow_shift = float(data.pop("flow_shift", 7.0))
172
  embedded_guidance_scale = float(data.pop("embedded_guidance_scale", 6.0))
173
 
174
+ logger.info(f"Processing with parameters: width={width}, height={height}, "
175
+ f"video_length={video_length}, seed={seed}, "
176
+ f"num_inference_steps={num_inference_steps}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ try:
179
+ # Run inference
180
+ outputs = self.model.predict(
181
+ prompt=prompt,
182
+ height=height,
183
+ width=width,
184
+ video_length=video_length,
185
+ seed=seed,
186
+ negative_prompt="",
187
+ infer_steps=num_inference_steps,
188
+ guidance_scale=guidance_scale,
189
+ num_videos_per_prompt=1,
190
+ flow_shift=flow_shift,
191
+ batch_size=1,
192
+ embedded_guidance_scale=embedded_guidance_scale
193
+ )
194
+
195
+ # Get the video tensor
196
+ samples = outputs['samples']
197
+ sample = samples[0].unsqueeze(0)
198
+
199
+ # Save to temporary file
200
+ temp_path = "/tmp/temp_video.mp4"
201
+ save_videos_grid(sample, temp_path, fps=24)
202
+
203
+ # Read video file and convert to base64
204
+ with open(temp_path, "rb") as f:
205
+ video_bytes = f.read()
206
+ import base64
207
+ video_base64 = base64.b64encode(video_bytes).decode()
208
+
209
+ # Cleanup
210
+ os.remove(temp_path)
211
+
212
+ logger.info("Successfully generated and encoded video")
213
+
214
+ return {
215
+ "video_base64": video_base64,
216
+ "seed": outputs['seeds'][0],
217
+ "prompt": outputs['prompts'][0]
218
+ }
219
+
220
+ except Exception as e:
221
+ logger.error(f"Error during video generation: {str(e)}")
222
+ raise