In [None]:
import torch
import torchaudio
from IPython.display import Audio, display

from models.model import Vocos
from utils.audio import LogMelSpectrogram
from config import MelConfig, VocosConfig

from pathlib import Path
import random

def load_and_resample_audio(audio_path, target_sr):
 y, sr = torchaudio.load(audio_path)
 if y.size(0) > 1:
 y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time]
 if sr != target_sr:
 y = torchaudio.functional.resample(y, sr, target_sr)
 return y

device = 'cpu'

mel_config = MelConfig()
vocos_config = VocosConfig()

mel_extractor = LogMelSpectrogram(mel_config)
model = Vocos(vocos_config, mel_config).to(device)
model.load_state_dict(torch.load('./checkpoints/generator_0.pt', map_location='cpu'))
model.eval()

audio_paths = list(Path('./audios').rglob('*.wav'))

In [None]:
audio_path = random.choice(audio_paths)
with torch.inference_mode():
 audio = load_and_resample_audio(audio_path, mel_config.sample_rate).to(device)
 mel = mel_extractor(audio)
 recon_audio = model(mel)
display(Audio(audio, rate=mel_config.sample_rate))
display(Audio(recon_audio, rate=mel_config.sample_rate))