dinalt commited on
Commit
35354b2
1 Parent(s): fc26b50

Update modelling_walsh.py

Browse files

- Added support for inference cache.
- Refactor common code in attention
- Removed unused code (fragments from another project)

Files changed (1) hide show
  1. modelling_walsh.py +369 -296
modelling_walsh.py CHANGED
@@ -1,5 +1,5 @@
1
  # See: https://huggingface.co/docs/transformers/custom_models
2
- from typing import Optional, Tuple, Union
3
  import math
4
  import copy
5
  import sys
@@ -9,7 +9,7 @@ import torch
9
  from torch import nn, Tensor
10
  import torch.nn.init as init
11
  from torch.nn import functional as F
12
- from transformers.modeling_outputs import CausalLMOutput
13
  from transformers import (
14
  PreTrainedModel,
15
  PretrainedConfig,
@@ -18,6 +18,10 @@ from transformers import (
18
  AutoModelForCausalLM,
19
  )
20
 
 
 
 
 
21
  from transformers.utils import (
22
  is_flash_attn_2_available,
23
  is_flash_attn_greater_or_equal_2_10,
@@ -26,6 +30,8 @@ from transformers.utils import (
26
  if is_flash_attn_2_available():
27
  from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
28
 
 
 
29
  # The model type string to bind.
30
  model_type = "walsh-causal-v1"
31
 
@@ -78,6 +84,10 @@ class Config(PretrainedConfig):
78
  layer_args=dict(),
79
  embedding_args=dict(),
80
  output_proj_args=dict(),
 
 
 
 
81
 
82
  **kwargs,
83
  ):
@@ -113,6 +123,10 @@ class Config(PretrainedConfig):
113
  self.layer_args = layer_args
114
  self.embedding_args = embedding_args
115
  self.output_proj_args = output_proj_args
 
 
 
 
116
 
117
  super().__init__(**kwargs)
118
 
@@ -204,6 +218,8 @@ class HFCausalModel(PreTrainedModel):
204
  _no_split_modules = ["DeepNetLayer"]
205
  _supports_flash_attn_2 = True
206
  _supports_sdpa = True
 
 
207
 
208
  def __init__(self, config):
209
  super().__init__(config)
@@ -221,40 +237,143 @@ class HFCausalModel(PreTrainedModel):
221
  token_type_ids: Optional[torch.LongTensor] = None,
222
  position_ids: Optional[torch.LongTensor] = None,
223
  labels: Optional[torch.LongTensor] = None,
 
 
224
  output_attentions: Optional[bool] = None,
225
  output_hidden_states: Optional[bool] = None,
226
  return_dict: Optional[bool] = None,
227
  **kwargs,
228
  ) -> (Tensor, dict[str, Tensor]):
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  if self.gradient_checkpointing and self.training:
 
 
 
 
 
231
  gradient_checkpointing_func = self._gradient_checkpointing_func
232
  else:
233
  gradient_checkpointing_func = None
 
234
 
235
- logits, attentions = self.transformer_head(
236
  input_ids=input_ids,
237
- need_weights=output_attentions,
 
238
  gradient_checkpointing_func=gradient_checkpointing_func,
 
 
 
239
  )
 
 
 
240
 
241
  # Compute loss.
242
  if labels is not None:
243
  loss = self.loss_function(logits=logits, labels=labels, input_ids=input_ids)
244
  else:
245
  loss = None
 
 
 
 
 
246
 
247
- return CausalLMOutput(loss=loss, logits=logits, attentions=attentions)
248
-
249
- # Needed for generate() method.
250
- def prepare_inputs_for_generation(self, input_ids, **kwargs):
251
- attention_mask = kwargs.get("attention_mask", None)
252
- model_inputs = {
253
- "input_ids": input_ids,
254
- "attention_mask": attention_mask,
255
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  return model_inputs
257
 
 
 
 
 
 
 
 
 
 
258
  def _make_embedding(self, config):
259
  embedding_cls = get_dynamic_class(config.embdding_cls)
260
  return embedding_cls(config.vocab_size, self.d_model, config.pad_index, **config.embedding_args)
@@ -278,7 +397,7 @@ class HFCausalModel(PreTrainedModel):
278
  norm_cls = get_dynamic_class(config.norm_cls)
279
  return norm_cls(self.d_model)
280
 
281
- def _make_self_attention(self, config):
282
  attention_cls = get_dynamic_class(config.attention_cls)
283
  # Map HF _attn_implementation to attn_type
284
  match config._attn_implementation:
@@ -299,28 +418,32 @@ class HFCausalModel(PreTrainedModel):
299
  d_model=self.d_model,
300
  num_heads=config.num_attention_heads,
301
  attn_type=attn_type,
 
 
302
  **config.attention_args,
303
  )
304
 
305
- def _make_feedforward(self, config):
306
  feedforward_cls = get_dynamic_class(config.feedforward_cls)
307
  return feedforward_cls(
308
  d_model=self.d_model,
309
  feedforward_dim=config.dim_feedforward,
310
  dropout=config.dropout,
311
  activation=self._make_activation(config),
 
312
  **config.feedforward_args,
313
  )
314
 
315
- def _make_layer(self, config):
316
  layer_cls = get_dynamic_class(config.layer_cls)
317
  return layer_cls(
318
  d_model=self.d_model,
319
  dropout=self._make_dropout(config),
320
- attention=self._make_self_attention(config),
321
- feedforward=self._make_feedforward(config),
322
  norm1=self._make_norm(config),
323
  norm2=self._make_norm(config),
 
324
  **config.layer_args,
325
  )
326
 
@@ -328,7 +451,7 @@ class HFCausalModel(PreTrainedModel):
328
  layer_stack_cls = get_dynamic_class(config.layer_stack_cls)
329
  return layer_stack_cls(
330
  layers=nn.ModuleList([
331
- self._make_layer(config) for _ in range(config.num_hidden_layers)
332
  ]),
333
  **config.layer_stack_args,
334
  )
@@ -364,43 +487,35 @@ class Transformer(nn.Module):
364
  self.sqrt_d_model = d_model**0.5
365
  self.reset_parameters()
366
 
367
- def forward(self, input_ids, need_weights, gradient_checkpointing_func):
368
- x = self.positional_encoder(self.embedding(input_ids) * self.sqrt_d_model)
369
-
370
- x, attentions = self.layer_stack(
371
- x,
372
- need_weights,
373
- gradient_checkpointing_func,
 
 
 
 
 
 
 
 
 
 
374
  )
375
 
376
- # Translate output embedding ot logits.
377
- logits = self.output_projection(x)
378
- return logits, attentions
 
379
 
380
  def reset_parameters(self):
381
  init.xavier_uniform_(self.output_projection.weight)
382
  init.constant_(self.output_projection.bias, 0.)
383
  init.normal_(self.embedding.weight, std=self.d_model**-0.5)
384
 
385
- # A vanilla positional encoder
386
- class PositionalEncoder(nn.Module):
387
- def __init__(self, d_embed, max_seq):
388
- super().__init__()
389
- self.d_embed = d_embed
390
- self.max_seq = max_seq
391
-
392
- weight = torch.zeros(max_seq, d_embed)
393
- position = torch.arange(0, max_seq, dtype=torch.float).unsqueeze(1)
394
- div_term = torch.exp(torch.arange(0, d_embed, 2).float() * (-math.log(10000.0) / d_embed))
395
- weight[:, 0::2] = torch.sin(position * div_term)
396
- weight[:, 1::2] = torch.cos(position * div_term)
397
- weight = weight.unsqueeze(0)
398
- self.register_buffer('weight', weight)
399
-
400
- def forward(self, x):
401
- seq_len = x.size(-2)
402
- return x + self.weight[:, :seq_len]
403
-
404
  # Converts a torch array of integers into their equivalent binary codes.
405
  def binary_tensor(x, bits):
406
  mask = 2**torch.arange(bits).to(x.device, x.dtype)
@@ -472,7 +587,7 @@ class RSWalshPositionalEncoder(nn.Module):
472
  # walsh = (hadamard_walsh_matrix(k)[:bits,:d_embed] -0.5) * self.gain
473
  self.register_buffer('walsh', walsh, persistent=False)
474
 
475
- def forward(self, x):
476
  seq_len = x.size(-2)
477
 
478
  # Get sequence of binary codes...
@@ -486,6 +601,12 @@ class RSWalshPositionalEncoder(nn.Module):
486
  shift = torch.randint(self.max_seq - seq_len + 1, (1,)).item()
487
  seq = self.binary_code[shift:seq_len + shift,:]
488
 
 
 
 
 
 
 
489
  # Disable shifting when not training. This does not appear to change the evaluation loss, but
490
  # it does makes predictions easier to analyse when the attention weights are not shifting with each step.
491
  else:
@@ -508,25 +629,58 @@ class TransformerLayerStack(nn.Module):
508
  super().__init__()
509
  self.layers = layers
510
 
511
- def forward(self, x, need_weights, gradient_checkpointing_func=None):
512
- attentions = []
 
 
 
 
 
 
 
 
 
 
 
513
  for layer in self.layers:
514
  if gradient_checkpointing_func is not None:
515
- x, attention_weights = gradient_checkpointing_func(
516
  layer.__call__,
517
- x,
518
- need_weights,
519
- use_reentrant=False
 
 
520
  )
521
  else:
522
- x, attention_weights = layer(x, need_weights=need_weights)
523
- if need_weights:
524
- attentions.append(attention_weights)
 
 
 
525
 
526
- return x, attentions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
528
  # DeepNet: Scaling Transformers to 1,000 Layers
529
  # https://arxiv.org/abs/2203.00555
 
530
  class DeepnetLayer(nn.Module):
531
  def __init__(
532
  self,
@@ -536,6 +690,7 @@ class DeepnetLayer(nn.Module):
536
  norm1,
537
  norm2,
538
  dropout,
 
539
  alpha=1.0,
540
  ):
541
  super().__init__()
@@ -547,27 +702,45 @@ class DeepnetLayer(nn.Module):
547
  self.dropout = dropout
548
  # Deepnet alpha
549
  self.alpha = alpha
 
550
 
551
- def forward(self, x, need_weights=False):
 
 
 
 
 
 
552
  # Keep input as residual
553
- residual = x * self.alpha
554
 
555
  # Compute attention
556
- x, attention_weights = self.attention(x, need_weights)
 
 
 
 
 
 
 
557
 
558
  # Add attention with residual and normalize.
559
- x = self.norm1(residual + self.dropout(x))
560
 
561
  # Keep output as next residual.
562
- residual = x * self.alpha
563
 
564
  # Pass through feedforward network.
565
- x = self.feedforward(x)
566
 
567
  # Combine residual and ff output, then normalize again.
568
- x = self.norm2(residual + self.dropout(x))
569
 
570
- return x, attention_weights
 
 
 
 
571
 
572
  # A vanilla MLP transfomer layer.
573
  class FeedforwardLayer(nn.Module):
@@ -576,6 +749,7 @@ class FeedforwardLayer(nn.Module):
576
  d_model: int,
577
  feedforward_dim: int,
578
  dropout,
 
579
  activation=nn.ReLU(),
580
  beta=1.0,
581
  bias=True,
@@ -598,41 +772,6 @@ class FeedforwardLayer(nn.Module):
598
  init.constant_(self.linear1.bias, 0.)
599
  init.constant_(self.linear2.bias, 0.)
600
 
601
- # GLU Variants Improve Transformer
602
- # https://arxiv.org/pdf/2002.05202v1.pdf
603
- class SwiGLUFeedforwardLayer(nn.Module):
604
- def __init__(
605
- self,
606
- d_model,
607
- d_feedforward,
608
- beta=1.0,
609
- dropout=0.1
610
- ):
611
- super().__init__()
612
- self.d_model = d_model
613
- self.d_feedforward = d_feedforward
614
- self.beta = 1.0
615
-
616
- self.linear1 = nn.Linear(self.d_model, self.d_feedforward * 2, bias=False)
617
- self.linear2 = nn.Linear(self.d_feedforward, self.d_model, bias=False)
618
- self.dropout = nn.Dropout(dropout)
619
- self.reset_parameters()
620
-
621
- def forward(self, x):
622
- x, gate = self.linear1(x).chunk(2, dim=-1)
623
- x = x * F.silu(gate)
624
- x = self.dropout(x)
625
- x = self.linear2(x)
626
- return x
627
-
628
- def reset_parameters(self):
629
- # Deepnet initialization
630
- # https://arxiv.org/pdf/2203.00555.pdf
631
- w, g = self.linear1.weight.chunk(2, dim=0)
632
- init.xavier_uniform_(w, gain=self.beta)
633
- init.xavier_uniform_(g, gain=self.beta)
634
- init.xavier_uniform_(self.linear2.weight, gain=self.beta)
635
-
636
  class CausalSelfAttention(nn.Module):
637
  def __init__(
638
  self,
@@ -643,6 +782,8 @@ class CausalSelfAttention(nn.Module):
643
  # torch: Use pytorch "scaled_dot_product_attention()"; faster; generally good compatibility; does not support returning attn weights.
644
  # flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; least memory usage.
645
  attn_type,
 
 
646
  beta=1.0,
647
  dropout=0.1,
648
  ):
@@ -651,6 +792,8 @@ class CausalSelfAttention(nn.Module):
651
  self.num_heads = num_heads
652
  self.beta = beta
653
  self.attn_type = attn_type
 
 
654
 
655
  assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads"
656
 
@@ -681,29 +824,56 @@ class CausalSelfAttention(nn.Module):
681
  init.constant_(self.in_proj.bias, 0.)
682
  init.constant_(self.output_linear.bias, 0.)
683
 
684
- def project_input(self, qkv):
685
- proj = self.in_proj(qkv)
686
- return proj.chunk(chunks=3, dim=-1)
687
-
688
- def forward(self, qkv, need_weights):
689
- if self.attn_type == "flash2":
690
- return self.flash2_forward(qkv)
691
-
692
- # qkv: (batch_size, seq_len, d_embed)
693
  batch_size, seq_len, d_embed = qkv.shape
 
 
694
 
695
- # Feed the inputs through the K, Q, V matrices.
696
- query, key, value = self.project_input(qkv)
697
-
698
  # Split projections into multiple heads and swap position of sequence / heads dimension
699
  query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
700
  key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
701
  value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
703
  # Default to returning empty attention weights.
704
- attention_weights = None
 
 
705
 
706
- if self.attn_type == "torch":
707
  # This context manager can be used to force which implementation to use.
708
  #with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
709
  attended_values = F.scaled_dot_product_attention(
@@ -712,7 +882,7 @@ class CausalSelfAttention(nn.Module):
712
  value,
713
  attn_mask=None,
714
  dropout_p=self.dropout.p if self.training else 0.0,
715
- is_causal=True,
716
  scale=self.dot_product_scale
717
  )
718
  # "native" scaled-dot-product attention implementation.
@@ -721,44 +891,57 @@ class CausalSelfAttention(nn.Module):
721
  scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale
722
 
723
  # Mask future positions from the past
724
- scores.masked_fill_(
725
- torch.tril(
726
- torch.ones(seq_len, seq_len, dtype=torch.bool, device=qkv.device),
727
- diagonal=0,
728
- ).logical_not(),
729
- float('-inf'),
730
- )
 
731
 
732
  # Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator
733
- attention_weights = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10))
734
  del scores
735
 
736
  # Use the attention weights to get a weighted combination of value vectors
737
- attended_values = torch.matmul(attention_weights, value)
738
- if not need_weights:
739
- del attention_weights
740
- attention_weights = None
741
 
742
  # Concatenate attention heads and project to original embedding size using the output linear layer
743
  attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed)
744
 
745
  # Project the concatenated output through the output matrix.
746
  attended_values = self.output_linear(attended_values)
747
- return attended_values, attention_weights
748
-
749
- def flash2_forward(self, qkv):
 
 
 
 
 
 
 
 
750
  batch_size, seq_len, d_embed = qkv.shape
751
 
752
  # Feed the inputs through the K, Q, V matrices.
753
  # query : (batch_size, seq_len, d_model)
754
  # qkv : (batch_size, seq_len, 3, num_heads, d_kq)
 
 
 
 
755
  qkv = self.in_proj(qkv).unflatten(
756
  -1,
757
  (3, self.num_heads, self.d_head)
758
  )
759
-
760
  attended_values = flash_attn_qkvpacked_func(
761
- qkv.bfloat16(),
762
  dropout_p=self.dropout.p if self.training else 0.0,
763
  softmax_scale=self.dot_product_scale,
764
  causal=True,
@@ -770,180 +953,70 @@ class CausalSelfAttention(nn.Module):
770
 
771
  # Project the concatenated output through the output matrix.
772
  attended_values = self.output_linear(attended_values)
773
- return attended_values, None
774
-
775
- # Attention layer with ALiBi relative positional encoding
776
- # TRAIN SHORT, TEST LONG: ATTENTION WITH LINEAR BIASES ENABLES INPUT LENGTH EXTRAPOLATION
777
- # https://arxiv.org/pdf/2108.12409.pdf
778
- def alibi_biases(query_len, key_len, device='cpu'):
779
- x = torch.arange(key_len, device=device)[None, :]
780
- y = torch.arange(query_len, device=device)[:, None]
781
- return x - y
782
 
783
- class CausalAlibiAttention(nn.Module):
784
- def __init__(
 
785
  self,
786
- d_model,
787
- num_heads,
788
- beta=1.0,
789
- dropout=0.1,
790
- # values:
791
- # native: Use local impementation; slowest option; good for debugging; useful when experimenting with non-standard stuff.
792
- # torch: Use pytorch "scaled_dot_product_attention()"; faster; generally good compatibility; does not support returning attn weights.
793
- # flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; can't train Alibi weights; least memory usage.
794
- # Note: You can perform initial training with "torch," then switch to "flash2," after the Alibi weights have settled.
795
- window_size=None,
796
- attn_type="native",
797
- freeze_alibi=True,
798
  ):
799
- super().__init__()
800
- self.d_model = d_model
801
- self.num_heads = num_heads
802
- self.beta = beta
803
- self.attn_type = attn_type
804
-
805
- assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads"
806
-
807
- # The dimension of each head.
808
- self.d_head = d_model // num_heads
809
-
810
- # We scale the attention scores by the inverse-square-root of the head dimension
811
- # this shifts the temerature of softmax.
812
- self.dot_product_scale = 1.0 / math.sqrt(self.d_head)
813
-
814
- self.in_proj = nn.Parameter(torch.empty(3 * self.d_model, self.d_model))
815
- self.output_linear = nn.Linear(self.d_model, self.d_model, bias=False)
816
-
817
- if window_size is not None:
818
- self.window_size=(window_size, -1)
819
- else:
820
- self.window_size = (-1, -1)
821
-
822
- self.dropout = nn.Dropout(dropout)
823
-
824
- # This generates the original slope distribution from the paper.
825
- # Observations with trainable slopes suggest that the high half of the slopes shift
826
- # towards / past 1.0 and the low half approach zero or even go slightly negative.
827
- # alibi_slopes = 1.0 / torch.logspace(1, 8, self.num_heads, base=2, dtype=torch.float)
828
-
829
- # These appear to work better, as initial values, in practice.
830
- alibi_slopes = 1.0 / torch.logspace(0, 7, self.num_heads, base=2, dtype=torch.float)
831
-
832
- # If not trainable, it can improve performance somewhat if the low half are set to zero. Apparently
833
- # making roughly half of the slopes position-agnostic is somehow closer to optimal?
834
- # alibi_slopes.masked_fill_(torch.where(torch.arange(0, self.num_heads) >= (self.num_heads / 2), True, False), 0)
835
-
836
- self.alibi_slopes = nn.Parameter(alibi_slopes)
837
-
838
- # Optionally, allow/disallow training of ALiBi slopes.
839
- self.alibi_slopes.requires_grad = (not freeze_alibi)
840
- self.reset_parameters()
841
-
842
- def extra_repr(self) -> str:
843
- return f'd_model={self.d_model}, num_heads={self.num_heads}, beta={self.beta}, attn_type={self.attn_type}, window_size={self.window_size}, dropout={self.dropout}'
844
-
845
- def reset_parameters(self):
846
- # Deepnet initialization
847
- # https://arxiv.org/pdf/2203.00555.pdf
848
-
849
- q, k, v = self.in_proj.chunk(3)
850
- init.xavier_uniform_(q, gain=1.0)
851
- init.xavier_uniform_(k, gain=1.0)
852
- init.xavier_uniform_(v, gain=self.beta)
853
- init.xavier_uniform_(self.output_linear.weight, gain=self.beta)
854
-
855
- def project_input(self, qkv):
856
- proj = F.linear(qkv, self.in_proj)
857
- return proj.chunk(chunks=3, dim=-1)
858
-
859
- def forward(self, qkv, need_weights):
860
- if self.attn_type == "flash2":
861
- return self.flash2_forward(qkv)
862
-
863
- # qkv: (batch_size, seq_len, d_embed)
864
- batch_size, seq_len, d_embed = qkv.shape
865
-
866
- # Feed the inputs through the K, Q, V matrices.
867
- query, key, value = self.project_input(qkv)
868
-
869
- # Split projections into multiple heads and swap position of sequence / heads dimension
870
- query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
871
- key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
872
- value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
873
-
874
- # Apply Alibi relative positional biases.
875
- attn_bias = alibi_biases(seq_len, seq_len, device=query.device) * self.alibi_slopes.view(-1, 1, 1)
876
-
877
- # Mask future positions from the past
878
- causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=qkv.device), diagonal=0)
879
- attn_bias.masked_fill_(causal_mask.logical_not(), float('-inf'))
880
- del causal_mask
881
-
882
- # Default to returning empty attention weights.
883
- attention_weights = None
884
-
885
- if self.attn_type == "torch":
886
- # This context manager can be used to force which implementation to use.
887
- #with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
888
- attended_values = F.scaled_dot_product_attention(
889
- query,
890
- key,
891
- value,
892
- attn_mask=attn_bias.to(dtype=query.dtype),
893
- dropout_p=self.dropout.p if self.training else 0.0,
894
- is_causal=False,
895
- scale=self.dot_product_scale
896
- )
897
- # "native" scaled-dot-product attention implementation.
898
- else:
899
- # Compute attention scores
900
- scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale
901
-
902
- # Adjust scores with attn_mask
903
- scores += attn_bias
904
-
905
- # Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator
906
- attention_weights = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10))
907
-
908
- # Use the attention weights to get a weighted combination of value vectors
909
- attended_values = torch.matmul(attention_weights, value)
910
- if not need_weights:
911
- attention_weights = None
912
-
913
- # Concatenate attention heads and project to original embedding size using the output linear layer
914
- attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed)
915
-
916
- # Project the concatenated output through the output matrix.
917
- attended_values = self.output_linear(attended_values)
918
- return attended_values, attention_weights
919
-
920
- def flash2_forward(self, qkv):
921
  batch_size, seq_len, d_embed = qkv.shape
922
 
923
  # Feed the inputs through the K, Q, V matrices.
924
- # query : (batch_size, seq_len, d_model)
925
- # qkv : (batch_size, seq_len, 3, num_heads, d_kq)
926
- qkv = F.linear(
927
- qkv,
928
- self.in_proj,
929
- ).unflatten(
930
- -1,
931
- (3, self.num_heads, self.d_head)
932
- )
933
-
934
- attended_values = flash_attn_qkvpacked_func(
935
- qkv.bfloat16(),
 
 
 
936
  dropout_p=self.dropout.p if self.training else 0.0,
937
  softmax_scale=self.dot_product_scale,
938
  causal=True,
939
- window_size=self.window_size,
940
- alibi_slopes=self.alibi_slopes.float(),
941
- ).to(dtype=qkv.dtype)
942
  # attended_values: (batch_size, seqlen, nheads, headdim)
943
-
944
  # Concatentate heads back into d_embed
945
  attended_values = attended_values.view(batch_size, seq_len, d_embed)
946
 
947
  # Project the concatenated output through the output matrix.
948
  attended_values = self.output_linear(attended_values)
949
- return attended_values, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # See: https://huggingface.co/docs/transformers/custom_models
2
+ from typing import Optional, Tuple, Union, List
3
  import math
4
  import copy
5
  import sys
 
9
  from torch import nn, Tensor
10
  import torch.nn.init as init
11
  from torch.nn import functional as F
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutput, CausalLMOutputWithPast
13
  from transformers import (
14
  PreTrainedModel,
15
  PretrainedConfig,
 
18
  AutoModelForCausalLM,
19
  )
20
 
21
+ from transformers.utils import logging
22
+
23
+ from transformers.cache_utils import Cache, DynamicCache
24
+
25
  from transformers.utils import (
26
  is_flash_attn_2_available,
27
  is_flash_attn_greater_or_equal_2_10,
 
30
  if is_flash_attn_2_available():
31
  from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
32
 
33
+ logger = logging.get_logger(__name__)
34
+
35
  # The model type string to bind.
36
  model_type = "walsh-causal-v1"
37
 
 
84
  layer_args=dict(),
85
  embedding_args=dict(),
86
  output_proj_args=dict(),
87
+
88
+ output_attentions=False,
89
+ output_hidden_states=False,
90
+ use_cache=True,
91
 
92
  **kwargs,
93
  ):
 
123
  self.layer_args = layer_args
124
  self.embedding_args = embedding_args
125
  self.output_proj_args = output_proj_args
126
+
127
+ self.output_attentions = output_attentions
128
+ self.output_hidden_states = output_hidden_states
129
+ self.use_cache = use_cache
130
 
131
  super().__init__(**kwargs)
132
 
 
218
  _no_split_modules = ["DeepNetLayer"]
219
  _supports_flash_attn_2 = True
220
  _supports_sdpa = True
221
+ _supports_cache_class = True
222
+ _skip_keys_device_placement = "past_key_values"
223
 
224
  def __init__(self, config):
225
  super().__init__(config)
 
237
  token_type_ids: Optional[torch.LongTensor] = None,
238
  position_ids: Optional[torch.LongTensor] = None,
239
  labels: Optional[torch.LongTensor] = None,
240
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
241
+ use_cache: Optional[bool] = None,
242
  output_attentions: Optional[bool] = None,
243
  output_hidden_states: Optional[bool] = None,
244
  return_dict: Optional[bool] = None,
245
  **kwargs,
246
  ) -> (Tensor, dict[str, Tensor]):
247
 
248
+ batch_size, seq_len = input_ids.shape
249
+
250
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
251
+ output_hidden_states = (
252
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
253
+ )
254
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
255
+
256
+ if use_cache:
257
+ # If legacy cache, convert to DynamicCache
258
+ use_legacy_cache = not isinstance(past_key_values, Cache)
259
+ if use_legacy_cache:
260
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
261
+
262
+
263
  if self.gradient_checkpointing and self.training:
264
+ if use_cache:
265
+ logger.warning_once(
266
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
267
+ )
268
+ use_cache = False
269
  gradient_checkpointing_func = self._gradient_checkpointing_func
270
  else:
271
  gradient_checkpointing_func = None
272
+
273
 
274
+ outputs = self.transformer_head(
275
  input_ids=input_ids,
276
+ position_ids=position_ids,
277
+ output_attentions=output_attentions,
278
  gradient_checkpointing_func=gradient_checkpointing_func,
279
+ past_key_values=past_key_values,
280
+ use_cache=use_cache,
281
+ output_hidden_states=output_hidden_states,
282
  )
283
+
284
+ logits = outputs["logits"].float()
285
+ attentions = outputs["attentions"]
286
 
287
  # Compute loss.
288
  if labels is not None:
289
  loss = self.loss_function(logits=logits, labels=labels, input_ids=input_ids)
290
  else:
291
  loss = None
292
+
293
+ # Convert back to legacy cache, if that's what we received
294
+ new_cache = outputs["past_key_values"]
295
+ if use_cache and new_cache is not None and use_legacy_cache:
296
+ new_cache = new_cache.to_legacy_cache()
297
 
298
+ return CausalLMOutputWithPast(
299
+ loss=loss,
300
+ logits=logits,
301
+ past_key_values=new_cache,
302
+ hidden_states=outputs["hidden_states"],
303
+ attentions=outputs["attentions"],
304
+ )
305
+
306
+ # Implementation from Huggingface Transformers,
307
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py
308
+ # Note: We do not implement attention mask at present, so some of this code is not applicable
309
+ # TODO: Reenable attention mask support for batch inference..
310
+ def prepare_inputs_for_generation(
311
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
312
+ ):
313
+ # Omit tokens covered by past_key_values
314
+ if past_key_values is not None:
315
+ if isinstance(past_key_values, Cache):
316
+ cache_length = past_key_values.get_seq_length()
317
+ past_length = past_key_values.seen_tokens
318
+ max_cache_length = past_key_values.get_max_length()
319
+ else:
320
+ cache_length = past_length = past_key_values[0][0].shape[2]
321
+ max_cache_length = None
322
+
323
+ # Keep only the unprocessed tokens:
324
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
325
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
326
+ # input)
327
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
328
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
329
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
330
+ # input_ids based on the past_length.
331
+ elif past_length < input_ids.shape[1]:
332
+ input_ids = input_ids[:, past_length:]
333
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
334
+
335
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
336
+ if (
337
+ max_cache_length is not None
338
+ and attention_mask is not None
339
+ and cache_length + input_ids.shape[1] > max_cache_length
340
+ ):
341
+ attention_mask = attention_mask[:, -max_cache_length:]
342
+
343
+ position_ids = kwargs.get("position_ids", None)
344
+ if attention_mask is not None and position_ids is None:
345
+ # create position_ids on the fly for batch generation
346
+ position_ids = attention_mask.long().cumsum(-1) - 1
347
+ position_ids.masked_fill_(attention_mask == 0, 1)
348
+ if past_key_values:
349
+ position_ids = position_ids[:, -input_ids.shape[1] :]
350
+
351
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
352
+ # NOTE: Injecting positional embeddings is not yet supported.
353
+ if inputs_embeds is not None and past_key_values is None:
354
+ model_inputs = {"inputs_embeds": inputs_embeds}
355
+ else:
356
+ model_inputs = {"input_ids": input_ids}
357
+
358
+ model_inputs.update(
359
+ {
360
+ "position_ids": position_ids,
361
+ "past_key_values": past_key_values,
362
+ "use_cache": kwargs.get("use_cache"),
363
+ "attention_mask": attention_mask,
364
+ }
365
+ )
366
  return model_inputs
367
 
368
+ @staticmethod
369
+ def _reorder_cache(past_key_values, beam_idx):
370
+ reordered_past = ()
371
+ for layer_past in past_key_values:
372
+ reordered_past += (
373
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
374
+ )
375
+ return reordered_past
376
+
377
  def _make_embedding(self, config):
378
  embedding_cls = get_dynamic_class(config.embdding_cls)
379
  return embedding_cls(config.vocab_size, self.d_model, config.pad_index, **config.embedding_args)
 
397
  norm_cls = get_dynamic_class(config.norm_cls)
398
  return norm_cls(self.d_model)
399
 
400
+ def _make_self_attention(self, layer_idx, config):
401
  attention_cls = get_dynamic_class(config.attention_cls)
402
  # Map HF _attn_implementation to attn_type
403
  match config._attn_implementation:
 
418
  d_model=self.d_model,
419
  num_heads=config.num_attention_heads,
420
  attn_type=attn_type,
421
+ layer_idx=layer_idx,
422
+ config=config,
423
  **config.attention_args,
424
  )
425
 
426
+ def _make_feedforward(self, layer_idx, config):
427
  feedforward_cls = get_dynamic_class(config.feedforward_cls)
428
  return feedforward_cls(
429
  d_model=self.d_model,
430
  feedforward_dim=config.dim_feedforward,
431
  dropout=config.dropout,
432
  activation=self._make_activation(config),
433
+ layer_idx=layer_idx,
434
  **config.feedforward_args,
435
  )
436
 
437
+ def _make_layer(self, layer_idx, config):
438
  layer_cls = get_dynamic_class(config.layer_cls)
439
  return layer_cls(
440
  d_model=self.d_model,
441
  dropout=self._make_dropout(config),
442
+ attention=self._make_self_attention(layer_idx, config),
443
+ feedforward=self._make_feedforward(layer_idx, config),
444
  norm1=self._make_norm(config),
445
  norm2=self._make_norm(config),
446
+ layer_idx=layer_idx,
447
  **config.layer_args,
448
  )
449
 
 
451
  layer_stack_cls = get_dynamic_class(config.layer_stack_cls)
452
  return layer_stack_cls(
453
  layers=nn.ModuleList([
454
+ self._make_layer(layer_idx, config) for layer_idx in range(config.num_hidden_layers)
455
  ]),
456
  **config.layer_stack_args,
457
  )
 
487
  self.sqrt_d_model = d_model**0.5
488
  self.reset_parameters()
489
 
490
+ def forward(
491
+ self,
492
+ input_ids,
493
+ position_ids,
494
+ output_attentions,
495
+ gradient_checkpointing_func,
496
+ past_key_values,
497
+ use_cache,
498
+ output_hidden_states,
499
+ ):
500
+ outputs = self.layer_stack(
501
+ self.positional_encoder(self.embedding(input_ids) * self.sqrt_d_model, position_ids),
502
+ output_attentions=output_attentions,
503
+ gradient_checkpointing_func=gradient_checkpointing_func,
504
+ past_key_values=past_key_values,
505
+ use_cache=use_cache,
506
+ output_hidden_states=output_hidden_states,
507
  )
508
 
509
+ # Translate output states to logits.
510
+ outputs["logits"] = self.output_projection(outputs["last_hidden_state"])
511
+ del outputs["last_hidden_state"]
512
+ return outputs
513
 
514
  def reset_parameters(self):
515
  init.xavier_uniform_(self.output_projection.weight)
516
  init.constant_(self.output_projection.bias, 0.)
517
  init.normal_(self.embedding.weight, std=self.d_model**-0.5)
518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  # Converts a torch array of integers into their equivalent binary codes.
520
  def binary_tensor(x, bits):
521
  mask = 2**torch.arange(bits).to(x.device, x.dtype)
 
587
  # walsh = (hadamard_walsh_matrix(k)[:bits,:d_embed] -0.5) * self.gain
588
  self.register_buffer('walsh', walsh, persistent=False)
589
 
590
+ def forward(self, x, position_ids=None):
591
  seq_len = x.size(-2)
592
 
593
  # Get sequence of binary codes...
 
601
  shift = torch.randint(self.max_seq - seq_len + 1, (1,)).item()
602
  seq = self.binary_code[shift:seq_len + shift,:]
603
 
604
+ # When the cache is used for generation, after the first call, we are only passed a single token at a time,
605
+ # with the remaining tokens being in the cache. We need to make sure that the newly injected tokens have the
606
+ # correct relative position by indexing the codes with the position_ids.
607
+ elif position_ids != None:
608
+ seq = self.binary_code[position_ids, :]
609
+
610
  # Disable shifting when not training. This does not appear to change the evaluation loss, but
611
  # it does makes predictions easier to analyse when the attention weights are not shifting with each step.
612
  else:
 
629
  super().__init__()
630
  self.layers = layers
631
 
632
+ def forward(
633
+ self,
634
+ hidden_states,
635
+ output_attentions,
636
+ past_key_values,
637
+ use_cache,
638
+ output_hidden_states,
639
+ gradient_checkpointing_func=None,
640
+ ):
641
+ present_key_value = None
642
+ all_attentions = [] if output_attentions else None
643
+ all_hidden_states = [hidden_states] if output_hidden_states else None
644
+
645
  for layer in self.layers:
646
  if gradient_checkpointing_func is not None:
647
+ layer_outputs = gradient_checkpointing_func(
648
  layer.__call__,
649
+ hidden_states,
650
+ output_attentions,
651
+ past_key_values,
652
+ use_cache,
653
+ use_reentrant=False,
654
  )
655
  else:
656
+ layer_outputs = layer(
657
+ hidden_states,
658
+ output_attentions,
659
+ past_key_values,
660
+ use_cache,
661
+ )
662
 
663
+ hidden_states = layer_outputs["hidden_states"]
664
+
665
+ if output_hidden_states:
666
+ all_hidden_states.append(hidden_states)
667
+
668
+ if use_cache:
669
+ present_key_value = layer_outputs["past_key_values"]
670
+
671
+ if output_attentions:
672
+ all_attentions.append(layer_outputs["attentions"])
673
+
674
+ return dict(
675
+ last_hidden_state=hidden_states,
676
+ past_key_values=present_key_value,
677
+ hidden_states=hidden_states,
678
+ attentions=all_attentions,
679
+ )
680
 
681
  # DeepNet: Scaling Transformers to 1,000 Layers
682
  # https://arxiv.org/abs/2203.00555
683
+ # Note: This is a type of Pre-Layer-Norm Transformer layer.
684
  class DeepnetLayer(nn.Module):
685
  def __init__(
686
  self,
 
690
  norm1,
691
  norm2,
692
  dropout,
693
+ layer_idx,
694
  alpha=1.0,
695
  ):
696
  super().__init__()
 
702
  self.dropout = dropout
703
  # Deepnet alpha
704
  self.alpha = alpha
705
+ self.layer_idx = layer_idx
706
 
707
+ def forward(
708
+ self,
709
+ hidden_states,
710
+ output_attentions,
711
+ past_key_values,
712
+ use_cache,
713
+ ):
714
  # Keep input as residual
715
+ residual = hidden_states * self.alpha
716
 
717
  # Compute attention
718
+ attn_outputs = self.attention(
719
+ hidden_states,
720
+ past_key_values=past_key_values,
721
+ use_cache=use_cache,
722
+ output_attentions=output_attentions
723
+ )
724
+
725
+ hidden_states = attn_outputs["hidden_states"]
726
 
727
  # Add attention with residual and normalize.
728
+ hidden_states = self.norm1(residual + self.dropout(hidden_states))
729
 
730
  # Keep output as next residual.
731
+ residual = hidden_states * self.alpha
732
 
733
  # Pass through feedforward network.
734
+ hidden_states = self.feedforward(hidden_states)
735
 
736
  # Combine residual and ff output, then normalize again.
737
+ hidden_states = self.norm2(residual + self.dropout(hidden_states))
738
 
739
+ return dict(
740
+ hidden_states=hidden_states,
741
+ attentions=attn_outputs["attentions"],
742
+ past_key_values=attn_outputs["past_key_values"]
743
+ )
744
 
745
  # A vanilla MLP transfomer layer.
746
  class FeedforwardLayer(nn.Module):
 
749
  d_model: int,
750
  feedforward_dim: int,
751
  dropout,
752
+ layer_idx,
753
  activation=nn.ReLU(),
754
  beta=1.0,
755
  bias=True,
 
772
  init.constant_(self.linear1.bias, 0.)
773
  init.constant_(self.linear2.bias, 0.)
774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775
  class CausalSelfAttention(nn.Module):
776
  def __init__(
777
  self,
 
782
  # torch: Use pytorch "scaled_dot_product_attention()"; faster; generally good compatibility; does not support returning attn weights.
783
  # flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; least memory usage.
784
  attn_type,
785
+ layer_idx,
786
+ config,
787
  beta=1.0,
788
  dropout=0.1,
789
  ):
 
792
  self.num_heads = num_heads
793
  self.beta = beta
794
  self.attn_type = attn_type
795
+ self.layer_idx = layer_idx
796
+ self.config = config
797
 
798
  assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads"
799
 
 
824
  init.constant_(self.in_proj.bias, 0.)
825
  init.constant_(self.output_linear.bias, 0.)
826
 
827
+ # Project QKV input through input matrices, reshape to (batch_size, n_heads, seq_len, d_model), and apply cache.
828
+ def project_input(self, qkv, past_key_values):
 
 
 
 
 
 
 
829
  batch_size, seq_len, d_embed = qkv.shape
830
+ proj = self.in_proj(qkv)
831
+ query, key, value = proj.chunk(chunks=3, dim=-1)
832
 
 
 
 
833
  # Split projections into multiple heads and swap position of sequence / heads dimension
834
  query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
835
  key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
836
  value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
837
 
838
+ # Update the cache values.
839
+ if past_key_values is not None:
840
+ key, value = past_key_values.update(key, value, self.layer_idx)
841
+ return query, key, value
842
+
843
+ def forward(
844
+ self,
845
+ qkv,
846
+ output_attentions,
847
+ past_key_values,
848
+ use_cache,
849
+ ):
850
+ attn_type = self.attn_type
851
+ if output_attentions and attn_type != "native":
852
+ logger.warning_once(
853
+ "CausalSelfAttention(output_attentions=True) and attn_type is not 'native': "
854
+ "Forcing native attention."
855
+ )
856
+ attn_type = "native"
857
+
858
+ if attn_type == "flash2":
859
+ if use_cache is None or use_cache == False:
860
+ return self.flash2_forward(qkv)
861
+ else:
862
+ return self.flash2_forward_cached(qkv, past_key_values)
863
+
864
+ # qkv: (batch_size, seq_len, d_embed)
865
+ batch_size, seq_len, d_embed = qkv.shape
866
+
867
+ # Feed the inputs through the K, Q, V matrices.
868
+ query, key, value = self.project_input(qkv, past_key_values)
869
+ kv_seq_len = key.shape[-2]
870
+
871
  # Default to returning empty attention weights.
872
+ attentions = None
873
+
874
+ # https://github.com/pytorch/pytorch/issues/112577
875
 
876
+ if attn_type == "torch":
877
  # This context manager can be used to force which implementation to use.
878
  #with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
879
  attended_values = F.scaled_dot_product_attention(
 
882
  value,
883
  attn_mask=None,
884
  dropout_p=self.dropout.p if self.training else 0.0,
885
+ is_causal=(seq_len > 1),
886
  scale=self.dot_product_scale
887
  )
888
  # "native" scaled-dot-product attention implementation.
 
891
  scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale
892
 
893
  # Mask future positions from the past
894
+ if seq_len > 1:
895
+ scores.masked_fill_(
896
+ torch.tril(
897
+ torch.ones(seq_len, kv_seq_len, dtype=torch.bool, device=qkv.device),
898
+ diagonal=0,
899
+ ).logical_not(),
900
+ float('-inf'),
901
+ )
902
 
903
  # Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator
904
+ attentions = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10))
905
  del scores
906
 
907
  # Use the attention weights to get a weighted combination of value vectors
908
+ attended_values = torch.matmul(attentions, value)
909
+ if not output_attentions:
910
+ del attentions
911
+ attentions = None
912
 
913
  # Concatenate attention heads and project to original embedding size using the output linear layer
914
  attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed)
915
 
916
  # Project the concatenated output through the output matrix.
917
  attended_values = self.output_linear(attended_values)
918
+ return dict(
919
+ hidden_states=attended_values,
920
+ attentions=attentions,
921
+ past_key_values=past_key_values
922
+ )
923
+
924
+ # No cache support, but faster
925
+ def flash2_forward(
926
+ self,
927
+ qkv,
928
+ ):
929
  batch_size, seq_len, d_embed = qkv.shape
930
 
931
  # Feed the inputs through the K, Q, V matrices.
932
  # query : (batch_size, seq_len, d_model)
933
  # qkv : (batch_size, seq_len, 3, num_heads, d_kq)
934
+ # Feed the inputs through the K, Q, V matrices.
935
+ # query : (batch_size, seq_len, d_model)
936
+ # qkv : (batch_size, seq_len, 3, num_heads, d_kq)
937
+
938
  qkv = self.in_proj(qkv).unflatten(
939
  -1,
940
  (3, self.num_heads, self.d_head)
941
  )
942
+
943
  attended_values = flash_attn_qkvpacked_func(
944
+ self._downcast_to_float16(qkv)[0],
945
  dropout_p=self.dropout.p if self.training else 0.0,
946
  softmax_scale=self.dot_product_scale,
947
  causal=True,
 
953
 
954
  # Project the concatenated output through the output matrix.
955
  attended_values = self.output_linear(attended_values)
956
+ return dict(
957
+ hidden_states=attended_values,
958
+ attentions=None,
959
+ past_key_values=None
960
+ )
 
 
 
 
961
 
962
+ # See https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py
963
+ #https://huggingface.co/docs/transformers/internal/generation_utils
964
+ def flash2_forward_cached(
965
  self,
966
+ qkv,
967
+ past_key_values,
 
 
 
 
 
 
 
 
 
 
968
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
969
  batch_size, seq_len, d_embed = qkv.shape
970
 
971
  # Feed the inputs through the K, Q, V matrices.
972
+ query, key, value = self.project_input(qkv, past_key_values)
973
+ query, key, value = self._downcast_to_float16(query, key, value)
974
+
975
+ # Expected inputs to flash2:
976
+ # q: (batch_size, seqlen, nheads, headdim)
977
+ # k: (batch_size, seqlen, nheads_k, headdim)
978
+ # v: (batch_size, seqlen, nheads_k, headdim)
979
+ query = query.transpose(1, 2)
980
+ key = key.transpose(1, 2)
981
+ value = value.transpose(1, 2)
982
+
983
+ attended_values = flash_attn_func(
984
+ q=query,
985
+ k=key,
986
+ v=value,
987
  dropout_p=self.dropout.p if self.training else 0.0,
988
  softmax_scale=self.dot_product_scale,
989
  causal=True,
990
+ )
 
 
991
  # attended_values: (batch_size, seqlen, nheads, headdim)
992
+
993
  # Concatentate heads back into d_embed
994
  attended_values = attended_values.view(batch_size, seq_len, d_embed)
995
 
996
  # Project the concatenated output through the output matrix.
997
  attended_values = self.output_linear(attended_values)
998
+ return dict(
999
+ hidden_states=attended_values,
1000
+ attentions=None,
1001
+ past_key_values=past_key_values
1002
+ )
1003
+
1004
+ def _downcast_to_float16(self, *args):
1005
+ if args[0].dtype != torch.float32:
1006
+ return args
1007
+
1008
+ if torch.is_autocast_enabled():
1009
+ target_dtype = torch.get_autocast_gpu_dtype()
1010
+ # Handle the case where the model is quantized
1011
+ elif hasattr(self.config, "_pre_quantization_dtype"):
1012
+ target_dtype = self.config._pre_quantization_dtype
1013
+ else:
1014
+ target_dtype = self.output_linear.weight.dtype
1015
+
1016
+ logger.warning_once(
1017
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
1018
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
1019
+ f" {target_dtype}."
1020
+ )
1021
+
1022
+ return (arg.to(target_dtype) for arg in args)