File size: 9,541 Bytes
7eebd5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
from typing import List, Optional, Union, Dict, Tuple, Any
import os
from functools import cached_property
from transformers import PreTrainedTokenizerFast
from transformers.tokenization_utils_base import TruncationStrategy, PaddingStrategy
from tokenizers import Tokenizer, processors
from tokenizers.pre_tokenizers import WhitespaceSplit
from tokenizers.processors import TemplateProcessing
import torch
from hangul_romanize import Transliter
from hangul_romanize.rule import academic
import cutlet
from TTS.tts.layers.xtts.tokenizer import (multilingual_cleaners, basic_cleaners,
chinese_transliterate, korean_transliterate,
japanese_cleaners)
class XTTSTokenizerFast(PreTrainedTokenizerFast):
"""
Fast Tokenizer implementation for XTTS model using HuggingFace's PreTrainedTokenizerFast
"""
def __init__(
self,
vocab_file: str = None,
tokenizer_object: Optional[Tokenizer] = None,
unk_token: str = "[UNK]",
pad_token: str = "[PAD]",
bos_token: str = "[START]",
eos_token: str = "[STOP]",
clean_up_tokenization_spaces: bool = True,
**kwargs
):
if tokenizer_object is None and vocab_file is not None:
tokenizer_object = Tokenizer.from_file(vocab_file)
if tokenizer_object is not None:
# Configure the tokenizer
tokenizer_object.pre_tokenizer = WhitespaceSplit()
tokenizer_object.enable_padding(
direction='right',
pad_id=tokenizer_object.token_to_id(pad_token) or 0,
pad_token=pad_token
)
tokenizer_object.post_processor = TemplateProcessing(
single=f"{bos_token} $A {eos_token}",
special_tokens=[
(bos_token, tokenizer_object.token_to_id(bos_token)),
(eos_token, tokenizer_object.token_to_id(eos_token)),
],
)
super().__init__(
tokenizer_object=tokenizer_object,
unk_token=unk_token,
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs
)
# Character limits per language
self.char_limits = {
"en": 250, "de": 253, "fr": 273, "es": 239,
"it": 213, "pt": 203, "pl": 224, "zh": 82,
"ar": 166, "cs": 186, "ru": 182, "nl": 251,
"tr": 226, "ja": 71, "hu": 224, "ko": 95,
}
# Initialize language tools
self._katsu = None
self._korean_transliter = Transliter(academic)
@cached_property
def katsu(self):
if self._katsu is None:
self._katsu = cutlet.Cutlet()
return self._katsu
def check_input_length(self, text: str, lang: str):
"""Check if input text length is within limits for language"""
lang = lang.split("-")[0] # remove region
limit = self.char_limits.get(lang, 250)
if len(text) > limit:
print(f"Warning: Text length exceeds {limit} char limit for '{lang}', may cause truncation.")
def preprocess_text(self, text: str, lang: str) -> str:
"""Apply text preprocessing for language"""
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it",
"nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
text = multilingual_cleaners(text, lang)
if lang == "zh":
text = chinese_transliterate(text)
if lang == "ko":
text = korean_transliterate(text)
elif lang == "ja":
text = japanese_cleaners(text, self.katsu)
else:
text = basic_cleaners(text)
return text
def _batch_encode_plus(
self,
batch_text_or_text_pairs,
add_special_tokens: bool = True,
padding_strategy = PaddingStrategy.DO_NOT_PAD,
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = 402,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> Dict[str, Any]:
"""
Override batch encoding to handle language-specific preprocessing
"""
lang = kwargs.pop("lang", ["en"] * len(batch_text_or_text_pairs))
if isinstance(lang, str):
lang = [lang] * len(batch_text_or_text_pairs)
# Preprocess each text in the batch with its corresponding language
processed_texts = []
for text, text_lang in zip(batch_text_or_text_pairs, lang):
if isinstance(text, str):
# Check length and preprocess
self.check_input_length(text, text_lang)
processed_text = self.preprocess_text(text, text_lang)
# Format text with language tag and spaces
lang_code = "zh-cn" if text_lang == "zh" else text_lang
processed_text = f"[{lang_code}]{processed_text}"
processed_text = processed_text.replace(" ", "[SPACE]")
processed_texts.append(processed_text)
else:
processed_texts.append(text)
# Call the parent class's encoding method with processed texts
return super()._batch_encode_plus(
processed_texts,
add_special_tokens=add_special_tokens,
padding_strategy=padding_strategy,
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
**kwargs
)
def __call__(
self,
text: Union[str, List[str]],
lang: Union[str, List[str]] = "en",
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = True, # Changed default to True
truncation: Union[bool, str, TruncationStrategy] = True, # Changed default to True
max_length: Optional[int] = 402,
stride: int = 0,
return_tensors: Optional[str] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = True, # Changed default to True
**kwargs
):
"""
Main tokenization method
Args:
text: Text or list of texts to tokenize
lang: Language code or list of language codes corresponding to each text
add_special_tokens: Whether to add special tokens
padding: Padding strategy (default True)
truncation: Truncation strategy (default True)
max_length: Maximum length
stride: Stride for truncation
return_tensors: Format of output tensors ("pt" for PyTorch)
return_token_type_ids: Whether to return token type IDs
return_attention_mask: Whether to return attention mask (default True)
"""
# Convert single string to list for batch processing
if isinstance(text, str):
text = [text]
if isinstance(lang, str):
lang = [lang]
# Ensure text and lang lists have same length
if len(text) != len(lang):
raise ValueError(f"Number of texts ({len(text)}) must match number of language codes ({len(lang)})")
# Convert padding strategy
if isinstance(padding, bool):
padding_strategy = PaddingStrategy.MAX_LENGTH if padding else PaddingStrategy.DO_NOT_PAD
else:
padding_strategy = PaddingStrategy(padding)
# Convert truncation strategy
if isinstance(truncation, bool):
truncation_strategy = TruncationStrategy.LONGEST_FIRST if truncation else TruncationStrategy.DO_NOT_TRUNCATE
else:
truncation_strategy = TruncationStrategy(truncation)
# Use the batch encoding method
encoded = self._batch_encode_plus(
text,
add_special_tokens=add_special_tokens,
padding_strategy=padding_strategy,
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
lang=lang,
**kwargs
)
return encoded |