import logging import os import time from pathlib import Path from typing import Any, Dict, List, Optional, Union import hydra import torch from omegaconf import DictConfig, OmegaConf from relik.inference.data.splitters.blank_sentence_splitter import BlankSentenceSplitter from relik.common.log import get_logger from relik.common.upload import get_logged_in_username, upload from relik.common.utils import CONFIG_NAME, from_cache from relik.inference.data.objects import ( AnnotationType, RelikOutput, Span, TaskType, Triples, ) from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter from relik.inference.data.splitters.spacy_sentence_splitter import SpacySentenceSplitter from relik.inference.data.splitters.window_based_splitter import WindowSentenceSplitter from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer from relik.inference.data.window.manager import WindowManager from relik.reader.data.relik_reader_sample import RelikReaderSample from relik.reader.pytorch_modules.base import RelikReaderBase from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction from relik.retriever.indexers.base import BaseDocumentIndex from relik.retriever.indexers.document import Document from relik.retriever.pytorch_modules import PRECISION_MAP from relik.retriever.pytorch_modules.model import GoldenRetriever # set tokenizers parallelism to False os.environ["TOKENIZERS_PARALLELISM"] = os.getenv("TOKENIZERS_PARALLELISM", "false") LOG_QUERY = os.getenv("RELIK_LOG_QUERY_ON_FILE", "false").lower() == "true" logger = get_logger(__name__, level=logging.INFO) file_logger = None if LOG_QUERY: RELIK_LOG_PATH = Path(__file__).parent.parent.parent / "relik.log" # create file handler which logs even debug messages fh = logging.FileHandler(RELIK_LOG_PATH) fh.setLevel(logging.INFO) file_logger = get_logger("relik", level=logging.INFO) file_logger.addHandler(fh) class Relik: """ Relik main class. It is a wrapper around a retriever and a reader. Args: retriever (:obj:`GoldenRetriever`): The retriever to use. reader (:obj:`RelikReaderBase`): The reader to use. document_index (:obj:`BaseDocumentIndex`, `optional`): The document index to use. If `None`, the retriever's document index will be used. device (`str`, `optional`, defaults to `cpu`): The device to use for both the retriever and the reader. retriever_device (`str`, `optional`, defaults to `None`): The device to use for the retriever. If `None`, the `device` argument will be used. document_index_device (`str`, `optional`, defaults to `None`): The device to use for the document index. If `None`, the `device` argument will be used. reader_device (`str`, `optional`, defaults to `None`): The device to use for the reader. If `None`, the `device` argument will be used. precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `32`): The precision to use for both the retriever and the reader. retriever_precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `None`): The precision to use for the retriever. If `None`, the `precision` argument will be used. document_index_precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `None`): The precision to use for the document index. If `None`, the `precision` argument will be used. reader_precision (`int`, `str` or `torch.dtype`, `optional`, defaults to `None`): The precision to use for the reader. If `None`, the `precision` argument will be used. metadata_fields (`list[str]`, `optional`, defaults to `None`): The fields to add to the candidates for the reader. top_k (`int`, `optional`, defaults to `None`): The number of candidates to retrieve for each window. window_size (`int`, `optional`, defaults to `None`): The size of the window. If `None`, the whole text will be annotated. window_stride (`int`, `optional`, defaults to `None`): The stride of the window. If `None`, there will be no overlap between windows. **kwargs: Additional keyword arguments to pass to the retriever and the reader. """ def __init__( self, retriever: GoldenRetriever | DictConfig | Dict | None = None, reader: RelikReaderBase | DictConfig | None = None, device: str | None = None, retriever_device: str | None = None, document_index_device: str | None = None, reader_device: str | None = None, precision: int | str | torch.dtype | None = None, retriever_precision: int | str | torch.dtype | None = None, document_index_precision: int | str | torch.dtype | None = None, reader_precision: int | str | torch.dtype | None = None, task: TaskType | str = TaskType.SPAN, metadata_fields: list[str] | None = None, top_k: int | None = None, window_size: int | str | None = None, window_stride: int | None = None, retriever_kwargs: Dict[str, Any] | None = None, reader_kwargs: Dict[str, Any] | None = None, **kwargs, ) -> None: # parse task into a TaskType if isinstance(task, str): try: task = TaskType(task.lower()) except ValueError: raise ValueError( f"Task `{task}` not recognized. " f"Please choose one of {list(TaskType)}." ) self.task = task # organize devices if device is not None: if retriever_device is None: retriever_device = device if document_index_device is None: document_index_device = device if reader_device is None: reader_device = device # organize precision if precision is not None: if retriever_precision is None: retriever_precision = precision if document_index_precision is None: document_index_precision = precision if reader_precision is None: reader_precision = precision # retriever self.retriever: Dict[TaskType, GoldenRetriever] = { TaskType.SPAN: None, TaskType.TRIPLET: None, } if retriever: # check retriever type, it can be a GoldenRetriever, a DictConfig or a Dict if not isinstance(retriever, (GoldenRetriever, DictConfig, Dict)): raise ValueError( f"`retriever` must be a `GoldenRetriever`, a `DictConfig` or " f"a `Dict`, got `{type(retriever)}`." ) # we need to check weather the DictConfig is a DictConfig for an instance of GoldenRetriever # or a primitive Dict if isinstance(retriever, DictConfig): # then it is probably a primitive Dict if "_target_" not in retriever: retriever = OmegaConf.to_container(retriever, resolve=True) # convert the key to TaskType try: retriever = { TaskType(k.lower()): v for k, v in retriever.items() } except ValueError as e: raise ValueError( f"Please choose a valid task type (one of {list(TaskType)}) for each retriever." ) from e if isinstance(retriever, Dict): # convert the key to TaskType retriever = {TaskType(k): v for k, v in retriever.items()} else: retriever = {task: retriever} # instantiate each retriever if self.task in [TaskType.SPAN, TaskType.BOTH]: self.retriever[TaskType.SPAN] = self._instantiate_retriever( retriever[TaskType.SPAN], retriever_device, retriever_precision, None, document_index_device, document_index_precision, ) if self.task in [TaskType.TRIPLET, TaskType.BOTH]: self.retriever[TaskType.TRIPLET] = self._instantiate_retriever( retriever[TaskType.TRIPLET], retriever_device, retriever_precision, None, document_index_device, document_index_precision, ) # clean up None retrievers from the dictionary self.retriever = { task_type: r for task_type, r in self.retriever.items() if r is not None } # torch compile # self.retriever = {task_type: torch.compile(r, backend="onnxrt") for task_type, r in self.retriever.items()} # reader self.reader: RelikReaderBase | None = None if reader: reader = ( hydra.utils.instantiate( reader, device=reader_device, precision=reader_precision, ) if isinstance(reader, DictConfig) else reader ) reader.training = False reader.eval() if reader_device is not None: logger.info(f"Moving reader to `{reader_device}`.") reader.to(reader_device) if reader_precision is not None and reader.precision != PRECISION_MAP[reader_precision]: logger.info( f"Setting precision of reader to `{PRECISION_MAP[reader_precision]}`." ) reader.to(PRECISION_MAP[reader_precision]) self.reader = reader # self.reader = torch.compile(self.reader, backend="tvm") # windowization stuff self.tokenizer = SpacyTokenizer(language="en") # TODO: parametrize? self.sentence_splitter: BaseSentenceSplitter | None = None self.window_manager: WindowManager | None = None if metadata_fields is None: metadata_fields = [] self.metadata_fields = metadata_fields # inference params self.top_k = top_k self.window_size = window_size self.window_stride = window_stride @staticmethod def _instantiate_retriever( retriever, retriever_device, retriever_precision, document_index, document_index_device, document_index_precision, ): if not isinstance(retriever, GoldenRetriever): # convert to DictConfig retriever = hydra.utils.instantiate( OmegaConf.create(retriever), device=retriever_device, precision=retriever_precision, index_device=document_index_device, index_precision=document_index_precision, ) retriever.training = False retriever.eval() if document_index is not None: if retriever.document_index is not None: logger.info( "The Retriever already has a document index, replacing it with the provided one." "If you want to keep using the old one, please do not provide a document index." ) retriever.document_index = document_index # we override the device and the precision of the document index if provided if document_index_device is not None: logger.info(f"Moving document index to `{document_index_device}`.") retriever.document_index.to(document_index_device) if document_index_precision is not None: logger.info( f"Setting precision of document index to `{PRECISION_MAP[document_index_precision]}`." ) retriever.document_index.to(PRECISION_MAP[document_index_precision]) # retriever.document_index = document_index # now we can move the retriever to the right device and set the precision if retriever_device is not None: logger.info(f"Moving retriever to `{retriever_device}`.") retriever.to(retriever_device) if retriever_precision is not None: logger.info( f"Setting precision of retriever to `{PRECISION_MAP[retriever_precision]}`." ) retriever.to(PRECISION_MAP[retriever_precision]) return retriever def __call__( self, text: str | List[str] | None = None, windows: List[RelikReaderSample] | None = None, candidates: List[str] | List[Document] | Dict[TaskType, List[Document]] | None = None, mentions: List[List[int]] | List[List[List[int]]] | None = None, top_k: int | None = None, window_size: int | None = None, window_stride: int | None = None, is_split_into_words: bool = False, retriever_batch_size: int | None = 32, reader_batch_size: int | None = 32, return_also_windows: bool = False, annotation_type: str | AnnotationType = AnnotationType.CHAR, progress_bar: bool = False, **kwargs, ) -> Union[RelikOutput, list[RelikOutput]]: """ Annotate a text with entities. Args: text (`str` or `list`): The text to annotate. If a list is provided, each element of the list will be annotated separately. candidates (`list[str]`, `list[Document]`, `optional`, defaults to `None`): The candidates to use for the reader. If `None`, the candidates will be retrieved from the retriever. mentions (`list[list[int]]` or `list[list[list[int]]]`, `optional`, defaults to `None`): The mentions to use for the reader. If `None`, the mentions will be predicted by the reader. top_k (`int`, `optional`, defaults to `None`): The number of candidates to retrieve for each window. window_size (`int`, `optional`, defaults to `None`): The size of the window. If `None`, the whole text will be annotated. window_stride (`int`, `optional`, defaults to `None`): The stride of the window. If `None`, there will be no overlap between windows. retriever_batch_size (`int`, `optional`, defaults to `None`): The batch size to use for the retriever. The whole input is the batch for the retriever. reader_batch_size (`int`, `optional`, defaults to `None`): The batch size to use for the reader. The whole input is the batch for the reader. return_also_windows (`bool`, `optional`, defaults to `False`): Whether to return the windows in the output. annotation_type (`str` or `AnnotationType`, `optional`, defaults to `char`): The type of annotation to return. If `char`, the spans will be in terms of character offsets. If `word`, the spans will be in terms of word offsets. **kwargs: Additional keyword arguments to pass to the retriever and the reader. Returns: `RelikOutput` or `list[RelikOutput]`: The annotated text. If a list was provided as input, a list of `RelikOutput` objects will be returned. """ if text is None and windows is None: raise ValueError( "Either `text` or `windows` must be provided. Both are `None`." ) if isinstance(annotation_type, str): try: annotation_type = AnnotationType(annotation_type) except ValueError: raise ValueError( f"Annotation type {annotation_type} not recognized. " f"Please choose one of {list(AnnotationType)}." ) if top_k is None: top_k = self.top_k or 100 if window_size is None: window_size = self.window_size if window_stride is None: window_stride = self.window_stride if text: if isinstance(text, str): text = [text] if mentions is not None: mentions = [mentions] if file_logger is not None: file_logger.info("Annotating the following text:") for t in text: file_logger.info(f" {t}") if self.window_manager is None: if window_size == "none": self.sentence_splitter = BlankSentenceSplitter() elif window_size == "sentence": self.sentence_splitter = SpacySentenceSplitter() else: self.sentence_splitter = WindowSentenceSplitter( window_size=window_size, window_stride=window_stride ) self.window_manager = WindowManager( self.tokenizer, self.sentence_splitter ) if ( window_size not in ["sentence", "none"] and window_stride is not None and window_size < window_stride ): raise ValueError( f"Window size ({window_size}) must be greater than window stride ({window_stride})" ) if windows is None: # windows were provided, use them windows, blank_windows = self.window_manager.create_windows( text, window_size, window_stride, is_split_into_words=is_split_into_words, mentions=mentions ) else: blank_windows = [] text = {w.doc_id: w.text for w in windows} if candidates is not None and any( r is not None for r in self.retriever.values() ): logger.info( "Both candidates and a retriever were provided. " "Retriever will be ignored." ) windows_candidates = {TaskType.SPAN: None, TaskType.TRIPLET: None} if candidates is not None: # again, check if candidates is a dict if isinstance(candidates, Dict): if self.task not in candidates: raise ValueError( f"Task `{self.task}` not found in `candidates`." f"Please choose one of {list(TaskType)}." ) else: candidates = {self.task: candidates} for task_type, _candidates in candidates.items(): if isinstance(_candidates, list): _candidates = [ [ c if isinstance(c, Document) else Document(c) for c in _candidates[w.doc_id] ] for w in windows ] windows_candidates[task_type] = _candidates else: # retrieve candidates first if self.retriever is None: raise ValueError( "No retriever was provided, please provide a retriever or candidates." ) start_retr = time.time() for task_type, retriever in self.retriever.items(): retriever_out = retriever.retrieve( [w.text for w in windows], text_pair=[w.doc_topic.text if w.doc_topic is not None else None for w in windows], k=top_k, batch_size=retriever_batch_size, progress_bar=progress_bar, **kwargs, ) windows_candidates[task_type] = [ [p.document for p in predictions] for predictions in retriever_out ] end_retr = time.time() logger.info(f"Retrieval took {end_retr - start_retr} seconds.") # clean up None's windows_candidates = { t: c for t, c in windows_candidates.items() if c is not None } # add passage to the windows for task_type, task_candidates in windows_candidates.items(): for window, candidates in zip(windows, task_candidates): # construct the candidates for the reader formatted_candidates = [] for candidate in candidates: window_candidate_text = candidate.text for field in self.metadata_fields: window_candidate_text += f"{candidate.metadata.get(field, '')}" formatted_candidates.append(window_candidate_text) # create a member for the windows that is named like the task setattr(window, f"{task_type.value}_candidates", formatted_candidates) for task_type, task_candidates in windows_candidates.items(): for window in blank_windows: setattr(window, f"{task_type.value}_candidates", []) setattr(window, "predicted_spans", []) setattr(window, "predicted_triples", []) if self.reader is not None: start_read = time.time() windows = self.reader.read( samples=windows, max_batch_size=reader_batch_size, annotation_type=annotation_type, progress_bar=progress_bar, **kwargs, ) end_read = time.time() logger.info(f"Reading took {end_read - start_read} seconds.") # TODO: check merging behavior without a reader # do we want to merge windows if there is no reader? if self.window_size is not None and self.window_size not in ["sentence", "none"]: start_w = time.time() windows = windows + blank_windows windows.sort(key=lambda x: (x.doc_id, x.offset)) merged_windows = self.window_manager.merge_windows(windows) end_w = time.time() logger.info(f"Merging took {end_w - start_w} seconds.") else: merged_windows = windows else: windows = windows + blank_windows windows.sort(key=lambda x: (x.doc_id, x.offset)) merged_windows = windows # transform predictions into RelikOutput objects output = [] for w in merged_windows: span_labels = [] triples_labels = [] # span extraction should always be present if getattr(w, "predicted_spans", None) is not None: span_labels = sorted( [ Span(start=ss, end=se, label=sl, text=text[w.doc_id][ss:se]) if annotation_type == AnnotationType.CHAR else Span(start=ss, end=se, label=sl, text=w.words[ss:se]) for ss, se, sl in w.predicted_spans ], key=lambda x: x.start, ) # triple extraction is optional, if here add it if getattr(w, "predicted_triples", None) is not None: triples_labels = [ Triples( subject=span_labels[subj], label=label, object=span_labels[obj], confidence=conf, ) for subj, label, obj, conf in w.predicted_triples ] # create the output sample_output = RelikOutput( text=text[w.doc_id], tokens=w.words, spans=span_labels, triples=triples_labels, candidates={ task_type: [ r.document_index.documents.get_document_from_text(c) for c in getattr(w, f"{task_type.value}_candidates", []) if r.document_index.documents.get_document_from_text(c) is not None ] for task_type, r in self.retriever.items() }, ) output.append(sample_output) # add windows to the output if requested # do we want to force windows to be returned if there is no reader? if return_also_windows: for i, sample_output in enumerate(output): sample_output.windows = [w for w in windows if w.doc_id == i] # if only one text was provided, return a single RelikOutput object if len(output) == 1: return output[0] return output @classmethod def from_pretrained( cls, model_name_or_dir: Union[str, os.PathLike], config_file_name: str = CONFIG_NAME, *args, **kwargs, ) -> "Relik": """ Instantiate a `Relik` from a pretrained model. Args: model_name_or_dir (`str` or `os.PathLike`): The name or path of the model to load. config_file_name (`str`, `optional`, defaults to `config.yaml`): The name of the configuration file to load. *args: Additional positional arguments to pass to `OmegaConf.merge`. **kwargs: Additional keyword arguments to pass to `OmegaConf.merge`. Returns: `Relik`: The instantiated `Relik`. """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) model_dir = from_cache( model_name_or_dir, filenames=[config_file_name], cache_dir=cache_dir, force_download=force_download, ) config_path = model_dir / config_file_name if not config_path.exists(): raise FileNotFoundError( f"Model configuration file not found at {config_path}." ) # overwrite config with config_kwargs config = OmegaConf.load(config_path) # if kwargs is not None: config = OmegaConf.merge(config, OmegaConf.create(kwargs)) # do we want to print the config? I like it logger.info(f"Loading Relik from {model_name_or_dir}") # load relik from config relik = hydra.utils.instantiate(config, _recursive_=False, *args) return relik def save_pretrained( self, output_dir: Union[str, os.PathLike], config: Optional[Dict[str, Any]] = None, config_file_name: Optional[str] = None, save_weights: bool = False, push_to_hub: bool = False, model_id: Optional[str] = None, organization: Optional[str] = None, repo_name: Optional[str] = None, retriever_model_id: Optional[str] = None, reader_model_id: Optional[str] = None, **kwargs, ): """ Save the configuration of Relik to the specified directory as a YAML file. Args: output_dir (`str`): The directory to save the configuration file to. config (`Optional[Dict[str, Any]]`, `optional`): The configuration to save. If `None`, the current configuration will be saved. Defaults to `None`. config_file_name (`Optional[str]`, `optional`): The name of the configuration file. Defaults to `config.yaml`. save_weights (`bool`, `optional`): Whether to save the weights of the model. Defaults to `False`. push_to_hub (`bool`, `optional`): Whether to push the saved model to the hub. Defaults to `False`. model_id (`Optional[str]`, `optional`): The id of the model to push to the hub. If `None`, the name of the directory will be used. Defaults to `None`. organization (`Optional[str]`, `optional`): The organization to push the model to. Defaults to `None`. repo_name (`Optional[str]`, `optional`): The name of the repository to push the model to. Defaults to `None`. retriever_model_id (`Optional[str]`, `optional`): The id of the retriever model to push to the hub. If `None`, the name of the directory will be used. Defaults to `None`. reader_model_id (`Optional[str]`, `optional`): The id of the reader model to push to the hub. If `None`, the name of the directory will be used. Defaults to `None`. **kwargs: Additional keyword arguments to pass to `OmegaConf.save`. """ # create the output directory output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) retrievers_names: Dict[TaskType, Dict | None] = { TaskType.SPAN: { "question_encoder_name": None, "passage_encoder_name": None, "document_index_name": None, }, TaskType.TRIPLET: { "question_encoder_name": None, "passage_encoder_name": None, "document_index_name": None, }, } if save_weights: # save weights # retriever model_id = model_id or output_dir.name retriever_model_id = retriever_model_id or f"retriever-{model_id}" for task_type, retriever in self.retriever.items(): if retriever is None: continue task_retriever_model_id = f"{retriever_model_id}-{task_type.value}" question_encoder_name = f"{task_retriever_model_id}-question-encoder" passage_encoder_name = f"{task_retriever_model_id}-passage-encoder" document_index_name = f"{task_retriever_model_id}-index" logger.info( f"Saving retriever to {output_dir / task_retriever_model_id}" ) retriever.save_pretrained( output_dir / task_retriever_model_id, question_encoder_name=question_encoder_name, passage_encoder_name=passage_encoder_name, document_index_name=document_index_name, push_to_hub=push_to_hub, organization=organization, **kwargs, ) retrievers_names[task_type] = { "reader_model_id": task_retriever_model_id, "question_encoder_name": question_encoder_name, "passage_encoder_name": passage_encoder_name, "document_index_name": document_index_name, } # reader reader_model_id = reader_model_id or f"reader-{model_id}" logger.info(f"Saving reader to {output_dir / reader_model_id}") self.reader.save_pretrained( output_dir / reader_model_id, push_to_hub=push_to_hub, organization=organization, **kwargs, ) if push_to_hub: user = organization or get_logged_in_username() # we need to update the config with the model ids that will # result from the push to hub for task_type, retriever_names in retrievers_names.items(): retriever_names[ "question_encoder_name" ] = f"{user}/{retriever_names['question_encoder_name']}" retriever_names[ "passage_encoder_name" ] = f"{user}/{retriever_names['passage_encoder_name']}" retriever_names[ "document_index_name" ] = f"{user}/{retriever_names['document_index_name']}" # question_encoder_name = f"{user}/{question_encoder_name}" # passage_encoder_name = f"{user}/{passage_encoder_name}" # document_index_name = f"{user}/{document_index_name}" reader_model_id = f"{user}/{reader_model_id}" else: for task_type, retriever_names in retrievers_names.items(): retriever_names["question_encoder_name"] = ( output_dir / retriever_names["question_encoder_name"] ) retriever_names["passage_encoder_name"] = ( output_dir / retriever_names["passage_encoder_name"] ) retriever_names["document_index_name"] = ( output_dir / retriever_names["document_index_name"] ) reader_model_id = output_dir / reader_model_id else: # save config only for task_type, retriever_names in retrievers_names.items(): retriever = self.retriever.get(task_type, None) if retriever is None: continue retriever_names[ "question_encoder_name" ] = retriever.question_encoder.name_or_path retriever_names[ "passage_encoder_name" ] = retriever.passage_encoder.name_or_path retriever_names[ "document_index_name" ] = retriever.document_index.name_or_path reader_model_id = self.reader.name_or_path if config is None: # create a default config config = { "_target_": f"{self.__class__.__module__}.{self.__class__.__name__}" } if self.retriever is not None: config["retriever"] = {} for task_type, retriever in self.retriever.items(): if retriever is None: continue config["retriever"][task_type.value] = { "_target_": f"{retriever.__class__.__module__}.{retriever.__class__.__name__}", } if retriever.question_encoder is not None: config["retriever"][task_type.value][ "question_encoder" ] = retrievers_names[task_type]["question_encoder_name"] if ( retriever.passage_encoder is not None and not retriever.passage_encoder_is_question_encoder ): config["retriever"][task_type.value][ "passage_encoder" ] = retrievers_names[task_type]["passage_encoder_name"] if retriever.document_index is not None: config["retriever"][task_type.value][ "document_index" ] = retrievers_names[task_type]["document_index_name"] if self.reader is not None: config["reader"] = { "_target_": f"{self.reader.__class__.__module__}.{self.reader.__class__.__name__}", "transformer_model": reader_model_id, } # these are model-specific and should be saved config["task"] = self.task config["metadata_fields"] = self.metadata_fields config["top_k"] = self.top_k config["window_size"] = self.window_size config["window_stride"] = self.window_stride config_file_name = config_file_name or CONFIG_NAME logger.info(f"Saving relik config to {output_dir / config_file_name}") OmegaConf.save(config, output_dir / config_file_name) if push_to_hub: # push to hub logger.info("Pushing to hub") model_id = model_id or output_dir.name upload( output_dir, model_id, filenames=[config_file_name], organization=organization, repo_name=repo_name, )