from relik.retriever.trainer import RetrieverTrainer from relik import GoldenRetriever from relik.retriever.indexers.inmemory import InMemoryDocumentIndex from relik.retriever.data.datasets import AidaInBatchNegativesDataset if __name__ == "__main__": # instantiate retriever document_index = InMemoryDocumentIndex( documents="/root/golden-retriever-v2/data/dpr-like/el/definitions.txt", device="cuda", precision="16", ) retriever = GoldenRetriever( question_encoder="intfloat/e5-small-v2", document_index=document_index ) train_dataset = AidaInBatchNegativesDataset( name="aida_train", path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/train.jsonl", tokenizer=retriever.question_tokenizer, question_batch_size=64, passage_batch_size=400, max_passage_length=64, use_topics=True, shuffle=True, ) val_dataset = AidaInBatchNegativesDataset( name="aida_val", path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/val.jsonl", tokenizer=retriever.question_tokenizer, question_batch_size=64, passage_batch_size=400, max_passage_length=64, use_topics=True, ) trainer = RetrieverTrainer( retriever=retriever, train_dataset=train_dataset, val_dataset=val_dataset, max_steps=25_000, wandb_offline_mode=True, ) trainer.train()