Spaces:
Running
Running
# Imports | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import torch | |
import torchaudio | |
from torch import nn | |
import pytorch_lightning as pl | |
from ema_pytorch import EMA | |
import yaml | |
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler | |
# Load configs | |
def load_configs(config_path): | |
with open(config_path, 'r') as file: | |
config = yaml.safe_load(file) | |
pl_configs = config['model'] | |
model_configs = config['model']['model'] | |
return pl_configs, model_configs | |
# plot mel spectrogram | |
def plot_mel_spectrogram(sample, sr): | |
transform = torchaudio.transforms.MelSpectrogram( | |
sample_rate=sr, | |
n_fft=1024, | |
hop_length=512, | |
n_mels=80, | |
center=True, | |
norm="slaney", | |
) | |
spectrogram = transform(torch.mean(sample, dim=0)) # downmix and cal spectrogram | |
spectrogram = torchaudio.functional.amplitude_to_DB(spectrogram, 1.0, 1e-10, 80.0) | |
# Plot the Mel spectrogram | |
fig = plt.figure(figsize=(7, 4)) | |
plt.imshow(spectrogram, aspect='auto', origin='lower') | |
plt.colorbar(format='%+2.0f dB') | |
plt.xlabel('Frame') | |
plt.ylabel('Mel Bin') | |
plt.title('Mel Spectrogram') | |
plt.tight_layout() | |
return fig | |
# Define PyTorch Lightning model | |
class Model(pl.LightningModule): | |
def __init__( | |
self, | |
lr: float, | |
lr_beta1: float, | |
lr_beta2: float, | |
lr_eps: float, | |
lr_weight_decay: float, | |
ema_beta: float, | |
ema_power: float, | |
model: nn.Module, | |
): | |
super().__init__() | |
self.lr = lr | |
self.lr_beta1 = lr_beta1 | |
self.lr_beta2 = lr_beta2 | |
self.lr_eps = lr_eps | |
self.lr_weight_decay = lr_weight_decay | |
self.model = model | |
self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power) | |
# Instantiate model (must match model that was trained) | |
def load_model(model_configs, pl_configs) -> nn.Module: | |
# Diffusion model | |
model = DiffusionModel( | |
net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case) | |
in_channels=model_configs['in_channels'], # U-Net: number of input/output (audio) channels | |
channels=model_configs['channels'], # U-Net: channels at each layer | |
factors=model_configs['factors'], # U-Net: downsampling and upsampling factors at each layer | |
items=model_configs['items'], # U-Net: number of repeating items at each layer | |
attentions=model_configs['attentions'], # U-Net: attention enabled/disabled at each layer | |
attention_heads=model_configs['attention_heads'], # U-Net: number of attention heads per attention item | |
attention_features=model_configs['attention_features'], # U-Net: number of attention features per attention item | |
diffusion_t=VDiffusion, # The diffusion method used | |
sampler_t=VSampler # The diffusion sampler used | |
) | |
# pl model | |
model = Model( | |
lr=pl_configs['lr'], | |
lr_beta1=pl_configs['lr_beta1'], | |
lr_beta2=pl_configs['lr_beta2'], | |
lr_eps=pl_configs['lr_eps'], | |
lr_weight_decay=pl_configs['lr_weight_decay'], | |
ema_beta=pl_configs['ema_beta'], | |
ema_power=pl_configs['ema_power'], | |
model=model | |
) | |
return model | |
# Assign to GPU | |
def assign_to_gpu(model): | |
if torch.cuda.is_available(): | |
model = model.to('cuda') | |
print(f"Device: {model.device}") | |
return model | |
# Load model checkpoint | |
def load_checkpoint(model, ckpt_path) -> None: | |
checkpoint = torch.load(ckpt_path, map_location='cpu')['state_dict'] | |
model.load_state_dict(checkpoint) # should output "<All keys matched successfully>" | |
# Generate Samples | |
def generate_samples(model_name, num_samples, num_steps, duration=32768): | |
# load_checkpoint | |
ckpt_path = models[model_name] | |
load_checkpoint(model, ckpt_path) | |
if num_samples > 1: | |
duration = duration / 2 | |
with torch.no_grad(): | |
all_samples = torch.zeros(2, 0) # initialize all samples | |
for i in range(num_samples): | |
noise = torch.randn((1, 2, int(duration)), device=model.device) # [batch_size, in_channels, length] | |
generated_sample = model.model_ema.ema_model.sample(noise, num_steps=num_steps).squeeze(0).cpu() # Suggested num_steps 10-100 | |
# concatenate all samples: | |
all_samples = torch.concat((all_samples, generated_sample), dim=1) | |
torch.cuda.empty_cache() | |
fig = plot_mel_spectrogram(all_samples, sr) | |
plt.title(f"{model_name} Mel Spectrogram") | |
return (sr, all_samples.cpu().detach().numpy().T), fig # (sample rate, audio), plot | |
# load model & configs | |
sr = 44100 # sampling rate | |
config_path = "saved_models/config.yaml" # config path | |
pl_configs, model_configs = load_configs(config_path) | |
model = load_model(model_configs, pl_configs) | |
model = assign_to_gpu(model) | |
models = { | |
"Kicks": "saved_models/kicks/kicks_v7.ckpt", | |
"Snares": "saved_models/snares/snares_v0.ckpt", | |
"Hi-hats": "saved_models/hihats/hihats_v2.ckpt", | |
"Percussion": "saved_models/percussion/percussion_v0.ckpt" | |
} | |
demo = gr.Interface( | |
generate_samples, | |
inputs=[ | |
gr.Dropdown(choices=list(models.keys()), value=list(models.keys())[3], label="Model"), | |
gr.Slider(1, 25, step=1, label="Number of Samples to Generate", value=3), | |
gr.Slider(1, 100, step=1, label="Number of Diffusion Steps", value=15) | |
], | |
outputs=[ | |
gr.Audio(label="Generated Audio Sample"), | |
gr.Plot(label="Generated Audio Spectrogram") | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |