Update modeling_gpt2a.py
Browse files- modeling_gpt2a.py +60 -59
modeling_gpt2a.py
CHANGED
@@ -952,65 +952,66 @@ class GPT2AModel(GPT2APreTrainedModel):
|
|
952 |
all_self_attentions = () if output_attentions else None
|
953 |
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
954 |
all_hidden_states = () if output_hidden_states else None
|
955 |
-
for
|
956 |
-
|
957 |
-
|
958 |
-
|
959 |
-
|
960 |
-
|
961 |
-
layer_past
|
962 |
-
|
963 |
-
|
964 |
-
attention_mask
|
965 |
-
|
966 |
-
head_mask
|
967 |
-
|
968 |
-
|
969 |
-
|
970 |
-
|
971 |
-
|
972 |
-
|
973 |
-
def
|
974 |
-
|
975 |
-
|
976 |
-
|
977 |
-
|
978 |
-
|
979 |
-
|
980 |
-
|
981 |
-
|
982 |
-
|
983 |
-
|
984 |
-
|
985 |
-
|
986 |
-
|
987 |
-
|
988 |
-
|
989 |
-
|
990 |
-
|
991 |
-
|
992 |
-
|
993 |
-
|
994 |
-
|
995 |
-
|
996 |
-
|
997 |
-
|
998 |
-
|
999 |
-
|
1000 |
-
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
|
1006 |
-
|
1007 |
-
|
1008 |
-
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
-
|
1013 |
-
|
|
|
1014 |
|
1015 |
hidden_states = self.ln_f(hidden_states)
|
1016 |
|
|
|
952 |
all_self_attentions = () if output_attentions else None
|
953 |
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
954 |
all_hidden_states = () if output_hidden_states else None
|
955 |
+
for _full_iteration in range(self.config.full_layer_repetitions):
|
956 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
957 |
+
# Model parallel
|
958 |
+
if self.model_parallel:
|
959 |
+
torch.cuda.set_device(hidden_states.device)
|
960 |
+
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
961 |
+
if layer_past is not None:
|
962 |
+
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
963 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
964 |
+
if attention_mask is not None:
|
965 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
966 |
+
if isinstance(head_mask, torch.Tensor):
|
967 |
+
head_mask = head_mask.to(hidden_states.device)
|
968 |
+
if output_hidden_states:
|
969 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
970 |
+
|
971 |
+
if self.gradient_checkpointing and self.training:
|
972 |
+
|
973 |
+
def create_custom_forward(module):
|
974 |
+
def custom_forward(*inputs):
|
975 |
+
# None for past_key_value
|
976 |
+
return module(*inputs, use_cache, output_attentions)
|
977 |
+
|
978 |
+
return custom_forward
|
979 |
+
|
980 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
981 |
+
create_custom_forward(block),
|
982 |
+
hidden_states,
|
983 |
+
None,
|
984 |
+
attention_mask,
|
985 |
+
head_mask[i],
|
986 |
+
encoder_hidden_states,
|
987 |
+
encoder_attention_mask,
|
988 |
+
)
|
989 |
+
else:
|
990 |
+
outputs = block(
|
991 |
+
hidden_states,
|
992 |
+
layer_past=layer_past,
|
993 |
+
attention_mask=attention_mask,
|
994 |
+
head_mask=head_mask[i],
|
995 |
+
encoder_hidden_states=encoder_hidden_states,
|
996 |
+
encoder_attention_mask=encoder_attention_mask,
|
997 |
+
use_cache=use_cache,
|
998 |
+
output_attentions=output_attentions,
|
999 |
+
)
|
1000 |
+
|
1001 |
+
hidden_states = outputs[0]
|
1002 |
+
if use_cache is True:
|
1003 |
+
presents = presents + (outputs[1],)
|
1004 |
+
|
1005 |
+
if output_attentions:
|
1006 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
1007 |
+
if self.config.add_cross_attention:
|
1008 |
+
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
1009 |
+
|
1010 |
+
# Model Parallel: If it's the last layer for that device, put things on the next device
|
1011 |
+
if self.model_parallel:
|
1012 |
+
for k, v in self.device_map.items():
|
1013 |
+
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
1014 |
+
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
1015 |
|
1016 |
hidden_states = self.ln_f(hidden_states)
|
1017 |
|