from typing import Dict, List, Any import sys, os, re from tqdm import tqdm import torch from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig from IndicTransTokenizer.utils import preprocess_batch, postprocess_batch from IndicTransTokenizer.tokenizer import IndicTransTokenizer class EndpointHandler(): def __init__(self, direction = "en-indic", quantization = ""): self.model_name = "ai4bharat/indictrans2-en-indic-1B" self.utterance_pattern = re.compile(r"^\d+$") self.timestamp_pattern = re.compile(r"(\d+:\d+:\d+,\d+)\s*-->\s*(\d+:\d+:\d+,\d+)") self.BATCH_SIZE = 16 self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" self.model = None self.tokenizer = None if quantization == "4-bit": qconfig = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, ) elif quantization == "8-bit": qconfig = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_compute_dtype=torch.bfloat16, ) else: qconfig = None self.tokenizer = IndicTransTokenizer(direction=direction) self.model = AutoModelForSeq2SeqLM.from_pretrained( self.model_name, trust_remote_code=True, low_cpu_mem_usage=True, quantization_config=qconfig ) if qconfig==None: self.model = self.model.half() self.model.eval() def batch_translate(self, input_sentences, src_lang, tgt_lang): translations = [] for i in range(0, len(input_sentences), self.BATCH_SIZE): batch = input_sentences[i : i + self.BATCH_SIZE] # Preprocess the batch and extract entity mappings batch, entity_map = preprocess_batch( batch, src_lang=src_lang, tgt_lang=tgt_lang ) # Tokenize the batch and generate input encodings inputs = self.tokenizer( batch, src=True, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True, ).to(self.DEVICE) # Generate translations using the model with torch.no_grad(): generated_tokens = self.model.generate( **inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1, ) # Decode the generated tokens into text generated_tokens = self.tokenizer.batch_decode( generated_tokens.detach().cpu().tolist(), src=False ) # Postprocess the translations, including entity replacement translations += postprocess_batch( generated_tokens, lang=tgt_lang, placeholder_entity_map=entity_map ) del inputs if torch.cuda.is_available(): torch.cuda.empty_cache() return translations def read_srt(self, srt_path): data = [] with open(srt_path, 'r', encoding='utf-8') as fp: utterance_ind = "" start_end = "" text = "" for ind, line in enumerate(fp.readlines()): line = line.strip() if, line) is not None: utterance_ind = line elif, line) is not None: start_end = line else: text += line if utterance_ind!='' and start_end!='' and text!='': data.append({'utterance_ind': utterance_ind, 'start_end': start_end, 'text': text}) utterance_ind = '' start_end = '' text = '' return data def test(self, inputs) -> List[Dict[str, Any]]: """ data args: inputs (:obj: (transcript_path : 'str', src_lang : 'str', tgt_lang : 'str') kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ src_lang = inputs["src_lang"] tgt_lang = inputs["tgt_lang"] transcript_path = inputs["transcript_path"] output_translations = [] if self.model is not None: transcriptions = self.read_srt(transcript_path) trans_sents = [entry['text'] for entry in transcriptions] indic_translations = self.batch_translate(trans_sents, src_lang, tgt_lang) for i in tqdm(range(len(transcriptions))): entry = transcriptions[i] entry['text'] = indic_translations[i] output_translations.append(entry) return output_translations else: return [] def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: (transcript_path : 'str', src_lang : 'str', tgt_lang : 'str') kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ inputs = data.pop("inputs",data) src_lang = inputs["src_lang"] tgt_lang = inputs["tgt_lang"] transcript_path = inputs["transcript_path"] output_translations = [] if self.model is not None: transcriptions = self.read_srt(transcript_path) trans_sents = [entry['text'] for entry in transcriptions] indic_translations = self.batch_translate(trans_sents, src_lang, tgt_lang) for i in tqdm(range(len(transcriptions))): entry = transcriptions[i] entry['text'] = indic_translations[i] output_translations.append(entry) return output_translations else: return []