Spaces:
Runtime error
Runtime error
Vaibhavbrkn
commited on
Commit
Β·
dfabef6
1
Parent(s):
bb64a16
Initial Commit
Browse files- app.py +106 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from keybert import KeyBERT
|
4 |
+
import random
|
5 |
+
from transformers import (
|
6 |
+
T5ForConditionalGeneration,
|
7 |
+
T5Tokenizer,
|
8 |
+
)
|
9 |
+
import re
|
10 |
+
import transformers
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
+
MAX_LEN = 512
|
16 |
+
|
17 |
+
tokenizer = T5Tokenizer.from_pretrained('t5-base')
|
18 |
+
model = T5ForConditionalGeneration.from_pretrained(
|
19 |
+
'Vaibhavbrkn/question-gen')
|
20 |
+
mod = KeyBERT('distilbert-base-nli-mean-tokens')
|
21 |
+
model.to(DEVICE)
|
22 |
+
|
23 |
+
context = "The Transgender Persons Bill, 2016 was hurriedly passed in the Lok Sabha, amid much outcry from the very community it claims to protect."
|
24 |
+
|
25 |
+
|
26 |
+
def filter_keyword(data, ran=5):
|
27 |
+
ap = []
|
28 |
+
real = []
|
29 |
+
res = re.sub(r'-', ' ', data)
|
30 |
+
res = re.sub(r'[^\w\s\.\,]', '', res)
|
31 |
+
for i in range(1, 4):
|
32 |
+
ap.append(mod.extract_keywords(
|
33 |
+
res, keyphrase_ngram_range=(1, i), diversity=0.7, top_n=ran*2))
|
34 |
+
for i in range(3):
|
35 |
+
for j in range(len(ap[i])):
|
36 |
+
if ap[i][j][0].lower() in res.lower():
|
37 |
+
real.append(ap[i][j])
|
38 |
+
|
39 |
+
real = sorted(real, key=lambda x: x[1], reverse=True)
|
40 |
+
ap = []
|
41 |
+
st = ""
|
42 |
+
for i in range(len(real)):
|
43 |
+
if real[i][0] in st:
|
44 |
+
continue
|
45 |
+
else:
|
46 |
+
ap.append(real[i])
|
47 |
+
st += real[i][0] + " "
|
48 |
+
if len(ap) == ran:
|
49 |
+
break
|
50 |
+
|
51 |
+
return ap
|
52 |
+
|
53 |
+
|
54 |
+
# FOR BAD label negative or bottom 3
|
55 |
+
|
56 |
+
def func(context, slide):
|
57 |
+
slide = int(slide)
|
58 |
+
randomness = 0.4
|
59 |
+
orig = int(np.ceil(randomness * slide))
|
60 |
+
temp = slide - orig
|
61 |
+
ap = filter_keyword(context, ran=slide*2)
|
62 |
+
outputs = []
|
63 |
+
for i in range(orig):
|
64 |
+
|
65 |
+
inputs = "context: " + context + " keyword: " + ap[i][0]
|
66 |
+
source_tokenizer = tokenizer.encode_plus(
|
67 |
+
inputs, max_length=512, pad_to_max_length=True, return_tensors="pt")
|
68 |
+
outs = model.generate(input_ids=source_tokenizer['input_ids'].to(
|
69 |
+
DEVICE), attention_mask=source_tokenizer['attention_mask'].to(DEVICE), max_length=50)
|
70 |
+
dec = [tokenizer.decode(ids) for ids in outs][0]
|
71 |
+
st = dec.replace("<pad> ", "")
|
72 |
+
st = st.replace("</s>", "")
|
73 |
+
if ap[i][1] > 0.0:
|
74 |
+
outputs.append((st, "Good"))
|
75 |
+
else:
|
76 |
+
outputs.append((st, "Bad"))
|
77 |
+
|
78 |
+
del ap[: orig]
|
79 |
+
|
80 |
+
if temp > 0:
|
81 |
+
for i in range(temp):
|
82 |
+
keyword = random.choice(ap)
|
83 |
+
inputs = "context: " + context + \
|
84 |
+
" keyword: " + keyword[0]
|
85 |
+
source_tokenizer = tokenizer.encode_plus(
|
86 |
+
inputs, max_length=512, pad_to_max_length=True, return_tensors="pt")
|
87 |
+
outs = model.generate(input_ids=source_tokenizer['input_ids'].to(
|
88 |
+
DEVICE), attention_mask=source_tokenizer['attention_mask'].to(DEVICE), max_length=50)
|
89 |
+
dec = [tokenizer.decode(ids) for ids in outs][0]
|
90 |
+
st = dec.replace("<pad> ", "")
|
91 |
+
st = st.replace("</s>", "")
|
92 |
+
if keyword[1] > 0.0:
|
93 |
+
outputs.append((st, "Good"))
|
94 |
+
else:
|
95 |
+
outputs.append((st, "Bad"))
|
96 |
+
|
97 |
+
return outputs
|
98 |
+
|
99 |
+
|
100 |
+
gr.Interface(func,
|
101 |
+
[
|
102 |
+
gr.inputs.Textbox(lines=10, label="context"),
|
103 |
+
gr.inputs.Slider(minimum=1, maximum=5,
|
104 |
+
default=3, label="No of Question"),
|
105 |
+
],
|
106 |
+
gr.outputs.KeyValues(), capture_session=True, server_name="0.0.0.0").launch()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
keybert
|
5 |
+
gradio
|