zhzluke96 commited on
Commit
8c22399
1 Parent(s): 4554b6b
launch.py CHANGED
@@ -1,109 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from modules import config
 
3
  from modules import generate_audio as generate
4
-
5
- from functools import lru_cache
6
- from typing import Callable
7
-
8
  from modules.api.Api import APIManager
9
 
10
  from modules.api.impl import (
11
- base_api,
12
  tts_api,
13
  ssml_api,
14
  google_api,
15
  openai_api,
16
  refiner_api,
 
 
 
17
  )
18
 
 
 
19
  torch._dynamo.config.cache_size_limit = 64
20
  torch._dynamo.config.suppress_errors = True
21
  torch.set_float32_matmul_precision("high")
22
 
23
 
24
- def create_api():
25
- api = APIManager()
26
 
27
- base_api.setup(api)
28
- tts_api.setup(api)
29
- ssml_api.setup(api)
30
- google_api.setup(api)
31
- openai_api.setup(api)
32
- refiner_api.setup(api)
 
 
 
33
 
34
- return api
35
 
36
 
37
- def conditional_cache(condition: Callable):
38
- def decorator(func):
39
- @lru_cache(None)
40
- def cached_func(*args, **kwargs):
41
- return func(*args, **kwargs)
42
 
43
- def wrapper(*args, **kwargs):
44
- if condition(*args, **kwargs):
45
- return cached_func(*args, **kwargs)
46
- else:
47
- return func(*args, **kwargs)
48
 
49
- return wrapper
50
-
51
- return decorator
52
-
53
-
54
- if __name__ == "__main__":
55
- import argparse
56
- import uvicorn
57
-
58
- parser = argparse.ArgumentParser(
59
- description="Start the FastAPI server with command line arguments"
60
  )
61
  parser.add_argument(
62
- "--host", type=str, default="0.0.0.0", help="Host to run the server on"
 
 
63
  )
64
  parser.add_argument(
65
- "--port", type=int, default=8000, help="Port to run the server on"
 
 
 
66
  )
67
  parser.add_argument(
68
- "--reload", action="store_true", help="Enable auto-reload for development"
 
 
 
 
69
  )
70
- parser.add_argument("--compile", action="store_true", help="Enable model compile")
71
  parser.add_argument(
72
  "--lru_size",
73
  type=int,
74
  default=64,
75
  help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
76
  )
 
 
 
 
 
 
 
 
77
  parser.add_argument(
78
  "--cors_origin",
79
  type=str,
80
- default="*",
81
  help="Allowed CORS origins. Use '*' to allow all origins.",
82
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- args = parser.parse_args()
85
 
86
- config.args = args
 
 
 
 
 
 
87
 
88
- if args.compile:
89
- print("Model compile is enabled")
90
- config.enable_model_compile = True
91
 
92
- def should_cache(*args, **kwargs):
93
- spk_seed = kwargs.get("spk_seed", -1)
94
- infer_seed = kwargs.get("infer_seed", -1)
95
- return spk_seed != -1 and infer_seed != -1
96
 
97
- if args.lru_size > 0:
98
- config.lru_size = args.lru_size
99
- generate.generate_audio = conditional_cache(should_cache)(
100
- generate.generate_audio
101
- )
102
 
103
- api = create_api()
104
  config.api = api
105
 
106
- if args.cors_origin:
107
- api.set_cors(allow_origins=[args.cors_origin])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- uvicorn.run(api.app, host=args.host, port=args.port, reload=args.reload)
 
1
+ import os
2
+ import logging
3
+
4
+ logging.basicConfig(
5
+ level=os.getenv("LOG_LEVEL", "INFO"),
6
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
7
+ )
8
+
9
+ from modules.devices import devices
10
+ import argparse
11
+ import uvicorn
12
+
13
  import torch
14
  from modules import config
15
+ from modules.utils import env
16
  from modules import generate_audio as generate
 
 
 
 
17
  from modules.api.Api import APIManager
18
 
19
  from modules.api.impl import (
20
+ style_api,
21
  tts_api,
22
  ssml_api,
23
  google_api,
24
  openai_api,
25
  refiner_api,
26
+ speaker_api,
27
+ ping_api,
28
+ models_api,
29
  )
30
 
31
+ logger = logging.getLogger(__name__)
32
+
33
  torch._dynamo.config.cache_size_limit = 64
34
  torch._dynamo.config.suppress_errors = True
35
  torch.set_float32_matmul_precision("high")
36
 
37
 
38
+ def create_api(app, no_docs=False, exclude=[]):
39
+ app_mgr = APIManager(app=app, no_docs=no_docs, exclude_patterns=exclude)
40
 
41
+ ping_api.setup(app_mgr)
42
+ models_api.setup(app_mgr)
43
+ style_api.setup(app_mgr)
44
+ speaker_api.setup(app_mgr)
45
+ tts_api.setup(app_mgr)
46
+ ssml_api.setup(app_mgr)
47
+ google_api.setup(app_mgr)
48
+ openai_api.setup(app_mgr)
49
+ refiner_api.setup(app_mgr)
50
 
51
+ return app_mgr
52
 
53
 
54
+ def get_and_update_env(*args):
55
+ val = env.get_env_or_arg(*args)
56
+ key = args[1]
57
+ config.runtime_env_vars[key] = val
58
+ return val
59
 
 
 
 
 
 
60
 
61
+ def setup_model_args(parser: argparse.ArgumentParser):
62
+ parser.add_argument("--compile", action="store_true", help="Enable model compile")
63
+ parser.add_argument(
64
+ "--half",
65
+ action="store_true",
66
+ help="Enable half precision for model inference",
 
 
 
 
 
67
  )
68
  parser.add_argument(
69
+ "--off_tqdm",
70
+ action="store_true",
71
+ help="Disable tqdm progress bar",
72
  )
73
  parser.add_argument(
74
+ "--device_id",
75
+ type=str,
76
+ help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
77
+ default=None,
78
  )
79
  parser.add_argument(
80
+ "--use_cpu",
81
+ nargs="+",
82
+ help="use CPU as torch device for specified modules",
83
+ default=[],
84
+ type=str.lower,
85
  )
 
86
  parser.add_argument(
87
  "--lru_size",
88
  type=int,
89
  default=64,
90
  help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
91
  )
92
+
93
+
94
+ def setup_api_args(parser: argparse.ArgumentParser):
95
+ parser.add_argument("--api_host", type=str, help="Host to run the server on")
96
+ parser.add_argument("--api_port", type=int, help="Port to run the server on")
97
+ parser.add_argument(
98
+ "--reload", action="store_true", help="Enable auto-reload for development"
99
+ )
100
  parser.add_argument(
101
  "--cors_origin",
102
  type=str,
 
103
  help="Allowed CORS origins. Use '*' to allow all origins.",
104
  )
105
+ parser.add_argument(
106
+ "--no_playground",
107
+ action="store_true",
108
+ help="Disable the playground entry",
109
+ )
110
+ parser.add_argument(
111
+ "--no_docs",
112
+ action="store_true",
113
+ help="Disable the documentation entry",
114
+ )
115
+ # 配置哪些api要跳过 比如 exclude="/v1/speakers/*,/v1/tts/*"
116
+ parser.add_argument(
117
+ "--exclude",
118
+ type=str,
119
+ help="Exclude the specified API from the server",
120
+ )
121
 
 
122
 
123
+ def process_model_args(args):
124
+ lru_size = get_and_update_env(args, "lru_size", 64, int)
125
+ compile = get_and_update_env(args, "compile", False, bool)
126
+ device_id = get_and_update_env(args, "device_id", None, str)
127
+ use_cpu = get_and_update_env(args, "use_cpu", [], list)
128
+ half = get_and_update_env(args, "half", False, bool)
129
+ off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
130
 
131
+ generate.setup_lru_cache()
132
+ devices.reset_device()
133
+ devices.first_time_calculation()
134
 
 
 
 
 
135
 
136
+ def process_api_args(args, app):
137
+ cors_origin = get_and_update_env(args, "cors_origin", "*", str)
138
+ no_playground = get_and_update_env(args, "no_playground", False, bool)
139
+ no_docs = get_and_update_env(args, "no_docs", False, bool)
140
+ exclude = get_and_update_env(args, "exclude", "", str)
141
 
142
+ api = create_api(app=app, no_docs=no_docs, exclude=exclude.split(","))
143
  config.api = api
144
 
145
+ if cors_origin:
146
+ api.set_cors(allow_origins=[cors_origin])
147
+
148
+ if not no_playground:
149
+ api.setup_playground()
150
+
151
+ if compile:
152
+ logger.info("Model compile is enabled")
153
+
154
+
155
+ app_description = """
156
+ ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
157
+ ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax
158
+
159
+ 项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
160
+
161
+ > 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/>
162
+ > All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging
163
+
164
+ > 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/>
165
+ > [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb)
166
+ """
167
+ app_title = "ChatTTS Forge API"
168
+ app_version = "0.1.0"
169
+
170
+ if __name__ == "__main__":
171
+ import dotenv
172
+ from fastapi import FastAPI
173
+
174
+ dotenv.load_dotenv(
175
+ dotenv_path=os.getenv("ENV_FILE", ".env.api"),
176
+ )
177
+
178
+ parser = argparse.ArgumentParser(
179
+ description="Start the FastAPI server with command line arguments"
180
+ )
181
+ setup_api_args(parser)
182
+ setup_model_args(parser)
183
+
184
+ args = parser.parse_args()
185
+
186
+ app = FastAPI(
187
+ title=app_title,
188
+ description=app_description,
189
+ version=app_version,
190
+ redoc_url=None if config.runtime_env_vars.no_docs else "/redoc",
191
+ docs_url=None if config.runtime_env_vars.no_docs else "/docs",
192
+ )
193
+
194
+ process_model_args(args)
195
+ process_api_args(args, app)
196
+
197
+ host = get_and_update_env(args, "api_host", "0.0.0.0", str)
198
+ port = get_and_update_env(args, "api_port", 7870, int)
199
+ reload = get_and_update_env(args, "reload", False, bool)
200
 
201
+ uvicorn.run(app, host=host, port=port, reload=reload)
modules/api/Api.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, APIRouter
2
  from fastapi.middleware.cors import CORSMiddleware
3
 
4
  import logging
@@ -24,25 +24,8 @@ def is_excluded(path, exclude_patterns):
24
 
25
 
26
  class APIManager:
27
- def __init__(self, no_docs=False, exclude_patterns=[]):
28
- self.app = FastAPI(
29
- title="ChatTTS Forge API",
30
- description="""
31
- ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
32
- ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax
33
-
34
- 项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
35
-
36
- > 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/>
37
- > All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging
38
-
39
- > 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/>
40
- > [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb)
41
- """,
42
- version="0.1.0",
43
- redoc_url=None if no_docs else "/redoc",
44
- docs_url=None if no_docs else "/docs",
45
- )
46
  self.registered_apis = {}
47
  self.logger = logging.getLogger(__name__)
48
  self.exclude = exclude_patterns
@@ -57,6 +40,8 @@ ChatTTS-Forge is a powerful text-to-speech generation tool that supports generat
57
  allow_methods: list = ["*"],
58
  allow_headers: list = ["*"],
59
  ):
 
 
60
  self.app.add_middleware(
61
  CORSMiddleware,
62
  allow_origins=allow_origins,
@@ -64,6 +49,7 @@ ChatTTS-Forge is a powerful text-to-speech generation tool that supports generat
64
  allow_methods=allow_methods,
65
  allow_headers=allow_headers,
66
  )
 
67
 
68
  def setup_playground(self):
69
  app = self.app
 
1
+ from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
 
4
  import logging
 
24
 
25
 
26
  class APIManager:
27
+ def __init__(self, app: FastAPI, no_docs=False, exclude_patterns=[]):
28
+ self.app = app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  self.registered_apis = {}
30
  self.logger = logging.getLogger(__name__)
31
  self.exclude = exclude_patterns
 
40
  allow_methods: list = ["*"],
41
  allow_headers: list = ["*"],
42
  ):
43
+ # reset middleware stack
44
+ self.app.middleware_stack = None
45
  self.app.add_middleware(
46
  CORSMiddleware,
47
  allow_origins=allow_origins,
 
49
  allow_methods=allow_methods,
50
  allow_headers=allow_headers,
51
  )
52
+ self.app.build_middleware_stack()
53
 
54
  def setup_playground(self):
55
  app = self.app
modules/api/impl/refiner_api.py CHANGED
@@ -7,6 +7,7 @@ from modules import refiner
7
 
8
  from modules.api import utils as api_utils
9
  from modules.api.Api import APIManager
 
10
 
11
 
12
  class RefineTextRequest(BaseModel):
@@ -18,6 +19,7 @@ class RefineTextRequest(BaseModel):
18
  temperature: float = 0.7
19
  repetition_penalty: float = 1.0
20
  max_new_token: int = 384
 
21
 
22
 
23
  async def refiner_prompt_post(request: RefineTextRequest):
@@ -26,8 +28,11 @@ async def refiner_prompt_post(request: RefineTextRequest):
26
  """
27
 
28
  try:
 
 
 
29
  refined_text = refiner.refine_text(
30
- text=request.text,
31
  prompt=request.prompt,
32
  seed=request.seed,
33
  top_P=request.top_P,
 
7
 
8
  from modules.api import utils as api_utils
9
  from modules.api.Api import APIManager
10
+ from modules.normalization import text_normalize
11
 
12
 
13
  class RefineTextRequest(BaseModel):
 
19
  temperature: float = 0.7
20
  repetition_penalty: float = 1.0
21
  max_new_token: int = 384
22
+ normalize: bool = True
23
 
24
 
25
  async def refiner_prompt_post(request: RefineTextRequest):
 
28
  """
29
 
30
  try:
31
+ text = request.text
32
+ if request.normalize:
33
+ text = text_normalize(request.text)
34
  refined_text = refiner.refine_text(
35
+ text=text,
36
  prompt=request.prompt,
37
  seed=request.seed,
38
  top_P=request.top_P,
modules/api/impl/speaker_api.py CHANGED
@@ -35,10 +35,14 @@ def setup(app: APIManager):
35
 
36
  @app.get("/v1/speakers/list", response_model=api_utils.BaseResponse)
37
  async def list_speakers():
38
- return {
39
- "message": "ok",
40
- "data": [spk.to_json() for spk in speaker_mgr.list_speakers()],
41
- }
 
 
 
 
42
 
43
  @app.post("/v1/speakers/update", response_model=api_utils.BaseResponse)
44
  async def update_speakers(request: SpeakersUpdate):
@@ -59,7 +63,8 @@ def setup(app: APIManager):
59
  # number array => Tensor
60
  speaker.emb = torch.tensor(spk["tensor"])
61
  speaker_mgr.save_all()
62
- return {"message": "ok", "data": None}
 
63
 
64
  @app.post("/v1/speaker/create", response_model=api_utils.BaseResponse)
65
  async def create_speaker(request: CreateSpeaker):
@@ -88,12 +93,7 @@ def setup(app: APIManager):
88
  raise HTTPException(
89
  status_code=400, detail="Missing tensor or seed in request"
90
  )
91
- return {"message": "ok", "data": speaker.to_json()}
92
-
93
- @app.post("/v1/speaker/refresh", response_model=api_utils.BaseResponse)
94
- async def refresh_speakers():
95
- speaker_mgr.refresh_speakers()
96
- return {"message": "ok"}
97
 
98
  @app.post("/v1/speaker/update", response_model=api_utils.BaseResponse)
99
  async def update_speaker(request: UpdateSpeaker):
@@ -113,11 +113,11 @@ def setup(app: APIManager):
113
  # number array => Tensor
114
  speaker.emb = torch.tensor(request.tensor)
115
  speaker_mgr.update_speaker(speaker)
116
- return {"message": "ok"}
117
 
118
  @app.post("/v1/speaker/detail", response_model=api_utils.BaseResponse)
119
  async def speaker_detail(request: SpeakerDetail):
120
  speaker = speaker_mgr.get_speaker_by_id(request.id)
121
  if speaker is None:
122
  raise HTTPException(status_code=404, detail="Speaker not found")
123
- return {"message": "ok", "data": speaker.to_json(with_emb=request.with_emb)}
 
35
 
36
  @app.get("/v1/speakers/list", response_model=api_utils.BaseResponse)
37
  async def list_speakers():
38
+ return api_utils.success_response(
39
+ [spk.to_json() for spk in speaker_mgr.list_speakers()]
40
+ )
41
+
42
+ @app.post("/v1/speakers/refresh", response_model=api_utils.BaseResponse)
43
+ async def refresh_speakers():
44
+ speaker_mgr.refresh_speakers()
45
+ return api_utils.success_response(None)
46
 
47
  @app.post("/v1/speakers/update", response_model=api_utils.BaseResponse)
48
  async def update_speakers(request: SpeakersUpdate):
 
63
  # number array => Tensor
64
  speaker.emb = torch.tensor(spk["tensor"])
65
  speaker_mgr.save_all()
66
+
67
+ return api_utils.success_response(None)
68
 
69
  @app.post("/v1/speaker/create", response_model=api_utils.BaseResponse)
70
  async def create_speaker(request: CreateSpeaker):
 
93
  raise HTTPException(
94
  status_code=400, detail="Missing tensor or seed in request"
95
  )
96
+ return api_utils.success_response(speaker.to_json())
 
 
 
 
 
97
 
98
  @app.post("/v1/speaker/update", response_model=api_utils.BaseResponse)
99
  async def update_speaker(request: UpdateSpeaker):
 
113
  # number array => Tensor
114
  speaker.emb = torch.tensor(request.tensor)
115
  speaker_mgr.update_speaker(speaker)
116
+ return api_utils.success_response(None)
117
 
118
  @app.post("/v1/speaker/detail", response_model=api_utils.BaseResponse)
119
  async def speaker_detail(request: SpeakerDetail):
120
  speaker = speaker_mgr.get_speaker_by_id(request.id)
121
  if speaker is None:
122
  raise HTTPException(status_code=404, detail="Speaker not found")
123
+ return api_utils.success_response(speaker.to_json(with_emb=request.with_emb))
modules/api/impl/tts_api.py CHANGED
@@ -9,8 +9,6 @@ from fastapi.responses import FileResponse
9
 
10
  from modules.normalization import text_normalize
11
 
12
- from modules import generate_audio as generate
13
-
14
  from modules.api import utils as api_utils
15
  from modules.api.Api import APIManager
16
  from modules.synthesize_audio import synthesize_audio
 
9
 
10
  from modules.normalization import text_normalize
11
 
 
 
12
  from modules.api import utils as api_utils
13
  from modules.api.Api import APIManager
14
  from modules.synthesize_audio import synthesize_audio
modules/gradio_dcls_fix.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def dcls_patch():
2
+ from gradio import data_classes
3
+
4
+ data_classes.PredictBody.__get_pydantic_json_schema__ = lambda x, y: {
5
+ "type": "object",
6
+ }
modules/webui/app.py CHANGED
@@ -46,11 +46,19 @@ def create_app_footer():
46
 
47
  config.versions.gradio_version = gradio_version
48
 
 
 
 
 
 
 
 
 
 
 
 
49
  gr.Markdown(
50
- f"""
51
- 🍦 [ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
52
- version: [{git_tag}](https://github.com/lenML/ChatTTS-Forge/commit/{git_commit}) | branch: `{git_branch}` | python: `{python_version}` | torch: `{torch_version}`
53
- """,
54
  elem_classes=["no-translate"],
55
  )
56
 
 
46
 
47
  config.versions.gradio_version = gradio_version
48
 
49
+ footer_items = ["🍦 [ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)"]
50
+ footer_items.append(
51
+ f"version: [{git_tag}](https://github.com/lenML/ChatTTS-Forge/commit/{git_commit})"
52
+ )
53
+ footer_items.append(f"branch: `{git_branch}`")
54
+ footer_items.append(f"python: `{python_version}`")
55
+ footer_items.append(f"torch: `{torch_version}`")
56
+
57
+ if config.runtime_env_vars.api and not config.runtime_env_vars.no_docs:
58
+ footer_items.append(f"[API](/docs)")
59
+
60
  gr.Markdown(
61
+ " | ".join(footer_items),
 
 
 
62
  elem_classes=["no-translate"],
63
  )
64
 
modules/webui/js/localization.js CHANGED
@@ -163,6 +163,23 @@ function localizeWholePage() {
163
  }
164
  }
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  document.addEventListener("DOMContentLoaded", function () {
167
  if (!hasLocalization()) {
168
  return;
@@ -170,9 +187,11 @@ document.addEventListener("DOMContentLoaded", function () {
170
 
171
  onUiUpdate(function (m) {
172
  m.forEach(function (mutation) {
173
- mutation.addedNodes.forEach(function (node) {
174
- processNode(node);
175
- });
 
 
176
  });
177
  });
178
 
 
163
  }
164
  }
165
 
166
+ /**
167
+ *
168
+ * @param {HTMLElement} node
169
+ */
170
+ function isNeedTranslate(node) {
171
+ if (!node) return false;
172
+ if (!(node instanceof HTMLElement)) return true;
173
+ while (node.parentElement !== document.body) {
174
+ if (node.classList.contains("no-translate")) {
175
+ return false;
176
+ }
177
+ node = node.parentElement;
178
+ if (!node) break;
179
+ }
180
+ return true;
181
+ }
182
+
183
  document.addEventListener("DOMContentLoaded", function () {
184
  if (!hasLocalization()) {
185
  return;
 
187
 
188
  onUiUpdate(function (m) {
189
  m.forEach(function (mutation) {
190
+ Array.from(mutation.addedNodes)
191
+ .filter(isNeedTranslate)
192
+ .forEach(function (node) {
193
+ processNode(node);
194
+ });
195
  });
196
  });
197
 
modules/webui/tts_tab.py CHANGED
@@ -96,7 +96,7 @@ def create_tts_interface():
96
  )
97
 
98
  gr.Markdown("📝Speaker info")
99
- infos = gr.Markdown("empty")
100
 
101
  spk_file_upload.change(
102
  fn=load_spk_info,
 
96
  )
97
 
98
  gr.Markdown("📝Speaker info")
99
+ infos = gr.Markdown("empty", elem_classes=["no-translate"])
100
 
101
  spk_file_upload.change(
102
  fn=load_spk_info,
webui.py CHANGED
@@ -1,27 +1,30 @@
1
  import os
2
  import logging
3
 
4
- # logging.basicConfig(
5
- # level=os.getenv("LOG_LEVEL", "INFO"),
6
- # format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
7
- # )
8
 
9
- from modules.devices import devices
10
- from modules.utils import env
 
 
 
 
 
 
 
 
11
  from modules.webui import webui_config
12
  from modules.webui.app import webui_init, create_interface
13
- from modules import generate_audio
14
- from modules import config
15
 
16
- if __name__ == "__main__":
17
- import argparse
18
- import dotenv
19
 
20
- dotenv.load_dotenv(
21
- dotenv_path=os.getenv("ENV_FILE", ".env.webui"),
22
- )
23
 
24
- parser = argparse.ArgumentParser(description="Gradio App")
25
  parser.add_argument("--server_name", type=str, help="server name")
26
  parser.add_argument("--server_port", type=int, help="server port")
27
  parser.add_argument(
@@ -29,16 +32,6 @@ if __name__ == "__main__":
29
  )
30
  parser.add_argument("--debug", action="store_true", help="enable debug mode")
31
  parser.add_argument("--auth", type=str, help="username:password for authentication")
32
- parser.add_argument(
33
- "--half",
34
- action="store_true",
35
- help="Enable half precision for model inference",
36
- )
37
- parser.add_argument(
38
- "--off_tqdm",
39
- action="store_true",
40
- help="Disable tqdm progress bar",
41
- )
42
  parser.add_argument(
43
  "--tts_max_len",
44
  type=int,
@@ -54,58 +47,39 @@ if __name__ == "__main__":
54
  type=int,
55
  help="Max batch size for TTS",
56
  )
57
- parser.add_argument(
58
- "--lru_size",
59
- type=int,
60
- default=64,
61
- help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
62
- )
63
- parser.add_argument(
64
- "--device_id",
65
- type=str,
66
- help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
67
- default=None,
68
- )
69
- parser.add_argument(
70
- "--use_cpu",
71
- nargs="+",
72
- help="use CPU as torch device for specified modules",
73
- default=[],
74
- type=str.lower,
75
- )
76
- parser.add_argument("--compile", action="store_true", help="Enable model compile")
77
  # webui_Experimental
78
  parser.add_argument(
79
  "--webui_experimental",
80
  action="store_true",
81
  help="Enable webui_experimental features",
82
  )
83
-
84
  parser.add_argument(
85
  "--language",
86
  type=str,
87
  help="Set the default language for the webui",
88
  )
89
- args = parser.parse_args()
 
 
 
 
90
 
91
- def get_and_update_env(*args):
92
- val = env.get_env_or_arg(*args)
93
- key = args[1]
94
- config.runtime_env_vars[key] = val
95
- return val
96
 
 
97
  server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
98
  server_port = get_and_update_env(args, "server_port", 7860, int)
99
  share = get_and_update_env(args, "share", False, bool)
100
  debug = get_and_update_env(args, "debug", False, bool)
101
  auth = get_and_update_env(args, "auth", None, str)
102
- half = get_and_update_env(args, "half", False, bool)
103
- off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
104
- lru_size = get_and_update_env(args, "lru_size", 64, int)
105
- device_id = get_and_update_env(args, "device_id", None, str)
106
- use_cpu = get_and_update_env(args, "use_cpu", [], list)
107
- compile = get_and_update_env(args, "compile", False, bool)
108
  language = get_and_update_env(args, "language", "zh-CN", str)
 
 
 
 
 
 
 
 
109
 
110
  webui_config.experimental = get_and_update_env(
111
  args, "webui_experimental", False, bool
@@ -120,15 +94,57 @@ if __name__ == "__main__":
120
  if auth:
121
  auth = tuple(auth.split(":"))
122
 
123
- generate_audio.setup_lru_cache()
124
- devices.reset_device()
125
- devices.first_time_calculation()
126
-
127
- demo.queue().launch(
128
  server_name=server_name,
129
  server_port=server_port,
130
  share=share,
131
  debug=debug,
132
  auth=auth,
133
  show_api=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import logging
3
 
4
+ logging.basicConfig(
5
+ level=os.getenv("LOG_LEVEL", "INFO"),
6
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
7
+ )
8
 
9
+ from launch import (
10
+ get_and_update_env,
11
+ setup_api_args,
12
+ setup_model_args,
13
+ process_api_args,
14
+ process_model_args,
15
+ app_description,
16
+ app_title,
17
+ app_version,
18
+ )
19
  from modules.webui import webui_config
20
  from modules.webui.app import webui_init, create_interface
21
+ import argparse
22
+ from modules.gradio_dcls_fix import dcls_patch
23
 
24
+ dcls_patch()
 
 
25
 
 
 
 
26
 
27
+ def setup_webui_args(parser: argparse.ArgumentParser):
28
  parser.add_argument("--server_name", type=str, help="server name")
29
  parser.add_argument("--server_port", type=int, help="server port")
30
  parser.add_argument(
 
32
  )
33
  parser.add_argument("--debug", action="store_true", help="enable debug mode")
34
  parser.add_argument("--auth", type=str, help="username:password for authentication")
 
 
 
 
 
 
 
 
 
 
35
  parser.add_argument(
36
  "--tts_max_len",
37
  type=int,
 
47
  type=int,
48
  help="Max batch size for TTS",
49
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # webui_Experimental
51
  parser.add_argument(
52
  "--webui_experimental",
53
  action="store_true",
54
  help="Enable webui_experimental features",
55
  )
 
56
  parser.add_argument(
57
  "--language",
58
  type=str,
59
  help="Set the default language for the webui",
60
  )
61
+ parser.add_argument(
62
+ "--api",
63
+ action="store_true",
64
+ help="use api=True to launch the API together with the webui (run launch.py for only API server)",
65
+ )
66
 
 
 
 
 
 
67
 
68
+ def process_webui_args(args):
69
  server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
70
  server_port = get_and_update_env(args, "server_port", 7860, int)
71
  share = get_and_update_env(args, "share", False, bool)
72
  debug = get_and_update_env(args, "debug", False, bool)
73
  auth = get_and_update_env(args, "auth", None, str)
 
 
 
 
 
 
74
  language = get_and_update_env(args, "language", "zh-CN", str)
75
+ api = get_and_update_env(args, "api", "zh-CN", str)
76
+
77
+ webui_config.experimental = get_and_update_env(
78
+ args, "webui_experimental", False, bool
79
+ )
80
+ webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
81
+ webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
82
+ webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
83
 
84
  webui_config.experimental = get_and_update_env(
85
  args, "webui_experimental", False, bool
 
94
  if auth:
95
  auth = tuple(auth.split(":"))
96
 
97
+ app, local_url, share_url = demo.queue().launch(
 
 
 
 
98
  server_name=server_name,
99
  server_port=server_port,
100
  share=share,
101
  debug=debug,
102
  auth=auth,
103
  show_api=False,
104
+ prevent_thread_lock=True,
105
+ app_kwargs={
106
+ "title": app_title,
107
+ "description": app_description,
108
+ "version": app_version,
109
+ # "redoc_url": (
110
+ # None
111
+ # if api is False
112
+ # else None if config.runtime_env_vars.no_docs else "/redoc"
113
+ # ),
114
+ # "docs_url": (
115
+ # None
116
+ # if api is False
117
+ # else None if config.runtime_env_vars.no_docs else "/docs"
118
+ # ),
119
+ "docs_url": "/docs",
120
+ },
121
+ )
122
+ # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
123
+ # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
124
+ # running web ui and do whatever the attacker wants, including installing an extension and
125
+ # running its code. We disable this here. Suggested by RyotaK.
126
+ app.user_middleware = [
127
+ x for x in app.user_middleware if x.cls.__name__ != "CustomCORSMiddleware"
128
+ ]
129
+
130
+ if api:
131
+ process_api_args(args, app)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ import dotenv
136
+
137
+ dotenv.load_dotenv(
138
+ dotenv_path=os.getenv("ENV_FILE", ".env.webui"),
139
  )
140
+
141
+ parser = argparse.ArgumentParser(description="Gradio App")
142
+
143
+ setup_webui_args(parser)
144
+ setup_model_args(parser)
145
+ setup_api_args(parser)
146
+
147
+ args = parser.parse_args()
148
+
149
+ process_model_args(args)
150
+ process_webui_args(args)