OpenNLPLab's picture
Upload codebase
bc6e7dd
raw
history blame
5.76 kB
# CREDITS: This comes almost as-is from the Triton layer norm tutorial
# https://github.com/openai/triton/blob/master/python/tutorials/05-layer-norm.py
# Copyright 2024 OpenNLPLab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding=utf-8
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
# fmt: off
@triton.jit
def srms_norm_fw(X, Y, V, stride, N, eps, BLOCK_SIZE_N: tl.constexpr):
# fmt: on
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
# Move to this row
x_ptrs = X + row * stride + cols
x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32)
x_zm = tl.where(mask, x, 0.0)
x_var = tl.sum(x_zm * x_zm, axis=0) / N
rstd = 1.0 / tl.sqrt(x_var + eps)
# Normalize, optionally affine
y = x_zm * rstd
tl.store(V + row, rstd)
y_ptrs = Y + row * stride + cols
tl.store(y_ptrs, y, mask=mask)
# Backward pass (DX + partial DW + partial DB)
# fmt: off
@triton.jit
def srms_norm_bwd_dx_fused(
DX, DY,
X, V,
stride, N,
# META-parameters
BLOCK_SIZE_N: tl.constexpr,
):
# fmt: on
# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
# offset data pointers to start at the row of interest
x_ptrs = X + row * stride + cols
dy_ptrs = DY + row * stride + cols
# load data to SRAM
x = tl.load(x_ptrs, mask=mask, other=0)
dy = tl.load(dy_ptrs, mask=mask, other=0)
rstd = tl.load(V + row)
# compute dx
xhat = x * rstd
wdy = dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy, 0.)
mean1 = tl.sum(xhat * wdy, axis=0) / N
dx = (wdy - (xhat * mean1)) * rstd
# write-back dx
mask = cols < N # re-materialize the mask to save registers
dx_ptrs = DX + row * stride + cols
tl.store(dx_ptrs, dx, mask=mask)
class _SrmsNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, eps):
# catch eps being too small if the tensors are fp16
if x.dtype == torch.float16:
eps = max(eps, 1.6e-5)
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
# allocate mean and std, they'll be used in the backward pass
rstd = torch.empty((M, ), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
if not x_arg.is_contiguous() or not y.is_contiguous():
x_arg = x_arg.contiguous()
y = y.contiguous()
# heuristics for number of warps.
num_warps = min(max(BLOCK_SIZE_N // 256, 1), 16)
# enqueue kernel
# fmt: off
srms_norm_fw[(M,)](
x_arg, y, rstd,
x_arg.stride(0),
N,
eps,
num_warps=num_warps,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)
# fmt: on
ctx.save_for_backward(x, rstd)
ctx.BLOCK_SIZE_N = BLOCK_SIZE_N
ctx.num_warps = num_warps
return y.reshape_as(x)
@staticmethod
def backward(
ctx, dy
): # pragma: no cover # this is covered, but called directly from C++
x, rstd = ctx.saved_tensors
# flatten the batch dimension, if any.
# We're interested in 'samples' x norm_dimension
x = x.reshape(-1, x.size(-1))
M, N = x.size()
# heuristics for amount of parallel reduction stream for DG/DB
GROUP_SIZE_M = 32
if N <= 8192:
GROUP_SIZE_M = 64
if N <= 4096:
GROUP_SIZE_M = 96
if N <= 2048:
GROUP_SIZE_M = 128
if N <= 1024:
GROUP_SIZE_M = 256
if dy.dtype == torch.float32:
GROUP_SIZE_M = GROUP_SIZE_M // 2
# allocate output
dy = dy.contiguous()
dx = torch.empty_like(dy)
# Check the tensor shapes and layouts
# we suppose in the kernel that they have the same size and are contiguous
assert (
dy.numel() == x.numel()
), "Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm"
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
num_warps = min(max(ctx.BLOCK_SIZE_N // 256, 1), 16)
# fmt: off
srms_norm_bwd_dx_fused[(M,)](
dx, dy, x,
rstd,
x.stride(0),
N,
BLOCK_SIZE_N=ctx.BLOCK_SIZE_N,
num_warps=num_warps
)
# fmt: on
dx = dx.reshape_as(dy)
return dx, None, None
class SimpleRMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.dim = dim
def forward(self, x):
return _SrmsNorm.apply(x, self.eps)