|
import contextlib |
|
import logging |
|
import os |
|
from typing import Callable, List, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from relik.common.log import get_logger |
|
from relik.retriever.common.model_inputs import ModelInputs |
|
from relik.retriever.data.base.datasets import BaseDataset |
|
from relik.retriever.data.labels import Labels |
|
from relik.retriever.indexers.base import BaseDocumentIndex |
|
from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample |
|
|
|
logger = get_logger(__name__, level=logging.INFO) |
|
|
|
|
|
class InMemoryDocumentIndex(BaseDocumentIndex): |
|
DOCUMENTS_FILE_NAME = "documents.json" |
|
EMBEDDINGS_FILE_NAME = "embeddings.pt" |
|
|
|
def __init__( |
|
self, |
|
documents: Union[str, List[str], Labels, os.PathLike, List[os.PathLike]] = None, |
|
embeddings: Optional[torch.Tensor] = None, |
|
device: str = "cpu", |
|
precision: Optional[str] = None, |
|
name_or_dir: Optional[Union[str, os.PathLike]] = None, |
|
*args, |
|
**kwargs, |
|
) -> None: |
|
""" |
|
An in-memory indexer. |
|
|
|
Args: |
|
documents (:obj:`Union[List[str], PassageManager]`): |
|
The documents to be indexed. |
|
embeddings (:obj:`Optional[torch.Tensor]`, `optional`, defaults to :obj:`None`): |
|
The embeddings of the documents. |
|
device (:obj:`str`, `optional`, defaults to "cpu"): |
|
The device to be used for storing the embeddings. |
|
""" |
|
|
|
super().__init__(documents, embeddings, name_or_dir) |
|
|
|
if embeddings is not None and documents is not None: |
|
logger.info("Both documents and embeddings are provided.") |
|
if documents.get_label_size() != embeddings.shape[0]: |
|
raise ValueError( |
|
"The number of documents and embeddings must be the same." |
|
) |
|
|
|
|
|
self.embeddings = embeddings |
|
|
|
del embeddings |
|
|
|
if precision is not None: |
|
if ( |
|
self.embeddings is not None |
|
and self.embeddings.dtype != PRECISION_MAP[precision] |
|
): |
|
logger.info( |
|
f"Index vectors are of type {self.embeddings.dtype}. " |
|
f"Converting to {PRECISION_MAP[precision]}." |
|
) |
|
self.embeddings = self.embeddings.to(PRECISION_MAP[precision]) |
|
else: |
|
if ( |
|
device == "cpu" |
|
and self.embeddings is not None |
|
and self.embeddings.dtype != torch.float32 |
|
): |
|
logger.info( |
|
"Index vectors are of type {}. Converting to float32.".format( |
|
self.embeddings.dtype |
|
) |
|
) |
|
self.embeddings = self.embeddings.to(PRECISION_MAP[32]) |
|
|
|
if self.embeddings is not None and not self.embeddings.device == device: |
|
self.embeddings = self.embeddings.to(device) |
|
|
|
|
|
self.device = device |
|
|
|
self.precision = precision |
|
|
|
@torch.no_grad() |
|
@torch.inference_mode() |
|
def index( |
|
self, |
|
retriever, |
|
documents: Optional[List[str]] = None, |
|
batch_size: int = 32, |
|
num_workers: int = 4, |
|
max_length: Optional[int] = None, |
|
collate_fn: Optional[Callable] = None, |
|
encoder_precision: Optional[Union[str, int]] = None, |
|
compute_on_cpu: bool = False, |
|
force_reindex: bool = False, |
|
add_to_existing_index: bool = False, |
|
) -> "InMemoryDocumentIndex": |
|
""" |
|
Index the documents using the encoder. |
|
|
|
Args: |
|
retriever (:obj:`torch.nn.Module`): |
|
The encoder to be used for indexing. |
|
documents (:obj:`List[str]`, `optional`, defaults to :obj:`None`): |
|
The documents to be indexed. |
|
batch_size (:obj:`int`, `optional`, defaults to 32): |
|
The batch size to be used for indexing. |
|
num_workers (:obj:`int`, `optional`, defaults to 4): |
|
The number of workers to be used for indexing. |
|
max_length (:obj:`int`, `optional`, defaults to None): |
|
The maximum length of the input to the encoder. |
|
collate_fn (:obj:`Callable`, `optional`, defaults to None): |
|
The collate function to be used for batching. |
|
encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None): |
|
The precision to be used for the encoder. |
|
compute_on_cpu (:obj:`bool`, `optional`, defaults to False): |
|
Whether to compute the embeddings on CPU. |
|
force_reindex (:obj:`bool`, `optional`, defaults to False): |
|
Whether to force reindexing. |
|
add_to_existing_index (:obj:`bool`, `optional`, defaults to False): |
|
Whether to add the new documents to the existing index. |
|
|
|
Returns: |
|
:obj:`InMemoryIndexer`: The indexer object. |
|
""" |
|
|
|
if documents is None and self.documents is None: |
|
raise ValueError("Documents must be provided.") |
|
|
|
if self.embeddings is not None and not force_reindex: |
|
logger.info( |
|
"Embeddings are already present and `force_reindex` is `False`. Skipping indexing." |
|
) |
|
if documents is None: |
|
return self |
|
|
|
if collate_fn is None: |
|
tokenizer = retriever.passage_tokenizer |
|
|
|
def collate_fn(x): |
|
return ModelInputs( |
|
tokenizer( |
|
x, |
|
padding=True, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=max_length or tokenizer.model_max_length, |
|
) |
|
) |
|
|
|
if force_reindex: |
|
if documents is not None: |
|
self.documents.add_labels(documents) |
|
data = [k for k in self.documents.get_labels()] |
|
|
|
else: |
|
if documents is not None: |
|
data = [k for k in Labels(documents).get_labels()] |
|
else: |
|
return self |
|
|
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
BaseDataset(name="passage", data=data), |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=num_workers, |
|
pin_memory=False, |
|
collate_fn=collate_fn, |
|
) |
|
|
|
encoder = retriever.passage_encoder |
|
|
|
|
|
passage_embeddings: List[torch.Tensor] = [] |
|
|
|
encoder_device = "cpu" if compute_on_cpu else self.device |
|
|
|
|
|
|
|
device_type_for_autocast = str(encoder_device).split(":")[0] |
|
|
|
autocast_pssg_mngr = ( |
|
contextlib.nullcontext() |
|
if device_type_for_autocast == "cpu" |
|
else ( |
|
torch.autocast( |
|
device_type=device_type_for_autocast, |
|
dtype=PRECISION_MAP[encoder_precision], |
|
) |
|
) |
|
) |
|
with autocast_pssg_mngr: |
|
|
|
for batch in tqdm(dataloader, desc="Indexing"): |
|
|
|
batch: ModelInputs = batch.to(encoder_device) |
|
|
|
passage_outs = encoder(**batch).pooler_output |
|
|
|
if self.device == "cpu": |
|
passage_embeddings.extend([c.detach().cpu() for c in passage_outs]) |
|
else: |
|
passage_embeddings.extend([c for c in passage_outs]) |
|
|
|
|
|
|
|
if not self.device == "cpu": |
|
passage_embeddings = [c.detach().cpu() for c in passage_embeddings] |
|
|
|
passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0) |
|
|
|
if not self.device == "cpu": |
|
passage_embeddings = passage_embeddings.to(PRECISION_MAP[self.precision]) |
|
passage_embeddings = passage_embeddings.to(self.device) |
|
self.embeddings = passage_embeddings |
|
|
|
|
|
del passage_embeddings |
|
|
|
return self |
|
|
|
@torch.no_grad() |
|
@torch.inference_mode() |
|
def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]: |
|
""" |
|
Search the documents using the query. |
|
|
|
Args: |
|
query (:obj:`torch.Tensor`): |
|
The query to be used for searching. |
|
k (:obj:`int`, `optional`, defaults to 1): |
|
The number of documents to be retrieved. |
|
|
|
Returns: |
|
:obj:`List[RetrievedSample]`: The retrieved documents. |
|
""" |
|
|
|
|
|
device_type_for_autocast = str(self.device).split(":")[0] |
|
|
|
autocast_pssg_mngr = ( |
|
contextlib.nullcontext() |
|
if device_type_for_autocast == "cpu" |
|
else ( |
|
torch.autocast( |
|
device_type=device_type_for_autocast, |
|
dtype=self.embeddings.dtype, |
|
) |
|
) |
|
) |
|
with autocast_pssg_mngr: |
|
similarity = torch.matmul(query, self.embeddings.T) |
|
|
|
retriever_out: Tuple = torch.topk( |
|
similarity, k=min(k, similarity.shape[-1]), dim=1 |
|
) |
|
|
|
batch_top_k: List[List[int]] = retriever_out.indices.detach().cpu().tolist() |
|
|
|
batch_scores: List[List[float]] = retriever_out.values.detach().cpu().tolist() |
|
|
|
batch_passages = [ |
|
[self.documents.get_label_from_index(i) for i in indices] |
|
for indices in batch_top_k |
|
] |
|
|
|
batch_retrieved_samples = [ |
|
[ |
|
RetrievedSample(label=passage, index=index, score=score) |
|
for passage, index, score in zip(passages, indices, scores) |
|
] |
|
for passages, indices, scores in zip( |
|
batch_passages, batch_top_k, batch_scores |
|
) |
|
] |
|
return batch_retrieved_samples |
|
|