|
import math
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import tqdm
|
|
from audiotools import AudioSignal
|
|
from torch import nn
|
|
|
|
SUPPORTED_VERSIONS = ["1.0.0"]
|
|
|
|
|
|
@dataclass
|
|
class DACFile:
|
|
codes: torch.Tensor
|
|
|
|
|
|
chunk_length: int
|
|
original_length: int
|
|
input_db: float
|
|
channels: int
|
|
sample_rate: int
|
|
padding: bool
|
|
dac_version: str
|
|
|
|
def save(self, path):
|
|
artifacts = {
|
|
"codes": self.codes.numpy().astype(np.uint16),
|
|
"metadata": {
|
|
"input_db": self.input_db.numpy().astype(np.float32),
|
|
"original_length": self.original_length,
|
|
"sample_rate": self.sample_rate,
|
|
"chunk_length": self.chunk_length,
|
|
"channels": self.channels,
|
|
"padding": self.padding,
|
|
"dac_version": SUPPORTED_VERSIONS[-1],
|
|
},
|
|
}
|
|
path = Path(path).with_suffix(".dac")
|
|
with open(path, "wb") as f:
|
|
np.save(f, artifacts)
|
|
return path
|
|
|
|
@classmethod
|
|
def load(cls, path):
|
|
artifacts = np.load(path, allow_pickle=True)[()]
|
|
codes = torch.from_numpy(artifacts["codes"].astype(int))
|
|
if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
|
|
raise RuntimeError(
|
|
f"Given file {path} can't be loaded with this version of descript-audio-codec."
|
|
)
|
|
return cls(codes=codes, **artifacts["metadata"])
|
|
|
|
|
|
class CodecMixin:
|
|
@property
|
|
def padding(self):
|
|
if not hasattr(self, "_padding"):
|
|
self._padding = True
|
|
return self._padding
|
|
|
|
@padding.setter
|
|
def padding(self, value):
|
|
assert isinstance(value, bool)
|
|
|
|
layers = [
|
|
l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
|
|
]
|
|
|
|
for layer in layers:
|
|
if value:
|
|
if hasattr(layer, "original_padding"):
|
|
layer.padding = layer.original_padding
|
|
else:
|
|
layer.original_padding = layer.padding
|
|
layer.padding = tuple(0 for _ in range(len(layer.padding)))
|
|
|
|
self._padding = value
|
|
|
|
def get_delay(self):
|
|
|
|
l_out = self.get_output_length(0)
|
|
L = l_out
|
|
|
|
layers = []
|
|
for layer in self.modules():
|
|
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
|
layers.append(layer)
|
|
|
|
for layer in reversed(layers):
|
|
d = layer.dilation[0]
|
|
k = layer.kernel_size[0]
|
|
s = layer.stride[0]
|
|
|
|
if isinstance(layer, nn.ConvTranspose1d):
|
|
L = ((L - d * (k - 1) - 1) / s) + 1
|
|
elif isinstance(layer, nn.Conv1d):
|
|
L = (L - 1) * s + d * (k - 1) + 1
|
|
|
|
L = math.ceil(L)
|
|
|
|
l_in = L
|
|
|
|
return (l_in - l_out) // 2
|
|
|
|
def get_output_length(self, input_length):
|
|
L = input_length
|
|
|
|
for layer in self.modules():
|
|
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
|
|
d = layer.dilation[0]
|
|
k = layer.kernel_size[0]
|
|
s = layer.stride[0]
|
|
|
|
if isinstance(layer, nn.Conv1d):
|
|
L = ((L - d * (k - 1) - 1) / s) + 1
|
|
elif isinstance(layer, nn.ConvTranspose1d):
|
|
L = (L - 1) * s + d * (k - 1) + 1
|
|
|
|
L = math.floor(L)
|
|
return L
|
|
|
|
@torch.no_grad()
|
|
def compress(
|
|
self,
|
|
audio_path_or_signal: Union[str, Path, AudioSignal],
|
|
win_duration: float = 1.0,
|
|
verbose: bool = False,
|
|
normalize_db: float = -16,
|
|
n_quantizers: int = None,
|
|
) -> DACFile:
|
|
"""Processes an audio signal from a file or AudioSignal object into
|
|
discrete codes. This function processes the signal in short windows,
|
|
using constant GPU memory.
|
|
|
|
Parameters
|
|
----------
|
|
audio_path_or_signal : Union[str, Path, AudioSignal]
|
|
audio signal to reconstruct
|
|
win_duration : float, optional
|
|
window duration in seconds, by default 5.0
|
|
verbose : bool, optional
|
|
by default False
|
|
normalize_db : float, optional
|
|
normalize db, by default -16
|
|
|
|
Returns
|
|
-------
|
|
DACFile
|
|
Object containing compressed codes and metadata
|
|
required for decompression
|
|
"""
|
|
audio_signal = audio_path_or_signal
|
|
if isinstance(audio_signal, (str, Path)):
|
|
audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
|
|
|
|
self.eval()
|
|
original_padding = self.padding
|
|
original_device = audio_signal.device
|
|
|
|
audio_signal = audio_signal.clone()
|
|
original_sr = audio_signal.sample_rate
|
|
|
|
resample_fn = audio_signal.resample
|
|
loudness_fn = audio_signal.loudness
|
|
|
|
|
|
if audio_signal.signal_duration >= 10 * 60 * 60:
|
|
resample_fn = audio_signal.ffmpeg_resample
|
|
loudness_fn = audio_signal.ffmpeg_loudness
|
|
|
|
original_length = audio_signal.signal_length
|
|
resample_fn(self.sample_rate)
|
|
input_db = loudness_fn()
|
|
|
|
if normalize_db is not None:
|
|
audio_signal.normalize(normalize_db)
|
|
audio_signal.ensure_max_of_audio()
|
|
|
|
nb, nac, nt = audio_signal.audio_data.shape
|
|
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
|
|
win_duration = (
|
|
audio_signal.signal_duration if win_duration is None else win_duration
|
|
)
|
|
|
|
if audio_signal.signal_duration <= win_duration:
|
|
|
|
self.padding = True
|
|
n_samples = nt
|
|
hop = nt
|
|
else:
|
|
|
|
self.padding = False
|
|
|
|
audio_signal.zero_pad(self.delay, self.delay)
|
|
n_samples = int(win_duration * self.sample_rate)
|
|
|
|
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
|
|
hop = self.get_output_length(n_samples)
|
|
|
|
codes = []
|
|
range_fn = range if not verbose else tqdm.trange
|
|
|
|
for i in range_fn(0, nt, hop):
|
|
x = audio_signal[..., i : i + n_samples]
|
|
x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
|
|
|
|
audio_data = x.audio_data.to(self.device)
|
|
audio_data = self.preprocess(audio_data, self.sample_rate)
|
|
_, c, _, _, _ = self.encode(audio_data, n_quantizers)
|
|
codes.append(c.to(original_device))
|
|
chunk_length = c.shape[-1]
|
|
|
|
codes = torch.cat(codes, dim=-1)
|
|
|
|
dac_file = DACFile(
|
|
codes=codes,
|
|
chunk_length=chunk_length,
|
|
original_length=original_length,
|
|
input_db=input_db,
|
|
channels=nac,
|
|
sample_rate=original_sr,
|
|
padding=self.padding,
|
|
dac_version=SUPPORTED_VERSIONS[-1],
|
|
)
|
|
|
|
if n_quantizers is not None:
|
|
codes = codes[:, :n_quantizers, :]
|
|
|
|
self.padding = original_padding
|
|
return dac_file
|
|
|
|
@torch.no_grad()
|
|
def decompress(
|
|
self,
|
|
obj: Union[str, Path, DACFile],
|
|
verbose: bool = False,
|
|
) -> AudioSignal:
|
|
"""Reconstruct audio from a given .dac file
|
|
|
|
Parameters
|
|
----------
|
|
obj : Union[str, Path, DACFile]
|
|
.dac file location or corresponding DACFile object.
|
|
verbose : bool, optional
|
|
Prints progress if True, by default False
|
|
|
|
Returns
|
|
-------
|
|
AudioSignal
|
|
Object with the reconstructed audio
|
|
"""
|
|
self.eval()
|
|
if isinstance(obj, (str, Path)):
|
|
obj = DACFile.load(obj)
|
|
|
|
original_padding = self.padding
|
|
self.padding = obj.padding
|
|
|
|
range_fn = range if not verbose else tqdm.trange
|
|
codes = obj.codes
|
|
original_device = codes.device
|
|
chunk_length = obj.chunk_length
|
|
recons = []
|
|
|
|
for i in range_fn(0, codes.shape[-1], chunk_length):
|
|
c = codes[..., i : i + chunk_length].to(self.device)
|
|
z = self.quantizer.from_codes(c)[0]
|
|
r = self.decode(z)
|
|
recons.append(r.to(original_device))
|
|
|
|
recons = torch.cat(recons, dim=-1)
|
|
recons = AudioSignal(recons, self.sample_rate)
|
|
|
|
resample_fn = recons.resample
|
|
loudness_fn = recons.loudness
|
|
|
|
|
|
if recons.signal_duration >= 10 * 60 * 60:
|
|
resample_fn = recons.ffmpeg_resample
|
|
loudness_fn = recons.ffmpeg_loudness
|
|
|
|
recons.normalize(obj.input_db)
|
|
resample_fn(obj.sample_rate)
|
|
recons = recons[..., : obj.original_length]
|
|
loudness_fn()
|
|
recons.audio_data = recons.audio_data.reshape(
|
|
-1, obj.channels, obj.original_length
|
|
)
|
|
|
|
self.padding = original_padding
|
|
return recons
|
|
|