|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import faiss |
|
import numpy as np |
|
import torch |
|
|
|
from usearch.index import Index |
|
|
|
from sentence_transformers import SentenceTransformer |
|
from sentence_transformers.quantization import quantize_embeddings |
|
|
|
from typing import Tuple, List, Union |
|
|
|
class SimilaritySearch: |
|
""" |
|
A class dedicated to encoding text data, quantizing embeddings, and managing indices for efficient similarity search. |
|
|
|
Attributes |
|
---------- |
|
model_name : str |
|
Name or identifier of the embedding model. |
|
|
|
device : str |
|
Computation device ('cpu' or 'cuda'). |
|
|
|
ndim : int |
|
Dimension of the embeddings. |
|
|
|
metric : str |
|
Metric used for the index ('ip' for inner product, etc.). |
|
|
|
dtype : str |
|
Data type for the index ('i8' for int8, etc.). |
|
|
|
Methods |
|
------- |
|
encode(corpus, normalize_embeddings=True) |
|
Encodes a list of text data into embeddings. |
|
|
|
quantize_embeddings(embeddings, quantization_type) |
|
Quantizes the embeddings for efficient storage and search. |
|
|
|
create_faiss_index(ubinary_embeddings, index_path) |
|
Creates and saves a FAISS binary index. |
|
|
|
create_usearch_index(int8_embeddings, index_path) |
|
Creates and saves a USEARCH integer index. |
|
|
|
load_usearch_index_view(index_path) |
|
Loads a USEARCH index as a view for memory-efficient operations. |
|
|
|
load_faiss_index(index_path) |
|
Loads a FAISS binary index for searching. |
|
|
|
search(query, top_k=10, rescore_multiplier=4) |
|
Performs a search operation against the indexed embeddings. |
|
""" |
|
def __init__( |
|
self, |
|
model_name: str, |
|
device: str = "cuda", |
|
ndim: int = 1024, |
|
metric: str = "ip", |
|
dtype: str = "i8" |
|
): |
|
""" |
|
Initializes the EmbeddingIndexer with the specified model, device, and index configurations. |
|
|
|
Parameters |
|
---------- |
|
model_name : str |
|
The name or identifier of the SentenceTransformer model to use for embedding. |
|
|
|
device : str, optional |
|
The computation device to use ('cpu' or 'cuda'). Default is 'cuda'. |
|
|
|
ndim : int, optional |
|
The dimensionality of the embeddings. Default is 1024. |
|
|
|
metric : str, optional |
|
The metric used for the index ('ip' for inner product). Default is 'ip'. |
|
|
|
dtype : str, optional |
|
The data type for the USEARCH index ('i8' for 8-bit integer). Default is 'i8'. |
|
""" |
|
self.model_name = model_name |
|
self.device = device |
|
self.ndim = ndim |
|
self.metric = metric |
|
self.dtype = dtype |
|
self.model = SentenceTransformer( |
|
self.model_name, |
|
device=self.device |
|
) |
|
|
|
self.binary_index = None |
|
self.int8_index = None |
|
|
|
|
|
def encode( |
|
self, |
|
corpus: list, |
|
normalize_embeddings: bool = True |
|
) -> np.ndarray: |
|
""" |
|
Encodes the given corpus into full-precision embeddings. |
|
|
|
Parameters |
|
---------- |
|
corpus : list |
|
A list of sentences to be encoded. |
|
|
|
normalize_embeddings : bool, optional |
|
Whether to normalize returned vectors to have length 1. In that case, |
|
the faster dot-product (util.dot_score) instead of cosine similarity can be used. Default is True. |
|
|
|
Returns |
|
------- |
|
np.ndarray |
|
The full-precision embeddings of the corpus. |
|
|
|
Notes |
|
----- |
|
This method normalizes the embeddings and shows the progress bar during the encoding process. |
|
""" |
|
try: |
|
embeddings = self.model.encode( |
|
corpus, |
|
normalize_embeddings=normalize_embeddings, |
|
show_progress_bar=True |
|
) |
|
return embeddings |
|
|
|
except Exception as e: |
|
print(f"An error occurred during encoding: {e}") |
|
|
|
|
|
def quantize_embeddings( |
|
self, |
|
embeddings: np.ndarray, |
|
quantization_type: str |
|
) -> Union[np.ndarray, bytearray]: |
|
""" |
|
Quantizes the given embeddings based on the specified quantization type ('ubinary' or 'int8'). |
|
|
|
Parameters |
|
---------- |
|
embeddings : np.ndarray |
|
The full-precision embeddings to be quantized. |
|
quantization_type : str |
|
The type of quantization ('ubinary' for unsigned binary, 'int8' for 8-bit integers). |
|
|
|
Returns |
|
------- |
|
Union[np.ndarray, bytearray] |
|
The quantized embeddings. |
|
|
|
Raises |
|
------ |
|
ValueError |
|
If an unsupported quantization type is provided. |
|
""" |
|
try: |
|
if quantization_type == "ubinary": |
|
return self._quantize_to_ubinary( |
|
embeddings=embeddings |
|
) |
|
|
|
elif quantization_type == "int8": |
|
return self._quantize_to_int8( |
|
embeddings=embeddings |
|
) |
|
|
|
else: |
|
raise ValueError(f"Unsupported quantization type: {quantization_type}") |
|
|
|
except Exception as e: |
|
print(f"An error occurred during quantization: {e}") |
|
|
|
|
|
def create_faiss_index( |
|
self, |
|
ubinary_embeddings: bytearray, |
|
index_path: str = None, |
|
save: bool = False |
|
) -> None: |
|
""" |
|
Creates and saves a FAISS binary index from ubinary embeddings. |
|
|
|
Parameters |
|
---------- |
|
ubinary_embeddings : bytearray |
|
The ubinary-quantized embeddings. |
|
|
|
index_path : str, optional |
|
The file path to save the FAISS binary index. Default is None. |
|
|
|
save : bool, optional |
|
Indicator for saving the index. Default is False. |
|
|
|
Notes |
|
----- |
|
The dimensionality of the index is specified during the class initialization (default is 1024). |
|
""" |
|
try: |
|
self.binary_index = faiss.IndexBinaryFlat( |
|
self.ndim |
|
) |
|
self.binary_index.add( |
|
ubinary_embeddings |
|
) |
|
|
|
if save and index_path: |
|
self._save_faiss_index_binary( |
|
index_path=index_path |
|
) |
|
|
|
except Exception as e: |
|
print(f"An error occurred during index creation: {e}") |
|
|
|
|
|
def create_usearch_index( |
|
self, |
|
int8_embeddings: np.ndarray, |
|
index_path: str = None, |
|
save: bool = False |
|
) -> None: |
|
""" |
|
Creates and saves a USEARCH integer index from int8 embeddings. |
|
|
|
Parameters |
|
---------- |
|
int8_embeddings : np.ndarray |
|
The int8-quantized embeddings. |
|
|
|
index_path : str, optional |
|
The file path to save the USEARCH integer index. Default is None. |
|
|
|
save : bool, optional |
|
Indicator for saving the index. Default is False. |
|
|
|
Returns |
|
------- |
|
None |
|
|
|
Notes |
|
----- |
|
The dimensionality and metric of the index are specified during class initialization. |
|
""" |
|
try: |
|
self.int8_index = Index( |
|
ndim=self.ndim, |
|
metric=self.metric, |
|
dtype=self.dtype |
|
) |
|
|
|
self.int8_index.add( |
|
np.arange( |
|
len(int8_embeddings) |
|
), |
|
int8_embeddings |
|
) |
|
|
|
if save == True and index_path: |
|
self._save_int8_index( |
|
index_path=index_path |
|
) |
|
|
|
return self.int8_index |
|
|
|
except Exception as e: |
|
print(f"An error occurred during USEARCH index creation: {e}") |
|
|
|
|
|
def load_usearch_index_view( |
|
self, |
|
index_path: str |
|
) -> any: |
|
""" |
|
Loads a USEARCH index as a view for memory-efficient operations. |
|
|
|
Parameters |
|
---------- |
|
index_path : str |
|
The file path to the USEARCH index to be loaded as a view. |
|
|
|
Returns |
|
------- |
|
object |
|
A view of the USEARCH index for memory-efficient similarity search operations. |
|
|
|
Notes |
|
----- |
|
Implementing this would depend on the specific USEARCH index handling library being used. |
|
""" |
|
try: |
|
self.int8_index = Index.restore( |
|
index_path, |
|
view=True |
|
) |
|
|
|
return self.int8_index |
|
|
|
except Exception as e: |
|
print(f"An error occurred while loading USEARCH index: {e}") |
|
|
|
|
|
def load_faiss_index( |
|
self, |
|
index_path: str |
|
) -> None: |
|
""" |
|
Loads a FAISS binary index from a specified file path. |
|
|
|
This method loads a binary index created by FAISS into the class |
|
attribute `binary_index`, ready for performing similarity searches. |
|
|
|
Parameters |
|
---------- |
|
index_path : str |
|
The file path to the saved FAISS binary index. |
|
|
|
Returns |
|
------- |
|
None |
|
|
|
Notes |
|
----- |
|
The loaded index is stored in the `binary_index` attribute of the class. |
|
Ensure that the index at `index_path` is compatible with the configurations |
|
(e.g., dimensions) used for this class instance. |
|
""" |
|
try: |
|
self.binary_index = faiss.read_index_binary( |
|
index_path |
|
) |
|
|
|
except Exception as e: |
|
print(f"An error occurred while loading the FAISS index: {e}") |
|
|
|
|
|
def search( |
|
self, |
|
query: str, |
|
top_k: int = 10, |
|
rescore_multiplier: int = 4 |
|
) -> Tuple[List[float], List[int]]: |
|
""" |
|
Performs a search operation against the indexed embeddings. |
|
|
|
Parameters |
|
---------- |
|
query : str |
|
The query sentence/string to be searched. |
|
|
|
top_k : int, optional |
|
The number of top results to return. |
|
|
|
rescore_multiplier : int, optional |
|
The multiplier used to increase the initial retrieval size for re-scoring. |
|
Higher values can increase precision at the cost of performance. |
|
|
|
Returns |
|
------- |
|
Tuple[List[float], List[int]] |
|
A tuple containing the scores and the indices of the top k results. |
|
|
|
Notes |
|
----- |
|
This method assumes that `binary_index` and `int8_index` are already loaded or created. |
|
""" |
|
try: |
|
if self.binary_index is None or self.int8_index is None: |
|
raise ValueError("Indices must be loaded or created before searching.") |
|
|
|
query_embedding = self.encode( |
|
corpus=query, |
|
normalize_embeddings=False |
|
) |
|
|
|
query_embedding_ubinary = self.quantize_embeddings( |
|
embeddings=query_embedding.reshape(1, -1), |
|
quantization_type="ubinary" |
|
) |
|
|
|
_scores, binary_ids = self.binary_index.search( |
|
query_embedding_ubinary, |
|
top_k * rescore_multiplier |
|
) |
|
|
|
binary_ids = binary_ids[0] |
|
|
|
int8_embeddings = self.int8_index[binary_ids].astype(int) |
|
|
|
scores = query_embedding @ int8_embeddings.T |
|
|
|
indices = (-scores).argsort()[:top_k] |
|
|
|
top_k_indices = binary_ids[indices] |
|
top_k_scores = scores[indices] |
|
|
|
return top_k_scores.tolist(), top_k_indices.tolist() |
|
|
|
except Exception as e: |
|
print(f"An error occurred while searching semantic similar sentences: {e}") |
|
|
|
|
|
def _quantize_to_ubinary( |
|
self, |
|
embeddings: np.ndarray |
|
) -> np.ndarray: |
|
""" |
|
Placeholder private method for ubinary quantization. |
|
|
|
Parameters |
|
---------- |
|
embeddings : np.ndarray |
|
The embeddings to quantize. |
|
|
|
Returns |
|
------- |
|
np.ndarray |
|
The quantized embeddings. |
|
""" |
|
try: |
|
ubinary_embeddings = quantize_embeddings( |
|
embeddings, |
|
"ubinary" |
|
) |
|
return ubinary_embeddings |
|
|
|
except Exception as e: |
|
print(f"An error occurred during ubinary quantization: {e}") |
|
|
|
|
|
def _quantize_to_int8( |
|
self, |
|
embeddings: np.ndarray |
|
) -> np.ndarray: |
|
""" |
|
Placeholder private method for int8 quantization. |
|
|
|
Parameters |
|
---------- |
|
embeddings : np.ndarray |
|
The embeddings to quantize. |
|
|
|
Returns |
|
------- |
|
np.ndarray |
|
The quantized embeddings. |
|
""" |
|
try: |
|
int8_embeddings = quantize_embeddings( |
|
embeddings, |
|
"int8" |
|
) |
|
|
|
return int8_embeddings |
|
|
|
except Exception as e: |
|
print(f"An error occurred during int8 quantization: {e}") |
|
|
|
|
|
def _save_faiss_index_binary( |
|
self, |
|
index_path: str |
|
) -> None: |
|
""" |
|
Saves the FAISS binary index to disk. |
|
|
|
This private method is called internally to save the constructed FAISS binary index to the specified file path. |
|
|
|
Parameters |
|
---------- |
|
index_path : str |
|
The path to the file where the binary index should be saved. This value is checked in the public method |
|
`create_faiss_index`. |
|
|
|
Returns |
|
------- |
|
None |
|
|
|
Notes |
|
----- |
|
This method should not be called directly. It is intended to be used internally by the `create_faiss_index` method. |
|
""" |
|
try: |
|
faiss.write_index_binary( |
|
self.binary_index, |
|
index_path |
|
) |
|
|
|
return None |
|
|
|
except Exception as e: |
|
print(f"An error occurred during FAISS binary index saving: {e}") |
|
|
|
|
|
def _save_int8_index( |
|
self, |
|
index_path: str |
|
) -> None: |
|
""" |
|
Saves the int8_index to disk. |
|
|
|
This private method is called internally to save the constructed int8_index to the specified file path. |
|
|
|
Parameters |
|
---------- |
|
index_path : str |
|
The path to the file where the int8_index should be saved. This value is checked in the public method |
|
`_save_int8_index`. |
|
|
|
Returns |
|
------- |
|
None |
|
|
|
Notes |
|
----- |
|
This method should not be called directly. It is intended to be used internally by the `_save_int8_index` method. |
|
""" |
|
try: |
|
self.int8_index.save( |
|
index_path |
|
) |
|
|
|
return None |
|
|
|
except Exception as e: |
|
print(f"An error occurred during int8_index saving: {e}") |