kevinwang676's picture
Duplicate from zlc99/M4Singer
26925fd
raw
history blame
4.75 kB
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""Utility functions."""
import fnmatch
import logging
import os
import sys
import h5py
import numpy as np
def find_files(root_dir, query="*.wav", include_root_dir=True):
"""Find files recursively.
Args:
root_dir (str): Root root_dir to find.
query (str): Query to find.
include_root_dir (bool): If False, root_dir name is not included.
Returns:
list: List of found filenames.
"""
files = []
for root, dirnames, filenames in os.walk(root_dir, followlinks=True):
for filename in fnmatch.filter(filenames, query):
files.append(os.path.join(root, filename))
if not include_root_dir:
files = [file_.replace(root_dir + "/", "") for file_ in files]
return files
def read_hdf5(hdf5_name, hdf5_path):
"""Read hdf5 dataset.
Args:
hdf5_name (str): Filename of hdf5 file.
hdf5_path (str): Dataset name in hdf5 file.
Return:
any: Dataset values.
"""
if not os.path.exists(hdf5_name):
logging.error(f"There is no such a hdf5 file ({hdf5_name}).")
sys.exit(1)
hdf5_file = h5py.File(hdf5_name, "r")
if hdf5_path not in hdf5_file:
logging.error(f"There is no such a data in hdf5 file. ({hdf5_path})")
sys.exit(1)
hdf5_data = hdf5_file[hdf5_path][()]
hdf5_file.close()
return hdf5_data
def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True):
"""Write dataset to hdf5.
Args:
hdf5_name (str): Hdf5 dataset filename.
hdf5_path (str): Dataset path in hdf5.
write_data (ndarray): Data to write.
is_overwrite (bool): Whether to overwrite dataset.
"""
# convert to numpy array
write_data = np.array(write_data)
# check folder existence
folder_name, _ = os.path.split(hdf5_name)
if not os.path.exists(folder_name) and len(folder_name) != 0:
os.makedirs(folder_name)
# check hdf5 existence
if os.path.exists(hdf5_name):
# if already exists, open with r+ mode
hdf5_file = h5py.File(hdf5_name, "r+")
# check dataset existence
if hdf5_path in hdf5_file:
if is_overwrite:
logging.warning("Dataset in hdf5 file already exists. "
"recreate dataset in hdf5.")
hdf5_file.__delitem__(hdf5_path)
else:
logging.error("Dataset in hdf5 file already exists. "
"if you want to overwrite, please set is_overwrite = True.")
hdf5_file.close()
sys.exit(1)
else:
# if not exists, open with w mode
hdf5_file = h5py.File(hdf5_name, "w")
# write data to hdf5
hdf5_file.create_dataset(hdf5_path, data=write_data)
hdf5_file.flush()
hdf5_file.close()
class HDF5ScpLoader(object):
"""Loader class for a fests.scp file of hdf5 file.
Examples:
key1 /some/path/a.h5:feats
key2 /some/path/b.h5:feats
key3 /some/path/c.h5:feats
key4 /some/path/d.h5:feats
...
>>> loader = HDF5ScpLoader("hdf5.scp")
>>> array = loader["key1"]
key1 /some/path/a.h5
key2 /some/path/b.h5
key3 /some/path/c.h5
key4 /some/path/d.h5
...
>>> loader = HDF5ScpLoader("hdf5.scp", "feats")
>>> array = loader["key1"]
"""
def __init__(self, feats_scp, default_hdf5_path="feats"):
"""Initialize HDF5 scp loader.
Args:
feats_scp (str): Kaldi-style feats.scp file with hdf5 format.
default_hdf5_path (str): Path in hdf5 file. If the scp contain the info, not used.
"""
self.default_hdf5_path = default_hdf5_path
with open(feats_scp) as f:
lines = [line.replace("\n", "") for line in f.readlines()]
self.data = {}
for line in lines:
key, value = line.split()
self.data[key] = value
def get_path(self, key):
"""Get hdf5 file path for a given key."""
return self.data[key]
def __getitem__(self, key):
"""Get ndarray for a given key."""
p = self.data[key]
if ":" in p:
return read_hdf5(*p.split(":"))
else:
return read_hdf5(p, self.default_hdf5_path)
def __len__(self):
"""Return the length of the scp file."""
return len(self.data)
def __iter__(self):
"""Return the iterator of the scp file."""
return iter(self.data)
def keys(self):
"""Return the keys of the scp file."""
return self.data.keys()