|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
from text import text_to_sequence |
|
import numpy as np |
|
from scipy.io import wavfile |
|
import torch |
|
import json |
|
import commons |
|
import utils |
|
import sys |
|
import pathlib |
|
|
|
try: |
|
import onnxruntime as ort |
|
except ImportError: |
|
print('Please install onnxruntime!') |
|
sys.exit(1) |
|
|
|
|
|
def to_numpy(tensor: torch.Tensor): |
|
return tensor.detach().cpu().numpy() if tensor.requires_grad \ |
|
else tensor.detach().numpy() |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser(description='inference') |
|
parser.add_argument('--onnx_model', required=True, help='onnx model') |
|
parser.add_argument('--cfg', required=True, help='config file') |
|
parser.add_argument('--outdir', default="onnx_output", |
|
help='ouput directory') |
|
|
|
|
|
|
|
|
|
parser.add_argument('--test_file', required=True, help='test file') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def get_symbols_from_json(path): |
|
import os |
|
assert os.path.isfile(path) |
|
with open(path, 'r') as f: |
|
data = json.load(f) |
|
return data['symbols'] |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
print(args) |
|
if not pathlib.Path(args.outdir).exists(): |
|
pathlib.Path(args.outdir).mkdir(exist_ok=True, parents=True) |
|
|
|
symbols = get_symbols_from_json(args.cfg) |
|
phone_dict = { |
|
symbol: i for i, symbol in enumerate(symbols) |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hps = utils.get_hparams_from_file(args.cfg) |
|
|
|
ort_sess = ort.InferenceSession(args.onnx_model) |
|
|
|
with open(args.test_file) as fin: |
|
for line in fin: |
|
arr = line.strip().split("|") |
|
audio_path = arr[0] |
|
|
|
|
|
sid = 3 |
|
text = '[ZH]你好,重庆市位于四川省东边[ZH]' |
|
|
|
|
|
|
|
seq = text_to_sequence(text, cleaner_names=hps.data.text_cleaners |
|
) |
|
if hps.data.add_blank: |
|
seq = commons.intersperse(seq, 0) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = np.array([seq], dtype=np.int64) |
|
x_len = np.array([x.shape[1]], dtype=np.int64) |
|
sid = np.array([sid], dtype=np.int64) |
|
|
|
|
|
scales = np.array([0.667, 0.8, 1], dtype=np.float32) |
|
|
|
|
|
scales.resize(1, 3) |
|
|
|
ort_inputs = { |
|
'input': x, |
|
'input_lengths': x_len, |
|
'scales': scales, |
|
'sid': sid |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import time |
|
|
|
start_time = time.perf_counter() |
|
audio = np.squeeze(ort_sess.run(None, ort_inputs)) |
|
audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6 |
|
audio = np.clip(audio, -32767.0, 32767.0) |
|
end_time = time.perf_counter() |
|
|
|
print("infer time cost: ", end_time - start_time, "s") |
|
|
|
wavfile.write(args.outdir + "/" + audio_path.split("/")[-1], |
|
hps.data.sampling_rate, audio.astype(np.int16)) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|