Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset | |
""" | |
import argparse | |
import glob | |
import os | |
from shutil import copy | |
import h5py | |
import numpy as np | |
import soundfile as sf | |
import torch | |
import tqdm | |
import fairseq | |
from torch import nn | |
def read_audio(fname): | |
""" Load an audio file and return PCM along with the sample rate """ | |
wav, sr = sf.read(fname) | |
assert sr == 16e3 | |
return wav, 16e3 | |
class PretrainedWav2VecModel(nn.Module): | |
def __init__(self, fname): | |
super().__init__() | |
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fname]) | |
model = model[0] | |
model.eval() | |
self.model = model | |
def forward(self, x): | |
with torch.no_grad(): | |
z = self.model.feature_extractor(x) | |
if isinstance(z, tuple): | |
z = z[0] | |
c = self.model.feature_aggregator(z) | |
return z, c | |
class EmbeddingWriterConfig(argparse.ArgumentParser): | |
def __init__(self): | |
super().__init__("Pre-compute embeddings for flashlight datasets") | |
kwargs = {"action": "store", "type": str, "required": True} | |
self.add_argument("--input", "-i", help="Input Directory", **kwargs) | |
self.add_argument("--output", "-o", help="Output Directory", **kwargs) | |
self.add_argument("--model", help="Path to model checkpoint", **kwargs) | |
self.add_argument("--split", help="Dataset Splits", nargs="+", **kwargs) | |
self.add_argument( | |
"--ext", default="wav", required=False, help="Audio file extension" | |
) | |
self.add_argument( | |
"--no-copy-labels", | |
action="store_true", | |
help="Do not copy label files. Useful for large datasets, use --targetdir in flashlight then.", | |
) | |
self.add_argument( | |
"--use-feat", | |
action="store_true", | |
help="Use the feature vector ('z') instead of context vector ('c') for features", | |
) | |
self.add_argument("--gpu", help="GPU to use", default=0, type=int) | |
class Prediction: | |
""" Lightweight wrapper around a fairspeech embedding model """ | |
def __init__(self, fname, gpu=0): | |
self.gpu = gpu | |
self.model = PretrainedWav2VecModel(fname).cuda(gpu) | |
def __call__(self, x): | |
x = torch.from_numpy(x).float().cuda(self.gpu) | |
with torch.no_grad(): | |
z, c = self.model(x.unsqueeze(0)) | |
return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy() | |
class H5Writer: | |
""" Write features as hdf5 file in flashlight compatible format """ | |
def __init__(self, fname): | |
self.fname = fname | |
os.makedirs(os.path.dirname(self.fname), exist_ok=True) | |
def write(self, data): | |
channel, T = data.shape | |
with h5py.File(self.fname, "w") as out_ds: | |
data = data.T.flatten() | |
out_ds["features"] = data | |
out_ds["info"] = np.array([16e3 // 160, T, channel]) | |
class EmbeddingDatasetWriter(object): | |
"""Given a model and a flashlight dataset, pre-compute and store embeddings | |
Args: | |
input_root, str : | |
Path to the flashlight dataset | |
output_root, str : | |
Desired output directory. Will be created if non-existent | |
split, str : | |
Dataset split | |
""" | |
def __init__( | |
self, | |
input_root, | |
output_root, | |
split, | |
model_fname, | |
extension="wav", | |
gpu=0, | |
verbose=False, | |
use_feat=False, | |
): | |
assert os.path.exists(model_fname) | |
self.model_fname = model_fname | |
self.model = Prediction(self.model_fname, gpu) | |
self.input_root = input_root | |
self.output_root = output_root | |
self.split = split | |
self.verbose = verbose | |
self.extension = extension | |
self.use_feat = use_feat | |
assert os.path.exists(self.input_path), "Input path '{}' does not exist".format( | |
self.input_path | |
) | |
def _progress(self, iterable, **kwargs): | |
if self.verbose: | |
return tqdm.tqdm(iterable, **kwargs) | |
return iterable | |
def require_output_path(self, fname=None): | |
path = self.get_output_path(fname) | |
os.makedirs(path, exist_ok=True) | |
def input_path(self): | |
return self.get_input_path() | |
def output_path(self): | |
return self.get_output_path() | |
def get_input_path(self, fname=None): | |
if fname is None: | |
return os.path.join(self.input_root, self.split) | |
return os.path.join(self.get_input_path(), fname) | |
def get_output_path(self, fname=None): | |
if fname is None: | |
return os.path.join(self.output_root, self.split) | |
return os.path.join(self.get_output_path(), fname) | |
def copy_labels(self): | |
self.require_output_path() | |
labels = list( | |
filter( | |
lambda x: self.extension not in x, glob.glob(self.get_input_path("*")) | |
) | |
) | |
for fname in tqdm.tqdm(labels): | |
copy(fname, self.output_path) | |
def input_fnames(self): | |
return sorted(glob.glob(self.get_input_path("*.{}".format(self.extension)))) | |
def __len__(self): | |
return len(self.input_fnames) | |
def write_features(self): | |
paths = self.input_fnames | |
fnames_context = map( | |
lambda x: os.path.join( | |
self.output_path, x.replace("." + self.extension, ".h5context") | |
), | |
map(os.path.basename, paths), | |
) | |
for name, target_fname in self._progress( | |
zip(paths, fnames_context), total=len(self) | |
): | |
wav, sr = read_audio(name) | |
z, c = self.model(wav) | |
feat = z if self.use_feat else c | |
writer = H5Writer(target_fname) | |
writer.write(feat) | |
def __repr__(self): | |
return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format( | |
n_files=len(self), **self.__dict__ | |
) | |
if __name__ == "__main__": | |
args = EmbeddingWriterConfig().parse_args() | |
for split in args.split: | |
writer = EmbeddingDatasetWriter( | |
input_root=args.input, | |
output_root=args.output, | |
split=split, | |
model_fname=args.model, | |
gpu=args.gpu, | |
extension=args.ext, | |
use_feat=args.use_feat, | |
) | |
print(writer) | |
writer.require_output_path() | |
print("Writing Features...") | |
writer.write_features() | |
print("Done.") | |
if not args.no_copy_labels: | |
print("Copying label data...") | |
writer.copy_labels() | |
print("Done.") | |