File size: 2,518 Bytes
49041a5
 
e016491
49041a5
 
e016491
 
 
 
49041a5
 
 
 
 
37f6612
9f79e93
37f6612
49041a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e016491
 
49041a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f79e93
 
37f6612
49041a5
 
 
 
 
 
 
bee682c
 
49041a5
b4ac259
49041a5
 
bee682c
49041a5
 
b4ac259
49041a5
 
 
 
 
e016491
 
49041a5
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
'''Librispeech 100h English ASR demo

@ML2 --> @HuggingFace

2022-02-11
2022-02-16 
    - changed to HF
    - server setting commented
    - model cache dir commented
'''

import os
from glob import glob
from loguru import logger
import soundfile as sf
import librosa
# from scipy.io import wavfile
import gradio as gr

from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.asr_inference import Speech2Text


# ---------- Settings ----------
GPU_ID = '-1'
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
DEVICE = 'cuda' if GPU_ID != '-1' else 'cpu'

SERVER_PORT = 42208
SERVER_NAME = "0.0.0.0"

SSL_DIR = './keyble_ssl'
# MODEL_DIR = '/home/jkang/HDD4T/jkang/huggingface'
MODEL_DIR = './model'

EXAMPLE_DIR = './examples'
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.wav')))

# ---------- Logging ----------
logger.add('app.log', mode='a')
logger.info('============================= App restarted =============================')

# ---------- Model ----------
logger.info('download model')
d = ModelDownloader(MODEL_DIR)
out = d.download_and_unpack("jkang/espnet2_librispeech_100_conformer")
logger.info('model downloaded')
model = Speech2Text.from_pretrained(
    asr_train_config=out['asr_train_config'],
    asr_model_file=out['asr_model_file']
)
logger.info('model loaded')

def predict(wav_file):
    logger.info('wav file loaded')
    # speech, rate = sf.read(wav_file)
    speech, rate = librosa.load(wav_file, sr=16000)
    # rate, speech = wavfile.read(wav_file)
    nbests = model(speech)
    text, *_ = nbests[0]
    logger.info('predicted')
    return text

iface = gr.Interface(
    predict,
    title='ESPNet2 ASR Librispeech Conformer (trained on clean-100h)',
    description='Upload your wav file to test the model',
    inputs=[
        gr.inputs.Audio(label='wav file', source='microphone', type='filepath')
    ],
    outputs=[
        gr.outputs.Textbox(label='decoding result'),
    ],
    examples=examples,
    article='<p style="text-align:center">Model URL<a target="_blank" href="https://huggingface.co/jkang/espnet2_librispeech_100_conformer">🤗</a></p>',
)

if __name__ == '__main__':
    try:
        iface.launch(debug=True,
                    #  server_name=SERVER_NAME,
                    #  server_port=SERVER_PORT,
                     enable_queue=True,
                    #  ssl_keyfile=SSL_DIR,
                    #  ssl_certfile=SSL_DIR
                     )    
    except KeyboardInterrupt as e:
        print(e)

    finally:
        iface.close()