|
from relik.retriever.trainer import RetrieverTrainer |
|
from relik import GoldenRetriever |
|
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex |
|
from relik.retriever.data.datasets import AidaInBatchNegativesDataset |
|
|
|
if __name__ == "__main__": |
|
|
|
document_index = InMemoryDocumentIndex( |
|
documents="/root/golden-retriever-v2/data/dpr-like/el/definitions.txt", |
|
device="cuda", |
|
precision="16", |
|
) |
|
retriever = GoldenRetriever( |
|
question_encoder="intfloat/e5-small-v2", document_index=document_index |
|
) |
|
|
|
train_dataset = AidaInBatchNegativesDataset( |
|
name="aida_train", |
|
path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/train.jsonl", |
|
tokenizer=retriever.question_tokenizer, |
|
question_batch_size=64, |
|
passage_batch_size=400, |
|
max_passage_length=64, |
|
use_topics=True, |
|
shuffle=True, |
|
) |
|
val_dataset = AidaInBatchNegativesDataset( |
|
name="aida_val", |
|
path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/val.jsonl", |
|
tokenizer=retriever.question_tokenizer, |
|
question_batch_size=64, |
|
passage_batch_size=400, |
|
max_passage_length=64, |
|
use_topics=True, |
|
) |
|
|
|
trainer = RetrieverTrainer( |
|
retriever=retriever, |
|
train_dataset=train_dataset, |
|
val_dataset=val_dataset, |
|
max_steps=25_000, |
|
wandb_offline_mode=True, |
|
) |
|
|
|
trainer.train() |
|
|