File size: 16,055 Bytes
0743270 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 |
# Copyright (c) 2023, Tri Dao.
# Copyright 2024 CATIE. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modification to the original version from Tri Dao:
# - support for torch.compile
from typing import Tuple, Optional
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 2 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
@triton.heuristics(
{
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
}
)
@triton.jit
def cross_entropy_fwd_kernel(
loss_ptr, # data ptrs
lse_ptr,
z_loss_ptr,
logits_ptr,
labels_ptr,
smoothing,
logit_scale,
lse_square_scale,
ignore_index,
total_classes,
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
n_cols, # shapes
logits_row_stride, # strides
BLOCK_SIZE: tl.constexpr,
HAS_SMOOTHING: tl.constexpr,
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
SPLIT: tl.constexpr,
PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0)
):
row_idx = tl.program_id(0)
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
sum_logits = 0.0 # For smoothing
if not PRECOMPUTED_LSE:
# Statistics for online softmax
m_i = -float("inf")
l_i = 0.0
for col_offset in range(0, n_cols, BLOCK_SIZE):
cols = col_offset + tl.arange(0, BLOCK_SIZE)
logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
tl.float32
) * logit_scale
if HAS_SMOOTHING:
sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))
m_i_new = tl.maximum(m_i, tl.max(logits))
l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
m_i = m_i_new
lse = tl.log(l_i) + m_i
tl.store(lse_ptr + row_idx, lse)
else:
lse = tl.load(lse_ptr + row_idx)
label_idx = tl.load(labels_ptr + row_idx)
if label_idx == ignore_index:
loss = 0.0
z_loss = 0.0
else:
label_idx -= class_start_idx
if label_idx >= 0 and label_idx < n_cols:
logits_label = tl.load(logits_ptr + label_idx) * logit_scale
if HAS_SMOOTHING:
loss = (
(lse if not SPLIT else 0.0)
- smoothing * sum_logits / total_classes
- (1 - smoothing) * logits_label
)
else:
loss = (lse if not SPLIT else 0.0) - logits_label
else:
# If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
if HAS_SMOOTHING:
loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
else:
loss = 0.0
if not SPLIT:
z_loss = lse_square_scale * lse * lse
loss += z_loss
else:
z_loss = 0.0
tl.store(loss_ptr + row_idx, loss)
if not SPLIT:
tl.store(z_loss_ptr + row_idx, z_loss)
@triton.heuristics(
{
"HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
}
)
@triton.jit
def cross_entropy_bwd_kernel(
dlogits_ptr, # data ptrs
dloss_ptr,
logits_ptr,
lse_ptr,
labels_ptr,
smoothing,
logit_scale,
lse_square_scale,
ignore_index,
total_classes,
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
n_cols, # shapes
logits_row_stride, # strides
dlogits_row_stride,
dloss_row_stride,
BLOCK_SIZE: tl.constexpr,
HAS_SMOOTHING: tl.constexpr,
):
row_idx = tl.program_id(0)
col_block_idx = tl.program_id(1)
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
label_idx = tl.load(labels_ptr + row_idx)
if label_idx != ignore_index:
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
else:
dloss = 0.0
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
tl.float32
) * logit_scale
lse = tl.load(lse_ptr + row_idx)
probs = tl.exp(logits - lse)
probs += 2.0 * lse_square_scale * lse * probs
label_idx -= class_start_idx
if HAS_SMOOTHING:
smooth_positive = 1.0 - smoothing
smooth_negative = smoothing / total_classes
probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative
else:
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
@torch.library.custom_op("flasht5::cross_entropy_triton_fwd", mutates_args=(), device_types="cuda")
def cross_entropy_triton_fwd(
logits: torch.Tensor,
labels: torch.Tensor,
precomputed_lse: torch.Tensor,
use_precomputed_lse: bool,
split: bool,
smoothing: float,
logit_scale: float,
lse_square_scale: float,
ignore_index: int,
total_classes: int,
class_start_idx: int,
n_cols: int,
n_rows: int,
BLOCK_SIZE: int,
num_warps: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if logits.stride(-1) != 1:
logits = logits.contiguous()
losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
if use_precomputed_lse:
assert precomputed_lse.shape == (n_rows,)
lse = precomputed_lse.contiguous()
else:
lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(logits.device.index):
cross_entropy_fwd_kernel[(n_rows,)](
losses, # data ptrs
lse,
z_losses,
logits,
labels,
smoothing,
logit_scale,
lse_square_scale,
ignore_index,
total_classes,
class_start_idx,
n_cols, # shapes
logits.stride(0), # strides
BLOCK_SIZE=BLOCK_SIZE, # constants
SPLIT=split,
PRECOMPUTED_LSE=use_precomputed_lse,
num_warps=num_warps,
)
return losses, z_losses, lse
@torch.library.register_fake("flasht5::cross_entropy_triton_fwd")
def cross_entropy_triton_fwd_abstract(logits, labels, precomputed_lse, use_precomputed_lse, split, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps):
losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
z_losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
logsumexp = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
return losses, z_losses, logsumexp
@torch.library.custom_op("flasht5::cross_entropy_triton_bwd", mutates_args={"logits"}, device_types="cuda")
def cross_entropy_triton_bwd(
dlosses: torch.Tensor,
logits: torch.Tensor,
lse: torch.Tensor,
labels: torch.Tensor,
inplace_backward: bool,
smoothing: float,
logit_scale: float,
lse_square_scale: float,
ignore_index: int,
total_classes: int,
class_start_idx: int,
n_cols: int,
n_rows: int,
BLOCK_SIZE: int,
num_warps: int
) -> torch.Tensor:
dlogits = logits if inplace_backward else torch.empty_like(logits)
grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(logits.device.index):
cross_entropy_bwd_kernel[grid](
dlogits, # data ptrs
dlosses,
logits,
lse,
labels,
smoothing,
logit_scale,
lse_square_scale,
ignore_index,
total_classes,
class_start_idx,
n_cols, # shapes
logits.stride(0), # strides
dlogits.stride(0),
dlosses.stride(0),
BLOCK_SIZE=BLOCK_SIZE, # constants
num_warps=num_warps,
)
return dlogits if not inplace_backward else None
@torch.library.register_fake("flasht5::cross_entropy_triton_bwd")
def cross_entropy_triton_bwd_abstract(dlosses, logits, lse, labels, inplace_backward, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps):
return torch.empty_like(logits)
class CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(
ctx,
logits,
labels,
precomputed_lse=None,
smoothing=0.0,
logit_scale=1.0,
lse_square_scale=0.0,
ignore_index=-100,
inplace_backward=False,
process_group=None,
):
# For some reason Triton generates wrong code when labels has dtype long and its address
# is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index.
if labels.dtype == torch.long and labels.data_ptr() % 16 != 0:
labels = F.pad(labels, (0, 1))[..., :-1]
assert labels.data_ptr() % 16 == 0
n_rows, n_cols = logits.shape
assert labels.shape == (n_rows,)
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
total_classes = world_size * n_cols
rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
class_start_idx = rank * n_cols
use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0
MAX_BLOCK_SIZE = 16 * 1024
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
num_warps = (
4
if BLOCK_SIZE < 2048
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
)
losses, z_losses, lse = torch.ops.flasht5.cross_entropy_triton_fwd(
logits, labels, precomputed_lse, use_precomputed_lse, \
world_size > 1, smoothing, logit_scale, lse_square_scale, \
ignore_index, total_classes, class_start_idx, \
n_cols, n_rows, BLOCK_SIZE, num_warps
)
if world_size > 1:
# If there's no smoothing, if labels are in the vocab of this partition, losses contains
# - predicted logit, and 0 otherwise.
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
# -0.9 * predicted logit - 0.1 * sum logit / total_classes.
# For labels not in the vocab of this partition, losses contains
# -0.1 * sum logit / total_classes.
if world_size > 1:
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
handle_losses = torch.distributed.all_reduce(
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
)
lse = torch.logsumexp(lse_allgather, dim=0)
handle_losses.wait()
# After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
# we just have to add the (global) lse.
# If there's smoothing=0.1, the total losses are
# -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
# Again, we just have to add the (global) lse.
losses += lse
if lse_square_scale != 0.0:
z_losses = lse_square_scale * lse.square()
z_losses.masked_fill_(labels == ignore_index, 0.0)
losses += z_losses
else:
z_losses = torch.zeros_like(losses)
losses.masked_fill_(labels == ignore_index, 0.0)
ctx.save_for_backward(logits, lse, labels)
ctx.mark_non_differentiable(z_losses)
ctx.smoothing = smoothing
ctx.logit_scale = logit_scale
ctx.lse_square_scale = lse_square_scale
ctx.ignore_index = ignore_index
ctx.total_classes = total_classes
ctx.class_start_idx = class_start_idx
ctx.inplace_backward = inplace_backward
return losses, z_losses
@staticmethod
def backward(ctx, grad_losses, grad_z_losses):
del grad_z_losses # z_losses are only for logging.
logits, lse, labels = ctx.saved_tensors
n_rows, n_cols = logits.shape
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
dlogits = torch.ops.flasht5.cross_entropy_triton_bwd(
grad_losses, logits, lse, labels, \
ctx.inplace_backward, ctx.smoothing, ctx.logit_scale, \
ctx.lse_square_scale, ctx.ignore_index, ctx.total_classes, \
ctx.class_start_idx, n_cols, n_rows, BLOCK_SIZE, num_warps
)
if ctx.inplace_backward:
dlogits = logits
return dlogits, None, None, None, None, None, None, None, None, None
def cross_entropy_loss(
logits: torch.Tensor,
labels: torch.Tensor,
precomputed_lse: Optional[torch.Tensor] = None,
label_smoothing: float = 0.0,
logit_scale: float = 1.0,
lse_square_scale: float = 0.0,
ignore_index=-100,
inplace_backward: bool = False,
process_group=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
logits: (batch, vocab_size)
labels: (batch,)
label_smoothing: float
logit_scale: float. Multiply logits by this scale before calculating the loss.
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss".
ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
This saves memory.
process_group: if not None, we're doing Tensor Parallel: each process is responsible for
one part of the vocab. The loss will be aggregated across processes.
Returns:
losses: (batch,), float
z_losses: (batch,), float
"""
return CrossEntropyLoss.apply(
logits.view(-1, logits.shape[-1]),
labels.view(-1),
precomputed_lse,
label_smoothing,
logit_scale,
lse_square_scale,
ignore_index,
inplace_backward,
process_group,
)
|