Xuyang Shen commited on
Commit
b9d0948
1 Parent(s): 201dd8a

Update modeling_transnormer.py

Browse files
Files changed (1) hide show
  1. modeling_transnormer.py +157 -166
modeling_transnormer.py CHANGED
@@ -53,8 +53,13 @@ logger = logging.get_logger(__name__)
53
 
54
  _CONFIG_FOR_DOC = "TransnormerConfig"
55
 
 
56
  use_triton = eval(os.environ.get("use_triton", default="True"))
57
  debug = eval(os.environ.get("debug", default="False"))
 
 
 
 
58
 
59
  if use_triton:
60
  try:
@@ -84,6 +89,7 @@ if not has_lightning_attention:
84
  ########## start Transnormer
85
  ##### Linearized Relative Positional Encoding: https://openreview.net/forum?id=xoLyps2qWc&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DTMLR%2FAuthors%23your-submissions)
86
  class Lrpe(nn.Module):
 
87
  def __init__(
88
  self,
89
  num_heads=8,
@@ -93,9 +99,8 @@ class Lrpe(nn.Module):
93
  d = num_heads * embed_dim
94
 
95
  self.index = torch.empty(0)
96
- self.theta = nn.Parameter(
97
- 10000 ** (-2 / d * torch.arange(d)).reshape(num_heads, 1, -1)
98
- )
99
 
100
  def extra_repr(self):
101
  return print_module(self)
@@ -114,6 +119,7 @@ class Lrpe(nn.Module):
114
 
115
 
116
  class GLU(nn.Module):
 
117
  def __init__(self, d1, d2, bias=False):
118
  super().__init__()
119
  if debug:
@@ -136,6 +142,7 @@ class GLU(nn.Module):
136
 
137
 
138
  class NormLinearAttention(nn.Module):
 
139
  def __init__(
140
  self,
141
  embed_dim,
@@ -169,7 +176,7 @@ class NormLinearAttention(nn.Module):
169
  )
170
 
171
  self.qkv_proj = nn.Linear(embed_dim, 3 * hidden_dim, bias=bias)
172
- self.output_gate = nn.Sequential(
173
  nn.Linear(embed_dim, gate_dim, bias=bias),
174
  nn.Linear(gate_dim, hidden_dim, bias=bias),
175
  )
@@ -187,7 +194,6 @@ class NormLinearAttention(nn.Module):
187
  use_cache: bool = False,
188
  slope_rate: Optional[torch.Tensor] = None,
189
  ):
190
- do_eval = eval(os.environ.get("do_eval", default="False"))
191
  if (not self.training) and (not do_eval):
192
  return self.inference(
193
  x,
@@ -203,11 +209,11 @@ class NormLinearAttention(nn.Module):
203
  # linear map
204
  qkv = self.act(self.qkv_proj(x))
205
  q, k, v = qkv.split([d, d, d], dim=-1)
206
-
207
  # reshape
208
  q, k, v = map(
209
- lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads), [q, k, v]
210
- )
211
 
212
  q_offset = 0
213
  # lrpe relys on position, get cache first
@@ -222,12 +228,12 @@ class NormLinearAttention(nn.Module):
222
  # lrpe
223
  if self.linear_use_lrpe:
224
  q = self.lrpe(q, offset=q_offset)
225
- k = self.lrpe(k)
226
 
227
  if attn_padding_mask is not None:
228
  v = v.masked_fill(
229
- (1 - attn_padding_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0
230
- )
231
 
232
  if not has_lightning_attention:
233
  if attn_mask == None:
@@ -236,9 +242,8 @@ class NormLinearAttention(nn.Module):
236
  attn_mask = torch.exp(slope_rate * attn_mask)
237
  output = linear_attention(q, k, v, attn_mask)
238
  else:
239
- output = lightning_attention(
240
- q, k, v, True, slope_rate.squeeze(-1).squeeze(-1)
241
- )
242
 
243
  # reshape
244
  output = rearrange(output, "b h n d -> b n (h d)")
@@ -257,14 +262,14 @@ class NormLinearAttention(nn.Module):
257
  return output, attn_weights, past_key_value
258
 
259
  def inference(
260
- self,
261
- x,
262
- attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
263
- attn_padding_mask: Optional[torch.Tensor] = None, # (b, m)
264
- output_attentions: bool = False,
265
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
266
- use_cache: bool = False,
267
- slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
268
  ):
269
  # x: b n d
270
  b, n, d = x.shape
@@ -273,13 +278,13 @@ class NormLinearAttention(nn.Module):
273
  q, k, v = qkv.split([d, d, d], dim=-1)
274
  # reshape
275
  q, k, v = map(
276
- lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads), [q, k, v]
277
- )
278
-
279
  # rpe
280
  if self.linear_use_lrpe:
281
  q = self.lrpe(q, offset=self.offset)
282
- k = self.lrpe(k)
283
 
284
  if past_key_value == None:
285
  self.offset = q.shape[-2]
@@ -290,38 +295,47 @@ class NormLinearAttention(nn.Module):
290
 
291
  # only use for the first time
292
  if past_key_value == None:
293
- if attn_mask == None:
294
- attn_mask = (torch.tril(torch.ones(n, n))).to(q)
295
- if slope_rate != None:
296
- attn_mask = torch.exp(slope_rate * attn_mask)
297
-
298
  if attn_padding_mask is not None:
299
- attn_mask = attn_mask.masked_fill(
300
- (1 - attn_padding_mask).unsqueeze(1).unsqueeze(2).to(torch.bool),
301
- 0,
302
- )
303
- energy = torch.einsum("... n d, ... m d -> ... n m", q, k)
304
-
305
- if attn_mask != None:
306
- energy = energy * attn_mask
307
-
308
- output = torch.einsum("... n m, ... m d -> ... n d", energy, v)
309
-
310
- eval_and_not_generate = eval(
311
- os.environ.get("eval_and_not_generate", default="False")
312
- )
313
- if eval_and_not_generate:
314
- kv = None
315
- else:
316
- # b, h, n, e, d
317
- kv_outproduct = torch.einsum("... n e, ... n d -> ... n e d", k, v)
318
- # 1, 1, n, 1, 1
319
- index = torch.arange(n - 1, -1, -1).reshape(1, 1, -1, 1, 1).to(x)
320
- # (h, 1, 1) -> (1, h, 1, 1, 1); (1, h, 1, 1, 1), (1, 1, n, 1, 1) -> (1, h, n, 1, 1)
321
- decay = ratio.unsqueeze(0).unsqueeze(-1) ** index
322
-
323
- kv_outproduct_with_decay = kv_outproduct * decay
324
- kv = torch.sum(kv_outproduct_with_decay, dim=-3)
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  else:
326
  kv = past_key_value
327
 
@@ -329,12 +343,11 @@ class NormLinearAttention(nn.Module):
329
  for i in range(n):
330
  kv = ratio * kv + torch.einsum(
331
  "... n d, ... n e -> ... d e",
332
- k[:, :, i : i + 1],
333
- v[:, :, i : i + 1],
334
- )
335
- qkv = torch.einsum(
336
- "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv
337
  )
 
 
338
  output.append(qkv)
339
  output = torch.concat(output, dim=-2)
340
 
@@ -353,6 +366,7 @@ class NormLinearAttention(nn.Module):
353
 
354
 
355
  class TransnormerDecoderLayer(nn.Module):
 
356
  def __init__(self, config: TransnormerConfig):
357
  super().__init__()
358
  self.embed_dim = config.decoder_embed_dim
@@ -392,14 +406,14 @@ class TransnormerDecoderLayer(nn.Module):
392
  return residual + x
393
 
394
  def forward(
395
- self,
396
- x,
397
- attn_mask: Optional[torch.Tensor] = None,
398
- attn_padding_mask: Optional[torch.Tensor] = None,
399
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
400
- output_attentions: Optional[bool] = False,
401
- use_cache: Optional[bool] = False,
402
- slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
403
  ):
404
  residual = x
405
  x = self.token_norm(x)
@@ -419,13 +433,13 @@ class TransnormerDecoderLayer(nn.Module):
419
  x = self.channel_mixer(x)
420
  x = self.residual_connection(x, residual)
421
 
422
- outputs = (x,)
423
 
424
  if output_attentions:
425
- outputs += (self_attn_weights,)
426
 
427
  if use_cache:
428
- outputs += (present_key_value,)
429
 
430
  return outputs
431
 
@@ -447,9 +461,7 @@ TRANSNORMER_START_DOCSTRING = r"""
447
  """
448
 
449
 
450
- @add_start_docstrings(
451
- TRANSNORMER_START_DOCSTRING,
452
- )
453
  class TransnormerPreTrainedModel(PreTrainedModel):
454
  config_class = TransnormerConfig
455
  base_model_prefix = "model"
@@ -534,9 +546,7 @@ TRANSNORMER_INPUTS_DOCSTRING = r"""
534
  """
535
 
536
 
537
- @add_start_docstrings(
538
- TRANSNORMER_START_DOCSTRING,
539
- )
540
  class TransnormerModel(TransnormerPreTrainedModel):
541
  """
542
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`TransnormerDecoderLayer`]
@@ -560,29 +570,31 @@ class TransnormerModel(TransnormerPreTrainedModel):
560
  self.slopes = self._build_slope_tensor(config.decoder_attention_heads)
561
 
562
  # params
563
- self.embed_tokens = nn.Embedding(
564
- config.vocab_size, config.decoder_embed_dim, self.padding_idx
565
- )
566
  self.layers = nn.ModuleList([])
567
  for i in range(config.decoder_layers):
568
  if len(self.linear_use_lrpe_list) > 0:
569
  config.linear_use_lrpe = self.linear_use_lrpe_list[i]
570
  self.layers.append(TransnormerDecoderLayer(config))
571
 
572
- self.final_norm = get_norm_fn(config.norm_type)(config.decoder_embed_dim)
 
573
  self.embed_dim = config.decoder_embed_dim
574
- self.embed_scale = (
575
- 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
576
- )
577
 
578
  # Initialize weights and apply final processing
579
  self.post_init()
580
 
581
  @staticmethod
582
  def _build_slope_tensor(n_attention_heads: int):
 
583
  def get_slopes(n):
 
584
  def get_slopes_power_of_2(n):
585
- start = 2 ** (-(2 ** -(math.log2(n) - 3)))
586
  ratio = start
587
  return [start * ratio**i for i in range(n)]
588
 
@@ -591,18 +603,15 @@ class TransnormerModel(TransnormerPreTrainedModel):
591
  n
592
  ) # In the paper, we only train models that have 2^a heads for some a. This function has
593
  else: # some good properties that only occur when the input is a power of 2. To maintain that even
594
- closest_power_of_2 = 2 ** math.floor(
595
  math.log2(n)
596
  ) # when the number of heads is not a power of 2, we use this workaround.
597
- return (
598
- get_slopes_power_of_2(closest_power_of_2)
599
- + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
600
- )
601
 
602
  # h, 1, 1
603
  slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
604
- n_attention_heads, 1, 1
605
- )
606
 
607
  return slopes
608
 
@@ -615,26 +624,26 @@ class TransnormerModel(TransnormerPreTrainedModel):
615
  def set_input_embeddings(self, value):
616
  self.embed_tokens = value
617
 
618
- def _prepare_decoder_linear_attn_mask(
619
- self, input_shape, inputs_embeds, past_key_values_length
620
- ):
621
  bsz, tgt_len = input_shape
622
  src_len = tgt_len + past_key_values_length
623
 
624
  def power_log(x):
625
- return 2 ** (math.ceil(math.log(x, 2)))
626
 
627
  n = power_log(max(tgt_len, src_len))
628
  if self._linear_attn_mask.shape[-1] < n:
629
 
630
  def get_mask(n):
631
- mask = torch.triu(torch.zeros(n, n).float().fill_(float("-inf")), 1)
 
632
  # no slope version
633
  # -n, ..., -2, -1, 0
634
  for i in range(n):
635
  x = torch.arange(i + 1)
636
  y = x
637
- mask[i, : i + 1] = -torch.flip(y, [0])
638
 
639
  return mask
640
 
@@ -646,7 +655,8 @@ class TransnormerModel(TransnormerPreTrainedModel):
646
  linear_attn_mask = self._linear_attn_mask[:, -tgt_len:, -src_len:]
647
  num_heads = linear_attn_mask.shape[0]
648
 
649
- return linear_attn_mask[None, :, :, :].expand(bsz, num_heads, tgt_len, src_len)
 
650
 
651
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
652
  def forward(
@@ -660,21 +670,15 @@ class TransnormerModel(TransnormerPreTrainedModel):
660
  output_hidden_states: Optional[bool] = None,
661
  return_dict: Optional[bool] = None,
662
  ) -> Union[Tuple, BaseModelOutputWithPast]:
663
- output_attentions = (
664
- output_attentions
665
- if output_attentions is not None
666
- else self.config.output_attentions
667
- )
668
- output_hidden_states = (
669
- output_hidden_states
670
- if output_hidden_states is not None
671
- else self.config.output_hidden_states
672
- )
673
  use_cache = use_cache if use_cache is not None else self.config.use_cache
674
 
675
- return_dict = (
676
- return_dict if return_dict is not None else self.config.use_return_dict
677
- )
678
 
679
  # retrieve input_ids and inputs_embeds
680
  if input_ids is not None and inputs_embeds is not None:
@@ -696,7 +700,7 @@ class TransnormerModel(TransnormerPreTrainedModel):
696
  if past_key_values is not None:
697
  past_key_values_length = past_key_values[0][0].shape[-2]
698
  seq_length_with_past = seq_length_with_past + past_key_values_length
699
-
700
  if inputs_embeds is None:
701
  # !!! use embed_scale
702
  inputs_embeds = self.embed_scale * self.embed_tokens(input_ids)
@@ -718,23 +722,23 @@ class TransnormerModel(TransnormerPreTrainedModel):
718
  ##### norm linear layers
719
  linear_attn_padding_mask = attn_padding_mask
720
  linear_attn_mask = self._prepare_decoder_linear_attn_mask(
721
- (batch_size, seq_length), inputs_embeds, past_key_values_length
722
- )
723
 
724
- slope_rates = [self.slopes.to(input_ids.device) for _ in range(self.num_layers)]
 
 
725
 
726
  for idx, layer in enumerate(self.layers):
727
  if output_hidden_states:
728
- all_hidden_states += (hidden_states,)
729
 
730
- past_key_value = (
731
- past_key_values[idx] if past_key_values is not None else None
732
- )
733
 
734
  slope_rate = slope_rates[idx]
735
  slope_rate = slope_rate * (1 - idx / (self.num_layers - 1) + 1e-5)
736
  mask = linear_attn_mask
737
-
738
  layer_outputs = layer(
739
  hidden_states,
740
  attn_mask=mask,
@@ -748,27 +752,24 @@ class TransnormerModel(TransnormerPreTrainedModel):
748
  hidden_states = layer_outputs[0]
749
 
750
  if use_cache:
751
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
 
752
 
753
  if output_attentions:
754
- all_self_attns += (layer_outputs[1],)
755
-
756
- # if idx == 0:
757
- # break
758
 
759
  hidden_states = self.final_norm(hidden_states)
760
 
761
  # add hidden states from the last decoder layer
762
  if output_hidden_states:
763
- all_hidden_states += (hidden_states,)
764
 
765
  next_cache = next_decoder_cache if use_cache else None
766
  if not return_dict:
767
  return tuple(
768
- v
769
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
770
- if v is not None
771
- )
772
  return BaseModelOutputWithPast(
773
  last_hidden_state=hidden_states,
774
  past_key_values=next_cache,
@@ -778,6 +779,7 @@ class TransnormerModel(TransnormerPreTrainedModel):
778
 
779
 
780
  class TransnormerForCausalLM(TransnormerPreTrainedModel):
 
781
  def __init__(self, config):
782
  super().__init__(config)
783
  self.model = TransnormerModel(config)
@@ -785,9 +787,9 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
785
  logging_info(self.model)
786
 
787
  # the lm_head weight is automatically tied to the embed tokens weight
788
- self.lm_head = nn.Linear(
789
- config.decoder_embed_dim, config.vocab_size, bias=False
790
- )
791
 
792
  # Initialize weights and apply final processing
793
  self.post_init()
@@ -811,9 +813,8 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
811
  return self.model
812
 
813
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
814
- @replace_return_docstrings(
815
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
816
- )
817
  def forward(
818
  self,
819
  input_ids: torch.LongTensor = None,
@@ -851,19 +852,13 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
851
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
852
  "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
853
  ```"""
854
- output_attentions = (
855
- output_attentions
856
- if output_attentions is not None
857
- else self.config.output_attentions
858
- )
859
- output_hidden_states = (
860
- output_hidden_states
861
- if output_hidden_states is not None
862
- else self.config.output_hidden_states
863
- )
864
- return_dict = (
865
- return_dict if return_dict is not None else self.config.use_return_dict
866
- )
867
 
868
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
869
  outputs = self.model(
@@ -894,8 +889,8 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
894
  loss = loss_fct(shift_logits, shift_labels)
895
 
896
  if not return_dict:
897
- output = (logits,) + outputs[1:]
898
- return (loss,) + output if loss is not None else output
899
 
900
  return CausalLMOutputWithPast(
901
  loss=loss,
@@ -922,22 +917,18 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
922
  else:
923
  model_inputs = {"input_ids": input_ids}
924
 
925
- model_inputs.update(
926
- {
927
- "past_key_values": past_key_values,
928
- "use_cache": kwargs.get("use_cache"),
929
- "attention_mask": attention_mask,
930
- }
931
- )
932
  return model_inputs
933
 
934
  @staticmethod
935
  def _reorder_cache(past_key_values, beam_idx):
936
  reordered_past = ()
937
  for layer_past in past_key_values:
938
- reordered_past += (
939
- tuple(
940
- past_state.index_select(0, beam_idx) for past_state in layer_past
941
- ),
942
- )
943
  return reordered_past
 
53
 
54
  _CONFIG_FOR_DOC = "TransnormerConfig"
55
 
56
+ # TODO: fix environment: https://huggingface.co/OpenNLPLab/TransNormerLLM-7B/discussions/1
57
  use_triton = eval(os.environ.get("use_triton", default="True"))
58
  debug = eval(os.environ.get("debug", default="False"))
59
+ do_eval = eval(os.environ.get("do_eval", default="False"))
60
+ eval_and_not_generate = eval(
61
+ os.environ.get("eval_and_not_generate", default="False"))
62
+ BLOCK = 256
63
 
64
  if use_triton:
65
  try:
 
89
  ########## start Transnormer
90
  ##### Linearized Relative Positional Encoding: https://openreview.net/forum?id=xoLyps2qWc&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DTMLR%2FAuthors%23your-submissions)
91
  class Lrpe(nn.Module):
92
+
93
  def __init__(
94
  self,
95
  num_heads=8,
 
99
  d = num_heads * embed_dim
100
 
101
  self.index = torch.empty(0)
102
+ self.theta = nn.Parameter(10000**(-2 / d * torch.arange(d)).reshape(
103
+ num_heads, 1, -1))
 
104
 
105
  def extra_repr(self):
106
  return print_module(self)
 
119
 
120
 
121
  class GLU(nn.Module):
122
+
123
  def __init__(self, d1, d2, bias=False):
124
  super().__init__()
125
  if debug:
 
142
 
143
 
144
  class NormLinearAttention(nn.Module):
145
+
146
  def __init__(
147
  self,
148
  embed_dim,
 
176
  )
177
 
178
  self.qkv_proj = nn.Linear(embed_dim, 3 * hidden_dim, bias=bias)
179
+ self.output_gate = nn.Sequential(
180
  nn.Linear(embed_dim, gate_dim, bias=bias),
181
  nn.Linear(gate_dim, hidden_dim, bias=bias),
182
  )
 
194
  use_cache: bool = False,
195
  slope_rate: Optional[torch.Tensor] = None,
196
  ):
 
197
  if (not self.training) and (not do_eval):
198
  return self.inference(
199
  x,
 
209
  # linear map
210
  qkv = self.act(self.qkv_proj(x))
211
  q, k, v = qkv.split([d, d, d], dim=-1)
212
+
213
  # reshape
214
  q, k, v = map(
215
+ lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads),
216
+ [q, k, v])
217
 
218
  q_offset = 0
219
  # lrpe relys on position, get cache first
 
228
  # lrpe
229
  if self.linear_use_lrpe:
230
  q = self.lrpe(q, offset=q_offset)
231
+ k = self.lrpe(k, offset=q_offset)
232
 
233
  if attn_padding_mask is not None:
234
  v = v.masked_fill(
235
+ (1 - attn_padding_mask).unsqueeze(1).unsqueeze(-1).to(
236
+ torch.bool), 0)
237
 
238
  if not has_lightning_attention:
239
  if attn_mask == None:
 
242
  attn_mask = torch.exp(slope_rate * attn_mask)
243
  output = linear_attention(q, k, v, attn_mask)
244
  else:
245
+ output = lightning_attention(q, k, v, True,
246
+ slope_rate.squeeze(-1).squeeze(-1))
 
247
 
248
  # reshape
249
  output = rearrange(output, "b h n d -> b n (h d)")
 
262
  return output, attn_weights, past_key_value
263
 
264
  def inference(
265
+ self,
266
+ x,
267
+ attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
268
+ attn_padding_mask: Optional[torch.Tensor] = None, # (b, m)
269
+ output_attentions: bool = False,
270
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
271
+ use_cache: bool = False,
272
+ slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
273
  ):
274
  # x: b n d
275
  b, n, d = x.shape
 
278
  q, k, v = qkv.split([d, d, d], dim=-1)
279
  # reshape
280
  q, k, v = map(
281
+ lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads),
282
+ [q, k, v])
283
+
284
  # rpe
285
  if self.linear_use_lrpe:
286
  q = self.lrpe(q, offset=self.offset)
287
+ k = self.lrpe(k, offset=self.offset)
288
 
289
  if past_key_value == None:
290
  self.offset = q.shape[-2]
 
295
 
296
  # only use for the first time
297
  if past_key_value == None:
298
+ slope_rate = slope_rate.to(torch.float32)
 
 
 
 
299
  if attn_padding_mask is not None:
300
+ v = v.masked_fill(
301
+ (1 - attn_padding_mask).unsqueeze(1).unsqueeze(-1).to(
302
+ torch.bool), 0)
303
+ NUM_BLOCK = (n + BLOCK - 1) // BLOCK
304
+ b, h, n, d = q.shape
305
+ e = v.shape[-1]
306
+ # other
307
+ array = torch.arange(BLOCK).to(q) + 1 ## !!!! important
308
+ q_decay = torch.exp(-slope_rate * array.reshape(-1, 1))
309
+ k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1)))
310
+ index = array[:, None] - array[None, :]
311
+ s_index = slope_rate * index[
312
+ None,
313
+ None,
314
+ ]
315
+ s_index = torch.where(index >= 0, -s_index, float("-inf"))
316
+ diag_decay = torch.exp(s_index)
317
+
318
+ kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device)
319
+ output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
320
+ for i in range(NUM_BLOCK):
321
+ si = i * BLOCK
322
+ ei = min(si + BLOCK, n)
323
+ m = ei - si
324
+
325
+ qi = q[:, :, si:ei].contiguous()
326
+ ki = k[:, :, si:ei].contiguous()
327
+ vi = v[:, :, si:ei].contiguous()
328
+ qkv_none_diag = torch.matmul(qi * q_decay[:, :m],
329
+ kv).to(torch.float32)
330
+
331
+ # diag
332
+ qk = torch.matmul(qi, ki.transpose(-1, -2)).to(
333
+ torch.float32) * diag_decay[:, :, :m, :m]
334
+ qkv_diag = torch.matmul(qk, vi.to(torch.float32))
335
+ block_decay = torch.exp(-slope_rate * m)
336
+ output[:, :, si:ei] = qkv_none_diag + qkv_diag
337
+ kv = block_decay * kv + torch.matmul(
338
+ (ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi)
339
  else:
340
  kv = past_key_value
341
 
 
343
  for i in range(n):
344
  kv = ratio * kv + torch.einsum(
345
  "... n d, ... n e -> ... d e",
346
+ k[:, :, i:i + 1],
347
+ v[:, :, i:i + 1],
 
 
 
348
  )
349
+ qkv = torch.einsum("... n e, ... e d -> ... n d",
350
+ q[:, :, i:i + 1], kv)
351
  output.append(qkv)
352
  output = torch.concat(output, dim=-2)
353
 
 
366
 
367
 
368
  class TransnormerDecoderLayer(nn.Module):
369
+
370
  def __init__(self, config: TransnormerConfig):
371
  super().__init__()
372
  self.embed_dim = config.decoder_embed_dim
 
406
  return residual + x
407
 
408
  def forward(
409
+ self,
410
+ x,
411
+ attn_mask: Optional[torch.Tensor] = None,
412
+ attn_padding_mask: Optional[torch.Tensor] = None,
413
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
414
+ output_attentions: Optional[bool] = False,
415
+ use_cache: Optional[bool] = False,
416
+ slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
417
  ):
418
  residual = x
419
  x = self.token_norm(x)
 
433
  x = self.channel_mixer(x)
434
  x = self.residual_connection(x, residual)
435
 
436
+ outputs = (x, )
437
 
438
  if output_attentions:
439
+ outputs += (self_attn_weights, )
440
 
441
  if use_cache:
442
+ outputs += (present_key_value, )
443
 
444
  return outputs
445
 
 
461
  """
462
 
463
 
464
+ @add_start_docstrings(TRANSNORMER_START_DOCSTRING, )
 
 
465
  class TransnormerPreTrainedModel(PreTrainedModel):
466
  config_class = TransnormerConfig
467
  base_model_prefix = "model"
 
546
  """
547
 
548
 
549
+ @add_start_docstrings(TRANSNORMER_START_DOCSTRING, )
 
 
550
  class TransnormerModel(TransnormerPreTrainedModel):
551
  """
552
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`TransnormerDecoderLayer`]
 
570
  self.slopes = self._build_slope_tensor(config.decoder_attention_heads)
571
 
572
  # params
573
+ self.embed_tokens = nn.Embedding(config.vocab_size,
574
+ config.decoder_embed_dim,
575
+ self.padding_idx)
576
  self.layers = nn.ModuleList([])
577
  for i in range(config.decoder_layers):
578
  if len(self.linear_use_lrpe_list) > 0:
579
  config.linear_use_lrpe = self.linear_use_lrpe_list[i]
580
  self.layers.append(TransnormerDecoderLayer(config))
581
 
582
+ self.final_norm = get_norm_fn(config.norm_type)(
583
+ config.decoder_embed_dim)
584
  self.embed_dim = config.decoder_embed_dim
585
+ self.embed_scale = (1.0 if config.no_scale_embedding else math.sqrt(
586
+ self.embed_dim))
 
587
 
588
  # Initialize weights and apply final processing
589
  self.post_init()
590
 
591
  @staticmethod
592
  def _build_slope_tensor(n_attention_heads: int):
593
+
594
  def get_slopes(n):
595
+
596
  def get_slopes_power_of_2(n):
597
+ start = 2**(-(2**-(math.log2(n) - 3)))
598
  ratio = start
599
  return [start * ratio**i for i in range(n)]
600
 
 
603
  n
604
  ) # In the paper, we only train models that have 2^a heads for some a. This function has
605
  else: # some good properties that only occur when the input is a power of 2. To maintain that even
606
+ closest_power_of_2 = 2**math.floor(
607
  math.log2(n)
608
  ) # when the number of heads is not a power of 2, we use this workaround.
609
+ return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
610
+ 2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
 
 
611
 
612
  # h, 1, 1
613
  slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
614
+ n_attention_heads, 1, 1)
 
615
 
616
  return slopes
617
 
 
624
  def set_input_embeddings(self, value):
625
  self.embed_tokens = value
626
 
627
+ def _prepare_decoder_linear_attn_mask(self, input_shape, inputs_embeds,
628
+ past_key_values_length):
 
629
  bsz, tgt_len = input_shape
630
  src_len = tgt_len + past_key_values_length
631
 
632
  def power_log(x):
633
+ return 2**(math.ceil(math.log(x, 2)))
634
 
635
  n = power_log(max(tgt_len, src_len))
636
  if self._linear_attn_mask.shape[-1] < n:
637
 
638
  def get_mask(n):
639
+ mask = torch.triu(
640
+ torch.zeros(n, n).float().fill_(float("-inf")), 1)
641
  # no slope version
642
  # -n, ..., -2, -1, 0
643
  for i in range(n):
644
  x = torch.arange(i + 1)
645
  y = x
646
+ mask[i, :i + 1] = -torch.flip(y, [0])
647
 
648
  return mask
649
 
 
655
  linear_attn_mask = self._linear_attn_mask[:, -tgt_len:, -src_len:]
656
  num_heads = linear_attn_mask.shape[0]
657
 
658
+ return linear_attn_mask[None, :, :, :].expand(bsz, num_heads, tgt_len,
659
+ src_len)
660
 
661
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
662
  def forward(
 
670
  output_hidden_states: Optional[bool] = None,
671
  return_dict: Optional[bool] = None,
672
  ) -> Union[Tuple, BaseModelOutputWithPast]:
673
+ output_attentions = (output_attentions if output_attentions is not None
674
+ else self.config.output_attentions)
675
+ output_hidden_states = (output_hidden_states
676
+ if output_hidden_states is not None else
677
+ self.config.output_hidden_states)
 
 
 
 
 
678
  use_cache = use_cache if use_cache is not None else self.config.use_cache
679
 
680
+ return_dict = (return_dict if return_dict is not None else
681
+ self.config.use_return_dict)
 
682
 
683
  # retrieve input_ids and inputs_embeds
684
  if input_ids is not None and inputs_embeds is not None:
 
700
  if past_key_values is not None:
701
  past_key_values_length = past_key_values[0][0].shape[-2]
702
  seq_length_with_past = seq_length_with_past + past_key_values_length
703
+
704
  if inputs_embeds is None:
705
  # !!! use embed_scale
706
  inputs_embeds = self.embed_scale * self.embed_tokens(input_ids)
 
722
  ##### norm linear layers
723
  linear_attn_padding_mask = attn_padding_mask
724
  linear_attn_mask = self._prepare_decoder_linear_attn_mask(
725
+ (batch_size, seq_length), inputs_embeds, past_key_values_length)
 
726
 
727
+ slope_rates = [
728
+ self.slopes.to(input_ids.device) for _ in range(self.num_layers)
729
+ ]
730
 
731
  for idx, layer in enumerate(self.layers):
732
  if output_hidden_states:
733
+ all_hidden_states += (hidden_states, )
734
 
735
+ past_key_value = (past_key_values[idx]
736
+ if past_key_values is not None else None)
 
737
 
738
  slope_rate = slope_rates[idx]
739
  slope_rate = slope_rate * (1 - idx / (self.num_layers - 1) + 1e-5)
740
  mask = linear_attn_mask
741
+
742
  layer_outputs = layer(
743
  hidden_states,
744
  attn_mask=mask,
 
752
  hidden_states = layer_outputs[0]
753
 
754
  if use_cache:
755
+ next_decoder_cache += (
756
+ layer_outputs[2 if output_attentions else 1], )
757
 
758
  if output_attentions:
759
+ all_self_attns += (layer_outputs[1], )
 
 
 
760
 
761
  hidden_states = self.final_norm(hidden_states)
762
 
763
  # add hidden states from the last decoder layer
764
  if output_hidden_states:
765
+ all_hidden_states += (hidden_states, )
766
 
767
  next_cache = next_decoder_cache if use_cache else None
768
  if not return_dict:
769
  return tuple(
770
+ v for v in
771
+ [hidden_states, next_cache, all_hidden_states, all_self_attns]
772
+ if v is not None)
 
773
  return BaseModelOutputWithPast(
774
  last_hidden_state=hidden_states,
775
  past_key_values=next_cache,
 
779
 
780
 
781
  class TransnormerForCausalLM(TransnormerPreTrainedModel):
782
+
783
  def __init__(self, config):
784
  super().__init__(config)
785
  self.model = TransnormerModel(config)
 
787
  logging_info(self.model)
788
 
789
  # the lm_head weight is automatically tied to the embed tokens weight
790
+ self.lm_head = nn.Linear(config.decoder_embed_dim,
791
+ config.vocab_size,
792
+ bias=False)
793
 
794
  # Initialize weights and apply final processing
795
  self.post_init()
 
813
  return self.model
814
 
815
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
816
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast,
817
+ config_class=_CONFIG_FOR_DOC)
 
818
  def forward(
819
  self,
820
  input_ids: torch.LongTensor = None,
 
852
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
853
  "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
854
  ```"""
855
+ output_attentions = (output_attentions if output_attentions is not None
856
+ else self.config.output_attentions)
857
+ output_hidden_states = (output_hidden_states
858
+ if output_hidden_states is not None else
859
+ self.config.output_hidden_states)
860
+ return_dict = (return_dict if return_dict is not None else
861
+ self.config.use_return_dict)
 
 
 
 
 
 
862
 
863
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
864
  outputs = self.model(
 
889
  loss = loss_fct(shift_logits, shift_labels)
890
 
891
  if not return_dict:
892
+ output = (logits, ) + outputs[1:]
893
+ return (loss, ) + output if loss is not None else output
894
 
895
  return CausalLMOutputWithPast(
896
  loss=loss,
 
917
  else:
918
  model_inputs = {"input_ids": input_ids}
919
 
920
+ model_inputs.update({
921
+ "past_key_values": past_key_values,
922
+ "use_cache": kwargs.get("use_cache"),
923
+ "attention_mask": attention_mask,
924
+ })
 
 
925
  return model_inputs
926
 
927
  @staticmethod
928
  def _reorder_cache(past_key_values, beam_idx):
929
  reordered_past = ()
930
  for layer_past in past_key_values:
931
+ reordered_past += (tuple(
932
+ past_state.index_select(0, beam_idx)
933
+ for past_state in layer_past), )
 
 
934
  return reordered_past