davidberenstein1957's picture
fix: limit token lengths
bcbb85b
raw
history blame
4.25 kB
import duckdb
import gradio as gr
import polars as pl
from datasets import load_dataset
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from model2vec import StaticModel
global ds
global df
# Load a model from the HuggingFace hub (in this case the potion-base-8M model)
model_name = "minishlab/potion-base-8M"
model = StaticModel.from_pretrained(model_name)
def get_iframe(hub_repo_id):
if not hub_repo_id:
raise ValueError("Hub repo id is required")
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
iframe = f"""
<iframe
src="{url}"
frameborder="0"
width="100%"
height="600px"
></iframe>
"""
return iframe
def load_dataset_from_hub(hub_repo_id):
global ds
ds = load_dataset(hub_repo_id)
def get_columns(split: str):
global ds
ds_split = ds[split]
return gr.Dropdown(
choices=ds_split.column_names,
value=ds_split.column_names[0],
label="Select a column",
visible=True,
)
def get_splits():
global ds
splits = list(ds.keys())
return gr.Dropdown(
choices=splits, value=splits[0], label="Select a split", visible=True
)
def vectorize_dataset(split: str, column: str):
global df
global ds
df = ds[split].to_polars()
embeddings = model.encode(df[column], max_length=512)
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
def run_query(query: str):
global df
vector = model.encode(query)
df_results = duckdb.sql(
query=f"""
SELECT *
FROM df
ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
LIMIT 5
"""
).to_df()
return gr.Dataframe(df_results, visible=True)
def hide_components():
return [
gr.Dropdown(visible=False),
gr.Dropdown(visible=False),
gr.Textbox(visible=False),
gr.Button(visible=False),
gr.Dataframe(visible=False),
]
def partial_hide_components():
return [
gr.Textbox(visible=False),
gr.Button(visible=False),
gr.Dataframe(visible=False),
]
def show_components():
return [
gr.Textbox(visible=True, label="Query"),
gr.Button(visible=True, value="Search"),
]
with gr.Blocks() as demo:
gr.HTML(
"""
<h1>Vector Search any Hugging Face Dataset</h1>
<p>
This app allows you to vector search any Hugging Face dataset.
You can search for the nearest neighbors of a query vector, or
perform a similarity search on a dataframe.
</p>
"""
)
with gr.Row():
with gr.Column():
search_in = HuggingfaceHubSearch(
label="Search Huggingface Hub",
placeholder="Search for models on Huggingface",
search_type="dataset",
sumbit_on_select=True,
)
with gr.Row():
search_out = gr.HTML(label="Search Results")
with gr.Row():
split_dropdown = gr.Dropdown(label="Select a split", visible=False)
column_dropdown = gr.Dropdown(label="Select a column", visible=False)
with gr.Row():
query_input = gr.Textbox(label="Query", visible=False)
btn_run = gr.Button("Search", visible=False)
results_output = gr.Dataframe(label="Results", visible=False)
search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then(
fn=load_dataset_from_hub,
inputs=search_in,
show_progress=True,
).then(
fn=hide_components,
outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output],
).then(fn=get_splits, outputs=split_dropdown).then(
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
)
split_dropdown.change(
fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
)
column_dropdown.change(
fn=partial_hide_components,
outputs=[query_input, btn_run, results_output],
).then(fn=vectorize_dataset, inputs=[split_dropdown, column_dropdown]).then(
fn=show_components, outputs=[query_input, btn_run]
)
btn_run.click(fn=run_query, inputs=query_input, outputs=results_output)
demo.launch()