crystina-z commited on
Commit
934f74d
·
1 Parent(s): a5881f5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+
4
+ import streamlit as st
5
+ from pathlib import Path
6
+
7
+ import sys
8
+ path_root = Path("./")
9
+ sys.path.append(str(path_root))
10
+
11
+
12
+ st.set_page_config(page_title="PSC Runtime",
13
+ page_icon='🌸', layout="centered")
14
+
15
+ # cola, colb, colc = st.columns([5, 4, 5])
16
+
17
+ # colaa, colbb, colcc = st.columns([1, 8, 1])
18
+ # with colbb:
19
+ # runtime = st.select_slider(
20
+ # 'Select a runtime type',
21
+ # options=['PyTorch', 'ONNX Runtime'])
22
+ # st.write('Now using: ', runtime)
23
+
24
+
25
+ # colaa, colbb, colcc = st.columns([1, 8, 1])
26
+ # with colbb:
27
+ # encoder = st.select_slider(
28
+ # 'Select a query encoder',
29
+ # options=['uniCOIL', 'SPLADE++ Ensemble Distil', 'SPLADE++ Self Distil'])
30
+ # st.write('Now Running Encoder: ', encoder)
31
+
32
+ # if runtime == 'PyTorch':
33
+ # runtime = 'pytorch'
34
+ # runtime_index = 1
35
+ # else:
36
+ # runtime = 'onnx'
37
+ # runtime_index = 0
38
+
39
+
40
+ col1, col2 = st.columns([9, 1])
41
+ with col1:
42
+ search_query = st.text_input(label="search query", placeholder="Search")
43
+
44
+ with col2:
45
+ st.write('#')
46
+ button_clicked = st.button("🔎")
47
+
48
+
49
+ import torch
50
+ fn = ""
51
+ object = torch.load(fn)
52
+ outputs = [x[2] for x in object]
53
+ query2outputs = {}
54
+ for output in outputs:
55
+ all_queries = {x['query'] for x in output}
56
+ assert len(all_queries) == 1
57
+ query = list(all_queries)[0]
58
+ query2outputs[query] = [x['hits'] for x in output]
59
+
60
+ search_query = sorted(query2outputs)[0]
61
+
62
+ if search_query or button_clicked:
63
+
64
+ num_results = None
65
+ t_0 = time.time()
66
+ search_results = query2outputs[search_query]
67
+
68
+ st.write(
69
+ f'<p align=\"right\" style=\"color:grey;\"> Before aggregation for query [{search_query}] ms</p>', unsafe_allow_html=True)
70
+
71
+ for i, result in enumerate(search_results):
72
+ result_id = result["docid"]
73
+ contents = result["content"]
74
+
75
+ # output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id} | <b>Score</b>:{result_score:.2f}</div>'
76
+ output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id}'
77
+
78
+ try:
79
+ st.write(output, unsafe_allow_html=True)
80
+ st.write(
81
+ f'<div class="row">{contents}</div>', unsafe_allow_html=True)
82
+
83
+ except:
84
+ pass
85
+ st.write('---')