XzJosh commited on
Commit
e5f651c
·
verified ·
1 Parent(s): 03da129

Upload 34 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,34 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -texttext/cmudict_cache.pickle filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,9 +1,10 @@
1
- import os
 
2
 
3
  gpt_path = os.environ.get(
4
- "gpt_path", "models/Carol/Carol-e15.ckpt"
5
  )
6
- sovits_path = os.environ.get("sovits_path", "models/Carol/Carol_e40_s2160.pth")
7
  cnhubert_base_path = os.environ.get(
8
  "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
9
  )
@@ -21,6 +22,10 @@ import numpy as np
21
  import librosa,torch
22
  from feature_extractor import cnhubert
23
  cnhubert.cnhubert_base_path=cnhubert_base_path
 
 
 
 
24
 
25
  from module.models import SynthesizerTrn
26
  from AR.models.t2s_lightning_module import Text2SemanticLightningModule
@@ -106,29 +111,42 @@ if is_half == True:
106
  else:
107
  ssl_model = ssl_model.to(device)
108
 
109
- vq_model = SynthesizerTrn(
110
- hps.data.filter_length // 2 + 1,
111
- hps.train.segment_size // hps.data.hop_length,
112
- n_speakers=hps.data.n_speakers,
113
- **hps.model
114
- )
115
- if is_half == True:
116
- vq_model = vq_model.half().to(device)
117
- else:
118
- vq_model = vq_model.to(device)
119
- vq_model.eval()
120
- print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
121
- hz = 50
122
- max_sec = config["data"]["max_sec"]
123
- # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
124
- t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
125
- t2s_model.load_state_dict(dict_s1["weight"])
126
- if is_half == True:
127
- t2s_model = t2s_model.half()
128
- t2s_model = t2s_model.to(device)
129
- t2s_model.eval()
130
- total = sum([param.nelement() for param in t2s_model.parameters()])
131
- print("Number of parameter: %.2fM" % (total / 1e6))
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  def get_spepc(hps, filename):
@@ -150,17 +168,29 @@ def get_spepc(hps, filename):
150
  dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
151
 
152
 
153
- def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
 
 
 
 
154
  t0 = ttime()
155
  prompt_text = prompt_text.strip("\n")
156
  prompt_language, text = prompt_language, text.strip("\n")
 
 
 
 
157
  with torch.no_grad():
158
- wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
159
  wav16k = torch.from_numpy(wav16k)
 
160
  if is_half == True:
161
  wav16k = wav16k.half().to(device)
 
162
  else:
163
  wav16k = wav16k.to(device)
 
 
164
  ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
165
  "last_hidden_state"
166
  ].transpose(
@@ -175,10 +205,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
175
  phones1 = cleaned_text_to_sequence(phones1)
176
  texts = text.split("\n")
177
  audio_opt = []
178
- zero_wav = np.zeros(
179
- int(hps.data.sampling_rate * 0.3),
180
- dtype=np.float16 if is_half == True else np.float32,
181
- )
182
  for text in texts:
183
  # 解决输入目标文本的空行导致报错的问题
184
  if (len(text.strip()) == 0):
@@ -319,28 +346,59 @@ def cut3(inp):
319
  inp = inp.strip("\n")
320
  return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
  with gr.Blocks(title="GPT-SoVITS WebUI") as app:
324
  gr.Markdown(value="""
325
- # <center>【AI珈乐】在线语音生成(GPT-SoVITS)\n
326
 
327
  ### <center>模型作者:Xz乔希 https://space.bilibili.com/5859321\n
328
  ### <center>数据集下载:https://huggingface.co/datasets/XzJosh/audiodataset\n
329
- ### <center>声音归属:珈乐Carol https://space.bilibili.com/351609538\n
330
  ### <center>GPT-SoVITS项目:https://github.com/RVC-Boss/GPT-SoVITS\n
331
  ### <center>使用本模型请严格遵守法律法规!发布二创作品请标注本项目作者及链接、作品使用GPT-SoVITS AI生成!\n
332
- ### <center>⚠️在线端极不稳定且生成速度极慢,强烈建议下载模型本地推理!\n
333
  """)
334
  # with gr.Tabs():
335
  # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
336
  with gr.Group():
337
- gr.Markdown(value="*请上传并填写参考信息")
338
  with gr.Row():
339
- inp_ref = gr.Audio(label="请上传参考音频", type="filepath", value="Carol_653.wav")
340
- prompt_text = gr.Textbox(label="参考音频的文本", value="电视剧神话,是的,电视剧神话但那我觉得你们既然猜出来了你们肯定是有。")
341
- prompt_language = gr.Dropdown(
342
- label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
343
- )
 
 
 
 
 
 
 
 
 
 
 
344
  gr.Markdown(value="*请填写需要合成的目标文本")
345
  with gr.Row():
346
  text = gr.Textbox(label="需要合成的文本", value="")
@@ -351,21 +409,21 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
351
  output = gr.Audio(label="输出的语音")
352
  inference_button.click(
353
  get_tts_wav,
354
- [inp_ref, prompt_text, prompt_language, text, text_language],
355
  [output],
356
  )
357
 
358
- gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
359
- with gr.Row():
360
- text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
361
- button1 = gr.Button("凑五句一切", variant="primary")
362
- button2 = gr.Button("凑50字一切", variant="primary")
363
- button3 = gr.Button("按中文句号。切", variant="primary")
364
- text_opt = gr.Textbox(label="切分后文本", value="")
365
- button1.click(cut1, [text_inp], [text_opt])
366
- button2.click(cut2, [text_inp], [text_opt])
367
- button3.click(cut3, [text_inp], [text_opt])
368
- gr.Markdown(value="后续将支持混合语种编码文本输入。")
369
 
370
  app.queue(max_size=10)
371
  app.launch(inbrowser=True)
 
1
+ import os,re
2
+ import gradio as gr
3
 
4
  gpt_path = os.environ.get(
5
+ "gpt_path", "models/Azuma/Azuma-e10.ckpt"
6
  )
7
+ sovits_path = os.environ.get("sovits_path", "models/Azuma/Azuma_e35_s1435.pth")
8
  cnhubert_base_path = os.environ.get(
9
  "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
10
  )
 
22
  import librosa,torch
23
  from feature_extractor import cnhubert
24
  cnhubert.cnhubert_base_path=cnhubert_base_path
25
+ import ssl
26
+ ssl._create_default_https_context = ssl._create_unverified_context
27
+ import nltk
28
+ nltk.download('cmudict')
29
 
30
  from module.models import SynthesizerTrn
31
  from AR.models.t2s_lightning_module import Text2SemanticLightningModule
 
111
  else:
112
  ssl_model = ssl_model.to(device)
113
 
114
+ def change_sovits_weights(sovits_path):
115
+ global vq_model,hps
116
+ dict_s2=torch.load(sovits_path,map_location="cpu")
117
+ hps=dict_s2["config"]
118
+ hps = DictToAttrRecursive(hps)
119
+ hps.model.semantic_frame_rate = "25hz"
120
+ vq_model = SynthesizerTrn(
121
+ hps.data.filter_length // 2 + 1,
122
+ hps.train.segment_size // hps.data.hop_length,
123
+ n_speakers=hps.data.n_speakers,
124
+ **hps.model
125
+ )
126
+ del vq_model.enc_q
127
+ if is_half == True:
128
+ vq_model = vq_model.half().to(device)
129
+ else:
130
+ vq_model = vq_model.to(device)
131
+ vq_model.eval()
132
+ print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
133
+ change_sovits_weights(sovits_path)
134
+
135
+ def change_gpt_weights(gpt_path):
136
+ global hz,max_sec,t2s_model,config
137
+ hz = 50
138
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
139
+ config = dict_s1["config"]
140
+ max_sec = config["data"]["max_sec"]
141
+ t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
142
+ t2s_model.load_state_dict(dict_s1["weight"])
143
+ if is_half == True:
144
+ t2s_model = t2s_model.half()
145
+ t2s_model = t2s_model.to(device)
146
+ t2s_model.eval()
147
+ total = sum([param.nelement() for param in t2s_model.parameters()])
148
+ print("Number of parameter: %.2fM" % (total / 1e6))
149
+ change_gpt_weights(gpt_path)
150
 
151
 
152
  def get_spepc(hps, filename):
 
168
  dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
169
 
170
 
171
+ def get_tts_wav(selected_text, prompt_text, prompt_language, text, text_language):
172
+ ref_wav_path = text_to_audio_mappings.get(selected_text, "")
173
+ if not ref_wav_path:
174
+ print("Audio file not found for the selected text.")
175
+ return
176
  t0 = ttime()
177
  prompt_text = prompt_text.strip("\n")
178
  prompt_language, text = prompt_language, text.strip("\n")
179
+ zero_wav = np.zeros(
180
+ int(hps.data.sampling_rate * 0.3),
181
+ dtype=np.float16 if is_half == True else np.float32,
182
+ )
183
  with torch.no_grad():
184
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
185
  wav16k = torch.from_numpy(wav16k)
186
+ zero_wav_torch = torch.from_numpy(zero_wav)
187
  if is_half == True:
188
  wav16k = wav16k.half().to(device)
189
+ zero_wav_torch = zero_wav_torch.half().to(device)
190
  else:
191
  wav16k = wav16k.to(device)
192
+ zero_wav_torch = zero_wav_torch.to(device)
193
+ wav16k=torch.cat([wav16k,zero_wav_torch])
194
  ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
195
  "last_hidden_state"
196
  ].transpose(
 
205
  phones1 = cleaned_text_to_sequence(phones1)
206
  texts = text.split("\n")
207
  audio_opt = []
208
+
 
 
 
209
  for text in texts:
210
  # 解决输入目标文本的空行导致报错的问题
211
  if (len(text.strip()) == 0):
 
346
  inp = inp.strip("\n")
347
  return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
348
 
349
+ def scan_audio_files(folder_path):
350
+ """ 扫描指定文件夹获取音频文件列表 """
351
+ return [f for f in os.listdir(folder_path) if f.endswith('.wav')]
352
+
353
+ def load_audio_text_mappings(folder_path, list_file_name):
354
+ text_to_audio_mappings = {}
355
+ audio_to_text_mappings = {}
356
+ with open(os.path.join(folder_path, list_file_name), 'r', encoding='utf-8') as file:
357
+ for line in file:
358
+ parts = line.strip().split('|')
359
+ if len(parts) >= 4:
360
+ audio_file_name = parts[0]
361
+ text = parts[3]
362
+ audio_file_path = os.path.join(folder_path, audio_file_name)
363
+ text_to_audio_mappings[text] = audio_file_path
364
+ audio_to_text_mappings[audio_file_path] = text
365
+ return text_to_audio_mappings, audio_to_text_mappings
366
+
367
+ audio_folder_path = 'audio/Azuma'
368
+ text_to_audio_mappings, audio_to_text_mappings = load_audio_text_mappings(audio_folder_path, 'Azuma.list')
369
 
370
  with gr.Blocks(title="GPT-SoVITS WebUI") as app:
371
  gr.Markdown(value="""
372
+ # <center>【AI东雪莲】在线语音生成(GPT-SoVITS)\n
373
 
374
  ### <center>模型作者:Xz乔希 https://space.bilibili.com/5859321\n
375
  ### <center>数据集下载:https://huggingface.co/datasets/XzJosh/audiodataset\n
376
+ ### <center>声音归属:東雪蓮Official https://space.bilibili.com/1437582453\n
377
  ### <center>GPT-SoVITS项目:https://github.com/RVC-Boss/GPT-SoVITS\n
378
  ### <center>使用本模型请严格遵守法律法规!发布二创作品请标注本项目作者及链接、作品使用GPT-SoVITS AI生成!\n
379
+ ### <center>⚠️在线端不稳定且生成速度较慢,强烈建议下载模型本地推理!\n
380
  """)
381
  # with gr.Tabs():
382
  # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
383
  with gr.Group():
384
+ gr.Markdown(value="*参考音频选择(必选)")
385
  with gr.Row():
386
+ audio_select = gr.Dropdown(label="选择参考音频(不建议选较长的)", choices=list(text_to_audio_mappings.keys()))
387
+ ref_audio = gr.Audio(label="参考音频试听")
388
+ ref_text = gr.Textbox(label="参考音频文本")
389
+
390
+ # 定义更新参考文本的函数
391
+ def update_ref_text_and_audio(selected_text):
392
+ audio_path = text_to_audio_mappings.get(selected_text, "")
393
+ return selected_text, audio_path
394
+
395
+ # 绑定下拉菜单的变化到更新函数
396
+ audio_select.change(update_ref_text_and_audio, [audio_select], [ref_text, ref_audio])
397
+
398
+ # 其他 Gradio 组件和功能
399
+ prompt_language = gr.Dropdown(
400
+ label="参考音频语种", choices=["中文", "英文", "日文"], value="中文"
401
+ )
402
  gr.Markdown(value="*请填写需要合成的目标文本")
403
  with gr.Row():
404
  text = gr.Textbox(label="需要合成的文本", value="")
 
409
  output = gr.Audio(label="输出的语音")
410
  inference_button.click(
411
  get_tts_wav,
412
+ [audio_select, ref_text, prompt_language, text, text_language],
413
  [output],
414
  )
415
 
416
+
417
+ gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先��。合成会根据文本的换行分开合成再拼起来。")
418
+ with gr.Row():
419
+ text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
420
+ button1 = gr.Button("凑五句一切", variant="primary")
421
+ button2 = gr.Button("凑50字一切", variant="primary")
422
+ button3 = gr.Button("按中文句号。切", variant="primary")
423
+ text_opt = gr.Textbox(label="切分后文本", value="")
424
+ button1.click(cut1, [text_inp], [text_opt])
425
+ button2.click(cut2, [text_inp], [text_opt])
426
+ button3.click(cut3, [text_inp], [text_opt])
427
 
428
  app.queue(max_size=10)
429
  app.launch(inbrowser=True)
models/Azuma/Azuma-e10.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a34b18606751974abdf9178ad76fcda77736693424eb5189384506da80a7b23e
3
+ size 155084485
models/Azuma/Azuma_e35_s1435.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f923e268a9f7d1b410cac5fb861775c39b4973dbd309381829c36965cfd64ef2
3
+ size 84930071
module/data_utils.py CHANGED
@@ -1,6 +1,8 @@
1
- import time, logging
 
2
  import os
3
- import random, traceback
 
4
  import numpy as np
5
  import torch
6
  import torch.utils.data
@@ -12,15 +14,12 @@ from text import cleaned_text_to_sequence
12
  from utils import load_wav_to_torch, load_filepaths_and_text
13
  import torch.nn.functional as F
14
  from functools import lru_cache
15
- import torch
16
  import requests
17
  from scipy.io import wavfile
18
  from io import BytesIO
19
-
20
- # from config import exp_dir
21
  from my_utils import load_audio
22
 
23
-
24
  class TextAudioSpeakerLoader(torch.utils.data.Dataset):
25
  """
26
  1) loads audio, speaker_id, text pairs
@@ -44,7 +43,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
44
 
45
  for line in lines:
46
  tmp = line.split("\t")
47
- if len(tmp) != 4:
48
  continue
49
  self.phoneme_data[tmp[0]] = [tmp[1]]
50
 
@@ -52,7 +51,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
52
  tmp = self.audiopaths_sid_text
53
  leng = len(tmp)
54
  min_num = 100
55
- if leng < min_num:
56
  self.audiopaths_sid_text = []
57
  for _ in range(max(2, int(min_num / leng))):
58
  self.audiopaths_sid_text += tmp
@@ -77,20 +76,28 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
77
  for audiopath in tqdm(self.audiopaths_sid_text):
78
  try:
79
  phoneme = self.phoneme_data[audiopath][0]
80
- phoneme = phoneme.split(" ")
81
  phoneme_ids = cleaned_text_to_sequence(phoneme)
82
  except Exception:
83
  print(f"{audiopath} not in self.phoneme_data !")
84
  skipped_phone += 1
85
  continue
 
86
  size = os.path.getsize("%s/%s" % (self.path5, audiopath))
87
  duration = size / self.sampling_rate / 2
 
 
 
 
 
 
88
  if 54 > duration > 0.6 or self.val:
89
  audiopaths_sid_text_new.append([audiopath, phoneme_ids])
90
  lengths.append(size // (2 * self.hop_length))
91
  else:
92
  skipped_dur += 1
93
  continue
 
94
  print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
95
  print("total left: ", len(audiopaths_sid_text_new))
96
  assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
@@ -103,10 +110,8 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
103
  try:
104
  spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
105
  with torch.no_grad():
106
- ssl = torch.load(
107
- "%s/%s.pt" % (self.path4, audiopath), map_location="cpu"
108
- )
109
- if ssl.shape[-1] != spec.shape[-1]:
110
  typee = ssl.dtype
111
  ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
112
  ssl.requires_grad = False
@@ -117,25 +122,15 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
117
  ssl = torch.zeros(1, 768, 100)
118
  text = text[-1:]
119
  print("load audio or ssl error!!!!!!", audiopath)
120
- # print(ssl.requires_grad,spec.requires_grad,wav.requires_grad,text.requires_grad)
121
  return (ssl, spec, wav, text)
122
 
123
  def get_audio(self, filename):
124
- audio_array = load_audio(
125
- filename, self.sampling_rate
126
- ) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
127
- # print(filename,audio_array.max(),audio_array.min(),audio_array.mean())
128
  audio = torch.FloatTensor(audio_array) # /32768
129
  audio_norm = audio
130
  audio_norm = audio_norm.unsqueeze(0)
131
- spec = spectrogram_torch(
132
- audio_norm,
133
- self.filter_length,
134
- self.sampling_rate,
135
- self.hop_length,
136
- self.win_length,
137
- center=False,
138
- )
139
  spec = torch.squeeze(spec, 0)
140
  return spec, audio_norm
141
 
@@ -152,14 +147,11 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
152
 
153
  def random_slice(self, ssl, wav, mel):
154
  assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
155
- "first",
156
- ssl.shape,
157
- wav.shape,
158
- )
159
 
160
  len_mel = mel.shape[1]
161
  if self.val:
162
- reference_mel = mel[:, : len_mel // 3]
163
  return reference_mel, ssl, wav, mel
164
  dir = random.randint(0, 1)
165
  sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
@@ -167,29 +159,22 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
167
  if dir == 0:
168
  reference_mel = mel[:, :sep_point]
169
  ssl = ssl[:, :, sep_point:]
170
- wav2 = wav[:, sep_point * self.hop_length :]
171
  mel = mel[:, sep_point:]
172
  else:
173
  reference_mel = mel[:, sep_point:]
174
  ssl = ssl[:, :, :sep_point]
175
- wav2 = wav[:, : sep_point * self.hop_length]
176
  mel = mel[:, :sep_point]
177
 
178
  assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
179
- ssl.shape,
180
- wav.shape,
181
- wav2.shape,
182
- mel.shape,
183
- sep_point,
184
- self.hop_length,
185
- sep_point * self.hop_length,
186
- dir,
187
- )
188
  return reference_mel, ssl, wav2, mel
189
 
190
 
191
- class TextAudioSpeakerCollate:
192
- """Zero-pads model inputs and targets"""
 
193
 
194
  def __init__(self, return_ids=False):
195
  self.return_ids = return_ids
@@ -202,8 +187,8 @@ class TextAudioSpeakerCollate:
202
  """
203
  # Right zero-pad all one-hot text sequences to max input length
204
  _, ids_sorted_decreasing = torch.sort(
205
- torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
206
- )
207
 
208
  max_ssl_len = max([x[0].size(2) for x in batch])
209
  max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
@@ -231,31 +216,22 @@ class TextAudioSpeakerCollate:
231
  row = batch[ids_sorted_decreasing[i]]
232
 
233
  ssl = row[0]
234
- ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
235
  ssl_lengths[i] = ssl.size(2)
236
 
237
  spec = row[1]
238
- spec_padded[i, :, : spec.size(1)] = spec
239
  spec_lengths[i] = spec.size(1)
240
 
241
  wav = row[2]
242
- wav_padded[i, :, : wav.size(1)] = wav
243
  wav_lengths[i] = wav.size(1)
244
 
245
  text = row[3]
246
- text_padded[i, : text.size(0)] = text
247
  text_lengths[i] = text.size(0)
248
 
249
- return (
250
- ssl_padded,
251
- ssl_lengths,
252
- spec_padded,
253
- spec_lengths,
254
- wav_padded,
255
- wav_lengths,
256
- text_padded,
257
- text_lengths,
258
- )
259
 
260
 
261
  class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
@@ -268,18 +244,9 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
268
  Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
269
  """
270
 
271
- def __init__(
272
- self,
273
- dataset,
274
- batch_size,
275
- boundaries,
276
- num_replicas=None,
277
- rank=None,
278
- shuffle=True,
279
- ):
280
  super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
281
  self.lengths = dataset.lengths
282
- # print(233333333333333,self.lengths,dir(dataset))
283
  self.batch_size = batch_size
284
  self.boundaries = boundaries
285
 
@@ -295,24 +262,22 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
295
  if idx_bucket != -1:
296
  buckets[idx_bucket].append(i)
297
 
298
- for i in range(len(buckets) - 1, 0, -1):
299
- # for i in range(len(buckets) - 1, -1, -1):
300
  if len(buckets[i]) == 0:
301
  buckets.pop(i)
302
  self.boundaries.pop(i + 1)
 
303
 
304
  num_samples_per_bucket = []
305
  for i in range(len(buckets)):
306
  len_bucket = len(buckets[i])
307
  total_batch_size = self.num_replicas * self.batch_size
308
- rem = (
309
- total_batch_size - (len_bucket % total_batch_size)
310
- ) % total_batch_size
311
  num_samples_per_bucket.append(len_bucket + rem)
312
  return buckets, num_samples_per_bucket
313
 
314
  def __iter__(self):
315
- # deterministically shuffle based on epoch
316
  g = torch.Generator()
317
  g.manual_seed(self.epoch)
318
 
@@ -331,25 +296,13 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
331
  ids_bucket = indices[i]
332
  num_samples_bucket = self.num_samples_per_bucket[i]
333
 
334
- # add extra samples to make it evenly divisible
335
  rem = num_samples_bucket - len_bucket
336
- ids_bucket = (
337
- ids_bucket
338
- + ids_bucket * (rem // len_bucket)
339
- + ids_bucket[: (rem % len_bucket)]
340
- )
341
 
342
- # subsample
343
- ids_bucket = ids_bucket[self.rank :: self.num_replicas]
344
 
345
- # batching
346
  for j in range(len(ids_bucket) // self.batch_size):
347
- batch = [
348
- bucket[idx]
349
- for idx in ids_bucket[
350
- j * self.batch_size : (j + 1) * self.batch_size
351
- ]
352
- ]
353
  batches.append(batch)
354
 
355
  if self.shuffle:
@@ -376,4 +329,4 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
376
  return -1
377
 
378
  def __len__(self):
379
- return self.num_samples // self.batch_size
 
1
+ import time
2
+ import logging
3
  import os
4
+ import random
5
+ import traceback
6
  import numpy as np
7
  import torch
8
  import torch.utils.data
 
14
  from utils import load_wav_to_torch, load_filepaths_and_text
15
  import torch.nn.functional as F
16
  from functools import lru_cache
 
17
  import requests
18
  from scipy.io import wavfile
19
  from io import BytesIO
 
 
20
  from my_utils import load_audio
21
 
22
+ # ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
23
  class TextAudioSpeakerLoader(torch.utils.data.Dataset):
24
  """
25
  1) loads audio, speaker_id, text pairs
 
43
 
44
  for line in lines:
45
  tmp = line.split("\t")
46
+ if (len(tmp) != 4):
47
  continue
48
  self.phoneme_data[tmp[0]] = [tmp[1]]
49
 
 
51
  tmp = self.audiopaths_sid_text
52
  leng = len(tmp)
53
  min_num = 100
54
+ if (leng < min_num):
55
  self.audiopaths_sid_text = []
56
  for _ in range(max(2, int(min_num / leng))):
57
  self.audiopaths_sid_text += tmp
 
76
  for audiopath in tqdm(self.audiopaths_sid_text):
77
  try:
78
  phoneme = self.phoneme_data[audiopath][0]
79
+ phoneme = phoneme.split(' ')
80
  phoneme_ids = cleaned_text_to_sequence(phoneme)
81
  except Exception:
82
  print(f"{audiopath} not in self.phoneme_data !")
83
  skipped_phone += 1
84
  continue
85
+
86
  size = os.path.getsize("%s/%s" % (self.path5, audiopath))
87
  duration = size / self.sampling_rate / 2
88
+
89
+ if duration == 0:
90
+ print(f"Zero duration for {audiopath}, skipping...")
91
+ skipped_dur += 1
92
+ continue
93
+
94
  if 54 > duration > 0.6 or self.val:
95
  audiopaths_sid_text_new.append([audiopath, phoneme_ids])
96
  lengths.append(size // (2 * self.hop_length))
97
  else:
98
  skipped_dur += 1
99
  continue
100
+
101
  print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
102
  print("total left: ", len(audiopaths_sid_text_new))
103
  assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
 
110
  try:
111
  spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
112
  with torch.no_grad():
113
+ ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
114
+ if (ssl.shape[-1] != spec.shape[-1]):
 
 
115
  typee = ssl.dtype
116
  ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
117
  ssl.requires_grad = False
 
122
  ssl = torch.zeros(1, 768, 100)
123
  text = text[-1:]
124
  print("load audio or ssl error!!!!!!", audiopath)
 
125
  return (ssl, spec, wav, text)
126
 
127
  def get_audio(self, filename):
128
+ audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
 
 
 
129
  audio = torch.FloatTensor(audio_array) # /32768
130
  audio_norm = audio
131
  audio_norm = audio_norm.unsqueeze(0)
132
+ spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
133
+ center=False)
 
 
 
 
 
 
134
  spec = torch.squeeze(spec, 0)
135
  return spec, audio_norm
136
 
 
147
 
148
  def random_slice(self, ssl, wav, mel):
149
  assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
150
+ "first", ssl.shape, wav.shape)
 
 
 
151
 
152
  len_mel = mel.shape[1]
153
  if self.val:
154
+ reference_mel = mel[:, :len_mel // 3]
155
  return reference_mel, ssl, wav, mel
156
  dir = random.randint(0, 1)
157
  sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
 
159
  if dir == 0:
160
  reference_mel = mel[:, :sep_point]
161
  ssl = ssl[:, :, sep_point:]
162
+ wav2 = wav[:, sep_point * self.hop_length:]
163
  mel = mel[:, sep_point:]
164
  else:
165
  reference_mel = mel[:, sep_point:]
166
  ssl = ssl[:, :, :sep_point]
167
+ wav2 = wav[:, :sep_point * self.hop_length]
168
  mel = mel[:, :sep_point]
169
 
170
  assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
171
+ ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
 
 
 
 
 
 
 
 
172
  return reference_mel, ssl, wav2, mel
173
 
174
 
175
+ class TextAudioSpeakerCollate():
176
+ """ Zero-pads model inputs and targets
177
+ """
178
 
179
  def __init__(self, return_ids=False):
180
  self.return_ids = return_ids
 
187
  """
188
  # Right zero-pad all one-hot text sequences to max input length
189
  _, ids_sorted_decreasing = torch.sort(
190
+ torch.LongTensor([x[1].size(1) for x in batch]),
191
+ dim=0, descending=True)
192
 
193
  max_ssl_len = max([x[0].size(2) for x in batch])
194
  max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
 
216
  row = batch[ids_sorted_decreasing[i]]
217
 
218
  ssl = row[0]
219
+ ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
220
  ssl_lengths[i] = ssl.size(2)
221
 
222
  spec = row[1]
223
+ spec_padded[i, :, :spec.size(1)] = spec
224
  spec_lengths[i] = spec.size(1)
225
 
226
  wav = row[2]
227
+ wav_padded[i, :, :wav.size(1)] = wav
228
  wav_lengths[i] = wav.size(1)
229
 
230
  text = row[3]
231
+ text_padded[i, :text.size(0)] = text
232
  text_lengths[i] = text.size(0)
233
 
234
+ return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
 
 
 
 
 
 
 
 
 
235
 
236
 
237
  class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
 
244
  Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
245
  """
246
 
247
+ def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
 
 
 
 
 
 
 
 
248
  super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
249
  self.lengths = dataset.lengths
 
250
  self.batch_size = batch_size
251
  self.boundaries = boundaries
252
 
 
262
  if idx_bucket != -1:
263
  buckets[idx_bucket].append(i)
264
 
265
+ i = len(buckets) - 1
266
+ while i >= 0:
267
  if len(buckets[i]) == 0:
268
  buckets.pop(i)
269
  self.boundaries.pop(i + 1)
270
+ i -= 1
271
 
272
  num_samples_per_bucket = []
273
  for i in range(len(buckets)):
274
  len_bucket = len(buckets[i])
275
  total_batch_size = self.num_replicas * self.batch_size
276
+ rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
 
 
277
  num_samples_per_bucket.append(len_bucket + rem)
278
  return buckets, num_samples_per_bucket
279
 
280
  def __iter__(self):
 
281
  g = torch.Generator()
282
  g.manual_seed(self.epoch)
283
 
 
296
  ids_bucket = indices[i]
297
  num_samples_bucket = self.num_samples_per_bucket[i]
298
 
 
299
  rem = num_samples_bucket - len_bucket
300
+ ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
 
 
 
 
301
 
302
+ ids_bucket = ids_bucket[self.rank::self.num_replicas]
 
303
 
 
304
  for j in range(len(ids_bucket) // self.batch_size):
305
+ batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
 
 
 
 
 
306
  batches.append(batch)
307
 
308
  if self.shuffle:
 
329
  return -1
330
 
331
  def __len__(self):
332
+ return self.num_samples // self.batch_size
requirements.txt CHANGED
@@ -1,18 +1,24 @@
1
  numpy
2
  scipy
3
- torch
4
  librosa==0.9.2
5
  numba==0.56.4
6
- pytorch-lightning
 
7
  gradio==3.47.1
8
  ffmpeg-python
9
- tqdm==4.59.0
 
 
10
  cn2an
11
  pypinyin
12
- pyopenjtalk-prebuilt
13
  g2p_en
14
  torchaudio
 
15
  sentencepiece
16
  transformers
17
- einops
18
- jieba
 
 
 
1
  numpy
2
  scipy
3
+ tensorboard
4
  librosa==0.9.2
5
  numba==0.56.4
6
+ pytorch-lightning==2.1
7
+ torchmetrics==0.10.1
8
  gradio==3.47.1
9
  ffmpeg-python
10
+ onnxruntime
11
+ tqdm
12
+ funasr
13
  cn2an
14
  pypinyin
15
+ pyopenjtalk
16
  g2p_en
17
  torchaudio
18
+ modelscope
19
  sentencepiece
20
  transformers
21
+ chardet
22
+ PyYAML
23
+ psutil
24
+ jieba_fast
text/chinese.py CHANGED
@@ -18,7 +18,7 @@ pinyin_to_symbol_map = {
18
  for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
19
  }
20
 
21
- import jieba.posseg as psg
22
 
23
 
24
  rep_map = {
 
18
  for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
19
  }
20
 
21
+ import jieba_fast.posseg as psg
22
 
23
 
24
  rep_map = {
text/tone_sandhi.py CHANGED
@@ -14,7 +14,7 @@
14
  from typing import List
15
  from typing import Tuple
16
 
17
- import jieba
18
  from pypinyin import lazy_pinyin
19
  from pypinyin import Style
20
 
 
14
  from typing import List
15
  from typing import Tuple
16
 
17
+ import jieba_fast as jieba
18
  from pypinyin import lazy_pinyin
19
  from pypinyin import Style
20
 
utils.py CHANGED
@@ -18,7 +18,7 @@ logging.getLogger("matplotlib").setLevel(logging.ERROR)
18
 
19
  MATPLOTLIB_FLAG = False
20
 
21
- logging.basicConfig(stream=sys.stdout, level=logging.WARNING)
22
  logger = logging
23
 
24
 
@@ -310,13 +310,13 @@ def check_git_hash(model_dir):
310
  def get_logger(model_dir, filename="train.log"):
311
  global logger
312
  logger = logging.getLogger(os.path.basename(model_dir))
313
- logger.setLevel(logging.WARNING)
314
 
315
  formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
316
  if not os.path.exists(model_dir):
317
  os.makedirs(model_dir)
318
  h = logging.FileHandler(os.path.join(model_dir, filename))
319
- h.setLevel(logging.WARNING)
320
  h.setFormatter(formatter)
321
  logger.addHandler(h)
322
  return logger
 
18
 
19
  MATPLOTLIB_FLAG = False
20
 
21
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
22
  logger = logging
23
 
24
 
 
310
  def get_logger(model_dir, filename="train.log"):
311
  global logger
312
  logger = logging.getLogger(os.path.basename(model_dir))
313
+ logger.setLevel(logging.DEBUG)
314
 
315
  formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
316
  if not os.path.exists(model_dir):
317
  os.makedirs(model_dir)
318
  h = logging.FileHandler(os.path.join(model_dir, filename))
319
+ h.setLevel(logging.DEBUG)
320
  h.setFormatter(formatter)
321
  logger.addHandler(h)
322
  return logger