File size: 3,468 Bytes
dfabef6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63f4262
 
 
dfabef6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63f4262
 
dfabef6
 
 
 
42f0c65
dfabef6
 
 
 
 
 
 
 
 
 
 
63f4262
dfabef6
 
 
 
 
 
 
 
6f63374
dfabef6
95f5785
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import numpy as np
from keybert import KeyBERT
import random
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
)
import re
import transformers
import torch


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MAX_LEN = 512

tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained(
    'Vaibhavbrkn/question-gen')
mod = KeyBERT('distilbert-base-nli-mean-tokens')
model.to(DEVICE)

context = "The Transgender Persons Bill, 2016 was hurriedly passed in the Lok Sabha, amid much outcry from the very community it claims to protect."


def filter_keyword(data, ran=5):
    ap = []
    real = []
    res = re.sub(r'-', ' ', data)
    res = re.sub(r'[^\w\s\.\,]', '', res)
    for i in range(1, 4):
        ap.append(mod.extract_keywords(
            res, keyphrase_ngram_range=(1, i), diversity=0.7, top_n=ran*2))
    for i in range(3):
        for j in range(len(ap[i])):
            if ap[i][j][0].lower() in res.lower():
                real.append(ap[i][j])

    real = sorted(real, key=lambda x: x[1], reverse=True)
    ap = []
    st = ""
    for i in range(len(real)):
        if real[i][0] in st:
            continue
        else:
            ap.append(real[i])
            st += real[i][0] + " "
        if len(ap) == ran:
            break

    return ap


# FOR BAD label negative or bottom 3

def func(context, slide):
    slide = int(slide)
    randomness = 0.4
    orig = int(np.ceil(randomness * slide))
    temp = slide - orig
    ap = filter_keyword(context, ran=slide*2)
    outputs = []
    print(slide)
    print(orig)
    print(ap)
    for i in range(orig):

        inputs = "context: " + context + " keyword: " + ap[i][0]
        source_tokenizer = tokenizer.encode_plus(
            inputs, max_length=512, pad_to_max_length=True, return_tensors="pt")
        outs = model.generate(input_ids=source_tokenizer['input_ids'].to(
            DEVICE), attention_mask=source_tokenizer['attention_mask'].to(DEVICE), max_length=50)
        dec = [tokenizer.decode(ids) for ids in outs][0]
        st = dec.replace("<pad> ", "")
        st = st.replace("</s>", "")
        if ap[i][1] > 0.0:
            outputs.append((st, "Good"))
        else:
            outputs.append((st, "Bad"))

    del ap[: orig]
    print("first",outputs)
    print(temp)

    if temp > 0:
        for i in range(temp):
            keyword = random.choice(ap)
            inputs = "context: " + context + " keyword: " + keyword[0]
            source_tokenizer = tokenizer.encode_plus(
                inputs, max_length=512, pad_to_max_length=True, return_tensors="pt")
            outs = model.generate(input_ids=source_tokenizer['input_ids'].to(
                DEVICE), attention_mask=source_tokenizer['attention_mask'].to(DEVICE), max_length=50)
            dec = [tokenizer.decode(ids) for ids in outs][0]
            st = dec.replace("<pad> ", "")
            st = st.replace("</s>", "")
            if keyword[1] > 0.0:
                outputs.append((st, "Good"))
            else:
                outputs.append((st, "Bad"))
    print("second",outputs)

    return outputs


gr.Interface(func,
             [
                 gr.inputs.Textbox(lines=10, label="context"),
                 gr.inputs.Slider(minimum=1, maximum=5,
                                  default=1, label="No of Question"),
             ],
             gr.outputs.KeyValues()).launch()