cognitivess
commited on
Commit
•
cab70b7
1
Parent(s):
98759b5
Update cognitivess_model/tokenization_cognitivess.py
Browse files
cognitivess_model/tokenization_cognitivess.py
CHANGED
@@ -3,8 +3,8 @@ import json
|
|
3 |
import os
|
4 |
|
5 |
class CognitivessTokenizer(PreTrainedTokenizer):
|
6 |
-
def __init__(self, vocab_file, merges_file,
|
7 |
-
super().__init__(
|
8 |
self.vocab_file = vocab_file
|
9 |
self.merges_file = merges_file
|
10 |
self.encoder = self.load_vocab(vocab_file)
|
@@ -15,6 +15,12 @@ class CognitivessTokenizer(PreTrainedTokenizer):
|
|
15 |
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
|
16 |
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
@property
|
19 |
def vocab_size(self):
|
20 |
return len(self.encoder)
|
@@ -54,4 +60,4 @@ class CognitivessTokenizer(PreTrainedTokenizer):
|
|
54 |
|
55 |
def load_vocab(self, vocab_file):
|
56 |
with open(vocab_file, "r", encoding="utf-8") as f:
|
57 |
-
return json.load(f)
|
|
|
3 |
import os
|
4 |
|
5 |
class CognitivessTokenizer(PreTrainedTokenizer):
|
6 |
+
def __init__(self, vocab_file, merges_file, **kwargs):
|
7 |
+
super().__init__(**kwargs)
|
8 |
self.vocab_file = vocab_file
|
9 |
self.merges_file = merges_file
|
10 |
self.encoder = self.load_vocab(vocab_file)
|
|
|
15 |
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
|
16 |
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
17 |
|
18 |
+
@classmethod
|
19 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
|
20 |
+
vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
|
21 |
+
merges_file = os.path.join(pretrained_model_name_or_path, "merges.txt")
|
22 |
+
return cls(vocab_file, merges_file, **kwargs)
|
23 |
+
|
24 |
@property
|
25 |
def vocab_size(self):
|
26 |
return len(self.encoder)
|
|
|
60 |
|
61 |
def load_vocab(self, vocab_file):
|
62 |
with open(vocab_file, "r", encoding="utf-8") as f:
|
63 |
+
return json.load(f)
|