Upload tokenization_dart.py
Browse files- tokenization_dart.py +10 -23
tokenization_dart.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
import logging
|
2 |
-
import
|
3 |
-
from typing import Dict, List
|
4 |
-
from pydantic.dataclasses import dataclass
|
5 |
|
6 |
from transformers import PreTrainedTokenizerFast
|
7 |
from tokenizers.decoders import Decoder
|
@@ -39,35 +37,24 @@ PROMPT_TEMPLATE = (
|
|
39 |
"{{ '</character>' }}"
|
40 |
|
41 |
"{{ '<general>' }}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
"{% if 'general' not in messages or messages['general'] is none %}"
|
43 |
"{{ '' }}"
|
44 |
"{% else %}"
|
45 |
"{{ messages['general'] }}"
|
46 |
"{% endif %}"
|
|
|
47 |
).strip()
|
48 |
# fmt: on
|
49 |
|
50 |
|
51 |
-
@dataclass
|
52 |
-
class Category:
|
53 |
-
name: str
|
54 |
-
bos_token_id: int
|
55 |
-
eos_token_id: int
|
56 |
-
|
57 |
-
|
58 |
-
@dataclass
|
59 |
-
class TagCategoryConfig:
|
60 |
-
categories: Dict[str, Category]
|
61 |
-
category_to_token_ids: Dict[str, List[int]]
|
62 |
-
|
63 |
-
|
64 |
-
def load_tag_category_config(config_json: str):
|
65 |
-
with open(config_json, "rb") as file:
|
66 |
-
config: TagCategoryConfig = TagCategoryConfig(**json.loads(file.read()))
|
67 |
-
|
68 |
-
return config
|
69 |
-
|
70 |
-
|
71 |
class DartDecoder:
|
72 |
def __init__(self, special_tokens: List[str]):
|
73 |
self.special_tokens = list(special_tokens)
|
|
|
1 |
import logging
|
2 |
+
from typing import List
|
|
|
|
|
3 |
|
4 |
from transformers import PreTrainedTokenizerFast
|
5 |
from tokenizers.decoders import Decoder
|
|
|
37 |
"{{ '</character>' }}"
|
38 |
|
39 |
"{{ '<general>' }}"
|
40 |
+
# length token
|
41 |
+
"{% if 'length' not in messages or messages['length'] is none %}"
|
42 |
+
"{{ '<|long|>' }}"
|
43 |
+
"{% else %}"
|
44 |
+
"{{ messages['length'] }}"
|
45 |
+
"{% endif %}"
|
46 |
+
|
47 |
+
# general token
|
48 |
"{% if 'general' not in messages or messages['general'] is none %}"
|
49 |
"{{ '' }}"
|
50 |
"{% else %}"
|
51 |
"{{ messages['general'] }}"
|
52 |
"{% endif %}"
|
53 |
+
"{{ '<|input_end|>' }}"
|
54 |
).strip()
|
55 |
# fmt: on
|
56 |
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
class DartDecoder:
|
59 |
def __init__(self, special_tokens: List[str]):
|
60 |
self.special_tokens = list(special_tokens)
|