PitchVC-vino / app.py
OlaWod's picture
support tune f0
b562f6f verified
raw
history blame contribute delete
No virus
4.48 kB
import os
import json
import math
import torch
import torch.nn.functional as F
import librosa
import numpy as np
import soundfile as sf
import gradio as gr
import openvino as ov
from env import AttrDict
from meldataset import mel_spectrogram, MAX_WAV_VALUE
from stft import TorchSTFT
# files
hpfile = "config_v1_16k.json"
g1path = "exp/g1.xml"
g2path = "exp/g2.xml"
spk2id_path = "filelists/spk2id.json"
f0_stats_path = "filelists/f0_stats.json"
spk_stats_path = "filelists/spk_stats.json"
spk_emb_dir = "dataset/spk"
spk_wav_dir = "dataset/audio"
# load config
with open(hpfile) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
# load models
core = ov.Core()
g1 = core.read_model(model=g1path)
g1 = core.compile_model(model=g1, device_name="CPU")
g2 = core.read_model(model=g2path)
g2 = core.compile_model(model=g2, device_name="CPU")
stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft)
# load stats
with open(spk2id_path) as f:
spk2id = json.load(f)
with open(f0_stats_path) as f:
f0_stats = json.load(f)
with open(spk_stats_path) as f:
spk_stats = json.load(f)
# tune f0
threshold = 10
step = (math.log(1100) - math.log(50)) / 256
def tune_f0(initial_f0, i):
if i == 0:
return initial_f0
voiced = initial_f0 > threshold
initial_lf0 = np.log(initial_f0)
lf0 = initial_lf0 + step * i
f0 = np.exp(lf0)
f0 = np.where(voiced, f0, initial_f0)
return f0
# infer
def infer(wav, mel, spk_emb, spk_id, f0_mean_tgt):
# g1
out = g1([wav, mel, spk_emb, spk_id, f0_mean_tgt])
x = out[g1.output(0)]
har_source = out[g1.output(1)]
# stft
har_source = torch.from_numpy(har_source)
har_spec, har_phase = stft.transform(har_source)
har_spec, har_phase = har_spec.numpy(), har_phase.numpy()
# g2
out = g2([x, har_spec, har_phase])
spec = out[g2.output(0)]
phase = out[g2.output(1)]
# istft
spec, phase = torch.from_numpy(spec), torch.from_numpy(phase)
y = stft.inverse(spec, phase)
return y
# convert function
def convert(tgt_spk, src_wav, f0_shift=0):
tgt_ref = spk_stats[tgt_spk]["best_spk_emb"]
tgt_emb = f"{spk_emb_dir}/{tgt_spk}/{tgt_ref}.npy"
with torch.no_grad():
# tgt
spk_id = spk2id[tgt_spk]
spk_id = np.array([spk_id], dtype=np.int64)[None, :]
spk_emb = np.load(tgt_emb)[None, :]
f0_mean_tgt = f0_stats[tgt_spk]["mean"]
f0_mean_tgt = np.array([f0_mean_tgt], dtype=np.float32)[None, :]
f0_mean_tgt = tune_f0(f0_mean_tgt, f0_shift)
# src
wav, sr = librosa.load(src_wav, sr=16000)
wav = wav[None, :]
mel = mel_spectrogram(torch.from_numpy(wav), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax).numpy()
# cvt
y = infer(wav, mel, spk_emb, spk_id, f0_mean_tgt)
audio = y.squeeze()
audio = audio / torch.max(torch.abs(audio)) * 0.95
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')
sf.write("out.wav", audio, h.sampling_rate, "PCM_16")
out_wav = "out.wav"
return out_wav
# change spk
def change_spk(tgt_spk):
tgt_ref = spk_stats[tgt_spk]["best_spk_emb"]
tgt_wav = f"{spk_wav_dir}/{tgt_spk}/{tgt_ref}.wav"
return tgt_wav
# interface
with gr.Blocks() as demo:
gr.Markdown("# PitchVC-vino")
gr.Markdown("Gradio Demo for PitchVC with OpenVINO on CPU. ([Github Repo](https://github.com/OlaWod/PitchVC))")
with gr.Row():
with gr.Column():
tgt_spk = gr.Dropdown(choices=spk2id.keys(), type="value", label="Target Speaker")
ref_audio = gr.Audio(label="Reference Audio", type='filepath')
src_audio = gr.Audio(label="Source Audio", type='filepath')
f0_shift = gr.Slider(minimum=-30, maximum=30, value=0, step=1, label="F0 Shift")
with gr.Column():
out_audio = gr.Audio(label="Output Audio", type='filepath')
submit = gr.Button(value="Submit")
tgt_spk.change(fn=change_spk, inputs=[tgt_spk], outputs=[ref_audio])
submit.click(convert, [tgt_spk, src_audio, f0_shift], [out_audio])
examples = gr.Examples(
examples=[["p225", 'dataset/audio/p226/p226_341.wav', 0],
["p226", 'dataset/audio/p225/p225_220.wav', -5]],
inputs=[tgt_spk, src_audio, f0_shift])
demo.launch()