oweller2
commited on
Commit
•
d831694
1
Parent(s):
7561dc4
update model
Browse files- modeling_flexbert.py +3 -6
- pytorch_model.bin +1 -1
modeling_flexbert.py
CHANGED
@@ -1654,7 +1654,7 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1654 |
|
1655 |
hidden_states = self.bert(
|
1656 |
input_ids,
|
1657 |
-
attention_mask=None,
|
1658 |
position_ids=position_ids,
|
1659 |
indices=indices,
|
1660 |
cu_seqlens=cu_seqlens,
|
@@ -1703,11 +1703,8 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1703 |
shift_labels.view(-1)
|
1704 |
)
|
1705 |
|
1706 |
-
if self.unpad_embeddings:
|
1707 |
-
|
1708 |
-
logits = logits.view(batch_size, -1, self.vocab_size)
|
1709 |
-
except Exception as e:
|
1710 |
-
breakpoint()
|
1711 |
|
1712 |
if self.pad_logits:
|
1713 |
# print(f"Padding logits: {logits.shape}")
|
|
|
1654 |
|
1655 |
hidden_states = self.bert(
|
1656 |
input_ids,
|
1657 |
+
attention_mask=None, # let FA handle it
|
1658 |
position_ids=position_ids,
|
1659 |
indices=indices,
|
1660 |
cu_seqlens=cu_seqlens,
|
|
|
1703 |
shift_labels.view(-1)
|
1704 |
)
|
1705 |
|
1706 |
+
if self.unpad_embeddings: # revert back to normal logits
|
1707 |
+
logits = logits.view(batch_size, -1, self.vocab_size)
|
|
|
|
|
|
|
1708 |
|
1709 |
if self.pad_logits:
|
1710 |
# print(f"Padding logits: {logits.shape}")
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 598685038
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7863cc4c58494c661ffd3c77af90796f5caa8217917c2f6e7c99cc28d65b58c2
|
3 |
size 598685038
|