Pklett commited on
Commit
c2d160f
·
1 Parent(s): 925cd56

upload custom code

Browse files
Files changed (5) hide show
  1. attention.py +771 -0
  2. blocks.py +120 -0
  3. configuration.py +207 -0
  4. modeling_mpt.py +837 -0
  5. utils.py +17 -0
attention.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/mosaicml/llm-foundry
2
+ # Classes changed: MultiheadAttention
3
+ # Functions changed: scaled_multihead_dot_product_attention, build_alibi_bias, build_attn_bias
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ """Attention layers."""
7
+ import math
8
+ import warnings
9
+ from typing import Optional
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange
13
+ from packaging import version
14
+ from torch import nn
15
+ from torch.linalg import vector_norm
16
+ from llmfoundry.models.layers.norm import LPLayerNorm
17
+ from torch.nn import functional as F
18
+
19
+ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
20
+ original_is_causal: bool):
21
+ # disable causal when it is not needed
22
+ # necessary for flash & triton for generation with kv_cache
23
+ if original_is_causal and num_query_tokens != num_key_tokens:
24
+ if num_query_tokens != 1:
25
+ raise NotImplementedError(
26
+ 'MPT does not support query and key with different number of tokens, unless number of query tokens is 1.'
27
+ )
28
+ else:
29
+ return False
30
+ return original_is_causal
31
+
32
+
33
+ def scaled_multihead_dot_product_attention(
34
+ query,
35
+ key,
36
+ value,
37
+ n_heads,
38
+ past_key_value=None,
39
+ long_range_past_key_value=None,
40
+ softmax_scale=None,
41
+ attn_bias=None,
42
+ attn_bias_ae=None,
43
+ key_padding_mask=None,
44
+ is_causal=False,
45
+ dropout_p=0.0,
46
+ training=False,
47
+ needs_weights=False,
48
+ multiquery=False,
49
+ topk=None,
50
+ faiss_indexes=None,
51
+ n_layers=None,
52
+ current_layer=None,
53
+ mask_by_sim=False,
54
+ sim_threshold=0.0
55
+ ):
56
+ q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
57
+ kv_n_heads = 1 if multiquery else n_heads
58
+ k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
59
+ v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
60
+
61
+ had_kv=False
62
+ if past_key_value is not None:
63
+ # attn_impl: flash & triton use kernels which expect input shape [b, s, h, d_head].
64
+ # kv_cache is therefore stored using that shape.
65
+ # attn_impl: torch stores the kv_cache in the ordering which is most advantageous
66
+ # for its attn computation ie
67
+ # keys are stored as tensors with shape [b, h, d_head, s] and
68
+ # values are stored as tensors with shape [b, h, s, d_head]
69
+ if len(past_key_value) != 0:
70
+ k = torch.cat([past_key_value[0], k], dim=3)
71
+ v = torch.cat([past_key_value[1], v], dim=2)
72
+ had_kv=True
73
+
74
+ past_key_value = (k, v)
75
+
76
+ b, h, s_q, d = q.shape
77
+ s_k = k.size(-1)
78
+
79
+ if softmax_scale is None:
80
+ softmax_scale = 1 / math.sqrt(d)
81
+
82
+ attn_weight = q.matmul(k) * softmax_scale
83
+
84
+ if attn_bias is not None:
85
+ # clamp to 0 necessary for torch 2.0 compile()
86
+ _s_q = max(0, attn_bias.size(2) - s_q)
87
+ _s_k = max(0, attn_bias.size(3) - s_k)
88
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
89
+
90
+ if (attn_bias.size(-1) != 1 and
91
+ attn_bias.size(-1) != s_k) or (attn_bias.size(-2) != 1 and
92
+ attn_bias.size(-2) != s_q):
93
+ raise RuntimeError(
94
+ f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.'
95
+ )
96
+ attn_weight = attn_weight + attn_bias
97
+
98
+ if needs_weights:
99
+ reshaped_idx = None
100
+ if long_range_past_key_value is not None or faiss_indexes is not None:
101
+ if long_range_past_key_value is not None: #manual external memories
102
+
103
+ k_cache, v_cache = long_range_past_key_value
104
+ s_cache = k_cache.size(-1)
105
+
106
+ k_cache = k_cache.to(k.device)
107
+ v_cache = v_cache.to(k.device)
108
+
109
+ q_n = q/vector_norm(q, ord=2, dim=-1, keepdim=True)
110
+ k_n = k_cache/vector_norm(k_cache, ord=2, dim=-2, keepdim=True)
111
+
112
+ sim = q_n.matmul(k_n)
113
+ if s_cache<topk:
114
+ topk = s_cache #number of tokens in cache < topk
115
+ val, idx = torch.topk(sim, k=topk, dim=-1)
116
+
117
+ reshaped_idx = idx.reshape(b, h, s_q * topk)
118
+
119
+ selected_k = k_cache.gather(dim=-1, index=reshaped_idx.unsqueeze(-2).expand(-1, -1, d, -1))
120
+ selected_v = v_cache.gather(dim=-2, index=reshaped_idx.unsqueeze(-1).expand(-1, -1, -1, d))
121
+
122
+ sim_mask = rearrange(~ (val > sim_threshold).bool(), 'b h s i -> b h (s i)').unsqueeze(-2).expand(-1, -1, s_q, -1)
123
+ min_val = torch.finfo(selected_k.dtype).min
124
+
125
+ elif faiss_indexes is not None: #faiss indexes
126
+
127
+ kn_index, kv_index = faiss_indexes
128
+ q_n = q/vector_norm(q, ord=2, dim=-1, keepdim=True)
129
+
130
+ one_hot_encodings = F.one_hot(torch.arange(0, n_heads*n_layers, device=q.device))*10
131
+ q_n = torch.concat([rearrange(q_n, 'b h s d -> b (h s) d', h=n_heads), one_hot_encodings[n_heads*current_layer:n_heads*(current_layer+1)].unsqueeze(0).repeat_interleave(repeats=q.size(-2), dim=-2)], dim=-1).squeeze()
132
+
133
+ D, I = kn_index.search(q_n.to('cpu').numpy(), k=topk)
134
+
135
+ selected_k=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,:d], '(h s) d -> 1 h d s', h=32).to(q.device)
136
+ selected_v=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,d:], '(h s) d -> 1 h s d', h=32).to(q.device)
137
+
138
+ s_k_ae = selected_k.size(-1)
139
+ s_k += s_k_ae
140
+ attn_weight_cache = q.matmul(selected_k) * softmax_scale
141
+ if mask_by_sim:
142
+ attn_weight_cache = attn_weight_cache.masked_fill(sim_mask, min_val)
143
+
144
+ if attn_bias_ae is not None:
145
+ # clamp to 0 necessary for torch 2.0 compile()
146
+ _s_q = max(0, attn_bias_ae.size(2) - s_q)
147
+ _s_k = max(0, attn_bias_ae.size(3) - s_k_ae)
148
+ attn_bias_ae = attn_bias_ae[:, :, _s_q:, _s_k:]
149
+
150
+ if (attn_bias_ae.size(-1) != 1 and
151
+ attn_bias_ae.size(-1) != s_k_ae) or (attn_bias_ae.size(-2) != 1 and
152
+ attn_bias_ae.size(-2) != s_q):
153
+ raise RuntimeError(
154
+ f'attn_bias (shape: {attn_bias_ae.shape}) is expected to broadcast to shape: {attn_weight_cache.shape}.'
155
+ )
156
+ attn_weight_cache = attn_weight_cache + attn_bias_ae
157
+
158
+ attn_weight = torch.cat([attn_weight_cache, attn_weight], dim=-1)
159
+ v = torch.cat([selected_v, v], dim=-2)
160
+
161
+ min_val = torch.finfo(q.dtype).min
162
+
163
+ if key_padding_mask is not None:
164
+ if attn_bias is not None:
165
+ warnings.warn(
166
+ 'Propogating key_padding_mask to the attention module ' +\
167
+ 'and applying it within the attention module can cause ' +\
168
+ 'unneccessary computation/memory usage. Consider integrating ' +\
169
+ 'into attn_bias once and passing that to each attention ' +\
170
+ 'module instead.'
171
+ )
172
+ attn_weight = attn_weight.masked_fill(
173
+ ~key_padding_mask.view((b, 1, 1, s_k)), min_val)
174
+
175
+ def _create_active_externalism_mask(k, s_q, device):
176
+ mask = torch.zeros(s_q, s_q * k, device=device, dtype=torch.bool)
177
+ for i in range(s_q):
178
+ mask[i, i * k : (i + 1) * k] = 1
179
+ return ~mask
180
+
181
+ if is_causal and (not q.size(2) == 1):
182
+ s = max(s_q, s_k)
183
+ causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
184
+ causal_mask = causal_mask.tril()
185
+ causal_mask = causal_mask.to(torch.bool)
186
+ causal_mask = ~causal_mask
187
+ causal_mask = causal_mask[-s_q:, -s_k:]
188
+
189
+ if long_range_past_key_value is not None:
190
+ mask = _create_active_externalism_mask(k=topk,s_q=s_q, device=attn_weight.device)
191
+ s=s_q
192
+ if had_kv:
193
+ s += (past_key_value[0][0].size(-1) -s_q)
194
+ causal_mask = torch.cat([mask, causal_mask[:,-s:]], dim=1)
195
+
196
+ attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k),
197
+ min_val)
198
+
199
+ attn_weight = torch.softmax(attn_weight, dim=-1)
200
+
201
+ if dropout_p:
202
+ attn_weight = torch.nn.functional.dropout(attn_weight,
203
+ p=dropout_p,
204
+ training=training,
205
+ inplace=True)
206
+
207
+ out = attn_weight.to(v.dtype).matmul(v)
208
+ out = rearrange(out, 'b h s d -> b s (h d)')
209
+
210
+ if needs_weights:
211
+ return out, attn_weight, past_key_value, reshaped_idx
212
+ return out, None, past_key_value, None
213
+
214
+
215
+ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
216
+ for tensor in tensors:
217
+ if tensor.dtype not in valid_dtypes:
218
+ raise TypeError(f'{tensor.dtype=} must be in {valid_dtypes=}.')
219
+ if not tensor.is_cuda:
220
+ raise TypeError(f'Inputs must be cuda tensors ({tensor.is_cuda=}).')
221
+
222
+
223
+ def flash_attn_fn(
224
+ query,
225
+ key,
226
+ value,
227
+ n_heads,
228
+ past_key_value=None,
229
+ softmax_scale=None,
230
+ attn_bias=None,
231
+ key_padding_mask=None,
232
+ is_causal=False,
233
+ dropout_p=0.0,
234
+ training=False,
235
+ needs_weights=False,
236
+ multiquery=False,
237
+ ):
238
+ try:
239
+ from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
240
+ except:
241
+ raise RuntimeError('Please install flash-attn==1.0.3.post0')
242
+
243
+ check_valid_inputs(query, key, value)
244
+
245
+ if past_key_value is not None:
246
+ if len(past_key_value) != 0:
247
+ key = torch.cat([past_key_value[0], key], dim=1)
248
+ value = torch.cat([past_key_value[1], value], dim=1)
249
+
250
+ past_key_value = (key, value)
251
+
252
+ if attn_bias is not None:
253
+ # clamp to 0 necessary for torch 2.0 compile()
254
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
255
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
256
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
257
+
258
+ if attn_bias is not None:
259
+ raise NotImplementedError(f'attn_bias not implemented for flash attn.')
260
+
261
+ batch_size, seqlen = query.shape[:2]
262
+
263
+ if key_padding_mask is None:
264
+ key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
265
+ query_padding_mask = key_padding_mask[:, -query.size(1):]
266
+
267
+ query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input(
268
+ query, query_padding_mask)
269
+ query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
270
+
271
+ key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input(
272
+ key, key_padding_mask)
273
+ key_unpad = rearrange(key_unpad,
274
+ 'nnz (h d) -> nnz h d',
275
+ h=1 if multiquery else n_heads)
276
+
277
+ value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask)
278
+ value_unpad = rearrange(value_unpad,
279
+ 'nnz (h d) -> nnz h d',
280
+ h=1 if multiquery else n_heads)
281
+
282
+ if multiquery:
283
+ # Expanding a tensor does not allocate new memory, but only creates a new
284
+ # view on the existing tensor where a dimension of size one is expanded
285
+ # to a larger size by setting the stride to 0.
286
+ # - pytorch docs
287
+ #
288
+ # hopefully the kernels can utilize this and we're jot just wasting BW here
289
+ key_unpad = key_unpad.expand(key_unpad.size(0), n_heads,
290
+ key_unpad.size(-1))
291
+ value_unpad = value_unpad.expand(value_unpad.size(0), n_heads,
292
+ value_unpad.size(-1))
293
+
294
+ dropout_p = dropout_p if training else 0.0
295
+
296
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
297
+
298
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(
299
+ query_unpad,
300
+ key_unpad,
301
+ value_unpad,
302
+ cu_seqlens_q,
303
+ cu_seqlens_k,
304
+ max_seqlen_q,
305
+ max_seqlen_k,
306
+ dropout_p,
307
+ softmax_scale=softmax_scale,
308
+ causal=reset_is_causal,
309
+ return_attn_probs=needs_weights)
310
+
311
+ output = bert_padding.pad_input(
312
+ rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
313
+ seqlen)
314
+ return output, None, past_key_value
315
+
316
+
317
+ def triton_flash_attn_fn(
318
+ query,
319
+ key,
320
+ value,
321
+ n_heads,
322
+ past_key_value=None,
323
+ softmax_scale=None,
324
+ attn_bias=None,
325
+ key_padding_mask=None,
326
+ is_causal=False,
327
+ dropout_p=0.0,
328
+ training=False,
329
+ needs_weights=False,
330
+ multiquery=False,
331
+ ):
332
+ try:
333
+ from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
334
+ except:
335
+ _installed = False
336
+ if version.parse(torch.__version__) < version.parse('2.0.0'):
337
+ _installed = True
338
+ # if torch1.13.1 revert to using triton flash attn from HazyResearch
339
+ # with flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202
340
+ try:
341
+ from flash_attn.flash_attn_triton import flash_attn_func
342
+ except:
343
+ _installed = False
344
+ if not _installed:
345
+ # installing triton-pre-mlir works for both torch1.13.1 and torch2.0+
346
+ # default recommendation is to install this variant
347
+ raise RuntimeError(
348
+ 'Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU '
349
+ 'and `pip install .[gpu]` if installing from llm-foundry source or '
350
+ '`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` '
351
+ 'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). '
352
+ 'Note: (1) requires you have CMake and PyTorch already installed.'
353
+ )
354
+
355
+ check_valid_inputs(query, key, value)
356
+
357
+ if past_key_value is not None:
358
+ if len(past_key_value) != 0:
359
+ key = torch.cat([past_key_value[0], key], dim=1)
360
+ value = torch.cat([past_key_value[1], value], dim=1)
361
+
362
+ past_key_value = (key, value)
363
+
364
+ if attn_bias is not None:
365
+ # clamp to 0 necessary for torch 2.0 compile()
366
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
367
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
368
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
369
+
370
+ if dropout_p:
371
+ raise NotImplementedError(
372
+ f'Dropout not implemented for attn_impl: triton.')
373
+
374
+ if needs_weights:
375
+ raise NotImplementedError(
376
+ f'attn_impl: triton cannot return attn weights.')
377
+
378
+ if key_padding_mask is not None:
379
+ warnings.warn(
380
+ 'Propagating key_padding_mask to the attention module ' +\
381
+ 'and applying it within the attention module can cause ' +\
382
+ 'unnecessary computation/memory usage. Consider integrating ' +\
383
+ 'into attn_bias once and passing that to each attention ' +\
384
+ 'module instead.'
385
+ )
386
+ b_size, s_k = key_padding_mask.shape[:2]
387
+
388
+ if attn_bias is None:
389
+ attn_bias = query.new_zeros(b_size, 1, 1, s_k)
390
+
391
+ attn_bias = attn_bias.masked_fill(
392
+ ~key_padding_mask.view((b_size, 1, 1, s_k)),
393
+ torch.finfo(query.dtype).min)
394
+
395
+ query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
396
+ key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
397
+ value = rearrange(value,
398
+ 'b s (h d) -> b s h d',
399
+ h=1 if multiquery else n_heads)
400
+
401
+ if multiquery:
402
+ # Expanding a tensor does not allocate new memory, but only creates a new
403
+ # view on the existing tensor where a dimension of size one is expanded
404
+ # to a larger size by setting the stride to 0.
405
+ # - pytorch docs
406
+ #
407
+ # hopefully the kernels can utilize this and we're jot just wasting BW here
408
+ key = key.expand(*key.shape[:2], n_heads, key.size(-1))
409
+ value = value.expand(*value.shape[:2], n_heads, value.size(-1))
410
+
411
+ reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
412
+ attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal,
413
+ softmax_scale)
414
+
415
+ output = attn_output.view(*attn_output.shape[:2], -1)
416
+
417
+ return output, None, past_key_value
418
+
419
+
420
+ class MultiheadAttention(nn.Module):
421
+ """Multi-head self attention.
422
+
423
+ Using torch or triton attention implemetation enables user to also use
424
+ additive bias.
425
+ """
426
+
427
+ def __init__(
428
+ self,
429
+ d_model: int,
430
+ n_heads: int,
431
+ attn_impl: str = 'triton',
432
+ clip_qkv: Optional[float] = None,
433
+ qk_ln: bool = False,
434
+ softmax_scale: Optional[float] = None,
435
+ attn_pdrop: float = 0.0,
436
+ low_precision_layernorm: bool = False,
437
+ verbose: int = 0,
438
+ device: Optional[str] = None,
439
+ ):
440
+ super().__init__()
441
+
442
+ self.attn_impl = attn_impl
443
+ self.clip_qkv = clip_qkv
444
+ self.qk_ln = qk_ln
445
+
446
+ self.d_model = d_model
447
+ self.n_heads = n_heads
448
+ self.softmax_scale = softmax_scale
449
+ if self.softmax_scale is None:
450
+ self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
451
+ self.attn_dropout_p = attn_pdrop
452
+
453
+ self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
454
+ # for param init fn; enables shape based init of fused layers
455
+ fuse_splits = (d_model, 2 * d_model)
456
+ self.Wqkv._fused = (0, fuse_splits) # type: ignore
457
+
458
+ if self.qk_ln:
459
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
460
+ self.q_ln = layernorm_class(self.d_model, device=device)
461
+ self.k_ln = layernorm_class(self.d_model, device=device)
462
+
463
+ if self.attn_impl == 'flash':
464
+ self.attn_fn = flash_attn_fn
465
+ elif self.attn_impl == 'triton':
466
+ self.attn_fn = triton_flash_attn_fn
467
+ if verbose:
468
+ warnings.warn(
469
+ 'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +\
470
+ 'it uses more memory. When training larger models this can trigger ' +\
471
+ 'alloc retries which hurts performance. If encountered, we recommend ' +\
472
+ 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.'
473
+ )
474
+ elif self.attn_impl == 'torch':
475
+ self.attn_fn = scaled_multihead_dot_product_attention
476
+ if torch.cuda.is_available() and verbose:
477
+ warnings.warn(
478
+ 'Using `attn_impl: torch`. If your model does not use `alibi` or ' +\
479
+ '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +\
480
+ 'we recommend using `attn_impl: triton`.'
481
+ )
482
+ else:
483
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
484
+
485
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
486
+ self.out_proj._is_residual = True # type: ignore
487
+
488
+ def forward(
489
+ self,
490
+ x,
491
+ past_key_value=None,
492
+ long_range_past_key_value=None,
493
+ attn_bias=None,
494
+ attn_bias_ae=None,
495
+ attention_mask=None,
496
+ is_causal=True,
497
+ needs_weights=False,
498
+ topk=None,
499
+ faiss_indexes=None,
500
+ n_layers=None,
501
+ current_layer=None,
502
+ mask_by_sim=None,
503
+ sim_threshold=None
504
+ ):
505
+ qkv = self.Wqkv(x)
506
+
507
+ if self.clip_qkv:
508
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
509
+
510
+ query, key, value = qkv.chunk(3, dim=2)
511
+
512
+ key_padding_mask = attention_mask
513
+
514
+ if self.qk_ln:
515
+ # Applying layernorm to qk
516
+ dtype = query.dtype
517
+ query = self.q_ln(query).to(dtype)
518
+ key = self.k_ln(key).to(dtype)
519
+
520
+ context, attn_weights, past_key_value, reshaped_idx = self.attn_fn(
521
+ query,
522
+ key,
523
+ value,
524
+ self.n_heads,
525
+ past_key_value=past_key_value,
526
+ long_range_past_key_value=long_range_past_key_value,
527
+ softmax_scale=self.softmax_scale,
528
+ attn_bias=attn_bias,
529
+ attn_bias_ae=attn_bias_ae,
530
+ key_padding_mask=key_padding_mask,
531
+ is_causal=is_causal,
532
+ dropout_p=self.attn_dropout_p,
533
+ training=self.training,
534
+ needs_weights=needs_weights,
535
+ topk=topk,
536
+ faiss_indexes=faiss_indexes,
537
+ n_layers=n_layers,
538
+ current_layer=current_layer,
539
+ mask_by_sim=mask_by_sim,
540
+ sim_threshold=sim_threshold
541
+ )
542
+
543
+ return self.out_proj(context), attn_weights, past_key_value, reshaped_idx
544
+
545
+
546
+ class MultiQueryAttention(nn.Module):
547
+ """Multi-Query self attention.
548
+
549
+ Using torch or triton attention implemetation enables user to also use
550
+ additive bias.
551
+ """
552
+
553
+ def __init__(
554
+ self,
555
+ d_model: int,
556
+ n_heads: int,
557
+ attn_impl: str = 'triton',
558
+ clip_qkv: Optional[float] = None,
559
+ qk_ln: bool = False,
560
+ softmax_scale: Optional[float] = None,
561
+ attn_pdrop: float = 0.0,
562
+ low_precision_layernorm: bool = False,
563
+ verbose: int = 0,
564
+ device: Optional[str] = None,
565
+ ):
566
+ super().__init__()
567
+
568
+ self.attn_impl = attn_impl
569
+ self.clip_qkv = clip_qkv
570
+ self.qk_ln = qk_ln
571
+
572
+ self.d_model = d_model
573
+ self.n_heads = n_heads
574
+ self.head_dim = d_model // n_heads
575
+ self.softmax_scale = softmax_scale
576
+ if self.softmax_scale is None:
577
+ self.softmax_scale = 1 / math.sqrt(self.head_dim)
578
+ self.attn_dropout_p = attn_pdrop
579
+
580
+ # NOTE: if we ever want to make attn TensorParallel, I'm pretty sure we'll
581
+ # want to split Wqkv into Wq and Wkv where Wq can be TensorParallel but
582
+ # Wkv shouldn't be TensorParallel
583
+ # - vchiley
584
+ self.Wqkv = nn.Linear(
585
+ d_model,
586
+ d_model + 2 * self.head_dim,
587
+ device=device,
588
+ )
589
+ # for param init fn; enables shape based init of fused layers
590
+ fuse_splits = (d_model, d_model + self.head_dim)
591
+ self.Wqkv._fused = (0, fuse_splits) # type: ignore
592
+
593
+ if self.qk_ln:
594
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
595
+ self.q_ln = layernorm_class(d_model, device=device)
596
+ self.k_ln = layernorm_class(self.head_dim, device=device)
597
+
598
+ if self.attn_impl == 'flash':
599
+ self.attn_fn = flash_attn_fn
600
+ elif self.attn_impl == 'triton':
601
+ self.attn_fn = triton_flash_attn_fn
602
+ if verbose:
603
+ warnings.warn(
604
+ 'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +\
605
+ 'it uses more memory. When training larger models this can trigger ' +\
606
+ 'alloc retries which hurts performance. If encountered, we recommend ' +\
607
+ 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.'
608
+ )
609
+ elif self.attn_impl == 'torch':
610
+ self.attn_fn = scaled_multihead_dot_product_attention
611
+ if torch.cuda.is_available() and verbose:
612
+ warnings.warn(
613
+ 'Using `attn_impl: torch`. If your model does not use `alibi` or ' +\
614
+ '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +\
615
+ 'we recommend using `attn_impl: triton`.'
616
+ )
617
+ else:
618
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
619
+
620
+ self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
621
+ self.out_proj._is_residual = True # type: ignore
622
+
623
+ def forward(
624
+ self,
625
+ x,
626
+ past_key_value=None,
627
+ attn_bias=None,
628
+ attention_mask=None,
629
+ is_causal=True,
630
+ needs_weights=False,
631
+ ):
632
+ qkv = self.Wqkv(x)
633
+
634
+ if self.clip_qkv:
635
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
636
+
637
+ query, key, value = qkv.split(
638
+ [self.d_model, self.head_dim, self.head_dim], dim=2)
639
+
640
+ key_padding_mask = attention_mask
641
+
642
+ if self.qk_ln:
643
+ # Applying layernorm to qk
644
+ dtype = query.dtype
645
+ query = self.q_ln(query).to(dtype)
646
+ key = self.k_ln(key).to(dtype)
647
+
648
+ context, attn_weights, past_key_value = self.attn_fn(
649
+ query,
650
+ key,
651
+ value,
652
+ self.n_heads,
653
+ past_key_value=past_key_value,
654
+ softmax_scale=self.softmax_scale,
655
+ attn_bias=attn_bias,
656
+ key_padding_mask=key_padding_mask,
657
+ is_causal=is_causal,
658
+ dropout_p=self.attn_dropout_p,
659
+ training=self.training,
660
+ needs_weights=needs_weights,
661
+ multiquery=True,
662
+ )
663
+
664
+ return self.out_proj(context), attn_weights, past_key_value
665
+
666
+
667
+ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal,
668
+ use_sequence_id):
669
+ if attn_impl == 'flash':
670
+ return None
671
+ elif attn_impl in ['torch', 'triton']:
672
+ if alibi:
673
+ if (prefix_lm or not causal) or use_sequence_id:
674
+ return (1, n_heads, seq_len, seq_len)
675
+ return (1, n_heads, 1, seq_len)
676
+ elif prefix_lm or use_sequence_id:
677
+ return (1, 1, seq_len, seq_len)
678
+ return None
679
+ else:
680
+ raise ValueError(f'{attn_impl=} is an invalid setting.')
681
+
682
+
683
+ def build_attn_bias(
684
+ attn_impl,
685
+ n_heads,
686
+ seq_len,
687
+ attn_bias=None,
688
+ causal=False,
689
+ alibi=False,
690
+ alibi_bias_max=8,
691
+ for_ae=False,
692
+ topk=0,
693
+ device=None,
694
+ dtype=None
695
+ ):
696
+ if attn_impl == 'flash':
697
+ return None
698
+ elif attn_impl in ['torch', 'triton']:
699
+ if alibi:
700
+ # in place add alibi to attn bias
701
+ if attn_bias is not None:
702
+ attn_bias = attn_bias.add(
703
+ build_alibi_bias(
704
+ n_heads,
705
+ seq_len,
706
+ full=not causal,
707
+ alibi_bias_max=alibi_bias_max,
708
+ device=device,
709
+ dtype=dtype,
710
+ for_ae=for_ae,
711
+ topk=topk
712
+ ))
713
+ else:
714
+ attn_bias = build_alibi_bias(
715
+ n_heads,
716
+ seq_len,
717
+ full=not causal,
718
+ alibi_bias_max=alibi_bias_max,
719
+ for_ae=for_ae,
720
+ topk=topk)
721
+ return attn_bias
722
+
723
+
724
+ def gen_slopes(n_heads, alibi_bias_max=8, device=None):
725
+ _n_heads = 2**math.ceil(math.log2(n_heads))
726
+ m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
727
+ m = m.mul(alibi_bias_max / _n_heads)
728
+ slopes = (1. / torch.pow(2, m))
729
+
730
+ if _n_heads != n_heads:
731
+ # if n_heads is not a power of two,
732
+ # Huggingface and FasterTransformer calculate slopes normally,
733
+ # then return this strided concatenation of slopes
734
+ slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
735
+
736
+ return slopes.view(1, n_heads, 1, 1)
737
+
738
+
739
+ def build_alibi_bias(
740
+ n_heads,
741
+ seq_len,
742
+ full=False,
743
+ alibi_bias_max=8,
744
+ device=None,
745
+ dtype=None,
746
+ for_ae=False,
747
+ topk=0
748
+ ):
749
+ if not for_ae:
750
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32,
751
+ device=device).view(1, 1, 1, seq_len)
752
+ else:
753
+ alibi_bias = torch.tensor(-seq_len, dtype=torch.int32,
754
+ device=device).repeat(seq_len*topk).view(1, 1, 1, seq_len*(topk))
755
+ if full:
756
+ # generate 1 x Heads x SeqLen x SeqLen alibi bias mask
757
+ # otherwise the mask is 1 x Heads x 1 x SeqLen (which is broadcast to the appropriate size)
758
+ alibi_bias = alibi_bias - torch.arange(
759
+ 1 - seq_len, 1, dtype=torch.int32, device=device).view(
760
+ 1, 1, seq_len, 1)
761
+ alibi_bias = alibi_bias.abs().mul(-1)
762
+
763
+ slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
764
+ alibi_bias = alibi_bias * slopes
765
+ return alibi_bias.to(dtype=dtype)
766
+
767
+
768
+ ATTN_CLASS_REGISTRY = {
769
+ 'multihead_attention': MultiheadAttention,
770
+ 'multiquery_attention': MultiQueryAttention,
771
+ }
blocks.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/mosaicml/llm-foundry
2
+ # Classes changed: MPTBlock
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """GPT Blocks used for the GPT Model."""
6
+
7
+ from typing import Dict, Optional, Tuple
8
+ import torch
9
+ import torch.nn as nn
10
+ from .attention import ATTN_CLASS_REGISTRY
11
+ from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
12
+
13
+ class MPTMLP(nn.Module):
14
+
15
+ def __init__(self,
16
+ d_model: int,
17
+ expansion_ratio: int,
18
+ device: Optional[str] = None):
19
+ super().__init__()
20
+ self.up_proj = nn.Linear(d_model,
21
+ expansion_ratio * d_model,
22
+ device=device)
23
+ self.act = nn.GELU(approximate='none')
24
+ self.down_proj = nn.Linear(expansion_ratio * d_model,
25
+ d_model,
26
+ device=device)
27
+ self.down_proj._is_residual = True # type: ignore
28
+
29
+ def forward(self, x):
30
+ return self.down_proj(self.act(self.up_proj(x)))
31
+
32
+ class MPTBlock(nn.Module):
33
+ def __init__(
34
+ self,
35
+ d_model: int,
36
+ n_heads: int,
37
+ expansion_ratio: int,
38
+ attn_config: Dict = {
39
+ 'attn_type': 'multihead_attention',
40
+ 'attn_pdrop': 0.0,
41
+ 'attn_impl': 'triton',
42
+ 'qk_ln': False,
43
+ 'clip_qkv': None,
44
+ 'softmax_scale': None,
45
+ 'prefix_lm': False,
46
+ 'attn_uses_sequence_id': False,
47
+ 'alibi': False,
48
+ 'alibi_bias_max': 8,
49
+ },
50
+ resid_pdrop: float = 0.0,
51
+ norm_type: str = 'low_precision_layernorm',
52
+ verbose: int = 0,
53
+ device: Optional[str] = None,
54
+ **kwargs):
55
+ del kwargs # unused, just to capture any extra args from the config
56
+ super().__init__()
57
+
58
+ norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
59
+ attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
60
+
61
+ self.norm_1 = norm_class(d_model, device=device)
62
+ self.attn = attn_class(
63
+ attn_impl=attn_config['attn_impl'],
64
+ clip_qkv=attn_config['clip_qkv'],
65
+ qk_ln=attn_config['qk_ln'],
66
+ softmax_scale=attn_config['softmax_scale'],
67
+ attn_pdrop=attn_config['attn_pdrop'],
68
+ d_model=d_model,
69
+ n_heads=n_heads,
70
+ verbose=verbose,
71
+ device=device,
72
+ )
73
+ self.norm_2 = norm_class(d_model, device=device)
74
+ self.ffn = MPTMLP(
75
+ d_model=d_model,
76
+ expansion_ratio=expansion_ratio,
77
+ device=device,
78
+ )
79
+ self.resid_attn_dropout = nn.Dropout(resid_pdrop)
80
+ self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
81
+
82
+ def forward(
83
+ self,
84
+ x: torch.Tensor,
85
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
86
+ long_range_past_key_value:Optional[Tuple[torch.Tensor]] = None,
87
+ attn_bias: Optional[torch.Tensor] = None,
88
+ attn_bias_ae: Optional[torch.Tensor] = None,
89
+ attention_mask: Optional[torch.ByteTensor] = None,
90
+ is_causal: bool = True,
91
+ topk:int=None,
92
+ needs_weights:bool=None,
93
+ faiss_indexes:Tuple=None,
94
+ n_layers:int=None,
95
+ current_layer:int=None,
96
+ mask_by_sim:bool=False,
97
+ sim_threshold:float=None
98
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
99
+ a = self.norm_1(x)
100
+ b, attn_weights, past_key_value, reshaped_idx = self.attn(
101
+ a,
102
+ past_key_value=past_key_value,
103
+ long_range_past_key_value=long_range_past_key_value,
104
+ attn_bias=attn_bias,
105
+ attn_bias_ae=attn_bias_ae,
106
+ attention_mask=attention_mask,
107
+ is_causal=is_causal,
108
+ topk=topk,
109
+ needs_weights=needs_weights,
110
+ faiss_indexes=faiss_indexes,
111
+ n_layers=n_layers,
112
+ current_layer=current_layer,
113
+ mask_by_sim=mask_by_sim,
114
+ sim_threshold=sim_threshold
115
+ )
116
+ x = x + self.resid_attn_dropout(b)
117
+ m = self.norm_2(x)
118
+ n = self.ffn(m)
119
+ x = x + self.resid_ffn_dropout(n)
120
+ return x, attn_weights, past_key_value, reshaped_idx
configuration.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/mosaicml/llm-foundry
2
+ # Classes changed: MPTConfig
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+
6
+ """A HuggingFace-style model configuration."""
7
+
8
+ from typing import Dict, List, Optional, Union
9
+ from transformers import PretrainedConfig
10
+
11
+ attn_config_defaults: Dict = {
12
+ 'attn_type': 'multihead_attention',
13
+ 'attn_pdrop': 0.0,
14
+ 'attn_impl': 'torch',
15
+ 'qk_ln': False,
16
+ 'clip_qkv': None,
17
+ 'softmax_scale': None,
18
+ 'prefix_lm': False,
19
+ 'attn_uses_sequence_id': False,
20
+ 'alibi': True,
21
+ 'alibi_bias_max': 8,
22
+ "topk": 10,
23
+ 'mask_by_sim':True,
24
+ 'sim_threshold':0.25,
25
+ 'use_active_externalism':True,
26
+ 'memory_type':'manual'
27
+ }
28
+
29
+ init_config_defaults: Dict = {
30
+ 'name': 'kaiming_normal_',
31
+ 'fan_mode': 'fan_in',
32
+ 'init_nonlinearity': 'relu',
33
+ 'init_div_is_residual': True,
34
+ 'emb_init_std': None,
35
+ 'emb_init_uniform_lim': None,
36
+ 'init_std': None,
37
+ 'init_gain': 0.0,
38
+ }
39
+
40
+
41
+ class ExtendedMPTConfig(PretrainedConfig):
42
+ model_type = 'extended-mpt'
43
+
44
+ def __init__(
45
+ self,
46
+ d_model: int = 4096,
47
+ n_heads: int = 32,
48
+ n_layers: int = 32,
49
+ expansion_ratio: int = 4,
50
+ max_seq_len: int = 2048,
51
+ vocab_size: int = 50432,
52
+ resid_pdrop: float = 0.0,
53
+ emb_pdrop: float = 0.0,
54
+ learned_pos_emb: bool = True,
55
+ attn_config: Dict = attn_config_defaults,
56
+ init_device: str = 'cpu',
57
+ logit_scale: Optional[Union[float, str]] = None,
58
+ no_bias: bool = True,
59
+ verbose: int = 0,
60
+ embedding_fraction: float = 1.0,
61
+ norm_type: str = 'low_precision_layernorm',
62
+ use_cache: bool = False,
63
+ init_config: Dict = init_config_defaults,
64
+ use_active_externalism_by_layer: List[bool] = [True for _ in range(32)],
65
+ memory_device:str = 'cpu',
66
+ **kwargs,
67
+ ):
68
+ """The MPT configuration class.
69
+
70
+ Args:
71
+ d_model (int): The size of the embedding dimension of the model.
72
+ n_heads (int): The number of attention heads.
73
+ n_layers (int): The number of layers in the model.
74
+ expansion_ratio (int): The ratio of the up/down scale in the MLP.
75
+ max_seq_len (int): The maximum sequence length of the model.
76
+ vocab_size (int): The size of the vocabulary.
77
+ resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
78
+ emb_pdrop (float): The dropout probability for the embedding layer.
79
+ learned_pos_emb (bool): Whether to use learned positional embeddings
80
+ attn_config (Dict): A dictionary used to configure the model's attention module:
81
+ attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
82
+ attn_pdrop (float): The dropout probability for the attention layers.
83
+ attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
84
+ qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
85
+ clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
86
+ this value.
87
+ softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
88
+ use the default scale of ``1/sqrt(d_keys)``.
89
+ prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
90
+ extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
91
+ can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
92
+ attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
93
+ When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
94
+ which sub-sequence each token belongs to.
95
+ Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
96
+ alibi (bool): Whether to use the alibi bias instead of position embeddings.
97
+ alibi_bias_max (int): The maximum value of the alibi bias.
98
+ init_device (str): The device to use for parameter initialization.
99
+ logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
100
+ no_bias (bool): Whether to use bias in all layers.
101
+ verbose (int): The verbosity level. 0 is silent.
102
+ embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
103
+ norm_type (str): choose type of norm to use
104
+ multiquery_attention (bool): Whether to use multiquery attention implementation.
105
+ use_cache (bool): Whether or not the model should return the last key/values attentions
106
+ init_config (Dict): A dictionary used to configure the model initialization:
107
+ init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
108
+ 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
109
+ 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
110
+ init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
111
+ emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
112
+ emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
113
+ used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
114
+ init_std (float): The standard deviation of the normal distribution used to initialize the model,
115
+ if using the baseline_ parameter initialization scheme.
116
+ init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
117
+ fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
118
+ init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
119
+ ---
120
+ See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
121
+ """
122
+ self.d_model = d_model
123
+ self.n_heads = n_heads
124
+ self.n_layers = n_layers
125
+ self.expansion_ratio = expansion_ratio
126
+ self.max_seq_len = max_seq_len
127
+ self.vocab_size = vocab_size
128
+ self.resid_pdrop = resid_pdrop
129
+ self.emb_pdrop = emb_pdrop
130
+ self.learned_pos_emb = learned_pos_emb
131
+ self.attn_config = attn_config
132
+ self.init_device = init_device
133
+ self.logit_scale = logit_scale
134
+ self.no_bias = no_bias
135
+ self.verbose = verbose
136
+ self.embedding_fraction = embedding_fraction
137
+ self.norm_type = norm_type
138
+ self.use_cache = use_cache
139
+ self.init_config = init_config
140
+ self.use_active_externalism_by_layer = use_active_externalism_by_layer
141
+ self.memory_device = memory_device
142
+ if 'name' in kwargs:
143
+ del kwargs['name']
144
+ if 'loss_fn' in kwargs:
145
+ del kwargs['loss_fn']
146
+ super().__init__(**kwargs)
147
+
148
+ self._validate_config()
149
+
150
+ def _set_config_defaults(self, config, config_defaults):
151
+ # set config defaults
152
+ for k, v in config_defaults.items():
153
+ if k not in config:
154
+ config[k] = v
155
+ return config
156
+
157
+ def _validate_config(self):
158
+ # set config defaults
159
+ self.attn_config = self._set_config_defaults(
160
+ self.attn_config,
161
+ attn_config_defaults,
162
+ )
163
+ self.init_config = self._set_config_defaults(
164
+ self.init_config,
165
+ init_config_defaults,
166
+ )
167
+
168
+ if self.d_model % self.n_heads != 0:
169
+ raise ValueError('d_model must be divisible by n_heads')
170
+ if any(
171
+ prob < 0 or prob > 1 for prob in
172
+ [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop]):
173
+ raise ValueError(
174
+ "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"
175
+ )
176
+ if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
177
+ raise ValueError(
178
+ f"Unknown attn_impl={self.attn_config['attn_impl']}")
179
+ if self.attn_config['prefix_lm'] and self.attn_config[
180
+ 'attn_impl'] not in ['torch', 'triton']:
181
+ raise NotImplementedError(
182
+ 'prefix_lm only implemented with torch and triton attention.')
183
+ if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in [
184
+ 'torch', 'triton'
185
+ ]:
186
+ raise NotImplementedError(
187
+ 'alibi only implemented with torch and triton attention.')
188
+ if self.attn_config['attn_uses_sequence_id'] and self.attn_config[
189
+ 'attn_impl'] not in ['torch', 'triton']:
190
+ raise NotImplementedError(
191
+ 'attn_uses_sequence_id only implemented with torch and triton attention.'
192
+ )
193
+ if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
194
+ raise ValueError(
195
+ 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'
196
+ )
197
+ if isinstance(self.logit_scale,
198
+ str) and self.logit_scale != 'inv_sqrt_d_model':
199
+ raise ValueError(
200
+ f"{self.logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
201
+ )
202
+ if self.init_config.get('name', None) is None:
203
+ raise ValueError(f"{self.init_config=} 'name' needs to be set.")
204
+ if not self.learned_pos_emb and not self.attn_config['alibi']:
205
+ raise ValueError(
206
+ f'Positional information must be provided to the model using either learned_pos_emb or alibi.'
207
+ )
modeling_mpt.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/mosaicml/llm-foundry
2
+ # Classes changed: MPTModel, MPTForCausalLM
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """A simple, flexible implementation of a GPT model.
6
+
7
+ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
8
+ """
9
+
10
+ import math
11
+ import warnings
12
+ from typing import List, Optional, Tuple, Union
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.linalg import vector_norm
17
+ import faiss
18
+ from einops import rearrange
19
+ from composer.utils import dist
20
+ from omegaconf import DictConfig
21
+
22
+ from transformers import (PreTrainedModel, PreTrainedTokenizer,
23
+ PreTrainedTokenizerFast)
24
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
25
+ CausalLMOutputWithPast)
26
+ from llmfoundry.models.layers.custom_embedding import SharedEmbedding
27
+ from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
28
+ from llmfoundry.models.utils.param_init_fns import MODEL_INIT_REGISTRY
29
+
30
+ from .configuration import ExtendedMPTConfig
31
+ from .attention import attn_bias_shape, build_attn_bias
32
+ from .blocks import MPTBlock
33
+ from .utils import instantiate_from_config
34
+
35
+ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
36
+
37
+ class MPTPreTrainedModel(PreTrainedModel):
38
+ config_class = ExtendedMPTConfig
39
+ base_model_prefix = 'model'
40
+ _no_split_modules = ['MPTBlock']
41
+
42
+ class ExtendedMPTModel(MPTPreTrainedModel):
43
+
44
+ def __init__(self, config: ExtendedMPTConfig):
45
+ config._validate_config()
46
+ super().__init__(config)
47
+
48
+ self.attn_impl = config.attn_config['attn_impl']
49
+ self.prefix_lm = config.attn_config['prefix_lm']
50
+ self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
51
+ self.alibi = config.attn_config['alibi']
52
+ self.alibi_bias_max = config.attn_config['alibi_bias_max']
53
+
54
+ self.mask_by_sim = config.attn_config['mask_by_sim']
55
+ self.sim_threshold = config.attn_config['sim_threshold']
56
+ self.topk = config.attn_config['topk']
57
+ self.use_active_externalism = config.attn_config['use_active_externalism']
58
+
59
+ self.use_active_externalism_by_layer = config.use_active_externalism_by_layer
60
+
61
+ if config.init_device == 'mixed':
62
+ if dist.get_local_rank() == 0:
63
+ config.init_device = 'cpu'
64
+ else:
65
+ config.init_device = 'meta'
66
+
67
+ if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
68
+ norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
69
+ raise NotImplementedError(
70
+ f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).'
71
+ )
72
+ norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
73
+
74
+ # CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414)
75
+ # both report this helping with stabilizing training
76
+ self.embedding_fraction = config.embedding_fraction
77
+
78
+ self.wte = SharedEmbedding(config.vocab_size,
79
+ config.d_model,
80
+ device=config.init_device)
81
+ if not self.alibi:
82
+ self.wpe = torch.nn.Embedding(config.max_seq_len,
83
+ config.d_model,
84
+ device=config.init_device)
85
+ self.emb_drop = nn.Dropout(config.emb_pdrop)
86
+ self.blocks = nn.ModuleList([
87
+ MPTBlock(
88
+ device=config.init_device,
89
+ **config.to_dict(),
90
+ ) for _ in range(config.n_layers)
91
+ ])
92
+ self.norm_f = norm_class(config.d_model, device=config.init_device)
93
+
94
+ if config.init_device != 'meta':
95
+ print(
96
+ f'You are using {config.init_device=}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
97
+ )
98
+ self.apply(self.param_init_fn)
99
+
100
+ self.is_causal = not self.prefix_lm
101
+
102
+ # define attn mask
103
+ self._attn_bias_initialized = False
104
+ self.attn_bias = None
105
+ self.attn_bias_shape = attn_bias_shape(
106
+ self.attn_impl,
107
+ config.n_heads,
108
+ config.max_seq_len,
109
+ self.alibi,
110
+ prefix_lm=self.prefix_lm,
111
+ causal=self.is_causal,
112
+ use_sequence_id=self.attn_uses_sequence_id,
113
+ )
114
+ self._attn_bias_ae_initialized = False
115
+ self.attn_bias_ae = None
116
+
117
+ if self.config.no_bias:
118
+ for module in self.modules():
119
+ if hasattr(module, 'bias') and isinstance(
120
+ module.bias, nn.Parameter):
121
+ if self.config.verbose:
122
+ warnings.warn(
123
+ f'Removing bias ({module.bias}) from {module}.')
124
+ module.register_parameter('bias', None)
125
+
126
+ # Print verbose info
127
+ if config.verbose and config.verbose > 2:
128
+ print(self)
129
+ if 'verbose' not in self.config.init_config:
130
+ self.config.init_config['verbose'] = self.config.verbose
131
+ if self.config.init_config['verbose'] > 1:
132
+ init_fn_name = self.config.init_config['name']
133
+ warnings.warn(f'Using {init_fn_name} initialization.')
134
+
135
+ def get_input_embeddings(self):
136
+ return self.wte
137
+
138
+ def set_input_embeddings(self, value: nn.Embedding):
139
+ self.wte = value
140
+
141
+ @torch.no_grad()
142
+ def _attn_bias(
143
+ self,
144
+ device,
145
+ dtype,
146
+ attention_mask: Optional[torch.ByteTensor] = None,
147
+ prefix_mask: Optional[torch.ByteTensor] = None,
148
+ sequence_id: Optional[torch.LongTensor] = None,
149
+ seq_len: Optional[int] = None,
150
+ use_active_externalism:bool=None,
151
+ topk=None,
152
+ ):
153
+ if not self._attn_bias_initialized:
154
+ if self.attn_bias_shape:
155
+ self.attn_bias = torch.zeros(self.attn_bias_shape,
156
+ device=device,
157
+ dtype=dtype)
158
+ self.attn_bias = build_attn_bias(
159
+ self.attn_impl,
160
+ self.config.n_heads,
161
+ self.config.max_seq_len,
162
+ device=device,
163
+ dtype=dtype,
164
+ attn_bias = self.attn_bias,
165
+ causal=self.is_causal,
166
+ alibi=self.alibi,
167
+ alibi_bias_max=self.alibi_bias_max
168
+ )
169
+ self._attn_bias_initialized = True
170
+
171
+ if use_active_externalism:
172
+ self.attn_bias_ae = build_attn_bias(
173
+ self.attn_impl,
174
+ self.config.n_heads,
175
+ seq_len,
176
+ device=device,
177
+ dtype=dtype,
178
+ causal=self.is_causal,
179
+ alibi=self.alibi,
180
+ alibi_bias_max=self.alibi_bias_max,
181
+ for_ae=use_active_externalism,
182
+ topk=topk
183
+ )
184
+
185
+ self._attn_bias_ae_initialized = True
186
+
187
+ # flash does not support prefix_lm and will incorporate any
188
+ # attention_mask inside the attention module
189
+ if self.attn_impl == 'flash':
190
+ return self.attn_bias, attention_mask
191
+
192
+ if self.attn_bias is not None:
193
+ # .to(*args, **kwargs) is a no-op if tensor is already on
194
+ # specified device or of specificed dtype
195
+ self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
196
+
197
+ attn_bias = self.attn_bias
198
+
199
+ if self.attn_bias_ae is not None:
200
+ self.attn_bias_ae = self.attn_bias_ae.to(dtype=dtype, device=device)
201
+ attn_bias_ae = self.attn_bias_ae
202
+
203
+ # If using torch or triton, we incorporate the prefix_mask (if appropriate)
204
+ if self.prefix_lm:
205
+ assert isinstance(attn_bias, torch.Tensor) # pyright
206
+ assert isinstance(prefix_mask, torch.Tensor) # pyright
207
+ attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
208
+
209
+ # If using torch or triton, we incorporate sequence_id (if appropriate)
210
+ if self.attn_uses_sequence_id and sequence_id is not None:
211
+ assert isinstance(attn_bias, torch.Tensor) # pyright
212
+ attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
213
+
214
+ # If using torch or triton, we incorporate attention_mask. This will output
215
+ # None in place of attention_mask since it will not be further needed in the
216
+ # attention modules.
217
+ if attention_mask is not None:
218
+ s_k = attention_mask.shape[-1]
219
+ if attn_bias is None:
220
+ attn_bias = torch.zeros((1, 1, 1, s_k),
221
+ device=device,
222
+ dtype=dtype)
223
+ else:
224
+ # clamp to 0 necessary for torch 2.0 compile()
225
+ _s_k = max(0, attn_bias.size(-1) - s_k)
226
+ attn_bias = attn_bias[:, :, :, _s_k:]
227
+ if prefix_mask is not None and (attention_mask.shape !=
228
+ prefix_mask.shape):
229
+ raise ValueError(
230
+ f'attention_mask shape={attention_mask.shape} ' +
231
+ f'and prefix_mask shape={prefix_mask.shape} are not equal.')
232
+ min_val = torch.finfo(attn_bias.dtype).min
233
+ attn_bias = attn_bias.masked_fill(
234
+ ~attention_mask.view(-1, 1, 1, s_k), min_val)
235
+
236
+ return attn_bias, attn_bias_ae, None
237
+
238
+ def _apply_prefix_mask(self, attn_bias: torch.Tensor,
239
+ prefix_mask: torch.Tensor):
240
+ s_k, s_q = attn_bias.shape[-2:]
241
+ if (s_k != self.config.max_seq_len) or (s_q != self.config.max_seq_len):
242
+ raise ValueError(
243
+ 'attn_bias does not match the expected shape. ' +
244
+ f'The last two dimensions should both be {self.config.max_length} '
245
+ + f'but are {s_k} and {s_q}.')
246
+ seq_len = prefix_mask.shape[-1]
247
+ if seq_len > self.config.max_seq_len:
248
+ raise ValueError(
249
+ f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
250
+ )
251
+
252
+ # select seq_len subset of attn mask
253
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
254
+
255
+ # Mix the causal max and the bidirectional mask to get the full
256
+ # allowable attention (i.e. full = not accounting for padding yet)
257
+ causal = torch.tril(
258
+ torch.ones((seq_len, seq_len),
259
+ dtype=torch.bool,
260
+ device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
261
+ prefix = prefix_mask.view(-1, 1, 1, seq_len)
262
+ cannot_attend = ~torch.logical_or(causal, prefix.bool())
263
+
264
+ min_val = torch.finfo(attn_bias.dtype).min
265
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
266
+
267
+ return attn_bias
268
+
269
+ def _apply_sequence_id(self, attn_bias: torch.Tensor,
270
+ sequence_id: torch.LongTensor):
271
+ seq_len = sequence_id.shape[-1]
272
+ if seq_len > self.config.max_seq_len:
273
+ raise ValueError(
274
+ f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
275
+ )
276
+
277
+ # select seq_len subset of attn mask
278
+ attn_bias = attn_bias[..., :seq_len, :seq_len]
279
+
280
+ # Restrict attention to tokens that share the same value
281
+ # in sequence_id
282
+ cannot_attend = torch.logical_not(
283
+ torch.eq(
284
+ sequence_id.view(-1, seq_len, 1),
285
+ sequence_id.view(-1, 1, seq_len),
286
+ )).unsqueeze(1)
287
+ min_val = torch.finfo(attn_bias.dtype).min
288
+ attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
289
+
290
+ return attn_bias
291
+
292
+ def forward(
293
+ self,
294
+ input_ids: torch.LongTensor,
295
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
296
+ attention_mask: Optional[torch.ByteTensor] = None,
297
+ prefix_mask: Optional[torch.ByteTensor] = None,
298
+ sequence_id: Optional[torch.LongTensor] = None,
299
+ return_dict: Optional[bool] = None,
300
+ output_attentions: Optional[bool] = None,
301
+ output_hidden_states: Optional[bool] = None,
302
+ use_cache: Optional[bool] = None,
303
+ inputs_embeds: Optional[torch.Tensor] = None,
304
+ use_active_externalism:Optional[bool]=None,
305
+ long_range_past_key_values:Optional[List[Tuple[torch.FloatTensor]]] = None,
306
+ faiss_indexes:Tuple=None,
307
+ topk:int=None,
308
+ ):
309
+ return_dict = (return_dict
310
+ if return_dict is not None else self.config.return_dict)
311
+ use_cache = (use_cache
312
+ if use_cache is not None else self.config.use_cache)
313
+ use_active_externalism = (use_active_externalism
314
+ if use_active_externalism is not None else self.use_active_externalism)
315
+ topk = (topk if topk is not None else self.topk)
316
+
317
+ if attention_mask is not None:
318
+ attention_mask = attention_mask.bool()
319
+
320
+ if prefix_mask is not None:
321
+ prefix_mask = prefix_mask.bool()
322
+
323
+ # These args are passed in by keyword in huggingface's generate function
324
+ # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
325
+ # but have not yet been fully implemented in MPTModel
326
+ if not return_dict:
327
+ raise NotImplementedError(
328
+ 'return_dict False is not implemented yet for MPT')
329
+ if output_attentions:
330
+ if self.attn_impl != 'torch':
331
+ raise NotImplementedError(
332
+ 'output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.'
333
+ )
334
+
335
+ if (attention_mask is not None and
336
+ attention_mask[:, 0].sum() != attention_mask.shape[0] and
337
+ self.training):
338
+ raise NotImplementedError(
339
+ 'MPT does not support training with left padding.')
340
+
341
+ if self.prefix_lm and prefix_mask is None:
342
+ raise ValueError(
343
+ 'prefix_mask is a required argument when MPT is configured with prefix_lm=True.'
344
+ )
345
+
346
+ # Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT)
347
+ if inputs_embeds is not None:
348
+ raise NotImplementedError(
349
+ 'inputs_embeds is not implemented for MPT.')
350
+
351
+ if self.training:
352
+ if self.attn_uses_sequence_id and sequence_id is None:
353
+ raise ValueError(
354
+ 'sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True '
355
+ + 'and the model is in train mode.')
356
+ elif (self.attn_uses_sequence_id is False) and (sequence_id
357
+ is not None):
358
+ warnings.warn(
359
+ 'MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. '
360
+ +
361
+ 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.'
362
+ )
363
+
364
+ S = input_ids.size(1)
365
+
366
+ assert (
367
+ S <= self.config.max_seq_len
368
+ ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
369
+
370
+ tok_emb = self.wte(input_ids) # type: ignore
371
+ if self.alibi:
372
+ x = tok_emb
373
+ else:
374
+ past_position = 0
375
+ if past_key_values is not None:
376
+ if len(past_key_values) != self.config.n_layers:
377
+ raise ValueError(
378
+ f'past_key_values must provide a past_key_value for each attention '
379
+ +
380
+ f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
381
+ )
382
+ # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
383
+ # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
384
+ # Here we shift position embedding using the `seq` dim of the past key
385
+ past_position = past_key_values[0][0].size(1)
386
+ if self.attn_impl == 'torch':
387
+ past_position = past_key_values[0][0].size(3)
388
+
389
+ if S + past_position > self.config.max_seq_len:
390
+ raise ValueError(
391
+ f'Cannot forward input with past sequence length {past_position} and current sequence length '
392
+ f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.'
393
+ )
394
+ pos = torch.arange(
395
+ past_position,
396
+ S + past_position,
397
+ dtype=torch.long,
398
+ device=input_ids.device,
399
+ ).unsqueeze(0)
400
+ if attention_mask is not None:
401
+ # adjust the position indices to account for padding tokens
402
+ pos = torch.clamp(
403
+ pos - torch.cumsum((~attention_mask).to(torch.int32),
404
+ dim=1)[:, past_position:],
405
+ min=0,
406
+ )
407
+
408
+ pos_emb = self.wpe(pos) # type: ignore
409
+ x = tok_emb + pos_emb
410
+
411
+ if self.embedding_fraction == 1:
412
+ x = self.emb_drop(x) # type: ignore
413
+ else:
414
+ # this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414
415
+ x_shrunk = (x * self.embedding_fraction) + (
416
+ x.detach() * (1 - self.embedding_fraction))
417
+ assert isinstance(self.emb_drop, nn.Module) # pyright
418
+ x = self.emb_drop(x_shrunk)
419
+
420
+ # self._attn_bias_initialized = False #right now this needs to run each step
421
+
422
+ seq_len = S
423
+ if past_key_values is not None:
424
+ past_position = past_key_values[0][0].size(-1)
425
+ seq_len += past_position
426
+
427
+ attn_bias, attn_bias_ae, attention_mask = self._attn_bias(
428
+ device=x.device,
429
+ dtype=torch.float32,
430
+ attention_mask=attention_mask,
431
+ prefix_mask=prefix_mask,
432
+ sequence_id=sequence_id,
433
+ seq_len = seq_len,
434
+ use_active_externalism=use_active_externalism,
435
+ topk=topk
436
+ )
437
+
438
+ # initialize the past key values cache if it should be used
439
+ if use_cache and past_key_values is None:
440
+ past_key_values = [() for _ in range(self.config.n_layers)
441
+ ] # type: ignore
442
+
443
+ all_hidden_states = () if output_hidden_states else None
444
+ all_self_attns = () if output_attentions else None
445
+ all_idx = () if output_attentions else None
446
+ for b_idx, block in enumerate(self.blocks): # type: ignore
447
+ if output_hidden_states:
448
+ assert all_hidden_states is not None # pyright
449
+ all_hidden_states = all_hidden_states + (x,)
450
+ past_key_value = (past_key_values[b_idx]
451
+ if past_key_values is not None else None)
452
+ long_range_past_key_value = (long_range_past_key_values[b_idx]
453
+ if (long_range_past_key_values is not None and self.use_active_externalism_by_layer[b_idx] and use_active_externalism is True) else None)
454
+
455
+ if long_range_past_key_value is not None and faiss_indexes is not None:
456
+ raise NotImplementedError(
457
+ 'Using faiss and passing key value pairs manually are mutually exclusive right now.')
458
+
459
+ x, attn_weights, past_key_value, reshaped_idx = block(
460
+ x,
461
+ past_key_value=past_key_value,
462
+ long_range_past_key_value=long_range_past_key_value,
463
+ attn_bias=attn_bias,
464
+ attention_mask=attention_mask,
465
+ attn_bias_ae=attn_bias_ae,
466
+ is_causal=self.is_causal,
467
+ topk=topk,
468
+ needs_weights=output_attentions,
469
+ faiss_indexes=faiss_indexes,
470
+ n_layers=self.config.n_layers,
471
+ current_layer=b_idx,
472
+ mask_by_sim=self.mask_by_sim,
473
+ sim_threshold=self.sim_threshold,
474
+ )
475
+ if past_key_values is not None:
476
+ past_key_values[b_idx] = past_key_value
477
+
478
+ if output_attentions:
479
+ assert all_self_attns is not None # pyright
480
+ all_self_attns = all_self_attns + (attn_weights,)
481
+
482
+ assert all_idx is not None
483
+ all_idx = all_idx + (reshaped_idx,)
484
+
485
+ x = self.norm_f(x) # type: ignore
486
+
487
+ # add hidden states from the last decoder layer
488
+ if output_hidden_states:
489
+ assert all_hidden_states is not None # pyright
490
+ all_hidden_states = all_hidden_states + (x,)
491
+
492
+ return BaseModelOutputWithPast(
493
+ last_hidden_state=x,
494
+ past_key_values=past_key_values,
495
+ hidden_states=all_hidden_states,
496
+ attentions=(all_self_attns, all_idx),
497
+ )
498
+
499
+ # Param Initialization, needed for device='meta' fast initialization
500
+ def param_init_fn(self, module):
501
+ init_fn_name = self.config.init_config['name']
502
+ MODEL_INIT_REGISTRY[init_fn_name](
503
+ module=module,
504
+ n_layers=self.config.n_layers,
505
+ d_model=self.config.d_model,
506
+ **self.config.init_config,
507
+ )
508
+
509
+ # FSDP Wrap function
510
+ def fsdp_wrap_fn(self, module):
511
+ return isinstance(module, MPTBlock)
512
+
513
+ # Activation Checkpointing
514
+ def activation_checkpointing_fn(self, module):
515
+ return isinstance(module, MPTBlock)
516
+
517
+ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
518
+
519
+ def __init__(self, config:ExtendedMPTConfig, external_memories=None):
520
+ if isinstance(config, DictConfig):
521
+ config = instantiate_from_config(config)
522
+
523
+ super().__init__(config)
524
+ if not config.tie_word_embeddings:
525
+ raise ValueError(
526
+ 'MPTForCausalLM only supports tied word embeddings')
527
+
528
+ print(f'Instantiating an MPTForCausalLM model from {__file__}')
529
+
530
+ self.transformer: ExtendedMPTModel = ExtendedMPTModel(config)
531
+
532
+ self.use_active_externalism = config.attn_config['use_active_externalism']
533
+ self.memory_type = config.attn_config['memory_type']
534
+ self._memories = None
535
+ self.memory_device = config.memory_device
536
+
537
+ for child in self.transformer.children():
538
+ if isinstance(child, torch.nn.ModuleList):
539
+ continue
540
+ if isinstance(child, torch.nn.Module):
541
+ child._fsdp_wrap = True
542
+
543
+ # enables scaling output logits; similar to a softmax "temperature"
544
+ # PaLM paper uses scale 1/sqrt(config.d_model)
545
+ self.logit_scale = None
546
+ if config.logit_scale is not None:
547
+ logit_scale = config.logit_scale
548
+ if isinstance(logit_scale, str):
549
+ if logit_scale == 'inv_sqrt_d_model':
550
+ logit_scale = 1 / math.sqrt(config.d_model)
551
+ else:
552
+ raise ValueError(
553
+ f"{logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
554
+ )
555
+ self.logit_scale = logit_scale
556
+
557
+ if external_memories is not None:
558
+ self._memories = external_memories
559
+ self.memories = None
560
+
561
+ def set_memories(self, memories):
562
+ self.memories = memories
563
+
564
+ def empty_memories(self):
565
+ self.memories = None
566
+
567
+ def get_input_embeddings(self):
568
+ return self.transformer.wte
569
+
570
+ def set_input_embeddings(self, value):
571
+ self.transformer.wte = value
572
+
573
+ def get_output_embeddings(self):
574
+ return self.transformer.wte
575
+
576
+ def set_output_embeddings(self, new_embeddings):
577
+ self.transformer.wte = new_embeddings
578
+
579
+ def set_decoder(self, decoder):
580
+ self.transformer = decoder
581
+
582
+ def get_decoder(self):
583
+ return self.transformer
584
+
585
+ def forward(
586
+ self,
587
+ input_ids: torch.LongTensor,
588
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
589
+ attention_mask: Optional[torch.ByteTensor] = None,
590
+ prefix_mask: Optional[torch.ByteTensor] = None,
591
+ sequence_id: Optional[torch.LongTensor] = None,
592
+ labels: Optional[torch.LongTensor] = None,
593
+ return_dict: Optional[bool] = None,
594
+ output_attentions: Optional[bool] = None,
595
+ output_hidden_states: Optional[bool] = None,
596
+ use_cache: Optional[bool] = None,
597
+ inputs_embeds: Optional[torch.FloatTensor] = None,
598
+ use_active_externalism: Optional[bool]=None,
599
+ topk:int=None
600
+ ):
601
+ if self._memories is not None and self.memories is None:
602
+ self.memories = self.generate_cache(self._memories, cache_type=self.memory_type)
603
+
604
+ return_dict = (return_dict
605
+ if return_dict is not None else self.config.return_dict)
606
+ use_cache = (use_cache
607
+ if use_cache is not None else self.config.use_cache)
608
+ use_active_externalism = (use_active_externalism
609
+ if use_active_externalism is not None else self.use_active_externalism)
610
+
611
+ topk = topk if topk is not None else None
612
+
613
+ # if input_embeds is not none, raise a not implemented error
614
+ if inputs_embeds is not None:
615
+ raise NotImplementedError(
616
+ 'inputs_embeds has to be None (for hf/peft support).')
617
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
618
+
619
+ if hasattr(self, "memories") and type(self.memories)==list:
620
+ long_range_past_key_values = self.memories
621
+ faiss_indexes = None
622
+ elif hasattr(self, "memories"):
623
+ long_range_past_key_values = None
624
+ faiss_indexes = self.memories
625
+ else:
626
+ long_range_past_key_values = None
627
+ faiss_indexes = None
628
+
629
+ outputs = self.transformer(
630
+ input_ids=input_ids,
631
+ past_key_values=past_key_values,
632
+ long_range_past_key_values=long_range_past_key_values,
633
+ faiss_indexes=faiss_indexes,
634
+ attention_mask=attention_mask,
635
+ prefix_mask=prefix_mask,
636
+ sequence_id=sequence_id,
637
+ return_dict=return_dict,
638
+ output_attentions=output_attentions,
639
+ output_hidden_states=output_hidden_states,
640
+ use_cache=use_cache,
641
+ use_active_externalism=use_active_externalism,
642
+ topk=topk
643
+ )
644
+
645
+ # move outputs to same device as weights for token embedding
646
+ # needed to support HF `device_map`
647
+ logits = self.transformer.wte(
648
+ outputs.last_hidden_state.to(self.transformer.wte.weight.device),
649
+ True,
650
+ )
651
+
652
+ if self.logit_scale is not None:
653
+ if self.logit_scale == 0:
654
+ warnings.warn(
655
+ f'Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.'
656
+ )
657
+ logits *= self.logit_scale
658
+
659
+ loss = None
660
+ if labels is not None:
661
+ _labels = torch.roll(labels, shifts=-1)
662
+ _labels[:, -1] = -100
663
+ loss = F.cross_entropy(
664
+ logits.view(-1, logits.size(-1)),
665
+ _labels.to(logits.device).view(-1),
666
+ )
667
+
668
+ return CausalLMOutputWithPast(
669
+ loss=loss,
670
+ logits=logits,
671
+ past_key_values=outputs.past_key_values,
672
+ hidden_states=outputs.hidden_states,
673
+ attentions=outputs.attentions,
674
+ )
675
+
676
+ # Param Initialization, needed for device='meta' fast initialization
677
+ def param_init_fn(self, module):
678
+ init_fn_name = self.config.init_config['name']
679
+ MODEL_INIT_REGISTRY[init_fn_name](
680
+ module=module,
681
+ n_layers=self.config.n_layers,
682
+ d_model=self.config.d_model,
683
+ **self.config.init_config,
684
+ )
685
+
686
+ # FSDP Wrap function
687
+ def fsdp_wrap_fn(self, module):
688
+ return isinstance(module, MPTBlock)
689
+
690
+ # Activation Checkpointing
691
+ def activation_checkpointing_fn(self, module):
692
+ return isinstance(module, MPTBlock)
693
+
694
+ def generate_cache(self,
695
+ input_ids:torch.LongTensor,
696
+ stride:int=512,
697
+ max_len:int=2048,
698
+ cache_type:str='manual'):
699
+ if cache_type not in ['manual', 'faiss']:
700
+ raise NotImplementedError(f"Cache type {cache_type} not implemented.")
701
+
702
+ prev_end_loc=0
703
+ long_range_past_key_values = None
704
+ faiss_indexes= None
705
+ for b_idx in range(0, input_ids.size(-1), stride):
706
+ end_loc = min(b_idx + max_len, input_ids.size(-1))
707
+
708
+ trg_len = end_loc - prev_end_loc
709
+ subseq = input_ids[:, b_idx:end_loc].to(self.device)
710
+ with torch.no_grad():
711
+ outputs = self.transformer(subseq, use_cache=True, use_active_externalism=False)
712
+ to_cache = [(
713
+ kv[0][:,:,:,-trg_len:],
714
+ kv[1][:,:,-trg_len:])
715
+ for kv in outputs.past_key_values
716
+ ]
717
+ long_range_past_key_values, faiss_indexes = self.cache(to_cache, cache_type, long_range_past_key_values=long_range_past_key_values, faiss_indexes=faiss_indexes)
718
+
719
+ prev_end_loc = end_loc
720
+ if end_loc == input_ids.size(-1):
721
+ break
722
+ if long_range_past_key_values is not None:
723
+ return long_range_past_key_values
724
+ else:
725
+ return faiss_indexes
726
+
727
+ def cache(self,
728
+ to_cache:List,
729
+ cache_type:str='manual',
730
+ long_range_past_key_values:List=None,
731
+ faiss_indexes:faiss.IndexFlatIP=None,
732
+ max_length_cache=100000,
733
+ verbose=False):
734
+ if long_range_past_key_values is not None and faiss_indexes is not None:
735
+ raise NotImplementedError("Using faiss and passing key value pairs manually are mutually exclusive right now.")
736
+
737
+ if cache_type=='faiss':
738
+ one_hot_encodings = F.one_hot(torch.arange(0, self.config.n_heads*self.config.n_layers))*10
739
+ if faiss_indexes is None:
740
+ faiss_indexes = (faiss.IndexFlatIP(to_cache[0][0].size(-2)+one_hot_encodings.size(-1)), faiss.IndexFlatIP(to_cache[0][1].size(-1)*2))
741
+ kn_index, kv_index = faiss_indexes
742
+ for b_idx, (k, v) in enumerate(to_cache):
743
+ k_n = (k/vector_norm(k, ord=2, dim=-2, keepdim=True)).to('cpu')
744
+ k_n = torch.concat([rearrange(k_n, 'b h d s -> b (h s) d', h=self.config.n_heads), one_hot_encodings[self.config.n_heads*b_idx:self.config.n_heads*(b_idx+1)].unsqueeze(0).repeat_interleave(repeats=k.size(-1), dim=-2)], dim=-1)
745
+ kn_index.add(k_n.squeeze().numpy())
746
+
747
+ k= rearrange(k, 'b h d s -> b (h s) d', h=self.config.n_heads)
748
+ v= rearrange(v, 'b h s d -> b (h s) d', h=self.config.n_heads)
749
+ kv_index.add(torch.concat([v.squeeze(), k.squeeze()], dim=1).to('cpu').numpy())
750
+
751
+ else:
752
+ if long_range_past_key_values is None:
753
+ long_range_past_key_values = [(k.to(self.memory_device),v.to(self.memory_device)) for k,v in to_cache]
754
+ else:
755
+ long_range_past_key_values = [
756
+ (
757
+ torch.concat([kv[0], to_cache[ind][0].to(self.memory_device)], dim=3),
758
+ torch.concat([kv[1], to_cache[ind][1].to(self.memory_device)], dim=2)
759
+ )
760
+ for ind, kv in enumerate(long_range_past_key_values)
761
+ ]
762
+ if long_range_past_key_values is not None:
763
+ if long_range_past_key_values[0][0].size(-1) > max_length_cache: #set a limit on manual memory length
764
+ long_range_past_key_values = [
765
+ (
766
+ kv[0][:, :, :, -max_length_cache:],
767
+ kv[1][:, :, -max_length_cache:]
768
+ )
769
+ for kv in long_range_past_key_values]
770
+ if verbose:
771
+ if cache_type == 'faiss':
772
+ print(f"{kn_index.ntotal} keys in faiss index")
773
+ else:
774
+ print(f"{long_range_past_key_values[0][0].size(-1)} cached kvs")
775
+
776
+ return long_range_past_key_values, (kn_index, kv_index) if cache_type == 'faiss' else None
777
+
778
+ def prepare_inputs_for_generation(
779
+ self,
780
+ input_ids,
781
+ past_key_values=None,
782
+ inputs_embeds=None,
783
+ **kwargs,
784
+ ):
785
+ if inputs_embeds is not None:
786
+ raise NotImplementedError(
787
+ 'inputs_embeds is not implemented for MPT yet')
788
+
789
+ attention_mask = kwargs['attention_mask'].bool()
790
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
791
+ raise NotImplementedError(
792
+ 'MPT does not support generation with right padding.')
793
+
794
+ if self.transformer.attn_uses_sequence_id and self.training:
795
+ sequence_id = torch.zeros_like(input_ids[:1])
796
+ else:
797
+ sequence_id = None
798
+
799
+ if past_key_values is not None:
800
+ input_ids = input_ids[:, -1].unsqueeze(-1)
801
+
802
+ if self.transformer.prefix_lm:
803
+ # Leverage a convenience of sequential generation!
804
+ prefix_mask = torch.ones_like(attention_mask)
805
+ # This requires that we're using the cache
806
+ if kwargs.get('use_cache') == False:
807
+ raise NotImplementedError(
808
+ 'MPT with prefix_lm=True does not support use_cache=False.')
809
+ else:
810
+ prefix_mask = None
811
+
812
+ return {
813
+ 'input_ids': input_ids,
814
+ 'attention_mask': attention_mask,
815
+ 'prefix_mask': prefix_mask,
816
+ 'sequence_id': sequence_id,
817
+ 'past_key_values': past_key_values,
818
+ 'use_cache': kwargs.get('use_cache', True),
819
+ 'use_active_externalism': kwargs.get('use_active_externalism'),
820
+ 'topk': kwargs.get('topk', None),
821
+ }
822
+
823
+ @staticmethod
824
+ def _reorder_cache(past_key_values, beam_idx):
825
+ """Used by HuggingFace generate when using beam search with kv-caching.
826
+
827
+ See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
828
+ for an example in transformers.
829
+ """
830
+ reordered_past = []
831
+ for layer_past in past_key_values:
832
+ reordered_past += [
833
+ tuple(
834
+ past_state.index_select(0, beam_idx)
835
+ for past_state in layer_past)
836
+ ]
837
+ return reordered_past
utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import *
2
+
3
+ import importlib
4
+
5
+
6
+ def instantiate_from_config(config):
7
+ if not "target" in config:
8
+ raise KeyError("Expected key `target` to instantiate.")
9
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
10
+
11
+
12
+ def get_obj_from_str(string, reload=False):
13
+ module, cls = string.rsplit(".", 1)
14
+ if reload:
15
+ module_imp = importlib.import_module(module)
16
+ importlib.reload(module_imp)
17
+ return getattr(importlib.import_module(module, package=None), cls)