bird-of-paradise commited on
Commit
2d7348d
·
1 Parent(s): 098730b

Update class names to MultiHeadLatentAttention

Browse files
Files changed (3) hide show
  1. src/__init__.py +2 -2
  2. src/mla.py +1 -1
  3. src/tests/test_mla.py +2 -2
src/__init__.py CHANGED
@@ -5,7 +5,7 @@ Copyright (c) 2025
5
  Implementation of the Multi-Latent Attention mechanism from the DeepSeek-V2 paper.
6
  """
7
 
8
- from .mla import MultiLatentAttention, precompute_freqs_cis, reshape_for_broadcast, apply_rotary_emb
9
 
10
  __version__ = "0.1.0"
11
- __all__ = ["MultiLatentAttention", "precompute_freqs_cis", "reshape_for_broadcast","apply_rotary_emb"]
 
5
  Implementation of the Multi-Latent Attention mechanism from the DeepSeek-V2 paper.
6
  """
7
 
8
+ from .mla import MultiHeadLatentAttention, precompute_freqs_cis, reshape_for_broadcast, apply_rotary_emb
9
 
10
  __version__ = "0.1.0"
11
+ __all__ = ["MultiHeadLatentAttention", "precompute_freqs_cis", "reshape_for_broadcast","apply_rotary_emb"]
src/mla.py CHANGED
@@ -58,7 +58,7 @@ def apply_rotary_emb(
58
 
59
 
60
 
61
- class MultiLatentAttention(nn.Module):
62
  """
63
  Multi-Head Latent Attention(MLA) Module As in DeepSeek_V2 pape
64
  Key innovation from standard MHA:
 
58
 
59
 
60
 
61
+ class MultiHeadLatentAttention(nn.Module):
62
  """
63
  Multi-Head Latent Attention(MLA) Module As in DeepSeek_V2 pape
64
  Key innovation from standard MHA:
src/tests/test_mla.py CHANGED
@@ -1,6 +1,6 @@
1
  import unittest
2
  import torch
3
- from ..mla import MultiLatentAttention # Using relative import
4
 
5
  class TestMultiLatentAttention(unittest.TestCase):
6
  def setUp(self):
@@ -15,7 +15,7 @@ class TestMultiLatentAttention(unittest.TestCase):
15
  self.seq_len = 10
16
 
17
  # Initialize MLA
18
- self.mla = MultiLatentAttention(
19
  d_model=self.d_model,
20
  num_head=self.num_head,
21
  d_embed=self.d_embed,
 
1
  import unittest
2
  import torch
3
+ from ..mla import MultiHeadLatentAttention # Using relative import
4
 
5
  class TestMultiLatentAttention(unittest.TestCase):
6
  def setUp(self):
 
15
  self.seq_len = 10
16
 
17
  # Initialize MLA
18
+ self.mla = MultiHeadLatentAttention(
19
  d_model=self.d_model,
20
  num_head=self.num_head,
21
  d_embed=self.d_embed,