fine tune memory?
I am trying to fine tune this model using deepspeed, as suggested in the model's repo: https://github.com/salesforce/jaxformer#a100-fine-tune
I have tried on up to 4 x A100 with a total of 360GB of RAM, but every time my training crashes before starting, after the memory gets fully used (monitored with htop).
How much memory do I need to fine tune this?
Here is a configuration for deepspeed, which should fit on a single A100 with CPU offloading, however, this may be slow:
https://github.com/salesforce/jaxformer/blob/main/jaxformer/hf/train.py
thanks for replying
@enijkamp
. This is exactly what I am trying to use (with my own training data, a longer run, and saving checkpoints), but as I say above, loading the model uses more than 360GB of RAM.
I am not sure if I am activating CPU offloading, though... I suppose the default params in that file are enough?