Birchlabs commited on
Commit
e6c58da
1 Parent(s): aef6d89

Fix RuntimeError: pad attn scores back to original query sequence length, instead of unpadded sequence length (i.e. no change).

Browse files

Prevents RuntimeError on line 382's pad_input(…).reshape()
shape '[1, 4096, 4096]' is invalid for input of size 9400320

before this change, pad_input() was basically just doing a .unsqueeze(0):
attn_output.shape
torch.Size([2295, 32, 128])
pad_input(attn_output, indices_q, bsz, max_seqlen_q).shape
torch.Size([1, 2295, 32, 128])

after this change: pad_input actually pads the input back to the original query sequence length:
pad_input(attn_output, indices_q, bsz, q_len).shape
torch.Size([1, 4096, 32, 128])
and the reshape succeeds:
pad_input(attn_output, indices_q, bsz, q_len).reshape(bsz, q_len, h_size).shape
torch.Size([1, 4096, 4096])

Files changed (1) hide show
  1. modeling_flash_llama.py +1 -1
modeling_flash_llama.py CHANGED
@@ -378,7 +378,7 @@ class LlamaAttention(nn.Module):
378
 
379
  attn_output = attn_outputs[0] if output_attentions else attn_outputs
380
  attn_output = pad_input(
381
- attn_output, indices_q, bsz, max_seqlen_q
382
  ).reshape(bsz, q_len, h_size)
383
  attn_weights = attn_outputs[2] if output_attentions else None
384
 
 
378
 
379
  attn_output = attn_outputs[0] if output_attentions else attn_outputs
380
  attn_output = pad_input(
381
+ attn_output, indices_q, bsz, q_len
382
  ).reshape(bsz, q_len, h_size)
383
  attn_weights = attn_outputs[2] if output_attentions else None
384