|
from collections import namedtuple |
|
from dataclasses import dataclass |
|
import torch |
|
from typing import Tuple, Optional |
|
|
|
@dataclass |
|
class LongLlamaMemConfig: |
|
""" |
|
Class for configuring memory caches for LongLlama model. |
|
|
|
Args: |
|
positionals (`boolean`) |
|
Whether to use positional embeddings in memory layer |
|
cache_dtype (`torch.dtype`) |
|
Specifies storing type for keys and values |
|
attention_grouping (`Tuple[int, int]`, *optional*) |
|
One can trade speed for memory by performing attention |
|
in memory layers sequentially. |
|
When equal to `(4, 128)` the memory layers will process at most 4 heads and 128 queries |
|
from each head at once. That is at most 512 queries at once. |
|
""" |
|
|
|
positionals: bool = True |
|
cache_dtype: torch.dtype = torch.bfloat16 |
|
attention_grouping: Optional[Tuple[int, int]] = None |
|
|
|
|
|
@dataclass |
|
class LongLlamaMemCache: |
|
""" |
|
Class with LongLlama's memory cache |
|
|
|
Args: |
|
keys (`torch.FloatTensor` of shape `(batch_size, num_heads, mem_length, embed_size_per_head)`) |
|
values (`torch.FloatTensor` of shape `(batch_size, num_heads, mem_length, embed_size_per_head)`) |
|
masks (`torch.FloatTensor` of shape `(batch_size, 1, mem_length, 1)`) |
|
For masking out parts of memory |
|
""" |
|
|
|
keys: torch.FloatTensor |
|
values: torch.FloatTensor |
|
masks: torch.FloatTensor |
|
|
|
|
|
def mem_apply_update(prev_mem_cache: LongLlamaMemCache, new_mem_content: LongLlamaMemCache, mem_config: LongLlamaMemConfig): |
|
def update_one(prev, new): |
|
if len(prev.shape) != 4 or len(new.shape) != 4: |
|
raise ValueError(f"Memory cache content should be consistent in shape got {prev.shape} {new.shape}") |
|
|
|
return torch.concat([prev, new], dim=-2) |
|
|
|
insert_size = new_mem_content.keys.shape[-2] |
|
|
|
if new_mem_content.values.shape[-2] != insert_size or new_mem_content.masks.shape[-2] != insert_size: |
|
raise ValueError(f"Inconsistent mem_length in new_mem_content") |
|
|
|
return LongLlamaMemCache( |
|
keys=update_one(prev_mem_cache.keys, new_mem_content.keys), |
|
values=update_one(prev_mem_cache.values, new_mem_content.values), |
|
masks=update_one(prev_mem_cache.masks, new_mem_content.masks), |
|
) |
|
|