Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
import os | |
# Check if running in debug mode | |
debug_mode = '--debug' in sys.argv or os.environ.get('DEBUG') == 'True' | |
if debug_mode: | |
# Path to the local version of the package | |
local_package_path = "../../GaMaDHaNi" | |
# Add the local package path to sys.path | |
sys.path.insert(0, local_package_path) | |
print(f"Running in debug mode. Using package from: {local_package_path}") | |
else: | |
print("Running in normal mode. Using package from site-packages.") | |
import spaces | |
import gradio as gr | |
import numpy as np | |
import torch | |
import librosa | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
from functools import partial | |
import gin | |
import torchaudio | |
from absl import app | |
from torch.nn.functional import interpolate | |
import logging | |
import crepe | |
from hmmlearn import hmm | |
import soundfile as sf | |
import pdb | |
from gamadhani.utils.generate_utils import load_pitch_fns, load_audio_fns | |
import gamadhani.utils.pitch_to_audio_utils as p2a | |
from gamadhani.utils.utils import get_device | |
pitch_path = 'models/diffusion_pitch/' | |
audio_path = 'models/pitch_to_audio/' | |
device = get_device() | |
def predict_voicing(confidence): | |
# https://github.com/marl/crepe/pull/26 | |
""" | |
Find the Viterbi path for voiced versus unvoiced frames. | |
Parameters | |
---------- | |
confidence : np.ndarray [shape=(N,)] | |
voicing confidence array, i.e. the confidence in the presence of | |
a pitch | |
Returns | |
------- | |
voicing_states : np.ndarray [shape=(N,)] | |
HMM predictions for each frames state, 0 if unvoiced, 1 if | |
voiced | |
""" | |
# uniform prior on the voicing confidence | |
starting = np.array([0.5, 0.5]) | |
# transition probabilities inducing continuous voicing state | |
transition = np.array([[0.99, 0.01], [0.01, 0.99]]) | |
# mean and variance for unvoiced and voiced states | |
means = np.array([[0.0], [1.0]]) | |
variances = np.array([[0.25], [0.25]]) | |
# fix the model parameters because we are not optimizing the model | |
model = hmm.GaussianHMM(n_components=2) | |
model.startprob_, model.covars_, model.transmat_, model.means_, \ | |
model.n_features = starting, variances, transition, means, 1 | |
# find the Viterbi path | |
voicing_states = model.predict(confidence.reshape(-1, 1), [len(confidence)]) | |
return np.array(voicing_states) | |
def extract_pitch(audio, unvoice=True, sr=16000, frame_shift_ms=10, log=True): | |
time, frequency, confidence, _ = crepe.predict( | |
audio, sr=sr, | |
viterbi=True, | |
step_size=frame_shift_ms, | |
verbose=0 if not log else 1) | |
f0 = frequency | |
if unvoice: | |
is_voiced = predict_voicing(confidence) | |
frequency_unvoiced = frequency * is_voiced | |
f0 = frequency_unvoiced | |
return time, f0, confidence | |
def generate_pitch_reinterp(pitch, pitch_model, invert_pitch_fn, num_samples, num_steps, noise_std=0.4): | |
'''Generate pitch values for the melodic reinterpretation task''' | |
# hardcoding the amount of noise to be added | |
noisy_pitch = torch.Tensor(pitch[:, :, -1200:]).to(pitch_model.device) + (torch.normal(mean=0.0, std=noise_std*torch.ones((1200)))).to(pitch_model.device) | |
noisy_pitch = torch.clamp(noisy_pitch, -5.19, 5.19) # clipping the pitch values to be within the range of the model | |
samples = pitch_model.sample_sdedit(noisy_pitch, num_samples, num_steps) | |
inverted_pitches = [invert_pitch_fn(f0=samples.detach().cpu().numpy()[0])[0]] # pitch values in Hz | |
return samples, inverted_pitches | |
def generate_audio(audio_model, f0s, invert_audio_fn, singers=[3], num_steps=100): | |
'''Generate audio given pitch values''' | |
singer_tensor = torch.tensor(np.repeat(singers, repeats=f0s.shape[0])).to(audio_model.device) | |
samples, _, singers = audio_model.sample_cfg(f0s.shape[0], f0=f0s, num_steps=num_steps, singer=singer_tensor, strength=3) | |
audio = invert_audio_fn(samples) | |
return audio | |
def generate(pitch, num_samples=1, num_steps=100, singers=[3], outfolder='temp', audio_seq_len=750, pitch_qt=None ): | |
logging.log(logging.INFO, 'Generate function') | |
pitch, inverted_pitch = generate_pitch_reinterp(pitch, pitch_model, invert_pitch_fn, num_samples=num_samples, num_steps=100) | |
if pitch_qt is not None: | |
# if there is not pitch quantile transformer, undo the default quantile transformation that occurs | |
def undo_qt(x, min_clip=200): | |
pitch= pitch_qt.inverse_transform(x.reshape(-1, 1)).reshape(1, -1) | |
pitch = np.around(pitch) # round to nearest integer, done in preprocessing of pitch contour fed into model | |
pitch[pitch < 200] = np.nan | |
return pitch | |
pitch = torch.tensor(np.array([undo_qt(x) for x in pitch.detach().cpu().numpy()])).to(pitch_model.device) | |
interpolated_pitch = p2a.interpolate_pitch(pitch=pitch, audio_seq_len=audio_seq_len) # interpolate pitch values to match the audio model's input size | |
interpolated_pitch = torch.nan_to_num(interpolated_pitch, nan=196) # replace nan values with silent token | |
interpolated_pitch = interpolated_pitch.squeeze(1) # to match input size by removing the extra dimension | |
audio = generate_audio(audio_model, interpolated_pitch, invert_audio_fn, singers=singers, num_steps=100) | |
audio = audio.detach().cpu().numpy() | |
pitch = pitch.detach().cpu().numpy() | |
pitch_vals = np.where(pitch[0][:, 0] == 0, np.nan, pitch[0].flatten()) | |
# generate plot of model output to display on interface | |
model_output_plot = plt.figure() | |
plt.plot(pitch_vals, figure=model_output_plot, label='Model Output') | |
plt.close(model_output_plot) | |
return (16000, audio[0]), model_output_plot, pitch_vals | |
# pdb.set_trace() | |
pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn, _ = load_pitch_fns( | |
os.path.join(pitch_path, 'last.ckpt'), \ | |
model_type = 'diffusion', \ | |
config_path = os.path.join(pitch_path, 'config.gin'), \ | |
qt_path = os.path.join(pitch_path, 'qt.joblib'), \ | |
device = device | |
) | |
audio_model, audio_qt, audio_seq_len, invert_audio_fn = load_audio_fns( | |
os.path.join(audio_path, 'last.ckpt'), | |
qt_path = os.path.join(audio_path, 'qt.joblib'), | |
config_path = os.path.join(audio_path, 'config.gin'), | |
device = device | |
) | |
partial_generate = partial(generate, num_samples=1, num_steps=100, singers=[3], outfolder=None, pitch_qt=pitch_qt) # generate function with default arguments | |
def set_guide_and_generate(audio): | |
global selected_prime, pitch_task_fn | |
if audio is None: | |
return None, None | |
sr, audio = audio | |
if len(audio) < 12*sr: | |
audio = np.pad(audio, (0, 12*sr - len(audio)), mode='constant') | |
audio = audio.astype(np.float32) | |
audio /= np.max(np.abs(audio)) | |
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) # convert only last 4 s | |
mic_audio = audio.copy() | |
audio = audio[-12*16000:] # consider only last 12 s | |
_, f0, _ = extract_pitch(audio) | |
mic_f0 = f0.copy() # save the user input pitch values | |
f0 = pitch_task_fn(**{ | |
'inputs': { | |
'pitch': { | |
'data': torch.Tensor(f0), # task function expects a tensor | |
'sampling_rate': 100 | |
} | |
}, | |
'qt_transform': pitch_qt, | |
'time_downsample': 1, # pitch will be extracted at 100 Hz, thus no downsampling | |
'seq_len': None, | |
})['sampled_sequence'] | |
# pdb.set_trace() | |
f0 = f0.reshape(1, 1, -1) | |
f0 = torch.tensor(f0).to(pitch_model.device).float() | |
audio, pitch, _ = partial_generate(f0) | |
mic_f0 = np.where(mic_f0 == 0, np.nan, mic_f0) | |
# plot user input | |
user_input_plot = plt.figure() | |
plt.plot(np.arange(0, len(mic_f0)), mic_f0, label='User Input', figure=user_input_plot) | |
plt.close(user_input_plot) | |
return audio, user_input_plot, pitch | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
audio = gr.Audio(label="Input") | |
sbmt = gr.Button() | |
user_input = gr.Plot(label="User Input") | |
with gr.Column(): | |
generated_audio = gr.Audio(label="Generated Audio") | |
generated_pitch = gr.Plot(label="Generated Pitch") | |
sbmt.click(set_guide_and_generate, inputs=[audio], outputs=[generated_audio, user_input, generated_pitch]) | |
def main(argv): | |
demo.launch(share=True) | |
if __name__ == '__main__': | |
main(sys.argv) | |