Spaces:
Runtime error
Runtime error
NTT123
commited on
Commit
•
7383c33
1
Parent(s):
fb9df88
update
Browse files- app.py +11 -9
- inference.py +1 -1
app.py
CHANGED
@@ -5,22 +5,24 @@
|
|
5 |
|
6 |
|
7 |
import gradio as gr
|
8 |
-
|
9 |
from inference import load_tacotron_model, load_wavegru_net, mel_to_wav, text_to_mel
|
10 |
from wavegru_cpp import extract_weight_mask, load_wavegru_cpp
|
11 |
|
12 |
-
alphabet, tacotron_net, tacotron_config = load_tacotron_model(
|
13 |
-
"./alphabet.txt", "./tacotron.toml", "./pretrained_model_ljs_600k.ckpt"
|
14 |
-
)
|
15 |
-
|
16 |
|
17 |
-
|
|
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
|
|
|
21 |
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
def speak(text):
|
24 |
mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
|
25 |
y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
|
26 |
return 24_000, y
|
|
|
5 |
|
6 |
|
7 |
import gradio as gr
|
|
|
8 |
from inference import load_tacotron_model, load_wavegru_net, mel_to_wav, text_to_mel
|
9 |
from wavegru_cpp import extract_weight_mask, load_wavegru_cpp
|
10 |
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
def speak(text):
|
13 |
+
alphabet, tacotron_net, tacotron_config = load_tacotron_model(
|
14 |
+
"./alphabet.txt", "./tacotron.toml", "./pretrained_model_ljs_600k.ckpt"
|
15 |
+
)
|
16 |
|
17 |
+
wavegru_config, wavegru_net = load_wavegru_net(
|
18 |
+
"./wavegru.yaml", "./wavegru_vocoder_1024_v3_1330000.ckpt"
|
19 |
+
)
|
20 |
|
21 |
+
wave_cpp_weight_mask = extract_weight_mask(wavegru_net)
|
22 |
+
wavecpp = load_wavegru_cpp(
|
23 |
+
wave_cpp_weight_mask, wavegru_config["upsample_factors"][-1]
|
24 |
+
)
|
25 |
|
|
|
26 |
mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
|
27 |
y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
|
28 |
return 24_000, y
|
inference.py
CHANGED
@@ -67,7 +67,7 @@ def mel_to_wav(net, netcpp, mel, config):
|
|
67 |
if len(mel.shape) == 2:
|
68 |
mel = mel[None]
|
69 |
pad = config["num_pad_frames"] // 2 + 2
|
70 |
-
mel = np.pad(mel, [(0, 0), (pad, pad), (0, 0)], mode="
|
71 |
ft = wavegru_inference(net, mel)
|
72 |
ft = jax.device_get(ft[0])
|
73 |
wav = netcpp.inference(ft, 1.0)
|
|
|
67 |
if len(mel.shape) == 2:
|
68 |
mel = mel[None]
|
69 |
pad = config["num_pad_frames"] // 2 + 2
|
70 |
+
mel = np.pad(mel, [(0, 0), (pad, pad), (0, 0)], mode="edge")
|
71 |
ft = wavegru_inference(net, mel)
|
72 |
ft = jax.device_get(ft[0])
|
73 |
wav = netcpp.inference(ft, 1.0)
|