Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import pytorch_lightning as pl | |
from torch import nn | |
from tqdm import tqdm | |
import numpy as np | |
import einops | |
import wandb | |
import torch | |
# import wandb logging | |
from pytorch_lightning.loggers import WandbLogger | |
from stable_audio_tools import get_pretrained_model | |
from transformers import T5Tokenizer, T5EncoderModel | |
class SinActivation(nn.Module): | |
def forward(self, x): | |
return torch.sin(x) | |
class FourierFeatures(nn.Module): | |
def __init__(self, in_features, out_features, n_layers): | |
super().__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.n_layers = n_layers | |
layers = [] | |
layers += [nn.Linear(in_features, out_features)] | |
# add sin activation | |
layers += [SinActivation()] | |
for i in range(n_layers-1): | |
layers += [nn.Linear(out_features, out_features)] | |
layers += [SinActivation()] | |
self.layers = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.layers(x) | |
class FlowMatchingModule(pl.LightningModule): | |
def __init__(self, main_model=None, text_conditioner=None, max_tokens=128, n_channels=None, t_input=None): | |
super().__init__() | |
self.save_hyperparameters(ignore=['main_model', "text_conditioner"]) | |
self.model = main_model.transformer | |
self.input_layer = main_model.transformer.project_in | |
self.output_layer = main_model.transformer.project_out | |
self.text_conditioner = text_conditioner | |
self.d_model = self.input_layer.weight.shape[0] | |
self.d_input = self.input_layer.weight.shape[1] | |
# use fourier features for schedule | |
self.schedule_embedding = FourierFeatures(1, self.d_model, 2) | |
# use learned positional encoding | |
self.pitch_embedding = nn.Parameter(torch.randn(n_channels, self.d_model)) | |
# make embedding layer for tags | |
self.channels = n_channels | |
mean_proj = [] | |
for layer in self.model.layers: | |
mean_proj += [nn.Linear(self.d_model, self.d_model)] | |
self.mean_proj = nn.ModuleList(mean_proj) | |
def get_example_inputs(self): | |
text = "A piano playing a C major chord" | |
conditioning, conditioning_mask = self.text_conditioner(text, device = self.device) | |
# repeat conditioning | |
conditioning = einops.repeat(conditioning, 'b t d-> b t c d', c=self.channels) | |
conditioning_mask = einops.repeat(conditioning_mask, 'b t -> b t c', c=self.channels) | |
t = torch.rand(1, device=self.device) | |
z = torch.randn(1, self.hparams.t_input ,self.hparams.n_channels, self.d_input , device=self.device) | |
return z, conditioning, conditioning_mask, t | |
def forward(self, x, conditioning, conditioning_mask, t): | |
batch, t_input, n_channels, d_input = x.shape | |
# add conditioning to x | |
x = self.input_layer(x) | |
tz = self.schedule_embedding(t[:,None,None,None]) | |
pitch_z = self.pitch_embedding[None, None, :n_channels, :] | |
# print shapes | |
x = x + tz + pitch_z | |
rot = self.model.rotary_pos_emb.forward_from_seq_len(x.shape[1]) | |
conditioning = einops.rearrange(conditioning, 'b t c d -> (b c) t d', c=self.channels) | |
conditioning_mask = einops.rearrange(conditioning_mask, 'b t c -> (b c) t', c=self.channels) | |
for layer_idx, layer in enumerate(self.model.layers): | |
x = einops.rearrange(x, 'b t c d -> (b c) t d') | |
x = layer(x, rotary_pos_emb=rot, context = conditioning, context_mask = conditioning_mask) | |
x = einops.rearrange(x, '(b c) t d -> b t c d', c=self.channels) | |
x_ch_mean = x.mean(dim=2) | |
x_ch_mean = self.mean_proj[layer_idx](x_ch_mean) | |
# non linearity | |
# x_ch_mean = torch.relu(x_ch_mean) | |
# # layer norm | |
# x_ch_mean = torch.layer_norm(x_ch_mean, x_ch_mean.shape[1:]) | |
x += x_ch_mean[:, :, None, :] | |
x = self.output_layer(x) | |
return x | |
def step(self, batch, batch_idx): | |
x = batch["z"] | |
text = batch["text"] | |
conditioning, conditioning_mask = self.text_conditioner(text, device = self.device) | |
# repeat conditioning | |
conditioning = einops.repeat(conditioning, 'b t d-> b t c d', c=self.channels) | |
conditioning_mask = einops.repeat(conditioning_mask, 'b t -> b t c', c=self.channels) | |
x = einops.rearrange(x, 'b c d t -> b t c d') | |
z0 = torch.randn(x.shape, device=x.device) | |
z1 = x | |
t = torch.rand(x.shape[0], device=x.device) | |
zt = t[:,None,None,None] * z1 + (1 - t[:,None,None,None]) * z0 | |
vt = self(zt,conditioning,conditioning_mask,t) | |
loss = (vt - (z1 - z0)).pow(2).mean() | |
return loss | |
def sample(self, batch_size, text, steps=10, same_latent=False): | |
# Ensure model is on the correct device | |
device = next(self.parameters()).device | |
dtype = self.input_layer.weight.dtype | |
# Move conditioning to the correct device and dtype | |
conditioning, conditioning_mask = self.text_conditioner(text, device=device) | |
conditioning = einops.repeat(conditioning, "b t d-> b t c d", c=self.channels) | |
conditioning_mask = einops.repeat( | |
conditioning_mask, "b t -> b t c", c=self.channels | |
) | |
conditioning = conditioning.to(device=device, dtype=dtype) | |
conditioning_mask = conditioning_mask.to(device=device) | |
self.eval() | |
with torch.no_grad(): | |
# Create initial noise on the correct device and dtype | |
z0 = torch.randn( | |
batch_size, | |
self.hparams.t_input, | |
self.hparams.n_channels, | |
self.d_input, | |
device=device, | |
dtype=dtype, | |
) | |
if same_latent: | |
z0 = z0[0].repeat(batch_size, 1, 1, 1) | |
zt = z0 | |
for step in tqdm(range(steps)): | |
t = torch.tensor([step / steps], device=device, dtype=dtype) | |
zt = zt + (1 / steps) * self.forward( | |
zt, conditioning, conditioning_mask, t | |
) | |
return zt | |
def training_step(self, batch, batch_idx): | |
loss = self.step(batch, batch_idx) | |
self.log('trn_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
loss = self.step(batch, batch_idx) | |
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) | |
return loss | |
def configure_optimizers(self): | |
return torch.optim.Adam(self.parameters(), lr=1e-5) | |
class EncodedAudioDataset(torch.utils.data.Dataset): | |
def __init__(self, paths, pitch_range): | |
records = [] | |
print("Loading data") | |
for path in tqdm(paths): | |
records+=torch.load(path) | |
self.records = records | |
self.pitch_range = pitch_range | |
# keep only records with z | |
self.records = [r for r in self.records if "z" in r] | |
print(f"Loaded {len(self.records)} records") | |
def compose_prompt(self,record): | |
title = record["name"] if "name" in record else record["title"] | |
tags = record["tags"] | |
# take tags | |
# shuffle | |
tags = np.random.choice(tags, len(tags), replace=False) | |
# take random number of tags | |
tags = list(tags[:np.random.randint(0, len(tags)+1)]) | |
# | |
# take either the title or group or type or nothing | |
if "type_group" in record and "type" in record: | |
type_group = record["type_group"] | |
type = record["type"] | |
head = np.random.choice([title, type_group, type]) | |
else: | |
head = np.random.choice([title]) | |
# append tags | |
# with 75% chance add head | |
elements = tags | |
if np.random.rand() < 0.75: | |
elements = [head] + elements | |
# shuffle elements | |
elements = np.random.choice(elements, len(elements), replace=False) | |
prompt = " ".join(elements) | |
# make everything lowercase | |
prompt = prompt.lower() | |
return prompt | |
def __len__(self): | |
return len(self.records) | |
def __getitem__(self, idx): | |
return { | |
"z": self.records[idx]["z"][self.pitch_range[0]:self.pitch_range[1]], | |
"text": self.compose_prompt(self.records[idx]) | |
} | |
def check_for_nans(self): | |
for r in self.records: | |
# check if z has nan values | |
if np.isnan(r["z"]).any(): | |
raise ValueError("Nan values in z") | |
def get_z_shape(self): | |
shapes = [r["z"].shape for r in self.records] | |
# return unique shapes | |
return list(set(shapes)) | |
if __name__ == "__main__": | |
# set seed | |
SEED = 0 | |
torch.manual_seed(SEED) | |
BATCH_SIZE = 1 | |
LATENT_T = 86 | |
# initialize wandb logger | |
wandb.init() | |
logger = WandbLogger(project="synth_flow") | |
# don't log models | |
wandb.config.log_model = False | |
DATASET = "dataset_a" | |
if DATASET == "dataset_a": | |
PITCH_RANGE = [2,12] | |
trn_ds = EncodedAudioDataset([f"artefacts/synth_data_{i}.pt" for i in range(9)], PITCH_RANGE) | |
trn_ds.check_for_nans() | |
trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size=BATCH_SIZE, shuffle=True) | |
val_ds = EncodedAudioDataset([f"artefacts/synth_data_9.pt"], PITCH_RANGE) | |
val_ds.check_for_nans() | |
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True) | |
elif DATASET == "dataset_b": | |
PITCH_RANGE = [0,10] | |
trn_ds = EncodedAudioDataset([f"artefacts/synth_data_2_joined_{i}.pt" for i in range(3)], PITCH_RANGE) | |
trn_ds.check_for_nans() | |
trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size=BATCH_SIZE, shuffle=True) | |
val_ds = EncodedAudioDataset([f"artefacts/synth_data_2_joined_3.pt"], PITCH_RANGE) | |
val_ds.check_for_nans() | |
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True) | |
src_model = get_pretrained_model("stabilityai/stable-audio-open-1.0")[0].to("cpu") | |
src_model = src_model.to("cpu") | |
transformer_model = src_model.model.model | |
transformer_model = transformer_model.train() | |
text_conditioner = src_model.conditioner.conditioners.prompt | |
t5_version = "google-t5/t5-base" | |
lr_callback = pl.callbacks.LearningRateMonitor(logging_interval='step') | |
model = FlowMatchingModule( | |
main_model=transformer_model, | |
text_conditioner=text_conditioner, | |
n_channels=PITCH_RANGE[1] - PITCH_RANGE[0], | |
t_input=LATENT_T, | |
) | |
trainer = pl.Trainer(devices = [3], logger=logger, gradient_clip_val=1.0, callbacks=[lr_callback], max_epochs=1000, precision="16-mixed") | |
trainer.fit(model, trn_dl, val_dl, ckpt_path="synth_flow/9gzpz0i6/epoch=85-step=774000.ckpt") | |
# save checkpoint | |
trainer.save_checkpoint("artefacts/model_finetuned_2.ckpt") | |