from typing import Union, Tuple import torch import torch.nn.functional as F from torch import nn, Tensor from torch_sparse import SparseTensor, masked_select_nnz from torch_geometric.typing import OptTensor, Adj from torch_geometric.nn.inits import reset from torch_geometric.nn.norm import BatchNorm from torch_geometric.nn.glob import GlobalAttention from torch_geometric.data import Batch from torch_geometric.nn.conv import RGCNConv import constants from data import graph_from_tensor @torch.jit._overload def masked_edge_index(edge_index, edge_mask): # type: (Tensor, Tensor) -> Tensor pass @torch.jit._overload def masked_edge_index(edge_index, edge_mask): # type: (SparseTensor, Tensor) -> SparseTensor pass def masked_edge_index(edge_index, edge_mask): if isinstance(edge_index, Tensor): return edge_index[:, edge_mask] else: return masked_select_nnz(edge_index, edge_mask, layout='coo') def masked_edge_attrs(edge_attrs, edge_mask): return edge_attrs[edge_mask, :] class GCL(RGCNConv): def __init__(self, in_channels, out_channels, num_relations, nn, dropout=0.1, **kwargs): super().__init__(in_channels=in_channels, out_channels=out_channels, num_relations=num_relations, **kwargs) self.nn = nn self.dropout = dropout self.reset_edge_nn() def reset_edge_nn(self): reset(self.nn) def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]], edge_index: Adj, edge_type: OptTensor = None, edge_attr: OptTensor = None): # Convert input features to a pair of node features or node indices. x_l: OptTensor = None if isinstance(x, tuple): x_l = x[0] else: x_l = x if x_l is None: x_l = torch.arange(self.in_channels_l, device=self.weight.device) x_r: Tensor = x_l if isinstance(x, tuple): x_r = x[1] size = (x_l.size(0), x_r.size(0)) if isinstance(edge_index, SparseTensor): edge_type = edge_index.storage.value() assert edge_type is not None # propagate_type: (x: Tensor) out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device) weight = self.weight # Basis-decomposition if self.num_bases is not None: weight = (self.comp @ weight.view(self.num_bases, -1)).view( self.num_relations, self.in_channels_l, self.out_channels) # Block-diagonal-decomposition if self.num_blocks is not None: if x_l.dtype == torch.long and self.num_blocks is not None: raise ValueError('Block-diagonal decomposition not supported ' 'for non-continuous input features.') for i in range(self.num_relations): tmp = masked_edge_index(edge_index, edge_type == i) h = self.propagate(tmp, x=x_l, size=size) h = h.view(-1, weight.size(1), weight.size(2)) h = torch.einsum('abc,bcd->abd', h, weight[i]) out += h.contiguous().view(-1, self.out_channels) else: # No regularization/Basis-decomposition for i in range(self.num_relations): tmp = masked_edge_index(edge_index, edge_type == i) attr = masked_edge_attrs(edge_attr, edge_type == i) if x_l.dtype == torch.long: out += self.propagate(tmp, x=weight[i, x_l], size=size) else: h = self.propagate(tmp, x=x_l, size=size, edge_attr=attr) out = out + (h @ weight[i]) root = self.root if root is not None: out += root[x_r] if x_r.dtype == torch.long else x_r @ root if self.bias is not None: out += self.bias return out def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: # Use edge nn to compute weight tensor from edge attributes # (=onehot timestep distances between nodes) weights = self.nn(edge_attr) weights = weights[..., :self.in_channels_l] weights = weights.view(-1, self.in_channels_l) out = x_j * weights out = F.relu(out) out = F.dropout(out, p=self.dropout, training=self.training) return out class MLP(nn.Module): def __init__(self, input_dim=256, hidden_dim=256, output_dim=256, num_layers=2, activation=True, dropout=0.1): super().__init__() self.layers = nn.ModuleList() if num_layers == 1: self.layers.append(nn.Linear(input_dim, output_dim)) else: # Input layer (1) + Intermediate layers (n-2) + Output layer (1) self.layers.append(nn.Linear(input_dim, hidden_dim)) for _ in range(num_layers - 2): self.layers.append(nn.Linear(hidden_dim, hidden_dim)) self.layers.append(nn.Linear(hidden_dim, output_dim)) self.activation = activation self.p = dropout def forward(self, x): for layer in self.layers: x = F.dropout(x, p=self.p, training=self.training) x = layer(x) if self.activation: x = F.relu(x) return x class GCN(nn.Module): def __init__(self, input_dim=256, hidden_dim=256, n_layers=3, num_relations=3, num_dists=32, batch_norm=False, dropout=0.1): super().__init__() self.layers = nn.ModuleList() self.norm_layers = nn.ModuleList() edge_nn = nn.Linear(num_dists, input_dim) self.batch_norm = batch_norm self.layers.append(GCL(input_dim, hidden_dim, num_relations, edge_nn)) if self.batch_norm: self.norm_layers.append(BatchNorm(hidden_dim)) for i in range(n_layers-1): self.layers.append(GCL(hidden_dim, hidden_dim, num_relations, edge_nn)) if self.batch_norm: self.norm_layers.append(BatchNorm(hidden_dim)) self.p = dropout def forward(self, data): x, edge_index, edge_attrs = data.x, data.edge_index, data.edge_attrs edge_type = edge_attrs[:, 0] edge_attr = edge_attrs[:, 1:] for i in range(len(self.layers)): residual = x x = F.dropout(x, p=self.p, training=self.training) x = self.layers[i](x, edge_index, edge_type, edge_attr) if self.batch_norm: x = self.norm_layers[i](x) x = F.relu(x) x = residual + x return x class CNNEncoder(nn.Module): def __init__(self, output_dim=256, dense_dim=256, batch_norm=False, dropout=0.1): super().__init__() # Convolutional layers if batch_norm: self.conv = nn.Sequential( # From (4 x 32) to (8 x 4 x 32) nn.Conv2d(1, 8, 3, padding=1), nn.BatchNorm2d(8), nn.ReLU(True), # From (8 x 4 x 32) to (8 x 4 x 8) nn.MaxPool2d((1, 4), stride=(1, 4)), # From (8 x 4 x 8) to (16 x 4 x 8) nn.Conv2d(8, 16, 3, padding=1), nn.BatchNorm2d(16), nn.ReLU(True) ) else: self.conv = nn.Sequential( nn.Conv2d(1, 8, 3, padding=1), nn.ReLU(True), nn.MaxPool2d((1, 4), stride=(1, 4)), nn.Conv2d(8, 16, 3, padding=1), nn.ReLU(True) ) self.flatten = nn.Flatten(start_dim=1) # Linear layers self.lin = nn.Sequential( nn.Dropout(dropout), nn.Linear(16 * 4 * 8, dense_dim), nn.ReLU(True), nn.Dropout(dropout), nn.Linear(dense_dim, output_dim) ) def forward(self, x): x = x.unsqueeze(1) x = self.conv(x) x = self.flatten(x) x = self.lin(x) return x class CNNDecoder(nn.Module): def __init__(self, input_dim=256, dense_dim=256, batch_norm=False, dropout=0.1): super().__init__() # Linear decompressors self.lin = nn.Sequential( nn.Dropout(dropout), nn.Linear(input_dim, dense_dim), nn.ReLU(True), nn.Dropout(dropout), nn.Linear(dense_dim, 16 * 4 * 8), nn.ReLU(True) ) self.unflatten = nn.Unflatten(dim=1, unflattened_size=(16, 4, 8)) # Upsample and convolutional layers if batch_norm: self.conv = nn.Sequential( nn.Upsample(scale_factor=(1, 4), mode='nearest'), nn.Conv2d(16, 8, 3, padding=1), nn.BatchNorm2d(8), nn.ReLU(True), nn.Conv2d(8, 1, 3, padding=1) ) else: self.conv = nn.Sequential( nn.Upsample(scale_factor=(1, 4), mode='nearest'), nn.Conv2d(16, 8, 3, padding=1), nn.ReLU(True), nn.Conv2d(8, 1, 3, padding=1) ) def forward(self, x): x = self.lin(x) x = self.unflatten(x) x = self.conv(x) x = x.unsqueeze(1) return x class ContentEncoder(nn.Module): def __init__(self, **kwargs): super().__init__() self.__dict__.update(kwargs) self.dropout_layer = nn.Dropout(p=self.dropout) # Pitch and duration embedding layers (separate layers for drums # and non drums) self.non_drums_pitch_emb = nn.Linear(constants.N_PITCH_TOKENS, self.d//2) self.drums_pitch_emb = nn.Linear(constants.N_PITCH_TOKENS, self.d//2) self.dur_emb = nn.Linear(constants.N_DUR_TOKENS, self.d//2) # Batch norm layers self.bn_non_drums = nn.BatchNorm1d(num_features=self.d//2) self.bn_drums = nn.BatchNorm1d(num_features=self.d//2) self.bn_dur = nn.BatchNorm1d(num_features=self.d//2) self.chord_encoder = nn.Linear( self.d * (constants.MAX_SIMU_TOKENS-1), self.d) self.graph_encoder = GCN( dropout=self.dropout, input_dim=self.d, hidden_dim=self.d, n_layers=self.gnn_n_layers, num_relations=constants.N_EDGE_TYPES, batch_norm=self.batch_norm ) # Soft attention node-aggregation layer gate_nn = nn.Sequential( MLP(input_dim=self.d, output_dim=1, num_layers=1, activation=False, dropout=self.dropout), nn.BatchNorm1d(1) ) self.graph_attention = GlobalAttention(gate_nn) self.bars_encoder = nn.Linear(self.n_bars * self.d, self.d) def forward(self, graph): c_tensor = graph.c_tensor # Discard SOS token c_tensor = c_tensor[:, 1:, :] # Get drums and non drums tensors drums = c_tensor[graph.is_drum] non_drums = c_tensor[torch.logical_not(graph.is_drum)] # Compute drums embeddings sz = drums.size() drums_pitch = self.drums_pitch_emb( drums[..., :constants.N_PITCH_TOKENS]) drums_pitch = self.bn_drums(drums_pitch.view(-1, self.d//2)) drums_pitch = drums_pitch.view(sz[0], sz[1], self.d//2) drums_dur = self.dur_emb(drums[..., constants.N_PITCH_TOKENS:]) drums_dur = self.bn_dur(drums_dur.view(-1, self.d//2)) drums_dur = drums_dur.view(sz[0], sz[1], self.d//2) drums = torch.cat((drums_pitch, drums_dur), dim=-1) # n_nodes x MAX_SIMU_TOKENS x d # Compute non drums embeddings sz = non_drums.size() non_drums_pitch = self.non_drums_pitch_emb( non_drums[..., :constants.N_PITCH_TOKENS] ) non_drums_pitch = self.bn_non_drums(non_drums_pitch.view(-1, self.d//2)) non_drums_pitch = non_drums_pitch.view(sz[0], sz[1], self.d//2) non_drums_dur = self.dur_emb(non_drums[..., constants.N_PITCH_TOKENS:]) non_drums_dur = self.bn_dur(non_drums_dur.view(-1, self.d//2)) non_drums_dur = non_drums_dur.view(sz[0], sz[1], self.d//2) non_drums = torch.cat((non_drums_pitch, non_drums_dur), dim=-1) # n_nodes x MAX_SIMU_TOKENS x d # Compute chord embeddings (drums and non drums) drums = self.chord_encoder( drums.view(-1, self.d * (constants.MAX_SIMU_TOKENS-1)) ) non_drums = self.chord_encoder( non_drums.view(-1, self.d * (constants.MAX_SIMU_TOKENS-1)) ) drums = F.relu(drums) non_drums = F.relu(non_drums) drums = self.dropout_layer(drums) non_drums = self.dropout_layer(non_drums) # n_nodes x d # Merge drums and non drums out = torch.zeros((c_tensor.size(0), self.d), device=self.device, dtype=drums.dtype) out[graph.is_drum] = drums out[torch.logical_not(graph.is_drum)] = non_drums # n_nodes x d # Set initial graph node states to intermediate chord representations # and pass through GCN graph.x = out graph.distinct_bars = graph.bars + self.n_bars*graph.batch out = self.graph_encoder(graph) # n_nodes x d # Aggregate final node states into bar encodings with soft attention with torch.cuda.amp.autocast(enabled=False): out = self.graph_attention(out, batch=graph.distinct_bars) # bs x n_bars x d out = out.view(-1, self.n_bars * self.d) # bs x (n_bars*d) z_c = self.bars_encoder(out) # bs x d return z_c class StructureEncoder(nn.Module): def __init__(self, **kwargs): super().__init__() self.__dict__.update(kwargs) self.cnn_encoder = CNNEncoder( dense_dim=self.d, output_dim=self.d, dropout=self.dropout, batch_norm=self.batch_norm ) self.bars_encoder = nn.Linear(self.n_bars * self.d, self.d) def forward(self, graph): s_tensor = graph.s_tensor out = self.cnn_encoder(s_tensor.view(-1, constants.N_TRACKS, self.resolution * 4)) # (bs*n_bars) x d out = out.view(-1, self.n_bars * self.d) # bs x (n_bars*d) z_s = self.bars_encoder(out) # bs x d return z_s class Encoder(nn.Module): def __init__(self, **kwargs): super().__init__() self.__dict__.update(kwargs) self.s_encoder = StructureEncoder(**kwargs) self.c_encoder = ContentEncoder(**kwargs) self.dropout_layer = nn.Dropout(p=self.dropout) # Linear layer that merges content and structure representations self.linear_merge = nn.Linear(2*self.d, self.d) self.bn_linear_merge = nn.BatchNorm1d(num_features=self.d) self.linear_mu = nn.Linear(self.d, self.d) self.linear_log_var = nn.Linear(self.d, self.d) def forward(self, graph): z_s = self.s_encoder(graph) z_c = self.c_encoder(graph) # Merge content and structure representations z_g = torch.cat((z_c, z_s), dim=1) z_g = self.dropout_layer(z_g) z_g = self.linear_merge(z_g) z_g = self.bn_linear_merge(z_g) z_g = F.relu(z_g) # Compute mu and log(std^2) z_g = self.dropout_layer(z_g) mu = self.linear_mu(z_g) log_var = self.linear_log_var(z_g) return mu, log_var class StructureDecoder(nn.Module): def __init__(self, **kwargs): super().__init__() self.__dict__.update(kwargs) self.bars_decoder = nn.Linear(self.d, self.d * self.n_bars) self.cnn_decoder = CNNDecoder( input_dim=self.d, dense_dim=self.d, dropout=self.dropout, batch_norm=self.batch_norm ) def forward(self, z_s): # z_s: bs x d out = self.bars_decoder(z_s) # bs x (n_bars*d) out = self.cnn_decoder(out.reshape(-1, self.d)) out = out.view(z_s.size(0), self.n_bars, constants.N_TRACKS, -1) return out class ContentDecoder(nn.Module): def __init__(self, **kwargs): super().__init__() self.__dict__.update(kwargs) self.bars_decoder = nn.Linear(self.d, self.d * self.n_bars) self.graph_decoder = GCN( dropout=self.dropout, input_dim=self.d, hidden_dim=self.d, n_layers=self.gnn_n_layers, num_relations=constants.N_EDGE_TYPES, batch_norm=self.batch_norm ) self.chord_decoder = nn.Linear( self.d, self.d*(constants.MAX_SIMU_TOKENS-1)) # Pitch and duration (un)embedding linear layers self.drums_pitch_emb = nn.Linear(self.d//2, constants.N_PITCH_TOKENS) self.non_drums_pitch_emb = nn.Linear( self.d//2, constants.N_PITCH_TOKENS) self.dur_emb = nn.Linear(self.d//2, constants.N_DUR_TOKENS) self.dropout_layer = nn.Dropout(p=self.dropout) def forward(self, z_c, s): out = self.bars_decoder(z_c) # bs x (n_bars*d) # Initialize node features with corresponding z_bar # and propagate with GNN s.distinct_bars = s.bars + self.n_bars*s.batch _, counts = torch.unique(s.distinct_bars, return_counts=True) out = out.view(-1, self.d) out = torch.repeat_interleave(out, counts, axis=0) # n_nodes x d s.x = out out = self.graph_decoder(s) # n_nodes x d out = self.chord_decoder(out) # n_nodes x (MAX_SIMU_TOKENS*d) out = out.view(-1, constants.MAX_SIMU_TOKENS-1, self.d) drums = out[s.is_drum] # n_nodes_drums x MAX_SIMU_TOKENS x d non_drums = out[torch.logical_not(s.is_drum)] # n_nodes_non_drums x MAX_SIMU_TOKENS x d # Obtain final pitch and dur logits (softmax will be applied # outside forward) non_drums = self.dropout_layer(non_drums) drums = self.dropout_layer(drums) drums_pitch = self.drums_pitch_emb(drums[..., :self.d//2]) drums_dur = self.dur_emb(drums[..., self.d//2:]) drums = torch.cat((drums_pitch, drums_dur), dim=-1) # n_nodes_drums x MAX_SIMU_TOKENS x d_token non_drums_pitch = self.non_drums_pitch_emb(non_drums[..., :self.d//2]) non_drums_dur = self.dur_emb(non_drums[..., self.d//2:]) non_drums = torch.cat((non_drums_pitch, non_drums_dur), dim=-1) # n_nodes_non_drums x MAX_SIMU_TOKENS x d_token # Merge drums and non-drums in the final output tensor d_token = constants.D_TOKEN_PAIR out = torch.zeros((s.num_nodes, constants.MAX_SIMU_TOKENS-1, d_token), device=self.device, dtype=drums.dtype) out[s.is_drum] = drums out[torch.logical_not(s.is_drum)] = non_drums return out class Decoder(nn.Module): def __init__(self, **kwargs): super().__init__() self.__dict__.update(kwargs) self.lin_decoder = nn.Linear(self.d, 2 * self.d) self.batch_norm = nn.BatchNorm1d(num_features=2*self.d) self.dropout = nn.Dropout(p=self.dropout) self.s_decoder = StructureDecoder(**kwargs) self.c_decoder = ContentDecoder(**kwargs) self.sigmoid_thresh = 0.5 def _structure_from_binary(self, s_tensor): # Create graph structures for each batch s = [] for i in range(s_tensor.size(0)): s.append(graph_from_tensor(s_tensor[i])) # Create batch of graphs from single graphs s = Batch.from_data_list(s, exclude_keys=['batch']) s = s.to(next(self.parameters()).device) return s def _binary_from_logits(self, s_logits): # Hard threshold instead of sampling gives more pleasant results s_tensor = torch.sigmoid(s_logits) s_tensor[s_tensor >= self.sigmoid_thresh] = 1 s_tensor[s_tensor < self.sigmoid_thresh] = 0 s_tensor = s_tensor.bool() # Avoid empty bars by creating a fake activation for each empty # (n_tracks x n_timesteps) bar matrix in position [0, 0] empty_mask = ~s_tensor.any(dim=-1).any(dim=-1) idxs = torch.nonzero(empty_mask, as_tuple=True) s_tensor[idxs + (0, 0)] = True return s_tensor def _structure_from_logits(self, s_logits): # Compute binary structure tensor from logits and build torch geometric # structure from binary tensor s_tensor = self._binary_from_logits(s_logits) s = self._structure_from_binary(s_tensor) return s def forward(self, z, s=None): # Obtain z_s and z_c from z z = self.lin_decoder(z) z = self.batch_norm(z) z = F.relu(z) z = self.dropout(z) # bs x (2*d) z_s, z_c = z[:, :self.d], z[:, self.d:] # Obtain the tensor containing structure logits s_logits = self.s_decoder(z_s) if s is None: # Build torch geometric graph structure from structure logits. # This step involves non differentiable operations. # No gradients pass through here. s = self._structure_from_logits(s_logits.detach()) # Obtain the tensor containing content logits c_logits = self.c_decoder(z_c, s) return s_logits, c_logits class VAE(nn.Module): def __init__(self, **kwargs): super().__init__() self.encoder = Encoder(**kwargs) self.decoder = Decoder(**kwargs) def forward(self, graph): # Encoder pass mu, log_var = self.encoder(graph) # Reparameterization trick z = torch.exp(0.5 * log_var) z = z * torch.randn_like(z) z = z + mu # Decoder pass out = self.decoder(z, graph) return out, mu, log_var