File size: 6,184 Bytes
c1811af
9831428
41cb046
a381bc0
80e614a
9831428
 
 
 
064fc00
 
9831428
 
064fc00
80e614a
9522bb7
9831428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9522bb7
 
41cb046
 
 
 
9831428
4970856
 
 
a381bc0
 
4970856
a381bc0
4970856
80e614a
a381bc0
4970856
 
 
80e614a
 
 
 
 
 
 
 
22eefa0
80e614a
 
22eefa0
80e614a
 
22eefa0
80e614a
 
 
4970856
 
41cb046
 
 
 
80e614a
41cb046
 
 
 
77eba15
 
41cb046
77eba15
41cb046
288a5de
41cb046
77eba15
 
 
41cb046
288a5de
77eba15
 
 
288a5de
 
 
41cb046
 
288a5de
41cb046
 
288a5de
41cb046
 
288a5de
41cb046
 
288a5de
41cb046
 
288a5de
41cb046
 
 
288a5de
41cb046
 
 
288a5de
 
 
 
 
 
 
 
 
 
41cb046
 
 
288a5de
 
 
41cb046
288a5de
41cb046
288a5de
41cb046
 
288a5de
41cb046
 
288a5de
41cb046
 
288a5de
 
41cb046
 
288a5de
 
 
 
 
 
41cb046
288a5de
9522bb7
77eba15
 
9522bb7
064fc00
4970856
 
 
 
80e614a
4970856
 
c1811af
41cb046
4970856
 
c1811af
e537f35
4970856
c1811af
4970856
a72ba2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import gradio as gr
import numpy as np
import wikipediaapi as wk
import wikipedia
from transformers import (
    TokenClassificationPipeline,
    AutoModelForTokenClassification,
    AutoTokenizer,
    BertForQuestionAnswering,
    BertTokenizer
)
from transformers.pipelines import AggregationStrategy
import torch
print("hello")
# =====[ DEFINE PIPELINE ]===== #
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
    def __init__(self, model, *args, **kwargs):
        super().__init__(
            model=AutoModelForTokenClassification.from_pretrained(model),
            tokenizer=AutoTokenizer.from_pretrained(model),
            *args,
            **kwargs
        )

    def postprocess(self, model_outputs):
        results = super().postprocess(
            model_outputs=model_outputs,
            aggregation_strategy=AggregationStrategy.SIMPLE,
        )
        return np.unique([result.get("word").strip() for result in results])

# =====[ LOAD PIPELINE ]===== #
keyPhraseExtractionModel = "ml6team/keyphrase-extraction-kbir-inspec"
extractor = KeyphraseExtractionPipeline(model=keyPhraseExtractionModel)
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

#TODO: add further preprocessing
def keyphrases_extraction(text: str) -> str:
    keyphrases = extractor(text)
    return keyphrases

def wikipedia_search(input: str) -> str:
    input = input.replace("\n", " ")
    keyphrases = keyphrases_extraction(input)

    wiki = wk.Wikipedia('en')
    
    try :
        #TODO: add better extraction and search
        if len(keyphrases) == 0:
            return "Can you add more details to your question?"
    
        query_suggestion = wikipedia.suggest(keyphrases[0])
        if(query_suggestion != None):
            results = wikipedia.search(query_suggestion)
        else:
            results = wikipedia.search(keyphrases[0])

        index = 0
        page = wiki.page(results[index])
        while not ('.' in page.summary) or not page.exists():
            index += 1
            if index == len(results):
                raise Exception
            page = wiki.page(results[index])
        return page.summary
    
    except:
        return "I cannot answer this question"
    
def answer_question(question):

    context = wikipedia_search(question)
    if (context == "I cannot answer this question") or (context == "Can you add more details to your question?"):
        return context

    # ======== Tokenize ========
    # Apply the tokenizer to the input text, treating them as a text-pair.


    input_ids = tokenizer.encode(question, context)
    question_ids = input_ids[:input_ids.index(tokenizer.sep_token_id)+1]

    # Report how long the input sequence is. if longer than 512 tokens divide it multiple sequences

    length_of_group = 512 - len(question_ids)
    input_ids_without_question = input_ids[input_ids.index(tokenizer.sep_token_id)+1:]
    print(f"Query has {len(input_ids)} tokens, divided in {len(input_ids_without_question)//length_of_group + 1}.\n")

    input_ids_split = []
    for group in range(len(input_ids_without_question)//length_of_group + 1):
        input_ids_split.append(question_ids + input_ids_without_question[length_of_group*group:length_of_group*(group+1)-1])
    input_ids_split.append(question_ids + input_ids_without_question[length_of_group*(len(input_ids_without_question)//length_of_group + 1):len(input_ids_without_question)-1])
    
    scores = []
    for input in input_ids_split:
    # ======== Set Segment IDs ========
    # Search the input_ids for the first instance of the `[SEP]` token.
        sep_index = input.index(tokenizer.sep_token_id)

    # The number of segment A tokens includes the [SEP] token istelf.
        num_seg_a = sep_index + 1

    # The remainder are segment B.
        num_seg_b = len(input) - num_seg_a

    # Construct the list of 0s and 1s.
        segment_ids = [0]*num_seg_a + [1]*num_seg_b

    # There should be a segment_id for every input token.
        assert len(segment_ids) == len(input)

    # ======== Evaluate ========
    # Run our example through the model.
        outputs = model(torch.tensor([input]), # The tokens representing our input text.
                    token_type_ids=torch.tensor([segment_ids]), # The segment IDs to differentiate question from answer_text
                    return_dict=True) 

        start_scores = outputs.start_logits
        end_scores = outputs.end_logits

        max_start_score = torch.max(start_scores)
        max_end_score = torch.max(end_scores)

        print(max_start_score)
        print(max_end_score)



    # ======== Reconstruct Answer ========
    # Find the tokens with the highest `start` and `end` scores.
    
        answer_start = torch.argmax(start_scores)
        answer_end = torch.argmax(end_scores)

    
    # Get the string versions of the input tokens.
        tokens = tokenizer.convert_ids_to_tokens(input_ids)

    # Start with the first token.
        answer = tokens[answer_start]

    # Select the remaining answer tokens and join them with whitespace.
        for i in range(answer_start + 1, answer_end + 1):
        
        # If it's a subword token, then recombine it with the previous token.
            if tokens[i][0:2] == '##':
                answer += tokens[i][2:]
        
        # Otherwise, add a space then the token.
            else:
                answer += ' ' + tokens[i]

        scores.append((max_start_score, max_end_score, answer))

    # Compare scores for answers found and each paragraph and pick the most relevant.

    final_answer = max(scores, key=lambda x: x[0] + x[1])[2]

    return final_answer

# =====[ DEFINE INTERFACE ]===== #'
title = "Azza Conversational Agent"
examples = [
    ["Where is the Eiffel Tower?"],
    ["What is the population of France?"]
]
print("hello")
demo = gr.Interface(
    title = title,

    fn=answer_question,
    inputs = "text", 
    outputs = "text",

    examples=examples,
    )

if __name__ == "__main__":
    demo.launch()