Analogy / app.py
asahi417's picture
update
ac00d10
raw
history blame
No virus
3.49 kB
from relbert import RelBERT
import gradio as gr
model = RelBERT(model='relbert/relbert-roberta-large')
def cosine_similarity(a, b, zero_vector_mask: float = -100):
norm_a = sum(map(lambda x: x * x, a)) ** 0.5
norm_b = sum(map(lambda x: x * x, b)) ** 0.5
if norm_b * norm_a == 0:
return zero_vector_mask
return sum(map(lambda x: x[0] * x[1], zip(a, b)))/(norm_a * norm_b)
def greet(
query,
candidate_1,
candidate_2,
candidate_3,
candidate_4,
candidate_5,
candidate_6,
candidate_7,
candidate_8,
candidate_9,
candidate_10):
pairs = []
pairs_id = []
for n, i in enumerate([
candidate_1,
candidate_2,
candidate_3,
candidate_4,
candidate_5,
candidate_6,
candidate_7,
candidate_8,
candidate_9,
candidate_10
]):
query = query.split(',')
# validate query
if len(query) == 0:
raise ValueError(f'ERROR: query is empty {query}')
if len(query) == 1:
raise ValueError(f'ERROR: query contains single word {query}')
if len(query) > 2:
raise ValueError(f'ERROR: query contains more than two word {query}')
if i != '':
if len(i.split(',')) == 1:
raise ValueError(f'ERROR: candidate {n + 1} contains single word {i.split(",")}')
if len(i.split(',')) > 2:
raise ValueError(f'ERROR: candidate {n + 1} contains more than two word {i.split(",")}')
pairs.append(i.split(','))
pairs_id.append(n+1)
if len(pairs_id) < 2:
raise ValueError(f'ERROR: please specify at least two candidates: {pairs}')
vectors = model.get_embedding(pairs+[query])
vector_q = vectors.pop(-1)
sims = []
for v in vectors:
sims.append(cosine_similarity(v, vector_q))
output = sorted(list(zip(pairs_id, sims, pairs)), key=lambda _x: _x[1], reverse=True)
output = {f'candidate {n}: [{p[0]}, {p[1]}]': s for n, (i, s, p) in enumerate(output)}
return output
demo = gr.Interface(
fn=greet,
inputs=[
gr.Textbox(lines=1, placeholder="Query Word Pair (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 1 (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 2 (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 3 (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 4 (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 5 (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 6 (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 7 (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 8 (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 9 (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 10 (separate by comma)")
],
outputs="label",
examples=[
["beauty,aesthete", "pleasure,hedonist", "emotion,demagogue", "opinion,sympathizer", "seance,medium", "luxury,ascetic"] + [''] * 5,
["classroom,desk", "bank,dollar", "church,pew", "studio,paintbrush", "museum,artifact"] + [''] * 6,
],
)
demo.launch(show_error=True)