Plachta commited on
Commit
a5ba843
·
1 Parent(s): dc5116b

Replaced Encodec with Vocos

Browse files
Files changed (2) hide show
  1. app.py +56 -62
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import argparse
2
  import logging
3
  import os
4
  import pathlib
@@ -19,7 +18,6 @@ langid.set_languages(['en', 'zh', 'ja'])
19
 
20
  import torch
21
  import torchaudio
22
- import random
23
 
24
  import numpy as np
25
 
@@ -35,7 +33,8 @@ from macros import *
35
  from examples import *
36
 
37
  import gradio as gr
38
- import whisper
 
39
 
40
  torch._C._jit_set_profiling_executor(False)
41
  torch._C._jit_set_profiling_mode(False)
@@ -72,8 +71,13 @@ model.eval()
72
  # Encodec model
73
  audio_tokenizer = AudioTokenizer(device)
74
 
 
 
 
75
  # ASR
76
- whisper_model = whisper.load_model("medium").to(device)
 
 
77
 
78
  # Voice Presets
79
  preset_list = os.walk("./presets/").__next__()[2]
@@ -89,34 +93,33 @@ def clear_prompts():
89
  endfiletime = time.time() - 60
90
  if endfiletime > lastmodifytime:
91
  os.remove(filename)
 
 
92
  except:
93
  return
94
 
95
- def transcribe_one(model, audio_path):
96
- # load audio and pad/trim it to fit 30 seconds
97
- audio = whisper.load_audio(audio_path)
98
- audio = whisper.pad_or_trim(audio)
 
99
 
100
- # make log-Mel spectrogram and move to the same device as the model
101
- mel = whisper.log_mel_spectrogram(audio).to(model.device)
102
 
103
- # detect the spoken language
104
- _, probs = model.detect_language(mel)
105
- print(f"Detected language: {max(probs, key=probs.get)}")
106
- lang = max(probs, key=probs.get)
107
- # decode the audio
108
- options = whisper.DecodingOptions(temperature=1.0, best_of=5, fp16=False if device == torch.device("cpu") else True, sample_len=150)
109
- result = whisper.decode(model, mel, options)
110
 
111
  # print the recognized text
112
- print(result.text)
113
 
114
- text_pr = result.text
115
  if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
116
  text_pr += "."
117
 
118
  # delete all variables
119
- del audio, mel, probs, result
120
  gc.collect()
121
  return lang, text_pr
122
 
@@ -137,7 +140,7 @@ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
137
  assert wav_pr.ndim and wav_pr.size(0) == 1
138
 
139
  if transcript_content == "":
140
- text_pr, lang_pr = make_prompt(name, wav_pr, sr, save=False)
141
  else:
142
  lang_pr = langid.classify(str(transcript_content))[0]
143
  lang_token = lang2token[lang_pr]
@@ -147,6 +150,8 @@ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
147
  audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
148
 
149
  # tokenize text
 
 
150
  phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
151
  text_tokens, enroll_x_lens = text_collater(
152
  [
@@ -155,6 +160,8 @@ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
155
  )
156
 
157
  message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
 
 
158
 
159
  # save as npz file
160
  np.savez(os.path.join(tempfile.gettempdir(), f"{name}.npz"),
@@ -166,30 +173,6 @@ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
166
  return message, os.path.join(tempfile.gettempdir(), f"{name}.npz")
167
 
168
 
169
- def make_prompt(name, wav, sr, save=True):
170
- if not isinstance(wav, torch.FloatTensor):
171
- wav = torch.tensor(wav)
172
- if wav.abs().max() > 1:
173
- wav /= wav.abs().max()
174
- if wav.size(-1) == 2:
175
- wav = wav.mean(-1, keepdim=False)
176
- if wav.ndim == 1:
177
- wav = wav.unsqueeze(0)
178
- assert wav.ndim and wav.size(0) == 1
179
- torchaudio.save(f"./prompts/{name}.wav", wav, sr)
180
- lang, text = transcribe_one(whisper_model, f"./prompts/{name}.wav")
181
- lang_token = lang2token[lang]
182
- text = lang_token + text + lang_token
183
- with open(f"./prompts/{name}.txt", 'w') as f:
184
- f.write(text)
185
- if not save:
186
- os.remove(f"./prompts/{name}.wav")
187
- os.remove(f"./prompts/{name}.txt")
188
- # delete all variables
189
- del lang_token, wav, sr
190
- gc.collect()
191
- return text, lang
192
-
193
  @torch.no_grad()
194
  def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
195
  if len(text) > 150:
@@ -209,7 +192,7 @@ def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt,
209
  assert wav_pr.ndim and wav_pr.size(0) == 1
210
 
211
  if transcript_content == "":
212
- text_pr, lang_pr = make_prompt('dummy', wav_pr, sr, save=False)
213
  else:
214
  lang_pr = langid.classify(str(transcript_content))[0]
215
  lang_token = lang2token[lang_pr]
@@ -222,6 +205,9 @@ def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt,
222
  lang = token2lang[lang_token]
223
  text = lang_token + text + lang_token
224
 
 
 
 
225
  # tokenize audio
226
  encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
227
  audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
@@ -237,6 +223,8 @@ def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt,
237
 
238
  enroll_x_lens = None
239
  if text_pr:
 
 
240
  text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
241
  text_prompts, enroll_x_lens = text_collater(
242
  [
@@ -256,15 +244,16 @@ def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt,
256
  prompt_language=lang_pr,
257
  text_language=langs if accent == "no-accent" else lang,
258
  )
259
- samples = audio_tokenizer.decode(
260
- [(encoded_frames.transpose(2, 1), None)]
261
- )
 
262
 
263
  message = f"text prompt: {text_pr}\nsythesized text: {text}"
264
  # delete all variables
265
  del audio_prompts, text_tokens, text_prompts, phone_tokens, encoded_frames, wav_pr, sr, audio_prompt, record_audio_prompt, transcript_content
266
  gc.collect()
267
- return message, (24000, samples[0][0].cpu().numpy())
268
 
269
  @torch.no_grad()
270
  def infer_from_prompt(text, language, accent, preset_prompt, prompt_file):
@@ -315,16 +304,17 @@ def infer_from_prompt(text, language, accent, preset_prompt, prompt_file):
315
  prompt_language=lang_pr,
316
  text_language=langs if accent == "no-accent" else lang,
317
  )
318
- samples = audio_tokenizer.decode(
319
- [(encoded_frames.transpose(2, 1), None)]
320
- )
 
321
 
322
  message = f"sythesized text: {text}"
323
 
324
  # delete all variables
325
  del audio_prompts, text_tokens, text_prompts, phone_tokens, encoded_frames, prompt_file, preset_prompt
326
  gc.collect()
327
- return message, (24000, samples[0][0].cpu().numpy())
328
 
329
 
330
  from utils.sentence_cutter import split_text_into_sentences
@@ -407,11 +397,13 @@ def infer_long_text(text, preset_prompt, prompt=None, language='auto', accent='n
407
  text_language=langs if accent == "no-accent" else lang,
408
  )
409
  complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
410
- samples = audio_tokenizer.decode(
411
- [(complete_tokens, None)]
412
- )
 
 
413
  message = f"Cut into {len(sentences)} sentences"
414
- return message, (24000, samples[0][0].cpu().numpy())
415
  elif mode == "sliding-window":
416
  complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
417
  original_audio_prompts = audio_prompts
@@ -453,12 +445,14 @@ def infer_long_text(text, preset_prompt, prompt=None, language='auto', accent='n
453
  else:
454
  audio_prompts = original_audio_prompts
455
  text_prompts = original_text_prompts
456
- samples = audio_tokenizer.decode(
457
- [(complete_tokens, None)]
458
- )
 
 
459
  message = f"Cut into {len(sentences)} sentences"
460
 
461
- return message, (24000, samples[0][0].cpu().numpy())
462
  else:
463
  raise ValueError(f"No such mode {mode}")
464
 
 
 
1
  import logging
2
  import os
3
  import pathlib
 
18
 
19
  import torch
20
  import torchaudio
 
21
 
22
  import numpy as np
23
 
 
33
  from examples import *
34
 
35
  import gradio as gr
36
+ from vocos import Vocos
37
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
38
 
39
  torch._C._jit_set_profiling_executor(False)
40
  torch._C._jit_set_profiling_mode(False)
 
71
  # Encodec model
72
  audio_tokenizer = AudioTokenizer(device)
73
 
74
+ # Vocos decoder
75
+ vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)
76
+
77
  # ASR
78
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
79
+ whisper = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to(device)
80
+ whisper.config.forced_decoder_ids = None
81
 
82
  # Voice Presets
83
  preset_list = os.walk("./presets/").__next__()[2]
 
93
  endfiletime = time.time() - 60
94
  if endfiletime > lastmodifytime:
95
  os.remove(filename)
96
+ del path, filename, lastmodifytime, endfiletime
97
+ gc.collect()
98
  except:
99
  return
100
 
101
+ def transcribe_one(wav, sr):
102
+ if sr != 16000:
103
+ wav4trans = torchaudio.transforms.Resample(sr, 16000)(wav)
104
+ else:
105
+ wav4trans = wav
106
 
107
+ input_features = whisper_processor(wav4trans.squeeze(0), sampling_rate=16000, return_tensors="pt").input_features
 
108
 
109
+ # generate token ids
110
+ predicted_ids = whisper.generate(input_features.to(device))
111
+ lang = whisper_processor.batch_decode(predicted_ids[:, 1])[0].strip("<|>")
112
+ # decode token ids to text
113
+ text_pr = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
 
 
114
 
115
  # print the recognized text
116
+ print(text_pr)
117
 
 
118
  if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
119
  text_pr += "."
120
 
121
  # delete all variables
122
+ del wav4trans, input_features, predicted_ids
123
  gc.collect()
124
  return lang, text_pr
125
 
 
140
  assert wav_pr.ndim and wav_pr.size(0) == 1
141
 
142
  if transcript_content == "":
143
+ lang_pr, text_pr = transcribe_one(wav_pr, sr)
144
  else:
145
  lang_pr = langid.classify(str(transcript_content))[0]
146
  lang_token = lang2token[lang_pr]
 
150
  audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
151
 
152
  # tokenize text
153
+ lang_token = lang2token[lang_pr]
154
+ text_pr = lang_token + text_pr + lang_token
155
  phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
156
  text_tokens, enroll_x_lens = text_collater(
157
  [
 
160
  )
161
 
162
  message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
163
+ if lang_pr not in ['ja', 'zh', 'en']:
164
+ return f"Prompt can only made with one of model-supported languages, got {lang_pr} instead", None
165
 
166
  # save as npz file
167
  np.savez(os.path.join(tempfile.gettempdir(), f"{name}.npz"),
 
173
  return message, os.path.join(tempfile.gettempdir(), f"{name}.npz")
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  @torch.no_grad()
177
  def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
178
  if len(text) > 150:
 
192
  assert wav_pr.ndim and wav_pr.size(0) == 1
193
 
194
  if transcript_content == "":
195
+ lang_pr, text_pr = transcribe_one(wav_pr, sr)
196
  else:
197
  lang_pr = langid.classify(str(transcript_content))[0]
198
  lang_token = lang2token[lang_pr]
 
205
  lang = token2lang[lang_token]
206
  text = lang_token + text + lang_token
207
 
208
+ if lang_pr not in ['ja', 'zh', 'en']:
209
+ return f"Reference audio must be a speech of one of model-supported languages, got {lang_pr} instead", None
210
+
211
  # tokenize audio
212
  encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
213
  audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
 
223
 
224
  enroll_x_lens = None
225
  if text_pr:
226
+ lang_token = lang2token[lang_pr]
227
+ text_pr = lang_token + text_pr + lang_token
228
  text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
229
  text_prompts, enroll_x_lens = text_collater(
230
  [
 
244
  prompt_language=lang_pr,
245
  text_language=langs if accent == "no-accent" else lang,
246
  )
247
+ # Decode with Vocos
248
+ frames = encoded_frames.permute(2,0,1)
249
+ features = vocos.codes_to_features(frames)
250
+ samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
251
 
252
  message = f"text prompt: {text_pr}\nsythesized text: {text}"
253
  # delete all variables
254
  del audio_prompts, text_tokens, text_prompts, phone_tokens, encoded_frames, wav_pr, sr, audio_prompt, record_audio_prompt, transcript_content
255
  gc.collect()
256
+ return message, (24000, samples.squeeze(0).cpu().numpy())
257
 
258
  @torch.no_grad()
259
  def infer_from_prompt(text, language, accent, preset_prompt, prompt_file):
 
304
  prompt_language=lang_pr,
305
  text_language=langs if accent == "no-accent" else lang,
306
  )
307
+ # Decode with Vocos
308
+ frames = encoded_frames.permute(2,0,1)
309
+ features = vocos.codes_to_features(frames)
310
+ samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
311
 
312
  message = f"sythesized text: {text}"
313
 
314
  # delete all variables
315
  del audio_prompts, text_tokens, text_prompts, phone_tokens, encoded_frames, prompt_file, preset_prompt
316
  gc.collect()
317
+ return message, (24000, samples.squeeze(0).cpu().numpy())
318
 
319
 
320
  from utils.sentence_cutter import split_text_into_sentences
 
397
  text_language=langs if accent == "no-accent" else lang,
398
  )
399
  complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
400
+ # Decode with Vocos
401
+ frames = encoded_frames.permute(2, 0, 1)
402
+ features = vocos.codes_to_features(frames)
403
+ samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
404
+
405
  message = f"Cut into {len(sentences)} sentences"
406
+ return message, (24000, samples.squeeze(0).cpu().numpy())
407
  elif mode == "sliding-window":
408
  complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
409
  original_audio_prompts = audio_prompts
 
445
  else:
446
  audio_prompts = original_audio_prompts
447
  text_prompts = original_text_prompts
448
+ # Decode with Vocos
449
+ frames = encoded_frames.permute(2, 0, 1)
450
+ features = vocos.codes_to_features(frames)
451
+ samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
452
+
453
  message = f"Cut into {len(sentences)} sentences"
454
 
455
+ return message, (24000, samples.squeeze(0).cpu().numpy())
456
  else:
457
  raise ValueError(f"No such mode {mode}")
458
 
requirements.txt CHANGED
@@ -5,6 +5,7 @@ torchvision==0.15.2
5
  torchaudio
6
  tokenizers
7
  encodec
 
8
  langid
9
  unidecode
10
  pyopenjtalk
 
5
  torchaudio
6
  tokenizers
7
  encodec
8
+ vocos
9
  langid
10
  unidecode
11
  pyopenjtalk