|
import logging |
|
import random |
|
import time |
|
from copy import deepcopy |
|
from pathlib import Path |
|
from typing import List, Optional, Set, Union |
|
|
|
import lightning as pl |
|
import torch |
|
from lightning.pytorch.trainer.states import RunningStage |
|
from omegaconf import DictConfig |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from relik.common.log import get_console_logger, get_logger |
|
from relik.retriever.callbacks.base import PredictionCallback |
|
from relik.retriever.common.model_inputs import ModelInputs |
|
from relik.retriever.data.base.datasets import BaseDataset |
|
from relik.retriever.data.datasets import GoldenRetrieverDataset |
|
from relik.retriever.data.utils import HardNegativesManager |
|
from relik.retriever.indexers.base import BaseDocumentIndex |
|
from relik.retriever.pytorch_modules.model import GoldenRetriever |
|
|
|
console_logger = get_console_logger() |
|
logger = get_logger(__name__, level=logging.INFO) |
|
|
|
|
|
class GoldenRetrieverPredictionCallback(PredictionCallback): |
|
def __init__( |
|
self, |
|
k: Optional[int] = None, |
|
batch_size: int = 32, |
|
num_workers: int = 8, |
|
document_index: Optional[BaseDocumentIndex] = None, |
|
precision: Union[str, int] = 32, |
|
force_reindex: bool = True, |
|
retriever_dir: Optional[Path] = None, |
|
stages: Optional[Set[Union[str, RunningStage]]] = None, |
|
other_callbacks: Optional[List[DictConfig]] = None, |
|
dataset: Optional[Union[DictConfig, BaseDataset]] = None, |
|
dataloader: Optional[DataLoader] = None, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__(batch_size, stages, other_callbacks, dataset, dataloader) |
|
self.k = k |
|
self.num_workers = num_workers |
|
self.document_index = document_index |
|
self.precision = precision |
|
self.force_reindex = force_reindex |
|
self.retriever_dir = retriever_dir |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
trainer: pl.Trainer, |
|
pl_module: pl.LightningModule, |
|
datasets: Optional[ |
|
Union[DictConfig, BaseDataset, List[DictConfig], List[BaseDataset]] |
|
] = None, |
|
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, |
|
*args, |
|
**kwargs, |
|
) -> dict: |
|
stage = trainer.state.stage |
|
logger.info(f"Computing predictions for stage {stage.value}") |
|
if stage not in self.stages: |
|
raise ValueError( |
|
f"Stage `{stage}` not supported, only {self.stages} are supported" |
|
) |
|
|
|
self.datasets, self.dataloaders = self._get_datasets_and_dataloaders( |
|
datasets, |
|
dataloaders, |
|
trainer, |
|
dataloader_kwargs=dict( |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers, |
|
pin_memory=True, |
|
shuffle=False, |
|
), |
|
) |
|
|
|
|
|
pl_module.eval() |
|
|
|
retriever: GoldenRetriever = pl_module.model |
|
|
|
|
|
dataloader_predictions = {} |
|
|
|
for dataloader_idx, dataloader in enumerate(self.dataloaders): |
|
current_dataset: GoldenRetrieverDataset = self.datasets[dataloader_idx] |
|
logger.info( |
|
f"Computing passage embeddings for dataset {current_dataset.name}" |
|
) |
|
|
|
|
|
tokenizer = current_dataset.tokenizer |
|
|
|
def collate_fn(x): |
|
return ModelInputs( |
|
tokenizer( |
|
x, |
|
truncation=True, |
|
padding=True, |
|
max_length=current_dataset.max_passage_length, |
|
return_tensors="pt", |
|
) |
|
) |
|
|
|
|
|
|
|
if (self.retriever_dir is not None and trainer.current_epoch == 0) or ( |
|
self.retriever_dir is not None and stage == RunningStage.TESTING |
|
): |
|
force_reindex = False |
|
else: |
|
force_reindex = self.force_reindex |
|
|
|
if ( |
|
not force_reindex |
|
and self.retriever_dir is not None |
|
and stage == RunningStage.TESTING |
|
): |
|
retriever = retriever.from_pretrained(self.retriever_dir) |
|
|
|
|
|
|
|
retriever.eval() |
|
|
|
retriever.index( |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers, |
|
max_length=current_dataset.max_passage_length, |
|
collate_fn=collate_fn, |
|
precision=self.precision, |
|
compute_on_cpu=False, |
|
force_reindex=force_reindex, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predictions = [] |
|
start = time.time() |
|
for batch in tqdm( |
|
dataloader, |
|
desc=f"Computing predictions for dataset {current_dataset.name}", |
|
): |
|
batch = batch.to(pl_module.device) |
|
|
|
retriever_output = retriever.retrieve( |
|
**batch.questions, k=self.k, precision=self.precision |
|
) |
|
|
|
for batch_idx, retrieved_samples in enumerate(retriever_output): |
|
|
|
gold_passages = batch["positives"][batch_idx] |
|
|
|
gold_passage_indices = [ |
|
retriever.get_index_from_passage(passage) |
|
for passage in gold_passages |
|
] |
|
retrieved_indices = [r.index for r in retrieved_samples] |
|
retrieved_passages = [r.label for r in retrieved_samples] |
|
retrieved_scores = [r.score for r in retrieved_samples] |
|
|
|
correct_indices = set(gold_passage_indices) & set(retrieved_indices) |
|
|
|
wrong_indices = set(retrieved_indices) - set(gold_passage_indices) |
|
|
|
prediction_output = dict( |
|
sample_idx=batch.sample_idx[batch_idx], |
|
gold=gold_passages, |
|
predictions=retrieved_passages, |
|
scores=retrieved_scores, |
|
correct=[ |
|
retriever.get_passage_from_index(i) for i in correct_indices |
|
], |
|
wrong=[ |
|
retriever.get_passage_from_index(i) for i in wrong_indices |
|
], |
|
) |
|
predictions.append(prediction_output) |
|
end = time.time() |
|
logger.info(f"Time to retrieve: {str(end - start)}") |
|
|
|
dataloader_predictions[dataloader_idx] = predictions |
|
|
|
|
|
|
|
|
|
|
|
return dataloader_predictions |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NegativeAugmentationCallback(GoldenRetrieverPredictionCallback): |
|
""" |
|
Callback that computes the predictions of a retriever model on a dataset and computes the |
|
negative examples for the training set. |
|
|
|
Args: |
|
k (:obj:`int`, `optional`, defaults to 100): |
|
The number of top-k retrieved passages to |
|
consider for the evaluation. |
|
batch_size (:obj:`int`, `optional`, defaults to 32): |
|
The batch size to use for the evaluation. |
|
num_workers (:obj:`int`, `optional`, defaults to 0): |
|
The number of workers to use for the evaluation. |
|
force_reindex (:obj:`bool`, `optional`, defaults to :obj:`False`): |
|
Whether to force the reindexing of the dataset. |
|
retriever_dir (:obj:`Path`, `optional`): |
|
The path to the retriever directory. If not specified, the retriever will be |
|
initialized from scratch. |
|
stages (:obj:`Set[str]`, `optional`): |
|
The stages to run the callback on. If not specified, the callback will be run on |
|
train, validation and test. |
|
other_callbacks (:obj:`List[DictConfig]`, `optional`): |
|
A list of other callbacks to run on the same stages. |
|
dataset (:obj:`Union[DictConfig, BaseDataset]`, `optional`): |
|
The dataset to use for the evaluation. If not specified, the dataset will be |
|
initialized from scratch. |
|
metrics_to_monitor (:obj:`List[str]`, `optional`): |
|
The metrics to monitor for the evaluation. |
|
threshold (:obj:`float`, `optional`, defaults to 0.8): |
|
The threshold to consider. If the recall score of the retriever is above the |
|
threshold, the negative examples will be added to the training set. |
|
max_negatives (:obj:`int`, `optional`, defaults to 5): |
|
The maximum number of negative examples to add to the training set. |
|
add_with_probability (:obj:`float`, `optional`, defaults to 1.0): |
|
The probability with which to add the negative examples to the training set. |
|
refresh_every_n_epochs (:obj:`int`, `optional`, defaults to 1): |
|
The number of epochs after which to refresh the index. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
k: int = 100, |
|
batch_size: int = 32, |
|
num_workers: int = 0, |
|
force_reindex: bool = False, |
|
retriever_dir: Optional[Path] = None, |
|
stages: Set[Union[str, RunningStage]] = None, |
|
other_callbacks: Optional[List[DictConfig]] = None, |
|
dataset: Optional[Union[DictConfig, BaseDataset]] = None, |
|
metrics_to_monitor: List[str] = None, |
|
threshold: float = 0.8, |
|
max_negatives: int = 5, |
|
add_with_probability: float = 1.0, |
|
refresh_every_n_epochs: int = 1, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
k=k, |
|
batch_size=batch_size, |
|
num_workers=num_workers, |
|
force_reindex=force_reindex, |
|
retriever_dir=retriever_dir, |
|
stages=stages, |
|
other_callbacks=other_callbacks, |
|
dataset=dataset, |
|
*args, |
|
**kwargs, |
|
) |
|
if metrics_to_monitor is None: |
|
metrics_to_monitor = ["val_loss"] |
|
self.metrics_to_monitor = metrics_to_monitor |
|
self.threshold = threshold |
|
self.max_negatives = max_negatives |
|
self.add_with_probability = add_with_probability |
|
self.refresh_every_n_epochs = refresh_every_n_epochs |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
trainer: pl.Trainer, |
|
pl_module: pl.LightningModule, |
|
*args, |
|
**kwargs, |
|
) -> dict: |
|
""" |
|
Computes the predictions of a retriever model on a dataset and computes the negative |
|
examples for the training set. |
|
|
|
Args: |
|
trainer (:obj:`pl.Trainer`): |
|
The trainer object. |
|
pl_module (:obj:`pl.LightningModule`): |
|
The lightning module. |
|
|
|
Returns: |
|
A dictionary containing the negative examples. |
|
""" |
|
stage = trainer.state.stage |
|
if stage not in self.stages: |
|
return {} |
|
|
|
if self.metrics_to_monitor not in trainer.logged_metrics: |
|
raise ValueError( |
|
f"Metric `{self.metrics_to_monitor}` not found in trainer.logged_metrics" |
|
f"Available metrics: {trainer.logged_metrics.keys()}" |
|
) |
|
if trainer.logged_metrics[self.metrics_to_monitor] < self.threshold: |
|
return {} |
|
|
|
if trainer.current_epoch % self.refresh_every_n_epochs != 0: |
|
return {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if trainer.current_epoch % self.refresh_every_n_epochs != 0: |
|
return {} |
|
|
|
logger.info( |
|
f"At least one metric from {self.metrics_to_monitor} is above threshold " |
|
f"{self.threshold}. Computing hard negatives." |
|
) |
|
|
|
|
|
trainer.datamodule.train_dataset.hn_manager = None |
|
dataset_copy = deepcopy(trainer.datamodule.train_dataset) |
|
predictions = super().__call__( |
|
trainer, |
|
pl_module, |
|
datasets=dataset_copy, |
|
dataloaders=DataLoader( |
|
dataset_copy.to_torch_dataset(), |
|
shuffle=False, |
|
batch_size=None, |
|
num_workers=self.num_workers, |
|
pin_memory=True, |
|
collate_fn=lambda x: x, |
|
), |
|
*args, |
|
**kwargs, |
|
) |
|
logger.info(f"Computing hard negatives for epoch {trainer.current_epoch}") |
|
|
|
|
|
predictions = list(predictions.values())[0] |
|
|
|
hard_negatives_list = {} |
|
for prediction in tqdm(predictions, desc="Collecting hard negatives"): |
|
if random.random() < 1 - self.add_with_probability: |
|
continue |
|
top_k_passages = prediction["predictions"] |
|
gold_passages = prediction["gold"] |
|
|
|
wrong_passages = [ |
|
passage_id |
|
for passage_id in top_k_passages |
|
if passage_id not in gold_passages |
|
][: self.max_negatives] |
|
hard_negatives_list[prediction["sample_idx"]] = wrong_passages |
|
|
|
trainer.datamodule.train_dataset.hn_manager = HardNegativesManager( |
|
tokenizer=trainer.datamodule.train_dataset.tokenizer, |
|
max_length=trainer.datamodule.train_dataset.max_passage_length, |
|
data=hard_negatives_list, |
|
) |
|
|
|
|
|
predictions = {0: predictions} |
|
return predictions |
|
|