import pathlib import yaml import torch import torchaudio import numpy as np from lightning_module import SSLDualLightningModule import gradio as gr import subprocess import requests def normalize_waveform(wav, sr, db=-3): wav, _ = torchaudio.sox_effects.apply_effects_tensor( wav.unsqueeze(0), sr, [["norm", "{}".format(db)]], ) return wav.squeeze(0) def download_file_from_google_drive(id, destination): URL = "https://docs.google.com/uc?export=download" session = requests.Session() response = session.get(URL, params = { 'id' : id }, stream = True) token = get_confirm_token(response) if token: params = { 'id' : id, 'confirm' : token } response = session.get(URL, params = params, stream = True) save_response_content(response, destination) def get_confirm_token(response): for key, value in response.cookies.items(): if key.startswith('download_warning'): return value return None def save_response_content(response, destination): CHUNK_SIZE = 32768 with open(destination, "wb") as f: for chunk in response.iter_content(CHUNK_SIZE): if chunk: # filter out keep-alive new chunks f.write(chunk) def calc_spectrogram(wav, config): spec_module = torchaudio.transforms.MelSpectrogram( sample_rate=config["preprocess"]["sampling_rate"], n_fft=config["preprocess"]["fft_length"], win_length=config["preprocess"]["frame_length"], hop_length=config["preprocess"]["frame_shift"], f_min=config["preprocess"]["fmin"], f_max=config["preprocess"]["fmax"], n_mels=config["preprocess"]["n_mels"], power=1, center=True, norm="slaney", mel_scale="slaney", ) specs = spec_module(wav) log_spec = torch.log( torch.clamp_min(specs, config["preprocess"]["min_magnitude"]) * config["preprocess"]["comp_factor"] ).to(torch.float32) return log_spec def transfer(audio): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") wp_src = pathlib.Path("aet_sample/src.wav") wav_src, sr = torchaudio.load(wp_src) sr_inp, wav_tar = audio wav_tar = wav_tar / (np.max(np.abs(wav_tar)) * 1.1) wav_tar = torch.from_numpy(wav_tar.astype(np.float32)) resampler = torchaudio.transforms.Resample( orig_freq=sr_inp, new_freq=sr, ) wav_tar = resampler(wav_tar) config_path = pathlib.Path("configs/test/melspec/ssl_tono.yaml") config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) melspec_src = calc_spectrogram( normalize_waveform(wav_src.squeeze(0), sr), config ) wav_tar = normalize_waveform(wav_tar.squeeze(0), sr) ckpt_path = pathlib.Path("tono_aet_melspec.ckpt").resolve() src_model = SSLDualLightningModule(config).load_from_checkpoint( checkpoint_path=ckpt_path, config=config, strict=False ).eval() encoder_src = src_model.encoder.to(device) channelfeats_src = src_model.channelfeats.to(device) channel_src = src_model.channel.to(device) with torch.no_grad(): _, enc_hidden_src = encoder_src( melspec_src.unsqueeze(0).unsqueeze(1).transpose(2, 3).to(device) ) chfeats_src = channelfeats_src(enc_hidden_src) wav_transfer = channel_src(wav_tar.unsqueeze(0), chfeats_src) wav_transfer = wav_transfer.cpu().detach().numpy()[0, :] return sr, wav_transfer if __name__ == "__main__": subprocess.run(["curl", "-OL", "https://sarulab.sakura.ne.jp/saeki/selfremaster/pretrained/tono_aet_melspec.ckpt"]) download_file_from_google_drive("10OJ2iznutxzp8MEIS6lBVaIS_g5c_70V", "hifigan/hifigan_melspec_universal") iface = gr.Interface( transfer, "audio", gr.outputs.Audio(type="numpy"), examples=[ ["aet_sample/tar.wav"] ], layout="horizontal", title='Audio effect transfer with SelfRemaster', description='Extracting the channel feature of a historical audio recording with a pretrained SelfRemaster and adding it to any high-quality audio. (Source audio is aet_sample/src.wav)' ) iface.launch()