riccorl's picture
first commit
626eca0
raw
history blame
1.5 kB
from relik.retriever.trainer import RetrieverTrainer
from relik import GoldenRetriever
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
from relik.retriever.data.datasets import AidaInBatchNegativesDataset
if __name__ == "__main__":
# instantiate retriever
document_index = InMemoryDocumentIndex(
documents="/root/golden-retriever-v2/data/dpr-like/el/definitions.txt",
device="cuda",
precision="16",
)
retriever = GoldenRetriever(
question_encoder="intfloat/e5-small-v2", document_index=document_index
)
train_dataset = AidaInBatchNegativesDataset(
name="aida_train",
path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/train.jsonl",
tokenizer=retriever.question_tokenizer,
question_batch_size=64,
passage_batch_size=400,
max_passage_length=64,
use_topics=True,
shuffle=True,
)
val_dataset = AidaInBatchNegativesDataset(
name="aida_val",
path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/val.jsonl",
tokenizer=retriever.question_tokenizer,
question_batch_size=64,
passage_batch_size=400,
max_passage_length=64,
use_topics=True,
)
trainer = RetrieverTrainer(
retriever=retriever,
train_dataset=train_dataset,
val_dataset=val_dataset,
max_steps=25_000,
wandb_offline_mode=True,
)
trainer.train()