|
import itertools |
|
import os |
|
|
|
import torch |
|
import numpy as np |
|
from torch_geometric.data import Dataset |
|
from torch_geometric.data import Data |
|
from torch_geometric.data.collate import collate |
|
|
|
import constants |
|
from constants import EdgeTypes |
|
|
|
|
|
def get_node_labels(s_tensor, ones_idxs): |
|
|
|
|
|
labels = torch.zeros_like(s_tensor, dtype=torch.long, |
|
device=s_tensor.device) |
|
n_nodes = len(ones_idxs[0]) |
|
labels[ones_idxs] = torch.arange(n_nodes, device=s_tensor.device) |
|
return labels |
|
|
|
|
|
def get_track_edges(s_tensor, ones_idxs=None, node_labels=None): |
|
|
|
track_edges = [] |
|
|
|
if ones_idxs is None: |
|
|
|
ones_idxs = torch.nonzero(s_tensor, as_tuple=True) |
|
|
|
if node_labels is None: |
|
node_labels = get_node_labels(s_tensor, ones_idxs) |
|
|
|
|
|
for track in range(s_tensor.size(0)): |
|
|
|
tss = list(ones_idxs[1][ones_idxs[0] == track]) |
|
edge_type = EdgeTypes.TRACK.value + track |
|
edges = [ |
|
|
|
|
|
|
|
(node_labels[track, t1], |
|
node_labels[track, t2], edge_type, t2 - t1) |
|
for t1, t2 in zip(tss[:-1], tss[1:]) |
|
] |
|
inverse_edges = [(u, v, t, d) for (v, u, t, d) in edges] |
|
track_edges.extend(edges + inverse_edges) |
|
|
|
return torch.tensor(track_edges, dtype=torch.long) |
|
|
|
|
|
def get_onset_edges(s_tensor, ones_idxs=None, node_labels=None): |
|
|
|
onset_edges = [] |
|
edge_type = EdgeTypes.ONSET.value |
|
|
|
if ones_idxs is None: |
|
|
|
ones_idxs = torch.nonzero(s_tensor, as_tuple=True) |
|
|
|
if node_labels is None: |
|
node_labels = get_node_labels(s_tensor, ones_idxs) |
|
|
|
|
|
for ts in range(s_tensor.size(1)): |
|
|
|
tracks = list(ones_idxs[0][ones_idxs[1] == ts]) |
|
|
|
combinations = list(itertools.combinations(tracks, 2)) |
|
edges = [ |
|
|
|
(node_labels[track1, ts], node_labels[track2, ts], edge_type, 0) |
|
for track1, track2 in combinations |
|
] |
|
inverse_edges = [(u, v, t, d) for (v, u, t, d) in edges] |
|
onset_edges.extend(edges + inverse_edges) |
|
|
|
return torch.tensor(onset_edges, dtype=torch.long) |
|
|
|
|
|
def get_next_edges(s_tensor, ones_idxs=None, node_labels=None): |
|
|
|
next_edges = [] |
|
edge_type = EdgeTypes.NEXT.value |
|
|
|
if ones_idxs is None: |
|
|
|
ones_idxs = torch.nonzero(s_tensor, as_tuple=True) |
|
|
|
if node_labels is None: |
|
node_labels = get_node_labels(s_tensor, ones_idxs) |
|
|
|
|
|
tss = torch.nonzero(torch.any(s_tensor.bool(), dim=0)).squeeze() |
|
if tss.dim() == 0: |
|
return torch.tensor([], dtype=torch.long) |
|
|
|
for i in range(tss.size(0)-1): |
|
|
|
t1, t2 = tss[i], tss[i+1] |
|
|
|
t1_tracks = ones_idxs[0][ones_idxs[1] == t1] |
|
t2_tracks = ones_idxs[0][ones_idxs[1] == t2] |
|
|
|
|
|
|
|
|
|
tracks_product = list(itertools.product(t1_tracks, t2_tracks)) |
|
tracks_product = [(track1, track2) |
|
for (track1, track2) in tracks_product |
|
if track1 != track2] |
|
|
|
edges = [(node_labels[track1, t1], node_labels[track2, t2], |
|
edge_type, t2 - t1) |
|
for track1, track2 in tracks_product] |
|
|
|
next_edges.extend(edges) |
|
|
|
return torch.tensor(next_edges, dtype=torch.long) |
|
|
|
|
|
def get_track_features(s_tensor): |
|
|
|
|
|
ones_idxs = torch.nonzero(s_tensor) |
|
|
|
n_nodes = len(ones_idxs) |
|
tracks = ones_idxs[:, 0] |
|
n_tracks = s_tensor.size(0) |
|
|
|
|
|
|
|
features = torch.zeros((n_nodes, n_tracks)) |
|
features[torch.arange(n_nodes), tracks] = 1 |
|
|
|
return features |
|
|
|
|
|
def graph_from_tensor(s_tensor): |
|
|
|
bars = [] |
|
|
|
|
|
for i in range(s_tensor.size(0)): |
|
|
|
bar = s_tensor[i] |
|
|
|
|
|
|
|
if not torch.any(bar): |
|
bar[0, 0] = 1 |
|
|
|
|
|
track_edges = get_track_edges(bar) |
|
onset_edges = get_onset_edges(bar) |
|
next_edges = get_next_edges(bar) |
|
edges = [track_edges, onset_edges, next_edges] |
|
|
|
|
|
is_edgeless = (len(track_edges) == 0 and |
|
len(onset_edges) == 0 and |
|
len(next_edges) == 0) |
|
if not is_edgeless: |
|
edge_list = torch.cat([x for x in edges |
|
if torch.numel(x) > 0]) |
|
|
|
|
|
|
|
|
|
|
|
edge_index = (edge_list[:, :2].t().contiguous() if not is_edgeless else |
|
torch.LongTensor([[0], [0]])) |
|
attrs = (edge_list[:, 2:] if not is_edgeless else |
|
torch.Tensor([[0, 0]])) |
|
|
|
|
|
edge_attrs = torch.zeros(attrs.size(0), s_tensor.shape[-1] + 1) |
|
edge_attrs[:, 0] = attrs[:, 0] |
|
edge_attrs[torch.arange(edge_attrs.size(0)), |
|
attrs.long()[:, 1] + 1] = 1 |
|
|
|
node_features = get_track_features(bar) |
|
is_drum = node_features[:, 0].bool() |
|
num_nodes = torch.sum(bar, dtype=torch.long) |
|
|
|
bars.append(Data(edge_index=edge_index, edge_attrs=edge_attrs, |
|
num_nodes=num_nodes, node_features=node_features, |
|
is_drum=is_drum).to(s_tensor.device)) |
|
|
|
|
|
graph, _, _ = collate( |
|
Data, |
|
data_list=bars, |
|
increment=True, |
|
add_batch=True |
|
) |
|
|
|
|
|
|
|
graph.bars = graph.batch |
|
|
|
return graph |
|
|
|
|
|
class PolyphemusDataset(Dataset): |
|
|
|
def __init__(self, dir, n_bars=2): |
|
self.dir = dir |
|
self.files = list(os.scandir(self.dir)) |
|
self.len = len(self.files) |
|
self.n_bars = n_bars |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
sample_path = os.path.join(self.dir, self.files[idx].name) |
|
data = np.load(sample_path) |
|
c_tensor = torch.tensor(data["c_tensor"], dtype=torch.long) |
|
s_tensor = torch.tensor(data["s_tensor"], dtype=torch.bool) |
|
|
|
|
|
|
|
c_tensor = c_tensor.reshape(c_tensor.shape[0], self.n_bars, -1, |
|
c_tensor.shape[2], c_tensor.shape[3]) |
|
c_tensor = c_tensor.permute(1, 0, 2, 3, 4) |
|
s_tensor = s_tensor.reshape(s_tensor.shape[0], self.n_bars, -1) |
|
s_tensor = s_tensor.permute(1, 0, 2) |
|
|
|
|
|
pitches = c_tensor[..., 0] |
|
onehot_p = torch.zeros( |
|
(pitches.shape[0]*pitches.shape[1]*pitches.shape[2]*pitches.shape[3], |
|
constants.N_PITCH_TOKENS), |
|
dtype=torch.float32 |
|
) |
|
onehot_p[torch.arange(0, onehot_p.shape[0]), pitches.reshape(-1)] = 1. |
|
onehot_p = onehot_p.reshape(pitches.shape[0], pitches.shape[1], |
|
pitches.shape[2], pitches.shape[3], |
|
constants.N_PITCH_TOKENS) |
|
|
|
|
|
durs = c_tensor[..., 1] |
|
onehot_d = torch.zeros( |
|
(durs.shape[0]*durs.shape[1]*durs.shape[2]*durs.shape[3], |
|
constants.N_DUR_TOKENS), |
|
dtype=torch.float32 |
|
) |
|
onehot_d[torch.arange(0, onehot_d.shape[0]), durs.reshape(-1)] = 1. |
|
onehot_d = onehot_d.reshape(durs.shape[0], durs.shape[1], |
|
durs.shape[2], durs.shape[3], |
|
constants.N_DUR_TOKENS) |
|
|
|
|
|
c_tensor = torch.cat((onehot_p, onehot_d), dim=-1) |
|
|
|
|
|
graph = graph_from_tensor(s_tensor) |
|
|
|
|
|
c_tensor = c_tensor.reshape(-1, c_tensor.shape[-2], c_tensor.shape[-1]) |
|
c_tensor = c_tensor[s_tensor.reshape(-1).bool()] |
|
|
|
graph.c_tensor = c_tensor |
|
graph.s_tensor = s_tensor.float() |
|
|
|
return graph |
|
|