Update modeling_dbrx.py
Browse files- modeling_dbrx.py +8 -22
modeling_dbrx.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
"""PyTorch Dbrx model."""
|
2 |
|
3 |
import math
|
@@ -244,28 +245,13 @@ def resolve_ffn_act_fn(
|
|
244 |
# Copied from LLaMaAttention
|
245 |
#############################################################################
|
246 |
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
counts = []
|
251 |
-
for i in range(1, max_num + 1):
|
252 |
-
counts.append(
|
253 |
-
torch.sum(attention_mask == i, axis=-1)
|
254 |
-
) # shape: B, count length of data point maksed with i
|
255 |
-
result = torch.stack(counts, axis=1)
|
256 |
-
result = result.flatten()
|
257 |
-
return result[result.nonzero()].squeeze(-1).to(dtype=torch.int32)
|
258 |
-
|
259 |
-
|
260 |
-
def _get_unpad_data(attention_mask):
|
261 |
-
seqlens_in_batch = get_max_seqlen_in_batch(
|
262 |
-
attention_mask
|
263 |
-
) # attention_mask.sum(dim=-1, dtype=torch.int32)
|
264 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
265 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
266 |
-
cu_seqlens = F.pad(
|
267 |
-
|
268 |
-
)
|
269 |
return (
|
270 |
indices,
|
271 |
cu_seqlens,
|
@@ -426,7 +412,7 @@ class DbrxFlashAttention2(DbrxAttention):
|
|
426 |
**kwargs: Any,
|
427 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
428 |
Optional[Tuple[torch.Tensor]]]:
|
429 |
-
logger.
|
430 |
'Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.'
|
431 |
)
|
432 |
output_attentions = False
|
@@ -1459,4 +1445,4 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
|
|
1459 |
reordered_past += (tuple(
|
1460 |
past_state.index_select(0, beam_idx.to(past_state.device))
|
1461 |
for past_state in layer_past),)
|
1462 |
-
return reordered_past
|
|
|
1 |
+
# code adapted from https://huggingface.co/fahadh4ilyas
|
2 |
"""PyTorch Dbrx model."""
|
3 |
|
4 |
import math
|
|
|
245 |
# Copied from LLaMaAttention
|
246 |
#############################################################################
|
247 |
|
248 |
+
|
249 |
+
def _get_unpad_data(attention_mask: torch.Tensor):
|
250 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
252 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
253 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
|
254 |
+
(1, 0))
|
|
|
255 |
return (
|
256 |
indices,
|
257 |
cu_seqlens,
|
|
|
412 |
**kwargs: Any,
|
413 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
414 |
Optional[Tuple[torch.Tensor]]]:
|
415 |
+
logger.debug(
|
416 |
'Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.'
|
417 |
)
|
418 |
output_attentions = False
|
|
|
1445 |
reordered_past += (tuple(
|
1446 |
past_state.index_select(0, beam_idx.to(past_state.device))
|
1447 |
for past_state in layer_past),)
|
1448 |
+
return reordered_past
|