Spaces:
Running
on
A10G
Running
on
A10G
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import typing as tp | |
from einops import rearrange | |
from librosa import filters | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torchaudio | |
class ChromaExtractor(nn.Module): | |
"""Chroma extraction and quantization. | |
Args: | |
sample_rate (int): Sample rate for the chroma extraction. | |
n_chroma (int): Number of chroma bins for the chroma extraction. | |
radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12). | |
nfft (int, optional): Number of FFT. | |
winlen (int, optional): Window length. | |
winhop (int, optional): Window hop size. | |
argmax (bool, optional): Whether to use argmax. Defaults to False. | |
norm (float, optional): Norm for chroma normalization. Defaults to inf. | |
""" | |
def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None, | |
winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False, | |
norm: float = torch.inf): | |
super().__init__() | |
self.winlen = winlen or 2 ** radix2_exp | |
self.nfft = nfft or self.winlen | |
self.winhop = winhop or (self.winlen // 4) | |
self.sample_rate = sample_rate | |
self.n_chroma = n_chroma | |
self.norm = norm | |
self.argmax = argmax | |
self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, | |
n_chroma=self.n_chroma)), persistent=False) | |
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, | |
hop_length=self.winhop, power=2, center=True, | |
pad=0, normalized=True) | |
def forward(self, wav: torch.Tensor) -> torch.Tensor: | |
T = wav.shape[-1] | |
# in case we are getting a wav that was dropped out (nullified) | |
# from the conditioner, make sure wav length is no less that nfft | |
if T < self.nfft: | |
pad = self.nfft - T | |
r = 0 if pad % 2 == 0 else 1 | |
wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) | |
assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}" | |
spec = self.spec(wav).squeeze(1) | |
raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec) | |
norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) | |
norm_chroma = rearrange(norm_chroma, 'b d t -> b t d') | |
if self.argmax: | |
idx = norm_chroma.argmax(-1, keepdim=True) | |
norm_chroma[:] = 0 | |
norm_chroma.scatter_(dim=-1, index=idx, value=1) | |
return norm_chroma | |