from typing import Any, Dict, Tuple from transformers import Pipeline from transformers.pipelines.base import GenericTensor from transformers.utils import ModelOutput from typing import Union,List class SimilarPipeline(Pipeline): def __init__(self, max_length=512,*args, **kwargs): super().__init__(*args, **kwargs) self.max_length = max_length def _sanitize_parameters(self, **pipeline_parameters): return {},{},{} def preprocess(self, input: Union[Tuple[str],List[Tuple[str]]], **preprocess_parameters: Dict) -> Dict[str, GenericTensor]: if isinstance(input, list): a = list(map(lambda x: x[0], input)) b = list(map(lambda x: x[1], input)) else: a = input[0] b = input[1] tensors = self.tokenizer( a, b, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt", ) return tensors def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput: _,logits = self.model(**input_tensors) return logits.tolist() def postprocess( self, model_outputs: ModelOutput, **postprocess_parameters: Dict ) -> Any: return model_outputs