oweller2 commited on
Commit
2d5427f
1 Parent(s): 082b6b3
Files changed (3) hide show
  1. config.json +1 -1
  2. modeling_flexbert.py +0 -22
  3. tokenizer.py +15 -3
config.json CHANGED
@@ -70,7 +70,7 @@
70
  "num_hidden_layers": 22,
71
  "num_initial_layers": 1,
72
  "pad_logits": true,
73
- "pad_token_id": 50283,
74
  "padding": "unpadded",
75
  "pooling_type": "cls",
76
  "position_embedding_type": "absolute",
 
70
  "num_hidden_layers": 22,
71
  "num_initial_layers": 1,
72
  "pad_logits": true,
73
+ "pad_token_id": null,
74
  "padding": "unpadded",
75
  "pooling_type": "cls",
76
  "position_embedding_type": "absolute",
modeling_flexbert.py CHANGED
@@ -1713,36 +1713,14 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1713
  self,
1714
  input_ids: torch.Tensor,
1715
  attention_mask: Optional[torch.Tensor] = None,
1716
- position_ids: Optional[torch.Tensor] = None,
1717
  **kwargs
1718
  ) -> dict:
1719
  if attention_mask is None:
1720
  attention_mask = torch.ones_like(input_ids)
1721
 
1722
- # Calculate sequence-local positions
1723
- seqlens = attention_mask.sum(dim=-1) # Get length of each sequence
1724
- position_ids = torch.zeros_like(input_ids)
1725
- for i in range(len(seqlens)):
1726
- position_ids[i, :seqlens[i]] = torch.arange(seqlens[i], device=input_ids.device)
1727
-
1728
-
1729
- batch_size, seq_len = input_ids.shape[:2]
1730
- if self.unpad_embeddings:
1731
- input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
1732
- input_ids, attention_mask, position_ids, None
1733
- )
1734
- else:
1735
- indices = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).repeat(batch_size, 1)
1736
- cu_seqlens = None
1737
- max_seqlen = None
1738
  return {
1739
  "input_ids": input_ids,
1740
  "attention_mask": attention_mask,
1741
- "position_ids": position_ids,
1742
- "indices": indices,
1743
- "cu_seqlens": cu_seqlens,
1744
- "max_seqlen": max_seqlen,
1745
- "batch_size": batch_size,
1746
  }
1747
 
1748
  def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
 
1713
  self,
1714
  input_ids: torch.Tensor,
1715
  attention_mask: Optional[torch.Tensor] = None,
 
1716
  **kwargs
1717
  ) -> dict:
1718
  if attention_mask is None:
1719
  attention_mask = torch.ones_like(input_ids)
1720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1721
  return {
1722
  "input_ids": input_ids,
1723
  "attention_mask": attention_mask,
 
 
 
 
 
1724
  }
1725
 
1726
  def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
tokenizer.py CHANGED
@@ -23,7 +23,11 @@ class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
23
  ends_with_eos(seq) for seq in input_ids
24
  ], dtype=torch.bool)
25
 
26
- if last_token_is_eos.any():
 
 
 
 
27
  # Process each sequence individually
28
  batch_size = input_ids.shape[0]
29
  for i in range(batch_size):
@@ -41,7 +45,11 @@ class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
41
  ends_with_eos(seq) for seq in input_ids
42
  ], dtype=bool)
43
 
44
- if last_token_is_eos.any():
 
 
 
 
45
  batch_size = input_ids.shape[0]
46
  for i in range(batch_size):
47
  if last_token_is_eos[i]:
@@ -56,7 +64,11 @@ class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
56
  elif isinstance(input_ids, list):
57
  last_token_is_eos = [ends_with_eos(seq) for seq in input_ids]
58
 
59
- if any(last_token_is_eos):
 
 
 
 
60
  for key in ['input_ids', 'attention_mask']:
61
  outputs[key] = [
62
  [0] + sequence[:-1] if is_eos else sequence
 
23
  ends_with_eos(seq) for seq in input_ids
24
  ], dtype=torch.bool)
25
 
26
+ if last_token_is_eos.all():
27
+ # If all sequences have EOS, just truncate all
28
+ for key in ['input_ids', 'attention_mask']:
29
+ outputs[key] = outputs[key][..., :-1]
30
+ elif last_token_is_eos.any():
31
  # Process each sequence individually
32
  batch_size = input_ids.shape[0]
33
  for i in range(batch_size):
 
45
  ends_with_eos(seq) for seq in input_ids
46
  ], dtype=bool)
47
 
48
+ if last_token_is_eos.all():
49
+ # If all sequences have EOS, just truncate all
50
+ for key in ['input_ids', 'attention_mask']:
51
+ outputs[key] = outputs[key][..., :-1]
52
+ elif last_token_is_eos.any():
53
  batch_size = input_ids.shape[0]
54
  for i in range(batch_size):
55
  if last_token_is_eos[i]:
 
64
  elif isinstance(input_ids, list):
65
  last_token_is_eos = [ends_with_eos(seq) for seq in input_ids]
66
 
67
+ if all(last_token_is_eos):
68
+ # If all sequences have EOS, just truncate all
69
+ for key in ['input_ids', 'attention_mask']:
70
+ outputs[key] = [sequence[:-1] for sequence in outputs[key]]
71
+ elif any(last_token_is_eos):
72
  for key in ['input_ids', 'attention_mask']:
73
  outputs[key] = [
74
  [0] + sequence[:-1] if is_eos else sequence