Markus28 commited on
Commit
46df05d
·
1 Parent(s): 6343db7

feat: moved flash attention code into this repository

Browse files
configuration_bert.py CHANGED
@@ -91,6 +91,9 @@ class JinaBertConfig(PretrainedConfig):
91
  assert 'max_position_embeddings' not in kwargs
92
  super().__init__(pad_token_id=pad_token_id, **kwargs)
93
 
 
 
 
94
  self.vocab_size = vocab_size
95
  self.hidden_size = hidden_size
96
  self.num_hidden_layers = num_hidden_layers
@@ -113,4 +116,4 @@ class JinaBertConfig(PretrainedConfig):
113
  self.num_tasks = num_tasks
114
  self.use_flash_attn = use_flash_attn
115
  self.use_qk_norm = use_qk_norm
116
- self.emb_pooler = emb_pooler
 
91
  assert 'max_position_embeddings' not in kwargs
92
  super().__init__(pad_token_id=pad_token_id, **kwargs)
93
 
94
+ if fused_mlp and hidden_act not in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]:
95
+ raise ValueError('Fused MLP only supports approximate gelu')
96
+
97
  self.vocab_size = vocab_size
98
  self.hidden_size = hidden_size
99
  self.num_hidden_layers = num_hidden_layers
 
116
  self.num_tasks = num_tasks
117
  self.use_flash_attn = use_flash_attn
118
  self.use_qk_norm = use_qk_norm
119
+ self.emb_pooler = emb_pooler
flash_components/bert_padding.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+
7
+
8
+ class IndexFirstAxis(torch.autograd.Function):
9
+ @staticmethod
10
+ def forward(ctx, input, indices):
11
+ ctx.save_for_backward(indices)
12
+ assert input.ndim >= 2
13
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
14
+ second_dim = other_shape.numel()
15
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
16
+ # return input[indices]
17
+ return torch.gather(
18
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
19
+ ).reshape(-1, *other_shape)
20
+
21
+ @staticmethod
22
+ def backward(ctx, grad_output):
23
+ (indices,) = ctx.saved_tensors
24
+ assert grad_output.ndim >= 2
25
+ other_shape = grad_output.shape[1:]
26
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
27
+ grad_input = torch.zeros(
28
+ [ctx.first_axis_dim, grad_output.shape[1]],
29
+ device=grad_output.device,
30
+ dtype=grad_output.dtype,
31
+ )
32
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
33
+ # grad_input[indices] = grad_output
34
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
35
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
36
+
37
+
38
+ index_first_axis = IndexFirstAxis.apply
39
+
40
+
41
+ class IndexPutFirstAxis(torch.autograd.Function):
42
+ @staticmethod
43
+ def forward(ctx, values, indices, first_axis_dim):
44
+ ctx.save_for_backward(indices)
45
+ assert indices.ndim == 1
46
+ assert values.ndim >= 2
47
+ output = torch.zeros(
48
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
49
+ )
50
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
51
+ output[indices] = values
52
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
53
+ return output
54
+
55
+ @staticmethod
56
+ def backward(ctx, grad_output):
57
+ (indices,) = ctx.saved_tensors
58
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
59
+ grad_values = grad_output[indices]
60
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
61
+ return grad_values, None, None
62
+
63
+
64
+ index_put_first_axis = IndexPutFirstAxis.apply
65
+
66
+
67
+ class IndexFirstAxisResidual(torch.autograd.Function):
68
+ @staticmethod
69
+ def forward(ctx, input, indices):
70
+ ctx.save_for_backward(indices)
71
+ assert input.ndim >= 2
72
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
73
+ second_dim = other_shape.numel()
74
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
75
+ output = input[indices]
76
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
77
+ # memory format to channel_first. In other words, input might not be contiguous.
78
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
79
+ return output, input.detach()
80
+
81
+ @staticmethod
82
+ def backward(ctx, grad_output, grad_residual):
83
+ (indices,) = ctx.saved_tensors
84
+ assert grad_output.ndim >= 2
85
+ other_shape = grad_output.shape[1:]
86
+ assert grad_residual.shape[1:] == other_shape
87
+ grad_input = grad_residual
88
+ # grad_input[indices] += grad_output
89
+ indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
90
+ indices = indices.expand_as(grad_output)
91
+ grad_input.scatter_add_(0, indices, grad_output)
92
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
93
+
94
+
95
+ index_first_axis_residual = IndexFirstAxisResidual.apply
96
+
97
+
98
+ def unpad_input(hidden_states, attention_mask):
99
+ """
100
+ Arguments:
101
+ hidden_states: (batch, seqlen, ...)
102
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
103
+ Return:
104
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
105
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
106
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
107
+ max_seqlen_in_batch: int
108
+ """
109
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
110
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
111
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
112
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
113
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
114
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
115
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
116
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
117
+ # so we write custom forward and backward to make it a bit faster.
118
+ return (
119
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
120
+ indices,
121
+ cu_seqlens,
122
+ max_seqlen_in_batch,
123
+ )
124
+
125
+
126
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
127
+ """
128
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
129
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
130
+
131
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
132
+ ```
133
+ [
134
+ [2, 3, 0, 0, 0, 0],
135
+ [3, 2, 0, 0, 0, 0],
136
+ [6, 0, 0, 0, 0, 0]
137
+ ]
138
+ ```
139
+ , which refers to the 3D-attention mask:
140
+ ```
141
+ [
142
+ [
143
+ [1, 0, 0, 0, 0, 0],
144
+ [1, 1, 0, 0, 0, 0],
145
+ [0, 0, 1, 0, 0, 0],
146
+ [0, 0, 1, 1, 0, 0],
147
+ [0, 0, 1, 1, 1, 0],
148
+ [0, 0, 0, 0, 0, 1]
149
+ ],
150
+ [
151
+ [1, 0, 0, 0, 0, 0],
152
+ [1, 1, 0, 0, 0, 0],
153
+ [1, 1, 1, 0, 0, 0],
154
+ [0, 0, 0, 1, 0, 0],
155
+ [0, 0, 0, 1, 1, 0],
156
+ [0, 0, 0, 0, 0, 1]
157
+ ],
158
+ [
159
+ [1, 0, 0, 0, 0, 0],
160
+ [1, 1, 0, 0, 0, 0],
161
+ [1, 1, 1, 0, 0, 0],
162
+ [1, 1, 1, 1, 0, 0],
163
+ [1, 1, 1, 1, 1, 0],
164
+ [1, 1, 1, 1, 1, 1]
165
+ ]
166
+ ]
167
+ ```.
168
+
169
+ Arguments:
170
+ hidden_states: (batch, seqlen, ...)
171
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
172
+ Return:
173
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
174
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
175
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
176
+ max_seqlen_in_batch: int
177
+ """
178
+ length = attention_mask_in_length.sum(dim=-1)
179
+ seqlen = attention_mask_in_length.size(-1)
180
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
181
+ seqlen) < length.unsqueeze(
182
+ 1)
183
+ real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
184
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
185
+ indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
186
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
187
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
188
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
189
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
190
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
191
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
192
+ # so we write custom forward and backward to make it a bit faster.
193
+ return (
194
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
195
+ indices,
196
+ cu_seqlens,
197
+ max_seqlen_in_batch,
198
+ )
199
+
200
+
201
+ def pad_input(hidden_states, indices, batch, seqlen):
202
+ """
203
+ Arguments:
204
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
205
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
206
+ batch: int, batch size for the padded sequence.
207
+ seqlen: int, maximum sequence length for the padded sequence.
208
+ Return:
209
+ hidden_states: (batch, seqlen, ...)
210
+ """
211
+ dim = hidden_states.shape[-1]
212
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
213
+ # output[indices] = hidden_states
214
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
215
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
flash_components/block.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ from functools import partial
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch import Tensor
9
+ from torchvision.ops import StochasticDepth
10
+
11
+ from .mha import MHA
12
+ from .mlp import Mlp
13
+
14
+ try:
15
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
16
+ except ImportError:
17
+ layer_norm_fn, RMSNorm = None, None
18
+
19
+
20
+ class Block(nn.Module):
21
+ def __init__(
22
+ self,
23
+ dim,
24
+ mixer_cls=None,
25
+ mlp_cls=None,
26
+ norm_cls=nn.LayerNorm,
27
+ dropout_cls=nn.Dropout,
28
+ prenorm=True,
29
+ resid_dropout1=0.0,
30
+ resid_dropout2=0.0,
31
+ drop_path1=0.0,
32
+ drop_path2=0.0,
33
+ fused_dropout_add_ln=False,
34
+ return_residual=False,
35
+ residual_in_fp32=False,
36
+ sequence_parallel=False,
37
+ mark_shared_params=False,
38
+ ):
39
+ """
40
+ For prenorm=True, this Block has a slightly different structure compared to a regular
41
+ prenorm Transformer block.
42
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
43
+ [Ref: https://arxiv.org/abs/2002.04745]
44
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
45
+ the hidden_states (output of the MLP) and the residual.
46
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
47
+ The residual needs to be provided (except for the very first block).
48
+
49
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
50
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
51
+
52
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
53
+ This is for performance reason: for post-norm architecture, returning the input allows us
54
+ to fuse the backward of nn.Linear with the residual connection.
55
+ """
56
+ super().__init__()
57
+ self.prenorm = prenorm
58
+ self.fused_dropout_add_ln = fused_dropout_add_ln
59
+ self.return_residual = return_residual
60
+ self.residual_in_fp32 = residual_in_fp32
61
+ if self.residual_in_fp32:
62
+ assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
63
+ if mixer_cls is None:
64
+ mixer_cls = partial(MHA, num_heads=dim // 64)
65
+ if mlp_cls is None:
66
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
67
+ self.mixer = mixer_cls(dim)
68
+ self.dropout1 = dropout_cls(resid_dropout1)
69
+ self.drop_path1 = StochasticDepth(drop_path1, mode="row")
70
+ self.norm1 = norm_cls(dim)
71
+ self.mlp = mlp_cls(dim)
72
+ if not isinstance(self.mlp, nn.Identity):
73
+ self.dropout2 = dropout_cls(resid_dropout2)
74
+ self.drop_path2 = StochasticDepth(drop_path2, mode="row")
75
+ self.norm2 = norm_cls(dim)
76
+
77
+ if self.fused_dropout_add_ln:
78
+ assert layer_norm_fn is not None, "Triton is not installed"
79
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
80
+ self.dropout1, nn.Dropout
81
+ )
82
+
83
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
84
+ # then the input to each worker in the tensor parallel group will be different.
85
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
86
+ # For now this is not an issue because we always use sequence_parallel=True during training
87
+ # and only use sequence_parallel=False during inference.
88
+
89
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
90
+ if sequence_parallel:
91
+ for p in self.norm1.parameters():
92
+ p._sequence_parallel = True
93
+ if hasattr(self, "norm2"):
94
+ for p in self.norm2.parameters():
95
+ p._sequence_parallel = True
96
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
97
+ if mark_shared_params:
98
+ for p in self.norm1.parameters():
99
+ p._shared_params = True
100
+ if hasattr(self, "norm2"):
101
+ for p in self.norm2.parameters():
102
+ p._shared_params = True
103
+
104
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
105
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
106
+
107
+ def forward(
108
+ self,
109
+ hidden_states: Tensor,
110
+ residual: Optional[Tensor] = None,
111
+ mixer_subset=None,
112
+ mixer_kwargs=None,
113
+ ):
114
+ r"""Pass the input through the encoder layer.
115
+
116
+ Args:
117
+ hidden_states: the sequence to the encoder layer (required).
118
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
119
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
120
+ before applying the query projection. Useful for e.g., ViT where we only care
121
+ about the CLS token in the last layer.
122
+ """
123
+ if self.prenorm:
124
+ if not self.fused_dropout_add_ln:
125
+ dropped = self.drop_path1(self.dropout1(hidden_states))
126
+ residual = (dropped + residual) if residual is not None else dropped
127
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
128
+ if self.residual_in_fp32:
129
+ residual = residual.to(torch.float32)
130
+ else:
131
+ if self.drop_path1.p == 0 or not self.training:
132
+ rowscale1 = None
133
+ else:
134
+ rowscale1 = self.drop_path1(
135
+ torch.ones(
136
+ hidden_states.shape[:-1],
137
+ device=hidden_states.device,
138
+ dtype=hidden_states.dtype,
139
+ )
140
+ )
141
+ hidden_states, residual = layer_norm_fn(
142
+ hidden_states,
143
+ self.norm1.weight,
144
+ self.norm1.bias,
145
+ residual=residual,
146
+ eps=self.norm1.eps,
147
+ dropout_p=self.dropout1.p if self.training else 0.0,
148
+ rowscale=rowscale1,
149
+ prenorm=True,
150
+ residual_in_fp32=self.residual_in_fp32,
151
+ is_rms_norm=isinstance(self.norm1, RMSNorm)
152
+ )
153
+ if mixer_kwargs is None:
154
+ mixer_kwargs = {}
155
+ if mixer_subset is not None:
156
+ mixer_kwargs["mixer_subset"] = mixer_subset
157
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
158
+ if mixer_subset is not None:
159
+ residual = residual[:, mixer_subset]
160
+ if not isinstance(self.mlp, nn.Identity):
161
+ if not self.fused_dropout_add_ln:
162
+ dropped = self.drop_path2(self.dropout2(hidden_states))
163
+ residual = (dropped + residual) if residual is not None else dropped
164
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
165
+ if self.residual_in_fp32:
166
+ residual = residual.to(torch.float32)
167
+ else:
168
+ if self.drop_path2.p == 0 or not self.training:
169
+ rowscale2 = None
170
+ else:
171
+ rowscale2 = self.drop_path2(
172
+ torch.ones(
173
+ hidden_states.shape[:-1],
174
+ device=hidden_states.device,
175
+ dtype=hidden_states.dtype,
176
+ )
177
+ )
178
+ hidden_states, residual = layer_norm_fn(
179
+ hidden_states,
180
+ self.norm2.weight,
181
+ self.norm2.bias,
182
+ residual=residual,
183
+ eps=self.norm2.eps,
184
+ dropout_p=self.dropout2.p if self.training else 0.0,
185
+ rowscale=rowscale2,
186
+ prenorm=True,
187
+ residual_in_fp32=self.residual_in_fp32,
188
+ is_rms_norm=isinstance(self.norm2, RMSNorm)
189
+ )
190
+ hidden_states = self.mlp(hidden_states)
191
+ return hidden_states, residual
192
+ else:
193
+ assert residual is None
194
+ mixer_out = self.mixer(
195
+ hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
196
+ )
197
+ if self.return_residual: # mixer out is actually a pair here
198
+ mixer_out, hidden_states = mixer_out
199
+ if not self.fused_dropout_add_ln:
200
+ hidden_states = self.norm1(
201
+ (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
202
+ dtype=self.norm1.weight.dtype
203
+ )
204
+ )
205
+ else:
206
+ if self.drop_path1.p == 0 or not self.training:
207
+ rowscale1 = None
208
+ else:
209
+ rowscale1 = self.drop_path1(
210
+ torch.ones(
211
+ mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
212
+ )
213
+ )
214
+ hidden_states = layer_norm_fn(
215
+ mixer_out,
216
+ self.norm1.weight,
217
+ self.norm1.bias,
218
+ residual=hidden_states,
219
+ eps=self.norm1.eps,
220
+ dropout_p=self.dropout1.p if self.training else 0.0,
221
+ rowscale=rowscale1,
222
+ prenorm=False,
223
+ is_rms_norm=isinstance(self.norm1, RMSNorm)
224
+ )
225
+ if not isinstance(self.mlp, nn.Identity):
226
+ mlp_out = self.mlp(hidden_states)
227
+ if self.return_residual: # mlp out is actually a pair here
228
+ mlp_out, hidden_states = mlp_out
229
+ if not self.fused_dropout_add_ln:
230
+ hidden_states = self.norm2(
231
+ (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
232
+ dtype=self.norm2.weight.dtype
233
+ )
234
+ )
235
+ else:
236
+ if self.drop_path2.p == 0 or not self.training:
237
+ rowscale2 = None
238
+ else:
239
+ rowscale2 = self.drop_path2(
240
+ torch.ones(
241
+ mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
242
+ )
243
+ )
244
+ hidden_states = layer_norm_fn(
245
+ mlp_out,
246
+ self.norm2.weight,
247
+ self.norm2.bias,
248
+ residual=hidden_states,
249
+ eps=self.norm2.eps,
250
+ dropout_p=self.dropout2.p if self.training else 0.0,
251
+ rowscale=rowscale2,
252
+ prenorm=False,
253
+ is_rms_norm=isinstance(self.norm2, RMSNorm)
254
+ )
255
+ return hidden_states
256
+
257
+
258
+ class ParallelBlock(nn.Module):
259
+ """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
260
+ and PaLM.
261
+ """
262
+
263
+ def __init__(
264
+ self,
265
+ dim,
266
+ mixer_cls=None,
267
+ mlp_cls=None,
268
+ norm_cls=nn.LayerNorm,
269
+ dropout_cls=nn.Dropout,
270
+ resid_dropout1=0.0,
271
+ resid_dropout2=0.0,
272
+ tied_norm=False,
273
+ fused_dropout_add_ln=False,
274
+ residual_in_fp32=False,
275
+ sequence_parallel=False,
276
+ mark_shared_params=False,
277
+ ):
278
+ """
279
+ This Block has a slightly different structure compared to a regular
280
+ prenorm Transformer block.
281
+ The standard block is: LN -> MHA / MLP -> Dropout -> Add.
282
+ [Ref: https://arxiv.org/abs/2002.04745]
283
+ Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
284
+ the hidden_states (output1 of the MHA / MLP) and the residual.
285
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
286
+ The residual needs to be provided (except for the very first block).
287
+ """
288
+ super().__init__()
289
+ self.tied_norm = tied_norm
290
+ self.fused_dropout_add_ln = fused_dropout_add_ln
291
+ self.residual_in_fp32 = residual_in_fp32
292
+ if mixer_cls is None:
293
+ mixer_cls = partial(MHA, num_heads=dim // 64)
294
+ if mlp_cls is None:
295
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
296
+ self.mixer = mixer_cls(dim)
297
+ self.dropout1 = dropout_cls(resid_dropout1)
298
+ self.norm1 = norm_cls(dim)
299
+ self.mlp = mlp_cls(dim)
300
+ self.dropout2 = dropout_cls(resid_dropout2)
301
+ if not self.tied_norm:
302
+ self.norm2 = norm_cls(dim)
303
+
304
+ if self.fused_dropout_add_ln:
305
+ assert layer_norm_fn is not None, "Triton is not installed"
306
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
307
+ self.dropout1, nn.Dropout
308
+ )
309
+
310
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
311
+ # then the input to each worker in the tensor parallel group will be different.
312
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
313
+ # For now this is not an issue because we always use sequence_parallel=True during training
314
+ # and only use sequence_parallel=False during inference.
315
+
316
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
317
+ if sequence_parallel:
318
+ for p in self.norm1.parameters():
319
+ p._sequence_parallel = True
320
+ if hasattr(self, "norm2"):
321
+ for p in self.norm2.parameters():
322
+ p._sequence_parallel = True
323
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
324
+ if mark_shared_params:
325
+ for p in self.norm1.parameters():
326
+ p._shared_params = True
327
+ if hasattr(self, "norm2"):
328
+ for p in self.norm2.parameters():
329
+ p._shared_params = True
330
+
331
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
332
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
333
+
334
+ def forward(
335
+ self,
336
+ hidden_states1: Tensor,
337
+ hidden_states2: Optional[Tensor] = None,
338
+ residual: Optional[Tensor] = None,
339
+ mixer_kwargs=None,
340
+ ):
341
+ r"""Pass the input through the encoder layer.
342
+
343
+ Args:
344
+ hidden_states1: the output of the previous attention (mixer) or embedding layer.
345
+ hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
346
+ residual.
347
+ """
348
+ # TODO: Ideally we should only do the allgather / allreduce once for
349
+ # the Linear to MLP & Attention
350
+ if not self.fused_dropout_add_ln:
351
+ dropped1 = self.dropout1(hidden_states1)
352
+ # For the very 1st block, we only want 1 dropout, not two different dropouts
353
+ if hidden_states2 is not None:
354
+ dropped2 = self.dropout2(hidden_states2)
355
+ residual = (
356
+ (residual + dropped1 + dropped2)
357
+ if residual is not None
358
+ else dropped1 + dropped2
359
+ )
360
+ else:
361
+ residual = (residual + dropped1) if residual is not None else dropped1
362
+ hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
363
+ hidden_states2 = (
364
+ self.norm2(residual.to(dtype=self.norm2.weight.dtype))
365
+ if not self.tied_norm
366
+ else hidden_states1
367
+ )
368
+ if self.residual_in_fp32:
369
+ residual = residual.to(torch.float32)
370
+ else:
371
+ weight2, bias2 = (
372
+ (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
373
+ )
374
+ hidden_states1, *rest, residual = layer_norm_fn(
375
+ hidden_states1,
376
+ self.norm1.weight,
377
+ self.norm1.bias,
378
+ residual=residual,
379
+ x1=hidden_states2,
380
+ weight1=weight2,
381
+ bias1=bias2,
382
+ eps=self.norm1.eps,
383
+ dropout_p=self.dropout1.p if self.training else 0.0,
384
+ prenorm=True,
385
+ residual_in_fp32=self.residual_in_fp32,
386
+ is_rms_norm=isinstance(self.norm1, RMSNorm)
387
+ )
388
+ if self.tied_norm:
389
+ hidden_states2 = hidden_states1
390
+ else:
391
+ hidden_states2, = rest
392
+ if mixer_kwargs is None:
393
+ mixer_kwargs = {}
394
+ hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
395
+ hidden_states2 = self.mlp(hidden_states2)
396
+ return hidden_states1, hidden_states2, residual
flash_components/embedding.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor
6
+
7
+
8
+ class GPT2Embeddings(nn.Module):
9
+ def __init__(
10
+ self,
11
+ embed_dim,
12
+ vocab_size,
13
+ max_position_embeddings,
14
+ padding_idx=None,
15
+ word_embed_proj_dim=None,
16
+ device=None,
17
+ dtype=None,
18
+ ):
19
+ """
20
+ If max_position_embeddings <= 0, there's no position embeddings
21
+ If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
22
+ the project up to embed_dim
23
+ """
24
+ factory_kwargs = {"device": device, "dtype": dtype}
25
+ super().__init__()
26
+ if word_embed_proj_dim is None:
27
+ self.word_embeddings = nn.Embedding(
28
+ vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
29
+ )
30
+ self.project_in = None
31
+ else:
32
+ self.word_embeddings = nn.Embedding(
33
+ vocab_size, word_embed_proj_dim, padding_idx=padding_idx, **factory_kwargs
34
+ )
35
+ self.project_in = nn.Linear(
36
+ word_embed_proj_dim, embed_dim, bias=False, **factory_kwargs
37
+ )
38
+ self.max_position_embeddings = max_position_embeddings
39
+ if self.max_position_embeddings > 0:
40
+ self.position_embeddings = nn.Embedding(
41
+ max_position_embeddings, embed_dim, **factory_kwargs
42
+ )
43
+
44
+ def forward(self, input_ids, position_ids=None):
45
+ """
46
+ input_ids: (batch, seqlen)
47
+ position_ids: (batch, seqlen)
48
+ """
49
+ batch_size, seqlen = input_ids.shape
50
+ embeddings = self.word_embeddings(input_ids)
51
+ if self.project_in is not None:
52
+ embeddings = self.project_in(embeddings)
53
+ if self.max_position_embeddings > 0:
54
+ if position_ids is None:
55
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
56
+ position_embeddings = self.position_embeddings(position_ids)
57
+ embeddings = embeddings + position_embeddings
58
+ return embeddings
59
+
60
+
61
+ class BertEmbeddings(nn.Module):
62
+ def __init__(
63
+ self,
64
+ embed_dim,
65
+ vocab_size,
66
+ max_position_embeddings,
67
+ type_vocab_size,
68
+ padding_idx=None,
69
+ device=None,
70
+ dtype=None,
71
+ ):
72
+ """
73
+ If max_position_embeddings <= 0, there's no position embeddings
74
+ If type_vocab_size <= 0, there's no token type embeddings
75
+ """
76
+ factory_kwargs = {"device": device, "dtype": dtype}
77
+ super().__init__()
78
+ self.word_embeddings = nn.Embedding(
79
+ vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
80
+ )
81
+ self.max_position_embeddings = max_position_embeddings
82
+ self.type_vocab_size = type_vocab_size
83
+ if self.max_position_embeddings > 0:
84
+ self.position_embeddings = nn.Embedding(
85
+ max_position_embeddings, embed_dim, **factory_kwargs
86
+ )
87
+ if self.type_vocab_size > 0:
88
+ self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
89
+
90
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
91
+ """
92
+ input_ids: (batch, seqlen)
93
+ position_ids: (batch, seqlen)
94
+ token_type_ids: (batch, seqlen)
95
+ """
96
+ batch_size, seqlen = input_ids.shape
97
+ embeddings = self.word_embeddings(input_ids)
98
+ if self.max_position_embeddings > 0:
99
+ if position_ids is None:
100
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
101
+ position_embeddings = self.position_embeddings(position_ids)
102
+ embeddings = embeddings + position_embeddings
103
+ if self.type_vocab_size > 0:
104
+ if token_type_ids is None:
105
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
106
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
107
+ embeddings = embeddings + token_type_embeddings
108
+ return embeddings
109
+
110
+
111
+ class VocabParallelEmbedding(nn.Embedding):
112
+ def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
113
+ self.process_group = process_group
114
+ if process_group is not None:
115
+ world_size = torch.distributed.get_world_size(process_group)
116
+ if num_embeddings % world_size != 0:
117
+ raise ValueError(
118
+ f"num_embeddings ({num_embeddings}) must be divisible by "
119
+ f"world_size ({world_size})"
120
+ )
121
+ if world_size > 1 and padding_idx is not None:
122
+ raise RuntimeError("ParallelEmbedding does not support padding_idx")
123
+ else:
124
+ world_size = 1
125
+ super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
126
+
127
+ def forward(self, input: Tensor) -> Tensor:
128
+ if self.process_group is None:
129
+ return super().forward(input)
130
+ else:
131
+ rank = torch.distributed.get_rank(self.process_group)
132
+ vocab_size = self.num_embeddings
133
+ vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
134
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
135
+ input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
136
+ input = input - vocab_start_index
137
+ input[input_ids_mask] = 0
138
+ embeddings = super().forward(input)
139
+ embeddings[input_ids_mask] = 0.0
140
+ return embeddings
141
+
142
+
143
+ class ColumnParallelEmbedding(nn.Embedding):
144
+ def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
145
+ self.process_group = process_group
146
+ if process_group is not None:
147
+ world_size = torch.distributed.get_world_size(process_group)
148
+ if embedding_dim % world_size != 0:
149
+ raise ValueError(
150
+ f"embedding_dim ({embedding_dim}) must be divisible by "
151
+ f"world_size ({world_size})"
152
+ )
153
+ else:
154
+ world_size = 1
155
+ super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
156
+
157
+
flash_components/mha.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange, repeat
9
+
10
+ try:
11
+ from flash_attn import (
12
+ flash_attn_kvpacked_func,
13
+ flash_attn_qkvpacked_func,
14
+ flash_attn_varlen_kvpacked_func,
15
+ flash_attn_varlen_qkvpacked_func,
16
+ flash_attn_with_kvcache,
17
+ )
18
+ except ImportError:
19
+ flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
20
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
21
+ flash_attn_with_kvcache = None
22
+
23
+ try:
24
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
25
+ except ImportError:
26
+ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
27
+
28
+ try:
29
+ from flash_attn.layers.rotary import RotaryEmbedding
30
+ except ImportError:
31
+ RotaryEmbedding = None
32
+
33
+
34
+ # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
35
+ def get_alibi_slopes(nheads):
36
+ def get_slopes_power_of_2(nheads):
37
+ start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
38
+ ratio = start
39
+ return [start * ratio**i for i in range(nheads)]
40
+
41
+ if math.log2(nheads).is_integer():
42
+ return get_slopes_power_of_2(nheads)
43
+ else:
44
+ closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
45
+ return (
46
+ get_slopes_power_of_2(closest_power_of_2)
47
+ + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
48
+ )
49
+
50
+ class MultiHeadLayernorm(nn.Module):
51
+ def __init__(self, head_dim, num_heads, eps=1e-05, shared_normalization=False):
52
+ super().__init__()
53
+ if shared_normalization:
54
+ self._reduce_dims = (-2, -1)
55
+ else:
56
+ self._reduce_dims = (-1,)
57
+ self.weight = nn.Parameter(torch.ones((num_heads, head_dim)))
58
+ self.bias = nn.Parameter(torch.zeros((num_heads, head_dim)))
59
+ self.eps = eps
60
+
61
+ def forward(self, x):
62
+ var, mean = torch.var_mean(x, dim=self._reduce_dims, keepdim=True)
63
+ x = (x - mean) / torch.sqrt(var + self.eps)
64
+ return self.weight * x + self.bias
65
+
66
+ class FlashSelfAttention(nn.Module):
67
+ """Implement the scaled dot product attention with softmax.
68
+ Arguments
69
+ ---------
70
+ softmax_scale: The temperature to use for the softmax attention.
71
+ (default: 1/sqrt(d_keys) where d_keys is computed at
72
+ runtime)
73
+ attention_dropout: The dropout rate to apply to the attention
74
+ (default: 0.0)
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ causal=False,
80
+ softmax_scale=None,
81
+ attention_dropout=0.0,
82
+ window_size=(-1, -1),
83
+ alibi_slopes=None,
84
+ deterministic=False,
85
+ qk_norm_kwargs=None,
86
+ ):
87
+ super().__init__()
88
+ assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
89
+ assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
90
+ self.causal = causal
91
+ self.softmax_scale = softmax_scale
92
+ self.drop = nn.Dropout(attention_dropout)
93
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
94
+ self.window_size = window_size
95
+ self.deterministic = deterministic
96
+ if qk_norm_kwargs is not None:
97
+ self.qk_norm = True
98
+ self.q_layernorm = MultiHeadLayernorm(**qk_norm_kwargs)
99
+ self.k_layernorm = MultiHeadLayernorm(**qk_norm_kwargs)
100
+ else:
101
+ self.qk_norm = False
102
+ self.q_layernorm = None
103
+ self.k_layernorm = None
104
+
105
+ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
106
+ """Implements the multihead softmax attention.
107
+ Arguments
108
+ ---------
109
+ qkv: The tensor containing the query, key, and value.
110
+ If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
111
+ If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
112
+ (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
113
+ causal: if passed, will override self.causal
114
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
115
+ of the sequences in the batch, used to index into qkv.
116
+ max_seqlen: int. Maximum sequence length in the batch.
117
+ Returns:
118
+ --------
119
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
120
+ else (B, S, H, D).
121
+ """
122
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
123
+ assert qkv.is_cuda
124
+ if self.qk_norm:
125
+ if cu_seqlens is None:
126
+ assert qkv.shape[2] == 3
127
+ q, k, v = qkv.unbind(2)
128
+ q = self.q_layernorm(q)
129
+ k = self.k_layernorm(k)
130
+ qkv = torch.stack([q, k, v], dim=2)
131
+ else:
132
+ assert qkv.shape[1] == 3
133
+ q, k, v = qkv.unbind(1)
134
+ q = self.q_layernorm(q)
135
+ k = self.k_layernorm(k)
136
+ qkv = torch.stack([q, k, v], dim=1)
137
+ causal = self.causal if causal is None else causal
138
+ unpadded = cu_seqlens is not None
139
+ if self.alibi_slopes is not None:
140
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
141
+ if unpadded:
142
+ assert cu_seqlens.dtype == torch.int32
143
+ assert max_seqlen is not None
144
+ assert isinstance(max_seqlen, int)
145
+ return flash_attn_varlen_qkvpacked_func(
146
+ qkv,
147
+ cu_seqlens,
148
+ max_seqlen,
149
+ self.drop.p if self.training else 0.0,
150
+ softmax_scale=self.softmax_scale,
151
+ causal=causal,
152
+ alibi_slopes=self.alibi_slopes,
153
+ window_size=self.window_size,
154
+ deterministic=self.deterministic,
155
+ )
156
+ else:
157
+ return flash_attn_qkvpacked_func(
158
+ qkv,
159
+ self.drop.p if self.training else 0.0,
160
+ softmax_scale=self.softmax_scale,
161
+ causal=causal,
162
+ alibi_slopes=self.alibi_slopes,
163
+ window_size=self.window_size,
164
+ deterministic=self.deterministic,
165
+ )
166
+
167
+
168
+ class FlashCrossAttention(nn.Module):
169
+ """Implement the scaled dot product attention with softmax.
170
+ Arguments
171
+ ---------
172
+ softmax_scale: The temperature to use for the softmax attention.
173
+ (default: 1/sqrt(d_keys) where d_keys is computed at
174
+ runtime)
175
+ attention_dropout: The dropout rate to apply to the attention
176
+ (default: 0.0)
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ causal=False,
182
+ softmax_scale=None,
183
+ attention_dropout=0.0,
184
+ alibi_slopes=None,
185
+ window_size=(-1, -1),
186
+ deterministic=False,
187
+ ):
188
+ super().__init__()
189
+ assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
190
+ assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
191
+ self.causal = causal
192
+ self.softmax_scale = softmax_scale
193
+ self.drop = nn.Dropout(attention_dropout)
194
+ self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
195
+ self.window_size = window_size
196
+ self.deterministic = deterministic
197
+
198
+ def forward(
199
+ self,
200
+ q,
201
+ kv,
202
+ causal=None,
203
+ cu_seqlens=None,
204
+ max_seqlen=None,
205
+ cu_seqlens_k=None,
206
+ max_seqlen_k=None,
207
+ ):
208
+ """Implements the multihead softmax attention.
209
+ Arguments
210
+ ---------
211
+ q: The tensor containing the query. (B, Sq, H, D)
212
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
213
+ causal: if passed, will override self.causal
214
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
215
+ of the sequences in the batch, used to index into q.
216
+ max_seqlen: int. Maximum sequence length in the batch of q.
217
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
218
+ of the sequences in the batch, used to index into kv.
219
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
220
+ """
221
+ assert q.dtype in [torch.float16, torch.bfloat16]
222
+ assert q.is_cuda and kv.is_cuda
223
+ causal = self.causal if causal is None else causal
224
+ unpadded = cu_seqlens is not None
225
+ if self.alibi_slopes is not None:
226
+ self.alibi_slopes = self.alibi_slopes.to(torch.float32)
227
+ if unpadded:
228
+ assert cu_seqlens.dtype == torch.int32
229
+ assert max_seqlen is not None
230
+ assert isinstance(max_seqlen, int)
231
+ assert cu_seqlens_k is not None
232
+ assert cu_seqlens_k.dtype == torch.int32
233
+ assert max_seqlen_k is not None
234
+ assert isinstance(max_seqlen, int)
235
+ return flash_attn_varlen_kvpacked_func(
236
+ q,
237
+ kv,
238
+ cu_seqlens,
239
+ cu_seqlens_k,
240
+ max_seqlen,
241
+ max_seqlen_k,
242
+ self.drop.p if self.training else 0.0,
243
+ softmax_scale=self.softmax_scale,
244
+ causal=causal,
245
+ alibi_slopes=self.alibi_slopes,
246
+ window_size=self.window_size,
247
+ deterministic=self.deterministic,
248
+ )
249
+ else:
250
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
251
+ seqlen_k = kv.shape[1]
252
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
253
+ return flash_attn_kvpacked_func(
254
+ q,
255
+ kv,
256
+ self.drop.p if self.training else 0.0,
257
+ causal=causal,
258
+ softmax_scale=self.softmax_scale,
259
+ alibi_slopes=self.alibi_slopes,
260
+ window_size=self.window_size,
261
+ deterministic=self.deterministic,
262
+ )
263
+
264
+
265
+ class SelfAttention(nn.Module):
266
+ """Implement the scaled dot product attention with softmax.
267
+ Arguments
268
+ ---------
269
+ softmax_scale: The temperature to use for the softmax attention.
270
+ (default: 1/sqrt(d_keys) where d_keys is computed at
271
+ runtime)
272
+ attention_dropout: The dropout rate to apply to the attention
273
+ (default: 0.0)
274
+ """
275
+ def __init__(self,
276
+ causal=False,
277
+ softmax_scale=None,
278
+ attention_dropout=0.0,
279
+ alibi_slopes=None,
280
+ qk_norm_kwargs=None,
281
+ ):
282
+ super().__init__()
283
+ self.causal = causal
284
+ self.softmax_scale = softmax_scale
285
+ self.drop = nn.Dropout(attention_dropout)
286
+ self.register_buffer('alibi_slopes', alibi_slopes, persistent=False)
287
+ if alibi_slopes is not None:
288
+ self.register_buffer('linear_biases', self._build_linear_biases(16), persistent=False)
289
+ else:
290
+ self.linear_biases = None
291
+ if qk_norm_kwargs is not None:
292
+ self.qk_norm = True
293
+ self.q_layernorm = MultiHeadLayernorm(**qk_norm_kwargs)
294
+ self.k_layernorm = MultiHeadLayernorm(**qk_norm_kwargs)
295
+ else:
296
+ self.qk_norm = False
297
+ self.q_layernorm = None
298
+ self.k_layernorm = None
299
+
300
+ def _build_linear_biases(self, seqlen):
301
+ context_position = torch.arange(seqlen, device=self.alibi_slopes.device)[:, None]
302
+ memory_position = torch.arange(seqlen, device=self.alibi_slopes.device)[None, :]
303
+ # distance tensor is of shape (seqlen, seqlen)
304
+ distance = torch.abs(memory_position - context_position)
305
+ # alibi tensor is of shape (1, H, seqlen, seqlen)
306
+ linear_biases = (distance[None, ...] * self.alibi_slopes[:, None, None])[None, ...]
307
+ return linear_biases
308
+
309
+ def forward(self, qkv, causal=None, key_padding_mask=None):
310
+ """Implements the multihead softmax attention.
311
+ Arguments
312
+ ---------
313
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
314
+ causal: if passed, will override self.causal
315
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
316
+ False means to mask out. (B, S)
317
+ """
318
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
319
+ causal = self.causal if causal is None else causal
320
+ q, k, v = qkv.unbind(dim=2)
321
+ if self.qk_norm:
322
+ q = self.q_layernorm(q)
323
+ k = self.k_layernorm(k)
324
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
325
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
326
+ if key_padding_mask is not None:
327
+ padding_mask = torch.full(
328
+ (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
329
+ )
330
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
331
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
332
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
333
+ if self.alibi_slopes is not None:
334
+ if seqlen > self.linear_biases.shape[-1]:
335
+ self.linear_biases = self._build_linear_biases(seqlen)
336
+ cropped_biases = self.linear_biases[..., :seqlen, :seqlen]
337
+ scores = scores - cropped_biases
338
+ if causal:
339
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
340
+ # So we have to construct the mask in float
341
+ causal_mask = torch.triu(
342
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
343
+ )
344
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
345
+ scores = scores + causal_mask.to(dtype=scores.dtype)
346
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
347
+ attention_drop = self.drop(attention)
348
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
349
+ return output
350
+
351
+
352
+ class CrossAttention(nn.Module):
353
+ """Implement the scaled dot product attention with softmax.
354
+ Arguments
355
+ ---------
356
+ softmax_scale: The temperature to use for the softmax attention.
357
+ (default: 1/sqrt(d_keys) where d_keys is computed at
358
+ runtime)
359
+ attention_dropout: The dropout rate to apply to the attention
360
+ (default: 0.0)
361
+ """
362
+
363
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
364
+ super().__init__()
365
+ self.causal = causal
366
+ self.softmax_scale = softmax_scale
367
+ self.drop = nn.Dropout(attention_dropout)
368
+
369
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
370
+ """Implements the multihead softmax attention.
371
+ Arguments
372
+ ---------
373
+ q: The tensor containing the query. (B, Sq, H, D)
374
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
375
+ causal: if passed, will override self.causal
376
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
377
+ False means to mask out. (B, Sk)
378
+ """
379
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
380
+ causal = self.causal if causal is None else causal
381
+ seqlen_k = kv.shape[1]
382
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
383
+ if kv.shape[3] != q.shape[2]: # MQA/GQA
384
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
385
+ k, v = kv.unbind(dim=2)
386
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
387
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
388
+ if key_padding_mask is not None:
389
+ padding_mask = torch.full(
390
+ (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
391
+ )
392
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
393
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
394
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
395
+ if causal:
396
+ # causal mask needs to take into account the difference between seqlen_q and seqlen_k
397
+ row_idx = rearrange(
398
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
399
+ )
400
+ col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
401
+ sk = (
402
+ seqlen_k
403
+ if key_padding_mask is None
404
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
405
+ )
406
+ causal_mask = col_idx > row_idx + sk - seqlen_q
407
+ scores = scores.masked_fill(causal_mask, -10000.0)
408
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
409
+ attention_drop = self.drop(attention)
410
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
411
+ return output
412
+
413
+
414
+ class LinearResidual(nn.Linear):
415
+ """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
416
+
417
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
418
+ return super().forward(input), input
419
+
420
+
421
+ def _update_kv_cache(kv, inference_params, layer_idx):
422
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
423
+ # Pre-allocate memory for key-values for inference.
424
+ num_heads, head_dim = kv.shape[-2:]
425
+ if layer_idx not in inference_params.key_value_memory_dict:
426
+ kv_cache = torch.empty(
427
+ inference_params.max_batch_size,
428
+ inference_params.max_seqlen,
429
+ 2,
430
+ num_heads,
431
+ head_dim,
432
+ dtype=kv.dtype,
433
+ device=kv.device,
434
+ )
435
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
436
+ else:
437
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
438
+ # Adjust key and value for inference
439
+ batch_start = inference_params.batch_size_offset
440
+ batch_end = batch_start + kv.shape[0]
441
+ sequence_start = inference_params.seqlen_offset
442
+ sequence_end = sequence_start + kv.shape[1]
443
+ assert batch_end <= kv_cache.shape[0]
444
+ assert sequence_end <= kv_cache.shape[1]
445
+ assert kv_cache is not None
446
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
447
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
448
+
449
+
450
+ class MHA(nn.Module):
451
+ """Multi-head self-attention and cross-attention"""
452
+
453
+ def __init__(
454
+ self,
455
+ embed_dim,
456
+ num_heads,
457
+ num_heads_kv=None,
458
+ cross_attn=False,
459
+ qkv_proj_bias=True,
460
+ out_proj_bias=True,
461
+ dropout=0.0,
462
+ softmax_scale=None,
463
+ causal=False,
464
+ layer_idx=None,
465
+ dwconv=False,
466
+ rotary_emb_dim=0,
467
+ rotary_emb_base=10000.0,
468
+ rotary_emb_scale_base=None,
469
+ rotary_emb_interleaved=False,
470
+ use_alibi=False,
471
+ window_size=(-1, -1),
472
+ fused_bias_fc=False,
473
+ use_flash_attn=False,
474
+ return_residual=False,
475
+ checkpointing=False,
476
+ device=None,
477
+ dtype=None,
478
+ qk_norm=False,
479
+ qk_norm_kwargs=None,
480
+ ) -> None:
481
+ """
482
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
483
+ return_residual: whether to return the input x along with the output. This is for
484
+ performance reason: for post-norm architecture, returning the input allows us
485
+ to fuse the backward of nn.Linear with the residual connection.
486
+ """
487
+ if qk_norm and cross_attn:
488
+ raise NotImplementedError('QK normalization is only implemented for self-attention.')
489
+ if qk_norm:
490
+ qk_norm_kwargs = qk_norm_kwargs if qk_norm_kwargs is not None else {}
491
+ qk_norm_kwargs.update({'num_heads': num_heads, 'head_dim': embed_dim // num_heads})
492
+ factory_kwargs = {"device": device, "dtype": dtype}
493
+ super().__init__()
494
+ self.embed_dim = embed_dim
495
+ self.cross_attn = cross_attn
496
+ self.causal = causal
497
+ self.layer_idx = layer_idx
498
+ self.dwconv = dwconv
499
+ self.rotary_emb_dim = rotary_emb_dim
500
+ self.use_flash_attn = use_flash_attn
501
+ self.return_residual = return_residual
502
+ self.checkpointing = checkpointing
503
+ if use_alibi:
504
+ assert not cross_attn or use_flash_attn, "ALiBi code path requires self-attention or cross-attention with flash_attn"
505
+ alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
506
+ else:
507
+ alibi_slopes = None
508
+ if window_size != (-1, -1):
509
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
510
+
511
+ self.num_heads = num_heads
512
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
513
+ assert (
514
+ self.num_heads % self.num_heads_kv == 0
515
+ ), "num_heads must be divisible by num_heads_kv"
516
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
517
+ self.head_dim = self.embed_dim // num_heads
518
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
519
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
520
+
521
+ if self.rotary_emb_dim > 0:
522
+ assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
523
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
524
+ self.rotary_emb = RotaryEmbedding(
525
+ self.rotary_emb_dim,
526
+ base=rotary_emb_base,
527
+ scale_base=rotary_emb_scale_base,
528
+ interleaved=rotary_emb_interleaved,
529
+ device=device,
530
+ )
531
+
532
+ if fused_bias_fc and FusedDense is None:
533
+ raise ImportError("fused_dense is not installed")
534
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
535
+ linear_resid_cls = (
536
+ LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
537
+ )
538
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
539
+ inner_attn_cls = (
540
+ partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size, qk_norm_kwargs=qk_norm_kwargs)
541
+ if use_flash_attn
542
+ else partial(SelfAttention, alibi_slopes=alibi_slopes, qk_norm_kwargs=qk_norm_kwargs)
543
+ )
544
+ inner_cross_attn_cls = (
545
+ partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
546
+ if use_flash_attn
547
+ else CrossAttention
548
+ )
549
+ if not self.cross_attn:
550
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
551
+ else:
552
+ self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
553
+ self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
554
+ if self.dwconv:
555
+ if self.num_heads_kv == self.num_heads:
556
+ self.dwconv_qkv = nn.Conv1d(
557
+ qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
558
+ )
559
+ else:
560
+ self.dwconv_q = nn.Conv1d(
561
+ embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
562
+ )
563
+ self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
564
+ self.inner_attn = inner_attn_cls(
565
+ causal=causal,
566
+ softmax_scale=softmax_scale,
567
+ attention_dropout=dropout,
568
+ )
569
+ self.inner_cross_attn = inner_cross_attn_cls(
570
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
571
+ )
572
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
573
+
574
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
575
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
576
+ device = self.out_proj.weight.device
577
+ return torch.empty(
578
+ batch_size,
579
+ max_seqlen,
580
+ 2,
581
+ self.num_heads_kv,
582
+ self.head_dim,
583
+ dtype=dtype,
584
+ device=device,
585
+ )
586
+
587
+ def _update_kv_cache(self, kv, inference_params):
588
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
589
+ assert not self.dwconv, "Generation does not support dwconv yet"
590
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
591
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
592
+
593
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
594
+ """
595
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
596
+ q: (batch_size, seqlen_q, nheads, head_dim)
597
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
598
+ """
599
+ assert inference_params is not None and inference_params.seqlen_offset > 0
600
+ assert self.use_flash_attn
601
+ if self.rotary_emb_dim > 0:
602
+ assert self.rotary_emb.scale is None, "This code path does not support xPos"
603
+ self.rotary_emb._update_cos_sin_cache(
604
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
605
+ )
606
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
607
+ else:
608
+ rotary_cos, rotary_sin = None, None
609
+ batch = q.shape[0]
610
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
611
+ cache_seqlens = (
612
+ inference_params.lengths_per_sample[:batch]
613
+ if inference_params.lengths_per_sample is not None
614
+ else inference_params.seqlen_offset
615
+ )
616
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
617
+ context = flash_attn_with_kvcache(
618
+ q,
619
+ kv_cache[:, :, 0],
620
+ kv_cache[:, :, 1],
621
+ kv[:, :, 0],
622
+ kv[:, :, 1],
623
+ rotary_cos=rotary_cos,
624
+ rotary_sin=rotary_sin,
625
+ cache_seqlens=cache_seqlens,
626
+ softmax_scale=self.inner_cross_attn.softmax_scale,
627
+ causal=self.inner_cross_attn.causal,
628
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
629
+ alibi_slopes=alibi_slopes,
630
+ )
631
+ return context
632
+
633
+ def _update_kvcache_attention(self, q, kv, inference_params):
634
+ """Write kv to inference_params, then do attention"""
635
+ if (
636
+ inference_params.seqlen_offset == 0
637
+ or flash_attn_with_kvcache is None
638
+ or not self.use_flash_attn
639
+ ):
640
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
641
+ kv = self._update_kv_cache(kv, inference_params)
642
+ return self.inner_cross_attn(q, kv)
643
+ else:
644
+ batch = q.shape[0]
645
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
646
+ cache_seqlens = (
647
+ inference_params.lengths_per_sample[:batch]
648
+ if inference_params.lengths_per_sample is not None
649
+ else inference_params.seqlen_offset
650
+ )
651
+ alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
652
+ return flash_attn_with_kvcache(
653
+ q,
654
+ kv_cache[:, :, 0],
655
+ kv_cache[:, :, 1],
656
+ kv[:, :, 0],
657
+ kv[:, :, 1],
658
+ cache_seqlens=cache_seqlens,
659
+ softmax_scale=self.inner_cross_attn.softmax_scale,
660
+ causal=self.inner_cross_attn.causal,
661
+ alibi_slopes=alibi_slopes,
662
+ )
663
+
664
+ def forward(
665
+ self,
666
+ x,
667
+ x_kv=None,
668
+ key_padding_mask=None,
669
+ cu_seqlens=None,
670
+ max_seqlen=None,
671
+ mixer_subset=None,
672
+ inference_params=None,
673
+ **kwargs,
674
+ ):
675
+ """
676
+ Arguments:
677
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
678
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
679
+ is the is the sum of the sequence lengths in the batch.
680
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
681
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
682
+ of the sequences in the batch, used to index into x. Only applicable when using
683
+ FlashAttention.
684
+ max_seqlen: int. Maximum sequence length in the batch.
685
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
686
+ (batch, seqlen). Only applicable when not using FlashAttention.
687
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
688
+ before applying the query projection. Useful for e.g., ViT where we only care
689
+ about the CLS token in the last layer.
690
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
691
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
692
+ """
693
+ if cu_seqlens is not None:
694
+ assert max_seqlen is not None
695
+ assert key_padding_mask is None
696
+ assert self.use_flash_attn
697
+ assert not self.dwconv
698
+ assert self.rotary_emb_dim == 0
699
+ if key_padding_mask is not None:
700
+ assert cu_seqlens is None
701
+ assert max_seqlen is None
702
+ assert not self.use_flash_attn
703
+ if inference_params is not None:
704
+ assert key_padding_mask is None
705
+ assert cu_seqlens is None and max_seqlen is None
706
+ assert not self.dwconv
707
+
708
+ kwargs = (
709
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
710
+ if self.use_flash_attn
711
+ else {"key_padding_mask": key_padding_mask, **kwargs}
712
+ )
713
+ seqlen_offset = (
714
+ 0
715
+ if inference_params is None
716
+ else (
717
+ inference_params.lengths_per_sample
718
+ if inference_params.lengths_per_sample is not None
719
+ else inference_params.seqlen_offset
720
+ )
721
+ )
722
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
723
+ batch, seqlen = x.shape[:2]
724
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
725
+ assert x_kv is None and mixer_subset is None
726
+ if not self.return_residual:
727
+ qkv = self.Wqkv(x)
728
+ else:
729
+ qkv, x = self.Wqkv(x)
730
+ if self.dwconv:
731
+ qkv = rearrange(
732
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
733
+ ).contiguous()
734
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
735
+ if (
736
+ inference_params is None
737
+ or inference_params.seqlen_offset == 0
738
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
739
+ or not self.use_flash_attn
740
+ ):
741
+ if self.rotary_emb_dim > 0:
742
+ qkv = self.rotary_emb(
743
+ qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
744
+ )
745
+ if inference_params is None:
746
+ if not self.checkpointing:
747
+ context = self.inner_attn(qkv, **kwargs)
748
+ else:
749
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, use_reentrant=False, **kwargs)
750
+ else:
751
+ context = self._update_kvcache_attention(
752
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
753
+ )
754
+ else:
755
+ context = self._apply_rotary_update_kvcache_attention(
756
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
757
+ )
758
+ else:
759
+ if self.cross_attn:
760
+ if not self.return_residual:
761
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
762
+ kv = self.Wkv(x_kv if x_kv is not None else x)
763
+ else:
764
+ if x_kv is not None:
765
+ kv, x_kv = self.Wkv(x_kv)
766
+ else:
767
+ kv, x = self.Wkv(x)
768
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
769
+ else:
770
+ assert self.num_heads_kv != self.num_heads
771
+ if not self.return_residual:
772
+ qkv = self.Wqkv(x)
773
+ else:
774
+ qkv, x = self.Wqkv(x)
775
+ q = qkv[..., : self.num_heads * self.head_dim]
776
+ kv = qkv[..., self.num_heads * self.head_dim :]
777
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
778
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
779
+ if self.dwconv:
780
+ q = rearrange(
781
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
782
+ ).contiguous()
783
+ kv = rearrange(
784
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
785
+ ).contiguous()
786
+ if (
787
+ inference_params is None
788
+ or inference_params.seqlen_offset == 0
789
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
790
+ or not self.use_flash_attn
791
+ ):
792
+ if self.rotary_emb_dim > 0:
793
+ q, kv = self.rotary_emb(
794
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
795
+ )
796
+ if inference_params is None:
797
+ if not self.checkpointing:
798
+ context = self.inner_cross_attn(q, kv, **kwargs)
799
+ else:
800
+ context = torch.utils.checkpoint.checkpoint(
801
+ self.inner_cross_attn, q, kv, use_reentrant=False, **kwargs
802
+ )
803
+ else:
804
+ context = self._update_kvcache_attention(q, kv, inference_params)
805
+ else:
806
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
807
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
808
+ return out if not self.return_residual else (out, x)
flash_components/mlp.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.distributed import ProcessGroup
7
+
8
+
9
+ try:
10
+ from flash_attn.ops.activations import swiglu
11
+ except ImportError:
12
+ swiglu = None
13
+
14
+ try:
15
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
16
+ except ImportError:
17
+ ColumnParallelLinear, RowParallelLinear = None, None
18
+
19
+ try:
20
+ from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
21
+ except ImportError:
22
+ FusedMLP, ParallelFusedMLP = None, None
23
+
24
+
25
+ class Mlp(nn.Module):
26
+ def __init__(
27
+ self,
28
+ in_features,
29
+ hidden_features=None,
30
+ out_features=None,
31
+ activation=F.gelu,
32
+ bias1=True,
33
+ bias2=True,
34
+ return_residual=False,
35
+ device=None,
36
+ dtype=None,
37
+ ):
38
+ factory_kwargs = {"device": device, "dtype": dtype}
39
+ super().__init__()
40
+ out_features = out_features if out_features is not None else in_features
41
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
42
+ self.return_residual = return_residual
43
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
44
+ self.activation = activation
45
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
46
+
47
+ def forward(self, x):
48
+ y = self.fc1(x)
49
+ y = self.activation(y)
50
+ y = self.fc2(y)
51
+ return y if not self.return_residual else (y, x)
52
+
53
+
54
+ class ParallelMLP(nn.Module):
55
+ def __init__(
56
+ self,
57
+ in_features,
58
+ hidden_features=None,
59
+ out_features=None,
60
+ activation=F.gelu,
61
+ process_group: ProcessGroup = None,
62
+ sequence_parallel=True,
63
+ bias1=True,
64
+ bias2=True,
65
+ device=None,
66
+ dtype=None,
67
+ ):
68
+ factory_kwargs = {"device": device, "dtype": dtype}
69
+ super().__init__()
70
+ assert ColumnParallelLinear is not None, "Need to install fused_dense"
71
+ assert RowParallelLinear is not None, "Need to install fused_dense"
72
+ out_features = out_features if out_features is not None else in_features
73
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
74
+ self.fc1 = ColumnParallelLinear(
75
+ in_features,
76
+ hidden_features,
77
+ process_group,
78
+ bias=bias1,
79
+ sequence_parallel=sequence_parallel,
80
+ **factory_kwargs,
81
+ )
82
+ self.activation = activation
83
+ self.fc2 = RowParallelLinear(
84
+ hidden_features,
85
+ out_features,
86
+ process_group,
87
+ bias=bias2,
88
+ sequence_parallel=sequence_parallel,
89
+ **factory_kwargs,
90
+ )
91
+
92
+ def forward(self, x):
93
+ y = self.fc1(x)
94
+ y = self.activation(y)
95
+ y = self.fc2(y)
96
+ return y
97
+
98
+
99
+ class GatedMlp(nn.Module):
100
+ def __init__(
101
+ self,
102
+ in_features,
103
+ hidden_features=None,
104
+ out_features=None,
105
+ activation=F.sigmoid,
106
+ bias1=True,
107
+ bias2=True,
108
+ multiple_of=128,
109
+ return_residual=False,
110
+ device=None,
111
+ dtype=None,
112
+ ):
113
+ factory_kwargs = {"device": device, "dtype": dtype}
114
+ super().__init__()
115
+ out_features = out_features if out_features is not None else in_features
116
+ hidden_features = (
117
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
118
+ )
119
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
120
+ self.return_residual = return_residual
121
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
122
+ self.activation = activation
123
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
124
+
125
+ def forward(self, x):
126
+ y = self.fc1(x)
127
+ if self.activation == F.sigmoid: # Special case for GLU
128
+ y = F.glu(y, dim=-1)
129
+ elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
130
+ y, gate = y.chunk(2, dim=-1)
131
+ y = swiglu(gate, y)
132
+ else:
133
+ y, gate = y.chunk(2, dim=-1)
134
+ y = y * self.activation(gate)
135
+ y = self.fc2(y)
136
+ return y if not self.return_residual else (y, x)
137
+
138
+
139
+ class ParallelGatedMlp(nn.Module):
140
+ """Parallel GatedMlp"""
141
+
142
+ def __init__(
143
+ self,
144
+ in_features,
145
+ process_group,
146
+ hidden_features=None,
147
+ out_features=None,
148
+ activation=F.sigmoid,
149
+ bias1=True,
150
+ bias2=True,
151
+ multiple_of=128,
152
+ sequence_parallel=True,
153
+ device=None,
154
+ dtype=None,
155
+ ):
156
+ factory_kwargs = {"device": device, "dtype": dtype}
157
+ super().__init__()
158
+ out_features = out_features if out_features is not None else in_features
159
+ hidden_features = (
160
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
161
+ )
162
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
163
+ if ColumnParallelLinear is None or RowParallelLinear is None:
164
+ raise ImportError("fused_dense is not installed")
165
+ self.fc1 = ColumnParallelLinear(
166
+ in_features,
167
+ 2 * hidden_features,
168
+ process_group,
169
+ bias=bias1,
170
+ sequence_parallel=sequence_parallel,
171
+ **factory_kwargs,
172
+ )
173
+ self.activation = activation
174
+ self.fc2 = RowParallelLinear(
175
+ hidden_features,
176
+ out_features,
177
+ process_group,
178
+ bias=bias2,
179
+ sequence_parallel=sequence_parallel,
180
+ **factory_kwargs,
181
+ )
182
+
183
+ def forward(self, x):
184
+ y = self.fc1(x)
185
+ if self.activation == F.sigmoid: # Special case for GLU
186
+ y = F.glu(y, dim=-1)
187
+ else:
188
+ y, gate = y.chunk(2, dim=-1)
189
+ y = y * self.activation(gate)
190
+ y = self.fc2(y)
191
+ return y
modeling_bert.py CHANGED
@@ -29,17 +29,17 @@ from transformers.models.bert.modeling_bert import (
29
  BaseModelOutputWithPoolingAndCrossAttentions,
30
  BertForPreTrainingOutput,
31
  )
32
- from flash_attn.bert_padding import (
33
  index_first_axis,
34
  index_first_axis_residual,
35
  pad_input,
36
  unpad_input,
37
  )
38
 
39
- from flash_attn.modules.block import Block
40
- from flash_attn.modules.embedding import BertEmbeddings
41
- from flash_attn.modules.mha import MHA
42
- from flash_attn.modules.mlp import FusedMLP, Mlp
43
 
44
  try:
45
  from flash_attn.ops.fused_dense import FusedDense
 
29
  BaseModelOutputWithPoolingAndCrossAttentions,
30
  BertForPreTrainingOutput,
31
  )
32
+ from .flash_components.bert_padding import (
33
  index_first_axis,
34
  index_first_axis_residual,
35
  pad_input,
36
  unpad_input,
37
  )
38
 
39
+ from .flash_components.block import Block
40
+ from .flash_components.embedding import BertEmbeddings
41
+ from .flash_components.mha import MHA
42
+ from .flash_components.mlp import FusedMLP, Mlp
43
 
44
  try:
45
  from flash_attn.ops.fused_dense import FusedDense
small_config.json DELETED
@@ -1,30 +0,0 @@
1
- {
2
- "_name_or_path": "jinaai/jina-bert-flash-implementation",
3
- "auto_map": {
4
- "AutoConfig": "jinaai/jina-bert-flash-implementation--configuration_bert.JinaBertConfig",
5
- "AutoModel": "jinaai/jina-bert-flash-implementation--modeling_bert.BertModel",
6
- "AutoModelForPreTraining": "jinaai/jina-bert-flash-implementation--modeling_bert.BertForPreTraining",
7
- "AutoModelForMaskedLM": "jinaai/jina-bert-flash-implementation--modeling_bert.BertForPreTraining"
8
- },
9
- "vocab_size": 30528,
10
- "hidden_size": 512,
11
- "num_hidden_layers": 4,
12
- "num_attention_heads": 8,
13
- "intermediate_size": 2048,
14
- "hidden_act": "gelu",
15
- "hidden_dropout_prob": 0.1,
16
- "attention_probs_dropout_prob": 0.1,
17
- "type_vocab_size": 0,
18
- "initializer_range": 0.02,
19
- "layer_norm_eps": 1e-12,
20
- "pad_token_id": 0,
21
- "dense_seq_output": true,
22
- "fused_mlp": false,
23
- "mlp_checkpoint_lvl": 0,
24
- "last_layer_subset": false,
25
- "fused_dropout_add_ln": false,
26
- "fused_bias_fc": false,
27
- "pad_vocab_size_multiple": 1,
28
- "num_tasks": 6,
29
- "use_flash_attn": true
30
- }