import functools import math from array import array import torch import torch.nn as nn from torch.nn import functional as F from typing import List, Optional, Union, Iterable, Tuple, Mapping from transformers import PretrainedConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig from vllm.distributed import get_pp_group from vllm.inputs import InputContext, INPUT_REGISTRY from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.gpt2 import GPT2Block from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs from vllm.sequence import IntermediateTensors, SequenceData, VLLM_TOKEN_ID_ARRAY_TYPE from vllm.model_executor.models.interfaces import SupportsMultiModal from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder # noqa from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler # noqa from TTS.TTS.tts.layers.xtts.gpt import LearnedPositionEmbeddings # Constants for token calculation _AUDIO_PLACEHOLDER_TOKEN = 8192 # Using XTTS start_audio_token as placeholder _AUDIO_TOKENS_PER_SECOND = 6.25 _CODE_STRIDE_LEN = 1024 def get_xtts_max_audio_tokens(ctx: InputContext) -> int: """Calculate maximum audio tokens based on text context and audio duration.""" # Based on GPT config and common XTTS settings text_context = ctx.model_config.max_seq_len - 100 # Reserve space for text # Allow for ~30 seconds of audio (similar to whisper chunks) max_audio_duration = 30.0 audio_tokens = math.ceil(max_audio_duration * _AUDIO_TOKENS_PER_SECOND) total_tokens = text_context + audio_tokens + 4 # +4 for special tokens return min(total_tokens, 1000) # Cap at 1000 tokens as specified def dummy_seq_data_for_xtts( ctx: InputContext, seq_len: int, audio_count: int, ) -> SequenceData: """Create dummy sequence data for XTTS profiling.""" # Calculate audio token space needed audio_len_tokens = math.ceil(_AUDIO_TOKENS_PER_SECOND * 5) # Assume 5s per chunk audio_placeholder = array( VLLM_TOKEN_ID_ARRAY_TYPE, [_AUDIO_PLACEHOLDER_TOKEN] ) * audio_len_tokens # Add separator between chunks audio_token_ids = (audio_placeholder + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count # Fill remaining sequence with padding other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - len(audio_token_ids)) return SequenceData(audio_token_ids + other_token_ids) def dummy_conditioning_for_xtts( ctx: InputContext, audio_count: int, ) -> dict: """Create dummy conditioning data for XTTS.""" return { "cond_latents": [(torch.zeros(80, 1024), 22050) for _ in range(audio_count)] } def dummy_data_for_xtts( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], ) -> Tuple[SequenceData, dict]: """Create complete dummy data for XTTS profiling.""" audio_count = mm_counts["audio"] seq_data = dummy_seq_data_for_xtts(ctx, seq_len, audio_count) cond_data = dummy_conditioning_for_xtts(ctx, audio_count) return (seq_data, cond_data) def input_mapper_for_xtts(ctx: InputContext, data: object) -> MultiModalInputs: """Map input data to XTTS format.""" if not isinstance(data, list): data = [data] # Each item should be a tuple of (mel_spec, sample_rate) for audio_input in data: if not isinstance(audio_input, tuple): raise NotImplementedError(f"Unsupported data type: {type(audio_input)}") return MultiModalInputs({"cond_latents": data}) @MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_xtts) @MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", get_xtts_max_audio_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_xtts) class XttsGPT(nn.Module, SupportsMultiModal): def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional["QuantizationConfig"] = None, ): super().__init__() self.config = config self.quant_config = quant_config # XTTS specific components self.conditioning_encoder = ConditioningEncoder( 80, config.n_embd, num_attn_heads=config.n_head ) if config.use_perceiver_resampler: self.conditioning_perceiver = PerceiverResampler( dim=config.n_embd, depth=2, dim_context=config.n_embd, num_latents=32, dim_head=64, heads=8, ff_mult=4, use_flash_attn=False, ) # Core GPT components following VLLM pattern self.gpt = XttsGPT2Model( config, cache_config, quant_config, prefix="gpt" ) # Prediction heads self.text_head = ColumnParallelLinear( config.n_embd, config.vocab_size, bias=False, quant_config=quant_config, prefix="text_head" ) self.mel_head = ColumnParallelLinear( config.n_embd, config.num_audio_tokens, bias=False, quant_config=quant_config, prefix="mel_head" ) self.sampler = Sampler() def get_style_emb(self, cond_input: torch.Tensor, return_latent: bool = False) -> torch.Tensor: """Get conditioning embeddings from mel spectrograms.""" if not return_latent: if cond_input.ndim == 4: cond_input = cond_input.squeeze(1) conds = self.conditioning_encoder(cond_input) if hasattr(self, 'conditioning_perceiver'): conds = self.conditioning_perceiver( conds.permute(0, 2, 1) ).transpose(1, 2) else: conds = cond_input.unsqueeze(1) return conds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, cond_latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass following VLLM pattern.""" if cond_latents is not None: # Combine conditioning with input embeddings input_embeds = self.gpt.get_input_embeddings()(input_ids) combined_embeds = torch.cat([cond_latents, input_embeds], dim=1) hidden_states = self.gpt( inputs_embeds=combined_embeds, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, ) else: hidden_states = self.gpt( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, ) return hidden_states def compute_logits( # useless but kept for compatibility self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: """Compute output logits.""" text_logits = self.text_head(hidden_states[sampling_metadata.selected_token_indices]) mel_logits = self.mel_head(hidden_states[sampling_metadata.selected_token_indices]) return torch.cat([text_logits, mel_logits], dim=1) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: """Sample next tokens using VLLM sampler.""" return self.sampler(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """Load weights following VLLM pattern.""" params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: if name not in params_dict: continue param = params_dict[name] if "c_attn" in name or "c_proj" in name or "c_fc" in name: if name.endswith(".weight"): loaded_weight = loaded_weight.t() weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) class XttsGPT2Model(nn.Module): """VLLM-style implementation of GPT2 core architecture.""" def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional["QuantizationConfig"] = None, prefix: str = "", ): super().__init__() self.config = config self.text_embedding = VocabParallelEmbedding(config.number_text_tokens, config.n_embd) self.mel_embedding = VocabParallelEmbedding(config.num_audio_tokens, config.n_embd) self.text_pos_embedding = ( LearnedPositionEmbeddings(config.max_text_seq_len, config.n_embd) if config.max_mel_seq_len != -1 else functools.partial(config.null_position_embeddings, dim=config.n_embd) ) self.mel_pos_embedding = ( LearnedPositionEmbeddings(config.max_mel_seq_len, config.n_embd) if config.max_mel_seq_len != -1 else functools.partial(config.null_position_embeddings, dim=config.n_embd) ) # Build gpt blocks self.h = nn.ModuleList([ GPT2Block( config, cache_config, quant_config, prefix=f"{prefix}.h.{i}" ) for i in range(config.num_hidden_layers) ]) self.final_norm = nn.LayerNorm( config.n_embd, eps=config.layer_norm_epsilon ) def forward( # TODO: this is not correct, allieeate it with the correct implementation self, input_ids: torch.Tensor, position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] for i in range(self.start_layer, self.end_layer): layer = self.h[i] hidden_states = layer(hidden_states, kv_caches[i - self.start_layer], attn_metadata) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states