|
import contextlib |
|
import logging |
|
import math |
|
import os |
|
from dataclasses import dataclass |
|
from typing import Callable, List, Optional, Union |
|
|
|
import numpy |
|
import torch |
|
from pytorch_modules import RetrievedSample |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from relik.common.log import get_logger |
|
from relik.common.utils import is_package_available |
|
from relik.retriever.common.model_inputs import ModelInputs |
|
from relik.retriever.data.base.datasets import BaseDataset |
|
from relik.retriever.data.labels import Labels |
|
from relik.retriever.indexers.base import BaseDocumentIndex |
|
from relik.retriever.pytorch_modules import PRECISION_MAP |
|
from relik.retriever.pytorch_modules.model import GoldenRetriever |
|
|
|
if is_package_available("faiss"): |
|
import faiss |
|
import faiss.contrib.torch_utils |
|
|
|
logger = get_logger(__name__, level=logging.INFO) |
|
|
|
|
|
@dataclass |
|
class FaissOutput: |
|
indices: Union[torch.Tensor, numpy.ndarray] |
|
distances: Union[torch.Tensor, numpy.ndarray] |
|
|
|
|
|
class FaissDocumentIndex(BaseDocumentIndex): |
|
DOCUMENTS_FILE_NAME = "documents.json" |
|
EMBEDDINGS_FILE_NAME = "embeddings.pt" |
|
INDEX_FILE_NAME = "index.faiss" |
|
|
|
def __init__( |
|
self, |
|
documents: Union[List[str], Labels], |
|
embeddings: Optional[Union[torch.Tensor, numpy.ndarray]] = None, |
|
index=None, |
|
index_type: str = "Flat", |
|
metric: int = faiss.METRIC_INNER_PRODUCT, |
|
normalize: bool = False, |
|
device: str = "cpu", |
|
name_or_dir: Optional[Union[str, os.PathLike]] = None, |
|
*args, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(documents, embeddings, name_or_dir) |
|
|
|
if embeddings is not None and documents is not None: |
|
logger.info("Both documents and embeddings are provided.") |
|
if documents.get_label_size() != embeddings.shape[0]: |
|
raise ValueError( |
|
"The number of documents and embeddings must be the same." |
|
) |
|
|
|
|
|
self.device = device |
|
|
|
|
|
self.index_type = index_type |
|
self.metric = metric |
|
self.normalize = normalize |
|
|
|
if index is not None: |
|
self.embeddings = index |
|
if self.device == "cuda": |
|
|
|
faiss_resource = faiss.StandardGpuResources() |
|
self.embeddings = faiss.index_cpu_to_gpu( |
|
faiss_resource, 0, self.embeddings |
|
) |
|
else: |
|
if embeddings is not None: |
|
|
|
logger.info("Building the index from the embeddings.") |
|
self.embeddings = self._build_faiss_index( |
|
embeddings=embeddings, |
|
index_type=index_type, |
|
normalize=normalize, |
|
metric=metric, |
|
) |
|
|
|
def _build_faiss_index( |
|
self, |
|
embeddings: Optional[Union[torch.Tensor, numpy.ndarray]], |
|
index_type: str, |
|
normalize: bool, |
|
metric: int, |
|
): |
|
|
|
self.normalize = ( |
|
normalize |
|
and metric == faiss.METRIC_INNER_PRODUCT |
|
and not isinstance(embeddings, torch.Tensor) |
|
) |
|
if self.normalize: |
|
index_type = f"L2norm,{index_type}" |
|
faiss_vector_size = embeddings.shape[1] |
|
if self.device == "cpu": |
|
index_type = index_type.replace("x,", "x_HNSW32,") |
|
index_type = index_type.replace( |
|
"x", str(math.ceil(math.sqrt(faiss_vector_size)) * 4) |
|
) |
|
self.embeddings = faiss.index_factory(faiss_vector_size, index_type, metric) |
|
|
|
|
|
if self.device == "cuda": |
|
|
|
faiss_resource = faiss.StandardGpuResources() |
|
self.embeddings = faiss.index_cpu_to_gpu(faiss_resource, 0, self.embeddings) |
|
else: |
|
|
|
embeddings = ( |
|
embeddings.cpu() if isinstance(embeddings, torch.Tensor) else embeddings |
|
) |
|
|
|
|
|
if isinstance(embeddings, torch.Tensor) and embeddings.dtype == torch.float16: |
|
embeddings = embeddings.float() |
|
|
|
self.embeddings.add(embeddings) |
|
|
|
|
|
self.index_type = index_type |
|
self.metric = metric |
|
|
|
|
|
embeddings = None |
|
|
|
return self.embeddings |
|
|
|
@torch.no_grad() |
|
@torch.inference_mode() |
|
def index( |
|
self, |
|
retriever: GoldenRetriever, |
|
documents: Optional[List[str]] = None, |
|
batch_size: int = 32, |
|
num_workers: int = 4, |
|
max_length: Optional[int] = None, |
|
collate_fn: Optional[Callable] = None, |
|
encoder_precision: Optional[Union[str, int]] = None, |
|
compute_on_cpu: bool = False, |
|
force_reindex: bool = False, |
|
*args, |
|
**kwargs, |
|
) -> "FaissDocumentIndex": |
|
""" |
|
Index the documents using the encoder. |
|
|
|
Args: |
|
retriever (:obj:`torch.nn.Module`): |
|
The encoder to be used for indexing. |
|
documents (:obj:`List[str]`, `optional`, defaults to None): |
|
The documents to be indexed. |
|
batch_size (:obj:`int`, `optional`, defaults to 32): |
|
The batch size to be used for indexing. |
|
num_workers (:obj:`int`, `optional`, defaults to 4): |
|
The number of workers to be used for indexing. |
|
max_length (:obj:`int`, `optional`, defaults to None): |
|
The maximum length of the input to the encoder. |
|
collate_fn (:obj:`Callable`, `optional`, defaults to None): |
|
The collate function to be used for batching. |
|
encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None): |
|
The precision to be used for the encoder. |
|
compute_on_cpu (:obj:`bool`, `optional`, defaults to False): |
|
Whether to compute the embeddings on CPU. |
|
force_reindex (:obj:`bool`, `optional`, defaults to False): |
|
Whether to force reindexing. |
|
|
|
Returns: |
|
:obj:`InMemoryIndexer`: The indexer object. |
|
""" |
|
|
|
if self.embeddings is not None and not force_reindex: |
|
logger.log( |
|
"Embeddings are already present and `force_reindex` is `False`. Skipping indexing." |
|
) |
|
if documents is None: |
|
return self |
|
|
|
|
|
if collate_fn is None: |
|
tokenizer = retriever.passage_tokenizer |
|
|
|
def collate_fn(x): |
|
return ModelInputs( |
|
tokenizer( |
|
x, |
|
padding=True, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=max_length or tokenizer.model_max_length, |
|
) |
|
) |
|
|
|
if force_reindex: |
|
if documents is not None: |
|
self.documents.add_labels(documents) |
|
data = [k for k in self.documents.get_labels()] |
|
|
|
else: |
|
if documents is not None: |
|
data = [k for k in Labels(documents).get_labels()] |
|
else: |
|
return self |
|
|
|
dataloader = DataLoader( |
|
BaseDataset(name="passage", data=data), |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=num_workers, |
|
pin_memory=False, |
|
collate_fn=collate_fn, |
|
) |
|
|
|
encoder = retriever.passage_encoder |
|
|
|
|
|
passage_embeddings: List[torch.Tensor] = [] |
|
|
|
encoder_device = "cpu" if compute_on_cpu else self.device |
|
|
|
|
|
|
|
device_type_for_autocast = str(encoder_device).split(":")[0] |
|
|
|
autocast_pssg_mngr = ( |
|
contextlib.nullcontext() |
|
if device_type_for_autocast == "cpu" |
|
else ( |
|
torch.autocast( |
|
device_type=device_type_for_autocast, |
|
dtype=PRECISION_MAP[encoder_precision], |
|
) |
|
) |
|
) |
|
with autocast_pssg_mngr: |
|
|
|
for batch in tqdm(dataloader, desc="Indexing"): |
|
|
|
batch: ModelInputs = batch.to(encoder_device) |
|
|
|
passage_outs = encoder(**batch) |
|
|
|
if self.device == "cpu": |
|
passage_embeddings.extend([c.detach().cpu() for c in passage_outs]) |
|
else: |
|
passage_embeddings.extend([c for c in passage_outs]) |
|
|
|
|
|
passage_embeddings = [c.detach().cpu() for c in passage_embeddings] |
|
|
|
passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0) |
|
|
|
passage_embeddings.to(PRECISION_MAP["float32"]) |
|
|
|
|
|
self.embeddings = self._build_faiss_index( |
|
embeddings=passage_embeddings, |
|
index_type=self.index_type, |
|
normalize=self.normalize, |
|
metric=self.metric, |
|
) |
|
|
|
del passage_embeddings |
|
|
|
return self |
|
|
|
@torch.no_grad() |
|
@torch.inference_mode() |
|
def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]: |
|
k = min(k, self.embeddings.ntotal) |
|
|
|
if self.normalize: |
|
faiss.normalize_L2(query) |
|
if isinstance(query, torch.Tensor) and self.device == "cpu": |
|
query = query.detach().cpu() |
|
|
|
retriever_out = self.embeddings.search(query, k) |
|
|
|
|
|
batch_top_k: List[List[int]] = retriever_out[1].detach().cpu().tolist() |
|
|
|
batch_scores: List[List[float]] = retriever_out[0].detach().cpu().tolist() |
|
|
|
batch_passages = [ |
|
[self.documents.get_label_from_index(i) for i in indices] |
|
for indices in batch_top_k |
|
] |
|
|
|
batch_retrieved_samples = [ |
|
[ |
|
RetrievedSample(label=passage, index=index, score=score) |
|
for passage, index, score in zip(passages, indices, scores) |
|
] |
|
for passages, indices, scores in zip( |
|
batch_passages, batch_top_k, batch_scores |
|
) |
|
] |
|
return batch_retrieved_samples |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_embeddings_from_index( |
|
self, index: int |
|
) -> Union[torch.Tensor, numpy.ndarray]: |
|
""" |
|
Get the document vector from the index. |
|
|
|
Args: |
|
index (`int`): |
|
The index of the document. |
|
|
|
Returns: |
|
`torch.Tensor`: The document vector. |
|
""" |
|
if self.embeddings is None: |
|
raise ValueError( |
|
"The documents must be indexed before they can be retrieved." |
|
) |
|
if index >= self.embeddings.ntotal: |
|
raise ValueError( |
|
f"The index {index} is out of bounds. The maximum index is {self.embeddings.ntotal}." |
|
) |
|
return self.embeddings.reconstruct(index) |
|
|