MARS5-TTS / mars5 /ar_generate.py
arnavmehta7's picture
Add files (#1)
8520a55 verified
raw history blame
No virus
8.87 kB
import torch
import torch.nn.functional as F
import torchaudio
import copy
from torch import Tensor, nn
import logging
from .model import length_to_mask
from .samplers import (apply_typical_p, early_eos_penalty,
top_k_top_p_filtering, freq_rep_penalty)
from .nn_future import RotatingBufferCache
from .minbpe.codebook import CodebookTokenizer
from .minbpe.regex import RegexTokenizer
@torch.inference_mode()
def ar_generate(texttok: RegexTokenizer, speechtok: CodebookTokenizer,
codeclm: nn.Module, xx: Tensor, ss_gen: Tensor, first_codex_idx: int,
max_len: int = 1500, fp16: bool = True, temperature: float = 1.0, topk: int = None,
top_p=1.0, alpha_frequency=0, alpha_presence=0, penalty_window=100,
typical_p=1.0, eos_penalty_factor=1.0, eos_penalty_decay=0, n_phones_gen=None, vocode=True,
beam_width: int = 1, beam_length_penalty=2, use_kv_cache: bool = True) -> tuple[Tensor, Tensor]:
""" Use the `codeclm` language model to autoregressively generate a completion of `xx` (seq_len), where the first `first_codex_idx`-1
indices correspond to the input phones. The output generation is limited to at most `max_len` (measured as num latent codes).
Returns both output first quantizer codes and synthesized audio using `codec`. Use decoding with `beam_width` to keep
track of top `beam_width` outcomes, selecting the top one among them.
- Optionally vocode if `vocode` (default True).
- See `InferenceConfig` for other inference docs.
"""
assert xx.dim() == 1, "Only batch size of 1 is currently supported."
assert beam_width == 1, "Only beam size of 1 is currently supported."
# internally our batch size will be the beam width
bs = beam_width
x_inp = xx[None].repeat(bs, 1) # (bs, seq_len)
ss_gen = ss_gen[None].repeat(bs, 1, 1)
# We must subtract 1 in the line below so that we match the train-time conditions of having a
# False padding value for the <bos> token position. This is needed so that we correctly use the
# _acoustic_ and not the linguistic language embedding for the <bos> token.
offsets = torch.tensor([first_codex_idx - 1 for _ in range(bs)], dtype=torch.long, device=xx.device)
valid_logit_idx_start = len(texttok.vocab) # vocab['s2i']['quant0-0000']
valid_logit_idx_end = len(texttok.vocab) + len(speechtok.vocab) + 1 # vocab['s2i']['quant1-0000']
# Make mask that is True where we have valid outputs, False otherwise (where we have text outputs).
# logit_mask = torch.zeros(n_vocab, dtype=bool, device=x_inp.device)
# logit_mask[valid_logit_idx_start:valid_logit_idx_end] = True
# logit_mask[vocab['s2i']['<eos>']] = True
cum_logprobs = torch.zeros(bs, dtype=torch.float, device=x_inp.device)
eos_idx = len(texttok.vocab) + speechtok.special_tokens['<|endofspeech|>']
n_vocab = len(texttok.vocab) + len(speechtok.vocab)
logging.info(f"Starting beam decoding with beam_width={beam_width}")
prev_ids = [[] for _ in range(bs)]
cache = None
if use_kv_cache:
# Initialise kv cache
cache_window = min(codeclm.ar.args.sliding_window, x_inp.shape[-1] + max_len)
cache = RotatingBufferCache(codeclm.ar.args.n_layers, bs, cache_window, codeclm.ar.args.n_kv_heads, codeclm.ar.args.head_dim)
cache.to(device=x_inp.device, dtype=torch.float16)
counter = 0
while x_inp.shape[-1] < max_len:
counter += 1
gen_length = torch.tensor([x_inp.shape[-1] for _ in range(bs)], dtype=torch.long, device=xx.device)
padding_mask = length_to_mask(gen_length, offsets)
with torch.autocast('cuda', enabled=fp16):
logits: Tensor = codeclm(x_inp, padding_mask, spk_reference=ss_gen, cache=cache, counter=counter)
logits = logits.float()
logits = logits[:, -1] # select last index, now (bs, logit_dim)
# <---------------------- logit filtering ---------------------->
filtered_logits = logits.clone()
# apply repetition penalty before logit mask if any item in the beam has more than 1 prior token.
if len(prev_ids[0]) > 1:
filtered_logits = freq_rep_penalty(filtered_logits, previous=torch.tensor(prev_ids, dtype=torch.long),
alpha_frequency=alpha_frequency, alpha_presence=alpha_presence,
penalty_window=penalty_window)
filtered_logits[..., :valid_logit_idx_start-1] = float('-inf')
filtered_logits[..., valid_logit_idx_end:] = float('-inf')
if n_phones_gen is not None:
# apply eos penalty
filtered_logits = early_eos_penalty(filtered_logits, len(prev_ids[0]), n_phones_gen,
eos_penalty_decay, eos_penalty_factor,
eos_index=eos_idx)
filtered_logits = filtered_logits / temperature
filtered_logits = top_k_top_p_filtering(filtered_logits, top_k=topk, top_p=top_p)
filtered_logits = apply_typical_p(filtered_logits, mass=typical_p)
# mask out anything that isn't first quantizer output codes
filtered_logits[..., :valid_logit_idx_start-1] = float('-inf')
filtered_logits[..., valid_logit_idx_end:] = float('-inf')
logits = filtered_logits
# <---------------------- next frame prediction --------------------->
logprobs = logits.log_softmax(dim=-1)
# update assignments: if any beam ended in <eos> last step, it MUST also end in <eos> this step.
# so, below we multiply the logits with a True/False mask, setting to
for j in range(bs):
if x_inp[j, -1] == eos_idx:
# do not add any additional probability to it, keeping it the same for all vocab idxs
logprobs[j] = float('-inf') # zero probability of anything non-eos after 1 eos
logprobs[j, eos_idx] = 0 # probability=1 of <eos> after <eos>
candidate_cum_logprobs = cum_logprobs[:, None] + logprobs # (bs, 1) + (bs, vocab) -> (bs, vocab)
logp_flat = logprobs.flatten()
candidates = torch.multinomial(logp_flat.exp(), num_samples=beam_width, replacement=False) # (bs,)
# Ravel it up:
beam_idxs = candidates // n_vocab # (bs,)
tok_inds_in_each_beam = candidates % n_vocab # (bs,)
# check for breaks
if torch.all(tok_inds_in_each_beam == eos_idx):
# apply length penalty:
non_eos_toks = (x_inp != eos_idx).sum(dim=-1) # (bs,) number of non eos toks
gen_length = non_eos_toks - first_codex_idx
penalties = (gen_length**beam_length_penalty)
penalized_cum_tok_logp = candidate_cum_logprobs / penalties[:, None]
eos_avg_logps = penalized_cum_tok_logp[:, eos_idx]
best_beam_idx = eos_avg_logps.argmax()
best_avg_logp = eos_avg_logps[best_beam_idx]
best_beam = x_inp[best_beam_idx]
logging.info((f"best beam = {best_beam_idx} @ penalized_cum_tok_logp = {best_avg_logp.item():.3f} |\n num toks: {non_eos_toks.cpu().tolist()}. "
f"Candidates: {eos_avg_logps.cpu()} |\n non-eos toks: {non_eos_toks.cpu().tolist()} |\n penalties: {penalties.cpu().tolist()} | "
f"raw cumulative probs: {candidate_cum_logprobs[:, eos_idx].cpu().tolist()}"))
break
# update beam histories:
x_inp = x_inp[beam_idxs]
# update next token
next_sample = tok_inds_in_each_beam
# update cum logprob
cum_logprobs = cum_logprobs[beam_idxs] + logprobs[beam_idxs, tok_inds_in_each_beam]
# update prior inds to point to correct beam
prev_ids = [copy.deepcopy(prev_ids[beam_idx.item()]) for beam_idx in beam_idxs]
# add new tokens to previous ids
for j in range(bs):
prev_ids[j].append(tok_inds_in_each_beam[j].item())
logging.debug("L%d | next sample: %s | beam: %s | cum_logp: %s", len(x_inp[0]), next_sample.cpu().tolist(), beam_idxs.cpu().tolist(), cum_logprobs.cpu())
# update cache with beam indexes
if cache is not None:
cache.cache_k = cache.cache_k[:, beam_idxs]
cache.cache_v = cache.cache_v[:, beam_idxs]
# add 1 None below to make (bs,) -> (bs, 1) so we can concat along seq len dim.
x_inp = torch.cat([x_inp, next_sample[:, None]], dim=-1)
if x_inp.shape[-1] >= max_len - 1:
logging.warning(f"[autoregressive generation] output length = {x_inp.shape[-1]} -- inference likely failed or input too long!")
best_beam = x_inp[0]
if not vocode: return best_beam # (seq_len,)
else: raise AssertionError()