|
import os |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
import hydra |
|
import numpy |
|
import torch |
|
from omegaconf import OmegaConf |
|
from rich.pretty import pprint |
|
|
|
from relik.common import upload |
|
from relik.common.log import get_console_logger, get_logger |
|
from relik.common.utils import ( |
|
from_cache, |
|
is_remote_url, |
|
is_str_a_path, |
|
relative_to_absolute_path, |
|
sapienzanlp_model_urls, |
|
) |
|
from relik.retriever.data.labels import Labels |
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
console_logger = get_console_logger() |
|
|
|
|
|
@dataclass |
|
class IndexerOutput: |
|
indices: Union[torch.Tensor, numpy.ndarray] |
|
distances: Union[torch.Tensor, numpy.ndarray] |
|
|
|
|
|
class BaseDocumentIndex: |
|
CONFIG_NAME = "config.yaml" |
|
DOCUMENTS_FILE_NAME = "documents.json" |
|
EMBEDDINGS_FILE_NAME = "embeddings.pt" |
|
|
|
def __init__( |
|
self, |
|
documents: Union[str, List[str], Labels, os.PathLike, List[os.PathLike]] = None, |
|
embeddings: Optional[torch.Tensor] = None, |
|
name_or_dir: Optional[Union[str, os.PathLike]] = None, |
|
) -> None: |
|
if documents is not None: |
|
if isinstance(documents, Labels): |
|
self.documents = documents |
|
else: |
|
documents_are_paths = False |
|
|
|
|
|
if not isinstance(documents, list): |
|
documents = [documents] |
|
|
|
|
|
if isinstance(documents[0], str) or isinstance( |
|
documents[0], os.PathLike |
|
): |
|
|
|
documents_are_paths = is_str_a_path(documents[0]) |
|
|
|
|
|
if documents_are_paths: |
|
logger.info("Loading documents from paths") |
|
_documents = [] |
|
for doc in documents: |
|
with open(relative_to_absolute_path(doc)) as f: |
|
_documents += [line.strip() for line in f.readlines()] |
|
|
|
documents = list(set(_documents)) |
|
|
|
self.documents = Labels() |
|
self.documents.add_labels(documents) |
|
else: |
|
self.documents = Labels() |
|
|
|
self.embeddings = embeddings |
|
self.name_or_dir = name_or_dir |
|
|
|
def __iter__(self): |
|
|
|
for i in range(len(self)): |
|
yield self[i] |
|
|
|
def __len__(self): |
|
return self.documents.get_label_size() |
|
|
|
def __getitem__(self, index): |
|
return self.get_passage_from_index(index) |
|
|
|
@property |
|
def config(self) -> Dict[str, Any]: |
|
""" |
|
The configuration of the document index. |
|
|
|
Returns: |
|
`Dict[str, Any]`: The configuration of the retriever. |
|
""" |
|
|
|
def obj_to_dict(obj): |
|
match obj: |
|
case dict(): |
|
data = {} |
|
for k, v in obj.items(): |
|
data[k] = obj_to_dict(v) |
|
return data |
|
|
|
case list() | tuple(): |
|
return [obj_to_dict(x) for x in obj] |
|
|
|
case object(__dict__=_): |
|
data = { |
|
"_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}", |
|
} |
|
for k, v in obj.__dict__.items(): |
|
if not k.startswith("_"): |
|
data[k] = obj_to_dict(v) |
|
return data |
|
|
|
case _: |
|
return obj |
|
|
|
return obj_to_dict(self) |
|
|
|
def index( |
|
self, |
|
retriever, |
|
*args, |
|
**kwargs, |
|
) -> "BaseDocumentIndex": |
|
raise NotImplementedError |
|
|
|
def search(self, query: Any, k: int = 1, *args, **kwargs) -> List: |
|
raise NotImplementedError |
|
|
|
def get_index_from_passage(self, document: str) -> int: |
|
""" |
|
Get the index of the passage. |
|
|
|
Args: |
|
document (`str`): |
|
The document to get the index for. |
|
|
|
Returns: |
|
`int`: The index of the document. |
|
""" |
|
return self.documents.get_index_from_label(document) |
|
|
|
def get_passage_from_index(self, index: int) -> str: |
|
""" |
|
Get the document from the index. |
|
|
|
Args: |
|
index (`int`): |
|
The index of the document. |
|
|
|
Returns: |
|
`str`: The document. |
|
""" |
|
return self.documents.get_label_from_index(index) |
|
|
|
def get_embeddings_from_index(self, index: int) -> torch.Tensor: |
|
""" |
|
Get the document vector from the index. |
|
|
|
Args: |
|
index (`int`): |
|
The index of the document. |
|
|
|
Returns: |
|
`torch.Tensor`: The document vector. |
|
""" |
|
if self.embeddings is None: |
|
raise ValueError( |
|
"The documents must be indexed before they can be retrieved." |
|
) |
|
if index >= self.embeddings.shape[0]: |
|
raise ValueError( |
|
f"The index {index} is out of bounds. The maximum index is {len(self.embeddings) - 1}." |
|
) |
|
return self.embeddings[index] |
|
|
|
def get_embeddings_from_passage(self, document: str) -> torch.Tensor: |
|
""" |
|
Get the document vector from the document label. |
|
|
|
Args: |
|
document (`str`): |
|
The document to get the vector for. |
|
|
|
Returns: |
|
`torch.Tensor`: The document vector. |
|
""" |
|
if self.embeddings is None: |
|
raise ValueError( |
|
"The documents must be indexed before they can be retrieved." |
|
) |
|
return self.get_embeddings_from_index(self.get_index_from_passage(document)) |
|
|
|
def save_pretrained( |
|
self, |
|
output_dir: Union[str, os.PathLike], |
|
config: Optional[Dict[str, Any]] = None, |
|
config_file_name: Optional[str] = None, |
|
document_file_name: Optional[str] = None, |
|
embedding_file_name: Optional[str] = None, |
|
push_to_hub: bool = False, |
|
**kwargs, |
|
): |
|
""" |
|
Save the retriever to a directory. |
|
|
|
Args: |
|
output_dir (`str`): |
|
The directory to save the retriever to. |
|
config (`Optional[Dict[str, Any]]`, `optional`): |
|
The configuration to save. If `None`, the current configuration of the retriever will be |
|
saved. Defaults to `None`. |
|
config_file_name (`Optional[str]`, `optional`): |
|
The name of the configuration file. Defaults to `config.yaml`. |
|
document_file_name (`Optional[str]`, `optional`): |
|
The name of the document file. Defaults to `documents.json`. |
|
embedding_file_name (`Optional[str]`, `optional`): |
|
The name of the embedding file. Defaults to `embeddings.pt`. |
|
push_to_hub (`bool`, `optional`): |
|
Whether to push the saved retriever to the hub. Defaults to `False`. |
|
""" |
|
if config is None: |
|
|
|
config = self.config |
|
|
|
config_file_name = config_file_name or self.CONFIG_NAME |
|
document_file_name = document_file_name or self.DOCUMENTS_FILE_NAME |
|
embedding_file_name = embedding_file_name or self.EMBEDDINGS_FILE_NAME |
|
|
|
|
|
output_dir = Path(output_dir) |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
logger.info(f"Saving retriever to {output_dir}") |
|
logger.info(f"Saving config to {output_dir / config_file_name}") |
|
|
|
pprint(config, console=console_logger, expand_all=True) |
|
OmegaConf.save(config, output_dir / config_file_name) |
|
|
|
|
|
embedding_path = output_dir / embedding_file_name |
|
logger.info(f"Saving retriever state to {output_dir / embedding_path}") |
|
torch.save(self.embeddings, embedding_path) |
|
|
|
|
|
documents_path = output_dir / document_file_name |
|
logger.info(f"Saving passage index to {documents_path}") |
|
self.documents.save(documents_path) |
|
|
|
logger.info("Saving document index to disk done.") |
|
|
|
if push_to_hub: |
|
|
|
logger.info(f"Pushing to hub") |
|
model_id = model_id or output_dir.name |
|
upload(output_dir, model_id, **kwargs) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
name_or_dir: Union[str, os.PathLike], |
|
device: str = "cpu", |
|
precision: Optional[str] = None, |
|
config_file_name: Optional[str] = None, |
|
document_file_name: Optional[str] = None, |
|
embedding_file_name: Optional[str] = None, |
|
config_kwargs: Optional[Dict[str, Any]] = None, |
|
*args, |
|
**kwargs, |
|
) -> "BaseDocumentIndex": |
|
cache_dir = kwargs.pop("cache_dir", None) |
|
force_download = kwargs.pop("force_download", False) |
|
|
|
config_file_name = config_file_name or cls.CONFIG_NAME |
|
document_file_name = document_file_name or cls.DOCUMENTS_FILE_NAME |
|
embedding_file_name = embedding_file_name or cls.EMBEDDINGS_FILE_NAME |
|
|
|
model_dir = from_cache( |
|
name_or_dir, |
|
filenames=[config_file_name, document_file_name, embedding_file_name], |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
) |
|
|
|
config_path = model_dir / config_file_name |
|
if not config_path.exists(): |
|
raise FileNotFoundError( |
|
f"Model configuration file not found at {config_path}." |
|
) |
|
|
|
config = OmegaConf.load(config_path) |
|
|
|
if config_kwargs is not None: |
|
config = OmegaConf.merge(config, OmegaConf.create(config_kwargs)) |
|
pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True) |
|
|
|
|
|
documents_path = model_dir / document_file_name |
|
|
|
if not documents_path.exists(): |
|
raise ValueError(f"Document file `{documents_path}` does not exist.") |
|
logger.info(f"Loading documents from {documents_path}") |
|
documents = Labels.from_file(documents_path) |
|
|
|
|
|
embedding_path = model_dir / embedding_file_name |
|
|
|
embeddings = None |
|
if embedding_path.exists(): |
|
logger.info(f"Loading embeddings from {embedding_path}") |
|
embeddings = torch.load(embedding_path, map_location="cpu") |
|
else: |
|
logger.warning(f"Embedding file `{embedding_path}` does not exist.") |
|
|
|
document_index = hydra.utils.instantiate( |
|
config, |
|
documents=documents, |
|
embeddings=embeddings, |
|
device=device, |
|
precision=precision, |
|
name_or_dir=name_or_dir, |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
return document_index |
|
|