oweller2 commited on
Commit
d831694
1 Parent(s): 7561dc4

update model

Browse files
Files changed (2) hide show
  1. modeling_flexbert.py +3 -6
  2. pytorch_model.bin +1 -1
modeling_flexbert.py CHANGED
@@ -1654,7 +1654,7 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1654
 
1655
  hidden_states = self.bert(
1656
  input_ids,
1657
- attention_mask=None,
1658
  position_ids=position_ids,
1659
  indices=indices,
1660
  cu_seqlens=cu_seqlens,
@@ -1703,11 +1703,8 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1703
  shift_labels.view(-1)
1704
  )
1705
 
1706
- if self.unpad_embeddings:
1707
- try:
1708
- logits = logits.view(batch_size, -1, self.vocab_size)
1709
- except Exception as e:
1710
- breakpoint()
1711
 
1712
  if self.pad_logits:
1713
  # print(f"Padding logits: {logits.shape}")
 
1654
 
1655
  hidden_states = self.bert(
1656
  input_ids,
1657
+ attention_mask=None, # let FA handle it
1658
  position_ids=position_ids,
1659
  indices=indices,
1660
  cu_seqlens=cu_seqlens,
 
1703
  shift_labels.view(-1)
1704
  )
1705
 
1706
+ if self.unpad_embeddings: # revert back to normal logits
1707
+ logits = logits.view(batch_size, -1, self.vocab_size)
 
 
 
1708
 
1709
  if self.pad_logits:
1710
  # print(f"Padding logits: {logits.shape}")
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fea155de40c6fd0d7f58f431f493c5f614a64dabe168b72cbc74421a9bf17baf
3
  size 598685038
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7863cc4c58494c661ffd3c77af90796f5caa8217917c2f6e7c99cc28d65b58c2
3
  size 598685038