cognitivess commited on
Commit
17b367c
·
verified ·
1 Parent(s): a3e9b32

Update cognitivess_model/tokenization_cognitivess.py

Browse files
cognitivess_model/tokenization_cognitivess.py CHANGED
@@ -1,8 +1,19 @@
1
  from transformers import PreTrainedTokenizer
 
 
2
 
3
  class CognitivessTokenizer(PreTrainedTokenizer):
4
- def __init__(self, *args, **kwargs):
5
  super().__init__(*args, **kwargs)
 
 
 
 
 
 
 
 
 
6
 
7
  @property
8
  def vocab_size(self):
@@ -24,12 +35,23 @@ class CognitivessTokenizer(PreTrainedTokenizer):
24
  return " ".join(tokens)
25
 
26
  def save_vocabulary(self, save_directory):
 
 
27
  vocab_file = os.path.join(save_directory, "vocab.json")
28
- with open(vocab_file, "w", encoding="utf-8") as f:
29
- json.dump(self.encoder, f, ensure_ascii=False)
30
- return (vocab_file,)
 
 
 
 
 
31
 
32
  def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
33
  bos_token_id = [self.bos_token_id]
34
  eos_token_id = [self.eos_token_id]
35
  return bos_token_id + token_ids_0 + eos_token_id
 
 
 
 
 
1
  from transformers import PreTrainedTokenizer
2
+ import json
3
+ import os
4
 
5
  class CognitivessTokenizer(PreTrainedTokenizer):
6
+ def __init__(self, vocab_file, merges_file, *args, **kwargs):
7
  super().__init__(*args, **kwargs)
8
+ self.vocab_file = vocab_file
9
+ self.merges_file = merges_file
10
+ self.encoder = self.load_vocab(vocab_file)
11
+ self.decoder = {v: k for k, v in self.encoder.items()}
12
+
13
+ with open(merges_file, encoding="utf-8") as merges_handle:
14
+ bpe_data = merges_handle.read()
15
+ bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
16
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
17
 
18
  @property
19
  def vocab_size(self):
 
35
  return " ".join(tokens)
36
 
37
  def save_vocabulary(self, save_directory):
38
+ if not os.path.isdir(save_directory):
39
+ os.makedirs(save_directory)
40
  vocab_file = os.path.join(save_directory, "vocab.json")
41
+ merges_file = os.path.join(save_directory, "merges.txt")
42
+
43
+ with open(vocab_file, "w", encoding="utf-8") as vocab_handle:
44
+ json.dump(self.encoder, vocab_handle, ensure_ascii=False)
45
+ with open(merges_file, "w", encoding="utf-8") as merges_handle:
46
+ merges_handle.write("\n".join(" ".join(pair) for pair in self.bpe_ranks.keys()))
47
+
48
+ return (vocab_file, merges_file)
49
 
50
  def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
51
  bos_token_id = [self.bos_token_id]
52
  eos_token_id = [self.eos_token_id]
53
  return bos_token_id + token_ids_0 + eos_token_id
54
+
55
+ def load_vocab(self, vocab_file):
56
+ with open(vocab_file, "r", encoding="utf-8") as f:
57
+ return json.load(f)