xtts2-gpt / tokenizer.py
mlinmg's picture
Upload 8 files
7eebd5c verified
from typing import List, Optional, Union, Dict, Tuple, Any
import os
from functools import cached_property
from transformers import PreTrainedTokenizerFast
from transformers.tokenization_utils_base import TruncationStrategy, PaddingStrategy
from tokenizers import Tokenizer, processors
from tokenizers.pre_tokenizers import WhitespaceSplit
from tokenizers.processors import TemplateProcessing
import torch
from hangul_romanize import Transliter
from hangul_romanize.rule import academic
import cutlet
from TTS.tts.layers.xtts.tokenizer import (multilingual_cleaners, basic_cleaners,
chinese_transliterate, korean_transliterate,
japanese_cleaners)
class XTTSTokenizerFast(PreTrainedTokenizerFast):
"""
Fast Tokenizer implementation for XTTS model using HuggingFace's PreTrainedTokenizerFast
"""
def __init__(
self,
vocab_file: str = None,
tokenizer_object: Optional[Tokenizer] = None,
unk_token: str = "[UNK]",
pad_token: str = "[PAD]",
bos_token: str = "[START]",
eos_token: str = "[STOP]",
clean_up_tokenization_spaces: bool = True,
**kwargs
):
if tokenizer_object is None and vocab_file is not None:
tokenizer_object = Tokenizer.from_file(vocab_file)
if tokenizer_object is not None:
# Configure the tokenizer
tokenizer_object.pre_tokenizer = WhitespaceSplit()
tokenizer_object.enable_padding(
direction='right',
pad_id=tokenizer_object.token_to_id(pad_token) or 0,
pad_token=pad_token
)
tokenizer_object.post_processor = TemplateProcessing(
single=f"{bos_token} $A {eos_token}",
special_tokens=[
(bos_token, tokenizer_object.token_to_id(bos_token)),
(eos_token, tokenizer_object.token_to_id(eos_token)),
],
)
super().__init__(
tokenizer_object=tokenizer_object,
unk_token=unk_token,
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs
)
# Character limits per language
self.char_limits = {
"en": 250, "de": 253, "fr": 273, "es": 239,
"it": 213, "pt": 203, "pl": 224, "zh": 82,
"ar": 166, "cs": 186, "ru": 182, "nl": 251,
"tr": 226, "ja": 71, "hu": 224, "ko": 95,
}
# Initialize language tools
self._katsu = None
self._korean_transliter = Transliter(academic)
@cached_property
def katsu(self):
if self._katsu is None:
self._katsu = cutlet.Cutlet()
return self._katsu
def check_input_length(self, text: str, lang: str):
"""Check if input text length is within limits for language"""
lang = lang.split("-")[0] # remove region
limit = self.char_limits.get(lang, 250)
if len(text) > limit:
print(f"Warning: Text length exceeds {limit} char limit for '{lang}', may cause truncation.")
def preprocess_text(self, text: str, lang: str) -> str:
"""Apply text preprocessing for language"""
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it",
"nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
text = multilingual_cleaners(text, lang)
if lang == "zh":
text = chinese_transliterate(text)
if lang == "ko":
text = korean_transliterate(text)
elif lang == "ja":
text = japanese_cleaners(text, self.katsu)
else:
text = basic_cleaners(text)
return text
def _batch_encode_plus(
self,
batch_text_or_text_pairs,
add_special_tokens: bool = True,
padding_strategy = PaddingStrategy.DO_NOT_PAD,
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = 402,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> Dict[str, Any]:
"""
Override batch encoding to handle language-specific preprocessing
"""
lang = kwargs.pop("lang", ["en"] * len(batch_text_or_text_pairs))
if isinstance(lang, str):
lang = [lang] * len(batch_text_or_text_pairs)
# Preprocess each text in the batch with its corresponding language
processed_texts = []
for text, text_lang in zip(batch_text_or_text_pairs, lang):
if isinstance(text, str):
# Check length and preprocess
self.check_input_length(text, text_lang)
processed_text = self.preprocess_text(text, text_lang)
# Format text with language tag and spaces
lang_code = "zh-cn" if text_lang == "zh" else text_lang
processed_text = f"[{lang_code}]{processed_text}"
processed_text = processed_text.replace(" ", "[SPACE]")
processed_texts.append(processed_text)
else:
processed_texts.append(text)
# Call the parent class's encoding method with processed texts
return super()._batch_encode_plus(
processed_texts,
add_special_tokens=add_special_tokens,
padding_strategy=padding_strategy,
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
**kwargs
)
def __call__(
self,
text: Union[str, List[str]],
lang: Union[str, List[str]] = "en",
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = True, # Changed default to True
truncation: Union[bool, str, TruncationStrategy] = True, # Changed default to True
max_length: Optional[int] = 402,
stride: int = 0,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = True, # Changed default to True
**kwargs
):
"""
Main tokenization method
Args:
text: Text or list of texts to tokenize
lang: Language code or list of language codes corresponding to each text
add_special_tokens: Whether to add special tokens
padding: Padding strategy (default True)
truncation: Truncation strategy (default True)
max_length: Maximum length
stride: Stride for truncation
return_tensors: Format of output tensors ("pt" for PyTorch)
return_token_type_ids: Whether to return token type IDs
return_attention_mask: Whether to return attention mask (default True)
"""
# Convert single string to list for batch processing
if isinstance(text, str):
text = [text]
if isinstance(lang, str):
lang = [lang]
# Ensure text and lang lists have same length
if len(text) != len(lang):
raise ValueError(f"Number of texts ({len(text)}) must match number of language codes ({len(lang)})")
# Convert padding strategy
if isinstance(padding, bool):
padding_strategy = PaddingStrategy.MAX_LENGTH if padding else PaddingStrategy.DO_NOT_PAD
else:
padding_strategy = PaddingStrategy(padding)
# Convert truncation strategy
if isinstance(truncation, bool):
truncation_strategy = TruncationStrategy.LONGEST_FIRST if truncation else TruncationStrategy.DO_NOT_TRUNCATE
else:
truncation_strategy = TruncationStrategy(truncation)
# Use the batch encoding method
encoded = self._batch_encode_plus(
text,
add_special_tokens=add_special_tokens,
padding_strategy=padding_strategy,
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
lang=lang,
**kwargs
)
return encoded