import os from pathlib import Path from typing import Any, Callable, Dict, Optional, Union import hydra from omegaconf import OmegaConf from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel from rich.pretty import pprint from relik.common.log import get_console_logger, get_logger from relik.common.upload import upload from relik.common.utils import CONFIG_NAME, from_cache, get_callable_from_string from relik.inference.data.objects import EntitySpan, RelikOutput from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer from relik.inference.data.window.manager import WindowManager from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction from relik.reader.relik_reader import RelikReader from relik.retriever.data.utils import batch_generator from relik.retriever.indexers.base import BaseDocumentIndex from relik.retriever.pytorch_modules.model import GoldenRetriever logger = get_logger(__name__) console_logger = get_console_logger() class Relik: """ Relik main class. It is a wrapper around a retriever and a reader. Args: retriever (`Optional[GoldenRetriever]`, `optional`): The retriever to use. If `None`, a retriever will be instantiated from the provided `question_encoder`, `passage_encoder` and `document_index`. Defaults to `None`. question_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`): The question encoder to use. If `retriever` is `None`, a retriever will be instantiated from this parameter. Defaults to `None`. passage_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`): The passage encoder to use. If `retriever` is `None`, a retriever will be instantiated from this parameter. Defaults to `None`. document_index (`Optional[Union[str, BaseDocumentIndex]]`, `optional`): The document index to use. If `retriever` is `None`, a retriever will be instantiated from this parameter. Defaults to `None`. reader (`Optional[Union[str, RelikReader]]`, `optional`): The reader to use. If `None`, a reader will be instantiated from the provided `reader`. Defaults to `None`. retriever_device (`str`, `optional`, defaults to `cpu`): The device to use for the retriever. """ def __init__( self, retriever: GoldenRetriever | None = None, question_encoder: str | GoldenRetrieverModel | None = None, passage_encoder: str | GoldenRetrieverModel | None = None, document_index: str | BaseDocumentIndex | None = None, reader: str | RelikReader | None = None, device: str = "cpu", retriever_device: str | None = None, document_index_device: str | None = None, reader_device: str | None = None, precision: int = 32, retriever_precision: int | None = None, document_index_precision: int | None = None, reader_precision: int | None = None, reader_kwargs: dict | None = None, retriever_kwargs: dict | None = None, candidates_preprocessing_fn: str | Callable | None = None, top_k: int | None = None, window_size: int | None = None, window_stride: int | None = None, **kwargs, ) -> None: # retriever retriever_device = retriever_device or device document_index_device = document_index_device or device retriever_precision = retriever_precision or precision document_index_precision = document_index_precision or precision if retriever is None and question_encoder is None: raise ValueError( "Either `retriever` or `question_encoder` must be provided" ) if retriever is None: self.retriever_kwargs = dict( question_encoder=question_encoder, passage_encoder=passage_encoder, document_index=document_index, device=retriever_device, precision=retriever_precision, index_device=document_index_device, index_precision=document_index_precision, ) # overwrite default_retriever_kwargs with retriever_kwargs self.retriever_kwargs.update(retriever_kwargs or {}) retriever = GoldenRetriever(**self.retriever_kwargs) retriever.training = False retriever.eval() self.retriever = retriever # reader self.reader_device = reader_device or device self.reader_precision = reader_precision or precision self.reader_kwargs = reader_kwargs if isinstance(reader, str): reader_kwargs = reader_kwargs or {} reader = RelikReaderForSpanExtraction(reader, **reader_kwargs) self.reader = reader # windowization stuff self.tokenizer = SpacyTokenizer(language="en") self.window_manager: WindowManager | None = None # candidates preprocessing # TODO: maybe move this logic somewhere else candidates_preprocessing_fn = candidates_preprocessing_fn or (lambda x: x) if isinstance(candidates_preprocessing_fn, str): candidates_preprocessing_fn = get_callable_from_string( candidates_preprocessing_fn ) self.candidates_preprocessing_fn = candidates_preprocessing_fn # inference params self.top_k = top_k self.window_size = window_size self.window_stride = window_stride def __call__( self, text: Union[str, list], top_k: Optional[int] = None, window_size: Optional[int] = None, window_stride: Optional[int] = None, retriever_batch_size: Optional[int] = 32, reader_batch_size: Optional[int] = 32, return_also_windows: bool = False, **kwargs, ) -> Union[RelikOutput, list[RelikOutput]]: """ Annotate a text with entities. Args: text (`str` or `list`): The text to annotate. If a list is provided, each element of the list will be annotated separately. top_k (`int`, `optional`, defaults to `None`): The number of candidates to retrieve for each window. window_size (`int`, `optional`, defaults to `None`): The size of the window. If `None`, the whole text will be annotated. window_stride (`int`, `optional`, defaults to `None`): The stride of the window. If `None`, there will be no overlap between windows. retriever_batch_size (`int`, `optional`, defaults to `None`): The batch size to use for the retriever. The whole input is the batch for the retriever. reader_batch_size (`int`, `optional`, defaults to `None`): The batch size to use for the reader. The whole input is the batch for the reader. return_also_windows (`bool`, `optional`, defaults to `False`): Whether to return the windows in the output. **kwargs: Additional keyword arguments to pass to the retriever and the reader. Returns: `RelikOutput` or `list[RelikOutput]`: The annotated text. If a list was provided as input, a list of `RelikOutput` objects will be returned. """ if top_k is None: top_k = self.top_k or 100 if window_size is None: window_size = self.window_size if window_stride is None: window_stride = self.window_stride if isinstance(text, str): text = [text] if window_size is not None: if self.window_manager is None: self.window_manager = WindowManager(self.tokenizer) if window_size == "sentence": # todo: implement sentence windowizer raise NotImplementedError("Sentence windowizer not implemented yet") # if window_size < window_stride: # raise ValueError( # f"Window size ({window_size}) must be greater than window stride ({window_stride})" # ) # window generator windows = [ window for doc_id, t in enumerate(text) for window in self.window_manager.create_windows( t, window_size=window_size, stride=window_stride, doc_id=doc_id, ) ] # retrieve candidates first windows_candidates = [] # TODO: Move batching inside retriever for batch in batch_generator(windows, batch_size=retriever_batch_size): retriever_out = self.retriever.retrieve([b.text for b in batch], k=top_k) windows_candidates.extend( [[p.label for p in predictions] for predictions in retriever_out] ) # add passage to the windows for window, candidates in zip(windows, windows_candidates): window.window_candidates = [ self.candidates_preprocessing_fn(c) for c in candidates ] windows = self.reader.read(samples=windows, max_batch_size=reader_batch_size) windows = self.window_manager.merge_windows(windows) # transform predictions into RelikOutput objects output = [] for w in windows: sample_output = RelikOutput( text=text[w.doc_id], labels=sorted( [ EntitySpan( start=ss, end=se, label=sl, text=text[w.doc_id][ss:se] ) for ss, se, sl in w.predicted_window_labels_chars ], key=lambda x: x.start, ), ) output.append(sample_output) if return_also_windows: for i, sample_output in enumerate(output): sample_output.windows = [w for w in windows if w.doc_id == i] # if only one text was provided, return a single RelikOutput object if len(output) == 1: return output[0] return output @classmethod def from_pretrained( cls, model_name_or_dir: Union[str, os.PathLike], config_kwargs: Optional[Dict] = None, config_file_name: str = CONFIG_NAME, *args, **kwargs, ) -> "Relik": cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) model_dir = from_cache( model_name_or_dir, filenames=[config_file_name], cache_dir=cache_dir, force_download=force_download, ) config_path = model_dir / config_file_name if not config_path.exists(): raise FileNotFoundError( f"Model configuration file not found at {config_path}." ) # overwrite config with config_kwargs config = OmegaConf.load(config_path) if config_kwargs is not None: # TODO: check merging behavior config = OmegaConf.merge(config, OmegaConf.create(config_kwargs)) # do we want to print the config? I like it pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True) # load relik from config relik = hydra.utils.instantiate(config, *args, **kwargs) return relik def save_pretrained( self, output_dir: Union[str, os.PathLike], config: Optional[Dict[str, Any]] = None, config_file_name: Optional[str] = None, save_weights: bool = False, push_to_hub: bool = False, model_id: Optional[str] = None, organization: Optional[str] = None, repo_name: Optional[str] = None, **kwargs, ): """ Save the configuration of Relik to the specified directory as a YAML file. Args: output_dir (`str`): The directory to save the configuration file to. config (`Optional[Dict[str, Any]]`, `optional`): The configuration to save. If `None`, the current configuration will be saved. Defaults to `None`. config_file_name (`Optional[str]`, `optional`): The name of the configuration file. Defaults to `config.yaml`. save_weights (`bool`, `optional`): Whether to save the weights of the model. Defaults to `False`. push_to_hub (`bool`, `optional`): Whether to push the saved model to the hub. Defaults to `False`. model_id (`Optional[str]`, `optional`): The id of the model to push to the hub. If `None`, the name of the directory will be used. Defaults to `None`. organization (`Optional[str]`, `optional`): The organization to push the model to. Defaults to `None`. repo_name (`Optional[str]`, `optional`): The name of the repository to push the model to. Defaults to `None`. **kwargs: Additional keyword arguments to pass to `OmegaConf.save`. """ if config is None: # create a default config config = { "_target_": f"{self.__class__.__module__}.{self.__class__.__name__}" } if self.retriever is not None: if self.retriever.question_encoder is not None: config[ "question_encoder" ] = self.retriever.question_encoder.name_or_path if self.retriever.passage_encoder is not None: config[ "passage_encoder" ] = self.retriever.passage_encoder.name_or_path if self.retriever.document_index is not None: config["document_index"] = self.retriever.document_index.name_or_dir if self.reader is not None: config["reader"] = self.reader.model_path config["retriever_kwargs"] = self.retriever_kwargs config["reader_kwargs"] = self.reader_kwargs # expand the fn as to be able to save it and load it later config[ "candidates_preprocessing_fn" ] = f"{self.candidates_preprocessing_fn.__module__}.{self.candidates_preprocessing_fn.__name__}" # these are model-specific and should be saved config["top_k"] = self.top_k config["window_size"] = self.window_size config["window_stride"] = self.window_stride config_file_name = config_file_name or CONFIG_NAME # create the output directory output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Saving relik config to {output_dir / config_file_name}") # pretty print the config pprint(config, console=console_logger, expand_all=True) OmegaConf.save(config, output_dir / config_file_name) if save_weights: model_id = model_id or output_dir.name retriever_model_id = model_id + "-retriever" # save weights logger.info(f"Saving retriever to {output_dir / retriever_model_id}") self.retriever.save_pretrained( output_dir / retriever_model_id, question_encoder_name=retriever_model_id + "-question-encoder", passage_encoder_name=retriever_model_id + "-passage-encoder", document_index_name=retriever_model_id + "-index", push_to_hub=push_to_hub, organization=organization, repo_name=repo_name, **kwargs, ) reader_model_id = model_id + "-reader" logger.info(f"Saving reader to {output_dir / reader_model_id}") self.reader.save_pretrained( output_dir / reader_model_id, push_to_hub=push_to_hub, organization=organization, repo_name=repo_name, **kwargs, ) if push_to_hub: # push to hub logger.info(f"Pushing to hub") model_id = model_id or output_dir.name upload(output_dir, model_id, organization=organization, repo_name=repo_name) def main(): from pprint import pprint relik = Relik( question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder", document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder", reader="riccorl/relik-reader-aida-deberta-small", device="cuda", precision=16, top_k=100, window_size=32, window_stride=16, candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing", ) input_text = """ Bernie Ecclestone, the former boss of Formula One, has admitted fraud after failing to declare more than £400m held in a trust in Singapore. The 92-year-old billionaire did not disclose the trust to the government in July 2015. Appearing at Southwark Crown Court on Thursday, he told the judge "I plead guilty" after having previously pleaded not guilty. Ecclestone had been due to go on trial next month. """ preds = relik(input_text) pprint(preds) if __name__ == "__main__": main()