|
import duckdb |
|
import polars as pl |
|
from datasets import load_dataset |
|
from model2vec import StaticModel |
|
|
|
|
|
model_name = "minishlab/M2V_multilingual_output" |
|
model = StaticModel.from_pretrained(model_name) |
|
|
|
|
|
ds = load_dataset("fka/awesome-chatgpt-prompts") |
|
df = ds["train"].to_polars() |
|
embeddings = model.encode(df["prompt"]) |
|
df = df.with_columns(pl.Series(embeddings).alias("embeddings")) |
|
vector = model.encode("vector search", show_progress_bar=True) |
|
duckdb.sql( |
|
query=f""" |
|
SELECT * |
|
FROM df |
|
ORDER BY array_distance(embeddings, {vector.tolist()}::FLOAT[256]) |
|
LIMIT 1 |
|
""" |
|
).show() |
|
|