Text Generation
English
Eval Results
zifei9 commited on
Commit
9ada1ae
1 Parent(s): 1305186

Update modeling_gpt2.py

Browse files

Updating to be compatible with the latest transformer version

Files changed (1) hide show
  1. modeling_gpt2.py +558 -151
modeling_gpt2.py CHANGED
@@ -23,11 +23,16 @@ from typing import Optional, Tuple, Union
23
 
24
  import torch
25
  import torch.utils.checkpoint
 
26
  from torch import nn
27
- from torch.cuda.amp import autocast
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
 
30
  from transformers.activations import ACT2FN
 
 
 
 
 
31
  from transformers.modeling_outputs import (
32
  BaseModelOutputWithPastAndCrossAttentions,
33
  CausalLMOutputWithCrossAttentions,
@@ -46,26 +51,25 @@ from transformers.utils import (
46
  add_code_sample_docstrings,
47
  add_start_docstrings,
48
  add_start_docstrings_to_model_forward,
 
 
 
49
  logging,
50
  replace_return_docstrings,
51
  )
 
52
  from .configuration_gpt2 import GPT2Config
53
- # from mltools.dmx import DmxPreTrainedModel as PreTrainedModel
 
 
 
 
54
 
55
  logger = logging.get_logger(__name__)
56
 
57
- _CHECKPOINT_FOR_DOC = "gpt2"
58
  _CONFIG_FOR_DOC = "GPT2Config"
59
 
60
- GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
61
- "gpt2",
62
- "gpt2-medium",
63
- "gpt2-large",
64
- "gpt2-xl",
65
- "distilgpt2",
66
- # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
67
- ]
68
-
69
 
70
  def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
71
  """Load tf checkpoints in a pytorch model"""
@@ -128,7 +132,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
128
  class GPT2Attention(nn.Module):
129
  def __init__(self, config, is_cross_attention=False, layer_idx=None):
130
  super().__init__()
131
-
132
  max_positions = config.max_position_embeddings
133
  self.register_buffer(
134
  "bias",
@@ -166,6 +170,7 @@ class GPT2Attention(nn.Module):
166
 
167
  self.attn_dropout = nn.Dropout(config.attn_pdrop)
168
  self.resid_dropout = nn.Dropout(config.resid_pdrop)
 
169
 
170
  self.pruned_heads = set()
171
 
@@ -210,7 +215,7 @@ class GPT2Attention(nn.Module):
210
  query_length, key_length = query.size(-2), key.size(-2)
211
  causal_mask = self.bias[
212
  :, :, key_length - query_length : key_length, :key_length
213
- ].to(attn_weights.device)
214
  mask_value = torch.finfo(attn_weights.dtype).min
215
  # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
216
  # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
@@ -264,7 +269,7 @@ class GPT2Attention(nn.Module):
264
  scale_factor /= float(self.layer_idx + 1)
265
 
266
  # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
267
- with autocast(enabled=False):
268
  q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
269
  -1, dk, k_seq_len
270
  )
@@ -385,6 +390,244 @@ class GPT2Attention(nn.Module):
385
  return outputs # a, present, (attentions)
386
 
387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  class GPT2MLP(nn.Module):
389
  def __init__(self, intermediate_size, config):
390
  super().__init__()
@@ -404,19 +647,27 @@ class GPT2MLP(nn.Module):
404
  return hidden_states
405
 
406
 
 
 
 
 
 
 
 
407
  class GPT2Block(nn.Module):
408
  def __init__(self, config, layer_idx=None):
409
  super().__init__()
410
  hidden_size = config.hidden_size
411
  inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
 
412
 
413
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
414
- self.attn = GPT2Attention(config, layer_idx=layer_idx)
415
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
416
 
417
  if config.add_cross_attention:
418
- self.crossattention = GPT2Attention(
419
- config, is_cross_attention=True, layer_idx=layer_idx
420
  )
421
  self.ln_cross_attn = nn.LayerNorm(
422
  hidden_size, eps=config.layer_norm_epsilon
@@ -500,9 +751,12 @@ class GPT2PreTrainedModel(PreTrainedModel):
500
  config_class = GPT2Config
501
  load_tf_weights = load_tf_weights_in_gpt2
502
  base_model_prefix = "transformer"
 
503
  supports_gradient_checkpointing = True
504
  _no_split_modules = ["GPT2Block"]
505
  _skip_keys_device_placement = "past_key_values"
 
 
506
 
507
  def __init__(self, *inputs, **kwargs):
508
  super().__init__(*inputs, **kwargs)
@@ -666,6 +920,56 @@ GPT2_INPUTS_DOCSTRING = r"""
666
  return_dict (`bool`, *optional*):
667
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
668
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669
 
670
 
671
  @add_start_docstrings(
@@ -673,6 +977,8 @@ GPT2_INPUTS_DOCSTRING = r"""
673
  GPT2_START_DOCSTRING,
674
  )
675
  class GPT2Model(GPT2PreTrainedModel):
 
 
676
  def __init__(self, config):
677
  super().__init__(config)
678
 
@@ -687,11 +993,65 @@ class GPT2Model(GPT2PreTrainedModel):
687
  )
688
  self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
689
 
 
 
 
690
  self.gradient_checkpointing = False
 
691
 
692
  # Initialize weights and apply final processing
693
  self.post_init()
694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  def get_input_embeddings(self):
696
  return self.wte
697
 
@@ -776,38 +1136,71 @@ class GPT2Model(GPT2PreTrainedModel):
776
  )
777
  position_ids = position_ids.unsqueeze(0)
778
 
779
- # GPT2Attention mask.
780
- if attention_mask is not None:
781
- if batch_size <= 0:
782
- raise ValueError("batch_size has to be defined and > 0")
783
- attention_mask = attention_mask.view(batch_size, -1)
784
- # We create a 3D attention mask from a 2D tensor mask.
785
- # Sizes are [batch_size, 1, 1, to_seq_length]
786
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
787
- # this attention mask is more simple than the triangular masking of causal attention
788
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
789
- attention_mask = attention_mask[:, None, None, :]
790
-
791
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
792
- # masked positions, this operation will create a tensor which is 0.0 for
793
- # positions we want to attend and the dtype's smallest value for masked positions.
794
- # Since we are adding it to the raw scores before the softmax, this is
795
- # effectively the same as removing these entirely.
796
- attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
797
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798
 
799
  # If a 2D or 3D attention mask is provided for the cross-attention
800
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
801
  if self.config.add_cross_attention and encoder_hidden_states is not None:
802
- (
803
- encoder_batch_size,
804
- encoder_sequence_length,
805
- _,
806
- ) = encoder_hidden_states.size()
807
  encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
808
  if encoder_attention_mask is None:
809
  encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
810
- encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
 
 
 
 
 
 
 
 
 
811
  else:
812
  encoder_attention_mask = None
813
 
@@ -817,11 +1210,6 @@ class GPT2Model(GPT2PreTrainedModel):
817
  # head_mask has shape n_layer x batch x n_heads x N x N
818
  head_mask = self.get_head_mask(head_mask, self.config.n_layer)
819
 
820
- if inputs_embeds is None:
821
- inputs_embeds = self.wte(input_ids)
822
- position_embeds = self.wpe(position_ids)
823
- hidden_states = inputs_embeds.to(position_embeds.device) + position_embeds
824
-
825
  if token_type_ids is not None:
826
  token_type_embeds = self.wte(token_type_ids)
827
  hidden_states = hidden_states + token_type_embeds
@@ -843,17 +1231,21 @@ class GPT2Model(GPT2PreTrainedModel):
843
  () if output_attentions and self.config.add_cross_attention else None
844
  )
845
  all_hidden_states = () if output_hidden_states else None
846
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
847
- # # Ensure layer_past is on same device as hidden_states (might not be correct)
848
- if layer_past is not None:
849
- layer_past = tuple(
850
- past_state.to(hidden_states.device) for past_state in layer_past
851
- )
852
- # # Ensure that attention_mask is always on the same device as hidden_states
853
- if attention_mask is not None:
854
- attention_mask = attention_mask.to(hidden_states.device)
855
- if isinstance(head_mask, torch.Tensor):
856
- head_mask = head_mask.to(hidden_states.device)
 
 
 
 
857
  if output_hidden_states:
858
  all_hidden_states = all_hidden_states + (hidden_states,)
859
 
@@ -894,6 +1286,12 @@ class GPT2Model(GPT2PreTrainedModel):
894
  outputs[3 if use_cache else 2],
895
  )
896
 
 
 
 
 
 
 
897
  hidden_states = self.ln_f(hidden_states)
898
 
899
  hidden_states = hidden_states.view(output_shape)
@@ -930,7 +1328,7 @@ class GPT2Model(GPT2PreTrainedModel):
930
  """,
931
  GPT2_START_DOCSTRING,
932
  )
933
- class GPT2LMHeadModel(GPT2PreTrainedModel):
934
  _tied_weights_keys = ["lm_head.weight"]
935
 
936
  def __init__(self, config):
@@ -938,64 +1336,50 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
938
  self.transformer = GPT2Model(config)
939
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
940
 
 
 
 
 
941
  # Initialize weights and apply final processing
942
  self.post_init()
943
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
944
  def get_output_embeddings(self):
945
  return self.lm_head
946
 
947
  def set_output_embeddings(self, new_embeddings):
948
  self.lm_head = new_embeddings
949
 
950
- def prepare_inputs_for_generation(
951
- self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
952
- ):
953
- token_type_ids = kwargs.get("token_type_ids", None)
954
- # Omit tokens covered by past_key_values
955
- if past_key_values:
956
- past_length = past_key_values[0][0].shape[2]
957
-
958
- # Some generation methods already pass only the last input ID
959
- if input_ids.shape[1] > past_length:
960
- remove_prefix_length = past_length
961
- else:
962
- # Default to old behavior: keep only final ID
963
- remove_prefix_length = input_ids.shape[1] - 1
964
-
965
- input_ids = input_ids[:, remove_prefix_length:]
966
- if token_type_ids is not None:
967
- token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
968
-
969
- attention_mask = kwargs.get("attention_mask", None)
970
- position_ids = kwargs.get("position_ids", None)
971
-
972
- if attention_mask is not None and position_ids is None:
973
- # create position_ids on the fly for batch generation
974
- position_ids = attention_mask.long().cumsum(-1) - 1
975
- position_ids.masked_fill_(attention_mask == 0, 1)
976
- if past_key_values:
977
- position_ids = position_ids[:, -input_ids.shape[1] :]
978
- else:
979
- position_ids = None
980
-
981
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
982
- if inputs_embeds is not None and past_key_values is None:
983
- model_inputs = {"inputs_embeds": inputs_embeds}
984
- else:
985
- model_inputs = {"input_ids": input_ids}
986
-
987
- model_inputs.update(
988
- {
989
- "past_key_values": past_key_values,
990
- "use_cache": kwargs.get("use_cache"),
991
- "position_ids": position_ids,
992
- "attention_mask": attention_mask,
993
- "token_type_ids": token_type_ids,
994
- }
995
- )
996
-
997
- return model_inputs
998
-
999
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1000
  @add_code_sample_docstrings(
1001
  checkpoint=_CHECKPOINT_FOR_DOC,
@@ -1046,6 +1430,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
1046
  )
1047
  hidden_states = transformer_outputs[0]
1048
 
 
 
 
 
 
1049
  lm_logits = self.lm_head(hidden_states)
1050
 
1051
  loss = None
@@ -1101,7 +1490,7 @@ input sequence).
1101
  """,
1102
  GPT2_START_DOCSTRING,
1103
  )
1104
- class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1105
  _tied_weights_keys = ["lm_head.weight"]
1106
 
1107
  def __init__(self, config):
@@ -1111,53 +1500,54 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1111
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1112
  self.multiple_choice_head = SequenceSummary(config)
1113
 
 
 
 
 
1114
  # Initialize weights and apply final processing
1115
  self.post_init()
1116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1117
  def get_output_embeddings(self):
1118
  return self.lm_head
1119
 
1120
  def set_output_embeddings(self, new_embeddings):
1121
  self.lm_head = new_embeddings
1122
 
1123
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
1124
- token_type_ids = kwargs.get("token_type_ids", None)
1125
- # Omit tokens covered by past_key_values
1126
- if past_key_values:
1127
- past_length = past_key_values[0][0].shape[2]
1128
-
1129
- # Some generation methods already pass only the last input ID
1130
- if input_ids.shape[1] > past_length:
1131
- remove_prefix_length = past_length
1132
- else:
1133
- # Default to old behavior: keep only final ID
1134
- remove_prefix_length = input_ids.shape[1] - 1
1135
-
1136
- input_ids = input_ids[:, remove_prefix_length:]
1137
- if token_type_ids is not None:
1138
- token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
1139
-
1140
- attention_mask = kwargs.get("attention_mask", None)
1141
- position_ids = kwargs.get("position_ids", None)
1142
-
1143
- if attention_mask is not None and position_ids is None:
1144
- # create position_ids on the fly for batch generation
1145
- position_ids = attention_mask.long().cumsum(-1) - 1
1146
- position_ids.masked_fill_(attention_mask == 0, 1)
1147
- if past_key_values:
1148
- position_ids = position_ids[:, -input_ids.shape[1] :]
1149
- else:
1150
- position_ids = None
1151
-
1152
- return {
1153
- "input_ids": input_ids,
1154
- "past_key_values": past_key_values,
1155
- "use_cache": kwargs.get("use_cache"),
1156
- "position_ids": position_ids,
1157
- "attention_mask": attention_mask,
1158
- "token_type_ids": token_type_ids,
1159
- }
1160
-
1161
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1162
  @replace_return_docstrings(
1163
  output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC
@@ -1200,8 +1590,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1200
  >>> import torch
1201
  >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel
1202
 
1203
- >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
1204
- >>> model = GPT2DoubleHeadsModel.from_pretrained("gpt2")
1205
 
1206
  >>> # Add a [CLS] to the vocabulary (we should train it also!)
1207
  >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
@@ -1239,6 +1629,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
1239
 
1240
  hidden_states = transformer_outputs[0]
1241
 
 
 
 
 
 
1242
  lm_logits = self.lm_head(hidden_states)
1243
  mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1244
 
@@ -1314,6 +1709,10 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1314
  self.transformer = GPT2Model(config)
1315
  self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1316
 
 
 
 
 
1317
  # Initialize weights and apply final processing
1318
  self.post_init()
1319
 
@@ -1384,7 +1783,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1384
  sequence_lengths = sequence_lengths.to(logits.device)
1385
  else:
1386
  sequence_lengths = -1
1387
- logger.warning(
1388
  f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1389
  "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1390
  )
@@ -1457,6 +1856,10 @@ class GPT2ForTokenClassification(GPT2PreTrainedModel):
1457
  self.dropout = nn.Dropout(classifier_dropout)
1458
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1459
 
 
 
 
 
1460
  # Initialize weights and apply final processing
1461
  self.post_init()
1462
 
@@ -1558,6 +1961,10 @@ class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
1558
  self.transformer = GPT2Model(config)
1559
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
1560
 
 
 
 
 
1561
  # Initialize weights and apply final processing
1562
  self.post_init()
1563
 
 
23
 
24
  import torch
25
  import torch.utils.checkpoint
26
+ from packaging import version
27
  from torch import nn
 
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
 
30
  from transformers.activations import ACT2FN
31
+ from transformers.generation import GenerationMixin
32
+ from transformers.modeling_attn_mask_utils import (
33
+ _prepare_4d_attention_mask_for_sdpa,
34
+ _prepare_4d_causal_attention_mask_for_sdpa,
35
+ )
36
  from transformers.modeling_outputs import (
37
  BaseModelOutputWithPastAndCrossAttentions,
38
  CausalLMOutputWithCrossAttentions,
 
51
  add_code_sample_docstrings,
52
  add_start_docstrings,
53
  add_start_docstrings_to_model_forward,
54
+ get_torch_version,
55
+ is_flash_attn_2_available,
56
+ is_flash_attn_greater_or_equal_2_10,
57
  logging,
58
  replace_return_docstrings,
59
  )
60
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
61
  from .configuration_gpt2 import GPT2Config
62
+
63
+
64
+ if is_flash_attn_2_available():
65
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
66
+
67
 
68
  logger = logging.get_logger(__name__)
69
 
70
+ _CHECKPOINT_FOR_DOC = "openai-community/gpt2"
71
  _CONFIG_FOR_DOC = "GPT2Config"
72
 
 
 
 
 
 
 
 
 
 
73
 
74
  def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
75
  """Load tf checkpoints in a pytorch model"""
 
132
  class GPT2Attention(nn.Module):
133
  def __init__(self, config, is_cross_attention=False, layer_idx=None):
134
  super().__init__()
135
+ self.config = config
136
  max_positions = config.max_position_embeddings
137
  self.register_buffer(
138
  "bias",
 
170
 
171
  self.attn_dropout = nn.Dropout(config.attn_pdrop)
172
  self.resid_dropout = nn.Dropout(config.resid_pdrop)
173
+ self.is_causal = True
174
 
175
  self.pruned_heads = set()
176
 
 
215
  query_length, key_length = query.size(-2), key.size(-2)
216
  causal_mask = self.bias[
217
  :, :, key_length - query_length : key_length, :key_length
218
+ ]
219
  mask_value = torch.finfo(attn_weights.dtype).min
220
  # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
221
  # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
 
269
  scale_factor /= float(self.layer_idx + 1)
270
 
271
  # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
272
+ with torch.amp.autocast(query.device.type, enabled=False):
273
  q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
274
  -1, dk, k_seq_len
275
  )
 
390
  return outputs # a, present, (attentions)
391
 
392
 
393
+ class GPT2FlashAttention2(GPT2Attention):
394
+ """
395
+ GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays
396
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
397
+ flash attention and deal with padding tokens in case the input contains any of them.
398
+ """
399
+
400
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
401
+ def __init__(self, *args, **kwargs):
402
+ super().__init__(*args, **kwargs)
403
+
404
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
405
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
406
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
407
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
408
+
409
+ def forward(
410
+ self,
411
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
412
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
413
+ attention_mask: Optional[torch.FloatTensor] = None,
414
+ head_mask: Optional[torch.FloatTensor] = None,
415
+ encoder_hidden_states: Optional[torch.Tensor] = None,
416
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
417
+ use_cache: Optional[bool] = False,
418
+ output_attentions: Optional[bool] = False,
419
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
420
+ bsz, _, _ = hidden_states.size()
421
+ if encoder_hidden_states is not None:
422
+ if not hasattr(self, "q_attn"):
423
+ raise ValueError(
424
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
425
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
426
+ )
427
+
428
+ query = self.q_attn(hidden_states)
429
+ key, value = self.c_attn(encoder_hidden_states).split(
430
+ self.split_size, dim=2
431
+ )
432
+ attention_mask = encoder_attention_mask
433
+ else:
434
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
435
+
436
+ query = self._split_heads(query, self.num_heads, self.head_dim)
437
+ key = self._split_heads(key, self.num_heads, self.head_dim)
438
+ value = self._split_heads(value, self.num_heads, self.head_dim)
439
+
440
+ if layer_past is not None:
441
+ past_key = layer_past[0]
442
+ past_value = layer_past[1]
443
+ key = torch.cat((past_key, key), dim=-2)
444
+ value = torch.cat((past_value, value), dim=-2)
445
+
446
+ present = None
447
+ if use_cache is True:
448
+ present = (key, value)
449
+
450
+ query_length = query.shape[2]
451
+ tgt_len = key.shape[2]
452
+
453
+ # Flash attention requires the input to have the shape
454
+ # batch_size x seq_length x head_dim x hidden_dim
455
+ query = query.transpose(1, 2).view(
456
+ bsz, query_length, self.num_heads, self.head_dim
457
+ )
458
+ key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
459
+ value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
460
+
461
+ attn_dropout = self.attn_dropout.p if self.training else 0.0
462
+
463
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
464
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
465
+ # cast them back in the correct dtype just to be sure everything works as expected.
466
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
467
+ # in fp32. (LlamaRMSNorm handles it correctly)
468
+
469
+ if query.dtype == torch.float32:
470
+ if torch.is_autocast_enabled():
471
+ target_dtype = torch.get_autocast_gpu_dtype()
472
+ # Handle the case where the model is quantized
473
+ elif hasattr(self.config, "_pre_quantization_dtype"):
474
+ target_dtype = self.config._pre_quantization_dtype
475
+ else:
476
+ target_dtype = self.c_proj.weight.dtype
477
+
478
+ logger.warning_once(
479
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
480
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
481
+ f" {target_dtype}."
482
+ )
483
+
484
+ query = query.to(target_dtype)
485
+ key = key.to(target_dtype)
486
+ value = value.to(target_dtype)
487
+
488
+ attn_output = _flash_attention_forward(
489
+ query,
490
+ key,
491
+ value,
492
+ attention_mask,
493
+ query_length,
494
+ dropout=attn_dropout,
495
+ is_causal=self.is_causal,
496
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
497
+ )
498
+
499
+ attn_weights_reshaped = attn_output.reshape(
500
+ bsz, query_length, self.num_heads * self.head_dim
501
+ )
502
+ attn_output = self.c_proj(attn_weights_reshaped)
503
+ attn_output = self.resid_dropout(attn_output)
504
+
505
+ outputs = (attn_output, present)
506
+ if output_attentions:
507
+ outputs += (attn_weights_reshaped,)
508
+
509
+ return outputs
510
+
511
+
512
+ class GPT2SdpaAttention(GPT2Attention):
513
+ """
514
+ GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
515
+ `GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
516
+ to adapt to the SDPA API.
517
+ """
518
+
519
+ def __init__(self, *args, **kwargs):
520
+ super().__init__(*args, **kwargs)
521
+
522
+ # Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__
523
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
524
+ # attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
525
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
526
+ self.require_contiguous_qkv = version.parse(
527
+ get_torch_version()
528
+ ) < version.parse("2.2.0")
529
+
530
+ def forward(
531
+ self,
532
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
533
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
534
+ attention_mask: Optional[torch.FloatTensor] = None,
535
+ head_mask: Optional[torch.FloatTensor] = None,
536
+ encoder_hidden_states: Optional[torch.Tensor] = None,
537
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
538
+ use_cache: Optional[bool] = False,
539
+ output_attentions: Optional[bool] = False,
540
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
541
+ if output_attentions or head_mask is not None:
542
+ logger.warning_once(
543
+ "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
544
+ "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
545
+ "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
546
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
547
+ )
548
+ return super().forward(
549
+ hidden_states=hidden_states,
550
+ layer_past=layer_past,
551
+ attention_mask=attention_mask,
552
+ head_mask=head_mask,
553
+ encoder_hidden_states=encoder_hidden_states,
554
+ encoder_attention_mask=encoder_attention_mask,
555
+ use_cache=use_cache,
556
+ output_attentions=output_attentions,
557
+ )
558
+
559
+ bsz, q_len, _ = hidden_states.size()
560
+
561
+ # Initial attention projections
562
+ is_cross_attention = encoder_hidden_states is not None
563
+ if is_cross_attention:
564
+ if not hasattr(self, "q_attn"):
565
+ raise ValueError(
566
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
567
+ "Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
568
+ )
569
+
570
+ query = self.q_attn(hidden_states)
571
+ key, value = self.c_attn(encoder_hidden_states).split(
572
+ self.split_size, dim=2
573
+ )
574
+ attention_mask = encoder_attention_mask
575
+ else:
576
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
577
+
578
+ query = self._split_heads(query, self.num_heads, self.head_dim)
579
+ key = self._split_heads(key, self.num_heads, self.head_dim)
580
+ value = self._split_heads(value, self.num_heads, self.head_dim)
581
+
582
+ # Optional kv caching
583
+ if layer_past is not None:
584
+ past_key = layer_past[0]
585
+ past_value = layer_past[1]
586
+ key = torch.cat((past_key, key), dim=-2)
587
+ value = torch.cat((past_value, value), dim=-2)
588
+
589
+ present = None
590
+ if use_cache is True:
591
+ present = (key, value)
592
+
593
+ # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
594
+ if (
595
+ self.require_contiguous_qkv
596
+ and query.device.type == "cuda"
597
+ and attention_mask is not None
598
+ ):
599
+ query = query.contiguous()
600
+ key = key.contiguous()
601
+ value = value.contiguous()
602
+
603
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
604
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
605
+ is_causal = (
606
+ True
607
+ if attention_mask is None and q_len > 1 and not is_cross_attention
608
+ else False
609
+ )
610
+
611
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
612
+ query,
613
+ key,
614
+ value,
615
+ attn_mask=attention_mask,
616
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
617
+ is_causal=is_causal,
618
+ )
619
+
620
+ # Reshape outputs
621
+ attn_output = attn_output.transpose(1, 2).contiguous()
622
+ attn_output = attn_output.view(bsz, q_len, self.embed_dim)
623
+
624
+ # Final projection
625
+ attn_output = self.c_proj(attn_output)
626
+ attn_output = self.resid_dropout(attn_output)
627
+
628
+ return attn_output, present, None
629
+
630
+
631
  class GPT2MLP(nn.Module):
632
  def __init__(self, intermediate_size, config):
633
  super().__init__()
 
647
  return hidden_states
648
 
649
 
650
+ GPT2_ATTENTION_CLASSES = {
651
+ "eager": GPT2Attention,
652
+ "flash_attention_2": GPT2FlashAttention2,
653
+ "sdpa": GPT2SdpaAttention,
654
+ }
655
+
656
+
657
  class GPT2Block(nn.Module):
658
  def __init__(self, config, layer_idx=None):
659
  super().__init__()
660
  hidden_size = config.hidden_size
661
  inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
662
+ attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation]
663
 
664
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
665
+ self.attn = attention_class(config=config, layer_idx=layer_idx)
666
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
667
 
668
  if config.add_cross_attention:
669
+ self.crossattention = attention_class(
670
+ config=config, is_cross_attention=True, layer_idx=layer_idx
671
  )
672
  self.ln_cross_attn = nn.LayerNorm(
673
  hidden_size, eps=config.layer_norm_epsilon
 
751
  config_class = GPT2Config
752
  load_tf_weights = load_tf_weights_in_gpt2
753
  base_model_prefix = "transformer"
754
+ is_parallelizable = True
755
  supports_gradient_checkpointing = True
756
  _no_split_modules = ["GPT2Block"]
757
  _skip_keys_device_placement = "past_key_values"
758
+ _supports_flash_attn_2 = True
759
+ _supports_sdpa = True
760
 
761
  def __init__(self, *inputs, **kwargs):
762
  super().__init__(*inputs, **kwargs)
 
920
  return_dict (`bool`, *optional*):
921
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
922
  """
923
+ PARALLELIZE_DOCSTRING = r"""
924
+ This is an experimental feature and is a subject to change at a moment's notice.
925
+
926
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
927
+ it will evenly distribute blocks across all devices.
928
+
929
+ Args:
930
+ device_map (`Dict[int, list]`, *optional*):
931
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
932
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
933
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
934
+ following number of attention modules:
935
+
936
+ - openai-community/gpt2: 12
937
+ - openai-community/gpt2-medium: 24
938
+ - openai-community/gpt2-large: 36
939
+ - openai-community/gpt2-xl: 48
940
+
941
+ Example:
942
+
943
+ ```python
944
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
945
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl")
946
+ device_map = {
947
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
948
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
949
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
950
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
951
+ }
952
+ model.parallelize(device_map)
953
+ ```
954
+ """
955
+ DEPARALLELIZE_DOCSTRING = r"""
956
+ Moves the model to cpu from a model parallel state.
957
+
958
+ Example:
959
+
960
+ ```python
961
+ # On a 4 GPU machine with openai-community/gpt2-large:
962
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large")
963
+ device_map = {
964
+ 0: [0, 1, 2, 3, 4, 5, 6, 7],
965
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
966
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
967
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
968
+ }
969
+ model.parallelize(device_map) # Splits the model across several devices
970
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
971
+ ```
972
+ """
973
 
974
 
975
  @add_start_docstrings(
 
977
  GPT2_START_DOCSTRING,
978
  )
979
  class GPT2Model(GPT2PreTrainedModel):
980
+ _supports_param_buffer_assignment = False
981
+
982
  def __init__(self, config):
983
  super().__init__(config)
984
 
 
993
  )
994
  self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
995
 
996
+ # Model parallel
997
+ self.model_parallel = False
998
+ self.device_map = None
999
  self.gradient_checkpointing = False
1000
+ self._attn_implementation = config._attn_implementation
1001
 
1002
  # Initialize weights and apply final processing
1003
  self.post_init()
1004
 
1005
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1006
+ def parallelize(self, device_map=None):
1007
+ # Check validity of device_map
1008
+ warnings.warn(
1009
+ "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
1010
+ " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1011
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
1012
+ " ...}",
1013
+ FutureWarning,
1014
+ )
1015
+ self.device_map = (
1016
+ get_device_map(len(self.h), range(torch.cuda.device_count()))
1017
+ if device_map is None
1018
+ else device_map
1019
+ )
1020
+ assert_device_map(self.device_map, len(self.h))
1021
+ self.model_parallel = True
1022
+ self.first_device = (
1023
+ "cpu"
1024
+ if "cpu" in self.device_map.keys()
1025
+ else "cuda:" + str(min(self.device_map.keys()))
1026
+ )
1027
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
1028
+ self.wte = self.wte.to(self.first_device)
1029
+ self.wpe = self.wpe.to(self.first_device)
1030
+ # Load onto devices
1031
+ for k, v in self.device_map.items():
1032
+ for block in v:
1033
+ cuda_device = "cuda:" + str(k)
1034
+ self.h[block] = self.h[block].to(cuda_device)
1035
+ # ln_f to last
1036
+ self.ln_f = self.ln_f.to(self.last_device)
1037
+
1038
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1039
+ def deparallelize(self):
1040
+ warnings.warn(
1041
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1042
+ FutureWarning,
1043
+ )
1044
+ self.model_parallel = False
1045
+ self.device_map = None
1046
+ self.first_device = "cpu"
1047
+ self.last_device = "cpu"
1048
+ self.wte = self.wte.to("cpu")
1049
+ self.wpe = self.wpe.to("cpu")
1050
+ for index in range(len(self.h)):
1051
+ self.h[index] = self.h[index].to("cpu")
1052
+ self.ln_f = self.ln_f.to("cpu")
1053
+ torch.cuda.empty_cache()
1054
+
1055
  def get_input_embeddings(self):
1056
  return self.wte
1057
 
 
1136
  )
1137
  position_ids = position_ids.unsqueeze(0)
1138
 
1139
+ if inputs_embeds is None:
1140
+ inputs_embeds = self.wte(input_ids)
1141
+ position_embeds = self.wpe(position_ids)
1142
+ hidden_states = inputs_embeds + position_embeds
1143
+
1144
+ # Attention mask.
1145
+ _use_sdpa = (
1146
+ self._attn_implementation == "sdpa"
1147
+ and output_attentions is False
1148
+ and head_mask is None
1149
+ )
1150
+ attention_mask = (
1151
+ attention_mask.view(batch_size, -1) if attention_mask is not None else None
1152
+ )
1153
+ if self._attn_implementation == "flash_attention_2":
1154
+ attention_mask = (
1155
+ attention_mask
1156
+ if (attention_mask is not None and 0 in attention_mask)
1157
+ else None
1158
+ )
1159
+ elif _use_sdpa:
1160
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1161
+ attention_mask=attention_mask,
1162
+ input_shape=(batch_size, input_shape[-1]),
1163
+ inputs_embeds=inputs_embeds,
1164
+ past_key_values_length=past_length,
1165
+ )
1166
+ else:
1167
+ if attention_mask is not None:
1168
+ # We create a 3D attention mask from a 2D tensor mask.
1169
+ # Sizes are [batch_size, 1, 1, to_seq_length]
1170
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
1171
+ # this attention mask is more simple than the triangular masking of causal attention
1172
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
1173
+ attention_mask = attention_mask[:, None, None, :]
1174
+
1175
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1176
+ # masked positions, this operation will create a tensor which is 0.0 for
1177
+ # positions we want to attend and the dtype's smallest value for masked positions.
1178
+ # Since we are adding it to the raw scores before the softmax, this is
1179
+ # effectively the same as removing these entirely.
1180
+ attention_mask = attention_mask.to(
1181
+ dtype=self.dtype
1182
+ ) # fp16 compatibility
1183
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
1184
 
1185
  # If a 2D or 3D attention mask is provided for the cross-attention
1186
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1187
  if self.config.add_cross_attention and encoder_hidden_states is not None:
1188
+ encoder_batch_size, encoder_sequence_length, _ = (
1189
+ encoder_hidden_states.size()
1190
+ )
 
 
1191
  encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1192
  if encoder_attention_mask is None:
1193
  encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1194
+ if _use_sdpa:
1195
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1196
+ mask=encoder_attention_mask,
1197
+ dtype=inputs_embeds.dtype,
1198
+ tgt_len=input_shape[-1],
1199
+ )
1200
+ elif not self._attn_implementation == "flash_attention_2":
1201
+ encoder_attention_mask = self.invert_attention_mask(
1202
+ encoder_attention_mask
1203
+ )
1204
  else:
1205
  encoder_attention_mask = None
1206
 
 
1210
  # head_mask has shape n_layer x batch x n_heads x N x N
1211
  head_mask = self.get_head_mask(head_mask, self.config.n_layer)
1212
 
 
 
 
 
 
1213
  if token_type_ids is not None:
1214
  token_type_embeds = self.wte(token_type_ids)
1215
  hidden_states = hidden_states + token_type_embeds
 
1231
  () if output_attentions and self.config.add_cross_attention else None
1232
  )
1233
  all_hidden_states = () if output_hidden_states else None
1234
+ for i in range(len(self.h)):
1235
+ block, layer_past = self.h[i], past_key_values[i]
1236
+ # Model parallel
1237
+ if self.model_parallel:
1238
+ torch.cuda.set_device(hidden_states.device)
1239
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
1240
+ if layer_past is not None:
1241
+ layer_past = tuple(
1242
+ past_state.to(hidden_states.device) for past_state in layer_past
1243
+ )
1244
+ # Ensure that attention_mask is always on the same device as hidden_states
1245
+ if attention_mask is not None:
1246
+ attention_mask = attention_mask.to(hidden_states.device)
1247
+ if isinstance(head_mask, torch.Tensor):
1248
+ head_mask = head_mask.to(hidden_states.device)
1249
  if output_hidden_states:
1250
  all_hidden_states = all_hidden_states + (hidden_states,)
1251
 
 
1286
  outputs[3 if use_cache else 2],
1287
  )
1288
 
1289
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1290
+ if self.model_parallel:
1291
+ for k, v in self.device_map.items():
1292
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1293
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1294
+
1295
  hidden_states = self.ln_f(hidden_states)
1296
 
1297
  hidden_states = hidden_states.view(output_shape)
 
1328
  """,
1329
  GPT2_START_DOCSTRING,
1330
  )
1331
+ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
1332
  _tied_weights_keys = ["lm_head.weight"]
1333
 
1334
  def __init__(self, config):
 
1336
  self.transformer = GPT2Model(config)
1337
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1338
 
1339
+ # Model parallel
1340
+ self.model_parallel = False
1341
+ self.device_map = None
1342
+
1343
  # Initialize weights and apply final processing
1344
  self.post_init()
1345
 
1346
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1347
+ def parallelize(self, device_map=None):
1348
+ warnings.warn(
1349
+ "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
1350
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1351
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
1352
+ " 0, 'transformer.h.1': 1, ...}",
1353
+ FutureWarning,
1354
+ )
1355
+ self.device_map = (
1356
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1357
+ if device_map is None
1358
+ else device_map
1359
+ )
1360
+ assert_device_map(self.device_map, len(self.transformer.h))
1361
+ self.transformer.parallelize(self.device_map)
1362
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1363
+ self.model_parallel = True
1364
+
1365
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1366
+ def deparallelize(self):
1367
+ warnings.warn(
1368
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1369
+ FutureWarning,
1370
+ )
1371
+ self.transformer.deparallelize()
1372
+ self.transformer = self.transformer.to("cpu")
1373
+ self.lm_head = self.lm_head.to("cpu")
1374
+ self.model_parallel = False
1375
+ torch.cuda.empty_cache()
1376
+
1377
  def get_output_embeddings(self):
1378
  return self.lm_head
1379
 
1380
  def set_output_embeddings(self, new_embeddings):
1381
  self.lm_head = new_embeddings
1382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1383
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1384
  @add_code_sample_docstrings(
1385
  checkpoint=_CHECKPOINT_FOR_DOC,
 
1430
  )
1431
  hidden_states = transformer_outputs[0]
1432
 
1433
+ # Set device for model parallelism
1434
+ if self.model_parallel:
1435
+ torch.cuda.set_device(self.transformer.first_device)
1436
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1437
+
1438
  lm_logits = self.lm_head(hidden_states)
1439
 
1440
  loss = None
 
1490
  """,
1491
  GPT2_START_DOCSTRING,
1492
  )
1493
+ class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
1494
  _tied_weights_keys = ["lm_head.weight"]
1495
 
1496
  def __init__(self, config):
 
1500
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1501
  self.multiple_choice_head = SequenceSummary(config)
1502
 
1503
+ # Model parallel
1504
+ self.model_parallel = False
1505
+ self.device_map = None
1506
+
1507
  # Initialize weights and apply final processing
1508
  self.post_init()
1509
 
1510
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1511
+ def parallelize(self, device_map=None):
1512
+ warnings.warn(
1513
+ "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
1514
+ " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
1515
+ " own `device_map` but it needs to be a dictionary module_name to device, so for instance"
1516
+ " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
1517
+ FutureWarning,
1518
+ )
1519
+ self.device_map = (
1520
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1521
+ if device_map is None
1522
+ else device_map
1523
+ )
1524
+ assert_device_map(self.device_map, len(self.transformer.h))
1525
+ self.transformer.parallelize(self.device_map)
1526
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1527
+ self.multiple_choice_head = self.multiple_choice_head.to(
1528
+ self.transformer.first_device
1529
+ )
1530
+ self.model_parallel = True
1531
+
1532
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1533
+ def deparallelize(self):
1534
+ warnings.warn(
1535
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1536
+ FutureWarning,
1537
+ )
1538
+ self.transformer.deparallelize()
1539
+ self.transformer = self.transformer.to("cpu")
1540
+ self.lm_head = self.lm_head.to("cpu")
1541
+ self.multiple_choice_head = self.multiple_choice_head.to("cpu")
1542
+ self.model_parallel = False
1543
+ torch.cuda.empty_cache()
1544
+
1545
  def get_output_embeddings(self):
1546
  return self.lm_head
1547
 
1548
  def set_output_embeddings(self, new_embeddings):
1549
  self.lm_head = new_embeddings
1550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1551
  @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1552
  @replace_return_docstrings(
1553
  output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC
 
1590
  >>> import torch
1591
  >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel
1592
 
1593
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
1594
+ >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2")
1595
 
1596
  >>> # Add a [CLS] to the vocabulary (we should train it also!)
1597
  >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
 
1629
 
1630
  hidden_states = transformer_outputs[0]
1631
 
1632
+ # Set device for model parallelism
1633
+ if self.model_parallel:
1634
+ torch.cuda.set_device(self.transformer.first_device)
1635
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1636
+
1637
  lm_logits = self.lm_head(hidden_states)
1638
  mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1639
 
 
1709
  self.transformer = GPT2Model(config)
1710
  self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1711
 
1712
+ # Model parallel
1713
+ self.model_parallel = False
1714
+ self.device_map = None
1715
+
1716
  # Initialize weights and apply final processing
1717
  self.post_init()
1718
 
 
1783
  sequence_lengths = sequence_lengths.to(logits.device)
1784
  else:
1785
  sequence_lengths = -1
1786
+ logger.warning_once(
1787
  f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1788
  "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1789
  )
 
1856
  self.dropout = nn.Dropout(classifier_dropout)
1857
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1858
 
1859
+ # Model parallel
1860
+ self.model_parallel = False
1861
+ self.device_map = None
1862
+
1863
  # Initialize weights and apply final processing
1864
  self.post_init()
1865
 
 
1961
  self.transformer = GPT2Model(config)
1962
  self.qa_outputs = nn.Linear(config.hidden_size, 2)
1963
 
1964
+ # Model parallel
1965
+ self.model_parallel = False
1966
+ self.device_map = None
1967
+
1968
  # Initialize weights and apply final processing
1969
  self.post_init()
1970