cognitivess
commited on
Update cognitivess_model/tokenization_cognitivess.py
Browse files
cognitivess_model/tokenization_cognitivess.py
CHANGED
@@ -1,8 +1,19 @@
|
|
1 |
from transformers import PreTrainedTokenizer
|
|
|
|
|
2 |
|
3 |
class CognitivessTokenizer(PreTrainedTokenizer):
|
4 |
-
def __init__(self, *args, **kwargs):
|
5 |
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
@property
|
8 |
def vocab_size(self):
|
@@ -24,12 +35,23 @@ class CognitivessTokenizer(PreTrainedTokenizer):
|
|
24 |
return " ".join(tokens)
|
25 |
|
26 |
def save_vocabulary(self, save_directory):
|
|
|
|
|
27 |
vocab_file = os.path.join(save_directory, "vocab.json")
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
33 |
bos_token_id = [self.bos_token_id]
|
34 |
eos_token_id = [self.eos_token_id]
|
35 |
return bos_token_id + token_ids_0 + eos_token_id
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import PreTrainedTokenizer
|
2 |
+
import json
|
3 |
+
import os
|
4 |
|
5 |
class CognitivessTokenizer(PreTrainedTokenizer):
|
6 |
+
def __init__(self, vocab_file, merges_file, *args, **kwargs):
|
7 |
super().__init__(*args, **kwargs)
|
8 |
+
self.vocab_file = vocab_file
|
9 |
+
self.merges_file = merges_file
|
10 |
+
self.encoder = self.load_vocab(vocab_file)
|
11 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
12 |
+
|
13 |
+
with open(merges_file, encoding="utf-8") as merges_handle:
|
14 |
+
bpe_data = merges_handle.read()
|
15 |
+
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
|
16 |
+
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
17 |
|
18 |
@property
|
19 |
def vocab_size(self):
|
|
|
35 |
return " ".join(tokens)
|
36 |
|
37 |
def save_vocabulary(self, save_directory):
|
38 |
+
if not os.path.isdir(save_directory):
|
39 |
+
os.makedirs(save_directory)
|
40 |
vocab_file = os.path.join(save_directory, "vocab.json")
|
41 |
+
merges_file = os.path.join(save_directory, "merges.txt")
|
42 |
+
|
43 |
+
with open(vocab_file, "w", encoding="utf-8") as vocab_handle:
|
44 |
+
json.dump(self.encoder, vocab_handle, ensure_ascii=False)
|
45 |
+
with open(merges_file, "w", encoding="utf-8") as merges_handle:
|
46 |
+
merges_handle.write("\n".join(" ".join(pair) for pair in self.bpe_ranks.keys()))
|
47 |
+
|
48 |
+
return (vocab_file, merges_file)
|
49 |
|
50 |
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
51 |
bos_token_id = [self.bos_token_id]
|
52 |
eos_token_id = [self.eos_token_id]
|
53 |
return bos_token_id + token_ids_0 + eos_token_id
|
54 |
+
|
55 |
+
def load_vocab(self, vocab_file):
|
56 |
+
with open(vocab_file, "r", encoding="utf-8") as f:
|
57 |
+
return json.load(f)
|