DeepLearning101's picture
Upload 17 files
109bb65
raw
history blame
2.97 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adefossez
import json
from pathlib import Path
import math
import os
import sys
import torchaudio
from torch.nn import functional as F
def find_audio_files(path, exts=[".wav"], progress=True):
audio_files = []
for root, folders, files in os.walk(path, followlinks=True):
for file in files:
file = Path(root) / file
if file.suffix.lower() in exts:
audio_files.append(str(file.resolve()))
meta = []
for idx, file in enumerate(audio_files):
siginfo, _ = torchaudio.info(file)
length = siginfo.length // siginfo.channels
meta.append((file, length))
if progress:
print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
meta.sort()
return meta
class Audioset:
def __init__(self, files=None, length=None, stride=None,
pad=True, with_path=False, sample_rate=None):
"""
files should be a list [(file, length)]
"""
self.files = files
self.num_examples = []
self.length = length
self.stride = stride or length
self.with_path = with_path
self.sample_rate = sample_rate
for file, file_length in self.files:
if length is None:
examples = 1
elif file_length < length:
examples = 1 if pad else 0
elif pad:
examples = int(math.ceil((file_length - self.length) / self.stride) + 1)
else:
examples = (file_length - self.length) // self.stride + 1
self.num_examples.append(examples)
def __len__(self):
return sum(self.num_examples)
def __getitem__(self, index):
for (file, _), examples in zip(self.files, self.num_examples):
if index >= examples:
index -= examples
continue
num_frames = 0
offset = 0
if self.length is not None:
offset = self.stride * index
num_frames = self.length
out, sr = torchaudio.load(str(file), offset=offset, num_frames=num_frames)
if self.sample_rate is not None:
if sr != self.sample_rate:
raise RuntimeError(f"Expected {file} to have sample rate of "
f"{self.sample_rate}, but got {sr}")
if num_frames:
out = F.pad(out, (0, num_frames - out.shape[-1]))
if self.with_path:
return out, file
else:
return out
if __name__ == "__main__":
meta = []
for path in sys.argv[1:]:
meta += find_audio_files(path)
json.dump(meta, sys.stdout, indent=4)