oweller2
commited on
Commit
•
2d5427f
1
Parent(s):
082b6b3
done
Browse files- config.json +1 -1
- modeling_flexbert.py +0 -22
- tokenizer.py +15 -3
config.json
CHANGED
@@ -70,7 +70,7 @@
|
|
70 |
"num_hidden_layers": 22,
|
71 |
"num_initial_layers": 1,
|
72 |
"pad_logits": true,
|
73 |
-
"pad_token_id":
|
74 |
"padding": "unpadded",
|
75 |
"pooling_type": "cls",
|
76 |
"position_embedding_type": "absolute",
|
|
|
70 |
"num_hidden_layers": 22,
|
71 |
"num_initial_layers": 1,
|
72 |
"pad_logits": true,
|
73 |
+
"pad_token_id": null,
|
74 |
"padding": "unpadded",
|
75 |
"pooling_type": "cls",
|
76 |
"position_embedding_type": "absolute",
|
modeling_flexbert.py
CHANGED
@@ -1713,36 +1713,14 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1713 |
self,
|
1714 |
input_ids: torch.Tensor,
|
1715 |
attention_mask: Optional[torch.Tensor] = None,
|
1716 |
-
position_ids: Optional[torch.Tensor] = None,
|
1717 |
**kwargs
|
1718 |
) -> dict:
|
1719 |
if attention_mask is None:
|
1720 |
attention_mask = torch.ones_like(input_ids)
|
1721 |
|
1722 |
-
# Calculate sequence-local positions
|
1723 |
-
seqlens = attention_mask.sum(dim=-1) # Get length of each sequence
|
1724 |
-
position_ids = torch.zeros_like(input_ids)
|
1725 |
-
for i in range(len(seqlens)):
|
1726 |
-
position_ids[i, :seqlens[i]] = torch.arange(seqlens[i], device=input_ids.device)
|
1727 |
-
|
1728 |
-
|
1729 |
-
batch_size, seq_len = input_ids.shape[:2]
|
1730 |
-
if self.unpad_embeddings:
|
1731 |
-
input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
|
1732 |
-
input_ids, attention_mask, position_ids, None
|
1733 |
-
)
|
1734 |
-
else:
|
1735 |
-
indices = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).repeat(batch_size, 1)
|
1736 |
-
cu_seqlens = None
|
1737 |
-
max_seqlen = None
|
1738 |
return {
|
1739 |
"input_ids": input_ids,
|
1740 |
"attention_mask": attention_mask,
|
1741 |
-
"position_ids": position_ids,
|
1742 |
-
"indices": indices,
|
1743 |
-
"cu_seqlens": cu_seqlens,
|
1744 |
-
"max_seqlen": max_seqlen,
|
1745 |
-
"batch_size": batch_size,
|
1746 |
}
|
1747 |
|
1748 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
|
|
1713 |
self,
|
1714 |
input_ids: torch.Tensor,
|
1715 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
1716 |
**kwargs
|
1717 |
) -> dict:
|
1718 |
if attention_mask is None:
|
1719 |
attention_mask = torch.ones_like(input_ids)
|
1720 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1721 |
return {
|
1722 |
"input_ids": input_ids,
|
1723 |
"attention_mask": attention_mask,
|
|
|
|
|
|
|
|
|
|
|
1724 |
}
|
1725 |
|
1726 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
tokenizer.py
CHANGED
@@ -23,7 +23,11 @@ class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
|
|
23 |
ends_with_eos(seq) for seq in input_ids
|
24 |
], dtype=torch.bool)
|
25 |
|
26 |
-
if last_token_is_eos.
|
|
|
|
|
|
|
|
|
27 |
# Process each sequence individually
|
28 |
batch_size = input_ids.shape[0]
|
29 |
for i in range(batch_size):
|
@@ -41,7 +45,11 @@ class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
|
|
41 |
ends_with_eos(seq) for seq in input_ids
|
42 |
], dtype=bool)
|
43 |
|
44 |
-
if last_token_is_eos.
|
|
|
|
|
|
|
|
|
45 |
batch_size = input_ids.shape[0]
|
46 |
for i in range(batch_size):
|
47 |
if last_token_is_eos[i]:
|
@@ -56,7 +64,11 @@ class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
|
|
56 |
elif isinstance(input_ids, list):
|
57 |
last_token_is_eos = [ends_with_eos(seq) for seq in input_ids]
|
58 |
|
59 |
-
if
|
|
|
|
|
|
|
|
|
60 |
for key in ['input_ids', 'attention_mask']:
|
61 |
outputs[key] = [
|
62 |
[0] + sequence[:-1] if is_eos else sequence
|
|
|
23 |
ends_with_eos(seq) for seq in input_ids
|
24 |
], dtype=torch.bool)
|
25 |
|
26 |
+
if last_token_is_eos.all():
|
27 |
+
# If all sequences have EOS, just truncate all
|
28 |
+
for key in ['input_ids', 'attention_mask']:
|
29 |
+
outputs[key] = outputs[key][..., :-1]
|
30 |
+
elif last_token_is_eos.any():
|
31 |
# Process each sequence individually
|
32 |
batch_size = input_ids.shape[0]
|
33 |
for i in range(batch_size):
|
|
|
45 |
ends_with_eos(seq) for seq in input_ids
|
46 |
], dtype=bool)
|
47 |
|
48 |
+
if last_token_is_eos.all():
|
49 |
+
# If all sequences have EOS, just truncate all
|
50 |
+
for key in ['input_ids', 'attention_mask']:
|
51 |
+
outputs[key] = outputs[key][..., :-1]
|
52 |
+
elif last_token_is_eos.any():
|
53 |
batch_size = input_ids.shape[0]
|
54 |
for i in range(batch_size):
|
55 |
if last_token_is_eos[i]:
|
|
|
64 |
elif isinstance(input_ids, list):
|
65 |
last_token_is_eos = [ends_with_eos(seq) for seq in input_ids]
|
66 |
|
67 |
+
if all(last_token_is_eos):
|
68 |
+
# If all sequences have EOS, just truncate all
|
69 |
+
for key in ['input_ids', 'attention_mask']:
|
70 |
+
outputs[key] = [sequence[:-1] for sequence in outputs[key]]
|
71 |
+
elif any(last_token_is_eos):
|
72 |
for key in ['input_ids', 'attention_mask']:
|
73 |
outputs[key] = [
|
74 |
[0] + sequence[:-1] if is_eos else sequence
|