Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import random | |
import re | |
import torch | |
import transformers | |
from keybert import KeyBERT | |
from transformers import (T5ForConditionalGeneration, T5Tokenizer) | |
DEVICE = torch.device('cpu') | |
MAX_LEN = 512 | |
tokenizer = T5Tokenizer.from_pretrained('t5-base') | |
model = T5ForConditionalGeneration.from_pretrained('ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation') | |
mod = KeyBERT('distilbert-base-nli-mean-tokens') | |
model.to(DEVICE) | |
context = "The Transgender Persons Bill, 2016 was hurriedly passed in the Lok Sabha, amid much outcry from the very community it claims to protect." | |
def func(context, slide): | |
slide = int(slide) | |
randomness = 0.4 | |
orig = int(np.ceil(randomness * slide)) | |
temp = slide - orig | |
ap = filter_keyword(context, ran=slide*2) | |
outputs = [] | |
print(slide) | |
print(orig) | |
print(ap) | |
for i in range(orig): | |
inputs = "context: " + context + " keyword: " + ap[i][0] | |
source_tokenizer = tokenizer.encode_plus(inputs, max_length=512, pad_to_max_length=True, return_tensors="pt") | |
outs = model.generate(input_ids=source_tokenizer['input_ids'].to(DEVICE), attention_mask=source_tokenizer['attention_mask'].to(DEVICE), max_length=50) | |
dec = [tokenizer.decode(ids) for ids in outs][0] | |
st = dec.replace("<pad> ", "") | |
st = st.replace("</s>", "") | |
if ap[i][1] > 0.0: | |
outputs.append((st, "Good")) | |
else: | |
outputs.append((st, "Bad")) | |
del ap[: orig] | |
print("first",outputs) | |
print(temp) | |
if temp > 0: | |
for i in range(temp): | |
keyword = random.choice(ap) | |
inputs = "context: " + context + " keyword: " + keyword[0] | |
source_tokenizer = tokenizer.encode_plus(inputs, max_length=512, pad_to_max_length=True, return_tensors="pt") | |
outs = model.generate(input_ids=source_tokenizer['input_ids'].to(DEVICE), attention_mask=source_tokenizer['attention_mask'].to(DEVICE), max_length=50) | |
dec = [tokenizer.decode(ids) for ids in outs][0] | |
st = dec.replace("<pad> ", "") | |
st = st.replace("</s>", "") | |
if keyword[1] > 0.0: | |
outputs.append((st, "Good")) | |
else: | |
outputs.append((st, "Bad")) | |
print("second",outputs) | |
return outputs | |
gr.Interface(func, [gr.inputs.Textbox(lines=10, label="context"), gr.inputs.Slider(minimum=1, maximum=5, default=1, label="No of Question"),], gr.outputs.KeyValues()).launch() |