Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import soundfile as sf | |
import numpy as np | |
import yaml | |
from inference import MasteringStyleTransfer | |
from utils import download_youtube_audio | |
from config import args | |
import pyloudnorm as pyln | |
import tempfile | |
import os | |
import pandas as pd | |
mastering_transfer = MasteringStyleTransfer(args) | |
def denormalize_audio(audio, dtype=np.int16): | |
""" | |
Denormalize the audio from the range [-1, 1] to the full range of the specified dtype. | |
""" | |
if dtype == np.int16: | |
audio = np.clip(audio, -1, 1) # Ensure the input is in the range [-1, 1] | |
return (audio * 32767).astype(np.int16) | |
elif dtype == np.float32: | |
return audio.astype(np.float32) | |
else: | |
raise ValueError("Unsupported dtype. Use np.int16 or np.float32.") | |
def loudness_normalize(audio, sample_rate, target_loudness=-12.0): | |
# Ensure audio is float32 | |
if audio.dtype != np.float32: | |
audio = audio.astype(np.float32) | |
# If audio is mono, reshape to (samples, 1) | |
if audio.ndim == 1: | |
audio = audio.reshape(-1, 1) | |
meter = pyln.Meter(sample_rate) # create BS.1770 meter | |
loudness = meter.integrated_loudness(audio) | |
loudness_normalized_audio = pyln.normalize.loudness(audio, loudness, target_loudness) | |
return loudness_normalized_audio | |
def process_youtube_url(url): | |
try: | |
audio, sr = download_youtube_audio(url) | |
return (sr, audio) | |
except Exception as e: | |
return None, f"Error processing YouTube URL: {str(e)}" | |
def process_audio_with_youtube(input_audio, input_youtube_url, reference_audio, reference_youtube_url): | |
if input_youtube_url: | |
input_audio, error = process_youtube_url(input_youtube_url) | |
if error: | |
return None, None, error | |
if reference_youtube_url: | |
reference_audio, error = process_youtube_url(reference_youtube_url) | |
if error: | |
return None, None, error | |
if input_audio is None or reference_audio is None: | |
return None, None, "Both input and reference audio are required." | |
return process_audio(input_audio, reference_audio) | |
def process_audio(input_audio, reference_audio): | |
output_audio, predicted_params, sr, normalized_input = mastering_transfer.process_audio( | |
input_audio, reference_audio | |
) | |
param_output = mastering_transfer.get_param_output_string(predicted_params) | |
# Convert output_audio to numpy array if it's a tensor | |
if isinstance(output_audio, torch.Tensor): | |
output_audio = output_audio.cpu().numpy() | |
if output_audio.ndim == 1: | |
output_audio = output_audio.reshape(-1, 1) | |
elif output_audio.ndim > 2: | |
output_audio = output_audio.squeeze() | |
# Ensure the audio is in the correct shape (samples, channels) | |
if output_audio.shape[1] > output_audio.shape[0]: | |
output_audio = output_audio.transpose(1,0) | |
# Normalize output audio | |
output_audio = loudness_normalize(output_audio, sr) | |
# Denormalize the audio to int16 | |
output_audio = denormalize_audio(output_audio, dtype=np.int16) | |
return (sr, output_audio), param_output, (sr, normalized_input) | |
def perform_ito(input_audio, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights): | |
if ito_reference_audio is None: | |
ito_reference_audio = reference_audio | |
af_weights = [float(w.strip()) for w in af_weights.split(',')] | |
ito_config = { | |
'optimizer': optimizer, | |
'learning_rate': learning_rate, | |
'num_steps': num_steps, | |
'af_weights': af_weights, | |
'sample_rate': args.sample_rate | |
} | |
input_tensor = mastering_transfer.preprocess_audio(input_audio, args.sample_rate) | |
reference_tensor = mastering_transfer.preprocess_audio(reference_audio, args.sample_rate) | |
ito_reference_tensor = mastering_transfer.preprocess_audio(ito_reference_audio, args.sample_rate) | |
initial_reference_feature = mastering_transfer.get_reference_embedding(reference_tensor) | |
all_results, min_loss_step = mastering_transfer.inference_time_optimization( | |
input_tensor, ito_reference_tensor, ito_config, initial_reference_feature | |
) | |
ito_log = "" | |
loss_values = [] | |
for result in all_results: | |
ito_log += result['log'] | |
loss_values.append({"step": result['step'], "loss": result['loss']}) | |
# Return the results of the last step | |
last_result = all_results[-1] | |
current_output = last_result['audio'] | |
ito_param_output = mastering_transfer.get_param_output_string(last_result['params']) | |
# Convert current_output to numpy array if it's a tensor | |
if isinstance(current_output, torch.Tensor): | |
current_output = current_output.cpu().numpy() | |
if current_output.ndim == 1: | |
current_output = current_output.reshape(-1, 1) | |
elif current_output.ndim > 2: | |
current_output = current_output.squeeze() | |
# Ensure the audio is in the correct shape (samples, channels) | |
if current_output.shape[1] > current_output.shape[0]: | |
current_output = current_output.transpose(1,0) | |
# Loudness normalize output audio | |
current_output = loudness_normalize(current_output, args.sample_rate) | |
# Denormalize the audio to int16 | |
current_output = denormalize_audio(current_output, dtype=np.int16) | |
return (args.sample_rate, current_output), ito_param_output, num_steps, ito_log, pd.DataFrame(loss_values), all_results | |
def update_ito_output(all_results, selected_step): | |
selected_result = all_results[selected_step - 1] | |
current_output = selected_result['audio'] | |
ito_param_output = mastering_transfer.get_param_output_string(selected_result['params']) | |
# Convert current_output to numpy array if it's a tensor | |
if isinstance(current_output, torch.Tensor): | |
current_output = current_output.cpu().numpy() | |
if current_output.ndim == 1: | |
current_output = current_output.reshape(-1, 1) | |
elif current_output.ndim > 2: | |
current_output = current_output.squeeze() | |
# Ensure the audio is in the correct shape (samples, channels) | |
if current_output.shape[1] > current_output.shape[0]: | |
current_output = current_output.transpose(1,0) | |
# Loudness normalize output audio | |
current_output = loudness_normalize(current_output, args.sample_rate) | |
# Denormalize the audio to int16 | |
current_output = denormalize_audio(current_output, dtype=np.int16) | |
return (args.sample_rate, current_output), ito_param_output, selected_result['log'] | |
""" APP display """ | |
with gr.Blocks() as demo: | |
gr.Markdown("# ITO-Master: Inference Time Optimization for Mastering Style Transfer") | |
gr.Markdown("# Step 1: Mastering Style Transfer") | |
with gr.Tab("Upload Audio"): | |
with gr.Row(): | |
input_audio = gr.Audio(label="Input Audio") | |
reference_audio = gr.Audio(label="Reference Audio") | |
process_button = gr.Button("Process Mastering Style Transfer") | |
with gr.Row(): | |
with gr.Column(): | |
output_audio = gr.Audio(label="Output Audio", type='numpy') | |
normalized_input = gr.Audio(label="Normalized Input Audio", type='numpy') | |
param_output = gr.Textbox(label="Predicted Parameters", lines=5) | |
process_button.click( | |
process_audio, | |
inputs=[input_audio, reference_audio], | |
outputs=[output_audio, param_output, normalized_input] | |
) | |
with gr.Tab("YouTube Audio"): | |
with gr.Row(): | |
input_youtube_url = gr.Textbox(label="Input YouTube URL") | |
reference_youtube_url = gr.Textbox(label="Reference YouTube URL") | |
with gr.Row(): | |
input_audio_yt = gr.Audio(label="Input Audio (Do not put when using YouTube URL)") | |
reference_audio_yt = gr.Audio(label="Reference Audio (Do not put when using YouTube URL)") | |
process_button_yt = gr.Button("Process Mastering Style Transfer") | |
with gr.Row(): | |
output_audio_yt = gr.Audio(label="Output Audio", type='numpy') | |
param_output_yt = gr.Textbox(label="Predicted Parameters", lines=5) | |
error_message_yt = gr.Textbox(label="Error Message", visible=False) | |
def process_and_handle_errors(input_audio, input_youtube_url, reference_audio, reference_youtube_url): | |
result = process_audio_with_youtube(input_audio, input_youtube_url, reference_audio, reference_youtube_url) | |
if len(result) == 3 and isinstance(result[2], str): # Error occurred | |
return None, None, gr.update(visible=True, value=result[2]) | |
return result[0], result[1], gr.update(visible=False, value="") | |
process_button_yt.click( | |
process_and_handle_errors, | |
inputs=[input_audio_yt, input_youtube_url, reference_audio_yt, reference_youtube_url], | |
outputs=[output_audio_yt, param_output_yt, error_message_yt] | |
) | |
gr.Markdown("## Step 2: Inference Time Optimization (ITO)") | |
with gr.Row(): | |
ito_reference_audio = gr.Audio(label="ITO Reference Audio (optional)") | |
with gr.Column(): | |
num_steps = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of Steps") | |
optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer") | |
learning_rate = gr.Slider(minimum=0.0001, maximum=0.1, value=0.001, step=0.0001, label="Learning Rate") | |
af_weights = gr.Textbox(label="AudioFeatureLoss Weights (comma-separated)", value="0.1,0.001,1.0,1.0,0.1") | |
ito_button = gr.Button("Perform ITO") | |
with gr.Row(): | |
with gr.Column(): | |
ito_output_audio = gr.Audio(label="ITO Output Audio") | |
ito_step_slider = gr.Slider(minimum=1, maximum=100, step=1, label="ITO Step", interactive=True) | |
ito_param_output = gr.Textbox(label="ITO Predicted Parameters", lines=15) | |
with gr.Column(): | |
ito_loss_plot = gr.LinePlot( | |
x="step", | |
y="loss", | |
title="ITO Loss Curve", | |
x_title="Step", | |
y_title="Loss", | |
height=300, | |
width=600, | |
) | |
ito_log = gr.Textbox(label="ITO Log", lines=10) | |
all_results = gr.State([]) | |
ito_button.click( | |
perform_ito, | |
inputs=[normalized_input, reference_audio, ito_reference_audio, num_steps, optimizer, learning_rate, af_weights], | |
outputs=[ito_output_audio, ito_param_output, ito_step_slider, ito_log, ito_loss_plot, all_results] | |
).then( | |
update_ito_output, | |
inputs=[all_results, ito_step_slider], | |
outputs=[ito_output_audio, ito_param_output, ito_log] | |
) | |
ito_step_slider.change( | |
update_ito_output, | |
inputs=[all_results, ito_step_slider], | |
outputs=[ito_output_audio, ito_param_output, ito_log] | |
) | |
demo.launch() | |