Ngaima Sandiman
Initial commit.
685ecb2
raw
history blame
1.41 kB
from typing import List, Tuple
import torch
class KVCache:
def __init__(self) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
def num_items(self) -> int:
if len(self.key_cache) == 0:
return 0
else:
# The shape of the key_cache is [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
return self.key_cache[0].shape[-2]
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
if len(self.key_cache) <= layer_idx:
# If we never added anything to the KV-Cache of this layer, let's create it.
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
# ... otherwise we concatenate the new keys with the existing ones.
# each tensor has shape: [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
self.key_cache[layer_idx] = torch.cat(
[self.key_cache[layer_idx], key_states], dim=-2
)
self.value_cache[layer_idx] = torch.cat(
[self.value_cache[layer_idx], value_states], dim=-2
)
# ... and then we return all the existing keys + the new ones.
return self.key_cache[layer_idx], self.value_cache[layer_idx]