Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import resampy | |
import torch | |
import nemo.collections.asr as nemo_asr | |
asr_model = nemo_asr.models.EncDecCTCModelBPE. \ | |
from_pretrained("NeonBohdan/stt_uk_citrinet_512_gamma_0_25",map_location="cpu") | |
asr_model.preprocessor.featurizer.dither = 0.0 | |
asr_model.preprocessor.featurizer.pad_to = 0 | |
asr_model.eval() | |
asr_model.encoder.freeze() | |
asr_model.decoder.freeze() | |
total_buffer = asr_model.cfg["sample_rate"] | |
overhead_len = asr_model.cfg["sample_rate"] // 4 | |
model_stride = 4 | |
def resample(sr, audio_data): | |
audio_fp32 = np.divide(audio_data, np.iinfo(audio_data.dtype).max, dtype=np.float32) | |
audio_16k = resampy.resample(audio_fp32, sr, asr_model.cfg["sample_rate"]) | |
return audio_16k | |
def model(audio_16k): | |
logits, logits_len, greedy_predictions = asr_model.forward( | |
input_signal=torch.tensor([audio_16k]), | |
input_signal_length=torch.tensor([len(audio_16k)]) | |
) | |
# cut overhead | |
logits_overhead = logits.shape[1] * overhead_len // total_buffer | |
extra = 1 if (logits.shape[1] * overhead_len % total_buffer) else 0 | |
logits = logits[:,logits_overhead:-logits_overhead-extra] | |
logits_len -= 2 * logits_overhead + extra | |
current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor( | |
logits, decoder_lengths=logits_len, return_hypotheses=False, | |
) | |
return current_hypotheses[0] | |
def transcribe(audio, state): | |
if state is None: | |
state = [np.array([], dtype=np.float32), ""] | |
sr, audio_data = audio | |
audio_16k = resample(sr, audio_data) | |
# join to audio sequence | |
state[0] = np.concatenate([state[0], audio_16k]) | |
buffer_len = len(state[0]) | |
if (buffer_len > total_buffer): | |
buffer_len = buffer_len - buffer_len % total_buffer | |
buffer = state[0][:buffer_len] | |
state[0] = state[0][buffer_len - overhead_len:] | |
# run model | |
text = model(buffer) | |
else: | |
text = "" | |
if (len(text) != 0): | |
state[1] += text + " " | |
return state[1], state | |
gr.Interface( | |
fn=transcribe, | |
inputs=[ | |
gr.Audio(source="microphone", type="numpy", streaming=True), | |
gr.State(None) | |
], | |
outputs=[ | |
"textbox", | |
"state" | |
], | |
live=True).launch() |