Last commit not found
import typing as tp | |
from einops import rearrange | |
import numpy as np | |
import torch | |
from torch import nn | |
class EncodecModel(nn.Module): | |
def __init__(self, | |
decoder=None, | |
quantizer=None, | |
frame_rate=None, | |
sample_rate=None, | |
channels=None, | |
causal=False, | |
renormalize=False): | |
super().__init__() | |
self.frame_rate=0 | |
self.sample_rate=0 | |
self.channels=0 | |
self.decoder = decoder | |
self.quantizer = quantizer | |
self.frame_rate = frame_rate | |
self.sample_rate = sample_rate | |
self.channels = channels | |
self.renormalize = renormalize | |
self.causal = causal | |
if self.causal: | |
# we force disabling here to avoid handling linear overlap of segments | |
# as supported in original EnCodec codebase. | |
assert not self.renormalize, 'Causal model does not support renormalize' | |
def total_codebooks(self): | |
"""Total number of quantizer codebooks available.""" | |
return self.quantizer.total_codebooks | |
def num_codebooks(self): | |
"""Active number of codebooks used by the quantizer.""" | |
return self.quantizer.num_codebooks | |
def set_num_codebooks(self, n): | |
"""Set the active number of codebooks used by the quantizer.""" | |
self.quantizer.set_num_codebooks(n) | |
def cardinality(self): | |
"""Cardinality of each codebook.""" | |
return self.quantizer.bins | |
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: | |
scale: tp.Optional[torch.Tensor] | |
if self.renormalize: | |
mono = x.mean(dim=1, keepdim=True) | |
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() | |
scale = 1e-8 + volume | |
x = x / scale | |
scale = scale.view(-1, 1) | |
else: | |
scale = None | |
return x, scale | |
def postprocess(self, | |
x: torch.Tensor, | |
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: | |
if scale is not None: | |
assert self.renormalize | |
x = x * scale.view(-1, 1, 1) | |
return x | |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): | |
# B,K,T -> B,C,T | |
emb = self.decode_latent(codes) | |
out = self.decoder(emb) | |
out = self.postprocess(out, scale) | |
return out | |
def decode_latent(self, codes: torch.Tensor): | |
"""Decode from the discrete codes to continuous latent space.""" | |
return self.quantizer.decode(codes) |