zhzluke96 commited on
Commit
84cfd61
·
1 Parent(s): 22884c9
modules/devices.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def torch_gc():
5
+ if torch.cuda.is_available():
6
+ with torch.cuda.device("cuda"):
7
+ torch.cuda.empty_cache()
8
+ torch.cuda.ipc_collect()
modules/generate_audio.py CHANGED
@@ -8,6 +8,8 @@ from modules import models, config
8
 
9
  import logging
10
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
13
 
@@ -96,6 +98,8 @@ def generate_audio_batch(
96
 
97
  sample_rate = 24000
98
 
 
 
99
  return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
100
 
101
 
 
8
 
9
  import logging
10
 
11
+ from modules import devices
12
+
13
  logger = logging.getLogger(__name__)
14
 
15
 
 
98
 
99
  sample_rate = 24000
100
 
101
+ devices.torch_gc()
102
+
103
  return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
104
 
105
 
modules/normalization.py CHANGED
@@ -75,13 +75,15 @@ character_map = {
75
  "“": " ",
76
  "’": " ",
77
  "”": " ",
 
 
78
  ":": ",",
79
  ";": ",",
80
  "!": ".",
81
  "(": ",",
82
  ")": ",",
83
- # '[': ',',
84
- # ']': ',',
85
  ">": ",",
86
  "<": ",",
87
  "-": ",",
@@ -110,13 +112,6 @@ def apply_emoji_map(text):
110
  return emojiswitch.demojize(text, delimiters=("", ""), lang="zh")
111
 
112
 
113
- @pre_normalize()
114
- def apply_markdown_to_text(text):
115
- if is_markdown(text):
116
- text = markdown_to_text(text)
117
- return text
118
-
119
-
120
  @post_normalize()
121
  def insert_spaces_between_uppercase(s):
122
  # 使用正则表达式在每个相邻的大写字母之间插入空格
@@ -127,6 +122,29 @@ def insert_spaces_between_uppercase(s):
127
  )
128
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def ensure_suffix(a: str, b: str, c: str):
131
  a = a.strip()
132
  if not a.endswith(b):
@@ -171,6 +189,7 @@ def sentence_normalize(sentence_text: str):
171
  sentences = tx.normalize(part)
172
  dest_text = ""
173
  for sentence in sentences:
 
174
  dest_text += sentence
175
  return dest_text
176
 
@@ -197,7 +216,6 @@ def text_normalize(text, is_end=False):
197
  lines = [line for line in lines if line]
198
  lines = [sentence_normalize(line) for line in lines]
199
  content = "\n".join(lines)
200
- content = apply_post_normalize(content)
201
  return content
202
 
203
 
@@ -216,6 +234,16 @@ console.log('1')
216
 
217
  *一条文本*
218
  """,
 
 
 
 
 
 
 
 
 
 
219
  ]
220
 
221
  for i, test_case in enumerate(test_cases):
 
75
  "“": " ",
76
  "’": " ",
77
  "”": " ",
78
+ '"': " ",
79
+ "'": " ",
80
  ":": ",",
81
  ";": ",",
82
  "!": ".",
83
  "(": ",",
84
  ")": ",",
85
+ "[": ",",
86
+ "]": ",",
87
  ">": ",",
88
  "<": ",",
89
  "-": ",",
 
112
  return emojiswitch.demojize(text, delimiters=("", ""), lang="zh")
113
 
114
 
 
 
 
 
 
 
 
115
  @post_normalize()
116
  def insert_spaces_between_uppercase(s):
117
  # 使用正则表达式在每个相邻的大写字母之间插入空格
 
122
  )
123
 
124
 
125
+ @pre_normalize()
126
+ def apply_markdown_to_text(text):
127
+ if is_markdown(text):
128
+ text = markdown_to_text(text)
129
+ return text
130
+
131
+
132
+ # 将 "xxx" => \nxxx\n
133
+ # 将 'xxx' => \nxxx\n
134
+ @pre_normalize()
135
+ def replace_quotes(text):
136
+ repl = r"\n\1\n"
137
+ patterns = [
138
+ ['"', '"'],
139
+ ["'", "'"],
140
+ ["“", "”"],
141
+ ["‘", "’"],
142
+ ]
143
+ for p in patterns:
144
+ text = re.sub(rf"({p[0]}[^{p[0]}{p[1]}]+?{p[1]})", repl, text)
145
+ return text
146
+
147
+
148
  def ensure_suffix(a: str, b: str, c: str):
149
  a = a.strip()
150
  if not a.endswith(b):
 
189
  sentences = tx.normalize(part)
190
  dest_text = ""
191
  for sentence in sentences:
192
+ sentence = apply_post_normalize(sentence)
193
  dest_text += sentence
194
  return dest_text
195
 
 
216
  lines = [line for line in lines if line]
217
  lines = [sentence_normalize(line) for line in lines]
218
  content = "\n".join(lines)
 
219
  return content
220
 
221
 
 
234
 
235
  *一条文本*
236
  """,
237
+ """
238
+ 在沙漠、岩石、雪地上行走了很长的时间以后,小王子终于发现了一条大路。所有的大路都是通往人住的地方的。
239
+ “你们好。”小王子说。
240
+ 这是一个玫瑰盛开的花园。
241
+ “你好。”玫瑰花说道。
242
+ 小王子瞅着这些花,它们全都和他的那朵花一样。
243
+ “你们是什么花?”小王子惊奇地问。
244
+ “我们是玫瑰花。”花儿们说道。
245
+ “啊!”小王子说……。
246
+ """,
247
  ]
248
 
249
  for i, test_case in enumerate(test_cases):
modules/utils/audio.py CHANGED
@@ -5,6 +5,16 @@ import pyrubberband as pyrb
5
  import numpy as np
6
  from io import BytesIO
7
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def audiosegment_to_librosawav(audiosegment):
10
  channel_sounds = audiosegment.split_to_mono()
 
5
  import numpy as np
6
  from io import BytesIO
7
 
8
+ INT16_MAX = np.iinfo(np.int16).max
9
+
10
+
11
+ def audio_to_int16(audio_data):
12
+ if audio_data.dtype == np.float32:
13
+ audio_data = (audio_data * INT16_MAX).astype(np.int16)
14
+ if audio_data.dtype == np.float16:
15
+ audio_data = (audio_data * INT16_MAX).astype(np.int16)
16
+ return audio_data
17
+
18
 
19
  def audiosegment_to_librosawav(audiosegment):
20
  channel_sounds = audiosegment.split_to_mono()
webui.py CHANGED
@@ -1,4 +1,16 @@
1
- import spaces
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import logging
4
 
@@ -29,7 +41,7 @@ from modules.api.utils import calc_spk_style
29
  from modules.normalization import text_normalize
30
  from modules import refiner, config
31
 
32
- from modules.utils import env
33
  from modules.SentenceSplitter import SentenceSplitter
34
 
35
  torch._dynamo.config.cache_size_limit = 64
@@ -40,7 +52,7 @@ webui_config = {
40
  "tts_max": 1000,
41
  "ssml_max": 5000,
42
  "spliter_threshold": 100,
43
- "max_batch_size": 12,
44
  }
45
 
46
 
@@ -65,7 +77,7 @@ def segments_length_limit(segments, total_max: int):
65
 
66
  @torch.inference_mode()
67
  @spaces.GPU
68
- def synthesize_ssml(ssml: str, batch_size=8):
69
  try:
70
  batch_size = int(batch_size)
71
  except Exception:
@@ -92,7 +104,10 @@ def synthesize_ssml(ssml: str, batch_size=8):
92
 
93
  buffer.seek(0)
94
 
95
- return buffer.read()
 
 
 
96
 
97
 
98
  @torch.inference_mode()
@@ -110,12 +125,12 @@ def tts_generate(
110
  prefix,
111
  style,
112
  disable_normalize=False,
113
- batch_size=8,
114
  ):
115
  try:
116
  batch_size = int(batch_size)
117
  except Exception:
118
- batch_size = 8
119
 
120
  max_len = webui_config["tts_max"]
121
  text = text.strip()[0:max_len]
@@ -157,8 +172,6 @@ def tts_generate(
157
  prompt2=prompt2,
158
  prefix=prefix,
159
  )
160
-
161
- return sample_rate, audio_data
162
  else:
163
  spliter = SentenceSplitter(webui_config["spliter_threshold"])
164
  sentences = spliter.parse(text)
@@ -178,7 +191,8 @@ def tts_generate(
178
  sample_rate = audio_data_batch[0][0]
179
  audio_data = np.concatenate([data for _, data in audio_data_batch])
180
 
181
- return sample_rate, audio_data
 
182
 
183
 
184
  @torch.inference_mode()
@@ -366,7 +380,7 @@ def create_tts_interface():
366
  batch_size_input = gr.Slider(
367
  1,
368
  webui_config["max_batch_size"],
369
- value=8,
370
  step=1,
371
  label="Batch Size",
372
  )
@@ -593,7 +607,7 @@ def create_ssml_interface():
593
  # batch size
594
  batch_size_input = gr.Slider(
595
  label="Batch Size",
596
- value=8,
597
  minimum=1,
598
  maximum=webui_config["max_batch_size"],
599
  step=1,
@@ -892,7 +906,7 @@ if __name__ == "__main__":
892
 
893
  webui_config["tts_max"] = env.get_env_or_arg(args, "tts_max_len", 1000, int)
894
  webui_config["ssml_max"] = env.get_env_or_arg(args, "ssml_max_len", 5000, int)
895
- webui_config["max_batch_size"] = env.get_env_or_arg(args, "max_batch_size", 12, int)
896
 
897
  demo = create_interface()
898
 
 
1
+ try:
2
+ import spaces
3
+ except:
4
+
5
+ class NoneSpaces:
6
+ def __init__(self):
7
+ pass
8
+
9
+ def GPU(self, fn):
10
+ return fn
11
+
12
+ spaces = NoneSpaces()
13
+
14
  import os
15
  import logging
16
 
 
41
  from modules.normalization import text_normalize
42
  from modules import refiner, config
43
 
44
+ from modules.utils import env, audio
45
  from modules.SentenceSplitter import SentenceSplitter
46
 
47
  torch._dynamo.config.cache_size_limit = 64
 
52
  "tts_max": 1000,
53
  "ssml_max": 5000,
54
  "spliter_threshold": 100,
55
+ "max_batch_size": 8,
56
  }
57
 
58
 
 
77
 
78
  @torch.inference_mode()
79
  @spaces.GPU
80
+ def synthesize_ssml(ssml: str, batch_size=4):
81
  try:
82
  batch_size = int(batch_size)
83
  except Exception:
 
104
 
105
  buffer.seek(0)
106
 
107
+ audio_data = buffer.read()
108
+ audio_data = audio.audio_to_int16(audio_data)
109
+
110
+ return audio_data
111
 
112
 
113
  @torch.inference_mode()
 
125
  prefix,
126
  style,
127
  disable_normalize=False,
128
+ batch_size=4,
129
  ):
130
  try:
131
  batch_size = int(batch_size)
132
  except Exception:
133
+ batch_size = 4
134
 
135
  max_len = webui_config["tts_max"]
136
  text = text.strip()[0:max_len]
 
172
  prompt2=prompt2,
173
  prefix=prefix,
174
  )
 
 
175
  else:
176
  spliter = SentenceSplitter(webui_config["spliter_threshold"])
177
  sentences = spliter.parse(text)
 
191
  sample_rate = audio_data_batch[0][0]
192
  audio_data = np.concatenate([data for _, data in audio_data_batch])
193
 
194
+ audio_data = audio.audio_to_int16(audio_data)
195
+ return sample_rate, audio_data
196
 
197
 
198
  @torch.inference_mode()
 
380
  batch_size_input = gr.Slider(
381
  1,
382
  webui_config["max_batch_size"],
383
+ value=4,
384
  step=1,
385
  label="Batch Size",
386
  )
 
607
  # batch size
608
  batch_size_input = gr.Slider(
609
  label="Batch Size",
610
+ value=4,
611
  minimum=1,
612
  maximum=webui_config["max_batch_size"],
613
  step=1,
 
906
 
907
  webui_config["tts_max"] = env.get_env_or_arg(args, "tts_max_len", 1000, int)
908
  webui_config["ssml_max"] = env.get_env_or_arg(args, "ssml_max_len", 5000, int)
909
+ webui_config["max_batch_size"] = env.get_env_or_arg(args, "max_batch_size", 8, int)
910
 
911
  demo = create_interface()
912