riccorl's picture
first commit
626eca0
raw
history blame
3.88 kB
# Required to make the "experiments" dir the default one for the output of the models
hydra:
run:
dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
model_name: ${model.language_model} # used to name the model in wandb
project_name: relik-retriever # used to name the project in wandb
defaults:
- _self_
- model: golden_retriever
- index: inmemory
- loss: nce_loss
- optimizer: radamw
- scheduler: linear_scheduler
- data: dataset_v2 # iterable_in_batch_negatives #dataset_v2
- logging: wandb_logging
- override hydra/job_logging: colorlog
- override hydra/hydra_logging: colorlog
train:
# reproducibility
seed: 42
set_determinism_the_old_way: False
# torch parameters
float32_matmul_precision: "medium"
# if true, only test the model
only_test: False
# if provided, initialize the model with the weights from the checkpoint
pretrain_ckpt_path: null
# if provided, start training from the checkpoint
checkpoint_path: null
# task specific parameter
top_k: 100
# pl_trainer
pl_trainer:
_target_: lightning.Trainer
accelerator: gpu
devices: 1
num_nodes: 1
strategy: auto
accumulate_grad_batches: 1
gradient_clip_val: 1.0
val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps
check_val_every_n_epoch: 1
max_epochs: 0
max_steps: 25_000
deterministic: True
fast_dev_run: False
precision: 16
reload_dataloaders_every_n_epochs: 1
early_stopping_callback:
# null
_target_: lightning.callbacks.EarlyStopping
monitor: validate_recall@${train.top_k}
mode: max
patience: 3
model_checkpoint_callback:
_target_: lightning.callbacks.ModelCheckpoint
monitor: validate_recall@${train.top_k}
mode: max
verbose: True
save_top_k: 1
save_last: False
filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}"
auto_insert_metric_name: False
callbacks:
prediction_callback:
_target_: relik.retriever.callbacks.prediction_callbacks.GoldenRetrieverPredictionCallback
k: ${train.top_k}
batch_size: 64
precision: 16
index_precision: 16
other_callbacks:
- _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback
k: ${train.top_k}
verbose: True
- _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback
k: 50
verbose: True
prog_bar: False
- _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback
k: ${train.top_k}
verbose: True
- _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback
hard_negatives_callback:
_target_: relik.retriever.callbacks.prediction_callbacks.NegativeAugmentationCallback
k: ${train.top_k}
batch_size: 64
precision: 16
index_precision: 16
stages: [validate] #[validate, sanity_check]
metrics_to_monitor:
validate_recall@${train.top_k}
# - sanity_check_recall@${train.top_k}
threshold: 0.0
max_negatives: 20
add_with_probability: 1.0
refresh_every_n_epochs: 1
other_callbacks:
- _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback
k: ${train.top_k}
verbose: True
prefix: "train"
utils_callbacks:
- _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback
- _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback
# - _target_: relik.retriever.callbacks.utils_callbacks.ResetModelCallback
# question_encoder: ${model.pl_module.model.question_encoder}
# passage_encoder: ${model.pl_module.model.passage_encoder}