Add `ChatGLMTokenizerFast` and `ChatGLMTokenizerConverter`

#12
by chielo - opened
Files changed (2) hide show
  1. tokenization_chatglm.py +184 -6
  2. tokenizer_config.json +2 -2
tokenization_chatglm.py CHANGED
@@ -3,13 +3,36 @@ import os
3
  import re
4
  from typing import List, Optional, Union, Dict
5
  from sentencepiece import SentencePieceProcessor
6
- from transformers import PreTrainedTokenizer
 
 
 
 
 
 
 
 
7
  from transformers.utils import logging, PaddingStrategy
8
  from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
9
 
10
 
11
  logger = logging.get_logger(__name__)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  class SPTokenizer:
15
  def __init__(self, model_path: str):
@@ -24,8 +47,7 @@ class SPTokenizer:
24
  self.pad_id: int = self.sp_model.unk_id()
25
  assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
26
 
27
- role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
28
- special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
29
  self.special_tokens = {}
30
  self.index_special_tokens = {}
31
  for token in special_tokens:
@@ -86,7 +108,7 @@ class SPTokenizer:
86
  """Converts an index (integer) in a token (str) using the vocab."""
87
  if index in self.index_special_tokens:
88
  return self.index_special_tokens[index]
89
- if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0 or index > self.sp_model.vocab_size():
90
  return ""
91
  return self.sp_model.IdToPiece(index)
92
 
@@ -216,8 +238,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
216
  return (vocab_file,)
217
 
218
  def get_prefix_tokens(self):
219
- prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
220
- return prefix_tokens
221
 
222
  def build_single_message(self, role, metadata, message):
223
  assert role in ["system", "user", "assistant", "observation"], role
@@ -326,3 +347,160 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
326
  encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
327
 
328
  return encoded_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import re
4
  from typing import List, Optional, Union, Dict
5
  from sentencepiece import SentencePieceProcessor
6
+ from transformers import AddedToken, PreTrainedTokenizer, PreTrainedTokenizerFast
7
+ from transformers.convert_slow_tokenizer import (
8
+ SLOW_TO_FAST_CONVERTERS,
9
+ SpmConverter,
10
+ decoders,
11
+ normalizers,
12
+ pre_tokenizers,
13
+ processors,
14
+ )
15
  from transformers.utils import logging, PaddingStrategy
16
  from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
17
 
18
 
19
  logger = logging.get_logger(__name__)
20
 
21
+ ADDITIONAL_SPECIAL_TOKENS = [
22
+ "[MASK]",
23
+ "[gMASK]",
24
+ "[sMASK]",
25
+ "<!sop!>",
26
+ "<!eop!>",
27
+ "<|system|>",
28
+ "<|user|>",
29
+ "<|assistant|>",
30
+ "<|observation|>",
31
+ ]
32
+ PREFIX_TOKENS = ["[gMASK]", "<!sop!>"]
33
+
34
+ DUMMY_PREFIX_INDICATOR_FOR_FAST = "<!dummy-prefix!>"
35
+
36
 
37
  class SPTokenizer:
38
  def __init__(self, model_path: str):
 
47
  self.pad_id: int = self.sp_model.unk_id()
48
  assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
49
 
50
+ special_tokens = ADDITIONAL_SPECIAL_TOKENS
 
51
  self.special_tokens = {}
52
  self.index_special_tokens = {}
53
  for token in special_tokens:
 
108
  """Converts an index (integer) in a token (str) using the vocab."""
109
  if index in self.index_special_tokens:
110
  return self.index_special_tokens[index]
111
+ if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0 or index >= self.sp_model.vocab_size():
112
  return ""
113
  return self.sp_model.IdToPiece(index)
114
 
 
238
  return (vocab_file,)
239
 
240
  def get_prefix_tokens(self):
241
+ return list(map(self.get_command, PREFIX_TOKENS))
 
242
 
243
  def build_single_message(self, role, metadata, message):
244
  assert role in ["system", "user", "assistant", "observation"], role
 
347
  encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
348
 
349
  return encoded_inputs
350
+
351
+
352
+ class ChatGLMTokenizerFast(PreTrainedTokenizerFast):
353
+ # multiple breaking changes, no backward-compatibility
354
+ slow_tokenizer_class = ChatGLMTokenizer
355
+ vocab_files_names = {
356
+ **ChatGLMTokenizer.vocab_files_names,
357
+ **PreTrainedTokenizerFast.vocab_files_names,
358
+ }
359
+
360
+ def __init__(self, **kwargs):
361
+ kwargs.setdefault("clean_up_tokenization_spaces", False)
362
+ kwargs.setdefault("bos_token", "<s>")
363
+ kwargs.setdefault("eos_token", "</s>")
364
+ kwargs.setdefault("unk_token", "<unk>")
365
+ kwargs.setdefault("pad_token", "<unk>")
366
+ super().__init__(**kwargs)
367
+
368
+ @property
369
+ def dummy_prefix_indicator(self):
370
+ return DUMMY_PREFIX_INDICATOR_FOR_FAST
371
+
372
+ @property
373
+ def can_save_slow_tokenizer(self) -> bool:
374
+ # multiple breaking changes
375
+ return False
376
+
377
+ def save_pretrained(self, *args, **kwargs):
378
+ if not self.can_save_slow_tokenizer:
379
+ logger.warning(
380
+ f"{type(self).__name__} does not support saving slow tokenizer. "
381
+ "Saving it at the same directory may break the original tokenizer. "
382
+ "Please keep a backup beforehand."
383
+ )
384
+
385
+ return super().save_pretrained(*args, **kwargs)
386
+
387
+ def build_single_message_prompt(self, role, metadata, message):
388
+ assert role in ["system", "user", "assistant", "observation"], role
389
+ return (
390
+ f"<|{role}|>"
391
+ f"{self.dummy_prefix_indicator}{metadata}\n"
392
+ f"{self.dummy_prefix_indicator}{message}"
393
+ )
394
+
395
+ def build_chat_prompt(self, query, history=None, role="user", metadata=""):
396
+ inputs = []
397
+
398
+ for item in history or []:
399
+ content = item["content"]
400
+
401
+ if item["role"] == "system" and "tools" in item:
402
+ content += "\n" + json.dumps(
403
+ item["tools"], indent=4, ensure_ascii=False
404
+ )
405
+
406
+ inputs.append(
407
+ self.build_single_message_prompt(
408
+ item["role"], item.get("metadata", ""), content
409
+ )
410
+ )
411
+
412
+ inputs.append(self.build_single_message_prompt(role, metadata, query))
413
+ inputs.append("<|assistant|>")
414
+
415
+ return "".join(inputs)
416
+
417
+ def build_chat_input(self, *args, **kwargs):
418
+ return self.batch_encode_plus(
419
+ [self.build_chat_prompt(*args, **kwargs)],
420
+ return_tensors="pt",
421
+ )
422
+
423
+
424
+ ChatGLMTokenizer.register_for_auto_class()
425
+ ChatGLMTokenizerFast.register_for_auto_class()
426
+
427
+
428
+ class ChatGLMTokenizerConverter(SpmConverter):
429
+ handle_byte_fallback = True
430
+
431
+ def normalizer(self, proto):
432
+ return normalizers.Sequence(
433
+ [
434
+ normalizers.Replace(
435
+ pattern=DUMMY_PREFIX_INDICATOR_FOR_FAST, content="▁"
436
+ ),
437
+ normalizers.Replace(pattern=" ", content="▁"),
438
+ ]
439
+ )
440
+
441
+ def pre_tokenizer(self, replacement, add_prefix_space):
442
+ # NOTE: don't use Metaspace, it won't merge spaces into one token
443
+ # without Metaspace: " " => ["▁▁"]
444
+ # with Metaspace: " " => ["▁", "▁"]
445
+ return pre_tokenizers.Split(DUMMY_PREFIX_INDICATOR_FOR_FAST, "merged_with_next")
446
+
447
+ def decoder(self, replacement, add_prefix_space):
448
+ return decoders.Sequence(
449
+ [
450
+ decoders.ByteFallback(),
451
+ decoders.Metaspace(replacement="▁", add_prefix_space=True),
452
+ ]
453
+ )
454
+
455
+ def tokenizer(self, proto):
456
+ tokenizer = super().tokenizer(proto)
457
+
458
+ tokenizer.model.byte_fallback = True
459
+
460
+ assert tokenizer.token_to_id("<unk>") == 0
461
+ assert tokenizer.token_to_id("<s>") == 1
462
+ assert tokenizer.token_to_id("</s>") == 2
463
+ special_tokens = [
464
+ "<unk>",
465
+ "<s>",
466
+ "</s>",
467
+ *ADDITIONAL_SPECIAL_TOKENS,
468
+ ]
469
+
470
+ tokenizer.add_special_tokens(
471
+ [AddedToken(token, special=True) for token in special_tokens]
472
+ )
473
+
474
+ return tokenizer
475
+
476
+ def converted(self):
477
+ tokenizer = super().converted()
478
+
479
+ # Post processors
480
+ prefix_token_ids = list(map(tokenizer.token_to_id, PREFIX_TOKENS))
481
+ assert all(i is not None for i in prefix_token_ids)
482
+ prefix_template = " ".join(PREFIX_TOKENS)
483
+
484
+ template_special_tokens = list(frozenset(zip(PREFIX_TOKENS, prefix_token_ids)))
485
+
486
+ if "</s>" not in PREFIX_TOKENS:
487
+ eos_token_id = tokenizer.token_to_id("</s>")
488
+ assert eos_token_id is not None
489
+ template_special_tokens.append(("</s>", eos_token_id))
490
+
491
+ post = processors.TemplateProcessing(
492
+ single=f"{prefix_template} $A",
493
+ pair=f"{prefix_template} $A $B:1 </s>:1",
494
+ special_tokens=template_special_tokens,
495
+ )
496
+ if tokenizer.post_processor is None:
497
+ tokenizer.post_processor = post
498
+ else:
499
+ tokenizer.post_processor = processors.Sequence(
500
+ [tokenizer.post_processor, post]
501
+ )
502
+
503
+ return tokenizer
504
+
505
+
506
+ SLOW_TO_FAST_CONVERTERS[ChatGLMTokenizer.__name__] = ChatGLMTokenizerConverter
tokenizer_config.json CHANGED
@@ -7,7 +7,7 @@
7
  "auto_map": {
8
  "AutoTokenizer": [
9
  "tokenization_chatglm.ChatGLMTokenizer",
10
- null
11
- ]
12
  }
13
  }
 
7
  "auto_map": {
8
  "AutoTokenizer": [
9
  "tokenization_chatglm.ChatGLMTokenizer",
10
+ "tokenization_chatglm.ChatGLMTokenizerFast"
11
+ ]
12
  }
13
  }