demo_prj3 / app.py
Nghiamc02's picture
Update app.py
9651ce6 verified
raw
history blame
794 Bytes
# from flask import Flask, render_template, request
from sentence_transformers import util
import torch
from semantic import load_corpus_and_model
# app = Flask(__name__)
query_prefix = "query: "
# # Load the pre-encoded answers from the file
answers_emb = torch.load('encoded_answers.pt')
test_queries, test_doc, model = load_corpus_and_model()
import gradio as gr
def query(q):
user_query = q
query_emb = model.encode([query_prefix + user_query], convert_to_tensor=True, show_progress_bar=False)
best_answer_index = util.cos_sim(query_emb, answers_emb).argmax().item()
best_answer_key = list(test_doc.keys())[best_answer_index]
best_answer = test_doc[best_answer_key]
return best_answer
iface = gr.Interface(fn=query, inputs="text", outputs="text")
iface.launch()