Commit
•
4648c2c
1
Parent(s):
be29b01
Update handler.py
Browse files- handler.py +8 -2
handler.py
CHANGED
@@ -48,7 +48,8 @@ def get_default_args():
|
|
48 |
parser.add_argument("--precision", type=str, default="bf16", choices=["bf16", "fp32", "fp16"])
|
49 |
parser.add_argument("--rope-theta", type=int, default=256)
|
50 |
parser.add_argument("--load-key", type=str, default="module")
|
51 |
-
|
|
|
52 |
# VAE settings
|
53 |
parser.add_argument("--vae", type=str, default="884-16c-hy")
|
54 |
parser.add_argument("--vae-precision", type=str, default="fp16")
|
@@ -139,10 +140,15 @@ class EndpointHandler:
|
|
139 |
# Set paths for model components
|
140 |
dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
|
141 |
original_vae_path = Path(path) / "hunyuan-video-t2v-720p/vae"
|
142 |
-
|
|
|
|
|
|
|
|
|
143 |
# Log all critical paths
|
144 |
logger.info(f"Model base path: {self.args.model_base}")
|
145 |
logger.info(f"DiT weight path: {dit_weight_path}")
|
|
|
146 |
logger.info(f"Original VAE path: {original_vae_path}")
|
147 |
|
148 |
# Verify paths exist
|
|
|
48 |
parser.add_argument("--precision", type=str, default="bf16", choices=["bf16", "fp32", "fp16"])
|
49 |
parser.add_argument("--rope-theta", type=int, default=256)
|
50 |
parser.add_argument("--load-key", type=str, default="module")
|
51 |
+
parser.add_argument("--use-fp8", action="store_true", default=False)
|
52 |
+
|
53 |
# VAE settings
|
54 |
parser.add_argument("--vae", type=str, default="884-16c-hy")
|
55 |
parser.add_argument("--vae-precision", type=str, default="fp16")
|
|
|
140 |
# Set paths for model components
|
141 |
dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
|
142 |
original_vae_path = Path(path) / "hunyuan-video-t2v-720p/vae"
|
143 |
+
|
144 |
+
# to save on memory, we activate fp8 weights and we override the previous dit_weight_path setting
|
145 |
+
self.args.use_fp8 = True
|
146 |
+
dit_weight_path = Path(path) / "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt"
|
147 |
+
|
148 |
# Log all critical paths
|
149 |
logger.info(f"Model base path: {self.args.model_base}")
|
150 |
logger.info(f"DiT weight path: {dit_weight_path}")
|
151 |
+
logger.info(f"Use fp8: {self.args.use_fp8}")
|
152 |
logger.info(f"Original VAE path: {original_vae_path}")
|
153 |
|
154 |
# Verify paths exist
|