theodotus's picture
Buggy version of buffered ASR
a20f918
raw
history blame
2.3 kB
import gradio as gr
import numpy as np
import resampy
import torch
import nemo.collections.asr as nemo_asr
asr_model = nemo_asr.models.EncDecCTCModelBPE. \
from_pretrained("NeonBohdan/stt_uk_citrinet_512_gamma_0_25",map_location="cpu")
asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
asr_model.eval()
asr_model.encoder.freeze()
asr_model.decoder.freeze()
total_buffer = asr_model.cfg["sample_rate"]
overhead_len = asr_model.cfg["sample_rate"] // 4
model_stride = 4
def resample(sr, audio_data):
audio_fp32 = np.divide(audio_data, np.iinfo(audio_data.dtype).max, dtype=np.float32)
audio_16k = resampy.resample(audio_fp32, sr, asr_model.cfg["sample_rate"])
return audio_16k
def model(audio_16k):
logits, logits_len, greedy_predictions = asr_model.forward(
input_signal=torch.tensor([audio_16k]),
input_signal_length=torch.tensor([len(audio_16k)])
)
# cut overhead
logits_overhead = logits.shape[1] * overhead_len // total_buffer
extra = 1 if (logits.shape[1] * overhead_len % total_buffer) else 0
logits = logits[:,logits_overhead:-logits_overhead-extra]
logits_len -= 2 * logits_overhead + extra
current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
logits, decoder_lengths=logits_len, return_hypotheses=False,
)
return current_hypotheses[0]
def transcribe(audio, state):
if state is None:
state = [np.array([], dtype=np.float32), ""]
sr, audio_data = audio
audio_16k = resample(sr, audio_data)
# join to audio sequence
state[0] = np.concatenate([state[0], audio_16k])
buffer_len = len(state[0])
if (buffer_len > total_buffer):
buffer_len = buffer_len - buffer_len % total_buffer
buffer = state[0][:buffer_len]
state[0] = state[0][buffer_len - overhead_len:]
# run model
text = model(buffer)
else:
text = ""
if (len(text) != 0):
state[1] += text + " "
return state[1], state
gr.Interface(
fn=transcribe,
inputs=[
gr.Audio(source="microphone", type="numpy", streaming=True),
gr.State(None)
],
outputs=[
"textbox",
"state"
],
live=True).launch()