|
from typing import List |
|
|
|
from torchtune.modules.tokenizers import TikTokenTokenizer |
|
from torchtune.modules.tokenizers._utils import _split_long_repetitions |
|
from torchtune.modules.tokenizers._tiktoken import ( |
|
MAX_ENCODE_CHARS, |
|
MAX_NO_WHITESPACE_CHARS, |
|
ALL_SPECIAL_TOKENS, |
|
) |
|
|
|
|
|
|
|
START_IMAGE = "<|start_image|>" |
|
END_IMAGE = "<|end_image|>" |
|
START_VIDEO = "<|start_video|>" |
|
END_VIDEO = "<|end_video|>" |
|
START_AUDIO = "<|start_audio|>" |
|
END_AUDIO = "<|end_audio|>" |
|
|
|
A2A_SPECIAL_TOKENS = ALL_SPECIAL_TOKENS[:-2] + [ |
|
START_IMAGE, |
|
END_IMAGE, |
|
START_VIDEO, |
|
END_VIDEO, |
|
START_AUDIO, |
|
END_AUDIO, |
|
] + ALL_SPECIAL_TOKENS[-2:] |
|
|
|
|
|
class A2ATokenizer(TikTokenTokenizer): |
|
def encode( |
|
self, |
|
text: str, |
|
add_bos: bool, |
|
add_eos: bool, |
|
) -> List[int]: |
|
""" |
|
Encode a string into a list of token ids. Assumes that the string |
|
contains no special tokens. |
|
|
|
Args: |
|
text (str): The string to encode. |
|
add_bos (bool): Whether to add the beginning of sequence token. |
|
add_eos (bool): Whether to add the end of sequence token. |
|
|
|
Returns: |
|
List[int]: The list of token ids. |
|
""" |
|
substrs: List[str] = [] |
|
tokens = [] |
|
for i in range(0, len(text), MAX_ENCODE_CHARS): |
|
substr = text[i : i + MAX_ENCODE_CHARS] |
|
|
|
sliced_substr = _split_long_repetitions(substr, MAX_NO_WHITESPACE_CHARS) |
|
substrs.extend(sliced_substr) |
|
for substr in substrs: |
|
|
|
|
|
|
|
|
|
tokens.extend( |
|
self.tt_model.encode( |
|
substr, |
|
allowed_special=set([ |
|
START_IMAGE, |
|
END_IMAGE, |
|
START_VIDEO, |
|
END_VIDEO, |
|
START_AUDIO, |
|
END_AUDIO, |
|
]), |
|
disallowed_special=(), |
|
) |
|
) |
|
if add_bos: |
|
tokens.insert(0, self.bos_id) |
|
if add_eos: |
|
tokens.append(self.eos_id) |
|
return tokens |
|
|
|
|
|
def a2a_tokenizer(path: str) -> TikTokenTokenizer: |
|
tiktoken = A2ATokenizer(path, all_special_tokens=A2A_SPECIAL_TOKENS) |
|
tiktoken.pad_id = 0 |
|
return tiktoken |