import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from .nn_future import (FNNSwiGLU, MistralTransformer, ModelArgs, RotatingBufferCache, SinePositionalEmbedding) from .utils import construct_padding_mask, length_to_mask LAYERNORM_EPS = 4e-5 # ------------------------ # Code adapted from OpenAI guided diffusion repo def timestep_embedding(timesteps, dim, max_period=10000, dtype=torch.float32): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding # -------------------------------- # autoregressive codec language model class CodecLM(nn.Module): def __init__(self, n_vocab, dim=1536, nhead=24, n_layers=26, n_spk_layers=2, dim_ff_scale=None, sliding_window=3000) -> None: super().__init__() if dim_ff_scale is None: hidden_dim = int(dim*4*(3/4)) else: hidden_dim = int(dim*dim_ff_scale) self.cfg = ModelArgs(n_vocab, dim=dim, n_layers=n_layers, n_heads=nhead, n_kv_heads=nhead, hidden_dim=hidden_dim, sliding_window=sliding_window) self.ar = MistralTransformer(self.cfg) self.embed = nn.Embedding(n_vocab, dim) # --- spk embedding network dim_ff = int(dim*4*(3/4)) self.pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True) self.ref_chunked_emb = ChunkedEmbedding(1024 + 1, 8, dim) # add 1 for pad idx self.spk_identity_emb = nn.Embedding(1, dim) # define custom decoder encoder_layer = nn.TransformerEncoderLayer(dim, nhead, dim_ff, activation=FNNSwiGLU(dim, dim_ff), dropout=0, batch_first=True, norm_first=True, layer_norm_eps=LAYERNORM_EPS) encoder_layer.linear1 = nn.Identity() self.spk_encoder = nn.TransformerEncoder(encoder_layer, n_spk_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS)) # monkeypatch for broken copy.deepcopy of nn.Modules in nn.TransformerDecoder for l in self.spk_encoder.layers: l.activation = FNNSwiGLU(dim, dim_ff) @torch.inference_mode def get_spk_embedding(self, spk_reference, c_codes_lengths=None) -> Tensor: """ Gets speaker reference embeddings using `spk_reference` codes of shape (bs, seq_len, n_codebooks). """ bs = spk_reference.shape[0] if bs != 1: raise AssertionError(f"Speaker embedding extraction only implemented using for bs=1 currently.") spk_seq = self.ref_chunked_emb(spk_reference) # (bs, sl, dim) spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim) spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim) # add pos encoding spk_seq = self.pos_embedding(spk_seq) # codebook goes from indices 0->1023, padding is idx 1024 (the 1025th entry) src_key_padding_mask = construct_padding_mask(spk_reference[:, :, 0], 1024) src_key_padding_mask = torch.cat(( # append a zero here since we DO want to attend to initial position. torch.zeros(src_key_padding_mask.shape[0], 1, dtype=bool, device=src_key_padding_mask.device), src_key_padding_mask ), dim=1) # pass through transformer res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim). return res.squeeze(1) def forward(self, x: Tensor, x_padding_mask: Optional[Tensor] = None, spk_reference: Optional[Tensor] = None, cache: Optional[RotatingBufferCache] = None, counter: int = 0) -> Tensor: """ Inputs: - `x`: (bs, seq_len, vocab_size) - `x_padding_mask`: (bs, seq_len) mask for each input, True for positions to *ignore*, False otherwise. Note that since this is an autoregressive model, this doesn't actually matter for infernece, so it is ignored at inference. - `spk_reference`: (bs, seq_len, n_codebooks) corresponding to the speaker reference to clone from. - `cache` and `counter`: used for kv caching, optional. Returns `x` of same shape (bs, seq_len, dim) """ x = self.embed(x) # --- speaker reference/embedding if spk_reference is not None: # compute ref bs = spk_reference.shape[0] spk_seq = self.ref_chunked_emb(spk_reference) # (bs, sl, dim) spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim) spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim) # add pos encoding spk_seq = self.pos_embedding(spk_seq) # codebook goes from indices 0->1023, padding is idx 1024 (the 1025th entry) src_key_padding_mask = construct_padding_mask(spk_reference[:, :, 0], 1024) src_key_padding_mask = torch.cat(( # append a zero here since we DO want to attend to initial position. torch.zeros(src_key_padding_mask.shape[0], 1, dtype=bool, device=src_key_padding_mask.device), src_key_padding_mask ), dim=1) # pass through transformer res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim). x = torch.cat([res, x], dim=1) positions = torch.arange(0, x.shape[1], device=x.device, dtype=torch.long) if cache is not None and counter != 1: # using only the last token to predict the next one x = x[:,-1,:].unsqueeze(1) positions = positions[-1:] x = self.ar(x, positions, cache) # (bs, seq_len, vocab) if spk_reference is not None and (cache is None or counter == 1): x = x[:, 1:] # strip out the first output token corresponding to the speaker embedding token. return x # ------------------------- # residual discrete diffusion model class ChunkedEmbedding(nn.Module): def __init__(self, codebook_size: int, n_quantizer: int, dim: int) -> None: super().__init__() assert dim % n_quantizer == 0, f"ChunkedEmbedding output dim ({dim}) must be divisible by n_quant {n_quantizer}" self.embs = nn.ModuleList([nn.Embedding(codebook_size, dim//n_quantizer) for _ in range(n_quantizer)]) def forward(self, x: Tensor) -> Tensor: """ Embeds each codebook index in `x` (bs, seq_len, n_quantizer) to an embedding vector, concatenating results. Returns output of shape (bs, seq_len, dim) """ y = torch.cat([self.embs[i](x[..., i]) for i in range(x.shape[-1])], dim=-1) return y class ResidualTransformer(nn.Module): def __init__(self, n_text_vocab, n_quant=1024, dim=1024, nhead=16, enc_layers=8, dec_layers=16, n_spk_layers=3, c_quant_levels=8, pred_quant_levels=8, t_emb_dim=1024, norm_first=True, p_cond_drop=0.1, dropout=0) -> None: super().__init__() self.cond_pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True) self.pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True) # *4 from heuristic, *2/3 from swiglu, since there are 3 linear matrices not 2. # so we must keep # params the same. dim_ff = int(dim*4*(3/4)) # define custom encoder encoder_layer = nn.TransformerEncoderLayer(dim, nhead, dim_ff, activation=FNNSwiGLU(dim, dim_ff), dropout=dropout, batch_first=True, norm_first=norm_first, layer_norm_eps=LAYERNORM_EPS) encoder_layer.linear1 = nn.Identity() encoder = nn.TransformerEncoder(encoder_layer, enc_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS) if norm_first else None) # define custom decoder decoder_layer = nn.TransformerDecoderLayer(dim, nhead, dim_ff, activation=FNNSwiGLU(dim, dim_ff), dropout=dropout, batch_first=True, norm_first=norm_first, layer_norm_eps=LAYERNORM_EPS) decoder_layer.linear1 = nn.Identity() decoder = nn.TransformerDecoder(decoder_layer, dec_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS) if norm_first else None) # monkeypatch for broken copy.deepcopy of nn.Modules in nn.TransformerDecoder for l in decoder.layers: l.activation = FNNSwiGLU(dim, dim_ff) self.tfm = nn.Transformer(dim, nhead, dim_feedforward=dim_ff, batch_first=True, norm_first=norm_first, num_encoder_layers=enc_layers, num_decoder_layers=dec_layers, custom_encoder=encoder, custom_decoder=decoder, layer_norm_eps=LAYERNORM_EPS, dropout=dropout ) # Timestep embedding network self.t_emb_dim = t_emb_dim self.timestep_encoder_emb = nn.Sequential( nn.Linear(t_emb_dim, dim), nn.SiLU(), nn.Linear(dim, dim) ) self.timestep_decoder_emb = nn.Sequential( nn.Linear(t_emb_dim, dim), nn.SiLU(), nn.Linear(dim, dim) ) self.text_embed = nn.Embedding(n_text_vocab, dim) ## ----> reference / conditioning encoder: self.ref_embedder = ChunkedEmbedding(n_quant, c_quant_levels, dim) self.ref_pos_embedding = SinePositionalEmbedding(dim, scale=False, alpha=True) self.spk_identity_emb = nn.Embedding(1, dim) spk_encoder_layer = nn.TransformerEncoderLayer(dim, nhead, dim_ff, activation=FNNSwiGLU(dim, dim_ff), dropout=dropout, batch_first=True, norm_first=True, layer_norm_eps=LAYERNORM_EPS) spk_encoder_layer.linear1 = nn.Identity() self.spk_encoder = nn.TransformerEncoder(spk_encoder_layer, n_spk_layers, norm=nn.LayerNorm(dim, eps=LAYERNORM_EPS)) # monkeypatch for broken copy.deepcopy of nn.Modules in nn.TransformerDecoder for l in self.spk_encoder.layers: l.activation = FNNSwiGLU(dim, dim_ff) # ----> end speaker encoder network # self.residual_encoder = nn.Embedding(n_quant, dim) # only encode first quantization level of decoder input. self.residual_encoder = ChunkedEmbedding(n_quant, c_quant_levels, dim) self.residual_decoder = nn.ModuleList([ nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, n_quant) ) for i in range(pred_quant_levels) ]) self.n_quantizer = pred_quant_levels self.p_cond_drop = p_cond_drop @torch.inference_mode def get_spk_embedding(self, c_codes, c_codes_length) -> Tensor: """ Obtain speaker embedding vectors using `c_codes` from reference encodec sequences, and `c_codes_length` of lengths for each sequence """ bs = c_codes.shape[0] spk_seq = self.ref_embedder(c_codes) # (bs, sl, dim) spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim) spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim) # add pos encoding spk_seq = self.ref_pos_embedding(spk_seq) # add 1 to c_codes_length to account for the fact that we concatenate the spk_ref_emb to it. src_key_padding_mask = length_to_mask(c_codes_length+1, torch.zeros_like(c_codes_length), max_len=spk_seq.shape[1]) src_key_padding_mask = src_key_padding_mask.to(dtype=torch.bool, device=spk_seq.device) # pass through transformer res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim). return res.squeeze(1) def forward(self, c_text: Tensor, c_codes: Tensor, c_texts_length: Tensor, c_codes_length: Tensor, x: Tensor, x_padding_mask: Tensor, t: Tensor, drop_cond=False): """ Input: - `c_text`: (bs, seq_len1) the prompt text (BPE encoded) - `c_codes`: (bs, seq_len2, n_quant) the full tokenized codes of the reference speech - `c_texts_length`: (bs, ) the length of the codes in the text prompt - `c_codes_length`: (bs, ) the length of the prompt acoustic token codes in `c_codes`. - `x`: (bs, seq_len3) L0 residual codes - `x`: (bs, seq_len3, n_quant) L0 residual codes - `x_padding_mask`: (bs, seq_len3) masking for residual codes - `t`: (bs) timestep - `drop_cond`: bool, whether or not to forcibly drop the conditioning information. Returns: - outs: (bs, seq_len, n_quantizer, codebook_size) """ c_text = self.text_embed(c_text) # (bs, seq_len1, dim) ## ----> reference / conditioning encoder: bs = c_codes.shape[0] if self.training: zero_cond_inds = torch.rand_like(t, dtype=c_text.dtype) < self.p_cond_drop else: # never randomly zero when in eval mode zero_cond_inds = torch.zeros_like(t, dtype=torch.bool) if drop_cond: # force drop conditioning zero_cond_inds = torch.ones_like(t, dtype=torch.bool) c_codes_length[zero_cond_inds] = 0 c_codes[zero_cond_inds] = 1024 spk_seq = self.ref_embedder(c_codes) # (bs, sl, dim) spk_ref_emb = self.spk_identity_emb.weight[None].expand(bs, -1, -1) # (bs, 1, dim) spk_seq = torch.cat([spk_ref_emb, spk_seq], dim=1) # (bs, 1+sl, dim) # add pos encoding spk_seq = self.ref_pos_embedding(spk_seq) # add 1 to c_codes_length to account for the fact that we concatenate the spk_ref_emb to it. src_key_padding_mask = length_to_mask(c_codes_length+1, torch.zeros_like(c_codes_length), max_len=spk_seq.shape[1]) src_key_padding_mask = src_key_padding_mask.to(dtype=torch.bool, device=spk_seq.device) # pass through transformer res = self.spk_encoder(spk_seq, is_causal=False, src_key_padding_mask=src_key_padding_mask)[:, :1] # select first element -> now (bs, 1, dim). c_codes = res # (bs, 1, dim) c_codes_lengths_extract = torch.ones_like(c_codes_length) # manually override all the code lengths to equal 1, since we only have 1 spk embedding. ## ----> end reference / conditioning encoder: ## ----> timestep embeddings and parsing t_emb = timestep_embedding(t, self.t_emb_dim, dtype=c_text.dtype) t_emb_encoder = self.timestep_encoder_emb(t_emb) # (bs, t_dim) t_emb_decoder = self.timestep_decoder_emb(t_emb) ## ----> concatenating text/phone inputs and implicit speaker embedding. c_phones_unpacked = nn.utils.rnn.unpad_sequence(c_text, c_texts_length.cpu(), batch_first=True) c_codes_unpacked = nn.utils.rnn.unpad_sequence(c_codes, c_codes_lengths_extract.cpu(), batch_first=True) # >>> Concat [speaker codes, text codes] assert all(b.shape[0] == 1 for b in c_codes_unpacked) c_joined = [torch.cat((b, a), dim=0) for a, b in zip(c_phones_unpacked, c_codes_unpacked)] c = nn.utils.rnn.pad_sequence(c_joined, batch_first=True) c_joined_lengths = torch.tensor([p.shape[0] for p in c_joined], device=c.device, dtype=torch.long) c_padding_mask = length_to_mask(c_joined_lengths, torch.zeros_like(c_joined_lengths)) c = self.cond_pos_embedding(c) ## Format input: x = self.residual_encoder(x) # (bs, seq_len3, dim) x = self.pos_embedding(x) x = x + t_emb_decoder[:, None] c = c + t_emb_encoder[:, None] ## Perform prediction: output = self.tfm(c, x, src_key_padding_mask=c_padding_mask, tgt_key_padding_mask=x_padding_mask, memory_key_padding_mask=c_padding_mask) # (bs, seq_len, dim) outs = torch.stack([self.residual_decoder[i](output) for i in range(self.n_quantizer)], dim=-1) # (bs, seq_len, logit_dim, n_quant) return outs