|
import duckdb |
|
import gradio as gr |
|
import polars as pl |
|
from datasets import load_dataset |
|
from gradio_huggingfacehub_search import HuggingfaceHubSearch |
|
from model2vec import StaticModel |
|
|
|
global ds |
|
global df |
|
|
|
|
|
model_name = "minishlab/M2V_multilingual_output" |
|
model = StaticModel.from_pretrained(model_name) |
|
|
|
|
|
def get_iframe(hub_repo_id): |
|
if not hub_repo_id: |
|
raise ValueError("Hub repo id is required") |
|
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer" |
|
iframe = f""" |
|
<iframe |
|
src="{url}" |
|
frameborder="0" |
|
width="100%" |
|
height="600px" |
|
></iframe> |
|
""" |
|
return iframe |
|
|
|
|
|
def load_dataset_from_hub(hub_repo_id): |
|
global ds |
|
ds = load_dataset(hub_repo_id) |
|
|
|
|
|
def get_columns(split: str): |
|
global ds |
|
ds_split = ds[split] |
|
return gr.Dropdown( |
|
choices=ds_split.column_names, |
|
value=ds_split.column_names[0], |
|
label="Select a column", |
|
) |
|
|
|
|
|
def get_splits(): |
|
global ds |
|
splits = list(ds.keys()) |
|
return gr.Dropdown(choices=splits, value=splits[0], label="Select a split") |
|
|
|
|
|
def vectorize_dataset(split: str, column: str): |
|
global df |
|
global ds |
|
df = ds[split].to_polars() |
|
embeddings = model.encode(df[column]) |
|
df = df.with_columns(pl.Series(embeddings).alias("embeddings")) |
|
|
|
|
|
def run_query(query: str): |
|
global df |
|
vector = model.encode(query) |
|
return duckdb.sql( |
|
query=f""" |
|
SELECT * |
|
FROM df |
|
ORDER BY array_distance(embeddings, {vector.tolist()}::FLOAT[256]) |
|
LIMIT 5 |
|
""" |
|
).to_df() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML( |
|
""" |
|
<h1>Vector Search any Hugging Face Dataset</h1> |
|
<p> |
|
This app allows you to vector search any Hugging Face dataset. |
|
You can search for the nearest neighbors of a query vector, or |
|
perform a similarity search on a dataframe. |
|
</p> |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
search_in = HuggingfaceHubSearch( |
|
label="Search Huggingface Hub", |
|
placeholder="Search for models on Huggingface", |
|
search_type="dataset", |
|
sumbit_on_select=True, |
|
) |
|
with gr.Row(): |
|
search_out = gr.HTML(label="Search Results") |
|
search_in.submit(get_iframe, inputs=search_in, outputs=search_out) |
|
|
|
btn_load_dataset = gr.Button("Load Dataset") |
|
|
|
with gr.Row(variant="panel"): |
|
split_dropdown = gr.Dropdown(label="Select a split") |
|
column_dropdown = gr.Dropdown(label="Select a column") |
|
with gr.Row(variant="panel"): |
|
query_input = gr.Textbox(label="Query") |
|
|
|
btn_load_dataset.click( |
|
load_dataset_from_hub, inputs=search_in, show_progress=True |
|
).then(fn=get_splits, outputs=split_dropdown).then( |
|
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown |
|
) |
|
split_dropdown.change( |
|
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown |
|
).then(fn=vectorize_dataset, inputs=[split_dropdown, column_dropdown]) |
|
|
|
btn_run = gr.Button("Run") |
|
results_output = gr.Dataframe(label="Results") |
|
|
|
btn_run.click(fn=run_query, inputs=query_input, outputs=results_output) |
|
demo.launch() |
|
|