oweller2 commited on
Commit
082b6b3
·
1 Parent(s): 8efbef0
Files changed (1) hide show
  1. tokenizer.py +35 -25
tokenizer.py CHANGED
@@ -22,40 +22,50 @@ class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
22
  last_token_is_eos = torch.tensor([
23
  ends_with_eos(seq) for seq in input_ids
24
  ], dtype=torch.bool)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  elif isinstance(input_ids, numpy.ndarray):
26
  last_token_is_eos = numpy.array([
27
  ends_with_eos(seq) for seq in input_ids
28
  ], dtype=bool)
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  elif isinstance(input_ids, list):
30
  last_token_is_eos = [ends_with_eos(seq) for seq in input_ids]
31
-
32
- # Use the same last_token_is_eos check for both input_ids and attention_mask
33
- for key in ['input_ids', 'attention_mask']:
34
- if isinstance(outputs[key], torch.Tensor):
35
- # Only remove last token where last_token_is_eos is True
36
- mask = last_token_is_eos.unsqueeze(-1)
37
- outputs[key] = torch.where(
38
- mask,
39
- outputs[key][..., :-1],
40
- outputs[key]
41
- )
42
- elif isinstance(outputs[key], numpy.ndarray):
43
- # Expand dimensions for broadcasting
44
- mask = numpy.expand_dims(last_token_is_eos, -1)
45
- outputs[key] = numpy.where(
46
- mask,
47
- outputs[key][..., :-1],
48
- outputs[key]
49
- )
50
- elif isinstance(outputs[key], list):
51
- # For lists, use the same last_token_is_eos list for both keys
52
- outputs[key] = [
53
- sequence[:-1] if is_eos else sequence
54
- for sequence, is_eos in zip(outputs[key], last_token_is_eos)
55
- ]
56
 
57
  return outputs
58
 
 
59
  # Register the class
60
  from transformers import AutoTokenizer
61
  AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)
 
22
  last_token_is_eos = torch.tensor([
23
  ends_with_eos(seq) for seq in input_ids
24
  ], dtype=torch.bool)
25
+
26
+ if last_token_is_eos.any():
27
+ # Process each sequence individually
28
+ batch_size = input_ids.shape[0]
29
+ for i in range(batch_size):
30
+ if last_token_is_eos[i]:
31
+ for key in ['input_ids', 'attention_mask']:
32
+ # Remove last token and add padding at start for this sequence
33
+ truncated = outputs[key][i, :-1]
34
+ outputs[key][i] = torch.cat([
35
+ torch.zeros_like(truncated[:1]),
36
+ truncated
37
+ ])
38
+
39
  elif isinstance(input_ids, numpy.ndarray):
40
  last_token_is_eos = numpy.array([
41
  ends_with_eos(seq) for seq in input_ids
42
  ], dtype=bool)
43
+
44
+ if last_token_is_eos.any():
45
+ batch_size = input_ids.shape[0]
46
+ for i in range(batch_size):
47
+ if last_token_is_eos[i]:
48
+ for key in ['input_ids', 'attention_mask']:
49
+ # Remove last token and add padding at start for this sequence
50
+ truncated = outputs[key][i, :-1]
51
+ outputs[key][i] = numpy.concatenate([
52
+ numpy.zeros_like(truncated[:1]),
53
+ truncated
54
+ ])
55
+
56
  elif isinstance(input_ids, list):
57
  last_token_is_eos = [ends_with_eos(seq) for seq in input_ids]
58
+
59
+ if any(last_token_is_eos):
60
+ for key in ['input_ids', 'attention_mask']:
61
+ outputs[key] = [
62
+ [0] + sequence[:-1] if is_eos else sequence
63
+ for sequence, is_eos in zip(outputs[key], last_token_is_eos)
64
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  return outputs
67
 
68
+
69
  # Register the class
70
  from transformers import AutoTokenizer
71
  AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)