import os from functools import lru_cache from typing import Optional import gradio as gr from dotenv import load_dotenv from qdrant_client import QdrantClient, models from sentence_transformers import SentenceTransformer load_dotenv() URL = os.getenv("QDRANT_URL") QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") sentence_embedding_model = SentenceTransformer("BAAI/bge-large-en") print(URL) print(QDRANT_API_KEY) collection_name = "dataset_cards" client = QdrantClient( url=URL, api_key=QDRANT_API_KEY, ) def format_results(results): markdown = ( "

✨ Dataset Search Results ✨" "

\n\n" ) for result in results: hub_id = result.payload["id"] download_number = result.payload["downloads"] url = f"https://huggingface.co/datasets/{hub_id}" header = f"## [{hub_id}]({url})" markdown += header + "\n" markdown += f"**Downloads:** {download_number}\n\n" markdown += f"{result.payload['section_text']} \n" return markdown @lru_cache(maxsize=100_000) def search(query: str, limit: Optional[int] = 10): query_ = sentence_embedding_model.encode( f"Represent this sentence for searching relevant passages:{query}" ) results = client.search( collection_name="dataset_cards", query_vector=query_, limit=limit, ) return format_results(results) @lru_cache(maxsize=100_000) def hub_id_qdrant_id(hub_id): matches = client.scroll( collection_name="dataset_cards", scroll_filter=models.Filter( must=[ models.FieldCondition(key="id", match=models.MatchValue(value=hub_id)), ] ), limit=1, with_payload=True, with_vectors=False, ) try: return matches[0][0].id except IndexError as e: raise gr.Error( f"Hub id {hub_id} not in out database. This could be because it is very new" " or because it doesn't have much documentation." ) from e @lru_cache() def recommend(hub_id, limit: Optional[int] = 10): positive_id = hub_id_qdrant_id(hub_id) results = client.recommend( collection_name=collection_name, positive=[positive_id], limit=limit ) return format_results(results) def query(search_term, search_type, limit: Optional[int] = 10): if search_type == "Recommend similar datasets": return recommend(search_term, limit) else: return search(search_term, limit) with gr.Blocks() as demo: gr.Markdown("## 🤗 Semantic Dataset Search") with gr.Row(): gr.Markdown( "This Gradio app allows you to search for datasets based on their" " descriptions. You can either search for similar datasets to a given" " dataset or search for datasets based on a query." ) with gr.Row(): search_term = gr.Textbox( value="movie review sentiment", label="hub id i.e. IMDB or query i.e. movie review sentiment", ) with gr.Row(): with gr.Row(): find_similar_btn = gr.Button("Search") search_type = gr.Radio( ["Recommend similar datasets", "Semantic Search"], label="Search type", value="Semantic Search", interactive=True, ) with gr.Column(): max_results = gr.Slider( minimum=1, maximum=50, step=1, value=10, label="Maximum number of results", help="This is the maximum number of results that will be returned", ) results = gr.Markdown() find_similar_btn.click(query, [search_term, search_type, max_results], results) demo.launch()