phoebeklett commited on
Commit
77921fc
1 Parent(s): f8a6b8e

Delete modeling_mpt.py

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +0 -833
modeling_mpt.py DELETED
@@ -1,833 +0,0 @@
1
- # Adapted from https://github.com/mosaicml/llm-foundry
2
- # Classes changed: MPTModel, MPTForCausalLM
3
- # SPDX-License-Identifier: Apache-2.0
4
-
5
- """A simple, flexible implementation of a GPT model.
6
-
7
- Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
8
- """
9
-
10
- import math
11
- import warnings
12
- from typing import List, Optional, Tuple, Union
13
- import torch
14
- import torch.nn as nn
15
- import torch.nn.functional as F
16
- from torch.linalg import vector_norm
17
- import faiss
18
- from einops import rearrange
19
- from composer.utils import dist
20
- from omegaconf import DictConfig
21
-
22
- from transformers import (PreTrainedModel, PreTrainedTokenizer,
23
- PreTrainedTokenizerFast)
24
- from transformers.modeling_outputs import (BaseModelOutputWithPast,
25
- CausalLMOutputWithPast)
26
- from llmfoundry.models.layers.custom_embedding import SharedEmbedding
27
- from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
28
- from llmfoundry.models.utils.param_init_fns import MODEL_INIT_REGISTRY
29
-
30
- from .configuration import ExtendedMPTConfig
31
- from .attention import attn_bias_shape, build_attn_bias
32
- from .blocks import MPTBlock
33
- from .utils import instantiate_from_config
34
-
35
- Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
36
-
37
- class MPTPreTrainedModel(PreTrainedModel):
38
- config_class = ExtendedMPTConfig
39
- base_model_prefix = 'model'
40
- _no_split_modules = ['MPTBlock']
41
-
42
- class ExtendedMPTModel(MPTPreTrainedModel):
43
-
44
- def __init__(self, config: ExtendedMPTConfig):
45
- config._validate_config()
46
- super().__init__(config)
47
-
48
- self.attn_impl = config.attn_config['attn_impl']
49
- self.prefix_lm = config.attn_config['prefix_lm']
50
- self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
51
- self.alibi = config.attn_config['alibi']
52
- self.alibi_bias_max = config.attn_config['alibi_bias_max']
53
-
54
- self.mask_by_sim = config.attn_config['mask_by_sim']
55
- self.sim_threshold = config.attn_config['sim_threshold']
56
- self.topk = config.attn_config['topk']
57
- self.use_active_externalism = config.attn_config['use_active_externalism']
58
-
59
- self.use_active_externalism_by_layer = config.use_active_externalism_by_layer
60
-
61
- if config.init_device == 'mixed':
62
- if dist.get_local_rank() == 0:
63
- config.init_device = 'cpu'
64
- else:
65
- config.init_device = 'meta'
66
-
67
- if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
68
- norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
69
- raise NotImplementedError(
70
- f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).'
71
- )
72
- norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
73
-
74
- # CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414)
75
- # both report this helping with stabilizing training
76
- self.embedding_fraction = config.embedding_fraction
77
-
78
- self.wte = SharedEmbedding(config.vocab_size,
79
- config.d_model,
80
- device=config.init_device)
81
- if not self.alibi:
82
- self.wpe = torch.nn.Embedding(config.max_seq_len,
83
- config.d_model,
84
- device=config.init_device)
85
- self.emb_drop = nn.Dropout(config.emb_pdrop)
86
- self.blocks = nn.ModuleList([
87
- MPTBlock(
88
- device=config.init_device,
89
- **config.to_dict(),
90
- ) for _ in range(config.n_layers)
91
- ])
92
- self.norm_f = norm_class(config.d_model, device=config.init_device)
93
-
94
- if config.init_device != 'meta':
95
- print(
96
- f'You are using {config.init_device=}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
97
- )
98
- self.apply(self.param_init_fn)
99
-
100
- self.is_causal = not self.prefix_lm
101
-
102
- # define attn mask
103
- self._attn_bias_initialized = False
104
- self.attn_bias = None
105
- self.attn_bias_shape = attn_bias_shape(
106
- self.attn_impl,
107
- config.n_heads,
108
- config.max_seq_len,
109
- self.alibi,
110
- prefix_lm=self.prefix_lm,
111
- causal=self.is_causal,
112
- use_sequence_id=self.attn_uses_sequence_id,
113
- )
114
- self._attn_bias_ae_initialized = False #for active externalism
115
- self.attn_bias_ae = None
116
-
117
- if self.config.no_bias:
118
- for module in self.modules():
119
- if hasattr(module, 'bias') and isinstance(
120
- module.bias, nn.Parameter):
121
- if self.config.verbose:
122
- warnings.warn(
123
- f'Removing bias ({module.bias}) from {module}.')
124
- module.register_parameter('bias', None)
125
-
126
- # Print verbose info
127
- if config.verbose and config.verbose > 2:
128
- print(self)
129
- if 'verbose' not in self.config.init_config:
130
- self.config.init_config['verbose'] = self.config.verbose
131
- if self.config.init_config['verbose'] > 1:
132
- init_fn_name = self.config.init_config['name']
133
- warnings.warn(f'Using {init_fn_name} initialization.')
134
-
135
- def get_input_embeddings(self):
136
- return self.wte
137
-
138
- def set_input_embeddings(self, value: nn.Embedding):
139
- self.wte = value
140
-
141
- @torch.no_grad()
142
- def _attn_bias(
143
- self,
144
- device,
145
- dtype,
146
- attention_mask: Optional[torch.ByteTensor] = None,
147
- prefix_mask: Optional[torch.ByteTensor] = None,
148
- sequence_id: Optional[torch.LongTensor] = None,
149
- seq_len: Optional[int] = None,
150
- use_active_externalism:bool=None,
151
- topk=None,
152
- ):
153
- if not self._attn_bias_initialized:
154
- if self.attn_bias_shape:
155
- self.attn_bias = torch.zeros(self.attn_bias_shape,
156
- device=device,
157
- dtype=dtype)
158
- self.attn_bias = build_attn_bias(
159
- self.attn_impl,
160
- self.config.n_heads,
161
- self.config.max_seq_len,
162
- device=device,
163
- dtype=dtype,
164
- attn_bias = self.attn_bias,
165
- causal=self.is_causal,
166
- alibi=self.alibi,
167
- alibi_bias_max=self.alibi_bias_max
168
- )
169
- self._attn_bias_initialized = True
170
-
171
- if use_active_externalism: #for active externalism, init every time since seq_len changes
172
- self.attn_bias_ae = build_attn_bias(
173
- self.attn_impl,
174
- self.config.n_heads,
175
- seq_len,
176
- device=device,
177
- dtype=dtype,
178
- causal=self.is_causal,
179
- alibi=self.alibi,
180
- alibi_bias_max=self.alibi_bias_max,
181
- for_ae=use_active_externalism,
182
- topk=topk
183
- )
184
-
185
- self._attn_bias_ae_initialized = True
186
-
187
- # flash does not support prefix_lm and will incorporate any
188
- # attention_mask inside the attention module
189
- if self.attn_impl == 'flash':
190
- return self.attn_bias, attention_mask
191
-
192
- if self.attn_bias is not None:
193
- # .to(*args, **kwargs) is a no-op if tensor is already on
194
- # specified device or of specificed dtype
195
- self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
196
-
197
- attn_bias = self.attn_bias
198
-
199
- if self.attn_bias_ae is not None: #for active externalism
200
- self.attn_bias_ae = self.attn_bias_ae.to(dtype=dtype, device=device)
201
- attn_bias_ae = self.attn_bias_ae
202
-
203
- # If using torch or triton, we incorporate the prefix_mask (if appropriate)
204
- if self.prefix_lm:
205
- assert isinstance(attn_bias, torch.Tensor) # pyright
206
- assert isinstance(prefix_mask, torch.Tensor) # pyright
207
- attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
208
-
209
- # If using torch or triton, we incorporate sequence_id (if appropriate)
210
- if self.attn_uses_sequence_id and sequence_id is not None:
211
- assert isinstance(attn_bias, torch.Tensor) # pyright
212
- attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
213
-
214
- # If using torch or triton, we incorporate attention_mask. This will output
215
- # None in place of attention_mask since it will not be further needed in the
216
- # attention modules.
217
- if attention_mask is not None:
218
- s_k = attention_mask.shape[-1]
219
- if attn_bias is None:
220
- attn_bias = torch.zeros((1, 1, 1, s_k),
221
- device=device,
222
- dtype=dtype)
223
- else:
224
- # clamp to 0 necessary for torch 2.0 compile()
225
- _s_k = max(0, attn_bias.size(-1) - s_k)
226
- attn_bias = attn_bias[:, :, :, _s_k:]
227
- if prefix_mask is not None and (attention_mask.shape !=
228
- prefix_mask.shape):
229
- raise ValueError(
230
- f'attention_mask shape={attention_mask.shape} ' +
231
- f'and prefix_mask shape={prefix_mask.shape} are not equal.')
232
- min_val = torch.finfo(attn_bias.dtype).min
233
- attn_bias = attn_bias.masked_fill(
234
- ~attention_mask.view(-1, 1, 1, s_k), min_val)
235
-
236
- return attn_bias, attn_bias_ae, None
237
-
238
- def _apply_prefix_mask(self, attn_bias: torch.Tensor,
239
- prefix_mask: torch.Tensor):
240
- s_k, s_q = attn_bias.shape[-2:]
241
- if (s_k != self.config.max_seq_len) or (s_q != self.config.max_seq_len):
242
- raise ValueError(
243
- 'attn_bias does not match the expected shape. ' +
244
- f'The last two dimensions should both be {self.config.max_length} '
245
- + f'but are {s_k} and {s_q}.')
246
- seq_len = prefix_mask.shape[-1]
247
- if seq_len > self.config.max_seq_len:
248
- raise ValueError(
249
- f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
250
- )
251
-
252
- # select seq_len subset of attn mask
253
- attn_bias = attn_bias[..., :seq_len, :seq_len]
254
-
255
- # Mix the causal max and the bidirectional mask to get the full
256
- # allowable attention (i.e. full = not accounting for padding yet)
257
- causal = torch.tril(
258
- torch.ones((seq_len, seq_len),
259
- dtype=torch.bool,
260
- device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
261
- prefix = prefix_mask.view(-1, 1, 1, seq_len)
262
- cannot_attend = ~torch.logical_or(causal, prefix.bool())
263
-
264
- min_val = torch.finfo(attn_bias.dtype).min
265
- attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
266
-
267
- return attn_bias
268
-
269
- def _apply_sequence_id(self, attn_bias: torch.Tensor,
270
- sequence_id: torch.LongTensor):
271
- seq_len = sequence_id.shape[-1]
272
- if seq_len > self.config.max_seq_len:
273
- raise ValueError(
274
- f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
275
- )
276
-
277
- # select seq_len subset of attn mask
278
- attn_bias = attn_bias[..., :seq_len, :seq_len]
279
-
280
- # Restrict attention to tokens that share the same value
281
- # in sequence_id
282
- cannot_attend = torch.logical_not(
283
- torch.eq(
284
- sequence_id.view(-1, seq_len, 1),
285
- sequence_id.view(-1, 1, seq_len),
286
- )).unsqueeze(1)
287
- min_val = torch.finfo(attn_bias.dtype).min
288
- attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
289
-
290
- return attn_bias
291
-
292
- def forward(
293
- self,
294
- input_ids: torch.LongTensor,
295
- past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
296
- attention_mask: Optional[torch.ByteTensor] = None,
297
- prefix_mask: Optional[torch.ByteTensor] = None,
298
- sequence_id: Optional[torch.LongTensor] = None,
299
- return_dict: Optional[bool] = None,
300
- output_attentions: Optional[bool] = None,
301
- output_hidden_states: Optional[bool] = None,
302
- use_cache: Optional[bool] = None,
303
- inputs_embeds: Optional[torch.Tensor] = None,
304
- use_active_externalism:Optional[bool]=None,
305
- long_range_past_key_values:Optional[List[Tuple[torch.FloatTensor]]] = None,
306
- faiss_indexes:Tuple=None,
307
- topk:int=None,
308
- ):
309
- return_dict = (return_dict
310
- if return_dict is not None else self.config.return_dict)
311
- use_cache = (use_cache
312
- if use_cache is not None else self.config.use_cache)
313
- use_active_externalism = (use_active_externalism
314
- if use_active_externalism is not None else self.use_active_externalism)
315
- topk = (topk if topk is not None else self.topk)
316
-
317
- if attention_mask is not None:
318
- attention_mask = attention_mask.bool()
319
-
320
- if prefix_mask is not None:
321
- prefix_mask = prefix_mask.bool()
322
-
323
- # These args are passed in by keyword in huggingface's generate function
324
- # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
325
- # but have not yet been fully implemented in MPTModel
326
- if not return_dict:
327
- raise NotImplementedError(
328
- 'return_dict False is not implemented yet for MPT')
329
- if output_attentions:
330
- if self.attn_impl != 'torch':
331
- raise NotImplementedError(
332
- 'output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.'
333
- )
334
-
335
- if (attention_mask is not None and
336
- attention_mask[:, 0].sum() != attention_mask.shape[0] and
337
- self.training):
338
- raise NotImplementedError(
339
- 'MPT does not support training with left padding.')
340
-
341
- if self.prefix_lm and prefix_mask is None:
342
- raise ValueError(
343
- 'prefix_mask is a required argument when MPT is configured with prefix_lm=True.'
344
- )
345
-
346
- # Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT)
347
- if inputs_embeds is not None:
348
- raise NotImplementedError(
349
- 'inputs_embeds is not implemented for MPT.')
350
-
351
- if self.training:
352
- if self.attn_uses_sequence_id and sequence_id is None:
353
- raise ValueError(
354
- 'sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True '
355
- + 'and the model is in train mode.')
356
- elif (self.attn_uses_sequence_id is False) and (sequence_id
357
- is not None):
358
- warnings.warn(
359
- 'MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. '
360
- +
361
- 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.'
362
- )
363
-
364
- S = input_ids.size(1)
365
-
366
- assert (
367
- S <= self.config.max_seq_len
368
- ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
369
-
370
- tok_emb = self.wte(input_ids) # type: ignore
371
- if self.alibi:
372
- x = tok_emb
373
- else:
374
- past_position = 0
375
- if past_key_values is not None:
376
- if len(past_key_values) != self.config.n_layers:
377
- raise ValueError(
378
- f'past_key_values must provide a past_key_value for each attention '
379
- +
380
- f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
381
- )
382
- # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
383
- # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
384
- # Here we shift position embedding using the `seq` dim of the past key
385
- past_position = past_key_values[0][0].size(1)
386
- if self.attn_impl == 'torch':
387
- past_position = past_key_values[0][0].size(3)
388
-
389
- if S + past_position > self.config.max_seq_len:
390
- raise ValueError(
391
- f'Cannot forward input with past sequence length {past_position} and current sequence length '
392
- f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.'
393
- )
394
- pos = torch.arange(
395
- past_position,
396
- S + past_position,
397
- dtype=torch.long,
398
- device=input_ids.device,
399
- ).unsqueeze(0)
400
- if attention_mask is not None:
401
- # adjust the position indices to account for padding tokens
402
- pos = torch.clamp(
403
- pos - torch.cumsum((~attention_mask).to(torch.int32),
404
- dim=1)[:, past_position:],
405
- min=0,
406
- )
407
-
408
- pos_emb = self.wpe(pos) # type: ignore
409
- x = tok_emb + pos_emb
410
-
411
- if self.embedding_fraction == 1:
412
- x = self.emb_drop(x) # type: ignore
413
- else:
414
- # this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414
415
- x_shrunk = (x * self.embedding_fraction) + (
416
- x.detach() * (1 - self.embedding_fraction))
417
- assert isinstance(self.emb_drop, nn.Module) # pyright
418
- x = self.emb_drop(x_shrunk)
419
-
420
- seq_len = S #for active externalism
421
- if past_key_values is not None:
422
- past_position = past_key_values[0][0].size(-1)
423
- seq_len += past_position
424
-
425
- attn_bias, attn_bias_ae, attention_mask = self._attn_bias(
426
- device=x.device,
427
- dtype=torch.float32,
428
- attention_mask=attention_mask,
429
- prefix_mask=prefix_mask,
430
- sequence_id=sequence_id,
431
- seq_len = seq_len,
432
- use_active_externalism=use_active_externalism,
433
- topk=topk
434
- )
435
-
436
- # initialize the past key values cache if it should be used
437
- if use_cache and past_key_values is None:
438
- past_key_values = [() for _ in range(self.config.n_layers)
439
- ] # type: ignore
440
-
441
- all_hidden_states = () if output_hidden_states else None
442
- all_self_attns = () if output_attentions else None
443
- all_idx = () if output_attentions else None
444
- for b_idx, block in enumerate(self.blocks): # type: ignore
445
- if output_hidden_states:
446
- assert all_hidden_states is not None # pyright
447
- all_hidden_states = all_hidden_states + (x,)
448
- past_key_value = (past_key_values[b_idx]
449
- if past_key_values is not None else None)
450
- long_range_past_key_value = (long_range_past_key_values[b_idx]
451
- if (long_range_past_key_values is not None and self.use_active_externalism_by_layer[b_idx] and use_active_externalism is True) else None)
452
-
453
- if long_range_past_key_value is not None and faiss_indexes is not None:
454
- raise NotImplementedError(
455
- 'Using faiss and passing key value pairs manually are mutually exclusive right now.')
456
-
457
- x, attn_weights, past_key_value, reshaped_idx = block(
458
- x,
459
- past_key_value=past_key_value,
460
- long_range_past_key_value=long_range_past_key_value,
461
- attn_bias=attn_bias,
462
- attention_mask=attention_mask,
463
- attn_bias_ae=attn_bias_ae,
464
- is_causal=self.is_causal,
465
- topk=topk,
466
- needs_weights=output_attentions,
467
- faiss_indexes=faiss_indexes,
468
- n_layers=self.config.n_layers,
469
- current_layer=b_idx,
470
- mask_by_sim=self.mask_by_sim,
471
- sim_threshold=self.sim_threshold,
472
- )
473
- if past_key_values is not None:
474
- past_key_values[b_idx] = past_key_value
475
-
476
- if output_attentions:
477
- assert all_self_attns is not None # pyright
478
- all_self_attns = all_self_attns + (attn_weights,)
479
-
480
- assert all_idx is not None
481
- all_idx = all_idx + (reshaped_idx,)
482
-
483
- x = self.norm_f(x) # type: ignore
484
-
485
- # add hidden states from the last decoder layer
486
- if output_hidden_states:
487
- assert all_hidden_states is not None # pyright
488
- all_hidden_states = all_hidden_states + (x,)
489
-
490
- return BaseModelOutputWithPast(
491
- last_hidden_state=x,
492
- past_key_values=past_key_values,
493
- hidden_states=all_hidden_states,
494
- attentions=(all_self_attns, all_idx), #return reshaped_idx for active externalism
495
- )
496
-
497
- # Param Initialization, needed for device='meta' fast initialization
498
- def param_init_fn(self, module):
499
- init_fn_name = self.config.init_config['name']
500
- MODEL_INIT_REGISTRY[init_fn_name](
501
- module=module,
502
- n_layers=self.config.n_layers,
503
- d_model=self.config.d_model,
504
- **self.config.init_config,
505
- )
506
-
507
- # FSDP Wrap function
508
- def fsdp_wrap_fn(self, module):
509
- return isinstance(module, MPTBlock)
510
-
511
- # Activation Checkpointing
512
- def activation_checkpointing_fn(self, module):
513
- return isinstance(module, MPTBlock)
514
-
515
- class ExtendedMPTForCausalLM(MPTPreTrainedModel):
516
-
517
- def __init__(self, config:ExtendedMPTConfig, external_memories=None):
518
- if isinstance(config, DictConfig):
519
- config = instantiate_from_config(config)
520
-
521
- super().__init__(config)
522
- if not config.tie_word_embeddings:
523
- raise ValueError(
524
- 'MPTForCausalLM only supports tied word embeddings')
525
-
526
- print(f'Instantiating an MPTForCausalLM model from {__file__}')
527
-
528
- self.transformer: ExtendedMPTModel = ExtendedMPTModel(config)
529
-
530
- self.use_active_externalism = config.attn_config['use_active_externalism']
531
- self.memory_type = config.attn_config['memory_type']
532
- self._memories = None
533
- self.memory_device = config.memory_device
534
-
535
- for child in self.transformer.children():
536
- if isinstance(child, torch.nn.ModuleList):
537
- continue
538
- if isinstance(child, torch.nn.Module):
539
- child._fsdp_wrap = True
540
-
541
- # enables scaling output logits; similar to a softmax "temperature"
542
- # PaLM paper uses scale 1/sqrt(config.d_model)
543
- self.logit_scale = None
544
- if config.logit_scale is not None:
545
- logit_scale = config.logit_scale
546
- if isinstance(logit_scale, str):
547
- if logit_scale == 'inv_sqrt_d_model':
548
- logit_scale = 1 / math.sqrt(config.d_model)
549
- else:
550
- raise ValueError(
551
- f"{logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
552
- )
553
- self.logit_scale = logit_scale
554
-
555
- if external_memories is not None:
556
- self._memories = external_memories
557
- self.memories = None
558
-
559
- def set_memories(self, memories):
560
- self.memories = memories
561
-
562
- def empty_memories(self):
563
- self.memories = None
564
-
565
- def get_input_embeddings(self):
566
- return self.transformer.wte
567
-
568
- def set_input_embeddings(self, value):
569
- self.transformer.wte = value
570
-
571
- def get_output_embeddings(self):
572
- return self.transformer.wte
573
-
574
- def set_output_embeddings(self, new_embeddings):
575
- self.transformer.wte = new_embeddings
576
-
577
- def set_decoder(self, decoder):
578
- self.transformer = decoder
579
-
580
- def get_decoder(self):
581
- return self.transformer
582
-
583
- def forward(
584
- self,
585
- input_ids: torch.LongTensor,
586
- past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
587
- attention_mask: Optional[torch.ByteTensor] = None,
588
- prefix_mask: Optional[torch.ByteTensor] = None,
589
- sequence_id: Optional[torch.LongTensor] = None,
590
- labels: Optional[torch.LongTensor] = None,
591
- return_dict: Optional[bool] = None,
592
- output_attentions: Optional[bool] = None,
593
- output_hidden_states: Optional[bool] = None,
594
- use_cache: Optional[bool] = None,
595
- inputs_embeds: Optional[torch.FloatTensor] = None,
596
- use_active_externalism: Optional[bool]=None,
597
- topk:int=None
598
- ):
599
- if self._memories is not None and self.memories is None: #init memories once on first call
600
- self.memories = self.generate_cache(self._memories, cache_type=self.memory_type)
601
-
602
- return_dict = (return_dict
603
- if return_dict is not None else self.config.return_dict)
604
- use_cache = (use_cache
605
- if use_cache is not None else self.config.use_cache)
606
- use_active_externalism = (use_active_externalism
607
- if use_active_externalism is not None else self.use_active_externalism)
608
-
609
- topk = topk if topk is not None else None
610
-
611
- # if input_embeds is not none, raise a not implemented error
612
- if inputs_embeds is not None:
613
- raise NotImplementedError(
614
- 'inputs_embeds has to be None (for hf/peft support).')
615
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
616
-
617
- if hasattr(self, "memories") and type(self.memories)==list:
618
- long_range_past_key_values = self.memories
619
- faiss_indexes = None
620
- elif hasattr(self, "memories"):
621
- long_range_past_key_values = None
622
- faiss_indexes = self.memories
623
- else:
624
- long_range_past_key_values = None
625
- faiss_indexes = None
626
-
627
- outputs = self.transformer(
628
- input_ids=input_ids,
629
- past_key_values=past_key_values,
630
- long_range_past_key_values=long_range_past_key_values,
631
- faiss_indexes=faiss_indexes,
632
- attention_mask=attention_mask,
633
- prefix_mask=prefix_mask,
634
- sequence_id=sequence_id,
635
- return_dict=return_dict,
636
- output_attentions=output_attentions,
637
- output_hidden_states=output_hidden_states,
638
- use_cache=use_cache,
639
- use_active_externalism=use_active_externalism,
640
- topk=topk
641
- )
642
-
643
- # move outputs to same device as weights for token embedding
644
- # needed to support HF `device_map`
645
- logits = self.transformer.wte(
646
- outputs.last_hidden_state.to(self.transformer.wte.weight.device),
647
- True,
648
- )
649
-
650
- if self.logit_scale is not None:
651
- if self.logit_scale == 0:
652
- warnings.warn(
653
- f'Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.'
654
- )
655
- logits *= self.logit_scale
656
-
657
- loss = None
658
- if labels is not None:
659
- _labels = torch.roll(labels, shifts=-1)
660
- _labels[:, -1] = -100
661
- loss = F.cross_entropy(
662
- logits.view(-1, logits.size(-1)),
663
- _labels.to(logits.device).view(-1),
664
- )
665
-
666
- return CausalLMOutputWithPast(
667
- loss=loss,
668
- logits=logits,
669
- past_key_values=outputs.past_key_values,
670
- hidden_states=outputs.hidden_states,
671
- attentions=outputs.attentions,
672
- )
673
-
674
- # Param Initialization, needed for device='meta' fast initialization
675
- def param_init_fn(self, module):
676
- init_fn_name = self.config.init_config['name']
677
- MODEL_INIT_REGISTRY[init_fn_name](
678
- module=module,
679
- n_layers=self.config.n_layers,
680
- d_model=self.config.d_model,
681
- **self.config.init_config,
682
- )
683
-
684
- # FSDP Wrap function
685
- def fsdp_wrap_fn(self, module):
686
- return isinstance(module, MPTBlock)
687
-
688
- # Activation Checkpointing
689
- def activation_checkpointing_fn(self, module):
690
- return isinstance(module, MPTBlock)
691
-
692
- def generate_cache(self,
693
- input_ids:torch.LongTensor,
694
- stride:int=512,
695
- max_len:int=2048,
696
- cache_type:str='manual'):
697
- if cache_type not in ['manual', 'faiss']:
698
- raise NotImplementedError(f"Cache type {cache_type} not implemented.")
699
-
700
- prev_end_loc=0
701
- long_range_past_key_values = None
702
- faiss_indexes= None
703
- for b_idx in range(0, input_ids.size(-1), stride): #generate kv-pairs using stride
704
- end_loc = min(b_idx + max_len, input_ids.size(-1))
705
- trg_len = end_loc - prev_end_loc
706
- subseq = input_ids[:, b_idx:end_loc].to(self.device)
707
- with torch.no_grad():
708
- outputs = self.transformer(subseq, use_cache=True, use_active_externalism=False)
709
- to_cache = [(
710
- kv[0][:,:,:,-trg_len:],
711
- kv[1][:,:,-trg_len:])
712
- for kv in outputs.past_key_values
713
- ]
714
- long_range_past_key_values, faiss_indexes = self.cache(to_cache, cache_type, long_range_past_key_values=long_range_past_key_values, faiss_indexes=faiss_indexes)
715
-
716
- prev_end_loc = end_loc
717
- if end_loc == input_ids.size(-1):
718
- break
719
- if long_range_past_key_values is not None:
720
- return long_range_past_key_values
721
- else:
722
- return faiss_indexes
723
-
724
- def cache(self,
725
- to_cache:List,
726
- cache_type:str='manual',
727
- long_range_past_key_values:List=None,
728
- faiss_indexes:faiss.IndexFlatIP=None,
729
- max_length_cache=100000,
730
- verbose=False):
731
- if long_range_past_key_values is not None and faiss_indexes is not None:
732
- raise NotImplementedError("Using faiss and passing key value pairs manually are mutually exclusive right now.")
733
-
734
- if cache_type=='faiss': #add one-hot encoding to match layer, head indices
735
- one_hot_encodings = F.one_hot(torch.arange(0, self.config.n_heads*self.config.n_layers))*10
736
- if faiss_indexes is None:
737
- faiss_indexes = (faiss.IndexFlatIP(to_cache[0][0].size(-2)+one_hot_encodings.size(-1)), faiss.IndexFlatIP(to_cache[0][1].size(-1)*2))
738
- kn_index, kv_index = faiss_indexes
739
- for b_idx, (k, v) in enumerate(to_cache):
740
- k_n = (k/vector_norm(k, ord=2, dim=-2, keepdim=True)).to('cpu')
741
- k_n = torch.concat([rearrange(k_n, 'b h d s -> b (h s) d', h=self.config.n_heads), one_hot_encodings[self.config.n_heads*b_idx:self.config.n_heads*(b_idx+1)].unsqueeze(0).repeat_interleave(repeats=k.size(-1), dim=-2)], dim=-1)
742
- kn_index.add(k_n.squeeze().numpy())
743
-
744
- k= rearrange(k, 'b h d s -> b (h s) d', h=self.config.n_heads)
745
- v= rearrange(v, 'b h s d -> b (h s) d', h=self.config.n_heads)
746
- kv_index.add(torch.concat([v.squeeze(), k.squeeze()], dim=1).to('cpu').numpy())
747
- else:
748
- if long_range_past_key_values is None:
749
- long_range_past_key_values = [(k.to(self.memory_device),v.to(self.memory_device)) for k,v in to_cache]
750
- else:
751
- long_range_past_key_values = [
752
- (
753
- torch.concat([kv[0], to_cache[ind][0].to(self.memory_device)], dim=3),
754
- torch.concat([kv[1], to_cache[ind][1].to(self.memory_device)], dim=2)
755
- )
756
- for ind, kv in enumerate(long_range_past_key_values)
757
- ]
758
- if long_range_past_key_values is not None: #set a limit on manual memory length
759
- if long_range_past_key_values[0][0].size(-1) > max_length_cache:
760
- long_range_past_key_values = [
761
- (
762
- kv[0][:, :, :, -max_length_cache:],
763
- kv[1][:, :, -max_length_cache:]
764
- )
765
- for kv in long_range_past_key_values]
766
- if verbose:
767
- if cache_type == 'faiss':
768
- print(f"{kn_index.ntotal} keys in faiss index")
769
- else:
770
- print(f"{long_range_past_key_values[0][0].size(-1)} cached kvs")
771
-
772
- return long_range_past_key_values, (kn_index, kv_index) if cache_type == 'faiss' else None
773
-
774
- def prepare_inputs_for_generation(
775
- self,
776
- input_ids,
777
- past_key_values=None,
778
- inputs_embeds=None,
779
- **kwargs,
780
- ):
781
- if inputs_embeds is not None:
782
- raise NotImplementedError(
783
- 'inputs_embeds is not implemented for MPT yet')
784
-
785
- attention_mask = kwargs['attention_mask'].bool()
786
- if attention_mask[:, -1].sum() != attention_mask.shape[0]:
787
- raise NotImplementedError(
788
- 'MPT does not support generation with right padding.')
789
-
790
- if self.transformer.attn_uses_sequence_id and self.training:
791
- sequence_id = torch.zeros_like(input_ids[:1])
792
- else:
793
- sequence_id = None
794
-
795
- if past_key_values is not None:
796
- input_ids = input_ids[:, -1].unsqueeze(-1)
797
-
798
- if self.transformer.prefix_lm:
799
- # Leverage a convenience of sequential generation!
800
- prefix_mask = torch.ones_like(attention_mask)
801
- # This requires that we're using the cache
802
- if kwargs.get('use_cache') == False:
803
- raise NotImplementedError(
804
- 'MPT with prefix_lm=True does not support use_cache=False.')
805
- else:
806
- prefix_mask = None
807
-
808
- return {
809
- 'input_ids': input_ids,
810
- 'attention_mask': attention_mask,
811
- 'prefix_mask': prefix_mask,
812
- 'sequence_id': sequence_id,
813
- 'past_key_values': past_key_values,
814
- 'use_cache': kwargs.get('use_cache', True),
815
- 'use_active_externalism': kwargs.get('use_active_externalism'), #add a few more kwargs for active externalism
816
- 'topk': kwargs.get('topk', None),
817
- }
818
-
819
- @staticmethod
820
- def _reorder_cache(past_key_values, beam_idx):
821
- """Used by HuggingFace generate when using beam search with kv-caching.
822
-
823
- See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
824
- for an example in transformers.
825
- """
826
- reordered_past = []
827
- for layer_past in past_key_values:
828
- reordered_past += [
829
- tuple(
830
- past_state.index_select(0, beam_idx)
831
- for past_state in layer_past)
832
- ]
833
- return reordered_past