zhzluke96 commited on
Commit
374f426
·
1 Parent(s): 650b56c
modules/config.py CHANGED
@@ -5,8 +5,11 @@ from modules.utils.JsonObject import JsonObject
5
 
6
  from modules.utils import git
7
 
 
8
  runtime_env_vars = JsonObject({})
9
 
 
 
10
  api = None
11
 
12
  versions = JsonObject(
 
5
 
6
  from modules.utils import git
7
 
8
+ # TODO impl RuntimeEnvVars() class
9
  runtime_env_vars = JsonObject({})
10
 
11
+ auto_gc = True
12
+
13
  api = None
14
 
15
  versions = JsonObject(
modules/denoise.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from audio_denoiser.AudioDenoiser import AudioDenoiser
2
+ import torch
3
+ import torchaudio
4
+
5
+
6
+ class TTSAudioDenoiser:
7
+ pass
modules/devices/devices.py CHANGED
@@ -120,12 +120,14 @@ dtype_decoder: torch.dtype = torch.float32
120
 
121
 
122
  def reset_device():
 
 
 
 
 
 
 
123
  if config.runtime_env_vars.half:
124
- global dtype
125
- global dtype_dvae
126
- global dtype_vocos
127
- global dtype_gpt
128
- global dtype_decoder
129
  dtype = torch.float16
130
  dtype_dvae = torch.float16
131
  dtype_vocos = torch.float16
@@ -133,15 +135,21 @@ def reset_device():
133
  dtype_decoder = torch.float16
134
 
135
  logger.info("Using half precision: torch.float16")
 
 
 
 
 
 
 
 
136
 
137
- if (
138
- config.runtime_env_vars.device_id is not None
139
- or config.runtime_env_vars.use_cpu is not None
140
- ):
141
- global device
142
  device = get_optimal_device()
143
 
144
- logger.info(f"Using device: {device}")
145
 
146
 
147
  @lru_cache
 
120
 
121
 
122
  def reset_device():
123
+ global device
124
+ global dtype
125
+ global dtype_dvae
126
+ global dtype_vocos
127
+ global dtype_gpt
128
+ global dtype_decoder
129
+
130
  if config.runtime_env_vars.half:
 
 
 
 
 
131
  dtype = torch.float16
132
  dtype_dvae = torch.float16
133
  dtype_vocos = torch.float16
 
135
  dtype_decoder = torch.float16
136
 
137
  logger.info("Using half precision: torch.float16")
138
+ else:
139
+ dtype = torch.float32
140
+ dtype_dvae = torch.float32
141
+ dtype_vocos = torch.float32
142
+ dtype_gpt = torch.float32
143
+ dtype_decoder = torch.float32
144
+
145
+ logger.info("Using full precision: torch.float32")
146
 
147
+ if config.runtime_env_vars.use_cpu == "all":
148
+ device = cpu
149
+ else:
 
 
150
  device = get_optimal_device()
151
 
152
+ logger.info(f"Using device: {device}")
153
 
154
 
155
  @lru_cache
modules/generate_audio.py CHANGED
@@ -7,6 +7,7 @@ from modules.utils.SeedContext import SeedContext
7
  from modules import models, config
8
 
9
  import logging
 
10
 
11
  from modules.devices import devices
12
  from typing import Union
@@ -100,7 +101,9 @@ def generate_audio_batch(
100
 
101
  sample_rate = 24000
102
 
103
- devices.torch_gc()
 
 
104
 
105
  return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
106
 
 
7
  from modules import models, config
8
 
9
  import logging
10
+ import gc
11
 
12
  from modules.devices import devices
13
  from typing import Union
 
101
 
102
  sample_rate = 24000
103
 
104
+ if config.auto_gc:
105
+ devices.torch_gc()
106
+ gc.collect()
107
 
108
  return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
109
 
modules/hf.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 给huggingface space写的兼容代码
2
+
3
+ try:
4
+ import spaces
5
+ except:
6
+
7
+ class NoneSpaces:
8
+ def __init__(self):
9
+ pass
10
+
11
+ def GPU(self, fn):
12
+ return fn
13
+
14
+ spaces = NoneSpaces()
modules/models.py CHANGED
@@ -51,7 +51,8 @@ def initialize_chat_tts():
51
  def load_chat_tts():
52
  if chat_tts is None:
53
  initialize_chat_tts().join()
54
- load_event.wait()
 
55
  return chat_tts
56
 
57
 
 
51
  def load_chat_tts():
52
  if chat_tts is None:
53
  initialize_chat_tts().join()
54
+ if chat_tts is None:
55
+ raise Exception("Failed to load ChatTTS models")
56
  return chat_tts
57
 
58
 
modules/utils/SeedContext.py CHANGED
@@ -2,6 +2,9 @@ import torch
2
  import random
3
  import numpy as np
4
  from modules.utils import rng
 
 
 
5
 
6
 
7
  def deterministic(seed=0):
@@ -59,8 +62,11 @@ class SeedContext:
59
  try:
60
  deterministic(self.seed)
61
  except Exception as e:
62
- raise ValueError(
63
- f"Seed must be an integer, but: <{type(self.seed)}> {self.seed}"
 
 
 
64
  )
65
 
66
  def __exit__(self, exc_type, exc_value, traceback):
 
2
  import random
3
  import numpy as np
4
  from modules.utils import rng
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
 
9
 
10
  def deterministic(seed=0):
 
62
  try:
63
  deterministic(self.seed)
64
  except Exception as e:
65
+ # raise ValueError(
66
+ # f"Seed must be an integer, but: <{type(self.seed)}> {self.seed}"
67
+ # )
68
+ logger.warning(
69
+ f"Deterministic field, with: <{type(self.seed)}> {self.seed}"
70
  )
71
 
72
  def __exit__(self, exc_type, exc_value, traceback):
modules/webui/__init__.py ADDED
File without changes
modules/webui/app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ import gradio as gr
6
+
7
+ from modules import config
8
+
9
+ from modules.webui.tts_tab import create_tts_interface
10
+ from modules.webui.ssml_tab import create_ssml_interface
11
+ from modules.webui.spliter_tab import create_spliter_tab
12
+ from modules.webui.speaker_tab import create_speaker_panel
13
+ from modules.webui.readme_tab import create_readme_tab
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ logging.basicConfig(
18
+ level=os.getenv("LOG_LEVEL", "INFO"),
19
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
20
+ )
21
+
22
+
23
+ def webui_init():
24
+ # fix: If the system proxy is enabled in the Windows system, you need to skip these
25
+ os.environ["NO_PROXY"] = "localhost,127.0.0.1,0.0.0.0"
26
+
27
+ torch._dynamo.config.cache_size_limit = 64
28
+ torch._dynamo.config.suppress_errors = True
29
+ torch.set_float32_matmul_precision("high")
30
+
31
+ logger.info("WebUI module initialized")
32
+
33
+
34
+ def create_app_footer():
35
+ gradio_version = gr.__version__
36
+ git_tag = config.versions.git_tag
37
+ git_commit = config.versions.git_commit
38
+ git_branch = config.versions.git_branch
39
+ python_version = config.versions.python_version
40
+ torch_version = config.versions.torch_version
41
+
42
+ config.versions.gradio_version = gradio_version
43
+
44
+ gr.Markdown(
45
+ f"""
46
+ 🍦 [ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
47
+ version: [{git_tag}](https://github.com/lenML/ChatTTS-Forge/commit/{git_commit}) | branch: `{git_branch}` | python: `{python_version}` | torch: `{torch_version}`
48
+ """
49
+ )
50
+
51
+
52
+ def create_interface():
53
+
54
+ js_func = """
55
+ function refresh() {
56
+ const url = new URL(window.location);
57
+
58
+ if (url.searchParams.get('__theme') !== 'dark') {
59
+ url.searchParams.set('__theme', 'dark');
60
+ window.location.href = url.href;
61
+ }
62
+ }
63
+ """
64
+
65
+ head_js = """
66
+ <script>
67
+ </script>
68
+ """
69
+
70
+ with gr.Blocks(js=js_func, head=head_js, title="ChatTTS Forge WebUI") as demo:
71
+ css = """
72
+ <style>
73
+ .big-button {
74
+ height: 80px;
75
+ }
76
+ #input_title div.eta-bar {
77
+ display: none !important; transform: none !important;
78
+ }
79
+ footer {
80
+ display: none !important;
81
+ }
82
+ </style>
83
+ """
84
+
85
+ gr.HTML(css)
86
+ with gr.Tabs() as tabs:
87
+ with gr.TabItem("TTS"):
88
+ create_tts_interface()
89
+
90
+ with gr.TabItem("SSML", id="ssml"):
91
+ ssml_input = create_ssml_interface()
92
+
93
+ with gr.TabItem("Spilter"):
94
+ create_spliter_tab(ssml_input, tabs=tabs)
95
+
96
+ if config.runtime_env_vars.webui_experimental:
97
+ with gr.TabItem("Speaker"):
98
+ create_speaker_panel()
99
+ with gr.TabItem("Denoise"):
100
+ gr.Markdown("🚧 Under construction")
101
+ with gr.TabItem("Inpainting"):
102
+ gr.Markdown("🚧 Under construction")
103
+ with gr.TabItem("ASR"):
104
+ gr.Markdown("🚧 Under construction")
105
+
106
+ with gr.TabItem("README"):
107
+ create_readme_tab()
108
+
109
+ create_app_footer()
110
+ return demo
modules/webui/asr_tab.py ADDED
File without changes
modules/webui/denoise_tab.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def create_denoise_tab():
5
+ pass
modules/webui/examples.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ example_texts = [
2
+ {
3
+ "text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]",
4
+ },
5
+ {"text": "Big 🍌, a big 🍌, hey, your feeling is really wonderful [lbreak]"},
6
+ {
7
+ "text": """
8
+ # 这是 markdown 标题
9
+
10
+ ```
11
+ 代码块将跳过
12
+ ```
13
+
14
+ - **文本标准化**:
15
+ - **Markdown**: 自动检测处理 markdown 格式文本。
16
+ - **数字转写**: 自动将数字转为模型可识别的文本。
17
+ - **Emoji 适配**: 自动翻译 emoji 为可读文本。
18
+ - **基于分词器**: 基于 tokenizer 预处理文本,覆盖模型所有不支持字符范围。
19
+ - **中英文识别**: 适配英文环境。
20
+ """
21
+ },
22
+ {
23
+ "text": "天气预报显示,今天会有小雨,请大家出门时记得带伞。降温的天气也提醒我们要适时添衣保暖 [lbreak]",
24
+ },
25
+ {
26
+ "text": "公司的年度总结会议将在下周三举行,请各部门提前准备好相关材料,确保会议顺利进行 [lbreak]",
27
+ },
28
+ {
29
+ "text": "今天的午餐菜单包括烤鸡、沙拉和蔬菜汤,大家可以根据自己的口味选择适合的菜品 [lbreak]",
30
+ },
31
+ {
32
+ "text": "请注意,电梯将在下午两点进行例行维护,预计需要一个小时的时间,请大家在此期间使用楼梯 [lbreak]",
33
+ },
34
+ {
35
+ "text": "图书馆新到了一批书籍,涵盖了文学、科学和历史等多个领域,欢迎大家前来借阅 [lbreak]",
36
+ },
37
+ {
38
+ "text": "电影中梁朝伟扮演的陈永仁的编号27149 [lbreak]",
39
+ },
40
+ {
41
+ "text": "这块黄金重达324.75克 [lbreak]",
42
+ },
43
+ {
44
+ "text": "我们班的最高总分为583分 [lbreak]",
45
+ },
46
+ {
47
+ "text": "12~23 [lbreak]",
48
+ },
49
+ {
50
+ "text": "-1.5~2 [lbreak]",
51
+ },
52
+ {
53
+ "text": "她出生于86年8月18日,她弟弟出生于1995年3月1日 [lbreak]",
54
+ },
55
+ {
56
+ "text": "等会请在12:05请通知我 [lbreak]",
57
+ },
58
+ {
59
+ "text": "今天的最低气温达到-10°C [lbreak]",
60
+ },
61
+ {
62
+ "text": "现场有7/12的观众投出了赞成票 [lbreak]",
63
+ },
64
+ {
65
+ "text": "明天有62%的概率降雨 [lbreak]",
66
+ },
67
+ {
68
+ "text": "随便来几个价格12块5,34.5元,20.1万 [lbreak]",
69
+ },
70
+ {
71
+ "text": "这是固话0421-33441122 [lbreak]",
72
+ },
73
+ {
74
+ "text": "这是手机+86 18544139121 [lbreak]",
75
+ },
76
+ ]
77
+
78
+ ssml_example1 = """
79
+ <speak version="0.1">
80
+ <voice spk="Bob" seed="42" style="narration-relaxed">
81
+ 下面是一个 ChatTTS 用于合成多角色多情感的有声书示例[lbreak]
82
+ </voice>
83
+ <voice spk="Bob" seed="42" style="narration-relaxed">
84
+ 黛玉冷笑道:[lbreak]
85
+ </voice>
86
+ <voice spk="female2" seed="42" style="angry">
87
+ 我说呢 [uv_break] ,亏了绊住,不然,早就飞起来了[lbreak]
88
+ </voice>
89
+ <voice spk="Bob" seed="42" style="narration-relaxed">
90
+ 宝玉道:[lbreak]
91
+ </voice>
92
+ <voice spk="Alice" seed="42" style="unfriendly">
93
+ “只许和你玩 [uv_break] ,替你解闷。不过偶然到他那里,就说这些闲话。”[lbreak]
94
+ </voice>
95
+ <voice spk="female2" seed="42" style="angry">
96
+ “好没意思的话![uv_break] 去不去,关我什么事儿? 又没叫你替我解闷儿 [uv_break],还许你不理我呢” [lbreak]
97
+ </voice>
98
+ <voice spk="Bob" seed="42" style="narration-relaxed">
99
+ 说着,便赌气回房去了 [lbreak]
100
+ </voice>
101
+ </speak>
102
+ """
103
+ ssml_example2 = """
104
+ <speak version="0.1">
105
+ <voice spk="Bob" seed="42" style="narration-relaxed">
106
+ 使用 prosody 控制生成文本的语速语调和音量,示例如下 [lbreak]
107
+
108
+ <prosody>
109
+ 无任何限制将会继承父级voice配置进行生成 [lbreak]
110
+ </prosody>
111
+ <prosody rate="1.5">
112
+ 设置 rate 大于1表示加速,小于1为减速 [lbreak]
113
+ </prosody>
114
+ <prosody pitch="6">
115
+ 设置 pitch 调整音调,设置为6表示提高6个半音 [lbreak]
116
+ </prosody>
117
+ <prosody volume="2">
118
+ 设置 volume 调整音量,设置为2表示提高2个分贝 [lbreak]
119
+ </prosody>
120
+
121
+ 在 voice 中无prosody包裹的文本即为默认生成状态下的语音 [lbreak]
122
+ </voice>
123
+ </speak>
124
+ """
125
+ ssml_example3 = """
126
+ <speak version="0.1">
127
+ <voice spk="Bob" seed="42" style="narration-relaxed">
128
+ 使用 break 标签将会简单的 [lbreak]
129
+
130
+ <break time="500" />
131
+
132
+ 插入一段空白到生成结果中 [lbreak]
133
+ </voice>
134
+ </speak>
135
+ """
136
+
137
+ ssml_example4 = """
138
+ <speak version="0.1">
139
+ <voice spk="Bob" seed="42" style="excited">
140
+ temperature for sampling (may be overridden by style or speaker) [lbreak]
141
+ <break time="500" />
142
+ 温度值用于采样,这个值有可能被 style 或者 speaker 覆盖 [lbreak]
143
+ <break time="500" />
144
+ temperature for sampling ,这个值有可能被 style 或者 speaker 覆盖 [lbreak]
145
+ <break time="500" />
146
+ 温度值用于采样,(may be overridden by style or speaker) [lbreak]
147
+ </voice>
148
+ </speak>
149
+ """
150
+
151
+ ssml_examples = [
152
+ ssml_example1,
153
+ ssml_example2,
154
+ ssml_example3,
155
+ ssml_example4,
156
+ ]
157
+
158
+ default_ssml = """
159
+ <speak version="0.1">
160
+ <voice spk="Bob" seed="42" style="narration-relaxed">
161
+ 这里是一个简单的 SSML 示例 [lbreak]
162
+ </voice>
163
+ </speak>
164
+ """
modules/webui/readme_tab.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def read_local_readme():
5
+ with open("README.md", "r", encoding="utf-8") as file:
6
+ content = file.read()
7
+ content = content[content.index("# ") :]
8
+ return content
9
+
10
+
11
+ def create_readme_tab():
12
+ readme_content = read_local_readme()
13
+ gr.Markdown(readme_content)
modules/webui/speaker_tab.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from modules.webui.webui_utils import get_speakers
4
+
5
+
6
+ # 显示 a b c d 四个选择框,选择一个或多个,然后可以试音,并导出
7
+ def create_speaker_panel():
8
+ speakers = get_speakers()
9
+
10
+ def get_speaker_show_name(spk):
11
+ pass
12
+
13
+ gr.Markdown("🚧 Under construction")
modules/webui/spliter_tab.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from modules.normalization import text_normalize
4
+ from modules.webui.webui_utils import (
5
+ get_speakers,
6
+ get_styles,
7
+ split_long_text,
8
+ synthesize_ssml,
9
+ )
10
+ from modules.webui import webui_config
11
+ from modules.webui.examples import ssml_examples, default_ssml
12
+
13
+
14
+ def merge_dataframe_to_ssml(dataframe, spk, style, seed):
15
+ if style == "*auto":
16
+ style = None
17
+ if spk == "-1" or spk == -1:
18
+ spk = None
19
+ if seed == -1 or seed == "-1":
20
+ seed = None
21
+
22
+ ssml = ""
23
+ indent = " " * 2
24
+
25
+ for i, row in dataframe.iterrows():
26
+ ssml += f"{indent}<voice"
27
+ if spk:
28
+ ssml += f' spk="{spk}"'
29
+ if style:
30
+ ssml += f' style="{style}"'
31
+ if seed:
32
+ ssml += f' seed="{seed}"'
33
+ ssml += ">\n"
34
+ ssml += f"{indent}{indent}{text_normalize(row[1])}\n"
35
+ ssml += f"{indent}</voice>\n"
36
+ return f"<speak version='0.1'>\n{ssml}</speak>"
37
+
38
+
39
+ # 长文本处理
40
+ # 可以输入长文本,并选择切割方法,切割之后可以将拼接的SSML发送到SSML tab
41
+ # 根据 。 句号切割,切割之后显示到 data table
42
+ def create_spliter_tab(ssml_input, tabs):
43
+ speakers = get_speakers()
44
+
45
+ def get_speaker_show_name(spk):
46
+ if spk.gender == "*" or spk.gender == "":
47
+ return spk.name
48
+ return f"{spk.gender} : {spk.name}"
49
+
50
+ speaker_names = ["*random"] + [
51
+ get_speaker_show_name(speaker) for speaker in speakers
52
+ ]
53
+
54
+ styles = ["*auto"] + [s.get("name") for s in get_styles()]
55
+
56
+ with gr.Row():
57
+ with gr.Column(scale=1):
58
+ # 选择说话人 选择风格 选择seed
59
+ with gr.Group():
60
+ gr.Markdown("🗣️Speaker")
61
+ spk_input_text = gr.Textbox(
62
+ label="Speaker (Text or Seed)",
63
+ value="female2",
64
+ show_label=False,
65
+ )
66
+ spk_input_dropdown = gr.Dropdown(
67
+ choices=speaker_names,
68
+ interactive=True,
69
+ value="female : female2",
70
+ show_label=False,
71
+ )
72
+ spk_rand_button = gr.Button(
73
+ value="🎲",
74
+ variant="secondary",
75
+ )
76
+ with gr.Group():
77
+ gr.Markdown("🎭Style")
78
+ style_input_dropdown = gr.Dropdown(
79
+ choices=styles,
80
+ interactive=True,
81
+ show_label=False,
82
+ value="*auto",
83
+ )
84
+ with gr.Group():
85
+ gr.Markdown("🗣️Seed")
86
+ infer_seed_input = gr.Number(
87
+ value=42,
88
+ label="Inference Seed",
89
+ show_label=False,
90
+ minimum=-1,
91
+ maximum=2**32 - 1,
92
+ )
93
+ infer_seed_rand_button = gr.Button(
94
+ value="🎲",
95
+ variant="secondary",
96
+ )
97
+
98
+ send_btn = gr.Button("📩Send to SSML", variant="primary")
99
+
100
+ with gr.Column(scale=3):
101
+ with gr.Group():
102
+ gr.Markdown("📝Long Text Input")
103
+ gr.Markdown("- 此页面用于处理超长文本")
104
+ gr.Markdown("- 切割后,可以选择说话人、风格、seed,然后发送到SSML")
105
+ long_text_input = gr.Textbox(
106
+ label="Long Text Input",
107
+ lines=10,
108
+ placeholder="输入长文本",
109
+ elem_id="long-text-input",
110
+ show_label=False,
111
+ )
112
+ long_text_split_button = gr.Button("🔪Split Text")
113
+
114
+ with gr.Row():
115
+ with gr.Column(scale=3):
116
+ with gr.Group():
117
+ gr.Markdown("🎨Output")
118
+ long_text_output = gr.DataFrame(
119
+ headers=["index", "text", "length"],
120
+ datatype=["number", "str", "number"],
121
+ elem_id="long-text-output",
122
+ interactive=False,
123
+ wrap=True,
124
+ value=[],
125
+ )
126
+
127
+ spk_input_dropdown.change(
128
+ fn=lambda x: x.startswith("*") and "-1" or x.split(":")[-1].strip(),
129
+ inputs=[spk_input_dropdown],
130
+ outputs=[spk_input_text],
131
+ )
132
+ spk_rand_button.click(
133
+ lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
134
+ inputs=[spk_input_text],
135
+ outputs=[spk_input_text],
136
+ )
137
+ infer_seed_rand_button.click(
138
+ lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
139
+ inputs=[infer_seed_input],
140
+ outputs=[infer_seed_input],
141
+ )
142
+ long_text_split_button.click(
143
+ split_long_text,
144
+ inputs=[long_text_input],
145
+ outputs=[long_text_output],
146
+ )
147
+
148
+ infer_seed_rand_button.click(
149
+ lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
150
+ inputs=[infer_seed_input],
151
+ outputs=[infer_seed_input],
152
+ )
153
+
154
+ send_btn.click(
155
+ merge_dataframe_to_ssml,
156
+ inputs=[
157
+ long_text_output,
158
+ spk_input_text,
159
+ style_input_dropdown,
160
+ infer_seed_input,
161
+ ],
162
+ outputs=[ssml_input],
163
+ )
164
+
165
+ def change_tab():
166
+ return gr.Tabs(selected="ssml")
167
+
168
+ send_btn.click(change_tab, inputs=[], outputs=[tabs])
modules/webui/ssml_tab.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from modules.webui.webui_utils import (
3
+ synthesize_ssml,
4
+ )
5
+ from modules.webui import webui_config
6
+ from modules.webui.examples import ssml_examples, default_ssml
7
+
8
+
9
+ def create_ssml_interface():
10
+ with gr.Row():
11
+ with gr.Column(scale=3):
12
+ with gr.Group():
13
+ gr.Markdown("📝SSML Input")
14
+ gr.Markdown(f"- 最长{webui_config.ssml_max:,}字符,超过会被截断")
15
+ gr.Markdown("- 尽量保证使用相同的 seed")
16
+ gr.Markdown(
17
+ "- 关于SSML可以看这个 [文档](https://github.com/lenML/ChatTTS-Forge/blob/main/docs/SSML.md)"
18
+ )
19
+ ssml_input = gr.Textbox(
20
+ label="SSML Input",
21
+ lines=10,
22
+ value=default_ssml,
23
+ placeholder="输入 SSML 或选择示例",
24
+ elem_id="ssml_input",
25
+ show_label=False,
26
+ )
27
+ ssml_button = gr.Button("🔊Synthesize SSML", variant="primary")
28
+ with gr.Column(scale=1):
29
+ with gr.Group():
30
+ # 参数
31
+ gr.Markdown("🎛️Parameters")
32
+ # batch size
33
+ batch_size_input = gr.Slider(
34
+ label="Batch Size",
35
+ value=4,
36
+ minimum=1,
37
+ maximum=webui_config.max_batch_size,
38
+ step=1,
39
+ )
40
+ with gr.Group():
41
+ gr.Markdown("🎄Examples")
42
+ gr.Examples(
43
+ examples=ssml_examples,
44
+ inputs=[ssml_input],
45
+ )
46
+
47
+ ssml_output = gr.Audio(label="Generated Audio")
48
+
49
+ ssml_button.click(
50
+ synthesize_ssml,
51
+ inputs=[ssml_input, batch_size_input],
52
+ outputs=ssml_output,
53
+ )
54
+
55
+ return ssml_input
modules/webui/tts_tab.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from modules.webui.webui_utils import (
4
+ get_speakers,
5
+ get_styles,
6
+ refine_text,
7
+ tts_generate,
8
+ )
9
+ from modules.webui import webui_config
10
+ from modules.webui.examples import example_texts
11
+ from modules import config
12
+
13
+
14
+ def create_tts_interface():
15
+ speakers = get_speakers()
16
+
17
+ def get_speaker_show_name(spk):
18
+ if spk.gender == "*" or spk.gender == "":
19
+ return spk.name
20
+ return f"{spk.gender} : {spk.name}"
21
+
22
+ speaker_names = ["*random"] + [
23
+ get_speaker_show_name(speaker) for speaker in speakers
24
+ ]
25
+
26
+ styles = ["*auto"] + [s.get("name") for s in get_styles()]
27
+
28
+ history = []
29
+
30
+ with gr.Row():
31
+ with gr.Column(scale=1):
32
+ with gr.Group():
33
+ gr.Markdown("🎛️Sampling")
34
+ temperature_input = gr.Slider(
35
+ 0.01, 2.0, value=0.3, step=0.01, label="Temperature"
36
+ )
37
+ top_p_input = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Top P")
38
+ top_k_input = gr.Slider(1, 50, value=20, step=1, label="Top K")
39
+ batch_size_input = gr.Slider(
40
+ 1,
41
+ webui_config.max_batch_size,
42
+ value=4,
43
+ step=1,
44
+ label="Batch Size",
45
+ )
46
+
47
+ with gr.Row():
48
+ with gr.Group():
49
+ gr.Markdown("🎭Style")
50
+ gr.Markdown("- 后缀为 `_p` 表示带prompt,效果更强但是影响质量")
51
+ style_input_dropdown = gr.Dropdown(
52
+ choices=styles,
53
+ # label="Choose Style",
54
+ interactive=True,
55
+ show_label=False,
56
+ value="*auto",
57
+ )
58
+ with gr.Row():
59
+ with gr.Group():
60
+ gr.Markdown("🗣️Speaker")
61
+ with gr.Tabs():
62
+ with gr.Tab(label="Pick"):
63
+ spk_input_text = gr.Textbox(
64
+ label="Speaker (Text or Seed)",
65
+ value="female2",
66
+ show_label=False,
67
+ )
68
+ spk_input_dropdown = gr.Dropdown(
69
+ choices=speaker_names,
70
+ # label="Choose Speaker",
71
+ interactive=True,
72
+ value="female : female2",
73
+ show_label=False,
74
+ )
75
+ spk_rand_button = gr.Button(
76
+ value="🎲",
77
+ # tooltip="Random Seed",
78
+ variant="secondary",
79
+ )
80
+ spk_input_dropdown.change(
81
+ fn=lambda x: x.startswith("*")
82
+ and "-1"
83
+ or x.split(":")[-1].strip(),
84
+ inputs=[spk_input_dropdown],
85
+ outputs=[spk_input_text],
86
+ )
87
+ spk_rand_button.click(
88
+ lambda x: str(torch.randint(0, 2**32 - 1, (1,)).item()),
89
+ inputs=[spk_input_text],
90
+ outputs=[spk_input_text],
91
+ )
92
+
93
+ if config.runtime_env_vars.webui_experimental:
94
+ with gr.Tab(label="Upload"):
95
+ spk_input_upload = gr.File(label="Speaker (Upload)")
96
+ # TODO 读取 speaker
97
+ # spk_input_upload.change(
98
+ # fn=lambda x: x.read().decode("utf-8"),
99
+ # inputs=[spk_input_upload],
100
+ # outputs=[spk_input_text],
101
+ # )
102
+ with gr.Group():
103
+ gr.Markdown("💃Inference Seed")
104
+ infer_seed_input = gr.Number(
105
+ value=42,
106
+ label="Inference Seed",
107
+ show_label=False,
108
+ minimum=-1,
109
+ maximum=2**32 - 1,
110
+ )
111
+ infer_seed_rand_button = gr.Button(
112
+ value="🎲",
113
+ # tooltip="Random Seed",
114
+ variant="secondary",
115
+ )
116
+ use_decoder_input = gr.Checkbox(
117
+ value=True, label="Use Decoder", visible=False
118
+ )
119
+ with gr.Group():
120
+ gr.Markdown("🔧Prompt engineering")
121
+ prompt1_input = gr.Textbox(label="Prompt 1")
122
+ prompt2_input = gr.Textbox(label="Prompt 2")
123
+ prefix_input = gr.Textbox(label="Prefix")
124
+
125
+ if config.runtime_env_vars.webui_experimental:
126
+ prompt_audio = gr.File(label="prompt_audio")
127
+
128
+ infer_seed_rand_button.click(
129
+ lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
130
+ inputs=[infer_seed_input],
131
+ outputs=[infer_seed_input],
132
+ )
133
+ with gr.Column(scale=3):
134
+ with gr.Row():
135
+ with gr.Column(scale=4):
136
+ with gr.Group():
137
+ input_title = gr.Markdown(
138
+ "📝Text Input",
139
+ elem_id="input-title",
140
+ )
141
+ gr.Markdown(
142
+ f"- 字数限制{webui_config.tts_max:,}字,超过部分截断"
143
+ )
144
+ gr.Markdown("- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`")
145
+ gr.Markdown(
146
+ "- If the input text is all in English, it is recommended to check disable_normalize"
147
+ )
148
+ text_input = gr.Textbox(
149
+ show_label=False,
150
+ label="Text to Speech",
151
+ lines=10,
152
+ placeholder="输入文本或选择示例",
153
+ elem_id="text-input",
154
+ )
155
+ # TODO 字数统计,其实实现很好写,但是就是会触发loading...并且还要和后端交互...
156
+ # text_input.change(
157
+ # fn=lambda x: (
158
+ # f"📝Text Input ({len(x)} char)"
159
+ # if x
160
+ # else (
161
+ # "📝Text Input (0 char)"
162
+ # if not x
163
+ # else "📝Text Input (0 char)"
164
+ # )
165
+ # ),
166
+ # inputs=[text_input],
167
+ # outputs=[input_title],
168
+ # )
169
+ with gr.Row():
170
+ contorl_tokens = [
171
+ "[laugh]",
172
+ "[uv_break]",
173
+ "[v_break]",
174
+ "[lbreak]",
175
+ ]
176
+
177
+ for tk in contorl_tokens:
178
+ t_btn = gr.Button(tk)
179
+ t_btn.click(
180
+ lambda text, tk=tk: text + " " + tk,
181
+ inputs=[text_input],
182
+ outputs=[text_input],
183
+ )
184
+ with gr.Column(scale=1):
185
+ with gr.Group():
186
+ gr.Markdown("🎶Refiner")
187
+ refine_prompt_input = gr.Textbox(
188
+ label="Refine Prompt",
189
+ value="[oral_2][laugh_0][break_6]",
190
+ )
191
+ refine_button = gr.Button("✍️Refine Text")
192
+ # TODO 分割句子,使用当前配置拼接为SSML,然后发送到SSML tab
193
+ # send_button = gr.Button("📩Split and send to SSML")
194
+
195
+ with gr.Group():
196
+ gr.Markdown("🔊Generate")
197
+ disable_normalize_input = gr.Checkbox(
198
+ value=False, label="Disable Normalize"
199
+ )
200
+ tts_button = gr.Button(
201
+ "🔊Generate Audio",
202
+ variant="primary",
203
+ elem_classes="big-button",
204
+ )
205
+
206
+ with gr.Group():
207
+ gr.Markdown("🎄Examples")
208
+ sample_dropdown = gr.Dropdown(
209
+ choices=[sample["text"] for sample in example_texts],
210
+ show_label=False,
211
+ value=None,
212
+ interactive=True,
213
+ )
214
+ sample_dropdown.change(
215
+ fn=lambda x: x,
216
+ inputs=[sample_dropdown],
217
+ outputs=[text_input],
218
+ )
219
+
220
+ with gr.Group():
221
+ gr.Markdown("🎨Output")
222
+ tts_output = gr.Audio(label="Generated Audio")
223
+
224
+ refine_button.click(
225
+ refine_text,
226
+ inputs=[text_input, refine_prompt_input],
227
+ outputs=[text_input],
228
+ )
229
+
230
+ tts_button.click(
231
+ tts_generate,
232
+ inputs=[
233
+ text_input,
234
+ temperature_input,
235
+ top_p_input,
236
+ top_k_input,
237
+ spk_input_text,
238
+ infer_seed_input,
239
+ use_decoder_input,
240
+ prompt1_input,
241
+ prompt2_input,
242
+ prefix_input,
243
+ style_input_dropdown,
244
+ disable_normalize_input,
245
+ batch_size_input,
246
+ ],
247
+ outputs=tts_output,
248
+ )
modules/webui/webui_config.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ tts_max = 1000
2
+ ssml_max = 1000
3
+ spliter_threshold = 100
4
+ max_batch_size = 8
modules/webui/webui_utils.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import sys
4
+
5
+ import numpy as np
6
+
7
+ from modules.devices import devices
8
+ from modules.synthesize_audio import synthesize_audio
9
+ from modules.hf import spaces
10
+ from modules.webui import webui_config
11
+
12
+ logging.basicConfig(
13
+ level=os.getenv("LOG_LEVEL", "INFO"),
14
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
15
+ )
16
+
17
+
18
+ import gradio as gr
19
+
20
+ import torch
21
+
22
+ from modules.ssml import parse_ssml
23
+ from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
24
+
25
+ from modules.speaker import speaker_mgr
26
+ from modules.data import styles_mgr
27
+
28
+ from modules.api.utils import calc_spk_style
29
+ import modules.generate_audio as generate
30
+
31
+ from modules.normalization import text_normalize
32
+ from modules import refiner, config
33
+
34
+ from modules.utils import env, audio
35
+ from modules.SentenceSplitter import SentenceSplitter
36
+
37
+
38
+ def get_speakers():
39
+ return speaker_mgr.list_speakers()
40
+
41
+
42
+ def get_styles():
43
+ return styles_mgr.list_items()
44
+
45
+
46
+ def segments_length_limit(segments, total_max: int):
47
+ ret_segments = []
48
+ total_len = 0
49
+ for seg in segments:
50
+ if "text" not in seg:
51
+ continue
52
+ total_len += len(seg["text"])
53
+ if total_len > total_max:
54
+ break
55
+ ret_segments.append(seg)
56
+ return ret_segments
57
+
58
+
59
+ @torch.inference_mode()
60
+ @spaces.GPU
61
+ def synthesize_ssml(ssml: str, batch_size=4):
62
+ try:
63
+ batch_size = int(batch_size)
64
+ except Exception:
65
+ batch_size = 8
66
+
67
+ ssml = ssml.strip()
68
+
69
+ if ssml == "":
70
+ return None
71
+
72
+ segments = parse_ssml(ssml)
73
+ max_len = webui_config.ssml_max
74
+ segments = segments_length_limit(segments, max_len)
75
+
76
+ if len(segments) == 0:
77
+ return None
78
+
79
+ synthesize = SynthesizeSegments(batch_size=batch_size)
80
+ audio_segments = synthesize.synthesize_segments(segments)
81
+ combined_audio = combine_audio_segments(audio_segments)
82
+
83
+ return audio.pydub_to_np(combined_audio)
84
+
85
+
86
+ @torch.inference_mode()
87
+ @spaces.GPU
88
+ def tts_generate(
89
+ text,
90
+ temperature,
91
+ top_p,
92
+ top_k,
93
+ spk,
94
+ infer_seed,
95
+ use_decoder,
96
+ prompt1,
97
+ prompt2,
98
+ prefix,
99
+ style,
100
+ disable_normalize=False,
101
+ batch_size=4,
102
+ ):
103
+ try:
104
+ batch_size = int(batch_size)
105
+ except Exception:
106
+ batch_size = 4
107
+
108
+ max_len = webui_config.tts_max
109
+ text = text.strip()[0:max_len]
110
+
111
+ if text == "":
112
+ return None
113
+
114
+ if style == "*auto":
115
+ style = None
116
+
117
+ if isinstance(top_k, float):
118
+ top_k = int(top_k)
119
+
120
+ params = calc_spk_style(spk=spk, style=style)
121
+ spk = params.get("spk", spk)
122
+
123
+ infer_seed = infer_seed or params.get("seed", infer_seed)
124
+ temperature = temperature or params.get("temperature", temperature)
125
+ prefix = prefix or params.get("prefix", prefix)
126
+ prompt1 = prompt1 or params.get("prompt1", "")
127
+ prompt2 = prompt2 or params.get("prompt2", "")
128
+
129
+ infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.int64)
130
+ infer_seed = int(infer_seed)
131
+
132
+ if not disable_normalize:
133
+ text = text_normalize(text)
134
+
135
+ sample_rate, audio_data = synthesize_audio(
136
+ text=text,
137
+ temperature=temperature,
138
+ top_P=top_p,
139
+ top_K=top_k,
140
+ spk=spk,
141
+ infer_seed=infer_seed,
142
+ use_decoder=use_decoder,
143
+ prompt1=prompt1,
144
+ prompt2=prompt2,
145
+ prefix=prefix,
146
+ batch_size=batch_size,
147
+ )
148
+
149
+ audio_data = audio.audio_to_int16(audio_data)
150
+ return sample_rate, audio_data
151
+
152
+
153
+ @torch.inference_mode()
154
+ @spaces.GPU
155
+ def refine_text(text: str, prompt: str):
156
+ text = text_normalize(text)
157
+ return refiner.refine_text(text, prompt=prompt)
158
+
159
+
160
+ @torch.inference_mode()
161
+ @spaces.GPU
162
+ def split_long_text(long_text_input):
163
+ spliter = SentenceSplitter(webui_config.spliter_threshold)
164
+ sentences = spliter.parse(long_text_input)
165
+ sentences = [text_normalize(s) for s in sentences]
166
+ data = []
167
+ for i, text in enumerate(sentences):
168
+ data.append([i, text, len(text)])
169
+ return data
webui.py CHANGED
@@ -1,859 +1,10 @@
1
- try:
2
- import spaces
3
- except:
4
-
5
- class NoneSpaces:
6
- def __init__(self):
7
- pass
8
-
9
- def GPU(self, fn):
10
- return fn
11
-
12
- spaces = NoneSpaces()
13
-
14
  import os
15
- import logging
16
- import sys
17
-
18
- import numpy as np
19
-
20
  from modules.devices import devices
21
- from modules.synthesize_audio import synthesize_audio
22
-
23
- logging.basicConfig(
24
- level=os.getenv("LOG_LEVEL", "INFO"),
25
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
26
- )
27
-
28
-
29
- import gradio as gr
30
-
31
- import torch
32
-
33
- from modules.ssml import parse_ssml
34
- from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
35
-
36
- from modules.speaker import speaker_mgr
37
- from modules.data import styles_mgr
38
-
39
- from modules.api.utils import calc_spk_style
40
- import modules.generate_audio as generate
41
-
42
- from modules.normalization import text_normalize
43
- from modules import refiner, config, models
44
-
45
- from modules.utils import env, audio
46
- from modules.SentenceSplitter import SentenceSplitter
47
-
48
- # fix: If the system proxy is enabled in the Windows system, you need to skip these
49
- os.environ["NO_PROXY"] = "localhost,127.0.0.1,0.0.0.0"
50
-
51
- torch._dynamo.config.cache_size_limit = 64
52
- torch._dynamo.config.suppress_errors = True
53
- torch.set_float32_matmul_precision("high")
54
-
55
- webui_config = {
56
- "tts_max": 1000,
57
- "ssml_max": 5000,
58
- "spliter_threshold": 100,
59
- "max_batch_size": 8,
60
- }
61
-
62
-
63
- def get_speakers():
64
- return speaker_mgr.list_speakers()
65
-
66
-
67
- def get_styles():
68
- return styles_mgr.list_items()
69
-
70
-
71
- def segments_length_limit(segments, total_max: int):
72
- ret_segments = []
73
- total_len = 0
74
- for seg in segments:
75
- if "text" not in seg:
76
- continue
77
- total_len += len(seg["text"])
78
- if total_len > total_max:
79
- break
80
- ret_segments.append(seg)
81
- return ret_segments
82
-
83
-
84
- @torch.inference_mode()
85
- @spaces.GPU
86
- def synthesize_ssml(ssml: str, batch_size=4):
87
- try:
88
- batch_size = int(batch_size)
89
- except Exception:
90
- batch_size = 8
91
-
92
- ssml = ssml.strip()
93
-
94
- if ssml == "":
95
- return None
96
-
97
- segments = parse_ssml(ssml)
98
- max_len = webui_config["ssml_max"]
99
- segments = segments_length_limit(segments, max_len)
100
-
101
- if len(segments) == 0:
102
- return None
103
-
104
- models.load_chat_tts()
105
- synthesize = SynthesizeSegments(batch_size=batch_size)
106
- audio_segments = synthesize.synthesize_segments(segments)
107
- combined_audio = combine_audio_segments(audio_segments)
108
-
109
- return audio.pydub_to_np(combined_audio)
110
-
111
-
112
- @torch.inference_mode()
113
- @spaces.GPU
114
- def tts_generate(
115
- text,
116
- temperature,
117
- top_p,
118
- top_k,
119
- spk,
120
- infer_seed,
121
- use_decoder,
122
- prompt1,
123
- prompt2,
124
- prefix,
125
- style,
126
- disable_normalize=False,
127
- batch_size=4,
128
- ):
129
- try:
130
- batch_size = int(batch_size)
131
- except Exception:
132
- batch_size = 4
133
-
134
- max_len = webui_config["tts_max"]
135
- text = text.strip()[0:max_len]
136
-
137
- if text == "":
138
- return None
139
-
140
- if style == "*auto":
141
- style = None
142
-
143
- if isinstance(top_k, float):
144
- top_k = int(top_k)
145
-
146
- params = calc_spk_style(spk=spk, style=style)
147
- spk = params.get("spk", spk)
148
-
149
- infer_seed = infer_seed or params.get("seed", infer_seed)
150
- temperature = temperature or params.get("temperature", temperature)
151
- prefix = prefix or params.get("prefix", prefix)
152
- prompt1 = prompt1 or params.get("prompt1", "")
153
- prompt2 = prompt2 or params.get("prompt2", "")
154
-
155
- infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.int64)
156
- infer_seed = int(infer_seed)
157
-
158
- if not disable_normalize:
159
- text = text_normalize(text)
160
-
161
- models.load_chat_tts()
162
- sample_rate, audio_data = synthesize_audio(
163
- text=text,
164
- temperature=temperature,
165
- top_P=top_p,
166
- top_K=top_k,
167
- spk=spk,
168
- infer_seed=infer_seed,
169
- use_decoder=use_decoder,
170
- prompt1=prompt1,
171
- prompt2=prompt2,
172
- prefix=prefix,
173
- batch_size=batch_size,
174
- )
175
-
176
- audio_data = audio.audio_to_int16(audio_data)
177
- return sample_rate, audio_data
178
-
179
-
180
- @torch.inference_mode()
181
- @spaces.GPU
182
- def refine_text(text: str, prompt: str):
183
- text = text_normalize(text)
184
- return refiner.refine_text(text, prompt=prompt)
185
-
186
-
187
- def read_local_readme():
188
- with open("README.md", "r", encoding="utf-8") as file:
189
- content = file.read()
190
- content = content[content.index("# ") :]
191
- return content
192
-
193
-
194
- # 演示示例文本
195
- sample_texts = [
196
- {
197
- "text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]",
198
- },
199
- {
200
- "text": "天气预报显示,今天会有小雨,请大家出门时记得带伞。降温的天气也提醒我们要适时添衣保暖 [lbreak]",
201
- },
202
- {
203
- "text": "公司的年度总结会议将在下周三举行,请各部门提前准备好相关材料,确保会议顺利进行 [lbreak]",
204
- },
205
- {
206
- "text": "今天的午餐菜单包括烤鸡、���拉和蔬菜汤,大家可以根据自己的口味选择适合的菜品 [lbreak]",
207
- },
208
- {
209
- "text": "请注意,电梯将在下午两点进行例行维护,预计需要一个小时的时间,请大家在此期间使用楼梯 [lbreak]",
210
- },
211
- {
212
- "text": "图书馆新到了一批书籍,涵盖了文学、科学和历史等多个领域,欢迎大家前来借阅 [lbreak]",
213
- },
214
- {
215
- "text": "电影中梁朝伟扮演的陈永仁的编号27149 [lbreak]",
216
- },
217
- {
218
- "text": "这块黄金重达324.75克 [lbreak]",
219
- },
220
- {
221
- "text": "我们班的最高总分为583分 [lbreak]",
222
- },
223
- {
224
- "text": "12~23 [lbreak]",
225
- },
226
- {
227
- "text": "-1.5~2 [lbreak]",
228
- },
229
- {
230
- "text": "她出生于86年8月18日,她弟弟出生于1995年3月1日 [lbreak]",
231
- },
232
- {
233
- "text": "等会请在12:05请通知我 [lbreak]",
234
- },
235
- {
236
- "text": "今天的最低气温达到-10°C [lbreak]",
237
- },
238
- {
239
- "text": "现场有7/12的观众投出了赞成票 [lbreak]",
240
- },
241
- {
242
- "text": "明天有62%的概率降雨 [lbreak]",
243
- },
244
- {
245
- "text": "随便来几个价格12块5,34.5元,20.1万 [lbreak]",
246
- },
247
- {
248
- "text": "这是固话0421-33441122 [lbreak]",
249
- },
250
- {
251
- "text": "这是手机+86 18544139121 [lbreak]",
252
- },
253
- ]
254
-
255
- ssml_example1 = """
256
- <speak version="0.1">
257
- <voice spk="Bob" seed="42" style="narration-relaxed">
258
- 下面是一个 ChatTTS 用于合成多角色多情感的有声书示例[lbreak]
259
- </voice>
260
- <voice spk="Bob" seed="42" style="narration-relaxed">
261
- 黛玉冷笑道:[lbreak]
262
- </voice>
263
- <voice spk="female2" seed="42" style="angry">
264
- 我说呢 [uv_break] ,亏了绊住,不然,早就飞起来了[lbreak]
265
- </voice>
266
- <voice spk="Bob" seed="42" style="narration-relaxed">
267
- 宝玉道:[lbreak]
268
- </voice>
269
- <voice spk="Alice" seed="42" style="unfriendly">
270
- “只许和你玩 [uv_break] ,替你解闷。不过偶然到他那里,就说这些闲话。”[lbreak]
271
- </voice>
272
- <voice spk="female2" seed="42" style="angry">
273
- “好没意思的话![uv_break] 去不去,关我什么事儿? 又没叫你替我解闷儿 [uv_break],还许你不理我呢” [lbreak]
274
- </voice>
275
- <voice spk="Bob" seed="42" style="narration-relaxed">
276
- 说着,便赌气回房去了 [lbreak]
277
- </voice>
278
- </speak>
279
- """
280
- ssml_example2 = """
281
- <speak version="0.1">
282
- <voice spk="Bob" seed="42" style="narration-relaxed">
283
- 使用 prosody 控制生成文本的语速语调和音量,示例如下 [lbreak]
284
-
285
- <prosody>
286
- 无任何限制将会继承父级voice配置进行生成 [lbreak]
287
- </prosody>
288
- <prosody rate="1.5">
289
- 设置 rate 大于1表示加速,小于1为减速 [lbreak]
290
- </prosody>
291
- <prosody pitch="6">
292
- 设置 pitch 调整音调,设置为6表示提高6个半音 [lbreak]
293
- </prosody>
294
- <prosody volume="2">
295
- 设置 volume 调整音量,设置为2表示提高2个分贝 [lbreak]
296
- </prosody>
297
-
298
- 在 voice 中无prosody包裹的文本即为默认生成状态下的语音 [lbreak]
299
- </voice>
300
- </speak>
301
- """
302
- ssml_example3 = """
303
- <speak version="0.1">
304
- <voice spk="Bob" seed="42" style="narration-relaxed">
305
- 使用 break 标签将会简单的 [lbreak]
306
-
307
- <break time="500" />
308
-
309
- 插入一段空白到生成结果中 [lbreak]
310
- </voice>
311
- </speak>
312
- """
313
-
314
- ssml_example4 = """
315
- <speak version="0.1">
316
- <voice spk="Bob" seed="42" style="excited">
317
- temperature for sampling (may be overridden by style or speaker) [lbreak]
318
- <break time="500" />
319
- 温度值用于采样,这个值有可能被 style 或者 speaker 覆盖 [lbreak]
320
- <break time="500" />
321
- temperature for sampling ,这个值有可能被 style 或者 speaker 覆盖 [lbreak]
322
- <break time="500" />
323
- 温度值用于采样,(may be overridden by style or speaker) [lbreak]
324
- </voice>
325
- </speak>
326
- """
327
-
328
- default_ssml = """
329
- <speak version="0.1">
330
- <voice spk="Bob" seed="42" style="narration-relaxed">
331
- 这里是一个简单的 SSML 示例 [lbreak]
332
- </voice>
333
- </speak>
334
- """
335
-
336
-
337
- def create_tts_interface():
338
- speakers = get_speakers()
339
-
340
- def get_speaker_show_name(spk):
341
- if spk.gender == "*" or spk.gender == "":
342
- return spk.name
343
- return f"{spk.gender} : {spk.name}"
344
-
345
- speaker_names = ["*random"] + [
346
- get_speaker_show_name(speaker) for speaker in speakers
347
- ]
348
-
349
- styles = ["*auto"] + [s.get("name") for s in get_styles()]
350
-
351
- history = []
352
-
353
- with gr.Row():
354
- with gr.Column(scale=1):
355
- with gr.Group():
356
- gr.Markdown("🎛️Sampling")
357
- temperature_input = gr.Slider(
358
- 0.01, 2.0, value=0.3, step=0.01, label="Temperature"
359
- )
360
- top_p_input = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Top P")
361
- top_k_input = gr.Slider(1, 50, value=20, step=1, label="Top K")
362
- batch_size_input = gr.Slider(
363
- 1,
364
- webui_config["max_batch_size"],
365
- value=4,
366
- step=1,
367
- label="Batch Size",
368
- )
369
-
370
- with gr.Row():
371
- with gr.Group():
372
- gr.Markdown("🎭Style")
373
- gr.Markdown("- 后缀为 `_p` 表示带prompt,效果更强但是影响质量")
374
- style_input_dropdown = gr.Dropdown(
375
- choices=styles,
376
- # label="Choose Style",
377
- interactive=True,
378
- show_label=False,
379
- value="*auto",
380
- )
381
- with gr.Row():
382
- with gr.Group():
383
- gr.Markdown("🗣️Speaker (Name or Seed)")
384
- spk_input_text = gr.Textbox(
385
- label="Speaker (Text or Seed)",
386
- value="female2",
387
- show_label=False,
388
- )
389
- spk_input_dropdown = gr.Dropdown(
390
- choices=speaker_names,
391
- # label="Choose Speaker",
392
- interactive=True,
393
- value="female : female2",
394
- show_label=False,
395
- )
396
- spk_rand_button = gr.Button(
397
- value="🎲",
398
- # tooltip="Random Seed",
399
- variant="secondary",
400
- )
401
- spk_input_dropdown.change(
402
- fn=lambda x: x.startswith("*")
403
- and "-1"
404
- or x.split(":")[-1].strip(),
405
- inputs=[spk_input_dropdown],
406
- outputs=[spk_input_text],
407
- )
408
- spk_rand_button.click(
409
- lambda x: str(torch.randint(0, 2**32 - 1, (1,)).item()),
410
- inputs=[spk_input_text],
411
- outputs=[spk_input_text],
412
- )
413
- with gr.Group():
414
- gr.Markdown("💃Inference Seed")
415
- infer_seed_input = gr.Number(
416
- value=42,
417
- label="Inference Seed",
418
- show_label=False,
419
- minimum=-1,
420
- maximum=2**32 - 1,
421
- )
422
- infer_seed_rand_button = gr.Button(
423
- value="🎲",
424
- # tooltip="Random Seed",
425
- variant="secondary",
426
- )
427
- use_decoder_input = gr.Checkbox(
428
- value=True, label="Use Decoder", visible=False
429
- )
430
- with gr.Group():
431
- gr.Markdown("🔧Prompt engineering")
432
- prompt1_input = gr.Textbox(label="Prompt 1")
433
- prompt2_input = gr.Textbox(label="Prompt 2")
434
- prefix_input = gr.Textbox(label="Prefix")
435
-
436
- infer_seed_rand_button.click(
437
- lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
438
- inputs=[infer_seed_input],
439
- outputs=[infer_seed_input],
440
- )
441
- with gr.Column(scale=3):
442
- with gr.Row():
443
- with gr.Column(scale=4):
444
- with gr.Group():
445
- input_title = gr.Markdown(
446
- "📝Text Input",
447
- elem_id="input-title",
448
- )
449
- gr.Markdown(
450
- f"- 字数限制{webui_config['tts_max']:,}字,超过部分截断"
451
- )
452
- gr.Markdown("- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`")
453
- gr.Markdown(
454
- "- If the input text is all in English, it is recommended to check disable_normalize"
455
- )
456
- text_input = gr.Textbox(
457
- show_label=False,
458
- label="Text to Speech",
459
- lines=10,
460
- placeholder="输入文本或选择示例",
461
- elem_id="text-input",
462
- )
463
- # TODO 字数统计,其实实现很好写,但是就是会触发loading...并且还要和后端交互...
464
- # text_input.change(
465
- # fn=lambda x: (
466
- # f"📝Text Input ({len(x)} char)"
467
- # if x
468
- # else (
469
- # "📝Text Input (0 char)"
470
- # if not x
471
- # else "📝Text Input (0 char)"
472
- # )
473
- # ),
474
- # inputs=[text_input],
475
- # outputs=[input_title],
476
- # )
477
- with gr.Row():
478
- contorl_tokens = [
479
- "[laugh]",
480
- "[uv_break]",
481
- "[v_break]",
482
- "[lbreak]",
483
- ]
484
-
485
- for tk in contorl_tokens:
486
- t_btn = gr.Button(tk)
487
- t_btn.click(
488
- lambda text, tk=tk: text + " " + tk,
489
- inputs=[text_input],
490
- outputs=[text_input],
491
- )
492
- with gr.Column(scale=1):
493
- with gr.Group():
494
- gr.Markdown("🎶Refiner")
495
- refine_prompt_input = gr.Textbox(
496
- label="Refine Prompt",
497
- value="[oral_2][laugh_0][break_6]",
498
- )
499
- refine_button = gr.Button("✍️Refine Text")
500
- # TODO 分割句子,使用当前配置拼接为SSML,然后发送到SSML tab
501
- # send_button = gr.Button("📩Split and send to SSML")
502
-
503
- with gr.Group():
504
- gr.Markdown("🔊Generate")
505
- disable_normalize_input = gr.Checkbox(
506
- value=False, label="Disable Normalize"
507
- )
508
- tts_button = gr.Button(
509
- "🔊Generate Audio",
510
- variant="primary",
511
- elem_classes="big-button",
512
- )
513
-
514
- with gr.Group():
515
- gr.Markdown("🎄Examples")
516
- sample_dropdown = gr.Dropdown(
517
- choices=[sample["text"] for sample in sample_texts],
518
- show_label=False,
519
- value=None,
520
- interactive=True,
521
- )
522
- sample_dropdown.change(
523
- fn=lambda x: x,
524
- inputs=[sample_dropdown],
525
- outputs=[text_input],
526
- )
527
-
528
- with gr.Group():
529
- gr.Markdown("🎨Output")
530
- tts_output = gr.Audio(label="Generated Audio")
531
-
532
- refine_button.click(
533
- refine_text,
534
- inputs=[text_input, refine_prompt_input],
535
- outputs=[text_input],
536
- )
537
-
538
- tts_button.click(
539
- tts_generate,
540
- inputs=[
541
- text_input,
542
- temperature_input,
543
- top_p_input,
544
- top_k_input,
545
- spk_input_text,
546
- infer_seed_input,
547
- use_decoder_input,
548
- prompt1_input,
549
- prompt2_input,
550
- prefix_input,
551
- style_input_dropdown,
552
- disable_normalize_input,
553
- batch_size_input,
554
- ],
555
- outputs=tts_output,
556
- )
557
-
558
-
559
- def create_ssml_interface():
560
- examples = [
561
- ssml_example1,
562
- ssml_example2,
563
- ssml_example3,
564
- ssml_example4,
565
- ]
566
-
567
- with gr.Row():
568
- with gr.Column(scale=3):
569
- with gr.Group():
570
- gr.Markdown("📝SSML Input")
571
- gr.Markdown(f"- 最长{webui_config['ssml_max']:,}字符,超过会被截断")
572
- gr.Markdown("- 尽量保证使用相同的 seed")
573
- gr.Markdown(
574
- "- 关于SSML可以看这个 [文档](https://github.com/lenML/ChatTTS-Forge/blob/main/docs/SSML.md)"
575
- )
576
- ssml_input = gr.Textbox(
577
- label="SSML Input",
578
- lines=10,
579
- value=default_ssml,
580
- placeholder="输入 SSML 或选择示例",
581
- elem_id="ssml_input",
582
- show_label=False,
583
- )
584
- ssml_button = gr.Button("🔊Synthesize SSML", variant="primary")
585
- with gr.Column(scale=1):
586
- with gr.Group():
587
- # 参数
588
- gr.Markdown("🎛️Parameters")
589
- # batch size
590
- batch_size_input = gr.Slider(
591
- label="Batch Size",
592
- value=4,
593
- minimum=1,
594
- maximum=webui_config["max_batch_size"],
595
- step=1,
596
- )
597
- with gr.Group():
598
- gr.Markdown("🎄Examples")
599
- gr.Examples(
600
- examples=examples,
601
- inputs=[ssml_input],
602
- )
603
-
604
- ssml_output = gr.Audio(label="Generated Audio")
605
-
606
- ssml_button.click(
607
- synthesize_ssml,
608
- inputs=[ssml_input, batch_size_input],
609
- outputs=ssml_output,
610
- )
611
-
612
- return ssml_input
613
-
614
-
615
- # NOTE: 这个其实是需要GPU的...但是spaces会自动卸载,所以不太好使,具体处理在text_normalize中兼容
616
- # @spaces.GPU
617
- def split_long_text(long_text_input):
618
- spliter = SentenceSplitter(webui_config["spliter_threshold"])
619
- sentences = spliter.parse(long_text_input)
620
- sentences = [text_normalize(s) for s in sentences]
621
- data = []
622
- for i, text in enumerate(sentences):
623
- data.append([i, text, len(text)])
624
- return data
625
-
626
-
627
- def merge_dataframe_to_ssml(dataframe, spk, style, seed):
628
- if style == "*auto":
629
- style = None
630
- if spk == "-1" or spk == -1:
631
- spk = None
632
- if seed == -1 or seed == "-1":
633
- seed = None
634
-
635
- ssml = ""
636
- indent = " " * 2
637
-
638
- for i, row in dataframe.iterrows():
639
- ssml += f"{indent}<voice"
640
- if spk:
641
- ssml += f' spk="{spk}"'
642
- if style:
643
- ssml += f' style="{style}"'
644
- if seed:
645
- ssml += f' seed="{seed}"'
646
- ssml += ">\n"
647
- ssml += f"{indent}{indent}{text_normalize(row[1])}\n"
648
- ssml += f"{indent}</voice>\n"
649
- return f"<speak version='0.1'>\n{ssml}</speak>"
650
-
651
-
652
- # 长文本处理
653
- # 可以输入长文本,并选择切割方法,切割之后可以将拼接的SSML发送到SSML tab
654
- # 根据 。 句号切割,切割之后显示到 data table
655
- def create_long_content_tab(ssml_input, tabs):
656
- speakers = get_speakers()
657
-
658
- def get_speaker_show_name(spk):
659
- if spk.gender == "*" or spk.gender == "":
660
- return spk.name
661
- return f"{spk.gender} : {spk.name}"
662
-
663
- speaker_names = ["*random"] + [
664
- get_speaker_show_name(speaker) for speaker in speakers
665
- ]
666
-
667
- styles = ["*auto"] + [s.get("name") for s in get_styles()]
668
-
669
- with gr.Row():
670
- with gr.Column(scale=1):
671
- # 选择说话人 选择风格 选择seed
672
- with gr.Group():
673
- gr.Markdown("🗣️Speaker")
674
- spk_input_text = gr.Textbox(
675
- label="Speaker (Text or Seed)",
676
- value="female2",
677
- show_label=False,
678
- )
679
- spk_input_dropdown = gr.Dropdown(
680
- choices=speaker_names,
681
- interactive=True,
682
- value="female : female2",
683
- show_label=False,
684
- )
685
- spk_rand_button = gr.Button(
686
- value="🎲",
687
- variant="secondary",
688
- )
689
- with gr.Group():
690
- gr.Markdown("🎭Style")
691
- style_input_dropdown = gr.Dropdown(
692
- choices=styles,
693
- interactive=True,
694
- show_label=False,
695
- value="*auto",
696
- )
697
- with gr.Group():
698
- gr.Markdown("🗣️Seed")
699
- infer_seed_input = gr.Number(
700
- value=42,
701
- label="Inference Seed",
702
- show_label=False,
703
- minimum=-1,
704
- maximum=2**32 - 1,
705
- )
706
- infer_seed_rand_button = gr.Button(
707
- value="🎲",
708
- variant="secondary",
709
- )
710
-
711
- send_btn = gr.Button("📩Send to SSML", variant="primary")
712
-
713
- with gr.Column(scale=3):
714
- with gr.Group():
715
- gr.Markdown("📝Long Text Input")
716
- gr.Markdown("- 此页面用于处理超长文本")
717
- gr.Markdown("- 切割后,可以选择说话人、风格、seed,然后发送到SSML")
718
- long_text_input = gr.Textbox(
719
- label="Long Text Input",
720
- lines=10,
721
- placeholder="输入长文本",
722
- elem_id="long-text-input",
723
- show_label=False,
724
- )
725
- long_text_split_button = gr.Button("🔪Split Text")
726
-
727
- with gr.Row():
728
- with gr.Column(scale=3):
729
- with gr.Group():
730
- gr.Markdown("🎨Output")
731
- long_text_output = gr.DataFrame(
732
- headers=["index", "text", "length"],
733
- datatype=["number", "str", "number"],
734
- elem_id="long-text-output",
735
- interactive=False,
736
- wrap=True,
737
- value=[],
738
- )
739
-
740
- spk_input_dropdown.change(
741
- fn=lambda x: x.startswith("*") and "-1" or x.split(":")[-1].strip(),
742
- inputs=[spk_input_dropdown],
743
- outputs=[spk_input_text],
744
- )
745
- spk_rand_button.click(
746
- lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
747
- inputs=[spk_input_text],
748
- outputs=[spk_input_text],
749
- )
750
- infer_seed_rand_button.click(
751
- lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
752
- inputs=[infer_seed_input],
753
- outputs=[infer_seed_input],
754
- )
755
- long_text_split_button.click(
756
- split_long_text,
757
- inputs=[long_text_input],
758
- outputs=[long_text_output],
759
- )
760
-
761
- infer_seed_rand_button.click(
762
- lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
763
- inputs=[infer_seed_input],
764
- outputs=[infer_seed_input],
765
- )
766
-
767
- send_btn.click(
768
- merge_dataframe_to_ssml,
769
- inputs=[
770
- long_text_output,
771
- spk_input_text,
772
- style_input_dropdown,
773
- infer_seed_input,
774
- ],
775
- outputs=[ssml_input],
776
- )
777
-
778
- def change_tab():
779
- return gr.Tabs(selected="ssml")
780
-
781
- send_btn.click(change_tab, inputs=[], outputs=[tabs])
782
-
783
-
784
- def create_readme_tab():
785
- readme_content = read_local_readme()
786
- gr.Markdown(readme_content)
787
-
788
-
789
- def create_app_footer():
790
- gradio_version = gr.__version__
791
- git_tag = config.versions.git_tag
792
- git_commit = config.versions.git_commit
793
- git_branch = config.versions.git_branch
794
- python_version = config.versions.python_version
795
- torch_version = config.versions.torch_version
796
-
797
- config.versions.gradio_version = gradio_version
798
-
799
- gr.Markdown(
800
- f"""
801
- 🍦 [ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
802
- version: [{git_tag}](https://github.com/lenML/ChatTTS-Forge/commit/{git_commit}) | branch: `{git_branch}` | python: `{python_version}` | torch: `{torch_version}`
803
- """
804
- )
805
-
806
-
807
- def create_interface():
808
-
809
- js_func = """
810
- function refresh() {
811
- const url = new URL(window.location);
812
-
813
- if (url.searchParams.get('__theme') !== 'dark') {
814
- url.searchParams.set('__theme', 'dark');
815
- window.location.href = url.href;
816
- }
817
- }
818
- """
819
-
820
- head_js = """
821
- <script>
822
- </script>
823
- """
824
-
825
- with gr.Blocks(js=js_func, head=head_js, title="ChatTTS Forge WebUI") as demo:
826
- css = """
827
- <style>
828
- .big-button {
829
- height: 80px;
830
- }
831
- #input_title div.eta-bar {
832
- display: none !important; transform: none !important;
833
- }
834
- footer {
835
- display: none !important;
836
- }
837
- </style>
838
- """
839
-
840
- gr.HTML(css)
841
- with gr.Tabs() as tabs:
842
- with gr.TabItem("TTS"):
843
- create_tts_interface()
844
-
845
- with gr.TabItem("SSML", id="ssml"):
846
- ssml_input = create_ssml_interface()
847
-
848
- with gr.TabItem("Long Text"):
849
- create_long_content_tab(ssml_input, tabs=tabs)
850
-
851
- with gr.TabItem("README"):
852
- create_readme_tab()
853
-
854
- create_app_footer()
855
- return demo
856
-
857
 
858
  if __name__ == "__main__":
859
  import argparse
@@ -916,6 +67,12 @@ if __name__ == "__main__":
916
  type=str.lower,
917
  )
918
  parser.add_argument("--compile", action="store_true", help="Enable model compile")
 
 
 
 
 
 
919
 
920
  args = parser.parse_args()
921
 
@@ -936,20 +93,23 @@ if __name__ == "__main__":
936
  device_id = get_and_update_env(args, "device_id", None, str)
937
  use_cpu = get_and_update_env(args, "use_cpu", [], list)
938
  compile = get_and_update_env(args, "compile", False, bool)
 
939
 
940
- webui_config["tts_max"] = get_and_update_env(args, "tts_max_len", 1000, int)
941
- webui_config["ssml_max"] = get_and_update_env(args, "ssml_max_len", 5000, int)
942
- webui_config["max_batch_size"] = get_and_update_env(args, "max_batch_size", 8, int)
943
 
944
  demo = create_interface()
945
 
946
  if auth:
947
  auth = tuple(auth.split(":"))
948
 
949
- generate.setup_lru_cache()
950
  devices.reset_device()
951
  devices.first_time_calculation()
952
 
 
 
953
  demo.queue().launch(
954
  server_name=server_name,
955
  server_port=server_port,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
 
2
  from modules.devices import devices
3
+ from modules.utils import env
4
+ from modules.webui import webui_config
5
+ from modules.webui.app import webui_init, create_interface
6
+ from modules import generate_audio
7
+ from modules import config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  if __name__ == "__main__":
10
  import argparse
 
67
  type=str.lower,
68
  )
69
  parser.add_argument("--compile", action="store_true", help="Enable model compile")
70
+ # webui_Experimental
71
+ parser.add_argument(
72
+ "--webui_experimental",
73
+ action="store_true",
74
+ help="Enable webui_experimental features",
75
+ )
76
 
77
  args = parser.parse_args()
78
 
 
93
  device_id = get_and_update_env(args, "device_id", None, str)
94
  use_cpu = get_and_update_env(args, "use_cpu", [], list)
95
  compile = get_and_update_env(args, "compile", False, bool)
96
+ webui_experimental = get_and_update_env(args, "webui_experimental", False, bool)
97
 
98
+ webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
99
+ webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
100
+ webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
101
 
102
  demo = create_interface()
103
 
104
  if auth:
105
  auth = tuple(auth.split(":"))
106
 
107
+ generate_audio.setup_lru_cache()
108
  devices.reset_device()
109
  devices.first_time_calculation()
110
 
111
+ webui_init()
112
+
113
  demo.queue().launch(
114
  server_name=server_name,
115
  server_port=server_port,