davidberenstein1957's picture
feat: simplify steps
a442ffd
raw
history blame
3.29 kB
import duckdb
import gradio as gr
import polars as pl
from datasets import load_dataset
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from model2vec import StaticModel
global ds
global df
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
model_name = "minishlab/M2V_multilingual_output"
model = StaticModel.from_pretrained(model_name)
def get_iframe(hub_repo_id):
if not hub_repo_id:
raise ValueError("Hub repo id is required")
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
iframe = f"""
<iframe
src="{url}"
frameborder="0"
width="100%"
height="600px"
></iframe>
"""
return iframe
def load_dataset_from_hub(hub_repo_id):
global ds
ds = load_dataset(hub_repo_id)
def get_columns(split: str):
global ds
ds_split = ds[split]
return gr.Dropdown(
choices=ds_split.column_names,
value=ds_split.column_names[0],
label="Select a column",
)
def get_splits():
global ds
splits = list(ds.keys())
return gr.Dropdown(choices=splits, value=splits[0], label="Select a split")
def vectorize_dataset(split: str, column: str):
global df
global ds
df = ds[split].to_polars()
embeddings = model.encode(df[column])
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
def run_query(query: str):
global df
vector = model.encode(query)
return duckdb.sql(
query=f"""
SELECT *
FROM df
ORDER BY array_distance(embeddings, {vector.tolist()}::FLOAT[256])
LIMIT 5
"""
).to_df()
with gr.Blocks() as demo:
gr.HTML(
"""
<h1>Vector Search any Hugging Face Dataset</h1>
<p>
This app allows you to vector search any Hugging Face dataset.
You can search for the nearest neighbors of a query vector, or
perform a similarity search on a dataframe.
</p>
"""
)
with gr.Row():
with gr.Column():
search_in = HuggingfaceHubSearch(
label="Search Huggingface Hub",
placeholder="Search for models on Huggingface",
search_type="dataset",
sumbit_on_select=True,
)
with gr.Row():
search_out = gr.HTML(label="Search Results")
search_in.submit(get_iframe, inputs=search_in, outputs=search_out)
btn_load_dataset = gr.Button("Load Dataset")
with gr.Row(variant="panel"):
split_dropdown = gr.Dropdown(label="Select a split")
column_dropdown = gr.Dropdown(label="Select a column")
with gr.Row(variant="panel"):
query_input = gr.Textbox(label="Query")
btn_load_dataset.click(
load_dataset_from_hub, inputs=search_in, show_progress=True
).then(fn=get_splits, outputs=split_dropdown).then(
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
)
split_dropdown.change(
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
).then(fn=vectorize_dataset, inputs=[split_dropdown, column_dropdown])
btn_run = gr.Button("Run")
results_output = gr.Dataframe(label="Results")
btn_run.click(fn=run_query, inputs=query_input, outputs=results_output)
demo.launch()