kltn20133118's picture
Upload 337 files
dbaa71b verified
from typing import Any, List, Optional
from pydantic import PrivateAttr
from transformers import pipeline, Pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
from obsei.analyzer.base_analyzer import (
BaseAnalyzer,
BaseAnalyzerConfig,
MAX_LENGTH,
)
from obsei.payload import TextPayload
class TranslationAnalyzer(BaseAnalyzer):
_pipeline: Pipeline = PrivateAttr()
_max_length: int = PrivateAttr()
TYPE: str = "Translation"
model_name_or_path: str
def __init__(self, **data: Any):
super().__init__(**data)
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name_or_path)
self._pipeline = pipeline(
"translation", model=model, tokenizer=tokenizer, device=self._device_id
)
if hasattr(self._pipeline.model.config, "max_position_embeddings"):
self._max_length = self._pipeline.model.config.max_position_embeddings
else:
self._max_length = MAX_LENGTH
def analyze_input(
self,
source_response_list: List[TextPayload],
analyzer_config: Optional[BaseAnalyzerConfig] = None,
**kwargs: Any,
) -> List[TextPayload]:
analyzer_output = []
for batch_responses in self.batchify(source_response_list, self.batch_size):
texts = [
source_response.processed_text[: self._max_length]
for source_response in batch_responses
]
batch_predictions = self._pipeline(texts)
for prediction, source_response in zip(batch_predictions, batch_responses):
segmented_data = {
"translation_data": {
"original_text": source_response.processed_text
}
}
if source_response.segmented_data:
segmented_data = {
**segmented_data,
**source_response.segmented_data,
}
analyzer_output.append(
TextPayload(
processed_text=prediction["translation_text"],
meta=source_response.meta,
segmented_data=segmented_data,
source_name=source_response.source_name,
)
)
return analyzer_output