Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import itertools | |
import logging | |
import re | |
import time | |
from g2p_en import G2p | |
logger = logging.getLogger(__name__) | |
FAIL_SENT = "FAILED_SENTENCE" | |
def parse(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--data-path", type=str, required=True) | |
parser.add_argument("--out-path", type=str, required=True) | |
parser.add_argument("--lower-case", action="store_true") | |
parser.add_argument("--do-filter", action="store_true") | |
parser.add_argument("--use-word-start", action="store_true") | |
parser.add_argument("--dup-vowel", default=1, type=int) | |
parser.add_argument("--dup-consonant", default=1, type=int) | |
parser.add_argument("--no-punc", action="store_true") | |
parser.add_argument("--reserve-word", type=str, default="") | |
parser.add_argument( | |
"--reserve-first-column", | |
action="store_true", | |
help="first column is sentence id", | |
) | |
### | |
parser.add_argument("--parallel-process-num", default=1, type=int) | |
parser.add_argument("--logdir", default="") | |
args = parser.parse_args() | |
return args | |
def process_sent(sent, g2p, res_wrds, args): | |
sents = pre_process_sent(sent, args.do_filter, args.lower_case, res_wrds) | |
pho_seqs = [do_g2p(g2p, s, res_wrds, i == 0) for i, s in enumerate(sents)] | |
pho_seq = ( | |
[FAIL_SENT] | |
if [FAIL_SENT] in pho_seqs | |
else list(itertools.chain.from_iterable(pho_seqs)) | |
) | |
if args.no_punc: | |
pho_seq = remove_punc(pho_seq) | |
if args.dup_vowel > 1 or args.dup_consonant > 1: | |
pho_seq = dup_pho(pho_seq, args.dup_vowel, args.dup_consonant) | |
if args.use_word_start: | |
pho_seq = add_word_start(pho_seq) | |
return " ".join(pho_seq) | |
def remove_punc(sent): | |
ns = [] | |
regex = re.compile("[^a-zA-Z0-9 ]") | |
for p in sent: | |
if (not regex.search(p)) or p == FAIL_SENT: | |
if p == " " and (len(ns) == 0 or ns[-1] == " "): | |
continue | |
ns.append(p) | |
return ns | |
def do_g2p(g2p, sent, res_wrds, is_first_sent): | |
if sent in res_wrds: | |
pho_seq = [res_wrds[sent]] | |
else: | |
pho_seq = g2p(sent) | |
if not is_first_sent: | |
pho_seq = [" "] + pho_seq # add space to separate | |
return pho_seq | |
def pre_process_sent(sent, do_filter, lower_case, res_wrds): | |
if do_filter: | |
sent = re.sub("-", " ", sent) | |
sent = re.sub("—", " ", sent) | |
if len(res_wrds) > 0: | |
wrds = sent.split() | |
wrds = ["SPLIT_ME " + w + " SPLIT_ME" if w in res_wrds else w for w in wrds] | |
sents = [x.strip() for x in " ".join(wrds).split("SPLIT_ME") if x.strip() != ""] | |
else: | |
sents = [sent] | |
if lower_case: | |
sents = [s.lower() if s not in res_wrds else s for s in sents] | |
return sents | |
def dup_pho(sent, dup_v_num, dup_c_num): | |
""" | |
duplicate phoneme defined as cmudict | |
http://www.speech.cs.cmu.edu/cgi-bin/cmudict | |
""" | |
if dup_v_num == 1 and dup_c_num == 1: | |
return sent | |
ns = [] | |
for p in sent: | |
ns.append(p) | |
if re.search(r"\d$", p): | |
for i in range(1, dup_v_num): | |
ns.append(f"{p}-{i}P") | |
elif re.search(r"\w", p): | |
for i in range(1, dup_c_num): | |
ns.append(f"{p}-{i}P") | |
return ns | |
def add_word_start(sent): | |
ns = [] | |
do_add = True | |
ws = "▁" | |
for p in sent: | |
if do_add: | |
p = ws + p | |
do_add = False | |
if p == " ": | |
do_add = True | |
else: | |
ns.append(p) | |
return ns | |
def load_reserve_word(reserve_word): | |
if reserve_word == "": | |
return [] | |
with open(reserve_word, "r") as fp: | |
res_wrds = [x.strip().split() for x in fp.readlines() if x.strip() != ""] | |
assert sum([0 if len(x) == 2 else 1 for x in res_wrds]) == 0 | |
res_wrds = dict(res_wrds) | |
return res_wrds | |
def process_sents(sents, args): | |
g2p = G2p() | |
out_sents = [] | |
res_wrds = load_reserve_word(args.reserve_word) | |
for sent in sents: | |
col1 = "" | |
if args.reserve_first_column: | |
col1, sent = sent.split(None, 1) | |
sent = process_sent(sent, g2p, res_wrds, args) | |
if args.reserve_first_column and col1 != "": | |
sent = f"{col1} {sent}" | |
out_sents.append(sent) | |
return out_sents | |
def main(): | |
args = parse() | |
out_sents = [] | |
with open(args.data_path, "r") as fp: | |
sent_list = [x.strip() for x in fp.readlines()] | |
if args.parallel_process_num > 1: | |
try: | |
import submitit | |
except ImportError: | |
logger.warn( | |
"submitit is not found and only one job is used to process the data" | |
) | |
submitit = None | |
if args.parallel_process_num == 1 or submitit is None: | |
out_sents = process_sents(sent_list, args) | |
else: | |
# process sentences with parallel computation | |
lsize = len(sent_list) // args.parallel_process_num + 1 | |
executor = submitit.AutoExecutor(folder=args.logdir) | |
executor.update_parameters(timeout_min=1000, cpus_per_task=4) | |
jobs = [] | |
for i in range(args.parallel_process_num): | |
job = executor.submit( | |
process_sents, sent_list[lsize * i : lsize * (i + 1)], args | |
) | |
jobs.append(job) | |
is_running = True | |
while is_running: | |
time.sleep(5) | |
is_running = sum([job.done() for job in jobs]) < len(jobs) | |
out_sents = list(itertools.chain.from_iterable([job.result() for job in jobs])) | |
with open(args.out_path, "w") as fp: | |
fp.write("\n".join(out_sents) + "\n") | |
if __name__ == "__main__": | |
main() | |