Wootang01 commited on
Commit
4662318
1 Parent(s): 6fe4b43

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import re
5
+ import torch
6
+ import transformers
7
+
8
+ from keybert import KeyBERT
9
+ from transformers import (T5ForConditionalGeneration, T5Tokenizer)
10
+
11
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ MAX_LEN = 512
13
+
14
+ tokenizer = T5Tokenizer.from_pretrained('t5-base')
15
+ model = T5ForConditionalGeneration.from_pretrained('Vaibhavbrkn/question-gen')
16
+ mod = KeyBERT('distilbert-base-nli-mean-tokens')
17
+ model.to(DEVICE)
18
+
19
+ context = "The Transgender Persons Bill, 2016 was hurriedly passed in the Lok Sabha, amid much outcry from the very community it claims to protect."
20
+
21
+ def func(context, slide):
22
+ slide = int(slide)
23
+ randomness = 0.4
24
+ orig = int(np.ceil(randomness * slide))
25
+ temp = slide - orig
26
+ ap = filter_keyword(context, ran=slide*2)
27
+ outputs = []
28
+ print(slide)
29
+ print(orig)
30
+ print(ap)
31
+ for i in range(orig):
32
+ inputs = "context: " + context + " keyword: " + ap[i][0]
33
+ source_tokenizer = tokenizer.encode_plus(inputs, max_length=512, pad_to_max_length=True, return_tensors="pt")
34
+ outs = model.generate(input_ids=source_tokenizer['input_ids'].to(DEVICE), attention_mask=source_tokenizer['attention_mask'].to(DEVICE), max_length=50)
35
+ dec = [tokenizer.decode(ids) for ids in outs][0]
36
+ st = dec.replace("<pad> ", "")
37
+ st = st.replace("</s>", "")
38
+ if ap[i][1] > 0.0:
39
+ outputs.append((st, "Good"))
40
+ else:
41
+ outputs.append((st, "Bad"))
42
+
43
+ del ap[: orig]
44
+ print("first",outputs)
45
+ print(temp)
46
+
47
+ if temp > 0:
48
+ for i in range(temp):
49
+ keyword = random.choice(ap)
50
+ inputs = "context: " + context + " keyword: " + keyword[0]
51
+ source_tokenizer = tokenizer.encode_plus(inputs, max_length=512, pad_to_max_length=True, return_tensors="pt")
52
+ outs = model.generate(input_ids=source_tokenizer['input_ids'].to(DEVICE), attention_mask=source_tokenizer['attention_mask'].to(DEVICE), max_length=50)
53
+ dec = [tokenizer.decode(ids) for ids in outs][0]
54
+ st = dec.replace("<pad> ", "")
55
+ st = st.replace("</s>", "")
56
+ if keyword[1] > 0.0:
57
+ outputs.append((st, "Good"))
58
+ else:
59
+ outputs.append((st, "Bad"))
60
+ print("second",outputs)
61
+
62
+ return outputs
63
+
64
+ gr.Interface(func, [gr.inputs.Textbox(lines=10, label="context"), gr.inputs.Slider(minimum=1, maximum=5, default=1, label="No of Question"),], gr.outputs.KeyValues()).launch()