File size: 4,704 Bytes
550eb56 2d7348d 550eb56 2d7348d 550eb56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import unittest
import torch
from ..mla import MultiHeadLatentAttention # Using relative import
class TestMultiLatentAttention(unittest.TestCase):
def setUp(self):
# Common dimensions for testing
self.d_model = 512
self.num_head = 8
self.d_embed = 512
self.d_c = 64 # Compression dim for K/V
self.d_c1 = 64 # Compression dim for Q
self.d_rotate = 32 # For future RoPE implementation
self.batch_size = 2
self.seq_len = 10
# Initialize MLA
self.mla = MultiHeadLatentAttention(
d_model=self.d_model,
num_head=self.num_head,
d_embed=self.d_embed,
d_c=self.d_c,
d_c1=self.d_c1,
d_rotate=self.d_rotate
)
def test_basic_forward(self):
"""Test basic forward pass without caching"""
x = torch.randn(self.batch_size, self.seq_len, self.d_model)
output = self.mla(x)
# Check output shape
self.assertEqual(
output.shape,
(self.batch_size, self.seq_len, self.d_model),
"Output shape mismatch"
)
def test_cross_attention(self):
"""Test cross-attention functionality"""
query = torch.randn(self.batch_size, self.seq_len, self.d_model)
kv = torch.randn(self.batch_size, self.seq_len * 2, self.d_model) # Different seq_len
output = self.mla(query, key_value_states=kv)
self.assertEqual(
output.shape,
(self.batch_size, self.seq_len, self.d_model),
"Cross-attention output shape mismatch"
)
def test_cache_initialization(self):
"""Test if cache is properly initialized"""
x = torch.randn(self.batch_size, self.seq_len, self.d_model)
_ = self.mla(x, use_cache=True, start_pos=0)
self.assertIsNotNone(self.mla.cache_kv)
self.assertEqual(
self.mla.cache_kv.shape[-1],
self.d_c,
"Cache compression dimension mismatch"
)
def test_sequential_caching(self):
"""Test sequential forward passes with caching"""
# Initial sequence
prompt_len = 5
prompt = torch.randn(self.batch_size, prompt_len, self.d_model)
# First forward pass with prompt
output1 = self.mla(prompt, use_cache=True, start_pos=0)
cached_kv_1 = self.mla.cache_kv[:, :prompt_len].clone()
# Second forward pass with one new token
new_token = torch.randn(self.batch_size, 1, self.d_model)
output2 = self.mla(new_token, use_cache=True, start_pos=prompt_len)
# Verify cache consistency
# First part of cache should remain unchanged
self.assertTrue(
torch.allclose(
self.mla.cache_kv[:, :prompt_len],
cached_kv_1,
rtol=1e-5
),
"Cache was modified for previously processed tokens"
)
# Verify new token was added to cache
self.assertFalse(
torch.allclose(
self.mla.cache_kv[:, prompt_len:prompt_len+1],
torch.zeros_like(self.mla.cache_kv[:, prompt_len:prompt_len+1]),
rtol=1e-5
),
"New token was not added to cache"
)
def test_attention_mask_with_cache(self):
"""Test attention masking with cached KV"""
seq_len = 5
x = torch.randn(self.batch_size, seq_len, self.d_model)
# Create causal mask
mask = torch.triu(
torch.ones(seq_len, seq_len) * float('-inf'),
diagonal=1
).unsqueeze(0)
# First forward pass with mask
output1 = self.mla(x, use_cache=True, start_pos=0, att_mask=mask)
# Second pass with one token
new_token = torch.randn(self.batch_size, 1, self.d_model)
extended_mask = torch.triu(
torch.ones(seq_len + 1, seq_len + 1) * float('-inf'),
diagonal=1
).unsqueeze(0)
output2 = self.mla(
new_token,
use_cache=True,
start_pos=seq_len,
att_mask=extended_mask
)
self.assertEqual(
output2.shape,
(self.batch_size, 1, self.d_model),
"Output shape incorrect for cached attention with mask"
)
def run_tests():
suite = unittest.TestLoader().loadTestsFromTestCase(TestMultiLatentAttention)
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite)
# Run the tests
run_tests() |