update
Browse files- app.py +83 -4
- flagged/log.csv +3 -0
- requirements.txt +1 -0
app.py
CHANGED
@@ -1,7 +1,86 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from relbert import RelBERT
|
2 |
import gradio as gr
|
3 |
|
4 |
+
model = RelBERT()
|
|
|
5 |
|
6 |
+
|
7 |
+
def cosine_similarity(a, b, zero_vector_mask: float = -100):
|
8 |
+
norm_a = sum(map(lambda x: x * x, a)) ** 0.5
|
9 |
+
norm_b = sum(map(lambda x: x * x, b)) ** 0.5
|
10 |
+
if norm_b * norm_a == 0:
|
11 |
+
return zero_vector_mask
|
12 |
+
return sum(map(lambda x: x[0] * x[1], zip(a, b)))/(norm_a * norm_b)
|
13 |
+
|
14 |
+
|
15 |
+
def greet(
|
16 |
+
query,
|
17 |
+
candidate_1,
|
18 |
+
candidate_2,
|
19 |
+
candidate_3,
|
20 |
+
candidate_4,
|
21 |
+
candidate_5,
|
22 |
+
candidate_6,
|
23 |
+
candidate_7,
|
24 |
+
candidate_8,
|
25 |
+
candidate_9,
|
26 |
+
candidate_10):
|
27 |
+
pairs = []
|
28 |
+
pairs_id = []
|
29 |
+
for n, i in enumerate([
|
30 |
+
candidate_1,
|
31 |
+
candidate_2,
|
32 |
+
candidate_3,
|
33 |
+
candidate_4,
|
34 |
+
candidate_5,
|
35 |
+
candidate_6,
|
36 |
+
candidate_7,
|
37 |
+
candidate_8,
|
38 |
+
candidate_9,
|
39 |
+
candidate_10
|
40 |
+
]):
|
41 |
+
query = query.split(',')
|
42 |
+
# validate query
|
43 |
+
if len(query) == 0:
|
44 |
+
raise ValueError(f'ERROR: query is empty {query}')
|
45 |
+
if len(query) == 1:
|
46 |
+
raise ValueError(f'ERROR: query contains single word {query}')
|
47 |
+
if len(query) > 2:
|
48 |
+
raise ValueError(f'ERROR: query contains more than two word {query}')
|
49 |
+
|
50 |
+
if i != '':
|
51 |
+
if len(i.split(',')) != 1:
|
52 |
+
raise ValueError(f'ERROR: candidate {n + 1} contains single word {i.split(",")}')
|
53 |
+
if len(i.split(',')) > 2:
|
54 |
+
raise ValueError(f'ERROR: candidate {n + 1} contains more than two word {i.split(",")}')
|
55 |
+
pairs.append(i.split(','))
|
56 |
+
pairs_id.append(n+1)
|
57 |
+
if len(pairs_id) < 2:
|
58 |
+
raise ValueError(f'ERROR: please specify at least two candidates: {pairs}')
|
59 |
+
vectors = model.get_embedding(pairs+[query])
|
60 |
+
vector_q = vectors.pop(-1)
|
61 |
+
sims = []
|
62 |
+
for v in vectors:
|
63 |
+
sims.append(cosine_similarity(v, vector_q))
|
64 |
+
output = sorted(list(zip(pairs_id, sims, pairs)), key=lambda _x: _x[1], reverse=True)
|
65 |
+
output = {f'candidate {n}: [{p[0]}, {p[1]}]': s for n, (i, s, p) in enumerate(output)}
|
66 |
+
return output
|
67 |
+
|
68 |
+
|
69 |
+
demo = gr.Interface(
|
70 |
+
fn=greet,
|
71 |
+
inputs=[
|
72 |
+
gr.Textbox(lines=1, placeholder="Query Word Pair (separate by comma i.e. 'scotch whisky,wheat')"),
|
73 |
+
gr.Textbox(lines=1, placeholder="Candidate Word Pair 1 (separate by comma i.e. 'scotch whisky,wheat')"),
|
74 |
+
gr.Textbox(lines=1, placeholder="Candidate Word Pair 2 (separate by comma i.e. 'scotch whisky,wheat')"),
|
75 |
+
gr.Textbox(lines=1, placeholder="Candidate Word Pair 3 (separate by comma i.e. 'scotch whisky,wheat')"),
|
76 |
+
gr.Textbox(lines=1, placeholder="Candidate Word Pair 4 (separate by comma i.e. 'scotch whisky,wheat')"),
|
77 |
+
gr.Textbox(lines=1, placeholder="Candidate Word Pair 5 (separate by comma i.e. 'scotch whisky,wheat')"),
|
78 |
+
gr.Textbox(lines=1, placeholder="Candidate Word Pair 6 (separate by comma i.e. 'scotch whisky,wheat')"),
|
79 |
+
gr.Textbox(lines=1, placeholder="Candidate Word Pair 7 (separate by comma i.e. 'scotch whisky,wheat')"),
|
80 |
+
gr.Textbox(lines=1, placeholder="Candidate Word Pair 8 (separate by comma i.e. 'scotch whisky,wheat')"),
|
81 |
+
gr.Textbox(lines=1, placeholder="Candidate Word Pair 9 (separate by comma i.e. 'scotch whisky,wheat')"),
|
82 |
+
gr.Textbox(lines=1, placeholder="Candidate Word Pair 10 (separate by comma i.e. 'scotch whisky,wheat')")
|
83 |
+
],
|
84 |
+
outputs="label",
|
85 |
+
)
|
86 |
+
demo.launch(show_error=True)
|
flagged/log.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
'query','candidate_1','candidate_2','candidate_3','candidate_4','candidate_5','candidate_6','candidate_7','candidate_8','candidate_9','candidate_10','output','flag','username','timestamp'
|
2 |
+
'','','','','','','','','','','','','','','2022-08-12 18:23:10.741127'
|
3 |
+
'','{"dog": 0.7, "cat": 0.3}','','','2022-08-12 18:59:42.932316'
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
relbert
|