File size: 3,179 Bytes
f58d9a0 c31d68b 87187b9 cd474b3 dba7731 cd474b3 87187b9 f58d9a0 87187b9 c31d68b 87187b9 e2905f0 c31d68b 7c630e0 87187b9 e2905f0 87187b9 7c630e0 c31d68b 7c630e0 217ce82 7c630e0 87187b9 f58d9a0 e2905f0 87187b9 ac00d10 87187b9 f58d9a0 87187b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import json
import requests
import re
from relbert import RelBERT
import gradio as gr
model = RelBERT(model='relbert/relbert-roberta-large')
def get_example():
url = "https://huggingface.co/datasets/relbert/analogy_questions/raw/main/dataset/sat/test.jsonl"
r = requests.get(url)
example = [json.loads(i) for i in r.content.decode().split('\n') if len(i) > 0]
return example
def cosine_similarity(a, b, zero_vector_mask: float = -100):
norm_a = sum(map(lambda x: x * x, a)) ** 0.5
norm_b = sum(map(lambda x: x * x, b)) ** 0.5
if norm_b * norm_a == 0:
return zero_vector_mask
return sum(map(lambda x: x[0] * x[1], zip(a, b)))/(norm_a * norm_b)
def clean(text):
text = re.sub(r"\A\s+", "", text)
text = re.sub(r"\s+\Z", "", text)
return text
def greet(
query,
candidate_1,
candidate_2,
candidate_3,
candidate_4,
candidate_5,
candidate_6):
query = [clean(i) for i in query.split(',')]
# validate query
if len(query) == 0:
raise ValueError(f'ERROR: query is empty {query}')
if len(query) == 1:
raise ValueError(f'ERROR: query contains single word {query}')
if len(query) > 2:
raise ValueError(f'ERROR: query contains more than two word {query}')
pairs = []
pairs_id = []
for n, i in enumerate([
candidate_1,
candidate_2,
candidate_3,
candidate_4,
candidate_5,
candidate_6
]):
if i == '':
continue
candidate = [clean(x) for x in i.split(',')]
if len(candidate) == 1:
raise ValueError(f'ERROR: candidate {n + 1} contains single word {candidate}')
if len(candidate) > 2:
raise ValueError(f'ERROR: candidate {n + 1} contains more than two word {candidate}')
pairs.append(candidate)
pairs_id.append(n+1)
if len(pairs_id) < 2:
raise ValueError(f'ERROR: please specify at least two candidates: {pairs}')
vectors = model.get_embedding(pairs+[query])
vector_q = vectors.pop(-1)
sims = []
for v in vectors:
sims.append(cosine_similarity(v, vector_q))
output = {f'candidate {i}: [{p[0]}, {p[1]}]': s for i, s, p in zip(pairs_id, sims, pairs)}
return output
examples = get_example()[:15]
examples = [[','.join(i['stem'])] + [','.join(c) for c in i['choice'] + [''] * (6 - len(i['choice']))] for i in examples]
demo = gr.Interface(
fn=greet,
inputs=[
gr.Textbox(lines=1, placeholder="Query Word Pair (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 1 (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 2 (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 3 (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 4 (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 5 (separate by comma)"),
gr.Textbox(lines=1, placeholder="Candidate Word Pair 6 (separate by comma)"),
],
outputs="label",
examples=examples
)
demo.launch(show_error=True)
|