Fix for FlashAttention RuntimeError & Triton Multi GPU fix.

#17
positional_embedding.py CHANGED
@@ -269,10 +269,10 @@ class RotaryEmbedding(torch.nn.Module):
269
  return (
270
  apply_rotary_pos_emb(
271
  q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
272
- ),
273
  apply_rotary_pos_emb(
274
  k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
275
- ),
276
  )
277
 
278
  @classmethod
 
269
  return (
270
  apply_rotary_pos_emb(
271
  q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
272
+ ).to(q.dtype),
273
  apply_rotary_pos_emb(
274
  k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
275
+ ).to(q.dtype),
276
  )
277
 
278
  @classmethod
triton_flash_blocksparse_attn.py CHANGED
@@ -611,30 +611,31 @@ def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BL
611
  # print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
612
  # {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
613
 
614
- _fwd_kernel[grid](
615
- q, k, v, sm_scale,
616
- layout_crow_indices,
617
- layout_col_indices,
618
- layout_crow_indices.stride(0), layout_crow_indices.stride(1),
619
- layout_col_indices.stride(0), layout_col_indices.stride(1),
620
- tmp, L, m,
621
- o,
622
- q.stride(0), q.stride(1), q.stride(2), q.stride(3),
623
- k.stride(0), k.stride(1), k.stride(2), k.stride(3),
624
- v.stride(0), v.stride(1), v.stride(2), v.stride(3),
625
- o.stride(0), o.stride(1), o.stride(2), o.stride(3),
626
- q.shape[0], q.shape[1], k.shape[2],
627
- k.shape[2] - q.shape[2],
628
- q_rounded_len,
629
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
630
- BLOCK_DMODEL=BLOCK_DMODEL,
631
- EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
632
- EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
633
- INFERENCE=inference,
634
- NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
635
- num_warps=num_warps,
636
- num_stages=num_stages,
637
- )
 
638
  if inference:
639
  L, m = None, None
640
 
@@ -991,37 +992,38 @@ def blocksparse_flash_attn_padded_fwd(
991
 
992
  grid = (len(q_start_sids), n_heads)
993
 
994
- _fwd_kernel_batch_inference[grid](
995
- q, k, v, out,
996
- sm_scale,
997
- q_batch_starts,
998
- q_batch_ends,
999
- k_batch_starts,
1000
- k_batch_ends,
1001
- q_batch_ids,
1002
- q_start_sids,
1003
-
1004
- *q.stride(),
1005
- *k.stride(),
1006
- *v.stride(),
1007
- *out.stride(),
1008
-
1009
- layout_crow_indices,
1010
- layout_col_indices,
1011
- *layout_crow_indices.stride(),
1012
- *layout_col_indices.stride(),
1013
-
1014
- q_k_ratio,
1015
- HAS_BATCH_DIM = True,
1016
- D_HEAD = head_size,
1017
- BLOCK_M = block_size,
1018
- BLOCK_N = block_size,
1019
- BLOCK_D = block_d,
1020
- BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
1021
- EVEN_D = block_d == head_size,
1022
- num_warps = 1 if q_len == 1 else 4,
1023
- num_stages = 3
1024
- )
 
1025
 
1026
  return out
1027
 
@@ -1093,37 +1095,38 @@ def blocksparse_flash_attn_varlen_fwd(
1093
 
1094
  grid = (len(q_start_sids), n_heads)
1095
 
1096
- _fwd_kernel_batch_inference[grid](
1097
- q, k, v, out,
1098
- sm_scale,
1099
- cu_seqlens_q[:-1],
1100
- cu_seqlens_q[1:],
1101
- cu_seqlens_k[:-1],
1102
- cu_seqlens_k[1:],
1103
- q_batch_ids,
1104
- q_start_sids,
1105
-
1106
- 0, *q.stride(),
1107
- 0, *k.stride(),
1108
- 0, *v.stride(),
1109
- 0, *out.stride(),
1110
-
1111
- layout_crow_indices,
1112
- layout_col_indices,
1113
- *layout_crow_indices.stride(),
1114
- *layout_col_indices.stride(),
1115
-
1116
- q_k_ratio,
1117
- HAS_BATCH_DIM = False,
1118
- D_HEAD = head_size,
1119
- BLOCK_M = block_size,
1120
- BLOCK_N = block_size,
1121
- BLOCK_D = block_d,
1122
- BLOCK_M_LOADING = 16 if decoding_only else block_size, # smaller for decoding
1123
- EVEN_D = block_d == head_size,
1124
- num_warps = 1 if decoding_only else 4,
1125
- num_stages = 3
1126
- )
 
1127
 
1128
  return out
1129
 
 
611
  # print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
612
  # {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
613
 
614
+ with torch.cuda.device(q.device.index):
615
+ _fwd_kernel[grid](
616
+ q, k, v, sm_scale,
617
+ layout_crow_indices,
618
+ layout_col_indices,
619
+ layout_crow_indices.stride(0), layout_crow_indices.stride(1),
620
+ layout_col_indices.stride(0), layout_col_indices.stride(1),
621
+ tmp, L, m,
622
+ o,
623
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
624
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
625
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
626
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
627
+ q.shape[0], q.shape[1], k.shape[2],
628
+ k.shape[2] - q.shape[2],
629
+ q_rounded_len,
630
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
631
+ BLOCK_DMODEL=BLOCK_DMODEL,
632
+ EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
633
+ EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
634
+ INFERENCE=inference,
635
+ NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
636
+ num_warps=num_warps,
637
+ num_stages=num_stages,
638
+ )
639
  if inference:
640
  L, m = None, None
641
 
 
992
 
993
  grid = (len(q_start_sids), n_heads)
994
 
995
+ with torch.cuda.device(q.device.index):
996
+ _fwd_kernel_batch_inference[grid](
997
+ q, k, v, out,
998
+ sm_scale,
999
+ q_batch_starts,
1000
+ q_batch_ends,
1001
+ k_batch_starts,
1002
+ k_batch_ends,
1003
+ q_batch_ids,
1004
+ q_start_sids,
1005
+
1006
+ *q.stride(),
1007
+ *k.stride(),
1008
+ *v.stride(),
1009
+ *out.stride(),
1010
+
1011
+ layout_crow_indices,
1012
+ layout_col_indices,
1013
+ *layout_crow_indices.stride(),
1014
+ *layout_col_indices.stride(),
1015
+
1016
+ q_k_ratio,
1017
+ HAS_BATCH_DIM = True,
1018
+ D_HEAD = head_size,
1019
+ BLOCK_M = block_size,
1020
+ BLOCK_N = block_size,
1021
+ BLOCK_D = block_d,
1022
+ BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
1023
+ EVEN_D = block_d == head_size,
1024
+ num_warps = 1 if q_len == 1 else 4,
1025
+ num_stages = 3
1026
+ )
1027
 
1028
  return out
1029
 
 
1095
 
1096
  grid = (len(q_start_sids), n_heads)
1097
 
1098
+ with torch.cuda.device(q.device.index):
1099
+ _fwd_kernel_batch_inference[grid](
1100
+ q, k, v, out,
1101
+ sm_scale,
1102
+ cu_seqlens_q[:-1],
1103
+ cu_seqlens_q[1:],
1104
+ cu_seqlens_k[:-1],
1105
+ cu_seqlens_k[1:],
1106
+ q_batch_ids,
1107
+ q_start_sids,
1108
+
1109
+ 0, *q.stride(),
1110
+ 0, *k.stride(),
1111
+ 0, *v.stride(),
1112
+ 0, *out.stride(),
1113
+
1114
+ layout_crow_indices,
1115
+ layout_col_indices,
1116
+ *layout_crow_indices.stride(),
1117
+ *layout_col_indices.stride(),
1118
+
1119
+ q_k_ratio,
1120
+ HAS_BATCH_DIM = False,
1121
+ D_HEAD = head_size,
1122
+ BLOCK_M = block_size,
1123
+ BLOCK_N = block_size,
1124
+ BLOCK_D = block_d,
1125
+ BLOCK_M_LOADING = 16 if decoding_only else block_size, # smaller for decoding
1126
+ EVEN_D = block_d == head_size,
1127
+ num_warps = 1 if decoding_only else 4,
1128
+ num_stages = 3
1129
+ )
1130
 
1131
  return out
1132