|
import torch |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration,T5Model |
|
import gradio as gr |
|
|
|
def get_questions(paragraph, tokenizer, model, device): |
|
bt_levels = ['Remember', 'Understand', 'Apply', 'Analyse', 'Evaluate', 'Create'] |
|
questions_dict = {} |
|
for bt_level in bt_levels: |
|
input_text = f'{bt_level}: {paragraph} {tokenizer.eos_token}' |
|
input_ids = tokenizer.encode(input_text, max_length=512, padding='max_length', truncation=True, return_tensors='pt').to(device) |
|
model.eval() |
|
generated_ids = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True).to(device) |
|
output_text = tokenizer.decode(generated_ids.squeeze(), skip_special_tokens=True).lstrip('\n') |
|
output_text = output_text.split(' ', 1)[1] |
|
questions_dict.update({bt_level: output_text}) |
|
|
|
return questions_dict |
|
|
|
|
|
def main(paragraph): |
|
model = T5ForConditionalGeneration.from_pretrained('./save_model') |
|
tokenizer = T5Tokenizer.from_pretrained('./save_model') |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model.to(device) |
|
output = get_questions(paragraph, tokenizer, model, device) |
|
return output |
|
|
|
gr.Interface( |
|
fn=main, |
|
inputs="textbox", |
|
outputs="textbox", |
|
live=True).launch() |