import logging from typing import Iterable, Iterator, List, Optional import hydra import torch from lightning.pytorch.utilities import move_data_to_device from torch.utils.data import DataLoader from tqdm import tqdm from relik.reader.data.patches import merge_patches_predictions from relik.reader.data.relik_reader_sample import ( RelikReaderSample, load_relik_reader_samples, ) from relik.reader.relik_reader_core import RelikReaderCoreModel from relik.reader.utils.special_symbols import NME_SYMBOL logger = logging.getLogger(__name__) def convert_tokens_to_char_annotations( sample: RelikReaderSample, remove_nmes: bool = False ): char_annotations = set() for ( predicted_entity, predicted_spans, ) in sample.predicted_window_labels.items(): if predicted_entity == NME_SYMBOL and remove_nmes: continue for span_start, span_end in predicted_spans: span_start = sample.token2char_start[str(span_start)] span_end = sample.token2char_end[str(span_end)] char_annotations.add((span_start, span_end, predicted_entity)) char_probs_annotations = dict() for ( span_start, span_end, ), candidates_probs in sample.span_title_probabilities.items(): span_start = sample.token2char_start[str(span_start)] span_end = sample.token2char_end[str(span_end)] char_probs_annotations[(span_start, span_end)] = { title for title, _ in candidates_probs } sample.predicted_window_labels_chars = char_annotations sample.probs_window_labels_chars = char_probs_annotations class RelikReaderPredictor: def __init__( self, relik_reader_core: RelikReaderCoreModel, dataset_conf: Optional[dict] = None, predict_nmes: bool = False, ) -> None: self.relik_reader_core = relik_reader_core self.dataset_conf = dataset_conf self.predict_nmes = predict_nmes if self.dataset_conf is not None: # instantiate dataset self.dataset = hydra.utils.instantiate( dataset_conf, dataset_path=None, samples=None, ) def predict( self, path: Optional[str], samples: Optional[Iterable[RelikReaderSample]], dataset_conf: Optional[dict], token_batch_size: int = 1024, progress_bar: bool = False, **kwargs, ) -> List[RelikReaderSample]: annotated_samples = list( self._predict(path, samples, dataset_conf, token_batch_size, progress_bar) ) for sample in annotated_samples: merge_patches_predictions(sample) convert_tokens_to_char_annotations( sample, remove_nmes=not self.predict_nmes ) return annotated_samples def _predict( self, path: Optional[str], samples: Optional[Iterable[RelikReaderSample]], dataset_conf: dict, token_batch_size: int = 1024, progress_bar: bool = False, **kwargs, ) -> Iterator[RelikReaderSample]: assert ( path is not None or samples is not None ), "Either predict on a path or on an iterable of samples" samples = load_relik_reader_samples(path) if samples is None else samples # setup infrastructure to re-yield in order def samples_it(): for i, sample in enumerate(samples): assert sample._mixin_prediction_position is None sample._mixin_prediction_position = i yield sample next_prediction_position = 0 position2predicted_sample = {} # instantiate dataset if getattr(self, "dataset", None) is not None: dataset = self.dataset dataset.samples = samples_it() dataset.tokens_per_batch = token_batch_size else: dataset = hydra.utils.instantiate( dataset_conf, dataset_path=None, samples=samples_it(), tokens_per_batch=token_batch_size, ) # instantiate dataloader iterator = DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False) if progress_bar: iterator = tqdm(iterator, desc="Predicting") model_device = next(self.relik_reader_core.parameters()).device with torch.inference_mode(): for batch in iterator: # do batch predict with torch.autocast( "cpu" if model_device == torch.device("cpu") else "cuda" ): batch = move_data_to_device(batch, model_device) batch_out = self.relik_reader_core.batch_predict(**batch) # update prediction position position for sample in batch_out: if sample._mixin_prediction_position >= next_prediction_position: position2predicted_sample[ sample._mixin_prediction_position ] = sample # yield while next_prediction_position in position2predicted_sample: yield position2predicted_sample[next_prediction_position] del position2predicted_sample[next_prediction_position] next_prediction_position += 1 if len(position2predicted_sample) > 0: logger.warning( "It seems samples have been discarded in your dataset. " "This means that you WON'T have a prediction for each input sample. " "Prediction order will also be partially disrupted" ) for k, v in sorted(position2predicted_sample.items(), key=lambda x: x[0]): yield v if progress_bar: iterator.close()