import duckdb import gradio as gr import polars as pl from datasets import load_dataset from gradio_huggingfacehub_search import HuggingfaceHubSearch from model2vec import StaticModel global ds global df # Load a model from the HuggingFace hub (in this case the potion-base-8M model) model_name = "minishlab/potion-base-8M" model = StaticModel.from_pretrained(model_name) def get_iframe(hub_repo_id): if not hub_repo_id: raise ValueError("Hub repo id is required") url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer" iframe = f""" """ return iframe def load_dataset_from_hub(hub_repo_id): global ds ds = load_dataset(hub_repo_id) def get_columns(split: str): global ds ds_split = ds[split] return gr.Dropdown( choices=ds_split.column_names, value=ds_split.column_names[0], label="Select a column", ) def get_splits(): global ds splits = list(ds.keys()) return gr.Dropdown(choices=splits, value=splits[0], label="Select a split") def vectorize_dataset(split: str, column: str): global df global ds df = ds[split].to_polars() embeddings = model.encode(df[column], max_length=512 * 4) df = df.with_columns(pl.Series(embeddings).alias("embeddings")) def run_query(query: str): global df vector = model.encode(query) return duckdb.sql( query=f""" SELECT * FROM df ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256]) LIMIT 5 """ ).to_df() with gr.Blocks() as demo: gr.HTML( """
This app allows you to vector search any Hugging Face dataset. You can search for the nearest neighbors of a query vector, or perform a similarity search on a dataframe.
""" ) with gr.Row(): with gr.Column(): search_in = HuggingfaceHubSearch( label="Search Huggingface Hub", placeholder="Search for models on Huggingface", search_type="dataset", sumbit_on_select=True, ) with gr.Row(): search_out = gr.HTML(label="Search Results") with gr.Row(variant="panel"): split_dropdown = gr.Dropdown(label="Select a split") column_dropdown = gr.Dropdown(label="Select a column") with gr.Row(variant="panel"): query_input = gr.Textbox(label="Query") search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then( fn=load_dataset_from_hub, inputs=search_in, show_progress=True, ).then(fn=get_splits, outputs=split_dropdown).then( fn=get_columns, inputs=split_dropdown, outputs=column_dropdown ) split_dropdown.change( fn=get_columns, inputs=split_dropdown, outputs=column_dropdown ).then(fn=vectorize_dataset, inputs=[split_dropdown, column_dropdown]) btn_run = gr.Button("Run") results_output = gr.Dataframe(label="Results") btn_run.click(fn=run_query, inputs=query_input, outputs=results_output) demo.launch()