Spaces:
Sleeping
Sleeping
from typing import Any, List, Optional | |
from pydantic import PrivateAttr | |
from transformers import pipeline, Pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
from obsei.analyzer.base_analyzer import ( | |
BaseAnalyzer, | |
BaseAnalyzerConfig, | |
MAX_LENGTH, | |
) | |
from obsei.payload import TextPayload | |
class TranslationAnalyzer(BaseAnalyzer): | |
_pipeline: Pipeline = PrivateAttr() | |
_max_length: int = PrivateAttr() | |
TYPE: str = "Translation" | |
model_name_or_path: str | |
def __init__(self, **data: Any): | |
super().__init__(**data) | |
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) | |
model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name_or_path) | |
self._pipeline = pipeline( | |
"translation", model=model, tokenizer=tokenizer, device=self._device_id | |
) | |
if hasattr(self._pipeline.model.config, "max_position_embeddings"): | |
self._max_length = self._pipeline.model.config.max_position_embeddings | |
else: | |
self._max_length = MAX_LENGTH | |
def analyze_input( | |
self, | |
source_response_list: List[TextPayload], | |
analyzer_config: Optional[BaseAnalyzerConfig] = None, | |
**kwargs: Any, | |
) -> List[TextPayload]: | |
analyzer_output = [] | |
for batch_responses in self.batchify(source_response_list, self.batch_size): | |
texts = [ | |
source_response.processed_text[: self._max_length] | |
for source_response in batch_responses | |
] | |
batch_predictions = self._pipeline(texts) | |
for prediction, source_response in zip(batch_predictions, batch_responses): | |
segmented_data = { | |
"translation_data": { | |
"original_text": source_response.processed_text | |
} | |
} | |
if source_response.segmented_data: | |
segmented_data = { | |
**segmented_data, | |
**source_response.segmented_data, | |
} | |
analyzer_output.append( | |
TextPayload( | |
processed_text=prediction["translation_text"], | |
meta=source_response.meta, | |
segmented_data=segmented_data, | |
source_name=source_response.source_name, | |
) | |
) | |
return analyzer_output | |