|
import contextlib |
|
import dataclasses |
|
from collections import defaultdict |
|
from typing import DefaultDict, Dict |
|
from pipeline import are_two_tensors_similar |
|
import torch |
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
class CacheContext: |
|
buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict) |
|
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int)) |
|
|
|
def get_incremental_name(self, name=None): |
|
if name is None: |
|
name = "default" |
|
idx = self.incremental_name_counters[name] |
|
self.incremental_name_counters[name] += 1 |
|
return f"{name}_{idx}" |
|
|
|
def reset_incremental_names(self): |
|
self.incremental_name_counters.clear() |
|
|
|
@torch.compiler.disable |
|
def get_buffer(self, name): |
|
return self.buffers.get(name) |
|
|
|
@torch.compiler.disable |
|
def set_buffer(self, name, buffer): |
|
self.buffers[name] = buffer |
|
|
|
def clear_buffers(self): |
|
self.buffers.clear() |
|
|
|
|
|
@torch.compiler.disable |
|
def get_buffer(name): |
|
cache_context = get_current_cache_context() |
|
assert cache_context is not None, "cache_context must be set before" |
|
return cache_context.get_buffer(name) |
|
|
|
|
|
@torch.compiler.disable |
|
def set_buffer(name, buffer): |
|
cache_context = get_current_cache_context() |
|
assert cache_context is not None, "cache_context must be set before" |
|
cache_context.set_buffer(name, buffer) |
|
|
|
|
|
_current_cache_context = None |
|
|
|
|
|
def create_cache_context(): |
|
return CacheContext() |
|
|
|
|
|
def get_current_cache_context(): |
|
return _current_cache_context |
|
|
|
|
|
def set_current_cache_context(cache_context=None): |
|
global _current_cache_context |
|
_current_cache_context = cache_context |
|
|
|
|
|
@contextlib.contextmanager |
|
def cache_context(cache_context): |
|
global _current_cache_context |
|
old_cache_context = _current_cache_context |
|
_current_cache_context = cache_context |
|
try: |
|
yield |
|
finally: |
|
_current_cache_context = old_cache_context |
|
|
|
|
|
@torch.compiler.disable |
|
def are_two_tensors_similar_old(t1, t2, *, threshold, parallelized=False): |
|
mean_diff = (t1 - t2).abs().mean() |
|
mean_t1 = t1.abs().mean() |
|
diff = mean_diff / mean_t1 |
|
return diff.item() < threshold |
|
|
|
|
|
@torch.compiler.disable |
|
def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states): |
|
hidden_states_residual = get_buffer("hidden_states_residual") |
|
assert hidden_states_residual is not None, "hidden_states_residual must be set before" |
|
hidden_states = hidden_states_residual + hidden_states |
|
|
|
encoder_hidden_states_residual = get_buffer("encoder_hidden_states_residual") |
|
assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before" |
|
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states |
|
|
|
hidden_states = hidden_states.contiguous() |
|
encoder_hidden_states = encoder_hidden_states.contiguous() |
|
|
|
return hidden_states, encoder_hidden_states |
|
|
|
|
|
@torch.compiler.disable |
|
def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False): |
|
prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual") |
|
can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar( |
|
prev_first_hidden_states_residual, |
|
first_hidden_states_residual, |
|
) |
|
return can_use_cache |
|
|
|
|
|
class CachedTransformerBlocks(torch.nn.Module): |
|
def __init__( |
|
self, |
|
transformer_blocks, |
|
single_transformer_blocks=None, |
|
*, |
|
transformer=None, |
|
residual_diff_threshold, |
|
return_hidden_states_first=True, |
|
): |
|
super().__init__() |
|
self.transformer = transformer |
|
self.transformer_blocks = transformer_blocks |
|
self.single_transformer_blocks = single_transformer_blocks |
|
self.residual_diff_threshold = residual_diff_threshold |
|
self.return_hidden_states_first = return_hidden_states_first |
|
|
|
def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs): |
|
if self.residual_diff_threshold <= 0.0: |
|
for block in self.transformer_blocks: |
|
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs) |
|
if not self.return_hidden_states_first: |
|
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states |
|
if self.single_transformer_blocks is not None: |
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
for block in self.single_transformer_blocks: |
|
hidden_states = block(hidden_states, *args, **kwargs) |
|
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :] |
|
return ( |
|
(hidden_states, encoder_hidden_states) |
|
if self.return_hidden_states_first |
|
else (encoder_hidden_states, hidden_states) |
|
) |
|
|
|
original_hidden_states = hidden_states |
|
first_transformer_block = self.transformer_blocks[0] |
|
hidden_states, encoder_hidden_states = first_transformer_block( |
|
hidden_states, encoder_hidden_states, *args, **kwargs |
|
) |
|
if not self.return_hidden_states_first: |
|
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states |
|
first_hidden_states_residual = hidden_states - original_hidden_states |
|
del original_hidden_states |
|
|
|
can_use_cache = get_can_use_cache( |
|
first_hidden_states_residual, |
|
threshold=self.residual_diff_threshold, |
|
parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False), |
|
) |
|
|
|
torch._dynamo.graph_break() |
|
if can_use_cache: |
|
del first_hidden_states_residual |
|
hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual( |
|
hidden_states, encoder_hidden_states |
|
) |
|
else: |
|
set_buffer("first_hidden_states_residual", first_hidden_states_residual) |
|
del first_hidden_states_residual |
|
( |
|
hidden_states, |
|
encoder_hidden_states, |
|
hidden_states_residual, |
|
encoder_hidden_states_residual, |
|
) = self.call_remaining_transformer_blocks(hidden_states, encoder_hidden_states, *args, **kwargs) |
|
set_buffer("hidden_states_residual", hidden_states_residual) |
|
set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual) |
|
torch._dynamo.graph_break() |
|
|
|
return ( |
|
(hidden_states, encoder_hidden_states) |
|
if self.return_hidden_states_first |
|
else (encoder_hidden_states, hidden_states) |
|
) |
|
|
|
def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, **kwargs): |
|
original_hidden_states = hidden_states |
|
original_encoder_hidden_states = encoder_hidden_states |
|
for block in self.transformer_blocks[1:]: |
|
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs) |
|
if not self.return_hidden_states_first: |
|
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states |
|
if self.single_transformer_blocks is not None: |
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
for block in self.single_transformer_blocks: |
|
hidden_states = block(hidden_states, *args, **kwargs) |
|
encoder_hidden_states, hidden_states = hidden_states.split( |
|
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 |
|
) |
|
|
|
|
|
|
|
hidden_states = hidden_states.reshape(-1).contiguous().reshape(original_hidden_states.shape) |
|
encoder_hidden_states = ( |
|
encoder_hidden_states.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape) |
|
) |
|
|
|
|
|
|
|
|
|
hidden_states_residual = hidden_states - original_hidden_states |
|
encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states |
|
|
|
hidden_states_residual = hidden_states_residual.reshape(-1).contiguous().reshape(original_hidden_states.shape) |
|
encoder_hidden_states_residual = ( |
|
encoder_hidden_states_residual.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape) |
|
) |
|
|
|
return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual |
|
|