|
import json |
|
|
|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
from omegaconf import OmegaConf |
|
from vosk import KaldiRecognizer, Model |
|
|
|
|
|
def load_vosk(model_id: str): |
|
model_dir = snapshot_download(model_id) |
|
return Model(model_path=model_dir) |
|
|
|
|
|
OmegaConf.register_new_resolver("load_vosk", load_vosk) |
|
|
|
models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml")) |
|
|
|
|
|
def automatic_speech_recognition(model_id: str, dialect_id: str, audio_data: str): |
|
if isinstance(models_config[model_id]["model"], dict): |
|
model = models_config[model_id]["model"][dialect_id] |
|
else: |
|
model = models_config[model_id]["model"] |
|
|
|
sample_rate, audio_array = audio_data |
|
if audio_array.ndim == 2: |
|
audio_array = audio_array[:, 0] |
|
|
|
audio_bytes = audio_array.tobytes() |
|
|
|
rec = KaldiRecognizer(model, sample_rate) |
|
|
|
rec.SetWords(True) |
|
|
|
results = [] |
|
|
|
for start in range(0, len(audio_bytes), 4000): |
|
end = min(start + 4000, len(audio_bytes)) |
|
data = audio_bytes[start:end] |
|
if rec.AcceptWaveform(data): |
|
raw_result = json.loads(rec.Result()) |
|
results.append(raw_result) |
|
|
|
final_result = json.loads(rec.FinalResult()) |
|
results.append(final_result) |
|
|
|
filtered_lines = [] |
|
|
|
for result in results: |
|
result["text"] = result["text"].replace(" ", "") |
|
if len(result["text"]) > 0: |
|
filtered_lines.append(result["text"]) |
|
|
|
return ", ".join(filtered_lines) + "." |
|
|
|
|
|
def when_model_selected(model_id: str): |
|
model_config = models_config[model_id] |
|
|
|
if "dialect_mapping" not in model_config: |
|
return gr.update(visible=False) |
|
|
|
dialect_drop_down_choices = [ |
|
(k, v) for k, v in model_config["dialect_mapping"].items() |
|
] |
|
|
|
return gr.update( |
|
choices=dialect_drop_down_choices, |
|
value=dialect_drop_down_choices[0][1], |
|
visible=True, |
|
) |
|
|
|
|
|
demo = gr.Blocks( |
|
title="臺灣南島語語音辨識系統", |
|
css="@import url(https://tauhu.tw/tauhu-oo.css);", |
|
theme=gr.themes.Default( |
|
font=( |
|
"tauhu-oo", |
|
gr.themes.GoogleFont("Source Sans Pro"), |
|
"ui-sans-serif", |
|
"system-ui", |
|
"sans-serif", |
|
) |
|
), |
|
) |
|
|
|
with demo: |
|
default_model_id = list(models_config.keys())[0] |
|
model_drop_down = gr.Dropdown( |
|
models_config.keys(), |
|
value=default_model_id, |
|
label="模型", |
|
) |
|
|
|
dialect_drop_down = gr.Radio( |
|
choices=[ |
|
(k, v) |
|
for k, v in models_config[default_model_id]["dialect_mapping"].items() |
|
], |
|
value=list(models_config[default_model_id]["dialect_mapping"].values())[0], |
|
label="族別", |
|
) |
|
|
|
model_drop_down.input( |
|
when_model_selected, |
|
inputs=[model_drop_down], |
|
outputs=[dialect_drop_down], |
|
) |
|
|
|
with open("DEMO.md") as tong: |
|
gr.Markdown(tong.read()) |
|
|
|
gr.Interface( |
|
automatic_speech_recognition, |
|
inputs=[ |
|
model_drop_down, |
|
dialect_drop_down, |
|
gr.Audio( |
|
label="上傳或錄音", |
|
type="numpy", |
|
format="wav", |
|
waveform_options=gr.WaveformOptions( |
|
sample_rate=16000, |
|
), |
|
), |
|
], |
|
outputs=[ |
|
gr.Text(interactive=False, label="辨識結果"), |
|
], |
|
allow_flagging="auto", |
|
) |
|
|
|
demo.launch() |
|
|