waidhoferj commited on
Commit
2e21b83
·
1 Parent(s): 28d5117

added support for direct audio duration input

Browse files
Files changed (1) hide show
  1. preprocessing/dataset.py +17 -6
preprocessing/dataset.py CHANGED
@@ -26,6 +26,7 @@ class SongDataset(Dataset):
26
  audio_start_offset=6, # seconds
27
  audio_window_duration=6, # seconds
28
  audio_window_jitter=1.0, # seconds
 
29
  ):
30
  assert (
31
  audio_window_duration > audio_window_jitter
@@ -33,10 +34,15 @@ class SongDataset(Dataset):
33
 
34
  self.audio_paths = audio_paths
35
  self.dance_labels = dance_labels
36
- audio_metadata = [ta.info(audio) for audio in audio_paths]
37
- self.audio_durations = [
38
- meta.num_frames / meta.sample_rate for meta in audio_metadata
39
- ]
 
 
 
 
 
40
  self.sample_rate = audio_metadata[0].sample_rate # assuming same sample rate
41
  self.audio_window_duration = int(audio_window_duration)
42
  self.audio_start_offset = audio_start_offset
@@ -209,7 +215,12 @@ class Music4DanceDataset(Dataset):
209
  multi_label=multi_label,
210
  min_votes=min_votes,
211
  )
212
- self.song_dataset = SongDataset(song_paths, labels, **kwargs)
 
 
 
 
 
213
 
214
  def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
215
  return self.song_dataset[index]
@@ -306,7 +317,7 @@ class DanceDataModule(pl.LightningDataModule):
306
  self.train_ds,
307
  batch_size=self.batch_size,
308
  num_workers=self.num_workers,
309
- shuffle=True,
310
  )
311
 
312
  def val_dataloader(self):
 
26
  audio_start_offset=6, # seconds
27
  audio_window_duration=6, # seconds
28
  audio_window_jitter=1.0, # seconds
29
+ audio_durations=None,
30
  ):
31
  assert (
32
  audio_window_duration > audio_window_jitter
 
34
 
35
  self.audio_paths = audio_paths
36
  self.dance_labels = dance_labels
37
+
38
+ # Added to limit file I/O
39
+ if audio_durations is None:
40
+ audio_metadata = [ta.info(audio) for audio in audio_paths]
41
+ self.audio_durations = [
42
+ meta.num_frames / meta.sample_rate for meta in audio_metadata
43
+ ]
44
+ else:
45
+ self.audio_durations = audio_durations
46
  self.sample_rate = audio_metadata[0].sample_rate # assuming same sample rate
47
  self.audio_window_duration = int(audio_window_duration)
48
  self.audio_start_offset = audio_start_offset
 
215
  multi_label=multi_label,
216
  min_votes=min_votes,
217
  )
218
+ self.song_dataset = SongDataset(
219
+ song_paths,
220
+ labels,
221
+ audio_durations=[30.0] * len(song_paths),
222
+ **kwargs,
223
+ )
224
 
225
  def __getitem__(self, index) -> tuple[torch.Tensor, torch.Tensor]:
226
  return self.song_dataset[index]
 
317
  self.train_ds,
318
  batch_size=self.batch_size,
319
  num_workers=self.num_workers,
320
+ shuffle=False,
321
  )
322
 
323
  def val_dataloader(self):