|
import argparse |
|
import json |
|
import logging |
|
import os |
|
from pathlib import Path |
|
import time |
|
from typing import Union |
|
|
|
import torch |
|
import tqdm |
|
|
|
from relik.retriever import GoldenRetriever |
|
from relik.common.log import get_logger |
|
from relik.retriever.common.model_inputs import ModelInputs |
|
from relik.retriever.data.base.datasets import BaseDataset |
|
from relik.retriever.indexers.base import BaseDocumentIndex |
|
from relik.retriever.indexers.faiss import FaissDocumentIndex |
|
|
|
logger = get_logger(level=logging.INFO) |
|
|
|
|
|
def compute_retriever_stats(dataset) -> None: |
|
correct, total = 0, 0 |
|
for sample in dataset: |
|
window_candidates = sample["window_candidates"] |
|
window_candidates = [c.replace("_", " ").lower() for c in window_candidates] |
|
|
|
for ss, se, label in sample["window_labels"]: |
|
if label == "--NME--": |
|
continue |
|
if label.replace("_", " ").lower() in window_candidates: |
|
correct += 1 |
|
total += 1 |
|
|
|
recall = correct / total |
|
print("Recall:", recall) |
|
|
|
|
|
@torch.no_grad() |
|
def add_candidates( |
|
retriever_name_or_path: Union[str, os.PathLike], |
|
document_index_name_or_path: Union[str, os.PathLike], |
|
input_path: Union[str, os.PathLike], |
|
batch_size: int = 128, |
|
num_workers: int = 4, |
|
index_type: str = "Flat", |
|
nprobe: int = 1, |
|
device: str = "cpu", |
|
precision: str = "fp32", |
|
topics: bool = False, |
|
): |
|
document_index = BaseDocumentIndex.from_pretrained( |
|
document_index_name_or_path, |
|
|
|
|
|
|
|
|
|
|
|
device=device, |
|
precision=precision, |
|
) |
|
|
|
retriever = GoldenRetriever( |
|
question_encoder=retriever_name_or_path, |
|
document_index=document_index, |
|
device=device, |
|
precision=precision, |
|
index_device=device, |
|
index_precision=precision, |
|
) |
|
retriever.eval() |
|
|
|
logger.info(f"Loading from {input_path}") |
|
with open(input_path) as f: |
|
samples = [json.loads(line) for line in f.readlines()] |
|
|
|
topics = topics and "doc_topic" in samples[0] |
|
|
|
|
|
tokenizer = retriever.question_tokenizer |
|
collate_fn = lambda batch: ModelInputs( |
|
tokenizer( |
|
[b["text"] for b in batch], |
|
text_pair=[b["doc_topic"] for b in batch] if topics else None, |
|
padding=True, |
|
return_tensors="pt", |
|
truncation=True, |
|
) |
|
) |
|
logger.info(f"Creating dataloader with batch size {batch_size}") |
|
dataloader = torch.utils.data.DataLoader( |
|
BaseDataset(name="passage", data=samples), |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=num_workers, |
|
pin_memory=False, |
|
collate_fn=collate_fn, |
|
) |
|
|
|
|
|
retrieved_accumulator = [] |
|
with torch.inference_mode(): |
|
start = time.time() |
|
num_completed_docs = 0 |
|
|
|
for documents_batch in tqdm.tqdm(dataloader): |
|
retrieve_kwargs = { |
|
**documents_batch, |
|
"k": 100, |
|
"precision": precision, |
|
} |
|
batch_out = retriever.retrieve(**retrieve_kwargs) |
|
retrieved_accumulator.extend(batch_out) |
|
|
|
end = time.time() |
|
|
|
output_data = [] |
|
|
|
|
|
|
|
for sample, retrieved in zip( |
|
samples[ |
|
num_completed_docs : num_completed_docs + len(retrieved_accumulator) |
|
], |
|
retrieved_accumulator, |
|
): |
|
candidate_titles = [c.label.split(" <def>", 1)[0] for c in retrieved] |
|
sample["window_candidates"] = candidate_titles |
|
sample["window_candidates_scores"] = [c.score for c in retrieved] |
|
output_data.append(sample) |
|
|
|
|
|
|
|
|
|
num_completed_docs += len(retrieved_accumulator) |
|
retrieved_accumulator = [] |
|
|
|
compute_retriever_stats(output_data) |
|
print(f"Retrieval took {end - start:.2f} seconds") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
add_candidates( |
|
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder", |
|
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered", |
|
"/root/relik-spaces/data/reader/aida/testa_windowed.jsonl", |
|
|
|
|
|
|
|
topics=True, |
|
device="cuda", |
|
) |
|
|