KaleiNeely
commited on
Update tokenization_rwkv5.py
Browse files- tokenization_rwkv5.py +12 -13
tokenization_rwkv5.py
CHANGED
@@ -15,8 +15,8 @@
|
|
15 |
"""Tokenization classes for RWKV5."""
|
16 |
|
17 |
import os
|
18 |
-
from typing import TYPE_CHECKING, List, Optional, Tuple
|
19 |
import re
|
|
|
20 |
|
21 |
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
22 |
from transformers.utils import logging
|
@@ -37,7 +37,6 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|
37 |
}
|
38 |
|
39 |
|
40 |
-
|
41 |
def whitespace_tokenize(text):
|
42 |
"""Runs basic whitespace cleaning and splitting on a piece of text.
|
43 |
The separators are kept
|
@@ -52,10 +51,9 @@ def whitespace_tokenize(text):
|
|
52 |
class WordpieceTokenizer(object):
|
53 |
"""Runs WordPiece tokenization."""
|
54 |
|
55 |
-
def __init__(self, vocab, unk_token
|
56 |
self.vocab = vocab
|
57 |
self.unk_token = unk_token
|
58 |
-
self.max_input_chars_per_word = max_input_chars_per_word
|
59 |
|
60 |
def tokenize(self, text):
|
61 |
"""
|
@@ -75,10 +73,6 @@ class WordpieceTokenizer(object):
|
|
75 |
output_tokens = []
|
76 |
for token in whitespace_tokenize(text):
|
77 |
chars = list(token)
|
78 |
-
if len(chars) > self.max_input_chars_per_word:
|
79 |
-
output_tokens.append(self.unk_token)
|
80 |
-
continue
|
81 |
-
|
82 |
is_bad = False
|
83 |
start = 0
|
84 |
sub_tokens = []
|
@@ -94,9 +88,12 @@ class WordpieceTokenizer(object):
|
|
94 |
if cur_substr is None:
|
95 |
is_bad = True
|
96 |
break
|
97 |
-
|
|
|
|
|
|
|
|
|
98 |
start = end
|
99 |
-
|
100 |
if is_bad:
|
101 |
output_tokens.append(self.unk_token)
|
102 |
else:
|
@@ -111,7 +108,7 @@ class Rwkv5Tokenizer(PreTrainedTokenizer):
|
|
111 |
|
112 |
model_input_names = ["input_ids", "attention_mask"]
|
113 |
|
114 |
-
def __init__(self, vocab_file, bos_token="<s>", eos_token="<s>", unk_token="<s>",
|
115 |
if not os.path.isfile(vocab_file):
|
116 |
raise ValueError(
|
117 |
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
|
@@ -130,7 +127,7 @@ class Rwkv5Tokenizer(PreTrainedTokenizer):
|
|
130 |
self.decoder = {v: k for k, v in vocab.items()}
|
131 |
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.encoder, unk_token=str(unk_token))
|
132 |
self._added_tokens_decoder = {0: AddedToken(str(bos_token))}
|
133 |
-
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
|
134 |
|
135 |
@property
|
136 |
def vocab_size(self):
|
@@ -146,7 +143,9 @@ class Rwkv5Tokenizer(PreTrainedTokenizer):
|
|
146 |
|
147 |
def _convert_token_to_id(self, token):
|
148 |
"""Converts a token (byte) to an id using the vocab."""
|
149 |
-
if
|
|
|
|
|
150 |
token = token.encode("utf-8", errors="replace")
|
151 |
return self.encoder.get(token, self.unk_token_id)
|
152 |
|
|
|
15 |
"""Tokenization classes for RWKV5."""
|
16 |
|
17 |
import os
|
|
|
18 |
import re
|
19 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
20 |
|
21 |
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
22 |
from transformers.utils import logging
|
|
|
37 |
}
|
38 |
|
39 |
|
|
|
40 |
def whitespace_tokenize(text):
|
41 |
"""Runs basic whitespace cleaning and splitting on a piece of text.
|
42 |
The separators are kept
|
|
|
51 |
class WordpieceTokenizer(object):
|
52 |
"""Runs WordPiece tokenization."""
|
53 |
|
54 |
+
def __init__(self, vocab, unk_token):
|
55 |
self.vocab = vocab
|
56 |
self.unk_token = unk_token
|
|
|
57 |
|
58 |
def tokenize(self, text):
|
59 |
"""
|
|
|
73 |
output_tokens = []
|
74 |
for token in whitespace_tokenize(text):
|
75 |
chars = list(token)
|
|
|
|
|
|
|
|
|
76 |
is_bad = False
|
77 |
start = 0
|
78 |
sub_tokens = []
|
|
|
88 |
if cur_substr is None:
|
89 |
is_bad = True
|
90 |
break
|
91 |
+
try:
|
92 |
+
cur_substr = cur_substr.decode()
|
93 |
+
except UnicodeDecodeError:
|
94 |
+
cur_substr = str(cur_substr)
|
95 |
+
sub_tokens.append(cur_substr)
|
96 |
start = end
|
|
|
97 |
if is_bad:
|
98 |
output_tokens.append(self.unk_token)
|
99 |
else:
|
|
|
108 |
|
109 |
model_input_names = ["input_ids", "attention_mask"]
|
110 |
|
111 |
+
def __init__(self, vocab_file, bos_token="<s>", eos_token="<s>", unk_token="<s>", **kwargs):
|
112 |
if not os.path.isfile(vocab_file):
|
113 |
raise ValueError(
|
114 |
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
|
|
|
127 |
self.decoder = {v: k for k, v in vocab.items()}
|
128 |
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.encoder, unk_token=str(unk_token))
|
129 |
self._added_tokens_decoder = {0: AddedToken(str(bos_token))}
|
130 |
+
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
|
131 |
|
132 |
@property
|
133 |
def vocab_size(self):
|
|
|
143 |
|
144 |
def _convert_token_to_id(self, token):
|
145 |
"""Converts a token (byte) to an id using the vocab."""
|
146 |
+
if token.startswith("b'\\"):
|
147 |
+
token = eval(token)
|
148 |
+
elif not isinstance(token, bytes):
|
149 |
token = token.encode("utf-8", errors="replace")
|
150 |
return self.encoder.get(token, self.unk_token_id)
|
151 |
|