zhzluke96 commited on
Commit
ebc4336
·
1 Parent(s): 21473c0
language/zh-CN.json CHANGED
@@ -57,8 +57,8 @@
57
  "🔊Generate speaker.pt": "🔊生成 speaker.pt",
58
  "Save .pt file": "保存.pt文件",
59
  "Save to File": "保存到文件",
60
- "🎤Test voice": "🎤测试语音",
61
- "Test Voice": "测试语音",
62
  "Current Seed": "当前种子",
63
  "Output Audio": "输出音频",
64
  "Merger": "融合",
@@ -79,6 +79,7 @@
79
  "README": "README",
80
  "readme": "readme",
81
  "changelog": "changelog",
 
82
  "TTS_STYLE_GUIDE": ["后缀为 _p 表示带prompt,效果更强但是影响质量"],
83
  "SSML_SPLITER_GUIDE": [
84
  "- 字数限制详见README,超过部分将截断",
 
57
  "🔊Generate speaker.pt": "🔊生成 speaker.pt",
58
  "Save .pt file": "保存.pt文件",
59
  "Save to File": "保存到文件",
60
+ "🎤Test voice": "🎤试语",
61
+ "Test Voice": "试语",
62
  "Current Seed": "当前种子",
63
  "Output Audio": "输出音频",
64
  "Merger": "融合",
 
79
  "README": "README",
80
  "readme": "readme",
81
  "changelog": "changelog",
82
+ "💼Speaker file": "💼音色文件",
83
  "TTS_STYLE_GUIDE": ["后缀为 _p 表示带prompt,效果更强但是影响质量"],
84
  "SSML_SPLITER_GUIDE": [
85
  "- 字数限制详见README,超过部分将截断",
modules/api/impl/google_api.py CHANGED
@@ -14,7 +14,7 @@ from modules import generate_audio as generate
14
  from modules.speaker import speaker_mgr
15
 
16
 
17
- from modules.ssml import parse_ssml
18
  from modules.SynthesizeSegments import (
19
  SynthesizeSegments,
20
  combine_audio_segments,
@@ -65,6 +65,8 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
65
  audioConfig = request.audioConfig
66
 
67
  # 提取参数
 
 
68
  language_code = voice.languageCode
69
  voice_name = voice.name
70
  infer_seed = voice.seed or 42
@@ -86,9 +88,8 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
86
  # TODO maybe need to change the sample rate
87
  sample_rate = 24000
88
 
89
- # TODO 使用 speaker
90
- spk = speaker_mgr.get_speaker(voice_name)
91
- if spk is None:
92
  raise HTTPException(
93
  status_code=400, detail="The specified voice name is not supported."
94
  )
@@ -120,7 +121,8 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
120
 
121
  elif input.ssml:
122
  # 处理SSML合成逻辑
123
- segments = parse_ssml(input.ssml)
 
124
  for seg in segments:
125
  seg["text"] = text_normalize(seg["text"], is_end=True)
126
 
@@ -171,7 +173,11 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
171
  import logging
172
 
173
  logging.exception(e)
174
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
175
 
176
 
177
  def setup(app: APIManager):
 
14
  from modules.speaker import speaker_mgr
15
 
16
 
17
+ from modules.ssml_parser.SSMLParser import create_ssml_parser
18
  from modules.SynthesizeSegments import (
19
  SynthesizeSegments,
20
  combine_audio_segments,
 
65
  audioConfig = request.audioConfig
66
 
67
  # 提取参数
68
+
69
+ # TODO 这个也许应该传给 normalizer
70
  language_code = voice.languageCode
71
  voice_name = voice.name
72
  infer_seed = voice.seed or 42
 
88
  # TODO maybe need to change the sample rate
89
  sample_rate = 24000
90
 
91
+ # 虽然 calc_spk_style 可以解析 seed 形式,但是这个接口只准备支持 speakers list 中存在的 speaker
92
+ if speaker_mgr.get_speaker(voice_name) is None:
 
93
  raise HTTPException(
94
  status_code=400, detail="The specified voice name is not supported."
95
  )
 
121
 
122
  elif input.ssml:
123
  # 处理SSML合成逻辑
124
+ parser = create_ssml_parser()
125
+ segments = parser.parse(input.ssml)
126
  for seg in segments:
127
  seg["text"] = text_normalize(seg["text"], is_end=True)
128
 
 
173
  import logging
174
 
175
  logging.exception(e)
176
+
177
+ if isinstance(e, HTTPException):
178
+ raise e
179
+ else:
180
+ raise HTTPException(status_code=500, detail=str(e))
181
 
182
 
183
  def setup(app: APIManager):
modules/api/impl/openai_api.py CHANGED
@@ -115,7 +115,11 @@ async def openai_speech_api(
115
  import logging
116
 
117
  logging.exception(e)
118
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
119
 
120
 
121
  class TranscribeSegment(BaseModel):
 
115
  import logging
116
 
117
  logging.exception(e)
118
+
119
+ if isinstance(e, HTTPException):
120
+ raise e
121
+ else:
122
+ raise HTTPException(status_code=500, detail=str(e))
123
 
124
 
125
  class TranscribeSegment(BaseModel):
modules/api/impl/refiner_api.py CHANGED
@@ -42,7 +42,11 @@ async def refiner_prompt_post(request: RefineTextRequest):
42
  import logging
43
 
44
  logging.exception(e)
45
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
46
 
47
 
48
  def setup(api_manager: APIManager):
 
42
  import logging
43
 
44
  logging.exception(e)
45
+
46
+ if isinstance(e, HTTPException):
47
+ raise e
48
+ else:
49
+ raise HTTPException(status_code=500, detail=str(e))
50
 
51
 
52
  def setup(api_manager: APIManager):
modules/api/impl/ssml_api.py CHANGED
@@ -7,7 +7,7 @@ from fastapi.responses import FileResponse
7
 
8
 
9
  from modules.normalization import text_normalize
10
- from modules.ssml import parse_ssml
11
  from modules.SynthesizeSegments import (
12
  SynthesizeSegments,
13
  combine_audio_segments,
@@ -34,7 +34,7 @@ async def synthesize_ssml(
34
  ):
35
  try:
36
  ssml = request.ssml
37
- format = request.format
38
  batch_size = request.batch_size
39
 
40
  if batch_size < 1:
@@ -42,10 +42,16 @@ async def synthesize_ssml(
42
  status_code=400, detail="Batch size must be greater than 0."
43
  )
44
 
45
- if not ssml:
46
  raise HTTPException(status_code=400, detail="SSML content is required.")
47
 
48
- segments = parse_ssml(ssml)
 
 
 
 
 
 
49
  for seg in segments:
50
  seg["text"] = text_normalize(seg["text"], is_end=True)
51
 
@@ -63,7 +69,11 @@ async def synthesize_ssml(
63
  import logging
64
 
65
  logging.exception(e)
66
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
67
 
68
 
69
  def setup(api_manager: APIManager):
 
7
 
8
 
9
  from modules.normalization import text_normalize
10
+ from modules.ssml_parser.SSMLParser import create_ssml_parser
11
  from modules.SynthesizeSegments import (
12
  SynthesizeSegments,
13
  combine_audio_segments,
 
34
  ):
35
  try:
36
  ssml = request.ssml
37
+ format = request.format.lower()
38
  batch_size = request.batch_size
39
 
40
  if batch_size < 1:
 
42
  status_code=400, detail="Batch size must be greater than 0."
43
  )
44
 
45
+ if not ssml or ssml == "":
46
  raise HTTPException(status_code=400, detail="SSML content is required.")
47
 
48
+ if format not in ["mp3", "wav"]:
49
+ raise HTTPException(
50
+ status_code=400, detail="Format must be 'mp3' or 'wav'."
51
+ )
52
+
53
+ parser = create_ssml_parser()
54
+ segments = parser.parse(ssml)
55
  for seg in segments:
56
  seg["text"] = text_normalize(seg["text"], is_end=True)
57
 
 
69
  import logging
70
 
71
  logging.exception(e)
72
+
73
+ if isinstance(e, HTTPException):
74
+ raise e
75
+ else:
76
+ raise HTTPException(status_code=500, detail=str(e))
77
 
78
 
79
  def setup(api_manager: APIManager):
modules/api/impl/tts_api.py CHANGED
@@ -44,6 +44,39 @@ class TTSParams(BaseModel):
44
 
45
  async def synthesize_tts(params: TTSParams = Depends()):
46
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  text = text_normalize(params.text, is_end=False)
48
 
49
  calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style)
@@ -87,7 +120,11 @@ async def synthesize_tts(params: TTSParams = Depends()):
87
  import logging
88
 
89
  logging.exception(e)
90
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
91
 
92
 
93
  def setup(api_manager: APIManager):
 
44
 
45
  async def synthesize_tts(params: TTSParams = Depends()):
46
  try:
47
+ # Validate text
48
+ if not params.text.strip():
49
+ raise HTTPException(
50
+ status_code=422, detail="Text parameter cannot be empty"
51
+ )
52
+
53
+ # Validate temperature
54
+ if not (0 <= params.temperature <= 1):
55
+ raise HTTPException(
56
+ status_code=422, detail="Temperature must be between 0 and 1"
57
+ )
58
+
59
+ # Validate top_P
60
+ if not (0 <= params.top_P <= 1):
61
+ raise HTTPException(status_code=422, detail="top_P must be between 0 and 1")
62
+
63
+ # Validate top_K
64
+ if params.top_K <= 0:
65
+ raise HTTPException(
66
+ status_code=422, detail="top_K must be a positive integer"
67
+ )
68
+ if params.top_K > 100:
69
+ raise HTTPException(
70
+ status_code=422, detail="top_K must be less than or equal to 100"
71
+ )
72
+
73
+ # Validate format
74
+ if params.format not in ["mp3", "wav"]:
75
+ raise HTTPException(
76
+ status_code=422,
77
+ detail="Invalid format. Supported formats are mp3 and wav",
78
+ )
79
+
80
  text = text_normalize(params.text, is_end=False)
81
 
82
  calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style)
 
120
  import logging
121
 
122
  logging.exception(e)
123
+
124
+ if isinstance(e, HTTPException):
125
+ raise e
126
+ else:
127
+ raise HTTPException(status_code=500, detail=str(e))
128
 
129
 
130
  def setup(api_manager: APIManager):
modules/generate_audio.py CHANGED
@@ -79,7 +79,11 @@ def generate_audio_batch(
79
  params_infer_code["spk_emb"] = spk.emb
80
  logger.info(("spk", spk.name))
81
  else:
82
- raise ValueError(f"spk must be int or Speaker, but: <{type(spk)}> {spk}")
 
 
 
 
83
 
84
  logger.info(
85
  {
 
79
  params_infer_code["spk_emb"] = spk.emb
80
  logger.info(("spk", spk.name))
81
  else:
82
+ logger.warn(
83
+ f"spk must be int or Speaker, but: <{type(spk)}> {spk}, wiil set to default voice"
84
+ )
85
+ with SeedContext(2, True):
86
+ params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
87
 
88
  logger.info(
89
  {
modules/normalization.py CHANGED
@@ -5,6 +5,10 @@ from modules.utils.markdown import markdown_to_text
5
  from modules import models
6
  import re
7
 
 
 
 
 
8
 
9
  @lru_cache(maxsize=64)
10
  def is_chinese(text):
@@ -159,6 +163,8 @@ def replace_unk_tokens(text):
159
  """
160
  把不在字典里的字符替换为 " , "
161
  """
 
 
162
  chat_tts = models.load_chat_tts()
163
  if "tokenizer" not in chat_tts.pretrain_models:
164
  # 这个地方只有在 huggingface spaces 中才会触发
 
5
  from modules import models
6
  import re
7
 
8
+ # 是否关闭 unk token 检查
9
+ # NOTE: 单测的时候用于跳过模型加载
10
+ DISABLE_UNK_TOKEN_CHECK = False
11
+
12
 
13
  @lru_cache(maxsize=64)
14
  def is_chinese(text):
 
163
  """
164
  把不在字典里的字符替换为 " , "
165
  """
166
+ if DISABLE_UNK_TOKEN_CHECK:
167
+ return text
168
  chat_tts = models.load_chat_tts()
169
  if "tokenizer" not in chat_tts.pretrain_models:
170
  # 这个地方只有在 huggingface spaces 中才会触发
modules/ssml.py CHANGED
@@ -66,245 +66,3 @@ def apply_random_seed(attrs: dict):
66
  seed = random.randint(0, 2**32 - 1)
67
  attrs["seed"] = seed
68
  logger.info(f"random seed: {seed}")
69
-
70
-
71
- class NotSupportSSML(Exception):
72
- pass
73
-
74
-
75
- def parse_ssml(ssml: str) -> List[Dict[str, Any]]:
76
- root = etree.fromstring(ssml)
77
-
78
- ssml_version = root.get("version", "NONE")
79
- if ssml_version != "0.1":
80
- raise NotSupportSSML("Unsupported ssml version: {ssml_version}")
81
-
82
- segments = []
83
-
84
- for voice in root.findall(".//voice"):
85
- voice_attrs = {
86
- "spk": voice.get("spk"),
87
- "style": voice.get("style"),
88
- "seed": voice.get("seed"),
89
- "top_p": voice.get("top_p"),
90
- "top_k": voice.get("top_k"),
91
- "temp": voice.get("temp"),
92
- "prompt1": voice.get("prompt1"),
93
- "prompt2": voice.get("prompt2"),
94
- "prefix": voice.get("prefix"),
95
- "normalize": voice.get("normalize"),
96
- }
97
-
98
- voice_attrs = {k: v for k, v in voice_attrs.items() if v is not None}
99
-
100
- expand_spk(voice_attrs)
101
- expand_style(voice_attrs)
102
-
103
- merge_prompt(voice_attrs, voice)
104
- apply_random_seed(voice_attrs)
105
-
106
- voice_segments = []
107
-
108
- if voice_attrs.get("temp", "") == "min":
109
- # ref: https://github.com/2noise/ChatTTS/issues/123#issue-2326908144
110
- voice_attrs["temp"] = 0.000000000001
111
- if voice_attrs.get("temp", "") == "max":
112
- voice_attrs["temp"] = 1
113
-
114
- # 处理 voice 开头的文本
115
- if voice.text and voice.text.strip():
116
- voice_segments.append(
117
- {"text": voice.text.strip(), "attrs": voice_attrs.copy()}
118
- )
119
-
120
- # 处理 voice 内部的文本和 prosody 元素
121
- for node in voice.iterchildren():
122
- if node.tag == "prosody":
123
- prosody_attrs = voice_attrs.copy()
124
- new_attrs = {
125
- "rate": node.get("rate"),
126
- "volume": node.get("volume"),
127
- "pitch": node.get("pitch"),
128
- }
129
- prosody_attrs.update(
130
- {k: v for k, v in new_attrs.items() if v is not None}
131
- )
132
- expand_style(prosody_attrs)
133
- merge_prompt(prosody_attrs, node)
134
- apply_random_seed(voice_attrs)
135
-
136
- if node.text and node.text.strip():
137
- voice_segments.append(
138
- {"text": node.text.strip(), "attrs": prosody_attrs}
139
- )
140
- elif node.tag == "break":
141
- time_ms = int(node.get("time", "0").replace("ms", ""))
142
- segment = {"break": time_ms}
143
- voice_segments.append(segment)
144
-
145
- if node.tail and node.tail.strip():
146
- voice_segments.append(
147
- {"text": node.tail.strip(), "attrs": voice_attrs.copy()}
148
- )
149
-
150
- end_segment = voice_segments[-1]
151
- end_segment["is_end"] = True
152
-
153
- segments = segments + voice_segments
154
-
155
- logger.info(f"collect len(segments): {len(segments)}")
156
- # logger.info(f"segments: {json.dumps(segments, ensure_ascii=False)}")
157
-
158
- return segments
159
-
160
-
161
- if __name__ == "__main__":
162
- # 示例 SSML 输入
163
- ssml1 = """
164
- <speak version="0.1">
165
- <voice spk="20398768" seed="42" temp="min" top_p="0.9" top_k="20">
166
- 电影中梁朝伟扮演的陈永仁的
167
- <prosody volume="5">
168
- 编号27149
169
- </prosody>
170
- <prosody rate="2">
171
- 编号27149
172
- </prosody>
173
- <prosody pitch="-12">
174
- 编号27149
175
- </prosody>
176
- <prosody pitch="12">
177
- 编号27149
178
- </prosody>
179
- </voice>
180
- <voice spk="20398768" seed="42" speed="9">
181
- 编号27149
182
- </voice>
183
- <voice spk="20398768" seed="42">
184
- 电影中梁朝伟扮演的陈永仁的编号27149
185
- </voice>
186
- </speak>
187
- """
188
-
189
- ssml2 = """
190
- <speak version="0.1">
191
- <voice spk="Bob">
192
- 也可以合成多角色多情感的有声 [uv_break] 书 [uv_break] ,例如:
193
- </voice>
194
- <voice spk="Bob">
195
- 黛玉冷笑道:
196
- </voice>
197
- <voice spk="female2">
198
- 我说呢,亏了绊住,不然,早就飞了来了。
199
- </voice>
200
- <voice spk="Bob" speed="0">
201
- 宝玉道:
202
- </voice>
203
- <voice spk="Alice">
204
- “只许和你玩,替你解闷。不过偶然到他那里,就说这些闲话。”
205
- </voice>
206
- <voice spk="female2">
207
- “好没意思的话!去不去,关我什么事儿?又没叫你替我解闷儿,还许你不理我呢”
208
- </voice>
209
- <voice spk="Bob">
210
- 说着,便赌气回房去了。
211
- </voice>
212
- </speak>
213
- """
214
- ssml22 = """
215
- <speak version="0.1">
216
- <voice spk="Bob" style="narration-relaxed">
217
- 下面是一个 ChatTTS 用于合成多角色多情感的有声书示例
218
- </voice>
219
- <voice spk="Bob" style="narration-relaxed">
220
- 黛玉冷笑道:
221
- </voice>
222
- <voice spk="female2" style="angry">
223
- 我说呢 [uv_break] ,亏了绊住,不然,早就飞起来了。
224
- </voice>
225
- <voice spk="Bob" style="narration-relaxed">
226
- 宝玉道:
227
- </voice>
228
- <voice spk="Alice" style="unfriendly">
229
- “只许和你玩 [uv_break] ,替你解闷。不过偶然到他那里,就说这些闲话。”
230
- </voice>
231
- <voice spk="female2" style="angry">
232
- “好没意思的话![uv_break] 去不去,关我什么事儿? 又没叫你替我解闷儿 [uv_break],还许你不理我呢”
233
- </voice>
234
- <voice spk="Bob" style="narration-relaxed">
235
- 说着,便赌气回房去了。
236
- </voice>
237
- </speak>
238
- """
239
-
240
- ssml3 = """
241
- <speak version="0.1">
242
- <voice spk="Bob" style="angry">
243
- “你到底在想什么?这已经是第三次了!每次我都告诉你要按时完成任务,可你总是拖延。你知道这对整个团队有多大的影响吗?!”
244
- </voice>
245
- <voice spk="Bob" style="assistant">
246
- “你到底在想什么?这已经是第三次了!每次我都告诉你要按时完成任务,可你总是拖延。你知道这对整个团队有多大的影响吗?!”
247
- </voice>
248
- <voice spk="Bob" style="gentle">
249
- “你到底在想什么?这已经是第三次了!每次我都告诉你要按时完成任务,可你总是拖延。你知道这对整个团队有多大的影响吗?!”
250
- </voice>
251
- </speak>
252
- """
253
-
254
- ssml4 = """
255
- <speak version="0.1">
256
- <voice spk="Bob" style="narration-relaxed">
257
- 使用 prosody 控制生成文本的语速语调和音量,示例如下
258
-
259
- <prosody>
260
- 无任何限制将会继承父级voice配置进行生成
261
- </prosody>
262
- <prosody rate="1.5">
263
- 设置 rate 大于1表示加速,小于1为减速
264
- </prosody>
265
- <prosody pitch="6">
266
- 设置 pitch 调整音调,设置为6表示提高6个半音
267
- </prosody>
268
- <prosody volume="2">
269
- 设置 volume 调整音量,设置为2表示提高2个分贝
270
- </prosody>
271
-
272
- 在 voice 中无prosody包裹的文本即为默认生成状态下的语音
273
- </voice>
274
- </speak>
275
- """
276
-
277
- ssml5 = """
278
- <speak version="0.1">
279
- <voice spk="Bob" style="narration-relaxed">
280
- 使用 break 标签将会简单的
281
-
282
- <break time="500" />
283
-
284
- 插入一段空白到生成结果中
285
- </voice>
286
- </speak>
287
- """
288
-
289
- ssml6 = """
290
- <speak version="0.1">
291
- <voice spk="Bob" style="excited">
292
- temperature for sampling (may be overridden by style or speaker)
293
- <break time="500" />
294
- 温度值用于采样,这个值有可能被 style 或者 speaker 覆盖
295
- <break time="500" />
296
- temperature for sampling ,这个值有可能被 style 或者 speaker 覆盖
297
- <break time="500" />
298
- 温度值用于采样,(may be overridden by style or speaker)
299
- </voice>
300
- </speak>
301
- """
302
-
303
- segments = parse_ssml(ssml6)
304
-
305
- print(segments)
306
-
307
- # audio_segments = synthesize_segments(segments)
308
- # combined_audio = combine_audio_segments(audio_segments)
309
-
310
- # combined_audio.export("output.wav", format="wav")
 
66
  seed = random.randint(0, 2**32 - 1)
67
  attrs["seed"] = seed
68
  logger.info(f"random seed: {seed}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/ssml_parser/SSMLParser.py CHANGED
@@ -29,6 +29,12 @@ class SSMLContext(Box):
29
  self.prompt2 = None
30
  self.prefix = None
31
 
 
 
 
 
 
 
32
 
33
  class SSMLSegment(Box):
34
  def __init__(self, text: str, attrs=SSMLContext()):
@@ -84,7 +90,7 @@ def create_ssml_parser():
84
 
85
  @parser.resolver("speak")
86
  def tag_speak(element, context, segments, parser):
87
- ctx = copy.deepcopy(context)
88
 
89
  version = element.get("version")
90
  if version != "0.1":
@@ -95,7 +101,7 @@ def create_ssml_parser():
95
 
96
  @parser.resolver("voice")
97
  def tag_voice(element, context, segments, parser):
98
- ctx = copy.deepcopy(context)
99
 
100
  ctx.spk = element.get("spk", ctx.spk)
101
  ctx.style = element.get("style", ctx.style)
@@ -131,7 +137,7 @@ def create_ssml_parser():
131
 
132
  @parser.resolver("prosody")
133
  def tag_prosody(element, context, segments, parser):
134
- ctx = copy.deepcopy(context)
135
 
136
  ctx.spk = element.get("spk", ctx.spk)
137
  ctx.style = element.get("style", ctx.style)
 
29
  self.prompt2 = None
30
  self.prefix = None
31
 
32
+ def clone(self):
33
+ ctx = SSMLContext()
34
+ for k, v in self.items():
35
+ ctx[k] = v
36
+ return ctx
37
+
38
 
39
  class SSMLSegment(Box):
40
  def __init__(self, text: str, attrs=SSMLContext()):
 
90
 
91
  @parser.resolver("speak")
92
  def tag_speak(element, context, segments, parser):
93
+ ctx = context.clone() if context is not None else SSMLContext()
94
 
95
  version = element.get("version")
96
  if version != "0.1":
 
101
 
102
  @parser.resolver("voice")
103
  def tag_voice(element, context, segments, parser):
104
+ ctx = context.clone() if context is not None else SSMLContext()
105
 
106
  ctx.spk = element.get("spk", ctx.spk)
107
  ctx.style = element.get("style", ctx.style)
 
137
 
138
  @parser.resolver("prosody")
139
  def tag_prosody(element, context, segments, parser):
140
+ ctx = context.clone() if context is not None else SSMLContext()
141
 
142
  ctx.spk = element.get("spk", ctx.spk)
143
  ctx.style = element.get("style", ctx.style)
modules/utils/git.py CHANGED
@@ -2,14 +2,25 @@ from functools import lru_cache
2
  import os
3
  import subprocess
4
 
 
5
  from modules.utils import constants
6
 
 
 
 
 
 
 
7
  git = os.environ.get("GIT", "git")
8
 
 
 
9
 
10
  @lru_cache()
11
  def commit_hash():
12
  try:
 
 
13
  return subprocess.check_output(
14
  [git, "-C", constants.ROOT_DIR, "rev-parse", "HEAD"],
15
  shell=False,
@@ -22,6 +33,8 @@ def commit_hash():
22
  @lru_cache()
23
  def git_tag():
24
  try:
 
 
25
  return subprocess.check_output(
26
  [git, "-C", constants.ROOT_DIR, "describe", "--tags"],
27
  shell=False,
@@ -44,6 +57,8 @@ def git_tag():
44
  @lru_cache()
45
  def branch_name():
46
  try:
 
 
47
  return subprocess.check_output(
48
  [git, "-C", constants.ROOT_DIR, "rev-parse", "--abbrev-ref", "HEAD"],
49
  shell=False,
 
2
  import os
3
  import subprocess
4
 
5
+
6
  from modules.utils import constants
7
 
8
+ # 用于判断是否在hf spaces
9
+ try:
10
+ import spaces
11
+ except:
12
+ spaces = None
13
+
14
  git = os.environ.get("GIT", "git")
15
 
16
+ in_hf_spaces = spaces is not None
17
+
18
 
19
  @lru_cache()
20
  def commit_hash():
21
  try:
22
+ if in_hf_spaces:
23
+ return "<hf>"
24
  return subprocess.check_output(
25
  [git, "-C", constants.ROOT_DIR, "rev-parse", "HEAD"],
26
  shell=False,
 
33
  @lru_cache()
34
  def git_tag():
35
  try:
36
+ if in_hf_spaces:
37
+ return "<hf>"
38
  return subprocess.check_output(
39
  [git, "-C", constants.ROOT_DIR, "describe", "--tags"],
40
  shell=False,
 
57
  @lru_cache()
58
  def branch_name():
59
  try:
60
+ if in_hf_spaces:
61
+ return "<hf>"
62
  return subprocess.check_output(
63
  [git, "-C", constants.ROOT_DIR, "rev-parse", "--abbrev-ref", "HEAD"],
64
  shell=False,
modules/utils/markdown.py CHANGED
@@ -46,6 +46,10 @@ class PlainTextRenderer(mistune.HTMLRenderer):
46
  # remove code
47
  return ""
48
 
 
 
 
 
49
 
50
  def markdown_to_text(markdown_text):
51
  renderer = PlainTextRenderer()
@@ -69,6 +73,9 @@ console.log(1)
69
  - 列表项 2
70
  - 列表项 3
71
 
 
 
 
72
  > 这是一个引用。
73
 
74
  `代码片段`
 
46
  # remove code
47
  return ""
48
 
49
+ def thematic_break(self) -> str:
50
+ # remove break
51
+ return "\n"
52
+
53
 
54
  def markdown_to_text(markdown_text):
55
  renderer = PlainTextRenderer()
 
73
  - 列表项 2
74
  - 列表项 3
75
 
76
+ 1. 第一
77
+ 2. 第二
78
+
79
  > 这是一个引用。
80
 
81
  `代码片段`
modules/webui/localization.py CHANGED
@@ -1,7 +1,9 @@
1
  import json
2
  import os
3
  import gradio as gr
 
4
 
 
5
 
6
  current_translation = {}
7
  localization_root = os.path.join(
@@ -24,11 +26,15 @@ def localization_js(filename):
24
  assert isinstance(v, str) or isinstance(
25
  v, list
26
  ), f"Value for key {k} is not a string or list"
 
 
27
  except Exception as e:
28
- print(str(e))
29
- print(f"Failed to load localization file {full_name}")
30
  else:
31
- print(f"Localization file {full_name} not found")
 
 
32
 
33
  # current_translation = {k: 'XXX' for k in current_translation.keys()} # use this to see if all texts are covered
34
 
 
1
  import json
2
  import os
3
  import gradio as gr
4
+ import logging
5
 
6
+ logger = logging.getLogger(__name__)
7
 
8
  current_translation = {}
9
  localization_root = os.path.join(
 
26
  assert isinstance(v, str) or isinstance(
27
  v, list
28
  ), f"Value for key {k} is not a string or list"
29
+
30
+ logger.info(f"Loaded localization file {full_name}")
31
  except Exception as e:
32
+ logger.warning(str(e))
33
+ logger.warning(f"Failed to load localization file {full_name}")
34
  else:
35
+ logger.warning(f"Localization file {full_name} does not exist")
36
+ else:
37
+ logger.warning(f"Localization file {filename} is not a string")
38
 
39
  # current_translation = {k: 'XXX' for k in current_translation.keys()} # use this to see if all texts are covered
40
 
modules/webui/speaker/speaker_editor.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from modules.speaker import Speaker
4
+ from modules.hf import spaces
5
+ from modules.webui import webui_config
6
+ from modules.webui.webui_utils import tts_generate
7
+
8
+ import tempfile
9
+
10
+
11
+ @torch.inference_mode()
12
+ @spaces.GPU
13
+ def test_spk_voice(spk_file, text: str):
14
+ if spk_file == "" or spk_file is None:
15
+ return None
16
+ spk = Speaker.from_file(spk_file)
17
+ return tts_generate(
18
+ spk=spk,
19
+ text=text,
20
+ )
21
+
22
+
23
+ def speaker_editor_ui():
24
+ def on_generate(spk_file, name, gender, desc):
25
+ spk: Speaker = Speaker.from_file(spk_file)
26
+ spk.name = name
27
+ spk.gender = gender
28
+ spk.desc = desc
29
+
30
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
31
+ torch.save(spk, tmp_file)
32
+ tmp_file_path = tmp_file.name
33
+
34
+ return tmp_file_path
35
+
36
+ def create_test_voice_card(spk_file):
37
+ with gr.Group():
38
+ gr.Markdown("🎤Test voice")
39
+ with gr.Row():
40
+ test_voice_btn = gr.Button(
41
+ "Test Voice", variant="secondary", interactive=False
42
+ )
43
+
44
+ with gr.Column(scale=4):
45
+ test_text = gr.Textbox(
46
+ label="Test Text",
47
+ placeholder="Please input test text",
48
+ value=webui_config.localization.DEFAULT_SPEAKER_TEST_TEXT,
49
+ )
50
+ with gr.Row():
51
+ with gr.Column(scale=4):
52
+ output_audio = gr.Audio(label="Output Audio", format="mp3")
53
+
54
+ test_voice_btn.click(
55
+ fn=test_spk_voice,
56
+ inputs=[spk_file, test_text],
57
+ outputs=[output_audio],
58
+ )
59
+
60
+ return test_voice_btn
61
+
62
+ has_file = gr.State(False)
63
+
64
+ # TODO 也许需要写个说明?
65
+ # gr.Markdown("SPEAKER_CREATOR_GUIDE")
66
+
67
+ with gr.Row():
68
+ with gr.Column(scale=2):
69
+ with gr.Group():
70
+ gr.Markdown("💼Speaker file")
71
+ spk_file = gr.File(label="*.pt file", file_types=[".pt"])
72
+
73
+ with gr.Group():
74
+ gr.Markdown("ℹ️Speaker info")
75
+ name_input = gr.Textbox(
76
+ label="Name",
77
+ placeholder="Enter speaker name",
78
+ value="*",
79
+ interactive=False,
80
+ )
81
+ gender_input = gr.Textbox(
82
+ label="Gender",
83
+ placeholder="Enter gender",
84
+ value="*",
85
+ interactive=False,
86
+ )
87
+ desc_input = gr.Textbox(
88
+ label="Description",
89
+ placeholder="Enter description",
90
+ value="*",
91
+ interactive=False,
92
+ )
93
+ with gr.Group():
94
+ gr.Markdown("🔊Generate speaker.pt")
95
+ generate_button = gr.Button("Save .pt file", interactive=False)
96
+ output_file = gr.File(label="Save to File")
97
+ with gr.Column(scale=5):
98
+ btn1 = create_test_voice_card(spk_file=spk_file)
99
+ btn2 = create_test_voice_card(spk_file=spk_file)
100
+ btn3 = create_test_voice_card(spk_file=spk_file)
101
+ btn4 = create_test_voice_card(spk_file=spk_file)
102
+
103
+ generate_button.click(
104
+ fn=on_generate,
105
+ inputs=[spk_file, name_input, gender_input, desc_input],
106
+ outputs=[output_file],
107
+ )
108
+
109
+ def spk_file_change(spk_file):
110
+ empty = spk_file is None or spk_file == ""
111
+ if empty:
112
+ return [
113
+ gr.Textbox(value="*", interactive=False),
114
+ gr.Textbox(value="*", interactive=False),
115
+ gr.Textbox(value="*", interactive=False),
116
+ gr.Button(interactive=False),
117
+ gr.Button(interactive=False),
118
+ gr.Button(interactive=False),
119
+ gr.Button(interactive=False),
120
+ gr.Button(interactive=False),
121
+ ]
122
+ spk: Speaker = Speaker.from_file(spk_file)
123
+ return [
124
+ gr.Textbox(value=spk.name, interactive=True),
125
+ gr.Textbox(value=spk.gender, interactive=True),
126
+ gr.Textbox(value=spk.describe, interactive=True),
127
+ gr.Button(interactive=True),
128
+ gr.Button(interactive=True),
129
+ gr.Button(interactive=True),
130
+ gr.Button(interactive=True),
131
+ gr.Button(interactive=True),
132
+ ]
133
+
134
+ spk_file.change(
135
+ fn=spk_file_change,
136
+ inputs=[spk_file],
137
+ outputs=[
138
+ name_input,
139
+ gender_input,
140
+ desc_input,
141
+ generate_button,
142
+ btn1,
143
+ btn2,
144
+ btn3,
145
+ btn4,
146
+ ],
147
+ )
modules/webui/speaker_tab.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
 
 
3
  from modules.webui.speaker.speaker_merger import create_speaker_merger
4
  from modules.webui.speaker.speaker_creator import speaker_creator_ui
5
 
@@ -7,6 +8,8 @@ from modules.webui.speaker.speaker_creator import speaker_creator_ui
7
  def create_speaker_panel():
8
 
9
  with gr.Tabs():
 
 
10
  with gr.TabItem("Creator"):
11
  speaker_creator_ui()
12
  with gr.TabItem("Merger"):
 
1
  import gradio as gr
2
 
3
+ from modules.webui.speaker.speaker_editor import speaker_editor_ui
4
  from modules.webui.speaker.speaker_merger import create_speaker_merger
5
  from modules.webui.speaker.speaker_creator import speaker_creator_ui
6
 
 
8
  def create_speaker_panel():
9
 
10
  with gr.Tabs():
11
+ with gr.Tab("Editor"):
12
+ speaker_editor_ui()
13
  with gr.TabItem("Creator"):
14
  speaker_creator_ui()
15
  with gr.TabItem("Merger"):
modules/webui/ssml/podcast_tab.py CHANGED
@@ -7,45 +7,65 @@ from modules.webui import webui_utils
7
  from modules.hf import spaces
8
 
9
  podcast_default_case = [
10
- [1, "female2", "你好,欢迎收听今天的播客内容。今天我们要聊的是中华料理。", "chat"],
11
- [2, "Alice", "嗨,我特别期待这个话题!中华料理真的是博大精深。", "chat"],
 
 
 
 
 
 
 
 
 
 
12
  [
13
  3,
14
  "Bob",
15
- "没错,中华料理有着几千年的历史,而且每个地区都有自己的特色菜。",
16
- "chat",
17
  ],
18
  [
19
  4,
20
  "female2",
21
- "那我们先从最有名的川菜开始吧。川菜以其麻辣著称,是很多人的最爱。",
22
- "chat",
23
  ],
24
  [
25
  5,
26
  "Alice",
27
- "对,我特别喜欢吃麻婆豆腐和辣子鸡。那种麻辣的感觉真是让人难以忘怀。",
28
- "chat",
29
  ],
30
  [
31
  6,
32
  "Bob",
33
- "除了川菜,粤菜也是很受欢迎的。粤菜讲究鲜美,像是白切鸡和蒸鱼都是经典。",
34
- "chat",
 
 
 
 
 
 
 
 
 
 
 
 
35
  ],
36
- [7, "female2", "对啊,粤菜的烹饪方式比较清淡,更注重食材本身的味道。", "chat"],
37
- [8, "Alice", "还有北京的京菜,像北京烤鸭,那可是来北京必吃的美食。", "chat"],
38
  [
39
  9,
40
  "Bob",
41
- "不仅如此,还有淮扬菜、湘菜、鲁菜等等,每个菜系都有其独特的风味。",
42
- "chat",
43
  ],
44
  [
45
  10,
46
  "female2",
47
- "对对对,像淮扬菜的狮子头,湘菜的剁椒鱼头,都是让人垂涎三尺的美味。",
48
- "chat",
49
  ],
50
  ]
51
 
@@ -111,10 +131,11 @@ def create_ssml_podcast_tab(ssml_input: gr.Textbox, tabs1: gr.Tabs, tabs2: gr.Ta
111
  script_table = gr.DataFrame(
112
  headers=["index", "speaker", "text", "style"],
113
  datatype=["number", "str", "str", "str"],
114
- interactive=False,
115
  wrap=True,
116
  value=podcast_default_case,
117
  row_count=(0, "dynamic"),
 
118
  )
119
 
120
  send_to_ssml_btn = gr.Button("📩Send to SSML", variant="primary")
 
7
  from modules.hf import spaces
8
 
9
  podcast_default_case = [
10
+ [
11
+ 1,
12
+ "female2",
13
+ "你好,欢迎收听今天的播客内容。今天我们要聊的是中华料理。 [lbreak]",
14
+ "podcast_p",
15
+ ],
16
+ [
17
+ 2,
18
+ "Alice",
19
+ "嗨,我特别期待这个话题!中华料理真的是博大精深。 [lbreak]",
20
+ "podcast_p",
21
+ ],
22
  [
23
  3,
24
  "Bob",
25
+ "没错,中华料理有着几千年的历史,而且每个地区都有自己的特色菜。 [lbreak]",
26
+ "podcast_p",
27
  ],
28
  [
29
  4,
30
  "female2",
31
+ "那我们先从最有名的川菜开始吧。川菜以其麻辣著称,是很多人的最爱。 [lbreak]",
32
+ "podcast_p",
33
  ],
34
  [
35
  5,
36
  "Alice",
37
+ "对,我特别喜欢吃麻婆豆腐和辣子鸡。那种麻辣的感觉真是让人难以忘怀。 [lbreak]",
38
+ "podcast_p",
39
  ],
40
  [
41
  6,
42
  "Bob",
43
+ "除了川菜,粤菜也是很受欢迎的。粤菜讲究鲜美,像是白切鸡和蒸鱼都是经典。 [lbreak]",
44
+ "podcast_p",
45
+ ],
46
+ [
47
+ 7,
48
+ "female2",
49
+ "对啊,粤菜的烹饪方式比较清淡,更注重食材本身的味道。 [lbreak]",
50
+ "podcast_p",
51
+ ],
52
+ [
53
+ 8,
54
+ "Alice",
55
+ "还有北京的京菜,像北京烤鸭,那可是来北京必吃的美食。 [lbreak]",
56
+ "podcast_p",
57
  ],
 
 
58
  [
59
  9,
60
  "Bob",
61
+ "不仅如此,还有淮扬菜、湘菜、鲁菜等等,每个菜系都有其独特的风味。 [lbreak]",
62
+ "podcast_p",
63
  ],
64
  [
65
  10,
66
  "female2",
67
+ "对对对,像淮扬菜的狮子头,湘菜的剁椒鱼头,都是让人垂涎三尺的美味。 [lbreak]",
68
+ "podcast_p",
69
  ],
70
  ]
71
 
 
131
  script_table = gr.DataFrame(
132
  headers=["index", "speaker", "text", "style"],
133
  datatype=["number", "str", "str", "str"],
134
+ interactive=True,
135
  wrap=True,
136
  value=podcast_default_case,
137
  row_count=(0, "dynamic"),
138
+ col_count=(4, "fixed"),
139
  )
140
 
141
  send_to_ssml_btn = gr.Button("📩Send to SSML", variant="primary")
modules/webui/tts_tab.py CHANGED
@@ -91,7 +91,9 @@ def create_tts_interface():
91
  )
92
 
93
  with gr.Tab(label="Upload"):
94
- spk_file_upload = gr.File(label="Speaker (Upload)")
 
 
95
 
96
  gr.Markdown("📝Speaker info")
97
  infos = gr.Markdown("empty")
 
91
  )
92
 
93
  with gr.Tab(label="Upload"):
94
+ spk_file_upload = gr.File(
95
+ label="Speaker (Upload)", file_types=[".pt"]
96
+ )
97
 
98
  gr.Markdown("📝Speaker info")
99
  infos = gr.Markdown("empty")
modules/webui/webui_utils.py CHANGED
@@ -93,13 +93,11 @@ def apply_audio_enhance(audio_data, sr, enable_denoise, enable_enhance):
93
  tensor = torch.from_numpy(audio_data).float().squeeze().cpu()
94
  enhancer = load_enhancer(device)
95
 
96
- if enable_enhance:
97
  lambd = 0.9 if enable_denoise else 0.1
98
  tensor, sr = enhancer.enhance(
99
  tensor, sr, tau=0.5, nfe=64, solver="rk4", lambd=lambd, device=device
100
  )
101
- elif enable_denoise:
102
- tensor, sr = enhancer.denoise(tensor, sr)
103
 
104
  audio_data = tensor.cpu().numpy()
105
  return audio_data, int(sr)
 
93
  tensor = torch.from_numpy(audio_data).float().squeeze().cpu()
94
  enhancer = load_enhancer(device)
95
 
96
+ if enable_enhance or enable_denoise:
97
  lambd = 0.9 if enable_denoise else 0.1
98
  tensor, sr = enhancer.enhance(
99
  tensor, sr, tau=0.5, nfe=64, solver="rk4", lambd=lambd, device=device
100
  )
 
 
101
 
102
  audio_data = tensor.cpu().numpy()
103
  return audio_data, int(sr)
webui.py CHANGED
@@ -84,7 +84,6 @@ if __name__ == "__main__":
84
  parser.add_argument(
85
  "--language",
86
  type=str,
87
- default="zh-CN",
88
  help="Set the default language for the webui",
89
  )
90
  args = parser.parse_args()
@@ -106,7 +105,7 @@ if __name__ == "__main__":
106
  device_id = get_and_update_env(args, "device_id", None, str)
107
  use_cpu = get_and_update_env(args, "use_cpu", [], list)
108
  compile = get_and_update_env(args, "compile", False, bool)
109
- language = get_and_update_env(args, "language", False, bool)
110
 
111
  webui_config.experimental = get_and_update_env(
112
  args, "webui_experimental", False, bool
@@ -115,8 +114,6 @@ if __name__ == "__main__":
115
  webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
116
  webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
117
 
118
- config.runtime_env_vars.language = "zh-CN"
119
-
120
  webui_init()
121
  demo = create_interface()
122
 
 
84
  parser.add_argument(
85
  "--language",
86
  type=str,
 
87
  help="Set the default language for the webui",
88
  )
89
  args = parser.parse_args()
 
105
  device_id = get_and_update_env(args, "device_id", None, str)
106
  use_cpu = get_and_update_env(args, "use_cpu", [], list)
107
  compile = get_and_update_env(args, "compile", False, bool)
108
+ language = get_and_update_env(args, "language", "zh-CN", str)
109
 
110
  webui_config.experimental = get_and_update_env(
111
  args, "webui_experimental", False, bool
 
114
  webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
115
  webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
116
 
 
 
117
  webui_init()
118
  demo = create_interface()
119