Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,11 +4,13 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
4 |
import streamlit as st
|
5 |
import torch
|
6 |
import pickle
|
|
|
7 |
|
8 |
model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
|
9 |
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
|
10 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
11 |
text = st.text_input("Enter word or key-phrase")
|
|
|
12 |
exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query (i.e exclude 'hot coffee' if the query is 'cold coffee')")
|
13 |
|
14 |
exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
|
@@ -17,12 +19,17 @@ k = st.number_input("Top k nearest key-phrases",1,10,5)
|
|
17 |
with st.sidebar:
|
18 |
diversify_box = st.checkbox("Diversify results",True)
|
19 |
if diversify_box:
|
20 |
-
|
21 |
|
|
|
22 |
with open("kp_dict_merged.pickle",'rb') as handle:
|
23 |
kp_dict = pickle.load(handle)
|
24 |
for key in kp_dict.keys():
|
25 |
kp_dict[key] = kp_dict[key].detach().numpy()
|
|
|
|
|
|
|
|
|
26 |
|
27 |
def calculate_top_k(out, tokens,text,exclude_text=False,exclude_words=False, k=5):
|
28 |
sim_dict = {}
|
@@ -65,10 +72,30 @@ def pool_embeddings(out, tok):
|
|
65 |
mean_pooled = summed / summed_mask
|
66 |
return mean_pooled
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
if text:
|
69 |
new_tokens = concat_tokens([text])
|
70 |
new_tokens.pop("KPS")
|
71 |
with torch.no_grad():
|
72 |
outputs = model(**new_tokens)
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import streamlit as st
|
5 |
import torch
|
6 |
import pickle
|
7 |
+
import itertools
|
8 |
|
9 |
model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
|
10 |
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
|
11 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
12 |
text = st.text_input("Enter word or key-phrase")
|
13 |
+
|
14 |
exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query (i.e exclude 'hot coffee' if the query is 'cold coffee')")
|
15 |
|
16 |
exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
|
|
|
19 |
with st.sidebar:
|
20 |
diversify_box = st.checkbox("Diversify results",True)
|
21 |
if diversify_box:
|
22 |
+
k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)
|
23 |
|
24 |
+
#load kp dict
|
25 |
with open("kp_dict_merged.pickle",'rb') as handle:
|
26 |
kp_dict = pickle.load(handle)
|
27 |
for key in kp_dict.keys():
|
28 |
kp_dict[key] = kp_dict[key].detach().numpy()
|
29 |
+
|
30 |
+
#load cosine distances of kp dict
|
31 |
+
with open("cosine_kp.pickle",'rb') as handle:
|
32 |
+
cosine_kp = pickle.load(handle)
|
33 |
|
34 |
def calculate_top_k(out, tokens,text,exclude_text=False,exclude_words=False, k=5):
|
35 |
sim_dict = {}
|
|
|
72 |
mean_pooled = summed / summed_mask
|
73 |
return mean_pooled
|
74 |
|
75 |
+
def extract_idxs(top_dict, kp_dict):
|
76 |
+
idxs = []
|
77 |
+
c = 0
|
78 |
+
for i in list(kp_dict.keys()):
|
79 |
+
if i in top_dict.keys():
|
80 |
+
idxs.append(c)
|
81 |
+
c+=1
|
82 |
+
return idxs
|
83 |
+
|
84 |
if text:
|
85 |
new_tokens = concat_tokens([text])
|
86 |
new_tokens.pop("KPS")
|
87 |
with torch.no_grad():
|
88 |
outputs = model(**new_tokens)
|
89 |
+
if not diversify_box:
|
90 |
+
sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
|
91 |
+
st.json(sim_dict)
|
92 |
+
else:
|
93 |
+
sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify)
|
94 |
+
idxs = extract_idxs(sim_dict, kp_dict)
|
95 |
+
distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
|
96 |
+
min_sim = np.inf
|
97 |
+
candidate = None
|
98 |
+
for combination in itertools.combinations(range(len(idxs)), k):
|
99 |
+
sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j])
|
100 |
+
|
101 |
+
|