soundfont-generator / train_lfm.py
erl-j
first commit
b362624
import torch
import pytorch_lightning as pl
from torch import nn
from tqdm import tqdm
import numpy as np
import einops
import wandb
import torch
# import wandb logging
from pytorch_lightning.loggers import WandbLogger
from stable_audio_tools import get_pretrained_model
from transformers import T5Tokenizer, T5EncoderModel
class SinActivation(nn.Module):
def forward(self, x):
return torch.sin(x)
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, n_layers):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.n_layers = n_layers
layers = []
layers += [nn.Linear(in_features, out_features)]
# add sin activation
layers += [SinActivation()]
for i in range(n_layers-1):
layers += [nn.Linear(out_features, out_features)]
layers += [SinActivation()]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class FlowMatchingModule(pl.LightningModule):
def __init__(self, main_model=None, text_conditioner=None, max_tokens=128, n_channels=None, t_input=None):
super().__init__()
self.save_hyperparameters(ignore=['main_model', "text_conditioner"])
self.model = main_model.transformer
self.input_layer = main_model.transformer.project_in
self.output_layer = main_model.transformer.project_out
self.text_conditioner = text_conditioner
self.d_model = self.input_layer.weight.shape[0]
self.d_input = self.input_layer.weight.shape[1]
# use fourier features for schedule
self.schedule_embedding = FourierFeatures(1, self.d_model, 2)
# use learned positional encoding
self.pitch_embedding = nn.Parameter(torch.randn(n_channels, self.d_model))
# make embedding layer for tags
self.channels = n_channels
mean_proj = []
for layer in self.model.layers:
mean_proj += [nn.Linear(self.d_model, self.d_model)]
self.mean_proj = nn.ModuleList(mean_proj)
def get_example_inputs(self):
text = "A piano playing a C major chord"
conditioning, conditioning_mask = self.text_conditioner(text, device = self.device)
# repeat conditioning
conditioning = einops.repeat(conditioning, 'b t d-> b t c d', c=self.channels)
conditioning_mask = einops.repeat(conditioning_mask, 'b t -> b t c', c=self.channels)
t = torch.rand(1, device=self.device)
z = torch.randn(1, self.hparams.t_input ,self.hparams.n_channels, self.d_input , device=self.device)
return z, conditioning, conditioning_mask, t
def forward(self, x, conditioning, conditioning_mask, t):
batch, t_input, n_channels, d_input = x.shape
# add conditioning to x
x = self.input_layer(x)
tz = self.schedule_embedding(t[:,None,None,None])
pitch_z = self.pitch_embedding[None, None, :n_channels, :]
# print shapes
x = x + tz + pitch_z
rot = self.model.rotary_pos_emb.forward_from_seq_len(x.shape[1])
conditioning = einops.rearrange(conditioning, 'b t c d -> (b c) t d', c=self.channels)
conditioning_mask = einops.rearrange(conditioning_mask, 'b t c -> (b c) t', c=self.channels)
for layer_idx, layer in enumerate(self.model.layers):
x = einops.rearrange(x, 'b t c d -> (b c) t d')
x = layer(x, rotary_pos_emb=rot, context = conditioning, context_mask = conditioning_mask)
x = einops.rearrange(x, '(b c) t d -> b t c d', c=self.channels)
x_ch_mean = x.mean(dim=2)
x_ch_mean = self.mean_proj[layer_idx](x_ch_mean)
# non linearity
# x_ch_mean = torch.relu(x_ch_mean)
# # layer norm
# x_ch_mean = torch.layer_norm(x_ch_mean, x_ch_mean.shape[1:])
x += x_ch_mean[:, :, None, :]
x = self.output_layer(x)
return x
def step(self, batch, batch_idx):
x = batch["z"]
text = batch["text"]
conditioning, conditioning_mask = self.text_conditioner(text, device = self.device)
# repeat conditioning
conditioning = einops.repeat(conditioning, 'b t d-> b t c d', c=self.channels)
conditioning_mask = einops.repeat(conditioning_mask, 'b t -> b t c', c=self.channels)
x = einops.rearrange(x, 'b c d t -> b t c d')
z0 = torch.randn(x.shape, device=x.device)
z1 = x
t = torch.rand(x.shape[0], device=x.device)
zt = t[:,None,None,None] * z1 + (1 - t[:,None,None,None]) * z0
vt = self(zt,conditioning,conditioning_mask,t)
loss = (vt - (z1 - z0)).pow(2).mean()
return loss
@torch.inference_mode()
def sample(self, batch_size, text, steps=10, same_latent=False):
# Ensure model is on the correct device
device = next(self.parameters()).device
dtype = self.input_layer.weight.dtype
# Move conditioning to the correct device and dtype
conditioning, conditioning_mask = self.text_conditioner(text, device=device)
conditioning = einops.repeat(conditioning, "b t d-> b t c d", c=self.channels)
conditioning_mask = einops.repeat(
conditioning_mask, "b t -> b t c", c=self.channels
)
conditioning = conditioning.to(device=device, dtype=dtype)
conditioning_mask = conditioning_mask.to(device=device)
self.eval()
with torch.no_grad():
# Create initial noise on the correct device and dtype
z0 = torch.randn(
batch_size,
self.hparams.t_input,
self.hparams.n_channels,
self.d_input,
device=device,
dtype=dtype,
)
if same_latent:
z0 = z0[0].repeat(batch_size, 1, 1, 1)
zt = z0
for step in tqdm(range(steps)):
t = torch.tensor([step / steps], device=device, dtype=dtype)
zt = zt + (1 / steps) * self.forward(
zt, conditioning, conditioning_mask, t
)
return zt
def training_step(self, batch, batch_idx):
loss = self.step(batch, batch_idx)
self.log('trn_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
loss = self.step(batch, batch_idx)
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-5)
class EncodedAudioDataset(torch.utils.data.Dataset):
def __init__(self, paths, pitch_range):
records = []
print("Loading data")
for path in tqdm(paths):
records+=torch.load(path)
self.records = records
self.pitch_range = pitch_range
# keep only records with z
self.records = [r for r in self.records if "z" in r]
print(f"Loaded {len(self.records)} records")
def compose_prompt(self,record):
title = record["name"] if "name" in record else record["title"]
tags = record["tags"]
# take tags
# shuffle
tags = np.random.choice(tags, len(tags), replace=False)
# take random number of tags
tags = list(tags[:np.random.randint(0, len(tags)+1)])
#
# take either the title or group or type or nothing
if "type_group" in record and "type" in record:
type_group = record["type_group"]
type = record["type"]
head = np.random.choice([title, type_group, type])
else:
head = np.random.choice([title])
# append tags
# with 75% chance add head
elements = tags
if np.random.rand() < 0.75:
elements = [head] + elements
# shuffle elements
elements = np.random.choice(elements, len(elements), replace=False)
prompt = " ".join(elements)
# make everything lowercase
prompt = prompt.lower()
return prompt
def __len__(self):
return len(self.records)
def __getitem__(self, idx):
return {
"z": self.records[idx]["z"][self.pitch_range[0]:self.pitch_range[1]],
"text": self.compose_prompt(self.records[idx])
}
def check_for_nans(self):
for r in self.records:
# check if z has nan values
if np.isnan(r["z"]).any():
raise ValueError("Nan values in z")
def get_z_shape(self):
shapes = [r["z"].shape for r in self.records]
# return unique shapes
return list(set(shapes))
if __name__ == "__main__":
# set seed
SEED = 0
torch.manual_seed(SEED)
BATCH_SIZE = 1
LATENT_T = 86
# initialize wandb logger
wandb.init()
logger = WandbLogger(project="synth_flow")
# don't log models
wandb.config.log_model = False
DATASET = "dataset_a"
if DATASET == "dataset_a":
PITCH_RANGE = [2,12]
trn_ds = EncodedAudioDataset([f"artefacts/synth_data_{i}.pt" for i in range(9)], PITCH_RANGE)
trn_ds.check_for_nans()
trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size=BATCH_SIZE, shuffle=True)
val_ds = EncodedAudioDataset([f"artefacts/synth_data_9.pt"], PITCH_RANGE)
val_ds.check_for_nans()
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True)
elif DATASET == "dataset_b":
PITCH_RANGE = [0,10]
trn_ds = EncodedAudioDataset([f"artefacts/synth_data_2_joined_{i}.pt" for i in range(3)], PITCH_RANGE)
trn_ds.check_for_nans()
trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size=BATCH_SIZE, shuffle=True)
val_ds = EncodedAudioDataset([f"artefacts/synth_data_2_joined_3.pt"], PITCH_RANGE)
val_ds.check_for_nans()
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True)
src_model = get_pretrained_model("stabilityai/stable-audio-open-1.0")[0].to("cpu")
src_model = src_model.to("cpu")
transformer_model = src_model.model.model
transformer_model = transformer_model.train()
text_conditioner = src_model.conditioner.conditioners.prompt
t5_version = "google-t5/t5-base"
lr_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
model = FlowMatchingModule(
main_model=transformer_model,
text_conditioner=text_conditioner,
n_channels=PITCH_RANGE[1] - PITCH_RANGE[0],
t_input=LATENT_T,
)
trainer = pl.Trainer(devices = [3], logger=logger, gradient_clip_val=1.0, callbacks=[lr_callback], max_epochs=1000, precision="16-mixed")
trainer.fit(model, trn_dl, val_dl, ckpt_path="synth_flow/9gzpz0i6/epoch=85-step=774000.ckpt")
# save checkpoint
trainer.save_checkpoint("artefacts/model_finetuned_2.ckpt")