|
import logging |
|
import time |
|
from pathlib import Path |
|
from typing import List, Optional, Set |
|
|
|
import lightning as pl |
|
import torch |
|
from lightning.pytorch.trainer.states import RunningStage |
|
from omegaconf import DictConfig |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from relik.common.log import get_logger |
|
from relik.retriever.callbacks.base import NLPTemplateCallback, PredictionCallback |
|
from relik.retriever.common.model_inputs import ModelInputs |
|
from relik.retriever.data.base.datasets import BaseDataset |
|
from relik.retriever.data.datasets import GoldenRetrieverDataset |
|
from relik.retriever.indexers.base import BaseDocumentIndex |
|
from relik.retriever.pytorch_modules.model import GoldenRetriever |
|
|
|
logger = get_logger(__name__, level=logging.INFO) |
|
|
|
|
|
class GoldenRetrieverPredictionCallback(PredictionCallback): |
|
def __init__( |
|
self, |
|
k: int | None = None, |
|
batch_size: int = 32, |
|
num_workers: int = 8, |
|
document_index: BaseDocumentIndex | None = None, |
|
precision: str | int = 32, |
|
force_reindex: bool = True, |
|
retriever_dir: Optional[Path] = None, |
|
stages: Set[str | RunningStage] | None = None, |
|
other_callbacks: List[DictConfig] | List[NLPTemplateCallback] | None = None, |
|
dataset: DictConfig | BaseDataset | None = None, |
|
dataloader: DataLoader | None = None, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__(batch_size, stages, other_callbacks, dataset, dataloader) |
|
self.k = k |
|
self.num_workers = num_workers |
|
self.document_index = document_index |
|
self.precision = precision |
|
self.force_reindex = force_reindex |
|
self.retriever_dir = retriever_dir |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
trainer: pl.Trainer, |
|
pl_module: pl.LightningModule, |
|
datasets: DictConfig |
|
| BaseDataset |
|
| List[DictConfig] |
|
| List[BaseDataset] |
|
| None = None, |
|
dataloaders: DataLoader | List[DataLoader] | None = None, |
|
*args, |
|
**kwargs, |
|
) -> dict: |
|
stage = trainer.state.stage |
|
logger.info(f"Computing predictions for stage {stage.value}") |
|
if stage not in self.stages: |
|
raise ValueError( |
|
f"Stage `{stage}` not supported, only {self.stages} are supported" |
|
) |
|
|
|
self.datasets, self.dataloaders = self._get_datasets_and_dataloaders( |
|
datasets, |
|
dataloaders, |
|
trainer, |
|
dataloader_kwargs=dict( |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers, |
|
pin_memory=True, |
|
shuffle=False, |
|
), |
|
) |
|
|
|
|
|
pl_module.eval() |
|
|
|
retriever: GoldenRetriever = pl_module.model |
|
|
|
|
|
dataloader_predictions = {} |
|
|
|
for dataloader_idx, dataloader in enumerate(self.dataloaders): |
|
current_dataset: GoldenRetrieverDataset = self.datasets[dataloader_idx] |
|
logger.info( |
|
f"Computing passage embeddings for dataset {current_dataset.name}" |
|
) |
|
|
|
tokenizer = current_dataset.tokenizer |
|
|
|
def collate_fn(x): |
|
return ModelInputs( |
|
tokenizer( |
|
x, |
|
truncation=True, |
|
padding=True, |
|
max_length=current_dataset.max_passage_length, |
|
return_tensors="pt", |
|
) |
|
) |
|
|
|
|
|
|
|
if (self.retriever_dir is not None and trainer.current_epoch == 0) or ( |
|
self.retriever_dir is not None and stage == RunningStage.TESTING |
|
): |
|
force_reindex = False |
|
else: |
|
force_reindex = self.force_reindex |
|
|
|
if ( |
|
not force_reindex |
|
and self.retriever_dir is not None |
|
and stage == RunningStage.TESTING |
|
): |
|
retriever = retriever.from_pretrained(self.retriever_dir) |
|
|
|
|
|
retriever.eval() |
|
|
|
retriever.index( |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers, |
|
max_length=current_dataset.max_passage_length, |
|
collate_fn=collate_fn, |
|
precision=self.precision, |
|
compute_on_cpu=False, |
|
force_reindex=force_reindex, |
|
) |
|
|
|
|
|
predictions = [] |
|
start = time.time() |
|
for batch in tqdm( |
|
dataloader, |
|
desc=f"Computing predictions for dataset {current_dataset.name}", |
|
): |
|
batch = batch.to(pl_module.device) |
|
|
|
retriever_output = retriever.retrieve( |
|
**batch.questions, k=self.k, precision=self.precision |
|
) |
|
|
|
for batch_idx, retrieved_samples in enumerate(retriever_output): |
|
|
|
gold_passages = batch["positives"][batch_idx] |
|
|
|
gold_passage_indices = [] |
|
for passage in gold_passages: |
|
try: |
|
gold_passage_indices.append( |
|
retriever.get_index_from_passage(passage) |
|
) |
|
except ValueError: |
|
logger.warning( |
|
f"Passage `{passage}` not found in the index. " |
|
"We will skip it, but the results might not reflect the " |
|
"actual performance." |
|
) |
|
pass |
|
retrieved_indices = [r.document.id for r in retrieved_samples if r] |
|
retrieved_passages = [ |
|
retriever.get_passage_from_index(i) for i in retrieved_indices |
|
] |
|
retrieved_scores = [r.score for r in retrieved_samples] |
|
|
|
correct_indices = set(gold_passage_indices) & set(retrieved_indices) |
|
|
|
wrong_indices = set(retrieved_indices) - set(gold_passage_indices) |
|
|
|
prediction_output = dict( |
|
sample_idx=batch.sample_idx[batch_idx], |
|
gold=gold_passages, |
|
predictions=retrieved_passages, |
|
scores=retrieved_scores, |
|
correct=[ |
|
retriever.get_passage_from_index(i) for i in correct_indices |
|
], |
|
wrong=[ |
|
retriever.get_passage_from_index(i) for i in wrong_indices |
|
], |
|
) |
|
predictions.append(prediction_output) |
|
end = time.time() |
|
logger.info(f"Time to retrieve: {str(end - start)}") |
|
|
|
dataloader_predictions[dataloader_idx] = predictions |
|
|
|
|
|
|
|
|
|
|
|
return dataloader_predictions |
|
|