linzheng commited on
Commit
d927f57
·
verified ·
1 Parent(s): d6425d9

Update model and kernels for training support

Browse files
Files changed (5) hide show
  1. eva.py +22 -17
  2. eva_agg_kernel.py +1349 -52
  3. eva_prep_kv_kernel.py +686 -26
  4. eva_pt_ref.py +1 -3
  5. modeling_evabyte.py +16 -196
eva.py CHANGED
@@ -2,8 +2,8 @@ from typing import Dict, Optional, Tuple, List, Any, Union
2
  import torch
3
  from torch import nn
4
  import torch.nn.functional as F
5
- from .eva_agg_kernel import triton_eva_agg_fwd
6
- from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd
7
  try:
8
  import triton
9
  USE_TRITON_IMPL = True
@@ -129,10 +129,10 @@ class EvaAttention(nn.Module):
129
  assert not output_attentions
130
  bsz, q_len, _ = hidden_states.size()
131
 
132
- if use_cache and past_key_value is None:
133
- raise ValueError
134
-
135
- assert isinstance(attention_mask, tuple)
136
 
137
  # infer the model's running mode
138
  is_prefilling = use_cache and past_key_value.get_seq_length(self.layer_idx) == 0
@@ -141,13 +141,16 @@ class EvaAttention(nn.Module):
141
  if is_prefilling:
142
  assert len(attention_mask) == 2
143
  window_mask, intra_chunk_mask = attention_mask
144
- chunk_dummpy_mask = None
145
  elif is_decoding:
146
  assert len(attention_mask) == 3
147
- window_mask, intra_chunk_mask, chunk_dummpy_mask = attention_mask
148
  else:
149
- window_mask, intra_chunk_mask = attention_mask
150
- chunk_dummpy_mask = None
 
 
 
151
 
152
  ############################################
153
  # compute q, k, v from hidden states
@@ -201,7 +204,7 @@ class EvaAttention(nn.Module):
201
  # k/v: [b, h, w, d]
202
  # rfa_k/rfa_v: [b, h, w//c, d]
203
  # 3. in forward inference; the seq_len is already divisible
204
- rfa_k, rfa_v = triton_eva_prep_kv_fwd(
205
  dump_k, dump_v,
206
  self.adaptive_mu_k, self.adaptive_phi,
207
  dump_rf_mask, self.head_dim_scaling, self.chunk_size
@@ -227,10 +230,11 @@ class EvaAttention(nn.Module):
227
  # q: [b, h, n, d]
228
  # k/v: [b, h, n, d]
229
  # rfa_k/rfa_v: [b, h, n // c, d]
230
- attn_output = triton_eva_agg_fwd(
231
  q, s_k, s_v,
232
  rfa_k, rfa_v,
233
- singleton_mask, self.head_dim_scaling, self.window_size, self.chunks_per_window
 
234
  )
235
  elif is_decoding:
236
  # 2. in decoding, the input shape is
@@ -258,8 +262,8 @@ class EvaAttention(nn.Module):
258
  agg_k = torch.cat([s_k, rfa_k[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2)
259
  agg_v = torch.cat([s_v, rfa_v[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2)
260
  if singleton_mask is not None:
261
- assert chunk_dummpy_mask is not None
262
- attn_mask = torch.cat([singleton_mask, chunk_dummpy_mask], dim=-1)
263
  else:
264
  attn_mask = singleton_mask
265
  else:
@@ -275,10 +279,11 @@ class EvaAttention(nn.Module):
275
  )
276
  else:
277
  # 3. in single-forward inference
278
- attn_output = triton_eva_agg_fwd(
279
  q, s_k, s_v,
280
  rfa_k, rfa_v,
281
- singleton_mask, self.head_dim_scaling, self.window_size, self.chunks_per_window
 
282
  )
283
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
284
  raise ValueError(
 
2
  import torch
3
  from torch import nn
4
  import torch.nn.functional as F
5
+ from .eva_agg_kernel import eva_agg_func_triton
6
+ from .eva_prep_kv_kernel import eva_prep_kv_func_triton
7
  try:
8
  import triton
9
  USE_TRITON_IMPL = True
 
129
  assert not output_attentions
130
  bsz, q_len, _ = hidden_states.size()
131
 
132
+ if use_cache:
133
+ if past_key_value is None:
134
+ raise ValueError
135
+ assert isinstance(attention_mask, tuple)
136
 
137
  # infer the model's running mode
138
  is_prefilling = use_cache and past_key_value.get_seq_length(self.layer_idx) == 0
 
141
  if is_prefilling:
142
  assert len(attention_mask) == 2
143
  window_mask, intra_chunk_mask = attention_mask
144
+ chunk_mask = None
145
  elif is_decoding:
146
  assert len(attention_mask) == 3
147
+ window_mask, intra_chunk_mask, chunk_mask = attention_mask
148
  else:
149
+ if attention_mask is not None:
150
+ assert isinstance(attention_mask, tuple) and len(attention_mask) == 3
151
+ window_mask, chunk_mask, intra_chunk_mask = attention_mask
152
+ else:
153
+ window_mask, chunk_mask, intra_chunk_mask = None, None, None
154
 
155
  ############################################
156
  # compute q, k, v from hidden states
 
204
  # k/v: [b, h, w, d]
205
  # rfa_k/rfa_v: [b, h, w//c, d]
206
  # 3. in forward inference; the seq_len is already divisible
207
+ rfa_k, rfa_v = eva_prep_kv_func_triton(
208
  dump_k, dump_v,
209
  self.adaptive_mu_k, self.adaptive_phi,
210
  dump_rf_mask, self.head_dim_scaling, self.chunk_size
 
230
  # q: [b, h, n, d]
231
  # k/v: [b, h, n, d]
232
  # rfa_k/rfa_v: [b, h, n // c, d]
233
+ attn_output = eva_agg_func_triton(
234
  q, s_k, s_v,
235
  rfa_k, rfa_v,
236
+ singleton_mask, chunk_mask,
237
+ self.head_dim_scaling, self.window_size, self.chunks_per_window
238
  )
239
  elif is_decoding:
240
  # 2. in decoding, the input shape is
 
262
  agg_k = torch.cat([s_k, rfa_k[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2)
263
  agg_v = torch.cat([s_v, rfa_v[..., :num_windows_seen_so_far * self.chunks_per_window, :]], dim=-2)
264
  if singleton_mask is not None:
265
+ assert chunk_mask is not None
266
+ attn_mask = torch.cat([singleton_mask, chunk_mask], dim=-1)
267
  else:
268
  attn_mask = singleton_mask
269
  else:
 
279
  )
280
  else:
281
  # 3. in single-forward inference
282
+ attn_output = eva_agg_func_triton(
283
  q, s_k, s_v,
284
  rfa_k, rfa_v,
285
+ singleton_mask, chunk_mask,
286
+ self.head_dim_scaling, self.window_size, self.chunks_per_window
287
  )
288
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
289
  raise ValueError(
eva_agg_kernel.py CHANGED
@@ -4,15 +4,969 @@ import torch
4
  import triton
5
  import triton.language as tl
6
 
7
- # Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
8
- # @triton.autotune(
9
- # configs=[
10
- # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
11
- # # This config has a race condition when EVEN_M == False, disabling it for now.
12
- # # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
13
- # ],
14
- # key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
15
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  @triton.heuristics(
17
  {
18
  "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
@@ -30,23 +984,24 @@ def _fwd_eva_agg_kernel(
30
  RFA_K,
31
  RFA_V,
32
  WindowMask,
 
33
  Out,
 
34
  softmax_scale,
35
  stride_qb, stride_qh, stride_qm,
36
  stride_kb, stride_kh, stride_kn,
37
  stride_vb, stride_vh, stride_vn,
38
  stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
39
  stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
40
- stride_mb, stride_mm,
 
41
  stride_ob, stride_oh, stride_om,
 
42
  nheads,
43
  seqlen_q,
44
  seqlen_k,
45
  nchunks,
46
  headdim,
47
- CACHE_KEY_SEQLEN_Q, # TODO: why keeping this
48
- CACHE_KEY_SEQLEN_K, # TODO: why keeping this
49
- CACHE_KEY_NCHUNKS, # TODO: why keeping this
50
  CHUNKS_PER_WINDOW: tl.constexpr,
51
  WINDOW_SIZE: tl.constexpr,
52
  MASK_TYPE: tl.constexpr,
@@ -106,11 +1061,18 @@ def _fwd_eva_agg_kernel(
106
  qk_scale = softmax_scale
107
  qk_scale *= 1.4426950408889634 # log2(e)
108
  if MASK_TYPE == 1:
109
- m_ptrs = (
110
  WindowMask +
111
- off_b * stride_mb +
112
- (offs_m[:, None] * stride_mm + offs_n[None, :])
113
  )
 
 
 
 
 
 
 
114
  m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
115
  d_i = tl.zeros([BLOCK_M], dtype=tl.float32)
116
  acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
@@ -181,32 +1143,37 @@ def _fwd_eva_agg_kernel(
181
 
182
  if MASK_TYPE == 1:
183
  if EVEN_M & EVEN_W:
184
- mask = tl.load(
185
- m_ptrs + start_n - start_idx_n
186
- ).to(tl.float32)
187
  else:
188
- mask = tl.load(
189
- m_ptrs + start_n - start_idx_n,
190
  mask=(offs_m[:, None] < seqlen_q)
191
  & ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
192
- other=0.0,
193
- ).to(tl.float32)
194
  # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
195
  # can then fuse the mult and add into an fma instruction. But if we have bias we need to
196
  # to multiply with softmax_scale here.
197
  # we assume mask already implies the causal masking
198
- qk = qk * qk_scale + mask
 
199
  m_ij = tl.maximum(tl.max(qk, 1), m_i)
200
- p = tl.exp2(qk - m_ij[:, None])
 
 
201
  else:
202
  qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
203
  m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
204
- p = tl.exp2(qk * qk_scale - m_ij[:, None])
 
 
205
 
206
  d_ij = tl.sum(p, 1)
207
 
208
  # scale acc_o
209
- prev_scale = tl.exp2(m_i - m_ij)
210
  # # -- update output accumulator --
211
  acc_o = acc_o * prev_scale[:, None]
212
  # update acc_o
@@ -278,13 +1245,37 @@ def _fwd_eva_agg_kernel(
278
  if not EVEN_C: # Need to mask out otherwise the softmax is wrong
279
  qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
280
 
281
- m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
282
- p = tl.exp2(qk * qk_scale - m_ij[:, None])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  d_ij = tl.sum(p, 1)
285
 
286
  # scale acc_o
287
- prev_scale = tl.exp2(m_i - m_ij)
288
  # # -- update output accumulator --
289
  acc_o = acc_o * prev_scale[:, None]
290
  # update acc_o
@@ -320,7 +1311,10 @@ def _fwd_eva_agg_kernel(
320
  d_i = d_i * prev_scale + d_ij
321
  m_i = m_ij
322
 
323
- # BUG: have to store and immediately load
 
 
 
324
  acc_o = acc_o / d_i[:, None]
325
  # TODO: understand why rematerialize offsets to save registers?
326
  start_m = tl.program_id(0)
@@ -353,8 +1347,30 @@ def _fwd_eva_agg_kernel(
353
  out_ptrs, acc_o,
354
  mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
355
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
- def triton_eva_agg_fwd(q, k, v, rfa_k, rfa_v, window_mask, softmax_scale, window_size, chunks_per_window):
 
 
 
 
 
 
 
358
  if rfa_k is None and rfa_v is None:
359
  empty_rfa_kv = 1
360
 
@@ -394,13 +1410,27 @@ def triton_eva_agg_fwd(q, k, v, rfa_k, rfa_v, window_mask, softmax_scale, window
394
  mask_type = 0
395
  if window_mask is not None:
396
  mask_type = 1
397
- assert window_mask.dtype == q.dtype, torch.float
398
  assert window_mask.is_cuda
399
  assert window_mask.dim() == 4
400
  assert window_mask.shape == (batch, 1, seqlen_q, window_size)
401
  if window_mask.stride(-1) != 1:
402
  window_mask = window_mask.contiguous()
403
- mask_strides = (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  (window_mask.stride(0), window_mask.stride(2))
405
  if mask_type == 1 else
406
  (0, 0)
@@ -416,20 +1446,16 @@ def triton_eva_agg_fwd(q, k, v, rfa_k, rfa_v, window_mask, softmax_scale, window
416
  if empty_rfa_kv == 0 else
417
  (0, 0, 0)
418
  )
419
- assert chunks_per_window > 0, "chunks_per_window must be greater than 0"
420
 
421
  o = torch.empty_like(q)
 
422
 
423
  BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
424
- if q.dtype == torch.float:
425
- BLOCK = 64
426
- else:
427
- BLOCK = 128
428
- num_warps = 4 if head_dim <= 64 else 8
429
- assert chunks_per_window >= BLOCK, "chunks_per_window must be greater than BLOCK"
430
- # WINDOW_MASK_TYPE:
431
- # - 0: regular causal mask, simply None
432
- # - 1: the shape must be B, 1, W, I, J
433
 
434
  grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
435
  _fwd_eva_agg_kernel[grid](
@@ -439,31 +1465,302 @@ def triton_eva_agg_fwd(q, k, v, rfa_k, rfa_v, window_mask, softmax_scale, window
439
  rfa_k,
440
  rfa_v,
441
  window_mask,
 
442
  o,
 
443
  softmax_scale,
444
  q.stride(0), q.stride(1), q.stride(2),
445
  k.stride(0), k.stride(1), k.stride(2),
446
  v.stride(0), v.stride(1), v.stride(2),
447
  rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
448
  rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
449
- mask_strides[0], mask_strides[1],
 
450
  o.stride(0), o.stride(1), o.stride(2),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  nheads,
452
  seqlen_q,
453
  seqlen_k,
454
  nchunks,
455
  head_dim,
456
- seqlen_q // 32,
457
- seqlen_k // 32,
458
- nchunks // 32,
459
  chunks_per_window,
460
  window_size,
461
  mask_type,
462
  empty_rfa_kv,
463
  BLOCK_HEADDIM,
464
- BLOCK_M=BLOCK,
465
- BLOCK_N=BLOCK,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  num_warps=num_warps,
467
- num_stages=1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  )
469
- return o
 
4
  import triton
5
  import triton.language as tl
6
 
7
+ @triton.heuristics(
8
+ {
9
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
10
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
11
+ "EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
12
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
13
+ }
14
+ )
15
+ @triton.jit
16
+ def _bwd_eva_agg_kernel_dkdv(
17
+ Q,
18
+ K,
19
+ V,
20
+ WindowMask,
21
+ DO,
22
+ LSE,
23
+ DO_T_O,
24
+ DK,
25
+ DV,
26
+ softmax_scale,
27
+ stride_qb, stride_qh, stride_qm,
28
+ stride_kb, stride_kh, stride_kn,
29
+ stride_vb, stride_vh, stride_vn,
30
+ stride_window_mask_b, stride_window_mask_m,
31
+ stride_do_b, stride_do_h, stride_do_m,
32
+ stride_lse_b, stride_lse_h,
33
+ stride_do_t_o_b, stride_do_t_o_h,
34
+ stride_dk_b, stride_dk_h, stride_dk_n,
35
+ stride_dv_b, stride_dv_h, stride_dv_n,
36
+ nheads,
37
+ seqlen_q,
38
+ seqlen_k,
39
+ headdim,
40
+ WINDOW_SIZE: tl.constexpr,
41
+ MASK_TYPE: tl.constexpr,
42
+ BLOCK_HEADDIM: tl.constexpr,
43
+ EVEN_M: tl.constexpr,
44
+ EVEN_N: tl.constexpr,
45
+ EVEN_W: tl.constexpr,
46
+ EVEN_HEADDIM: tl.constexpr,
47
+ BLOCK_M: tl.constexpr,
48
+ BLOCK_N: tl.constexpr,
49
+ ):
50
+ off_bh = tl.program_id(1)
51
+ off_h = off_bh % nheads
52
+ off_b = off_bh // nheads
53
+
54
+ start_n = tl.program_id(0)
55
+ # determine which window the current KV block belongs to
56
+ offs_w = (start_n * BLOCK_N) // WINDOW_SIZE
57
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
58
+ offs_m = tl.arange(0, BLOCK_M)
59
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
60
+
61
+ # initialize pointers
62
+ q_ptrs = (
63
+ Q +
64
+ off_b * stride_qb +
65
+ off_h * stride_qh +
66
+ offs_m[:, None] * stride_qm + offs_d[None, :]
67
+ )
68
+ k_ptrs = (
69
+ K +
70
+ off_b * stride_kb +
71
+ off_h * stride_kh +
72
+ offs_n[:, None] * stride_kn + offs_d[None, :]
73
+ )
74
+ v_ptrs = (
75
+ V +
76
+ off_b * stride_vb +
77
+ off_h * stride_vh +
78
+ offs_n[:, None] * stride_vn + offs_d[None, :]
79
+ )
80
+ do_ptrs = (
81
+ DO +
82
+ off_b * stride_do_b +
83
+ off_h * stride_do_h +
84
+ offs_m[:, None] * stride_do_m + offs_d[None, :]
85
+ )
86
+ do_t_o_ptrs = (
87
+ DO_T_O +
88
+ off_b * stride_do_t_o_b +
89
+ off_h * stride_do_t_o_h +
90
+ offs_m[:, None]
91
+ )
92
+ lse_ptrs = (
93
+ LSE +
94
+ off_b * stride_lse_b +
95
+ off_h * stride_lse_h +
96
+ offs_m[:, None]
97
+ )
98
+ if MASK_TYPE == 1:
99
+ m_ptrs = (
100
+ WindowMask +
101
+ off_b * stride_window_mask_b +
102
+ (offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
103
+ )
104
+ dk_ptrs = (
105
+ DK +
106
+ off_b * stride_dk_b +
107
+ off_h * stride_dk_h +
108
+ offs_n[:, None] * stride_dk_n + offs_d[None, :]
109
+ )
110
+ dv_ptrs = (
111
+ DV +
112
+ off_b * stride_dv_b +
113
+ off_h * stride_dv_h +
114
+ offs_n[:, None] * stride_dv_n + offs_d[None, :]
115
+ )
116
+
117
+ # 1. for singletons
118
+ # determine start and end of query block
119
+ begin_m = ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
120
+ end_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q)
121
+
122
+ dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
123
+ dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
124
+ if EVEN_N & EVEN_M:
125
+ if EVEN_HEADDIM:
126
+ k = tl.load(k_ptrs)
127
+ v = tl.load(v_ptrs)
128
+ else:
129
+ k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
130
+ v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
131
+ else:
132
+ if EVEN_HEADDIM:
133
+ k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
134
+ v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
135
+ else:
136
+ k = tl.load(
137
+ k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
138
+ )
139
+ v = tl.load(
140
+ v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
141
+ )
142
+ for start_m in range(begin_m, end_m, BLOCK_M):
143
+ start_m = tl.multiple_of(start_m, BLOCK_M)
144
+ # load q, do, and lse
145
+ if EVEN_M & EVEN_N:
146
+ if EVEN_HEADDIM:
147
+ q = tl.load(
148
+ q_ptrs + start_m * stride_qm
149
+ )
150
+ do = tl.load(
151
+ do_ptrs + start_m * stride_do_m
152
+ )
153
+ else:
154
+ q = tl.load(
155
+ q_ptrs + start_m * stride_qm,
156
+ mask=offs_d[None, :] < headdim,
157
+ other=0.0
158
+ )
159
+ do = tl.load(
160
+ do_ptrs + start_m * stride_do_m,
161
+ mask=offs_d[None, :] < headdim,
162
+ other=0.0
163
+ )
164
+ do_t_o = tl.load(
165
+ do_t_o_ptrs + start_m
166
+ )
167
+ lse = tl.load(
168
+ lse_ptrs + start_m
169
+ )
170
+ else:
171
+ if EVEN_HEADDIM:
172
+ q = tl.load(
173
+ q_ptrs + start_m * stride_qm,
174
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
175
+ other=0.0
176
+ )
177
+ do = tl.load(
178
+ do_ptrs + start_m * stride_do_m,
179
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
180
+ other=0.0
181
+ )
182
+ else:
183
+ q = tl.load(
184
+ q_ptrs + start_m * stride_qm,
185
+ mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
186
+ other=0.0
187
+ )
188
+ do = tl.load(
189
+ do_ptrs + start_m * stride_do_m,
190
+ mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
191
+ other=0.0
192
+ )
193
+ do_t_o = tl.load(
194
+ do_t_o_ptrs + start_m,
195
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
196
+ other=0.0
197
+ )
198
+ lse = tl.load(
199
+ lse_ptrs + start_m,
200
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
201
+ other=0.0
202
+ )
203
+ lse = tl.where(lse == float("-inf"), 0.0, lse)
204
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
205
+ qk += tl.dot(q, tl.trans(k))
206
+ if not EVEN_M:
207
+ qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf"))
208
+
209
+ if MASK_TYPE == 1:
210
+ if EVEN_M & EVEN_W:
211
+ mask = tl.load(
212
+ m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE)
213
+ )
214
+ else:
215
+ mask = tl.load(
216
+ m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE),
217
+ mask=((start_m + offs_m)[:, None] < seqlen_q)
218
+ & (((start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE) + offs_n)[None, :] < WINDOW_SIZE),
219
+ other=1,
220
+ )
221
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
222
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
223
+ # to multiply with softmax_scale here.
224
+ # we assume mask already implies the causal masking
225
+ qk = qk * softmax_scale
226
+ qk = tl.where(mask, float("-inf"), qk)
227
+ p = tl.exp(qk - lse)
228
+ else:
229
+ qk += tl.where((start_m + offs_m)[:, None] >= offs_n[None, :], 0, float("-inf"))
230
+ p = tl.exp(qk * softmax_scale - lse)
231
+
232
+ # dp [M, N]
233
+ dp = tl.dot(do, tl.trans(v))
234
+ # p [M, N], dp [M, N], do_t_o [M, 1] -> ds [M, N]
235
+ ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
236
+ # p is fp32 and [M, N], convert to q.dtype
237
+ # do [M, D] -> dv [N, D]
238
+ dv += tl.dot(tl.trans(p.to(do.dtype)), do)
239
+ # dk [N, D]
240
+ dk += tl.dot(tl.trans(ds), q)
241
+ if EVEN_N & EVEN_M:
242
+ if EVEN_HEADDIM:
243
+ tl.store(dv_ptrs, dv)
244
+ tl.store(dk_ptrs, dk)
245
+ else:
246
+ tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
247
+ tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
248
+ else:
249
+ if EVEN_HEADDIM:
250
+ tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
251
+ tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
252
+ else:
253
+ tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
254
+ tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
255
+
256
+ @triton.heuristics(
257
+ {
258
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
259
+ "EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
260
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
261
+ }
262
+ )
263
+ @triton.jit
264
+ def _bwd_eva_agg_kernel_drfa_kv(
265
+ Q,
266
+ RFA_K,
267
+ RFA_V,
268
+ ChunkMask,
269
+ DO,
270
+ LSE,
271
+ DO_T_O,
272
+ D_RFA_K,
273
+ D_RFA_V,
274
+ softmax_scale,
275
+ stride_qb, stride_qh, stride_qm,
276
+ stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
277
+ stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
278
+ stride_chunk_mask_b, stride_chunk_mask_m,
279
+ stride_do_b, stride_do_h, stride_do_m,
280
+ stride_lse_b, stride_lse_h,
281
+ stride_do_t_o_b, stride_do_t_o_h,
282
+ stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c,
283
+ stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c,
284
+ nheads,
285
+ seqlen_q,
286
+ nchunks,
287
+ headdim,
288
+ CHUNKS_PER_WINDOW: tl.constexpr,
289
+ WINDOW_SIZE: tl.constexpr,
290
+ MASK_TYPE: tl.constexpr,
291
+ BLOCK_HEADDIM: tl.constexpr,
292
+ EVEN_M: tl.constexpr,
293
+ EVEN_C: tl.constexpr,
294
+ EVEN_HEADDIM: tl.constexpr,
295
+ BLOCK_M: tl.constexpr,
296
+ BLOCK_N: tl.constexpr,
297
+ ):
298
+ off_bh = tl.program_id(1)
299
+ off_h = off_bh % nheads
300
+ off_b = off_bh // nheads
301
+ start_c = tl.program_id(0)
302
+ # there are 128 chunks per window
303
+ offs_c = start_c * BLOCK_N + tl.arange(0, BLOCK_N)
304
+ # determine which window the current KV block belongs to
305
+ offs_w = (start_c * BLOCK_N) // CHUNKS_PER_WINDOW
306
+ offs_m = tl.arange(0, BLOCK_M)
307
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
308
+
309
+ # initialize pointers
310
+ q_ptrs = (
311
+ Q +
312
+ off_b * stride_qb +
313
+ off_h * stride_qh +
314
+ (offs_m[:, None] * stride_qm + offs_d[None, :])
315
+ )
316
+ do_ptrs = (
317
+ DO +
318
+ off_b * stride_do_b +
319
+ off_h * stride_do_h +
320
+ (offs_m[:, None] * stride_do_m + offs_d[None, :])
321
+ )
322
+ do_t_o_ptrs = (
323
+ DO_T_O +
324
+ off_b * stride_do_t_o_b +
325
+ off_h * stride_do_t_o_h +
326
+ (offs_m[:, None])
327
+ )
328
+ lse_ptrs = (
329
+ LSE +
330
+ off_b * stride_lse_b +
331
+ off_h * stride_lse_h +
332
+ (offs_m[:, None])
333
+ )
334
+ rfa_k_ptrs = (
335
+ RFA_K +
336
+ off_b * stride_rfa_kb +
337
+ off_h * stride_rfa_kh +
338
+ (offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
339
+ )
340
+ rfa_v_ptrs = (
341
+ RFA_V +
342
+ off_b * stride_rfa_vb +
343
+ off_h * stride_rfa_vh +
344
+ (offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
345
+ )
346
+ if MASK_TYPE == 1:
347
+ rfa_m_ptrs = (
348
+ ChunkMask +
349
+ off_b * stride_chunk_mask_b +
350
+ (offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
351
+ )
352
+ d_rfa_k_ptrs = (
353
+ D_RFA_K +
354
+ off_b * stride_d_rfa_k_b +
355
+ off_h * stride_d_rfa_k_h +
356
+ (offs_c[:, None] * stride_d_rfa_k_c + offs_d[None, :])
357
+ )
358
+ d_rfa_v_ptrs = (
359
+ D_RFA_V +
360
+ off_b * stride_d_rfa_v_b +
361
+ off_h * stride_d_rfa_v_h +
362
+ (offs_c[:, None] * stride_d_rfa_v_c + offs_d[None, :])
363
+ )
364
+
365
+ d_rfa_k = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
366
+ d_rfa_v = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
367
+ if EVEN_C & EVEN_M:
368
+ if EVEN_HEADDIM:
369
+ rfa_k = tl.load(rfa_k_ptrs)
370
+ rfa_v = tl.load(rfa_v_ptrs)
371
+ else:
372
+ rfa_k = tl.load(rfa_k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
373
+ rfa_v = tl.load(rfa_v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
374
+ else:
375
+ if EVEN_HEADDIM:
376
+ rfa_k = tl.load(rfa_k_ptrs, mask=offs_c[:, None] < nchunks, other=0.0)
377
+ rfa_v = tl.load(rfa_v_ptrs, mask=offs_c[:, None] < nchunks, other=0.0)
378
+ else:
379
+ rfa_k = tl.load(
380
+ rfa_k_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0
381
+ )
382
+ rfa_v = tl.load(
383
+ rfa_v_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0
384
+ )
385
+ begin_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q)
386
+ end_m = seqlen_q
387
+ for start_m in range(begin_m, end_m, BLOCK_M):
388
+ start_m = tl.multiple_of(start_m, BLOCK_M)
389
+ # load q, do, and lse
390
+ if EVEN_M:
391
+ if EVEN_HEADDIM:
392
+ q = tl.load(
393
+ q_ptrs + start_m * stride_qm
394
+ )
395
+ do = tl.load(
396
+ do_ptrs + start_m * stride_do_m
397
+ )
398
+ else:
399
+ q = tl.load(
400
+ q_ptrs + start_m * stride_qm,
401
+ mask=offs_d[None, :] < headdim,
402
+ other=0.0
403
+ )
404
+ do = tl.load(
405
+ do_ptrs + start_m * stride_do_m,
406
+ mask=offs_d[None, :] < headdim,
407
+ other=0.0
408
+ )
409
+ do_t_o = tl.load(
410
+ do_t_o_ptrs + start_m
411
+ )
412
+ lse = tl.load(
413
+ lse_ptrs + start_m
414
+ )
415
+ else:
416
+ if EVEN_HEADDIM:
417
+ q = tl.load(
418
+ q_ptrs + start_m * stride_qm,
419
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
420
+ other=0.0
421
+ )
422
+ do = tl.load(
423
+ do_ptrs + start_m * stride_do_m,
424
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
425
+ other=0.0
426
+ )
427
+ else:
428
+ q = tl.load(
429
+ q_ptrs + start_m * stride_qm,
430
+ mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
431
+ other=0.0
432
+ )
433
+ do = tl.load(
434
+ do_ptrs + start_m * stride_do_m,
435
+ mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
436
+ other=0.0
437
+ )
438
+ do_t_o = tl.load(
439
+ do_t_o_ptrs + start_m,
440
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
441
+ other=0.0
442
+ )
443
+ lse = tl.load(
444
+ lse_ptrs + start_m,
445
+ mask=(start_m + offs_m)[:, None] < seqlen_q,
446
+ other=0.0
447
+ )
448
+ lse = tl.where(lse == float("-inf"), 0.0, lse)
449
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
450
+ qk += tl.dot(q, tl.trans(rfa_k))
451
+ if not EVEN_M:
452
+ qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf"))
453
+
454
+ if MASK_TYPE == 1:
455
+ if EVEN_M & EVEN_C:
456
+ mask = tl.load(
457
+ rfa_m_ptrs + (start_m * stride_chunk_mask_m)
458
+ )
459
+ else:
460
+ mask = tl.load(
461
+ rfa_m_ptrs + (start_m * stride_chunk_mask_m),
462
+ mask=((start_m + offs_m)[:, None] < seqlen_q)
463
+ & (offs_c[None, :] < nchunks),
464
+ other=1,
465
+ )
466
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
467
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
468
+ # to multiply with softmax_scale here.
469
+ # we assume mask already implies the causal masking
470
+ qk = qk * softmax_scale
471
+ qk = tl.where(mask, float("-inf"), qk)
472
+ p = tl.exp(qk - lse)
473
+ else:
474
+ p = tl.exp(qk * softmax_scale - lse)
475
+
476
+ dp = tl.dot(do, tl.trans(rfa_v))
477
+ ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
478
+ # p is fp32, convert to q.dtype
479
+ d_rfa_v += tl.dot(tl.trans(p.to(do.dtype)), do)
480
+ # move softmax_scale to ds to save computation
481
+ d_rfa_k += tl.dot(tl.trans(ds), q)
482
+ if EVEN_C & EVEN_M:
483
+ if EVEN_HEADDIM:
484
+ tl.store(d_rfa_v_ptrs, d_rfa_v)
485
+ tl.store(d_rfa_k_ptrs, d_rfa_k)
486
+ else:
487
+ tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_d[None, :] < headdim)
488
+ tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_d[None, :] < headdim)
489
+ else:
490
+ if EVEN_HEADDIM:
491
+ tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_c[:, None] < nchunks)
492
+ tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_c[:, None] < nchunks)
493
+ else:
494
+ tl.store(d_rfa_v_ptrs, d_rfa_v, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim))
495
+ tl.store(d_rfa_k_ptrs, d_rfa_k, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim))
496
+
497
+ @triton.heuristics(
498
+ {
499
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
500
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
501
+ "EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
502
+ "EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
503
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
504
+ }
505
+ )
506
+ @triton.jit
507
+ def _bwd_eva_agg_kernel_dq(
508
+ Q,
509
+ K,
510
+ V,
511
+ RFA_K,
512
+ RFA_V,
513
+ WindowMask,
514
+ ChunkMask,
515
+ DO,
516
+ LSE,
517
+ DO_T_O,
518
+ DQ,
519
+ softmax_scale,
520
+ stride_qb, stride_qh, stride_qm,
521
+ stride_kb, stride_kh, stride_kn,
522
+ stride_vb, stride_vh, stride_vn,
523
+ stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
524
+ stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
525
+ stride_window_mask_b, stride_window_mask_m,
526
+ stride_chunk_mask_b, stride_chunk_mask_m,
527
+ stride_do_b, stride_do_h, stride_do_m,
528
+ stride_lse_b, stride_lse_h,
529
+ stride_do_t_o_b, stride_do_t_o_h,
530
+ stride_dq_b, stride_dq_h, stride_dq_m,
531
+ nheads,
532
+ seqlen_q,
533
+ seqlen_k,
534
+ nchunks,
535
+ headdim,
536
+ CHUNKS_PER_WINDOW: tl.constexpr,
537
+ WINDOW_SIZE: tl.constexpr,
538
+ MASK_TYPE: tl.constexpr,
539
+ EMPTY_RFA_KV: tl.constexpr,
540
+ BLOCK_HEADDIM: tl.constexpr,
541
+ EVEN_M: tl.constexpr,
542
+ EVEN_N: tl.constexpr,
543
+ EVEN_W: tl.constexpr,
544
+ EVEN_C: tl.constexpr,
545
+ EVEN_HEADDIM: tl.constexpr,
546
+ BLOCK_M: tl.constexpr,
547
+ BLOCK_N: tl.constexpr,
548
+ ):
549
+ start_m = tl.program_id(0)
550
+ off_bh = tl.program_id(1)
551
+ off_h = off_bh % nheads
552
+ off_b = off_bh // nheads
553
+ # initialize offsets
554
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
555
+ offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
556
+ offs_n = tl.arange(0, BLOCK_N)
557
+ offs_c = tl.arange(0, BLOCK_N)
558
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
559
+ # TODO: add paratheses or not
560
+ q_ptrs = (
561
+ Q +
562
+ off_b * stride_qb +
563
+ off_h * stride_qh +
564
+ (offs_m[:, None] * stride_qm + offs_d[None, :])
565
+ )
566
+ k_ptrs = (
567
+ K +
568
+ off_b * stride_kb +
569
+ off_h * stride_kh +
570
+ (offs_n[:, None] * stride_kn + offs_d[None, :])
571
+ )
572
+ v_ptrs = (
573
+ V +
574
+ off_b * stride_vb +
575
+ off_h * stride_vh +
576
+ (offs_n[:, None] * stride_vn + offs_d[None, :])
577
+ )
578
+ if EMPTY_RFA_KV == 0:
579
+ rfa_k_ptrs = (
580
+ RFA_K +
581
+ off_b * stride_rfa_kb +
582
+ off_h * stride_rfa_kh +
583
+ (offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
584
+ )
585
+ rfa_v_ptrs = (
586
+ RFA_V +
587
+ off_b * stride_rfa_vb +
588
+ off_h * stride_rfa_vh +
589
+ (offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
590
+ )
591
+ dq_ptrs = (
592
+ DQ +
593
+ off_b * stride_dq_b +
594
+ off_h * stride_dq_h +
595
+ (offs_m[:, None] * stride_dq_m + offs_d[None, :])
596
+ )
597
+ do_ptrs = (
598
+ DO +
599
+ off_b * stride_do_b +
600
+ off_h * stride_do_h +
601
+ (offs_m[:, None] * stride_do_m + offs_d[None, :])
602
+ )
603
+ do_t_o_ptrs = (
604
+ DO_T_O +
605
+ off_b * stride_do_t_o_b +
606
+ off_h * stride_do_t_o_h +
607
+ offs_m[:, None]
608
+ )
609
+ lse_ptrs = (
610
+ LSE +
611
+ off_b * stride_lse_b +
612
+ off_h * stride_lse_h +
613
+ offs_m[:, None]
614
+ )
615
+ ### load q, do, do_t_o, lse ####
616
+ if EVEN_M:
617
+ if EVEN_HEADDIM:
618
+ q = tl.load(
619
+ q_ptrs
620
+ )
621
+ do = tl.load(
622
+ do_ptrs
623
+ )
624
+ else:
625
+ q = tl.load(
626
+ q_ptrs,
627
+ mask=offs_d[None, :] < headdim,
628
+ other=0.0
629
+ )
630
+ do = tl.load(
631
+ do_ptrs,
632
+ mask=offs_d[None, :] < headdim,
633
+ other=0.0
634
+ )
635
+ do_t_o = tl.load(
636
+ do_t_o_ptrs
637
+ )
638
+ lse = tl.load(
639
+ lse_ptrs
640
+ )
641
+ else:
642
+ if EVEN_HEADDIM:
643
+ q = tl.load(
644
+ q_ptrs,
645
+ mask=offs_m[:, None] < seqlen_q,
646
+ other=0.0
647
+ )
648
+ do = tl.load(
649
+ do_ptrs,
650
+ mask=offs_m[:, None] < seqlen_q,
651
+ other=0.0
652
+ )
653
+ else:
654
+ q = tl.load(
655
+ q_ptrs,
656
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
657
+ other=0.0
658
+ )
659
+ do = tl.load(
660
+ do_ptrs,
661
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
662
+ other=0.0
663
+ )
664
+ do_t_o = tl.load(
665
+ do_t_o_ptrs,
666
+ mask=offs_m[:, None] < seqlen_q,
667
+ other=0.0
668
+ )
669
+ lse = tl.load(
670
+ lse_ptrs,
671
+ mask=offs_m[:, None] < seqlen_q,
672
+ other=0.0
673
+ )
674
+ lse = tl.where(lse == float("-inf"), 0.0, lse)
675
+ lse *= 1.4426950408889634 # log2(e)
676
+ qk_scale = softmax_scale
677
+ qk_scale *= 1.4426950408889634 # log2(e)
678
+ if MASK_TYPE == 1:
679
+ window_mask_ptrs = (
680
+ WindowMask +
681
+ off_b * stride_window_mask_b +
682
+ (offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
683
+ )
684
+ if EMPTY_RFA_KV == 0:
685
+ chunk_mask_ptrs = (
686
+ ChunkMask +
687
+ off_b * stride_chunk_mask_b +
688
+ (offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
689
+ )
690
+
691
+ dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
692
+ # loop over k, v and update accumulator
693
+ # Iterate over local singletons;
694
+ # so we only iterate over blocks within the current window
695
+ start_idx_n = offs_w * WINDOW_SIZE
696
+ end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
697
+ for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
698
+ start_n = tl.multiple_of(start_n, BLOCK_N)
699
+ if EVEN_N & EVEN_M:
700
+ if EVEN_HEADDIM:
701
+ k = tl.load(
702
+ k_ptrs + start_n * stride_kn
703
+ )
704
+ else:
705
+ k = tl.load(
706
+ k_ptrs + start_n * stride_kn,
707
+ mask=offs_d[None, :] < headdim,
708
+ other=0.0
709
+ )
710
+ else:
711
+ if EVEN_HEADDIM:
712
+ k = tl.load(
713
+ k_ptrs + start_n * stride_kn,
714
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
715
+ other=0.0,
716
+ )
717
+ else:
718
+ k = tl.load(
719
+ k_ptrs + start_n * stride_kn,
720
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
721
+ other=0.0,
722
+ )
723
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
724
+ qk += tl.dot(q, tl.trans(k))
725
+ # Trying to combine the two masks seem to make the result wrong
726
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
727
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
728
+
729
+ if MASK_TYPE == 1:
730
+ if EVEN_M & EVEN_W:
731
+ window_mask = tl.load(
732
+ window_mask_ptrs + start_n - start_idx_n
733
+ )
734
+ else:
735
+ window_mask = tl.load(
736
+ window_mask_ptrs + start_n - start_idx_n,
737
+ mask=(offs_m[:, None] < seqlen_q)
738
+ & ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
739
+ other=1,
740
+ )
741
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
742
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
743
+ # to multiply with softmax_scale here.
744
+ # we assume mask already implies the causal masking
745
+ qk = qk * qk_scale
746
+ qk = tl.where(window_mask, float("-inf"), qk)
747
+ p = tl.exp2(qk - lse)
748
+ else:
749
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
750
+ p = tl.exp2(qk * qk_scale - lse)
751
+
752
+ if EVEN_N & EVEN_M:
753
+ if EVEN_HEADDIM:
754
+ v = tl.load(
755
+ v_ptrs + start_n * stride_vn
756
+ )
757
+ else:
758
+ v = tl.load(
759
+ v_ptrs + start_n * stride_vn,
760
+ mask=offs_d[None, :] < headdim,
761
+ other=0.0
762
+ )
763
+ else:
764
+ if EVEN_HEADDIM:
765
+ v = tl.load(
766
+ v_ptrs + start_n * stride_vn,
767
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
768
+ other=0.0,
769
+ )
770
+ else:
771
+ v = tl.load(
772
+ v_ptrs + start_n * stride_vn,
773
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
774
+ other=0.0,
775
+ )
776
+ dp = tl.dot(do, tl.trans(v))
777
+ ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
778
+ dq += tl.dot(ds, k)
779
+
780
+ if EMPTY_RFA_KV == 0:
781
+ # Iterate over RFA chunks
782
+ # we only iterate over chunks before the current local singleton window
783
+ end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
784
+ for start_c in range(0, end_idx_c, BLOCK_N):
785
+ start_c = tl.multiple_of(start_c, BLOCK_N)
786
+ # -- compute qk ----
787
+ if EVEN_C & EVEN_M:
788
+ if EVEN_HEADDIM:
789
+ rfa_k = tl.load(
790
+ rfa_k_ptrs + start_c * stride_rfa_kc
791
+ )
792
+ else:
793
+ rfa_k = tl.load(
794
+ rfa_k_ptrs + start_c * stride_rfa_kc,
795
+ mask=offs_d[None, :] < headdim,
796
+ other=0.0
797
+ )
798
+ else:
799
+ if EVEN_HEADDIM:
800
+ rfa_k = tl.load(
801
+ rfa_k_ptrs + start_c * stride_rfa_kc,
802
+ mask=(start_c + offs_c)[:, None] < nchunks,
803
+ other=0.0,
804
+ )
805
+ else:
806
+ rfa_k = tl.load(
807
+ rfa_k_ptrs + start_c * stride_rfa_kc,
808
+ mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
809
+ other=0.0,
810
+ )
811
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
812
+ qk += tl.dot(q, tl.trans(rfa_k))
813
+ # Trying to combine the two masks seem to make the result wrong
814
+ if not EVEN_C: # Need to mask out otherwise the softmax is wrong
815
+ qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
816
+
817
+ if MASK_TYPE == 1:
818
+ if EVEN_C & EVEN_M:
819
+ chunk_mask = tl.load(
820
+ chunk_mask_ptrs + start_c
821
+ )
822
+ else:
823
+ chunk_mask = tl.load(
824
+ chunk_mask_ptrs + start_c,
825
+ mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks),
826
+ other=1,
827
+ )
828
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
829
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
830
+ # to multiply with softmax_scale here.
831
+ # we assume mask already implies the causal masking
832
+ qk = qk * qk_scale
833
+ qk = tl.where(chunk_mask, float("-inf"), qk)
834
+ p = tl.exp2(qk - lse)
835
+ else:
836
+ p = tl.exp2(qk * qk_scale - lse)
837
+
838
+ if EVEN_C & EVEN_M:
839
+ if EVEN_HEADDIM:
840
+ rfa_v = tl.load(
841
+ rfa_v_ptrs + start_c * stride_rfa_vc
842
+ )
843
+ else:
844
+ rfa_v = tl.load(
845
+ rfa_v_ptrs + start_c * stride_rfa_vc,
846
+ mask=offs_d[None, :] < headdim,
847
+ other=0.0
848
+ )
849
+ else:
850
+ if EVEN_HEADDIM:
851
+ rfa_v = tl.load(
852
+ rfa_v_ptrs + start_c * stride_rfa_vc,
853
+ mask=(start_c + offs_n)[:, None] < nchunks,
854
+ other=0.0,
855
+ )
856
+ else:
857
+ rfa_v = tl.load(
858
+ rfa_v_ptrs + start_c * stride_rfa_vc,
859
+ mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
860
+ other=0.0,
861
+ )
862
+ dp = tl.dot(do, tl.trans(rfa_v))
863
+ ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
864
+ dq += tl.dot(ds, rfa_k)
865
+
866
+ start_m = tl.program_id(0)
867
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
868
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
869
+ dq_ptrs = (
870
+ DQ +
871
+ off_b * stride_dq_b +
872
+ off_h * stride_dq_h +
873
+ (offs_m[:, None] * stride_dq_m + offs_d[None, :])
874
+ )
875
+ if EVEN_M:
876
+ if EVEN_HEADDIM:
877
+ tl.store(
878
+ dq_ptrs, dq
879
+ )
880
+ else:
881
+ tl.store(
882
+ dq_ptrs, dq,
883
+ mask=offs_d[None, :] < headdim
884
+ )
885
+ else:
886
+ if EVEN_HEADDIM:
887
+ tl.store(
888
+ dq_ptrs, dq,
889
+ mask=offs_m[:, None] < seqlen_q
890
+ )
891
+ else:
892
+ tl.store(
893
+ dq_ptrs, dq,
894
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
895
+ )
896
+
897
+ _capability_90_config = {
898
+ "fwd": {
899
+ (torch.bfloat16, 64): (128, 128, 4, 3),
900
+ (torch.bfloat16, 128): (128, 128, 8, 3),
901
+ (torch.float32, 64): (128, 64, 8, 3),
902
+ (torch.float32, 128): (64, 32, 4, 3),
903
+ },
904
+ "bwd_dq": {
905
+ (torch.bfloat16, 64): (128, 64, 4, 3),
906
+ (torch.bfloat16, 128): (128, 64, 8, 3),
907
+ (torch.float32, 64): (128, 64, 8, 2),
908
+ (torch.float32, 128): (32, 32, 4, 2),
909
+ },
910
+ "bwd_dkdv": {
911
+ (torch.bfloat16, 64): (128, 64, 4, 2),
912
+ (torch.bfloat16, 128): (128, 64, 8, 2),
913
+ (torch.float32, 64): (128, 64, 8, 2),
914
+ (torch.float32, 128): (32, 32, 4, 1),
915
+ },
916
+ "bwd_drfa_kv": {
917
+ (torch.bfloat16, 64): (128, 64, 4, 2),
918
+ (torch.bfloat16, 128): (128, 64, 8, 2),
919
+ (torch.float32, 64): (128, 64, 8, 2),
920
+ (torch.float32, 128): (32, 32, 4, 1),
921
+ }
922
+ }
923
+
924
+ _capability_80_config = {
925
+ "fwd": {
926
+ (torch.bfloat16, 64): (64, 64, 4, 3),
927
+ (torch.bfloat16, 128): (64, 64, 8, 3),
928
+ (torch.float32, 64): (64, 32, 4, 2),
929
+ (torch.float32, 128): (64, 32, 8, 1),
930
+ },
931
+ "bwd_dq": {
932
+ (torch.bfloat16, 64): (64, 64, 4, 3),
933
+ (torch.bfloat16, 128): (64, 32, 4, 2),
934
+ (torch.float32, 64): (32, 32, 4, 2),
935
+ (torch.float32, 128): (32, 32, 4, 2),
936
+ },
937
+ "bwd_dkdv": {
938
+ (torch.bfloat16, 64): (64, 64, 4, 3),
939
+ (torch.bfloat16, 128): (32, 32, 4, 2),
940
+ (torch.float32, 64): (32, 32, 4, 1),
941
+ (torch.float32, 128): (16, 64, 8, 1),
942
+ },
943
+ "bwd_drfa_kv": {
944
+ (torch.bfloat16, 64): (64, 64, 4, 3),
945
+ (torch.bfloat16, 128): (64, 32, 4, 3),
946
+ (torch.float32, 64): (32, 32, 4, 1),
947
+ (torch.float32, 128): (32, 32, 4, 1),
948
+ }
949
+ }
950
+
951
+ def _get_config(dtype, head_dim, mode) -> tuple[int, int, int, int]:
952
+ capability = torch.cuda.get_device_capability()
953
+ if capability >= (9, 0):
954
+ kernel_config = _capability_90_config[mode].get((dtype, head_dim), (32, 32, 4, 1))
955
+ elif capability >= (8, 0):
956
+ kernel_config = _capability_80_config[mode].get((dtype, head_dim), (16, 16, 4, 1))
957
+ else:
958
+ if mode == "fwd":
959
+ if dtype == torch.float32:
960
+ kernel_config = (32, 16, 4, 2)
961
+ else:
962
+ kernel_config = (64, 32, 4, 2)
963
+ else:
964
+ if dtype == torch.float32:
965
+ kernel_config = (16, 16, 4, 1)
966
+ else:
967
+ kernel_config = (32, 32, 4, 1)
968
+ return kernel_config
969
+
970
  @triton.heuristics(
971
  {
972
  "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
 
984
  RFA_K,
985
  RFA_V,
986
  WindowMask,
987
+ ChunkMask,
988
  Out,
989
+ LSE,
990
  softmax_scale,
991
  stride_qb, stride_qh, stride_qm,
992
  stride_kb, stride_kh, stride_kn,
993
  stride_vb, stride_vh, stride_vn,
994
  stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
995
  stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
996
+ stride_window_mask_b, stride_window_mask_m,
997
+ stride_chunk_mask_b, stride_chunk_mask_m,
998
  stride_ob, stride_oh, stride_om,
999
+ stride_lse_b, stride_lse_h,
1000
  nheads,
1001
  seqlen_q,
1002
  seqlen_k,
1003
  nchunks,
1004
  headdim,
 
 
 
1005
  CHUNKS_PER_WINDOW: tl.constexpr,
1006
  WINDOW_SIZE: tl.constexpr,
1007
  MASK_TYPE: tl.constexpr,
 
1061
  qk_scale = softmax_scale
1062
  qk_scale *= 1.4426950408889634 # log2(e)
1063
  if MASK_TYPE == 1:
1064
+ window_mask_ptrs = (
1065
  WindowMask +
1066
+ off_b * stride_window_mask_b +
1067
+ (offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
1068
  )
1069
+ if EMPTY_RFA_KV == 0:
1070
+ chunk_mask_ptrs = (
1071
+ ChunkMask +
1072
+ off_b * stride_chunk_mask_b +
1073
+ (offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
1074
+ )
1075
+
1076
  m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
1077
  d_i = tl.zeros([BLOCK_M], dtype=tl.float32)
1078
  acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
 
1143
 
1144
  if MASK_TYPE == 1:
1145
  if EVEN_M & EVEN_W:
1146
+ window_mask = tl.load(
1147
+ window_mask_ptrs + start_n - start_idx_n
1148
+ )
1149
  else:
1150
+ window_mask = tl.load(
1151
+ window_mask_ptrs + start_n - start_idx_n,
1152
  mask=(offs_m[:, None] < seqlen_q)
1153
  & ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
1154
+ other=1,
1155
+ )
1156
  # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
1157
  # can then fuse the mult and add into an fma instruction. But if we have bias we need to
1158
  # to multiply with softmax_scale here.
1159
  # we assume mask already implies the causal masking
1160
+ qk = qk * qk_scale
1161
+ qk = tl.where(window_mask, float("-inf"), qk)
1162
  m_ij = tl.maximum(tl.max(qk, 1), m_i)
1163
+ masked_out_rows = (m_ij == float("-inf"))
1164
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
1165
+ p = tl.exp2(qk - m_ij_masked[:, None])
1166
  else:
1167
  qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
1168
  m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
1169
+ masked_out_rows = (m_ij == float("-inf"))
1170
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
1171
+ p = tl.exp2(qk * qk_scale - m_ij_masked[:, None])
1172
 
1173
  d_ij = tl.sum(p, 1)
1174
 
1175
  # scale acc_o
1176
+ prev_scale = tl.exp2(m_i - m_ij_masked)
1177
  # # -- update output accumulator --
1178
  acc_o = acc_o * prev_scale[:, None]
1179
  # update acc_o
 
1245
  if not EVEN_C: # Need to mask out otherwise the softmax is wrong
1246
  qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
1247
 
1248
+ if MASK_TYPE == 1:
1249
+ if EVEN_C & EVEN_M:
1250
+ chunk_mask = tl.load(
1251
+ chunk_mask_ptrs + start_c
1252
+ )
1253
+ else:
1254
+ chunk_mask = tl.load(
1255
+ chunk_mask_ptrs + start_c,
1256
+ mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks),
1257
+ other=1,
1258
+ )
1259
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
1260
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
1261
+ # to multiply with softmax_scale here.
1262
+ # we assume mask already implies the causal masking
1263
+ qk = qk * qk_scale
1264
+ qk = tl.where(chunk_mask, float("-inf"), qk)
1265
+ m_ij = tl.maximum(tl.max(qk, 1), m_i)
1266
+ masked_out_rows = (m_ij == float("-inf"))
1267
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
1268
+ p = tl.exp2(qk - m_ij_masked[:, None])
1269
+ else:
1270
+ m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
1271
+ masked_out_rows = (m_ij == float("-inf"))
1272
+ m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
1273
+ p = tl.exp2(qk * qk_scale - m_ij_masked[:, None])
1274
 
1275
  d_ij = tl.sum(p, 1)
1276
 
1277
  # scale acc_o
1278
+ prev_scale = tl.exp2(m_i - m_ij_masked)
1279
  # # -- update output accumulator --
1280
  acc_o = acc_o * prev_scale[:, None]
1281
  # update acc_o
 
1311
  d_i = d_i * prev_scale + d_ij
1312
  m_i = m_ij
1313
 
1314
+ # for rows that are all -inf, set d_i to 1.0
1315
+ d_i = tl.where(d_i == 0.0, 1.0, d_i)
1316
+ # multiply by log(2)
1317
+ lse_m = (m_i + tl.math.log2(d_i)) * 0.6931471805599453
1318
  acc_o = acc_o / d_i[:, None]
1319
  # TODO: understand why rematerialize offsets to save registers?
1320
  start_m = tl.program_id(0)
 
1347
  out_ptrs, acc_o,
1348
  mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
1349
  )
1350
+ lse_ptrs = (
1351
+ LSE +
1352
+ off_b * stride_lse_b +
1353
+ off_h * stride_lse_h +
1354
+ offs_m
1355
+ )
1356
+ if EVEN_M:
1357
+ tl.store(
1358
+ lse_ptrs, lse_m,
1359
+ )
1360
+ else:
1361
+ tl.store(
1362
+ lse_ptrs, lse_m,
1363
+ mask=offs_m < seqlen_q
1364
+ )
1365
 
1366
+ def triton_eva_agg_fwd(
1367
+ q, k, v, rfa_k, rfa_v,
1368
+ window_mask,
1369
+ chunk_mask,
1370
+ softmax_scale,
1371
+ window_size,
1372
+ chunks_per_window
1373
+ ):
1374
  if rfa_k is None and rfa_v is None:
1375
  empty_rfa_kv = 1
1376
 
 
1410
  mask_type = 0
1411
  if window_mask is not None:
1412
  mask_type = 1
1413
+ assert window_mask.dtype == torch.bool
1414
  assert window_mask.is_cuda
1415
  assert window_mask.dim() == 4
1416
  assert window_mask.shape == (batch, 1, seqlen_q, window_size)
1417
  if window_mask.stride(-1) != 1:
1418
  window_mask = window_mask.contiguous()
1419
+
1420
+ assert chunk_mask is not None
1421
+ assert chunk_mask.dtype == torch.bool
1422
+ assert chunk_mask.is_cuda
1423
+ assert chunk_mask.dim() == 4
1424
+ assert chunk_mask.shape == (batch, 1, seqlen_q, nchunks)
1425
+ if chunk_mask.stride(-1) != 1:
1426
+ chunk_mask = chunk_mask.contiguous()
1427
+
1428
+ chunk_mask_strides = (
1429
+ (chunk_mask.stride(0), chunk_mask.stride(2))
1430
+ if mask_type == 1 else
1431
+ (0, 0)
1432
+ )
1433
+ window_mask_strides = (
1434
  (window_mask.stride(0), window_mask.stride(2))
1435
  if mask_type == 1 else
1436
  (0, 0)
 
1446
  if empty_rfa_kv == 0 else
1447
  (0, 0, 0)
1448
  )
 
1449
 
1450
  o = torch.empty_like(q)
1451
+ lse = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
1452
 
1453
  BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
1454
+
1455
+ BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "fwd")
1456
+
1457
+ assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK"
1458
+ assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK_N"
 
 
 
 
1459
 
1460
  grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
1461
  _fwd_eva_agg_kernel[grid](
 
1465
  rfa_k,
1466
  rfa_v,
1467
  window_mask,
1468
+ chunk_mask,
1469
  o,
1470
+ lse,
1471
  softmax_scale,
1472
  q.stride(0), q.stride(1), q.stride(2),
1473
  k.stride(0), k.stride(1), k.stride(2),
1474
  v.stride(0), v.stride(1), v.stride(2),
1475
  rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
1476
  rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
1477
+ window_mask_strides[0], window_mask_strides[1],
1478
+ chunk_mask_strides[0], chunk_mask_strides[1],
1479
  o.stride(0), o.stride(1), o.stride(2),
1480
+ lse.stride(0), lse.stride(1),
1481
+ nheads,
1482
+ seqlen_q,
1483
+ seqlen_k,
1484
+ nchunks,
1485
+ head_dim,
1486
+ chunks_per_window,
1487
+ window_size,
1488
+ mask_type,
1489
+ empty_rfa_kv,
1490
+ BLOCK_HEADDIM,
1491
+ BLOCK_M=BLOCK_M,
1492
+ BLOCK_N=BLOCK_N,
1493
+ num_warps=num_warps,
1494
+ num_stages=num_stages,
1495
+ )
1496
+ return o, lse
1497
+
1498
+ def triton_eva_agg_bwd(
1499
+ do,
1500
+ q, k, v, rfa_k, rfa_v,
1501
+ window_mask, chunk_mask,
1502
+ o, lse,
1503
+ dq, dk, dv, d_rfa_k, d_rfa_v,
1504
+ softmax_scale,
1505
+ window_size,
1506
+ chunks_per_window,
1507
+ empty_rfa_kv,
1508
+ mask_type,
1509
+ ):
1510
+ if do.stride(-1) != 1:
1511
+ do = do.contiguous()
1512
+
1513
+ # shape constraints
1514
+ batch, nheads, seqlen_q, head_dim = q.shape
1515
+ _, _, seqlen_k, _ = k.shape
1516
+ if empty_rfa_kv == 0:
1517
+ nchunks = rfa_k.shape[-2]
1518
+ assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
1519
+ assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
1520
+ assert d_rfa_k.stride(-1) == d_rfa_v.stride(-1) == 1
1521
+ assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
1522
+ else:
1523
+ nchunks = 0
1524
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
1525
+
1526
+ assert lse.shape == (batch, nheads, seqlen_q)
1527
+ assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == rfa_k.stride(-1) == rfa_v.stride(-1) == 1
1528
+ assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
1529
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
1530
+
1531
+ assert head_dim <= 128, "We only test head dimensions up to 128"
1532
+
1533
+ window_mask_strides = (
1534
+ (window_mask.stride(0), window_mask.stride(2))
1535
+ if mask_type == 1 else
1536
+ (0, 0)
1537
+ )
1538
+ chunk_mask_strides = (
1539
+ (chunk_mask.stride(0), chunk_mask.stride(2))
1540
+ if mask_type == 1 else
1541
+ (0, 0)
1542
+ )
1543
+
1544
+ rfa_k_strides = (
1545
+ (rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
1546
+ if empty_rfa_kv == 0 else
1547
+ (0, 0, 0)
1548
+ )
1549
+ rfa_v_strides = (
1550
+ (rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
1551
+ if empty_rfa_kv == 0 else
1552
+ (0, 0, 0)
1553
+ )
1554
+
1555
+ d_rfa_k_strides = (
1556
+ (d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2))
1557
+ if empty_rfa_kv == 0 else
1558
+ (0, 0, 0)
1559
+ )
1560
+ d_rfa_v_strides = (
1561
+ (d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2))
1562
+ if empty_rfa_kv == 0 else
1563
+ (0, 0, 0)
1564
+ )
1565
+
1566
+ BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
1567
+
1568
+ do_t_o = torch.sum(do.to(torch.float32) * o.to(torch.float32), dim=-1).to(do.dtype)
1569
+
1570
+ BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dq")
1571
+
1572
+ assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK"
1573
+ assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK"
1574
+ grid = lambda META: (
1575
+ triton.cdiv(seqlen_q, META["BLOCK_M"]),
1576
+ batch * nheads,
1577
+ )
1578
+ _bwd_eva_agg_kernel_dq[grid](
1579
+ q,
1580
+ k,
1581
+ v,
1582
+ rfa_k,
1583
+ rfa_v,
1584
+ window_mask,
1585
+ chunk_mask,
1586
+ do,
1587
+ lse,
1588
+ do_t_o,
1589
+ dq,
1590
+ softmax_scale,
1591
+ q.stride(0), q.stride(1), q.stride(2),
1592
+ k.stride(0), k.stride(1), k.stride(2),
1593
+ v.stride(0), v.stride(1), v.stride(2),
1594
+ rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
1595
+ rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
1596
+ window_mask_strides[0], window_mask_strides[1],
1597
+ chunk_mask_strides[0], chunk_mask_strides[1],
1598
+ do.stride(0), do.stride(1), do.stride(2),
1599
+ lse.stride(0), lse.stride(1),
1600
+ do_t_o.stride(0), do_t_o.stride(1),
1601
+ dq.stride(0), dq.stride(1), dq.stride(2),
1602
  nheads,
1603
  seqlen_q,
1604
  seqlen_k,
1605
  nchunks,
1606
  head_dim,
 
 
 
1607
  chunks_per_window,
1608
  window_size,
1609
  mask_type,
1610
  empty_rfa_kv,
1611
  BLOCK_HEADDIM,
1612
+ BLOCK_M=BLOCK_M,
1613
+ BLOCK_N=BLOCK_N,
1614
+ num_warps=num_warps,
1615
+ num_stages=num_stages,
1616
+ )
1617
+
1618
+ BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dkdv")
1619
+ grid = lambda META: (
1620
+ triton.cdiv(seqlen_k, META["BLOCK_N"]),
1621
+ batch * nheads,
1622
+ )
1623
+ _bwd_eva_agg_kernel_dkdv[grid](
1624
+ q,
1625
+ k,
1626
+ v,
1627
+ window_mask,
1628
+ do,
1629
+ lse,
1630
+ do_t_o,
1631
+ dk,
1632
+ dv,
1633
+ softmax_scale,
1634
+ q.stride(0), q.stride(1), q.stride(2),
1635
+ k.stride(0), k.stride(1), k.stride(2),
1636
+ v.stride(0), v.stride(1), v.stride(2),
1637
+ window_mask_strides[0], window_mask_strides[1],
1638
+ do.stride(0), do.stride(1), do.stride(2),
1639
+ lse.stride(0), lse.stride(1),
1640
+ do_t_o.stride(0), do_t_o.stride(1),
1641
+ dk.stride(0), dk.stride(1), dk.stride(2),
1642
+ dv.stride(0), dv.stride(1), dv.stride(2),
1643
+ nheads,
1644
+ seqlen_q,
1645
+ seqlen_k,
1646
+ head_dim,
1647
+ window_size,
1648
+ mask_type,
1649
+ BLOCK_HEADDIM,
1650
+ BLOCK_M=BLOCK_M,
1651
+ BLOCK_N=BLOCK_N,
1652
  num_warps=num_warps,
1653
+ num_stages=num_stages,
1654
+ )
1655
+ if empty_rfa_kv == 0:
1656
+ BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_drfa_kv")
1657
+ grid = lambda META: (
1658
+ triton.cdiv(nchunks, META["BLOCK_N"]),
1659
+ batch * nheads,
1660
+ )
1661
+ _bwd_eva_agg_kernel_drfa_kv[grid](
1662
+ q,
1663
+ rfa_k,
1664
+ rfa_v,
1665
+ chunk_mask,
1666
+ do,
1667
+ lse,
1668
+ do_t_o,
1669
+ d_rfa_k,
1670
+ d_rfa_v,
1671
+ softmax_scale,
1672
+ q.stride(0), q.stride(1), q.stride(2),
1673
+ rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
1674
+ rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
1675
+ chunk_mask_strides[0], chunk_mask_strides[1],
1676
+ do.stride(0), do.stride(1), do.stride(2),
1677
+ lse.stride(0), lse.stride(1),
1678
+ do_t_o.stride(0), do_t_o.stride(1),
1679
+ d_rfa_k_strides[0], d_rfa_k_strides[1], d_rfa_k_strides[2],
1680
+ d_rfa_v_strides[0], d_rfa_v_strides[1], d_rfa_v_strides[2],
1681
+ nheads,
1682
+ seqlen_q,
1683
+ nchunks,
1684
+ head_dim,
1685
+ chunks_per_window,
1686
+ window_size,
1687
+ mask_type,
1688
+ BLOCK_HEADDIM,
1689
+ BLOCK_M=BLOCK_M,
1690
+ BLOCK_N=BLOCK_N,
1691
+ num_warps=num_warps,
1692
+ num_stages=num_stages,
1693
+ )
1694
+
1695
+
1696
+ class EvaAggFunc(torch.autograd.Function):
1697
+ @staticmethod
1698
+ def forward(ctx, q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale=None, window_size=None, chunks_per_window=None):
1699
+ if rfa_k is None and rfa_v is None:
1700
+ empty_rfa_kv = 1
1701
+ else:
1702
+ assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
1703
+ empty_rfa_kv = 0
1704
+
1705
+ if window_mask is not None:
1706
+ mask_type = 1
1707
+ else:
1708
+ mask_type = 0
1709
+ o, lse = triton_eva_agg_fwd(
1710
+ q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale, window_size, chunks_per_window
1711
+ )
1712
+ ctx.save_for_backward(q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask)
1713
+ ctx.softmax_scale = softmax_scale
1714
+ ctx.window_size = window_size
1715
+ ctx.chunks_per_window = chunks_per_window
1716
+ ctx.empty_rfa_kv = empty_rfa_kv
1717
+ ctx.mask_type = mask_type
1718
+ return o
1719
+
1720
+ @staticmethod
1721
+ def backward(ctx, do):
1722
+ q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask = ctx.saved_tensors
1723
+ dq = torch.empty_like(q)
1724
+ dk = torch.empty_like(k)
1725
+ dv = torch.empty_like(v)
1726
+ if ctx.empty_rfa_kv == 0:
1727
+ d_rfa_k = torch.empty_like(rfa_k)
1728
+ d_rfa_v = torch.empty_like(rfa_v)
1729
+ else:
1730
+ d_rfa_k = None
1731
+ d_rfa_v = None
1732
+ triton_eva_agg_bwd(
1733
+ do,
1734
+ q,
1735
+ k,
1736
+ v,
1737
+ rfa_k,
1738
+ rfa_v,
1739
+ window_mask,
1740
+ chunk_mask,
1741
+ o,
1742
+ lse,
1743
+ dq,
1744
+ dk,
1745
+ dv,
1746
+ d_rfa_k,
1747
+ d_rfa_v,
1748
+ softmax_scale=ctx.softmax_scale,
1749
+ window_size=ctx.window_size,
1750
+ chunks_per_window=ctx.chunks_per_window,
1751
+ empty_rfa_kv=ctx.empty_rfa_kv,
1752
+ mask_type=ctx.mask_type,
1753
+ )
1754
+ return dq, dk, dv, d_rfa_k, d_rfa_v, None, None, None, None, None
1755
+
1756
+
1757
+ def eva_agg_func_triton(
1758
+ q, k, v, rfa_k, rfa_v,
1759
+ window_mask, chunk_mask,
1760
+ softmax_scale=None, window_size=None, chunks_per_window=None,
1761
+ ):
1762
+ return EvaAggFunc.apply(
1763
+ q, k, v, rfa_k, rfa_v,
1764
+ window_mask, chunk_mask,
1765
+ softmax_scale, window_size, chunks_per_window,
1766
  )
 
eva_prep_kv_kernel.py CHANGED
@@ -16,7 +16,7 @@ def _fwd_eva_prep_kv_kernel(
16
  V, # [b, h, n, d]
17
  PARAM_MU, # [1, h, 1, 1, d]
18
  PARAM_PHI, # [1, h, 1, 1, d]
19
- ChunkMask, # [b, h, n, 1]
20
  Out_RFA_K, # [b, h, c, d]
21
  Out_RFA_V, # [b, h, c, d]
22
  softmax_scale,
@@ -31,8 +31,6 @@ def _fwd_eva_prep_kv_kernel(
31
  seqlen,
32
  nchunks,
33
  headdim,
34
- CACHE_KEY_SEQLEN, # TODO: why keeping this
35
- CACHE_KEY_NCHUNKS, # TODO: why keeping this
36
  CHUNKS_PER_BLOCK: tl.constexpr,
37
  CHUNK_SIZE: tl.constexpr,
38
  MASK_TYPE: tl.constexpr,
@@ -91,7 +89,7 @@ def _fwd_eva_prep_kv_kernel(
91
  log2e = 1.4426950408889634
92
  if MASK_TYPE == 1:
93
  m_ptrs = (
94
- ChunkMask +
95
  offs_b * stride_mb +
96
  (
97
  (
@@ -144,7 +142,7 @@ def _fwd_eva_prep_kv_kernel(
144
  if EVEN_N:
145
  mask = tl.load(
146
  m_ptrs
147
- ).to(tl.float32)
148
  else:
149
  mask = tl.load(
150
  m_ptrs,
@@ -153,12 +151,17 @@ def _fwd_eva_prep_kv_kernel(
153
  offs_c[:, None] * CHUNK_SIZE +
154
  offs_m[None, :]
155
  ) < seqlen,
156
- other=0.0,
157
- ).to(tl.float32)
158
- rfa_k_c_w = rfa_k_c_w + mask
159
-
160
- rfa_k_c_w = tl.exp2(rfa_k_c_w - tl.max(rfa_k_c_w, axis=-1)[:, None])
161
- rfa_k_c_w = rfa_k_c_w / tl.sum(rfa_k_c_w, axis=-1)[:, None]
 
 
 
 
 
162
  rfa_k_c = tl.sum(k * rfa_k_c_w[:, :, None].to(k.dtype), axis=-2)
163
  # TODO: understand why rematerialize offsets to save registers?
164
  offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
@@ -209,7 +212,7 @@ def _fwd_eva_prep_kv_kernel(
209
  )
210
 
211
  if MASK_TYPE == 1:
212
- rfa_v_c_w = rfa_v_c_w + mask
213
 
214
  if EVEN_N:
215
  if EVEN_HEADDIM:
@@ -246,8 +249,14 @@ def _fwd_eva_prep_kv_kernel(
246
  other=0.0
247
  )
248
 
249
- rfa_v_c_w = tl.exp2(rfa_v_c_w - tl.max(rfa_v_c_w, axis=-1)[:, None])
250
- rfa_v_c_w = rfa_v_c_w / tl.sum(rfa_v_c_w, axis=-1)[:, None]
 
 
 
 
 
 
251
  rfa_v_c = tl.sum(v * rfa_v_c_w[:, :, None].to(v.dtype), axis=-2)
252
 
253
  offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
@@ -279,7 +288,529 @@ def _fwd_eva_prep_kv_kernel(
279
  mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
280
  )
281
 
282
- def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, chunk_mask, softmax_scale, chunksize):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  k, v, param_mu, param_phi = [
284
  x if x.stride(-1) == 1 else x.contiguous()
285
  for x in [k, v, param_mu, param_phi]
@@ -300,16 +831,16 @@ def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, chunk_mask, softmax_scale,
300
  softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
301
 
302
  mask_type = 0
303
- if chunk_mask is not None:
304
  mask_type = 1
305
- assert chunk_mask.dtype == k.dtype
306
- assert chunk_mask.is_cuda
307
- assert chunk_mask.dim() == 4
308
- assert chunk_mask.shape == (batch, 1, seqlen, 1)
309
- if chunk_mask.stride(-1) != 1:
310
- chunk_mask = chunk_mask.contiguous()
311
  mask_strides = (
312
- (chunk_mask.stride(0), chunk_mask.stride(2))
313
  if mask_type == 1 else
314
  (0, 0)
315
  )
@@ -329,7 +860,7 @@ def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, chunk_mask, softmax_scale,
329
  v,
330
  param_mu,
331
  param_phi,
332
- chunk_mask,
333
  out_rfa_k,
334
  out_rfa_v,
335
  softmax_scale,
@@ -344,8 +875,6 @@ def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, chunk_mask, softmax_scale,
344
  seqlen,
345
  nchunks,
346
  head_dim,
347
- seqlen // 32,
348
- nchunks // 32,
349
  chunks_per_block,
350
  chunksize,
351
  mask_type,
@@ -355,3 +884,134 @@ def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, chunk_mask, softmax_scale,
355
  num_stages=1,
356
  )
357
  return out_rfa_k, out_rfa_v
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  V, # [b, h, n, d]
17
  PARAM_MU, # [1, h, 1, 1, d]
18
  PARAM_PHI, # [1, h, 1, 1, d]
19
+ Mask, # [b, h, n, 1]
20
  Out_RFA_K, # [b, h, c, d]
21
  Out_RFA_V, # [b, h, c, d]
22
  softmax_scale,
 
31
  seqlen,
32
  nchunks,
33
  headdim,
 
 
34
  CHUNKS_PER_BLOCK: tl.constexpr,
35
  CHUNK_SIZE: tl.constexpr,
36
  MASK_TYPE: tl.constexpr,
 
89
  log2e = 1.4426950408889634
90
  if MASK_TYPE == 1:
91
  m_ptrs = (
92
+ Mask +
93
  offs_b * stride_mb +
94
  (
95
  (
 
142
  if EVEN_N:
143
  mask = tl.load(
144
  m_ptrs
145
+ )
146
  else:
147
  mask = tl.load(
148
  m_ptrs,
 
151
  offs_c[:, None] * CHUNK_SIZE +
152
  offs_m[None, :]
153
  ) < seqlen,
154
+ other=1,
155
+ )
156
+ rfa_k_c_w = tl.where(mask, float("-inf"), rfa_k_c_w)
157
+
158
+ m_rfa_k_c_w = tl.max(rfa_k_c_w, axis=-1)
159
+ masked_out_rows_rfa_k = (m_rfa_k_c_w == float("-inf"))
160
+ m_rfa_k_c_w_masked = tl.where(masked_out_rows_rfa_k, 0, m_rfa_k_c_w)
161
+ rfa_k_c_w = tl.exp2(rfa_k_c_w - m_rfa_k_c_w_masked[:, None])
162
+ denom_k = tl.sum(rfa_k_c_w, axis=-1)
163
+ denom_k = tl.where(denom_k == 0.0, 1.0, denom_k)
164
+ rfa_k_c_w = rfa_k_c_w / denom_k[:, None]
165
  rfa_k_c = tl.sum(k * rfa_k_c_w[:, :, None].to(k.dtype), axis=-2)
166
  # TODO: understand why rematerialize offsets to save registers?
167
  offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
 
212
  )
213
 
214
  if MASK_TYPE == 1:
215
+ rfa_v_c_w = tl.where(mask, float("-inf"), rfa_v_c_w)
216
 
217
  if EVEN_N:
218
  if EVEN_HEADDIM:
 
249
  other=0.0
250
  )
251
 
252
+
253
+ m_rfa_v_c_w = tl.max(rfa_v_c_w, axis=-1)
254
+ masked_out_rows_rfa_v = (m_rfa_v_c_w == float("-inf"))
255
+ m_rfa_v_c_w_masked = tl.where(masked_out_rows_rfa_v, 0, m_rfa_v_c_w)
256
+ rfa_v_c_w = tl.exp2(rfa_v_c_w - m_rfa_v_c_w_masked[:, None])
257
+ denom_v = tl.sum(rfa_v_c_w, axis=-1)
258
+ denom_v = tl.where(denom_v == 0.0, 1.0, denom_v)
259
+ rfa_v_c_w = rfa_v_c_w / denom_v[:, None]
260
  rfa_v_c = tl.sum(v * rfa_v_c_w[:, :, None].to(v.dtype), axis=-2)
261
 
262
  offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
 
288
  mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
289
  )
290
 
291
+
292
+
293
+ @triton.heuristics(
294
+ {
295
+ "EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0,
296
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
297
+ }
298
+ )
299
+ @triton.jit
300
+ def _bwd_eva_prep_kv_kernel(
301
+ RFA_K, # [b, h, c, d]
302
+ RFA_V, # [b, h, c, d]
303
+ K, # [b, h, n, d]
304
+ V, # [b, h, n, d]
305
+ PARAM_MU, # [1, h, 1, 1, d]
306
+ PARAM_PHI, # [1, h, 1, 1, d]
307
+ Mask, # [b, h, n, 1]
308
+ D_RFA_K, # [b, h, c, d]
309
+ D_RFA_V, # [b, h, c, d]
310
+ D_K, # [b, h, n, d]
311
+ D_V, # [b, h, n, d]
312
+ D_PARAM_MU_PARTIAL, # [b, h, g, d]
313
+ D_PARAM_PHI_PARTIAL, # [b, h, g, d]
314
+ softmax_scale,
315
+ stride_rfa_k_b, stride_rfa_k_h, stride_rfa_k_c,
316
+ stride_rfa_v_b, stride_rfa_v_h, stride_rfa_v_c,
317
+ stride_kb, stride_kh, stride_kn,
318
+ stride_vb, stride_vh, stride_vn,
319
+ stride_mu_h,
320
+ stride_phi_h,
321
+ stride_mb, stride_mn,
322
+ stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c,
323
+ stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c,
324
+ stride_d_k_b, stride_d_k_h, stride_d_k_n,
325
+ stride_d_v_b, stride_d_v_h, stride_d_v_n,
326
+ stride_d_mu_b, stride_d_mu_h, stride_d_mu_g,
327
+ stride_d_phi_b, stride_d_phi_h, stride_d_phi_g,
328
+ nheads,
329
+ seqlen,
330
+ nchunks,
331
+ headdim,
332
+ CHUNKS_PER_BLOCK: tl.constexpr,
333
+ CHUNK_SIZE: tl.constexpr,
334
+ MASK_TYPE: tl.constexpr,
335
+ BLOCK_HEADDIM: tl.constexpr,
336
+ EVEN_N: tl.constexpr,
337
+ EVEN_HEADDIM: tl.constexpr,
338
+ BLOCK_N: tl.constexpr,
339
+ ):
340
+ start_n = tl.program_id(0)
341
+ offs_bh = tl.program_id(1)
342
+ offs_h = offs_bh % nheads
343
+ offs_b = offs_bh // nheads
344
+ # initialize offsets
345
+ # we load BLOCK_N keys and values each time, and
346
+ # reshape it to [CHUNKS_PER_BLOCK, CHUNK_SIZE]
347
+ offs_c = tl.arange(0, CHUNKS_PER_BLOCK)
348
+ offs_m = tl.arange(0, CHUNK_SIZE)
349
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
350
+
351
+ offs_rfa_c = start_n * CHUNKS_PER_BLOCK + offs_c
352
+
353
+ k_ptrs = (
354
+ K +
355
+ offs_b * stride_kb +
356
+ offs_h * stride_kh +
357
+ (
358
+ (
359
+ start_n * BLOCK_N +
360
+ offs_c[:, None, None] * CHUNK_SIZE +
361
+ offs_m[None, :, None]
362
+ ) * stride_kn +
363
+ offs_d[None, None, :]
364
+ )
365
+ )
366
+ rfa_k_ptrs = (
367
+ RFA_K +
368
+ offs_b * stride_rfa_k_b +
369
+ offs_h * stride_rfa_k_h +
370
+ (offs_rfa_c[:, None] * stride_rfa_k_c + offs_d[None, :])
371
+ )
372
+ rfa_v_ptrs = (
373
+ RFA_V +
374
+ offs_b * stride_rfa_v_b +
375
+ offs_h * stride_rfa_v_h +
376
+ (offs_rfa_c[:, None] * stride_rfa_v_c + offs_d[None, :])
377
+ )
378
+
379
+ d_rfa_k_ptrs = (
380
+ D_RFA_K +
381
+ offs_b * stride_d_rfa_k_b +
382
+ offs_h * stride_d_rfa_k_h +
383
+ (offs_rfa_c[:, None] * stride_d_rfa_k_c + offs_d[None, :])
384
+ )
385
+ d_rfa_v_ptrs = (
386
+ D_RFA_V +
387
+ offs_b * stride_d_rfa_v_b +
388
+ offs_h * stride_d_rfa_v_h +
389
+ (offs_rfa_c[:, None] * stride_d_rfa_v_c + offs_d[None, :])
390
+ )
391
+
392
+ param_mu_ptrs = (
393
+ PARAM_MU +
394
+ offs_h * stride_mu_h +
395
+ offs_d[None, None, :]
396
+ )
397
+ param_phi_ptrs = (
398
+ PARAM_PHI +
399
+ offs_h * stride_phi_h +
400
+ offs_d[None, None, :]
401
+ )
402
+
403
+ log2e = 1.4426950408889634
404
+ if MASK_TYPE == 1:
405
+ m_ptrs = (
406
+ Mask +
407
+ offs_b * stride_mb +
408
+ (
409
+ (
410
+ start_n * BLOCK_N +
411
+ offs_c[:, None] * CHUNK_SIZE +
412
+ offs_m[None, :]
413
+ ) * stride_mn
414
+ )
415
+ )
416
+ if EVEN_N:
417
+ if EVEN_HEADDIM:
418
+ k = tl.load(
419
+ k_ptrs
420
+ )
421
+ else:
422
+ k = tl.load(
423
+ k_ptrs,
424
+ mask=offs_d[None, None, :] < headdim,
425
+ other=0.0
426
+ )
427
+ else:
428
+ if EVEN_HEADDIM:
429
+ k = tl.load(
430
+ k_ptrs,
431
+ mask=(
432
+ start_n * BLOCK_N +
433
+ offs_c[:, None, None] * CHUNK_SIZE +
434
+ offs_m[None, :, None]
435
+ ) < seqlen,
436
+ other=0.0
437
+ )
438
+ else:
439
+ k = tl.load(
440
+ k_ptrs,
441
+ mask=(
442
+ (
443
+ start_n * BLOCK_N +
444
+ offs_c[:, None, None] * CHUNK_SIZE +
445
+ offs_m[None, :, None]
446
+ ) < seqlen
447
+ ) & (offs_d[None, None, :] < headdim),
448
+ other=0.0
449
+ )
450
+
451
+ if EVEN_N:
452
+ if EVEN_HEADDIM:
453
+ rfa_k = tl.load(
454
+ rfa_k_ptrs
455
+ )
456
+ else:
457
+ rfa_k = tl.load(
458
+ rfa_k_ptrs,
459
+ mask=offs_d[None, :] < headdim,
460
+ other=0.0
461
+ )
462
+ else:
463
+ if EVEN_HEADDIM:
464
+ rfa_k = tl.load(
465
+ rfa_k_ptrs,
466
+ mask=offs_rfa_c[:, None] < nchunks,
467
+ other=0.0
468
+ )
469
+ else:
470
+ rfa_k = tl.load(
471
+ rfa_k_ptrs,
472
+ mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
473
+ other=0.0
474
+ )
475
+
476
+ if EVEN_N:
477
+ if EVEN_HEADDIM:
478
+ d_rfa_k = tl.load(
479
+ d_rfa_k_ptrs
480
+ )
481
+ else:
482
+ d_rfa_k = tl.load(
483
+ d_rfa_k_ptrs,
484
+ mask=offs_d[None, :] < headdim,
485
+ other=0.0
486
+ )
487
+ else:
488
+ if EVEN_HEADDIM:
489
+ d_rfa_k = tl.load(
490
+ d_rfa_k_ptrs,
491
+ mask=offs_rfa_c[:, None] < nchunks,
492
+ other=0.0
493
+ )
494
+ else:
495
+ d_rfa_k = tl.load(
496
+ d_rfa_k_ptrs,
497
+ mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
498
+ other=0.0
499
+ )
500
+
501
+ param_mu = tl.load(param_mu_ptrs).to(k.dtype)
502
+ mu_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
503
+ mu_c_w += tl.sum(k * param_mu, axis=-1)
504
+ mu_c_w *= log2e
505
+
506
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
507
+ mu_c_w += tl.where(
508
+ (
509
+ start_n * BLOCK_N +
510
+ offs_c[:, None] * CHUNK_SIZE +
511
+ offs_m[None, :]
512
+ ) < seqlen,
513
+ 0,
514
+ float("-inf")
515
+ )
516
+
517
+ if MASK_TYPE == 1:
518
+ if EVEN_N:
519
+ mask = tl.load(
520
+ m_ptrs
521
+ )
522
+ else:
523
+ mask = tl.load(
524
+ m_ptrs,
525
+ mask=(
526
+ start_n * BLOCK_N +
527
+ offs_c[:, None] * CHUNK_SIZE +
528
+ offs_m[None, :]
529
+ ) < seqlen,
530
+ other=1,
531
+ )
532
+ mu_c_w = tl.where(mask, float("-inf"), mu_c_w)
533
+
534
+ # [c, w]
535
+ m_mu_c_w = tl.max(mu_c_w, axis=-1)
536
+ masked_out_rows_mu = (m_mu_c_w == float("-inf"))
537
+ m_mu_c_w_masked = tl.where(masked_out_rows_mu, 0, m_mu_c_w)
538
+ mu_c_w = tl.exp2(mu_c_w - m_mu_c_w_masked[:, None])
539
+ denom_mu = tl.sum(mu_c_w, axis=-1)
540
+ denom_mu = tl.where(denom_mu == 0.0, 1.0, denom_mu)
541
+ mu_tilde_c_w = mu_c_w / denom_mu[:, None]
542
+ mu_tilde_c_w = mu_tilde_c_w.to(k.dtype)
543
+ # [c, d] [c, w, d] -> [c, w]
544
+ d_mu_tilde_c_w = tl.sum(d_rfa_k[:, None, :] * k, axis=-1)
545
+ # [c, d] [c, d] -> [c]
546
+ d_out_rfa_k_t_rfa_k = tl.sum(d_rfa_k * rfa_k, axis=-1)[:, None]
547
+ d_mu_c_w = (d_mu_tilde_c_w - d_out_rfa_k_t_rfa_k) * mu_tilde_c_w
548
+
549
+ # [c, w] [c, w, d] -> [d]
550
+ d_param_mu = tl.sum(tl.sum(d_mu_c_w[:, :, None] * k, axis=0), axis=0)
551
+ # [c, w] [c, d] + [c, w] [1, 1, d] -> [c, w, d]
552
+ d_k = mu_tilde_c_w[:, :, None] * d_rfa_k[:, None, :] + d_mu_c_w[:, :, None] * param_mu
553
+
554
+ d_param_mu_partial_ptrs = (
555
+ D_PARAM_MU_PARTIAL +
556
+ offs_b * stride_d_mu_b +
557
+ offs_h * stride_d_mu_h +
558
+ start_n * stride_d_mu_g +
559
+ offs_d
560
+ )
561
+ if EVEN_HEADDIM:
562
+ tl.store(
563
+ d_param_mu_partial_ptrs, d_param_mu
564
+ )
565
+ else:
566
+ tl.store(
567
+ d_param_mu_partial_ptrs, d_param_mu,
568
+ mask=offs_d < headdim
569
+ )
570
+
571
+
572
+ v_ptrs = (
573
+ V +
574
+ offs_b * stride_vb +
575
+ offs_h * stride_vh +
576
+ (
577
+ (
578
+ start_n * BLOCK_N +
579
+ offs_c[:, None, None] * CHUNK_SIZE +
580
+ offs_m[None, :, None]
581
+ ) * stride_vn +
582
+ offs_d[None, None, :]
583
+ )
584
+ )
585
+ if EVEN_N:
586
+ if EVEN_HEADDIM:
587
+ v = tl.load(
588
+ v_ptrs
589
+ )
590
+ else:
591
+ v = tl.load(
592
+ v_ptrs,
593
+ mask=offs_d[None, None, :] < headdim,
594
+ other=0.0
595
+ )
596
+ else:
597
+ if EVEN_HEADDIM:
598
+ v = tl.load(
599
+ v_ptrs,
600
+ mask=(
601
+ start_n * BLOCK_N +
602
+ offs_c[:, None, None] * CHUNK_SIZE +
603
+ offs_m[None, :, None]
604
+ ) < seqlen,
605
+ other=0.0
606
+ )
607
+ else:
608
+ v = tl.load(
609
+ v_ptrs,
610
+ mask=(
611
+ (
612
+ start_n * BLOCK_N +
613
+ offs_c[:, None, None] * CHUNK_SIZE +
614
+ offs_m[None, :, None]
615
+ ) < seqlen
616
+ ) & (offs_d[None, None, :] < headdim),
617
+ other=0.0
618
+ )
619
+
620
+
621
+ if EVEN_N:
622
+ if EVEN_HEADDIM:
623
+ rfa_v = tl.load(
624
+ rfa_v_ptrs
625
+ )
626
+ else:
627
+ rfa_v = tl.load(
628
+ rfa_v_ptrs,
629
+ mask=offs_d[None, :] < headdim,
630
+ other=0.0
631
+ )
632
+ else:
633
+ if EVEN_HEADDIM:
634
+ rfa_v = tl.load(
635
+ rfa_v_ptrs,
636
+ mask=offs_rfa_c[:, None] < nchunks,
637
+ other=0.0
638
+ )
639
+ else:
640
+ rfa_v = tl.load(
641
+ rfa_v_ptrs,
642
+ mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
643
+ other=0.0
644
+ )
645
+
646
+ if EVEN_N:
647
+ if EVEN_HEADDIM:
648
+ d_rfa_v = tl.load(
649
+ d_rfa_v_ptrs
650
+ )
651
+ else:
652
+ d_rfa_v = tl.load(
653
+ d_rfa_v_ptrs,
654
+ mask=offs_d[None, :] < headdim,
655
+ other=0.0
656
+ )
657
+ else:
658
+ if EVEN_HEADDIM:
659
+ d_rfa_v = tl.load(
660
+ d_rfa_v_ptrs,
661
+ mask=offs_rfa_c[:, None] < nchunks,
662
+ other=0.0
663
+ )
664
+ else:
665
+ d_rfa_v = tl.load(
666
+ d_rfa_v_ptrs,
667
+ mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
668
+ other=0.0
669
+ )
670
+
671
+ param_phi = tl.load(param_phi_ptrs).to(k.dtype)
672
+ phi_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
673
+ phi_c_w += tl.sum(k * param_phi, axis=-1)
674
+ phi_c_w -= (0.5 * tl.sum(k * k, axis=-1))
675
+ phi_c_w *= log2e * softmax_scale
676
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
677
+ phi_c_w += tl.where(
678
+ (
679
+ start_n * BLOCK_N +
680
+ offs_c[:, None] * CHUNK_SIZE +
681
+ offs_m[None, :]
682
+ ) < seqlen,
683
+ 0,
684
+ float("-inf")
685
+ )
686
+
687
+ if MASK_TYPE == 1:
688
+ phi_c_w = tl.where(mask, float("-inf"), phi_c_w)
689
+
690
+
691
+ m_phi_c_w = tl.max(phi_c_w, axis=-1)
692
+ masked_out_rows_phi = (m_phi_c_w == float("-inf"))
693
+ m_phi_c_w_masked = tl.where(masked_out_rows_phi, 0, m_phi_c_w)
694
+ phi_c_w = tl.exp2(phi_c_w - m_phi_c_w_masked[:, None])
695
+ denom_phi = tl.sum(phi_c_w, axis=-1)
696
+ denom_phi = tl.where(denom_phi == 0.0, 1.0, denom_phi)
697
+ phi_tilde_c_w = phi_c_w / denom_phi[:, None]
698
+ # phi_c_w = tl.exp2(phi_c_w - tl.max(phi_c_w, axis=-1)[:, None])
699
+ # phi_tilde_c_w = phi_c_w / tl.sum(phi_c_w, axis=-1)[:, None]
700
+ phi_tilde_c_w = phi_tilde_c_w.to(k.dtype)
701
+ d_phi_tilde_c_w = tl.sum(d_rfa_v[:, None, :] * v, axis=-1)
702
+ d_out_rfa_v_t_rfa_v = tl.sum(d_rfa_v * rfa_v, axis=-1)[:, None]
703
+ d_phi_c_w = (d_phi_tilde_c_w.to(tl.float32) - d_out_rfa_v_t_rfa_v.to(tl.float32)) * phi_tilde_c_w
704
+
705
+ d_param_phi = tl.sum(tl.sum(d_phi_c_w[:, :, None] * k * softmax_scale, axis=0), axis=0)
706
+ d_v = phi_tilde_c_w[:, :, None] * d_rfa_v[:, None, :]
707
+ # [c, w, d] + [c, w] * [1, 1, d] - [c, w, d]
708
+ d_k = d_k + softmax_scale * d_phi_c_w[:, :, None] * (param_phi - k)
709
+
710
+ d_k_ptrs = (
711
+ D_K +
712
+ offs_b * stride_d_k_b +
713
+ offs_h * stride_d_k_h +
714
+ (
715
+ (
716
+ start_n * BLOCK_N +
717
+ offs_c[:, None, None] * CHUNK_SIZE +
718
+ offs_m[None, :, None]
719
+ ) * stride_d_k_n +
720
+ offs_d[None, None, :]
721
+ )
722
+ )
723
+ d_v_ptrs = (
724
+ D_V +
725
+ offs_b * stride_d_v_b +
726
+ offs_h * stride_d_v_h +
727
+ (
728
+ (
729
+ start_n * BLOCK_N +
730
+ offs_c[:, None, None] * CHUNK_SIZE +
731
+ offs_m[None, :, None]
732
+ ) * stride_d_v_n +
733
+ offs_d[None, None, :]
734
+ )
735
+ )
736
+ if EVEN_N:
737
+ if EVEN_HEADDIM:
738
+ tl.store(
739
+ d_k_ptrs, d_k
740
+ )
741
+ tl.store(
742
+ d_v_ptrs, d_v
743
+ )
744
+ else:
745
+ tl.store(
746
+ d_k_ptrs, d_k,
747
+ mask=offs_d[None, None, :] < headdim
748
+ )
749
+ tl.store(
750
+ d_v_ptrs, d_v,
751
+ mask=offs_d[None, None, :] < headdim
752
+ )
753
+ else:
754
+ if EVEN_HEADDIM:
755
+ tl.store(
756
+ d_k_ptrs, d_k,
757
+ mask=(
758
+ (
759
+ start_n * BLOCK_N +
760
+ offs_c[:, None, None] * CHUNK_SIZE +
761
+ offs_m[None, :, None]
762
+ ) < seqlen
763
+ ),
764
+ )
765
+ tl.store(
766
+ d_v_ptrs, d_v,
767
+ mask=(
768
+ (
769
+ start_n * BLOCK_N +
770
+ offs_c[:, None, None] * CHUNK_SIZE +
771
+ offs_m[None, :, None]
772
+ ) < seqlen
773
+ ),
774
+ )
775
+ else:
776
+ tl.store(
777
+ d_k_ptrs, d_k,
778
+ mask=(
779
+ (
780
+ start_n * BLOCK_N +
781
+ offs_c[:, None, None] * CHUNK_SIZE +
782
+ offs_m[None, :, None]
783
+ ) < seqlen
784
+ ) & (offs_d[None, None, :] < headdim),
785
+ )
786
+ tl.store(
787
+ d_v_ptrs, d_v,
788
+ mask=(
789
+ (
790
+ start_n * BLOCK_N +
791
+ offs_c[:, None, None] * CHUNK_SIZE +
792
+ offs_m[None, :, None]
793
+ ) < seqlen
794
+ ) & (offs_d[None, None, :] < headdim),
795
+ )
796
+ d_param_phi_partial_ptrs = (
797
+ D_PARAM_PHI_PARTIAL +
798
+ offs_b * stride_d_phi_b +
799
+ offs_h * stride_d_phi_h +
800
+ start_n * stride_d_phi_g +
801
+ offs_d
802
+ )
803
+ if EVEN_HEADDIM:
804
+ tl.store(
805
+ d_param_phi_partial_ptrs, d_param_phi
806
+ )
807
+ else:
808
+ tl.store(
809
+ d_param_phi_partial_ptrs, d_param_phi,
810
+ mask=offs_d < headdim
811
+ )
812
+
813
+ def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, mask, softmax_scale, chunksize):
814
  k, v, param_mu, param_phi = [
815
  x if x.stride(-1) == 1 else x.contiguous()
816
  for x in [k, v, param_mu, param_phi]
 
831
  softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
832
 
833
  mask_type = 0
834
+ if mask is not None:
835
  mask_type = 1
836
+ assert mask.dtype == torch.bool
837
+ assert mask.is_cuda
838
+ assert mask.dim() == 4
839
+ assert mask.shape == (batch, 1, seqlen, 1)
840
+ if mask.stride(-1) != 1:
841
+ mask = mask.contiguous()
842
  mask_strides = (
843
+ (mask.stride(0), mask.stride(2))
844
  if mask_type == 1 else
845
  (0, 0)
846
  )
 
860
  v,
861
  param_mu,
862
  param_phi,
863
+ mask,
864
  out_rfa_k,
865
  out_rfa_v,
866
  softmax_scale,
 
875
  seqlen,
876
  nchunks,
877
  head_dim,
 
 
878
  chunks_per_block,
879
  chunksize,
880
  mask_type,
 
884
  num_stages=1,
885
  )
886
  return out_rfa_k, out_rfa_v
887
+
888
+ def triton_eva_prep_kv_bwd(
889
+ d_rfa_k, d_rfa_v,
890
+ k, v, param_mu, param_phi,
891
+ mask,
892
+ rfa_k, rfa_v,
893
+ d_k, d_v, d_param_mu, d_param_phi,
894
+ softmax_scale,
895
+ mask_type,
896
+ chunksize
897
+ ):
898
+ d_rfa_k, d_rfa_v = [
899
+ x if x.stride(-1) == 1 else x.contiguous()
900
+ for x in [d_rfa_k, d_rfa_v]
901
+ ]
902
+
903
+ # shape constraints
904
+ batch, nheads, seqlen, head_dim = k.shape
905
+ assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize"
906
+ nchunks = seqlen // chunksize
907
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
908
+
909
+ mask_strides = (
910
+ (mask.stride(0), mask.stride(2))
911
+ if mask_type == 1 else
912
+ (0, 0)
913
+ )
914
+
915
+ BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
916
+ BLOCK = 128
917
+ num_warps = 4 if head_dim <= 64 else 8
918
+
919
+ assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize"
920
+ chunks_per_block = BLOCK // chunksize
921
+
922
+ partial_groups = triton.cdiv(seqlen, BLOCK)
923
+ d_param_mu_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device)
924
+ d_param_phi_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device)
925
+ grid = lambda META: (partial_groups, batch * nheads)
926
+ _bwd_eva_prep_kv_kernel[grid](
927
+ rfa_k, # [b, h, c, d]
928
+ rfa_v, # [b, h, c, d]
929
+ k, # [b, h, n, d]
930
+ v, # [b, h, n, d]
931
+ param_mu, # [1, h, 1, 1, d]
932
+ param_phi, # [1, h, 1, 1, d]
933
+ mask, # [b, h, n, 1]
934
+ d_rfa_k, # [b, h, c, d]
935
+ d_rfa_v, # [b, h, c, d]
936
+ d_k, # [b, h, n, d]
937
+ d_v, # [b, h, n, d]
938
+ d_param_mu_partial, # [b, h, g, d]
939
+ d_param_phi_partial, # [b, h, g, d]
940
+ softmax_scale,
941
+ rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2),
942
+ rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2),
943
+ k.stride(0), k.stride(1), k.stride(2),
944
+ v.stride(0), v.stride(1), v.stride(2),
945
+ param_mu.stride(1),
946
+ param_phi.stride(1),
947
+ mask_strides[0], mask_strides[1],
948
+ d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2),
949
+ d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2),
950
+ d_k.stride(0), d_k.stride(1), d_k.stride(2),
951
+ d_v.stride(0), d_v.stride(1), d_v.stride(2),
952
+ d_param_mu_partial.stride(0), d_param_mu_partial.stride(1), d_param_mu_partial.stride(2),
953
+ d_param_phi_partial.stride(0), d_param_phi_partial.stride(1), d_param_phi_partial.stride(2),
954
+ nheads,
955
+ seqlen,
956
+ nchunks,
957
+ head_dim,
958
+ chunks_per_block,
959
+ chunksize,
960
+ mask_type,
961
+ BLOCK_HEADDIM,
962
+ BLOCK_N=BLOCK,
963
+ num_warps=num_warps,
964
+ num_stages=1,
965
+ )
966
+ d_param_mu.copy_(d_param_mu_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_mu.dtype))
967
+ d_param_phi.copy_(d_param_phi_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_phi.dtype))
968
+
969
+
970
+
971
+ class EvaPrepKVFunc(torch.autograd.Function):
972
+ @staticmethod
973
+ def forward(ctx, k, v, param_mu, param_phi, mask, softmax_scale=None, chunksize=None):
974
+ if mask is not None:
975
+ mask_type = 1
976
+ else:
977
+ mask_type = 0
978
+ rfa_k, rfa_v = triton_eva_prep_kv_fwd(
979
+ k, v, param_mu, param_phi, mask, softmax_scale, chunksize
980
+ )
981
+ ctx.save_for_backward(k, v, param_mu, param_phi, mask, rfa_k, rfa_v)
982
+ ctx.softmax_scale = softmax_scale
983
+ ctx.chunksize = chunksize
984
+ ctx.mask_type = mask_type
985
+ return rfa_k, rfa_v
986
+
987
+ @staticmethod
988
+ def backward(ctx, d_rfa_k, d_rfa_v):
989
+ k, v, param_mu, param_phi, mask, rfa_k, rfa_v = ctx.saved_tensors
990
+ d_k = torch.empty_like(k)
991
+ d_v = torch.empty_like(v)
992
+ d_param_mu = torch.empty_like(param_mu)
993
+ d_param_phi = torch.empty_like(param_phi)
994
+ triton_eva_prep_kv_bwd(
995
+ d_rfa_k, d_rfa_v,
996
+ k, v, param_mu, param_phi,
997
+ mask,
998
+ rfa_k, rfa_v,
999
+ d_k, d_v, d_param_mu, d_param_phi,
1000
+ ctx.softmax_scale,
1001
+ ctx.mask_type,
1002
+ ctx.chunksize
1003
+ )
1004
+ return d_k, d_v, d_param_mu, d_param_phi, None, None, None
1005
+
1006
+ def eva_prep_kv_func_triton(
1007
+ k, v,
1008
+ param_mu, param_phi,
1009
+ mask,
1010
+ softmax_scale=None, chunksize=None
1011
+ ):
1012
+ return EvaPrepKVFunc.apply(
1013
+ k, v,
1014
+ param_mu, param_phi,
1015
+ mask,
1016
+ softmax_scale, chunksize
1017
+ )
eva_pt_ref.py CHANGED
@@ -263,7 +263,6 @@ class EvaAttention(nn.Module):
263
  v,
264
  self.layer_idx,
265
  self.window_size,
266
- self.singleton_update
267
  )
268
  else:
269
  prev_w_q = self.window_partition(q) # [b, h, w, i, d]
@@ -289,10 +288,9 @@ class EvaAttention(nn.Module):
289
  layer_idx=self.layer_idx,
290
  window_size=self.window_size,
291
  chunk_size=self.chunk_size,
292
- singleton_update=self.singleton_update
293
  )
294
  else:
295
- prev_s_mask = window_causal_mask # [1, 1, w, i, j]
296
  cur_s_mask = None
297
  prev_chunk_mask = self.window_partition(chunk_causal_mask)
298
  cur_chunk_mask = None
 
263
  v,
264
  self.layer_idx,
265
  self.window_size,
 
266
  )
267
  else:
268
  prev_w_q = self.window_partition(q) # [b, h, w, i, d]
 
288
  layer_idx=self.layer_idx,
289
  window_size=self.window_size,
290
  chunk_size=self.chunk_size,
 
291
  )
292
  else:
293
+ prev_s_mask = self.window_partition(prev_causal_mask) # [1, 1, w, i, j]
294
  cur_s_mask = None
295
  prev_chunk_mask = self.window_partition(chunk_causal_mask)
296
  cur_chunk_mask = None
modeling_evabyte.py CHANGED
@@ -148,7 +148,7 @@ class EvaByteRMSNorm(nn.Module):
148
  def __init__(self, config):
149
  super().__init__()
150
  self.config = config
151
- self.fp32_ln = config.fp32_ln
152
  self.variance_epsilon = config.rms_norm_eps
153
  self.add_unit_offset = config.norm_add_unit_offset
154
  if self.add_unit_offset:
@@ -157,18 +157,14 @@ class EvaByteRMSNorm(nn.Module):
157
  self.weight = nn.Parameter(torch.ones(config.hidden_size))
158
 
159
  def forward(self, hidden_states):
160
- if hasattr(self, 'config'):
161
- fp32_ln = self.config.fp32_ln
162
- else:
163
- fp32_ln = self.fp32_ln
164
- hidden_states = hidden_states.to(torch.float32 if fp32_ln else torch.bfloat16)
165
 
166
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
167
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
168
  if self.add_unit_offset:
169
- return (1 + self.weight) * hidden_states
170
  else:
171
- return self.weight * hidden_states
172
 
173
  class EvaByteRotaryEmbedding(torch.nn.Module):
174
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
@@ -313,7 +309,7 @@ class EvaByteDecoderLayer(nn.Module):
313
  cos=cos,
314
  sin=sin,
315
  multibyte_decoding=multibyte_decoding)
316
- hidden_states = residual + hidden_states
317
 
318
  # Fully Connected
319
  residual = hidden_states
@@ -321,7 +317,7 @@ class EvaByteDecoderLayer(nn.Module):
321
  residual = residual.float()
322
  hidden_states = self.post_attention_layernorm(hidden_states)
323
  hidden_states = self.mlp(hidden_states)
324
- hidden_states = residual + hidden_states
325
 
326
  outputs = (hidden_states, )
327
 
@@ -653,7 +649,7 @@ class EvaByteModel(EvaBytePreTrainedModel):
653
  )
654
  else:
655
  assert self.training
656
- assert seq_len % self.config.window_size == 0
657
  # for training, we need to pass in the attention mask
658
  # usually calculated by _prepare_training_attn_mask()
659
  causal_mask = attention_mask
@@ -683,31 +679,6 @@ class EvaByteModel(EvaBytePreTrainedModel):
683
  cos = cos.unsqueeze(1)
684
  sin = sin.unsqueeze(1)
685
 
686
- if USE_TRITON_IMPL and (not multibyte_decoding):
687
- # the masks generated above for triton kernels are boolean. Convert them to floats
688
- if (
689
- (not use_cache) or
690
- (use_cache and past_seen_tokens == 0)
691
- ):
692
- window_mask, intra_chunk_mask = causal_mask
693
-
694
- if window_mask is not None:
695
- assert window_mask.dtype == torch.bool
696
- window_mask_float = window_mask.to(torch.float)
697
- window_mask_float = window_mask_float.masked_fill(window_mask.to(torch.bool), MASK_MIN_VALUE)
698
- window_mask_float = window_mask_float.reshape(batch_size, 1, -1, self.config.window_size)
699
- window_mask = window_mask_float.to(hidden_states.dtype)
700
-
701
- if intra_chunk_mask is not None:
702
- assert intra_chunk_mask.dtype == torch.bool
703
- intra_chunk_mask_float = intra_chunk_mask.to(torch.float)
704
- intra_chunk_mask_float = intra_chunk_mask_float.masked_fill(intra_chunk_mask.to(torch.bool), MASK_MIN_VALUE)
705
- intra_chunk_mask = intra_chunk_mask_float.to(hidden_states.dtype)
706
- causal_mask = (window_mask, intra_chunk_mask)
707
-
708
- if self.config.fp32_skip_add:
709
- hidden_states = hidden_states.float()
710
-
711
  # decoder layers
712
  all_hidden_states = () if output_hidden_states else None
713
  all_self_attns = () if output_attentions else None
@@ -718,20 +689,17 @@ class EvaByteModel(EvaBytePreTrainedModel):
718
  all_hidden_states += (hidden_states, )
719
 
720
  if self.gradient_checkpointing and self.training:
721
-
722
- def create_custom_forward(module):
723
- def custom_forward(*inputs):
724
- # None for past_key_value
725
- return module(*inputs, output_attentions, use_cache=None)
726
-
727
- return custom_forward
728
-
729
  layer_outputs = torch.utils.checkpoint.checkpoint(
730
- create_custom_forward(decoder_layer),
731
  hidden_states,
732
  causal_mask,
733
  position_ids,
734
- None,
 
 
 
 
 
735
  )
736
  else:
737
  layer_outputs = decoder_layer(
@@ -806,154 +774,6 @@ class EvaByteForCausalLM(EvaBytePreTrainedModel, MultiByteDecodingMixin):
806
  def get_decoder(self):
807
  return self.model
808
 
809
- def _prepare_training_attn_mask(
810
- self,
811
- target_token_type_ids,
812
- use_doc_boundary_attention,
813
- EOS_TOKEN_TYPE_ID=None,
814
- PAD_TOKEN_TYPE_ID=None,
815
- ):
816
- '''
817
- This function prepares the attention mask for training byte models.
818
- target_token_type_ids:
819
- Tensor of shape (batch_size, seq_len), marking the token type ids
820
- for the target sequence. In particular, we should have
821
- - target_token_type_ids[i, j] = EOS_TOKEN_TYPE_ID
822
- if the j-th token in the i-th sequence is the end of an article.
823
- - target_token_type_ids[i, j] = PAD_TOKEN_TYPE_ID
824
- if the j-th token in the i-th sequence is the padding token.
825
- use_doc_boundary_attention: bool,
826
- whether to enable doc boundary attention.
827
- EOS_TOKEN_TYPE_ID: int,
828
- the token type id for the end of an article.
829
- PAD_TOKEN_TYPE_ID: int,
830
- the token type id for the padding token.
831
- '''
832
- assert self.training
833
- batch_size, num_tokens = target_token_type_ids.shape
834
-
835
- chunk_causal_mask, window_causal_mask = prepare_eva_attention_mask(
836
- num_tokens,
837
- target_token_type_ids.device,
838
- chunk_size=self.config.chunk_size,
839
- window_size=self.config.window_size,
840
- use_cache=False,
841
- cache=None
842
- )
843
- if use_doc_boundary_attention:
844
- #### step 1: mark each document with a unique id
845
- end_token_ids = {EOS_TOKEN_TYPE_ID, PAD_TOKEN_TYPE_ID}
846
- token_types = torch.zeros(batch_size, num_tokens)
847
- for sequence_idx, sequence in enumerate(target_token_type_ids):
848
- num_articles = 0
849
- start_index = 0
850
- # for each sample in the batch, the collapsed attention mask looks like:
851
- # [1, 1, .... 1, 0, 2, 2, ... 2, 0, ... n, n ..... n], assuming there are n articles in the sequence.
852
- # Each of the n articles are separated by 0.
853
- for token_idx, token_type_id in enumerate(sequence):
854
- if start_index is not None and token_type_id.item() in end_token_ids:
855
- num_articles += 1
856
- end_index = token_idx if token_type_id == PAD_TOKEN_TYPE_ID else token_idx + 1
857
- token_types[sequence_idx][start_index:end_index] = num_articles
858
- start_index = None
859
- elif start_index is None and token_type_id not in end_token_ids:
860
- start_index = token_idx + 1
861
-
862
- assert num_tokens % self.config.chunk_size == 0, "Number of tokens must be divisible by chunk size"
863
- assert num_tokens % self.config.window_size == 0, "Number of tokens must be divisible by window size"
864
- num_chunks = num_tokens // self.config.chunk_size
865
- num_windows = num_tokens // self.config.window_size
866
-
867
- article_separator = 0
868
-
869
- #### step 2: generate attention masks for each window
870
- #### NOTE: we perform exact attention within each window,
871
- #### so we only need to mask out different documents
872
- #### for each window.
873
- token_types_windows = token_types.reshape(batch_size, num_windows, self.config.window_size, 1)
874
- token_types_windows_t = token_types_windows.transpose(-1, -2)
875
- # replace all elements in TOKEN_SEPS with -1
876
- token_types_windows = torch.where(token_types_windows == article_separator, -1, token_types_windows)
877
- window_3d_mask = (token_types_windows == token_types_windows_t)
878
- window_3d_mask = ~window_3d_mask
879
-
880
- #### step 3: generate chunk-level 3D masks
881
- #### NOTE: this is a bit tricky, as we aim to mask out different
882
- #### documents to avoid cross-doc attention across chunks.
883
- #### Example: suppose we have a sequence of length 12 with 3 documents:
884
- #### [1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 3].
885
- #### The chunk-size and window-size are both 4.
886
- #### The chunk-level mask of shape (batch_size, seq_len, num_chunks) is:
887
- #### [
888
- #### [0, 0, 0],
889
- #### [0, 0, 0],
890
- #### [0, 0, 0],
891
- #### [0, 0, 0],
892
- ####
893
- #### [1, 0, 0],
894
- #### [0, 0, 0],
895
- #### [0, 0, 0],
896
- #### [0, 0, 0],
897
- ####
898
- #### [0, 1, 0],
899
- #### [0, 1, 0],
900
- #### [0, 1, 0],
901
- #### [0, 1, 0],
902
- #### ]
903
- #### Explanation:
904
- #### - Tokens will not attend to their own and future chunks.
905
- #### (as tokens within a chunk are captured by the window-level exact attention)
906
- #### - Tokens will attend to a chunk only if there are tokens
907
- #### from the same document in that chunk.
908
- #### The mask within each chunk of shape (batch_size, num_chunks, chunk_size) is:
909
- #### [
910
- #### [1, 1, 1, 1],
911
- #### [0, 0, 0, 1],
912
- #### [1, 1, 1, 1],
913
- #### ]
914
- #### Explanation:
915
- #### - If all tokens in a chunk are from the same document,
916
- #### no tokens will be masked out.
917
- #### - If there are tokens from different documents in a chunk,
918
- #### only tokens from the rightmost document will be kept.
919
- #### (b/c the future chunks might contain tokens from the rightmost document,
920
- #### but all the remaining docs will never get attended by other docs)
921
- token_types_chunks = token_types.reshape(batch_size, num_chunks, self.config.chunk_size)
922
- inter_chunk_mask = torch.zeros((batch_size, num_tokens, num_chunks), dtype=torch.bool)
923
- intra_chunk_mask = torch.ones_like(token_types_chunks, dtype=torch.bool)
924
-
925
- for chunk_idx in range(num_chunks):
926
- for batch_idx in range(batch_size):
927
- # Identify tokens in the current chunk belonging to each sequence
928
- chunk = token_types_chunks[batch_idx, chunk_idx]
929
- unique_elements = torch.unique(chunk, sorted=True).tolist()
930
-
931
- # Create a mask for whether each token can attend to the current chunk
932
- for token_type in unique_elements:
933
- if token_type == article_separator:
934
- continue
935
- token_mask = (token_types[batch_idx] == token_type)
936
- inter_chunk_mask[batch_idx, :, chunk_idx] |= token_mask
937
-
938
- # Create a mask within each chunk
939
- unique_elements = [x for x in unique_elements if x != article_separator]
940
- if len(unique_elements) > 1 and chunk[-1] != article_separator:
941
- intra_chunk_mask[batch_idx, chunk_idx] = (chunk == unique_elements[-1])
942
-
943
- inter_chunk_mask = ~inter_chunk_mask
944
- intra_chunk_mask = ~intra_chunk_mask
945
-
946
- window_mask = torch.logical_or(window_causal_mask, window_3d_mask.unsqueeze(1))
947
- inter_chunk_mask = torch.logical_or(chunk_causal_mask, inter_chunk_mask.unsqueeze(1))
948
- intra_chunk_mask = intra_chunk_mask.unsqueeze(1).unsqueeze(-1)
949
-
950
- joint_mask = torch.cat([window_mask, inter_chunk_mask.reshape(*window_mask.shape)], dim=-1)
951
- attention_mask = (joint_mask, intra_chunk_mask)
952
- else:
953
- joint_mask = torch.cat([window_causal_mask, chunk_causal_mask.reshape(*window_causal_mask.shape)], dim=-1)
954
- attention_mask = (joint_mask, None)
955
- return attention_mask
956
-
957
  def forward(
958
  self,
959
  input_ids: torch.LongTensor = None,
 
148
  def __init__(self, config):
149
  super().__init__()
150
  self.config = config
151
+ self.fp32_ln = True
152
  self.variance_epsilon = config.rms_norm_eps
153
  self.add_unit_offset = config.norm_add_unit_offset
154
  if self.add_unit_offset:
 
157
  self.weight = nn.Parameter(torch.ones(config.hidden_size))
158
 
159
  def forward(self, hidden_states):
160
+ _hidden_states = hidden_states.to(torch.float32 if self.fp32_ln else torch.bfloat16)
 
 
 
 
161
 
162
+ variance = _hidden_states.pow(2).mean(-1, keepdim=True)
163
+ _hidden_states = _hidden_states * torch.rsqrt(variance + self.variance_epsilon)
164
  if self.add_unit_offset:
165
+ return ((1 + self.weight) * _hidden_states).type_as(hidden_states)
166
  else:
167
+ return (self.weight * _hidden_states).type_as(hidden_states)
168
 
169
  class EvaByteRotaryEmbedding(torch.nn.Module):
170
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
 
309
  cos=cos,
310
  sin=sin,
311
  multibyte_decoding=multibyte_decoding)
312
+ hidden_states = (residual + hidden_states).to(hidden_states.dtype)
313
 
314
  # Fully Connected
315
  residual = hidden_states
 
317
  residual = residual.float()
318
  hidden_states = self.post_attention_layernorm(hidden_states)
319
  hidden_states = self.mlp(hidden_states)
320
+ hidden_states = (residual + hidden_states).to(hidden_states.dtype)
321
 
322
  outputs = (hidden_states, )
323
 
 
649
  )
650
  else:
651
  assert self.training
652
+ assert seq_len % self.config.window_size == 0, "Training is only tested for sequences that are a multiple of window_size"
653
  # for training, we need to pass in the attention mask
654
  # usually calculated by _prepare_training_attn_mask()
655
  causal_mask = attention_mask
 
679
  cos = cos.unsqueeze(1)
680
  sin = sin.unsqueeze(1)
681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  # decoder layers
683
  all_hidden_states = () if output_hidden_states else None
684
  all_self_attns = () if output_attentions else None
 
689
  all_hidden_states += (hidden_states, )
690
 
691
  if self.gradient_checkpointing and self.training:
 
 
 
 
 
 
 
 
692
  layer_outputs = torch.utils.checkpoint.checkpoint(
693
+ decoder_layer.__call__,
694
  hidden_states,
695
  causal_mask,
696
  position_ids,
697
+ past_key_values,
698
+ output_attentions,
699
+ use_cache,
700
+ cos,
701
+ sin,
702
+ multibyte_decoding,
703
  )
704
  else:
705
  layer_outputs = decoder_layer(
 
774
  def get_decoder(self):
775
  return self.model
776
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
777
  def forward(
778
  self,
779
  input_ids: torch.LongTensor = None,