davidberenstein1957 HF staff commited on
Commit
7e27e2f
1 Parent(s): 3f7d824

feat: initial version

Browse files
Files changed (3) hide show
  1. app.py +128 -0
  2. demo.py +23 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import duckdb
2
+ import gradio as gr
3
+ import polars as pl
4
+ from datasets import load_dataset
5
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
6
+ from model2vec import StaticModel
7
+
8
+ global ds
9
+ global df
10
+
11
+ # Load a model from the HuggingFace hub (in this case the potion-base-8M model)
12
+ model_name = "minishlab/M2V_multilingual_output"
13
+ model = StaticModel.from_pretrained(model_name)
14
+
15
+
16
+ def get_iframe(hub_repo_id):
17
+ if not hub_repo_id:
18
+ raise ValueError("Hub repo id is required")
19
+ url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
20
+ iframe = f"""
21
+ <iframe
22
+ src="{url}"
23
+ frameborder="0"
24
+ width="100%"
25
+ height="600px"
26
+ ></iframe>
27
+ """
28
+ return iframe
29
+
30
+
31
+ def load_dataset_from_hub(hub_repo_id):
32
+ global ds
33
+ ds = load_dataset(hub_repo_id)
34
+
35
+
36
+ def get_columns(split: str):
37
+ global ds
38
+ ds_split = ds[split]
39
+ return gr.Dropdown(
40
+ choices=ds_split.column_names,
41
+ value=ds_split.column_names[0],
42
+ label="Select a column",
43
+ )
44
+
45
+
46
+ def get_splits():
47
+ global ds
48
+ splits = list(ds.keys())
49
+ return gr.Dropdown(choices=splits, value=splits[0], label="Select a split")
50
+
51
+
52
+ def vectorize_dataset(split: str, column: str):
53
+ global df
54
+ global ds
55
+ df = ds[split].to_polars()
56
+ embeddings = model.encode(df[column])
57
+ df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
58
+
59
+
60
+ def run_query(query: str):
61
+ global df
62
+ vector = model.encode(query)
63
+ return duckdb.sql(
64
+ query=f"""
65
+ SELECT *
66
+ FROM df
67
+ ORDER BY array_distance(embeddings, {vector.tolist()}::FLOAT[256])
68
+ LIMIT 5
69
+ """
70
+ ).to_df()
71
+
72
+
73
+ with gr.Blocks() as demo:
74
+ gr.HTML(
75
+ """
76
+ <h1>Vector Search any Hugging Face Dataset</h1>
77
+ <p>
78
+ This app allows you to vector search any Hugging Face dataset.
79
+ You can search for the nearest neighbors of a query vector, or
80
+ perform a similarity search on a dataframe.
81
+ </p>
82
+ <p>
83
+ This app uses the <a href="https://huggingface.co/minishlab/M2V_multilingual_output">M2V_multilingual_output</a> model from the Hugging Face Hub.
84
+ </p>
85
+ """
86
+ )
87
+ with gr.Row():
88
+ with gr.Column():
89
+ search_in = HuggingfaceHubSearch(
90
+ label="Search Huggingface Hub",
91
+ placeholder="Search for models on Huggingface",
92
+ search_type="dataset",
93
+ sumbit_on_select=True,
94
+ )
95
+ with gr.Row():
96
+ search_out = gr.HTML(label="Search Results")
97
+ search_in.submit(get_iframe, inputs=search_in, outputs=search_out)
98
+
99
+ btn_load_dataset = gr.Button("Load Dataset")
100
+
101
+ with gr.Row(variant="panel"):
102
+ split_dropdown = gr.Dropdown(label="Select a split")
103
+ column_dropdown = gr.Dropdown(label="Select a column")
104
+ btn_vectorize_dataset = gr.Button("Vectorize")
105
+
106
+ btn_load_dataset.click(
107
+ load_dataset_from_hub, inputs=search_in, show_progress=True
108
+ ).then(fn=get_splits, outputs=split_dropdown).then(
109
+ fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
110
+ )
111
+ split_dropdown.change(
112
+ fn=get_columns, inputs=split_dropdown, outputs=column_dropdown
113
+ )
114
+
115
+ btn_vectorize_dataset.click(
116
+ fn=vectorize_dataset,
117
+ inputs=[split_dropdown, column_dropdown],
118
+ show_progress=True,
119
+ )
120
+
121
+ with gr.Row(variant="panel"):
122
+ query_input = gr.Textbox(label="Query")
123
+
124
+ btn_run = gr.Button("Run")
125
+ results_output = gr.Dataframe(label="Results")
126
+
127
+ btn_run.click(fn=run_query, inputs=query_input, outputs=results_output)
128
+ demo.launch()
demo.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import duckdb
2
+ import polars as pl
3
+ from datasets import load_dataset
4
+ from model2vec import StaticModel
5
+
6
+ # Load a model from the HuggingFace hub (in this case the potion-base-8M model)
7
+ model_name = "minishlab/M2V_multilingual_output"
8
+ model = StaticModel.from_pretrained(model_name)
9
+
10
+ # Make embeddings
11
+ ds = load_dataset("fka/awesome-chatgpt-prompts")
12
+ df = ds["train"].to_polars()
13
+ embeddings = model.encode(df["prompt"])
14
+ df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
15
+ vector = model.encode("vector search", show_progress_bar=True)
16
+ duckdb.sql(
17
+ query=f"""
18
+ SELECT *
19
+ FROM df
20
+ ORDER BY array_distance(embeddings, {vector.tolist()}::FLOAT[256])
21
+ LIMIT 1
22
+ """
23
+ ).show()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ polars
2
+ datasets
3
+ model2vec
4
+ duckdb
5
+ gradio
6
+ gradio-huggingfacehub-search