p6oz5qw / models /tokenizer.py
pruning's picture
Duplicate from pruning/omega-o1-v3
1ada072 verified
from typing import List
from torchtune.modules.tokenizers import TikTokenTokenizer
from torchtune.modules.tokenizers._utils import _split_long_repetitions
from torchtune.modules.tokenizers._tiktoken import (
MAX_ENCODE_CHARS,
MAX_NO_WHITESPACE_CHARS,
ALL_SPECIAL_TOKENS,
)
# use special tokens from TikTokenTokenizer, add some for MM delimiters
START_IMAGE = "<|start_image|>"
END_IMAGE = "<|end_image|>"
START_VIDEO = "<|start_video|>"
END_VIDEO = "<|end_video|>"
START_AUDIO = "<|start_audio|>"
END_AUDIO = "<|end_audio|>"
A2A_SPECIAL_TOKENS = ALL_SPECIAL_TOKENS[:-2] + [
START_IMAGE,
END_IMAGE,
START_VIDEO,
END_VIDEO,
START_AUDIO,
END_AUDIO,
] + ALL_SPECIAL_TOKENS[-2:]
# override to allow START_IMAGE, END_IMAGE to be encoded
class A2ATokenizer(TikTokenTokenizer):
def encode(
self,
text: str,
add_bos: bool,
add_eos: bool,
) -> List[int]:
"""
Encode a string into a list of token ids. Assumes that the string
contains no special tokens.
Args:
text (str): The string to encode.
add_bos (bool): Whether to add the beginning of sequence token.
add_eos (bool): Whether to add the end of sequence token.
Returns:
List[int]: The list of token ids.
"""
substrs: List[str] = []
tokens = []
for i in range(0, len(text), MAX_ENCODE_CHARS):
substr = text[i : i + MAX_ENCODE_CHARS]
# See https://github.com/openai/tiktoken/issues/195
sliced_substr = _split_long_repetitions(substr, MAX_NO_WHITESPACE_CHARS)
substrs.extend(sliced_substr)
for substr in substrs:
# allowed_special and disallowed_special are used by tiktoken to define
# how special tokens are encoded. Our setting here is to encode any
# special token as regular text and prevent tiktoken from raising errors.
# This means we should only call encode on strings not containing special tokens.
tokens.extend(
self.tt_model.encode(
substr,
allowed_special=set([
START_IMAGE,
END_IMAGE,
START_VIDEO,
END_VIDEO,
START_AUDIO,
END_AUDIO,
]),
disallowed_special=(),
)
)
if add_bos:
tokens.insert(0, self.bos_id)
if add_eos:
tokens.append(self.eos_id)
return tokens
def a2a_tokenizer(path: str) -> TikTokenTokenizer:
tiktoken = A2ATokenizer(path, all_special_tokens=A2A_SPECIAL_TOKENS)
tiktoken.pad_id = 0
return tiktoken