wetdog commited on
Commit
92df4f5
1 Parent(s): cde96b8

add inference app

Browse files
Files changed (1) hide show
  1. infer_onnx.py +180 -0
infer_onnx.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnxruntime
3
+
4
+ import utils
5
+ from text import text_to_sequence, sequence_to_text
6
+ import torch
7
+ import gradio as gr
8
+ import soundfile as sf
9
+ import tempfile
10
+ import yaml
11
+
12
+ def intersperse(lst, item):
13
+ result = [item] * (len(lst) * 2 + 1)
14
+ result[1::2] = lst
15
+ return result
16
+
17
+
18
+ def process_text(i: int, text: str, device: torch.device):
19
+ print(f"[{i}] - Input text: {text}")
20
+ x = torch.tensor(
21
+ intersperse(text_to_sequence(text, ["catalan_cleaners"]), 0),
22
+ dtype=torch.long,
23
+ device=device,
24
+ )[None]
25
+ x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
26
+ x_phones = sequence_to_text(x.squeeze(0).tolist())
27
+ print(x_phones)
28
+ return x.numpy(), x_lengths.numpy()
29
+
30
+ MODEL_PATH_MATCHA_MEL="matcha_multispeaker_cat_opset_15.onnx"
31
+ MODEL_PATH_MATCHA="matcha_hifigan_multispeaker_cat.onnx"
32
+ MODEL_PATH_VOCOS="mel_spec_22khz.onnx"
33
+ CONFIG_PATH="/home/jgiraldo/projects/tts-onnx-comparison/config_22khz.yaml"
34
+
35
+ sess_options = onnxruntime.SessionOptions()
36
+ model_matcha_mel= onnxruntime.InferenceSession(str(MODEL_PATH_MATCHA_MEL), sess_options=sess_options, providers=["CPUExecutionProvider"])
37
+ model_vocos = onnxruntime.InferenceSession(str(MODEL_PATH_VOCOS), sess_options=sess_options, providers=["CPUExecutionProvider"])
38
+ model_matcha = onnxruntime.InferenceSession(str(MODEL_PATH_MATCHA), sess_options=sess_options, providers=["CPUExecutionProvider"])
39
+
40
+ def vocos_inference(mel: torch.Tensor, config):
41
+
42
+ with open(CONFIG_PATH, "r") as f:
43
+ config = yaml.safe_load(f)
44
+
45
+ params = config["feature_extractor"]["init_args"]
46
+ sample_rate = params["sample_rate"]
47
+ n_fft= params["n_fft"]
48
+ hop_length= params["hop_length"]
49
+ win_length = n_fft
50
+
51
+ # ONNX inference
52
+ mag, x, y = model_vocos.run(
53
+ None,
54
+ {
55
+ "mels": mel.float().numpy()
56
+ },
57
+ )
58
+
59
+ # complex spectrogram from vocos output
60
+ spectrogram = mag * (x + 1j * y)
61
+ window = torch.hann_window(win_length)
62
+
63
+ # Inverse stft
64
+ pad = (win_length - hop_length) // 2
65
+ spectrogram = torch.tensor(spectrogram)
66
+ B, N, T = spectrogram.shape
67
+
68
+ print("Spectrogram synthesized shape", spectrogram.shape)
69
+ # Inverse FFT
70
+ ifft = torch.fft.irfft(spectrogram, n_fft, dim=1, norm="backward")
71
+ ifft = ifft * window[None, :, None]
72
+
73
+ # Overlap and Add
74
+ output_size = (T - 1) * hop_length + win_length
75
+ y = torch.nn.functional.fold(
76
+ ifft, output_size=(1, output_size), kernel_size=(1, win_length), stride=(1, hop_length),
77
+ )[:, 0, 0, pad:-pad]
78
+
79
+ # Window envelope
80
+ window_sq = window.square().expand(1, T, -1).transpose(1, 2)
81
+ window_envelope = torch.nn.functional.fold(
82
+ window_sq, output_size=(1, output_size), kernel_size=(1, win_length), stride=(1, hop_length),
83
+ ).squeeze()[pad:-pad]
84
+
85
+ # Normalize
86
+ assert (window_envelope > 1e-11).all()
87
+ y = y / window_envelope
88
+
89
+ return y
90
+
91
+ def tts(text:str, spk_id:int):
92
+ sid = np.array([int(spk_id)]) if spk_id is not None else None
93
+ text_matcha , text_lengths = process_text(0,text,"cpu")
94
+
95
+ # MATCHA VOCOS
96
+ inputs = {
97
+ "x": text_matcha,
98
+ "x_lengths": text_lengths,
99
+ "scales": np.array([0.667, 1.0], dtype=np.float32),
100
+ "spks": sid
101
+ }
102
+
103
+ mel, mel_lengths = model_matcha_mel.run(None, inputs)
104
+ # vocos inference
105
+ wavs_vocos = vocos_inference(mel, CONFIG_PATH)
106
+
107
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp_matcha_vocos:
108
+ sf.write(fp_matcha_vocos.name, wavs_vocos.squeeze(0), 22050, "PCM_24")
109
+
110
+ #MATCHA HIFIGAN
111
+
112
+ inputs = {
113
+ "x": text_matcha,
114
+ "x_lengths": text_lengths,
115
+ "scales": np.array([0.667, 1.0], dtype=np.float32),
116
+ "spks": sid
117
+ }
118
+ wavs, wav_lengths = model_matcha.run(None, inputs)
119
+
120
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp_matcha:
121
+ sf.write(fp_matcha.name, wavs.squeeze(0), 22050, "PCM_24")
122
+
123
+ return fp_matcha_vocos.name, fp_matcha.name
124
+
125
+ ## GUI space
126
+
127
+ title = """
128
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
129
+ <div
130
+ style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
131
+ > <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
132
+ TTS Catalan Comparison
133
+ </h1> </div>
134
+ </div>
135
+ """
136
+
137
+ description = """
138
+ VITS2 is an end-to-end speech synthesis model that predicts a speech waveform conditional on an input text sequence. VITS2 improved the
139
+ training and inference efficiency and naturalness by introducing adversarial learning into the duration predictor. The transformer
140
+ block was added to the normalizing flows to capture the long-term dependency when transforming the distribution.
141
+ The synthesis quality was improved by incorporating Gaussian noise into the alignment search.
142
+
143
+ 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up ODE-based speech synthesis
144
+
145
+ Models are being trained in openslr69 and festcat datasets
146
+ """
147
+
148
+ article = "Training and demo by BSC."
149
+
150
+ vits2_inference = gr.Interface(
151
+ fn=tts,
152
+ inputs=[
153
+ gr.Textbox(
154
+ value="m'ha costat desenvolupar molt una veu, i ara que la tinc no estaré en silenci.",
155
+ max_lines=1,
156
+ label="Input text",
157
+ ),
158
+ gr.Slider(
159
+ 1,
160
+ 47,
161
+ value=10,
162
+ step=1,
163
+ label="Speaker id",
164
+ info=f"Models are trained on 47 speakers. You can prompt the model using one of these speaker ids.",
165
+ ),
166
+ ],
167
+ outputs=[gr.Audio(label="Matcha vocos", interactive=False, type="filepath"),
168
+ gr.Audio(label="Matcha", interactive=False, type="filepath")]
169
+ )
170
+
171
+ demo = gr.Blocks()
172
+
173
+ with demo:
174
+ gr.Markdown(title)
175
+ gr.Markdown(description)
176
+ gr.TabbedInterface([vits2_inference], ["Multispeaker"])
177
+ gr.Markdown(article)
178
+
179
+ demo.queue(max_size=10)
180
+ demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)