Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code
zpn commited on
Commit
7680a1c
·
verified ·
1 Parent(s): 7b260c5

Update modeling_hf_nomic_bert.py

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +840 -3
modeling_hf_nomic_bert.py CHANGED
@@ -6,6 +6,9 @@
6
  import logging
7
 
8
  # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
 
 
 
9
  import os
10
  import re
11
  from collections import OrderedDict
@@ -17,7 +20,7 @@ import torch.nn as nn
17
  import torch.nn.functional as F
18
  from einops import rearrange, repeat
19
  from safetensors.torch import load_file as safe_load_file
20
- from transformers import GPT2Config, PreTrainedModel
21
  from transformers.models.bert.modeling_bert import (
22
  BaseModelOutputWithPoolingAndCrossAttentions,
23
  MaskedLMOutput,
@@ -25,6 +28,8 @@ from transformers.models.bert.modeling_bert import (
25
  )
26
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
27
  from transformers.utils.hub import cached_file, get_checkpoint_shard_files
 
 
28
 
29
  from .configuration_hf_nomic_bert import NomicBertConfig
30
 
@@ -268,6 +273,68 @@ def remap_bert_state_dict(
268
 
269
  return state_dict
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  class NomicBertPreTrainedModel(PreTrainedModel):
273
  """An abstract class to handle weights initialization and
@@ -382,6 +449,487 @@ def _init_weights(module, initializer_range=0.02):
382
  if module.padding_idx is not None:
383
  nn.init.zeros_(module.weight[module.padding_idx])
384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
  class NomicBertEmbeddings(nn.Module):
387
  def __init__(self, config):
@@ -466,17 +1014,19 @@ class NomciBertGatedMLP(nn.Module):
466
  fused_bias_fc=True,
467
  device=None,
468
  dtype=None,
 
469
  ):
470
  super().__init__()
471
  out_features = out_features if out_features is not None else in_features
472
  hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3)
473
- hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
474
  self.return_residual = return_residual
475
 
476
  self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1)
477
  self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1)
478
  self.activation = activation
479
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
 
480
 
481
  def forward(self, x):
482
  y = self.fc11(x)
@@ -485,6 +1035,10 @@ class NomciBertGatedMLP(nn.Module):
485
  y = F.glu(torch.cat([y, gate], dim=-1), dim=-1)
486
  else:
487
  y = y * self.activation(gate)
 
 
 
 
488
  y = self.fc2(y)
489
  return y if not self.return_residual else (y, x)
490
 
@@ -758,6 +1312,7 @@ class NomicBertAttention(nn.Module):
758
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
759
  self.causal = config.causal
760
  self.drop = nn.Dropout(config.attn_pdrop)
 
761
 
762
  def forward(
763
  self,
@@ -770,6 +1325,7 @@ class NomicBertAttention(nn.Module):
770
  is_padded_inputs: Optional[bool] = True,
771
  cu_seqlens: Optional[torch.Tensor] = None,
772
  max_seq_len: Optional[int] = None,
 
773
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
774
 
775
  has_layer_past = past_key_value is not None
@@ -792,6 +1348,13 @@ class NomicBertAttention(nn.Module):
792
 
793
  if self.rotary_head_dim:
794
  qkv = rearrange(qkv, "b h three s d -> b s three h d")
 
 
 
 
 
 
 
795
 
796
  query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
797
 
@@ -837,6 +1400,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
837
  bias2=config.mlp_fc2_bias,
838
  activation=activation,
839
  fused_bias_fc=config.fused_bias_fc,
 
840
  )
841
  else:
842
  self.mlp = NomicBertMLP(
@@ -866,6 +1430,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
866
  use_cache: Optional[bool] = False,
867
  cu_seqlens: Optional[torch.Tensor] = None,
868
  max_seq_len: Optional[int] = None,
 
869
  ):
870
  r"""Pass the input through the encoder layer.
871
 
@@ -886,6 +1451,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
886
  is_padded_inputs=is_padded_inputs,
887
  cu_seqlens=cu_seqlens,
888
  max_seq_len=max_seq_len,
 
889
  )
890
 
891
  dropped = self.dropout2(hidden_states)
@@ -902,6 +1468,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
902
  is_padded_inputs=is_padded_inputs,
903
  cu_seqlens=cu_seqlens,
904
  max_seq_len=max_seq_len,
 
905
  )
906
  hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
907
  mlp_out = self.mlp(hidden_states)
@@ -929,6 +1496,7 @@ class NomicBertEncoder(nn.Module):
929
  output_hidden_states: Optional[bool] = None,
930
  return_dict: Optional[bool] = None,
931
  is_padded_inputs: Optional[bool] = True,
 
932
  ):
933
  """If subset_mask is not None, we only want output for the subset of the sequence.
934
  This means that we only compute the last layer output for these tokens.
@@ -953,9 +1521,14 @@ class NomicBertEncoder(nn.Module):
953
  hidden_states2,
954
  residual,
955
  attention_mask,
 
 
 
 
 
956
  None,
957
  None,
958
- is_padded_inputs,
959
  # if you freeze ANY layers, you need `use_reentrant=False`
960
  # https://github.com/huggingface/transformers/issues/21381
961
  # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/7
@@ -973,6 +1546,7 @@ class NomicBertEncoder(nn.Module):
973
  is_padded_inputs,
974
  output_attentions,
975
  use_cache,
 
976
  )
977
  return hidden_states
978
 
@@ -1232,3 +1806,266 @@ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1232
  hidden_states=outputs.hidden_states,
1233
  attentions=outputs.attentions,
1234
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import logging
7
 
8
  # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
+ import math
10
+ import numpy as np
11
+ import collections
12
  import os
13
  import re
14
  from collections import OrderedDict
 
20
  import torch.nn.functional as F
21
  from einops import rearrange, repeat
22
  from safetensors.torch import load_file as safe_load_file
23
+ from transformers import GPT2Config, PreTrainedModel, ViTModel, ViTConfig
24
  from transformers.models.bert.modeling_bert import (
25
  BaseModelOutputWithPoolingAndCrossAttentions,
26
  MaskedLMOutput,
 
28
  )
29
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
30
  from transformers.utils.hub import cached_file, get_checkpoint_shard_files
31
+ from transformers.modeling_outputs import BaseModelOutputWithPast
32
+ from torch.nn.modules.utils import _pair
33
 
34
  from .configuration_hf_nomic_bert import NomicBertConfig
35
 
 
273
 
274
  return state_dict
275
 
276
+
277
+ def _trunc_normal_(tensor, mean, std, a, b):
278
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
279
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
280
+ def norm_cdf(x):
281
+ # Computes standard normal cumulative distribution function
282
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
283
+
284
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
285
+ print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
286
+ "The distribution of values may be incorrect.",
287
+ stacklevel=2)
288
+
289
+ # Values are generated by using a truncated uniform distribution and
290
+ # then using the inverse CDF for the normal distribution.
291
+ # Get upper and lower cdf values
292
+ l = norm_cdf((a - mean) / std)
293
+ u = norm_cdf((b - mean) / std)
294
+
295
+ # Uniformly fill tensor with values from [l, u], then translate to
296
+ # [2l-1, 2u-1].
297
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
298
+
299
+ # Use inverse cdf transform for normal distribution to get truncated
300
+ # standard normal
301
+ tensor.erfinv_()
302
+
303
+ # Transform to proper mean, std
304
+ tensor.mul_(std * math.sqrt(2.))
305
+ tensor.add_(mean)
306
+
307
+ # Clamp to ensure it's in the proper range
308
+ tensor.clamp_(min=a, max=b)
309
+ return tensor
310
+
311
+ def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
312
+ r"""Fills the input Tensor with values drawn from a truncated
313
+ normal distribution. The values are effectively drawn from the
314
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
315
+ with values outside :math:`[a, b]` redrawn until they are within
316
+ the bounds. The method used for generating the random values works
317
+ best when :math:`a \leq \text{mean} \leq b`.
318
+
319
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
320
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
321
+ and the result is subsquently scaled and shifted by the mean and std args.
322
+
323
+ Args:
324
+ tensor: an n-dimensional `torch.Tensor`
325
+ mean: the mean of the normal distribution
326
+ std: the standard deviation of the normal distribution
327
+ a: the minimum cutoff value
328
+ b: the maximum cutoff value
329
+ Examples:
330
+ >>> w = torch.empty(3, 5)
331
+ >>> nn.init.trunc_normal_(w)
332
+ """
333
+ with torch.no_grad():
334
+ _trunc_normal_(tensor, 0, 1.0, a, b)
335
+ tensor.mul_(std).add_(mean)
336
+ return tensor
337
+
338
 
339
  class NomicBertPreTrainedModel(PreTrainedModel):
340
  """An abstract class to handle weights initialization and
 
449
  if module.padding_idx is not None:
450
  nn.init.zeros_(module.weight[module.padding_idx])
451
 
452
+ def _ntuple(n):
453
+ def parse(x):
454
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
455
+ return tuple(x)
456
+ return tuple(repeat(x, n))
457
+ return parse
458
+
459
+
460
+ to_1tuple = _ntuple(1)
461
+ to_2tuple = _ntuple(2)
462
+ to_3tuple = _ntuple(3)
463
+ to_4tuple = _ntuple(4)
464
+ to_ntuple = _ntuple
465
+
466
+
467
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
468
+ """
469
+ Create 2D sin/cos positional embeddings.
470
+
471
+ Args:
472
+ embed_dim (`int`):
473
+ Embedding dimension.
474
+ grid_size (`int`):
475
+ The grid height and width.
476
+ add_cls_token (`bool`, *optional*, defaults to `False`):
477
+ Whether or not to add a classification (CLS) token.
478
+
479
+ Returns:
480
+ (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
481
+ position embeddings (with or without classification token)
482
+ """
483
+ grid_h = np.arange(grid_size, dtype=np.float32)
484
+
485
+ grid_w = np.arange(grid_size, dtype=np.float32)
486
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
487
+ grid = np.stack(grid, axis=0)
488
+
489
+ grid = grid.reshape([2, 1, grid_size, grid_size])
490
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
491
+ if add_cls_token:
492
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
493
+ return pos_embed
494
+
495
+
496
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
497
+ if embed_dim % 2 != 0:
498
+ raise ValueError("embed_dim must be even")
499
+
500
+ # use half of dimensions to encode grid_h
501
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
502
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
503
+
504
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
505
+ return emb
506
+
507
+
508
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
509
+ """
510
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
511
+ """
512
+ if embed_dim % 2 != 0:
513
+ raise ValueError("embed_dim must be even")
514
+
515
+ omega = np.arange(embed_dim // 2, dtype=float)
516
+ omega /= embed_dim / 2.0
517
+ omega = 1.0 / 10000**omega # (D/2,)
518
+
519
+ pos = pos.reshape(-1) # (M,)
520
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
521
+
522
+ emb_sin = np.sin(out) # (M, D/2)
523
+ emb_cos = np.cos(out) # (M, D/2)
524
+
525
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
526
+ return emb
527
+
528
+ def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
529
+ """generate N-D grid in dimension order.
530
+
531
+ The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
532
+
533
+ That is, the statement
534
+ [X1,X2,X3] = ndgrid(x1,x2,x3)
535
+
536
+ produces the same result as
537
+
538
+ [X2,X1,X3] = meshgrid(x2,x1,x3)
539
+
540
+ This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
541
+ torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
542
+
543
+ """
544
+ try:
545
+ return torch.meshgrid(*tensors, indexing='ij')
546
+ except TypeError:
547
+ # old PyTorch < 1.10 will follow this path as it does not have indexing arg,
548
+ # the old behaviour of meshgrid was 'ij'
549
+ return torch.meshgrid(*tensors)
550
+
551
+ def build_fourier_pos_embed(
552
+ feat_shape: List[int],
553
+ bands: Optional[torch.Tensor] = None,
554
+ num_bands: int = 64,
555
+ max_res: int = 224,
556
+ temperature: float = 10000.,
557
+ linear_bands: bool = False,
558
+ include_grid: bool = False,
559
+ in_pixels: bool = True,
560
+ ref_feat_shape: Optional[List[int]] = None,
561
+ dtype: torch.dtype = torch.float32,
562
+ device: Optional[torch.device] = None,
563
+ ) -> List[torch.Tensor]:
564
+ """
565
+
566
+ Args:
567
+ feat_shape: Feature shape for embedding.
568
+ bands: Pre-calculated frequency bands.
569
+ num_bands: Number of frequency bands (determines output dim).
570
+ max_res: Maximum resolution for pixel based freq.
571
+ temperature: Temperature for non-pixel freq.
572
+ linear_bands: Linear band spacing for pixel based freq.
573
+ include_grid: Include the spatial grid in output.
574
+ in_pixels: Output in pixel freq.
575
+ ref_feat_shape: Reference feature shape for resize / fine-tune.
576
+ dtype: Output dtype.
577
+ device: Output device.
578
+
579
+ Returns:
580
+
581
+ """
582
+ if bands is None:
583
+ if in_pixels:
584
+ bands = pixel_freq_bands(
585
+ num_bands,
586
+ float(max_res),
587
+ linear_bands=linear_bands,
588
+ device=device,
589
+ )
590
+ else:
591
+ bands = freq_bands(
592
+ num_bands,
593
+ temperature=temperature,
594
+ step=1,
595
+ device=device,
596
+ )
597
+ else:
598
+ if device is None:
599
+ device = bands.device
600
+ if dtype is None:
601
+ dtype = bands.dtype
602
+
603
+ if in_pixels:
604
+ t = [torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) for s in feat_shape]
605
+ else:
606
+ t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
607
+
608
+ if ref_feat_shape is not None:
609
+ # eva's scheme for resizing rope embeddings (ref shape = pretrain)
610
+ t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
611
+
612
+ grid = torch.stack(ndgrid(t), dim=-1)
613
+ grid = grid.unsqueeze(-1)
614
+ pos = grid * bands
615
+
616
+ pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype)
617
+ out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
618
+ return out
619
+
620
+
621
+ def build_rotary_pos_embed(
622
+ feat_shape: List[int],
623
+ bands: Optional[torch.Tensor] = None,
624
+ dim: int = 64,
625
+ max_res: int = 224,
626
+ temperature: float = 10000.,
627
+ linear_bands: bool = False,
628
+ in_pixels: bool = True,
629
+ ref_feat_shape: Optional[List[int]] = None,
630
+ dtype: torch.dtype = torch.float32,
631
+ device: Optional[torch.device] = None,
632
+ ):
633
+ """
634
+
635
+ Args:
636
+ feat_shape: Spatial shape of the target tensor for embedding.
637
+ bands: Optional pre-generated frequency bands
638
+ dim: Output dimension of embedding tensor.
639
+ max_res: Maximum resolution for pixel mode.
640
+ temperature: Temperature (inv freq) for non-pixel mode
641
+ linear_bands: Linearly (instead of log) spaced bands for pixel mode
642
+ in_pixels: Pixel vs language (inv freq) mode.
643
+ dtype: Output dtype.
644
+ device: Output device.
645
+
646
+ Returns:
647
+
648
+ """
649
+ sin_emb, cos_emb = build_fourier_pos_embed(
650
+ feat_shape,
651
+ bands=bands,
652
+ num_bands=dim // 4,
653
+ max_res=max_res,
654
+ temperature=temperature,
655
+ linear_bands=linear_bands,
656
+ in_pixels=in_pixels,
657
+ ref_feat_shape=ref_feat_shape,
658
+ device=device,
659
+ dtype=dtype,
660
+ )
661
+ num_spatial_dim = 1
662
+ # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
663
+ for x in feat_shape:
664
+ num_spatial_dim *= x
665
+ sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
666
+ cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
667
+ return sin_emb, cos_emb
668
+
669
+ def freq_bands(
670
+ num_bands: int,
671
+ temperature: float = 10000.,
672
+ step: int = 2,
673
+ device: Optional[torch.device] = None,
674
+ ) -> torch.Tensor:
675
+ exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
676
+ bands = 1. / (temperature ** exp)
677
+ return bands
678
+
679
+
680
+ def pixel_freq_bands(
681
+ num_bands: int,
682
+ max_freq: float = 224.,
683
+ linear_bands: bool = True,
684
+ device: Optional[torch.device] = None,
685
+ ):
686
+ if linear_bands:
687
+ bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
688
+ else:
689
+ bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
690
+ return bands * torch.pi
691
+
692
+ def rot(x):
693
+ return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
694
+
695
+ def apply_rot_embed_cat(x: torch.Tensor, emb):
696
+ sin_emb, cos_emb = emb.tensor_split(2, -1)
697
+ if sin_emb.ndim == 3:
698
+ return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
699
+ return x * cos_emb + rot(x) * sin_emb
700
+
701
+ # taken from https://github.com/huggingface/pytorch-image-models/blob/cb0e4391beedcc5ac3ae4bce16561b95c326f32c/timm/layers/pos_embed_sincos.py#L363
702
+ class NomicVisionRotaryEmbeddingCat(nn.Module):
703
+ """ Rotary position embedding w/ concatenatd sin & cos
704
+
705
+ The following impl/resources were referenced for this impl:
706
+ * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
707
+ * https://blog.eleuther.ai/rotary-embeddings/
708
+ """
709
+
710
+ def __init__(
711
+ self,
712
+ dim,
713
+ max_res=224,
714
+ temperature=10000,
715
+ in_pixels=True,
716
+ linear_bands: bool = False,
717
+ feat_shape: Optional[List[int]] = None,
718
+ ref_feat_shape: Optional[List[int]] = None,
719
+ ):
720
+ super().__init__()
721
+ self.dim = dim
722
+ self.max_res = max_res
723
+ self.temperature = temperature
724
+ self.in_pixels = in_pixels
725
+ self.feat_shape = feat_shape
726
+ self.ref_feat_shape = ref_feat_shape
727
+
728
+ if feat_shape is None:
729
+ # only cache bands
730
+ if in_pixels:
731
+ bands = pixel_freq_bands(
732
+ dim // 4,
733
+ float(max_res),
734
+ linear_bands=linear_bands,
735
+ )
736
+ else:
737
+ bands = freq_bands(
738
+ dim // 4,
739
+ temperature=temperature,
740
+ step=1,
741
+ )
742
+ self.register_buffer(
743
+ 'bands',
744
+ bands,
745
+ persistent=False,
746
+ )
747
+ self.pos_embed = None
748
+ else:
749
+ # cache full sin/cos embeddings if shape provided up front
750
+ embeds = build_rotary_pos_embed(
751
+ feat_shape=feat_shape,
752
+ dim=dim,
753
+ max_res=max_res,
754
+ linear_bands=linear_bands,
755
+ in_pixels=in_pixels,
756
+ ref_feat_shape=self.ref_feat_shape,
757
+ )
758
+ self.bands = None
759
+ self.register_buffer(
760
+ 'pos_embed',
761
+ torch.cat(embeds, -1),
762
+ persistent=False,
763
+ )
764
+
765
+ def get_embed(self, shape: Optional[List[int]] = None):
766
+ if self.bands is not None and shape is not None:
767
+ # rebuild embeddings every call, use if target shape changes
768
+ embeds = build_rotary_pos_embed(
769
+ shape,
770
+ self.bands,
771
+ in_pixels=self.in_pixels,
772
+ ref_feat_shape=self.ref_feat_shape,
773
+ )
774
+ return torch.cat(embeds, -1)
775
+ elif self.pos_embed is not None:
776
+ return self.pos_embed
777
+ else:
778
+ assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands"
779
+
780
+ def forward(self, x):
781
+ # assuming channel-first tensor where spatial dim are >= 2
782
+ pos_embed = self.get_embed(x.shape[2:])
783
+ return apply_rot_embed_cat(x, pos_embed)
784
+
785
+ class NomicVisionPatchEmbeddings(nn.Module):
786
+ def __init__(
787
+ self,
788
+ config,
789
+ ):
790
+ super().__init__()
791
+ img_size = _pair(config.img_size)
792
+ patch_size = _pair(config.patch_size)
793
+ self.img_size = img_size
794
+ self.patch_size = patch_size
795
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
796
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
797
+
798
+ self.proj = nn.Linear(
799
+ config.num_channels * patch_size[0] * patch_size[1], config.n_embd, bias=config.patch_embed_bias
800
+ )
801
+
802
+ self.learned_pos_embedding = False
803
+ self.sinusoidal_pos_embedding = False
804
+ self.no_embed_class = getattr(config, "no_embed_class", False)
805
+
806
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.n_embd)) if not getattr(config, "no_cls_token", False) else None
807
+ if config.learned_pos_embedding:
808
+ # this is the default in DINO
809
+ self.learned_pos_embedding = True
810
+ # hack for timm dinov2 with registers
811
+ num_patches = self.num_patches if getattr(config, "register_tokens", 0) > 0 else self.num_patches + 1
812
+ self.pos_embed = nn.Parameter(torch.randn(1, num_patches, config.n_embd) * 0.02) if getattr(config, "use_pos_embed", True) else None
813
+ elif getattr(config, "sinusoidal_pos_embedding", False):
814
+ self.sinusoidal_pos_embedding = True
815
+ if getattr(config, "use_pos_embed", True):
816
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, config.n_embd), requires_grad=False)
817
+ pos_embed = get_2d_sincos_pos_embed(config.n_embd, self.grid_size[0], add_cls_token=True)
818
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).to(self.pos_embed))
819
+ else:
820
+ self.pos_embed = None
821
+ else:
822
+ self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, config.n_embd) * 0.02) if getattr(config, "use_pos_embed", True) else None
823
+
824
+ if getattr(config, "register_tokens", 0) > 0:
825
+ self.reg_token = nn.Parameter(torch.randn(1, config.register_tokens, config.n_embd) * 0.02)
826
+ else:
827
+ self.reg_token = None
828
+
829
+ if config.mask_token:
830
+ self.mask_token = nn.Parameter(torch.zeros(1, config.n_embd))
831
+
832
+ self.patch_dropout = nn.Identity()
833
+
834
+ if getattr(config, "use_rotary_pos_emb", False):
835
+ ref_feat_shape = getattr(config, "ref_feat_shape", None)
836
+ ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
837
+ self.rope = NomicVisionRotaryEmbeddingCat(
838
+ config.n_embd // config.n_head,
839
+ in_pixels=False,
840
+ feat_shape=self.grid_size,
841
+ ref_feat_shape=ref_feat_shape,
842
+ )
843
+ else:
844
+ self.rope = None
845
+
846
+
847
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
848
+ """
849
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
850
+ resolution images.
851
+
852
+ Source:
853
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
854
+ """
855
+ num_patches = embeddings.shape[1] - 1
856
+ num_positions = self.pos_embed.shape[1] - 1
857
+ if num_patches == num_positions and height == width:
858
+ return self.pos_embed
859
+ class_pos_embed = self.pos_embed[:, 0]
860
+ patch_pos_embed = self.pos_embed[:, 1:]
861
+ dim = embeddings.shape[-1]
862
+ height = height // self.patch_size[0]
863
+ width = width // self.patch_size[1]
864
+ # we add a small number to avoid floating point error in the interpolation
865
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
866
+ height, width = height + 0.1, width + 0.1
867
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
868
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
869
+ patch_pos_embed = nn.functional.interpolate(
870
+ patch_pos_embed,
871
+ scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
872
+ mode="bicubic",
873
+ align_corners=False,
874
+ )
875
+ if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
876
+ raise ValueError("Width or height does not match with the interpolated position embeddings")
877
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
878
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
879
+
880
+ def forward(self, x):
881
+ # deepspeed case where the input is in fp32
882
+ if x.dtype != self.proj.weight.dtype:
883
+ x = x.to(dtype=self.proj.weight.dtype)
884
+
885
+ _, _, height, width = x.shape
886
+ x = self.proj(
887
+ rearrange(
888
+ x,
889
+ "b c (h p1) (w p2) -> b h w (c p1 p2)",
890
+ p1=self.patch_size[0],
891
+ p2=self.patch_size[1],
892
+ )
893
+ )
894
+ embeddings = rearrange(x, "b h w c -> b (h w) c")
895
+
896
+ to_cat = []
897
+ if self.cls_token is not None:
898
+ if self.sinusoidal_pos_embedding:
899
+ cls_token = self.cls_token + self.pos_embed[:, 0]
900
+ cls_token = cls_token.expand(embeddings.shape[0], -1, -1)
901
+ to_cat += [cls_token]
902
+ else:
903
+ cls_token = self.cls_token.expand(embeddings.shape[0], 1, -1)
904
+ to_cat += [cls_token]
905
+
906
+ if self.reg_token is not None:
907
+ to_cat += [self.reg_token.expand(embeddings.shape[0], -1, -1)]
908
+
909
+ rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
910
+
911
+ if self.no_embed_class:
912
+ if self.learned_pos_embedding:
913
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
914
+ else:
915
+ if self.pos_embed is not None:
916
+ embeddings = embeddings + self.pos_embed
917
+ if to_cat:
918
+ embeddings = torch.cat(to_cat + [embeddings], dim=1)
919
+ else:
920
+ if to_cat:
921
+ embeddings = torch.cat(to_cat + [embeddings], dim=1)
922
+ if self.learned_pos_embedding:
923
+ if self.pos_embed is not None:
924
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
925
+ else:
926
+ if self.pos_embed is not None:
927
+ embeddings = embeddings + self.pos_embed
928
+
929
+ embeddings = self.patch_dropout(embeddings)
930
+
931
+ return embeddings, rot_pos_embed
932
+
933
 
934
  class NomicBertEmbeddings(nn.Module):
935
  def __init__(self, config):
 
1014
  fused_bias_fc=True,
1015
  device=None,
1016
  dtype=None,
1017
+ norm_layer=False,
1018
  ):
1019
  super().__init__()
1020
  out_features = out_features if out_features is not None else in_features
1021
  hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3)
1022
+ hidden_features = int((hidden_features + multiple_of - 1) // multiple_of * multiple_of)
1023
  self.return_residual = return_residual
1024
 
1025
  self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1)
1026
  self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1)
1027
  self.activation = activation
1028
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
1029
+ self.norm = nn.LayerNorm(hidden_features) if norm_layer else nn.Identity()
1030
 
1031
  def forward(self, x):
1032
  y = self.fc11(x)
 
1035
  y = F.glu(torch.cat([y, gate], dim=-1), dim=-1)
1036
  else:
1037
  y = y * self.activation(gate)
1038
+
1039
+ # eva uses layer norm after the activation
1040
+ y = self.norm(y)
1041
+
1042
  y = self.fc2(y)
1043
  return y if not self.return_residual else (y, x)
1044
 
 
1312
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
1313
  self.causal = config.causal
1314
  self.drop = nn.Dropout(config.attn_pdrop)
1315
+ self.num_prefix_tokens = max(getattr(config, "register_tokens", 1), 1)
1316
 
1317
  def forward(
1318
  self,
 
1325
  is_padded_inputs: Optional[bool] = True,
1326
  cu_seqlens: Optional[torch.Tensor] = None,
1327
  max_seq_len: Optional[int] = None,
1328
+ rope: Optional[torch.Tensor] = None,
1329
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1330
 
1331
  has_layer_past = past_key_value is not None
 
1348
 
1349
  if self.rotary_head_dim:
1350
  qkv = rearrange(qkv, "b h three s d -> b s three h d")
1351
+ elif rope is not None:
1352
+ q, k, v = qkv.permute(0, 3, 1, 2, 4).unbind(dim=-2)
1353
+ q = torch.cat([q[:, :, :self.num_prefix_tokens], apply_rot_embed_cat(q[:, :, self.num_prefix_tokens:], rope)], dim=2).type_as(q)
1354
+ k = torch.cat([k[:, :, :self.num_prefix_tokens], apply_rot_embed_cat(k[:, :, self.num_prefix_tokens:], rope)], dim=2).type_as(q)
1355
+
1356
+ qkv = torch.stack([q, k, v], dim=-2)
1357
+ qkv = rearrange(qkv, "b h s three d -> b s three h d")
1358
 
1359
  query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
1360
 
 
1400
  bias2=config.mlp_fc2_bias,
1401
  activation=activation,
1402
  fused_bias_fc=config.fused_bias_fc,
1403
+ norm_layer=getattr(config, "norm_mlp", False),
1404
  )
1405
  else:
1406
  self.mlp = NomicBertMLP(
 
1430
  use_cache: Optional[bool] = False,
1431
  cu_seqlens: Optional[torch.Tensor] = None,
1432
  max_seq_len: Optional[int] = None,
1433
+ rope: Optional[torch.Tensor] = None,
1434
  ):
1435
  r"""Pass the input through the encoder layer.
1436
 
 
1451
  is_padded_inputs=is_padded_inputs,
1452
  cu_seqlens=cu_seqlens,
1453
  max_seq_len=max_seq_len,
1454
+ rope=rope,
1455
  )
1456
 
1457
  dropped = self.dropout2(hidden_states)
 
1468
  is_padded_inputs=is_padded_inputs,
1469
  cu_seqlens=cu_seqlens,
1470
  max_seq_len=max_seq_len,
1471
+ rope=rope,
1472
  )
1473
  hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
1474
  mlp_out = self.mlp(hidden_states)
 
1496
  output_hidden_states: Optional[bool] = None,
1497
  return_dict: Optional[bool] = None,
1498
  is_padded_inputs: Optional[bool] = True,
1499
+ rope: Optional[torch.Tensor] = None,
1500
  ):
1501
  """If subset_mask is not None, we only want output for the subset of the sequence.
1502
  This means that we only compute the last layer output for these tokens.
 
1521
  hidden_states2,
1522
  residual,
1523
  attention_mask,
1524
+ position_ids,
1525
+ past_key_values,
1526
+ is_padded_inputs,
1527
+ output_attentions,
1528
+ use_cache,
1529
  None,
1530
  None,
1531
+ rope,
1532
  # if you freeze ANY layers, you need `use_reentrant=False`
1533
  # https://github.com/huggingface/transformers/issues/21381
1534
  # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/7
 
1546
  is_padded_inputs,
1547
  output_attentions,
1548
  use_cache,
1549
+ rope=rope,
1550
  )
1551
  return hidden_states
1552
 
 
1806
  hidden_states=outputs.hidden_states,
1807
  attentions=outputs.attentions,
1808
  )
1809
+
1810
+ def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
1811
+ return GPT2Config(
1812
+ n_embd=vit_config.hidden_size,
1813
+ n_layer=vit_config.num_hidden_layers,
1814
+ n_head=vit_config.num_attention_heads,
1815
+ n_inner=vit_config.intermediate_size,
1816
+ activation_function=vit_config.hidden_act,
1817
+ vocab_size=0, # no vocab since using patches
1818
+ n_positions=0, # No absolute position embedding
1819
+ resid_pdrop=0.0, # No dropout
1820
+ embd_pdrop=getattr(vit_config, "dropout", 0.0),
1821
+ attn_pdrop=vit_config.attention_probs_dropout_prob,
1822
+ layer_norm_epsilon=vit_config.layer_norm_eps,
1823
+ initializer_range=vit_config.initializer_range,
1824
+ bos_token_id=None,
1825
+ eos_token_id=None,
1826
+ # These are new arguments not in the original GPT2Config
1827
+ drop_path_rate=0.0,
1828
+ # Why is there double layer norm??
1829
+ prepre_layernom=False,
1830
+ layer_scale=False,
1831
+ layer_scale_init=None,
1832
+ img_size=vit_config.image_size,
1833
+ patch_size=vit_config.patch_size,
1834
+ num_channels=vit_config.num_channels,
1835
+ prenorm=True,
1836
+ parallel_block=False,
1837
+ parallel_block_tied_norm=False,
1838
+ rotary_emb_fraction=0,
1839
+ tie_word_embeddings=False,
1840
+ fused_dropout_add_ln=True,
1841
+ fused_bias_fc=True,
1842
+ patch_embed_bias=True,
1843
+ use_flash_attn=True,
1844
+ qkv_proj_bias=True,
1845
+ mlp_fc1_bias=getattr(vit_config, "mlp_fc1_bias", True),
1846
+ mlp_fc2_bias=getattr(vit_config, "mlp_fc2_bias", True),
1847
+ use_rms_norm=False,
1848
+ causal=False,
1849
+ hidden_features_scaling_factor=1.0,
1850
+ mask_token=False,
1851
+ learned_pos_embedding=False,
1852
+ patch_dropout=0,
1853
+ sinusoidal_pos_embedding=vit_config.model_type == "vit_mae"
1854
+ )
1855
+
1856
+
1857
+ class NomicAttentionPooling(nn.Module):
1858
+ def __init__(
1859
+ self,
1860
+ config
1861
+ ):
1862
+ super().__init__()
1863
+ self.embed_dim = config.n_embd
1864
+ self.use_flash_attn = config.use_flash_attn
1865
+ self.fused_bias_fc = config.fused_bias_fc
1866
+
1867
+ self.num_heads = config.n_head
1868
+ self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
1869
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
1870
+ self.head_dim = self.embed_dim // self.num_heads
1871
+ # we don't really support mqa / gqa for now
1872
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
1873
+
1874
+ self.register_buffer(
1875
+ "norm_factor",
1876
+ torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
1877
+ persistent=False,
1878
+ )
1879
+
1880
+ self.Wq = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
1881
+ self.Wkv = nn.Linear(self.embed_dim, kv_dim, bias=config.qkv_proj_bias)
1882
+
1883
+ self.latent = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
1884
+
1885
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
1886
+ self.causal = config.causal
1887
+ self.drop = nn.Dropout(config.attn_pdrop)
1888
+
1889
+ def init_weights(self):
1890
+ trunc_normal_tf_(self.latent, std=self.embed_dim ** -0.5)
1891
+
1892
+ def forward(
1893
+ self,
1894
+ kv,
1895
+ attention_mask=None,
1896
+ cu_seqlens_k=None,
1897
+ max_seqlen_k=None,
1898
+ is_padded_inputs: Optional[bool] = True,
1899
+ output_attentions: bool = False,
1900
+ ):
1901
+ """Implements the multihead softmax attention.
1902
+ Arguments
1903
+ ---------
1904
+ q: The tensor containing the query. (B, Sq, H, D)
1905
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
1906
+ causal: if passed, will override self.causal
1907
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1908
+ of the sequences in the batch, used to index into q.
1909
+ max_seqlen: int. Maximum sequence length in the batch of q.
1910
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1911
+ of the sequences in the batch, used to index into kv.
1912
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
1913
+ """
1914
+ q_latent = self.latent.expand(kv.size(0), -1, -1)
1915
+ q = self.Wq(q_latent)
1916
+ bsz, q_len, h_size = q.shape
1917
+ kv = self.Wkv(kv)
1918
+ query = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
1919
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
1920
+
1921
+ key, value = kv[:, :, 0], kv[:, :, 1]
1922
+
1923
+ query = query.permute(0, 2, 1, 3)
1924
+ key = key.permute(0, 2, 1, 3)
1925
+ value = value.permute(0, 2, 1, 3)
1926
+
1927
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
1928
+ if attention_mask is not None:
1929
+ attention_scores = attention_scores + attention_mask
1930
+
1931
+ attentions_probs = F.softmax(attention_scores, dim=-1)
1932
+ attentions_probs = self.drop(attentions_probs)
1933
+
1934
+ attn_output = torch.matmul(attentions_probs, value)
1935
+ attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
1936
+
1937
+ attn_output = self.out_proj(attn_output)
1938
+
1939
+ return attn_output
1940
+
1941
+
1942
+ class NomicMultiHeadAttentionPooling(nn.Module):
1943
+ def __init__(
1944
+ self,
1945
+ config,
1946
+ ):
1947
+ super().__init__()
1948
+ self.prenorm = config.prenorm
1949
+ self.fused_dropout_add_ln = config.fused_dropout_add_ln
1950
+
1951
+ self.attn = NomicAttentionPooling(config)
1952
+ activation = (
1953
+ F.sigmoid
1954
+ if config.activation_function == "glu"
1955
+ else (F.silu if config.activation_function == "swiglu" else F.gelu)
1956
+ )
1957
+ if config.activation_function in ["glu", "swiglu", "geglu"]:
1958
+ self.mlp = NomciBertGatedMLP(
1959
+ config.n_embd,
1960
+ hidden_features=config.n_inner,
1961
+ bias1=config.mlp_fc1_bias,
1962
+ bias2=config.mlp_fc2_bias,
1963
+ activation=activation,
1964
+ fused_bias_fc=config.fused_bias_fc,
1965
+ )
1966
+ else:
1967
+ self.mlp = NomicBertMLP(
1968
+ config.n_embd,
1969
+ hidden_features=config.n_inner,
1970
+ bias1=config.mlp_fc1_bias,
1971
+ bias2=config.mlp_fc2_bias,
1972
+ activation=activation,
1973
+ fused_bias_fc=config.fused_bias_fc,
1974
+ )
1975
+
1976
+ self.dropout1 = nn.Dropout(config.resid_pdrop)
1977
+ self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1978
+ self.dropout2 = nn.Dropout(config.resid_pdrop)
1979
+
1980
+ def forward(
1981
+ self,
1982
+ hidden_states: torch.Tensor,
1983
+ attention_mask: Optional[torch.Tensor] = None,
1984
+ ):
1985
+ r"""Pass the input through the encoder layer.
1986
+
1987
+ Args:
1988
+ hidden_states: the sequence to the encoder layer (required).
1989
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
1990
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
1991
+ before applying the query projection. Useful for e.g., ViT where we only care
1992
+ about the CLS token in the last layer.
1993
+ """
1994
+
1995
+ attn_outputs = self.attn(
1996
+ hidden_states,
1997
+ attention_mask=attention_mask,
1998
+ )
1999
+
2000
+ normed = self.norm1(attn_outputs)
2001
+ hidden_states = hidden_states + self.mlp(normed)
2002
+
2003
+ return hidden_states
2004
+
2005
+ class NomicVisionPreTrainedModel(PreTrainedModel):
2006
+ """An abstract class to handle weights initialization and
2007
+ a simple interface for dowloading and loading pretrained models.
2008
+ """
2009
+
2010
+ config_class = NomicBertConfig
2011
+ base_model_prefix = "model"
2012
+ supports_gradient_checkpointing = True
2013
+ _no_split_modules = ["Block"]
2014
+ _skip_keys_device_placement = "past_key_values"
2015
+
2016
+ def __init__(self, config, *inputs, **kwargs):
2017
+ super().__init__(config)
2018
+ if not isinstance(config, GPT2Config):
2019
+ raise ValueError(
2020
+ "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
2021
+ "To create a model from a Google pretrained model use "
2022
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
2023
+ self.__class__.__name__, self.__class__.__name__
2024
+ )
2025
+ )
2026
+ self.config = config
2027
+
2028
+ class NomicVisionModel(NomicVisionPreTrainedModel):
2029
+ def __init__(self, config):
2030
+ super().__init__(config)
2031
+
2032
+ self.embeddings = NomicVisionPatchEmbeddings(config)
2033
+ self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
2034
+
2035
+ self.selector = NomicMultiHeadAttentionPooling(config)
2036
+
2037
+ self.global_pool = getattr(config, "global_pool", None)
2038
+ self.num_prefix_tokens = (1 if not getattr(config, "no_cls_token", False) else 0) + getattr(config, "register_tokens", 0)
2039
+
2040
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
2041
+
2042
+ def forward(
2043
+ self,
2044
+ pixel_values,
2045
+ attention_mask=None,
2046
+ position_ids=None,
2047
+ token_type_ids=None,
2048
+ return_dict=None,
2049
+ matryoshka_dim=None,
2050
+ ):
2051
+ embeddings, rope = self.embeddings(pixel_values)
2052
+
2053
+ original_dtype = embeddings.dtype
2054
+
2055
+ hidden_states = embeddings
2056
+ # unused but easier to pass to gradient checkpointing as words
2057
+ residual = None
2058
+ for layer in self.layers:
2059
+ # need to pass none for backwards compatability
2060
+ hidden_states, _, residual = layer(hidden_states, None, residual=residual, is_padded_inputs=False, rope=rope)
2061
+
2062
+ hidden_states = hidden_states + residual
2063
+ if self.global_pool == "avg":
2064
+ hidden_states = hidden_states[:, self.num_prefix_tokens:].mean(dim=1)
2065
+
2066
+ pooled_output = self.selector(hidden_states)
2067
+
2068
+ return BaseModelOutputWithPast(
2069
+ last_hidden_state=pooled_output,
2070
+ hidden_states=hidden_states,
2071
+ )