Spaces:
Runtime error
Runtime error
File size: 1,614 Bytes
9e7f361 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import gradio as gr
import datasets
import faiss
import os
from transformers import pipeline
auth_token = os.environ.get("CLARIN_KNEXT")
sample_text = (
"Wydarzenia te miały miejsce na początku mojej przyjaźni z Holmesem, "
"kiedy, jeszcze jako [unused0] kawalerowie [unused1], mieszkaliśmy razem przy Baker Street."
)
textbox = gr.Textbox(
label="Type your query here.",
value=sample_text, lines=10
)
def load_index(index_data: str = "clarin-knext/wsd-linking-index"):
ds = datasets.load_dataset(index_data, use_auth_token=auth_token)['train']
index_data = {
idx: (e_id, e_text) for idx, (e_id, e_text) in
enumerate(zip(ds['entities'], ds['texts']))
}
faiss_index = faiss.read_index("./encoder.faissindex", faiss.IO_FLAG_MMAP)
return index_data, faiss_index
def load_model(model_name: str = "clarin-knext/wsd-encoder"):
model = pipeline("feature-extraction", model=model_name, use_auth_token=auth_token)
return model
model = load_model()
index = load_index()
def predict(text: str = sample_text, top_k: int=3):
index_data, faiss_index = index
# takes only the [CLS] embedding (for now)
query = model(text, return_tensors='pt')[0][0].numpy().reshape(1, -1)
scores, indices = faiss_index.search(query, top_k)
scores, indices = scores.tolist(), indices.tolist()
results = "\n".join([
f"{index_data[result[0]]}: {result[1]}"
for output in zip(indices, scores)
for result in zip(*output)
])
return results
demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch() |