potsawee's picture
Add application file
ada39d6
raw
history blame
3.2 kB
import gradio as gr
import random
import torch
from transformers import MT5Tokenizer, MT5ForConditionalGeneration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = MT5Tokenizer.from_pretrained("potsawee/mt5-english-thai-large-translation")
translator = MT5ForConditionalGeneration.from_pretrained("potsawee/mt5-english-thai-large-translation")
summarizer = MT5ForConditionalGeneration.from_pretrained("potsawee/mt5-english-thai-large-summarization")
translator.eval()
summarizer.eval()
translator.to(device)
summarizer.to(device)
# def generate_multiple_choice_question(
# context
# ):
# num_questions = 1
# question_item = question_generation_sampling(
# g1_model, g1_tokenizer,
# g2_model, g2_tokenizer,
# context, num_questions, device
# )[0]
# question = question_item['question']
# options = question_item['options']
# options[0] = f"{options[0]} [ANSWER]"
# random.shuffle(options)
# output_string = f"Question: {question}\n[A] {options[0]}\n[B] {options[1]}\n[C] {options[2]}\n[D] {options[3]}"
# return output_string
#
# demo = gr.Interface(
# fn=generate_multiple_choice_question,
# inputs=gr.Textbox(lines=8, placeholder="Context Here..."),
# outputs=gr.Textbox(lines=5, placeholder="Question: \n[A] \n[B] \n[C] \n[D] "),
# title="Multiple-choice Question Generator",
# description="Provide some context (e.g. news article or any passage) in the context box and click **Submit**. The models currently support English only. This demo is a part of MQAG - https://github.com/potsawee/mqag0.",
# allow_flagging='never'
# )
def generate_output(
task,
text,
):
inputs = tokenizer(
[text],
padding="longest",
max_length=1024,
truncation=True,
return_tensors="pt",
).to(device)
if task == 'Translation':
outputs = translator.generate(
**inputs,
max_new_tokens=256,
)
elif task == 'Summarization':
outputs = summarizer.generate(
**inputs,
max_new_tokens=256,
)
else:
raise ValueError("task undefined!")
gen_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return gen_text
TASKS = ["Translation", "Summarization"]
demo = gr.Interface(
fn=generate_output,
inputs=[
gr.components.Radio(label="Task", choices=TASKS, value="Translation"),
gr.components.Textbox(label="Text (in English)", lines=10),
],
outputs=gr.Textbox(label="Text (in Thai)", lines=4),
# examples=[["Building a translation demo with Gradio is so easy!", "eng_Latn", "spa_Latn"]],
cache_examples=False,
title="English🇬🇧 to Thai🇹🇭 | Translation or Summarization",
description="Provide some text (in English) & select one of the tasks (Translation or Summarization). Note that currently the model only supports text up to 1024 tokens. The base architecture is mt5-large with the embeddings filtered to only English and Thai tokens and fine-tuned to XSum (Eng2Thai) Dataset (https://huggingface.co/datasets/potsawee/xsum_eng2thai).",
allow_flagging='never'
)
demo.launch()