File size: 4,704 Bytes
550eb56
 
2d7348d
550eb56
 
 
 
 
 
 
 
 
 
 
 
 
 
2d7348d
550eb56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import unittest
import torch
from ..mla import MultiHeadLatentAttention  # Using relative import

class TestMultiLatentAttention(unittest.TestCase):
    def setUp(self):
        # Common dimensions for testing
        self.d_model = 512
        self.num_head = 8
        self.d_embed = 512
        self.d_c = 64  # Compression dim for K/V
        self.d_c1 = 64  # Compression dim for Q
        self.d_rotate = 32  # For future RoPE implementation
        self.batch_size = 2
        self.seq_len = 10
        
        # Initialize MLA
        self.mla = MultiHeadLatentAttention(
            d_model=self.d_model,
            num_head=self.num_head,
            d_embed=self.d_embed,
            d_c=self.d_c,
            d_c1=self.d_c1,
            d_rotate=self.d_rotate
        )
        
    def test_basic_forward(self):
        """Test basic forward pass without caching"""
        x = torch.randn(self.batch_size, self.seq_len, self.d_model)
        output = self.mla(x)
        
        # Check output shape
        self.assertEqual(
            output.shape, 
            (self.batch_size, self.seq_len, self.d_model),
            "Output shape mismatch"
        )
        
    def test_cross_attention(self):
        """Test cross-attention functionality"""
        query = torch.randn(self.batch_size, self.seq_len, self.d_model)
        kv = torch.randn(self.batch_size, self.seq_len * 2, self.d_model)  # Different seq_len
        
        output = self.mla(query, key_value_states=kv)
        self.assertEqual(
            output.shape, 
            (self.batch_size, self.seq_len, self.d_model),
            "Cross-attention output shape mismatch"
        )
        
    def test_cache_initialization(self):
        """Test if cache is properly initialized"""
        x = torch.randn(self.batch_size, self.seq_len, self.d_model)
        _ = self.mla(x, use_cache=True, start_pos=0)
        
        self.assertIsNotNone(self.mla.cache_kv)
        self.assertEqual(
            self.mla.cache_kv.shape[-1],
            self.d_c,
            "Cache compression dimension mismatch"
        )
        
    def test_sequential_caching(self):
        """Test sequential forward passes with caching"""
        # Initial sequence
        prompt_len = 5
        prompt = torch.randn(self.batch_size, prompt_len, self.d_model)
        
        # First forward pass with prompt
        output1 = self.mla(prompt, use_cache=True, start_pos=0)
        cached_kv_1 = self.mla.cache_kv[:, :prompt_len].clone()
        
        # Second forward pass with one new token
        new_token = torch.randn(self.batch_size, 1, self.d_model)
        output2 = self.mla(new_token, use_cache=True, start_pos=prompt_len)
        
        # Verify cache consistency
        # First part of cache should remain unchanged
        self.assertTrue(
            torch.allclose(
                self.mla.cache_kv[:, :prompt_len],
                cached_kv_1,
                rtol=1e-5
            ),
            "Cache was modified for previously processed tokens"
        )
        
        # Verify new token was added to cache
        self.assertFalse(
            torch.allclose(
                self.mla.cache_kv[:, prompt_len:prompt_len+1],
                torch.zeros_like(self.mla.cache_kv[:, prompt_len:prompt_len+1]),
                rtol=1e-5
            ),
            "New token was not added to cache"
        )
        
    def test_attention_mask_with_cache(self):
        """Test attention masking with cached KV"""
        seq_len = 5
        x = torch.randn(self.batch_size, seq_len, self.d_model)
        
        # Create causal mask
        mask = torch.triu(
            torch.ones(seq_len, seq_len) * float('-inf'), 
            diagonal=1
        ).unsqueeze(0)
        
        # First forward pass with mask
        output1 = self.mla(x, use_cache=True, start_pos=0, att_mask=mask)
        
        # Second pass with one token
        new_token = torch.randn(self.batch_size, 1, self.d_model)
        extended_mask = torch.triu(
            torch.ones(seq_len + 1, seq_len + 1) * float('-inf'),
            diagonal=1
        ).unsqueeze(0)
        
        output2 = self.mla(
            new_token,
            use_cache=True,
            start_pos=seq_len,
            att_mask=extended_mask
        )
        
        self.assertEqual(
            output2.shape,
            (self.batch_size, 1, self.d_model),
            "Output shape incorrect for cached attention with mask"
        )

def run_tests():
    suite = unittest.TestLoader().loadTestsFromTestCase(TestMultiLatentAttention)
    runner = unittest.TextTestRunner(verbosity=2)
    runner.run(suite)

# Run the tests
run_tests()