File size: 2,208 Bytes
49041a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
'''Librispeech 100h English ASR demo

@ML2

2022-02-11
'''

import os
from glob import glob
from loguru import logger
import soundfile as sf
import gradio as gr

from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.asr_inference import Speech2Text


# ---------- Settings ----------
GPU_ID = '-1'
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
DEVICE = 'cuda' if GPU_ID != '-1' else 'cpu'

SERVER_PORT = 42208
SERVER_NAME = "0.0.0.0"

SSL_DIR = './keyble_ssl'
MODEL_DIR = '/home/jkang/HDD4T/jkang/huggingface'

EXAMPLE_DIR = './examples'
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.wav')))

# ---------- Logging ----------
logger.add('app.log', mode='a')
logger.info('============================= App restarted =============================')

# ---------- Model ----------
logger.info('download model')
d = ModelDownloader(MODEL_DIR)
out = d.download_and_unpack("jkang/espnet2_librispeech_100_conformer")
logger.info('model downloaded')
model = Speech2Text.from_pretrained(
    asr_train_config=out['asr_train_config'],
    asr_model_file=out['asr_model_file']
)
logger.info('model loaded')

def predict(wav_file):
    logger.info('wav file loaded')
    speech, rate = sf.read(wav_file)
    nbests = model(speech)
    text, *_ = nbests[0]
    logger.info('predicted')
    return text

iface = gr.Interface(
    predict,
    title='μ˜μ–΄ μŒμ„±μΈμ‹ 데λͺ¨ (espnet libri100) -- ν”„λ‘œν† νƒ€μž…',
    description='μ˜μ–΄ μŒμ„± νŒŒμΌμ„ μ—…λ‘œλ“œν•˜λ©΄ ν…μŠ€νŠΈ λ‚΄μš©μ„ 결과둜 λ³΄μ—¬μ€λ‹ˆλ‹€.',
    inputs=[
        gr.inputs.Audio(label='μ˜μ–΄ μŒμ„±', source='upload', type='filepath')
    ],
    outputs=[
        gr.outputs.Textbox(label='μŒμ„± 인식 λ””μ½”λ”©κ²°κ³Ό'),
    ],
    examples=examples,
    article='<p style="text-align:center">i-Scream AI</p>',
)

if __name__ == '__main__':
    try:
        iface.launch(debug=True,
                     server_name=SERVER_NAME,
                     server_port=SERVER_PORT,
                     enable_queue=True,
                    #  ssl_keyfile=SSL_DIR,
                    #  ssl_certfile=SSL_DIR
                     )    
    except KeyboardInterrupt as e:
        print(e)

    finally:
        iface.close()