KKMS-KSSW-HF / src /translator.py
Chintan Donda
Moving kkms_kssw.py to src/
04e306a
raw
history blame
No virus
2.4 kB
import src.constants as constants_utils
import requests
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from mosestokenizer import *
from indicnlp.tokenize import sentence_tokenize
from googletrans import Translator, constants
class TRANSLATOR:
def __init__(self):
print()
def split_sentences(self, paragraph, language):
if language == "en":
with MosesSentenceSplitter(language) as splitter:
return splitter([paragraph])
elif language in constants_utils.INDIC_LANGUAGE:
return sentence_tokenize.sentence_split(paragraph, lang=language)
def get_in_hindi(self, payload):
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
article = self.split_sentences(payload['inputs'], 'en')
# inputs = tokenizer(payload['input'], return_tensors="pt")
out_text = ""
for a in article:
inputs = tokenizer(a, return_tensors="pt")
translated_tokens = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["hin_Deva"], max_length=100)
translated_sent = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
out_text = out_text.join(translated_sent)
return out_text
def get_in_indic(self, text, language='Hindi'):
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
inputs = tokenizer(text, return_tensors="pt")
code = "eng_Latn"
if language == 'Hindi':
code= "hin_Deva"
elif language == 'Marathi':
code = "mar_Deva"
translated_tokens = model.generate(
**inputs,
forced_bos_token_id=tokenizer.lang_code_to_id[code],
max_length=1000
)
out_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return out_text
def get_indic_google_translate(self, text, language='Hindi'):
# Init the Google API translator
translator = Translator()
translations = translator.translate(text, dest=constants_utils.INDIC_LANGUAGE.get(language, 'en'))
return str(translations.text)