GaMaDHaNi / app.py
Nithya
updated parent repo and restructured things
98eb218
raw
history blame
7.86 kB
import spaces
import gradio as gr
import numpy as np
import torch
import librosa
import matplotlib.pyplot as plt
import pandas as pd
import os
from functools import partial
import gin
from gamadhani.utils.generate_utils import load_pitch_fns, load_audio_fns
import gamadhani.utils.pitch_to_audio_utils as p2a
from gamadhani.utils.utils import get_device
import torchaudio
from absl import app
from torch.nn.functional import interpolate
import logging
import crepe
from hmmlearn import hmm
import soundfile as sf
import pdb
pitch_path = 'models/diffusion_pitch/'
audio_path = 'models/pitch_to_audio/'
device = get_device()
def predict_voicing(confidence):
# https://github.com/marl/crepe/pull/26
"""
Find the Viterbi path for voiced versus unvoiced frames.
Parameters
----------
confidence : np.ndarray [shape=(N,)]
voicing confidence array, i.e. the confidence in the presence of
a pitch
Returns
-------
voicing_states : np.ndarray [shape=(N,)]
HMM predictions for each frames state, 0 if unvoiced, 1 if
voiced
"""
# uniform prior on the voicing confidence
starting = np.array([0.5, 0.5])
# transition probabilities inducing continuous voicing state
transition = np.array([[0.99, 0.01], [0.01, 0.99]])
# mean and variance for unvoiced and voiced states
means = np.array([[0.0], [1.0]])
variances = np.array([[0.25], [0.25]])
# fix the model parameters because we are not optimizing the model
model = hmm.GaussianHMM(n_components=2)
model.startprob_, model.covars_, model.transmat_, model.means_, \
model.n_features = starting, variances, transition, means, 1
# find the Viterbi path
voicing_states = model.predict(confidence.reshape(-1, 1), [len(confidence)])
return np.array(voicing_states)
def extract_pitch(audio, unvoice=True, sr=16000, frame_shift_ms=10, log=True):
time, frequency, confidence, _ = crepe.predict(
audio, sr=sr,
viterbi=True,
step_size=frame_shift_ms,
verbose=0 if not log else 1)
f0 = frequency
if unvoice:
is_voiced = predict_voicing(confidence)
frequency_unvoiced = frequency * is_voiced
f0 = frequency_unvoiced
return time, f0, confidence
def generate_pitch_reinterp(pitch, pitch_model, invert_pitch_fn, num_samples, num_steps, noise_std=0.4):
'''Generate pitch values for the melodic reinterpretation task'''
# hardcoding the amount of noise to be added
noisy_pitch = torch.Tensor(pitch[:, :, -1200:]).to(pitch_model.device) + (torch.normal(mean=0.0, std=noise_std*torch.ones((1200)))).to(pitch_model.device)
noisy_pitch = torch.clamp(noisy_pitch, -5.19, 5.19) # clipping the pitch values to be within the range of the model
samples = pitch_model.sample_sdedit(noisy_pitch, num_samples, num_steps)
inverted_pitches = [invert_pitch_fn(f0=samples.detach().cpu().numpy()[0])[0]] # pitch values in Hz
return samples, inverted_pitches
def generate_audio(audio_model, f0s, invert_audio_fn, singers=[3], num_steps=100):
'''Generate audio given pitch values'''
singer_tensor = torch.tensor(np.repeat(singers, repeats=f0s.shape[0])).to(audio_model.device)
samples, _, singers = audio_model.sample_cfg(f0s.shape[0], f0=f0s, num_steps=num_steps, singer=singer_tensor, strength=3)
audio = invert_audio_fn(samples)
return audio
@spaces.GPU(duration=120)
def generate(pitch, num_samples=1, num_steps=100, singers=[3], outfolder='temp', audio_seq_len=750, pitch_qt=None ):
logging.log(logging.INFO, 'Generate function')
pitch, inverted_pitch = generate_pitch_reinterp(pitch, pitch_model, invert_pitch_fn, num_samples=num_samples, num_steps=100)
if pitch_qt is not None:
# if there is not pitch quantile transformer, undo the default quantile transformation that occurs
def undo_qt(x, min_clip=200):
pitch= pitch_qt.inverse_transform(x.reshape(-1, 1)).reshape(1, -1)
pitch = np.around(pitch) # round to nearest integer, done in preprocessing of pitch contour fed into model
pitch[pitch < 200] = np.nan
return pitch
pitch = torch.tensor(np.array([undo_qt(x) for x in pitch.detach().cpu().numpy()])).to(pitch_model.device)
interpolated_pitch = p2a.interpolate_pitch(pitch=pitch, audio_seq_len=audio_seq_len) # interpolate pitch values to match the audio model's input size
interpolated_pitch = torch.nan_to_num(interpolated_pitch, nan=196) # replace nan values with silent token
interpolated_pitch = interpolated_pitch.squeeze(1) # to match input size by removing the extra dimension
audio = generate_audio(audio_model, interpolated_pitch, invert_audio_fn, singers=singers, num_steps=100)
audio = audio.detach().cpu().numpy()
pitch = pitch.detach().cpu().numpy()
pitch_vals = np.where(pitch[0][:, 0] == 0, np.nan, pitch[0].flatten())
# generate plot of model output to display on interface
model_output_plot = plt.figure()
plt.plot(pitch_vals, figure=model_output_plot, label='Model Output')
plt.close(model_output_plot)
return (16000, audio[0]), model_output_plot, pitch_vals
# pdb.set_trace()
pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn, _ = load_pitch_fns(
os.path.join(pitch_path, 'last.ckpt'), \
model_type = 'diffusion', \
config_path = os.path.join(pitch_path, 'config.gin'), \
qt_path = os.path.join(pitch_path, 'qt.joblib'), \
)
audio_model, audio_qt, audio_seq_len, invert_audio_fn = load_audio_fns(
os.path.join(audio_path, 'last.ckpt'),
qt_path = os.path.join(audio_path, 'qt.joblib'),
config_path = os.path.join(audio_path, 'config.gin')
)
partial_generate = partial(generate, num_samples=1, num_steps=100, singers=[3], outfolder=None, pitch_qt=pitch_qt) # generate function with default arguments
@spaces.GPU(duration=120)
def set_guide_and_generate(audio):
global selected_prime, pitch_task_fn
if audio is None:
return None, None
sr, audio = audio
if len(audio) < 12*sr:
audio = np.pad(audio, (0, 12*sr - len(audio)), mode='constant')
audio = audio.astype(np.float32)
audio /= np.max(np.abs(audio))
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) # convert only last 4 s
mic_audio = audio.copy()
audio = audio[-12*16000:] # consider only last 12 s
_, f0, _ = extract_pitch(audio)
mic_f0 = f0.copy() # save the user input pitch values
f0 = pitch_task_fn(**{
'inputs': {
'pitch': {
'data': torch.Tensor(f0), # task function expects a tensor
'sampling_rate': 100
}
},
'qt_transform': pitch_qt,
'time_downsample': 1, # pitch will be extracted at 100 Hz, thus no downsampling
'seq_len': None,
})['sampled_sequence']
# pdb.set_trace()
f0 = f0.reshape(1, 1, -1)
f0 = torch.tensor(f0).to(pitch_model.device).float()
audio, pitch, _ = partial_generate(f0)
mic_f0 = np.where(mic_f0 == 0, np.nan, mic_f0)
# plot user input
user_input_plot = plt.figure()
plt.plot(np.arange(0, len(mic_f0)), mic_f0, label='User Input', figure=user_input_plot)
plt.close(user_input_plot)
return audio, user_input_plot, pitch
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
audio = gr.Audio(label="Input")
sbmt = gr.Button()
user_input = gr.Plot(label="User Input")
with gr.Column():
generated_audio = gr.Audio(label="Generated Audio")
generated_pitch = gr.Plot(label="Generated Pitch")
sbmt.click(set_guide_and_generate, inputs=[audio], outputs=[generated_audio, user_input, generated_pitch])
def main(argv):
demo.launch(share=True)
if __name__ == '__main__':
app.run(main)