espnet2_asr_librispeech_100h / gradio_asr_en_libri100.py
jaekookang
change soundfile to librosa
0635e9e
raw
history blame
2.42 kB
'''Librispeech 100h English ASR demo
@ML2 --> @HuggingFace
2022-02-11
2022-02-16
- changed to HF
- server setting commented
- model cache dir commented
'''
import os
from glob import glob
from loguru import logger
# import soundfile as sf
import librosa
import gradio as gr
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.asr_inference import Speech2Text
# ---------- Settings ----------
GPU_ID = '-1'
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
DEVICE = 'cuda' if GPU_ID != '-1' else 'cpu'
SERVER_PORT = 42208
SERVER_NAME = "0.0.0.0"
SSL_DIR = './keyble_ssl'
# MODEL_DIR = '/home/jkang/HDD4T/jkang/huggingface'
MODEL_DIR = './model'
EXAMPLE_DIR = './examples'
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.wav')))
# ---------- Logging ----------
logger.add('app.log', mode='a')
logger.info('============================= App restarted =============================')
# ---------- Model ----------
logger.info('download model')
d = ModelDownloader(MODEL_DIR)
out = d.download_and_unpack("jkang/espnet2_librispeech_100_conformer")
logger.info('model downloaded')
model = Speech2Text.from_pretrained(
asr_train_config=out['asr_train_config'],
asr_model_file=out['asr_model_file']
)
logger.info('model loaded')
def predict(wav_file):
logger.info('wav file loaded')
# speech, rate = sf.read(wav_file)
speech, rate = librosa.load(wav_file, sr=16000)
nbests = model(speech)
text, *_ = nbests[0]
logger.info('predicted')
return text
iface = gr.Interface(
predict,
title='μ˜μ–΄ μŒμ„±μΈμ‹ 데λͺ¨ (espnet libri100) -- ν”„λ‘œν† νƒ€μž…',
description='μ˜μ–΄ μŒμ„± νŒŒμΌμ„ μ—…λ‘œλ“œν•˜λ©΄ ν…μŠ€νŠΈ λ‚΄μš©μ„ 결과둜 λ³΄μ—¬μ€λ‹ˆλ‹€.',
inputs=[
gr.inputs.Audio(label='μ˜μ–΄ μŒμ„±', source='upload', type='filepath')
],
outputs=[
gr.outputs.Textbox(label='μŒμ„± 인식 λ””μ½”λ”©κ²°κ³Ό'),
],
examples=examples,
article='<p style="text-align:center">i-Scream AI</p>',
)
if __name__ == '__main__':
try:
iface.launch(debug=True,
# server_name=SERVER_NAME,
# server_port=SERVER_PORT,
enable_queue=True,
# ssl_keyfile=SSL_DIR,
# ssl_certfile=SSL_DIR
)
except KeyboardInterrupt as e:
print(e)
finally:
iface.close()