import torch import pytorch_lightning as pl from torch import nn from tqdm import tqdm import numpy as np import einops import wandb import torch # import wandb logging from pytorch_lightning.loggers import WandbLogger from stable_audio_tools import get_pretrained_model from transformers import T5Tokenizer, T5EncoderModel class SinActivation(nn.Module): def forward(self, x): return torch.sin(x) class FourierFeatures(nn.Module): def __init__(self, in_features, out_features, n_layers): super().__init__() self.in_features = in_features self.out_features = out_features self.n_layers = n_layers layers = [] layers += [nn.Linear(in_features, out_features)] # add sin activation layers += [SinActivation()] for i in range(n_layers-1): layers += [nn.Linear(out_features, out_features)] layers += [SinActivation()] self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) class FlowMatchingModule(pl.LightningModule): def __init__(self, main_model=None, text_conditioner=None, max_tokens=128, n_channels=None, t_input=None): super().__init__() self.save_hyperparameters(ignore=['main_model', "text_conditioner"]) self.model = main_model.transformer self.input_layer = main_model.transformer.project_in self.output_layer = main_model.transformer.project_out self.text_conditioner = text_conditioner self.d_model = self.input_layer.weight.shape[0] self.d_input = self.input_layer.weight.shape[1] # use fourier features for schedule self.schedule_embedding = FourierFeatures(1, self.d_model, 2) # use learned positional encoding self.pitch_embedding = nn.Parameter(torch.randn(n_channels, self.d_model)) # make embedding layer for tags self.channels = n_channels mean_proj = [] for layer in self.model.layers: mean_proj += [nn.Linear(self.d_model, self.d_model)] self.mean_proj = nn.ModuleList(mean_proj) def get_example_inputs(self): text = "A piano playing a C major chord" conditioning, conditioning_mask = self.text_conditioner(text, device = self.device) # repeat conditioning conditioning = einops.repeat(conditioning, 'b t d-> b t c d', c=self.channels) conditioning_mask = einops.repeat(conditioning_mask, 'b t -> b t c', c=self.channels) t = torch.rand(1, device=self.device) z = torch.randn(1, self.hparams.t_input ,self.hparams.n_channels, self.d_input , device=self.device) return z, conditioning, conditioning_mask, t def forward(self, x, conditioning, conditioning_mask, t): batch, t_input, n_channels, d_input = x.shape # add conditioning to x x = self.input_layer(x) tz = self.schedule_embedding(t[:,None,None,None]) pitch_z = self.pitch_embedding[None, None, :n_channels, :] # print shapes x = x + tz + pitch_z rot = self.model.rotary_pos_emb.forward_from_seq_len(x.shape[1]) conditioning = einops.rearrange(conditioning, 'b t c d -> (b c) t d', c=self.channels) conditioning_mask = einops.rearrange(conditioning_mask, 'b t c -> (b c) t', c=self.channels) for layer_idx, layer in enumerate(self.model.layers): x = einops.rearrange(x, 'b t c d -> (b c) t d') x = layer(x, rotary_pos_emb=rot, context = conditioning, context_mask = conditioning_mask) x = einops.rearrange(x, '(b c) t d -> b t c d', c=self.channels) x_ch_mean = x.mean(dim=2) x_ch_mean = self.mean_proj[layer_idx](x_ch_mean) # non linearity # x_ch_mean = torch.relu(x_ch_mean) # # layer norm # x_ch_mean = torch.layer_norm(x_ch_mean, x_ch_mean.shape[1:]) x += x_ch_mean[:, :, None, :] x = self.output_layer(x) return x def step(self, batch, batch_idx): x = batch["z"] text = batch["text"] conditioning, conditioning_mask = self.text_conditioner(text, device = self.device) # repeat conditioning conditioning = einops.repeat(conditioning, 'b t d-> b t c d', c=self.channels) conditioning_mask = einops.repeat(conditioning_mask, 'b t -> b t c', c=self.channels) x = einops.rearrange(x, 'b c d t -> b t c d') z0 = torch.randn(x.shape, device=x.device) z1 = x t = torch.rand(x.shape[0], device=x.device) zt = t[:,None,None,None] * z1 + (1 - t[:,None,None,None]) * z0 vt = self(zt,conditioning,conditioning_mask,t) loss = (vt - (z1 - z0)).pow(2).mean() return loss @torch.inference_mode() def sample(self, batch_size, text, steps=10, same_latent=False): # Ensure model is on the correct device device = next(self.parameters()).device dtype = self.input_layer.weight.dtype # Move conditioning to the correct device and dtype conditioning, conditioning_mask = self.text_conditioner(text, device=device) conditioning = einops.repeat(conditioning, "b t d-> b t c d", c=self.channels) conditioning_mask = einops.repeat( conditioning_mask, "b t -> b t c", c=self.channels ) conditioning = conditioning.to(device=device, dtype=dtype) conditioning_mask = conditioning_mask.to(device=device) self.eval() with torch.no_grad(): # Create initial noise on the correct device and dtype z0 = torch.randn( batch_size, self.hparams.t_input, self.hparams.n_channels, self.d_input, device=device, dtype=dtype, ) if same_latent: z0 = z0[0].repeat(batch_size, 1, 1, 1) zt = z0 for step in tqdm(range(steps)): t = torch.tensor([step / steps], device=device, dtype=dtype) zt = zt + (1 / steps) * self.forward( zt, conditioning, conditioning_mask, t ) return zt def training_step(self, batch, batch_idx): loss = self.step(batch, batch_idx) self.log('trn_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return loss def validation_step(self, batch, batch_idx): loss = self.step(batch, batch_idx) self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-5) class EncodedAudioDataset(torch.utils.data.Dataset): def __init__(self, paths, pitch_range): records = [] print("Loading data") for path in tqdm(paths): records+=torch.load(path) self.records = records self.pitch_range = pitch_range # keep only records with z self.records = [r for r in self.records if "z" in r] print(f"Loaded {len(self.records)} records") def compose_prompt(self,record): title = record["name"] if "name" in record else record["title"] tags = record["tags"] # take tags # shuffle tags = np.random.choice(tags, len(tags), replace=False) # take random number of tags tags = list(tags[:np.random.randint(0, len(tags)+1)]) # # take either the title or group or type or nothing if "type_group" in record and "type" in record: type_group = record["type_group"] type = record["type"] head = np.random.choice([title, type_group, type]) else: head = np.random.choice([title]) # append tags # with 75% chance add head elements = tags if np.random.rand() < 0.75: elements = [head] + elements # shuffle elements elements = np.random.choice(elements, len(elements), replace=False) prompt = " ".join(elements) # make everything lowercase prompt = prompt.lower() return prompt def __len__(self): return len(self.records) def __getitem__(self, idx): return { "z": self.records[idx]["z"][self.pitch_range[0]:self.pitch_range[1]], "text": self.compose_prompt(self.records[idx]) } def check_for_nans(self): for r in self.records: # check if z has nan values if np.isnan(r["z"]).any(): raise ValueError("Nan values in z") def get_z_shape(self): shapes = [r["z"].shape for r in self.records] # return unique shapes return list(set(shapes)) if __name__ == "__main__": # set seed SEED = 0 torch.manual_seed(SEED) BATCH_SIZE = 1 LATENT_T = 86 # initialize wandb logger wandb.init() logger = WandbLogger(project="synth_flow") # don't log models wandb.config.log_model = False DATASET = "dataset_a" if DATASET == "dataset_a": PITCH_RANGE = [2,12] trn_ds = EncodedAudioDataset([f"artefacts/synth_data_{i}.pt" for i in range(9)], PITCH_RANGE) trn_ds.check_for_nans() trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size=BATCH_SIZE, shuffle=True) val_ds = EncodedAudioDataset([f"artefacts/synth_data_9.pt"], PITCH_RANGE) val_ds.check_for_nans() val_dl = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True) elif DATASET == "dataset_b": PITCH_RANGE = [0,10] trn_ds = EncodedAudioDataset([f"artefacts/synth_data_2_joined_{i}.pt" for i in range(3)], PITCH_RANGE) trn_ds.check_for_nans() trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size=BATCH_SIZE, shuffle=True) val_ds = EncodedAudioDataset([f"artefacts/synth_data_2_joined_3.pt"], PITCH_RANGE) val_ds.check_for_nans() val_dl = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True) src_model = get_pretrained_model("stabilityai/stable-audio-open-1.0")[0].to("cpu") src_model = src_model.to("cpu") transformer_model = src_model.model.model transformer_model = transformer_model.train() text_conditioner = src_model.conditioner.conditioners.prompt t5_version = "google-t5/t5-base" lr_callback = pl.callbacks.LearningRateMonitor(logging_interval='step') model = FlowMatchingModule( main_model=transformer_model, text_conditioner=text_conditioner, n_channels=PITCH_RANGE[1] - PITCH_RANGE[0], t_input=LATENT_T, ) trainer = pl.Trainer(devices = [3], logger=logger, gradient_clip_val=1.0, callbacks=[lr_callback], max_epochs=1000, precision="16-mixed") trainer.fit(model, trn_dl, val_dl, ckpt_path="synth_flow/9gzpz0i6/epoch=85-step=774000.ckpt") # save checkpoint trainer.save_checkpoint("artefacts/model_finetuned_2.ckpt")