|
from dataclasses import asdict, dataclass, field |
|
from typing import Dict, Optional, List |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
class XTTSAudioConfig: |
|
"""Configuration for audio processing parameters""" |
|
sample_rate: int = 22050 |
|
output_sample_rate: int = 24000 |
|
mel_channels: int = 80 |
|
hop_length: int = 256 |
|
win_length: int = 1024 |
|
n_fft: int = 1024 |
|
fmin: int = 0 |
|
fmax: int = 8000 |
|
power: float = 1.0 |
|
mel_norms_file: Optional[str] = None |
|
|
|
|
|
class XTTSGPTConfig(PretrainedConfig): |
|
"""Configuration class for the GPT component of XTTS""" |
|
model_type = "xtts_gpt" |
|
|
|
def __init__( |
|
self, |
|
|
|
vocab_size: int = 256, |
|
num_chars: int = 255, |
|
|
|
|
|
gpt_batch_size: int = 1, |
|
gpt_max_audio_tokens: int = 605, |
|
gpt_max_text_tokens: int = 402, |
|
gpt_max_prompt_tokens: int = 70, |
|
gpt_layers: int = 30, |
|
gpt_n_model_channels: int = 1024, |
|
gpt_n_heads: int = 16, |
|
gpt_number_text_tokens: int = 6681, |
|
gpt_start_text_token: Optional[int] = None, |
|
gpt_stop_text_token: Optional[int] = None, |
|
gpt_num_audio_tokens: int = 1026, |
|
gpt_start_audio_token: int = 1024, |
|
gpt_stop_audio_token: int = 1025, |
|
gpt_code_stride_len: int = 1024, |
|
gpt_use_masking_gt_prompt_approach: bool = True, |
|
gpt_use_perceiver_resampler: bool = True, |
|
gpt_checkpointing: bool = False, |
|
gpt_train_solo_embeddings: bool = False, |
|
|
|
|
|
enable_redaction: bool = False, |
|
kv_cache: bool = True, |
|
perceiver_cond_length_compression: int = 256, |
|
label_smoothing: float = 0.0, |
|
|
|
|
|
temperature: float = 0.75, |
|
length_penalty: float = 1.0, |
|
repetition_penalty: float = 5.0, |
|
top_k: int = 50, |
|
top_p: float = 0.85, |
|
gpt_cond_len: int = 30, |
|
gpt_cond_chunk_len: int = 4, |
|
max_ref_len: int = 30, |
|
sound_norm_refs: bool = False, |
|
|
|
|
|
audio_config: Optional[XTTSAudioConfig] = None, |
|
|
|
|
|
duration_const: int = 102400, |
|
char_limits: Optional[Dict[str, int]] = None, |
|
languages: Optional[List[str]] = None, |
|
pad_token_id: Optional[int] = None, |
|
bos_token_id: Optional[int] = None, |
|
eos_token_id: Optional[int] = None, |
|
**kwargs, |
|
): |
|
if char_limits is None: |
|
char_limits = { |
|
"en": 250, "de": 253, "fr": 273, "es": 239, |
|
"it": 213, "pt": 203, "pl": 224, "zh": 82, |
|
"ar": 166, "cs": 186, "ru": 182, "nl": 251, |
|
"tr": 226, "ja": 71, "hu": 224, "ko": 95, |
|
} |
|
|
|
if languages is None: |
|
languages = [ |
|
"en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", |
|
"cs", "ar", "zh-cn", "hu", "ko", "ja", "hi" |
|
] |
|
|
|
if audio_config is None: |
|
audio_config = XTTSAudioConfig() |
|
|
|
super().__init__( |
|
pad_token_id=pad_token_id, |
|
bos_token_id=bos_token_id, |
|
eos_token_id=eos_token_id, |
|
**kwargs |
|
) |
|
|
|
self.vocab_size = vocab_size |
|
self.num_chars = num_chars |
|
|
|
|
|
self.gpt_batch_size = gpt_batch_size |
|
self.gpt_max_audio_tokens = gpt_max_audio_tokens |
|
self.gpt_max_text_tokens = gpt_max_text_tokens |
|
self.gpt_max_prompt_tokens = gpt_max_prompt_tokens |
|
self.gpt_layers = gpt_layers |
|
self.gpt_n_model_channels = gpt_n_model_channels |
|
self.gpt_n_heads = gpt_n_heads |
|
self.gpt_number_text_tokens = gpt_number_text_tokens |
|
self.gpt_start_text_token = gpt_start_text_token |
|
self.gpt_stop_text_token = gpt_stop_text_token |
|
self.gpt_num_audio_tokens = gpt_num_audio_tokens |
|
self.gpt_start_audio_token = gpt_start_audio_token |
|
self.gpt_stop_audio_token = gpt_stop_audio_token |
|
self.gpt_code_stride_len = gpt_code_stride_len |
|
self.gpt_use_masking_gt_prompt_approach = gpt_use_masking_gt_prompt_approach |
|
self.gpt_use_perceiver_resampler = gpt_use_perceiver_resampler |
|
self.gpt_checkpointing = gpt_checkpointing |
|
self.gpt_train_solo_embeddings = gpt_train_solo_embeddings |
|
|
|
|
|
self.enable_redaction = enable_redaction |
|
self.kv_cache = kv_cache |
|
self.perceiver_cond_length_compression = perceiver_cond_length_compression |
|
self.label_smoothing = label_smoothing |
|
|
|
|
|
self.temperature = temperature |
|
self.length_penalty = length_penalty |
|
self.repetition_penalty = repetition_penalty |
|
self.top_k = top_k |
|
self.top_p = top_p |
|
self.gpt_cond_len = gpt_cond_len |
|
self.gpt_cond_chunk_len = gpt_cond_chunk_len |
|
self.max_ref_len = max_ref_len |
|
self.sound_norm_refs = sound_norm_refs |
|
|
|
|
|
self.audio_config = audio_config |
|
|
|
|
|
self.duration_const = duration_const |
|
self.char_limits = char_limits |
|
self.languages = languages |
|
|
|
def to_dict(self): |
|
"""Convert config to dictionary""" |
|
config_dict = super().to_dict() |
|
config_dict["audio_config"] = asdict(self.audio_config) |
|
return config_dict |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict): |
|
"""Create config from dictionary""" |
|
audio_config = XTTSAudioConfig(**config_dict.pop("audio_config", {})) |
|
return cls(audio_config=audio_config, **config_dict) |
|
|
|
def update_with_tokenizer(self, tokenizer=None): |
|
"""Update configuration values based on tokenizer""" |
|
if tokenizer is not None: |
|
self.gpt_number_text_tokens = tokenizer.get_vocab_size() |
|
self.gpt_start_text_token = tokenizer.bos_token_id |
|
self.gpt_stop_text_token = tokenizer.eos_token_id |