|
|
|
from sentence_transformers import util |
|
import torch |
|
from semantic import load_corpus_and_model |
|
|
|
|
|
query_prefix = "query: " |
|
|
|
|
|
answers_emb = torch.load('encoded_answers.pt') |
|
test_queries, test_doc, model = load_corpus_and_model() |
|
|
|
import gradio as gr |
|
|
|
def query(q): |
|
user_query = q |
|
query_emb = model.encode([query_prefix + user_query], convert_to_tensor=True, show_progress_bar=False) |
|
best_answer_index = util.cos_sim(query_emb, answers_emb).argmax().item() |
|
best_answer_key = list(test_doc.keys())[best_answer_index] |
|
best_answer = test_doc[best_answer_key] |
|
return best_answer |
|
|
|
iface = gr.Interface(fn=query, inputs="text", outputs="text") |
|
iface.launch() |