|
import gradio as gr |
|
import torch |
|
import torch.nn.functional as F |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
revision = None |
|
model = SentenceTransformer("avsolatorio/GIST-small-Embedding-v0", revision=revision) |
|
|
|
|
|
ref_texts = [ |
|
"Theatro App: Hello John. Hey John. Hi John. Call John", |
|
"Theatro App: Message John. Message for John. Leave a message for John", |
|
"Theatro App: Play messages. Listen to messages", |
|
"Theatro App: What time is it?", |
|
"Theatro App: What time is it?", |
|
"Theatro App: Cashier Backup. Backup Cashier. Register backup. Register assistance.", |
|
"Theatro App: repeat", |
|
"Theatro App: Check inventory", |
|
"Theatro App: Check Sales", |
|
"Theatro App: Curbside Pickup", |
|
"Theatro App: Replay last message.", |
|
"Theatro App: Post it. Post it for group" |
|
"Theatro App: Announcement. Announcement for the group", |
|
"Open question: This is about products sold in TractorSupply.", |
|
"Open question: This is about pet care.", |
|
"Open question: What is the weather like?", |
|
"Open question: What's 15% off from $79.99?", |
|
"Open question: Can you look up the skew for 1091784?", |
|
] |
|
|
|
ref_embeddings = model.encode(ref_texts, convert_to_tensor=True) |
|
|
|
def find_query_type(query): |
|
query_embeddings = model.encode([query], convert_to_tensor=True) |
|
scores = F.cosine_similarity(query_embeddings, ref_embeddings, dim=-1) |
|
max_index = torch.argmax(scores).item() |
|
ref_text = ref_texts[max_index] |
|
query_type = ref_text.split(": ")[0] |
|
return query_type |
|
|
|
import gradio as gr |
|
def predict(query): |
|
query_type = find_query_type(query) |
|
return query_type |
|
|
|
iface = gr.Interface(fn=predict, |
|
inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."), |
|
outputs="text", |
|
title="Query Type Classifier", |
|
description="This model classifies the type of your query. Just input your query and get the predicted category.") |
|
|
|
iface.launch() |
|
|