jbilcke-hf HF staff commited on
Commit
4648c2c
1 Parent(s): be29b01

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +8 -2
handler.py CHANGED
@@ -48,7 +48,8 @@ def get_default_args():
48
  parser.add_argument("--precision", type=str, default="bf16", choices=["bf16", "fp32", "fp16"])
49
  parser.add_argument("--rope-theta", type=int, default=256)
50
  parser.add_argument("--load-key", type=str, default="module")
51
-
 
52
  # VAE settings
53
  parser.add_argument("--vae", type=str, default="884-16c-hy")
54
  parser.add_argument("--vae-precision", type=str, default="fp16")
@@ -139,10 +140,15 @@ class EndpointHandler:
139
  # Set paths for model components
140
  dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
141
  original_vae_path = Path(path) / "hunyuan-video-t2v-720p/vae"
142
-
 
 
 
 
143
  # Log all critical paths
144
  logger.info(f"Model base path: {self.args.model_base}")
145
  logger.info(f"DiT weight path: {dit_weight_path}")
 
146
  logger.info(f"Original VAE path: {original_vae_path}")
147
 
148
  # Verify paths exist
 
48
  parser.add_argument("--precision", type=str, default="bf16", choices=["bf16", "fp32", "fp16"])
49
  parser.add_argument("--rope-theta", type=int, default=256)
50
  parser.add_argument("--load-key", type=str, default="module")
51
+ parser.add_argument("--use-fp8", action="store_true", default=False)
52
+
53
  # VAE settings
54
  parser.add_argument("--vae", type=str, default="884-16c-hy")
55
  parser.add_argument("--vae-precision", type=str, default="fp16")
 
140
  # Set paths for model components
141
  dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
142
  original_vae_path = Path(path) / "hunyuan-video-t2v-720p/vae"
143
+
144
+ # to save on memory, we activate fp8 weights and we override the previous dit_weight_path setting
145
+ self.args.use_fp8 = True
146
+ dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt"
147
+
148
  # Log all critical paths
149
  logger.info(f"Model base path: {self.args.model_base}")
150
  logger.info(f"DiT weight path: {dit_weight_path}")
151
+ logger.info(f"Use fp8: {self.args.use_fp8}")
152
  logger.info(f"Original VAE path: {original_vae_path}")
153
 
154
  # Verify paths exist