Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from __future__ import absolute_import, division, print_function, unicode_literals | |
import argparse | |
import concurrent.futures | |
import json | |
import multiprocessing | |
import os | |
from collections import namedtuple | |
from itertools import chain | |
import sentencepiece as spm | |
from fairseq.data import Dictionary | |
MILLISECONDS_TO_SECONDS = 0.001 | |
def process_sample(aud_path, lable, utt_id, sp, tgt_dict): | |
import torchaudio | |
input = {} | |
output = {} | |
si, ei = torchaudio.info(aud_path) | |
input["length_ms"] = int( | |
si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS | |
) | |
input["path"] = aud_path | |
token = " ".join(sp.EncodeAsPieces(lable)) | |
ids = tgt_dict.encode_line(token, append_eos=False) | |
output["text"] = lable | |
output["token"] = token | |
output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids])) | |
return {utt_id: {"input": input, "output": output}} | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--audio-dirs", | |
nargs="+", | |
default=["-"], | |
required=True, | |
help="input directories with audio files", | |
) | |
parser.add_argument( | |
"--labels", | |
required=True, | |
help="aggregated input labels with format <ID LABEL> per line", | |
type=argparse.FileType("r", encoding="UTF-8"), | |
) | |
parser.add_argument( | |
"--spm-model", | |
required=True, | |
help="sentencepiece model to use for encoding", | |
type=argparse.FileType("r", encoding="UTF-8"), | |
) | |
parser.add_argument( | |
"--dictionary", | |
required=True, | |
help="file to load fairseq dictionary from", | |
type=argparse.FileType("r", encoding="UTF-8"), | |
) | |
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav") | |
parser.add_argument( | |
"--output", | |
required=True, | |
type=argparse.FileType("w"), | |
help="path to save json output", | |
) | |
args = parser.parse_args() | |
sp = spm.SentencePieceProcessor() | |
sp.Load(args.spm_model.name) | |
tgt_dict = Dictionary.load(args.dictionary) | |
labels = {} | |
for line in args.labels: | |
(utt_id, label) = line.split(" ", 1) | |
labels[utt_id] = label | |
if len(labels) == 0: | |
raise Exception("No labels found in ", args.labels_path) | |
Sample = namedtuple("Sample", "aud_path utt_id") | |
samples = [] | |
for path, _, files in chain.from_iterable( | |
os.walk(path) for path in args.audio_dirs | |
): | |
for f in files: | |
if f.endswith(args.audio_format): | |
if len(os.path.splitext(f)) != 2: | |
raise Exception("Expect <utt_id.extension> file name. Got: ", f) | |
utt_id = os.path.splitext(f)[0] | |
if utt_id not in labels: | |
continue | |
samples.append(Sample(os.path.join(path, f), utt_id)) | |
utts = {} | |
num_cpu = multiprocessing.cpu_count() | |
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor: | |
future_to_sample = { | |
executor.submit( | |
process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict | |
): s | |
for s in samples | |
} | |
for future in concurrent.futures.as_completed(future_to_sample): | |
try: | |
data = future.result() | |
except Exception as exc: | |
print("generated an exception: ", exc) | |
else: | |
utts.update(data) | |
json.dump({"utts": utts}, args.output, indent=4) | |
if __name__ == "__main__": | |
main() | |