|
import json |
|
import requests |
|
from relbert import RelBERT |
|
import gradio as gr |
|
|
|
model = RelBERT(model='relbert/relbert-roberta-large') |
|
|
|
|
|
def get_example(): |
|
url = "https://huggingface.co/datasets/relbert/analogy_questions/raw/main/dataset/sat/test.jsonl" |
|
r = requests.get(url) |
|
example = [json.loads(i) for i in r.content.decode().split('\n') if len(i) > 0] |
|
return example |
|
|
|
|
|
def cosine_similarity(a, b, zero_vector_mask: float = -100): |
|
norm_a = sum(map(lambda x: x * x, a)) ** 0.5 |
|
norm_b = sum(map(lambda x: x * x, b)) ** 0.5 |
|
if norm_b * norm_a == 0: |
|
return zero_vector_mask |
|
return sum(map(lambda x: x[0] * x[1], zip(a, b)))/(norm_a * norm_b) |
|
|
|
|
|
def greet( |
|
query, |
|
candidate_1, |
|
candidate_2, |
|
candidate_3, |
|
candidate_4, |
|
candidate_5, |
|
candidate_6, |
|
candidate_7, |
|
candidate_8, |
|
candidate_9, |
|
candidate_10): |
|
query = query.split(',') |
|
|
|
if len(query) == 0: |
|
raise ValueError(f'ERROR: query is empty {query}') |
|
if len(query) == 1: |
|
raise ValueError(f'ERROR: query contains single word {query}') |
|
if len(query) > 2: |
|
raise ValueError(f'ERROR: query contains more than two word {query}') |
|
|
|
pairs = [] |
|
pairs_id = [] |
|
for n, i in enumerate([ |
|
candidate_1, |
|
candidate_2, |
|
candidate_3, |
|
candidate_4, |
|
candidate_5, |
|
candidate_6, |
|
candidate_7, |
|
candidate_8, |
|
candidate_9, |
|
candidate_10 |
|
]): |
|
if i == '': |
|
continue |
|
candidate = i.split(',') |
|
if len(candidate) == 1: |
|
raise ValueError(f'ERROR: candidate {n + 1} contains single word {candidate}') |
|
if len(candidate) > 2: |
|
raise ValueError(f'ERROR: candidate {n + 1} contains more than two word {candidate}') |
|
pairs.append(candidate) |
|
pairs_id.append(n+1) |
|
if len(pairs_id) < 2: |
|
raise ValueError(f'ERROR: please specify at least two candidates: {pairs}') |
|
vectors = model.get_embedding(pairs+[query]) |
|
vector_q = vectors.pop(-1) |
|
sims = [] |
|
for v in vectors: |
|
sims.append(cosine_similarity(v, vector_q)) |
|
output = sorted(list(zip(pairs_id, sims, pairs)), key=lambda _x: _x[1], reverse=True) |
|
output = {f'candidate {n + 1}: [{p[0]}, {p[1]}]': s for n, (i, s, p) in enumerate(output)} |
|
return output |
|
|
|
|
|
examples = get_example()[:15] |
|
examples = [[','.join(i['stem'])] + [','.join(c) for c in i['choice'] + [''] * (10 - len(i['choice']))] for i in examples] |
|
demo = gr.Interface( |
|
fn=greet, |
|
inputs=[ |
|
gr.Textbox(lines=1, placeholder="Query Word Pair (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 1 (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 2 (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 3 (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 4 (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 5 (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 6 (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 7 (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 8 (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 9 (separate by comma)"), |
|
gr.Textbox(lines=1, placeholder="Candidate Word Pair 10 (separate by comma)") |
|
], |
|
outputs="label", |
|
examples=examples |
|
) |
|
demo.launch(show_error=True) |
|
|
|
|
|
|