from typing import * import torch import torch.nn as nn from .. import models class Pipeline: """ A base class for pipelines. """ def __init__( self, models: dict[str, nn.Module] = None, ): if models is None: return self.models = models for model in self.models.values(): model.eval() @staticmethod def from_pretrained(path: str) -> "Pipeline": """ Load a pretrained model. """ import os import json is_local = os.path.exists(f"{path}/pipeline.json") if is_local: config_file = f"{path}/pipeline.json" else: from huggingface_hub import hf_hub_download config_file = hf_hub_download(path, "pipeline.json") with open(config_file, 'r') as f: args = json.load(f)['args'] _models = { k: models.from_pretrained(f"{path}/{v}") for k, v in args['models'].items() } new_pipeline = Pipeline(_models) new_pipeline._pretrained_args = args return new_pipeline @property def device(self) -> torch.device: for model in self.models.values(): if hasattr(model, 'device'): return model.device for model in self.models.values(): if hasattr(model, 'parameters'): return next(model.parameters()).device raise RuntimeError("No device found.") def to(self, device: torch.device) -> None: for model in self.models.values(): model.to(device) def cuda(self) -> None: self.to(torch.device("cuda")) def cpu(self) -> None: self.to(torch.device("cpu"))