|
from typing import List, Optional, Union, Dict, Tuple, Any |
|
import os |
|
from functools import cached_property |
|
|
|
from transformers import PreTrainedTokenizerFast |
|
from transformers.tokenization_utils_base import TruncationStrategy, PaddingStrategy |
|
from tokenizers import Tokenizer, processors |
|
from tokenizers.pre_tokenizers import WhitespaceSplit |
|
from tokenizers.processors import TemplateProcessing |
|
import torch |
|
from hangul_romanize import Transliter |
|
from hangul_romanize.rule import academic |
|
import cutlet |
|
|
|
from TTS.tts.layers.xtts.tokenizer import (multilingual_cleaners, basic_cleaners, |
|
chinese_transliterate, korean_transliterate, |
|
japanese_cleaners) |
|
|
|
class XTTSTokenizerFast(PreTrainedTokenizerFast): |
|
""" |
|
Fast Tokenizer implementation for XTTS model using HuggingFace's PreTrainedTokenizerFast |
|
""" |
|
def __init__( |
|
self, |
|
vocab_file: str = None, |
|
tokenizer_object: Optional[Tokenizer] = None, |
|
unk_token: str = "[UNK]", |
|
pad_token: str = "[PAD]", |
|
bos_token: str = "[START]", |
|
eos_token: str = "[STOP]", |
|
clean_up_tokenization_spaces: bool = True, |
|
**kwargs |
|
): |
|
if tokenizer_object is None and vocab_file is not None: |
|
tokenizer_object = Tokenizer.from_file(vocab_file) |
|
|
|
if tokenizer_object is not None: |
|
|
|
tokenizer_object.pre_tokenizer = WhitespaceSplit() |
|
tokenizer_object.enable_padding( |
|
direction='right', |
|
pad_id=tokenizer_object.token_to_id(pad_token) or 0, |
|
pad_token=pad_token |
|
) |
|
tokenizer_object.post_processor = TemplateProcessing( |
|
single=f"{bos_token} $A {eos_token}", |
|
special_tokens=[ |
|
(bos_token, tokenizer_object.token_to_id(bos_token)), |
|
(eos_token, tokenizer_object.token_to_id(eos_token)), |
|
], |
|
) |
|
|
|
super().__init__( |
|
tokenizer_object=tokenizer_object, |
|
unk_token=unk_token, |
|
pad_token=pad_token, |
|
bos_token=bos_token, |
|
eos_token=eos_token, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
**kwargs |
|
) |
|
|
|
|
|
self.char_limits = { |
|
"en": 250, "de": 253, "fr": 273, "es": 239, |
|
"it": 213, "pt": 203, "pl": 224, "zh": 82, |
|
"ar": 166, "cs": 186, "ru": 182, "nl": 251, |
|
"tr": 226, "ja": 71, "hu": 224, "ko": 95, |
|
} |
|
|
|
|
|
self._katsu = None |
|
self._korean_transliter = Transliter(academic) |
|
|
|
@cached_property |
|
def katsu(self): |
|
if self._katsu is None: |
|
self._katsu = cutlet.Cutlet() |
|
return self._katsu |
|
|
|
def check_input_length(self, text: str, lang: str): |
|
"""Check if input text length is within limits for language""" |
|
lang = lang.split("-")[0] |
|
limit = self.char_limits.get(lang, 250) |
|
if len(text) > limit: |
|
print(f"Warning: Text length exceeds {limit} char limit for '{lang}', may cause truncation.") |
|
|
|
def preprocess_text(self, text: str, lang: str) -> str: |
|
"""Apply text preprocessing for language""" |
|
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", |
|
"nl", "pl", "pt", "ru", "tr", "zh", "ko"}: |
|
text = multilingual_cleaners(text, lang) |
|
if lang == "zh": |
|
text = chinese_transliterate(text) |
|
if lang == "ko": |
|
text = korean_transliterate(text) |
|
elif lang == "ja": |
|
text = japanese_cleaners(text, self.katsu) |
|
else: |
|
text = basic_cleaners(text) |
|
return text |
|
|
|
def _batch_encode_plus( |
|
self, |
|
batch_text_or_text_pairs, |
|
add_special_tokens: bool = True, |
|
padding_strategy = PaddingStrategy.DO_NOT_PAD, |
|
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE, |
|
max_length: Optional[int] = 402, |
|
stride: int = 0, |
|
is_split_into_words: bool = False, |
|
pad_to_multiple_of: Optional[int] = None, |
|
return_tensors: Optional[str] = None, |
|
return_token_type_ids: Optional[bool] = None, |
|
return_attention_mask: Optional[bool] = None, |
|
return_overflowing_tokens: bool = False, |
|
return_special_tokens_mask: bool = False, |
|
return_offsets_mapping: bool = False, |
|
return_length: bool = False, |
|
verbose: bool = True, |
|
**kwargs |
|
) -> Dict[str, Any]: |
|
""" |
|
Override batch encoding to handle language-specific preprocessing |
|
""" |
|
lang = kwargs.pop("lang", ["en"] * len(batch_text_or_text_pairs)) |
|
if isinstance(lang, str): |
|
lang = [lang] * len(batch_text_or_text_pairs) |
|
|
|
|
|
processed_texts = [] |
|
for text, text_lang in zip(batch_text_or_text_pairs, lang): |
|
if isinstance(text, str): |
|
|
|
self.check_input_length(text, text_lang) |
|
processed_text = self.preprocess_text(text, text_lang) |
|
|
|
|
|
lang_code = "zh-cn" if text_lang == "zh" else text_lang |
|
processed_text = f"[{lang_code}]{processed_text}" |
|
processed_text = processed_text.replace(" ", "[SPACE]") |
|
|
|
processed_texts.append(processed_text) |
|
else: |
|
processed_texts.append(text) |
|
|
|
|
|
return super()._batch_encode_plus( |
|
processed_texts, |
|
add_special_tokens=add_special_tokens, |
|
padding_strategy=padding_strategy, |
|
truncation_strategy=truncation_strategy, |
|
max_length=max_length, |
|
stride=stride, |
|
is_split_into_words=is_split_into_words, |
|
pad_to_multiple_of=pad_to_multiple_of, |
|
return_tensors=return_tensors, |
|
return_token_type_ids=return_token_type_ids, |
|
return_attention_mask=return_attention_mask, |
|
return_overflowing_tokens=return_overflowing_tokens, |
|
return_special_tokens_mask=return_special_tokens_mask, |
|
return_offsets_mapping=return_offsets_mapping, |
|
return_length=return_length, |
|
verbose=verbose, |
|
**kwargs |
|
) |
|
|
|
def __call__( |
|
self, |
|
text: Union[str, List[str]], |
|
lang: Union[str, List[str]] = "en", |
|
add_special_tokens: bool = True, |
|
padding: Union[bool, str, PaddingStrategy] = True, |
|
truncation: Union[bool, str, TruncationStrategy] = True, |
|
max_length: Optional[int] = 402, |
|
stride: int = 0, |
|
return_tensors: Optional[str] = None, |
|
return_token_type_ids: Optional[bool] = None, |
|
return_attention_mask: Optional[bool] = True, |
|
**kwargs |
|
): |
|
""" |
|
Main tokenization method |
|
Args: |
|
text: Text or list of texts to tokenize |
|
lang: Language code or list of language codes corresponding to each text |
|
add_special_tokens: Whether to add special tokens |
|
padding: Padding strategy (default True) |
|
truncation: Truncation strategy (default True) |
|
max_length: Maximum length |
|
stride: Stride for truncation |
|
return_tensors: Format of output tensors ("pt" for PyTorch) |
|
return_token_type_ids: Whether to return token type IDs |
|
return_attention_mask: Whether to return attention mask (default True) |
|
""" |
|
|
|
if isinstance(text, str): |
|
text = [text] |
|
if isinstance(lang, str): |
|
lang = [lang] |
|
|
|
|
|
if len(text) != len(lang): |
|
raise ValueError(f"Number of texts ({len(text)}) must match number of language codes ({len(lang)})") |
|
|
|
|
|
if isinstance(padding, bool): |
|
padding_strategy = PaddingStrategy.MAX_LENGTH if padding else PaddingStrategy.DO_NOT_PAD |
|
else: |
|
padding_strategy = PaddingStrategy(padding) |
|
|
|
|
|
if isinstance(truncation, bool): |
|
truncation_strategy = TruncationStrategy.LONGEST_FIRST if truncation else TruncationStrategy.DO_NOT_TRUNCATE |
|
else: |
|
truncation_strategy = TruncationStrategy(truncation) |
|
|
|
|
|
encoded = self._batch_encode_plus( |
|
text, |
|
add_special_tokens=add_special_tokens, |
|
padding_strategy=padding_strategy, |
|
truncation_strategy=truncation_strategy, |
|
max_length=max_length, |
|
stride=stride, |
|
return_tensors=return_tensors, |
|
return_token_type_ids=return_token_type_ids, |
|
return_attention_mask=return_attention_mask, |
|
lang=lang, |
|
**kwargs |
|
) |
|
|
|
return encoded |