NTT123 commited on
Commit
7383c33
1 Parent(s): fb9df88
Files changed (2) hide show
  1. app.py +11 -9
  2. inference.py +1 -1
app.py CHANGED
@@ -5,22 +5,24 @@
5
 
6
 
7
  import gradio as gr
8
-
9
  from inference import load_tacotron_model, load_wavegru_net, mel_to_wav, text_to_mel
10
  from wavegru_cpp import extract_weight_mask, load_wavegru_cpp
11
 
12
- alphabet, tacotron_net, tacotron_config = load_tacotron_model(
13
- "./alphabet.txt", "./tacotron.toml", "./pretrained_model_ljs_600k.ckpt"
14
- )
15
-
16
 
17
- wavegru_config, wavegru_net = load_wavegru_net("./wavegru.yaml", "./wavegru_vocoder_1024_v3_1330000.ckpt")
 
 
 
18
 
19
- wave_cpp_weight_mask = extract_weight_mask(wavegru_net)
20
- wavecpp = load_wavegru_cpp(wave_cpp_weight_mask, wavegru_config["upsample_factors"][-1])
 
21
 
 
 
 
 
22
 
23
- def speak(text):
24
  mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
25
  y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
26
  return 24_000, y
 
5
 
6
 
7
  import gradio as gr
 
8
  from inference import load_tacotron_model, load_wavegru_net, mel_to_wav, text_to_mel
9
  from wavegru_cpp import extract_weight_mask, load_wavegru_cpp
10
 
 
 
 
 
11
 
12
+ def speak(text):
13
+ alphabet, tacotron_net, tacotron_config = load_tacotron_model(
14
+ "./alphabet.txt", "./tacotron.toml", "./pretrained_model_ljs_600k.ckpt"
15
+ )
16
 
17
+ wavegru_config, wavegru_net = load_wavegru_net(
18
+ "./wavegru.yaml", "./wavegru_vocoder_1024_v3_1330000.ckpt"
19
+ )
20
 
21
+ wave_cpp_weight_mask = extract_weight_mask(wavegru_net)
22
+ wavecpp = load_wavegru_cpp(
23
+ wave_cpp_weight_mask, wavegru_config["upsample_factors"][-1]
24
+ )
25
 
 
26
  mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
27
  y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
28
  return 24_000, y
inference.py CHANGED
@@ -67,7 +67,7 @@ def mel_to_wav(net, netcpp, mel, config):
67
  if len(mel.shape) == 2:
68
  mel = mel[None]
69
  pad = config["num_pad_frames"] // 2 + 2
70
- mel = np.pad(mel, [(0, 0), (pad, pad), (0, 0)], mode="reflect")
71
  ft = wavegru_inference(net, mel)
72
  ft = jax.device_get(ft[0])
73
  wav = netcpp.inference(ft, 1.0)
 
67
  if len(mel.shape) == 2:
68
  mel = mel[None]
69
  pad = config["num_pad_frames"] // 2 + 2
70
+ mel = np.pad(mel, [(0, 0), (pad, pad), (0, 0)], mode="edge")
71
  ft = wavegru_inference(net, mel)
72
  ft = jax.device_get(ft[0])
73
  wav = netcpp.inference(ft, 1.0)