|
import pathlib |
|
import yaml |
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
from lightning_module import SSLDualLightningModule |
|
import gradio as gr |
|
import subprocess |
|
import requests |
|
|
|
def normalize_waveform(wav, sr, db=-3): |
|
wav, _ = torchaudio.sox_effects.apply_effects_tensor( |
|
wav.unsqueeze(0), |
|
sr, |
|
[["norm", "{}".format(db)]], |
|
) |
|
return wav.squeeze(0) |
|
|
|
def download_file_from_google_drive(id, destination): |
|
URL = "https://docs.google.com/uc?export=download" |
|
|
|
session = requests.Session() |
|
|
|
response = session.get(URL, params = { 'id' : id }, stream = True) |
|
token = get_confirm_token(response) |
|
|
|
if token: |
|
params = { 'id' : id, 'confirm' : token } |
|
response = session.get(URL, params = params, stream = True) |
|
|
|
save_response_content(response, destination) |
|
|
|
def get_confirm_token(response): |
|
for key, value in response.cookies.items(): |
|
if key.startswith('download_warning'): |
|
return value |
|
|
|
return None |
|
|
|
def save_response_content(response, destination): |
|
CHUNK_SIZE = 32768 |
|
|
|
with open(destination, "wb") as f: |
|
for chunk in response.iter_content(CHUNK_SIZE): |
|
if chunk: |
|
f.write(chunk) |
|
|
|
def calc_spectrogram(wav, config): |
|
spec_module = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=config["preprocess"]["sampling_rate"], |
|
n_fft=config["preprocess"]["fft_length"], |
|
win_length=config["preprocess"]["frame_length"], |
|
hop_length=config["preprocess"]["frame_shift"], |
|
f_min=config["preprocess"]["fmin"], |
|
f_max=config["preprocess"]["fmax"], |
|
n_mels=config["preprocess"]["n_mels"], |
|
power=1, |
|
center=True, |
|
norm="slaney", |
|
mel_scale="slaney", |
|
) |
|
specs = spec_module(wav) |
|
log_spec = torch.log( |
|
torch.clamp_min(specs, config["preprocess"]["min_magnitude"]) |
|
* config["preprocess"]["comp_factor"] |
|
).to(torch.float32) |
|
return log_spec |
|
|
|
def transfer(audio): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
wp_src = pathlib.Path("aet_sample/src.wav") |
|
wav_src, sr = torchaudio.load(wp_src) |
|
sr_inp, wav_tar = audio |
|
wav_tar = wav_tar / (np.max(np.abs(wav_tar)) * 1.1) |
|
wav_tar = torch.from_numpy(wav_tar.astype(np.float32)) |
|
resampler = torchaudio.transforms.Resample( |
|
orig_freq=sr_inp, |
|
new_freq=sr, |
|
) |
|
wav_tar = resampler(wav_tar) |
|
config_path = pathlib.Path("configs/test/melspec/ssl_tono.yaml") |
|
config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) |
|
|
|
melspec_src = calc_spectrogram( |
|
normalize_waveform(wav_src.squeeze(0), sr), config |
|
) |
|
wav_tar = normalize_waveform(wav_tar.squeeze(0), sr) |
|
ckpt_path = pathlib.Path("tono_aet_melspec.ckpt").resolve() |
|
src_model = SSLDualLightningModule(config).load_from_checkpoint( |
|
checkpoint_path=ckpt_path, |
|
config=config, |
|
strict=False |
|
).eval() |
|
|
|
encoder_src = src_model.encoder.to(device) |
|
channelfeats_src = src_model.channelfeats.to(device) |
|
channel_src = src_model.channel.to(device) |
|
|
|
with torch.no_grad(): |
|
_, enc_hidden_src = encoder_src( |
|
melspec_src.unsqueeze(0).unsqueeze(1).transpose(2, 3).to(device) |
|
) |
|
chfeats_src = channelfeats_src(enc_hidden_src) |
|
wav_transfer = channel_src(wav_tar.unsqueeze(0), chfeats_src) |
|
wav_transfer = wav_transfer.cpu().detach().numpy()[0, :] |
|
return sr, wav_transfer |
|
|
|
if __name__ == "__main__": |
|
subprocess.run(["curl", "-OL", "https://sarulab.sakura.ne.jp/saeki/selfremaster/pretrained/tono_aet_melspec.ckpt"]) |
|
download_file_from_google_drive("10OJ2iznutxzp8MEIS6lBVaIS_g5c_70V", "hifigan/hifigan_melspec_universal") |
|
|
|
iface = gr.Interface( |
|
transfer, |
|
"audio", |
|
gr.outputs.Audio(type="numpy"), |
|
examples=[ |
|
["aet_sample/tar.wav"] |
|
], |
|
layout="horizontal", |
|
title='Audio effect transfer with SelfRemaster', |
|
description='Extracting the channel feature of a historical audio recording with a pretrained SelfRemaster and adding it to any high-quality audio. (Source audio is aet_sample/src.wav)' |
|
) |
|
|
|
iface.launch() |