davanstrien's picture
davanstrien HF staff
add results number slider
84bfe38
raw
history blame
No virus
3.86 kB
import os
from functools import lru_cache
from typing import Optional
import gradio as gr
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
load_dotenv()
URL = os.getenv("QDRANT_URL")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
sentence_embedding_model = SentenceTransformer("BAAI/bge-large-en")
print(URL)
print(QDRANT_API_KEY)
collection_name = "dataset_cards"
client = QdrantClient(
url=URL,
api_key=QDRANT_API_KEY,
)
def format_results(results):
markdown = (
"<h1 style='text-align: center;'> &#x2728; Dataset Search Results &#x2728;"
" </h1> \n\n"
)
for result in results:
hub_id = result.payload["id"]
download_number = result.payload["downloads"]
url = f"https://huggingface.co/datasets/{hub_id}"
header = f"## [{hub_id}]({url})"
markdown += header + "\n"
markdown += f"**Downloads:** {download_number}\n\n"
markdown += f"{result.payload['section_text']} \n"
return markdown
@lru_cache(maxsize=100_000)
def search(query: str, limit: Optional[int] = 10):
query_ = sentence_embedding_model.encode(
f"Represent this sentence for searching relevant passages:{query}"
)
results = client.search(
collection_name="dataset_cards",
query_vector=query_,
limit=limit,
)
return format_results(results)
@lru_cache(maxsize=100_000)
def hub_id_qdrant_id(hub_id):
matches = client.scroll(
collection_name="dataset_cards",
scroll_filter=models.Filter(
must=[
models.FieldCondition(key="id", match=models.MatchValue(value=hub_id)),
]
),
limit=1,
with_payload=True,
with_vectors=False,
)
try:
return matches[0][0].id
except IndexError as e:
raise gr.Error(
f"Hub id {hub_id} not in out database. This could be because it is very new"
" or because it doesn't have much documentation."
) from e
@lru_cache()
def recommend(hub_id, limit: Optional[int] = 10):
positive_id = hub_id_qdrant_id(hub_id)
results = client.recommend(
collection_name=collection_name, positive=[positive_id], limit=limit
)
return format_results(results)
def query(search_term, search_type, limit: Optional[int] = 10):
if search_type == "Recommend similar datasets":
return recommend(search_term, limit)
else:
return search(search_term, limit)
with gr.Blocks() as demo:
gr.Markdown("## &#129303; Semantic Dataset Search")
with gr.Row():
gr.Markdown(
"This Gradio app allows you to search for datasets based on their"
" descriptions. You can either search for similar datasets to a given"
" dataset or search for datasets based on a query."
)
with gr.Row():
search_term = gr.Textbox(
value="movie review sentiment",
label="hub id i.e. IMDB or query i.e. movie review sentiment",
)
with gr.Row():
with gr.Row():
find_similar_btn = gr.Button("Search")
search_type = gr.Radio(
["Recommend similar datasets", "Semantic Search"],
label="Search type",
value="Semantic Search",
interactive=True,
)
with gr.Column():
max_results = gr.Slider(
minimum=1,
maximum=50,
step=1,
value=10,
label="Maximum number of results",
help="This is the maximum number of results that will be returned",
)
results = gr.Markdown()
find_similar_btn.click(query, [search_term, search_type, max_results], results)
demo.launch()