|
import torch |
|
from audiotools import AudioSignal |
|
from audiotools.ml import BaseModel |
|
from encodec import EncodecModel |
|
|
|
|
|
class Encodec(BaseModel): |
|
def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0): |
|
super().__init__() |
|
|
|
if sample_rate == 24000: |
|
self.model = EncodecModel.encodec_model_24khz() |
|
else: |
|
self.model = EncodecModel.encodec_model_48khz() |
|
self.model.set_target_bandwidth(bandwidth) |
|
self.sample_rate = 44100 |
|
|
|
def forward( |
|
self, |
|
audio_data: torch.Tensor, |
|
sample_rate: int = 44100, |
|
n_quantizers: int = None, |
|
): |
|
signal = AudioSignal(audio_data, sample_rate) |
|
signal.resample(self.model.sample_rate) |
|
recons = self.model(signal.audio_data) |
|
recons = AudioSignal(recons, self.model.sample_rate) |
|
recons.resample(sample_rate) |
|
return {"audio": recons.audio_data} |
|
|
|
|
|
if __name__ == "__main__": |
|
import numpy as np |
|
from functools import partial |
|
|
|
model = Encodec() |
|
|
|
for n, m in model.named_modules(): |
|
o = m.extra_repr() |
|
p = sum([np.prod(p.size()) for p in m.parameters()]) |
|
fn = lambda o, p: o + f" {p/1e6:<.3f}M params." |
|
setattr(m, "extra_repr", partial(fn, o=o, p=p)) |
|
print(model) |
|
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) |
|
|
|
length = 88200 * 2 |
|
x = torch.randn(1, 1, length).to(model.device) |
|
x.requires_grad_(True) |
|
x.retain_grad() |
|
|
|
|
|
out = model(x)["audio"] |
|
|
|
print(x.shape, out.shape) |
|
|