|
import gradio as gr |
|
import os |
|
import torch |
|
import pickle |
|
import gzip |
|
|
|
from torch.nn.functional import cosine_similarity |
|
from model import create_semantic_ranking_model |
|
from timeit import default_timer as timer |
|
from typing import Tuple, Dict |
|
|
|
|
|
questions_texts = [] |
|
with open("questions_texts.txt", "r") as file: |
|
questions_texts = [line.strip() for line in file.readlines()] |
|
|
|
answers_texts = [] |
|
with open("answers_texts.txt", "r") as file: |
|
answers_texts = [line.strip() for line in file.readlines()] |
|
|
|
|
|
|
|
model, tokenizer = create_semantic_ranking_model() |
|
|
|
|
|
model.load_state_dict( |
|
torch.load(f="all-MiniLM-L6-v2.pth", |
|
map_location=torch.device("cpu")) |
|
) |
|
|
|
|
|
with gzip.open('response_embeddings.pkl.gz', 'rb') as f: |
|
response_embeddings = pickle.load(f) |
|
|
|
|
|
with gzip.open('response_list.pkl.gz', 'rb') as f: |
|
response_list = pickle.load(f) |
|
|
|
|
|
def predict(text) -> Tuple[Dict, float]: |
|
|
|
start_time = timer() |
|
|
|
|
|
model.eval() |
|
|
|
|
|
tokenized_inputs = tokenizer(text, return_tensors="pt", max_length=128, truncation=True, padding="max_length") |
|
|
|
|
|
with torch.inference_mode(): |
|
input_embeddings = model(**tokenized_inputs) |
|
|
|
|
|
similarity_scores = cosine_similarity(input_embeddings.unsqueeze(1), response_embeddings.unsqueeze(0), dim=2) |
|
top_responses_indices = torch.topk(similarity_scores, k=5, dim=1).indices.squeeze() |
|
|
|
|
|
top_responses = [response_list[idx] for idx in top_responses_indices] |
|
|
|
|
|
actual_response = None |
|
for question, answer in zip(questions_texts, answers_texts): |
|
if text.strip() == question.strip(): |
|
actual_response = answer |
|
break |
|
|
|
|
|
end_time = timer() |
|
pred_time = round(end_time - start_time, 4) |
|
|
|
|
|
return {"Top Responses": top_responses, "Actual Response": actual_response}, pred_time |
|
|
|
|
|
|
|
title = "Semantic Ranking with MiniLM-L6-v2" |
|
description = "[A MiniLM-L6-H384-uncased MiniLM based model](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) sentence embedding model trained to rank results from [HuggingFace π€ Hello-SimpleAI/HC3](https://huggingface.co/datasets/Hello-SimpleAI/HC3). [Source Code Found Here](https://colab.research.google.com/drive/1o5a9zH1TxzaxLKV5AFUhZE8L8yMnO9Jw?usp=sharing)" |
|
article = "Built with [Gradio](https://github.com/gradio-app/gradio) and [PyTorch](https://pytorch.org/). [Source Code Found Here](https://colab.research.google.com/drive/1o5a9zH1TxzaxLKV5AFUhZE8L8yMnO9Jw?usp=sharing)" |
|
|
|
|
|
demo = gr.Interface(fn=predict, |
|
inputs=gr.Textbox(lines=2, placeholder="Type your text here..."), |
|
outputs=[gr.JSON(label="Top Responses"), |
|
gr.Textbox(label="Actual Response", disabled=True), |
|
gr.Number(label="Prediction time (s)")], |
|
examples=example_texts, |
|
title=title, |
|
description=description, |
|
article=article) |
|
|
|
|
|
demo.launch() |
|
|