Spaces:
Runtime error
Runtime error
import time | |
import re | |
from math import floor, ceil | |
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils | |
# from nltk.tokenize import sent_tokenize | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS, cross_origin | |
import webvtt | |
from io import StringIO | |
from mosestokenizer import MosesSentenceSplitter | |
from indicTrans.inference.engine import Model | |
from punctuate import RestorePuncts | |
from indicnlp.tokenize.sentence_tokenize import sentence_split | |
app = Flask(__name__) | |
cors = CORS(app) | |
app.config['CORS_HEADERS'] = 'Content-Type' | |
indic2en_model = Model(expdir='models/v3/indic-en') | |
en2indic_model = Model(expdir='models/v3/en-indic') | |
m2m_model = Model(expdir='models/m2m') | |
rpunct = RestorePuncts() | |
indic_language_dict = { | |
'Assamese': 'as', | |
'Hindi' : 'hi', | |
'Marathi' : 'mr', | |
'Tamil' : 'ta', | |
'Bengali' : 'bn', | |
'Kannada' : 'kn', | |
'Oriya' : 'or', | |
'Telugu' : 'te', | |
'Gujarati' : 'gu', | |
'Malayalam' : 'ml', | |
'Punjabi' : 'pa', | |
} | |
splitter = MosesSentenceSplitter('en') | |
def get_inference_params(): | |
source_language = request.form['source_language'] | |
target_language = request.form['target_language'] | |
if source_language in indic_language_dict and target_language == 'English': | |
model = indic2en_model | |
source_lang = indic_language_dict[source_language] | |
target_lang = 'en' | |
elif source_language == 'English' and target_language in indic_language_dict: | |
model = en2indic_model | |
source_lang = 'en' | |
target_lang = indic_language_dict[target_language] | |
elif source_language in indic_language_dict and target_language in indic_language_dict: | |
model = m2m_model | |
source_lang = indic_language_dict[source_language] | |
target_lang = indic_language_dict[target_language] | |
return model, source_lang, target_lang | |
def main(): | |
return "IndicTrans API" | |
def supported_languages(): | |
return jsonify(indic_language_dict) | |
def infer_indic_en(): | |
model, source_lang, target_lang = get_inference_params() | |
source_text = request.form['text'] | |
start_time = time.time() | |
target_text = model.translate_paragraph(source_text, source_lang, target_lang) | |
end_time = time.time() | |
return {'text':target_text, 'duration':round(end_time-start_time, 2)} | |
def infer_vtt_indic_en(): | |
start_time = time.time() | |
model, source_lang, target_lang = get_inference_params() | |
source_text = request.form['text'] | |
# vad_segments = request.form['vad_nochunk'] # Assuming it is an array of start & end timestamps | |
vad = webvtt.read_buffer(StringIO(source_text)) | |
source_sentences = [v.text.replace('\r', '').replace('\n', ' ') for v in vad] | |
## SUMANTH LOGIC HERE ## | |
# for each vad timestamp, do: | |
large_sentence = ' '.join(source_sentences) # only sentences in that time range | |
large_sentence = large_sentence.lower() | |
# split_sents = sentence_split(large_sentence, 'en') | |
# print(split_sents) | |
large_sentence = re.sub(r'[^\w\s]', '', large_sentence) | |
punctuated = rpunct.punctuate(large_sentence, batch_size=32) | |
end_time = time.time() | |
print("Time Taken for punctuation: {} s".format(end_time - start_time)) | |
start_time = time.time() | |
split_sents = splitter([punctuated]) ### Please uncomment | |
# print(split_sents) | |
# output_sentence_punctuated = model.translate_paragraph(punctuated, source_lang, target_lang) | |
output_sents = model.batch_translate(split_sents, source_lang, target_lang) | |
# print(output_sents) | |
# output_sents = split_sents | |
# print(output_sents) | |
# align this to those range of source_sentences in `captions` | |
map_ = {split_sents[i] : output_sents[i] for i in range(len(split_sents))} | |
# print(map_) | |
punct_para = ' '.join(list(map_.keys())) | |
nmt_para = ' '.join(list(map_.values())) | |
nmt_words = nmt_para.split(' ') | |
len_punct = len(punct_para.split(' ')) | |
len_nmt = len(nmt_para.split(' ')) | |
start = 0 | |
for i in range(len(vad)): | |
if vad[i].text == '': | |
continue | |
len_caption = len(vad[i].text.split(' ')) | |
frac = (len_caption / len_punct) | |
# frac = round(frac, 2) | |
req_nmt_size = floor(frac * len_nmt) | |
# print(frac, req_nmt_size) | |
vad[i].text = ' '.join(nmt_words[start:start+req_nmt_size]) | |
# print(vad[i].text) | |
# print(start, req_nmt_size) | |
start += req_nmt_size | |
end_time = time.time() | |
print("Time Taken for translation: {} s".format(end_time - start_time)) | |
# vad.save('aligned.vtt') | |
return { | |
'text': vad.content, | |
# 'duration':round(end_time-start_time, 2) | |
} | |