crlandsc's picture
playback optimization
ae43eae
raw
history blame
5.61 kB
# Imports
import gradio as gr
import matplotlib.pyplot as plt
import torch
import torchaudio
from torch import nn
import pytorch_lightning as pl
from ema_pytorch import EMA
import yaml
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
# Load configs
def load_configs(config_path):
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
pl_configs = config['model']
model_configs = config['model']['model']
return pl_configs, model_configs
# plot mel spectrogram
def plot_mel_spectrogram(sample, sr):
transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sr,
n_fft=1024,
hop_length=512,
n_mels=80,
center=True,
norm="slaney",
)
spectrogram = transform(torch.mean(sample, dim=0)) # downmix and cal spectrogram
spectrogram = torchaudio.functional.amplitude_to_DB(spectrogram, 1.0, 1e-10, 80.0)
# Plot the Mel spectrogram
fig = plt.figure(figsize=(7, 4))
plt.imshow(spectrogram, aspect='auto', origin='lower')
plt.colorbar(format='%+2.0f dB')
plt.xlabel('Frame')
plt.ylabel('Mel Bin')
plt.title('Mel Spectrogram')
plt.tight_layout()
return fig
# Define PyTorch Lightning model
class Model(pl.LightningModule):
def __init__(
self,
lr: float,
lr_beta1: float,
lr_beta2: float,
lr_eps: float,
lr_weight_decay: float,
ema_beta: float,
ema_power: float,
model: nn.Module,
):
super().__init__()
self.lr = lr
self.lr_beta1 = lr_beta1
self.lr_beta2 = lr_beta2
self.lr_eps = lr_eps
self.lr_weight_decay = lr_weight_decay
self.model = model
self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power)
# Instantiate model (must match model that was trained)
def load_model(model_configs, pl_configs) -> nn.Module:
# Diffusion model
model = DiffusionModel(
net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
in_channels=model_configs['in_channels'], # U-Net: number of input/output (audio) channels
channels=model_configs['channels'], # U-Net: channels at each layer
factors=model_configs['factors'], # U-Net: downsampling and upsampling factors at each layer
items=model_configs['items'], # U-Net: number of repeating items at each layer
attentions=model_configs['attentions'], # U-Net: attention enabled/disabled at each layer
attention_heads=model_configs['attention_heads'], # U-Net: number of attention heads per attention item
attention_features=model_configs['attention_features'], # U-Net: number of attention features per attention item
diffusion_t=VDiffusion, # The diffusion method used
sampler_t=VSampler # The diffusion sampler used
)
# pl model
model = Model(
lr=pl_configs['lr'],
lr_beta1=pl_configs['lr_beta1'],
lr_beta2=pl_configs['lr_beta2'],
lr_eps=pl_configs['lr_eps'],
lr_weight_decay=pl_configs['lr_weight_decay'],
ema_beta=pl_configs['ema_beta'],
ema_power=pl_configs['ema_power'],
model=model
)
return model
# Assign to GPU
def assign_to_gpu(model):
if torch.cuda.is_available():
model = model.to('cuda')
print(f"Device: {model.device}")
return model
# Load model checkpoint
def load_checkpoint(model, ckpt_path) -> None:
checkpoint = torch.load(ckpt_path, map_location='cpu')['state_dict']
model.load_state_dict(checkpoint) # should output "<All keys matched successfully>"
# Generate Samples
def generate_samples(model_name, num_samples, num_steps, duration=32768):
# load_checkpoint
ckpt_path = models[model_name]
load_checkpoint(model, ckpt_path)
if num_samples > 1:
duration = duration / 2
with torch.no_grad():
all_samples = torch.zeros(2, 0) # initialize all samples
for i in range(num_samples):
noise = torch.randn((1, 2, int(duration)), device=model.device) # [batch_size, in_channels, length]
generated_sample = model.model_ema.ema_model.sample(noise, num_steps=num_steps).squeeze(0).cpu() # Suggested num_steps 10-100
# concatenate all samples:
all_samples = torch.concat((all_samples, generated_sample), dim=1)
torch.cuda.empty_cache()
fig = plot_mel_spectrogram(all_samples, sr)
plt.title(f"{model_name} Mel Spectrogram")
return (sr, all_samples.cpu().detach().numpy().T), fig # (sample rate, audio), plot
# load model & configs
sr = 44100 # sampling rate
config_path = "saved_models/config.yaml" # config path
pl_configs, model_configs = load_configs(config_path)
model = load_model(model_configs, pl_configs)
model = assign_to_gpu(model)
models = {
"Kicks": "saved_models/kicks/kicks_v7.ckpt",
"Snares": "saved_models/snares/snares_v0.ckpt",
"Hi-hats": "saved_models/hihats/hihats_v2.ckpt",
"Percussion": "saved_models/percussion/percussion_v0.ckpt"
}
demo = gr.Interface(
generate_samples,
inputs=[
gr.Dropdown(choices=list(models.keys()), value=list(models.keys())[3], label="Model"),
gr.Slider(1, 25, step=1, label="Number of Samples to Generate", value=3),
gr.Slider(1, 100, step=1, label="Number of Diffusion Steps", value=15)
],
outputs=[
gr.Audio(label="Generated Audio Sample"),
gr.Plot(label="Generated Audio Spectrogram")
]
)
if __name__ == "__main__":
demo.launch()