# Required to make the "experiments" dir the default one for the output of the models hydra: run: dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} model_name: ${model.language_model} # used to name the model in wandb project_name: relik-retriever # used to name the project in wandb defaults: - _self_ - model: golden_retriever - index: inmemory - loss: nce_loss - optimizer: radamw - scheduler: linear_scheduler - data: dataset_v2 - logging: wandb_logging - override hydra/job_logging: colorlog - override hydra/hydra_logging: colorlog train: # reproducibility seed: 42 set_determinism_the_old_way: False # torch parameters float32_matmul_precision: "medium" # if true, only test the model only_test: False # if provided, initialize the model with the weights from the checkpoint pretrain_ckpt_path: null # if provided, start training from the checkpoint checkpoint_path: null # task specific parameter top_k: 100 # pl_trainer pl_trainer: _target_: lightning.Trainer accelerator: gpu devices: 1 num_nodes: 1 strategy: auto accumulate_grad_batches: 1 gradient_clip_val: 1.0 val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps check_val_every_n_epoch: 1 max_epochs: 0 max_steps: 220_000 deterministic: True fast_dev_run: False precision: 16 reload_dataloaders_every_n_epochs: 1 early_stopping_callback: null # _target_: lightning.callbacks.EarlyStopping # monitor: validate_recall@${train.top_k} # mode: max # patience: 15 model_checkpoint_callback: _target_: lightning.callbacks.ModelCheckpoint monitor: validate_recall@${train.top_k} mode: max verbose: True save_top_k: 1 save_last: True filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}" auto_insert_metric_name: False callbacks: prediction_callback: _target_: relik.retriever.callbacks.training_callbacks.GoldenRetrieverPredictionCallback k: ${train.top_k} batch_size: 128 precision: 16 index_precision: 16 other_callbacks: - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback k: ${train.top_k} verbose: True - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback k: 50 verbose: True prog_bar: False - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback k: ${train.top_k} verbose: True - _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback hard_negatives_callback: _target_: relik.retriever.callbacks.prediction_callbacks.NegativeAugmentationCallback k: ${train.top_k} batch_size: 128 precision: 16 index_precision: 16 stages: [validate] #[validate, sanity_check] metrics_to_monitor: validate_recall@${train.top_k} # - sanity_check_recall@${train.top_k} threshold: 0.0 max_negatives: 15 add_with_probability: 0.2 refresh_every_n_epochs: 1 other_callbacks: - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback k: ${train.top_k} verbose: True prefix: "train" utils_callbacks: - _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback - _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback