XzJosh commited on
Commit
2fe559d
1 Parent(s): f606b6e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -118
app.py CHANGED
@@ -1,13 +1,5 @@
1
- import os,re,logging
2
- logging.getLogger("markdown_it").setLevel(logging.ERROR)
3
- logging.getLogger("urllib3").setLevel(logging.ERROR)
4
- logging.getLogger("httpcore").setLevel(logging.ERROR)
5
- logging.getLogger("httpx").setLevel(logging.ERROR)
6
- logging.getLogger("asyncio").setLevel(logging.ERROR)
7
-
8
- logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
9
- logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
10
- import pdb
11
 
12
  gpt_path = os.environ.get(
13
  "gpt_path", "models/Taffy/Taffy-e5.ckpt"
@@ -57,6 +49,7 @@ else:
57
  bert_model = bert_model.to(device)
58
 
59
 
 
60
  def get_bert_feature(text, word2ph):
61
  with torch.no_grad():
62
  inputs = tokenizer(text, return_tensors="pt")
@@ -70,8 +63,15 @@ def get_bert_feature(text, word2ph):
70
  repeat_feature = res[i].repeat(word2ph[i], 1)
71
  phone_level_feature.append(repeat_feature)
72
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
 
73
  return phone_level_feature.T
74
 
 
 
 
 
 
 
75
  class DictToAttrRecursive(dict):
76
  def __init__(self, input_dict):
77
  super().__init__(input_dict)
@@ -99,6 +99,12 @@ class DictToAttrRecursive(dict):
99
  except KeyError:
100
  raise AttributeError(f"Attribute {item} not found")
101
 
 
 
 
 
 
 
102
  ssl_model = cnhubert.get_model()
103
  if is_half == True:
104
  ssl_model = ssl_model.half().to(device)
@@ -117,8 +123,7 @@ def change_sovits_weights(sovits_path):
117
  n_speakers=hps.data.n_speakers,
118
  **hps.model
119
  )
120
- if("pretrained"not in sovits_path):
121
- del vq_model.enc_q
122
  if is_half == True:
123
  vq_model = vq_model.half().to(device)
124
  else:
@@ -160,88 +165,14 @@ def get_spepc(hps, filename):
160
  return spec
161
 
162
 
163
- dict_language={
164
- ("中文"):"zh",
165
- ("英文"):"en",
166
- ("日文"):"ja"
167
- }
168
-
169
-
170
- def splite_en_inf(sentence, language):
171
- pattern = re.compile(r'[a-zA-Z. ]+')
172
- textlist = []
173
- langlist = []
174
- pos = 0
175
- for match in pattern.finditer(sentence):
176
- start, end = match.span()
177
- if start > pos:
178
- textlist.append(sentence[pos:start])
179
- langlist.append(language)
180
- textlist.append(sentence[start:end])
181
- langlist.append("en")
182
- pos = end
183
- if pos < len(sentence):
184
- textlist.append(sentence[pos:])
185
- langlist.append(language)
186
-
187
- return textlist, langlist
188
-
189
-
190
- def clean_text_inf(text, language):
191
- phones, word2ph, norm_text = clean_text(text, language)
192
- phones = cleaned_text_to_sequence(phones)
193
-
194
- return phones, word2ph, norm_text
195
- def get_bert_inf(phones, word2ph, norm_text, language):
196
- if language == "zh":
197
- bert = get_bert_feature(norm_text, word2ph).to(device)
198
- else:
199
- bert = torch.zeros(
200
- (1024, len(phones)),
201
- dtype=torch.float16 if is_half == True else torch.float32,
202
- ).to(device)
203
-
204
- return bert
205
-
206
-
207
- def nonen_clean_text_inf(text, language):
208
- textlist, langlist = splite_en_inf(text, language)
209
- phones_list = []
210
- word2ph_list = []
211
- norm_text_list = []
212
- for i in range(len(textlist)):
213
- lang = langlist[i]
214
- phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
215
- phones_list.append(phones)
216
- if lang == "en" or "ja":
217
- pass
218
- else:
219
- word2ph_list.append(word2ph)
220
- norm_text_list.append(norm_text)
221
- print(word2ph_list)
222
- phones = sum(phones_list, [])
223
- word2ph = sum(word2ph_list, [])
224
- norm_text = ' '.join(norm_text_list)
225
-
226
- return phones, word2ph, norm_text
227
-
228
-
229
- def nonen_get_bert_inf(text, language):
230
- textlist, langlist = splite_en_inf(text, language)
231
- print(textlist)
232
- print(langlist)
233
- bert_list = []
234
- for i in range(len(textlist)):
235
- text = textlist[i]
236
- lang = langlist[i]
237
- phones, word2ph, norm_text = clean_text_inf(text, lang)
238
- bert = get_bert_inf(phones, word2ph, norm_text, lang)
239
- bert_list.append(bert)
240
- bert = torch.cat(bert_list, dim=1)
241
-
242
- return bert
243
-
244
- def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,how_to_cut=("不切")):
245
  t0 = ttime()
246
  prompt_text = prompt_text.strip("\n")
247
  prompt_language, text = prompt_language, text.strip("\n")
@@ -270,37 +201,28 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
270
  t1 = ttime()
271
  prompt_language = dict_language[prompt_language]
272
  text_language = dict_language[text_language]
273
-
274
- if prompt_language == "en":
275
- phones1, word2ph1, norm_text1 = clean_text_inf(prompt_text, prompt_language)
276
- else:
277
- phones1, word2ph1, norm_text1 = nonen_clean_text_inf(prompt_text, prompt_language)
278
- if(how_to_cut==("凑五句一切")):text=cut1(text)
279
- elif(how_to_cut==("凑50字一切")):text=cut2(text)
280
- elif(how_to_cut==("按中文句号。切")):text=cut3(text)
281
- elif(how_to_cut==("按英文句号.切")):text=cut4(text)
282
- text = text.replace("\n\n","\n").replace("\n\n","\n").replace("\n\n","\n")
283
- if(text[-1]not in splits):text+="。"if text_language!="en"else "."
284
- texts=text.split("\n")
285
  audio_opt = []
286
- if prompt_language == "en":
287
- bert1 = get_bert_inf(phones1, word2ph1, norm_text1, prompt_language)
288
- else:
289
- bert1 = nonen_get_bert_inf(prompt_text, prompt_language)
290
 
291
  for text in texts:
292
  # 解决输入目标文本的空行导致报错的问题
293
  if (len(text.strip()) == 0):
294
  continue
295
- if text_language == "en":
296
- phones2, word2ph2, norm_text2 = clean_text_inf(text, text_language)
 
 
297
  else:
298
- phones2, word2ph2, norm_text2 = nonen_clean_text_inf(text, text_language)
299
-
300
- if text_language == "en":
301
- bert2 = get_bert_inf(phones2, word2ph2, norm_text2, text_language)
 
 
302
  else:
303
- bert2 = nonen_get_bert_inf(text, text_language)
304
  bert = torch.cat([bert1, bert2], 1)
305
 
306
  all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
@@ -458,7 +380,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
458
  ### <center>⚠️在线端不稳定且生成速度较慢,强烈建议下载模型本地推理!\n
459
  """)
460
  # with gr.Tabs():
461
-
462
  with gr.Group():
463
  gr.Markdown(value="*参考音频选择(必选)")
464
  with gr.Row():
 
1
+ import os,re
2
+ import gradio as gr
 
 
 
 
 
 
 
 
3
 
4
  gpt_path = os.environ.get(
5
  "gpt_path", "models/Taffy/Taffy-e5.ckpt"
 
49
  bert_model = bert_model.to(device)
50
 
51
 
52
+ # bert_model=bert_model.to(device)
53
  def get_bert_feature(text, word2ph):
54
  with torch.no_grad():
55
  inputs = tokenizer(text, return_tensors="pt")
 
63
  repeat_feature = res[i].repeat(word2ph[i], 1)
64
  phone_level_feature.append(repeat_feature)
65
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
66
+ # if(is_half==True):phone_level_feature=phone_level_feature.half()
67
  return phone_level_feature.T
68
 
69
+
70
+ n_semantic = 1024
71
+
72
+ dict_s2=torch.load(sovits_path,map_location="cpu")
73
+ hps=dict_s2["config"]
74
+
75
  class DictToAttrRecursive(dict):
76
  def __init__(self, input_dict):
77
  super().__init__(input_dict)
 
99
  except KeyError:
100
  raise AttributeError(f"Attribute {item} not found")
101
 
102
+
103
+ hps = DictToAttrRecursive(hps)
104
+
105
+ hps.model.semantic_frame_rate = "25hz"
106
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
107
+ config = dict_s1["config"]
108
  ssl_model = cnhubert.get_model()
109
  if is_half == True:
110
  ssl_model = ssl_model.half().to(device)
 
123
  n_speakers=hps.data.n_speakers,
124
  **hps.model
125
  )
126
+ del vq_model.enc_q
 
127
  if is_half == True:
128
  vq_model = vq_model.half().to(device)
129
  else:
 
165
  return spec
166
 
167
 
168
+ dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
169
+
170
+
171
+ def get_tts_wav(selected_text, prompt_text, prompt_language, text, text_language):
172
+ ref_wav_path = text_to_audio_mappings.get(selected_text, "")
173
+ if not ref_wav_path:
174
+ print("Audio file not found for the selected text.")
175
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  t0 = ttime()
177
  prompt_text = prompt_text.strip("\n")
178
  prompt_language, text = prompt_language, text.strip("\n")
 
201
  t1 = ttime()
202
  prompt_language = dict_language[prompt_language]
203
  text_language = dict_language[text_language]
204
+ phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
205
+ phones1 = cleaned_text_to_sequence(phones1)
206
+ texts = text.split("\n")
 
 
 
 
 
 
 
 
 
207
  audio_opt = []
 
 
 
 
208
 
209
  for text in texts:
210
  # 解决输入目标文本的空行导致报错的问题
211
  if (len(text.strip()) == 0):
212
  continue
213
+ phones2, word2ph2, norm_text2 = clean_text(text, text_language)
214
+ phones2 = cleaned_text_to_sequence(phones2)
215
+ if prompt_language == "zh":
216
+ bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
217
  else:
218
+ bert1 = torch.zeros(
219
+ (1024, len(phones1)),
220
+ dtype=torch.float16 if is_half == True else torch.float32,
221
+ ).to(device)
222
+ if text_language == "zh":
223
+ bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
224
  else:
225
+ bert2 = torch.zeros((1024, len(phones2))).to(bert1)
226
  bert = torch.cat([bert1, bert2], 1)
227
 
228
  all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
 
380
  ### <center>⚠️在线端不稳定且生成速度较慢,强烈建议下载模型本地推理!\n
381
  """)
382
  # with gr.Tabs():
383
+ # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
384
  with gr.Group():
385
  gr.Markdown(value="*参考音频选择(必选)")
386
  with gr.Row():