Fix RuntimeError: pad attn scores back to original query sequence length, instead of unpadded sequence length (i.e. no change).
e6c58da
Birchlabs
commited on