Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import unittest | |
import torch | |
from fairseq.modules.multihead_attention import MultiheadAttention | |
class TestMultiheadAttention(unittest.TestCase): | |
def test_append_prev_key_padding_mask(self): | |
bsz = 1 | |
src_len = 4 | |
cases = [ | |
# no padding mask | |
(None, None, None), | |
# current padding mask only | |
( | |
torch.tensor([[1]]).bool(), | |
None, | |
torch.tensor([[0, 0, 0, 1]]).bool(), | |
), | |
# previous padding mask only | |
( | |
None, | |
torch.tensor([[0, 1, 0]]).bool(), | |
torch.tensor([[0, 1, 0, 0]]).bool(), | |
), | |
# both padding masks | |
( | |
torch.tensor([[1]]).bool(), | |
torch.tensor([[0, 1, 0]]).bool(), | |
torch.tensor([[0, 1, 0, 1]]).bool(), | |
), | |
# prev_key_padding_mask already full | |
( | |
torch.tensor([[0, 1, 0, 1]]).bool(), | |
None, | |
torch.tensor([[0, 1, 0, 1]]).bool(), | |
), | |
# key_padding_mask already full | |
( | |
None, | |
torch.tensor([[0, 1, 0, 1]]).bool(), | |
torch.tensor([[0, 1, 0, 1]]).bool(), | |
), | |
] | |
for c in cases: | |
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( | |
c[0], | |
c[1], | |
batch_size=bsz, | |
src_len=src_len, | |
static_kv=False, | |
) | |
if key_padding_mask is not None: | |
self.assertTrue( | |
torch.all(torch.eq(key_padding_mask, c[2])), | |
f"Unexpected resultant key padding mask: {key_padding_mask}" | |
f" given current: {c[0]} and previous: {c[1]}", | |
) | |
self.assertEqual(key_padding_mask.size(0), bsz) | |
self.assertEqual(key_padding_mask.size(1), src_len) | |
else: | |
self.assertIsNone(c[2]) | |
if __name__ == "__main__": | |
unittest.main() | |