Spaces:
Running
Running
import json | |
from collections import defaultdict | |
from random import shuffle | |
from typing import Optional | |
from tqdm import tqdm | |
import click | |
from text.cleaner import clean_text_bert | |
import os | |
import torch | |
from text.symbols import symbols, num_languages, num_tones | |
def main( | |
metadata: str, | |
cleaned_path: Optional[str], | |
train_path: str, | |
val_path: str, | |
config_path: str, | |
val_per_spk: int, | |
max_val_total: int, | |
clean: bool, | |
): | |
if train_path is None: | |
train_path = os.path.join(os.path.dirname(metadata), 'train.list') | |
if val_path is None: | |
val_path = os.path.join(os.path.dirname(metadata), 'val.list') | |
out_config_path = os.path.join(os.path.dirname(metadata), 'config.json') | |
if cleaned_path is None: | |
cleaned_path = metadata + ".cleaned" | |
if clean: | |
out_file = open(cleaned_path, "w", encoding="utf-8") | |
new_symbols = [] | |
for line in tqdm(open(metadata, encoding="utf-8").readlines()): | |
try: | |
utt, spk, language, text = line.strip().split("|") | |
norm_text, phones, tones, word2ph, bert = clean_text_bert(text, language, device='cuda:0') | |
for ph in phones: | |
if ph not in symbols and ph not in new_symbols: | |
new_symbols.append(ph) | |
print('update!, now symbols:') | |
print(new_symbols) | |
with open(f'{language}_symbol.txt', 'w') as f: | |
f.write(f'{new_symbols}') | |
assert len(phones) == len(tones) | |
assert len(phones) == sum(word2ph) | |
out_file.write( | |
"{}|{}|{}|{}|{}|{}|{}\n".format( | |
utt, | |
spk, | |
language, | |
norm_text, | |
" ".join(phones), | |
" ".join([str(i) for i in tones]), | |
" ".join([str(i) for i in word2ph]), | |
) | |
) | |
bert_path = utt.replace(".wav", ".bert.pt") | |
os.makedirs(os.path.dirname(bert_path), exist_ok=True) | |
torch.save(bert.cpu(), bert_path) | |
except Exception as error: | |
print("err!", line, error) | |
out_file.close() | |
metadata = cleaned_path | |
spk_utt_map = defaultdict(list) | |
spk_id_map = {} | |
current_sid = 0 | |
with open(metadata, encoding="utf-8") as f: | |
for line in f.readlines(): | |
utt, spk, language, text, phones, tones, word2ph = line.strip().split("|") | |
spk_utt_map[spk].append(line) | |
if spk not in spk_id_map.keys(): | |
spk_id_map[spk] = current_sid | |
current_sid += 1 | |
train_list = [] | |
val_list = [] | |
for spk, utts in spk_utt_map.items(): | |
shuffle(utts) | |
val_list += utts[:val_per_spk] | |
train_list += utts[val_per_spk:] | |
if len(val_list) > max_val_total: | |
train_list += val_list[max_val_total:] | |
val_list = val_list[:max_val_total] | |
with open(train_path, "w", encoding="utf-8") as f: | |
for line in train_list: | |
f.write(line) | |
with open(val_path, "w", encoding="utf-8") as f: | |
for line in val_list: | |
f.write(line) | |
config = json.load(open(config_path, encoding="utf-8")) | |
config["data"]["spk2id"] = spk_id_map | |
config["data"]["training_files"] = train_path | |
config["data"]["validation_files"] = val_path | |
config["data"]["n_speakers"] = len(spk_id_map) | |
config["num_languages"] = num_languages | |
config["num_tones"] = num_tones | |
config["symbols"] = symbols | |
with open(out_config_path, "w", encoding="utf-8") as f: | |
json.dump(config, f, indent=2, ensure_ascii=False) | |
if __name__ == "__main__": | |
main() | |