Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Fused Attention | |
=============== | |
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) | |
Credits: OpenAI kernel team | |
Extra Credits: | |
- Original flash attention paper (https://arxiv.org/abs/2205.14135) | |
- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) | |
""" | |
import math | |
import torch | |
import triton | |
import triton.language as tl | |
_BLOCK_N=64 | |
_BLOCK_M=64 | |
def _attn_fwd_inner(acc, l_i, m_i, q, | |
K_block_ptr, V_block_ptr, | |
start_m, qk_scale, N_CTX, | |
sliding_window_offset, sliding_window_size, | |
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr, | |
IS_EVEN_M: tl.constexpr, IS_EVEN_N: tl.constexpr, COMPLEMENT_SLIDING_WINDOW: tl.constexpr | |
): | |
# range of values handled by this stage | |
if SLIDING_WINDOW and not COMPLEMENT_SLIDING_WINDOW: | |
if COMPLEMENT_SLIDING_WINDOW: | |
lo = 0 | |
hi = (((start_m + 1) * BLOCK_M + sliding_window_offset - sliding_window_size + BLOCK_N - 1) // BLOCK_N) * BLOCK_N | |
else: | |
lo = ((start_m * BLOCK_M + sliding_window_offset - sliding_window_size + 1) // BLOCK_N) * BLOCK_N | |
hi = ((((start_m + 1) * BLOCK_M - 1) + sliding_window_offset + BLOCK_N) // BLOCK_N) * BLOCK_N | |
if lo < 0: | |
lo = 0 | |
if hi > N_CTX: | |
hi = N_CTX | |
# lo = 0 | |
# hi = N_CTX | |
lo = tl.multiple_of(lo, BLOCK_N) | |
K_block_ptr = tl.advance(K_block_ptr, (0, lo)) | |
V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) | |
else: | |
lo, hi = 0, N_CTX | |
# loop over k, v and update accumulator | |
for start_n in range(lo, hi, BLOCK_N): | |
start_n = tl.multiple_of(start_n, BLOCK_N) | |
# -- compute qk ---- | |
if IS_EVEN_N: | |
k = tl.load(K_block_ptr) | |
else: | |
k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option="zero") | |
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) | |
qk += tl.dot(q, k) | |
qk = qk * qk_scale | |
if SLIDING_WINDOW: | |
dist = tl.arange(0, BLOCK_M)[:, None] - tl.arange(0, BLOCK_N)[None, :] \ | |
+ start_m * BLOCK_M - start_n + sliding_window_offset | |
if COMPLEMENT_SLIDING_WINDOW: | |
mask = (dist >= sliding_window_size) | |
else: | |
mask = (dist >= 0) & (dist < sliding_window_size) | |
qk = tl.where(mask, qk, float("-inf")) | |
if not IS_EVEN_N: | |
qk = tl.where(((tl.arange(0, BLOCK_N) + start_n) < N_CTX)[None, :], qk, float("-inf")) | |
m_ij = tl.maximum(m_i, tl.max(qk, 1)) | |
qk = qk - m_ij[:, None] | |
p = tl.math.exp2(qk) | |
if SLIDING_WINDOW: | |
p = tl.where(mask, p, 0) | |
if not IS_EVEN_N: | |
p = tl.where(((tl.arange(0, BLOCK_N) + start_n) < N_CTX)[None, :], p, 0) | |
l_ij = tl.sum(p, 1) | |
# -- update m_i and l_i | |
tmp = m_i - m_ij | |
alpha_mask = (tmp != tmp) # check nan | |
alpha = tl.math.exp2(tmp) | |
alpha = tl.where(alpha_mask, 1., alpha) | |
l_i = l_i * alpha + l_ij | |
# -- update output accumulator -- | |
acc = acc * alpha[:, None] | |
# update acc | |
if IS_EVEN_N: | |
v = tl.load(V_block_ptr) | |
else: | |
v = tl.load(V_block_ptr, boundary_check=(0, 1), padding_option="zero") | |
acc += tl.dot(p.to(v.dtype), v) | |
# update m_i and l_i | |
m_i = m_ij | |
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) | |
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) | |
return acc, l_i, m_i | |
def _attn_fwd(Q, K, V, sm_scale, M, Out, L,# | |
stride_qz, stride_qh, stride_qm, stride_qk, # | |
stride_kz, stride_kh, stride_kn, stride_kk, # | |
stride_vz, stride_vh, stride_vk, stride_vn, # | |
stride_oz, stride_oh, stride_om, stride_on, # | |
Z, H, H_KV, # | |
N_CTX, # | |
ROUND_CTX, | |
NKV_CTX, | |
sliding_window_offset, | |
sliding_window_size, | |
IS_EVEN_M: tl.constexpr, | |
IS_EVEN_N: tl.constexpr, | |
BLOCK_M: tl.constexpr, # | |
BLOCK_DMODEL: tl.constexpr, # | |
BLOCK_N: tl.constexpr, # | |
END: tl.constexpr, | |
INIT: tl.constexpr, | |
SLIDING_WINDOW: tl.constexpr, | |
COMPLEMENT_SLIDING_WINDOW: tl.constexpr | |
): | |
start_m = tl.program_id(0) | |
off_hz = tl.program_id(1) | |
off_z = off_hz // H | |
off_h = off_hz % H | |
off_hkv = off_h // (H//H_KV) | |
q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh | |
k_offset = off_z.to(tl.int64) * stride_kz + off_hkv.to(tl.int64) * stride_kh | |
v_offset = off_z.to(tl.int64) * stride_vz + off_hkv.to(tl.int64) * stride_vh | |
o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh | |
# block pointers | |
Q_block_ptr = tl.make_block_ptr( | |
base=Q + q_offset, | |
shape=(N_CTX, BLOCK_DMODEL), | |
strides=(stride_qm, stride_qk), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, BLOCK_DMODEL), | |
order=(1, 0), | |
) | |
V_block_ptr = tl.make_block_ptr( | |
base=V + v_offset, | |
shape=(NKV_CTX, BLOCK_DMODEL), | |
strides=(stride_vk, stride_vn), | |
offsets=(0, 0), | |
block_shape=(BLOCK_N, BLOCK_DMODEL), | |
order=(1, 0), | |
) | |
K_block_ptr = tl.make_block_ptr( | |
base=K + k_offset, | |
shape=(BLOCK_DMODEL, NKV_CTX), | |
strides=(stride_kk, stride_kn), | |
offsets=(0, 0), | |
block_shape=(BLOCK_DMODEL, BLOCK_N), | |
order=(0, 1), | |
) | |
O_block_ptr = tl.make_block_ptr( | |
base=Out + o_offset, | |
shape=(ROUND_CTX, BLOCK_DMODEL), | |
strides=(stride_om, stride_on), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, BLOCK_DMODEL), | |
order=(1, 0), | |
) | |
# initialize offsets | |
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
# initialize pointer to m and l | |
m_ptrs = M + off_hz * ROUND_CTX + offs_m | |
l_ptrs = L + off_hz * ROUND_CTX + offs_m | |
if INIT: | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 | |
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) | |
else: | |
# don't have to check boundary for q len | |
m_i = tl.load(m_ptrs).to(tl.float32) | |
l_i = tl.load(l_ptrs).to(tl.float32) | |
acc = tl.load(O_block_ptr).to(tl.float32) | |
qk_scale = sm_scale | |
qk_scale *= 1.4426950408889634 # 1/log(2) | |
# load q: it will stay in SRAM throughout | |
if IS_EVEN_M: | |
q = tl.load(Q_block_ptr) | |
else: | |
q = tl.load(Q_block_ptr, boundary_check=(0, 1), padding_option="zero") | |
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # | |
start_m, qk_scale, NKV_CTX, # | |
sliding_window_offset, sliding_window_size, | |
BLOCK_M, BLOCK_DMODEL, BLOCK_N, SLIDING_WINDOW, IS_EVEN_M, IS_EVEN_N, | |
COMPLEMENT_SLIDING_WINDOW) | |
# epilogue | |
if (END): | |
m_i += tl.math.log2(l_i) | |
acc = acc / l_i[:, None] | |
else: | |
tl.store(l_ptrs, l_i) | |
tl.store(m_ptrs, m_i) | |
tl.store(O_block_ptr, acc.to(Out.type.element_ty)) | |
def _score_kernel( | |
Q, K, M, sm_scale, Out, | |
stride_qz, stride_qh, stride_qm, stride_qk, # | |
stride_kz, stride_kh, stride_kn, stride_kk, # | |
stride_oz, stride_oh, stride_on, | |
Z, H, H_KV, # | |
N_CTX, # | |
ROUND_CTX, | |
NKV_CTX, | |
sliding_window_offset, | |
sliding_window_size, | |
SLIDING_WINDOW: tl.constexpr, | |
COMPLEMENT_SLIDING_WINDOW: tl.constexpr, | |
IS_EVEN_M: tl.constexpr, | |
IS_EVEN_N: tl.constexpr, | |
BLOCK_M: tl.constexpr, # | |
BLOCK_DMODEL: tl.constexpr, # | |
BLOCK_N: tl.constexpr, # | |
): | |
start_n = tl.program_id(0) | |
off_hz = tl.program_id(1) | |
off_z = off_hz // H | |
off_h = off_hz % H | |
off_hkv = off_h // (H//H_KV) | |
q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh | |
k_offset = off_z.to(tl.int64) * stride_kz + off_hkv.to(tl.int64) * stride_kh | |
m_ptrs = M + off_hz * ROUND_CTX + tl.arange(0, BLOCK_M) | |
o = tl.zeros([BLOCK_M], dtype=tl.float32) | |
Q_block_ptr = tl.make_block_ptr( | |
base=Q + q_offset, | |
shape=(N_CTX, BLOCK_DMODEL), | |
strides=(stride_qm, stride_qk), | |
offsets=(0, 0), | |
block_shape=(BLOCK_M, BLOCK_DMODEL), | |
order=(1, 0), | |
) | |
K_block_ptr = tl.make_block_ptr( | |
base=K + k_offset, | |
shape=(BLOCK_DMODEL, NKV_CTX), | |
strides=(stride_kk, stride_kn), | |
offsets=(0, start_n * BLOCK_N), | |
block_shape=(BLOCK_DMODEL, BLOCK_N), | |
order=(0, 1), | |
) | |
if IS_EVEN_N: | |
k = tl.load(K_block_ptr) | |
else: | |
k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option="zero") | |
lo = 0 | |
hi = ROUND_CTX | |
qk_scale = sm_scale | |
qk_scale *= 1.4426950408889634 # 1/log(2) | |
for start_m in range(lo, hi, BLOCK_M): | |
start_m = tl.multiple_of(start_m, BLOCK_M) | |
if IS_EVEN_M: | |
q = tl.load(Q_block_ptr) | |
else: | |
q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero") | |
m = tl.load(m_ptrs) | |
# calc qk | |
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) | |
qk += tl.dot(q, k) | |
qk = qk * qk_scale | |
if SLIDING_WINDOW: | |
# dist = tl.arange(start_m, start_m + BLOCK_M)[:, None] \ | |
# - tl.arange(start_n * BLOCK_N, (start_n + 1) + BLOCK_N)[None, :] + sliding_window_offset | |
dist = tl.arange(0, BLOCK_M)[:, None] - tl.arange(0, BLOCK_N)[None, :] \ | |
+ start_m - start_n * BLOCK_N + sliding_window_offset | |
if COMPLEMENT_SLIDING_WINDOW: | |
mask = (dist >= sliding_window_size) | |
else: | |
mask = (dist >= 0) & (dist < sliding_window_size) | |
qk = qk - m[:, None] | |
p = tl.math.exp2(qk) # (BLOCK_M, BLOCK_N) | |
if SLIDING_WINDOW: | |
p = tl.where(mask, p, 0) | |
if not IS_EVEN_N: | |
p = tl.where( | |
((tl.arange(0, BLOCK_M) + start_m) < N_CTX)[:, None], | |
p, 0 | |
) | |
o += tl.sum(p, axis=0) | |
Q_block_ptr = tl.advance(Q_block_ptr, offsets=(BLOCK_M, 0)) | |
m_ptrs = m_ptrs + BLOCK_M | |
o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh | |
o_range = tl.arange(0, BLOCK_N) + start_n * BLOCK_N # orange | |
o_ptrs = Out + o_offset + o_range | |
tl.store(o_ptrs, o.to(Out.type.element_ty), mask = o_range < NKV_CTX) | |
def get_score(q, k, m, sliding_window, complement_sliding_window): | |
assert q.dim() == 4 | |
assert k.dim() == 4 | |
assert m.dim() == 3 | |
assert q.shape[:2] == m.shape[:2] | |
N_CTX = q.size(-2) | |
NKV_CTX = k.size(-2) | |
ROUND_CTX = m.size(-1) | |
ret = torch.zeros( | |
(q.size(0), q.size(1), k.size(2)), | |
dtype=k.dtype, device=k.device | |
) | |
if sliding_window is not None: | |
sliding_window_offset, sliding_window_size = sliding_window | |
else: | |
sliding_window_offset, sliding_window_size = None, None | |
grid = lambda META: ( | |
triton.cdiv(k.shape[2], META["BLOCK_N"]), | |
q.shape[0] * q.shape[1] | |
) | |
sm_scale = 1 / math.sqrt(q.size(-1)) | |
global _BLOCK_N | |
global _BLOCK_M | |
try: | |
_score_kernel[grid]( | |
q, k, m, sm_scale, ret, | |
q.stride(0), q.stride(1), q.stride(2), q.stride(3), | |
k.stride(0), k.stride(1), k.stride(2), k.stride(3), | |
ret.stride(0), ret.stride(1), ret.stride(2), | |
q.size(0), q.size(1), k.size(1), | |
N_CTX, ROUND_CTX, NKV_CTX, | |
sliding_window_offset, | |
sliding_window_size, | |
SLIDING_WINDOW=(sliding_window is not None), | |
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window, | |
BLOCK_M=_BLOCK_M, | |
BLOCK_N=_BLOCK_N, | |
BLOCK_DMODEL=q.size(-1) | |
) | |
except triton.OutOfResources as E: | |
from warnings import warn | |
_BLOCK_N = _BLOCK_N // 2 | |
_BLOCK_M = _BLOCK_M // 2 | |
warn(f"Triton Attention Output Resources. {E}\nUse smaller block size {_BLOCK_N}.") | |
_score_kernel[grid]( | |
q, k, m, sm_scale, ret, | |
q.stride(0), q.stride(1), q.stride(2), q.stride(3), | |
k.stride(0), k.stride(1), k.stride(2), k.stride(3), | |
ret.stride(0), ret.stride(1), ret.stride(2), | |
q.size(0), q.size(1), k.size(1), | |
N_CTX, ROUND_CTX, NKV_CTX, | |
sliding_window_offset, | |
sliding_window_size, | |
SLIDING_WINDOW=(sliding_window is not None), | |
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window, | |
BLOCK_M=_BLOCK_M, | |
BLOCK_N=_BLOCK_N, | |
BLOCK_DMODEL=q.size(-1) | |
) | |
return ret | |
def _forward( | |
q, k, v, sm_scale, | |
o = None, m = None, l = None, end = False, | |
sliding_window=None, init=False, | |
complement_sliding_window=False | |
): | |
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] | |
assert Lq == Lk and Lk == Lv | |
assert Lk in {16, 32, 64, 128} | |
q_round_len = math.ceil(q.shape[2] / 64) * 64 | |
if sliding_window is not None: | |
sliding_window_offset, sliding_window_size = sliding_window | |
else: | |
sliding_window_offset, sliding_window_size = None, None | |
grid = lambda META: ( | |
triton.cdiv(q.shape[2], META["BLOCK_M"]), | |
q.shape[0] * q.shape[1], | |
) | |
global _BLOCK_N | |
global _BLOCK_M | |
try: | |
_attn_fwd[grid]( | |
q, k, v, sm_scale, m, o, l, # | |
q.stride(0), q.stride(1), q.stride(2), q.stride(3), # | |
k.stride(0), k.stride(1), k.stride(2), k.stride(3), # | |
v.stride(0), v.stride(1), v.stride(2), v.stride(3), # | |
o.stride(0), o.stride(1), o.stride(2), o.stride(3), # | |
q.shape[0], q.shape[1], k.shape[1], # | |
q.shape[2], # | |
q_round_len, | |
k.shape[2], | |
sliding_window_offset, | |
sliding_window_size, | |
BLOCK_DMODEL=Lk, # | |
END=end, | |
INIT=init, | |
BLOCK_M=_BLOCK_M, | |
BLOCK_N=_BLOCK_N, | |
SLIDING_WINDOW=(sliding_window is not None), | |
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window, | |
num_warps=4, | |
num_stages=4 | |
) | |
except triton.OutOfResources as E: | |
_BLOCK_N = _BLOCK_N // 2 | |
_BLOCK_M = _BLOCK_M // 2 | |
from warnings import warn | |
warn(f"Triton Attention Output Resources. {E}\nUse smaller block size {_BLOCK_N}.") | |
_attn_fwd[grid]( | |
q, k, v, sm_scale, m, o, l, # | |
q.stride(0), q.stride(1), q.stride(2), q.stride(3), # | |
k.stride(0), k.stride(1), k.stride(2), k.stride(3), # | |
v.stride(0), v.stride(1), v.stride(2), v.stride(3), # | |
o.stride(0), o.stride(1), o.stride(2), o.stride(3), # | |
q.shape[0], q.shape[1], k.shape[1], # | |
q.shape[2], # | |
q_round_len, | |
k.shape[2], | |
sliding_window_offset, | |
sliding_window_size, | |
BLOCK_DMODEL=Lk, # | |
END=end, | |
INIT=init, | |
BLOCK_M=_BLOCK_M, | |
BLOCK_N=_BLOCK_N, | |
SLIDING_WINDOW=(sliding_window is not None), | |
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window, | |
num_warps=4, | |
num_stages=4 | |
) | |
if end: | |
o = o[:, :, :q.shape[2], :].contiguous().to(q.dtype) | |
return o, m, l | |
class MultiStageDotProductionAttention: | |
def __init__( | |
self, | |
q_shape, | |
dtype, | |
device, | |
): | |
self.q_shape = q_shape | |
self.dtype = dtype | |
self.device = device | |
self.end = False | |
self.ret = torch.zeros( | |
q_shape, dtype=dtype, device=device | |
) | |
self.score_list = [] | |
def append( | |
self, | |
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, | |
sliding_window=None, complement_sliding_window: bool = False, | |
end=False, get_score=False, | |
*args, **kwargs | |
): | |
raise NotImplementedError | |
def get_result(self): | |
return self.ret, self.score_list | |
class TritonMultiStageDotProductionAttention(MultiStageDotProductionAttention): | |
def __init__(self, q_shape, dtype, device): | |
self.q_shape = q_shape | |
self.dtype = dtype | |
self.device = device | |
q_round_len = math.ceil(q_shape[2] / 64) * 64 | |
o_shape = (q_shape[0], q_shape[1], q_round_len, q_shape[3]) | |
m_shape = (q_shape[0], q_shape[1], q_round_len) | |
l_shape = (q_shape[0], q_shape[1], q_round_len) | |
self.o = torch.empty(o_shape, device=device, dtype=torch.float32) | |
self.m = torch.empty(m_shape, device=device, dtype=torch.float32) | |
self.l = torch.empty(l_shape, device=device, dtype=torch.float32) | |
self.q_list = [] | |
self.k_list = [] | |
self.sliding_window_list = [] | |
self.complement_sliding_window_list = [] | |
self.score_list = [] | |
self.end = False | |
self.init = False | |
def finalize(self): | |
self.end = True | |
for q, k, sliding_window, comp in zip(self.q_list, self.k_list, self.sliding_window_list, self.complement_sliding_window_list): | |
if q is not None: | |
score = get_score(q, k, self.m, sliding_window, comp) | |
self.score_list.append(score) | |
else: | |
self.score_list.append(None) | |
self.ret = self.o | |
def append(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, end=False, get_score=False, sliding_window = None, complement_sliding_window: bool = False): | |
assert q.shape == self.q_shape | |
if isinstance(sliding_window, int): | |
sliding_window = ( | |
k.shape[2] - q.shape[2], sliding_window | |
) | |
q = q.contiguous() | |
k = k.contiguous() | |
v = v.contiguous() | |
sm_scale = 1 / math.sqrt(q.shape[-1]) | |
o, m, l = _forward( | |
q, k, v, sm_scale, self.o, self.m, self.l, | |
sliding_window=sliding_window, end=end, init=not self.init, | |
complement_sliding_window=complement_sliding_window | |
) | |
self.init = True | |
self.o = o | |
self.m = m | |
self.l = l | |
if get_score: | |
self.q_list.append(q) | |
self.k_list.append(k) | |
self.sliding_window_list.append(sliding_window) | |
self.complement_sliding_window_list.append(complement_sliding_window) | |
else: | |
self.q_list.append(None) | |
self.k_list.append(None) | |
self.sliding_window_list.append(None) | |
self.complement_sliding_window_list.append(None) | |
if end: | |
assert not self.end | |
self.finalize() | |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
""" | |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | |
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | |
""" | |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
if n_rep == 1: | |
return hidden_states | |
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) | |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | |
def streaming_forward( | |
q, k, v, | |
n_init, n_local, | |
): | |
# q,k,v should be tensors already equipped with RoPE | |
# k,v should already repeated to align with q.shape | |
assert q.dim() == 4 # (bsz, num_heads, seqlen, head_dim) | |
assert q.shape == k.shape == v.shape | |
head_dim = q.shape[-1] | |
if head_dim not in [16, 32, 64, 128, 256, 512]: | |
target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim | |
q = torch.nn.functional.pad(q, [0, target_dim, 0, 0, 0, 0, 0, 0]) | |
k = torch.nn.functional.pad(k, [0, target_dim, 0, 0, 0, 0, 0, 0]) | |
v = torch.nn.functional.pad(v, [0, target_dim, 0, 0, 0, 0, 0, 0]) | |
q_len = q.size(2) | |
k_len = k.size(2) | |
attn = TritonMultiStageDotProductionAttention(q.shape, q.dtype, q.device) | |
if k_len > n_local: | |
init_k = k[:, :, :n_init, :].contiguous() | |
init_v = v[:, :, :n_init, :].contiguous() | |
attn.append(q, k, v, sliding_window=n_local) | |
attn.append( | |
q, init_k, init_v, end=True, | |
sliding_window=(k_len - q_len, n_local), complement_sliding_window=True | |
) | |
else: | |
attn.append(q, k, v, sliding_window=n_local, end=True) | |
score, _ = attn.get_result() | |
return score[...,:head_dim] | |
def streaming_forward2( | |
q, k, v, | |
n_init, n_local, | |
): | |
q_len = q.size(2) | |
k_len = k.size(2) | |
attn = TritonMultiStageDotProductionAttention(q.shape, q.dtype, q.device) | |
if k_len > n_local: | |
init_k = k[:, :, :n_init, :].contiguous() | |
init_v = v[:, :, :n_init, :].contiguous() | |
else: | |
init_k = torch.empty( | |
(k.size(0), k.size(1), 0, k.size(3)), | |
dtype=k.dtype, device=k.device | |
) | |
init_v = torch.empty( | |
(v.size(0), v.size(1), 0, v.size(3)), | |
dtype=v.dtype, device=v.device | |
) | |
attn.append(q, k, v, sliding_window=n_local) | |
attn.append( | |
q, init_k, init_v, end=True, | |
sliding_window=(k_len - q_len, n_local), complement_sliding_window=True | |
) | |
score, _ = attn.get_result() | |
return score | |
def stream_llm_forward(n_local, n_init, *args, **kwargs): | |
Attn = TritonMultiStageDotProductionAttention | |
def forward(self, query : torch.Tensor, | |
key_value : torch.Tensor, | |
position_bias : torch.Tensor, | |
use_cache: bool, | |
past_key_value, | |
project_q, project_k, project_v, attention_out, | |
dim_head, num_heads, num_heads_kv | |
): | |
batch_size = query.size(0) | |
len_q = query.size(1) | |
len_k = key_value.size(1) | |
h_q = project_q(query) # (batch, len_q, num_heads * dim_head) | |
h_k = project_k(key_value) # (batch, len_k, num_heads * dim_head) | |
h_v = project_v(key_value) # (batch, len_k, num_heads * dim_head) | |
h_q = h_q.view(batch_size, len_q, num_heads, dim_head).permute(0, 2, 1, 3) # (batch, num_heads, len_q, dim_head) | |
h_k = h_k.view(batch_size, len_k, num_heads_kv, dim_head).permute(0, 2, 1, 3) # (batch, num_heads_kv, len_k, dim_head) | |
h_v = h_v.view(batch_size, len_k, num_heads_kv, dim_head).permute(0, 2, 1, 3) # (batch, num_heads_kv, len_k, dim_head) | |
h_q = h_q.contiguous() # (batch * num_heads, len_q, dim_head) | |
h_k = h_k.contiguous() # (batch * num_heads, len_k, dim_head) | |
h_v = h_v.contiguous() # (batch * num_heads, len_k, dim_head) | |
if past_key_value is not None: | |
h_k = torch.cat([past_key_value[0], h_k], dim=-2) | |
h_v = torch.cat([past_key_value[1], h_v], dim=-2) | |
len_k += past_key_value[2] | |
if use_cache: | |
if len_k <= n_local + n_init: | |
h_k_cache = h_k | |
h_v_cache = h_v | |
else: | |
h_k_cache = torch.cat([h_k[:,:, :n_init, :], h_k[:, :, max(0, h_k.size(-2) - n_local):, :]], dim=2) | |
h_v_cache = torch.cat([h_v[:,:, :n_init, :], h_v[:, :, max(0, h_k.size(-2) - n_local):, :]], dim=2) | |
current_key_value = (h_k_cache, h_v_cache, len_k) | |
else: | |
current_key_value = None | |
h_q_ = h_q | |
h_k_ = h_k | |
h_v_ = h_v | |
if len_q + n_local < h_k_.size(-2): | |
h_k_ = h_k_[:, :, h_k_.size(-2) - len_q - n_local:, :].contiguous().clone() | |
h_v_ = h_v_[:, :, h_v_.size(-2) - len_q - n_local:, :].contiguous().clone() | |
local_h_q, local_h_k = position_bias(h_q_, h_k_) | |
local_h_v = h_v_ | |
if len_k > n_local: | |
init_h_q = position_bias.apply_rotary_pos_emb_one_angle( | |
h_q, n_local + n_init | |
) | |
init_h_k = position_bias.apply_rotary_pos_emb( | |
h_k[:, :, :n_init, :].contiguous(), | |
n_init, n_init, position_bias._cos_cached, position_bias._sin_cached | |
) | |
init_h_v = h_v[:, :, :n_init, :].contiguous() | |
else: | |
init_h_q = h_q | |
init_h_k = torch.empty( | |
(batch_size, num_heads_kv, 0, dim_head), | |
device=h_k.device, | |
dtype=h_k.dtype | |
) | |
init_h_v = torch.empty( | |
(batch_size, num_heads_kv, 0, dim_head), | |
device=h_v.device, | |
dtype=h_v.dtype | |
) | |
attn = Attn(local_h_q.shape, local_h_q.dtype, local_h_q.device) | |
attn.append(local_h_q, local_h_k, local_h_v, sliding_window=n_local) | |
attn.append( | |
init_h_q, init_h_k, init_h_v, end=True, | |
sliding_window=(len_k - len_q, n_local), | |
complement_sliding_window=True | |
) | |
score, _ = attn.get_result() | |
score = score.view(batch_size, num_heads, len_q, dim_head).permute(0, 2, 1, 3).contiguous() # (batch, len_q, num_heads, dim_head) | |
score = score.reshape(batch_size, len_q, num_heads * dim_head) # (batch, len_q, num_heads * dim_head) | |
score = attention_out(score) | |
if use_cache: | |
return score, current_key_value | |
else: | |
return score | |
return forward | |