legalkit-retrieval / similarity_search.py
louisbrulenaudet's picture
Upload 11 files
6b2dcd4 verified
# -*- coding: utf-8 -*-
# Copyright (c) Louis Brulé Naudet. All Rights Reserved.
# This software may be used and distributed according to the terms of the License Agreement.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import faiss
import numpy as np
import torch
from usearch.index import Index
from sentence_transformers import SentenceTransformer
from sentence_transformers.quantization import quantize_embeddings
from typing import Tuple, List, Union
class SimilaritySearch:
"""
A class dedicated to encoding text data, quantizing embeddings, and managing indices for efficient similarity search.
Attributes
----------
model_name : str
Name or identifier of the embedding model.
device : str
Computation device ('cpu' or 'cuda').
ndim : int
Dimension of the embeddings.
metric : str
Metric used for the index ('ip' for inner product, etc.).
dtype : str
Data type for the index ('i8' for int8, etc.).
Methods
-------
encode(corpus, normalize_embeddings=True)
Encodes a list of text data into embeddings.
quantize_embeddings(embeddings, quantization_type)
Quantizes the embeddings for efficient storage and search.
create_faiss_index(ubinary_embeddings, index_path)
Creates and saves a FAISS binary index.
create_usearch_index(int8_embeddings, index_path)
Creates and saves a USEARCH integer index.
load_usearch_index_view(index_path)
Loads a USEARCH index as a view for memory-efficient operations.
load_faiss_index(index_path)
Loads a FAISS binary index for searching.
search(query, top_k=10, rescore_multiplier=4)
Performs a search operation against the indexed embeddings.
"""
def __init__(
self,
model_name: str,
device: str = "cuda",
ndim: int = 1024,
metric: str = "ip",
dtype: str = "i8"
):
"""
Initializes the EmbeddingIndexer with the specified model, device, and index configurations.
Parameters
----------
model_name : str
The name or identifier of the SentenceTransformer model to use for embedding.
device : str, optional
The computation device to use ('cpu' or 'cuda'). Default is 'cuda'.
ndim : int, optional
The dimensionality of the embeddings. Default is 1024.
metric : str, optional
The metric used for the index ('ip' for inner product). Default is 'ip'.
dtype : str, optional
The data type for the USEARCH index ('i8' for 8-bit integer). Default is 'i8'.
"""
self.model_name = model_name
self.device = device
self.ndim = ndim
self.metric = metric
self.dtype = dtype
self.model = SentenceTransformer(
self.model_name,
device=self.device
)
self.binary_index = None
self.int8_index = None
def encode(
self,
corpus: list,
normalize_embeddings: bool = True
) -> np.ndarray:
"""
Encodes the given corpus into full-precision embeddings.
Parameters
----------
corpus : list
A list of sentences to be encoded.
normalize_embeddings : bool, optional
Whether to normalize returned vectors to have length 1. In that case,
the faster dot-product (util.dot_score) instead of cosine similarity can be used. Default is True.
Returns
-------
np.ndarray
The full-precision embeddings of the corpus.
Notes
-----
This method normalizes the embeddings and shows the progress bar during the encoding process.
"""
try:
embeddings = self.model.encode(
corpus,
normalize_embeddings=normalize_embeddings,
show_progress_bar=True
)
return embeddings
except Exception as e:
print(f"An error occurred during encoding: {e}")
def quantize_embeddings(
self,
embeddings: np.ndarray,
quantization_type: str
) -> Union[np.ndarray, bytearray]:
"""
Quantizes the given embeddings based on the specified quantization type ('ubinary' or 'int8').
Parameters
----------
embeddings : np.ndarray
The full-precision embeddings to be quantized.
quantization_type : str
The type of quantization ('ubinary' for unsigned binary, 'int8' for 8-bit integers).
Returns
-------
Union[np.ndarray, bytearray]
The quantized embeddings.
Raises
------
ValueError
If an unsupported quantization type is provided.
"""
try:
if quantization_type == "ubinary":
return self._quantize_to_ubinary(
embeddings=embeddings
)
elif quantization_type == "int8":
return self._quantize_to_int8(
embeddings=embeddings
)
else:
raise ValueError(f"Unsupported quantization type: {quantization_type}")
except Exception as e:
print(f"An error occurred during quantization: {e}")
def create_faiss_index(
self,
ubinary_embeddings: bytearray,
index_path: str = None,
save: bool = False
) -> None:
"""
Creates and saves a FAISS binary index from ubinary embeddings.
Parameters
----------
ubinary_embeddings : bytearray
The ubinary-quantized embeddings.
index_path : str, optional
The file path to save the FAISS binary index. Default is None.
save : bool, optional
Indicator for saving the index. Default is False.
Notes
-----
The dimensionality of the index is specified during the class initialization (default is 1024).
"""
try:
self.binary_index = faiss.IndexBinaryFlat(
self.ndim
)
self.binary_index.add(
ubinary_embeddings
)
if save and index_path:
self._save_faiss_index_binary(
index_path=index_path
)
except Exception as e:
print(f"An error occurred during index creation: {e}")
def create_usearch_index(
self,
int8_embeddings: np.ndarray,
index_path: str = None,
save: bool = False
) -> None:
"""
Creates and saves a USEARCH integer index from int8 embeddings.
Parameters
----------
int8_embeddings : np.ndarray
The int8-quantized embeddings.
index_path : str, optional
The file path to save the USEARCH integer index. Default is None.
save : bool, optional
Indicator for saving the index. Default is False.
Returns
-------
None
Notes
-----
The dimensionality and metric of the index are specified during class initialization.
"""
try:
self.int8_index = Index(
ndim=self.ndim,
metric=self.metric,
dtype=self.dtype
)
self.int8_index.add(
np.arange(
len(int8_embeddings)
),
int8_embeddings
)
if save == True and index_path:
self._save_int8_index(
index_path=index_path
)
return self.int8_index
except Exception as e:
print(f"An error occurred during USEARCH index creation: {e}")
def load_usearch_index_view(
self,
index_path: str
) -> any:
"""
Loads a USEARCH index as a view for memory-efficient operations.
Parameters
----------
index_path : str
The file path to the USEARCH index to be loaded as a view.
Returns
-------
object
A view of the USEARCH index for memory-efficient similarity search operations.
Notes
-----
Implementing this would depend on the specific USEARCH index handling library being used.
"""
try:
self.int8_index = Index.restore(
index_path,
view=True
)
return self.int8_index
except Exception as e:
print(f"An error occurred while loading USEARCH index: {e}")
def load_faiss_index(
self,
index_path: str
) -> None:
"""
Loads a FAISS binary index from a specified file path.
This method loads a binary index created by FAISS into the class
attribute `binary_index`, ready for performing similarity searches.
Parameters
----------
index_path : str
The file path to the saved FAISS binary index.
Returns
-------
None
Notes
-----
The loaded index is stored in the `binary_index` attribute of the class.
Ensure that the index at `index_path` is compatible with the configurations
(e.g., dimensions) used for this class instance.
"""
try:
self.binary_index = faiss.read_index_binary(
index_path
)
except Exception as e:
print(f"An error occurred while loading the FAISS index: {e}")
def search(
self,
query: str,
top_k: int = 10,
rescore_multiplier: int = 4
) -> Tuple[List[float], List[int]]:
"""
Performs a search operation against the indexed embeddings.
Parameters
----------
query : str
The query sentence/string to be searched.
top_k : int, optional
The number of top results to return.
rescore_multiplier : int, optional
The multiplier used to increase the initial retrieval size for re-scoring.
Higher values can increase precision at the cost of performance.
Returns
-------
Tuple[List[float], List[int]]
A tuple containing the scores and the indices of the top k results.
Notes
-----
This method assumes that `binary_index` and `int8_index` are already loaded or created.
"""
try:
if self.binary_index is None or self.int8_index is None:
raise ValueError("Indices must be loaded or created before searching.")
query_embedding = self.encode(
corpus=query,
normalize_embeddings=False
)
query_embedding_ubinary = self.quantize_embeddings(
embeddings=query_embedding.reshape(1, -1),
quantization_type="ubinary"
)
_scores, binary_ids = self.binary_index.search(
query_embedding_ubinary,
top_k * rescore_multiplier
)
binary_ids = binary_ids[0]
int8_embeddings = self.int8_index[binary_ids].astype(int)
scores = query_embedding @ int8_embeddings.T
indices = (-scores).argsort()[:top_k]
top_k_indices = binary_ids[indices]
top_k_scores = scores[indices]
return top_k_scores.tolist(), top_k_indices.tolist()
except Exception as e:
print(f"An error occurred while searching semantic similar sentences: {e}")
def _quantize_to_ubinary(
self,
embeddings: np.ndarray
) -> np.ndarray:
"""
Placeholder private method for ubinary quantization.
Parameters
----------
embeddings : np.ndarray
The embeddings to quantize.
Returns
-------
np.ndarray
The quantized embeddings.
"""
try:
ubinary_embeddings = quantize_embeddings(
embeddings,
"ubinary"
)
return ubinary_embeddings
except Exception as e:
print(f"An error occurred during ubinary quantization: {e}")
def _quantize_to_int8(
self,
embeddings: np.ndarray
) -> np.ndarray:
"""
Placeholder private method for int8 quantization.
Parameters
----------
embeddings : np.ndarray
The embeddings to quantize.
Returns
-------
np.ndarray
The quantized embeddings.
"""
try:
int8_embeddings = quantize_embeddings(
embeddings,
"int8"
)
return int8_embeddings
except Exception as e:
print(f"An error occurred during int8 quantization: {e}")
def _save_faiss_index_binary(
self,
index_path: str
) -> None:
"""
Saves the FAISS binary index to disk.
This private method is called internally to save the constructed FAISS binary index to the specified file path.
Parameters
----------
index_path : str
The path to the file where the binary index should be saved. This value is checked in the public method
`create_faiss_index`.
Returns
-------
None
Notes
-----
This method should not be called directly. It is intended to be used internally by the `create_faiss_index` method.
"""
try:
faiss.write_index_binary(
self.binary_index,
index_path
)
return None
except Exception as e:
print(f"An error occurred during FAISS binary index saving: {e}")
def _save_int8_index(
self,
index_path: str
) -> None:
"""
Saves the int8_index to disk.
This private method is called internally to save the constructed int8_index to the specified file path.
Parameters
----------
index_path : str
The path to the file where the int8_index should be saved. This value is checked in the public method
`_save_int8_index`.
Returns
-------
None
Notes
-----
This method should not be called directly. It is intended to be used internally by the `_save_int8_index` method.
"""
try:
self.int8_index.save(
index_path
)
return None
except Exception as e:
print(f"An error occurred during int8_index saving: {e}")