zpn commited on
Commit
7852674
1 Parent(s): 48b25c8

Delete modeling_hf_nomic_bert.py

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +0 -2071
modeling_hf_nomic_bert.py DELETED
@@ -1,2071 +0,0 @@
1
- # Copyright (c) 2022, Tri Dao.
2
- # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
3
- # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
- # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
-
6
- import logging
7
-
8
- # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
- import math
10
- import numpy as np
11
- import collections
12
- import os
13
- import re
14
- from collections import OrderedDict
15
- from functools import partial
16
- from typing import List, Optional, Tuple, Union
17
-
18
- import torch
19
- import torch.nn as nn
20
- import torch.nn.functional as F
21
- from einops import rearrange, repeat
22
- from safetensors.torch import load_file as safe_load_file
23
- from transformers import GPT2Config, PreTrainedModel, ViTModel, ViTConfig
24
- from transformers.models.bert.modeling_bert import (
25
- BaseModelOutputWithPoolingAndCrossAttentions,
26
- MaskedLMOutput,
27
- SequenceClassifierOutput,
28
- )
29
- from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
30
- from transformers.utils.hub import cached_file, get_checkpoint_shard_files
31
- from transformers.modeling_outputs import BaseModelOutputWithPast
32
- from torch.nn.modules.utils import _pair
33
-
34
- from .configuration_hf_nomic_bert import NomicBertConfig
35
-
36
- logger = logging.getLogger(__name__)
37
-
38
-
39
- # adapted from flash attention, added safe serialization option for hf models
40
- def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
41
- # If not fp32, then we don't want to load directly to the GPU
42
- mapped_device = "cpu" if dtype not in [torch.float32, None] else device
43
- is_sharded = False
44
- load_safe = False
45
- resolved_archive_file = None
46
-
47
- weights_path = os.path.join(model_name, WEIGHTS_NAME)
48
- weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)
49
- safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)
50
- safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
51
-
52
- if os.path.isfile(weights_path):
53
- resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
54
- elif os.path.isfile(weights_index_path):
55
- resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False)
56
- is_sharded = True
57
- elif os.path.isfile(safe_weights_path):
58
- resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
59
- load_safe = True
60
- elif os.path.isfile(safe_weights_index_path):
61
- resolved_archive_file = cached_file(
62
- model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
63
- )
64
- is_sharded = True
65
- load_safe = True
66
- else: # Try loading from HF hub instead of from local files
67
- resolved_archive_file = None
68
- for weight_name in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
69
- resolved_archive_file = cached_file(
70
- model_name, weight_name, _raise_exceptions_for_missing_entries=False
71
- )
72
- if resolved_archive_file is not None:
73
- if weight_name in [SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]:
74
- load_safe = True
75
- if weight_name in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
76
- is_sharded = True
77
- break
78
-
79
- if resolved_archive_file is None:
80
- raise EnvironmentError(f"Model name {model_name} was not found.")
81
-
82
- if load_safe:
83
- loader = partial(safe_load_file, device=mapped_device)
84
- else:
85
- loader = partial(torch.load, map_location=mapped_device)
86
-
87
- if is_sharded:
88
- # resolved_archive_file becomes a list of files that point to the different
89
- # checkpoint shards in this case.
90
- resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(model_name, resolved_archive_file)
91
- state_dict = {}
92
- for sharded_file in resolved_archive_file:
93
- state_dict.update(loader(sharded_file))
94
- else:
95
- state_dict = loader(resolved_archive_file)
96
- # Convert dtype before moving to GPU to save memory
97
- if dtype is not None:
98
- state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
99
- state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
100
- return state_dict
101
-
102
-
103
- def filter_shapes(state_dict, model):
104
- """
105
- Filters the state dict to match the current model shape.
106
- """
107
- filtered_state_dict = {}
108
- for key, value in state_dict.items():
109
- if key in model.state_dict():
110
- if value.shape == model.state_dict()[key].shape:
111
- filtered_state_dict[key] = value
112
- return filtered_state_dict
113
-
114
-
115
- def remap_bert_state_dict(
116
- state_dict,
117
- config,
118
- remove_bert=False,
119
- remove_cls_weights=False,
120
- add_pooling_layer=False,
121
- ):
122
- """
123
- Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
124
- """
125
-
126
- def add_bert_prefix(key):
127
- # prepend bert. to the key
128
- if key.startswith("bert.") or key.startswith("cls."):
129
- return key
130
- return f"bert.{key}"
131
-
132
- state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
133
-
134
- # LayerNorm
135
- def key_mapping_ln_gamma_beta(key):
136
- key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
137
- key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
138
- return key
139
-
140
- state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
141
-
142
- # Layers
143
- def key_mapping_layers(key):
144
- return re.sub(r"^bert.encoder.layer\.", "bert.encoder.layers.", key)
145
-
146
- state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
147
-
148
- # LayerNorm
149
- def key_mapping_ln(key):
150
- key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
151
- key = re.sub(
152
- r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
153
- r"bert.encoder.layers.\1.norm1.\2",
154
- key,
155
- )
156
- key = re.sub(
157
- r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
158
- r"bert.encoder.layers.\1.norm2.\2",
159
- key,
160
- )
161
- key = re.sub(
162
- r"^cls.predictions.transform.LayerNorm.(weight|bias)",
163
- r"cls.predictions.transform.layer_norm.\1",
164
- key,
165
- )
166
- return key
167
-
168
- state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
169
-
170
- # MLP
171
- def key_mapping_mlp(key):
172
- key = re.sub(
173
- r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
174
- r"bert.encoder.layers.\1.mlp.fc1.\2",
175
- key,
176
- )
177
- key = re.sub(
178
- r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
179
- r"bert.encoder.layers.\1.mlp.fc2.\2",
180
- key,
181
- )
182
- return key
183
-
184
- state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
185
-
186
- # Attention
187
- last_layer_subset = getattr(config, "last_layer_subset", False)
188
- for d in range(config.num_hidden_layers):
189
- if f"bert.encoder.layers.{d}.attention.self.query.weight" not in state_dict:
190
- continue
191
- Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
192
- Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
193
- Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
194
- bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
195
- bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
196
- bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
197
- if not (last_layer_subset and d == config.num_hidden_layers - 1):
198
- state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
199
- state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
200
- else:
201
- state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
202
- state_dict[f"bert.encoder.layers.{d}.attn.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
203
- state_dict[f"bert.encoder.layers.{d}.attn.Wq.bias"] = bq
204
- state_dict[f"bert.encoder.layers.{d}.attn.Wkv.bias"] = torch.cat([bk, bv], dim=0)
205
-
206
- def key_mapping_attn(key):
207
- return re.sub(
208
- r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
209
- r"bert.encoder.layers.\1.attn.out_proj.\2",
210
- key,
211
- )
212
-
213
- state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
214
-
215
- def key_mapping_decoder_bias(key):
216
- return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
217
-
218
- # remove nsp weights, we don't use
219
- state_dict.pop("cls.seq_relationship.weight", None)
220
- state_dict.pop("cls.seq_relationship.bias", None)
221
- state_dict.pop("bert.embeddings.position_ids", None)
222
-
223
- state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
224
-
225
- if remove_cls_weights:
226
- cls_weights = [
227
- "cls.predictions.decoder.bias",
228
- "cls.predictions.transform.dense.weight",
229
- "cls.predictions.transform.dense.bias",
230
- "cls.predictions.transform.layer_norm.weight",
231
- "cls.predictions.transform.layer_norm.bias",
232
- "cls.predictions.decoder.weight",
233
- ]
234
- for weight in cls_weights:
235
- state_dict.pop(weight, None)
236
-
237
- # Word embedding
238
- pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
239
- if pad_vocab_size_multiple > 1:
240
- word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
241
- state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
242
- word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
243
- )
244
- if not remove_cls_weights:
245
- decoder_weight = state_dict["cls.predictions.decoder.weight"]
246
- state_dict["cls.predictions.decoder.weight"] = F.pad(
247
- decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
248
- )
249
- # If the vocab was padded, we want to set the decoder bias for those padded indices to be
250
- # strongly negative (i.e. the decoder shouldn't predict those indices).
251
- # TD [2022-05-09]: I don't think it affects the MLPerf training.
252
- if "cls.predictions.decoder.bias" in state_dict:
253
- decoder_bias = state_dict["cls.predictions.decoder.bias"]
254
- state_dict["cls.predictions.decoder.bias"] = F.pad(
255
- decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
256
- )
257
-
258
- if add_pooling_layer is False:
259
- pooler_weights = [
260
- "bert.pooler.dense.weight",
261
- "bert.pooler.dense.bias",
262
- ]
263
- for key in pooler_weights:
264
- state_dict.pop(key, None)
265
-
266
- if remove_bert:
267
-
268
- def remove_bert_prefix(key):
269
- key = re.sub(r"^bert.", "", key)
270
- return key
271
-
272
- state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
273
-
274
- return state_dict
275
-
276
-
277
- def _trunc_normal_(tensor, mean, std, a, b):
278
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
279
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
280
- def norm_cdf(x):
281
- # Computes standard normal cumulative distribution function
282
- return (1. + math.erf(x / math.sqrt(2.))) / 2.
283
-
284
- if (mean < a - 2 * std) or (mean > b + 2 * std):
285
- print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
286
- "The distribution of values may be incorrect.",
287
- stacklevel=2)
288
-
289
- # Values are generated by using a truncated uniform distribution and
290
- # then using the inverse CDF for the normal distribution.
291
- # Get upper and lower cdf values
292
- l = norm_cdf((a - mean) / std)
293
- u = norm_cdf((b - mean) / std)
294
-
295
- # Uniformly fill tensor with values from [l, u], then translate to
296
- # [2l-1, 2u-1].
297
- tensor.uniform_(2 * l - 1, 2 * u - 1)
298
-
299
- # Use inverse cdf transform for normal distribution to get truncated
300
- # standard normal
301
- tensor.erfinv_()
302
-
303
- # Transform to proper mean, std
304
- tensor.mul_(std * math.sqrt(2.))
305
- tensor.add_(mean)
306
-
307
- # Clamp to ensure it's in the proper range
308
- tensor.clamp_(min=a, max=b)
309
- return tensor
310
-
311
- def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
312
- r"""Fills the input Tensor with values drawn from a truncated
313
- normal distribution. The values are effectively drawn from the
314
- normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
315
- with values outside :math:`[a, b]` redrawn until they are within
316
- the bounds. The method used for generating the random values works
317
- best when :math:`a \leq \text{mean} \leq b`.
318
-
319
- NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
320
- bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
321
- and the result is subsquently scaled and shifted by the mean and std args.
322
-
323
- Args:
324
- tensor: an n-dimensional `torch.Tensor`
325
- mean: the mean of the normal distribution
326
- std: the standard deviation of the normal distribution
327
- a: the minimum cutoff value
328
- b: the maximum cutoff value
329
- Examples:
330
- >>> w = torch.empty(3, 5)
331
- >>> nn.init.trunc_normal_(w)
332
- """
333
- with torch.no_grad():
334
- _trunc_normal_(tensor, 0, 1.0, a, b)
335
- tensor.mul_(std).add_(mean)
336
- return tensor
337
-
338
-
339
- class NomicBertPreTrainedModel(PreTrainedModel):
340
- """An abstract class to handle weights initialization and
341
- a simple interface for dowloading and loading pretrained models.
342
- """
343
-
344
- config_class = NomicBertConfig
345
- base_model_prefix = "model"
346
- supports_gradient_checkpointing = True
347
- _no_split_modules = ["Block"]
348
- _skip_keys_device_placement = "past_key_values"
349
-
350
- def __init__(self, config, *inputs, **kwargs):
351
- super().__init__(config)
352
- if not isinstance(config, GPT2Config):
353
- raise ValueError(
354
- "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
355
- "To create a model from a Google pretrained model use "
356
- "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
357
- self.__class__.__name__, self.__class__.__name__
358
- )
359
- )
360
- self.config = config
361
-
362
- @classmethod
363
- def from_pretrained(cls, model_name, config=None, *inputs, **kwargs):
364
- """
365
- Instantiate a NomicBertPreTrainedModel from a pre-trained model file or a pytorch state dict.
366
- Download and cache the pre-trained model file if needed.
367
-
368
- Params:
369
- pretrained_model_name_or_path: either:
370
- - a path or url to a pretrained model archive containing:
371
- . `bert_config.json` a configuration file for the model
372
- . `pytorch_model.bin` a PyTorch dump of a NomicBertForPretraining instance
373
- - a path or url to a pretrained model archive containing:
374
- . `bert_config.json` a configuration file for the model
375
- . `model.chkpt` a TensorFlow checkpoint
376
- *inputs, **kwargs: additional input for the specific NomicBert class
377
- (ex: num_labels for NomicBertForSequenceClassification)
378
- """
379
- # Instantiate model.
380
- if config is None:
381
- config = cls.config_class.from_pretrained(model_name)
382
- remove_cls = cls != NomicBertForPreTraining
383
- remove_bert_prefix = cls != NomicBertForPreTraining and cls != NomicBertForSequenceClassification
384
- ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
385
- num_labels = kwargs.pop("num_labels", None)
386
- rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
387
- strict = kwargs.pop("strict", True)
388
- if rotary_scaling_factor:
389
- config.rotary_scaling_factor = rotary_scaling_factor
390
-
391
- if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
392
- config.n_positions = 2048
393
- if num_labels:
394
- config.num_labels = num_labels
395
-
396
- if "add_pooling_layer" in kwargs:
397
- model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
398
- else:
399
- if cls == NomicBertModel:
400
- model = cls(config, *inputs, add_pooling_layer=False)
401
- else:
402
- model = cls(config, *inputs)
403
- # TODO: fix this
404
- # Assuming we know what we're doing when loading from disk
405
- # Prob a bad assumption but i'm tired and want to train this asap
406
- if os.path.exists(model_name):
407
- model_path = f"{model_name}/pytorch_model.bin"
408
- if os.path.exists(model_path):
409
- state_dict = torch.load(f"{model_name}/pytorch_model.bin")
410
- else:
411
- model_path = f"{model_name}/model.safetensors"
412
- if not os.path.exists(model_path):
413
- raise ValueError(f"Model path {model_path} not found")
414
- state_dict = safe_load_file(model_path)
415
-
416
- if ignore_mismatched_shapes:
417
- state_dict = filter_shapes(state_dict, model)
418
- load_return = model.load_state_dict(state_dict, strict=False)
419
- else:
420
- # TODO: can probably check config class and see if we need to remap from a bert model
421
- state_dict = state_dict_from_pretrained(model_name)
422
- state_dict = remap_bert_state_dict(
423
- state_dict,
424
- config,
425
- remove_bert=remove_bert_prefix,
426
- remove_cls_weights=remove_cls,
427
- add_pooling_layer=getattr(config, "add_pooling_layer", False),
428
- )
429
- if ignore_mismatched_shapes:
430
- state_dict = filter_shapes(state_dict, model)
431
-
432
- load_return = model.load_state_dict(state_dict, strict=strict)
433
- logger.warning(load_return)
434
- return model
435
-
436
- def _set_gradient_checkpointing(self, module, value=False):
437
- if isinstance(module, NomicBertEncoder):
438
- module.gradient_checkpointing = value
439
-
440
-
441
- # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
442
- def _init_weights(module, initializer_range=0.02):
443
- if isinstance(module, nn.Linear):
444
- nn.init.normal_(module.weight, std=initializer_range)
445
- if module.bias is not None:
446
- nn.init.zeros_(module.bias)
447
- elif isinstance(module, nn.Embedding):
448
- nn.init.normal_(module.weight, std=initializer_range)
449
- if module.padding_idx is not None:
450
- nn.init.zeros_(module.weight[module.padding_idx])
451
-
452
- def _ntuple(n):
453
- def parse(x):
454
- if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
455
- return tuple(x)
456
- return tuple(repeat(x, n))
457
- return parse
458
-
459
-
460
- to_1tuple = _ntuple(1)
461
- to_2tuple = _ntuple(2)
462
- to_3tuple = _ntuple(3)
463
- to_4tuple = _ntuple(4)
464
- to_ntuple = _ntuple
465
-
466
-
467
- def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
468
- """
469
- Create 2D sin/cos positional embeddings.
470
-
471
- Args:
472
- embed_dim (`int`):
473
- Embedding dimension.
474
- grid_size (`int`):
475
- The grid height and width.
476
- add_cls_token (`bool`, *optional*, defaults to `False`):
477
- Whether or not to add a classification (CLS) token.
478
-
479
- Returns:
480
- (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
481
- position embeddings (with or without classification token)
482
- """
483
- grid_h = np.arange(grid_size, dtype=np.float32)
484
-
485
- grid_w = np.arange(grid_size, dtype=np.float32)
486
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
487
- grid = np.stack(grid, axis=0)
488
-
489
- grid = grid.reshape([2, 1, grid_size, grid_size])
490
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
491
- if add_cls_token:
492
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
493
- return pos_embed
494
-
495
-
496
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
497
- if embed_dim % 2 != 0:
498
- raise ValueError("embed_dim must be even")
499
-
500
- # use half of dimensions to encode grid_h
501
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
502
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
503
-
504
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
505
- return emb
506
-
507
-
508
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
509
- """
510
- embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
511
- """
512
- if embed_dim % 2 != 0:
513
- raise ValueError("embed_dim must be even")
514
-
515
- omega = np.arange(embed_dim // 2, dtype=float)
516
- omega /= embed_dim / 2.0
517
- omega = 1.0 / 10000**omega # (D/2,)
518
-
519
- pos = pos.reshape(-1) # (M,)
520
- out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
521
-
522
- emb_sin = np.sin(out) # (M, D/2)
523
- emb_cos = np.cos(out) # (M, D/2)
524
-
525
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
526
- return emb
527
-
528
- def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
529
- """generate N-D grid in dimension order.
530
-
531
- The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
532
-
533
- That is, the statement
534
- [X1,X2,X3] = ndgrid(x1,x2,x3)
535
-
536
- produces the same result as
537
-
538
- [X2,X1,X3] = meshgrid(x2,x1,x3)
539
-
540
- This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
541
- torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
542
-
543
- """
544
- try:
545
- return torch.meshgrid(*tensors, indexing='ij')
546
- except TypeError:
547
- # old PyTorch < 1.10 will follow this path as it does not have indexing arg,
548
- # the old behaviour of meshgrid was 'ij'
549
- return torch.meshgrid(*tensors)
550
-
551
- def build_fourier_pos_embed(
552
- feat_shape: List[int],
553
- bands: Optional[torch.Tensor] = None,
554
- num_bands: int = 64,
555
- max_res: int = 224,
556
- temperature: float = 10000.,
557
- linear_bands: bool = False,
558
- include_grid: bool = False,
559
- in_pixels: bool = True,
560
- ref_feat_shape: Optional[List[int]] = None,
561
- dtype: torch.dtype = torch.float32,
562
- device: Optional[torch.device] = None,
563
- ) -> List[torch.Tensor]:
564
- """
565
-
566
- Args:
567
- feat_shape: Feature shape for embedding.
568
- bands: Pre-calculated frequency bands.
569
- num_bands: Number of frequency bands (determines output dim).
570
- max_res: Maximum resolution for pixel based freq.
571
- temperature: Temperature for non-pixel freq.
572
- linear_bands: Linear band spacing for pixel based freq.
573
- include_grid: Include the spatial grid in output.
574
- in_pixels: Output in pixel freq.
575
- ref_feat_shape: Reference feature shape for resize / fine-tune.
576
- dtype: Output dtype.
577
- device: Output device.
578
-
579
- Returns:
580
-
581
- """
582
- if bands is None:
583
- if in_pixels:
584
- bands = pixel_freq_bands(
585
- num_bands,
586
- float(max_res),
587
- linear_bands=linear_bands,
588
- device=device,
589
- )
590
- else:
591
- bands = freq_bands(
592
- num_bands,
593
- temperature=temperature,
594
- step=1,
595
- device=device,
596
- )
597
- else:
598
- if device is None:
599
- device = bands.device
600
- if dtype is None:
601
- dtype = bands.dtype
602
-
603
- if in_pixels:
604
- t = [torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) for s in feat_shape]
605
- else:
606
- t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
607
-
608
- if ref_feat_shape is not None:
609
- # eva's scheme for resizing rope embeddings (ref shape = pretrain)
610
- t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
611
-
612
- grid = torch.stack(ndgrid(t), dim=-1)
613
- grid = grid.unsqueeze(-1)
614
- pos = grid * bands
615
-
616
- pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype)
617
- out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
618
- return out
619
-
620
-
621
- def build_rotary_pos_embed(
622
- feat_shape: List[int],
623
- bands: Optional[torch.Tensor] = None,
624
- dim: int = 64,
625
- max_res: int = 224,
626
- temperature: float = 10000.,
627
- linear_bands: bool = False,
628
- in_pixels: bool = True,
629
- ref_feat_shape: Optional[List[int]] = None,
630
- dtype: torch.dtype = torch.float32,
631
- device: Optional[torch.device] = None,
632
- ):
633
- """
634
-
635
- Args:
636
- feat_shape: Spatial shape of the target tensor for embedding.
637
- bands: Optional pre-generated frequency bands
638
- dim: Output dimension of embedding tensor.
639
- max_res: Maximum resolution for pixel mode.
640
- temperature: Temperature (inv freq) for non-pixel mode
641
- linear_bands: Linearly (instead of log) spaced bands for pixel mode
642
- in_pixels: Pixel vs language (inv freq) mode.
643
- dtype: Output dtype.
644
- device: Output device.
645
-
646
- Returns:
647
-
648
- """
649
- sin_emb, cos_emb = build_fourier_pos_embed(
650
- feat_shape,
651
- bands=bands,
652
- num_bands=dim // 4,
653
- max_res=max_res,
654
- temperature=temperature,
655
- linear_bands=linear_bands,
656
- in_pixels=in_pixels,
657
- ref_feat_shape=ref_feat_shape,
658
- device=device,
659
- dtype=dtype,
660
- )
661
- num_spatial_dim = 1
662
- # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
663
- for x in feat_shape:
664
- num_spatial_dim *= x
665
- sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
666
- cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
667
- return sin_emb, cos_emb
668
-
669
- def freq_bands(
670
- num_bands: int,
671
- temperature: float = 10000.,
672
- step: int = 2,
673
- device: Optional[torch.device] = None,
674
- ) -> torch.Tensor:
675
- exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
676
- bands = 1. / (temperature ** exp)
677
- return bands
678
-
679
-
680
- def pixel_freq_bands(
681
- num_bands: int,
682
- max_freq: float = 224.,
683
- linear_bands: bool = True,
684
- device: Optional[torch.device] = None,
685
- ):
686
- if linear_bands:
687
- bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
688
- else:
689
- bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
690
- return bands * torch.pi
691
-
692
- def rot(x):
693
- return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
694
-
695
- def apply_rot_embed_cat(x: torch.Tensor, emb):
696
- sin_emb, cos_emb = emb.tensor_split(2, -1)
697
- if sin_emb.ndim == 3:
698
- return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
699
- return x * cos_emb + rot(x) * sin_emb
700
-
701
- # taken from https://github.com/huggingface/pytorch-image-models/blob/cb0e4391beedcc5ac3ae4bce16561b95c326f32c/timm/layers/pos_embed_sincos.py#L363
702
- class NomicVisionRotaryEmbeddingCat(nn.Module):
703
- """ Rotary position embedding w/ concatenatd sin & cos
704
-
705
- The following impl/resources were referenced for this impl:
706
- * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
707
- * https://blog.eleuther.ai/rotary-embeddings/
708
- """
709
-
710
- def __init__(
711
- self,
712
- dim,
713
- max_res=224,
714
- temperature=10000,
715
- in_pixels=True,
716
- linear_bands: bool = False,
717
- feat_shape: Optional[List[int]] = None,
718
- ref_feat_shape: Optional[List[int]] = None,
719
- ):
720
- super().__init__()
721
- self.dim = dim
722
- self.max_res = max_res
723
- self.temperature = temperature
724
- self.in_pixels = in_pixels
725
- self.feat_shape = feat_shape
726
- self.ref_feat_shape = ref_feat_shape
727
-
728
- if feat_shape is None:
729
- # only cache bands
730
- if in_pixels:
731
- bands = pixel_freq_bands(
732
- dim // 4,
733
- float(max_res),
734
- linear_bands=linear_bands,
735
- )
736
- else:
737
- bands = freq_bands(
738
- dim // 4,
739
- temperature=temperature,
740
- step=1,
741
- )
742
- self.register_buffer(
743
- 'bands',
744
- bands,
745
- persistent=False,
746
- )
747
- self.pos_embed = None
748
- else:
749
- # cache full sin/cos embeddings if shape provided up front
750
- embeds = build_rotary_pos_embed(
751
- feat_shape=feat_shape,
752
- dim=dim,
753
- max_res=max_res,
754
- linear_bands=linear_bands,
755
- in_pixels=in_pixels,
756
- ref_feat_shape=self.ref_feat_shape,
757
- )
758
- self.bands = None
759
- self.register_buffer(
760
- 'pos_embed',
761
- torch.cat(embeds, -1),
762
- persistent=False,
763
- )
764
-
765
- def get_embed(self, shape: Optional[List[int]] = None):
766
- if self.bands is not None and shape is not None:
767
- # rebuild embeddings every call, use if target shape changes
768
- embeds = build_rotary_pos_embed(
769
- shape,
770
- self.bands,
771
- in_pixels=self.in_pixels,
772
- ref_feat_shape=self.ref_feat_shape,
773
- )
774
- return torch.cat(embeds, -1)
775
- elif self.pos_embed is not None:
776
- return self.pos_embed
777
- else:
778
- assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands"
779
-
780
- def forward(self, x):
781
- # assuming channel-first tensor where spatial dim are >= 2
782
- pos_embed = self.get_embed(x.shape[2:])
783
- return apply_rot_embed_cat(x, pos_embed)
784
-
785
- class NomicVisionPatchEmbeddings(nn.Module):
786
- def __init__(
787
- self,
788
- config,
789
- ):
790
- super().__init__()
791
- img_size = _pair(config.img_size)
792
- patch_size = _pair(config.patch_size)
793
- self.img_size = img_size
794
- self.patch_size = patch_size
795
- self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
796
- self.num_patches = self.grid_size[0] * self.grid_size[1]
797
-
798
- self.proj = nn.Linear(
799
- config.num_channels * patch_size[0] * patch_size[1], config.n_embd, bias=config.patch_embed_bias
800
- )
801
-
802
- self.learned_pos_embedding = False
803
- self.sinusoidal_pos_embedding = False
804
- self.no_embed_class = getattr(config, "no_embed_class", False)
805
-
806
- self.cls_token = nn.Parameter(torch.zeros(1, 1, config.n_embd)) if not getattr(config, "no_cls_token", False) else None
807
- if config.learned_pos_embedding:
808
- # this is the default in DINO
809
- self.learned_pos_embedding = True
810
- # hack for timm dinov2 with registers
811
- num_patches = self.num_patches if getattr(config, "register_tokens", 0) > 0 else self.num_patches + 1
812
- self.pos_embed = nn.Parameter(torch.randn(1, num_patches, config.n_embd) * 0.02) if getattr(config, "use_pos_embed", True) else None
813
- elif getattr(config, "sinusoidal_pos_embedding", False):
814
- self.sinusoidal_pos_embedding = True
815
- if getattr(config, "use_pos_embed", True):
816
- self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, config.n_embd), requires_grad=False)
817
- pos_embed = get_2d_sincos_pos_embed(config.n_embd, self.grid_size[0], add_cls_token=True)
818
- self.pos_embed.data.copy_(torch.from_numpy(pos_embed).to(self.pos_embed))
819
- else:
820
- self.pos_embed = None
821
- else:
822
- self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, config.n_embd) * 0.02) if getattr(config, "use_pos_embed", True) else None
823
-
824
- if getattr(config, "register_tokens", 0) > 0:
825
- self.reg_token = nn.Parameter(torch.randn(1, config.register_tokens, config.n_embd) * 0.02)
826
- else:
827
- self.reg_token = None
828
-
829
- if config.mask_token:
830
- self.mask_token = nn.Parameter(torch.zeros(1, config.n_embd))
831
-
832
- self.patch_dropout = nn.Identity()
833
-
834
- if getattr(config, "use_rotary_pos_emb", False):
835
- ref_feat_shape = getattr(config, "ref_feat_shape", None)
836
- ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
837
- self.rope = NomicVisionRotaryEmbeddingCat(
838
- config.n_embd // config.n_head,
839
- in_pixels=False,
840
- feat_shape=self.grid_size,
841
- ref_feat_shape=ref_feat_shape,
842
- )
843
- else:
844
- self.rope = None
845
-
846
-
847
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
848
- """
849
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
850
- resolution images.
851
-
852
- Source:
853
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
854
- """
855
- num_patches = embeddings.shape[1] - 1
856
- num_positions = self.pos_embed.shape[1] - 1
857
- if num_patches == num_positions and height == width:
858
- return self.pos_embed
859
- class_pos_embed = self.pos_embed[:, 0]
860
- patch_pos_embed = self.pos_embed[:, 1:]
861
- dim = embeddings.shape[-1]
862
- height = height // self.patch_size[0]
863
- width = width // self.patch_size[1]
864
- # we add a small number to avoid floating point error in the interpolation
865
- # see discussion at https://github.com/facebookresearch/dino/issues/8
866
- height, width = height + 0.1, width + 0.1
867
- patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
868
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
869
- patch_pos_embed = nn.functional.interpolate(
870
- patch_pos_embed,
871
- scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
872
- mode="bicubic",
873
- align_corners=False,
874
- )
875
- if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
876
- raise ValueError("Width or height does not match with the interpolated position embeddings")
877
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
878
- return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
879
-
880
- def forward(self, x):
881
- # deepspeed case where the input is in fp32
882
- if x.dtype != self.proj.weight.dtype:
883
- x = x.to(dtype=self.proj.weight.dtype)
884
-
885
- _, _, height, width = x.shape
886
- x = self.proj(
887
- rearrange(
888
- x,
889
- "b c (h p1) (w p2) -> b h w (c p1 p2)",
890
- p1=self.patch_size[0],
891
- p2=self.patch_size[1],
892
- )
893
- )
894
- embeddings = rearrange(x, "b h w c -> b (h w) c")
895
-
896
- to_cat = []
897
- if self.cls_token is not None:
898
- if self.sinusoidal_pos_embedding:
899
- cls_token = self.cls_token + self.pos_embed[:, 0]
900
- cls_token = cls_token.expand(embeddings.shape[0], -1, -1)
901
- to_cat += [cls_token]
902
- else:
903
- cls_token = self.cls_token.expand(embeddings.shape[0], 1, -1)
904
- to_cat += [cls_token]
905
-
906
- if self.reg_token is not None:
907
- to_cat += [self.reg_token.expand(embeddings.shape[0], -1, -1)]
908
-
909
- rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
910
-
911
- if self.no_embed_class:
912
- if self.learned_pos_embedding:
913
- embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
914
- else:
915
- if self.pos_embed is not None:
916
- embeddings = embeddings + self.pos_embed
917
- if to_cat:
918
- embeddings = torch.cat(to_cat + [embeddings], dim=1)
919
- else:
920
- if to_cat:
921
- embeddings = torch.cat(to_cat + [embeddings], dim=1)
922
- if self.learned_pos_embedding:
923
- if self.pos_embed is not None:
924
- embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
925
- else:
926
- if self.pos_embed is not None:
927
- embeddings = embeddings + self.pos_embed
928
-
929
- embeddings = self.patch_dropout(embeddings)
930
-
931
- return embeddings, rot_pos_embed
932
-
933
-
934
- class NomicBertEmbeddings(nn.Module):
935
- def __init__(self, config):
936
- """
937
- If max_position_embeddings <= 0, there's no position embeddings
938
- If type_vocab_size <= 0, there's no token type embeddings
939
- """
940
- super().__init__()
941
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
942
- self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
943
- self.type_vocab_size = config.type_vocab_size
944
- if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
945
- self.position_embeddings = nn.Embedding(
946
- config.max_position_embeddings,
947
- config.hidden_size,
948
- )
949
- if self.type_vocab_size > 0:
950
- self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
951
-
952
- def forward(self, input_ids, position_ids=None, token_type_ids=None):
953
- """
954
- input_ids: (batch, seqlen)
955
- position_ids: (batch, seqlen)
956
- token_type_ids: (batch, seqlen)
957
- """
958
- batch_size, seqlen = input_ids.shape
959
- embeddings = self.word_embeddings(input_ids)
960
-
961
- if self.type_vocab_size > 0:
962
- if token_type_ids is None:
963
- token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
964
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
965
- embeddings = embeddings + token_type_embeddings
966
-
967
- if self.max_position_embeddings > 0:
968
- if position_ids is None:
969
- position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
970
- position_embeddings = self.position_embeddings(position_ids)
971
- embeddings = embeddings + position_embeddings
972
- return embeddings
973
-
974
-
975
- class NomicBertMLP(nn.Module):
976
- def __init__(
977
- self,
978
- in_features,
979
- hidden_features=None,
980
- out_features=None,
981
- activation=F.gelu,
982
- bias1=True,
983
- bias2=True,
984
- return_residual=False,
985
- fused_bias_fc=False,
986
- ):
987
- super().__init__()
988
- out_features = out_features if out_features is not None else in_features
989
- hidden_features = hidden_features if hidden_features is not None else in_features * 4
990
- self.return_residual = return_residual
991
- self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
992
- approximate = "tanh" if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
993
- self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
994
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
995
-
996
- def forward(self, x):
997
- y = self.fc1(x)
998
- y = self.activation(y)
999
- y = self.fc2(y)
1000
- return y if not self.return_residual else (y, x)
1001
-
1002
-
1003
- class NomciBertGatedMLP(nn.Module):
1004
- def __init__(
1005
- self,
1006
- in_features,
1007
- hidden_features=None,
1008
- out_features=None,
1009
- activation=F.sigmoid,
1010
- bias1=True,
1011
- bias2=True,
1012
- multiple_of=256,
1013
- return_residual=False,
1014
- fused_bias_fc=True,
1015
- device=None,
1016
- dtype=None,
1017
- norm_layer=False,
1018
- ):
1019
- super().__init__()
1020
- out_features = out_features if out_features is not None else in_features
1021
- hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3)
1022
- hidden_features = int((hidden_features + multiple_of - 1) // multiple_of * multiple_of)
1023
- self.return_residual = return_residual
1024
-
1025
- self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1)
1026
- self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1)
1027
- self.activation = activation
1028
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
1029
- self.norm = nn.LayerNorm(hidden_features) if norm_layer else nn.Identity()
1030
-
1031
- def forward(self, x):
1032
- y = self.fc11(x)
1033
- gate = self.fc12(x)
1034
- if self.activation == F.sigmoid: # Special case for GLU
1035
- y = F.glu(torch.cat([y, gate], dim=-1), dim=-1)
1036
- else:
1037
- y = y * self.activation(gate)
1038
-
1039
- # eva uses layer norm after the activation
1040
- y = self.norm(y)
1041
-
1042
- y = self.fc2(y)
1043
- return y if not self.return_residual else (y, x)
1044
-
1045
-
1046
- def rotate_half(x, interleaved=False):
1047
- if not interleaved:
1048
- x1, x2 = x.chunk(2, dim=-1)
1049
- return torch.cat((-x2, x1), dim=-1)
1050
- else:
1051
- x1, x2 = x[..., ::2], x[..., 1::2]
1052
- return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
1053
-
1054
-
1055
- def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False):
1056
- """
1057
- x: (batch_size, seqlen, nheads, headdim)
1058
- cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
1059
- """
1060
- ro_dim = cos.shape[-1] * 2
1061
- assert ro_dim <= x.shape[-1]
1062
- cos, sin = (
1063
- cos[offset : offset + x.shape[1]],
1064
- sin[offset : offset + x.shape[1]],
1065
- )
1066
- cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
1067
- sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
1068
- return torch.cat(
1069
- [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
1070
- dim=-1,
1071
- )
1072
-
1073
-
1074
- class NomicBertRotaryEmbedding(nn.Module):
1075
- def __init__(
1076
- self,
1077
- dim: int,
1078
- base=10000.0,
1079
- interleaved=False,
1080
- scale_base=None,
1081
- pos_idx_in_fp32=True,
1082
- device=None,
1083
- ):
1084
- """
1085
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
1086
- of 1st half and 2nd half (GPT-NeoX style).
1087
- pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
1088
- otherwise they might be in lower precision.
1089
- This option was added because previously (before 2023-07-02), when we construct
1090
- the position indices, we use the dtype of self.inv_freq. In most cases this would
1091
- be fp32, but if the model is trained in pure bf16 (not mixed precision), then
1092
- self.inv_freq would be bf16, and the position indices are also in bf16.
1093
- Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
1094
- embeddings for some positions will coincide.
1095
- To maintain compatibility with models previously trained in pure bf16,
1096
- we add this option.
1097
- """
1098
- super().__init__()
1099
- self.dim = dim
1100
- self.base = float(base)
1101
- self.pos_idx_in_fp32 = pos_idx_in_fp32
1102
- # Generate and save the inverse frequency buffer (non trainable)
1103
- inv_freq = self._compute_inv_freq(device)
1104
- self.register_buffer("inv_freq", inv_freq, persistent=False)
1105
- self.interleaved = interleaved
1106
- self.scale_base = scale_base
1107
- scale = (
1108
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
1109
- if scale_base is not None
1110
- else None
1111
- )
1112
- self.register_buffer("scale", scale, persistent=False)
1113
-
1114
- self._seq_len_cached = 0
1115
- self._cos_cached = None
1116
- self._sin_cached = None
1117
- self._cos_k_cached = None
1118
- self._sin_k_cached = None
1119
-
1120
- def _compute_inv_freq(self, device=None):
1121
- return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
1122
-
1123
- def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
1124
- # Reset the tables if the sequence length has changed,
1125
- # if we're on a new device (possibly due to tracing for instance),
1126
- # or if we're switching from inference mode to training
1127
- if (
1128
- seqlen > self._seq_len_cached
1129
- or self._cos_cached is None
1130
- or self._cos_cached.device != device
1131
- or self._cos_cached.dtype != dtype
1132
- or (self.training and self._cos_cached.is_inference())
1133
- ):
1134
- self._seq_len_cached = seqlen
1135
- # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
1136
- # And the output of arange can be quite large, so bf16 would lose a lot of precision.
1137
- # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
1138
- if self.pos_idx_in_fp32:
1139
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
1140
- # We want fp32 here as well since inv_freq will be multiplied with t, and the output
1141
- # will be large. Having it in bf16 will lose a lot of precision and cause the
1142
- # cos & sin output to change significantly.
1143
- # We want to recompute self.inv_freq if it was not loaded in fp32
1144
- if self.inv_freq.dtype != torch.float32:
1145
- inv_freq = self._compute_inv_freq(device=device)
1146
- else:
1147
- inv_freq = self.inv_freq
1148
- else:
1149
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
1150
- inv_freq = self.inv_freq
1151
- # Don't do einsum, it converts fp32 to fp16 under AMP
1152
- # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
1153
- freqs = torch.outer(t, inv_freq)
1154
- self._cos_cached = torch.cos(freqs).to(dtype)
1155
- self._sin_cached = torch.sin(freqs).to(dtype)
1156
-
1157
- def forward(
1158
- self,
1159
- qkv: torch.Tensor,
1160
- kv: Optional[torch.Tensor] = None,
1161
- seqlen_offset: Union[int, torch.Tensor] = 0,
1162
- max_seqlen: Optional[int] = None,
1163
- ) -> Tuple[torch.Tensor, torch.Tensor]:
1164
- """
1165
- qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
1166
- else it's just q of shape (batch, seqlen, nheads, headdim)
1167
- kv: (batch, seqlen, 2, nheads, headdim)
1168
- seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
1169
- Most commonly used in inference when we have KV cache.
1170
- If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
1171
- should pass in max_seqlen, which will update the cos / sin cache up to that length.
1172
- Apply rotary embedding *inplace* to qkv and / or kv.
1173
- """
1174
- seqlen = qkv.shape[1]
1175
- if seqlen > self._seq_len_cached:
1176
- self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
1177
- elif max_seqlen is not None:
1178
- self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
1179
- elif isinstance(seqlen_offset, int):
1180
- self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
1181
-
1182
- q_rot = apply_rotary_emb(qkv[:, :, 0], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
1183
- k_rot = apply_rotary_emb(qkv[:, :, 1], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
1184
- return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
1185
-
1186
-
1187
- class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
1188
- def __init__(self, rotary_scaling_factor, max_position_embeddings, **kwargs):
1189
- super().__init__(**kwargs)
1190
- self.rotary_scaling_factor = rotary_scaling_factor
1191
- self.max_position_embeddings = max_position_embeddings
1192
-
1193
- def _compute_inv_freq(self, base=None, device=None):
1194
- if base is None:
1195
- base = self.base
1196
- return 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
1197
-
1198
- def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
1199
- # Reset the tables if the sequence length has changed,
1200
- # if we're on a new device (possibly due to tracing for instance),
1201
- # or if we're switching from inference mode to training
1202
- if seqlen > self.max_position_embeddings:
1203
- base = self.base * (
1204
- (self.rotary_scaling_factor * seqlen / self.max_position_embeddings) - (self.rotary_scaling_factor - 1)
1205
- ) ** (self.dim / (self.dim - 2))
1206
- inv_freq = self._compute_inv_freq(base=base, device=device)
1207
- self.register_buffer("inv_freq", inv_freq, persistent=False)
1208
-
1209
- if (
1210
- seqlen > self._seq_len_cached
1211
- or self._cos_cached is None
1212
- or self._cos_cached.device != device
1213
- or self._cos_cached.dtype != dtype
1214
- or (self.training and self._cos_cached.is_inference())
1215
- ):
1216
- self._seq_len_cached = seqlen
1217
- # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
1218
- # And the output of arange can be quite large, so bf16 would lose a lot of precision.
1219
- # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
1220
- if self.pos_idx_in_fp32:
1221
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
1222
- # We want fp32 here as well since inv_freq will be multiplied with t, and the output
1223
- # will be large. Having it in bf16 will lose a lot of precision and cause the
1224
- # cos & sin output to change significantly.
1225
- # We want to recompute self.inv_freq if it was not loaded in fp32
1226
- if self.inv_freq.dtype != torch.float32:
1227
- if seqlen > self.max_position_embeddings:
1228
- base = self.base * (
1229
- (self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)
1230
- ) ** (self.dim / (self.dim - 2))
1231
- else:
1232
- base = self.base
1233
- inv_freq = self._compute_inv_freq(device=device, base=base)
1234
- else:
1235
- inv_freq = self.inv_freq
1236
- else:
1237
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
1238
- inv_freq = self.inv_freq
1239
- # Don't do einsum, it converts fp32 to fp16 under AMP
1240
- # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
1241
- freqs = torch.outer(t, inv_freq)
1242
- if self.scale is None:
1243
- self._cos_cached = torch.cos(freqs).to(dtype)
1244
- self._sin_cached = torch.sin(freqs).to(dtype)
1245
- else:
1246
- power = (
1247
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
1248
- ) / self.scale_base
1249
- scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
1250
- # We want the multiplication by scale to happen in fp32
1251
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
1252
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
1253
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
1254
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
1255
-
1256
-
1257
- class NomicBertAttention(nn.Module):
1258
- """Multi-head self-attention and cross-attention"""
1259
-
1260
- def __init__(
1261
- self,
1262
- config,
1263
- ) -> None:
1264
- """
1265
- num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
1266
- return_residual: whether to return the input x along with the output. This is for
1267
- performance reason: for post-norm architecture, returning the input allows us
1268
- to fuse the backward of nn.Linear with the residual connection.
1269
- """
1270
- super().__init__()
1271
- self.embed_dim = config.n_embd
1272
- self.use_flash_attn = config.use_flash_attn
1273
- self.fused_bias_fc = config.fused_bias_fc
1274
-
1275
- self.num_heads = config.n_head
1276
- self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
1277
- assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
1278
- self.head_dim = self.embed_dim // self.num_heads
1279
- # we don't really support mqa / gqa for now
1280
- qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
1281
-
1282
- self.register_buffer(
1283
- "norm_factor",
1284
- torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
1285
- persistent=False,
1286
- )
1287
-
1288
- self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
1289
- if self.rotary_emb_dim > 0:
1290
- if getattr(config, "rotary_scaling_factor", None):
1291
- self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
1292
- dim=self.rotary_emb_dim,
1293
- base=config.rotary_emb_base,
1294
- scale_base=config.rotary_emb_scale_base,
1295
- interleaved=config.rotary_emb_interleaved,
1296
- rotary_scaling_factor=config.rotary_scaling_factor,
1297
- max_position_embeddings=config.max_trained_positions,
1298
- )
1299
- else:
1300
- self.rotary_emb = NomicBertRotaryEmbedding(
1301
- dim=self.rotary_emb_dim,
1302
- base=config.rotary_emb_base,
1303
- scale_base=config.rotary_emb_scale_base,
1304
- interleaved=config.rotary_emb_interleaved,
1305
- )
1306
- # bug in xformers: https://github.com/facebookresearch/xformers/issues/841
1307
- # uses the head dimension instead of the sequence dimension
1308
- self.rotary_head_dim = getattr(config, "rotary_head_dim", False)
1309
-
1310
- self.Wqkv = nn.Linear(self.embed_dim, qkv_dim, bias=config.qkv_proj_bias)
1311
-
1312
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
1313
- self.causal = config.causal
1314
- self.drop = nn.Dropout(config.attn_pdrop)
1315
- self.num_prefix_tokens = max(getattr(config, "register_tokens", 1), 1)
1316
-
1317
- def forward(
1318
- self,
1319
- hidden_states: torch.Tensor,
1320
- attention_mask: Optional[torch.Tensor] = None,
1321
- position_ids: Optional[torch.LongTensor] = None,
1322
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
1323
- output_attentions: bool = False,
1324
- use_cache: bool = False,
1325
- is_padded_inputs: Optional[bool] = True,
1326
- cu_seqlens: Optional[torch.Tensor] = None,
1327
- max_seq_len: Optional[int] = None,
1328
- rope: Optional[torch.Tensor] = None,
1329
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1330
-
1331
- has_layer_past = past_key_value is not None
1332
-
1333
- if has_layer_past:
1334
- past_key_value = past_key_value[0]
1335
- past_len = past_key_value[1]
1336
- else:
1337
- past_len = 0
1338
-
1339
- qkv = self.Wqkv(hidden_states)
1340
- qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
1341
-
1342
- past_key_value = (past_key_value, past_len + qkv.size(1)) if use_cache else None
1343
-
1344
- if self.rotary_emb_dim > 0:
1345
- if self.rotary_head_dim:
1346
- qkv = rearrange(qkv, "b s three h d -> b h three s d")
1347
- qkv = self.rotary_emb(qkv, seqlen_offset=past_len)
1348
-
1349
- if self.rotary_head_dim:
1350
- qkv = rearrange(qkv, "b h three s d -> b s three h d")
1351
- elif rope is not None:
1352
- q, k, v = qkv.permute(0, 3, 1, 2, 4).unbind(dim=-2)
1353
- q = torch.cat([q[:, :, :self.num_prefix_tokens], apply_rot_embed_cat(q[:, :, self.num_prefix_tokens:], rope)], dim=2).type_as(q)
1354
- k = torch.cat([k[:, :, :self.num_prefix_tokens], apply_rot_embed_cat(k[:, :, self.num_prefix_tokens:], rope)], dim=2).type_as(q)
1355
-
1356
- qkv = torch.stack([q, k, v], dim=-2)
1357
- qkv = rearrange(qkv, "b h s three d -> b s three h d")
1358
-
1359
- query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
1360
-
1361
- query = query.permute(0, 2, 1, 3)
1362
- key = key.permute(0, 2, 1, 3)
1363
- value = value.permute(0, 2, 1, 3)
1364
-
1365
- attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
1366
- if attention_mask is not None:
1367
- attention_scores = attention_scores + attention_mask
1368
-
1369
- attentions_probs = F.softmax(attention_scores, dim=-1)
1370
- attentions_probs = self.drop(attentions_probs)
1371
-
1372
- attn_output = torch.matmul(attentions_probs, value)
1373
- attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
1374
-
1375
- attn_output = self.out_proj(attn_output)
1376
-
1377
- return attn_output
1378
-
1379
-
1380
- class NomicBertBlock(NomicBertPreTrainedModel):
1381
- def __init__(
1382
- self,
1383
- config,
1384
- ):
1385
- super().__init__(config=config)
1386
- self.prenorm = config.prenorm
1387
- self.fused_dropout_add_ln = config.fused_dropout_add_ln
1388
-
1389
- self.attn = NomicBertAttention(config)
1390
- activation = (
1391
- F.sigmoid
1392
- if config.activation_function == "glu"
1393
- else (F.silu if config.activation_function == "swiglu" else F.gelu)
1394
- )
1395
- if config.activation_function in ["glu", "swiglu", "geglu"]:
1396
- self.mlp = NomciBertGatedMLP(
1397
- config.n_embd,
1398
- hidden_features=config.n_inner,
1399
- bias1=config.mlp_fc1_bias,
1400
- bias2=config.mlp_fc2_bias,
1401
- activation=activation,
1402
- fused_bias_fc=config.fused_bias_fc,
1403
- norm_layer=getattr(config, "norm_mlp", False),
1404
- )
1405
- else:
1406
- self.mlp = NomicBertMLP(
1407
- config.n_embd,
1408
- hidden_features=config.n_inner,
1409
- bias1=config.mlp_fc1_bias,
1410
- bias2=config.mlp_fc2_bias,
1411
- activation=activation,
1412
- fused_bias_fc=config.fused_bias_fc,
1413
- )
1414
-
1415
- self.dropout1 = nn.Dropout(config.resid_pdrop)
1416
- self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1417
- self.norm2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1418
- self.dropout2 = nn.Dropout(config.resid_pdrop)
1419
-
1420
- def forward(
1421
- self,
1422
- hidden_states: torch.Tensor,
1423
- hidden_states2: torch.Tensor,
1424
- residual: Optional[torch.Tensor] = None,
1425
- attention_mask: Optional[torch.Tensor] = None,
1426
- position_ids: Optional[torch.LongTensor] = None,
1427
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
1428
- is_padded_inputs: Optional[bool] = True,
1429
- output_attentions: Optional[bool] = False,
1430
- use_cache: Optional[bool] = False,
1431
- cu_seqlens: Optional[torch.Tensor] = None,
1432
- max_seq_len: Optional[int] = None,
1433
- rope: Optional[torch.Tensor] = None,
1434
- ):
1435
- r"""Pass the input through the encoder layer.
1436
-
1437
- Args:
1438
- hidden_states: the sequence to the encoder layer (required).
1439
- residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
1440
- mixer_subset: for cross-attention only. If not None, will take a subset of x
1441
- before applying the query projection. Useful for e.g., ViT where we only care
1442
- about the CLS token in the last layer.
1443
- """
1444
- if self.prenorm:
1445
- dropped = self.dropout1(hidden_states)
1446
- residual = (dropped + residual) if residual is not None else dropped
1447
- hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
1448
- hidden_states = self.attn(
1449
- hidden_states,
1450
- attention_mask=attention_mask,
1451
- is_padded_inputs=is_padded_inputs,
1452
- cu_seqlens=cu_seqlens,
1453
- max_seq_len=max_seq_len,
1454
- rope=rope,
1455
- )
1456
-
1457
- dropped = self.dropout2(hidden_states)
1458
- residual = (dropped + residual) if residual is not None else dropped
1459
- hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
1460
- hidden_states = self.mlp(hidden_states)
1461
-
1462
- return hidden_states, None, residual
1463
- else:
1464
- assert residual is None
1465
- attn_outputs = self.attn(
1466
- hidden_states,
1467
- attention_mask=attention_mask,
1468
- is_padded_inputs=is_padded_inputs,
1469
- cu_seqlens=cu_seqlens,
1470
- max_seq_len=max_seq_len,
1471
- rope=rope,
1472
- )
1473
- hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
1474
- mlp_out = self.mlp(hidden_states)
1475
-
1476
- hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
1477
- return hidden_states, None, None
1478
-
1479
-
1480
- class NomicBertEncoder(nn.Module):
1481
- def __init__(self, config: GPT2Config):
1482
- super().__init__()
1483
- self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
1484
- self.gradient_checkpointing = False
1485
- self.config = config
1486
-
1487
- def forward(
1488
- self,
1489
- hidden_states: torch.LongTensor = None,
1490
- attention_mask: Optional[torch.Tensor] = None,
1491
- position_ids: Optional[torch.LongTensor] = None,
1492
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1493
- inputs_embeds: Optional[torch.FloatTensor] = None,
1494
- use_cache: Optional[bool] = None,
1495
- output_attentions: Optional[bool] = None,
1496
- output_hidden_states: Optional[bool] = None,
1497
- return_dict: Optional[bool] = None,
1498
- is_padded_inputs: Optional[bool] = True,
1499
- rope: Optional[torch.Tensor] = None,
1500
- ):
1501
- """If subset_mask is not None, we only want output for the subset of the sequence.
1502
- This means that we only compute the last layer output for these tokens.
1503
- subset_mask: (batch, seqlen), dtype=torch.bool
1504
- """
1505
- hidden_states2 = None
1506
- residual = None
1507
-
1508
- for _, layer in enumerate(self.layers):
1509
- if self.gradient_checkpointing and self.training:
1510
-
1511
- def create_custom_forward(module):
1512
- def custom_forward(*inputs):
1513
- # None for past_key_value
1514
- return module(*inputs)
1515
-
1516
- return custom_forward
1517
-
1518
- hidden_states, hidden_states2, residual = torch.utils.checkpoint.checkpoint(
1519
- create_custom_forward(layer),
1520
- hidden_states,
1521
- hidden_states2,
1522
- residual,
1523
- attention_mask,
1524
- position_ids,
1525
- past_key_values,
1526
- is_padded_inputs,
1527
- output_attentions,
1528
- use_cache,
1529
- None,
1530
- None,
1531
- rope,
1532
- # if you freeze ANY layers, you need `use_reentrant=False`
1533
- # https://github.com/huggingface/transformers/issues/21381
1534
- # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/7
1535
- use_reentrant=False,
1536
- )
1537
-
1538
- else:
1539
- hidden_states, hidden_states2, residual = layer(
1540
- hidden_states,
1541
- hidden_states2,
1542
- residual,
1543
- attention_mask,
1544
- position_ids,
1545
- None,
1546
- is_padded_inputs,
1547
- output_attentions,
1548
- use_cache,
1549
- rope=rope,
1550
- )
1551
- return hidden_states
1552
-
1553
-
1554
- class NomicBertPooler(nn.Module):
1555
- def __init__(self, config):
1556
- super().__init__()
1557
- self.dense = nn.Linear(config.n_embd, config.n_embd)
1558
- self.activation = nn.Tanh()
1559
-
1560
- def forward(self, hidden_states, pool=True):
1561
- # We "pool" the model by simply taking the hidden state corresponding
1562
- # to the first token.
1563
- first_token_tensor = hidden_states[:, 0] if pool else hidden_states
1564
- pooled_output = self.dense(first_token_tensor)
1565
- pooled_output = self.activation(pooled_output)
1566
- return pooled_output
1567
-
1568
-
1569
- class NomicBertPredictionHeadTransform(nn.Module):
1570
- def __init__(self, config):
1571
- super().__init__()
1572
- self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
1573
- approximate = "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
1574
- if config.activation_function == "swiglu":
1575
- self.transform_act_fn = F.silu
1576
- else:
1577
- self.transform_act_fn = nn.GELU(approximate=approximate)
1578
-
1579
- self.layer_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1580
-
1581
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1582
- hidden_states = self.dense(hidden_states)
1583
- hidden_states = self.transform_act_fn(hidden_states)
1584
- hidden_states = self.layer_norm(hidden_states)
1585
-
1586
- return hidden_states
1587
-
1588
-
1589
- class NomicBertLMPredictionHead(nn.Module):
1590
- def __init__(self, config):
1591
- super().__init__()
1592
-
1593
- self.transform = NomicBertPredictionHeadTransform(config)
1594
-
1595
- self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=config.mlp_fc1_bias)
1596
-
1597
- def forward(self, hidden_states):
1598
- hidden_states = self.transform(hidden_states)
1599
- hidden_states = self.decoder(hidden_states)
1600
- return hidden_states
1601
-
1602
-
1603
- class NomicBertPreTrainingHeads(nn.Module):
1604
- def __init__(self, config):
1605
- super().__init__()
1606
- self.predictions = NomicBertLMPredictionHead(config)
1607
-
1608
- def forward(self, sequence_output):
1609
- prediction_scores = self.predictions(sequence_output)
1610
- return prediction_scores
1611
-
1612
-
1613
- class NomicBertModel(NomicBertPreTrainedModel):
1614
- def __init__(self, config: GPT2Config, add_pooling_layer=True):
1615
- super().__init__(config)
1616
- self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
1617
- if config.vocab_size % self.pad_vocab_size_multiple != 0:
1618
- config.vocab_size += self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple)
1619
-
1620
- assert config.activation_function in [
1621
- "gelu",
1622
- "gelu_new",
1623
- "gelu_fast",
1624
- "gelu_pytorch_tanh",
1625
- "swiglu",
1626
- "geglu",
1627
- "glu",
1628
- ]
1629
-
1630
- self.embeddings = NomicBertEmbeddings(config)
1631
- self.emb_drop = nn.Dropout(config.resid_pdrop)
1632
- self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1633
- self.encoder = NomicBertEncoder(config)
1634
- self.pooler = NomicBertPooler(config) if add_pooling_layer else None
1635
-
1636
- self.apply(partial(_init_weights, initializer_range=config.initializer_range))
1637
-
1638
- def forward(
1639
- self,
1640
- input_ids,
1641
- attention_mask=None,
1642
- position_ids=None,
1643
- token_type_ids=None,
1644
- return_dict=None,
1645
- matryoshka_dim=None,
1646
- ):
1647
- if token_type_ids is None:
1648
- token_type_ids = torch.zeros_like(input_ids)
1649
- hidden_states = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
1650
- hidden_states = self.emb_ln(hidden_states)
1651
- hidden_states = self.emb_drop(hidden_states)
1652
-
1653
- attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
1654
- sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict)
1655
-
1656
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1657
-
1658
- if matryoshka_dim:
1659
- sequence_output = sequence_output[:, :matryoshka_dim]
1660
-
1661
- return BaseModelOutputWithPoolingAndCrossAttentions(
1662
- last_hidden_state=sequence_output,
1663
- pooler_output=pooled_output,
1664
- )
1665
-
1666
-
1667
- class NomicBertForPreTraining(NomicBertPreTrainedModel):
1668
- _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1669
-
1670
- def __init__(self, config: GPT2Config):
1671
- super().__init__(config)
1672
-
1673
- self.bert = NomicBertModel(config, add_pooling_layer=getattr(config, "add_pooling_layer", False))
1674
- self.cls = NomicBertPreTrainingHeads(config)
1675
- self.mlm_loss = nn.CrossEntropyLoss()
1676
-
1677
- # Initialize weights and apply final processing
1678
- self.apply(partial(_init_weights, initializer_range=config.initializer_range))
1679
- self.tie_weights()
1680
-
1681
- def tie_weights(self):
1682
- self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
1683
-
1684
- def forward(
1685
- self,
1686
- input_ids,
1687
- position_ids=None,
1688
- token_type_ids=None,
1689
- attention_mask=None,
1690
- labels=None,
1691
- ):
1692
- """
1693
- If labels are provided, they must be -100 for masked out tokens (as specified in the attention
1694
- mask).
1695
- Outputs:
1696
- if `labels` and `next_sentence_label` are not `None`:
1697
- Outputs the total_loss which is the sum of the masked language modeling loss and the next
1698
- sentence classification loss.
1699
- if `labels` or `next_sentence_label` is `None`:
1700
- Outputs a tuple comprising
1701
- - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
1702
- - the next sentence classification logits of shape [batch_size, 2].
1703
-
1704
- """
1705
- outputs = self.bert(
1706
- input_ids,
1707
- position_ids=position_ids,
1708
- token_type_ids=token_type_ids,
1709
- attention_mask=attention_mask.bool() if attention_mask is not None else None,
1710
- )
1711
- sequence_output, _ = outputs.last_hidden_state, outputs.pooler_output
1712
-
1713
- prediction_scores = self.cls(sequence_output)
1714
-
1715
- total_loss = None
1716
- if labels is not None:
1717
- masked_lm_loss = self.mlm_loss(
1718
- rearrange(prediction_scores, "... v -> (...) v"),
1719
- rearrange(labels, "... -> (...)"),
1720
- )
1721
- total_loss = masked_lm_loss.float()
1722
-
1723
- return MaskedLMOutput(
1724
- loss=total_loss,
1725
- logits=prediction_scores,
1726
- hidden_states=outputs.hidden_states,
1727
- attentions=None,
1728
- )
1729
-
1730
-
1731
- class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1732
- def __init__(self, config):
1733
- super().__init__(config)
1734
- self.num_labels = config.num_labels
1735
- self.config = config
1736
-
1737
- self.bert = NomicBertModel(config)
1738
- classifier_dropout = getattr(config, "classifier_dropout", config.embd_pdrop)
1739
- self.dropout = nn.Dropout(classifier_dropout)
1740
- self.classifier = nn.Linear(config.n_embd, config.num_labels)
1741
-
1742
- # Initialize weights and apply final processing
1743
- self.post_init()
1744
-
1745
- def forward(
1746
- self,
1747
- input_ids: Optional[torch.Tensor] = None,
1748
- attention_mask: Optional[torch.Tensor] = None,
1749
- token_type_ids: Optional[torch.Tensor] = None,
1750
- position_ids: Optional[torch.Tensor] = None,
1751
- head_mask: Optional[torch.Tensor] = None,
1752
- inputs_embeds: Optional[torch.Tensor] = None,
1753
- labels: Optional[torch.Tensor] = None,
1754
- output_attentions: Optional[bool] = None,
1755
- output_hidden_states: Optional[bool] = None,
1756
- return_dict: Optional[bool] = None,
1757
- ):
1758
- r"""
1759
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1760
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1761
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1762
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1763
- """
1764
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1765
- outputs = self.bert(
1766
- input_ids,
1767
- position_ids=position_ids,
1768
- token_type_ids=token_type_ids,
1769
- attention_mask=attention_mask.bool() if attention_mask is not None else None,
1770
- )
1771
-
1772
- pooled_output = outputs[1]
1773
-
1774
- pooled_output = self.dropout(pooled_output)
1775
- logits = self.classifier(pooled_output)
1776
-
1777
- loss = None
1778
- if labels is not None:
1779
- if self.config.problem_type is None:
1780
- if self.num_labels == 1:
1781
- self.config.problem_type = "regression"
1782
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1783
- self.config.problem_type = "single_label_classification"
1784
- else:
1785
- self.config.problem_type = "multi_label_classification"
1786
-
1787
- if self.config.problem_type == "regression":
1788
- loss_fct = nn.MSELoss()
1789
- if self.num_labels == 1:
1790
- loss = loss_fct(logits.squeeze(), labels.squeeze())
1791
- else:
1792
- loss = loss_fct(logits, labels)
1793
- elif self.config.problem_type == "single_label_classification":
1794
- loss_fct = nn.CrossEntropyLoss()
1795
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1796
- elif self.config.problem_type == "multi_label_classification":
1797
- loss_fct = nn.BCEWithLogitsLoss()
1798
- loss = loss_fct(logits, labels)
1799
- if not return_dict:
1800
- output = (logits,) + outputs[2:]
1801
- return ((loss,) + output) if loss is not None else output
1802
-
1803
- return SequenceClassifierOutput(
1804
- loss=loss,
1805
- logits=logits,
1806
- hidden_states=outputs.hidden_states,
1807
- attentions=outputs.attentions,
1808
- )
1809
-
1810
- def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
1811
- return GPT2Config(
1812
- n_embd=vit_config.hidden_size,
1813
- n_layer=vit_config.num_hidden_layers,
1814
- n_head=vit_config.num_attention_heads,
1815
- n_inner=vit_config.intermediate_size,
1816
- activation_function=vit_config.hidden_act,
1817
- vocab_size=0, # no vocab since using patches
1818
- n_positions=0, # No absolute position embedding
1819
- resid_pdrop=0.0, # No dropout
1820
- embd_pdrop=getattr(vit_config, "dropout", 0.0),
1821
- attn_pdrop=vit_config.attention_probs_dropout_prob,
1822
- layer_norm_epsilon=vit_config.layer_norm_eps,
1823
- initializer_range=vit_config.initializer_range,
1824
- bos_token_id=None,
1825
- eos_token_id=None,
1826
- # These are new arguments not in the original GPT2Config
1827
- drop_path_rate=0.0,
1828
- # Why is there double layer norm??
1829
- prepre_layernom=False,
1830
- layer_scale=False,
1831
- layer_scale_init=None,
1832
- img_size=vit_config.image_size,
1833
- patch_size=vit_config.patch_size,
1834
- num_channels=vit_config.num_channels,
1835
- prenorm=True,
1836
- parallel_block=False,
1837
- parallel_block_tied_norm=False,
1838
- rotary_emb_fraction=0,
1839
- tie_word_embeddings=False,
1840
- fused_dropout_add_ln=True,
1841
- fused_bias_fc=True,
1842
- patch_embed_bias=True,
1843
- use_flash_attn=True,
1844
- qkv_proj_bias=True,
1845
- mlp_fc1_bias=getattr(vit_config, "mlp_fc1_bias", True),
1846
- mlp_fc2_bias=getattr(vit_config, "mlp_fc2_bias", True),
1847
- use_rms_norm=False,
1848
- causal=False,
1849
- hidden_features_scaling_factor=1.0,
1850
- mask_token=False,
1851
- learned_pos_embedding=False,
1852
- patch_dropout=0,
1853
- sinusoidal_pos_embedding=vit_config.model_type == "vit_mae"
1854
- )
1855
-
1856
-
1857
- class NomicAttentionPooling(nn.Module):
1858
- def __init__(
1859
- self,
1860
- config
1861
- ):
1862
- super().__init__()
1863
- self.embed_dim = config.n_embd
1864
- self.use_flash_attn = config.use_flash_attn
1865
- self.fused_bias_fc = config.fused_bias_fc
1866
-
1867
- self.num_heads = config.n_head
1868
- self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
1869
- assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
1870
- self.head_dim = self.embed_dim // self.num_heads
1871
- # we don't really support mqa / gqa for now
1872
- kv_dim = 2 * self.head_dim * self.num_heads_kv
1873
-
1874
- self.register_buffer(
1875
- "norm_factor",
1876
- torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
1877
- persistent=False,
1878
- )
1879
-
1880
- self.Wq = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
1881
- self.Wkv = nn.Linear(self.embed_dim, kv_dim, bias=config.qkv_proj_bias)
1882
-
1883
- self.latent = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
1884
-
1885
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
1886
- self.causal = config.causal
1887
- self.drop = nn.Dropout(config.attn_pdrop)
1888
-
1889
- def init_weights(self):
1890
- trunc_normal_tf_(self.latent, std=self.embed_dim ** -0.5)
1891
-
1892
- def forward(
1893
- self,
1894
- kv,
1895
- attention_mask=None,
1896
- cu_seqlens_k=None,
1897
- max_seqlen_k=None,
1898
- is_padded_inputs: Optional[bool] = True,
1899
- output_attentions: bool = False,
1900
- ):
1901
- """Implements the multihead softmax attention.
1902
- Arguments
1903
- ---------
1904
- q: The tensor containing the query. (B, Sq, H, D)
1905
- kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
1906
- causal: if passed, will override self.causal
1907
- cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1908
- of the sequences in the batch, used to index into q.
1909
- max_seqlen: int. Maximum sequence length in the batch of q.
1910
- cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1911
- of the sequences in the batch, used to index into kv.
1912
- max_seqlen_k: int. Maximum sequence length in the batch of k and v.
1913
- """
1914
- q_latent = self.latent.expand(kv.size(0), -1, -1)
1915
- q = self.Wq(q_latent)
1916
- bsz, q_len, h_size = q.shape
1917
- kv = self.Wkv(kv)
1918
- query = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
1919
- kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
1920
-
1921
- key, value = kv[:, :, 0], kv[:, :, 1]
1922
-
1923
- query = query.permute(0, 2, 1, 3)
1924
- key = key.permute(0, 2, 1, 3)
1925
- value = value.permute(0, 2, 1, 3)
1926
-
1927
- attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
1928
- if attention_mask is not None:
1929
- attention_scores = attention_scores + attention_mask
1930
-
1931
- attentions_probs = F.softmax(attention_scores, dim=-1)
1932
- attentions_probs = self.drop(attentions_probs)
1933
-
1934
- attn_output = torch.matmul(attentions_probs, value)
1935
- attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
1936
-
1937
- attn_output = self.out_proj(attn_output)
1938
-
1939
- return attn_output
1940
-
1941
-
1942
- class NomicMultiHeadAttentionPooling(nn.Module):
1943
- def __init__(
1944
- self,
1945
- config,
1946
- ):
1947
- super().__init__()
1948
- self.prenorm = config.prenorm
1949
- self.fused_dropout_add_ln = config.fused_dropout_add_ln
1950
-
1951
- self.attn = NomicAttentionPooling(config)
1952
- activation = (
1953
- F.sigmoid
1954
- if config.activation_function == "glu"
1955
- else (F.silu if config.activation_function == "swiglu" else F.gelu)
1956
- )
1957
- if config.activation_function in ["glu", "swiglu", "geglu"]:
1958
- self.mlp = NomciBertGatedMLP(
1959
- config.n_embd,
1960
- hidden_features=config.n_inner,
1961
- bias1=config.mlp_fc1_bias,
1962
- bias2=config.mlp_fc2_bias,
1963
- activation=activation,
1964
- fused_bias_fc=config.fused_bias_fc,
1965
- )
1966
- else:
1967
- self.mlp = NomicBertMLP(
1968
- config.n_embd,
1969
- hidden_features=config.n_inner,
1970
- bias1=config.mlp_fc1_bias,
1971
- bias2=config.mlp_fc2_bias,
1972
- activation=activation,
1973
- fused_bias_fc=config.fused_bias_fc,
1974
- )
1975
-
1976
- self.dropout1 = nn.Dropout(config.resid_pdrop)
1977
- self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1978
- self.dropout2 = nn.Dropout(config.resid_pdrop)
1979
-
1980
- def forward(
1981
- self,
1982
- hidden_states: torch.Tensor,
1983
- attention_mask: Optional[torch.Tensor] = None,
1984
- ):
1985
- r"""Pass the input through the encoder layer.
1986
-
1987
- Args:
1988
- hidden_states: the sequence to the encoder layer (required).
1989
- residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
1990
- mixer_subset: for cross-attention only. If not None, will take a subset of x
1991
- before applying the query projection. Useful for e.g., ViT where we only care
1992
- about the CLS token in the last layer.
1993
- """
1994
-
1995
- attn_outputs = self.attn(
1996
- hidden_states,
1997
- attention_mask=attention_mask,
1998
- )
1999
-
2000
- normed = self.norm1(attn_outputs)
2001
- hidden_states = hidden_states + self.mlp(normed)
2002
-
2003
- return hidden_states
2004
-
2005
- class NomicVisionPreTrainedModel(PreTrainedModel):
2006
- """An abstract class to handle weights initialization and
2007
- a simple interface for dowloading and loading pretrained models.
2008
- """
2009
-
2010
- config_class = NomicBertConfig
2011
- base_model_prefix = "model"
2012
- supports_gradient_checkpointing = True
2013
- _no_split_modules = ["Block"]
2014
- _skip_keys_device_placement = "past_key_values"
2015
-
2016
- def __init__(self, config, *inputs, **kwargs):
2017
- super().__init__(config)
2018
- if not isinstance(config, GPT2Config):
2019
- raise ValueError(
2020
- "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
2021
- "To create a model from a Google pretrained model use "
2022
- "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
2023
- self.__class__.__name__, self.__class__.__name__
2024
- )
2025
- )
2026
- self.config = config
2027
-
2028
- class NomicVisionModel(NomicVisionPreTrainedModel):
2029
- def __init__(self, config):
2030
- super().__init__(config)
2031
-
2032
- self.embeddings = NomicVisionPatchEmbeddings(config)
2033
- self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
2034
-
2035
- self.selector = NomicMultiHeadAttentionPooling(config)
2036
-
2037
- self.global_pool = getattr(config, "global_pool", None)
2038
- self.num_prefix_tokens = (1 if not getattr(config, "no_cls_token", False) else 0) + getattr(config, "register_tokens", 0)
2039
-
2040
- self.apply(partial(_init_weights, initializer_range=config.initializer_range))
2041
-
2042
- def forward(
2043
- self,
2044
- pixel_values,
2045
- attention_mask=None,
2046
- position_ids=None,
2047
- token_type_ids=None,
2048
- return_dict=None,
2049
- matryoshka_dim=None,
2050
- ):
2051
- embeddings, rope = self.embeddings(pixel_values)
2052
-
2053
- original_dtype = embeddings.dtype
2054
-
2055
- hidden_states = embeddings
2056
- # unused but easier to pass to gradient checkpointing as words
2057
- residual = None
2058
- for layer in self.layers:
2059
- # need to pass none for backwards compatability
2060
- hidden_states, _, residual = layer(hidden_states, None, residual=residual, is_padded_inputs=False, rope=rope)
2061
-
2062
- hidden_states = hidden_states + residual
2063
- if self.global_pool == "avg":
2064
- hidden_states = hidden_states[:, self.num_prefix_tokens:].mean(dim=1)
2065
-
2066
- pooled_output = self.selector(hidden_states)
2067
-
2068
- return BaseModelOutputWithPast(
2069
- last_hidden_state=pooled_output,
2070
- hidden_states=hidden_states,
2071
- )