Spaces:
Runtime error
Runtime error
from fis.feature_extraction.detection.base import BaseDetector | |
from fis.feature_extraction.embedding.base import BaseEncoder | |
from fis.feature_extraction.pipeline.base import EncodingPipeline | |
class PipelineFactory: | |
"""Factory method for encoding pipelines. | |
Example use: | |
>>> from fis.feature_extraction.pipeline.factory import PipelineFactory | |
>>> factory = PipelineFactory() | |
>>> factory.register_pipeline( | |
... name="example_pipeline", | |
... detection_model=BaseDetector(), | |
... embedding_model=BaseEncoder() | |
... ) | |
>>> pipeline = factory.get('example_pipeline') | |
""" | |
def __init__(self): | |
"""Instantiate factory object.""" | |
self._pipelines = {} | |
def register_pipeline(self, name: str, detection_model: BaseDetector, embedding_model: BaseEncoder) -> None: | |
"""Register a new pipeline to the factory. | |
Args: | |
name: Name of the pipeline to create. | |
detection_model: Instance of a BaseDetector object. | |
embedding_model: Instance of a BaseEncoder object. | |
""" | |
pipeline = EncodingPipeline(name=name, detection_model=detection_model, embedding_model=embedding_model) | |
self._pipelines[name] = pipeline | |
def get(self, name: str) -> EncodingPipeline: | |
"""Get a pipeline from its name. | |
Args: | |
name: Name of the pipeline to get. | |
Raises: | |
ValueError: If no pipeline has been registered with the given name. | |
Returns: | |
Encoding pipeline. | |
""" | |
pipeline = self._pipelines.get(name) | |
if not pipeline: | |
raise ValueError(name) | |
return pipeline | |