|
import unittest |
|
import torch |
|
from ..mla import MultiHeadLatentAttention |
|
|
|
class TestMultiLatentAttention(unittest.TestCase): |
|
def setUp(self): |
|
|
|
self.d_model = 512 |
|
self.num_head = 8 |
|
self.d_embed = 512 |
|
self.d_c = 64 |
|
self.d_c1 = 64 |
|
self.d_rotate = 32 |
|
self.batch_size = 2 |
|
self.seq_len = 10 |
|
|
|
|
|
self.mla = MultiHeadLatentAttention( |
|
d_model=self.d_model, |
|
num_head=self.num_head, |
|
d_embed=self.d_embed, |
|
d_c=self.d_c, |
|
d_c1=self.d_c1, |
|
d_rotate=self.d_rotate |
|
) |
|
|
|
def test_basic_forward(self): |
|
"""Test basic forward pass without caching""" |
|
x = torch.randn(self.batch_size, self.seq_len, self.d_model) |
|
output = self.mla(x) |
|
|
|
|
|
self.assertEqual( |
|
output.shape, |
|
(self.batch_size, self.seq_len, self.d_model), |
|
"Output shape mismatch" |
|
) |
|
|
|
def test_cross_attention(self): |
|
"""Test cross-attention functionality""" |
|
query = torch.randn(self.batch_size, self.seq_len, self.d_model) |
|
kv = torch.randn(self.batch_size, self.seq_len * 2, self.d_model) |
|
|
|
output = self.mla(query, key_value_states=kv) |
|
self.assertEqual( |
|
output.shape, |
|
(self.batch_size, self.seq_len, self.d_model), |
|
"Cross-attention output shape mismatch" |
|
) |
|
|
|
def test_cache_initialization(self): |
|
"""Test if cache is properly initialized""" |
|
x = torch.randn(self.batch_size, self.seq_len, self.d_model) |
|
_ = self.mla(x, use_cache=True, start_pos=0) |
|
|
|
self.assertIsNotNone(self.mla.cache_kv) |
|
self.assertEqual( |
|
self.mla.cache_kv.shape[-1], |
|
self.d_c, |
|
"Cache compression dimension mismatch" |
|
) |
|
|
|
def test_sequential_caching(self): |
|
"""Test sequential forward passes with caching""" |
|
|
|
prompt_len = 5 |
|
prompt = torch.randn(self.batch_size, prompt_len, self.d_model) |
|
|
|
|
|
output1 = self.mla(prompt, use_cache=True, start_pos=0) |
|
cached_kv_1 = self.mla.cache_kv[:, :prompt_len].clone() |
|
|
|
|
|
new_token = torch.randn(self.batch_size, 1, self.d_model) |
|
output2 = self.mla(new_token, use_cache=True, start_pos=prompt_len) |
|
|
|
|
|
|
|
self.assertTrue( |
|
torch.allclose( |
|
self.mla.cache_kv[:, :prompt_len], |
|
cached_kv_1, |
|
rtol=1e-5 |
|
), |
|
"Cache was modified for previously processed tokens" |
|
) |
|
|
|
|
|
self.assertFalse( |
|
torch.allclose( |
|
self.mla.cache_kv[:, prompt_len:prompt_len+1], |
|
torch.zeros_like(self.mla.cache_kv[:, prompt_len:prompt_len+1]), |
|
rtol=1e-5 |
|
), |
|
"New token was not added to cache" |
|
) |
|
|
|
def test_attention_mask_with_cache(self): |
|
"""Test attention masking with cached KV""" |
|
seq_len = 5 |
|
x = torch.randn(self.batch_size, seq_len, self.d_model) |
|
|
|
|
|
mask = torch.triu( |
|
torch.ones(seq_len, seq_len) * float('-inf'), |
|
diagonal=1 |
|
).unsqueeze(0) |
|
|
|
|
|
output1 = self.mla(x, use_cache=True, start_pos=0, att_mask=mask) |
|
|
|
|
|
new_token = torch.randn(self.batch_size, 1, self.d_model) |
|
extended_mask = torch.triu( |
|
torch.ones(seq_len + 1, seq_len + 1) * float('-inf'), |
|
diagonal=1 |
|
).unsqueeze(0) |
|
|
|
output2 = self.mla( |
|
new_token, |
|
use_cache=True, |
|
start_pos=seq_len, |
|
att_mask=extended_mask |
|
) |
|
|
|
self.assertEqual( |
|
output2.shape, |
|
(self.batch_size, 1, self.d_model), |
|
"Output shape incorrect for cached attention with mask" |
|
) |
|
|
|
def run_tests(): |
|
suite = unittest.TestLoader().loadTestsFromTestCase(TestMultiLatentAttention) |
|
runner = unittest.TextTestRunner(verbosity=2) |
|
runner.run(suite) |
|
|
|
|
|
run_tests() |