mlinmg commited on
Commit
bfce01d
1 Parent(s): 456660e

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +34 -81
  2. gpt_config.py +85 -54
  3. xtts2_gpt_modeling.py +207 -59
config.json CHANGED
@@ -1,14 +1,6 @@
1
  {
2
- "_name_or_path": "AstraMindAI/xtts2-gpt",
3
- "architectures": [
4
- "XttsGPT"
5
- ],
6
- "torch_dtype": "float32",
7
- "auto_map": {
8
- "AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig",
9
- "AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT",
10
- "AutoTokenizer": "AstraMindAI/xtts2-gpt--tokenizer.XTTSTokenizerFast"
11
- },
12
  "audio_config": {
13
  "fmax": 8000,
14
  "fmin": 0,
@@ -21,6 +13,7 @@
21
  "sample_rate": 22050,
22
  "win_length": 1024
23
  },
 
24
  "char_limits": {
25
  "ar": 166,
26
  "cs": 186,
@@ -39,28 +32,14 @@
39
  "tr": 226,
40
  "zh": 82
41
  },
 
 
 
 
42
  "duration_const": 102400,
 
43
  "enable_redaction": false,
44
- "gpt_batch_size": 1,
45
- "gpt_checkpointing": false,
46
- "gpt_code_stride_len": 1024,
47
- "gpt_cond_chunk_len": 4,
48
- "gpt_cond_len": 30,
49
- "gpt_layers": 30,
50
- "gpt_max_audio_tokens": 605,
51
- "gpt_max_prompt_tokens": 70,
52
- "gpt_max_text_tokens": 402,
53
- "gpt_n_heads": 16,
54
- "gpt_n_model_channels": 1024,
55
- "gpt_num_audio_tokens": 1026,
56
- "gpt_number_text_tokens": 6681,
57
- "gpt_start_audio_token": 1024,
58
- "gpt_start_text_token": null,
59
- "gpt_stop_audio_token": 1025,
60
- "gpt_stop_text_token": null,
61
- "gpt_train_solo_embeddings": false,
62
- "gpt_use_masking_gt_prompt_approach": true,
63
- "gpt_use_perceiver_resampler": true,
64
  "kv_cache": true,
65
  "label_smoothing": 0.0,
66
  "languages": [
@@ -82,60 +61,34 @@
82
  "ja",
83
  "hi"
84
  ],
 
 
 
 
85
  "max_ref_len": 30,
 
86
  "model_type": "xtts_gpt",
 
 
 
87
  "num_chars": 255,
 
 
88
  "perceiver_cond_length_compression": 256,
 
 
 
 
89
  "sound_norm_refs": false,
90
- "transformers_version": "4.45.1",
91
- "vocab_size": 256,
92
- "cond_d_vector_in_each_upsampling_layer": true,
93
- "d_vector_dim": 512,
94
- "decoder_input_dim": 1024,
95
- "input_sample_rate": 22050,
96
- "hifi_model_type": "xtts_hifigan",
97
- "output_hop_length": 256,
98
- "output_sample_rate": 24000,
99
- "resblock_dilation_sizes": [
100
- [
101
- 1,
102
- 3,
103
- 5
104
- ],
105
- [
106
- 1,
107
- 3,
108
- 5
109
- ],
110
- [
111
- 1,
112
- 3,
113
- 5
114
- ]
115
- ],
116
- "resblock_kernel_sizes": [
117
- 3,
118
- 7,
119
- 11
120
- ],
121
- "speaker_encoder_config": {
122
- "model_config": null,
123
- "model_name": "speaker_encoder",
124
- "preprocess_config": null,
125
- "speaker_embedding_dim": 512,
126
- "use_torch_spec": true
127
- },
128
- "upsample_initial_channel": 512,
129
- "upsample_kernel_sizes": [
130
- 16,
131
- 16,
132
- 4,
133
- 4
134
- ],
135
- "upsample_rates": [
136
- 8,
137
- 8,
138
- 2,
139
- 2
140
- ]
141
  }
 
1
  {
2
+ "activation_function": "gelu",
3
+ "attn_pdrop": 0.1,
 
 
 
 
 
 
 
 
4
  "audio_config": {
5
  "fmax": 8000,
6
  "fmin": 0,
 
13
  "sample_rate": 22050,
14
  "win_length": 1024
15
  },
16
+ "batch_size": 1,
17
  "char_limits": {
18
  "ar": 166,
19
  "cs": 186,
 
32
  "tr": 226,
33
  "zh": 82
34
  },
35
+ "checkpointing": false,
36
+ "code_stride_len": 1024,
37
+ "cond_chunk_len": 4,
38
+ "cond_len": 30,
39
  "duration_const": 102400,
40
+ "embd_pdrop": 0.1,
41
  "enable_redaction": false,
42
+ "hidden_size": 1024,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  "kv_cache": true,
44
  "label_smoothing": 0.0,
45
  "languages": [
 
61
  "ja",
62
  "hi"
63
  ],
64
+ "layer_norm_epsilon": 1e-05,
65
+ "max_audio_tokens": 605,
66
+ "max_position_embeddings": 2048,
67
+ "max_prompt_tokens": 70,
68
  "max_ref_len": 30,
69
+ "max_text_tokens": 402,
70
  "model_type": "xtts_gpt",
71
+ "n_inner": null,
72
+ "num_attention_heads": 16,
73
+ "num_audio_tokens": 1026,
74
  "num_chars": 255,
75
+ "num_hidden_layers": 30,
76
+ "number_text_tokens": 6681,
77
  "perceiver_cond_length_compression": 256,
78
+ "reorder_and_upcast_attn": false,
79
+ "repetition_penalty": 5.0,
80
+ "resid_pdrop": 0.1,
81
+ "scale_attn_by_inverse_layer_idx": false,
82
  "sound_norm_refs": false,
83
+ "start_audio_token": 1024,
84
+ "start_text_token": null,
85
+ "stop_audio_token": 1025,
86
+ "stop_text_token": null,
87
+ "temperature": 0.75,
88
+ "top_p": 0.85,
89
+ "train_solo_embeddings": false,
90
+ "transformers_version": "4.46.0",
91
+ "use_masking_gt_prompt_approach": true,
92
+ "use_perceiver_resampler": true,
93
+ "vocab_size": 256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  }
gpt_config.py CHANGED
@@ -29,27 +29,34 @@ class XTTSGPTConfig(PretrainedConfig):
29
  self,
30
  # Model architecture
31
  vocab_size: int = 256,
 
 
 
 
 
 
 
 
 
 
 
 
32
  num_chars: int = 255,
33
-
34
- # GPT parameters
35
- gpt_batch_size: int = 1,
36
- gpt_max_audio_tokens: int = 605,
37
- gpt_max_text_tokens: int = 402,
38
- gpt_max_prompt_tokens: int = 70,
39
- gpt_layers: int = 30,
40
- gpt_n_model_channels: int = 1024,
41
- gpt_n_heads: int = 16,
42
- gpt_number_text_tokens: int = 6681,
43
- gpt_start_text_token: Optional[int] = None,
44
- gpt_stop_text_token: Optional[int] = None,
45
- gpt_num_audio_tokens: int = 1026,
46
- gpt_start_audio_token: int = 1024,
47
- gpt_stop_audio_token: int = 1025,
48
- gpt_code_stride_len: int = 1024,
49
- gpt_use_masking_gt_prompt_approach: bool = True,
50
- gpt_use_perceiver_resampler: bool = True,
51
- gpt_checkpointing: bool = False,
52
- gpt_train_solo_embeddings: bool = False,
53
 
54
  # Training parameters
55
  enable_redaction: bool = False,
@@ -58,13 +65,13 @@ class XTTSGPTConfig(PretrainedConfig):
58
  label_smoothing: float = 0.0,
59
 
60
  # Generation parameters
61
- #temperature: float = 0.75, will trow a warning
62
- #length_penalty: float = 1.0,
63
- #repetition_penalty: float = 5.0,
64
- #top_k: int = 50,
65
- #top_p: float = 0.85,
66
- gpt_cond_len: int = 30,
67
- gpt_cond_chunk_len: int = 4,
68
  max_ref_len: int = 30,
69
  sound_norm_refs: bool = False,
70
 
@@ -78,6 +85,12 @@ class XTTSGPTConfig(PretrainedConfig):
78
  pad_token_id: Optional[int] = None,
79
  bos_token_id: Optional[int] = None,
80
  eos_token_id: Optional[int] = None,
 
 
 
 
 
 
81
  **kwargs,
82
  ):
83
  if char_limits is None:
@@ -105,27 +118,34 @@ class XTTSGPTConfig(PretrainedConfig):
105
  )
106
 
107
  self.vocab_size = vocab_size
 
 
 
 
 
 
 
 
 
 
 
 
108
  self.num_chars = num_chars
109
-
110
- # GPT parameters
111
- self.gpt_batch_size = gpt_batch_size
112
- self.gpt_max_audio_tokens = gpt_max_audio_tokens
113
- self.gpt_max_text_tokens = gpt_max_text_tokens
114
- self.gpt_max_prompt_tokens = gpt_max_prompt_tokens
115
- self.gpt_layers = gpt_layers
116
- self.gpt_n_model_channels = gpt_n_model_channels
117
- self.gpt_n_heads = gpt_n_heads
118
- self.gpt_number_text_tokens = gpt_number_text_tokens
119
- self.gpt_start_text_token = gpt_start_text_token
120
- self.gpt_stop_text_token = gpt_stop_text_token
121
- self.gpt_num_audio_tokens = gpt_num_audio_tokens
122
- self.gpt_start_audio_token = gpt_start_audio_token
123
- self.gpt_stop_audio_token = gpt_stop_audio_token
124
- self.gpt_code_stride_len = gpt_code_stride_len
125
- self.gpt_use_masking_gt_prompt_approach = gpt_use_masking_gt_prompt_approach
126
- self.gpt_use_perceiver_resampler = gpt_use_perceiver_resampler
127
- self.gpt_checkpointing = gpt_checkpointing
128
- self.gpt_train_solo_embeddings = gpt_train_solo_embeddings
129
 
130
  # Training parameters
131
  self.enable_redaction = enable_redaction
@@ -134,8 +154,13 @@ class XTTSGPTConfig(PretrainedConfig):
134
  self.label_smoothing = label_smoothing
135
 
136
  # Generation parameters
137
- self.gpt_cond_len = gpt_cond_len
138
- self.gpt_cond_chunk_len = gpt_cond_chunk_len
 
 
 
 
 
139
  self.max_ref_len = max_ref_len
140
  self.sound_norm_refs = sound_norm_refs
141
 
@@ -147,6 +172,12 @@ class XTTSGPTConfig(PretrainedConfig):
147
  self.char_limits = char_limits
148
  self.languages = languages
149
 
 
 
 
 
 
 
150
  def to_dict(self):
151
  """Convert config to dictionary"""
152
  config_dict = super().to_dict()
@@ -154,14 +185,14 @@ class XTTSGPTConfig(PretrainedConfig):
154
  return config_dict
155
 
156
  @classmethod
157
- def from_dict(cls, config_dict, *arg, **kwargs):
158
  """Create config from dictionary"""
159
  audio_config = XTTSAudioConfig(**config_dict.pop("audio_config", {}))
160
- return cls(audio_config=audio_config, **config_dict)
161
 
162
  def update_with_tokenizer(self, tokenizer=None):
163
  """Update configuration values based on tokenizer"""
164
  if tokenizer is not None:
165
- self.gpt_number_text_tokens = tokenizer.get_vocab_size()
166
- self.gpt_start_text_token = tokenizer.bos_token_id
167
- self.gpt_stop_text_token = tokenizer.eos_token_id
 
29
  self,
30
  # Model architecture
31
  vocab_size: int = 256,
32
+ hidden_size: int = 1024, # Changed from gpt_n_model_channels
33
+ num_hidden_layers: int = 30, # Changed from gpt_layers
34
+ num_attention_heads: int = 16, # Changed from gpt_n_heads
35
+ n_inner: Optional[int] = None, # Added for GPT-2 compatibility
36
+ max_position_embeddings: int = 2048, # Added for positional embeddings
37
+ layer_norm_epsilon: float = 1e-5, # Added for layer norm
38
+ activation_function: str = "gelu", # Added activation function
39
+ resid_pdrop: float = 0.1, # Added dropout rates
40
+ embd_pdrop: float = 0.1,
41
+ attn_pdrop: float = 0.1,
42
+
43
+ # Specific XTTS parameters
44
  num_chars: int = 255,
45
+ batch_size: int = 1, # Changed from gpt_batch_size
46
+ max_audio_tokens: int = 605, # Changed from gpt_max_audio_tokens
47
+ max_text_tokens: int = 402, # Changed from gpt_max_text_tokens
48
+ max_prompt_tokens: int = 70, # Changed from gpt_max_prompt_tokens
49
+ number_text_tokens: int = 6681, # Changed from gpt_number_text_tokens
50
+ start_text_token: Optional[int] = None, # Changed from gpt_start_text_token
51
+ stop_text_token: Optional[int] = None, # Changed from gpt_stop_text_token
52
+ num_audio_tokens: int = 1026, # Changed from gpt_num_audio_tokens
53
+ start_audio_token: int = 1024, # Changed from gpt_start_audio_token
54
+ stop_audio_token: int = 1025, # Changed from gpt_stop_audio_token
55
+ code_stride_len: int = 1024, # Changed from gpt_code_stride_len
56
+ use_masking_gt_prompt_approach: bool = True, # Changed from gpt_use_masking_gt_prompt_approach
57
+ use_perceiver_resampler: bool = True, # Changed from gpt_use_perceiver_resampler
58
+ checkpointing: bool = False, # Changed from gpt_checkpointing
59
+ train_solo_embeddings: bool = False, # Changed from gpt_train_solo_embeddings
 
 
 
 
 
60
 
61
  # Training parameters
62
  enable_redaction: bool = False,
 
65
  label_smoothing: float = 0.0,
66
 
67
  # Generation parameters
68
+ temperature: float = 0.75,
69
+ length_penalty: float = 1.0,
70
+ repetition_penalty: float = 5.0,
71
+ top_k: int = 50,
72
+ top_p: float = 0.85,
73
+ cond_len: int = 30, # Changed from gpt_cond_len
74
+ cond_chunk_len: int = 4, # Changed from gpt_cond_chunk_len
75
  max_ref_len: int = 30,
76
  sound_norm_refs: bool = False,
77
 
 
85
  pad_token_id: Optional[int] = None,
86
  bos_token_id: Optional[int] = None,
87
  eos_token_id: Optional[int] = None,
88
+
89
+ # GPT-2 compatibility flags
90
+ scale_attn_by_inverse_layer_idx: bool = False,
91
+ reorder_and_upcast_attn: bool = False,
92
+ add_cross_attention: bool = False,
93
+ tie_word_embeddings: bool = True,
94
  **kwargs,
95
  ):
96
  if char_limits is None:
 
118
  )
119
 
120
  self.vocab_size = vocab_size
121
+ self.hidden_size = hidden_size
122
+ self.num_hidden_layers = num_hidden_layers
123
+ self.num_attention_heads = num_attention_heads
124
+ self.n_inner = n_inner
125
+ self.max_position_embeddings = max_position_embeddings
126
+ self.layer_norm_epsilon = layer_norm_epsilon
127
+ self.activation_function = activation_function
128
+ self.resid_pdrop = resid_pdrop
129
+ self.embd_pdrop = embd_pdrop
130
+ self.attn_pdrop = attn_pdrop
131
+
132
+ # XTTS specific parameters
133
  self.num_chars = num_chars
134
+ self.batch_size = batch_size
135
+ self.max_audio_tokens = max_audio_tokens
136
+ self.max_text_tokens = max_text_tokens
137
+ self.max_prompt_tokens = max_prompt_tokens
138
+ self.number_text_tokens = number_text_tokens
139
+ self.start_text_token = start_text_token
140
+ self.stop_text_token = stop_text_token
141
+ self.num_audio_tokens = num_audio_tokens
142
+ self.start_audio_token = start_audio_token
143
+ self.stop_audio_token = stop_audio_token
144
+ self.code_stride_len = code_stride_len
145
+ self.use_masking_gt_prompt_approach = use_masking_gt_prompt_approach
146
+ self.use_perceiver_resampler = use_perceiver_resampler
147
+ self.checkpointing = checkpointing
148
+ self.train_solo_embeddings = train_solo_embeddings
 
 
 
 
 
149
 
150
  # Training parameters
151
  self.enable_redaction = enable_redaction
 
154
  self.label_smoothing = label_smoothing
155
 
156
  # Generation parameters
157
+ self.temperature = temperature
158
+ self.length_penalty = length_penalty
159
+ self.repetition_penalty = repetition_penalty
160
+ self.top_k = top_k
161
+ self.top_p = top_p
162
+ self.cond_len = cond_len
163
+ self.cond_chunk_len = cond_chunk_len
164
  self.max_ref_len = max_ref_len
165
  self.sound_norm_refs = sound_norm_refs
166
 
 
172
  self.char_limits = char_limits
173
  self.languages = languages
174
 
175
+ # GPT-2 compatibility flags
176
+ self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
177
+ self.reorder_and_upcast_attn = reorder_and_upcast_attn
178
+ self.add_cross_attention = add_cross_attention
179
+ self.tie_word_embeddings = tie_word_embeddings
180
+
181
  def to_dict(self):
182
  """Convert config to dictionary"""
183
  config_dict = super().to_dict()
 
185
  return config_dict
186
 
187
  @classmethod
188
+ def from_dict(cls, config_dict, *args, **kwargs):
189
  """Create config from dictionary"""
190
  audio_config = XTTSAudioConfig(**config_dict.pop("audio_config", {}))
191
+ return cls(audio_config=audio_config, **config_dict, **kwargs)
192
 
193
  def update_with_tokenizer(self, tokenizer=None):
194
  """Update configuration values based on tokenizer"""
195
  if tokenizer is not None:
196
+ self.number_text_tokens = tokenizer.get_vocab_size()
197
+ self.start_text_token = tokenizer.bos_token_id
198
+ self.stop_text_token = tokenizer.eos_token_id
xtts2_gpt_modeling.py CHANGED
@@ -8,19 +8,20 @@ from torch.nn import functional as F
8
  from typing import List, Optional, Union, Iterable, Tuple, Mapping
9
 
10
  from transformers import PretrainedConfig
11
- from vllm.attention import AttentionMetadata
12
- from vllm.config import CacheConfig
13
- from vllm.distributed import get_pp_group
14
  from vllm.inputs import InputContext, INPUT_REGISTRY
15
- from vllm.model_executor.layers.linear import ColumnParallelLinear
 
 
16
  from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
17
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
18
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
19
- from vllm.model_executor.models.gpt2 import GPT2Block
20
  from vllm.model_executor.sampling_metadata import SamplingMetadata
21
  from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
22
  from vllm.sequence import IntermediateTensors, SequenceData, VLLM_TOKEN_ID_ARRAY_TYPE
23
- from vllm.model_executor.models.interfaces import SupportsMultiModal
24
 
25
  from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder # noqa
26
  from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler # noqa
@@ -32,17 +33,147 @@ _AUDIO_PLACEHOLDER_TOKEN = 8192 # Using XTTS start_audio_token as placeholder
32
  _AUDIO_TOKENS_PER_SECOND = 6.25
33
  _CODE_STRIDE_LEN = 1024
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def get_xtts_max_audio_tokens(ctx: InputContext) -> int:
37
  """Calculate maximum audio tokens based on text context and audio duration."""
38
- # Based on GPT config and common XTTS settings
39
- text_context = ctx.model_config.max_seq_len - 100 # Reserve space for text
40
- # Allow for ~30 seconds of audio (similar to whisper chunks)
41
- max_audio_duration = 30.0
42
- audio_tokens = math.ceil(max_audio_duration * _AUDIO_TOKENS_PER_SECOND)
43
- total_tokens = text_context + audio_tokens + 4 # +4 for special tokens
44
-
45
- return min(total_tokens, 1000) # Cap at 1000 tokens as specified
46
 
47
 
48
  def dummy_seq_data_for_xtts(
@@ -73,7 +204,7 @@ def dummy_conditioning_for_xtts(
73
  ) -> dict:
74
  """Create dummy conditioning data for XTTS."""
75
  return {
76
- "cond_latents": [(torch.zeros(80, 1024), 22050) for _ in range(audio_count)]
77
  }
78
 
79
 
@@ -106,10 +237,11 @@ def input_mapper_for_xtts(ctx: InputContext, data: object) -> MultiModalInputs:
106
  @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_xtts)
107
  @MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", get_xtts_max_audio_tokens)
108
  @INPUT_REGISTRY.register_dummy_data(dummy_data_for_xtts)
109
- class XttsGPT(nn.Module, SupportsMultiModal):
110
  def __init__(
111
  self,
112
  config: PretrainedConfig,
 
113
  cache_config: Optional[CacheConfig] = None,
114
  quant_config: Optional["QuantizationConfig"] = None,
115
  ):
@@ -119,14 +251,16 @@ class XttsGPT(nn.Module, SupportsMultiModal):
119
 
120
  # XTTS specific components
121
  self.conditioning_encoder = ConditioningEncoder(
122
- 80, config.n_embd, num_attn_heads=config.n_head
 
 
123
  )
124
 
125
  if config.use_perceiver_resampler:
126
  self.conditioning_perceiver = PerceiverResampler(
127
- dim=config.n_embd,
128
  depth=2,
129
- dim_context=config.n_embd,
130
  num_latents=32,
131
  dim_head=64,
132
  heads=8,
@@ -144,7 +278,7 @@ class XttsGPT(nn.Module, SupportsMultiModal):
144
 
145
  # Prediction heads
146
  self.text_head = ColumnParallelLinear(
147
- config.n_embd,
148
  config.vocab_size,
149
  bias=False,
150
  quant_config=quant_config,
@@ -152,7 +286,7 @@ class XttsGPT(nn.Module, SupportsMultiModal):
152
  )
153
 
154
  self.mel_head = ColumnParallelLinear(
155
- config.n_embd,
156
  config.num_audio_tokens,
157
  bias=False,
158
  quant_config=quant_config,
@@ -176,15 +310,9 @@ class XttsGPT(nn.Module, SupportsMultiModal):
176
  conds = cond_input.unsqueeze(1)
177
  return conds
178
 
179
- def forward(
180
- self,
181
- input_ids: torch.Tensor,
182
- positions: torch.Tensor,
183
- kv_caches: List[torch.Tensor],
184
- attn_metadata: AttentionMetadata,
185
- intermediate_tensors: Optional[IntermediateTensors] = None,
186
- cond_latents: Optional[torch.Tensor] = None,
187
- ) -> torch.Tensor:
188
  """Forward pass following VLLM pattern."""
189
  if cond_latents is not None:
190
  # Combine conditioning with input embeddings
@@ -250,25 +378,39 @@ class XttsGPT2Model(nn.Module):
250
  self,
251
  config: PretrainedConfig,
252
  cache_config: Optional[CacheConfig] = None,
253
- quant_config: Optional["QuantizationConfig"] = None,
254
  prefix: str = "",
255
  ):
256
  super().__init__()
257
  self.config = config
258
- self.text_embedding = VocabParallelEmbedding(config.number_text_tokens, config.n_embd)
259
- self.mel_embedding = VocabParallelEmbedding(config.num_audio_tokens, config.n_embd)
 
 
 
 
 
 
 
260
 
261
  self.text_pos_embedding = (
262
- LearnedPositionEmbeddings(config.max_text_seq_len, config.n_embd)
263
- if config.max_mel_seq_len != -1
264
- else functools.partial(config.null_position_embeddings, dim=config.n_embd)
 
 
 
265
  )
 
266
  self.mel_pos_embedding = (
267
- LearnedPositionEmbeddings(config.max_mel_seq_len, config.n_embd)
268
- if config.max_mel_seq_len != -1
269
- else functools.partial(config.null_position_embeddings, dim=config.n_embd)
 
 
 
270
  )
271
- # Build gpt blocks
272
  self.h = nn.ModuleList([
273
  GPT2Block(
274
  config,
@@ -278,32 +420,38 @@ class XttsGPT2Model(nn.Module):
278
  ) for i in range(config.num_hidden_layers)
279
  ])
280
 
281
- self.final_norm = nn.LayerNorm(
282
- config.n_embd,
283
- eps=config.layer_norm_epsilon
284
- )
285
 
286
- def forward( # TODO: this is not correct, allieeate it with the correct implementation
287
- self,
288
- input_ids: torch.Tensor,
289
- position_ids: torch.Tensor,
290
- kv_caches: List[torch.Tensor],
291
- attn_metadata: AttentionMetadata,
292
- intermediate_tensors: Optional[IntermediateTensors],
 
293
  ) -> Union[torch.Tensor, IntermediateTensors]:
294
  if get_pp_group().is_first_rank:
295
- inputs_embeds = self.wte(input_ids)
296
- position_embeds = self.wpe(position_ids)
297
- hidden_states = inputs_embeds + position_embeds
 
 
 
 
298
  else:
299
  assert intermediate_tensors is not None
300
  hidden_states = intermediate_tensors["hidden_states"]
301
 
302
- for i in range(self.start_layer, self.end_layer):
303
- layer = self.h[i]
304
- hidden_states = layer(hidden_states,
305
- kv_caches[i - self.start_layer],
306
- attn_metadata)
 
307
 
308
  if not get_pp_group().is_last_rank:
309
  return IntermediateTensors({"hidden_states": hidden_states})
 
8
  from typing import List, Optional, Union, Iterable, Tuple, Mapping
9
 
10
  from transformers import PretrainedConfig
11
+ from vllm.attention import AttentionMetadata, Attention
12
+ from vllm.config import CacheConfig, MultiModalConfig
13
+ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
14
  from vllm.inputs import InputContext, INPUT_REGISTRY
15
+ from vllm.model_executor.layers.activation import get_act_fn
16
+ from vllm.model_executor.layers.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear
17
+ from vllm.model_executor.layers.quantization import QuantizationConfig
18
  from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
19
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
20
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 
21
  from vllm.model_executor.sampling_metadata import SamplingMetadata
22
  from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
23
  from vllm.sequence import IntermediateTensors, SequenceData, VLLM_TOKEN_ID_ARRAY_TYPE
24
+ from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
25
 
26
  from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder # noqa
27
  from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler # noqa
 
33
  _AUDIO_TOKENS_PER_SECOND = 6.25
34
  _CODE_STRIDE_LEN = 1024
35
 
36
+ class GPT2Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ config: PretrainedConfig,
40
+ cache_config: Optional[CacheConfig] = None,
41
+ quant_config: Optional[QuantizationConfig] = None,
42
+ prefix: str = "",
43
+ ):
44
+ super().__init__()
45
+ total_num_heads = config.num_attention_heads
46
+ self.hidden_size = config.hidden_size
47
+ tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
48
+ assert total_num_heads % tensor_model_parallel_world_size == 0
49
+ self.num_heads = total_num_heads // tensor_model_parallel_world_size
50
+ self.head_dim = self.hidden_size // total_num_heads
51
+ self.scale = self.head_dim**-0.5
52
+
53
+ self.c_attn = QKVParallelLinear(
54
+ self.hidden_size,
55
+ self.head_dim,
56
+ total_num_heads,
57
+ bias=True,
58
+ quant_config=quant_config,
59
+ prefix=f"{prefix}.c_attn",
60
+ )
61
+ self.c_proj = RowParallelLinear(
62
+ self.hidden_size,
63
+ self.hidden_size,
64
+ bias=True,
65
+ quant_config=quant_config,
66
+ prefix=f"{prefix}.c_proj",
67
+ )
68
+ self.attn = Attention(
69
+ self.num_heads,
70
+ self.head_dim,
71
+ scale=self.scale,
72
+ cache_config=cache_config,
73
+ quant_config=quant_config
74
+ )
75
+
76
+ def forward(
77
+ self,
78
+ hidden_states: torch.Tensor,
79
+ kv_cache: torch.Tensor,
80
+ attn_metadata: AttentionMetadata,
81
+ ) -> torch.Tensor:
82
+ qkv, _ = self.c_attn(hidden_states)
83
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
84
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
85
+ attn_output, _ = self.c_proj(attn_output)
86
+ return attn_output
87
+
88
+
89
+ class GPT2MLP(nn.Module):
90
+ def __init__(
91
+ self,
92
+ intermediate_size: int,
93
+ config: PretrainedConfig,
94
+ quant_config: Optional[QuantizationConfig] = None,
95
+ prefix: str = "",
96
+ ):
97
+ super().__init__()
98
+ hidden_size = config.hidden_size
99
+
100
+ self.c_fc = ColumnParallelLinear(
101
+ hidden_size,
102
+ intermediate_size,
103
+ bias=True,
104
+ quant_config=quant_config,
105
+ prefix=f"{prefix}.c_fc",
106
+ )
107
+ self.c_proj = RowParallelLinear(
108
+ intermediate_size,
109
+ hidden_size,
110
+ bias=True,
111
+ quant_config=quant_config,
112
+ prefix=f"{prefix}.c_proj",
113
+ )
114
+ self.act = get_act_fn(config.activation_function, quant_config, intermediate_size)
115
+
116
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
117
+ hidden_states, _ = self.c_fc(hidden_states)
118
+ hidden_states = self.act(hidden_states)
119
+ hidden_states, _ = self.c_proj(hidden_states)
120
+ return hidden_states
121
+
122
+
123
+ class GPT2Block(nn.Module):
124
+ def __init__(
125
+ self,
126
+ config: PretrainedConfig,
127
+ cache_config: Optional[CacheConfig] = None,
128
+ quant_config: Optional[QuantizationConfig] = None,
129
+ prefix: str = "",
130
+ ):
131
+ super().__init__()
132
+ hidden_size = config.hidden_size
133
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
134
+
135
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
136
+ self.attn = GPT2Attention(
137
+ config,
138
+ cache_config,
139
+ quant_config,
140
+ prefix=f"{prefix}.attn"
141
+ )
142
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
143
+ self.mlp = GPT2MLP(
144
+ inner_dim,
145
+ config,
146
+ quant_config,
147
+ prefix=f"{prefix}.mlp"
148
+ )
149
+
150
+ def forward(
151
+ self,
152
+ hidden_states: torch.Tensor,
153
+ kv_cache: torch.Tensor,
154
+ attn_metadata: AttentionMetadata,
155
+ ) -> torch.Tensor:
156
+ residual = hidden_states
157
+ hidden_states = self.ln_1(hidden_states)
158
+ attn_output = self.attn(
159
+ hidden_states=hidden_states,
160
+ kv_cache=kv_cache,
161
+ attn_metadata=attn_metadata,
162
+ )
163
+ hidden_states = attn_output + residual
164
+
165
+ residual = hidden_states
166
+ hidden_states = self.ln_2(hidden_states)
167
+ feed_forward_hidden_states = self.mlp(hidden_states)
168
+ hidden_states = residual + feed_forward_hidden_states
169
+ return hidden_states
170
+
171
+
172
 
173
  def get_xtts_max_audio_tokens(ctx: InputContext) -> int:
174
  """Calculate maximum audio tokens based on text context and audio duration."""
175
+ # Based on GPT config and XTTSv2 settings
176
+ return 608
 
 
 
 
 
 
177
 
178
 
179
  def dummy_seq_data_for_xtts(
 
204
  ) -> dict:
205
  """Create dummy conditioning data for XTTS."""
206
  return {
207
+ "audio": [(torch.zeros(80, 1024), 22050) for _ in range(audio_count)]
208
  }
209
 
210
 
 
237
  @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_xtts)
238
  @MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", get_xtts_max_audio_tokens)
239
  @INPUT_REGISTRY.register_dummy_data(dummy_data_for_xtts)
240
+ class XttsGPT(nn.Module, SupportsMultiModal, SupportsPP):
241
  def __init__(
242
  self,
243
  config: PretrainedConfig,
244
+ multimodal_config: MultiModalConfig,
245
  cache_config: Optional[CacheConfig] = None,
246
  quant_config: Optional["QuantizationConfig"] = None,
247
  ):
 
251
 
252
  # XTTS specific components
253
  self.conditioning_encoder = ConditioningEncoder(
254
+ config.audio_config.mel_channels,
255
+ config.hidden_size,
256
+ num_attn_heads=config.num_attention_heads
257
  )
258
 
259
  if config.use_perceiver_resampler:
260
  self.conditioning_perceiver = PerceiverResampler(
261
+ dim=config.hidden_size,
262
  depth=2,
263
+ dim_context=config.hidden_size,
264
  num_latents=32,
265
  dim_head=64,
266
  heads=8,
 
278
 
279
  # Prediction heads
280
  self.text_head = ColumnParallelLinear(
281
+ config.hidden_size,
282
  config.vocab_size,
283
  bias=False,
284
  quant_config=quant_config,
 
286
  )
287
 
288
  self.mel_head = ColumnParallelLinear(
289
+ config.hidden_size,
290
  config.num_audio_tokens,
291
  bias=False,
292
  quant_config=quant_config,
 
310
  conds = cond_input.unsqueeze(1)
311
  return conds
312
 
313
+ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor],
314
+ attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None,
315
+ cond_latents: Optional[torch.Tensor] = None ) -> torch.Tensor:
 
 
 
 
 
 
316
  """Forward pass following VLLM pattern."""
317
  if cond_latents is not None:
318
  # Combine conditioning with input embeddings
 
378
  self,
379
  config: PretrainedConfig,
380
  cache_config: Optional[CacheConfig] = None,
381
+ quant_config: Optional[QuantizationConfig] = None,
382
  prefix: str = "",
383
  ):
384
  super().__init__()
385
  self.config = config
386
+
387
+ self.text_embedding = VocabParallelEmbedding(
388
+ config.number_text_tokens,
389
+ config.hidden_size
390
+ )
391
+ self.mel_embedding = VocabParallelEmbedding(
392
+ config.num_audio_tokens,
393
+ config.hidden_size
394
+ )
395
 
396
  self.text_pos_embedding = (
397
+ LearnedPositionEmbeddings(
398
+ config.max_text_tokens + 2,
399
+ config.hidden_size
400
+ )
401
+ if config.max_audio_tokens != -1
402
+ else functools.partial(config.null_position_embeddings, dim=config.hidden_size)
403
  )
404
+
405
  self.mel_pos_embedding = (
406
+ LearnedPositionEmbeddings(
407
+ config.max_audio_tokens + 3,
408
+ config.hidden_size
409
+ )
410
+ if config.max_audio_tokens != -1
411
+ else functools.partial(config.null_position_embeddings, dim=config.hidden_size)
412
  )
413
+
414
  self.h = nn.ModuleList([
415
  GPT2Block(
416
  config,
 
420
  ) for i in range(config.num_hidden_layers)
421
  ])
422
 
423
+ self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
424
+
425
+ def get_input_embeddings(self):
426
+ return self.text_embedding
427
 
428
+ def forward(
429
+ self,
430
+ input_ids: Optional[torch.Tensor] = None,
431
+ positions: Optional[torch.Tensor] = None,
432
+ inputs_embeds: Optional[torch.Tensor] = None,
433
+ kv_caches: List[torch.Tensor] = None,
434
+ attn_metadata: AttentionMetadata = None,
435
+ intermediate_tensors: Optional[IntermediateTensors] = None,
436
  ) -> Union[torch.Tensor, IntermediateTensors]:
437
  if get_pp_group().is_first_rank:
438
+ if inputs_embeds is None:
439
+ inputs_embeds = self.text_embedding(input_ids)
440
+ hidden_states = inputs_embeds
441
+
442
+ if positions is not None:
443
+ position_embeds = self.text_pos_embedding(positions)
444
+ hidden_states = hidden_states + position_embeds
445
  else:
446
  assert intermediate_tensors is not None
447
  hidden_states = intermediate_tensors["hidden_states"]
448
 
449
+ for i, block in enumerate(self.h):
450
+ hidden_states = block(
451
+ hidden_states,
452
+ kv_caches[i],
453
+ attn_metadata
454
+ )
455
 
456
  if not get_pp_group().is_last_rank:
457
  return IntermediateTensors({"hidden_states": hidden_states})