from relbert import RelBERT import gradio as gr model = RelBERT() def cosine_similarity(a, b, zero_vector_mask: float = -100): norm_a = sum(map(lambda x: x * x, a)) ** 0.5 norm_b = sum(map(lambda x: x * x, b)) ** 0.5 if norm_b * norm_a == 0: return zero_vector_mask return sum(map(lambda x: x[0] * x[1], zip(a, b)))/(norm_a * norm_b) def greet( query, candidate_1, candidate_2, candidate_3, candidate_4, candidate_5, candidate_6, candidate_7, candidate_8, candidate_9, candidate_10): pairs = [] pairs_id = [] for n, i in enumerate([ candidate_1, candidate_2, candidate_3, candidate_4, candidate_5, candidate_6, candidate_7, candidate_8, candidate_9, candidate_10 ]): query = query.split(',') # validate query if len(query) == 0: raise ValueError(f'ERROR: query is empty {query}') if len(query) == 1: raise ValueError(f'ERROR: query contains single word {query}') if len(query) > 2: raise ValueError(f'ERROR: query contains more than two word {query}') if i != '': if len(i.split(',')) != 1: raise ValueError(f'ERROR: candidate {n + 1} contains single word {i.split(",")}') if len(i.split(',')) > 2: raise ValueError(f'ERROR: candidate {n + 1} contains more than two word {i.split(",")}') pairs.append(i.split(',')) pairs_id.append(n+1) if len(pairs_id) < 2: raise ValueError(f'ERROR: please specify at least two candidates: {pairs}') vectors = model.get_embedding(pairs+[query]) vector_q = vectors.pop(-1) sims = [] for v in vectors: sims.append(cosine_similarity(v, vector_q)) output = sorted(list(zip(pairs_id, sims, pairs)), key=lambda _x: _x[1], reverse=True) output = {f'candidate {n}: [{p[0]}, {p[1]}]': s for n, (i, s, p) in enumerate(output)} return output demo = gr.Interface( fn=greet, inputs=[ gr.Textbox(lines=1, placeholder="Query Word Pair (separate by comma i.e. 'scotch whisky,wheat')"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 1 (separate by comma i.e. 'scotch whisky,wheat')"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 2 (separate by comma i.e. 'scotch whisky,wheat')"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 3 (separate by comma i.e. 'scotch whisky,wheat')"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 4 (separate by comma i.e. 'scotch whisky,wheat')"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 5 (separate by comma i.e. 'scotch whisky,wheat')"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 6 (separate by comma i.e. 'scotch whisky,wheat')"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 7 (separate by comma i.e. 'scotch whisky,wheat')"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 8 (separate by comma i.e. 'scotch whisky,wheat')"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 9 (separate by comma i.e. 'scotch whisky,wheat')"), gr.Textbox(lines=1, placeholder="Candidate Word Pair 10 (separate by comma i.e. 'scotch whisky,wheat')") ], outputs="label", ) demo.launch(show_error=True)