Upload folder using huggingface_hub
Browse files- config.json +2 -1
- modeling_gpt2.py +4 -4
config.json
CHANGED
@@ -32,5 +32,6 @@
|
|
32 |
"torch_dtype": "bfloat16",
|
33 |
"transformers_version": "4.41.2",
|
34 |
"use_cache": true,
|
35 |
-
"vocab_size": 50257
|
|
|
36 |
}
|
|
|
32 |
"torch_dtype": "bfloat16",
|
33 |
"transformers_version": "4.41.2",
|
34 |
"use_cache": true,
|
35 |
+
"vocab_size": 50257,
|
36 |
+
"#": {"_attn_implementation": "flash_attention_2"}
|
37 |
}
|
modeling_gpt2.py
CHANGED
@@ -171,8 +171,8 @@ class GPT2Attention(nn.Module):
|
|
171 |
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
172 |
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
173 |
|
174 |
-
# rhys101 do attention in float32 if model in bfloat16 ?
|
175 |
-
if self.config.torch_dtype == torch.bfloat16:
|
176 |
self.c_attn = self.c_attn.to(torch.float32)
|
177 |
self.c_proj = self.c_proj.to(torch.float32)
|
178 |
|
@@ -315,8 +315,8 @@ class GPT2Attention(nn.Module):
|
|
315 |
output_attentions: Optional[bool] = False,
|
316 |
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
317 |
|
318 |
-
# rhys101 do attention in float32 if model in bfloat16 ?
|
319 |
-
if self.config.torch_dtype == torch.bfloat16:
|
320 |
hidden_states = hidden_states.to(torch.float32)
|
321 |
|
322 |
if encoder_hidden_states is not None:
|
|
|
171 |
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
172 |
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
173 |
|
174 |
+
# rhys101 do non flash attention in float32 if model in bfloat16 ?
|
175 |
+
if self.config._attn_implementation == 'eager' and self.config.torch_dtype == torch.bfloat16:
|
176 |
self.c_attn = self.c_attn.to(torch.float32)
|
177 |
self.c_proj = self.c_proj.to(torch.float32)
|
178 |
|
|
|
315 |
output_attentions: Optional[bool] = False,
|
316 |
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
317 |
|
318 |
+
# rhys101 do non flash attention in float32 if model in bfloat16 ?
|
319 |
+
if self.config._attn_implementation == 'eager' and self.config.torch_dtype == torch.bfloat16:
|
320 |
hidden_states = hidden_states.to(torch.float32)
|
321 |
|
322 |
if encoder_hidden_states is not None:
|