minskiter's picture
feat(similar.py): update pipeline
fdb0b54
raw
history blame contribute delete
No virus
1.38 kB
from typing import Any, Dict, Tuple
from transformers import Pipeline
from transformers.pipelines.base import GenericTensor
from transformers.utils import ModelOutput
from typing import Union,List
class SimilarPipeline(Pipeline):
def __init__(self, max_length=512,*args, **kwargs):
super().__init__(*args, **kwargs)
self.max_length = max_length
def _sanitize_parameters(self, **pipeline_parameters):
return {},{},{}
def preprocess(self, input: Union[Tuple[str],List[Tuple[str]]], **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
if isinstance(input, list):
a = list(map(lambda x: x[0], input))
b = list(map(lambda x: x[1], input))
else:
a = input[0]
b = input[1]
tensors = self.tokenizer(
a,
b,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
return tensors
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
_,logits = self.model(**input_tensors)
return logits.tolist()
def postprocess(
self,
model_outputs: ModelOutput,
**postprocess_parameters: Dict
) -> Any:
return model_outputs