rhysjones commited on
Commit
3f7726c
1 Parent(s): f2bac03

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +2 -1
  2. modeling_gpt2.py +4 -4
config.json CHANGED
@@ -32,5 +32,6 @@
32
  "torch_dtype": "bfloat16",
33
  "transformers_version": "4.41.2",
34
  "use_cache": true,
35
- "vocab_size": 50257
 
36
  }
 
32
  "torch_dtype": "bfloat16",
33
  "transformers_version": "4.41.2",
34
  "use_cache": true,
35
+ "vocab_size": 50257,
36
+ "#": {"_attn_implementation": "flash_attention_2"}
37
  }
modeling_gpt2.py CHANGED
@@ -171,8 +171,8 @@ class GPT2Attention(nn.Module):
171
  self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
172
  self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
173
 
174
- # rhys101 do attention in float32 if model in bfloat16 ?
175
- if self.config.torch_dtype == torch.bfloat16:
176
  self.c_attn = self.c_attn.to(torch.float32)
177
  self.c_proj = self.c_proj.to(torch.float32)
178
 
@@ -315,8 +315,8 @@ class GPT2Attention(nn.Module):
315
  output_attentions: Optional[bool] = False,
316
  ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
317
 
318
- # rhys101 do attention in float32 if model in bfloat16 ?
319
- if self.config.torch_dtype == torch.bfloat16:
320
  hidden_states = hidden_states.to(torch.float32)
321
 
322
  if encoder_hidden_states is not None:
 
171
  self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
172
  self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
173
 
174
+ # rhys101 do non flash attention in float32 if model in bfloat16 ?
175
+ if self.config._attn_implementation == 'eager' and self.config.torch_dtype == torch.bfloat16:
176
  self.c_attn = self.c_attn.to(torch.float32)
177
  self.c_proj = self.c_proj.to(torch.float32)
178
 
 
315
  output_attentions: Optional[bool] = False,
316
  ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
317
 
318
+ # rhys101 do non flash attention in float32 if model in bfloat16 ?
319
+ if self.config._attn_implementation == 'eager' and self.config.torch_dtype == torch.bfloat16:
320
  hidden_states = hidden_states.to(torch.float32)
321
 
322
  if encoder_hidden_states is not None: