File size: 2,295 Bytes
056b612
4f90f68
 
 
056b612
 
 
 
 
4f90f68
056b612
4f90f68
 
 
 
 
 
 
a20f918
 
 
 
 
4f90f68
 
 
 
 
 
 
 
 
 
 
 
 
 
a20f918
 
 
 
 
 
4f90f68
 
 
 
 
056b612
 
a20f918
 
 
4f90f68
 
 
 
a20f918
 
 
 
 
 
 
 
 
 
 
 
4f90f68
a20f918
 
 
056b612
 
 
 
b31fd8e
4f90f68
a20f918
b31fd8e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
import numpy as np
import resampy
import torch

import nemo.collections.asr as nemo_asr


asr_model = nemo_asr.models.EncDecCTCModelBPE. \
                    from_pretrained("NeonBohdan/stt_uk_citrinet_512_gamma_0_25",map_location="cpu")

asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
asr_model.eval()
asr_model.encoder.freeze()
asr_model.decoder.freeze()


total_buffer = asr_model.cfg["sample_rate"]
overhead_len = asr_model.cfg["sample_rate"] // 4
model_stride = 4



def resample(sr, audio_data):
    audio_fp32 = np.divide(audio_data, np.iinfo(audio_data.dtype).max, dtype=np.float32)
    audio_16k = resampy.resample(audio_fp32, sr, asr_model.cfg["sample_rate"])

    return audio_16k


def model(audio_16k):
    logits, logits_len, greedy_predictions = asr_model.forward(
        input_signal=torch.tensor([audio_16k]), 
        input_signal_length=torch.tensor([len(audio_16k)])
    )

    # cut overhead
    logits_overhead = logits.shape[1] * overhead_len // total_buffer
    extra = 1 if (logits.shape[1] * overhead_len % total_buffer) else 0
    logits = logits[:,logits_overhead:-logits_overhead-extra]
    logits_len -= 2 * logits_overhead + extra

    current_hypotheses, all_hyp = asr_model.decoding.ctc_decoder_predictions_tensor(
        logits, decoder_lengths=logits_len, return_hypotheses=False,
    )

    return current_hypotheses[0]


def transcribe(audio, state):
    if state is None:
        state = [np.array([], dtype=np.float32), ""]

    sr, audio_data = audio
    audio_16k = resample(sr, audio_data)

    # join to audio sequence
    state[0] = np.concatenate([state[0], audio_16k])

    buffer_len = len(state[0])
    if (buffer_len > total_buffer):
        buffer_len = buffer_len - buffer_len % total_buffer
        buffer = state[0][:buffer_len]
        state[0] = state[0][buffer_len - overhead_len:]
        # run model
        text = model(buffer)
    else:
        text = ""

    if (len(text) != 0):
        state[1] += text + " "
    return state[1], state


gr.Interface(
    fn=transcribe, 
    inputs=[
        gr.Audio(source="microphone", type="numpy", streaming=True), 
        gr.State(None)
    ],
    outputs=[
        "textbox",
        "state"
    ],
    live=True).launch()