|
from typing import List, Tuple |
|
import torch |
|
|
|
class KVCache: |
|
|
|
def __init__(self) -> None: |
|
self.key_cache: List[torch.Tensor] = [] |
|
self.value_cache: List[torch.Tensor] = [] |
|
|
|
def num_items(self) -> int: |
|
if len(self.key_cache) == 0: |
|
return 0 |
|
else: |
|
|
|
return self.key_cache[0].shape[-2] |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if len(self.key_cache) <= layer_idx: |
|
|
|
self.key_cache.append(key_states) |
|
self.value_cache.append(value_states) |
|
else: |
|
|
|
|
|
self.key_cache[layer_idx] = torch.cat( |
|
[self.key_cache[layer_idx], key_states], dim=-2 |
|
) |
|
self.value_cache[layer_idx] = torch.cat( |
|
[self.value_cache[layer_idx], value_states], dim=-2 |
|
) |
|
|
|
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx] |