from typing import Any, Dict, Tuple from transformers import Pipeline from transformers.pipelines.base import GenericTensor from transformers.utils import ModelOutput from typing import Union,List import torch class EncodePipeline(Pipeline): def __init__(self, max_length=256,*args, **kwargs): super().__init__(*args, **kwargs) self.max_length = max_length def _sanitize_parameters(self, **pipeline_parameters): return {},{},{} def preprocess(self, input: Union[Tuple[str],List[Tuple[str]]], **preprocess_parameters: Dict) -> Dict[str, GenericTensor]: tensors = self.tokenizer( input, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt", ) return tensors def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput: logits = self.model.encode(**input_tensors) return logits.tolist() def postprocess( self, model_outputs: ModelOutput, **postprocess_parameters: Dict ) -> Any: return model_outputs class SimilarPipeline(Pipeline): def __init__(self, max_length=256,*args, **kwargs): super().__init__(*args, **kwargs) self.max_length = max_length def _sanitize_parameters(self, **pipeline_parameters): return {},{},{} def preprocess(self, input: Union[Tuple[str],List[Tuple[str]]], **preprocess_parameters: Dict) -> Dict[str, GenericTensor]: if isinstance(input, list): a = list(map(lambda x: x[0], input)) b = list(map(lambda x: x[1], input)) else: a = input[0] b = input[1] tensors = self.tokenizer( a, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt", ) tensors_b = self.tokenizer( b, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt", ) for key in tensors: tensors[key] = torch.cat((tensors[key],tensors_b[key]),dim=0) return tensors def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput: _,logits = self.model(**input_tensors) logits_a = logits[:logits.size(0)//2] logits_b = logits[logits.size(0)//2:] logits = torch.nn.functional.cosine_similarity(logits_a, logits_b) return logits.tolist() def postprocess( self, model_outputs: ModelOutput, **postprocess_parameters: Dict ) -> Any: return model_outputs