Spaces:
Runtime error
Runtime error
### code reference: https://github.com/openai/whisper/blob/main/whisper/audio.py | |
import os | |
import torch | |
import torchaudio | |
import numpy as np | |
import torch.nn.functional as F | |
from torch import Tensor, nn | |
from typing import Dict, Iterable, Optional | |
# hard-coded audio hyperparameters | |
SAMPLE_RATE = 16000 | |
N_FFT = 1024 | |
N_MELS = 128 | |
HOP_LENGTH = int(0.01 * SAMPLE_RATE) | |
DURATION = 10 | |
N_SAMPLES = int(DURATION * SAMPLE_RATE) | |
N_FRAMES = N_SAMPLES // HOP_LENGTH + 1 | |
def sinusoids(length, channels, max_timescale=10000): | |
"""Returns sinusoids for positional embedding""" | |
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) | |
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) | |
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] | |
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) | |
class MelEncoder(nn.Module): | |
""" | |
time-frequency represntation | |
""" | |
def __init__(self, | |
sample_rate= 16000, | |
f_min=0, | |
f_max=8000, | |
n_fft=1024, | |
win_length=1024, | |
hop_length = int(0.01 * 16000), | |
n_mels = 128, | |
power = None, | |
pad= 0, | |
normalized= False, | |
center= True, | |
pad_mode= "reflect" | |
): | |
super(MelEncoder, self).__init__() | |
self.window = torch.hann_window(win_length) | |
self.spec_fn = torchaudio.transforms.Spectrogram( | |
n_fft = n_fft, | |
win_length = win_length, | |
hop_length = hop_length, | |
power = power | |
) | |
self.mel_scale = torchaudio.transforms.MelScale( | |
n_mels, | |
sample_rate, | |
f_min, | |
f_max, | |
n_fft // 2 + 1) | |
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() | |
def forward(self, wav): | |
spec = self.spec_fn(wav) | |
power_spec = spec.real.abs().pow(2) | |
mel_spec = self.mel_scale(power_spec) | |
mel_spec = self.amplitude_to_db(mel_spec) # Log10(max(reference value and amin)) | |
return mel_spec | |
class AudioEncoder(nn.Module): | |
def __init__( | |
self, n_mels: int, n_ctx: int, audio_dim: int, text_dim: int, num_of_stride_conv: int, | |
): | |
super().__init__() | |
self.mel_encoder = MelEncoder(n_mels=n_mels) | |
self.conv1 = nn.Conv1d(n_mels, audio_dim, kernel_size=3, padding=1) | |
self.conv_stack = nn.ModuleList([]) | |
for _ in range(num_of_stride_conv): | |
self.conv_stack.append( | |
nn.Conv1d(audio_dim, audio_dim, kernel_size=3, stride=2, padding=1) | |
) | |
# self.proj = nn.Linear(audio_dim, text_dim, bias=False) | |
self.register_buffer("positional_embedding", sinusoids(n_ctx, text_dim)) | |
def forward(self, x: Tensor): | |
""" | |
x : torch.Tensor, shape = (batch_size, waveform) | |
single channel wavform | |
""" | |
x = self.mel_encoder(x) # (batch_size, n_mels, n_ctx) | |
x = F.gelu(self.conv1(x)) | |
for conv in self.conv_stack: | |
x = F.gelu(conv(x)) | |
x = x.permute(0, 2, 1) | |
x = (x + self.positional_embedding).to(x.dtype) | |
return x |