Spaces:
Runtime error
Runtime error
'''Librispeech 100h English ASR demo | |
@ML2 --> @HuggingFace | |
2022-02-11 | |
2022-02-16 | |
- changed to HF | |
- server setting commented | |
- model cache dir commented | |
''' | |
import os | |
from glob import glob | |
from loguru import logger | |
# import soundfile as sf | |
import librosa | |
import gradio as gr | |
from espnet_model_zoo.downloader import ModelDownloader | |
from espnet2.bin.asr_inference import Speech2Text | |
# ---------- Settings ---------- | |
GPU_ID = '-1' | |
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID | |
DEVICE = 'cuda' if GPU_ID != '-1' else 'cpu' | |
SERVER_PORT = 42208 | |
SERVER_NAME = "0.0.0.0" | |
SSL_DIR = './keyble_ssl' | |
# MODEL_DIR = '/home/jkang/HDD4T/jkang/huggingface' | |
MODEL_DIR = './model' | |
EXAMPLE_DIR = './examples' | |
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.wav'))) | |
# ---------- Logging ---------- | |
logger.add('app.log', mode='a') | |
logger.info('============================= App restarted =============================') | |
# ---------- Model ---------- | |
logger.info('download model') | |
d = ModelDownloader(MODEL_DIR) | |
out = d.download_and_unpack("jkang/espnet2_librispeech_100_conformer") | |
logger.info('model downloaded') | |
model = Speech2Text.from_pretrained( | |
asr_train_config=out['asr_train_config'], | |
asr_model_file=out['asr_model_file'] | |
) | |
logger.info('model loaded') | |
def predict(wav_file): | |
logger.info('wav file loaded') | |
# speech, rate = sf.read(wav_file) | |
speech, rate = librosa.load(wav_file, sr=16000) | |
nbests = model(speech) | |
text, *_ = nbests[0] | |
logger.info('predicted') | |
return text | |
iface = gr.Interface( | |
predict, | |
title='μμ΄ μμ±μΈμ λ°λͺ¨ (espnet libri100) -- νλ‘ν νμ ', | |
description='μμ΄ μμ± νμΌμ μ λ‘λνλ©΄ ν μ€νΈ λ΄μ©μ κ²°κ³Όλ‘ λ³΄μ¬μ€λλ€.', | |
inputs=[ | |
gr.inputs.Audio(label='μμ΄ μμ±', source='upload', type='filepath') | |
], | |
outputs=[ | |
gr.outputs.Textbox(label='μμ± μΈμ λμ½λ©κ²°κ³Ό'), | |
], | |
examples=examples, | |
article='<p style="text-align:center">i-Scream AI</p>', | |
) | |
if __name__ == '__main__': | |
try: | |
iface.launch(debug=True, | |
# server_name=SERVER_NAME, | |
# server_port=SERVER_PORT, | |
enable_queue=True, | |
# ssl_keyfile=SSL_DIR, | |
# ssl_certfile=SSL_DIR | |
) | |
except KeyboardInterrupt as e: | |
print(e) | |
finally: | |
iface.close() |