File size: 3,436 Bytes
87187b9
cd474b3
 
dba7731
cd474b3
87187b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c630e0
 
 
 
 
 
 
 
 
87187b9
 
 
 
 
 
 
 
 
 
 
 
 
 
7c630e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87187b9
 
 
 
 
ac00d10
 
 
 
 
 
 
 
 
 
 
87187b9
 
ac00d10
 
 
 
87187b9
 
ac00d10
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from relbert import RelBERT
import gradio as gr

model = RelBERT(model='relbert/relbert-roberta-large')


def cosine_similarity(a, b, zero_vector_mask: float = -100):
    norm_a = sum(map(lambda x: x * x, a)) ** 0.5
    norm_b = sum(map(lambda x: x * x, b)) ** 0.5
    if norm_b * norm_a == 0:
        return zero_vector_mask
    return sum(map(lambda x: x[0] * x[1], zip(a, b)))/(norm_a * norm_b)


def greet(
        query,
        candidate_1,
        candidate_2,
        candidate_3,
        candidate_4,
        candidate_5,
        candidate_6,
        candidate_7,
        candidate_8,
        candidate_9,
        candidate_10):
    query = query.split(',')
    # validate query
    if len(query) == 0:
        raise ValueError(f'ERROR: query is empty {query}')
    if len(query) == 1:
        raise ValueError(f'ERROR: query contains single word {query}')
    if len(query) > 2:
        raise ValueError(f'ERROR: query contains more than two word {query}')

    pairs = []
    pairs_id = []
    for n, i in enumerate([
        candidate_1,
        candidate_2,
        candidate_3,
        candidate_4,
        candidate_5,
        candidate_6,
        candidate_7,
        candidate_8,
        candidate_9,
        candidate_10
    ]):
        if i == '':
            continue
        candidate = i.split(',')
        if len(candidate) == 1:
            raise ValueError(f'ERROR: candidate {n + 1} contains single word {candidate}')
        if len(candidate) > 2:
            raise ValueError(f'ERROR: candidate {n + 1} contains more than two word {candidate}')
        pairs.append(candidate)
        pairs_id.append(n+1)
    if len(pairs_id) < 2:
        raise ValueError(f'ERROR: please specify at least two candidates: {pairs}')
    vectors = model.get_embedding(pairs+[query])
    vector_q = vectors.pop(-1)
    sims = []
    for v in vectors:
        sims.append(cosine_similarity(v, vector_q))
    output = sorted(list(zip(pairs_id, sims, pairs)), key=lambda _x: _x[1], reverse=True)
    output = {f'candidate {n}: [{p[0]}, {p[1]}]': s for n, (i, s, p) in enumerate(output)}
    return output


demo = gr.Interface(
    fn=greet,
    inputs=[
        gr.Textbox(lines=1, placeholder="Query Word Pair (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 1 (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 2 (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 3 (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 4 (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 5 (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 6 (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 7 (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 8 (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 9 (separate by comma)"),
        gr.Textbox(lines=1, placeholder="Candidate Word Pair 10 (separate by comma)")
        ],
    outputs="label",
    examples=[
        ["beauty,aesthete", "pleasure,hedonist", "emotion,demagogue", "opinion,sympathizer", "seance,medium", "luxury,ascetic"] + [''] * 5,
        ["classroom,desk", "bank,dollar", "church,pew", "studio,paintbrush", "museum,artifact"] + [''] * 6,
    ],
)
demo.launch(show_error=True)