|
|
|
|
|
import argparse |
|
from collections import defaultdict |
|
from datetime import datetime |
|
import functools |
|
import io |
|
import logging |
|
from pathlib import Path |
|
import platform |
|
import time |
|
import tempfile |
|
|
|
from project_settings import project_path, log_directory |
|
import log |
|
|
|
log.setup(log_directory=log_directory) |
|
|
|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
|
|
from toolbox.k2_sherpa.examples import examples |
|
from toolbox.k2_sherpa import decode, nn_models |
|
from toolbox.k2_sherpa.utils import audio_convert |
|
|
|
main_logger = logging.getLogger("main") |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--pretrained_model_dir", |
|
default=(project_path / "pretrained_models").as_posix(), |
|
type=str |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def update_model_dropdown(language: str): |
|
if language not in nn_models.model_map.keys(): |
|
raise ValueError(f"Unsupported language: {language}") |
|
|
|
choices = nn_models.model_map[language] |
|
choices = [c["repo_id"] for c in choices] |
|
return gr.Dropdown( |
|
choices=choices, |
|
value=choices[0], |
|
interactive=True, |
|
) |
|
|
|
|
|
def build_html_output(s: str, style: str = "result_item_success"): |
|
return f""" |
|
<div class='result'> |
|
<div class='result_item {style}'> |
|
{s} |
|
</div> |
|
</div> |
|
""" |
|
|
|
|
|
@torch.no_grad() |
|
def process( |
|
language: str, |
|
repo_id: str, |
|
decoding_method: str, |
|
num_active_paths: int, |
|
add_punctuation: str, |
|
in_filename: str, |
|
pretrained_model_dir: Path, |
|
): |
|
main_logger.info("language: {}".format(language)) |
|
main_logger.info("repo_id: {}".format(repo_id)) |
|
main_logger.info("decoding_method: {}".format(decoding_method)) |
|
main_logger.info("num_active_paths: {}".format(num_active_paths)) |
|
main_logger.info("in_filename: {}".format(in_filename)) |
|
|
|
|
|
in_filename = Path(in_filename) |
|
out_filename = Path(tempfile.gettempdir()) / "asr" / in_filename.name |
|
out_filename.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
audio_convert(in_filename=in_filename.as_posix(), |
|
out_filename=out_filename.as_posix(), |
|
) |
|
|
|
|
|
m_list = nn_models.model_map.get(language) |
|
if m_list is None: |
|
raise AssertionError("language invalid: {}".format(language)) |
|
|
|
m_dict = None |
|
for m in m_list: |
|
if m["repo_id"] == repo_id: |
|
m_dict = m |
|
if m_dict is None: |
|
raise AssertionError("repo_id invalid: {}".format(repo_id)) |
|
|
|
|
|
local_model_dir = pretrained_model_dir / "huggingface" / repo_id |
|
nn_model_file = local_model_dir / m_dict["nn_model_file"] |
|
tokens_file = local_model_dir / m_dict["tokens_file"] |
|
|
|
recognizer = nn_models.load_recognizer( |
|
repo_id=m_dict["repo_id"], |
|
nn_model_file=nn_model_file.as_posix(), |
|
tokens_file=tokens_file.as_posix(), |
|
sub_folder=m_dict["sub_folder"], |
|
local_model_dir=local_model_dir, |
|
recognizer_type=m_dict["recognizer_type"], |
|
decoding_method=decoding_method, |
|
num_active_paths=num_active_paths, |
|
) |
|
|
|
|
|
now = datetime.now() |
|
date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f") |
|
logging.info(f"Started at {date_time}") |
|
start = time.time() |
|
|
|
text = decode.decode_by_recognizer(recognizer=recognizer, |
|
filename=out_filename.as_posix(), |
|
) |
|
|
|
date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f") |
|
end = time.time() |
|
|
|
|
|
metadata = torchaudio.info(out_filename.as_posix()) |
|
duration = metadata.num_frames / 16000 |
|
rtf = (end - start) / duration |
|
|
|
main_logger.info(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s") |
|
|
|
info = f""" |
|
Wave duration : {duration: .3f} s <br/> |
|
Processing time: {end - start: .3f} s <br/> |
|
RTF: {end - start: .3f}/{duration: .3f} = {rtf:.3f} <br/> |
|
""" |
|
|
|
main_logger.info(info) |
|
main_logger.info(f"\nrepo_id: {repo_id}\nhyp: {text}") |
|
|
|
return text, build_html_output(info) |
|
|
|
|
|
def process_uploaded_file(language: str, |
|
repo_id: str, |
|
decoding_method: str, |
|
num_active_paths: int, |
|
add_punctuation: str, |
|
in_filename: str, |
|
pretrained_model_dir: Path, |
|
): |
|
if in_filename is None or in_filename == "": |
|
return "", build_html_output( |
|
"Please first upload a file and then click " |
|
'the button "submit for recognition"', |
|
"result_item_error", |
|
) |
|
main_logger.info(f"Processing uploaded file: {in_filename}") |
|
|
|
try: |
|
return process( |
|
in_filename=in_filename, |
|
language=language, |
|
repo_id=repo_id, |
|
decoding_method=decoding_method, |
|
num_active_paths=num_active_paths, |
|
add_punctuation=add_punctuation, |
|
pretrained_model_dir=pretrained_model_dir, |
|
) |
|
except Exception as e: |
|
msg = "transcribe error: {}".format(str(e)) |
|
main_logger.info(msg) |
|
return "", build_html_output(msg, "result_item_error") |
|
|
|
|
|
|
|
|
|
css = """ |
|
.result {display:flex;flex-direction:column} |
|
.result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%} |
|
.result_item_success {background-color:mediumaquamarine;color:white;align-self:start} |
|
.result_item_error {background-color:#ff7070;color:white;align-self:start} |
|
""" |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
pretrained_model_dir = Path(args.pretrained_model_dir) |
|
pretrained_model_dir.mkdir(exist_ok=True) |
|
|
|
process_uploaded_file_ = functools.partial( |
|
process_uploaded_file, |
|
pretrained_model_dir=pretrained_model_dir, |
|
) |
|
|
|
title = "# Automatic Speech Recognition with Next-gen Kaldi" |
|
|
|
language_choices = list(nn_models.model_map.keys()) |
|
|
|
language_to_models = defaultdict(list) |
|
for k, v in nn_models.model_map.items(): |
|
for m in v: |
|
repo_id = m["repo_id"] |
|
language_to_models[k].append(repo_id) |
|
|
|
|
|
with gr.Blocks(css=css) as blocks: |
|
gr.Markdown(value=title) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Upload from disk"): |
|
language_radio = gr.Radio( |
|
label="Language", |
|
choices=language_choices, |
|
value=language_choices[0], |
|
) |
|
model_dropdown = gr.Dropdown( |
|
choices=language_to_models[language_choices[0]], |
|
label="Select a model", |
|
value=language_to_models[language_choices[0]][0], |
|
) |
|
decoding_method_radio = gr.Radio( |
|
label="Decoding method", |
|
choices=["greedy_search", "modified_beam_search"], |
|
value="greedy_search", |
|
) |
|
num_active_paths_slider = gr.Slider( |
|
minimum=1, |
|
value=4, |
|
step=1, |
|
label="Number of active paths for modified_beam_search", |
|
) |
|
punct_radio = gr.Radio( |
|
label="Whether to add punctuation (Only for Chinese and English)", |
|
choices=["Yes", "No"], |
|
value="Yes", |
|
) |
|
|
|
uploaded_file = gr.Audio( |
|
sources=["upload"], |
|
type="filepath", |
|
label="Upload from disk", |
|
) |
|
upload_button = gr.Button("Submit for recognition") |
|
uploaded_output = gr.Textbox(label="Recognized speech from uploaded file") |
|
uploaded_html_info = gr.HTML(label="Info") |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[ |
|
language_radio, |
|
model_dropdown, |
|
decoding_method_radio, |
|
num_active_paths_slider, |
|
punct_radio, |
|
uploaded_file, |
|
], |
|
outputs=[uploaded_output, uploaded_html_info], |
|
fn=process_uploaded_file_, |
|
) |
|
|
|
upload_button.click( |
|
process_uploaded_file_, |
|
inputs=[ |
|
language_radio, |
|
model_dropdown, |
|
decoding_method_radio, |
|
num_active_paths_slider, |
|
punct_radio, |
|
uploaded_file, |
|
], |
|
outputs=[uploaded_output, uploaded_html_info], |
|
) |
|
|
|
language_radio.change( |
|
update_model_dropdown, |
|
inputs=language_radio, |
|
outputs=model_dropdown, |
|
) |
|
|
|
blocks.queue().launch( |
|
share=False if platform.system() == "Windows" else False, |
|
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", |
|
server_port=7860 |
|
) |
|
|
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|