|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
import polars as pl |
|
import spaces |
|
import torch |
|
|
|
from typing import Tuple, List, Union |
|
|
|
from dataset import Dataset |
|
from similarity_search import SimilaritySearch |
|
|
|
|
|
def setup( |
|
description: str, |
|
model_name: str, |
|
device: str, |
|
ndim: int, |
|
metric: str, |
|
dtype: str |
|
) -> Tuple: |
|
""" |
|
Set up the model and tokenizer for a given pre-trained model ID. |
|
|
|
Parameters |
|
---------- |
|
description : str |
|
A string containing additional description information. |
|
|
|
model_name : str |
|
Name of the pre-trained model to load. |
|
|
|
device : str |
|
Device to run the model on, e.g., 'cuda' or 'cpu'. |
|
|
|
ndim : int |
|
Dimensionality of the model. |
|
|
|
metric : str |
|
Metric for similarity search. |
|
|
|
dtype : str |
|
Data type of the model. |
|
|
|
Returns |
|
------- |
|
instance : SimilaritySearch |
|
A class dedicated to encoding text data, quantizing embeddings, and managing indices for efficient similarity search. |
|
|
|
dataset : datasets.Dataset |
|
The loaded dataset. |
|
|
|
dataframe: pl.DataFrame |
|
A Polars DataFrame representing the dataset. |
|
|
|
description : str |
|
A string containing additional description information. |
|
""" |
|
try: |
|
instance = SimilaritySearch( |
|
model_name=model_name, |
|
device=device, |
|
ndim=ndim, |
|
metric=metric, |
|
dtype=dtype |
|
) |
|
|
|
instance.load_usearch_index_view( |
|
index_path="./usearch_int8.index", |
|
) |
|
|
|
instance.load_faiss_index( |
|
index_path="./faiss_ubinary.index", |
|
) |
|
|
|
dataset = Dataset.load( |
|
dataset_path="./legalkit.hf" |
|
) |
|
|
|
dataframe = Dataset.convert_to_polars( |
|
dataset=dataset |
|
) |
|
|
|
return instance, dataset, dataframe, description |
|
|
|
except Exception as e: |
|
error_message = f"An error occurred during setup: {str(e)}" |
|
raise RuntimeError(error_message) from e |
|
|
|
|
|
DESCRIPTION = """\ |
|
# LegalKit Retrieval, a binary Search with Scalar (int8) Rescoring through French legal codes |
|
|
|
This space showcases the [tsdae-lemone-mbert-base](https://huggingface.co/louisbrulenaudet/tsdae-lemone-mbert-base) |
|
model by Louis Brulé Naudet, a sentence embedding model based on BERT fitted using Transformer-based Sequential Denoising Auto-Encoder for unsupervised sentence embedding learning with one objective : french legal domain adaptation. |
|
|
|
This process is designed to be memory efficient and fast, with the binary index being small enough to fit in memory and the int8 index being loaded as a view to save memory. |
|
Additionally, the binary index is much faster (up to 32x) to search than the float32 index, while the rescoring is also extremely efficient. |
|
""" |
|
|
|
instance, dataset, dataframe, DESCRIPTION = setup( |
|
model_name="louisbrulenaudet/tsdae-lemone-mbert-base", |
|
description=DESCRIPTION, |
|
device="cpu", |
|
ndim=768, |
|
metric="ip", |
|
dtype="i8" |
|
) |
|
|
|
|
|
@spaces.GPU |
|
def search( |
|
query:str, |
|
top_k:int, |
|
rescore_multiplier:int |
|
) -> any: |
|
""" |
|
Perform a search operation using the initialized GPU space. |
|
|
|
Parameters |
|
---------- |
|
query : str |
|
The query for which similarity search is performed. |
|
|
|
top_k : int |
|
The number of top results to return. |
|
|
|
rescore_multiplier : int |
|
A multiplier for rescore operation. |
|
|
|
Returns |
|
------- |
|
any |
|
The search results in a suitable format. |
|
|
|
Notes |
|
----- |
|
This function performs a search operation using the initialized GPU space |
|
and returns the search results in a format compatible with the provided |
|
space. |
|
|
|
Examples |
|
-------- |
|
>>> results = search(query="example query", top_k=10, rescore_multiplier=2) |
|
""" |
|
global instance |
|
global dataset |
|
global dataframe |
|
|
|
top_k_scores, top_k_indices = instance.search( |
|
query=query, |
|
top_k=top_k, |
|
rescore_multiplier=rescore_multiplier |
|
) |
|
|
|
scores_df = pl.DataFrame( |
|
{ |
|
"index": top_k_indices, |
|
"score": top_k_scores |
|
} |
|
).with_columns( |
|
pl.col("index").cast(pl.UInt32) |
|
) |
|
|
|
results_df = dataframe.filter( |
|
pl.col("index").is_in(top_k_indices) |
|
).join( |
|
scores_df, |
|
how="inner", |
|
on="index" |
|
).sort( |
|
by="score", |
|
descending=True |
|
).select( |
|
[ |
|
"score", |
|
"input", |
|
"output", |
|
"start", |
|
"expiration" |
|
] |
|
) |
|
|
|
return gr.Dataframe( |
|
value=results_df, |
|
visible=True |
|
) |
|
|
|
|
|
with gr.Blocks(title="Quantized Retrieval") as demo: |
|
gr.Markdown( |
|
value=DESCRIPTION |
|
) |
|
gr.DuplicateButton() |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
query = gr.Textbox( |
|
label="Query for French legal codes articles", |
|
placeholder="Enter a query to search for relevant texts from the French law.", |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
top_k = gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
step=1, |
|
value=20, |
|
label="Number of documents to retrieve", |
|
info="Number of documents to retrieve from the binary search.", |
|
) |
|
with gr.Column(scale=2): |
|
rescore_multiplier = gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
step=1, |
|
value=4, |
|
label="Rescore multiplier", |
|
info="Search for 'rescore_multiplier' as many documents to rescore.", |
|
) |
|
|
|
search_button = gr.Button(value="Search") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
output = gr.Dataframe( |
|
visible=False, |
|
type="polars" |
|
) |
|
|
|
query.submit( |
|
search, |
|
inputs=[ |
|
query, |
|
top_k, |
|
rescore_multiplier |
|
], |
|
outputs=output |
|
) |
|
|
|
search_button.click( |
|
search, |
|
inputs=[ |
|
query, |
|
top_k, |
|
rescore_multiplier |
|
], |
|
outputs=output |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch( |
|
show_api=False |
|
) |