sayashi commited on
Commit
3fd0c0d
1 Parent(s): f4da48c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -1,15 +1,19 @@
1
  # coding=utf-8
2
  import os
3
  import re
 
4
  import utils
5
  import commons
6
  import json
 
7
  import gradio as gr
8
  from models import SynthesizerTrn
9
  from text import text_to_sequence
10
  from torch import no_grad, LongTensor
11
  import logging
12
  logging.getLogger('numba').setLevel(logging.WARNING)
 
 
13
  hps_ms = utils.get_hparams_from_file(r'config/config.json')
14
 
15
  def get_text(text, hps):
@@ -22,10 +26,11 @@ def get_text(text, hps):
22
  def create_tts_fn(net_g_ms, speaker_id):
23
  def tts_fn(text, language, noise_scale, noise_scale_w, length_scale):
24
  text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
25
- text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
26
- max_len = 150
27
- if text_len > max_len:
28
- return "Error: Text is too long", None
 
29
  if language == 0:
30
  text = f"[ZH]{text}[ZH]"
31
  elif language == 1:
@@ -34,11 +39,11 @@ def create_tts_fn(net_g_ms, speaker_id):
34
  text = f"{text}"
35
  stn_tst, clean_text = get_text(text, hps_ms)
36
  with no_grad():
37
- x_tst = stn_tst.unsqueeze(0)
38
- x_tst_lengths = LongTensor([stn_tst.size(0)])
39
- sid = LongTensor([speaker_id])
40
  audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
41
- length_scale=length_scale)[0][0, 0].data.float().numpy()
42
 
43
  return "Success", (22050, audio)
44
  return tts_fn
@@ -72,23 +77,29 @@ download_audio_js = """
72
  """
73
 
74
  if __name__ == '__main__':
 
 
 
 
 
 
75
  models = []
76
  with open("pretrained_models/info.json", "r", encoding="utf-8") as f:
77
  models_info = json.load(f)
78
  for i, info in models_info.items():
 
 
 
 
 
79
  net_g_ms = SynthesizerTrn(
80
  len(hps_ms.symbols),
81
  hps_ms.data.filter_length // 2 + 1,
82
  hps_ms.train.segment_size // hps_ms.data.hop_length,
83
  n_speakers=hps_ms.data.n_speakers,
84
  **hps_ms.model)
85
- _ = net_g_ms.eval()
86
- sid = info['sid']
87
- name_en = info['name_en']
88
- name_zh = info['name_zh']
89
- title = info['title']
90
- cover = f"pretrained_models/{i}/{info['cover']}"
91
  utils.load_checkpoint(f'pretrained_models/{i}/{i}.pth', net_g_ms, None)
 
92
  models.append((sid, name_en, name_zh, title, cover, net_g_ms, create_tts_fn(net_g_ms, sid)))
93
  with gr.Blocks() as app:
94
  gr.Markdown(
 
1
  # coding=utf-8
2
  import os
3
  import re
4
+ import argparse
5
  import utils
6
  import commons
7
  import json
8
+ import torch
9
  import gradio as gr
10
  from models import SynthesizerTrn
11
  from text import text_to_sequence
12
  from torch import no_grad, LongTensor
13
  import logging
14
  logging.getLogger('numba').setLevel(logging.WARNING)
15
+ limitation = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces
16
+
17
  hps_ms = utils.get_hparams_from_file(r'config/config.json')
18
 
19
  def get_text(text, hps):
 
26
  def create_tts_fn(net_g_ms, speaker_id):
27
  def tts_fn(text, language, noise_scale, noise_scale_w, length_scale):
28
  text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
29
+ if limitation:
30
+ text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
31
+ max_len = 100
32
+ if text_len > max_len:
33
+ return "Error: Text is too long", None
34
  if language == 0:
35
  text = f"[ZH]{text}[ZH]"
36
  elif language == 1:
 
39
  text = f"{text}"
40
  stn_tst, clean_text = get_text(text, hps_ms)
41
  with no_grad():
42
+ x_tst = stn_tst.unsqueeze(0).to(device)
43
+ x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
44
+ sid = LongTensor([speaker_id]).to(device)
45
  audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
46
+ length_scale=length_scale)[0][0, 0].data.cpu().float().numpy()
47
 
48
  return "Success", (22050, audio)
49
  return tts_fn
 
77
  """
78
 
79
  if __name__ == '__main__':
80
+ parser = argparse.ArgumentParser()
81
+ parser.add_argument('--device', type=str, default='cpu')
82
+ parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
83
+ args = parser.parse_args()
84
+ device = torch.device(args.device)
85
+
86
  models = []
87
  with open("pretrained_models/info.json", "r", encoding="utf-8") as f:
88
  models_info = json.load(f)
89
  for i, info in models_info.items():
90
+ sid = info['sid']
91
+ name_en = info['name_en']
92
+ name_zh = info['name_zh']
93
+ title = info['title']
94
+ cover = f"pretrained_models/{i}/{info['cover']}"
95
  net_g_ms = SynthesizerTrn(
96
  len(hps_ms.symbols),
97
  hps_ms.data.filter_length // 2 + 1,
98
  hps_ms.train.segment_size // hps_ms.data.hop_length,
99
  n_speakers=hps_ms.data.n_speakers,
100
  **hps_ms.model)
 
 
 
 
 
 
101
  utils.load_checkpoint(f'pretrained_models/{i}/{i}.pth', net_g_ms, None)
102
+ _ = net_g_ms.eval().to(device)
103
  models.append((sid, name_en, name_zh, title, cover, net_g_ms, create_tts_fn(net_g_ms, sid)))
104
  with gr.Blocks() as app:
105
  gr.Markdown(