chenjgtea
commited on
Commit
·
52b0147
1
Parent(s):
638a294
gpu模式下解决特殊字符转换处理
Browse files- Chat2TTS/core.py +14 -14
- web/app_cpu.py +3 -1
- web/app_gpu.py +15 -12
Chat2TTS/core.py
CHANGED
@@ -161,22 +161,22 @@ class Chat:
|
|
161 |
|
162 |
assert self.check_model(use_decoder=use_decoder)
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
if skip_refine_text:
|
165 |
self.logger.info("========对文本内容不做优化处理,仅做规则处理======")
|
166 |
-
|
167 |
-
text = [text]
|
168 |
-
|
169 |
-
text = [
|
170 |
-
self.normalizer(
|
171 |
-
text=t,
|
172 |
-
do_text_normalization=True,
|
173 |
-
do_homophone_replacement=True,
|
174 |
-
lang=None,
|
175 |
-
)
|
176 |
-
for t in text
|
177 |
-
]
|
178 |
-
|
179 |
-
if not skip_refine_text:
|
180 |
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
|
181 |
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
|
182 |
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
|
|
|
161 |
|
162 |
assert self.check_model(use_decoder=use_decoder)
|
163 |
|
164 |
+
if not isinstance(text, list):
|
165 |
+
text = [text]
|
166 |
+
|
167 |
+
text = [
|
168 |
+
self.normalizer(
|
169 |
+
text=t,
|
170 |
+
do_text_normalization=True,
|
171 |
+
do_homophone_replacement=True,
|
172 |
+
lang=None,
|
173 |
+
)
|
174 |
+
for t in text
|
175 |
+
]
|
176 |
+
|
177 |
if skip_refine_text:
|
178 |
self.logger.info("========对文本内容不做优化处理,仅做规则处理======")
|
179 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
|
181 |
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
|
182 |
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
|
web/app_cpu.py
CHANGED
@@ -175,7 +175,9 @@ def main(args):
|
|
175 |
logger.info("元素初始化完成,启动gradio服务=======")
|
176 |
|
177 |
# 运行gradio服务
|
178 |
-
demo.launch(
|
|
|
|
|
179 |
|
180 |
|
181 |
'''
|
|
|
175 |
logger.info("元素初始化完成,启动gradio服务=======")
|
176 |
|
177 |
# 运行gradio服务
|
178 |
+
demo.launch(server_name=args.server_name,
|
179 |
+
server_port=args.server_port,
|
180 |
+
share=False)
|
181 |
|
182 |
|
183 |
'''
|
web/app_gpu.py
CHANGED
@@ -60,16 +60,17 @@ def main(args):
|
|
60 |
interactive=True,
|
61 |
)
|
62 |
with gr.Row():
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
73 |
temperature_slider = gr.Slider(
|
74 |
minimum=0.00001,
|
75 |
maximum=1.0,
|
@@ -179,7 +180,9 @@ def main(args):
|
|
179 |
logger.info("元素初始化完成,启动gradio服务=======")
|
180 |
|
181 |
# 运行gradio服务
|
182 |
-
demo.launch(
|
|
|
|
|
183 |
|
184 |
|
185 |
'''
|
@@ -299,7 +302,7 @@ def general_chat_infer_audio(text,
|
|
299 |
if __name__ == "__main__":
|
300 |
parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
|
301 |
parser.add_argument(
|
302 |
-
"--server_name", type=str, default="
|
303 |
)
|
304 |
parser.add_argument("--server_port", type=int, default=7860, help="server port")
|
305 |
parser.add_argument(
|
|
|
60 |
interactive=True,
|
61 |
)
|
62 |
with gr.Row():
|
63 |
+
with gr.Column():
|
64 |
+
refine_text_checkBox = gr.Checkbox(
|
65 |
+
label="是否优化文本,如是则会对文本内容做基于模型优化",
|
66 |
+
interactive=True,
|
67 |
+
value=True
|
68 |
+
)
|
69 |
+
refine_audio_checkBox = gr.Checkbox(
|
70 |
+
label="是否生成音频文件,如是才会生成音频文件",
|
71 |
+
interactive=True,
|
72 |
+
value=True
|
73 |
+
)
|
74 |
temperature_slider = gr.Slider(
|
75 |
minimum=0.00001,
|
76 |
maximum=1.0,
|
|
|
180 |
logger.info("元素初始化完成,启动gradio服务=======")
|
181 |
|
182 |
# 运行gradio服务
|
183 |
+
demo.launch(server_name=args.server_name,
|
184 |
+
server_port=args.server_port,
|
185 |
+
share=False)
|
186 |
|
187 |
|
188 |
'''
|
|
|
302 |
if __name__ == "__main__":
|
303 |
parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
|
304 |
parser.add_argument(
|
305 |
+
"--server_name", type=str, default="127.0.0.1", help="server name"
|
306 |
)
|
307 |
parser.add_argument("--server_port", type=int, default=7860, help="server port")
|
308 |
parser.add_argument(
|