File size: 5,365 Bytes
09a868c 7eebd5c 7e1e475 7eebd5c 7e1e475 7eebd5c 7e1e475 7eebd5c 7e1e475 013e081 7e1e475 7eebd5c 7e1e475 7eebd5c 7e1e475 7eebd5c 7e1e475 09a868c 7e1e475 bfce01d 7e1e475 e7fb2db 09a868c 7eebd5c 7e1e475 7eebd5c e7fb2db bfce01d e7fb2db bfce01d 7e1e475 bfce01d 7e1e475 bfce01d 7e1e475 bfce01d 7eebd5c 7e1e475 bfce01d 7e1e475 09a868c 7e1e475 7eebd5c 7e1e475 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from dataclasses import asdict, dataclass
from typing import Dict, Optional, List
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
@dataclass
class GPTAudioConfig:
"""Configuration for GPT audio processing parameters"""
mel_channels: int = 80
sample_rate: int = 22050
output_sample_rate: int = 24000
@dataclass
class XTTSAudioConfig:
"""Configuration for audio processing parameters"""
sample_rate: int = 22050
output_sample_rate: int = 24000
mel_channels: int = 80
hop_length: int = 256
win_length: int = 1024
n_fft: int = 1024
fmin: int = 0
fmax: int = 8000
power: float = 1.0
mel_norms_file: Optional[str] = None
class XTTSGPTConfig(PretrainedConfig):
"""Configuration class for the GPT component of XTTS."""
model_type = "xtts_gpt"
def __init__(
self,
# Model architecture
hidden_size: int = 1024, # gpt_n_model_channels in original
n_inner: int = 4096,
num_hidden_layers: int = 30, # gpt_layers in original
num_attention_heads: int = 16, # gpt_n_heads in original
# Tokenizer settings
vocab_size: int = 6681, # gpt_number_text_tokens in original
number_text_tokens: int = 6681, # Explicit text token vocabulary size
start_text_token: Optional[int] = None,
stop_text_token: Optional[int] = None,
# Audio token settings
num_audio_tokens: int = 1026, # gpt_num_audio_tokens in original
start_audio_token: int = 1024, # gpt_start_audio_token in original
stop_audio_token: int = 1025, # gpt_stop_audio_token in original
# Sequence length settings
max_audio_tokens: int = 605, # gpt_max_audio_tokens in original
max_text_tokens: int = 402, # gpt_max_text_tokens in original
max_prompt_tokens: int = 70, # gpt_max_prompt_tokens in original
gpt_max_audio_tokens: int = 605, # Used for generation
# Model behavior settings
use_masking_gt_prompt_approach: bool = True, # gpt_use_masking_gt_prompt_approach in original
use_perceiver_resampler: bool = True, # gpt_use_perceiver_resampler in original
kv_cache: bool = True,
enable_redaction: bool = False,
# GPT batch settings
gpt_batch_size: int = 1,
# Audio processing
audio_config: Optional[Dict] = None,
# Architecture specifics
layer_norm_epsilon: float = 1e-5,
initializer_range: float = 0.02,
add_cross_attention: bool = False,
scale_attn_by_inverse_layer_idx: bool = False,
reorder_and_upcast_attn: bool = False,
# Size settings for the decoder
decoder_input_dim: int = 1024,
architectures=["XttsGPT"],
auto_map={
"AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig",
"AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT",
},
activation_function: str = "gelu",
attn_pdrop: float = 0.1,
**kwargs
):
super().__init__(**kwargs)
self.architectures = architectures
self.auto_map = auto_map
self.audio_config = GPTAudioConfig(
**audio_config if audio_config is not None else {}
)
self.activation_function = activation_function
self.attn_pdrop = attn_pdrop
self.hidden_size = hidden_size
self.n_inner = n_inner
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.vocab_size = vocab_size
self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token
self.stop_text_token = stop_text_token
self.num_audio_tokens = num_audio_tokens
self.start_audio_token = start_audio_token
self.stop_audio_token = stop_audio_token
self.max_audio_tokens = max_audio_tokens
self.max_text_tokens = max_text_tokens
self.max_prompt_tokens = max_prompt_tokens
self.gpt_max_audio_tokens = gpt_max_audio_tokens
self.use_masking_gt_prompt_approach = use_masking_gt_prompt_approach
self.use_perceiver_resampler = use_perceiver_resampler
self.kv_cache = kv_cache
self.enable_redaction = enable_redaction
self.gpt_batch_size = gpt_batch_size
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.add_cross_attention = add_cross_attention
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
self.reorder_and_upcast_attn = reorder_and_upcast_attn
self.decoder_input_dim = decoder_input_dim
def to_dict(self) -> Dict:
"""Convert the config to a dictionary."""
output = super().to_dict()
output["audio_config"] = asdict(self.audio_config)
return output
@classmethod
def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSGPTConfig":
"""Create a config from a dictionary."""
return cls(**config_dict)
|