Spaces:
Running
on
T4
Running
on
T4
File size: 6,433 Bytes
565faca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import warnings
import torch
import torch.nn as nn
from torch.nn import functional as F
class SelfAttention(nn.Module):
def __init__(self, config):
"""
Initializes the SelfAttention module.
Args:
config: An object containing the configuration parameters for the SelfAttention module.
"""
super().__init__()
self._validate_config(config)
self._initialize_parameters(config)
def empty_kv_cache(self, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype):
"""
Empties the key-value cache.
Args:
batch_size: The batch size.
kv_cache_maxlen: The maximum length of the key-value cache.
dtype: The data type of the cache.
Raises:
Exception: If trying to empty the KV cache when it is disabled.
"""
if self.kv_cache_enabled is False:
raise Exception("Trying to empty KV cache when it is disabled")
# register so that the cache moves devices along with the module
# TODO: get rid of re-allocation.
self.register_buffer(
"kv_cache",
torch.zeros(
2,
batch_size,
kv_cache_maxlen,
self.n_head,
self.n_embd // self.n_head,
dtype=dtype,
device=self.c_attn.weight.device,
),
persistent=False,
)
self.kv_cache_first_empty_index = 0
def _initialize_parameters(self, config):
"""
Initializes the parameters of the SelfAttention module.
Args:
config: An object containing the configuration parameters for the SelfAttention module.
"""
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
self.causal = config.causal
self.attn_kernel_type = config.attn_kernel_type
self.attn_dropout = nn.Dropout(config.dropout)
self.kv_cache_enabled = False
def _validate_config(self, config):
"""
Validates the configuration parameters.
Args:
config: An object containing the configuration parameters for the SelfAttention module.
Raises:
AssertionError: If the embedding dimension is not divisible by the number of heads.
"""
assert config.n_embd % config.n_head == 0, "Embedding dimension must be divisible by number of heads"
def _update_kv_cache(self, q, k, v):
"""
Updates the key-value cache.
Args:
q: The query tensor.
k: The key tensor.
v: The value tensor.
Returns:
The updated key and value tensors.
Raises:
AssertionError: If the dimensions of the query, key, and value tensors are not compatible.
"""
q_time, k_time, v_time = q.shape[1], k.shape[1], v.shape[1]
if self.kv_cache_first_empty_index == 0:
assert q_time == k_time and q_time == v_time
else:
assert (
q_time == 1
), f"Only one query at a time is supported, but got q_time={q_time} for kv_cache_first_empty_index={self.kv_cache_first_empty_index}"
self.kv_cache[0, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = k
self.kv_cache[1, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = v
self.kv_cache_first_empty_index += q_time
k = self.kv_cache[0, :, : self.kv_cache_first_empty_index]
v = self.kv_cache[1, :, : self.kv_cache_first_empty_index]
return k, v
def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor:
"""
Performs attention using the torch.nn.functional.scaled_dot_product_attention function.
Args:
c_x: The input tensor.
Returns:
The output tensor.
"""
q, k, v = c_x.split(1, dim=2) # q, k, v of shape (B, T, 1, nh, hs)
q = q.squeeze(2) # (B, T, nh, hs)
k = k.squeeze(2) # (B, T, nh, hs)
v = v.squeeze(2) # (B, T, nh, hs)
# if kv-caching and causal, for the "prefill" stage, we need to use a causal mask, and
# use no mask for the "one time step" parts.
# calculate this before updating kv_caching so we have the right value for kv_cache_first_empty_index
is_causal_attn_mask = self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0)
if self.kv_cache_enabled:
k, v = self._update_kv_cache(q, k, v)
q = q.transpose(1, 2) # (B, nh, T, hs)
k = k.transpose(1, 2) # (B, nh, T, hs)
v = v.transpose(1, 2) # (B, nh, T, hs)
y = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0,
is_causal=is_causal_attn_mask,
).transpose(
1, 2
) # (B, nh, T, hs) -> (B, T, nh, hs)
return y
def forward(self, x):
"""
Performs the forward pass of the SelfAttention module.
Args:
x: The input tensor.
Returns:
The output tensor.
"""
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
c_x = self.c_attn(x).view(B, T, 3, self.n_head, C // self.n_head) # (B, T, 3, nh, hs)
# causal self-attention;
if self.attn_kernel_type == "torch_attn":
y = self._torch_attn(c_x)
else:
raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}")
y = y.contiguous().view(B, T, C) # re-assemble all head outputs side by side: (B, T, nh, hs) -> (B, T, hs * nh)
# output projection
y = self.resid_dropout(self.c_proj(y))
return y
|