demo-banglat5 / app.py
Sanzana Lora
Update app.py
d1855cd verified
raw
history blame
2.81 kB
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import gradio as gr
# Load fine-tuned T5 models for different tasks
translation_model_en_bn = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/banglat5_nmt_en_bn")
translation_tokenizer_en_bn = AutoTokenizer.from_pretrained("csebuetnlp/banglat5_nmt_en_bn")
translation_model_bn_en = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/banglat5_nmt_bn_en")
translation_tokenizer_bn_en = AutoTokenizer.from_pretrained("csebuetnlp/banglat5_nmt_bn_en")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")
summarization_tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")
paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/banglat5_banglaparaphrase")
paraphrase_tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/banglat5_banglaparaphrase")
# Function to perform machine translation
def translate_text_en_bn(input_text):
inputs = translation_tokenizer_en_bn(input_text, return_tensors="pt")
outputs = translation_model_en_bn.generate(**inputs)
translated_text = translation_tokenizer_en_bn.decode(outputs[0], skip_special_tokens=True)
return translated_text
def translate_text_bn_en(input_text):
inputs = translation_tokenizer_bn_en(input_text, return_tensors="pt")
outputs = translation_model_bn_en.generate(**inputs)
translated_text = translation_tokenizer_bn_en.decode(outputs[0], skip_special_tokens=True)
return translated_text
# Function to perform summarization
def summarize_text(input_text):
inputs = summarization_tokenizer(input_text, return_tensors="pt")
outputs = summarization_model.generate(**inputs)
summarized_text = summarization_tokenizer.decode(outputs[0], skip_special_tokens=True)
return summarized_text
# Function to perform paraphrasing
def paraphrase_text(input_text):
inputs = paraphrase_tokenizer(input_text, return_tensors="pt")
outputs = paraphrase_model.generate(**inputs)
paraphrased_text = paraphrase_tokenizer.decode(outputs[0], skip_special_tokens=True)
return paraphrased_text
def process_text(text, task):
if task == "Translate_English_to_Bengali":
return translate_text_en_bn(text)
elif task == "Translate_Bengali_to_English":
return translate_text_bn_en(text)
elif task == "Summarize":
return summarize_text(text)
elif task == "Paraphrase":
return paraphrase_text(text)
# Define the Gradio interface
iface = gr.Interface(
fn=process_text,
inputs=["text", gr.Dropdown(["Translate_English_to_Bengali", "Translate_Bengali_to_English", "Summarize", "Paraphrase"])],
outputs="text",
live=False,
title="Usage of BanglaT5 Model"
)
# Launch the Gradio app
iface.launch(inline=False)