Update attention.py
Browse files- attention.py +51 -29
attention.py
CHANGED
@@ -5,6 +5,7 @@ from typing import Optional
|
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
from einops import rearrange
|
|
|
8 |
from torch import nn
|
9 |
from .norm import LPLayerNorm
|
10 |
|
@@ -16,25 +17,34 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_cau
|
|
16 |
return False
|
17 |
return original_is_causal
|
18 |
|
19 |
-
def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
20 |
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
24 |
(b, _, s_q, d) = q.shape
|
25 |
s_k = k.size(-1)
|
26 |
if softmax_scale is None:
|
27 |
softmax_scale = 1 / math.sqrt(d)
|
28 |
attn_weight = q.matmul(k) * softmax_scale
|
29 |
if attn_bias is not None:
|
|
|
|
|
|
|
30 |
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
|
31 |
raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
|
32 |
attn_weight = attn_weight + attn_bias
|
|
|
33 |
if key_padding_mask is not None:
|
34 |
if attn_bias is not None:
|
35 |
warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
|
36 |
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
37 |
-
if is_causal:
|
38 |
s = max(s_q, s_k)
|
39 |
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
40 |
causal_mask = causal_mask.tril()
|
@@ -45,11 +55,11 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
|
|
45 |
attn_weight = torch.softmax(attn_weight, dim=-1)
|
46 |
if dropout_p:
|
47 |
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
|
48 |
-
out = attn_weight.matmul(v)
|
49 |
out = rearrange(out, 'b h s d -> b s (h d)')
|
50 |
if needs_weights:
|
51 |
-
return (out, attn_weight)
|
52 |
-
return (out, None)
|
53 |
|
54 |
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
55 |
for tensor in tensors:
|
@@ -58,12 +68,21 @@ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
|
58 |
if not tensor.is_cuda:
|
59 |
raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
|
60 |
|
61 |
-
def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
62 |
try:
|
63 |
from flash_attn import bert_padding, flash_attn_interface
|
64 |
except:
|
65 |
raise RuntimeError('Please install flash-attn==1.0.3.post0')
|
66 |
check_valid_inputs(query, key, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
if attn_bias is not None:
|
68 |
raise NotImplementedError(f'attn_bias not implemented for flash attn.')
|
69 |
(batch_size, seqlen) = query.shape[:2]
|
@@ -83,14 +102,31 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
|
|
83 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
84 |
output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
|
85 |
output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
|
86 |
-
return (output, None)
|
87 |
|
88 |
-
def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
89 |
try:
|
90 |
from .flash_attn_triton import flash_attn_func
|
91 |
except:
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
check_valid_inputs(query, key, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
if dropout_p:
|
95 |
raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
|
96 |
if needs_weights:
|
@@ -110,7 +146,7 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
|
|
110 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
111 |
attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
112 |
output = attn_output.view(*attn_output.shape[:2], -1)
|
113 |
-
return (output, None)
|
114 |
|
115 |
class MultiheadAttention(nn.Module):
|
116 |
"""Multi-head self attention.
|
@@ -162,14 +198,7 @@ class MultiheadAttention(nn.Module):
|
|
162 |
dtype = query.dtype
|
163 |
query = self.q_ln(query).to(dtype)
|
164 |
key = self.k_ln(key).to(dtype)
|
165 |
-
|
166 |
-
if len(past_key_value) != 0:
|
167 |
-
key = torch.cat([past_key_value[0], key], dim=1)
|
168 |
-
value = torch.cat([past_key_value[1], value], dim=1)
|
169 |
-
past_key_value = (key, value)
|
170 |
-
if attn_bias is not None:
|
171 |
-
attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
|
172 |
-
(context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
|
173 |
return (self.out_proj(context), attn_weights, past_key_value)
|
174 |
|
175 |
class MultiQueryAttention(nn.Module):
|
@@ -223,14 +252,7 @@ class MultiQueryAttention(nn.Module):
|
|
223 |
dtype = query.dtype
|
224 |
query = self.q_ln(query).to(dtype)
|
225 |
key = self.k_ln(key).to(dtype)
|
226 |
-
|
227 |
-
if len(past_key_value) != 0:
|
228 |
-
key = torch.cat([past_key_value[0], key], dim=1)
|
229 |
-
value = torch.cat([past_key_value[1], value], dim=1)
|
230 |
-
past_key_value = (key, value)
|
231 |
-
if attn_bias is not None:
|
232 |
-
attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
|
233 |
-
(context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
|
234 |
return (self.out_proj(context), attn_weights, past_key_value)
|
235 |
|
236 |
def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
|
|
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
from einops import rearrange
|
8 |
+
from packaging import version
|
9 |
from torch import nn
|
10 |
from .norm import LPLayerNorm
|
11 |
|
|
|
17 |
return False
|
18 |
return original_is_causal
|
19 |
|
20 |
+
def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
21 |
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
|
22 |
+
kv_n_heads = 1 if multiquery else n_heads
|
23 |
+
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
|
24 |
+
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
|
25 |
+
if past_key_value is not None:
|
26 |
+
if len(past_key_value) != 0:
|
27 |
+
k = torch.cat([past_key_value[0], k], dim=3)
|
28 |
+
v = torch.cat([past_key_value[1], v], dim=2)
|
29 |
+
past_key_value = (k, v)
|
30 |
(b, _, s_q, d) = q.shape
|
31 |
s_k = k.size(-1)
|
32 |
if softmax_scale is None:
|
33 |
softmax_scale = 1 / math.sqrt(d)
|
34 |
attn_weight = q.matmul(k) * softmax_scale
|
35 |
if attn_bias is not None:
|
36 |
+
_s_q = max(0, attn_bias.size(2) - s_q)
|
37 |
+
_s_k = max(0, attn_bias.size(3) - s_k)
|
38 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
39 |
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
|
40 |
raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
|
41 |
attn_weight = attn_weight + attn_bias
|
42 |
+
min_val = torch.finfo(q.dtype).min
|
43 |
if key_padding_mask is not None:
|
44 |
if attn_bias is not None:
|
45 |
warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
|
46 |
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
47 |
+
if is_causal and (not q.size(2) == 1):
|
48 |
s = max(s_q, s_k)
|
49 |
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
50 |
causal_mask = causal_mask.tril()
|
|
|
55 |
attn_weight = torch.softmax(attn_weight, dim=-1)
|
56 |
if dropout_p:
|
57 |
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
|
58 |
+
out = attn_weight.to(v.dtype).matmul(v)
|
59 |
out = rearrange(out, 'b h s d -> b s (h d)')
|
60 |
if needs_weights:
|
61 |
+
return (out, attn_weight, past_key_value)
|
62 |
+
return (out, None, past_key_value)
|
63 |
|
64 |
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
65 |
for tensor in tensors:
|
|
|
68 |
if not tensor.is_cuda:
|
69 |
raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
|
70 |
|
71 |
+
def flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
72 |
try:
|
73 |
from flash_attn import bert_padding, flash_attn_interface
|
74 |
except:
|
75 |
raise RuntimeError('Please install flash-attn==1.0.3.post0')
|
76 |
check_valid_inputs(query, key, value)
|
77 |
+
if past_key_value is not None:
|
78 |
+
if len(past_key_value) != 0:
|
79 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
80 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
81 |
+
past_key_value = (key, value)
|
82 |
+
if attn_bias is not None:
|
83 |
+
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
84 |
+
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
85 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
86 |
if attn_bias is not None:
|
87 |
raise NotImplementedError(f'attn_bias not implemented for flash attn.')
|
88 |
(batch_size, seqlen) = query.shape[:2]
|
|
|
102 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
103 |
output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
|
104 |
output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
|
105 |
+
return (output, None, past_key_value)
|
106 |
|
107 |
+
def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
108 |
try:
|
109 |
from .flash_attn_triton import flash_attn_func
|
110 |
except:
|
111 |
+
_installed = False
|
112 |
+
if version.parse(torch.__version__) < version.parse('2.0.0'):
|
113 |
+
_installed = True
|
114 |
+
try:
|
115 |
+
from flash_attn.flash_attn_triton import flash_attn_func
|
116 |
+
except:
|
117 |
+
_installed = False
|
118 |
+
if not _installed:
|
119 |
+
raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
|
120 |
check_valid_inputs(query, key, value)
|
121 |
+
if past_key_value is not None:
|
122 |
+
if len(past_key_value) != 0:
|
123 |
+
key = torch.cat([past_key_value[0], key], dim=1)
|
124 |
+
value = torch.cat([past_key_value[1], value], dim=1)
|
125 |
+
past_key_value = (key, value)
|
126 |
+
if attn_bias is not None:
|
127 |
+
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
128 |
+
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
129 |
+
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
130 |
if dropout_p:
|
131 |
raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
|
132 |
if needs_weights:
|
|
|
146 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
147 |
attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
148 |
output = attn_output.view(*attn_output.shape[:2], -1)
|
149 |
+
return (output, None, past_key_value)
|
150 |
|
151 |
class MultiheadAttention(nn.Module):
|
152 |
"""Multi-head self attention.
|
|
|
198 |
dtype = query.dtype
|
199 |
query = self.q_ln(query).to(dtype)
|
200 |
key = self.k_ln(key).to(dtype)
|
201 |
+
(context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
return (self.out_proj(context), attn_weights, past_key_value)
|
203 |
|
204 |
class MultiQueryAttention(nn.Module):
|
|
|
252 |
dtype = query.dtype
|
253 |
query = self.q_ln(query).to(dtype)
|
254 |
key = self.k_ln(key).to(dtype)
|
255 |
+
(context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
return (self.out_proj(context), attn_weights, past_key_value)
|
257 |
|
258 |
def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
|