Spaces:
Runtime error
Runtime error
Upload 34 files
Browse files- .gitattributes +10 -11
- app.py +111 -53
- models/Azuma/Azuma-e10.ckpt +3 -0
- models/Azuma/Azuma_e35_s1435.pth +3 -0
- module/data_utils.py +45 -92
- requirements.txt +12 -6
- text/chinese.py +1 -1
- text/tone_sandhi.py +1 -1
- utils.py +3 -3
.gitattributes
CHANGED
@@ -1,35 +1,34 @@
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
|
11 |
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
13 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
17 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
20 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
22 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
|
24 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.db* filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.ark* filter=lfs diff=lfs merge=lfs -text
|
30 |
+
**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
|
31 |
+
**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
|
32 |
+
**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -texttext/cmudict_cache.pickle filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
-
import os
|
|
|
2 |
|
3 |
gpt_path = os.environ.get(
|
4 |
-
"gpt_path", "models/
|
5 |
)
|
6 |
-
sovits_path = os.environ.get("sovits_path", "models/
|
7 |
cnhubert_base_path = os.environ.get(
|
8 |
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
|
9 |
)
|
@@ -21,6 +22,10 @@ import numpy as np
|
|
21 |
import librosa,torch
|
22 |
from feature_extractor import cnhubert
|
23 |
cnhubert.cnhubert_base_path=cnhubert_base_path
|
|
|
|
|
|
|
|
|
24 |
|
25 |
from module.models import SynthesizerTrn
|
26 |
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
@@ -106,29 +111,42 @@ if is_half == True:
|
|
106 |
else:
|
107 |
ssl_model = ssl_model.to(device)
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
|
134 |
def get_spepc(hps, filename):
|
@@ -150,17 +168,29 @@ def get_spepc(hps, filename):
|
|
150 |
dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
|
151 |
|
152 |
|
153 |
-
def get_tts_wav(
|
|
|
|
|
|
|
|
|
154 |
t0 = ttime()
|
155 |
prompt_text = prompt_text.strip("\n")
|
156 |
prompt_language, text = prompt_language, text.strip("\n")
|
|
|
|
|
|
|
|
|
157 |
with torch.no_grad():
|
158 |
-
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
159 |
wav16k = torch.from_numpy(wav16k)
|
|
|
160 |
if is_half == True:
|
161 |
wav16k = wav16k.half().to(device)
|
|
|
162 |
else:
|
163 |
wav16k = wav16k.to(device)
|
|
|
|
|
164 |
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
165 |
"last_hidden_state"
|
166 |
].transpose(
|
@@ -175,10 +205,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
|
175 |
phones1 = cleaned_text_to_sequence(phones1)
|
176 |
texts = text.split("\n")
|
177 |
audio_opt = []
|
178 |
-
|
179 |
-
int(hps.data.sampling_rate * 0.3),
|
180 |
-
dtype=np.float16 if is_half == True else np.float32,
|
181 |
-
)
|
182 |
for text in texts:
|
183 |
# 解决输入目标文本的空行导致报错的问题
|
184 |
if (len(text.strip()) == 0):
|
@@ -319,28 +346,59 @@ def cut3(inp):
|
|
319 |
inp = inp.strip("\n")
|
320 |
return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
|
321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
|
323 |
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
324 |
gr.Markdown(value="""
|
325 |
-
# <center>【AI
|
326 |
|
327 |
### <center>模型作者:Xz乔希 https://space.bilibili.com/5859321\n
|
328 |
### <center>数据集下载:https://huggingface.co/datasets/XzJosh/audiodataset\n
|
329 |
-
### <center
|
330 |
### <center>GPT-SoVITS项目:https://github.com/RVC-Boss/GPT-SoVITS\n
|
331 |
### <center>使用本模型请严格遵守法律法规!发布二创作品请标注本项目作者及链接、作品使用GPT-SoVITS AI生成!\n
|
332 |
-
### <center
|
333 |
""")
|
334 |
# with gr.Tabs():
|
335 |
# with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
|
336 |
with gr.Group():
|
337 |
-
gr.Markdown(value="
|
338 |
with gr.Row():
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
gr.Markdown(value="*请填写需要合成的目标文本")
|
345 |
with gr.Row():
|
346 |
text = gr.Textbox(label="需要合成的文本", value="")
|
@@ -351,21 +409,21 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|
351 |
output = gr.Audio(label="输出的语音")
|
352 |
inference_button.click(
|
353 |
get_tts_wav,
|
354 |
-
[
|
355 |
[output],
|
356 |
)
|
357 |
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
|
370 |
app.queue(max_size=10)
|
371 |
app.launch(inbrowser=True)
|
|
|
1 |
+
import os,re
|
2 |
+
import gradio as gr
|
3 |
|
4 |
gpt_path = os.environ.get(
|
5 |
+
"gpt_path", "models/Azuma/Azuma-e10.ckpt"
|
6 |
)
|
7 |
+
sovits_path = os.environ.get("sovits_path", "models/Azuma/Azuma_e35_s1435.pth")
|
8 |
cnhubert_base_path = os.environ.get(
|
9 |
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
|
10 |
)
|
|
|
22 |
import librosa,torch
|
23 |
from feature_extractor import cnhubert
|
24 |
cnhubert.cnhubert_base_path=cnhubert_base_path
|
25 |
+
import ssl
|
26 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
27 |
+
import nltk
|
28 |
+
nltk.download('cmudict')
|
29 |
|
30 |
from module.models import SynthesizerTrn
|
31 |
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
|
|
111 |
else:
|
112 |
ssl_model = ssl_model.to(device)
|
113 |
|
114 |
+
def change_sovits_weights(sovits_path):
|
115 |
+
global vq_model,hps
|
116 |
+
dict_s2=torch.load(sovits_path,map_location="cpu")
|
117 |
+
hps=dict_s2["config"]
|
118 |
+
hps = DictToAttrRecursive(hps)
|
119 |
+
hps.model.semantic_frame_rate = "25hz"
|
120 |
+
vq_model = SynthesizerTrn(
|
121 |
+
hps.data.filter_length // 2 + 1,
|
122 |
+
hps.train.segment_size // hps.data.hop_length,
|
123 |
+
n_speakers=hps.data.n_speakers,
|
124 |
+
**hps.model
|
125 |
+
)
|
126 |
+
del vq_model.enc_q
|
127 |
+
if is_half == True:
|
128 |
+
vq_model = vq_model.half().to(device)
|
129 |
+
else:
|
130 |
+
vq_model = vq_model.to(device)
|
131 |
+
vq_model.eval()
|
132 |
+
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
133 |
+
change_sovits_weights(sovits_path)
|
134 |
+
|
135 |
+
def change_gpt_weights(gpt_path):
|
136 |
+
global hz,max_sec,t2s_model,config
|
137 |
+
hz = 50
|
138 |
+
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
139 |
+
config = dict_s1["config"]
|
140 |
+
max_sec = config["data"]["max_sec"]
|
141 |
+
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
142 |
+
t2s_model.load_state_dict(dict_s1["weight"])
|
143 |
+
if is_half == True:
|
144 |
+
t2s_model = t2s_model.half()
|
145 |
+
t2s_model = t2s_model.to(device)
|
146 |
+
t2s_model.eval()
|
147 |
+
total = sum([param.nelement() for param in t2s_model.parameters()])
|
148 |
+
print("Number of parameter: %.2fM" % (total / 1e6))
|
149 |
+
change_gpt_weights(gpt_path)
|
150 |
|
151 |
|
152 |
def get_spepc(hps, filename):
|
|
|
168 |
dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
|
169 |
|
170 |
|
171 |
+
def get_tts_wav(selected_text, prompt_text, prompt_language, text, text_language):
|
172 |
+
ref_wav_path = text_to_audio_mappings.get(selected_text, "")
|
173 |
+
if not ref_wav_path:
|
174 |
+
print("Audio file not found for the selected text.")
|
175 |
+
return
|
176 |
t0 = ttime()
|
177 |
prompt_text = prompt_text.strip("\n")
|
178 |
prompt_language, text = prompt_language, text.strip("\n")
|
179 |
+
zero_wav = np.zeros(
|
180 |
+
int(hps.data.sampling_rate * 0.3),
|
181 |
+
dtype=np.float16 if is_half == True else np.float32,
|
182 |
+
)
|
183 |
with torch.no_grad():
|
184 |
+
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
185 |
wav16k = torch.from_numpy(wav16k)
|
186 |
+
zero_wav_torch = torch.from_numpy(zero_wav)
|
187 |
if is_half == True:
|
188 |
wav16k = wav16k.half().to(device)
|
189 |
+
zero_wav_torch = zero_wav_torch.half().to(device)
|
190 |
else:
|
191 |
wav16k = wav16k.to(device)
|
192 |
+
zero_wav_torch = zero_wav_torch.to(device)
|
193 |
+
wav16k=torch.cat([wav16k,zero_wav_torch])
|
194 |
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
195 |
"last_hidden_state"
|
196 |
].transpose(
|
|
|
205 |
phones1 = cleaned_text_to_sequence(phones1)
|
206 |
texts = text.split("\n")
|
207 |
audio_opt = []
|
208 |
+
|
|
|
|
|
|
|
209 |
for text in texts:
|
210 |
# 解决输入目标文本的空行导致报错的问题
|
211 |
if (len(text.strip()) == 0):
|
|
|
346 |
inp = inp.strip("\n")
|
347 |
return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
|
348 |
|
349 |
+
def scan_audio_files(folder_path):
|
350 |
+
""" 扫描指定文件夹获取音频文件列表 """
|
351 |
+
return [f for f in os.listdir(folder_path) if f.endswith('.wav')]
|
352 |
+
|
353 |
+
def load_audio_text_mappings(folder_path, list_file_name):
|
354 |
+
text_to_audio_mappings = {}
|
355 |
+
audio_to_text_mappings = {}
|
356 |
+
with open(os.path.join(folder_path, list_file_name), 'r', encoding='utf-8') as file:
|
357 |
+
for line in file:
|
358 |
+
parts = line.strip().split('|')
|
359 |
+
if len(parts) >= 4:
|
360 |
+
audio_file_name = parts[0]
|
361 |
+
text = parts[3]
|
362 |
+
audio_file_path = os.path.join(folder_path, audio_file_name)
|
363 |
+
text_to_audio_mappings[text] = audio_file_path
|
364 |
+
audio_to_text_mappings[audio_file_path] = text
|
365 |
+
return text_to_audio_mappings, audio_to_text_mappings
|
366 |
+
|
367 |
+
audio_folder_path = 'audio/Azuma'
|
368 |
+
text_to_audio_mappings, audio_to_text_mappings = load_audio_text_mappings(audio_folder_path, 'Azuma.list')
|
369 |
|
370 |
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
371 |
gr.Markdown(value="""
|
372 |
+
# <center>【AI东雪莲】在线语音生成(GPT-SoVITS)\n
|
373 |
|
374 |
### <center>模型作者:Xz乔希 https://space.bilibili.com/5859321\n
|
375 |
### <center>数据集下载:https://huggingface.co/datasets/XzJosh/audiodataset\n
|
376 |
+
### <center>声音归属:東雪蓮Official https://space.bilibili.com/1437582453\n
|
377 |
### <center>GPT-SoVITS项目:https://github.com/RVC-Boss/GPT-SoVITS\n
|
378 |
### <center>使用本模型请严格遵守法律法规!发布二创作品请标注本项目作者及链接、作品使用GPT-SoVITS AI生成!\n
|
379 |
+
### <center>⚠️在线端不稳定且生成速度较慢,强烈建议下载模型本地推理!\n
|
380 |
""")
|
381 |
# with gr.Tabs():
|
382 |
# with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
|
383 |
with gr.Group():
|
384 |
+
gr.Markdown(value="*参考音频选择(必选)")
|
385 |
with gr.Row():
|
386 |
+
audio_select = gr.Dropdown(label="选择参考音频(不建议选较长的)", choices=list(text_to_audio_mappings.keys()))
|
387 |
+
ref_audio = gr.Audio(label="参考音频试听")
|
388 |
+
ref_text = gr.Textbox(label="参考音频文本")
|
389 |
+
|
390 |
+
# 定义更新参考文本的函数
|
391 |
+
def update_ref_text_and_audio(selected_text):
|
392 |
+
audio_path = text_to_audio_mappings.get(selected_text, "")
|
393 |
+
return selected_text, audio_path
|
394 |
+
|
395 |
+
# 绑定下拉菜单的变化到更新函数
|
396 |
+
audio_select.change(update_ref_text_and_audio, [audio_select], [ref_text, ref_audio])
|
397 |
+
|
398 |
+
# 其他 Gradio 组件和功能
|
399 |
+
prompt_language = gr.Dropdown(
|
400 |
+
label="参考音频语种", choices=["中文", "英文", "日文"], value="中文"
|
401 |
+
)
|
402 |
gr.Markdown(value="*请填写需要合成的目标文本")
|
403 |
with gr.Row():
|
404 |
text = gr.Textbox(label="需要合成的文本", value="")
|
|
|
409 |
output = gr.Audio(label="输出的语音")
|
410 |
inference_button.click(
|
411 |
get_tts_wav,
|
412 |
+
[audio_select, ref_text, prompt_language, text, text_language],
|
413 |
[output],
|
414 |
)
|
415 |
|
416 |
+
|
417 |
+
gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先��。合成会根据文本的换行分开合成再拼起来。")
|
418 |
+
with gr.Row():
|
419 |
+
text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
|
420 |
+
button1 = gr.Button("凑五句一切", variant="primary")
|
421 |
+
button2 = gr.Button("凑50字一切", variant="primary")
|
422 |
+
button3 = gr.Button("按中文句号。切", variant="primary")
|
423 |
+
text_opt = gr.Textbox(label="切分后文本", value="")
|
424 |
+
button1.click(cut1, [text_inp], [text_opt])
|
425 |
+
button2.click(cut2, [text_inp], [text_opt])
|
426 |
+
button3.click(cut3, [text_inp], [text_opt])
|
427 |
|
428 |
app.queue(max_size=10)
|
429 |
app.launch(inbrowser=True)
|
models/Azuma/Azuma-e10.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a34b18606751974abdf9178ad76fcda77736693424eb5189384506da80a7b23e
|
3 |
+
size 155084485
|
models/Azuma/Azuma_e35_s1435.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f923e268a9f7d1b410cac5fb861775c39b4973dbd309381829c36965cfd64ef2
|
3 |
+
size 84930071
|
module/data_utils.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
-
import time
|
|
|
2 |
import os
|
3 |
-
import random
|
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torch.utils.data
|
@@ -12,15 +14,12 @@ from text import cleaned_text_to_sequence
|
|
12 |
from utils import load_wav_to_torch, load_filepaths_and_text
|
13 |
import torch.nn.functional as F
|
14 |
from functools import lru_cache
|
15 |
-
import torch
|
16 |
import requests
|
17 |
from scipy.io import wavfile
|
18 |
from io import BytesIO
|
19 |
-
|
20 |
-
# from config import exp_dir
|
21 |
from my_utils import load_audio
|
22 |
|
23 |
-
|
24 |
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
25 |
"""
|
26 |
1) loads audio, speaker_id, text pairs
|
@@ -44,7 +43,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
44 |
|
45 |
for line in lines:
|
46 |
tmp = line.split("\t")
|
47 |
-
if len(tmp) != 4:
|
48 |
continue
|
49 |
self.phoneme_data[tmp[0]] = [tmp[1]]
|
50 |
|
@@ -52,7 +51,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
52 |
tmp = self.audiopaths_sid_text
|
53 |
leng = len(tmp)
|
54 |
min_num = 100
|
55 |
-
if leng < min_num:
|
56 |
self.audiopaths_sid_text = []
|
57 |
for _ in range(max(2, int(min_num / leng))):
|
58 |
self.audiopaths_sid_text += tmp
|
@@ -77,20 +76,28 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
77 |
for audiopath in tqdm(self.audiopaths_sid_text):
|
78 |
try:
|
79 |
phoneme = self.phoneme_data[audiopath][0]
|
80 |
-
phoneme = phoneme.split(
|
81 |
phoneme_ids = cleaned_text_to_sequence(phoneme)
|
82 |
except Exception:
|
83 |
print(f"{audiopath} not in self.phoneme_data !")
|
84 |
skipped_phone += 1
|
85 |
continue
|
|
|
86 |
size = os.path.getsize("%s/%s" % (self.path5, audiopath))
|
87 |
duration = size / self.sampling_rate / 2
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
if 54 > duration > 0.6 or self.val:
|
89 |
audiopaths_sid_text_new.append([audiopath, phoneme_ids])
|
90 |
lengths.append(size // (2 * self.hop_length))
|
91 |
else:
|
92 |
skipped_dur += 1
|
93 |
continue
|
|
|
94 |
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
|
95 |
print("total left: ", len(audiopaths_sid_text_new))
|
96 |
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
@@ -103,10 +110,8 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
103 |
try:
|
104 |
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
105 |
with torch.no_grad():
|
106 |
-
ssl = torch.load(
|
107 |
-
|
108 |
-
)
|
109 |
-
if ssl.shape[-1] != spec.shape[-1]:
|
110 |
typee = ssl.dtype
|
111 |
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
112 |
ssl.requires_grad = False
|
@@ -117,25 +122,15 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
117 |
ssl = torch.zeros(1, 768, 100)
|
118 |
text = text[-1:]
|
119 |
print("load audio or ssl error!!!!!!", audiopath)
|
120 |
-
# print(ssl.requires_grad,spec.requires_grad,wav.requires_grad,text.requires_grad)
|
121 |
return (ssl, spec, wav, text)
|
122 |
|
123 |
def get_audio(self, filename):
|
124 |
-
audio_array = load_audio(
|
125 |
-
filename, self.sampling_rate
|
126 |
-
) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
127 |
-
# print(filename,audio_array.max(),audio_array.min(),audio_array.mean())
|
128 |
audio = torch.FloatTensor(audio_array) # /32768
|
129 |
audio_norm = audio
|
130 |
audio_norm = audio_norm.unsqueeze(0)
|
131 |
-
spec = spectrogram_torch(
|
132 |
-
|
133 |
-
self.filter_length,
|
134 |
-
self.sampling_rate,
|
135 |
-
self.hop_length,
|
136 |
-
self.win_length,
|
137 |
-
center=False,
|
138 |
-
)
|
139 |
spec = torch.squeeze(spec, 0)
|
140 |
return spec, audio_norm
|
141 |
|
@@ -152,14 +147,11 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
152 |
|
153 |
def random_slice(self, ssl, wav, mel):
|
154 |
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
|
155 |
-
|
156 |
-
ssl.shape,
|
157 |
-
wav.shape,
|
158 |
-
)
|
159 |
|
160 |
len_mel = mel.shape[1]
|
161 |
if self.val:
|
162 |
-
reference_mel = mel[:, :
|
163 |
return reference_mel, ssl, wav, mel
|
164 |
dir = random.randint(0, 1)
|
165 |
sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
|
@@ -167,29 +159,22 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
|
167 |
if dir == 0:
|
168 |
reference_mel = mel[:, :sep_point]
|
169 |
ssl = ssl[:, :, sep_point:]
|
170 |
-
wav2 = wav[:, sep_point * self.hop_length
|
171 |
mel = mel[:, sep_point:]
|
172 |
else:
|
173 |
reference_mel = mel[:, sep_point:]
|
174 |
ssl = ssl[:, :, :sep_point]
|
175 |
-
wav2 = wav[:, :
|
176 |
mel = mel[:, :sep_point]
|
177 |
|
178 |
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
|
179 |
-
|
180 |
-
wav.shape,
|
181 |
-
wav2.shape,
|
182 |
-
mel.shape,
|
183 |
-
sep_point,
|
184 |
-
self.hop_length,
|
185 |
-
sep_point * self.hop_length,
|
186 |
-
dir,
|
187 |
-
)
|
188 |
return reference_mel, ssl, wav2, mel
|
189 |
|
190 |
|
191 |
-
class TextAudioSpeakerCollate:
|
192 |
-
"""Zero-pads model inputs and targets
|
|
|
193 |
|
194 |
def __init__(self, return_ids=False):
|
195 |
self.return_ids = return_ids
|
@@ -202,8 +187,8 @@ class TextAudioSpeakerCollate:
|
|
202 |
"""
|
203 |
# Right zero-pad all one-hot text sequences to max input length
|
204 |
_, ids_sorted_decreasing = torch.sort(
|
205 |
-
torch.LongTensor([x[1].size(1) for x in batch]),
|
206 |
-
|
207 |
|
208 |
max_ssl_len = max([x[0].size(2) for x in batch])
|
209 |
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
|
@@ -231,31 +216,22 @@ class TextAudioSpeakerCollate:
|
|
231 |
row = batch[ids_sorted_decreasing[i]]
|
232 |
|
233 |
ssl = row[0]
|
234 |
-
ssl_padded[i, :, :
|
235 |
ssl_lengths[i] = ssl.size(2)
|
236 |
|
237 |
spec = row[1]
|
238 |
-
spec_padded[i, :, :
|
239 |
spec_lengths[i] = spec.size(1)
|
240 |
|
241 |
wav = row[2]
|
242 |
-
wav_padded[i, :, :
|
243 |
wav_lengths[i] = wav.size(1)
|
244 |
|
245 |
text = row[3]
|
246 |
-
text_padded[i, :
|
247 |
text_lengths[i] = text.size(0)
|
248 |
|
249 |
-
return
|
250 |
-
ssl_padded,
|
251 |
-
ssl_lengths,
|
252 |
-
spec_padded,
|
253 |
-
spec_lengths,
|
254 |
-
wav_padded,
|
255 |
-
wav_lengths,
|
256 |
-
text_padded,
|
257 |
-
text_lengths,
|
258 |
-
)
|
259 |
|
260 |
|
261 |
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
@@ -268,18 +244,9 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
|
268 |
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
|
269 |
"""
|
270 |
|
271 |
-
def __init__(
|
272 |
-
self,
|
273 |
-
dataset,
|
274 |
-
batch_size,
|
275 |
-
boundaries,
|
276 |
-
num_replicas=None,
|
277 |
-
rank=None,
|
278 |
-
shuffle=True,
|
279 |
-
):
|
280 |
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
281 |
self.lengths = dataset.lengths
|
282 |
-
# print(233333333333333,self.lengths,dir(dataset))
|
283 |
self.batch_size = batch_size
|
284 |
self.boundaries = boundaries
|
285 |
|
@@ -295,24 +262,22 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
|
295 |
if idx_bucket != -1:
|
296 |
buckets[idx_bucket].append(i)
|
297 |
|
298 |
-
|
299 |
-
|
300 |
if len(buckets[i]) == 0:
|
301 |
buckets.pop(i)
|
302 |
self.boundaries.pop(i + 1)
|
|
|
303 |
|
304 |
num_samples_per_bucket = []
|
305 |
for i in range(len(buckets)):
|
306 |
len_bucket = len(buckets[i])
|
307 |
total_batch_size = self.num_replicas * self.batch_size
|
308 |
-
rem = (
|
309 |
-
total_batch_size - (len_bucket % total_batch_size)
|
310 |
-
) % total_batch_size
|
311 |
num_samples_per_bucket.append(len_bucket + rem)
|
312 |
return buckets, num_samples_per_bucket
|
313 |
|
314 |
def __iter__(self):
|
315 |
-
# deterministically shuffle based on epoch
|
316 |
g = torch.Generator()
|
317 |
g.manual_seed(self.epoch)
|
318 |
|
@@ -331,25 +296,13 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
|
331 |
ids_bucket = indices[i]
|
332 |
num_samples_bucket = self.num_samples_per_bucket[i]
|
333 |
|
334 |
-
# add extra samples to make it evenly divisible
|
335 |
rem = num_samples_bucket - len_bucket
|
336 |
-
ids_bucket = (
|
337 |
-
ids_bucket
|
338 |
-
+ ids_bucket * (rem // len_bucket)
|
339 |
-
+ ids_bucket[: (rem % len_bucket)]
|
340 |
-
)
|
341 |
|
342 |
-
|
343 |
-
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
|
344 |
|
345 |
-
# batching
|
346 |
for j in range(len(ids_bucket) // self.batch_size):
|
347 |
-
batch = [
|
348 |
-
bucket[idx]
|
349 |
-
for idx in ids_bucket[
|
350 |
-
j * self.batch_size : (j + 1) * self.batch_size
|
351 |
-
]
|
352 |
-
]
|
353 |
batches.append(batch)
|
354 |
|
355 |
if self.shuffle:
|
@@ -376,4 +329,4 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
|
376 |
return -1
|
377 |
|
378 |
def __len__(self):
|
379 |
-
return self.num_samples // self.batch_size
|
|
|
1 |
+
import time
|
2 |
+
import logging
|
3 |
import os
|
4 |
+
import random
|
5 |
+
import traceback
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
import torch.utils.data
|
|
|
14 |
from utils import load_wav_to_torch, load_filepaths_and_text
|
15 |
import torch.nn.functional as F
|
16 |
from functools import lru_cache
|
|
|
17 |
import requests
|
18 |
from scipy.io import wavfile
|
19 |
from io import BytesIO
|
|
|
|
|
20 |
from my_utils import load_audio
|
21 |
|
22 |
+
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
|
23 |
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
24 |
"""
|
25 |
1) loads audio, speaker_id, text pairs
|
|
|
43 |
|
44 |
for line in lines:
|
45 |
tmp = line.split("\t")
|
46 |
+
if (len(tmp) != 4):
|
47 |
continue
|
48 |
self.phoneme_data[tmp[0]] = [tmp[1]]
|
49 |
|
|
|
51 |
tmp = self.audiopaths_sid_text
|
52 |
leng = len(tmp)
|
53 |
min_num = 100
|
54 |
+
if (leng < min_num):
|
55 |
self.audiopaths_sid_text = []
|
56 |
for _ in range(max(2, int(min_num / leng))):
|
57 |
self.audiopaths_sid_text += tmp
|
|
|
76 |
for audiopath in tqdm(self.audiopaths_sid_text):
|
77 |
try:
|
78 |
phoneme = self.phoneme_data[audiopath][0]
|
79 |
+
phoneme = phoneme.split(' ')
|
80 |
phoneme_ids = cleaned_text_to_sequence(phoneme)
|
81 |
except Exception:
|
82 |
print(f"{audiopath} not in self.phoneme_data !")
|
83 |
skipped_phone += 1
|
84 |
continue
|
85 |
+
|
86 |
size = os.path.getsize("%s/%s" % (self.path5, audiopath))
|
87 |
duration = size / self.sampling_rate / 2
|
88 |
+
|
89 |
+
if duration == 0:
|
90 |
+
print(f"Zero duration for {audiopath}, skipping...")
|
91 |
+
skipped_dur += 1
|
92 |
+
continue
|
93 |
+
|
94 |
if 54 > duration > 0.6 or self.val:
|
95 |
audiopaths_sid_text_new.append([audiopath, phoneme_ids])
|
96 |
lengths.append(size // (2 * self.hop_length))
|
97 |
else:
|
98 |
skipped_dur += 1
|
99 |
continue
|
100 |
+
|
101 |
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
|
102 |
print("total left: ", len(audiopaths_sid_text_new))
|
103 |
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
|
|
110 |
try:
|
111 |
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
112 |
with torch.no_grad():
|
113 |
+
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
114 |
+
if (ssl.shape[-1] != spec.shape[-1]):
|
|
|
|
|
115 |
typee = ssl.dtype
|
116 |
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
117 |
ssl.requires_grad = False
|
|
|
122 |
ssl = torch.zeros(1, 768, 100)
|
123 |
text = text[-1:]
|
124 |
print("load audio or ssl error!!!!!!", audiopath)
|
|
|
125 |
return (ssl, spec, wav, text)
|
126 |
|
127 |
def get_audio(self, filename):
|
128 |
+
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
|
|
|
|
|
|
129 |
audio = torch.FloatTensor(audio_array) # /32768
|
130 |
audio_norm = audio
|
131 |
audio_norm = audio_norm.unsqueeze(0)
|
132 |
+
spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
|
133 |
+
center=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
spec = torch.squeeze(spec, 0)
|
135 |
return spec, audio_norm
|
136 |
|
|
|
147 |
|
148 |
def random_slice(self, ssl, wav, mel):
|
149 |
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
|
150 |
+
"first", ssl.shape, wav.shape)
|
|
|
|
|
|
|
151 |
|
152 |
len_mel = mel.shape[1]
|
153 |
if self.val:
|
154 |
+
reference_mel = mel[:, :len_mel // 3]
|
155 |
return reference_mel, ssl, wav, mel
|
156 |
dir = random.randint(0, 1)
|
157 |
sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
|
|
|
159 |
if dir == 0:
|
160 |
reference_mel = mel[:, :sep_point]
|
161 |
ssl = ssl[:, :, sep_point:]
|
162 |
+
wav2 = wav[:, sep_point * self.hop_length:]
|
163 |
mel = mel[:, sep_point:]
|
164 |
else:
|
165 |
reference_mel = mel[:, sep_point:]
|
166 |
ssl = ssl[:, :, :sep_point]
|
167 |
+
wav2 = wav[:, :sep_point * self.hop_length]
|
168 |
mel = mel[:, :sep_point]
|
169 |
|
170 |
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
|
171 |
+
ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
return reference_mel, ssl, wav2, mel
|
173 |
|
174 |
|
175 |
+
class TextAudioSpeakerCollate():
|
176 |
+
""" Zero-pads model inputs and targets
|
177 |
+
"""
|
178 |
|
179 |
def __init__(self, return_ids=False):
|
180 |
self.return_ids = return_ids
|
|
|
187 |
"""
|
188 |
# Right zero-pad all one-hot text sequences to max input length
|
189 |
_, ids_sorted_decreasing = torch.sort(
|
190 |
+
torch.LongTensor([x[1].size(1) for x in batch]),
|
191 |
+
dim=0, descending=True)
|
192 |
|
193 |
max_ssl_len = max([x[0].size(2) for x in batch])
|
194 |
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
|
|
|
216 |
row = batch[ids_sorted_decreasing[i]]
|
217 |
|
218 |
ssl = row[0]
|
219 |
+
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
|
220 |
ssl_lengths[i] = ssl.size(2)
|
221 |
|
222 |
spec = row[1]
|
223 |
+
spec_padded[i, :, :spec.size(1)] = spec
|
224 |
spec_lengths[i] = spec.size(1)
|
225 |
|
226 |
wav = row[2]
|
227 |
+
wav_padded[i, :, :wav.size(1)] = wav
|
228 |
wav_lengths[i] = wav.size(1)
|
229 |
|
230 |
text = row[3]
|
231 |
+
text_padded[i, :text.size(0)] = text
|
232 |
text_lengths[i] = text.size(0)
|
233 |
|
234 |
+
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
|
237 |
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
|
|
244 |
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
|
245 |
"""
|
246 |
|
247 |
+
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
249 |
self.lengths = dataset.lengths
|
|
|
250 |
self.batch_size = batch_size
|
251 |
self.boundaries = boundaries
|
252 |
|
|
|
262 |
if idx_bucket != -1:
|
263 |
buckets[idx_bucket].append(i)
|
264 |
|
265 |
+
i = len(buckets) - 1
|
266 |
+
while i >= 0:
|
267 |
if len(buckets[i]) == 0:
|
268 |
buckets.pop(i)
|
269 |
self.boundaries.pop(i + 1)
|
270 |
+
i -= 1
|
271 |
|
272 |
num_samples_per_bucket = []
|
273 |
for i in range(len(buckets)):
|
274 |
len_bucket = len(buckets[i])
|
275 |
total_batch_size = self.num_replicas * self.batch_size
|
276 |
+
rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
|
|
|
|
|
277 |
num_samples_per_bucket.append(len_bucket + rem)
|
278 |
return buckets, num_samples_per_bucket
|
279 |
|
280 |
def __iter__(self):
|
|
|
281 |
g = torch.Generator()
|
282 |
g.manual_seed(self.epoch)
|
283 |
|
|
|
296 |
ids_bucket = indices[i]
|
297 |
num_samples_bucket = self.num_samples_per_bucket[i]
|
298 |
|
|
|
299 |
rem = num_samples_bucket - len_bucket
|
300 |
+
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
|
|
|
|
|
|
|
|
|
301 |
|
302 |
+
ids_bucket = ids_bucket[self.rank::self.num_replicas]
|
|
|
303 |
|
|
|
304 |
for j in range(len(ids_bucket) // self.batch_size):
|
305 |
+
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
|
|
|
|
|
|
|
|
|
|
|
306 |
batches.append(batch)
|
307 |
|
308 |
if self.shuffle:
|
|
|
329 |
return -1
|
330 |
|
331 |
def __len__(self):
|
332 |
+
return self.num_samples // self.batch_size
|
requirements.txt
CHANGED
@@ -1,18 +1,24 @@
|
|
1 |
numpy
|
2 |
scipy
|
3 |
-
|
4 |
librosa==0.9.2
|
5 |
numba==0.56.4
|
6 |
-
pytorch-lightning
|
|
|
7 |
gradio==3.47.1
|
8 |
ffmpeg-python
|
9 |
-
|
|
|
|
|
10 |
cn2an
|
11 |
pypinyin
|
12 |
-
pyopenjtalk
|
13 |
g2p_en
|
14 |
torchaudio
|
|
|
15 |
sentencepiece
|
16 |
transformers
|
17 |
-
|
18 |
-
|
|
|
|
|
|
1 |
numpy
|
2 |
scipy
|
3 |
+
tensorboard
|
4 |
librosa==0.9.2
|
5 |
numba==0.56.4
|
6 |
+
pytorch-lightning==2.1
|
7 |
+
torchmetrics==0.10.1
|
8 |
gradio==3.47.1
|
9 |
ffmpeg-python
|
10 |
+
onnxruntime
|
11 |
+
tqdm
|
12 |
+
funasr
|
13 |
cn2an
|
14 |
pypinyin
|
15 |
+
pyopenjtalk
|
16 |
g2p_en
|
17 |
torchaudio
|
18 |
+
modelscope
|
19 |
sentencepiece
|
20 |
transformers
|
21 |
+
chardet
|
22 |
+
PyYAML
|
23 |
+
psutil
|
24 |
+
jieba_fast
|
text/chinese.py
CHANGED
@@ -18,7 +18,7 @@ pinyin_to_symbol_map = {
|
|
18 |
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
|
19 |
}
|
20 |
|
21 |
-
import
|
22 |
|
23 |
|
24 |
rep_map = {
|
|
|
18 |
for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
|
19 |
}
|
20 |
|
21 |
+
import jieba_fast.posseg as psg
|
22 |
|
23 |
|
24 |
rep_map = {
|
text/tone_sandhi.py
CHANGED
@@ -14,7 +14,7 @@
|
|
14 |
from typing import List
|
15 |
from typing import Tuple
|
16 |
|
17 |
-
import jieba
|
18 |
from pypinyin import lazy_pinyin
|
19 |
from pypinyin import Style
|
20 |
|
|
|
14 |
from typing import List
|
15 |
from typing import Tuple
|
16 |
|
17 |
+
import jieba_fast as jieba
|
18 |
from pypinyin import lazy_pinyin
|
19 |
from pypinyin import Style
|
20 |
|
utils.py
CHANGED
@@ -18,7 +18,7 @@ logging.getLogger("matplotlib").setLevel(logging.ERROR)
|
|
18 |
|
19 |
MATPLOTLIB_FLAG = False
|
20 |
|
21 |
-
logging.basicConfig(stream=sys.stdout, level=logging.
|
22 |
logger = logging
|
23 |
|
24 |
|
@@ -310,13 +310,13 @@ def check_git_hash(model_dir):
|
|
310 |
def get_logger(model_dir, filename="train.log"):
|
311 |
global logger
|
312 |
logger = logging.getLogger(os.path.basename(model_dir))
|
313 |
-
logger.setLevel(logging.
|
314 |
|
315 |
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
|
316 |
if not os.path.exists(model_dir):
|
317 |
os.makedirs(model_dir)
|
318 |
h = logging.FileHandler(os.path.join(model_dir, filename))
|
319 |
-
h.setLevel(logging.
|
320 |
h.setFormatter(formatter)
|
321 |
logger.addHandler(h)
|
322 |
return logger
|
|
|
18 |
|
19 |
MATPLOTLIB_FLAG = False
|
20 |
|
21 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
22 |
logger = logging
|
23 |
|
24 |
|
|
|
310 |
def get_logger(model_dir, filename="train.log"):
|
311 |
global logger
|
312 |
logger = logging.getLogger(os.path.basename(model_dir))
|
313 |
+
logger.setLevel(logging.DEBUG)
|
314 |
|
315 |
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
|
316 |
if not os.path.exists(model_dir):
|
317 |
os.makedirs(model_dir)
|
318 |
h = logging.FileHandler(os.path.join(model_dir, filename))
|
319 |
+
h.setLevel(logging.DEBUG)
|
320 |
h.setFormatter(formatter)
|
321 |
logger.addHandler(h)
|
322 |
return logger
|