GradientGuru
commited on
Commit
•
92dd329
1
Parent(s):
85278de
cache alibi_mask for accelerate training
Browse files- modeling_baichuan.py +8 -2
modeling_baichuan.py
CHANGED
@@ -249,7 +249,8 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
249 |
self.gradient_checkpointing = config.gradient_checkpointing
|
250 |
self.post_init()
|
251 |
self.max_cache_pos = config.model_max_length
|
252 |
-
self.first_run = True
|
|
|
253 |
|
254 |
def get_input_embeddings(self):
|
255 |
return self.embed_tokens
|
@@ -306,8 +307,13 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
306 |
if inputs_embeds is None:
|
307 |
inputs_embeds = self.embed_tokens(input_ids)
|
308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
-
alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
|
311 |
if attention_mask is not None:
|
312 |
if len(attention_mask.shape) == 2:
|
313 |
expanded_mask = attention_mask.to(alibi_mask.dtype)
|
|
|
249 |
self.gradient_checkpointing = config.gradient_checkpointing
|
250 |
self.post_init()
|
251 |
self.max_cache_pos = config.model_max_length
|
252 |
+
self.first_run = True
|
253 |
+
self.alibi_mask = None
|
254 |
|
255 |
def get_input_embeddings(self):
|
256 |
return self.embed_tokens
|
|
|
307 |
if inputs_embeds is None:
|
308 |
inputs_embeds = self.embed_tokens(input_ids)
|
309 |
|
310 |
+
if self.training:
|
311 |
+
if self.alibi_mask is None or self.alibi_mask.shape[-1] != seq_length_with_past:
|
312 |
+
self.alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
|
313 |
+
alibi_mask = self.alibi_mask
|
314 |
+
else:
|
315 |
+
alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
|
316 |
|
|
|
317 |
if attention_mask is not None:
|
318 |
if len(attention_mask.shape) == 2:
|
319 |
expanded_mask = attention_mask.to(alibi_mask.dtype)
|