from typing import List, Optional, Union, Dict, Tuple, Any import os from functools import cached_property from transformers import PreTrainedTokenizerFast from transformers.tokenization_utils_base import TruncationStrategy, PaddingStrategy from tokenizers import Tokenizer, processors from tokenizers.pre_tokenizers import WhitespaceSplit from tokenizers.processors import TemplateProcessing import torch from hangul_romanize import Transliter from hangul_romanize.rule import academic import cutlet from TTS.tts.layers.xtts.tokenizer import (multilingual_cleaners, basic_cleaners, chinese_transliterate, korean_transliterate, japanese_cleaners) class XTTSTokenizerFast(PreTrainedTokenizerFast): """ Fast Tokenizer implementation for XTTS model using HuggingFace's PreTrainedTokenizerFast """ def __init__( self, vocab_file: str = None, tokenizer_object: Optional[Tokenizer] = None, unk_token: str = "[UNK]", pad_token: str = "[PAD]", bos_token: str = "[START]", eos_token: str = "[STOP]", clean_up_tokenization_spaces: bool = True, **kwargs ): if tokenizer_object is None and vocab_file is not None: tokenizer_object = Tokenizer.from_file(vocab_file) if tokenizer_object is not None: # Configure the tokenizer tokenizer_object.pre_tokenizer = WhitespaceSplit() tokenizer_object.enable_padding( direction='right', pad_id=tokenizer_object.token_to_id(pad_token) or 0, pad_token=pad_token ) tokenizer_object.post_processor = TemplateProcessing( single=f"{bos_token} $A {eos_token}", special_tokens=[ (bos_token, tokenizer_object.token_to_id(bos_token)), (eos_token, tokenizer_object.token_to_id(eos_token)), ], ) super().__init__( tokenizer_object=tokenizer_object, unk_token=unk_token, pad_token=pad_token, bos_token=bos_token, eos_token=eos_token, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs ) # Character limits per language self.char_limits = { "en": 250, "de": 253, "fr": 273, "es": 239, "it": 213, "pt": 203, "pl": 224, "zh": 82, "ar": 166, "cs": 186, "ru": 182, "nl": 251, "tr": 226, "ja": 71, "hu": 224, "ko": 95, } # Initialize language tools self._katsu = None self._korean_transliter = Transliter(academic) @cached_property def katsu(self): if self._katsu is None: self._katsu = cutlet.Cutlet() return self._katsu def check_input_length(self, text: str, lang: str): """Check if input text length is within limits for language""" lang = lang.split("-")[0] # remove region limit = self.char_limits.get(lang, 250) if len(text) > limit: print(f"Warning: Text length exceeds {limit} char limit for '{lang}', may cause truncation.") def preprocess_text(self, text: str, lang: str) -> str: """Apply text preprocessing for language""" if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko"}: text = multilingual_cleaners(text, lang) if lang == "zh": text = chinese_transliterate(text) if lang == "ko": text = korean_transliterate(text) elif lang == "ja": text = japanese_cleaners(text, self.katsu) else: text = basic_cleaners(text) return text def _batch_encode_plus( self, batch_text_or_text_pairs, add_special_tokens: bool = True, padding_strategy = PaddingStrategy.DO_NOT_PAD, truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE, max_length: Optional[int] = 402, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[str] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs ) -> Dict[str, Any]: """ Override batch encoding to handle language-specific preprocessing """ lang = kwargs.pop("lang", ["en"] * len(batch_text_or_text_pairs)) if isinstance(lang, str): lang = [lang] * len(batch_text_or_text_pairs) # Preprocess each text in the batch with its corresponding language processed_texts = [] for text, text_lang in zip(batch_text_or_text_pairs, lang): if isinstance(text, str): # Check length and preprocess self.check_input_length(text, text_lang) processed_text = self.preprocess_text(text, text_lang) # Format text with language tag and spaces lang_code = "zh-cn" if text_lang == "zh" else text_lang processed_text = f"[{lang_code}]{processed_text}" processed_text = processed_text.replace(" ", "[SPACE]") processed_texts.append(processed_text) else: processed_texts.append(text) # Call the parent class's encoding method with processed texts return super()._batch_encode_plus( processed_texts, add_special_tokens=add_special_tokens, padding_strategy=padding_strategy, truncation_strategy=truncation_strategy, max_length=max_length, stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, verbose=verbose, **kwargs ) def __call__( self, text: Union[str, List[str]], lang: Union[str, List[str]] = "en", add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = True, # Changed default to True truncation: Union[bool, str, TruncationStrategy] = True, # Changed default to True max_length: Optional[int] = 402, stride: int = 0, return_tensors: Optional[str] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = True, # Changed default to True **kwargs ): """ Main tokenization method Args: text: Text or list of texts to tokenize lang: Language code or list of language codes corresponding to each text add_special_tokens: Whether to add special tokens padding: Padding strategy (default True) truncation: Truncation strategy (default True) max_length: Maximum length stride: Stride for truncation return_tensors: Format of output tensors ("pt" for PyTorch) return_token_type_ids: Whether to return token type IDs return_attention_mask: Whether to return attention mask (default True) """ # Convert single string to list for batch processing if isinstance(text, str): text = [text] if isinstance(lang, str): lang = [lang] # Ensure text and lang lists have same length if len(text) != len(lang): raise ValueError(f"Number of texts ({len(text)}) must match number of language codes ({len(lang)})") # Convert padding strategy if isinstance(padding, bool): padding_strategy = PaddingStrategy.MAX_LENGTH if padding else PaddingStrategy.DO_NOT_PAD else: padding_strategy = PaddingStrategy(padding) # Convert truncation strategy if isinstance(truncation, bool): truncation_strategy = TruncationStrategy.LONGEST_FIRST if truncation else TruncationStrategy.DO_NOT_TRUNCATE else: truncation_strategy = TruncationStrategy(truncation) # Use the batch encoding method encoded = self._batch_encode_plus( text, add_special_tokens=add_special_tokens, padding_strategy=padding_strategy, truncation_strategy=truncation_strategy, max_length=max_length, stride=stride, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, lang=lang, **kwargs ) return encoded