resrer-demo / app.py
seonglae's picture
fix: remove question mark for better inference
f12e339
import os
import streamlit as st
from pymilvus import MilvusClient
import torch
from model import encode_dpr_question, get_dpr_encoder
from model import summarize_text, get_summarizer
from model import ask_reader, get_reader
TITLE = 'ReSRer: Retriever-Summarizer-Reader'
INITIAL = "What is the population of NYC"
st.set_page_config(page_title=TITLE)
st.header(TITLE)
st.markdown('''
<h5>Ask short-answer question that can be find in Wikipedia data.</h5>
''', unsafe_allow_html=True)
st.markdown(
'This demo searches through 21,000,000 Wikipedia passages in real-time under the hood.')
@st.cache_resource
def load_models():
models = {}
models['encoder'] = get_dpr_encoder()
models['summarizer'] = get_summarizer()
models['reader'] = get_reader()
return models
@st.cache_resource
def load_client():
client = MilvusClient(user='resrer', password=os.environ['MILVUS_PW'],
uri=f"http://{os.environ['MILVUS_HOST']}:19530", db_name='psgs_w100')
return client
client = load_client()
models = load_models()
styl = """
<style>
.StatusWidget-enter-done{
position: fixed;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
}
.StatusWidget-enter-done button{
display: none;
}
</style>
"""
st.markdown(styl, unsafe_allow_html=True)
question = st.text_input("Question", INITIAL)
col1, col2, col3 = st.columns(3)
if col1.button("What is the capital of South Korea"):
question = "What is the capital of South Korea"
if col2.button("What is the most famous building in Paris"):
question = "What is the most famous building in Paris"
if col3.button("Who is the actor of Harry Potter"):
question = "Who is the actor of Harry Potter"
@torch.inference_mode()
def main(question: str):
if question in st.session_state:
print("Cache hit!")
ctx, summary, answer = st.session_state[question]
else:
print(f"Input: {question}")
# Embedding
question_vectors = encode_dpr_question(
models['encoder'][0], models['encoder'][1], [question])
query_vector = question_vectors.detach().cpu().numpy().tolist()[0]
# Retriever
results = client.search(collection_name='dpr_nq', data=[
query_vector], limit=10, output_fields=['title', 'text'])
texts = [result['entity']['text'] for result in results[0]]
ctx = '\n'.join(texts)
# Reader
[summary] = summarize_text(models['summarizer'][0],
models['summarizer'][1], [ctx])
answers = ask_reader(models['reader'][0],
models['reader'][1], [question], [summary])
answer = answers[0]['answer']
print(f"\nAnswer: {answer}")
st.session_state[question] = (ctx, summary, answer)
# Summary
st.write(f"### Answer: {answer}")
st.markdown('<h5>Summarized Context</h5>', unsafe_allow_html=True)
st.markdown(
f"<h6 style='padding: 0'>{summary}</h6><hr style='margin: 1em 0px'>", unsafe_allow_html=True)
st.markdown('<h5>Original Context</h5>', unsafe_allow_html=True)
st.markdown(ctx)
if question:
main(question)