|
import functools |
|
import math |
|
from array import array |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from typing import List, Optional, Union, Iterable, Tuple, Mapping |
|
|
|
from transformers import PretrainedConfig |
|
from vllm.attention import AttentionMetadata |
|
from vllm.config import CacheConfig |
|
from vllm.distributed import get_pp_group |
|
from vllm.inputs import InputContext, INPUT_REGISTRY |
|
from vllm.model_executor.layers.linear import ColumnParallelLinear |
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput |
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding |
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
|
from vllm.model_executor.models.gpt2 import GPT2Block |
|
from vllm.model_executor.sampling_metadata import SamplingMetadata |
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs |
|
from vllm.sequence import IntermediateTensors, SequenceData, VLLM_TOKEN_ID_ARRAY_TYPE |
|
from vllm.model_executor.models.interfaces import SupportsMultiModal |
|
|
|
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder |
|
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler |
|
|
|
from TTS.TTS.tts.layers.xtts.gpt import LearnedPositionEmbeddings |
|
|
|
|
|
_AUDIO_PLACEHOLDER_TOKEN = 8192 |
|
_AUDIO_TOKENS_PER_SECOND = 6.25 |
|
_CODE_STRIDE_LEN = 1024 |
|
|
|
|
|
def get_xtts_max_audio_tokens(ctx: InputContext) -> int: |
|
"""Calculate maximum audio tokens based on text context and audio duration.""" |
|
|
|
text_context = ctx.model_config.max_seq_len - 100 |
|
|
|
max_audio_duration = 30.0 |
|
audio_tokens = math.ceil(max_audio_duration * _AUDIO_TOKENS_PER_SECOND) |
|
total_tokens = text_context + audio_tokens + 4 |
|
|
|
return min(total_tokens, 1000) |
|
|
|
|
|
def dummy_seq_data_for_xtts( |
|
ctx: InputContext, |
|
seq_len: int, |
|
audio_count: int, |
|
) -> SequenceData: |
|
"""Create dummy sequence data for XTTS profiling.""" |
|
|
|
audio_len_tokens = math.ceil(_AUDIO_TOKENS_PER_SECOND * 5) |
|
audio_placeholder = array( |
|
VLLM_TOKEN_ID_ARRAY_TYPE, |
|
[_AUDIO_PLACEHOLDER_TOKEN] |
|
) * audio_len_tokens |
|
|
|
|
|
audio_token_ids = (audio_placeholder + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count |
|
|
|
|
|
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - len(audio_token_ids)) |
|
|
|
return SequenceData(audio_token_ids + other_token_ids) |
|
|
|
|
|
def dummy_conditioning_for_xtts( |
|
ctx: InputContext, |
|
audio_count: int, |
|
) -> dict: |
|
"""Create dummy conditioning data for XTTS.""" |
|
return { |
|
"cond_latents": [(torch.zeros(80, 1024), 22050) for _ in range(audio_count)] |
|
} |
|
|
|
|
|
def dummy_data_for_xtts( |
|
ctx: InputContext, |
|
seq_len: int, |
|
mm_counts: Mapping[str, int], |
|
) -> Tuple[SequenceData, dict]: |
|
"""Create complete dummy data for XTTS profiling.""" |
|
audio_count = mm_counts["audio"] |
|
seq_data = dummy_seq_data_for_xtts(ctx, seq_len, audio_count) |
|
cond_data = dummy_conditioning_for_xtts(ctx, audio_count) |
|
return (seq_data, cond_data) |
|
|
|
|
|
def input_mapper_for_xtts(ctx: InputContext, data: object) -> MultiModalInputs: |
|
"""Map input data to XTTS format.""" |
|
if not isinstance(data, list): |
|
data = [data] |
|
|
|
|
|
for audio_input in data: |
|
if not isinstance(audio_input, tuple): |
|
raise NotImplementedError(f"Unsupported data type: {type(audio_input)}") |
|
|
|
return MultiModalInputs({"cond_latents": data}) |
|
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_xtts) |
|
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("audio", get_xtts_max_audio_tokens) |
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_xtts) |
|
class XttsGPT(nn.Module, SupportsMultiModal): |
|
def __init__( |
|
self, |
|
config: PretrainedConfig, |
|
cache_config: Optional[CacheConfig] = None, |
|
quant_config: Optional["QuantizationConfig"] = None, |
|
): |
|
super().__init__() |
|
self.config = config |
|
self.quant_config = quant_config |
|
|
|
|
|
self.conditioning_encoder = ConditioningEncoder( |
|
80, config.n_embd, num_attn_heads=config.n_head |
|
) |
|
|
|
if config.use_perceiver_resampler: |
|
self.conditioning_perceiver = PerceiverResampler( |
|
dim=config.n_embd, |
|
depth=2, |
|
dim_context=config.n_embd, |
|
num_latents=32, |
|
dim_head=64, |
|
heads=8, |
|
ff_mult=4, |
|
use_flash_attn=False, |
|
) |
|
|
|
|
|
self.gpt = XttsGPT2Model( |
|
config, |
|
cache_config, |
|
quant_config, |
|
prefix="gpt" |
|
) |
|
|
|
|
|
self.text_head = ColumnParallelLinear( |
|
config.n_embd, |
|
config.vocab_size, |
|
bias=False, |
|
quant_config=quant_config, |
|
prefix="text_head" |
|
) |
|
|
|
self.mel_head = ColumnParallelLinear( |
|
config.n_embd, |
|
config.num_audio_tokens, |
|
bias=False, |
|
quant_config=quant_config, |
|
prefix="mel_head" |
|
) |
|
|
|
self.sampler = Sampler() |
|
|
|
def get_style_emb(self, cond_input: torch.Tensor, return_latent: bool = False) -> torch.Tensor: |
|
"""Get conditioning embeddings from mel spectrograms.""" |
|
if not return_latent: |
|
if cond_input.ndim == 4: |
|
cond_input = cond_input.squeeze(1) |
|
conds = self.conditioning_encoder(cond_input) |
|
|
|
if hasattr(self, 'conditioning_perceiver'): |
|
conds = self.conditioning_perceiver( |
|
conds.permute(0, 2, 1) |
|
).transpose(1, 2) |
|
else: |
|
conds = cond_input.unsqueeze(1) |
|
return conds |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
positions: torch.Tensor, |
|
kv_caches: List[torch.Tensor], |
|
attn_metadata: AttentionMetadata, |
|
intermediate_tensors: Optional[IntermediateTensors] = None, |
|
cond_latents: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
"""Forward pass following VLLM pattern.""" |
|
if cond_latents is not None: |
|
|
|
input_embeds = self.gpt.get_input_embeddings()(input_ids) |
|
combined_embeds = torch.cat([cond_latents, input_embeds], dim=1) |
|
hidden_states = self.gpt( |
|
inputs_embeds=combined_embeds, |
|
positions=positions, |
|
kv_caches=kv_caches, |
|
attn_metadata=attn_metadata, |
|
intermediate_tensors=intermediate_tensors, |
|
) |
|
else: |
|
hidden_states = self.gpt( |
|
input_ids=input_ids, |
|
positions=positions, |
|
kv_caches=kv_caches, |
|
attn_metadata=attn_metadata, |
|
intermediate_tensors=intermediate_tensors, |
|
) |
|
return hidden_states |
|
|
|
def compute_logits( |
|
self, |
|
hidden_states: torch.Tensor, |
|
sampling_metadata: SamplingMetadata, |
|
) -> torch.Tensor: |
|
"""Compute output logits.""" |
|
text_logits = self.text_head(hidden_states[sampling_metadata.selected_token_indices]) |
|
mel_logits = self.mel_head(hidden_states[sampling_metadata.selected_token_indices]) |
|
return torch.cat([text_logits, mel_logits], dim=1) |
|
|
|
|
|
def sample( |
|
self, |
|
logits: torch.Tensor, |
|
sampling_metadata: SamplingMetadata, |
|
) -> Optional[SamplerOutput]: |
|
"""Sample next tokens using VLLM sampler.""" |
|
return self.sampler(logits, sampling_metadata) |
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
|
"""Load weights following VLLM pattern.""" |
|
params_dict = dict(self.named_parameters(remove_duplicate=False)) |
|
|
|
for name, loaded_weight in weights: |
|
if name not in params_dict: |
|
continue |
|
|
|
param = params_dict[name] |
|
if "c_attn" in name or "c_proj" in name or "c_fc" in name: |
|
if name.endswith(".weight"): |
|
loaded_weight = loaded_weight.t() |
|
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader) |
|
weight_loader(param, loaded_weight) |
|
|
|
|
|
class XttsGPT2Model(nn.Module): |
|
"""VLLM-style implementation of GPT2 core architecture.""" |
|
|
|
def __init__( |
|
self, |
|
config: PretrainedConfig, |
|
cache_config: Optional[CacheConfig] = None, |
|
quant_config: Optional["QuantizationConfig"] = None, |
|
prefix: str = "", |
|
): |
|
super().__init__() |
|
self.config = config |
|
self.text_embedding = VocabParallelEmbedding(config.number_text_tokens, config.n_embd) |
|
self.mel_embedding = VocabParallelEmbedding(config.num_audio_tokens, config.n_embd) |
|
|
|
self.text_pos_embedding = ( |
|
LearnedPositionEmbeddings(config.max_text_seq_len, config.n_embd) |
|
if config.max_mel_seq_len != -1 |
|
else functools.partial(config.null_position_embeddings, dim=config.n_embd) |
|
) |
|
self.mel_pos_embedding = ( |
|
LearnedPositionEmbeddings(config.max_mel_seq_len, config.n_embd) |
|
if config.max_mel_seq_len != -1 |
|
else functools.partial(config.null_position_embeddings, dim=config.n_embd) |
|
) |
|
|
|
self.h = nn.ModuleList([ |
|
GPT2Block( |
|
config, |
|
cache_config, |
|
quant_config, |
|
prefix=f"{prefix}.h.{i}" |
|
) for i in range(config.num_hidden_layers) |
|
]) |
|
|
|
self.final_norm = nn.LayerNorm( |
|
config.n_embd, |
|
eps=config.layer_norm_epsilon |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
position_ids: torch.Tensor, |
|
kv_caches: List[torch.Tensor], |
|
attn_metadata: AttentionMetadata, |
|
intermediate_tensors: Optional[IntermediateTensors], |
|
) -> Union[torch.Tensor, IntermediateTensors]: |
|
if get_pp_group().is_first_rank: |
|
inputs_embeds = self.wte(input_ids) |
|
position_embeds = self.wpe(position_ids) |
|
hidden_states = inputs_embeds + position_embeds |
|
else: |
|
assert intermediate_tensors is not None |
|
hidden_states = intermediate_tensors["hidden_states"] |
|
|
|
for i in range(self.start_layer, self.end_layer): |
|
layer = self.h[i] |
|
hidden_states = layer(hidden_states, |
|
kv_caches[i - self.start_layer], |
|
attn_metadata) |
|
|
|
if not get_pp_group().is_last_rank: |
|
return IntermediateTensors({"hidden_states": hidden_states}) |
|
|
|
hidden_states = self.ln_f(hidden_states) |
|
return hidden_states |