Allow overriding config values from load_pipeline_from_config_path
Browse files- flux_pipeline.py +7 -1
flux_pipeline.py
CHANGED
@@ -608,12 +608,18 @@ class FluxPipeline:
|
|
608 |
|
609 |
@classmethod
|
610 |
def load_pipeline_from_config_path(
|
611 |
-
cls, path: str, flow_model_path: str = None, debug: bool = False
|
612 |
) -> "FluxPipeline":
|
613 |
with torch.inference_mode():
|
614 |
config = load_config_from_path(path)
|
615 |
if flow_model_path:
|
616 |
config.ckpt_path = flow_model_path
|
|
|
|
|
|
|
|
|
|
|
|
|
617 |
return cls.load_pipeline_from_config(config, debug=debug)
|
618 |
|
619 |
@classmethod
|
|
|
608 |
|
609 |
@classmethod
|
610 |
def load_pipeline_from_config_path(
|
611 |
+
cls, path: str, flow_model_path: str = None, debug: bool = False, **kwargs
|
612 |
) -> "FluxPipeline":
|
613 |
with torch.inference_mode():
|
614 |
config = load_config_from_path(path)
|
615 |
if flow_model_path:
|
616 |
config.ckpt_path = flow_model_path
|
617 |
+
for k, v in kwargs.items():
|
618 |
+
if hasattr(config, k):
|
619 |
+
logger.info(
|
620 |
+
f"Overriding config {k}:{getattr(config, k)} with value {v}"
|
621 |
+
)
|
622 |
+
setattr(config, k, v)
|
623 |
return cls.load_pipeline_from_config(config, debug=debug)
|
624 |
|
625 |
@classmethod
|