Upload 8 files
Browse files- config.json +142 -0
- gpt_config.py +172 -0
- special_tokens_map.json +6 -0
- tokenizer.json +0 -0
- tokenizer.py +233 -0
- tokenizer_config.json +191 -0
- xtts2_gpt_modeling.py +312 -0
- xttsv2-gpt.safetensors +3 -0
config.json
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "AstraMindAI/xtts2-gpt",
|
3 |
+
"architectures": [
|
4 |
+
"XttsGPT"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig",
|
8 |
+
"AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT"
|
9 |
+
},
|
10 |
+
"audio_config": {
|
11 |
+
"fmax": 8000,
|
12 |
+
"fmin": 0,
|
13 |
+
"hop_length": 256,
|
14 |
+
"mel_channels": 80,
|
15 |
+
"mel_norms_file": null,
|
16 |
+
"n_fft": 1024,
|
17 |
+
"output_sample_rate": 24000,
|
18 |
+
"power": 1.0,
|
19 |
+
"sample_rate": 22050,
|
20 |
+
"win_length": 1024
|
21 |
+
},
|
22 |
+
"char_limits": {
|
23 |
+
"ar": 166,
|
24 |
+
"cs": 186,
|
25 |
+
"de": 253,
|
26 |
+
"en": 250,
|
27 |
+
"es": 239,
|
28 |
+
"fr": 273,
|
29 |
+
"hu": 224,
|
30 |
+
"it": 213,
|
31 |
+
"ja": 71,
|
32 |
+
"ko": 95,
|
33 |
+
"nl": 251,
|
34 |
+
"pl": 224,
|
35 |
+
"pt": 203,
|
36 |
+
"ru": 182,
|
37 |
+
"tr": 226,
|
38 |
+
"zh": 82
|
39 |
+
},
|
40 |
+
"duration_const": 102400,
|
41 |
+
"enable_redaction": false,
|
42 |
+
"gpt_batch_size": 1,
|
43 |
+
"gpt_checkpointing": false,
|
44 |
+
"gpt_code_stride_len": 1024,
|
45 |
+
"gpt_cond_chunk_len": 4,
|
46 |
+
"gpt_cond_len": 30,
|
47 |
+
"gpt_layers": 30,
|
48 |
+
"gpt_max_audio_tokens": 605,
|
49 |
+
"gpt_max_prompt_tokens": 70,
|
50 |
+
"gpt_max_text_tokens": 402,
|
51 |
+
"gpt_n_heads": 16,
|
52 |
+
"gpt_n_model_channels": 1024,
|
53 |
+
"gpt_num_audio_tokens": 1026,
|
54 |
+
"gpt_number_text_tokens": 6681,
|
55 |
+
"gpt_start_audio_token": 1024,
|
56 |
+
"gpt_start_text_token": null,
|
57 |
+
"gpt_stop_audio_token": 1025,
|
58 |
+
"gpt_stop_text_token": null,
|
59 |
+
"gpt_train_solo_embeddings": false,
|
60 |
+
"gpt_use_masking_gt_prompt_approach": true,
|
61 |
+
"gpt_use_perceiver_resampler": true,
|
62 |
+
"kv_cache": true,
|
63 |
+
"label_smoothing": 0.0,
|
64 |
+
"languages": [
|
65 |
+
"en",
|
66 |
+
"es",
|
67 |
+
"fr",
|
68 |
+
"de",
|
69 |
+
"it",
|
70 |
+
"pt",
|
71 |
+
"pl",
|
72 |
+
"tr",
|
73 |
+
"ru",
|
74 |
+
"nl",
|
75 |
+
"cs",
|
76 |
+
"ar",
|
77 |
+
"zh-cn",
|
78 |
+
"hu",
|
79 |
+
"ko",
|
80 |
+
"ja",
|
81 |
+
"hi"
|
82 |
+
],
|
83 |
+
"max_ref_len": 30,
|
84 |
+
"model_type": "xtts_gpt",
|
85 |
+
"num_chars": 255,
|
86 |
+
"perceiver_cond_length_compression": 256,
|
87 |
+
"repetition_penalty": 5.0,
|
88 |
+
"sound_norm_refs": false,
|
89 |
+
"temperature": 0.75,
|
90 |
+
"top_p": 0.85,
|
91 |
+
"transformers_version": "4.45.1",
|
92 |
+
"vocab_size": 256,
|
93 |
+
"cond_d_vector_in_each_upsampling_layer": true,
|
94 |
+
"d_vector_dim": 512,
|
95 |
+
"decoder_input_dim": 1024,
|
96 |
+
"input_sample_rate": 22050,
|
97 |
+
"hifi_model_type": "xtts_hifigan",
|
98 |
+
"output_hop_length": 256,
|
99 |
+
"output_sample_rate": 24000,
|
100 |
+
"resblock_dilation_sizes": [
|
101 |
+
[
|
102 |
+
1,
|
103 |
+
3,
|
104 |
+
5
|
105 |
+
],
|
106 |
+
[
|
107 |
+
1,
|
108 |
+
3,
|
109 |
+
5
|
110 |
+
],
|
111 |
+
[
|
112 |
+
1,
|
113 |
+
3,
|
114 |
+
5
|
115 |
+
]
|
116 |
+
],
|
117 |
+
"resblock_kernel_sizes": [
|
118 |
+
3,
|
119 |
+
7,
|
120 |
+
11
|
121 |
+
],
|
122 |
+
"speaker_encoder_config": {
|
123 |
+
"model_config": null,
|
124 |
+
"model_name": "speaker_encoder",
|
125 |
+
"preprocess_config": null,
|
126 |
+
"speaker_embedding_dim": 512,
|
127 |
+
"use_torch_spec": true
|
128 |
+
},
|
129 |
+
"upsample_initial_channel": 512,
|
130 |
+
"upsample_kernel_sizes": [
|
131 |
+
16,
|
132 |
+
16,
|
133 |
+
4,
|
134 |
+
4
|
135 |
+
],
|
136 |
+
"upsample_rates": [
|
137 |
+
8,
|
138 |
+
8,
|
139 |
+
2,
|
140 |
+
2
|
141 |
+
]
|
142 |
+
}
|
gpt_config.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import asdict, dataclass, field
|
2 |
+
from typing import Dict, Optional, List
|
3 |
+
from transformers.configuration_utils import PretrainedConfig
|
4 |
+
from transformers.utils import logging
|
5 |
+
|
6 |
+
logger = logging.get_logger(__name__)
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class XTTSAudioConfig:
|
11 |
+
"""Configuration for audio processing parameters"""
|
12 |
+
sample_rate: int = 22050
|
13 |
+
output_sample_rate: int = 24000
|
14 |
+
mel_channels: int = 80
|
15 |
+
hop_length: int = 256
|
16 |
+
win_length: int = 1024
|
17 |
+
n_fft: int = 1024
|
18 |
+
fmin: int = 0
|
19 |
+
fmax: int = 8000
|
20 |
+
power: float = 1.0
|
21 |
+
mel_norms_file: Optional[str] = None
|
22 |
+
|
23 |
+
|
24 |
+
class XTTSGPTConfig(PretrainedConfig):
|
25 |
+
"""Configuration class for the GPT component of XTTS"""
|
26 |
+
model_type = "xtts_gpt"
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
# Model architecture
|
31 |
+
vocab_size: int = 256,
|
32 |
+
num_chars: int = 255,
|
33 |
+
|
34 |
+
# GPT parameters
|
35 |
+
gpt_batch_size: int = 1,
|
36 |
+
gpt_max_audio_tokens: int = 605,
|
37 |
+
gpt_max_text_tokens: int = 402,
|
38 |
+
gpt_max_prompt_tokens: int = 70,
|
39 |
+
gpt_layers: int = 30,
|
40 |
+
gpt_n_model_channels: int = 1024,
|
41 |
+
gpt_n_heads: int = 16,
|
42 |
+
gpt_number_text_tokens: int = 6681,
|
43 |
+
gpt_start_text_token: Optional[int] = None,
|
44 |
+
gpt_stop_text_token: Optional[int] = None,
|
45 |
+
gpt_num_audio_tokens: int = 1026,
|
46 |
+
gpt_start_audio_token: int = 1024,
|
47 |
+
gpt_stop_audio_token: int = 1025,
|
48 |
+
gpt_code_stride_len: int = 1024,
|
49 |
+
gpt_use_masking_gt_prompt_approach: bool = True,
|
50 |
+
gpt_use_perceiver_resampler: bool = True,
|
51 |
+
gpt_checkpointing: bool = False,
|
52 |
+
gpt_train_solo_embeddings: bool = False,
|
53 |
+
|
54 |
+
# Training parameters
|
55 |
+
enable_redaction: bool = False,
|
56 |
+
kv_cache: bool = True,
|
57 |
+
perceiver_cond_length_compression: int = 256,
|
58 |
+
label_smoothing: float = 0.0,
|
59 |
+
|
60 |
+
# Generation parameters
|
61 |
+
temperature: float = 0.75,
|
62 |
+
length_penalty: float = 1.0,
|
63 |
+
repetition_penalty: float = 5.0,
|
64 |
+
top_k: int = 50,
|
65 |
+
top_p: float = 0.85,
|
66 |
+
gpt_cond_len: int = 30,
|
67 |
+
gpt_cond_chunk_len: int = 4,
|
68 |
+
max_ref_len: int = 30,
|
69 |
+
sound_norm_refs: bool = False,
|
70 |
+
|
71 |
+
# Audio processing
|
72 |
+
audio_config: Optional[XTTSAudioConfig] = None,
|
73 |
+
|
74 |
+
# Constants and limits
|
75 |
+
duration_const: int = 102400,
|
76 |
+
char_limits: Optional[Dict[str, int]] = None,
|
77 |
+
languages: Optional[List[str]] = None,
|
78 |
+
pad_token_id: Optional[int] = None,
|
79 |
+
bos_token_id: Optional[int] = None,
|
80 |
+
eos_token_id: Optional[int] = None,
|
81 |
+
**kwargs,
|
82 |
+
):
|
83 |
+
if char_limits is None:
|
84 |
+
char_limits = {
|
85 |
+
"en": 250, "de": 253, "fr": 273, "es": 239,
|
86 |
+
"it": 213, "pt": 203, "pl": 224, "zh": 82,
|
87 |
+
"ar": 166, "cs": 186, "ru": 182, "nl": 251,
|
88 |
+
"tr": 226, "ja": 71, "hu": 224, "ko": 95,
|
89 |
+
}
|
90 |
+
|
91 |
+
if languages is None:
|
92 |
+
languages = [
|
93 |
+
"en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl",
|
94 |
+
"cs", "ar", "zh-cn", "hu", "ko", "ja", "hi"
|
95 |
+
]
|
96 |
+
|
97 |
+
if audio_config is None:
|
98 |
+
audio_config = XTTSAudioConfig()
|
99 |
+
|
100 |
+
super().__init__(
|
101 |
+
pad_token_id=pad_token_id,
|
102 |
+
bos_token_id=bos_token_id,
|
103 |
+
eos_token_id=eos_token_id,
|
104 |
+
**kwargs
|
105 |
+
)
|
106 |
+
|
107 |
+
self.vocab_size = vocab_size
|
108 |
+
self.num_chars = num_chars
|
109 |
+
|
110 |
+
# GPT parameters
|
111 |
+
self.gpt_batch_size = gpt_batch_size
|
112 |
+
self.gpt_max_audio_tokens = gpt_max_audio_tokens
|
113 |
+
self.gpt_max_text_tokens = gpt_max_text_tokens
|
114 |
+
self.gpt_max_prompt_tokens = gpt_max_prompt_tokens
|
115 |
+
self.gpt_layers = gpt_layers
|
116 |
+
self.gpt_n_model_channels = gpt_n_model_channels
|
117 |
+
self.gpt_n_heads = gpt_n_heads
|
118 |
+
self.gpt_number_text_tokens = gpt_number_text_tokens
|
119 |
+
self.gpt_start_text_token = gpt_start_text_token
|
120 |
+
self.gpt_stop_text_token = gpt_stop_text_token
|
121 |
+
self.gpt_num_audio_tokens = gpt_num_audio_tokens
|
122 |
+
self.gpt_start_audio_token = gpt_start_audio_token
|
123 |
+
self.gpt_stop_audio_token = gpt_stop_audio_token
|
124 |
+
self.gpt_code_stride_len = gpt_code_stride_len
|
125 |
+
self.gpt_use_masking_gt_prompt_approach = gpt_use_masking_gt_prompt_approach
|
126 |
+
self.gpt_use_perceiver_resampler = gpt_use_perceiver_resampler
|
127 |
+
self.gpt_checkpointing = gpt_checkpointing
|
128 |
+
self.gpt_train_solo_embeddings = gpt_train_solo_embeddings
|
129 |
+
|
130 |
+
# Training parameters
|
131 |
+
self.enable_redaction = enable_redaction
|
132 |
+
self.kv_cache = kv_cache
|
133 |
+
self.perceiver_cond_length_compression = perceiver_cond_length_compression
|
134 |
+
self.label_smoothing = label_smoothing
|
135 |
+
|
136 |
+
# Generation parameters
|
137 |
+
self.temperature = temperature
|
138 |
+
self.length_penalty = length_penalty
|
139 |
+
self.repetition_penalty = repetition_penalty
|
140 |
+
self.top_k = top_k
|
141 |
+
self.top_p = top_p
|
142 |
+
self.gpt_cond_len = gpt_cond_len
|
143 |
+
self.gpt_cond_chunk_len = gpt_cond_chunk_len
|
144 |
+
self.max_ref_len = max_ref_len
|
145 |
+
self.sound_norm_refs = sound_norm_refs
|
146 |
+
|
147 |
+
# Audio processing
|
148 |
+
self.audio_config = audio_config
|
149 |
+
|
150 |
+
# Constants and limits
|
151 |
+
self.duration_const = duration_const
|
152 |
+
self.char_limits = char_limits
|
153 |
+
self.languages = languages
|
154 |
+
|
155 |
+
def to_dict(self):
|
156 |
+
"""Convert config to dictionary"""
|
157 |
+
config_dict = super().to_dict()
|
158 |
+
config_dict["audio_config"] = asdict(self.audio_config)
|
159 |
+
return config_dict
|
160 |
+
|
161 |
+
@classmethod
|
162 |
+
def from_dict(cls, config_dict):
|
163 |
+
"""Create config from dictionary"""
|
164 |
+
audio_config = XTTSAudioConfig(**config_dict.pop("audio_config", {}))
|
165 |
+
return cls(audio_config=audio_config, **config_dict)
|
166 |
+
|
167 |
+
def update_with_tokenizer(self, tokenizer=None):
|
168 |
+
"""Update configuration values based on tokenizer"""
|
169 |
+
if tokenizer is not None:
|
170 |
+
self.gpt_number_text_tokens = tokenizer.get_vocab_size()
|
171 |
+
self.gpt_start_text_token = tokenizer.bos_token_id
|
172 |
+
self.gpt_stop_text_token = tokenizer.eos_token_id
|
special_tokens_map.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "[START]",
|
3 |
+
"eos_token": "[STOP]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"unk_token": "[UNK]"
|
6 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Union, Dict, Tuple, Any
|
2 |
+
import os
|
3 |
+
from functools import cached_property
|
4 |
+
|
5 |
+
from transformers import PreTrainedTokenizerFast
|
6 |
+
from transformers.tokenization_utils_base import TruncationStrategy, PaddingStrategy
|
7 |
+
from tokenizers import Tokenizer, processors
|
8 |
+
from tokenizers.pre_tokenizers import WhitespaceSplit
|
9 |
+
from tokenizers.processors import TemplateProcessing
|
10 |
+
import torch
|
11 |
+
from hangul_romanize import Transliter
|
12 |
+
from hangul_romanize.rule import academic
|
13 |
+
import cutlet
|
14 |
+
|
15 |
+
from TTS.tts.layers.xtts.tokenizer import (multilingual_cleaners, basic_cleaners,
|
16 |
+
chinese_transliterate, korean_transliterate,
|
17 |
+
japanese_cleaners)
|
18 |
+
|
19 |
+
class XTTSTokenizerFast(PreTrainedTokenizerFast):
|
20 |
+
"""
|
21 |
+
Fast Tokenizer implementation for XTTS model using HuggingFace's PreTrainedTokenizerFast
|
22 |
+
"""
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
vocab_file: str = None,
|
26 |
+
tokenizer_object: Optional[Tokenizer] = None,
|
27 |
+
unk_token: str = "[UNK]",
|
28 |
+
pad_token: str = "[PAD]",
|
29 |
+
bos_token: str = "[START]",
|
30 |
+
eos_token: str = "[STOP]",
|
31 |
+
clean_up_tokenization_spaces: bool = True,
|
32 |
+
**kwargs
|
33 |
+
):
|
34 |
+
if tokenizer_object is None and vocab_file is not None:
|
35 |
+
tokenizer_object = Tokenizer.from_file(vocab_file)
|
36 |
+
|
37 |
+
if tokenizer_object is not None:
|
38 |
+
# Configure the tokenizer
|
39 |
+
tokenizer_object.pre_tokenizer = WhitespaceSplit()
|
40 |
+
tokenizer_object.enable_padding(
|
41 |
+
direction='right',
|
42 |
+
pad_id=tokenizer_object.token_to_id(pad_token) or 0,
|
43 |
+
pad_token=pad_token
|
44 |
+
)
|
45 |
+
tokenizer_object.post_processor = TemplateProcessing(
|
46 |
+
single=f"{bos_token} $A {eos_token}",
|
47 |
+
special_tokens=[
|
48 |
+
(bos_token, tokenizer_object.token_to_id(bos_token)),
|
49 |
+
(eos_token, tokenizer_object.token_to_id(eos_token)),
|
50 |
+
],
|
51 |
+
)
|
52 |
+
|
53 |
+
super().__init__(
|
54 |
+
tokenizer_object=tokenizer_object,
|
55 |
+
unk_token=unk_token,
|
56 |
+
pad_token=pad_token,
|
57 |
+
bos_token=bos_token,
|
58 |
+
eos_token=eos_token,
|
59 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
60 |
+
**kwargs
|
61 |
+
)
|
62 |
+
|
63 |
+
# Character limits per language
|
64 |
+
self.char_limits = {
|
65 |
+
"en": 250, "de": 253, "fr": 273, "es": 239,
|
66 |
+
"it": 213, "pt": 203, "pl": 224, "zh": 82,
|
67 |
+
"ar": 166, "cs": 186, "ru": 182, "nl": 251,
|
68 |
+
"tr": 226, "ja": 71, "hu": 224, "ko": 95,
|
69 |
+
}
|
70 |
+
|
71 |
+
# Initialize language tools
|
72 |
+
self._katsu = None
|
73 |
+
self._korean_transliter = Transliter(academic)
|
74 |
+
|
75 |
+
@cached_property
|
76 |
+
def katsu(self):
|
77 |
+
if self._katsu is None:
|
78 |
+
self._katsu = cutlet.Cutlet()
|
79 |
+
return self._katsu
|
80 |
+
|
81 |
+
def check_input_length(self, text: str, lang: str):
|
82 |
+
"""Check if input text length is within limits for language"""
|
83 |
+
lang = lang.split("-")[0] # remove region
|
84 |
+
limit = self.char_limits.get(lang, 250)
|
85 |
+
if len(text) > limit:
|
86 |
+
print(f"Warning: Text length exceeds {limit} char limit for '{lang}', may cause truncation.")
|
87 |
+
|
88 |
+
def preprocess_text(self, text: str, lang: str) -> str:
|
89 |
+
"""Apply text preprocessing for language"""
|
90 |
+
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it",
|
91 |
+
"nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
|
92 |
+
text = multilingual_cleaners(text, lang)
|
93 |
+
if lang == "zh":
|
94 |
+
text = chinese_transliterate(text)
|
95 |
+
if lang == "ko":
|
96 |
+
text = korean_transliterate(text)
|
97 |
+
elif lang == "ja":
|
98 |
+
text = japanese_cleaners(text, self.katsu)
|
99 |
+
else:
|
100 |
+
text = basic_cleaners(text)
|
101 |
+
return text
|
102 |
+
|
103 |
+
def _batch_encode_plus(
|
104 |
+
self,
|
105 |
+
batch_text_or_text_pairs,
|
106 |
+
add_special_tokens: bool = True,
|
107 |
+
padding_strategy = PaddingStrategy.DO_NOT_PAD,
|
108 |
+
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
109 |
+
max_length: Optional[int] = 402,
|
110 |
+
stride: int = 0,
|
111 |
+
is_split_into_words: bool = False,
|
112 |
+
pad_to_multiple_of: Optional[int] = None,
|
113 |
+
return_tensors: Optional[str] = None,
|
114 |
+
return_token_type_ids: Optional[bool] = None,
|
115 |
+
return_attention_mask: Optional[bool] = None,
|
116 |
+
return_overflowing_tokens: bool = False,
|
117 |
+
return_special_tokens_mask: bool = False,
|
118 |
+
return_offsets_mapping: bool = False,
|
119 |
+
return_length: bool = False,
|
120 |
+
verbose: bool = True,
|
121 |
+
**kwargs
|
122 |
+
) -> Dict[str, Any]:
|
123 |
+
"""
|
124 |
+
Override batch encoding to handle language-specific preprocessing
|
125 |
+
"""
|
126 |
+
lang = kwargs.pop("lang", ["en"] * len(batch_text_or_text_pairs))
|
127 |
+
if isinstance(lang, str):
|
128 |
+
lang = [lang] * len(batch_text_or_text_pairs)
|
129 |
+
|
130 |
+
# Preprocess each text in the batch with its corresponding language
|
131 |
+
processed_texts = []
|
132 |
+
for text, text_lang in zip(batch_text_or_text_pairs, lang):
|
133 |
+
if isinstance(text, str):
|
134 |
+
# Check length and preprocess
|
135 |
+
self.check_input_length(text, text_lang)
|
136 |
+
processed_text = self.preprocess_text(text, text_lang)
|
137 |
+
|
138 |
+
# Format text with language tag and spaces
|
139 |
+
lang_code = "zh-cn" if text_lang == "zh" else text_lang
|
140 |
+
processed_text = f"[{lang_code}]{processed_text}"
|
141 |
+
processed_text = processed_text.replace(" ", "[SPACE]")
|
142 |
+
|
143 |
+
processed_texts.append(processed_text)
|
144 |
+
else:
|
145 |
+
processed_texts.append(text)
|
146 |
+
|
147 |
+
# Call the parent class's encoding method with processed texts
|
148 |
+
return super()._batch_encode_plus(
|
149 |
+
processed_texts,
|
150 |
+
add_special_tokens=add_special_tokens,
|
151 |
+
padding_strategy=padding_strategy,
|
152 |
+
truncation_strategy=truncation_strategy,
|
153 |
+
max_length=max_length,
|
154 |
+
stride=stride,
|
155 |
+
is_split_into_words=is_split_into_words,
|
156 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
157 |
+
return_tensors=return_tensors,
|
158 |
+
return_token_type_ids=return_token_type_ids,
|
159 |
+
return_attention_mask=return_attention_mask,
|
160 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
161 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
162 |
+
return_offsets_mapping=return_offsets_mapping,
|
163 |
+
return_length=return_length,
|
164 |
+
verbose=verbose,
|
165 |
+
**kwargs
|
166 |
+
)
|
167 |
+
|
168 |
+
def __call__(
|
169 |
+
self,
|
170 |
+
text: Union[str, List[str]],
|
171 |
+
lang: Union[str, List[str]] = "en",
|
172 |
+
add_special_tokens: bool = True,
|
173 |
+
padding: Union[bool, str, PaddingStrategy] = True, # Changed default to True
|
174 |
+
truncation: Union[bool, str, TruncationStrategy] = True, # Changed default to True
|
175 |
+
max_length: Optional[int] = 402,
|
176 |
+
stride: int = 0,
|
177 |
+
return_tensors: Optional[str] = None,
|
178 |
+
return_token_type_ids: Optional[bool] = None,
|
179 |
+
return_attention_mask: Optional[bool] = True, # Changed default to True
|
180 |
+
**kwargs
|
181 |
+
):
|
182 |
+
"""
|
183 |
+
Main tokenization method
|
184 |
+
Args:
|
185 |
+
text: Text or list of texts to tokenize
|
186 |
+
lang: Language code or list of language codes corresponding to each text
|
187 |
+
add_special_tokens: Whether to add special tokens
|
188 |
+
padding: Padding strategy (default True)
|
189 |
+
truncation: Truncation strategy (default True)
|
190 |
+
max_length: Maximum length
|
191 |
+
stride: Stride for truncation
|
192 |
+
return_tensors: Format of output tensors ("pt" for PyTorch)
|
193 |
+
return_token_type_ids: Whether to return token type IDs
|
194 |
+
return_attention_mask: Whether to return attention mask (default True)
|
195 |
+
"""
|
196 |
+
# Convert single string to list for batch processing
|
197 |
+
if isinstance(text, str):
|
198 |
+
text = [text]
|
199 |
+
if isinstance(lang, str):
|
200 |
+
lang = [lang]
|
201 |
+
|
202 |
+
# Ensure text and lang lists have same length
|
203 |
+
if len(text) != len(lang):
|
204 |
+
raise ValueError(f"Number of texts ({len(text)}) must match number of language codes ({len(lang)})")
|
205 |
+
|
206 |
+
# Convert padding strategy
|
207 |
+
if isinstance(padding, bool):
|
208 |
+
padding_strategy = PaddingStrategy.MAX_LENGTH if padding else PaddingStrategy.DO_NOT_PAD
|
209 |
+
else:
|
210 |
+
padding_strategy = PaddingStrategy(padding)
|
211 |
+
|
212 |
+
# Convert truncation strategy
|
213 |
+
if isinstance(truncation, bool):
|
214 |
+
truncation_strategy = TruncationStrategy.LONGEST_FIRST if truncation else TruncationStrategy.DO_NOT_TRUNCATE
|
215 |
+
else:
|
216 |
+
truncation_strategy = TruncationStrategy(truncation)
|
217 |
+
|
218 |
+
# Use the batch encoding method
|
219 |
+
encoded = self._batch_encode_plus(
|
220 |
+
text,
|
221 |
+
add_special_tokens=add_special_tokens,
|
222 |
+
padding_strategy=padding_strategy,
|
223 |
+
truncation_strategy=truncation_strategy,
|
224 |
+
max_length=max_length,
|
225 |
+
stride=stride,
|
226 |
+
return_tensors=return_tensors,
|
227 |
+
return_token_type_ids=return_token_type_ids,
|
228 |
+
return_attention_mask=return_attention_mask,
|
229 |
+
lang=lang,
|
230 |
+
**kwargs
|
231 |
+
)
|
232 |
+
|
233 |
+
return encoded
|
tokenizer_config.json
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "[STOP]",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"1": {
|
12 |
+
"content": "[UNK]",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"2": {
|
20 |
+
"content": "[SPACE]",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"259": {
|
28 |
+
"content": "[en]",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"260": {
|
36 |
+
"content": "[de]",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
},
|
43 |
+
"261": {
|
44 |
+
"content": "[START]",
|
45 |
+
"lstrip": false,
|
46 |
+
"normalized": false,
|
47 |
+
"rstrip": false,
|
48 |
+
"single_word": false,
|
49 |
+
"special": true
|
50 |
+
},
|
51 |
+
"262": {
|
52 |
+
"content": "[fr]",
|
53 |
+
"lstrip": false,
|
54 |
+
"normalized": false,
|
55 |
+
"rstrip": false,
|
56 |
+
"single_word": false,
|
57 |
+
"special": true
|
58 |
+
},
|
59 |
+
"267": {
|
60 |
+
"content": "[ru]",
|
61 |
+
"lstrip": false,
|
62 |
+
"normalized": false,
|
63 |
+
"rstrip": false,
|
64 |
+
"single_word": false,
|
65 |
+
"special": true
|
66 |
+
},
|
67 |
+
"284": {
|
68 |
+
"content": "[es]",
|
69 |
+
"lstrip": false,
|
70 |
+
"normalized": false,
|
71 |
+
"rstrip": false,
|
72 |
+
"single_word": false,
|
73 |
+
"special": true
|
74 |
+
},
|
75 |
+
"285": {
|
76 |
+
"content": "[it]",
|
77 |
+
"lstrip": false,
|
78 |
+
"normalized": false,
|
79 |
+
"rstrip": false,
|
80 |
+
"single_word": false,
|
81 |
+
"special": true
|
82 |
+
},
|
83 |
+
"286": {
|
84 |
+
"content": "[pt]",
|
85 |
+
"lstrip": false,
|
86 |
+
"normalized": false,
|
87 |
+
"rstrip": false,
|
88 |
+
"single_word": false,
|
89 |
+
"special": true
|
90 |
+
},
|
91 |
+
"293": {
|
92 |
+
"content": "[cs]",
|
93 |
+
"lstrip": false,
|
94 |
+
"normalized": false,
|
95 |
+
"rstrip": false,
|
96 |
+
"single_word": false,
|
97 |
+
"special": true
|
98 |
+
},
|
99 |
+
"294": {
|
100 |
+
"content": "[pl]",
|
101 |
+
"lstrip": false,
|
102 |
+
"normalized": false,
|
103 |
+
"rstrip": false,
|
104 |
+
"single_word": false,
|
105 |
+
"special": true
|
106 |
+
},
|
107 |
+
"295": {
|
108 |
+
"content": "[tr]",
|
109 |
+
"lstrip": false,
|
110 |
+
"normalized": false,
|
111 |
+
"rstrip": false,
|
112 |
+
"single_word": false,
|
113 |
+
"special": true
|
114 |
+
},
|
115 |
+
"297": {
|
116 |
+
"content": "[nl]",
|
117 |
+
"lstrip": false,
|
118 |
+
"normalized": false,
|
119 |
+
"rstrip": false,
|
120 |
+
"single_word": false,
|
121 |
+
"special": true
|
122 |
+
},
|
123 |
+
"5022": {
|
124 |
+
"content": "[ar]",
|
125 |
+
"lstrip": false,
|
126 |
+
"normalized": false,
|
127 |
+
"rstrip": false,
|
128 |
+
"single_word": false,
|
129 |
+
"special": true
|
130 |
+
},
|
131 |
+
"5023": {
|
132 |
+
"content": "[zh-cn]",
|
133 |
+
"lstrip": false,
|
134 |
+
"normalized": false,
|
135 |
+
"rstrip": false,
|
136 |
+
"single_word": false,
|
137 |
+
"special": true
|
138 |
+
},
|
139 |
+
"5412": {
|
140 |
+
"content": "[ja]",
|
141 |
+
"lstrip": false,
|
142 |
+
"normalized": false,
|
143 |
+
"rstrip": false,
|
144 |
+
"single_word": false,
|
145 |
+
"special": true
|
146 |
+
},
|
147 |
+
"5753": {
|
148 |
+
"content": "[hu]",
|
149 |
+
"lstrip": false,
|
150 |
+
"normalized": false,
|
151 |
+
"rstrip": false,
|
152 |
+
"single_word": false,
|
153 |
+
"special": true
|
154 |
+
},
|
155 |
+
"6152": {
|
156 |
+
"content": "[ko]",
|
157 |
+
"lstrip": false,
|
158 |
+
"normalized": false,
|
159 |
+
"rstrip": false,
|
160 |
+
"single_word": false,
|
161 |
+
"special": true
|
162 |
+
},
|
163 |
+
"6680": {
|
164 |
+
"content": "[hi]",
|
165 |
+
"lstrip": false,
|
166 |
+
"normalized": false,
|
167 |
+
"rstrip": false,
|
168 |
+
"single_word": false,
|
169 |
+
"special": true
|
170 |
+
},
|
171 |
+
"6681": {
|
172 |
+
"content": "[PAD]",
|
173 |
+
"lstrip": false,
|
174 |
+
"normalized": false,
|
175 |
+
"rstrip": false,
|
176 |
+
"single_word": false,
|
177 |
+
"special": true
|
178 |
+
}
|
179 |
+
},
|
180 |
+
"bos_token": "[START]",
|
181 |
+
"clean_up_tokenization_spaces": true,
|
182 |
+
"eos_token": "[STOP]",
|
183 |
+
"max_length": null,
|
184 |
+
"model_max_length": 1000000000000000019884624838656,
|
185 |
+
"pad_to_multiple_of": null,
|
186 |
+
"pad_token": "[PAD]",
|
187 |
+
"pad_token_type_id": 0,
|
188 |
+
"padding_side": "right",
|
189 |
+
"tokenizer_class": "XTTSTokenizer",
|
190 |
+
"unk_token": "[UNK]"
|
191 |
+
}
|
xtts2_gpt_modeling.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import math
|
3 |
+
from array import array
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from typing import List, Optional, Union, Iterable, Tuple, Mapping
|
9 |
+
|
10 |
+
from transformers import PretrainedConfig
|
11 |
+
from vllm.attention import AttentionMetadata
|
12 |
+
from vllm.config import CacheConfig
|
13 |
+
from vllm.distributed import get_pp_group
|
14 |
+
from vllm.inputs import InputContext, INPUT_REGISTRY
|
15 |
+
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
16 |
+
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
17 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
18 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
19 |
+
from vllm.model_executor.models.gpt2 import GPT2Block
|
20 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
21 |
+
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
|
22 |
+
from vllm.sequence import IntermediateTensors, SequenceData, VLLM_TOKEN_ID_ARRAY_TYPE
|
23 |
+
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
24 |
+
|
25 |
+
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder # noqa
|
26 |
+
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler # noqa
|
27 |
+
|
28 |
+
from TTS.TTS.tts.layers.xtts.gpt import LearnedPositionEmbeddings
|
29 |
+
|
30 |
+
# Constants for token calculation
|
31 |
+
_AUDIO_PLACEHOLDER_TOKEN = 8192 # Using XTTS start_audio_token as placeholder
|
32 |
+
_AUDIO_TOKENS_PER_SECOND = 6.25
|
33 |
+
_CODE_STRIDE_LEN = 1024
|
34 |
+
|
35 |
+
|
36 |
+
def get_xtts_max_audio_tokens(ctx: InputContext) -> int:
|
37 |
+
"""Calculate maximum audio tokens based on text context and audio duration."""
|
38 |
+
# Based on GPT config and common XTTS settings
|
39 |
+
text_context = ctx.model_config.max_seq_len - 100 # Reserve space for text
|
40 |
+
# Allow for ~30 seconds of audio (similar to whisper chunks)
|
41 |
+
max_audio_duration = 30.0
|
42 |
+
audio_tokens = math.ceil(max_audio_duration * _AUDIO_TOKENS_PER_SECOND)
|
43 |
+
total_tokens = text_context + audio_tokens + 4 # +4 for special tokens
|
44 |
+
|
45 |
+
return min(total_tokens, 1000) # Cap at 1000 tokens as specified
|
46 |
+
|
47 |
+
|
48 |
+
def dummy_seq_data_for_xtts(
|
49 |
+
ctx: InputContext,
|
50 |
+
seq_len: int,
|
51 |
+
audio_count: int,
|
52 |
+
) -> SequenceData:
|
53 |
+
"""Create dummy sequence data for XTTS profiling."""
|
54 |
+
# Calculate audio token space needed
|
55 |
+
audio_len_tokens = math.ceil(_AUDIO_TOKENS_PER_SECOND * 5) # Assume 5s per chunk
|
56 |
+
audio_placeholder = array(
|
57 |
+
VLLM_TOKEN_ID_ARRAY_TYPE,
|
58 |
+
[_AUDIO_PLACEHOLDER_TOKEN]
|
59 |
+
) * audio_len_tokens
|
60 |
+
|
61 |
+
# Add separator between chunks
|
62 |
+
audio_token_ids = (audio_placeholder + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count
|
63 |
+
|
64 |
+
# Fill remaining sequence with padding
|
65 |
+
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - len(audio_token_ids))
|
66 |
+
|
67 |
+
return SequenceData(audio_token_ids + other_token_ids)
|
68 |
+
|
69 |
+
|
70 |
+
def dummy_conditioning_for_xtts(
|
71 |
+
ctx: InputContext,
|
72 |
+
audio_count: int,
|
73 |
+
) -> dict:
|
74 |
+
"""Create dummy conditioning data for XTTS."""
|
75 |
+
return {
|
76 |
+
"cond_latents": [(torch.zeros(80, 1024), 22050) for _ in range(audio_count)]
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
def dummy_data_for_xtts(
|
81 |
+
ctx: InputContext,
|
82 |
+
seq_len: int,
|
83 |
+
mm_counts: Mapping[str, int],
|
84 |
+
) -> Tuple[SequenceData, dict]:
|
85 |
+
"""Create complete dummy data for XTTS profiling."""
|
86 |
+
audio_count = mm_counts["audio"]
|
87 |
+
seq_data = dummy_seq_data_for_xtts(ctx, seq_len, audio_count)
|
88 |
+
cond_data = dummy_conditioning_for_xtts(ctx, audio_count)
|
89 |
+
return (seq_data, cond_data)
|
90 |
+
|
91 |
+
|
92 |
+
def input_mapper_for_xtts(ctx: InputContext, data: object) -> MultiModalInputs:
|
93 |
+
"""Map input data to XTTS format."""
|
94 |
+
if not isinstance(data, list):
|
95 |
+
data = [data]
|
96 |
+
|
97 |
+
# Each item should be a tuple of (mel_spec, sample_rate)
|
98 |
+
for audio_input in data:
|
99 |
+
if not isinstance(audio_input, tuple):
|
100 |
+
raise NotImplementedError(f"Unsupported data type: {type(audio_input)}")
|
101 |
+
|
102 |
+
return MultiModalInputs({"cond_latents": data})
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_xtts)
|
107 |
+
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", get_xtts_max_audio_tokens)
|
108 |
+
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_xtts)
|
109 |
+
class XttsGPT(nn.Module, SupportsMultiModal):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
config: PretrainedConfig,
|
113 |
+
cache_config: Optional[CacheConfig] = None,
|
114 |
+
quant_config: Optional["QuantizationConfig"] = None,
|
115 |
+
):
|
116 |
+
super().__init__()
|
117 |
+
self.config = config
|
118 |
+
self.quant_config = quant_config
|
119 |
+
|
120 |
+
# XTTS specific components
|
121 |
+
self.conditioning_encoder = ConditioningEncoder(
|
122 |
+
80, config.n_embd, num_attn_heads=config.n_head
|
123 |
+
)
|
124 |
+
|
125 |
+
if config.use_perceiver_resampler:
|
126 |
+
self.conditioning_perceiver = PerceiverResampler(
|
127 |
+
dim=config.n_embd,
|
128 |
+
depth=2,
|
129 |
+
dim_context=config.n_embd,
|
130 |
+
num_latents=32,
|
131 |
+
dim_head=64,
|
132 |
+
heads=8,
|
133 |
+
ff_mult=4,
|
134 |
+
use_flash_attn=False,
|
135 |
+
)
|
136 |
+
|
137 |
+
# Core GPT components following VLLM pattern
|
138 |
+
self.gpt = XttsGPT2Model(
|
139 |
+
config,
|
140 |
+
cache_config,
|
141 |
+
quant_config,
|
142 |
+
prefix="gpt"
|
143 |
+
)
|
144 |
+
|
145 |
+
# Prediction heads
|
146 |
+
self.text_head = ColumnParallelLinear(
|
147 |
+
config.n_embd,
|
148 |
+
config.vocab_size,
|
149 |
+
bias=False,
|
150 |
+
quant_config=quant_config,
|
151 |
+
prefix="text_head"
|
152 |
+
)
|
153 |
+
|
154 |
+
self.mel_head = ColumnParallelLinear(
|
155 |
+
config.n_embd,
|
156 |
+
config.num_audio_tokens,
|
157 |
+
bias=False,
|
158 |
+
quant_config=quant_config,
|
159 |
+
prefix="mel_head"
|
160 |
+
)
|
161 |
+
|
162 |
+
self.sampler = Sampler()
|
163 |
+
|
164 |
+
def get_style_emb(self, cond_input: torch.Tensor, return_latent: bool = False) -> torch.Tensor:
|
165 |
+
"""Get conditioning embeddings from mel spectrograms."""
|
166 |
+
if not return_latent:
|
167 |
+
if cond_input.ndim == 4:
|
168 |
+
cond_input = cond_input.squeeze(1)
|
169 |
+
conds = self.conditioning_encoder(cond_input)
|
170 |
+
|
171 |
+
if hasattr(self, 'conditioning_perceiver'):
|
172 |
+
conds = self.conditioning_perceiver(
|
173 |
+
conds.permute(0, 2, 1)
|
174 |
+
).transpose(1, 2)
|
175 |
+
else:
|
176 |
+
conds = cond_input.unsqueeze(1)
|
177 |
+
return conds
|
178 |
+
|
179 |
+
def forward(
|
180 |
+
self,
|
181 |
+
input_ids: torch.Tensor,
|
182 |
+
positions: torch.Tensor,
|
183 |
+
kv_caches: List[torch.Tensor],
|
184 |
+
attn_metadata: AttentionMetadata,
|
185 |
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
186 |
+
cond_latents: Optional[torch.Tensor] = None,
|
187 |
+
) -> torch.Tensor:
|
188 |
+
"""Forward pass following VLLM pattern."""
|
189 |
+
if cond_latents is not None:
|
190 |
+
# Combine conditioning with input embeddings
|
191 |
+
input_embeds = self.gpt.get_input_embeddings()(input_ids)
|
192 |
+
combined_embeds = torch.cat([cond_latents, input_embeds], dim=1)
|
193 |
+
hidden_states = self.gpt(
|
194 |
+
inputs_embeds=combined_embeds,
|
195 |
+
positions=positions,
|
196 |
+
kv_caches=kv_caches,
|
197 |
+
attn_metadata=attn_metadata,
|
198 |
+
intermediate_tensors=intermediate_tensors,
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
hidden_states = self.gpt(
|
202 |
+
input_ids=input_ids,
|
203 |
+
positions=positions,
|
204 |
+
kv_caches=kv_caches,
|
205 |
+
attn_metadata=attn_metadata,
|
206 |
+
intermediate_tensors=intermediate_tensors,
|
207 |
+
)
|
208 |
+
return hidden_states
|
209 |
+
|
210 |
+
def compute_logits( # useless but kept for compatibility
|
211 |
+
self,
|
212 |
+
hidden_states: torch.Tensor,
|
213 |
+
sampling_metadata: SamplingMetadata,
|
214 |
+
) -> torch.Tensor:
|
215 |
+
"""Compute output logits."""
|
216 |
+
text_logits = self.text_head(hidden_states[sampling_metadata.selected_token_indices])
|
217 |
+
mel_logits = self.mel_head(hidden_states[sampling_metadata.selected_token_indices])
|
218 |
+
return torch.cat([text_logits, mel_logits], dim=1)
|
219 |
+
|
220 |
+
|
221 |
+
def sample(
|
222 |
+
self,
|
223 |
+
logits: torch.Tensor,
|
224 |
+
sampling_metadata: SamplingMetadata,
|
225 |
+
) -> Optional[SamplerOutput]:
|
226 |
+
"""Sample next tokens using VLLM sampler."""
|
227 |
+
return self.sampler(logits, sampling_metadata)
|
228 |
+
|
229 |
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
230 |
+
"""Load weights following VLLM pattern."""
|
231 |
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
232 |
+
|
233 |
+
for name, loaded_weight in weights:
|
234 |
+
if name not in params_dict:
|
235 |
+
continue
|
236 |
+
|
237 |
+
param = params_dict[name]
|
238 |
+
if "c_attn" in name or "c_proj" in name or "c_fc" in name:
|
239 |
+
if name.endswith(".weight"):
|
240 |
+
loaded_weight = loaded_weight.t()
|
241 |
+
|
242 |
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
243 |
+
weight_loader(param, loaded_weight)
|
244 |
+
|
245 |
+
|
246 |
+
class XttsGPT2Model(nn.Module):
|
247 |
+
"""VLLM-style implementation of GPT2 core architecture."""
|
248 |
+
|
249 |
+
def __init__(
|
250 |
+
self,
|
251 |
+
config: PretrainedConfig,
|
252 |
+
cache_config: Optional[CacheConfig] = None,
|
253 |
+
quant_config: Optional["QuantizationConfig"] = None,
|
254 |
+
prefix: str = "",
|
255 |
+
):
|
256 |
+
super().__init__()
|
257 |
+
self.config = config
|
258 |
+
self.text_embedding = VocabParallelEmbedding(config.number_text_tokens, config.n_embd)
|
259 |
+
self.mel_embedding = VocabParallelEmbedding(config.num_audio_tokens, config.n_embd)
|
260 |
+
|
261 |
+
self.text_pos_embedding = (
|
262 |
+
LearnedPositionEmbeddings(config.max_text_seq_len, config.n_embd)
|
263 |
+
if config.max_mel_seq_len != -1
|
264 |
+
else functools.partial(config.null_position_embeddings, dim=config.n_embd)
|
265 |
+
)
|
266 |
+
self.mel_pos_embedding = (
|
267 |
+
LearnedPositionEmbeddings(config.max_mel_seq_len, config.n_embd)
|
268 |
+
if config.max_mel_seq_len != -1
|
269 |
+
else functools.partial(config.null_position_embeddings, dim=config.n_embd)
|
270 |
+
)
|
271 |
+
# Build gpt blocks
|
272 |
+
self.h = nn.ModuleList([
|
273 |
+
GPT2Block(
|
274 |
+
config,
|
275 |
+
cache_config,
|
276 |
+
quant_config,
|
277 |
+
prefix=f"{prefix}.h.{i}"
|
278 |
+
) for i in range(config.num_hidden_layers)
|
279 |
+
])
|
280 |
+
|
281 |
+
self.final_norm = nn.LayerNorm(
|
282 |
+
config.n_embd,
|
283 |
+
eps=config.layer_norm_epsilon
|
284 |
+
)
|
285 |
+
|
286 |
+
def forward( # TODO: this is not correct, allieeate it with the correct implementation
|
287 |
+
self,
|
288 |
+
input_ids: torch.Tensor,
|
289 |
+
position_ids: torch.Tensor,
|
290 |
+
kv_caches: List[torch.Tensor],
|
291 |
+
attn_metadata: AttentionMetadata,
|
292 |
+
intermediate_tensors: Optional[IntermediateTensors],
|
293 |
+
) -> Union[torch.Tensor, IntermediateTensors]:
|
294 |
+
if get_pp_group().is_first_rank:
|
295 |
+
inputs_embeds = self.wte(input_ids)
|
296 |
+
position_embeds = self.wpe(position_ids)
|
297 |
+
hidden_states = inputs_embeds + position_embeds
|
298 |
+
else:
|
299 |
+
assert intermediate_tensors is not None
|
300 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
301 |
+
|
302 |
+
for i in range(self.start_layer, self.end_layer):
|
303 |
+
layer = self.h[i]
|
304 |
+
hidden_states = layer(hidden_states,
|
305 |
+
kv_caches[i - self.start_layer],
|
306 |
+
attn_metadata)
|
307 |
+
|
308 |
+
if not get_pp_group().is_last_rank:
|
309 |
+
return IntermediateTensors({"hidden_states": hidden_states})
|
310 |
+
|
311 |
+
hidden_states = self.ln_f(hidden_states)
|
312 |
+
return hidden_states
|
xttsv2-gpt.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:93fa43aaad29e232fa6c85f3d6c3285285c1fe4c89f9505d8153e231b12e1a50
|
3 |
+
size 1764117740
|