sywangyi commited on
Commit
72fc4ea
1 Parent(s): 79ec93c

directly use bool instead of torch.float16 to avoid crash in ASIC like HPU which does not support float16

Browse files
Files changed (1) hide show
  1. attention.py +2 -3
attention.py CHANGED
@@ -46,9 +46,8 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_
46
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
47
  if is_causal and (not q.size(2) == 1):
48
  s = max(s_q, s_k)
49
- causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
50
  causal_mask = causal_mask.tril()
51
- causal_mask = causal_mask.to(torch.bool)
52
  causal_mask = ~causal_mask
53
  causal_mask = causal_mask[-s_q:, -s_k:]
54
  attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
@@ -297,4 +296,4 @@ def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None
297
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
298
  alibi_bias = alibi_bias * slopes
299
  return alibi_bias.to(dtype=dtype)
300
- ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
 
46
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
47
  if is_causal and (not q.size(2) == 1):
48
  s = max(s_q, s_k)
49
+ causal_mask = attn_weight.new_ones(s, s, dtype=torch.bool)
50
  causal_mask = causal_mask.tril()
 
51
  causal_mask = ~causal_mask
52
  causal_mask = causal_mask[-s_q:, -s_k:]
53
  attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
 
296
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
297
  alibi_bias = alibi_bias * slopes
298
  return alibi_bias.to(dtype=dtype)
299
+ ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}