Weird Behavior of Eager Attention Implementation during Fine-tuning

#5
by delinqu - opened

Hi!
When fine-tuning with Paligemma2, I followed the Jupyter fine-tuning tutorial and used eager as suggested by Gemma2. However, the training converged super quickly, with the model achieving near 100% accuracy, but performing poorly on the test set. In contrast, when I used flash_attention_2, the model converged relatively slowly but performed well on the test set. I'm wondering why Gemma2 recommends using eager? What's the reason behind this difference? I'd greatly appreciate any valuable insights you can provide.

  • eager
    image.png

  • flash_attention_2

image.png

Google org

Hi @delinqu ,

In case of Eager Attention training converged super quickly performed well on training data but not test data which means Overfitting. This happens because it explicitly computes and stores the full attention matrix, which can overfit noise or minor patterns in the training data.

In case of flash_attention_2 converged relatively slowly and performed well on both train, test datasets which means no Overfitting. FlashAttention computes attention incrementally and in blocks. This reduces the chance of overfitting. The model has less capacity to memorize specific patterns from the training data. The training dynamics encourage the model to focus on broader patterns that generalize better.

Gemma 2 recommends using eager attention for batch inference with bfloat16 precision. This is because, when using the default attention mechanism (torch.scaled_dot_product_attention) with bfloat16, the model can produce NaN values for input sequences that contain padding. For more details, please refer to this reference

Thank you.

delinqu changed discussion status to closed

Sign up or log in to comment