TransNormerLLM-385M / lightning_attention2.py
OpenNLPLab's picture
Upgrade to lightning att2
d88d1ec verified
# Copyright 2024 OpenNLPLab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding=utf-8
import torch
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q,
K,
V,
Out,
S,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vn,
stride_ve,
stride_oz,
stride_oh,
stride_om,
stride_oe,
stride_sh,
Z,
H,
N_CTX,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL_QK: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DMODEL_V: tl.constexpr,
IS_CAUSAL: tl.constexpr,
USE_DECAY: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_h = off_hz % H
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL_QK)
offs_e = tl.arange(0, BLOCK_DMODEL_V)
# get current offset of q k v
off_q = (off_hz * stride_qh + offs_m[:, None] * stride_qm
+ offs_k[None, :] * stride_qk)
off_k = (off_hz * stride_kh + offs_n[:, None] * stride_kn
+ offs_k[None, :] * stride_kk)
off_v = (off_hz * stride_vh + offs_n[:, None] * stride_vn
+ offs_e[None, :] * stride_ve)
off_o = (off_hz * stride_oh + offs_m[:, None] * stride_om
+ offs_e[None, :] * stride_oe)
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0)
# loop over k, v and update accumulator
lo = 0
# print(start_m)
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < N_CTX,
other=0.0,
)
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < N_CTX,
other=0.0,
)
# -- compute qk ---
# qk = tl.dot(q, k)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# qk += tl.dot(q, k, trans_b=True)
qk += tl.dot(q, tl.trans(k))
if IS_CAUSAL:
index = offs_m[:, None] - (start_n + offs_n[None, :])
if USE_DECAY:
S_block_ptr = S + off_h * stride_sh
s = tl.load(S_block_ptr)
s_index = s * index
s_index = tl.where(s_index >= 0, -s_index, float("-inf"))
qk = tl.exp(s_index) * qk
else:
qk = tl.where(index >= 0, qk, 0)
acc += tl.dot(qk, v.to(qk.dtype))
out_ptrs = Out + off_o
tl.store(out_ptrs, acc.to(q.dtype), mask=offs_m[:, None] < N_CTX)
@triton.jit
def _bwd_kernel_kv(
Q,
K,
V,
S,
DO,
DQ,
DK,
DV,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vn,
stride_ve,
stride_oz,
stride_oh,
stride_om,
stride_oe,
stride_sh,
Z,
H,
N_CTX,
num_block,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL_QK: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DMODEL_V: tl.constexpr,
CAUSAL: tl.constexpr,
USE_DECAY: tl.constexpr,
):
start_n = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
DO += off_z * stride_oz + off_h * stride_oh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_kz + off_h * stride_kh
DV += off_z * stride_vz + off_h * stride_vh
# start of q
if CAUSAL:
lo = start_n * BLOCK_M
else:
lo = 0
# initialize row/col offsets
# seqlence offset
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_kvn = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
# feature offset
offs_qkk = tl.arange(0, BLOCK_DMODEL_QK)
offs_ve = tl.arange(0, BLOCK_DMODEL_V)
# row block index
offs_m = tl.arange(0, BLOCK_M)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_qkk[None, :] * stride_qk)
k_ptrs = K + (offs_kvn[:, None] * stride_kn
+ offs_qkk[None, :] * stride_kk)
v_ptrs = V + (offs_kvn[:, None] * stride_vn + offs_ve[None, :] * stride_ve)
do_ptrs = DO + (offs_qm[:, None] * stride_om
+ offs_ve[None, :] * stride_oe)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm
+ offs_qkk[None, :] * stride_qk)
# initialize dv amd dk
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL_V], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL_QK], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
v = tl.load(v_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX, other=0.0)
qk = tl.dot(q, tl.trans(k))
# qk = tl.dot(q, k, trans_b=True)
if CAUSAL:
index = offs_m_curr[:, None] - offs_kvn[None, :]
if USE_DECAY:
S_block_ptr = S + off_h * stride_sh
s = tl.load(S_block_ptr)
s_index = s * index
s_index = tl.where(s_index >= 0, -s_index, float("-inf"))
s = tl.exp(s_index)
qk = qk * s
else:
qk = tl.where(index >= 0, qk, 0)
p = qk
# compute dv
do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX, other=0.0)
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
dp = tl.dot(do, tl.trans(v).to(do.dtype))
if CAUSAL:
if USE_DECAY:
dp = dp * s
else:
dp = tl.where(index >= 0, dp, 0)
dk += tl.dot(tl.trans(dp.to(q.dtype)), q).to(tl.float32)
# increment pointers
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_om
# write-back
dv_ptrs = DV + (offs_kvn[:, None] * stride_vn
+ offs_ve[None, :] * stride_ve)
dk_ptrs = DK + (offs_kvn[:, None] * stride_kn
+ offs_qkk[None, :] * stride_kk)
tl.store(dv_ptrs, dv, mask=offs_kvn[:, None] < N_CTX)
tl.store(dk_ptrs, dk, mask=offs_kvn[:, None] < N_CTX)
@triton.jit
def _bwd_kernel_q(
Q,
K,
V,
S,
DO,
DQ,
DK,
DV,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vn,
stride_ve,
stride_oz,
stride_oh,
stride_om,
stride_oe,
stride_sh,
Z,
H,
N_CTX,
num_block,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL_QK: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DMODEL_V: tl.constexpr,
CAUSAL: tl.constexpr,
USE_DECAY: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
DO += off_z * stride_oz + off_h * stride_oh
DQ += off_z * stride_qz + off_h * stride_qh
# feature offset
offs_qkk = tl.arange(0, BLOCK_DMODEL_QK)
offs_ve = tl.arange(0, BLOCK_DMODEL_V)
# row block index
offs_m = tl.arange(0, BLOCK_M)
# row block index
offs_qm = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# do
do_ptrs = DO + (offs_qm[:, None] * stride_om
+ offs_ve[None, :] * stride_oe)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm
+ offs_qkk[None, :] * stride_qk)
do = tl.load(do_ptrs, mask=offs_qm[:, None] < N_CTX, other=0.0)
dq = tl.zeros([BLOCK_M, BLOCK_DMODEL_QK], dtype=tl.float32)
lo = 0
hi = (start_m + 1) * BLOCK_M if CAUSAL else N_CTX
offs_m_curr = start_m * BLOCK_M + offs_m
for start_n in range(0, num_block):
offs_kvn = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
k_ptrs = K + (offs_kvn[:, None] * stride_kn
+ offs_qkk[None, :] * stride_kk)
v_ptrs = V + (offs_kvn[:, None] * stride_vn
+ offs_ve[None, :] * stride_ve)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
v = tl.load(v_ptrs, mask=offs_kvn[:, None] < N_CTX, other=0.0)
# dp = do vT
dp = tl.dot(do, tl.trans(v).to(do.dtype))
if CAUSAL:
index = offs_m_curr[:, None] - offs_kvn[None, :]
if USE_DECAY:
S_block_ptr = S + off_h * stride_sh
s = tl.load(S_block_ptr)
s_index = s * index
s_index = tl.where(s_index >= 0, -s_index, float("-inf"))
s = tl.exp(s_index)
dp = dp * s
else:
dp = tl.where(index >= 0, dp, 0)
# dq = dq + dp k
dq += tl.dot(dp.to(k.dtype), k)
tl.store(dq_ptrs, dq, mask=offs_qm[:, None] < N_CTX)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, s):
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
s = s.contiguous()
# only support for Ampere now
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
raise RuntimeError(
"Lightning attention currently only supported for compute capability >= 80"
)
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
# right
o = torch.empty(
(q.shape[0], q.shape[1], q.shape[2], v.shape[-1]),
dtype=q.dtype,
device=q.device,
)
BLOCK_M = 128
BLOCK_N = 64
num_warps = 4 if Lk <= 64 else 8
num_stages = 1
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
use_decay = s.shape[0] > 0
_fwd_kernel[grid](
q,
k,
v,
o,
s,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
s.stride(0),
q.shape[0],
q.shape[1],
q.shape[2],
BLOCK_M=BLOCK_M,
BLOCK_DMODEL_QK=Lk,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL_V=Lv,
IS_CAUSAL=causal,
USE_DECAY=use_decay,
num_warps=num_warps,
num_stages=num_stages,
)
ctx.save_for_backward(q, k, v, s)
ctx.grid = grid
ctx.BLOCK_M = BLOCK_M
ctx.BLOCK_DMODEL_QK = Lk
ctx.BLOCK_N = BLOCK_N
ctx.BLOCK_DMODEL_V = Lv
ctx.causal = causal
ctx.use_decay = use_decay
return o
@staticmethod
def backward(ctx, do):
q, k, v, s = ctx.saved_tensors
BLOCK_M = 32
BLOCK_N = 32
num_warps = 4
num_stages = 1
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
grid_kv = (triton.cdiv(k.shape[2],
BLOCK_N), k.shape[0] * k.shape[1], 1)
_bwd_kernel_kv[grid_kv](
q,
k,
v,
s,
do,
dq,
dk,
dv,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
do.stride(0),
do.stride(1),
do.stride(2),
do.stride(3),
s.stride(0),
q.shape[0],
q.shape[1],
q.shape[2],
grid_kv[0],
BLOCK_M=BLOCK_M,
BLOCK_DMODEL_QK=ctx.BLOCK_DMODEL_QK,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL_V=ctx.BLOCK_DMODEL_V,
CAUSAL=ctx.causal,
USE_DECAY=ctx.use_decay,
num_warps=num_warps,
num_stages=num_stages,
)
grid_q = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
_bwd_kernel_q[grid_q](
q,
k,
v,
s,
do,
dq,
dk,
dv,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
do.stride(0),
do.stride(1),
do.stride(2),
do.stride(3),
s.stride(0),
q.shape[0],
q.shape[1],
q.shape[2],
grid_q[0],
BLOCK_M=BLOCK_M,
BLOCK_DMODEL_QK=ctx.BLOCK_DMODEL_QK,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL_V=ctx.BLOCK_DMODEL_V,
CAUSAL=ctx.causal,
USE_DECAY=ctx.use_decay,
num_warps=num_warps,
num_stages=num_stages,
)
return dq.to(q.dtype), dk, dv, None, None
attention = _attention.apply
def lightning_attention(q, k, v, causal, ed):
d = q.shape[-1]
e = v.shape[-1]
# arr = f(d)
if d >= 128:
m = 128
else:
m = 64
arr = [m * i for i in range(d // m + 1)]
if arr[-1] != d:
arr.append(d)
n = len(arr)
output = 0
for i in range(n - 1):
s = arr[i]
e = arr[i + 1]
q1 = q[..., s:e]
k1 = k[..., s:e]
o = attention(q1, k1, v, causal, ed)
output = output + o
return output