Spaces:
Runtime error
Runtime error
Commit
·
2e21b83
1
Parent(s):
28d5117
added support for direct audio duration input
Browse files- preprocessing/dataset.py +17 -6
preprocessing/dataset.py
CHANGED
@@ -26,6 +26,7 @@ class SongDataset(Dataset):
|
|
26 |
audio_start_offset=6, # seconds
|
27 |
audio_window_duration=6, # seconds
|
28 |
audio_window_jitter=1.0, # seconds
|
|
|
29 |
):
|
30 |
assert (
|
31 |
audio_window_duration > audio_window_jitter
|
@@ -33,10 +34,15 @@ class SongDataset(Dataset):
|
|
33 |
|
34 |
self.audio_paths = audio_paths
|
35 |
self.dance_labels = dance_labels
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
40 |
self.sample_rate = audio_metadata[0].sample_rate # assuming same sample rate
|
41 |
self.audio_window_duration = int(audio_window_duration)
|
42 |
self.audio_start_offset = audio_start_offset
|
@@ -209,7 +215,12 @@ class Music4DanceDataset(Dataset):
|
|
209 |
multi_label=multi_label,
|
210 |
min_votes=min_votes,
|
211 |
)
|
212 |
-
self.song_dataset = SongDataset(
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
|
215 |
return self.song_dataset[index]
|
@@ -306,7 +317,7 @@ class DanceDataModule(pl.LightningDataModule):
|
|
306 |
self.train_ds,
|
307 |
batch_size=self.batch_size,
|
308 |
num_workers=self.num_workers,
|
309 |
-
shuffle=
|
310 |
)
|
311 |
|
312 |
def val_dataloader(self):
|
|
|
26 |
audio_start_offset=6, # seconds
|
27 |
audio_window_duration=6, # seconds
|
28 |
audio_window_jitter=1.0, # seconds
|
29 |
+
audio_durations=None,
|
30 |
):
|
31 |
assert (
|
32 |
audio_window_duration > audio_window_jitter
|
|
|
34 |
|
35 |
self.audio_paths = audio_paths
|
36 |
self.dance_labels = dance_labels
|
37 |
+
|
38 |
+
# Added to limit file I/O
|
39 |
+
if audio_durations is None:
|
40 |
+
audio_metadata = [ta.info(audio) for audio in audio_paths]
|
41 |
+
self.audio_durations = [
|
42 |
+
meta.num_frames / meta.sample_rate for meta in audio_metadata
|
43 |
+
]
|
44 |
+
else:
|
45 |
+
self.audio_durations = audio_durations
|
46 |
self.sample_rate = audio_metadata[0].sample_rate # assuming same sample rate
|
47 |
self.audio_window_duration = int(audio_window_duration)
|
48 |
self.audio_start_offset = audio_start_offset
|
|
|
215 |
multi_label=multi_label,
|
216 |
min_votes=min_votes,
|
217 |
)
|
218 |
+
self.song_dataset = SongDataset(
|
219 |
+
song_paths,
|
220 |
+
labels,
|
221 |
+
audio_durations=[30.0] * len(song_paths),
|
222 |
+
**kwargs,
|
223 |
+
)
|
224 |
|
225 |
def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
|
226 |
return self.song_dataset[index]
|
|
|
317 |
self.train_ds,
|
318 |
batch_size=self.batch_size,
|
319 |
num_workers=self.num_workers,
|
320 |
+
shuffle=False,
|
321 |
)
|
322 |
|
323 |
def val_dataloader(self):
|