feat: selective activation checkpointing

#16
by Markus28 - opened

This PR hasn't been tested yet

This PR adds selective activation checkpointing to the BERT model.
By passing activation_checkpoint_lvl in the config, you can set how many of the BERT layers will be checkpointed if gradient_checkpointing_enable() is called. Reducing this number will save computation at the cost of increased VRAM usage. Checkpointing will not go into effect until gradient_checkpointing_enable() is called.

By default, the value is 100, which means that for any reasonable architecture, all layers will be checkpointed. For the base model, it might make sense to set this to something like 6 to checkpoint half of the layers.
We enforce that MLP checkpointing cannot occur within a checkpointed layer.

For pretraining, I think it would make sense to set this parameter to 0, even though nothing should happen before gradient_checkpointing_enable() is called. But better safe than sorry.

Publish this branch
This branch is in draft mode, publish it to be able to merge.
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment