asr-demo / app.py
txya900619's picture
feat: upload inference script
784aace
raw
history blame
1.69 kB
import os
import gradio as gr
import torch
from huggingface_hub import snapshot_download
from nemo.collections.asr.models import ASRModel
from omegaconf import OmegaConf
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
def load_model(model_id: str):
model_dir = snapshot_download(model_id)
model_ckpt_path = os.path.join(model_dir, "model.nemo")
asr_model = ASRModel.restore_from(model_ckpt_path)
asr_model.eval()
asr_model = asr_model.to(device)
return asr_model
OmegaConf.register_new_resolver("load_model", load_model)
models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml"))
def automatic_speech_recognition(model_id: str, audio_file: str):
model = models_config[model_id]["model"]
text = model.transcribe(audio_file)
return text
demo = gr.Blocks(
title="康統語音辨識系統",
)
with demo:
default_model_id = list(models_config.keys())[0]
model_drop_down = gr.Dropdown(
models_config.keys(),
value=default_model_id,
label="模型",
)
gr.Markdown(
"""
# 康統語音辨識系統
"""
)
gr.Interface(
automatic_speech_recognition,
inputs=[
model_drop_down,
gr.Audio(
label="上傳或錄音",
type="filepath",
waveform_options=gr.WaveformOptions(
sample_rate=16000,
),
),
],
outputs=[
gr.Text(interactive=False, label="辨識結果"),
],
allow_flagging="auto",
)
demo.launch()