File size: 3,876 Bytes
2f044c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Required to make the "experiments" dir the default one for the output of the models
hydra:
  run:
    dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}

model_name: ${model.language_model} # used to name the model in wandb
project_name: relik-retriever # used to name the project in wandb

defaults:
  - _self_
  - model: golden_retriever
  - index: inmemory
  - loss: nce_loss
  - optimizer: radamw
  - scheduler: linear_scheduler
  - data: dataset_v2 # iterable_in_batch_negatives #dataset_v2
  - logging: wandb_logging
  - override hydra/job_logging: colorlog
  - override hydra/hydra_logging: colorlog

train:
  # reproducibility
  seed: 42
  set_determinism_the_old_way: False
  # torch parameters
  float32_matmul_precision: "medium"
  # if true, only test the model
  only_test: False
  # if provided, initialize the model with the weights from the checkpoint
  pretrain_ckpt_path: null
  # if provided, start training from the checkpoint
  checkpoint_path: null

  # task specific parameter
  top_k: 100

  # pl_trainer
  pl_trainer:
    _target_: lightning.Trainer
    accelerator: gpu
    devices: 1
    num_nodes: 1
    strategy: auto
    accumulate_grad_batches: 1
    gradient_clip_val: 1.0
    val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps
    check_val_every_n_epoch: 1
    max_epochs: 0
    max_steps: 25_000
    deterministic: True
    fast_dev_run: False
    precision: 16
    reload_dataloaders_every_n_epochs: 1

  early_stopping_callback:
    # null
    _target_: lightning.callbacks.EarlyStopping
    monitor: validate_recall@${train.top_k}
    mode: max
    patience: 3

  model_checkpoint_callback:
    _target_: lightning.callbacks.ModelCheckpoint
    monitor: validate_recall@${train.top_k}
    mode: max
    verbose: True
    save_top_k: 1
    save_last: False
    filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}"
    auto_insert_metric_name: False

  callbacks:
    prediction_callback:
      _target_: relik.retriever.callbacks.training_callbacks.GoldenRetrieverPredictionCallback
      k: ${train.top_k}
      batch_size: 64
      precision: 16
      index_precision: 16
      other_callbacks:
        - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback
          k: ${train.top_k}
          verbose: True
        - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback
          k: 50
          verbose: True
          prog_bar: False
        - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback
          k: ${train.top_k}
          verbose: True
        - _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback

    hard_negatives_callback:
      _target_: relik.retriever.callbacks.prediction_callbacks.NegativeAugmentationCallback
      k: ${train.top_k}
      batch_size: 64
      precision: 16
      index_precision: 16
      stages: [validate] #[validate, sanity_check]
      metrics_to_monitor:
        validate_recall@${train.top_k}
        # - sanity_check_recall@${train.top_k}
      threshold: 0.0
      max_negatives: 20
      add_with_probability: 1.0
      refresh_every_n_epochs: 1
      other_callbacks:
        - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback
          k: ${train.top_k}
          verbose: True
          prefix: "train"

    utils_callbacks:
      - _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback
      - _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback
      # - _target_: relik.retriever.callbacks.utils_callbacks.ResetModelCallback
      #   question_encoder: ${model.pl_module.model.question_encoder}
      #   passage_encoder: ${model.pl_module.model.passage_encoder}