Wootang01's picture
Update app.py
61b4399
import gradio as gr
import numpy as np
import random
import re
import torch
import transformers
from keybert import KeyBERT
from transformers import (T5ForConditionalGeneration, T5Tokenizer)
DEVICE = torch.device('cpu')
MAX_LEN = 512
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation')
mod = KeyBERT('distilbert-base-nli-mean-tokens')
model.to(DEVICE)
context = "The Transgender Persons Bill, 2016 was hurriedly passed in the Lok Sabha, amid much outcry from the very community it claims to protect."
def func(context, slide):
slide = int(slide)
randomness = 0.4
orig = int(np.ceil(randomness * slide))
temp = slide - orig
ap = filter_keyword(context, ran=slide*2)
outputs = []
print(slide)
print(orig)
print(ap)
for i in range(orig):
inputs = "context: " + context + " keyword: " + ap[i][0]
source_tokenizer = tokenizer.encode_plus(inputs, max_length=512, pad_to_max_length=True, return_tensors="pt")
outs = model.generate(input_ids=source_tokenizer['input_ids'].to(DEVICE), attention_mask=source_tokenizer['attention_mask'].to(DEVICE), max_length=50)
dec = [tokenizer.decode(ids) for ids in outs][0]
st = dec.replace("<pad> ", "")
st = st.replace("</s>", "")
if ap[i][1] > 0.0:
outputs.append((st, "Good"))
else:
outputs.append((st, "Bad"))
del ap[: orig]
print("first",outputs)
print(temp)
if temp > 0:
for i in range(temp):
keyword = random.choice(ap)
inputs = "context: " + context + " keyword: " + keyword[0]
source_tokenizer = tokenizer.encode_plus(inputs, max_length=512, pad_to_max_length=True, return_tensors="pt")
outs = model.generate(input_ids=source_tokenizer['input_ids'].to(DEVICE), attention_mask=source_tokenizer['attention_mask'].to(DEVICE), max_length=50)
dec = [tokenizer.decode(ids) for ids in outs][0]
st = dec.replace("<pad> ", "")
st = st.replace("</s>", "")
if keyword[1] > 0.0:
outputs.append((st, "Good"))
else:
outputs.append((st, "Bad"))
print("second",outputs)
return outputs
gr.Interface(func, [gr.inputs.Textbox(lines=10, label="context"), gr.inputs.Slider(minimum=1, maximum=5, default=1, label="No of Question"),], gr.outputs.KeyValues()).launch()