ClaudiaIoana550 commited on
Commit
c563e3f
1 Parent(s): 588e93d

Update modeling_falcon.py

Browse files
Files changed (1) hide show
  1. modeling_falcon.py +638 -0
modeling_falcon.py CHANGED
@@ -171,6 +171,7 @@ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torc
171
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
172
  """
173
  Dropout add function
 
174
  Args:
175
  x (`torch.tensor`, *required*):
176
  input tensor
@@ -223,8 +224,10 @@ class FalconAttention(nn.Module):
223
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
224
  """
225
  Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
 
226
  Args:
227
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
 
228
  Returns:
229
  query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
230
  value: [batch_size, seq_length, num_heads, head_dim]
@@ -253,8 +256,10 @@ class FalconAttention(nn.Module):
253
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
254
  """
255
  Merge heads together over the last dimenstion
 
256
  Args:
257
  x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
 
258
  Returns:
259
  torch.tensor: [batch_size, seq_length, num_heads * head_dim]
260
  """
@@ -485,11 +490,14 @@ class FalconDecoderLayer(nn.Module):
485
 
486
 
487
  FALCON_START_DOCSTRING = r"""
 
488
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
489
  library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
 
490
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
491
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
492
  and behavior.
 
493
  Parameters:
494
  config ([`FalconConfig`]): Model configuration class with all the parameters of the model.
495
  Initializing with a config file does not load the weights associated with the model, only the
@@ -501,31 +509,40 @@ FALCON_INPUTS_DOCSTRING = r"""
501
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
502
  `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
503
  (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
 
504
  If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
505
  `input_ids`.
 
506
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
507
  [`PreTrainedTokenizer.__call__`] for details.
 
508
  [What are input IDs?](../glossary#input-ids)
509
  past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_hidden_layers`):
510
  Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
511
  `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
512
  their past given to this model should not be passed as `input_ids` as they have already been computed.
 
513
  Each element of `past_key_values` is a tuple (past_key, past_value):
514
  - past_key: [batch_size * num_heads, head_dim, kv_length]
515
  - past_value: [batch_size * num_heads, kv_length, head_dim]
516
  attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
517
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
518
  - 1 for tokens that are **not masked**,
519
  - 0 for tokens that are **masked**.
 
520
  [What are attention masks?](../glossary#attention-mask)
521
  head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
522
  Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
 
523
  - 1 indicates the head is **not masked**,
524
  - 0 indicates the head is **masked**.
 
525
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
526
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
527
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
528
  model's internal embedding lookup matrix.
 
529
  If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
530
  `past_key_values`).
531
  use_cache (`bool`, *optional*):
@@ -622,3 +639,624 @@ class FalconModel(FalconPreTrainedModel):
622
  def __init__(self, config: FalconConfig):
623
  super().__init__(config)
624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
172
  """
173
  Dropout add function
174
+
175
  Args:
176
  x (`torch.tensor`, *required*):
177
  input tensor
 
224
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
225
  """
226
  Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
227
+
228
  Args:
229
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
230
+
231
  Returns:
232
  query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
233
  value: [batch_size, seq_length, num_heads, head_dim]
 
256
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
257
  """
258
  Merge heads together over the last dimenstion
259
+
260
  Args:
261
  x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
262
+
263
  Returns:
264
  torch.tensor: [batch_size, seq_length, num_heads * head_dim]
265
  """
 
490
 
491
 
492
  FALCON_START_DOCSTRING = r"""
493
+
494
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
495
  library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
496
+
497
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
498
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
499
  and behavior.
500
+
501
  Parameters:
502
  config ([`FalconConfig`]): Model configuration class with all the parameters of the model.
503
  Initializing with a config file does not load the weights associated with the model, only the
 
509
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
510
  `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
511
  (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
512
+
513
  If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
514
  `input_ids`.
515
+
516
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
517
  [`PreTrainedTokenizer.__call__`] for details.
518
+
519
  [What are input IDs?](../glossary#input-ids)
520
  past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_hidden_layers`):
521
  Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
522
  `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
523
  their past given to this model should not be passed as `input_ids` as they have already been computed.
524
+
525
  Each element of `past_key_values` is a tuple (past_key, past_value):
526
  - past_key: [batch_size * num_heads, head_dim, kv_length]
527
  - past_value: [batch_size * num_heads, kv_length, head_dim]
528
  attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
529
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
530
+
531
  - 1 for tokens that are **not masked**,
532
  - 0 for tokens that are **masked**.
533
+
534
  [What are attention masks?](../glossary#attention-mask)
535
  head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
536
  Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
537
+
538
  - 1 indicates the head is **not masked**,
539
  - 0 indicates the head is **masked**.
540
+
541
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
542
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
543
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
544
  model's internal embedding lookup matrix.
545
+
546
  If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
547
  `past_key_values`).
548
  use_cache (`bool`, *optional*):
 
639
  def __init__(self, config: FalconConfig):
640
  super().__init__(config)
641
 
642
+ self.embed_dim = config.hidden_size
643
+ self.num_heads = config.num_attention_heads
644
+ self.use_alibi = config.alibi
645
+
646
+ # Embedding + LN Embedding
647
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
648
+
649
+ # Transformer blocks
650
+ self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
651
+
652
+ # Final Layer Norm
653
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
654
+
655
+ self.gradient_checkpointing = False
656
+
657
+ # Initialize weights and apply final processing
658
+ self.post_init()
659
+
660
+ def get_input_embeddings(self):
661
+ return self.word_embeddings
662
+
663
+ @staticmethod
664
+ def _prepare_attn_mask(
665
+ attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
666
+ ) -> torch.BoolTensor:
667
+ # Create a causal mask
668
+ # The attention mask we receive as input should cover the whole extended sequence, including any past
669
+ # cache, so its shape should be [batch_size, seq_length + past_key_values_length]
670
+ # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
671
+ if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
672
+ raise ValueError(
673
+ "Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
674
+ f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
675
+ f" {past_key_values_length}."
676
+ )
677
+ combined_attention_mask = None
678
+ device = attention_mask.device
679
+ _, seq_length = input_shape
680
+
681
+ if seq_length > 1:
682
+ combined_attention_mask = _make_causal_mask(
683
+ input_shape, device=device, past_key_values_length=past_key_values_length
684
+ )
685
+
686
+ # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
687
+ expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length)
688
+ combined_attention_mask = (
689
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
690
+ )
691
+
692
+ return combined_attention_mask
693
+
694
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
695
+ self.word_embeddings = new_embeddings
696
+
697
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
698
+ @add_code_sample_docstrings(
699
+ checkpoint=_CHECKPOINT_FOR_DOC,
700
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
701
+ config_class=_CONFIG_FOR_DOC,
702
+ )
703
+ def forward(
704
+ self,
705
+ input_ids: Optional[torch.LongTensor] = None,
706
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
707
+ attention_mask: Optional[torch.Tensor] = None,
708
+ head_mask: Optional[torch.LongTensor] = None,
709
+ inputs_embeds: Optional[torch.LongTensor] = None,
710
+ use_cache: Optional[bool] = None,
711
+ output_attentions: Optional[bool] = None,
712
+ output_hidden_states: Optional[bool] = None,
713
+ return_dict: Optional[bool] = None,
714
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
715
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
+ output_hidden_states = (
717
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
718
+ )
719
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
720
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
721
+
722
+ if input_ids is not None and inputs_embeds is not None:
723
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
724
+ elif input_ids is not None:
725
+ batch_size, seq_length = input_ids.shape
726
+ elif inputs_embeds is not None:
727
+ batch_size, seq_length, _ = inputs_embeds.shape
728
+ else:
729
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
730
+
731
+ if past_key_values is None:
732
+ past_key_values = tuple([None] * len(self.h))
733
+ else:
734
+ past_key_values = self._convert_to_rw_cache(past_key_values)
735
+
736
+ # Prepare head mask if needed
737
+ # 1.0 in head_mask indicate we keep the head
738
+ # attention_probs has shape batch_size x num_heads x N x N
739
+ # head_mask has shape n_layer x batch x num_heads x N x N
740
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
741
+
742
+ if inputs_embeds is None:
743
+ inputs_embeds = self.word_embeddings(input_ids)
744
+
745
+ hidden_states = inputs_embeds
746
+
747
+ presents = () if use_cache else None
748
+ all_self_attentions = () if output_attentions else None
749
+ all_hidden_states = () if output_hidden_states else None
750
+
751
+ # Compute alibi tensor: check build_alibi_tensor documentation
752
+ past_key_values_length = 0
753
+ if past_key_values[0] is not None:
754
+ past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
755
+ if attention_mask is None:
756
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
757
+ else:
758
+ attention_mask = attention_mask.to(hidden_states.device)
759
+
760
+ if self.use_alibi:
761
+ alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
762
+ else:
763
+ alibi = None
764
+
765
+ causal_mask = self._prepare_attn_mask(
766
+ attention_mask,
767
+ input_shape=(batch_size, seq_length),
768
+ past_key_values_length=past_key_values_length,
769
+ )
770
+
771
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
772
+ if output_hidden_states:
773
+ all_hidden_states = all_hidden_states + (hidden_states,)
774
+
775
+ if self.gradient_checkpointing and self.training:
776
+ if use_cache:
777
+ logger.warning(
778
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
779
+ )
780
+ use_cache = False
781
+
782
+ def create_custom_forward(module):
783
+ def custom_forward(*inputs):
784
+ # None for past_key_value
785
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
786
+
787
+ return custom_forward
788
+
789
+ outputs = torch.utils.checkpoint.checkpoint(
790
+ create_custom_forward(block),
791
+ hidden_states,
792
+ alibi,
793
+ causal_mask,
794
+ head_mask[i],
795
+ )
796
+ else:
797
+ outputs = block(
798
+ hidden_states,
799
+ layer_past=layer_past,
800
+ attention_mask=causal_mask,
801
+ head_mask=head_mask[i],
802
+ use_cache=use_cache,
803
+ output_attentions=output_attentions,
804
+ alibi=alibi,
805
+ )
806
+
807
+ hidden_states = outputs[0]
808
+ if use_cache is True:
809
+ presents = presents + (outputs[1],)
810
+
811
+ if output_attentions:
812
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
813
+
814
+ # Add last hidden state
815
+ hidden_states = self.ln_f(hidden_states)
816
+
817
+ if output_hidden_states:
818
+ all_hidden_states = all_hidden_states + (hidden_states,)
819
+
820
+ if presents is not None:
821
+ presents = self._convert_cache_to_standard_format(presents, batch_size)
822
+
823
+ if not return_dict:
824
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
825
+
826
+ return BaseModelOutputWithPastAndCrossAttentions(
827
+ last_hidden_state=hidden_states,
828
+ past_key_values=presents,
829
+ hidden_states=all_hidden_states,
830
+ attentions=all_self_attentions,
831
+ )
832
+
833
+
834
+ @add_start_docstrings(
835
+ "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
836
+ FALCON_START_DOCSTRING,
837
+ )
838
+ class FalconForCausalLM(FalconPreTrainedModel):
839
+ _tied_weights_keys = ["lm_head.weight"]
840
+
841
+ def __init__(self, config: FalconConfig):
842
+ super().__init__(config)
843
+ self.transformer = FalconModel(config)
844
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
845
+
846
+ # Initialize weights and apply final processing
847
+ self.post_init()
848
+
849
+ def get_output_embeddings(self):
850
+ return self.lm_head
851
+
852
+ def set_output_embeddings(self, new_embeddings: torch.Tensor):
853
+ self.lm_head = new_embeddings
854
+
855
+ def prepare_inputs_for_generation(
856
+ self,
857
+ input_ids: torch.LongTensor,
858
+ past_key_values: Optional[torch.Tensor] = None,
859
+ attention_mask: Optional[torch.Tensor] = None,
860
+ **kwargs,
861
+ ) -> dict:
862
+ if past_key_values is not None:
863
+ input_ids = input_ids[:, -1:]
864
+
865
+ return {
866
+ "input_ids": input_ids,
867
+ "past_key_values": past_key_values,
868
+ "use_cache": kwargs.get("use_cache"),
869
+ "attention_mask": attention_mask,
870
+ }
871
+
872
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
873
+ @add_code_sample_docstrings(
874
+ checkpoint=_CHECKPOINT_FOR_DOC,
875
+ output_type=CausalLMOutputWithCrossAttentions,
876
+ config_class=_CONFIG_FOR_DOC,
877
+ )
878
+ def forward(
879
+ self,
880
+ input_ids: Optional[torch.LongTensor] = None,
881
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
882
+ attention_mask: Optional[torch.Tensor] = None,
883
+ head_mask: Optional[torch.Tensor] = None,
884
+ inputs_embeds: Optional[torch.Tensor] = None,
885
+ labels: Optional[torch.Tensor] = None,
886
+ use_cache: Optional[bool] = None,
887
+ output_attentions: Optional[bool] = None,
888
+ output_hidden_states: Optional[bool] = None,
889
+ return_dict: Optional[bool] = None,
890
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
891
+ r"""
892
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
893
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
894
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
895
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
896
+ """
897
+
898
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
899
+
900
+ transformer_outputs = self.transformer(
901
+ input_ids,
902
+ past_key_values=past_key_values,
903
+ attention_mask=attention_mask,
904
+ head_mask=head_mask,
905
+ inputs_embeds=inputs_embeds,
906
+ use_cache=use_cache,
907
+ output_attentions=output_attentions,
908
+ output_hidden_states=output_hidden_states,
909
+ return_dict=return_dict,
910
+ )
911
+ hidden_states = transformer_outputs[0]
912
+
913
+ lm_logits = self.lm_head(hidden_states)
914
+
915
+ loss = None
916
+ if labels is not None:
917
+ # Shift so that tokens < n predict n
918
+ shift_logits = lm_logits[..., :-1, :].contiguous()
919
+ shift_labels = labels[..., 1:].contiguous()
920
+ batch_size, seq_length, vocab_size = shift_logits.shape
921
+ # Flatten the tokens
922
+ loss_fct = CrossEntropyLoss()
923
+ loss = loss_fct(
924
+ shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
925
+ )
926
+
927
+ if not return_dict:
928
+ output = (lm_logits,) + transformer_outputs[1:]
929
+ return ((loss,) + output) if loss is not None else output
930
+
931
+ return CausalLMOutputWithCrossAttentions(
932
+ loss=loss,
933
+ logits=lm_logits,
934
+ past_key_values=transformer_outputs.past_key_values,
935
+ hidden_states=transformer_outputs.hidden_states,
936
+ attentions=transformer_outputs.attentions,
937
+ )
938
+
939
+ def _reorder_cache(
940
+ self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
941
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
942
+ """
943
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
944
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
945
+ beam_idx at every generation step.
946
+
947
+ Output shares the same memory storage as `past`.
948
+ """
949
+
950
+ # Get a copy of `beam_idx` on all the devices where we need those indices.
951
+ device_to_beam_idx = {
952
+ past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
953
+ }
954
+ reordered_past = tuple(
955
+ (
956
+ layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
957
+ layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
958
+ )
959
+ for layer_past in past
960
+ )
961
+ return reordered_past
962
+
963
+
964
+ @add_start_docstrings(
965
+ """
966
+ The Falcon Model transformer with a sequence classification head on top (linear layer).
967
+
968
+ [`FalconForSequenceClassification`] uses the last token in order to do the classification, as other causal models
969
+ (e.g. GPT-1) do.
970
+
971
+ Since it does classification on the last token, it requires to know the position of the last token. If a
972
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
973
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
974
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
975
+ each row of the batch).
976
+ """,
977
+ FALCON_START_DOCSTRING,
978
+ )
979
+ class FalconForSequenceClassification(FalconPreTrainedModel):
980
+ def __init__(self, config: FalconConfig):
981
+ super().__init__(config)
982
+ self.num_labels = config.num_labels
983
+ self.transformer = FalconModel(config)
984
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
985
+
986
+ # Initialize weights and apply final processing
987
+ self.post_init()
988
+
989
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
990
+ @add_code_sample_docstrings(
991
+ checkpoint=_CHECKPOINT_FOR_DOC,
992
+ output_type=SequenceClassifierOutputWithPast,
993
+ config_class=_CONFIG_FOR_DOC,
994
+ )
995
+ def forward(
996
+ self,
997
+ input_ids: Optional[torch.LongTensor] = None,
998
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
999
+ attention_mask: Optional[torch.Tensor] = None,
1000
+ head_mask: Optional[torch.Tensor] = None,
1001
+ inputs_embeds: Optional[torch.Tensor] = None,
1002
+ labels: Optional[torch.Tensor] = None,
1003
+ use_cache: Optional[bool] = None,
1004
+ output_attentions: Optional[bool] = None,
1005
+ output_hidden_states: Optional[bool] = None,
1006
+ return_dict: Optional[bool] = None,
1007
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1008
+ r"""
1009
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1010
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1011
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1012
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1013
+ """
1014
+
1015
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1016
+
1017
+ transformer_outputs = self.transformer(
1018
+ input_ids,
1019
+ past_key_values=past_key_values,
1020
+ attention_mask=attention_mask,
1021
+ head_mask=head_mask,
1022
+ inputs_embeds=inputs_embeds,
1023
+ use_cache=use_cache,
1024
+ output_attentions=output_attentions,
1025
+ output_hidden_states=output_hidden_states,
1026
+ return_dict=return_dict,
1027
+ )
1028
+
1029
+ hidden_states = transformer_outputs[0]
1030
+ logits = self.score(hidden_states)
1031
+
1032
+ if input_ids is not None:
1033
+ batch_size = input_ids.shape[0]
1034
+ else:
1035
+ batch_size = inputs_embeds.shape[0]
1036
+
1037
+ if self.config.pad_token_id is None and batch_size != 1:
1038
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1039
+ if self.config.pad_token_id is None:
1040
+ sequence_lengths = -1
1041
+ else:
1042
+ if input_ids is not None:
1043
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
1044
+ else:
1045
+ sequence_lengths = -1
1046
+ logger.warning(
1047
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1048
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1049
+ )
1050
+
1051
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1052
+
1053
+ loss = None
1054
+ if labels is not None:
1055
+ if self.config.problem_type is None:
1056
+ if self.num_labels == 1:
1057
+ self.config.problem_type = "regression"
1058
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1059
+ self.config.problem_type = "single_label_classification"
1060
+ else:
1061
+ self.config.problem_type = "multi_label_classification"
1062
+
1063
+ if self.config.problem_type == "regression":
1064
+ loss_fct = MSELoss()
1065
+ if self.num_labels == 1:
1066
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1067
+ else:
1068
+ loss = loss_fct(pooled_logits, labels)
1069
+ elif self.config.problem_type == "single_label_classification":
1070
+ loss_fct = CrossEntropyLoss()
1071
+ loss = loss_fct(pooled_logits, labels)
1072
+ elif self.config.problem_type == "multi_label_classification":
1073
+ loss_fct = BCEWithLogitsLoss()
1074
+ loss = loss_fct(pooled_logits, labels)
1075
+ if not return_dict:
1076
+ output = (pooled_logits,) + transformer_outputs[1:]
1077
+ return ((loss,) + output) if loss is not None else output
1078
+
1079
+ return SequenceClassifierOutputWithPast(
1080
+ loss=loss,
1081
+ logits=pooled_logits,
1082
+ past_key_values=transformer_outputs.past_key_values,
1083
+ hidden_states=transformer_outputs.hidden_states,
1084
+ attentions=transformer_outputs.attentions,
1085
+ )
1086
+
1087
+
1088
+ @add_start_docstrings(
1089
+ """
1090
+ Falcon Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1091
+ Named-Entity-Recognition (NER) tasks.
1092
+ """,
1093
+ FALCON_START_DOCSTRING,
1094
+ )
1095
+ class FalconForTokenClassification(FalconPreTrainedModel):
1096
+ def __init__(self, config: FalconConfig):
1097
+ super().__init__(config)
1098
+ self.num_labels = config.num_labels
1099
+
1100
+ self.transformer = FalconModel(config)
1101
+ if getattr(config, "classifier_dropout", None) is not None:
1102
+ classifier_dropout = config.classifier_dropout
1103
+ elif getattr(config, "hidden_dropout", None) is not None:
1104
+ classifier_dropout = config.hidden_dropout
1105
+ else:
1106
+ classifier_dropout = 0.1
1107
+ self.dropout = nn.Dropout(classifier_dropout)
1108
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1109
+
1110
+ # Initialize weights and apply final processing
1111
+ self.post_init()
1112
+
1113
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1114
+ @add_code_sample_docstrings(
1115
+ checkpoint=_CHECKPOINT_FOR_DOC,
1116
+ output_type=TokenClassifierOutput,
1117
+ config_class=_CONFIG_FOR_DOC,
1118
+ )
1119
+ def forward(
1120
+ self,
1121
+ input_ids: Optional[torch.LongTensor] = None,
1122
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1123
+ attention_mask: Optional[torch.Tensor] = None,
1124
+ head_mask: Optional[torch.Tensor] = None,
1125
+ inputs_embeds: Optional[torch.Tensor] = None,
1126
+ labels: Optional[torch.Tensor] = None,
1127
+ use_cache: Optional[bool] = None,
1128
+ output_attentions: Optional[bool] = None,
1129
+ output_hidden_states: Optional[bool] = None,
1130
+ return_dict: Optional[bool] = None,
1131
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1132
+ r"""
1133
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1134
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1135
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1136
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1137
+ """
1138
+
1139
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1140
+
1141
+ transformer_outputs = self.transformer(
1142
+ input_ids,
1143
+ past_key_values=past_key_values,
1144
+ attention_mask=attention_mask,
1145
+ head_mask=head_mask,
1146
+ inputs_embeds=inputs_embeds,
1147
+ use_cache=use_cache,
1148
+ output_attentions=output_attentions,
1149
+ output_hidden_states=output_hidden_states,
1150
+ return_dict=return_dict,
1151
+ )
1152
+
1153
+ hidden_states = transformer_outputs[0]
1154
+ hidden_states = self.dropout(hidden_states)
1155
+ logits = self.classifier(hidden_states)
1156
+
1157
+ loss = None
1158
+ if labels is not None:
1159
+ batch_size, seq_length = labels.shape
1160
+ loss_fct = CrossEntropyLoss()
1161
+ loss = loss_fct(
1162
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1163
+ )
1164
+
1165
+ if not return_dict:
1166
+ output = (logits,) + transformer_outputs[2:]
1167
+ return ((loss,) + output) if loss is not None else output
1168
+
1169
+ return TokenClassifierOutput(
1170
+ loss=loss,
1171
+ logits=logits,
1172
+ hidden_states=transformer_outputs.hidden_states,
1173
+ attentions=transformer_outputs.attentions,
1174
+ )
1175
+
1176
+
1177
+ @add_start_docstrings(
1178
+ """
1179
+ The Falcon Model transformer with a span classification head on top for extractive question-answering tasks like
1180
+ SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1181
+ """,
1182
+ FALCON_START_DOCSTRING,
1183
+ )
1184
+ class FalconForQuestionAnswering(FalconPreTrainedModel):
1185
+ def __init__(self, config):
1186
+ super().__init__(config)
1187
+ self.transformer = FalconModel(config)
1188
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1189
+
1190
+ # Initialize weights and apply final processing
1191
+ self.post_init()
1192
+
1193
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1194
+ def forward(
1195
+ self,
1196
+ input_ids: Optional[torch.LongTensor] = None,
1197
+ attention_mask: Optional[torch.FloatTensor] = None,
1198
+ head_mask: Optional[torch.FloatTensor] = None,
1199
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1200
+ start_positions: Optional[torch.LongTensor] = None,
1201
+ end_positions: Optional[torch.LongTensor] = None,
1202
+ output_attentions: Optional[bool] = None,
1203
+ output_hidden_states: Optional[bool] = None,
1204
+ return_dict: Optional[bool] = None,
1205
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1206
+ r"""
1207
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1208
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1209
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1210
+ are not taken into account for computing the loss.
1211
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1212
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1213
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1214
+ are not taken into account for computing the loss.
1215
+ """
1216
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1217
+
1218
+ outputs = self.transformer(
1219
+ input_ids,
1220
+ attention_mask=attention_mask,
1221
+ head_mask=head_mask,
1222
+ inputs_embeds=inputs_embeds,
1223
+ output_attentions=output_attentions,
1224
+ output_hidden_states=output_hidden_states,
1225
+ return_dict=return_dict,
1226
+ )
1227
+
1228
+ sequence_output = outputs[0]
1229
+
1230
+ logits = self.qa_outputs(sequence_output)
1231
+ start_logits, end_logits = logits.split(1, dim=-1)
1232
+ start_logits = start_logits.squeeze(-1).contiguous()
1233
+ end_logits = end_logits.squeeze(-1).contiguous()
1234
+
1235
+ total_loss = None
1236
+ if start_positions is not None and end_positions is not None:
1237
+ # If we are on multi-GPU, split add a dimension
1238
+ if len(start_positions.size()) > 1:
1239
+ start_positions = start_positions.squeeze(-1)
1240
+ if len(end_positions.size()) > 1:
1241
+ end_positions = end_positions.squeeze(-1)
1242
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1243
+ ignored_index = start_logits.size(1)
1244
+ start_positions = start_positions.clamp(0, ignored_index)
1245
+ end_positions = end_positions.clamp(0, ignored_index)
1246
+
1247
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1248
+ start_loss = loss_fct(start_logits, start_positions)
1249
+ end_loss = loss_fct(end_logits, end_positions)
1250
+ total_loss = (start_loss + end_loss) / 2
1251
+
1252
+ if not return_dict:
1253
+ output = (start_logits, end_logits) + outputs[2:]
1254
+ return ((total_loss,) + output) if total_loss is not None else output
1255
+
1256
+ return QuestionAnsweringModelOutput(
1257
+ loss=total_loss,
1258
+ start_logits=start_logits,
1259
+ end_logits=end_logits,
1260
+ hidden_states=outputs.hidden_states,
1261
+ attentions=outputs.attentions,
1262
+ )