import itertools import os import torch import numpy as np from torch_geometric.data import Dataset from torch_geometric.data import Data from torch_geometric.data.collate import collate import constants from constants import EdgeTypes def get_node_labels(s_tensor, ones_idxs): # Build a tensor which has node labels in place of each activation in the # stucture tensor labels = torch.zeros_like(s_tensor, dtype=torch.long, device=s_tensor.device) n_nodes = len(ones_idxs[0]) labels[ones_idxs] = torch.arange(n_nodes, device=s_tensor.device) return labels def get_track_edges(s_tensor, ones_idxs=None, node_labels=None): track_edges = [] if ones_idxs is None: # Indices where the binary structure tensor is active ones_idxs = torch.nonzero(s_tensor, as_tuple=True) if node_labels is None: node_labels = get_node_labels(s_tensor, ones_idxs) # For each track, add direct and inverse edges between consecutive nodes for track in range(s_tensor.size(0)): # List of active timesteps in the current track tss = list(ones_idxs[1][ones_idxs[0] == track]) edge_type = EdgeTypes.TRACK.value + track edges = [ # Edge tuple: (u, v, type, ts_distance). Zip is used to obtain # consecutive active timesteps. Edges in different tracks have # different types. (node_labels[track, t1], node_labels[track, t2], edge_type, t2 - t1) for t1, t2 in zip(tss[:-1], tss[1:]) ] inverse_edges = [(u, v, t, d) for (v, u, t, d) in edges] track_edges.extend(edges + inverse_edges) return torch.tensor(track_edges, dtype=torch.long) def get_onset_edges(s_tensor, ones_idxs=None, node_labels=None): onset_edges = [] edge_type = EdgeTypes.ONSET.value if ones_idxs is None: # Indices where the binary structure tensor is active ones_idxs = torch.nonzero(s_tensor, as_tuple=True) if node_labels is None: node_labels = get_node_labels(s_tensor, ones_idxs) # Add direct and inverse edges between nodes played in the same timestep for ts in range(s_tensor.size(1)): # List of active tracks in the current timestep tracks = list(ones_idxs[0][ones_idxs[1] == ts]) # Obtain all possible pairwise combinations of active tracks combinations = list(itertools.combinations(tracks, 2)) edges = [ # Edge tuple: (u, v, type, ts_distance(=0)). (node_labels[track1, ts], node_labels[track2, ts], edge_type, 0) for track1, track2 in combinations ] inverse_edges = [(u, v, t, d) for (v, u, t, d) in edges] onset_edges.extend(edges + inverse_edges) return torch.tensor(onset_edges, dtype=torch.long) def get_next_edges(s_tensor, ones_idxs=None, node_labels=None): next_edges = [] edge_type = EdgeTypes.NEXT.value if ones_idxs is None: # Indices where the binary structure tensor is active ones_idxs = torch.nonzero(s_tensor, as_tuple=True) if node_labels is None: node_labels = get_node_labels(s_tensor, ones_idxs) # List of active timesteps tss = torch.nonzero(torch.any(s_tensor.bool(), dim=0)).squeeze() if tss.dim() == 0: return torch.tensor([], dtype=torch.long) for i in range(tss.size(0)-1): # Get consecutive active timesteps t1, t2 = tss[i], tss[i+1] # Get all the active tracks in the two timesteps t1_tracks = ones_idxs[0][ones_idxs[1] == t1] t2_tracks = ones_idxs[0][ones_idxs[1] == t2] # Combine the source and destination tracks, removing combinations with # the same source and destination track (since these represent track # edges). tracks_product = list(itertools.product(t1_tracks, t2_tracks)) tracks_product = [(track1, track2) for (track1, track2) in tracks_product if track1 != track2] # Edge tuple: (u, v, type, ts_distance). edges = [(node_labels[track1, t1], node_labels[track2, t2], edge_type, t2 - t1) for track1, track2 in tracks_product] next_edges.extend(edges) return torch.tensor(next_edges, dtype=torch.long) def get_track_features(s_tensor): # Indices where the binary structure tensor is active ones_idxs = torch.nonzero(s_tensor) n_nodes = len(ones_idxs) tracks = ones_idxs[:, 0] n_tracks = s_tensor.size(0) # The feature n_nodes x n_tracks tensor contains one-hot tracks # representations for each node features = torch.zeros((n_nodes, n_tracks)) features[torch.arange(n_nodes), tracks] = 1 return features def graph_from_tensor(s_tensor): bars = [] # Iterate over bars and construct a graph for each bar for i in range(s_tensor.size(0)): bar = s_tensor[i] # If the bar contains no activations, add a fake one to avoid having # to deal with empty graphs if not torch.any(bar): bar[0, 0] = 1 # Get edges from boolean activations track_edges = get_track_edges(bar) onset_edges = get_onset_edges(bar) next_edges = get_next_edges(bar) edges = [track_edges, onset_edges, next_edges] # Concatenate edge tensors (N x 4) (if any) is_edgeless = (len(track_edges) == 0 and len(onset_edges) == 0 and len(next_edges) == 0) if not is_edgeless: edge_list = torch.cat([x for x in edges if torch.numel(x) > 0]) # Adapt tensor to torch_geometric's Data # If no edges, add fake self-edge # edge_list[:, :2] contains source and destination node labels # edge_list[:, 2:] contains edge types and timestep distances edge_index = (edge_list[:, :2].t().contiguous() if not is_edgeless else torch.LongTensor([[0], [0]])) attrs = (edge_list[:, 2:] if not is_edgeless else torch.Tensor([[0, 0]])) # Add one hot timestep distance to edge attributes edge_attrs = torch.zeros(attrs.size(0), s_tensor.shape[-1] + 1) edge_attrs[:, 0] = attrs[:, 0] edge_attrs[torch.arange(edge_attrs.size(0)), attrs.long()[:, 1] + 1] = 1 node_features = get_track_features(bar) is_drum = node_features[:, 0].bool() num_nodes = torch.sum(bar, dtype=torch.long) bars.append(Data(edge_index=edge_index, edge_attrs=edge_attrs, num_nodes=num_nodes, node_features=node_features, is_drum=is_drum).to(s_tensor.device)) # Merge the graphs corresponding to different bars into a single big graph graph, _, _ = collate( Data, data_list=bars, increment=True, add_batch=True ) # Change bars assignment vector name (otherwise, Dataloader's collate # would overwrite graphs.batch) graph.bars = graph.batch return graph class PolyphemusDataset(Dataset): def __init__(self, dir, n_bars=2): self.dir = dir self.files = list(os.scandir(self.dir)) self.len = len(self.files) self.n_bars = n_bars def __len__(self): return self.len def __getitem__(self, idx): # Load tensors sample_path = os.path.join(self.dir, self.files[idx].name) data = np.load(sample_path) c_tensor = torch.tensor(data["c_tensor"], dtype=torch.long) s_tensor = torch.tensor(data["s_tensor"], dtype=torch.bool) # From (n_tracks x n_timesteps x ...) # to (n_bars x n_tracks x n_timesteps x ...) c_tensor = c_tensor.reshape(c_tensor.shape[0], self.n_bars, -1, c_tensor.shape[2], c_tensor.shape[3]) c_tensor = c_tensor.permute(1, 0, 2, 3, 4) s_tensor = s_tensor.reshape(s_tensor.shape[0], self.n_bars, -1) s_tensor = s_tensor.permute(1, 0, 2) # From decimals to onehot (pitches) pitches = c_tensor[..., 0] onehot_p = torch.zeros( (pitches.shape[0]*pitches.shape[1]*pitches.shape[2]*pitches.shape[3], constants.N_PITCH_TOKENS), dtype=torch.float32 ) onehot_p[torch.arange(0, onehot_p.shape[0]), pitches.reshape(-1)] = 1. onehot_p = onehot_p.reshape(pitches.shape[0], pitches.shape[1], pitches.shape[2], pitches.shape[3], constants.N_PITCH_TOKENS) # From decimals to onehot (durations) durs = c_tensor[..., 1] onehot_d = torch.zeros( (durs.shape[0]*durs.shape[1]*durs.shape[2]*durs.shape[3], constants.N_DUR_TOKENS), dtype=torch.float32 ) onehot_d[torch.arange(0, onehot_d.shape[0]), durs.reshape(-1)] = 1. onehot_d = onehot_d.reshape(durs.shape[0], durs.shape[1], durs.shape[2], durs.shape[3], constants.N_DUR_TOKENS) # Concatenate pitches and durations c_tensor = torch.cat((onehot_p, onehot_d), dim=-1) # Build graph structure from structure tensor graph = graph_from_tensor(s_tensor) # Filter silences in order to get a sparse representation c_tensor = c_tensor.reshape(-1, c_tensor.shape[-2], c_tensor.shape[-1]) c_tensor = c_tensor[s_tensor.reshape(-1).bool()] graph.c_tensor = c_tensor graph.s_tensor = s_tensor.float() return graph