radames's picture
new base
cb92d2b
raw
history blame
496 Bytes
from importlib import import_module
from types import ModuleType
def get_pipeline_class(pipeline_name: str) -> ModuleType:
try:
module = import_module(f"pipelines.{pipeline_name}")
except ModuleNotFoundError:
raise ValueError(f"Pipeline {pipeline_name} module not found")
pipeline_class = getattr(module, "Pipeline", None)
if pipeline_class is None:
raise ValueError(f"'Pipeline' class not found in module '{module_name}'.")
return pipeline_class