from dotenv import load_dotenv import os import pandas as pd from httpx import Client from huggingface_hub.utils import logging from functools import lru_cache from tqdm.contrib.concurrent import thread_map from huggingface_hub import HfApi import gradio as gr from sentence_transformers import SentenceTransformer import faiss import numpy as np from urllib.parse import quote load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables" BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co" logger = logging.get_logger(__name__) headers = { "authorization": f"Bearer ${HF_TOKEN}", } client = Client(headers=headers) api = HfApi(token=HF_TOKEN) def get_first_config_name(dataset: str): try: resp = client.get(f"{BASE_DATASETS_SERVER_URL}/splits?dataset={dataset}") data = resp.json() return data["splits"][0]["config"][0] except Exception as e: logger.error(f"Failed to get splits for {dataset}: {e}") return None def datasets_server_valid_rows(dataset: str): try: resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={dataset}") return resp.json()["viewer"] except Exception as e: logger.error(f"Failed to get is-valid for {dataset}: {e}") return None def dataset_is_valid(dataset): return dataset if datasets_server_valid_rows(dataset.id) else None def get_first_config_and_split_name(hub_id: str): try: resp = client.get( f"https://datasets-server.huggingface.co/splits?dataset={hub_id}" ) data = resp.json() return data["splits"][0]["config"], data["splits"][0]["split"] except Exception as e: logger.error(f"Failed to get splits for {hub_id}: {e}") return None def get_dataset_info(hub_id: str, config: str | None = None): if config is None: config = get_first_config_and_split_name(hub_id) if config is None: return None else: config = config[0] resp = client.get( f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}" ) resp.raise_for_status() return resp.json() def dataset_with_info(dataset): try: if info := get_dataset_info(dataset.id): columns = info.get("dataset_info", {}).get("features", {}) if columns is not None: return { "dataset": dataset.id, "column_names": ','.join(list(columns.keys())), "text": f"{dataset.id}-{','.join(list(columns.keys()))}", "likes": dataset.likes, "downloads": dataset.downloads, "created_at": dataset.created_at, "tags": dataset.tags, } except Exception as e: logger.error(f"Failed to get info for {dataset.id}: {e}") return None @lru_cache(maxsize=100) def prep_data(): datasets = list(api.list_datasets(limit=None, sort="createdAt", direction=-1)) print(f"Found {len(datasets)} datasets in the hub.") logger.info(f"Found {len(datasets)} datasets.") has_server = thread_map( dataset_is_valid, datasets, ) datasets_with_server = [x for x in has_server if x is not None] print(f"Found {len(datasets_with_server)} datasets with server.") dataset_infos = thread_map(dataset_with_info, datasets_with_server) dataset_infos = [x for x in dataset_infos if x is not None] print(f"Found {len(dataset_infos)} datasets with server data.") print(dataset_infos[0]) return dataset_infos all_datasets = prep_data() all_datasets_df = pd.DataFrame.from_dict(all_datasets) print(all_datasets_df.head()) text = all_datasets_df['text'] encoder = SentenceTransformer("paraphrase-mpnet-base-v2") vectors = encoder.encode(text) vector_dimension = vectors.shape[1] print("Start indexing") index = faiss.IndexFlatL2(vector_dimension) faiss.normalize_L2(vectors) index.add(vectors) print("Indexing done") def render_model_hub_link(hub_id): link = f"https://huggingface.co/datasets/{quote(hub_id)}" return f'{hub_id}' def search(dataset_name): print(f"start search for {dataset_name}") try: dataset_row = all_datasets_df[all_datasets_df.dataset == dataset_name].iloc[0] print(dataset_row) except IndexError: return pd.DataFrame([{"error": f"❌ Dataset does not exist or is not supported"}]) text = dataset_row["text"] search_vector = encoder.encode(text) _vector = np.array([search_vector]) faiss.normalize_L2(_vector) distances, ann = index.search(_vector, k=20) results = pd.DataFrame({'distances': distances[0], 'ann': ann[0]}) print("results for distances and ann") print(results) merge = pd.merge(results, all_datasets_df, left_on="ann", right_index=True) print("resultst for merged df (distances,ann, dataset info)") merge["dataset"] = merge["dataset"].apply(render_model_hub_link) return merge with gr.Blocks() as demo: gr.Markdown("# Search similar Datasets on Hugging Face") gr.Markdown("This space shows similar dataset based on column name and types") dataset_name = gr.Textbox( "asoria/bolivian-population", label="Dataset Name" ) btn = gr.Button("Show similar datasets") df = gr.DataFrame(datatype="markdown") btn.click(search, dataset_name, df) demo.launch()