crumb commited on
Commit
69affb4
·
1 Parent(s): e9b4d3b

Update modeling_gpt2a.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2a.py +60 -59
modeling_gpt2a.py CHANGED
@@ -952,65 +952,66 @@ class GPT2AModel(GPT2APreTrainedModel):
952
  all_self_attentions = () if output_attentions else None
953
  all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
954
  all_hidden_states = () if output_hidden_states else None
955
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
956
- # Model parallel
957
- if self.model_parallel:
958
- torch.cuda.set_device(hidden_states.device)
959
- # Ensure layer_past is on same device as hidden_states (might not be correct)
960
- if layer_past is not None:
961
- layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
962
- # Ensure that attention_mask is always on the same device as hidden_states
963
- if attention_mask is not None:
964
- attention_mask = attention_mask.to(hidden_states.device)
965
- if isinstance(head_mask, torch.Tensor):
966
- head_mask = head_mask.to(hidden_states.device)
967
- if output_hidden_states:
968
- all_hidden_states = all_hidden_states + (hidden_states,)
969
-
970
- if self.gradient_checkpointing and self.training:
971
-
972
- def create_custom_forward(module):
973
- def custom_forward(*inputs):
974
- # None for past_key_value
975
- return module(*inputs, use_cache, output_attentions)
976
-
977
- return custom_forward
978
-
979
- outputs = torch.utils.checkpoint.checkpoint(
980
- create_custom_forward(block),
981
- hidden_states,
982
- None,
983
- attention_mask,
984
- head_mask[i],
985
- encoder_hidden_states,
986
- encoder_attention_mask,
987
- )
988
- else:
989
- outputs = block(
990
- hidden_states,
991
- layer_past=layer_past,
992
- attention_mask=attention_mask,
993
- head_mask=head_mask[i],
994
- encoder_hidden_states=encoder_hidden_states,
995
- encoder_attention_mask=encoder_attention_mask,
996
- use_cache=use_cache,
997
- output_attentions=output_attentions,
998
- )
999
-
1000
- hidden_states = outputs[0]
1001
- if use_cache is True:
1002
- presents = presents + (outputs[1],)
1003
-
1004
- if output_attentions:
1005
- all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
1006
- if self.config.add_cross_attention:
1007
- all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
1008
-
1009
- # Model Parallel: If it's the last layer for that device, put things on the next device
1010
- if self.model_parallel:
1011
- for k, v in self.device_map.items():
1012
- if i == v[-1] and "cuda:" + str(k) != self.last_device:
1013
- hidden_states = hidden_states.to("cuda:" + str(k + 1))
 
1014
 
1015
  hidden_states = self.ln_f(hidden_states)
1016
 
 
952
  all_self_attentions = () if output_attentions else None
953
  all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
954
  all_hidden_states = () if output_hidden_states else None
955
+ for _full_iteration in range(self.config.full_layer_repetitions):
956
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
957
+ # Model parallel
958
+ if self.model_parallel:
959
+ torch.cuda.set_device(hidden_states.device)
960
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
961
+ if layer_past is not None:
962
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
963
+ # Ensure that attention_mask is always on the same device as hidden_states
964
+ if attention_mask is not None:
965
+ attention_mask = attention_mask.to(hidden_states.device)
966
+ if isinstance(head_mask, torch.Tensor):
967
+ head_mask = head_mask.to(hidden_states.device)
968
+ if output_hidden_states:
969
+ all_hidden_states = all_hidden_states + (hidden_states,)
970
+
971
+ if self.gradient_checkpointing and self.training:
972
+
973
+ def create_custom_forward(module):
974
+ def custom_forward(*inputs):
975
+ # None for past_key_value
976
+ return module(*inputs, use_cache, output_attentions)
977
+
978
+ return custom_forward
979
+
980
+ outputs = torch.utils.checkpoint.checkpoint(
981
+ create_custom_forward(block),
982
+ hidden_states,
983
+ None,
984
+ attention_mask,
985
+ head_mask[i],
986
+ encoder_hidden_states,
987
+ encoder_attention_mask,
988
+ )
989
+ else:
990
+ outputs = block(
991
+ hidden_states,
992
+ layer_past=layer_past,
993
+ attention_mask=attention_mask,
994
+ head_mask=head_mask[i],
995
+ encoder_hidden_states=encoder_hidden_states,
996
+ encoder_attention_mask=encoder_attention_mask,
997
+ use_cache=use_cache,
998
+ output_attentions=output_attentions,
999
+ )
1000
+
1001
+ hidden_states = outputs[0]
1002
+ if use_cache is True:
1003
+ presents = presents + (outputs[1],)
1004
+
1005
+ if output_attentions:
1006
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
1007
+ if self.config.add_cross_attention:
1008
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
1009
+
1010
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1011
+ if self.model_parallel:
1012
+ for k, v in self.device_map.items():
1013
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1014
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1015
 
1016
  hidden_states = self.ln_f(hidden_states)
1017