Spaces:
Build error
Build error
File size: 2,569 Bytes
66a6dc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import math
import torch
import librosa
# based on https://github.com/neuralaudio/hear-baseline/blob/main/hearbaseline/naive.py
class RandomMelProjection(torch.nn.Module):
def __init__(
self,
sample_rate,
embed_dim=4096,
n_mels=128,
n_fft=4096,
hop_size=1024,
seed=0,
epsilon=1e-4,
):
super().__init__()
self.sample_rate = sample_rate
self.embed_dim = embed_dim
self.n_mels = n_mels
self.n_fft = n_fft
self.hop_size = hop_size
self.seed = seed
self.epsilon = epsilon
# Set random seed
torch.random.manual_seed(self.seed)
# Create a Hann window buffer to apply to frames prior to FFT.
self.register_buffer("window", torch.hann_window(self.n_fft))
# Create a mel filter buffer.
mel_scale = torch.tensor(
librosa.filters.mel(
self.sample_rate,
n_fft=self.n_fft,
n_mels=self.n_mels,
)
)
self.register_buffer("mel_scale", mel_scale)
# Projection matrices.
normalization = math.sqrt(self.n_mels)
self.projection = torch.nn.Parameter(
torch.rand(self.n_mels, self.embed_dim) / normalization,
requires_grad=False,
)
def forward(self, x):
bs, chs, samp = x.size()
x = torch.stft(
x.view(bs, -1),
self.n_fft,
self.hop_size,
window=self.window,
return_complex=True,
)
x = x.unsqueeze(1).permute(0, 1, 3, 2)
# Apply the mel-scale filter to the power spectrum.
x = torch.matmul(x.abs(), self.mel_scale.transpose(0, 1))
# power scale
x = torch.pow(x + self.epsilon, 0.3)
# apply random projection
e = x.matmul(self.projection)
# take mean across temporal dim
e = e.mean(dim=2).view(bs, -1)
return e
def compute_frame_embedding(self, x):
# Compute the real-valued Fourier transform on windowed input signal.
x = torch.fft.rfft(x * self.window)
# Convert to a power spectrum.
x = torch.abs(x) ** 2.0
# Apply the mel-scale filter to the power spectrum.
x = torch.matmul(x, self.mel_scale.transpose(0, 1))
# Convert to a log mel spectrum.
x = torch.log(x + self.epsilon)
# Apply projection to get a 4096 dimension embedding
embedding = x.matmul(self.projection)
return embedding
|