Fix RuntimeError: pad attn scores back to original query sequence length, instead of unpadded sequence length (i.e. no change).
Browse filesPrevents RuntimeError on line 382's pad_input(…).reshape()
shape '[1, 4096, 4096]' is invalid for input of size 9400320
before this change, pad_input() was basically just doing a .unsqueeze(0):
attn_output.shape
torch.Size([2295, 32, 128])
pad_input(attn_output, indices_q, bsz, max_seqlen_q).shape
torch.Size([1, 2295, 32, 128])
after this change: pad_input actually pads the input back to the original query sequence length:
pad_input(attn_output, indices_q, bsz, q_len).shape
torch.Size([1, 4096, 32, 128])
and the reshape succeeds:
pad_input(attn_output, indices_q, bsz, q_len).reshape(bsz, q_len, h_size).shape
torch.Size([1, 4096, 4096])
- modeling_flash_llama.py +1 -1
modeling_flash_llama.py
CHANGED
@@ -378,7 +378,7 @@ class LlamaAttention(nn.Module):
|
|
378 |
|
379 |
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
380 |
attn_output = pad_input(
|
381 |
-
attn_output, indices_q, bsz,
|
382 |
).reshape(bsz, q_len, h_size)
|
383 |
attn_weights = attn_outputs[2] if output_attentions else None
|
384 |
|
|
|
378 |
|
379 |
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
380 |
attn_output = pad_input(
|
381 |
+
attn_output, indices_q, bsz, q_len
|
382 |
).reshape(bsz, q_len, h_size)
|
383 |
attn_weights = attn_outputs[2] if output_attentions else None
|
384 |
|