zhzluke96
commited on
Commit
·
02e90e4
1
Parent(s):
d6fe286
update
Browse files- modules/ChatTTS/ChatTTS/core.py +21 -5
- modules/ChatTTS/ChatTTS/model/dvae.py +3 -3
- modules/ChatTTS/ChatTTS/model/gpt.py +2 -3
- modules/ChatTTS/ChatTTS/utils/gpu_utils.py +3 -1
- modules/ChatTTS/ChatTTS/utils/infer_utils.py +5 -5
- modules/SynthesizeSegments.py +2 -2
- modules/api/Api.py +12 -1
- modules/api/impl/google_api.py +16 -3
- modules/api/impl/models_api.py +11 -0
- modules/api/impl/openai_api.py +19 -7
- modules/api/impl/ping_api.py +8 -0
- modules/api/utils.py +6 -2
- modules/config.py +2 -8
- modules/devices/__init__.py +0 -0
- modules/devices/devices.py +160 -0
- modules/devices/mac_devices.py +42 -0
- modules/generate_audio.py +33 -5
- modules/models.py +24 -20
- modules/normalization.py +47 -2
- modules/refiner.py +1 -1
- modules/speaker.py +11 -2
- modules/synthesize_audio.py +2 -1
- modules/utils/JsonObject.py +113 -0
- modules/utils/cache.py +92 -0
- modules/utils/zh_normalization/text_normlization.py +3 -3
- webui.py +49 -22
modules/ChatTTS/ChatTTS/core.py
CHANGED
@@ -101,13 +101,27 @@ class Chat:
|
|
101 |
tokenizer_path: str = None,
|
102 |
device: str = None,
|
103 |
compile: bool = True,
|
|
|
|
|
|
|
|
|
|
|
104 |
):
|
105 |
if not device:
|
106 |
device = select_device(4096)
|
107 |
self.logger.log(logging.INFO, f"use {device}")
|
108 |
|
|
|
|
|
|
|
|
|
|
|
109 |
if vocos_config_path:
|
110 |
-
vocos =
|
|
|
|
|
|
|
|
|
111 |
assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
|
112 |
vocos.load_state_dict(torch.load(vocos_ckpt_path))
|
113 |
self.pretrain_models["vocos"] = vocos
|
@@ -115,7 +129,7 @@ class Chat:
|
|
115 |
|
116 |
if dvae_config_path:
|
117 |
cfg = OmegaConf.load(dvae_config_path)
|
118 |
-
dvae = DVAE(**cfg).to(device).eval()
|
119 |
assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
|
120 |
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location=device))
|
121 |
self.pretrain_models["dvae"] = dvae
|
@@ -123,7 +137,7 @@ class Chat:
|
|
123 |
|
124 |
if gpt_config_path:
|
125 |
cfg = OmegaConf.load(gpt_config_path)
|
126 |
-
gpt = GPT_warpper(**cfg).to(device).eval()
|
127 |
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
|
128 |
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device))
|
129 |
if compile and "cuda" in str(device):
|
@@ -136,12 +150,14 @@ class Chat:
|
|
136 |
assert os.path.exists(
|
137 |
spk_stat_path
|
138 |
), f"Missing spk_stat.pt: {spk_stat_path}"
|
139 |
-
self.pretrain_models["spk_stat"] = torch.load(spk_stat_path).to(
|
|
|
|
|
140 |
self.logger.log(logging.INFO, "gpt loaded.")
|
141 |
|
142 |
if decoder_config_path:
|
143 |
cfg = OmegaConf.load(decoder_config_path)
|
144 |
-
decoder = DVAE(**cfg).to(device).eval()
|
145 |
assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
|
146 |
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location=device))
|
147 |
self.pretrain_models["decoder"] = decoder
|
|
|
101 |
tokenizer_path: str = None,
|
102 |
device: str = None,
|
103 |
compile: bool = True,
|
104 |
+
dtype: torch.dtype = torch.float32,
|
105 |
+
dtype_vocos: torch.dtype = None,
|
106 |
+
dtype_dvae: torch.dtype = None,
|
107 |
+
dtype_gpt: torch.dtype = None,
|
108 |
+
dtype_decoder: torch.dtype = None,
|
109 |
):
|
110 |
if not device:
|
111 |
device = select_device(4096)
|
112 |
self.logger.log(logging.INFO, f"use {device}")
|
113 |
|
114 |
+
dtype_vocos = dtype_vocos or dtype
|
115 |
+
dtype_dvae = dtype_dvae or dtype
|
116 |
+
dtype_gpt = dtype_gpt or dtype
|
117 |
+
dtype_decoder = dtype_decoder or dtype
|
118 |
+
|
119 |
if vocos_config_path:
|
120 |
+
vocos = (
|
121 |
+
Vocos.from_hparams(vocos_config_path)
|
122 |
+
.to(device=device, dtype=dtype_vocos)
|
123 |
+
.eval()
|
124 |
+
)
|
125 |
assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
|
126 |
vocos.load_state_dict(torch.load(vocos_ckpt_path))
|
127 |
self.pretrain_models["vocos"] = vocos
|
|
|
129 |
|
130 |
if dvae_config_path:
|
131 |
cfg = OmegaConf.load(dvae_config_path)
|
132 |
+
dvae = DVAE(**cfg).to(device=device, dtype=dtype_dvae).eval()
|
133 |
assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
|
134 |
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location=device))
|
135 |
self.pretrain_models["dvae"] = dvae
|
|
|
137 |
|
138 |
if gpt_config_path:
|
139 |
cfg = OmegaConf.load(gpt_config_path)
|
140 |
+
gpt = GPT_warpper(**cfg).to(device=device, dtype=dtype_gpt).eval()
|
141 |
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
|
142 |
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device))
|
143 |
if compile and "cuda" in str(device):
|
|
|
150 |
assert os.path.exists(
|
151 |
spk_stat_path
|
152 |
), f"Missing spk_stat.pt: {spk_stat_path}"
|
153 |
+
self.pretrain_models["spk_stat"] = torch.load(spk_stat_path).to(
|
154 |
+
device=device, dtype=dtype
|
155 |
+
)
|
156 |
self.logger.log(logging.INFO, "gpt loaded.")
|
157 |
|
158 |
if decoder_config_path:
|
159 |
cfg = OmegaConf.load(decoder_config_path)
|
160 |
+
decoder = DVAE(**cfg).to(device=device, dtype=dtype_decoder).eval()
|
161 |
assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
|
162 |
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location=device))
|
163 |
self.pretrain_models["decoder"] = decoder
|
modules/ChatTTS/ChatTTS/model/dvae.py
CHANGED
@@ -143,9 +143,9 @@ class DVAE(nn.Module):
|
|
143 |
else:
|
144 |
vq_feats = inp.detach().clone()
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
|
150 |
vq_feats = vq_feats.transpose(1, 2)
|
151 |
dec_out = self.decoder(input=vq_feats)
|
|
|
143 |
else:
|
144 |
vq_feats = inp.detach().clone()
|
145 |
|
146 |
+
vq_feats = vq_feats.view(
|
147 |
+
(vq_feats.size(0), 2, vq_feats.size(1)//2, vq_feats.size(2)),
|
148 |
+
).permute(0, 2, 3, 1).flatten(2)
|
149 |
|
150 |
vq_feats = vq_feats.transpose(1, 2)
|
151 |
dec_out = self.decoder(input=vq_feats)
|
modules/ChatTTS/ChatTTS/model/gpt.py
CHANGED
@@ -190,6 +190,8 @@ class GPT_warpper(nn.Module):
|
|
190 |
attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask
|
191 |
|
192 |
for i in tqdm(range(max_new_token)):
|
|
|
|
|
193 |
|
194 |
model_input = self.prepare_inputs_for_generation(inputs_ids,
|
195 |
outputs.past_key_values if i!=0 else None,
|
@@ -250,9 +252,6 @@ class GPT_warpper(nn.Module):
|
|
250 |
|
251 |
end_idx = end_idx + (~finish).int()
|
252 |
|
253 |
-
if finish.all():
|
254 |
-
break
|
255 |
-
|
256 |
inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
|
257 |
inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
|
258 |
|
|
|
190 |
attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask
|
191 |
|
192 |
for i in tqdm(range(max_new_token)):
|
193 |
+
if finish.all():
|
194 |
+
continue
|
195 |
|
196 |
model_input = self.prepare_inputs_for_generation(inputs_ids,
|
197 |
outputs.past_key_values if i!=0 else None,
|
|
|
252 |
|
253 |
end_idx = end_idx + (~finish).int()
|
254 |
|
|
|
|
|
|
|
255 |
inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
|
256 |
inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
|
257 |
|
modules/ChatTTS/ChatTTS/utils/gpu_utils.py
CHANGED
@@ -16,8 +16,10 @@ def select_device(min_memory = 2048):
|
|
16 |
if free_memory_mb < min_memory:
|
17 |
logger.log(logging.WARNING, f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.')
|
18 |
device = torch.device('cpu')
|
|
|
|
|
19 |
else:
|
20 |
logger.log(logging.WARNING, f'No GPU found, use CPU instead')
|
21 |
device = torch.device('cpu')
|
22 |
|
23 |
-
return device
|
|
|
16 |
if free_memory_mb < min_memory:
|
17 |
logger.log(logging.WARNING, f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.')
|
18 |
device = torch.device('cpu')
|
19 |
+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
20 |
+
device = torch.device('mps')
|
21 |
else:
|
22 |
logger.log(logging.WARNING, f'No GPU found, use CPU instead')
|
23 |
device = torch.device('cpu')
|
24 |
|
25 |
+
return device
|
modules/ChatTTS/ChatTTS/utils/infer_utils.py
CHANGED
@@ -101,8 +101,8 @@ character_map = {
|
|
101 |
"!": ".",
|
102 |
"(": ",",
|
103 |
")": ",",
|
104 |
-
|
105 |
-
|
106 |
">": ",",
|
107 |
"<": ",",
|
108 |
"-": ",",
|
@@ -131,11 +131,11 @@ halfwidth_2_fullwidth_map = {
|
|
131 |
">": ">",
|
132 |
"?": "?",
|
133 |
"@": "@",
|
134 |
-
|
135 |
"\\": "\",
|
136 |
-
|
137 |
"^": "^",
|
138 |
-
|
139 |
"`": "`",
|
140 |
"{": "{",
|
141 |
"|": "|",
|
|
|
101 |
"!": ".",
|
102 |
"(": ",",
|
103 |
")": ",",
|
104 |
+
"[": ",",
|
105 |
+
"]": ",",
|
106 |
">": ",",
|
107 |
"<": ",",
|
108 |
"-": ",",
|
|
|
131 |
">": ">",
|
132 |
"?": "?",
|
133 |
"@": "@",
|
134 |
+
"[": "[",
|
135 |
"\\": "\",
|
136 |
+
"]": "]",
|
137 |
"^": "^",
|
138 |
+
"_": "_",
|
139 |
"`": "`",
|
140 |
"{": "{",
|
141 |
"|": "|",
|
modules/SynthesizeSegments.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import numpy as np
|
2 |
from pydub import AudioSegment
|
3 |
-
from typing import Any, List, Dict
|
4 |
from scipy.io.wavfile import write
|
5 |
import io
|
6 |
from modules.utils.audio import time_stretch, pitch_shift
|
@@ -211,7 +211,7 @@ def generate_audio_segment(
|
|
211 |
return AudioSegment.from_file(byte_io, format="wav")
|
212 |
|
213 |
|
214 |
-
def synthesize_segment(segment: Dict[str, Any]) -> AudioSegment
|
215 |
if "break" in segment:
|
216 |
pause_segment = AudioSegment.silent(duration=segment["break"])
|
217 |
return pause_segment
|
|
|
1 |
import numpy as np
|
2 |
from pydub import AudioSegment
|
3 |
+
from typing import Any, List, Dict, Union
|
4 |
from scipy.io.wavfile import write
|
5 |
import io
|
6 |
from modules.utils.audio import time_stretch, pitch_shift
|
|
|
211 |
return AudioSegment.from_file(byte_io, format="wav")
|
212 |
|
213 |
|
214 |
+
def synthesize_segment(segment: Dict[str, Any]) -> Union[AudioSegment, None]:
|
215 |
if "break" in segment:
|
216 |
pause_segment = AudioSegment.silent(duration=segment["break"])
|
217 |
return pause_segment
|
modules/api/Api.py
CHANGED
@@ -27,7 +27,18 @@ class APIManager:
|
|
27 |
def __init__(self, no_docs=False, exclude_patterns=[]):
|
28 |
self.app = FastAPI(
|
29 |
title="ChatTTS Forge API",
|
30 |
-
description="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
version="0.1.0",
|
32 |
redoc_url=None if no_docs else "/redoc",
|
33 |
docs_url=None if no_docs else "/docs",
|
|
|
27 |
def __init__(self, no_docs=False, exclude_patterns=[]):
|
28 |
self.app = FastAPI(
|
29 |
title="ChatTTS Forge API",
|
30 |
+
description="""
|
31 |
+
ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
|
32 |
+
ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax
|
33 |
+
|
34 |
+
项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
|
35 |
+
|
36 |
+
> 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/>
|
37 |
+
> All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging
|
38 |
+
|
39 |
+
> 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/>
|
40 |
+
> [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb)
|
41 |
+
""",
|
42 |
version="0.1.0",
|
43 |
redoc_url=None if no_docs else "/redoc",
|
44 |
docs_url=None if no_docs else "/docs",
|
modules/api/impl/google_api.py
CHANGED
@@ -30,6 +30,7 @@ class SynthesisInput(BaseModel):
|
|
30 |
|
31 |
class VoiceSelectionParams(BaseModel):
|
32 |
languageCode: str = "ZH-CN"
|
|
|
33 |
name: str = "female2"
|
34 |
style: str = ""
|
35 |
temperature: float = 0.3
|
@@ -160,6 +161,18 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
160 |
|
161 |
|
162 |
def setup(app: APIManager):
|
163 |
-
app.post(
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
class VoiceSelectionParams(BaseModel):
|
32 |
languageCode: str = "ZH-CN"
|
33 |
+
|
34 |
name: str = "female2"
|
35 |
style: str = ""
|
36 |
temperature: float = 0.3
|
|
|
161 |
|
162 |
|
163 |
def setup(app: APIManager):
|
164 |
+
app.post(
|
165 |
+
"/v1/text:synthesize",
|
166 |
+
response_model=GoogleTextSynthesizeResponse,
|
167 |
+
description="""
|
168 |
+
google api document: <br/>
|
169 |
+
[https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize](https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize)
|
170 |
+
|
171 |
+
- 多个属性在本系统中无用仅仅是为了兼容google api
|
172 |
+
- voice 中的 topP, topK, temperature 为本系统中的参数
|
173 |
+
- voice.name 即 speaker name (或者speaker seed)
|
174 |
+
- voice.seed 为 infer seed (可在webui中测试具体作用)
|
175 |
+
|
176 |
+
- 编码格式影响的是 audioContent 的二进制格式,所以所有format都是返回带有base64数据的json
|
177 |
+
""",
|
178 |
+
)(google_text_synthesize)
|
modules/api/impl/models_api.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.api import utils as api_utils
|
2 |
+
from modules.api.Api import APIManager
|
3 |
+
from modules.models import reload_chat_tts
|
4 |
+
|
5 |
+
|
6 |
+
def setup(app: APIManager):
|
7 |
+
@app.get("/v1/models/reload", response_model=api_utils.BaseResponse)
|
8 |
+
async def reload_models():
|
9 |
+
# Reload models
|
10 |
+
reload_chat_tts()
|
11 |
+
return api_utils.success_response("Models reloaded")
|
modules/api/impl/openai_api.py
CHANGED
@@ -28,11 +28,11 @@ class AudioSpeechRequest(BaseModel):
|
|
28 |
model: str = "chattts-4w"
|
29 |
voice: str = "female2"
|
30 |
response_format: Literal["mp3", "wav"] = "mp3"
|
31 |
-
speed:
|
32 |
style: str = ""
|
33 |
# 是否开启batch合成,小于等于1表示不适用batch
|
34 |
# 开启batch合成会自动分割句子
|
35 |
-
batch_size: int = Field(1, ge=1, le=
|
36 |
spliter_threshold: float = Field(
|
37 |
100, ge=10, le=1024, description="Threshold for sentence spliter"
|
38 |
)
|
@@ -64,8 +64,8 @@ async def openai_speech_api(
|
|
64 |
params = api_utils.calc_spk_style(spk=voice, style=style)
|
65 |
|
66 |
spk = params.get("spk", -1)
|
67 |
-
seed = params.get("seed", 42)
|
68 |
-
temperature = params.get("temperature", 0.3)
|
69 |
prompt1 = params.get("prompt1", "")
|
70 |
prompt2 = params.get("prompt2", "")
|
71 |
prefix = params.get("prefix", "")
|
@@ -107,6 +107,18 @@ async def openai_speech_api(
|
|
107 |
|
108 |
|
109 |
def setup(api_manager: APIManager):
|
110 |
-
api_manager.post(
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
model: str = "chattts-4w"
|
29 |
voice: str = "female2"
|
30 |
response_format: Literal["mp3", "wav"] = "mp3"
|
31 |
+
speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
|
32 |
style: str = ""
|
33 |
# 是否开启batch合成,小于等于1表示不适用batch
|
34 |
# 开启batch合成会自动分割句子
|
35 |
+
batch_size: int = Field(1, ge=1, le=20, description="Batch size")
|
36 |
spliter_threshold: float = Field(
|
37 |
100, ge=10, le=1024, description="Threshold for sentence spliter"
|
38 |
)
|
|
|
64 |
params = api_utils.calc_spk_style(spk=voice, style=style)
|
65 |
|
66 |
spk = params.get("spk", -1)
|
67 |
+
seed = params.get("seed", request.seed or 42)
|
68 |
+
temperature = params.get("temperature", request.temperature or 0.3)
|
69 |
prompt1 = params.get("prompt1", "")
|
70 |
prompt2 = params.get("prompt2", "")
|
71 |
prefix = params.get("prefix", "")
|
|
|
107 |
|
108 |
|
109 |
def setup(api_manager: APIManager):
|
110 |
+
api_manager.post(
|
111 |
+
"/v1/audio/speech",
|
112 |
+
response_class=FileResponse,
|
113 |
+
description="""
|
114 |
+
openai api document:
|
115 |
+
[https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)
|
116 |
+
|
117 |
+
以下属性为本系统自定义属性,不在openai文档中:
|
118 |
+
- batch_size: 是否开启batch合成,小于等于1表示不使用batch (不推荐)
|
119 |
+
- spliter_threshold: 开启batch合成时,句子分割的阈值
|
120 |
+
- style: 风格
|
121 |
+
|
122 |
+
> model 可填任意值
|
123 |
+
""",
|
124 |
+
)(openai_speech_api)
|
modules/api/impl/ping_api.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.api import utils as api_utils
|
2 |
+
from modules.api.Api import APIManager
|
3 |
+
|
4 |
+
|
5 |
+
def setup(app: APIManager):
|
6 |
+
@app.get("/v1/ping", response_model=api_utils.BaseResponse)
|
7 |
+
async def ping():
|
8 |
+
return {"message": "ok", "data": "pong"}
|
modules/api/utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from pydantic import BaseModel
|
2 |
-
from typing import Any
|
3 |
|
4 |
import torch
|
5 |
|
@@ -36,6 +36,10 @@ class BaseResponse(BaseModel):
|
|
36 |
}
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
39 |
def wav_to_mp3(wav_data, bitrate="48k"):
|
40 |
audio = AudioSegment.from_wav(
|
41 |
wav_data,
|
@@ -51,7 +55,7 @@ def to_number(value, t, default=0):
|
|
51 |
return default
|
52 |
|
53 |
|
54 |
-
def calc_spk_style(spk: str
|
55 |
voice_attrs = {
|
56 |
"spk": None,
|
57 |
"seed": None,
|
|
|
1 |
from pydantic import BaseModel
|
2 |
+
from typing import Any, Union
|
3 |
|
4 |
import torch
|
5 |
|
|
|
36 |
}
|
37 |
|
38 |
|
39 |
+
def success_response(data: Any, message: str = "Success") -> BaseResponse:
|
40 |
+
return BaseResponse(message=message, data=data)
|
41 |
+
|
42 |
+
|
43 |
def wav_to_mp3(wav_data, bitrate="48k"):
|
44 |
audio = AudioSegment.from_wav(
|
45 |
wav_data,
|
|
|
55 |
return default
|
56 |
|
57 |
|
58 |
+
def calc_spk_style(spk: Union[str, int], style: Union[str, int]):
|
59 |
voice_attrs = {
|
60 |
"spk": None,
|
61 |
"seed": None,
|
modules/config.py
CHANGED
@@ -1,11 +1,5 @@
|
|
1 |
-
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
args = {}
|
6 |
|
7 |
api = None
|
8 |
-
|
9 |
-
model_config = {"half": False}
|
10 |
-
|
11 |
-
disable_tqdm = False
|
|
|
1 |
+
from modules.utils.JsonObject import JsonObject
|
2 |
|
3 |
+
runtime_env_vars = JsonObject({})
|
|
|
|
|
4 |
|
5 |
api = None
|
|
|
|
|
|
|
|
modules/devices/__init__.py
ADDED
File without changes
|
modules/devices/devices.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
from modules import config
|
5 |
+
|
6 |
+
import logging
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
if sys.platform == "darwin":
|
11 |
+
from modules.devices import mac_devices
|
12 |
+
|
13 |
+
|
14 |
+
def has_mps() -> bool:
|
15 |
+
if sys.platform != "darwin":
|
16 |
+
return False
|
17 |
+
else:
|
18 |
+
return mac_devices.has_mps
|
19 |
+
|
20 |
+
|
21 |
+
def get_cuda_device_id():
|
22 |
+
return (
|
23 |
+
int(config.runtime_env_vars.device_id)
|
24 |
+
if config.runtime_env_vars.device_id is not None
|
25 |
+
and config.runtime_env_vars.device_id.isdigit()
|
26 |
+
else 0
|
27 |
+
) or torch.cuda.current_device()
|
28 |
+
|
29 |
+
|
30 |
+
def get_cuda_device_string():
|
31 |
+
if config.runtime_env_vars.device_id is not None:
|
32 |
+
return f"cuda:{config.runtime_env_vars.device_id}"
|
33 |
+
|
34 |
+
return "cuda"
|
35 |
+
|
36 |
+
|
37 |
+
def get_available_gpus() -> list[tuple[int, int]]:
|
38 |
+
"""
|
39 |
+
Get the list of available GPUs and their free memory.
|
40 |
+
|
41 |
+
:return: A list of tuples where each tuple contains (GPU index, free memory in bytes).
|
42 |
+
"""
|
43 |
+
available_gpus = []
|
44 |
+
for i in range(torch.cuda.device_count()):
|
45 |
+
props = torch.cuda.get_device_properties(i)
|
46 |
+
free_memory = props.total_memory - torch.cuda.memory_reserved(i)
|
47 |
+
available_gpus.append((i, free_memory))
|
48 |
+
return available_gpus
|
49 |
+
|
50 |
+
|
51 |
+
def get_memory_available_gpus(min_memory=2048):
|
52 |
+
available_gpus = get_available_gpus()
|
53 |
+
memory_available_gpus = [
|
54 |
+
gpu for gpu, free_memory in available_gpus if free_memory > min_memory
|
55 |
+
]
|
56 |
+
return memory_available_gpus
|
57 |
+
|
58 |
+
|
59 |
+
def get_target_device_id_or_memory_available_gpu():
|
60 |
+
memory_available_gpus = get_memory_available_gpus()
|
61 |
+
device_id = get_cuda_device_id()
|
62 |
+
if device_id not in memory_available_gpus:
|
63 |
+
if len(memory_available_gpus) != 0:
|
64 |
+
logger.warning(
|
65 |
+
f"Device {device_id} is not available or does not have enough memory. will try to use {memory_available_gpus}"
|
66 |
+
)
|
67 |
+
config.runtime_env_vars.device_id = str(memory_available_gpus[0])
|
68 |
+
else:
|
69 |
+
logger.warning(
|
70 |
+
f"Device {device_id} is not available or does not have enough memory. Using CPU instead."
|
71 |
+
)
|
72 |
+
return "cpu"
|
73 |
+
return get_cuda_device_string()
|
74 |
+
|
75 |
+
|
76 |
+
def get_optimal_device_name():
|
77 |
+
if config.runtime_env_vars.use_cpu:
|
78 |
+
return "cpu"
|
79 |
+
|
80 |
+
if torch.cuda.is_available():
|
81 |
+
return get_target_device_id_or_memory_available_gpu()
|
82 |
+
|
83 |
+
if has_mps():
|
84 |
+
return "mps"
|
85 |
+
|
86 |
+
return "cpu"
|
87 |
+
|
88 |
+
|
89 |
+
def get_optimal_device():
|
90 |
+
return torch.device(get_optimal_device_name())
|
91 |
+
|
92 |
+
|
93 |
+
def get_device_for(task):
|
94 |
+
if task in config.cmd_opts.use_cpu or "all" in config.cmd_opts.use_cpu:
|
95 |
+
return cpu
|
96 |
+
|
97 |
+
return get_optimal_device()
|
98 |
+
|
99 |
+
|
100 |
+
def torch_gc():
|
101 |
+
try:
|
102 |
+
if torch.cuda.is_available():
|
103 |
+
with torch.cuda.device(get_cuda_device_string()):
|
104 |
+
torch.cuda.empty_cache()
|
105 |
+
torch.cuda.ipc_collect()
|
106 |
+
|
107 |
+
if has_mps():
|
108 |
+
mac_devices.torch_mps_gc()
|
109 |
+
except Exception as e:
|
110 |
+
logger.error(f"Error in torch_gc", exc_info=True)
|
111 |
+
|
112 |
+
|
113 |
+
cpu: torch.device = torch.device("cpu")
|
114 |
+
device: torch.device = get_optimal_device()
|
115 |
+
dtype: torch.dtype = torch.float32
|
116 |
+
dtype_dvae: torch.dtype = torch.float32
|
117 |
+
dtype_vocos: torch.dtype = torch.float32
|
118 |
+
dtype_gpt: torch.dtype = torch.float32
|
119 |
+
dtype_decoder: torch.dtype = torch.float32
|
120 |
+
|
121 |
+
|
122 |
+
def reset_device():
|
123 |
+
if config.runtime_env_vars.half:
|
124 |
+
global dtype
|
125 |
+
global dtype_dvae
|
126 |
+
global dtype_vocos
|
127 |
+
global dtype_gpt
|
128 |
+
global dtype_decoder
|
129 |
+
dtype = torch.float16
|
130 |
+
dtype_dvae = torch.float16
|
131 |
+
dtype_vocos = torch.float16
|
132 |
+
dtype_gpt = torch.float16
|
133 |
+
dtype_decoder = torch.float16
|
134 |
+
|
135 |
+
logger.info("Using half precision: torch.float16")
|
136 |
+
|
137 |
+
if (
|
138 |
+
config.runtime_env_vars.device_id is not None
|
139 |
+
or config.runtime_env_vars.use_cpu is not None
|
140 |
+
):
|
141 |
+
global device
|
142 |
+
device = get_optimal_device()
|
143 |
+
|
144 |
+
logger.info(f"Using device: {device}")
|
145 |
+
|
146 |
+
|
147 |
+
@lru_cache
|
148 |
+
def first_time_calculation():
|
149 |
+
"""
|
150 |
+
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
|
151 |
+
spends about 2.7 seconds doing that, at least wih NVidia.
|
152 |
+
"""
|
153 |
+
|
154 |
+
x = torch.zeros((1, 1)).to(device, dtype)
|
155 |
+
linear = torch.nn.Linear(1, 1).to(device, dtype)
|
156 |
+
linear(x)
|
157 |
+
|
158 |
+
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
159 |
+
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
160 |
+
conv2d(x)
|
modules/devices/mac_devices.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import logging
|
3 |
+
from packaging import version
|
4 |
+
import torch.backends
|
5 |
+
import torch.backends.mps
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
def check_for_mps() -> bool:
|
11 |
+
if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
12 |
+
if not getattr(torch, "has_mps", False):
|
13 |
+
return False
|
14 |
+
try:
|
15 |
+
torch.zeros(1).to(torch.device("mps"))
|
16 |
+
return True
|
17 |
+
except Exception:
|
18 |
+
return False
|
19 |
+
else:
|
20 |
+
try:
|
21 |
+
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
22 |
+
except:
|
23 |
+
logger.warning("MPS garbage collection failed", exc_info=True)
|
24 |
+
return False
|
25 |
+
|
26 |
+
|
27 |
+
has_mps = check_for_mps()
|
28 |
+
|
29 |
+
|
30 |
+
def torch_mps_gc() -> None:
|
31 |
+
try:
|
32 |
+
from torch.mps import empty_cache
|
33 |
+
|
34 |
+
empty_cache()
|
35 |
+
except Exception:
|
36 |
+
logger.warning("MPS garbage collection failed", exc_info=True)
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
print(torch.__version__)
|
41 |
+
print(has_mps)
|
42 |
+
torch_mps_gc()
|
modules/generate_audio.py
CHANGED
@@ -8,18 +8,20 @@ from modules import models, config
|
|
8 |
|
9 |
import logging
|
10 |
|
11 |
-
from modules import devices
|
|
|
|
|
|
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
15 |
|
16 |
-
@torch.inference_mode()
|
17 |
def generate_audio(
|
18 |
text: str,
|
19 |
temperature: float = 0.3,
|
20 |
top_P: float = 0.7,
|
21 |
top_K: float = 20,
|
22 |
-
spk: int
|
23 |
infer_seed: int = -1,
|
24 |
use_decoder: bool = True,
|
25 |
prompt1: str = "",
|
@@ -48,7 +50,7 @@ def generate_audio_batch(
|
|
48 |
temperature: float = 0.3,
|
49 |
top_P: float = 0.7,
|
50 |
top_K: float = 20,
|
51 |
-
spk: int
|
52 |
infer_seed: int = -1,
|
53 |
use_decoder: bool = True,
|
54 |
prompt1: str = "",
|
@@ -65,7 +67,7 @@ def generate_audio_batch(
|
|
65 |
"prompt2": prompt2 or "",
|
66 |
"prefix": prefix or "",
|
67 |
"repetition_penalty": 1.0,
|
68 |
-
"disable_tqdm": config.
|
69 |
}
|
70 |
|
71 |
if isinstance(spk, int):
|
@@ -103,6 +105,32 @@ def generate_audio_batch(
|
|
103 |
return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
|
104 |
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
if __name__ == "__main__":
|
107 |
import soundfile as sf
|
108 |
|
|
|
8 |
|
9 |
import logging
|
10 |
|
11 |
+
from modules.devices import devices
|
12 |
+
from typing import Union
|
13 |
+
|
14 |
+
from modules.utils.cache import conditional_cache
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
18 |
|
|
|
19 |
def generate_audio(
|
20 |
text: str,
|
21 |
temperature: float = 0.3,
|
22 |
top_P: float = 0.7,
|
23 |
top_K: float = 20,
|
24 |
+
spk: Union[int, Speaker] = -1,
|
25 |
infer_seed: int = -1,
|
26 |
use_decoder: bool = True,
|
27 |
prompt1: str = "",
|
|
|
50 |
temperature: float = 0.3,
|
51 |
top_P: float = 0.7,
|
52 |
top_K: float = 20,
|
53 |
+
spk: Union[int, Speaker] = -1,
|
54 |
infer_seed: int = -1,
|
55 |
use_decoder: bool = True,
|
56 |
prompt1: str = "",
|
|
|
67 |
"prompt2": prompt2 or "",
|
68 |
"prefix": prefix or "",
|
69 |
"repetition_penalty": 1.0,
|
70 |
+
"disable_tqdm": config.runtime_env_vars.off_tqdm,
|
71 |
}
|
72 |
|
73 |
if isinstance(spk, int):
|
|
|
105 |
return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
|
106 |
|
107 |
|
108 |
+
lru_cache_enabled = False
|
109 |
+
|
110 |
+
|
111 |
+
def setup_lru_cache():
|
112 |
+
global generate_audio_batch
|
113 |
+
global lru_cache_enabled
|
114 |
+
|
115 |
+
if lru_cache_enabled:
|
116 |
+
return
|
117 |
+
lru_cache_enabled = True
|
118 |
+
|
119 |
+
def should_cache(*args, **kwargs):
|
120 |
+
spk_seed = kwargs.get("spk", -1)
|
121 |
+
infer_seed = kwargs.get("infer_seed", -1)
|
122 |
+
return spk_seed != -1 and infer_seed != -1
|
123 |
+
|
124 |
+
lru_size = config.runtime_env_vars.lru_size
|
125 |
+
if isinstance(lru_size, int):
|
126 |
+
generate_audio_batch = conditional_cache(lru_size, should_cache)(
|
127 |
+
generate_audio_batch
|
128 |
+
)
|
129 |
+
logger.info(f"LRU cache enabled with size {lru_size}")
|
130 |
+
else:
|
131 |
+
logger.debug(f"LRU cache failed to enable, invalid size {lru_size}")
|
132 |
+
|
133 |
+
|
134 |
if __name__ == "__main__":
|
135 |
import soundfile as sf
|
136 |
|
modules/models.py
CHANGED
@@ -1,15 +1,11 @@
|
|
1 |
-
from modules.ChatTTS import ChatTTS
|
2 |
import torch
|
3 |
-
|
4 |
from modules import config
|
|
|
5 |
|
6 |
import logging
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
9 |
-
|
10 |
-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
11 |
-
print(f"device use {device}")
|
12 |
-
|
13 |
chat_tts = None
|
14 |
|
15 |
|
@@ -17,25 +13,33 @@ def load_chat_tts():
|
|
17 |
global chat_tts
|
18 |
if chat_tts:
|
19 |
return chat_tts
|
|
|
20 |
chat_tts = ChatTTS.Chat()
|
21 |
chat_tts.load_models(
|
22 |
-
compile=config.
|
23 |
source="local",
|
24 |
local_path="./models/ChatTTS",
|
25 |
-
device=device,
|
|
|
|
|
|
|
|
|
|
|
26 |
)
|
27 |
|
28 |
-
|
29 |
-
logging.info("half precision enabled")
|
30 |
-
for model_name, model in chat_tts.pretrain_models.items():
|
31 |
-
if isinstance(model, torch.nn.Module):
|
32 |
-
model.cpu()
|
33 |
-
if torch.cuda.is_available():
|
34 |
-
torch.cuda.empty_cache()
|
35 |
-
model.half()
|
36 |
-
if torch.cuda.is_available():
|
37 |
-
model.cuda()
|
38 |
-
model.eval()
|
39 |
-
logger.log(logging.INFO, f"{model_name} converted to half precision.")
|
40 |
|
41 |
return chat_tts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
from modules.ChatTTS import ChatTTS
|
3 |
from modules import config
|
4 |
+
from modules.devices import devices
|
5 |
|
6 |
import logging
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
9 |
chat_tts = None
|
10 |
|
11 |
|
|
|
13 |
global chat_tts
|
14 |
if chat_tts:
|
15 |
return chat_tts
|
16 |
+
|
17 |
chat_tts = ChatTTS.Chat()
|
18 |
chat_tts.load_models(
|
19 |
+
compile=config.runtime_env_vars.compile,
|
20 |
source="local",
|
21 |
local_path="./models/ChatTTS",
|
22 |
+
device=devices.device,
|
23 |
+
dtype=devices.dtype,
|
24 |
+
dtype_vocos=devices.dtype_vocos,
|
25 |
+
dtype_dvae=devices.dtype_dvae,
|
26 |
+
dtype_gpt=devices.dtype_gpt,
|
27 |
+
dtype_decoder=devices.dtype_decoder,
|
28 |
)
|
29 |
|
30 |
+
devices.torch_gc()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
return chat_tts
|
33 |
+
|
34 |
+
|
35 |
+
def reload_chat_tts():
|
36 |
+
logging.info("Reloading ChatTTS models")
|
37 |
+
global chat_tts
|
38 |
+
if chat_tts:
|
39 |
+
if torch.cuda.is_available():
|
40 |
+
for model_name, model in chat_tts.pretrain_models.items():
|
41 |
+
if isinstance(model, torch.nn.Module):
|
42 |
+
model.cpu()
|
43 |
+
torch.cuda.empty_cache()
|
44 |
+
chat_tts = None
|
45 |
+
return load_chat_tts()
|
modules/normalization.py
CHANGED
@@ -1,6 +1,15 @@
|
|
1 |
from modules.utils.zh_normalization.text_normlization import *
|
2 |
import emojiswitch
|
3 |
from modules.utils.markdown import markdown_to_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
post_normalize_pipeline = []
|
6 |
pre_normalize_pipeline = []
|
@@ -87,12 +96,17 @@ character_map = {
|
|
87 |
">": ",",
|
88 |
"<": ",",
|
89 |
"-": ",",
|
|
|
|
|
|
|
90 |
}
|
91 |
|
92 |
character_to_word = {
|
93 |
" & ": " and ",
|
94 |
}
|
95 |
|
|
|
|
|
96 |
|
97 |
@post_normalize()
|
98 |
def apply_character_to_word(text):
|
@@ -109,7 +123,8 @@ def apply_character_map(text):
|
|
109 |
|
110 |
@post_normalize()
|
111 |
def apply_emoji_map(text):
|
112 |
-
|
|
|
113 |
|
114 |
|
115 |
@post_normalize()
|
@@ -122,6 +137,26 @@ def insert_spaces_between_uppercase(s):
|
|
122 |
)
|
123 |
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
@pre_normalize()
|
126 |
def apply_markdown_to_text(text):
|
127 |
if is_markdown(text):
|
@@ -186,7 +221,7 @@ def sentence_normalize(sentence_text: str):
|
|
186 |
pattern = re.compile(r"(\[.+?\])|([^[]+)")
|
187 |
|
188 |
def normalize_part(part):
|
189 |
-
sentences = tx.normalize(part)
|
190 |
dest_text = ""
|
191 |
for sentence in sentences:
|
192 |
sentence = apply_post_normalize(sentence)
|
@@ -244,6 +279,16 @@ console.log('1')
|
|
244 |
“我们是玫瑰花。”花儿们说道。
|
245 |
“啊!”小王子说……。
|
246 |
""",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
]
|
248 |
|
249 |
for i, test_case in enumerate(test_cases):
|
|
|
1 |
from modules.utils.zh_normalization.text_normlization import *
|
2 |
import emojiswitch
|
3 |
from modules.utils.markdown import markdown_to_text
|
4 |
+
from modules import models
|
5 |
+
import re
|
6 |
+
|
7 |
+
|
8 |
+
def is_chinese(text):
|
9 |
+
# 中文字符的 Unicode 范围是 \u4e00-\u9fff
|
10 |
+
chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
|
11 |
+
return bool(chinese_pattern.search(text))
|
12 |
+
|
13 |
|
14 |
post_normalize_pipeline = []
|
15 |
pre_normalize_pipeline = []
|
|
|
96 |
">": ",",
|
97 |
"<": ",",
|
98 |
"-": ",",
|
99 |
+
"~": " ",
|
100 |
+
"~": " ",
|
101 |
+
"/": " ",
|
102 |
}
|
103 |
|
104 |
character_to_word = {
|
105 |
" & ": " and ",
|
106 |
}
|
107 |
|
108 |
+
## ---------- post normalize ----------
|
109 |
+
|
110 |
|
111 |
@post_normalize()
|
112 |
def apply_character_to_word(text):
|
|
|
123 |
|
124 |
@post_normalize()
|
125 |
def apply_emoji_map(text):
|
126 |
+
lang = "zh" if is_chinese(text) else "en"
|
127 |
+
return emojiswitch.demojize(text, delimiters=("", ""), lang=lang)
|
128 |
|
129 |
|
130 |
@post_normalize()
|
|
|
137 |
)
|
138 |
|
139 |
|
140 |
+
@post_normalize()
|
141 |
+
def replace_unk_tokens(text):
|
142 |
+
"""
|
143 |
+
把不在字典里的字符替换为 " , "
|
144 |
+
"""
|
145 |
+
chat_tts = models.load_chat_tts()
|
146 |
+
tokenizer = chat_tts.pretrain_models["tokenizer"]
|
147 |
+
vocab = tokenizer.get_vocab()
|
148 |
+
vocab_set = set(vocab.keys())
|
149 |
+
# 添加所有英语字符
|
150 |
+
vocab_set.update(set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"))
|
151 |
+
vocab_set.update(set(" \n\r\t"))
|
152 |
+
replaced_chars = [char if char in vocab_set else " , " for char in text]
|
153 |
+
output_text = "".join(replaced_chars)
|
154 |
+
return output_text
|
155 |
+
|
156 |
+
|
157 |
+
## ---------- pre normalize ----------
|
158 |
+
|
159 |
+
|
160 |
@pre_normalize()
|
161 |
def apply_markdown_to_text(text):
|
162 |
if is_markdown(text):
|
|
|
221 |
pattern = re.compile(r"(\[.+?\])|([^[]+)")
|
222 |
|
223 |
def normalize_part(part):
|
224 |
+
sentences = tx.normalize(part) if is_chinese(part) else [part]
|
225 |
dest_text = ""
|
226 |
for sentence in sentences:
|
227 |
sentence = apply_post_normalize(sentence)
|
|
|
279 |
“我们是玫瑰花。”花儿们说道。
|
280 |
“啊!”小王子说……。
|
281 |
""",
|
282 |
+
"""
|
283 |
+
State-of-the-art Machine Learning for PyTorch, TensorFlow, and JAX.
|
284 |
+
|
285 |
+
🤗 Transformers provides APIs and tools to easily download and train state-of-the-art pretrained models. Using pretrained models can reduce your compute costs, carbon footprint, and save you the time and resources required to train a model from scratch. These models support common tasks in different modalities, such as:
|
286 |
+
|
287 |
+
📝 Natural Language Processing: text classification, named entity recognition, question answering, language modeling, summarization, translation, multiple choice, and text generation.
|
288 |
+
🖼️ Computer Vision: image classification, object detection, and segmentation.
|
289 |
+
🗣️ Audio: automatic speech recognition and audio classification.
|
290 |
+
🐙 Multimodal: table question answering, optical character recognition, information extraction from scanned documents, video classification, and visual question answering.
|
291 |
+
""",
|
292 |
]
|
293 |
|
294 |
for i, test_case in enumerate(test_cases):
|
modules/refiner.py
CHANGED
@@ -29,7 +29,7 @@ def refine_text(
|
|
29 |
"temperature": temperature,
|
30 |
"repetition_penalty": repetition_penalty,
|
31 |
"max_new_token": max_new_token,
|
32 |
-
"disable_tqdm": config.
|
33 |
},
|
34 |
do_text_normalization=False,
|
35 |
)
|
|
|
29 |
"temperature": temperature,
|
30 |
"repetition_penalty": repetition_penalty,
|
31 |
"max_new_token": max_new_token,
|
32 |
+
"disable_tqdm": config.runtime_env_vars.off_tqdm,
|
33 |
},
|
34 |
do_text_normalization=False,
|
35 |
)
|
modules/speaker.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
import torch
|
3 |
|
4 |
from modules import models
|
@@ -53,6 +54,14 @@ class Speaker:
|
|
53 |
|
54 |
return is_update
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
# 每个speaker就是一个 emb 文件 .pt
|
58 |
# 管理 speaker 就是管理 ./data/speaker/ 下的所有 speaker
|
@@ -105,13 +114,13 @@ class SpeakerManager:
|
|
105 |
self.refresh_speakers()
|
106 |
return speaker
|
107 |
|
108 |
-
def get_speaker(self, name) -> Speaker
|
109 |
for speaker in self.speakers.values():
|
110 |
if speaker.name == name:
|
111 |
return speaker
|
112 |
return None
|
113 |
|
114 |
-
def get_speaker_by_id(self, id) -> Speaker
|
115 |
for speaker in self.speakers.values():
|
116 |
if str(speaker.id) == str(id):
|
117 |
return speaker
|
|
|
1 |
import os
|
2 |
+
from typing import Union
|
3 |
import torch
|
4 |
|
5 |
from modules import models
|
|
|
54 |
|
55 |
return is_update
|
56 |
|
57 |
+
def __hash__(self):
|
58 |
+
return hash(str(self.id))
|
59 |
+
|
60 |
+
def __eq__(self, other):
|
61 |
+
if not isinstance(other, Speaker):
|
62 |
+
return False
|
63 |
+
return str(self.id) == str(other.id)
|
64 |
+
|
65 |
|
66 |
# 每个speaker就是一个 emb 文件 .pt
|
67 |
# 管理 speaker 就是管理 ./data/speaker/ 下的所有 speaker
|
|
|
114 |
self.refresh_speakers()
|
115 |
return speaker
|
116 |
|
117 |
+
def get_speaker(self, name) -> Union[Speaker, None]:
|
118 |
for speaker in self.speakers.values():
|
119 |
if speaker.name == name:
|
120 |
return speaker
|
121 |
return None
|
122 |
|
123 |
+
def get_speaker_by_id(self, id) -> Union[Speaker, None]:
|
124 |
for speaker in self.speakers.values():
|
125 |
if str(speaker.id) == str(id):
|
126 |
return speaker
|
modules/synthesize_audio.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import io
|
|
|
2 |
from modules.SentenceSplitter import SentenceSplitter
|
3 |
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
|
4 |
|
@@ -14,7 +15,7 @@ def synthesize_audio(
|
|
14 |
temperature: float = 0.3,
|
15 |
top_P: float = 0.7,
|
16 |
top_K: float = 20,
|
17 |
-
spk: int
|
18 |
infer_seed: int = -1,
|
19 |
use_decoder: bool = True,
|
20 |
prompt1: str = "",
|
|
|
1 |
import io
|
2 |
+
from typing import Union
|
3 |
from modules.SentenceSplitter import SentenceSplitter
|
4 |
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
|
5 |
|
|
|
15 |
temperature: float = 0.3,
|
16 |
top_P: float = 0.7,
|
17 |
top_K: float = 20,
|
18 |
+
spk: Union[int, Speaker] = -1,
|
19 |
infer_seed: int = -1,
|
20 |
use_decoder: bool = True,
|
21 |
prompt1: str = "",
|
modules/utils/JsonObject.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class JsonObject:
|
2 |
+
def __init__(self, initial_dict=None):
|
3 |
+
"""
|
4 |
+
Initialize the JsonObject with an optional initial dictionary.
|
5 |
+
|
6 |
+
:param initial_dict: A dictionary to initialize the JsonObject.
|
7 |
+
"""
|
8 |
+
# If no initial dictionary is provided, use an empty dictionary
|
9 |
+
self._dict_obj = initial_dict if initial_dict is not None else {}
|
10 |
+
|
11 |
+
def __getattr__(self, name):
|
12 |
+
"""
|
13 |
+
Get an attribute value. If the attribute does not exist,
|
14 |
+
look it up in the internal dictionary.
|
15 |
+
|
16 |
+
:param name: The name of the attribute.
|
17 |
+
:return: The value of the attribute.
|
18 |
+
:raises AttributeError: If the attribute is not found in the dictionary.
|
19 |
+
"""
|
20 |
+
try:
|
21 |
+
return self._dict_obj[name]
|
22 |
+
except KeyError:
|
23 |
+
return None
|
24 |
+
|
25 |
+
def __setattr__(self, name, value):
|
26 |
+
"""
|
27 |
+
Set an attribute value. If the attribute name is '_dict_obj',
|
28 |
+
set it directly as an instance attribute. Otherwise,
|
29 |
+
store it in the internal dictionary.
|
30 |
+
|
31 |
+
:param name: The name of the attribute.
|
32 |
+
:param value: The value to set.
|
33 |
+
"""
|
34 |
+
if name == "_dict_obj":
|
35 |
+
super().__setattr__(name, value)
|
36 |
+
else:
|
37 |
+
self._dict_obj[name] = value
|
38 |
+
|
39 |
+
def __delattr__(self, name):
|
40 |
+
"""
|
41 |
+
Delete an attribute. If the attribute does not exist,
|
42 |
+
look it up in the internal dictionary and remove it.
|
43 |
+
|
44 |
+
:param name: The name of the attribute.
|
45 |
+
:raises AttributeError: If the attribute is not found in the dictionary.
|
46 |
+
"""
|
47 |
+
try:
|
48 |
+
del self._dict_obj[name]
|
49 |
+
except KeyError:
|
50 |
+
return
|
51 |
+
|
52 |
+
def __getitem__(self, key):
|
53 |
+
"""
|
54 |
+
Get an item value from the internal dictionary.
|
55 |
+
|
56 |
+
:param key: The key of the item.
|
57 |
+
:return: The value of the item.
|
58 |
+
:raises KeyError: If the key is not found in the dictionary.
|
59 |
+
"""
|
60 |
+
if key not in self._dict_obj:
|
61 |
+
return None
|
62 |
+
return self._dict_obj[key]
|
63 |
+
|
64 |
+
def __setitem__(self, key, value):
|
65 |
+
"""
|
66 |
+
Set an item value in the internal dictionary.
|
67 |
+
|
68 |
+
:param key: The key of the item.
|
69 |
+
:param value: The value to set.
|
70 |
+
"""
|
71 |
+
self._dict_obj[key] = value
|
72 |
+
|
73 |
+
def __delitem__(self, key):
|
74 |
+
"""
|
75 |
+
Delete an item from the internal dictionary.
|
76 |
+
|
77 |
+
:param key: The key of the item.
|
78 |
+
:raises KeyError: If the key is not found in the dictionary.
|
79 |
+
"""
|
80 |
+
del self._dict_obj[key]
|
81 |
+
|
82 |
+
def to_dict(self):
|
83 |
+
"""
|
84 |
+
Convert the JsonObject back to a regular dictionary.
|
85 |
+
|
86 |
+
:return: The internal dictionary.
|
87 |
+
"""
|
88 |
+
return self._dict_obj
|
89 |
+
|
90 |
+
def has_key(self, key):
|
91 |
+
"""
|
92 |
+
Check if the key exists in the internal dictionary.
|
93 |
+
|
94 |
+
:param key: The key to check.
|
95 |
+
:return: True if the key exists, False otherwise.
|
96 |
+
"""
|
97 |
+
return key in self._dict_obj
|
98 |
+
|
99 |
+
def keys(self):
|
100 |
+
"""
|
101 |
+
Get a list of keys in the internal dictionary.
|
102 |
+
|
103 |
+
:return: A list of keys.
|
104 |
+
"""
|
105 |
+
return self._dict_obj.keys()
|
106 |
+
|
107 |
+
def values(self):
|
108 |
+
"""
|
109 |
+
Get a list of values in the internal dictionary.
|
110 |
+
|
111 |
+
:return: A list of values.
|
112 |
+
"""
|
113 |
+
return self._dict_obj.values()
|
modules/utils/cache.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, TypeVar, Any
|
2 |
+
from typing_extensions import ParamSpec
|
3 |
+
|
4 |
+
from functools import lru_cache, _CacheInfo
|
5 |
+
|
6 |
+
|
7 |
+
def conditional_cache(maxsize: int, condition: Callable):
|
8 |
+
def decorator(func):
|
9 |
+
@lru_cache_ext(maxsize=maxsize)
|
10 |
+
def cached_func(*args, **kwargs):
|
11 |
+
return func(*args, **kwargs)
|
12 |
+
|
13 |
+
def wrapper(*args, **kwargs):
|
14 |
+
if condition(*args, **kwargs):
|
15 |
+
return cached_func(*args, **kwargs)
|
16 |
+
else:
|
17 |
+
return func(*args, **kwargs)
|
18 |
+
|
19 |
+
return wrapper
|
20 |
+
|
21 |
+
return decorator
|
22 |
+
|
23 |
+
|
24 |
+
def hash_list(l: list) -> int:
|
25 |
+
__hash = 0
|
26 |
+
for i, e in enumerate(l):
|
27 |
+
__hash = hash((__hash, i, hash_item(e)))
|
28 |
+
return __hash
|
29 |
+
|
30 |
+
|
31 |
+
def hash_dict(d: dict) -> int:
|
32 |
+
__hash = 0
|
33 |
+
for k, v in d.items():
|
34 |
+
__hash = hash((__hash, k, hash_item(v)))
|
35 |
+
return __hash
|
36 |
+
|
37 |
+
|
38 |
+
def hash_item(e) -> int:
|
39 |
+
if hasattr(e, "__hash__") and callable(e.__hash__):
|
40 |
+
try:
|
41 |
+
return hash(e)
|
42 |
+
except TypeError:
|
43 |
+
pass
|
44 |
+
if isinstance(e, (list, set, tuple)):
|
45 |
+
return hash_list(list(e))
|
46 |
+
elif isinstance(e, (dict)):
|
47 |
+
return hash_dict(e)
|
48 |
+
else:
|
49 |
+
raise TypeError(f"unhashable type: {e.__class__}")
|
50 |
+
|
51 |
+
|
52 |
+
PT = ParamSpec("PT")
|
53 |
+
RT = TypeVar("RT")
|
54 |
+
|
55 |
+
|
56 |
+
def lru_cache_ext(
|
57 |
+
*opts, hashfunc: Callable[..., int] = hash_item, **kwopts
|
58 |
+
) -> Callable[[Callable[PT, RT]], Callable[PT, RT]]:
|
59 |
+
def decorator(func: Callable[PT, RT]) -> Callable[PT, RT]:
|
60 |
+
class _lru_cache_ext_wrapper:
|
61 |
+
args: tuple
|
62 |
+
kwargs: dict[str, Any]
|
63 |
+
|
64 |
+
def cache_info(self) -> _CacheInfo: ...
|
65 |
+
def cache_clear(self) -> None: ...
|
66 |
+
|
67 |
+
@classmethod
|
68 |
+
@lru_cache(*opts, **kwopts)
|
69 |
+
def cached_func(cls, args_hash: int) -> RT:
|
70 |
+
return func(*cls.args, **cls.kwargs)
|
71 |
+
|
72 |
+
@classmethod
|
73 |
+
def __call__(cls, *args: PT.args, **kwargs: PT.kwargs) -> RT:
|
74 |
+
__hash = hashfunc(
|
75 |
+
(
|
76 |
+
id(func),
|
77 |
+
*[hashfunc(a) for a in args],
|
78 |
+
*[(hashfunc(k), hashfunc(v)) for k, v in kwargs.items()],
|
79 |
+
)
|
80 |
+
)
|
81 |
+
|
82 |
+
cls.args = args
|
83 |
+
cls.kwargs = kwargs
|
84 |
+
|
85 |
+
cls.cache_info = cls.cached_func.cache_info
|
86 |
+
cls.cache_clear = cls.cached_func.cache_clear
|
87 |
+
|
88 |
+
return cls.cached_func(__hash)
|
89 |
+
|
90 |
+
return _lru_cache_ext_wrapper()
|
91 |
+
|
92 |
+
return decorator
|
modules/utils/zh_normalization/text_normlization.py
CHANGED
@@ -72,9 +72,9 @@ class TextNormalizer():
|
|
72 |
return sentences
|
73 |
|
74 |
def _post_replace(self, sentence: str) -> str:
|
75 |
-
sentence = sentence.replace('/', '每')
|
76 |
-
sentence = sentence.replace('~', '至')
|
77 |
-
sentence = sentence.replace('~', '至')
|
78 |
sentence = sentence.replace('①', '一')
|
79 |
sentence = sentence.replace('②', '二')
|
80 |
sentence = sentence.replace('③', '三')
|
|
|
72 |
return sentences
|
73 |
|
74 |
def _post_replace(self, sentence: str) -> str:
|
75 |
+
# sentence = sentence.replace('/', '每')
|
76 |
+
# sentence = sentence.replace('~', '至')
|
77 |
+
# sentence = sentence.replace('~', '至')
|
78 |
sentence = sentence.replace('①', '一')
|
79 |
sentence = sentence.replace('②', '二')
|
80 |
sentence = sentence.replace('③', '三')
|
webui.py
CHANGED
@@ -14,9 +14,11 @@ except:
|
|
14 |
import os
|
15 |
import logging
|
16 |
|
17 |
-
|
18 |
|
|
|
19 |
from modules.synthesize_audio import synthesize_audio
|
|
|
20 |
|
21 |
logging.basicConfig(
|
22 |
level=os.getenv("LOG_LEVEL", "INFO"),
|
@@ -25,20 +27,17 @@ logging.basicConfig(
|
|
25 |
|
26 |
|
27 |
import gradio as gr
|
28 |
-
import io
|
29 |
-
import re
|
30 |
-
import numpy as np
|
31 |
|
32 |
import torch
|
33 |
|
34 |
from modules.ssml import parse_ssml
|
35 |
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
|
36 |
-
from modules.generate_audio import generate_audio, generate_audio_batch
|
37 |
|
38 |
from modules.speaker import speaker_mgr
|
39 |
from modules.data import styles_mgr
|
40 |
|
41 |
from modules.api.utils import calc_spk_style
|
|
|
42 |
|
43 |
from modules.normalization import text_normalize
|
44 |
from modules import refiner, config
|
@@ -147,7 +146,7 @@ def tts_generate(
|
|
147 |
prompt1 = prompt1 or params.get("prompt1", "")
|
148 |
prompt2 = prompt2 or params.get("prompt2", "")
|
149 |
|
150 |
-
infer_seed = clip(infer_seed, -1, 2**32 - 1)
|
151 |
infer_seed = int(infer_seed)
|
152 |
|
153 |
if not disable_normalize:
|
@@ -869,31 +868,59 @@ if __name__ == "__main__":
|
|
869 |
type=int,
|
870 |
help="Max batch size for TTS",
|
871 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
872 |
|
873 |
args = parser.parse_args()
|
874 |
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
886 |
|
887 |
demo = create_interface()
|
888 |
|
889 |
if auth:
|
890 |
auth = tuple(auth.split(":"))
|
891 |
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
if off_tqdm:
|
896 |
-
config.disable_tqdm = True
|
897 |
|
898 |
demo.queue().launch(
|
899 |
server_name=server_name,
|
|
|
14 |
import os
|
15 |
import logging
|
16 |
|
17 |
+
import numpy as np
|
18 |
|
19 |
+
from modules.devices import devices
|
20 |
from modules.synthesize_audio import synthesize_audio
|
21 |
+
from modules.utils.cache import conditional_cache
|
22 |
|
23 |
logging.basicConfig(
|
24 |
level=os.getenv("LOG_LEVEL", "INFO"),
|
|
|
27 |
|
28 |
|
29 |
import gradio as gr
|
|
|
|
|
|
|
30 |
|
31 |
import torch
|
32 |
|
33 |
from modules.ssml import parse_ssml
|
34 |
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
|
|
|
35 |
|
36 |
from modules.speaker import speaker_mgr
|
37 |
from modules.data import styles_mgr
|
38 |
|
39 |
from modules.api.utils import calc_spk_style
|
40 |
+
import modules.generate_audio as generate
|
41 |
|
42 |
from modules.normalization import text_normalize
|
43 |
from modules import refiner, config
|
|
|
146 |
prompt1 = prompt1 or params.get("prompt1", "")
|
147 |
prompt2 = prompt2 or params.get("prompt2", "")
|
148 |
|
149 |
+
infer_seed = np.clip(infer_seed, -1, 2**32 - 1)
|
150 |
infer_seed = int(infer_seed)
|
151 |
|
152 |
if not disable_normalize:
|
|
|
868 |
type=int,
|
869 |
help="Max batch size for TTS",
|
870 |
)
|
871 |
+
parser.add_argument(
|
872 |
+
"--lru_size",
|
873 |
+
type=int,
|
874 |
+
default=64,
|
875 |
+
help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
|
876 |
+
)
|
877 |
+
parser.add_argument(
|
878 |
+
"--device_id",
|
879 |
+
type=str,
|
880 |
+
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
|
881 |
+
default=None,
|
882 |
+
)
|
883 |
+
parser.add_argument(
|
884 |
+
"--use_cpu",
|
885 |
+
nargs="+",
|
886 |
+
help="use CPU as torch device for specified modules",
|
887 |
+
default=[],
|
888 |
+
type=str.lower,
|
889 |
+
)
|
890 |
+
parser.add_argument("--compile", action="store_true", help="Enable model compile")
|
891 |
|
892 |
args = parser.parse_args()
|
893 |
|
894 |
+
def get_and_update_env(*args):
|
895 |
+
val = env.get_env_or_arg(*args)
|
896 |
+
key = args[1]
|
897 |
+
config.runtime_env_vars[key] = val
|
898 |
+
return val
|
899 |
+
|
900 |
+
server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
|
901 |
+
server_port = get_and_update_env(args, "server_port", 7860, int)
|
902 |
+
share = get_and_update_env(args, "share", False, bool)
|
903 |
+
debug = get_and_update_env(args, "debug", False, bool)
|
904 |
+
auth = get_and_update_env(args, "auth", None, str)
|
905 |
+
half = get_and_update_env(args, "half", False, bool)
|
906 |
+
off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
|
907 |
+
lru_size = get_and_update_env(args, "lru_size", 64, int)
|
908 |
+
device_id = get_and_update_env(args, "device_id", None, str)
|
909 |
+
use_cpu = get_and_update_env(args, "use_cpu", [], list)
|
910 |
+
compile = get_and_update_env(args, "compile", False, bool)
|
911 |
+
|
912 |
+
webui_config["tts_max"] = get_and_update_env(args, "tts_max_len", 1000, int)
|
913 |
+
webui_config["ssml_max"] = get_and_update_env(args, "ssml_max_len", 5000, int)
|
914 |
+
webui_config["max_batch_size"] = get_and_update_env(args, "max_batch_size", 8, int)
|
915 |
|
916 |
demo = create_interface()
|
917 |
|
918 |
if auth:
|
919 |
auth = tuple(auth.split(":"))
|
920 |
|
921 |
+
generate.setup_lru_cache()
|
922 |
+
devices.reset_device()
|
923 |
+
devices.first_time_calculation()
|
|
|
|
|
924 |
|
925 |
demo.queue().launch(
|
926 |
server_name=server_name,
|