MarcusSu1216 commited on
Commit
50caebe
1 Parent(s): 7be9abd

Update data_utils.py

Browse files
Files changed (1) hide show
  1. data_utils.py +2 -15
data_utils.py CHANGED
@@ -23,7 +23,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
23
  3) computes spectrograms from audio files.
24
  """
25
 
26
- def __init__(self, audiopaths, hparams, all_in_mem: bool = False):
27
  self.audiopaths = load_filepaths_and_text(audiopaths)
28
  self.max_wav_value = hparams.data.max_wav_value
29
  self.sampling_rate = hparams.data.sampling_rate
@@ -37,10 +37,6 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
37
 
38
  random.seed(1234)
39
  random.shuffle(self.audiopaths)
40
-
41
- self.all_in_mem = all_in_mem
42
- if self.all_in_mem:
43
- self.cache = [self.get_audio(p[0]) for p in self.audiopaths]
44
 
45
  def get_audio(self, filename):
46
  filename = filename.replace("\\", "/")
@@ -51,8 +47,6 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
51
  audio_norm = audio / self.max_wav_value
52
  audio_norm = audio_norm.unsqueeze(0)
53
  spec_filename = filename.replace(".wav", ".spec.pt")
54
-
55
- # Ideally, all data generated after Mar 25 should have .spec.pt
56
  if os.path.exists(spec_filename):
57
  spec = torch.load(spec_filename)
58
  else:
@@ -79,10 +73,6 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
79
  assert abs(audio_norm.shape[1]-lmin * self.hop_length) < 3 * self.hop_length
80
  spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin]
81
  audio_norm = audio_norm[:, :lmin * self.hop_length]
82
-
83
- return c, f0, spec, audio_norm, spk, uv
84
-
85
- def random_slice(self, c, f0, spec, audio_norm, spk, uv):
86
  # if spec.shape[1] < 30:
87
  # print("skip too short audio:", filename)
88
  # return None
@@ -95,10 +85,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
95
  return c, f0, spec, audio_norm, spk, uv
96
 
97
  def __getitem__(self, index):
98
- if self.all_in_mem:
99
- return self.random_slice(*self.cache[index])
100
- else:
101
- return self.random_slice(*self.get_audio(self.audiopaths[index][0]))
102
 
103
  def __len__(self):
104
  return len(self.audiopaths)
 
23
  3) computes spectrograms from audio files.
24
  """
25
 
26
+ def __init__(self, audiopaths, hparams):
27
  self.audiopaths = load_filepaths_and_text(audiopaths)
28
  self.max_wav_value = hparams.data.max_wav_value
29
  self.sampling_rate = hparams.data.sampling_rate
 
37
 
38
  random.seed(1234)
39
  random.shuffle(self.audiopaths)
 
 
 
 
40
 
41
  def get_audio(self, filename):
42
  filename = filename.replace("\\", "/")
 
47
  audio_norm = audio / self.max_wav_value
48
  audio_norm = audio_norm.unsqueeze(0)
49
  spec_filename = filename.replace(".wav", ".spec.pt")
 
 
50
  if os.path.exists(spec_filename):
51
  spec = torch.load(spec_filename)
52
  else:
 
73
  assert abs(audio_norm.shape[1]-lmin * self.hop_length) < 3 * self.hop_length
74
  spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin]
75
  audio_norm = audio_norm[:, :lmin * self.hop_length]
 
 
 
 
76
  # if spec.shape[1] < 30:
77
  # print("skip too short audio:", filename)
78
  # return None
 
85
  return c, f0, spec, audio_norm, spk, uv
86
 
87
  def __getitem__(self, index):
88
+ return self.get_audio(self.audiopaths[index][0])
 
 
 
89
 
90
  def __len__(self):
91
  return len(self.audiopaths)