PoTaTo721 commited on
Commit
315fa0c
·
1 Parent(s): 4bb1f5a
Files changed (1) hide show
  1. app.py +6 -12
app.py CHANGED
@@ -3,6 +3,10 @@ import queue
3
  from huggingface_hub import snapshot_download
4
  import hydra
5
  import numpy as np
 
 
 
 
6
 
7
  # Download if not exists
8
  os.makedirs("checkpoints", exist_ok=True)
@@ -203,7 +207,7 @@ def inference_with_auto_rerank(
203
  best_sample_rate = None
204
 
205
  for attempt in range(max_attempts):
206
- audio_generator = inference(
207
  text,
208
  enable_reference_audio,
209
  reference_audio,
@@ -216,16 +220,6 @@ def inference_with_auto_rerank(
216
  streaming=False,
217
  )
218
 
219
- # 获取音频数据
220
- result = None
221
- for item in audio_generator:
222
- result = item
223
-
224
- if result is None:
225
- return None, None, "No audio generated"
226
-
227
- _, (sample_rate, audio), message = result
228
-
229
  if audio is None:
230
  return None, None, message
231
 
@@ -234,6 +228,7 @@ def inference_with_auto_rerank(
234
 
235
  asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
236
  wer = calculate_wer(text, asr_result["text"])
 
237
  if wer <= 0.3 and not asr_result["huge_gap"]:
238
  return None, (sample_rate, audio), None
239
 
@@ -253,7 +248,6 @@ n_audios = 4
253
  global_audio_list = []
254
  global_error_list = []
255
 
256
-
257
  def inference_wrapper(
258
  text,
259
  enable_reference_audio,
 
3
  from huggingface_hub import snapshot_download
4
  import hydra
5
  import numpy as np
6
+ import wave
7
+ import io
8
+ import pyrootutils
9
+ import gc
10
 
11
  # Download if not exists
12
  os.makedirs("checkpoints", exist_ok=True)
 
207
  best_sample_rate = None
208
 
209
  for attempt in range(max_attempts):
210
+ _, (sample_rate, audio), message = inference(
211
  text,
212
  enable_reference_audio,
213
  reference_audio,
 
220
  streaming=False,
221
  )
222
 
 
 
 
 
 
 
 
 
 
 
223
  if audio is None:
224
  return None, None, message
225
 
 
228
 
229
  asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
230
  wer = calculate_wer(text, asr_result["text"])
231
+
232
  if wer <= 0.3 and not asr_result["huge_gap"]:
233
  return None, (sample_rate, audio), None
234
 
 
248
  global_audio_list = []
249
  global_error_list = []
250
 
 
251
  def inference_wrapper(
252
  text,
253
  enable_reference_audio,