File size: 3,815 Bytes
f1a06f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
#!/usr/bin/env python3
import sys
sys.path.insert(0, "VITS-fast-fine-tuning")
import os
from pathlib import Path
from typing import Any, Dict
import onnx
import torch
import utils
from models import SynthesizerTrn
class OnnxModel(torch.nn.Module):
def __init__(self, model: SynthesizerTrn):
super().__init__()
self.model = model
def forward(
self,
x,
x_lengths,
noise_scale=1,
length_scale=1,
noise_scale_w=1.0,
sid=0,
max_len=None,
):
return self.model.infer(
x=x,
x_lengths=x_lengths,
sid=sid,
noise_scale=noise_scale,
length_scale=length_scale,
noise_scale_w=noise_scale_w,
max_len=max_len,
)[0]
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
@torch.no_grad()
def main():
name = os.environ.get("NAME", None)
if not name:
print("Please provide the environment variable NAME")
return
print("name", name)
if name == "C":
model_path = "G_C.pth"
config_path = "G_C.json"
elif name == "ZhiHuiLaoZhe":
model_path = "G_lkz_lao_new_new1_latest.pth"
config_path = "G_lkz_lao_new_new1_latest.json"
elif name == "ZhiHuiLaoZhe_new":
model_path = "G_lkz_unity_onnx_new1_latest.pth"
config_path = "G_lkz_unity_onnx_new1_latest.json"
else:
model_path = f"G_{name}_latest.pth"
config_path = f"G_{name}_latest.json"
print(name, model_path, config_path)
hps = utils.get_hparams_from_file(config_path)
net_g = SynthesizerTrn(
len(hps.symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
)
_ = net_g.eval()
_ = utils.load_checkpoint(model_path, net_g, None)
x = torch.randint(low=1, high=50, size=(50,), dtype=torch.int64)
x = x.unsqueeze(0)
x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
length_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_w = torch.tensor([1], dtype=torch.float32)
sid = torch.tensor([0], dtype=torch.int64)
model = OnnxModel(net_g)
opset_version = 13
filename = f"vits-zh-hf-fanchen-{name}.onnx"
torch.onnx.export(
model,
(x, x_length, noise_scale, length_scale, noise_scale_w, sid),
filename,
opset_version=opset_version,
input_names=[
"x",
"x_length",
"noise_scale",
"length_scale",
"noise_scale_w",
"sid",
],
output_names=["y"],
dynamic_axes={
"x": {0: "N", 1: "L"}, # n_audio is also known as batch_size
"x_length": {0: "N"},
"y": {0: "N", 2: "L"},
},
)
meta_data = {
"model_type": "vits",
"comment": f"hf-vits-models-fanchen-{name}",
"language": "Chinese",
"add_blank": int(hps.data.add_blank),
"n_speakers": int(hps.data.n_speakers),
"sample_rate": hps.data.sampling_rate,
"punctuation": ", . : ; ! ? , 。 : ; ! ? 、",
}
print("meta_data", meta_data)
add_meta_data(filename=filename, meta_data=meta_data)
if __name__ == "__main__":
main()
|