theodotus commited on
Commit
a20f918
·
1 Parent(s): 4f90f68

Buggy version of buffered ASR

Browse files
Files changed (1) hide show
  1. app.py +30 -7
app.py CHANGED
@@ -16,6 +16,11 @@ asr_model.encoder.freeze()
16
  asr_model.decoder.freeze()
17
 
18
 
 
 
 
 
 
19
 
20
  def resample(sr, audio_data):
21
  audio_fp32 = np.divide(audio_data, np.iinfo(audio_data.dtype).max, dtype=np.float32)
@@ -30,6 +35,12 @@ def model(audio_16k):
30
  input_signal_length=torch.tensor([len(audio_16k)])
31
  )
32
 
 
 
 
 
 
 
33
  current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
34
  logits, decoder_lengths=logits_len, return_hypotheses=False,
35
  )
@@ -37,24 +48,36 @@ def model(audio_16k):
37
  return current_hypotheses[0]
38
 
39
 
40
- def transcribe(audio, state=""):
41
- # if state is None:
42
- # pass
43
 
44
  sr, audio_data = audio
45
  audio_16k = resample(sr, audio_data)
46
 
47
- text = model(audio_16k)
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- state += text + " "
50
- return state, state
 
51
 
52
 
53
  gr.Interface(
54
  fn=transcribe,
55
  inputs=[
56
  gr.Audio(source="microphone", type="numpy", streaming=True),
57
- "state"
58
  ],
59
  outputs=[
60
  "textbox",
 
16
  asr_model.decoder.freeze()
17
 
18
 
19
+ total_buffer = asr_model.cfg["sample_rate"]
20
+ overhead_len = asr_model.cfg["sample_rate"] // 4
21
+ model_stride = 4
22
+
23
+
24
 
25
  def resample(sr, audio_data):
26
  audio_fp32 = np.divide(audio_data, np.iinfo(audio_data.dtype).max, dtype=np.float32)
 
35
  input_signal_length=torch.tensor([len(audio_16k)])
36
  )
37
 
38
+ # cut overhead
39
+ logits_overhead = logits.shape[1] * overhead_len // total_buffer
40
+ extra = 1 if (logits.shape[1] * overhead_len % total_buffer) else 0
41
+ logits = logits[:,logits_overhead:-logits_overhead-extra]
42
+ logits_len -= 2 * logits_overhead + extra
43
+
44
  current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
45
  logits, decoder_lengths=logits_len, return_hypotheses=False,
46
  )
 
48
  return current_hypotheses[0]
49
 
50
 
51
+ def transcribe(audio, state):
52
+ if state is None:
53
+ state = [np.array([], dtype=np.float32), ""]
54
 
55
  sr, audio_data = audio
56
  audio_16k = resample(sr, audio_data)
57
 
58
+ # join to audio sequence
59
+ state[0] = np.concatenate([state[0], audio_16k])
60
+
61
+ buffer_len = len(state[0])
62
+ if (buffer_len > total_buffer):
63
+ buffer_len = buffer_len - buffer_len % total_buffer
64
+ buffer = state[0][:buffer_len]
65
+ state[0] = state[0][buffer_len - overhead_len:]
66
+ # run model
67
+ text = model(buffer)
68
+ else:
69
+ text = ""
70
 
71
+ if (len(text) != 0):
72
+ state[1] += text + " "
73
+ return state[1], state
74
 
75
 
76
  gr.Interface(
77
  fn=transcribe,
78
  inputs=[
79
  gr.Audio(source="microphone", type="numpy", streaming=True),
80
+ gr.State(None)
81
  ],
82
  outputs=[
83
  "textbox",