File size: 1,414 Bytes
685ecb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from typing import List, Tuple
import torch
class KVCache:
def __init__(self) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
def num_items(self) -> int:
if len(self.key_cache) == 0:
return 0
else:
# The shape of the key_cache is [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
return self.key_cache[0].shape[-2]
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
if len(self.key_cache) <= layer_idx:
# If we never added anything to the KV-Cache of this layer, let's create it.
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
# ... otherwise we concatenate the new keys with the existing ones.
# each tensor has shape: [Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim]
self.key_cache[layer_idx] = torch.cat(
[self.key_cache[layer_idx], key_states], dim=-2
)
self.value_cache[layer_idx] = torch.cat(
[self.value_cache[layer_idx], value_states], dim=-2
)
# ... and then we return all the existing keys + the new ones.
return self.key_cache[layer_idx], self.value_cache[layer_idx] |