File size: 2,207 Bytes
5385de7
3dc5ca1
a2253fb
577f46a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f65500
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import json
import streamlit as st
import requests as req

# TODO: improve layout (columns, sidebar, forms)
# st.set_page_config(layout='wide')


st.title('Question answering help desk application')


##########################################################
st.subheader('1. A simple question')
##########################################################


WIKI_URL = 'https://en.wikipedia.org/w/api.php'
WIKI_QUERY = "?format=json&action=query&prop=extracts&explaintext=1"
WIKI_BERT = "&titles=BERT_(language_model)"
WIKI_METHOD = 'GET'

response = req.request(WIKI_METHOD, f'{WIKI_URL}{WIKI_QUERY}{WIKI_BERT}')
resp_json = json.loads(response.content.decode("utf-8"))
wiki_bert = resp_json['query']['pages']['62026514']['extract']
paragraph =  wiki_bert

written_passage = st.text_area(
    'Paragraph used for QA (you can also edit, or copy/paste new content)', 
    paragraph, 
    height=250
)
if written_passage:
    paragraph = written_passage

question = 'How many languages does bert understand?'
written_question = st.text_input(
    'Question used for QA (you can also edit, and experiment with the answers)', 
    question
)
if written_question:
    question = written_question

QA_URL = "https://api-inference.huggingface.co/models/deepset/roberta-base-squad2"
QA_METHOD = 'POST'


if st.button('Run QA inference (get answer prediction)'):
    if paragraph and question:
        inputs = {'question': question, 'context': paragraph}
        payload = json.dumps(inputs)
        prediction = req.request(QA_METHOD, QA_URL, data=payload)
        answer = json.loads(prediction.content.decode("utf-8"))
        answer_span = answer["answer"]
        answer_score = answer["score"]
        st.write(f'Answer: **{answer_span}**')
        start_par = max(0, answer["start"]-86)
        stop_para = min(answer["end"]+90, len(paragraph))
        answer_context = paragraph[start_par:stop_para].replace(answer_span, f'**{answer_span}**')
        st.write(f'Answer context (and score): ... _{answer_context}_ ... (score: {format(answer_score, ".3f")})')
        st.write(f'Answer JSON: ')
        st.write(answer)
    else:
        st.write('Write some passage of text and a question')
        st.stop()