The dtype of attention mask should match the qkv states

#3
by Yiming1894 - opened

The code snippet below highlights a potential risk in mixed-precision training/inference:

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
    attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)

The qkv states may have mixed precision, depending on the configuration of the accelerator or Seq2SeqTrainer. However, the attn_mask always has a dtype of float32, which casts torch.finfo(torch.float32).min to '-inf'. This can potentially cause the softmax result to become NaN.

According to Baichuan1, https://huggingface.co/baichuan-inc/Baichuan-7B/blob/c1a5c7d5b7f50ecc51bb0e08150a9f12e5656756/modeling_baichuan.py#L229 a feasible approach could be:

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
    attention_mask = torch.max(attention_mask, torch.tensor(torch.finfo(query_states.dtype).min))
    attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)

Sign up or log in to comment