anonymous9a7b
1
f032e68
raw
history blame
1.59 kB
import torch
from audiotools import AudioSignal
from audiotools.ml import BaseModel
from encodec import EncodecModel
class Encodec(BaseModel):
def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0):
super().__init__()
if sample_rate == 24000:
self.model = EncodecModel.encodec_model_24khz()
else:
self.model = EncodecModel.encodec_model_48khz()
self.model.set_target_bandwidth(bandwidth)
self.sample_rate = 44100
def forward(
self,
audio_data: torch.Tensor,
sample_rate: int = 44100,
n_quantizers: int = None,
):
signal = AudioSignal(audio_data, sample_rate)
signal.resample(self.model.sample_rate)
recons = self.model(signal.audio_data)
recons = AudioSignal(recons, self.model.sample_rate)
recons.resample(sample_rate)
return {"audio": recons.audio_data}
if __name__ == "__main__":
import numpy as np
from functools import partial
model = Encodec()
for n, m in model.named_modules():
o = m.extra_repr()
p = sum([np.prod(p.size()) for p in m.parameters()])
fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
setattr(m, "extra_repr", partial(fn, o=o, p=p))
print(model)
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
length = 88200 * 2
x = torch.randn(1, 1, length).to(model.device)
x.requires_grad_(True)
x.retain_grad()
# Make a forward pass
out = model(x)["audio"]
print(x.shape, out.shape)