riccorl's picture
Upload models
8197b11
raw
history blame
5.53 kB
import argparse
import json
import logging
import os
from pathlib import Path
import time
from typing import Union
import torch
import tqdm
from relik.retriever import GoldenRetriever
from relik.common.log import get_logger
from relik.retriever.common.model_inputs import ModelInputs
from relik.retriever.data.base.datasets import BaseDataset
from relik.retriever.indexers.base import BaseDocumentIndex
from relik.retriever.indexers.faiss import FaissDocumentIndex
logger = get_logger(level=logging.INFO)
def compute_retriever_stats(dataset) -> None:
correct, total = 0, 0
for sample in dataset:
window_candidates = sample["window_candidates"]
window_candidates = [c.replace("_", " ").lower() for c in window_candidates]
for ss, se, label in sample["window_labels"]:
if label == "--NME--":
continue
if label.replace("_", " ").lower() in window_candidates:
correct += 1
total += 1
recall = correct / total
print("Recall:", recall)
@torch.no_grad()
def add_candidates(
retriever_name_or_path: Union[str, os.PathLike],
document_index_name_or_path: Union[str, os.PathLike],
input_path: Union[str, os.PathLike],
batch_size: int = 128,
num_workers: int = 4,
index_type: str = "Flat",
nprobe: int = 1,
device: str = "cpu",
precision: str = "fp32",
topics: bool = False,
):
document_index = BaseDocumentIndex.from_pretrained(
document_index_name_or_path,
# config_kwargs={
# "_target_": "relik.retriever.indexers.faiss.FaissDocumentIndex",
# "index_type": index_type,
# "nprobe": nprobe,
# },
device=device,
precision=precision,
)
retriever = GoldenRetriever(
question_encoder=retriever_name_or_path,
document_index=document_index,
device=device,
precision=precision,
index_device=device,
index_precision=precision,
)
retriever.eval()
logger.info(f"Loading from {input_path}")
with open(input_path) as f:
samples = [json.loads(line) for line in f.readlines()]
topics = topics and "doc_topic" in samples[0]
# get tokenizer
tokenizer = retriever.question_tokenizer
collate_fn = lambda batch: ModelInputs(
tokenizer(
[b["text"] for b in batch],
text_pair=[b["doc_topic"] for b in batch] if topics else None,
padding=True,
return_tensors="pt",
truncation=True,
)
)
logger.info(f"Creating dataloader with batch size {batch_size}")
dataloader = torch.utils.data.DataLoader(
BaseDataset(name="passage", data=samples),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=False,
collate_fn=collate_fn,
)
# we also dump the candidates to a file after a while
retrieved_accumulator = []
with torch.inference_mode():
start = time.time()
num_completed_docs = 0
for documents_batch in tqdm.tqdm(dataloader):
retrieve_kwargs = {
**documents_batch,
"k": 100,
"precision": precision,
}
batch_out = retriever.retrieve(**retrieve_kwargs)
retrieved_accumulator.extend(batch_out)
end = time.time()
output_data = []
# get the correct document from the original dataset
# the dataloader is not shuffled, so we can just count the number of
# documents we have seen so far
for sample, retrieved in zip(
samples[
num_completed_docs : num_completed_docs + len(retrieved_accumulator)
],
retrieved_accumulator,
):
candidate_titles = [c.label.split(" <def>", 1)[0] for c in retrieved]
sample["window_candidates"] = candidate_titles
sample["window_candidates_scores"] = [c.score for c in retrieved]
output_data.append(sample)
# for sample in output_data:
# f_out.write(json.dumps(sample) + "\n")
num_completed_docs += len(retrieved_accumulator)
retrieved_accumulator = []
compute_retriever_stats(output_data)
print(f"Retrieval took {end - start:.2f} seconds")
if __name__ == "__main__":
# arg_parser = argparse.ArgumentParser()
# arg_parser.add_argument("--retriever_name_or_path", type=str, required=True)
# arg_parser.add_argument("--document_index_name_or_path", type=str, required=True)
# arg_parser.add_argument("--input_path", type=str, required=True)
# arg_parser.add_argument("--output_path", type=str, required=True)
# arg_parser.add_argument("--batch_size", type=int, default=128)
# arg_parser.add_argument("--device", type=str, default="cuda")
# arg_parser.add_argument("--index_device", type=str, default="cpu")
# arg_parser.add_argument("--precision", type=str, default="fp32")
# add_candidates(**vars(arg_parser.parse_args()))
add_candidates(
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered",
"/root/relik-spaces/data/reader/aida/testa_windowed.jsonl",
# index_type="HNSW32",
# index_type="IVF1024,PQ8",
# nprobe=1,
topics=True,
device="cuda",
)