File size: 5,163 Bytes
56d5504
 
 
 
 
 
 
8c50326
 
56d5504
 
 
 
 
 
 
 
 
 
ac97fd6
56d5504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f74f104
 
 
56d5504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac7ae49
56d5504
 
ac7ae49
56d5504
 
ac7ae49
56d5504
 
ac7ae49
56d5504
 
ac7ae49
56d5504
 
 
 
 
 
 
1ce0c52
9e25881
 
f74f104
9e25881
f74f104
 
 
 
 
9e25881
f74f104
ac7ae49
 
 
fc6493c
 
 
 
 
 
 
 
ac7ae49
fecd8c4
ac7ae49
d6b297a
56d5504
 
 
 
 
8c50326
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import json
import torch
import torch.nn.functional as F

from torch import Tensor
from transformers import AutoTokenizer, AutoModel

import gradio as gr

# instantiate tokenizer and model
def get_model(base_name='intfloat/e5-large-v2'):
    tokenizer = AutoTokenizer.from_pretrained(base_name)
    model = AutoModel.from_pretrained(base_name)
    
    return tokenizer, model

# get normalized scores on input_texts, the final scores are
# reported without queries, and the number of queries should
# be denoted as in how_many_q
def get_scores(model, tokenizer, input_texts, max_length=512, how_many_q=1, normalize=True):
    # Tokenize the input texts
    batch_dict = tokenizer(
        input_texts,
        max_length=max_length,
        padding=True,
        truncation=True,
        return_tensors='pt'
    )

    outputs = model(**batch_dict)
    embeddings = average_pool(
        outputs.last_hidden_state, batch_dict['attention_mask']
    )

    # (Optionally) normalize embeddings
    if normalize:
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
    scores = (embeddings[:how_many_q] @ embeddings[how_many_q:].T) * 100
    return scores

# get top n results out of the scores. This
# function only returns the scores and indices
def get_top(scores, top_k=None):
    result = torch.sort(scores, descending=True, dim=1)
    top_indices = result.indices
    top_values = result.values

    if top_k:
        top_indices = top_indices[:, :top_k]
        top_values = top_values[:, :top_k]

    return top_indices, top_values

# get top n results out of the scores. This function
# returns scores and indices along with the associated text
def get_human_readable_top(scores, input_texts, top_k=None):
    input_texts = list(filter(lambda text: "query:" not in text, input_texts))
    top_indices, top_values = get_top(scores, top_k)

    result = {}
    for input_idx, (indices, values) in enumerate(zip(top_indices, top_values)):
        q = input_texts[input_idx]
        a = []

        for idx, val in zip(indices.tolist(), values.tolist()):
            a.append({
            "idx": idx,
            "val": round(val, 3),
            "text": input_texts[idx]
            })

        result[q] = a

    return result

def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def get_result(q_txt, p_txt1, p_txt2, p_txt3, p_txt4, p_txt5):
    input_texts = [
        f"query: {q_txt}"
    ]

    if p_txt1 != '':
        input_texts.append(f"passage: {p_txt1}")

    if p_txt2 != '':
        input_texts.append(f"passage: {p_txt2}")

    if p_txt3 != '':
        input_texts.append(f"passage: {p_txt3}")        

    if p_txt4 != '':
        input_texts.append(f"passage: {p_txt4}")

    if p_txt5 != '':
        input_texts.append(f"passage: {p_txt5}")

    scores = get_scores(model, tokenizer, input_texts)
    result = get_human_readable_top(scores, input_texts)
    return json.dumps(result, indent=4)

tokenizer, model = get_model('intfloat/e5-large-v2')

with gr.Blocks() as demo:
    gr.Markdown("# E5 Large V2 Demo")
    
    q_txt = gr.Textbox(placeholder="Enter your query", label="Query")

    p_txt1 = gr.Textbox(placeholder="Enter passage 1", label="Passage 1")
    p_txt2 = gr.Textbox(placeholder="Enter passage 2", label="Passage 2")
    p_txt3 = gr.Textbox(placeholder="Enter passage 3", label="Passage 3")
    p_txt4 = gr.Textbox(placeholder="Enter passage 4", label="Passage 4")
    p_txt5 = gr.Textbox(placeholder="Enter passage 5", label="Passage 5")
    submit = gr.Button("Submit")
    o_txt = gr.Textbox(placeholder="Output", lines=10, interactive=False, label="Output")

    gr.Examples(
        [
            [
                "I'm searching for a planet not too far from Earth.",
                "Neptune is the eighth and farthest-known Solar planet from the Sun. In the Solar System, it is the fourth-largest planet by diameter, the third-most-massive planet, and the densest giant planet. It is 17 times the mass of Earth, slightly more massive than its near-twin Uranus.",
                "TRAPPIST-1d, also designated as 2MASS J23062928-0502285 d, is a small exoplanet (about 30% the mass of the earth), which orbits on the inner edge of the habitable zone of the ultracool dwarf star TRAPPIST-1 approximately 40 light-years (12.1 parsecs, or nearly 3.7336×1014 km) away from Earth in the constellation of Aquarius.",
                "A harsh desert world orbiting twin suns in the galaxy’s Outer Rim, Tatooine is a lawless place ruled by Hutt gangsters. Many settlers scratch out a living on moisture farms, while spaceport cities such as Mos Eisley and Mos Espa serve as home base for smugglers, criminals, and other rogues.",
                "",
                ""
            ]
        ], 
        inputs=[q_txt, p_txt1, p_txt2, p_txt3, p_txt4, p_txt5]
    )
    
    submit.click(
        get_result,
        [q_txt, p_txt1, p_txt2, p_txt3, p_txt4, p_txt5],
        o_txt
    )

demo.launch()