ITO-Master / inference.py
jhtonyKoo's picture
modify app
6d6c0d5
raw
history blame
16.3 kB
import torch
import soundfile as sf
import numpy as np
import argparse
import os
import yaml
import julius
import sys
currentdir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.dirname(currentdir))
from networks import Dasp_Mastering_Style_Transfer, Effects_Encoder
from modules.loss import AudioFeatureLoss, Loss
def convert_audio(wav: torch.Tensor, from_rate: float,
to_rate: float, to_channels: int) -> torch.Tensor:
"""Convert audio to new sample rate and number of audio channels.
"""
wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
wav = convert_audio_channels(wav, to_channels)
return wav
class MasteringStyleTransfer:
def __init__(self, args):
self.args = args
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load models
self.effects_encoder = self.load_effects_encoder()
self.mastering_converter = self.load_mastering_converter()
def load_effects_encoder(self):
effects_encoder = Effects_Encoder(self.args.cfg_enc)
reload_weights(effects_encoder, self.args.encoder_path, self.device)
effects_encoder.to(self.device)
effects_encoder.eval()
return effects_encoder
def load_mastering_converter(self):
mastering_converter = Dasp_Mastering_Style_Transfer(num_features=2048,
sample_rate=self.args.sample_rate,
tgt_fx_names=['eq', 'distortion', 'multiband_comp', 'gain', 'imager', 'limiter'],
model_type='tcn',
config=self.args.cfg_converter,
batch_size=1)
reload_weights(mastering_converter, self.args.model_path, self.device)
mastering_converter.to(self.device)
mastering_converter.eval()
return mastering_converter
def get_reference_embedding(self, reference_tensor):
with torch.no_grad():
reference_feature = self.effects_encoder(reference_tensor)
return reference_feature
def mastering_style_transfer(self, input_tensor, reference_feature):
with torch.no_grad():
output_audio = self.mastering_converter(input_tensor, reference_feature)
predicted_params = self.mastering_converter.get_last_predicted_params()
return output_audio, predicted_params
def inference_time_optimization(self, input_tensor, reference_tensor, ito_config, initial_reference_feature):
fit_embedding = torch.nn.Parameter(initial_reference_feature)
optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate'])
af_loss = AudioFeatureLoss(
weights=ito_config['af_weights'],
sample_rate=ito_config['sample_rate'],
stem_separation=False,
use_clap=False
)
min_loss = float('inf')
min_loss_step = 0
min_loss_output = None
min_loss_params = None
min_loss_embedding = None
loss_history = []
divergence_counter = 0
ito_log = []
for step in range(ito_config['num_steps']):
optimizer.zero_grad()
output_audio = self.mastering_converter(input_tensor, fit_embedding)
current_params = self.mastering_converter.get_last_predicted_params()
losses = af_loss(output_audio, reference_tensor)
total_loss = sum(losses.values())
loss_history.append(total_loss.item())
if total_loss < min_loss:
min_loss = total_loss.item()
min_loss_step = step
min_loss_output = output_audio.detach()
min_loss_params = current_params
min_loss_embedding = fit_embedding.detach().clone()
# Check for divergence
if len(loss_history) > 10 and total_loss > loss_history[-11]:
divergence_counter += 1
else:
divergence_counter = 0
# Log top 10 parameter differences
if step == 0:
initial_params = current_params
top_10_diff = self.get_top_10_diff_string(initial_params, current_params)
log_entry = f"Step {step + 1}, Loss: {total_loss.item():.4f}\n{top_10_diff}\n"
ito_log.append(log_entry)
if divergence_counter >= 10:
print(f"Optimization stopped early due to divergence at step {step}")
break
total_loss.backward()
optimizer.step()
return min_loss_output, min_loss_params, min_loss_embedding, min_loss_step + 1, "\n".join(ito_log)
def preprocess_audio(self, audio, target_sample_rate=44100):
sample_rate, data = audio
# Normalize audio to -1 to 1 range
if data.dtype == np.int16:
data = data.astype(np.float32) / 32768.0
elif data.dtype == np.float32:
data = np.clip(data, -1.0, 1.0)
else:
raise ValueError(f"Unsupported audio data type: {data.dtype}")
# Ensure stereo channels
if data.ndim == 1:
data = np.stack([data, data])
elif data.ndim == 2:
if data.shape[0] == 2:
pass # Already in correct shape
elif data.shape[1] == 2:
data = data.T
else:
data = np.stack([data[:, 0], data[:, 0]]) # Duplicate mono channel
else:
raise ValueError(f"Unsupported audio shape: {data.shape}")
# Convert to torch tensor
data_tensor = torch.FloatTensor(data).unsqueeze(0)
# Resample if necessary
if sample_rate != target_sample_rate:
data_tensor = julius.resample_frac(data_tensor, sample_rate, target_sample_rate)
return data_tensor.to(self.device)
def process_audio(self, input_audio, reference_audio, ito_reference_audio, params, perform_ito, log_ito=False):
input_tensor = self.preprocess_audio(input_audio, self.args.sample_rate)
reference_tensor = self.preprocess_audio(reference_audio, self.args.sample_rate)
ito_reference_tensor = self.preprocess_audio(ito_reference_audio, self.args.sample_rate)
reference_feature = self.get_reference_embedding(reference_tensor)
output_audio, predicted_params = self.mastering_style_transfer(input_tensor, reference_feature)
if perform_ito:
ito_log = []
for i in range(self.args.max_iter_ito):
loss, ito_predicted_params = self.ito_step(input_audio, ito_reference_audio, predicted_params)
if log_ito:
top_10_diff = self.get_top_10_diff(predicted_params, ito_predicted_params)
log_entry = f"Iteration {i+1}, Loss: {loss:.4f}\nTop 10 parameter differences:\n{top_10_diff}\n"
ito_log.append(log_entry)
predicted_params = ito_predicted_params
ito_output_audio = self.converter.convert(input_audio, predicted_params)
ito_log = "\n".join(ito_log) if log_ito else None
else:
ito_output_audio = None
ito_predicted_params = None
ito_log = None
return output_audio, predicted_params, ito_output_audio, ito_predicted_params, ito_log, self.args.sample_rate
def print_param_difference(self, initial_params, ito_params):
all_diffs = []
print("\nAll parameter differences:")
for fx_name in initial_params.keys():
print(f"\n{fx_name.upper()}:")
if isinstance(initial_params[fx_name], dict):
for param_name in initial_params[fx_name].keys():
initial_value = initial_params[fx_name][param_name]
ito_value = ito_params[fx_name][param_name]
# Calculate normalized difference
param_range = self.mastering_converter.fx_processors[fx_name].param_ranges[param_name]
normalized_diff = abs((ito_value - initial_value) / (param_range[1] - param_range[0]))
all_diffs.append((fx_name, param_name, initial_value, ito_value, normalized_diff))
print(f" {param_name}:")
print(f" Initial: {initial_value.item():.4f}")
print(f" ITO: {ito_value.item():.4f}")
print(f" Normalized Diff: {normalized_diff.item():.4f}")
else:
initial_value = initial_params[fx_name]
ito_value = ito_params[fx_name]
# For 'imager', assume range is 0 to 1
normalized_diff = abs(ito_value - initial_value)
all_diffs.append((fx_name, 'width', initial_value, ito_value, normalized_diff))
print(f" width:")
print(f" Initial: {initial_value.item():.4f}")
print(f" ITO: {ito_value.item():.4f}")
print(f" Normalized Diff: {normalized_diff.item():.4f}")
# Sort differences by normalized difference and get top 10
top_diffs = sorted(all_diffs, key=lambda x: x[4], reverse=True)[:10]
print("\nTop 10 parameter differences (sorted by normalized difference):")
for fx_name, param_name, initial_value, ito_value, normalized_diff in top_diffs:
print(f"{fx_name.upper()} - {param_name}:")
print(f" Initial: {initial_value.item():.4f}")
print(f" ITO: {ito_value.item():.4f}")
print(f" Normalized Diff: {normalized_diff.item():.4f}")
print()
def print_predicted_params(self, predicted_params):
if predicted_params is None:
print("No predicted parameters available.")
return
print("Predicted Parameters:")
for fx_name, fx_params in predicted_params.items():
print(f"\n{fx_name.upper()}:")
if isinstance(fx_params, dict):
for param_name, param_value in fx_params.items():
if isinstance(param_value, torch.Tensor):
param_value = param_value.detach().cpu().numpy()
print(f" {param_name}: {param_value}")
elif isinstance(fx_params, torch.Tensor):
param_value = fx_params.detach().cpu().numpy()
print(f" {param_value}")
else:
print(f" {fx_params}")
def get_param_output_string(self, params):
if params is None:
return "No parameters available"
output = []
for fx_name, fx_params in params.items():
output.append(f"{fx_name.upper()}:")
if isinstance(fx_params, dict):
for param_name, param_value in fx_params.items():
if isinstance(param_value, torch.Tensor):
param_value = param_value.item()
output.append(f" {param_name}: {param_value:.4f}")
elif isinstance(fx_params, torch.Tensor):
output.append(f" {fx_params.item():.4f}")
else:
output.append(f" {fx_params:.4f}")
return "\n".join(output)
def get_top_10_diff_string(self, initial_params, ito_params):
if initial_params is None or ito_params is None:
return "Cannot compare parameters"
all_diffs = []
for fx_name in initial_params.keys():
if isinstance(initial_params[fx_name], dict):
for param_name in initial_params[fx_name].keys():
initial_value = initial_params[fx_name][param_name]
ito_value = ito_params[fx_name][param_name]
param_range = self.mastering_converter.fx_processors[fx_name].param_ranges[param_name]
normalized_diff = abs((ito_value - initial_value) / (param_range[1] - param_range[0]))
all_diffs.append((fx_name, param_name, initial_value.item(), ito_value.item(), normalized_diff.item()))
else:
initial_value = initial_params[fx_name]
ito_value = ito_params[fx_name]
normalized_diff = abs(ito_value - initial_value)
all_diffs.append((fx_name, 'width', initial_value.item(), ito_value.item(), normalized_diff.item()))
top_diffs = sorted(all_diffs, key=lambda x: x[4], reverse=True)[:10]
output = ["Top 10 parameter differences (sorted by normalized difference):"]
for fx_name, param_name, initial_value, ito_value, normalized_diff in top_diffs:
output.append(f"{fx_name.upper()} - {param_name}:")
output.append(f" Initial: {initial_value:.4f}")
output.append(f" ITO: {ito_value:.4f}")
output.append(f" Normalized Diff: {normalized_diff:.4f}")
output.append("")
return "\n".join(output)
def reload_weights(model, ckpt_path, device):
checkpoint = torch.load(ckpt_path, map_location=device)
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint["model"].items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=False)
if __name__ == "__main__":
basis_path = '/data2/tony/Mastering_Style_Transfer/results/dasp_tcn_tuneenc_daspman_loudnessnorm/ckpt/1000/'
parser = argparse.ArgumentParser(description="Mastering Style Transfer")
parser.add_argument("--input_path", type=str, required=True, help="Path to input audio file")
parser.add_argument("--reference_path", type=str, required=True, help="Path to reference audio file")
parser.add_argument("--ito_reference_path", type=str, required=True, help="Path to ITO reference audio file")
parser.add_argument("--model_path", type=str, default=f"{basis_path}dasp_tcn_tuneenc_daspman_loudnessnorm_mastering_converter_1000.pt", help="Path to mastering converter model")
parser.add_argument("--encoder_path", type=str, default=f"{basis_path}dasp_tcn_tuneenc_daspman_loudnessnorm_effects_encoder_1000.pt", help="Path to effects encoder model")
parser.add_argument("--perform_ito", action="store_true", help="Whether to perform ITO")
parser.add_argument("--optimizer", type=str, default="RAdam", help="Optimizer for ITO")
parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate for ITO")
parser.add_argument("--num_steps", type=int, default=100, help="Number of optimization steps for ITO")
parser.add_argument("--af_weights", nargs='+', type=float, default=[0.1, 0.001, 1.0, 1.0, 0.1], help="Weights for AudioFeatureLoss")
parser.add_argument("--sample_rate", type=int, default=44100, help="Sample rate for AudioFeatureLoss")
parser.add_argument("--path_to_config", type=str, default='/home/tony/mastering_transfer/networks/configs.yaml', help="Path to network architecture configuration file")
args = parser.parse_args()
# load network configurations
with open(args.path_to_config, 'r') as f:
configs = yaml.full_load(f)
args.cfg_converter = configs['TCN']['param_mapping']
args.cfg_enc = configs['Effects_Encoder']['default']
ito_config = {
'optimizer': args.optimizer,
'learning_rate': args.learning_rate,
'num_steps': args.num_steps,
'af_weights': args.af_weights,
'sample_rate': args.sample_rate
}
mastering_style_transfer = MasteringStyleTransfer(args)
output_audio, predicted_params, ito_output_audio, ito_predicted_params, optimized_reference_feature, sr, ito_steps = mastering_style_transfer.process_audio(
args.input_path, args.reference_path, args.ito_reference_path, ito_config, args.perform_ito
)
# Save the output audio
sf.write("output_mastered.wav", output_audio.T, sr)
if ito_output_audio is not None:
sf.write("ito_output_mastered.wav", ito_output_audio.T, sr)