#!/usr/bin/python3 # -*- coding: utf-8 -*- from enum import Enum from functools import lru_cache import logging import os import platform from pathlib import Path import huggingface_hub import sherpa import sherpa_onnx main_logger = logging.getLogger("main") class EnumDecodingMethod(Enum): greedy_search = "greedy_search" modified_beam_search = "modified_beam_search" model_map = { "Chinese": [ { "repo_id": "csukuangfj/wenet-chinese-model", "nn_model_file": "final.zip", "nn_model_file_sub_folder": ".", "tokens_file": "units.txt", "tokens_file_sub_folder": ".", "normalize_samples": False, "loader": "load_sherpa_offline_recognizer", }, { "repo_id": "csukuangfj/sherpa-onnx-paraformer-zh-2024-03-09", "nn_model_file": "model.int8.onnx", "nn_model_file_sub_folder": ".", "tokens_file": "tokens.txt", "tokens_file_sub_folder": ".", "loader": "load_sherpa_offline_recognizer_from_paraformer", }, { "repo_id": "csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09", "nn_model_file": "model.int8.onnx", "nn_model_file_sub_folder": ".", "tokens_file": "tokens.txt", "tokens_file_sub_folder": ".", "loader": "load_sherpa_offline_recognizer_from_paraformer", }, { "repo_id": "luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2", "nn_model_file": "cpu_jit_epoch_10_avg_2_torch_1.7.1.pt", "nn_model_file_sub_folder": "exp", "tokens_file": "tokens.txt", "tokens_file_sub_folder": "data/lang_char", "normalize_samples": True, "loader": "load_sherpa_offline_recognizer", }, { "repo_id": "zrjin/sherpa-onnx-zipformer-multi-zh-hans-2023-9-2", "encoder_model_file": "encoder-epoch-20-avg-1.onnx", "encoder_model_file_sub_folder": ".", "decoder_model_file": "decoder-epoch-20-avg-1.onnx", "decoder_model_file_sub_folder": ".", "joiner_model_file": "joiner-epoch-20-avg-1.onnx", "joiner_model_file_sub_folder": ".", "tokens_file": "tokens.txt", "tokens_file_sub_folder": ".", "loader": "load_sherpa_offline_recognizer_from_transducer", }, ], "English": [ { "repo_id": "csukuangfj/sherpa-onnx-paraformer-en-2024-03-09", "nn_model_file": "model.int8.onnx", "nn_model_file_sub_folder": ".", "tokens_file": "tokens.txt", "tokens_file_sub_folder": ".", "loader": "load_sherpa_offline_recognizer_from_paraformer", }, ], "Chinese+English": [ { "repo_id": "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28", "nn_model_file": "model.int8.onnx", "nn_model_file_sub_folder": ".", "tokens_file": "tokens.txt", "tokens_file_sub_folder": ".", "loader": "load_sherpa_offline_recognizer_from_paraformer", }, ], "Chinese+Cantonese+English": [ { "repo_id": "csukuangfj/sherpa-onnx-paraformer-trilingual-zh-cantonese-en", "nn_model_file": "model.int8.onnx", "nn_model_file_sub_folder": ".", "tokens_file": "tokens.txt", "tokens_file_sub_folder": ".", "loader": "load_sherpa_offline_recognizer_from_paraformer", }, ] } def download_model(local_model_dir: str, **kwargs, ): repo_id = kwargs["repo_id"] if "nn_model_file" in kwargs.keys(): main_logger.info("download nn_model_file. filename: {}, subfolder: {}".format(kwargs["nn_model_file"], kwargs["nn_model_file_sub_folder"])) _ = huggingface_hub.hf_hub_download( repo_id=repo_id, filename=kwargs["nn_model_file"], subfolder=kwargs["nn_model_file_sub_folder"], local_dir=local_model_dir, ) if "encoder_model_file" in kwargs.keys(): main_logger.info("download encoder_model_file. filename: {}, subfolder: {}".format(kwargs["encoder_model_file"], kwargs["encoder_model_file_sub_folder"])) _ = huggingface_hub.hf_hub_download( repo_id=repo_id, filename=kwargs["encoder_model_file"], subfolder=kwargs["encoder_model_file_sub_folder"], local_dir=local_model_dir, ) if "decoder_model_file" in kwargs.keys(): main_logger.info("download decoder_model_file. filename: {}, subfolder: {}".format(kwargs["decoder_model_file"], kwargs["decoder_model_file_sub_folder"])) _ = huggingface_hub.hf_hub_download( repo_id=repo_id, filename=kwargs["decoder_model_file"], subfolder=kwargs["decoder_model_file_sub_folder"], local_dir=local_model_dir, ) if "joiner_model_file" in kwargs.keys(): main_logger.info("download joiner_model_file. filename: {}, subfolder: {}".format(kwargs["joiner_model_file"], kwargs["joiner_model_file_sub_folder"])) _ = huggingface_hub.hf_hub_download( repo_id=repo_id, filename=kwargs["joiner_model_file"], subfolder=kwargs["joiner_model_file_sub_folder"], local_dir=local_model_dir, ) if "tokens_file" in kwargs.keys(): main_logger.info("download tokens_file. filename: {}, subfolder: {}".format(kwargs["tokens_file"], kwargs["tokens_file_sub_folder"])) _ = huggingface_hub.hf_hub_download( repo_id=repo_id, filename=kwargs["tokens_file"], subfolder=kwargs["tokens_file_sub_folder"], local_dir=local_model_dir, ) def load_sherpa_offline_recognizer(nn_model_file: str, tokens_file: str, sample_rate: int = 16000, num_active_paths: int = 2, decoding_method: str = "greedy_search", num_mel_bins: int = 80, frame_dither: int = 0, normalize_samples: bool = False, ): feat_config = sherpa.FeatureConfig(normalize_samples=normalize_samples) feat_config.fbank_opts.frame_opts.samp_freq = sample_rate feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins feat_config.fbank_opts.frame_opts.dither = frame_dither if not os.path.exists(nn_model_file): raise AssertionError("nn_model_file not found. nn_model_file: {}".format(nn_model_file)) config = sherpa.OfflineRecognizerConfig( nn_model=nn_model_file, tokens=tokens_file, use_gpu=False, feat_config=feat_config, decoding_method=decoding_method, num_active_paths=num_active_paths, ) recognizer = sherpa.OfflineRecognizer(config) return recognizer def load_sherpa_offline_recognizer_from_paraformer(nn_model_file: str, tokens_file: str, sample_rate: int = 16000, decoding_method: str = "greedy_search", feature_dim: int = 80, num_threads: int = 2, ): recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( paraformer=nn_model_file, tokens=tokens_file, num_threads=num_threads, sample_rate=sample_rate, feature_dim=feature_dim, decoding_method=decoding_method, debug=False, ) return recognizer def load_sherpa_offline_recognizer_from_transducer(encoder_model_file: str, decoder_model_file: str, joiner_model_file: str, tokens_file: str, sample_rate: int = 16000, decoding_method: str = "greedy_search", feature_dim: int = 80, num_threads: int = 2, num_active_paths: int = 2, ): recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( encoder=encoder_model_file, decoder=decoder_model_file, joiner=joiner_model_file, tokens=tokens_file, num_threads=num_threads, sample_rate=sample_rate, feature_dim=feature_dim, decoding_method=decoding_method, max_active_paths=num_active_paths, ) return recognizer def load_recognizer(local_model_dir: Path, decoding_method: str = "greedy_search", num_active_paths: int = 4, **kwargs ): if not local_model_dir.exists(): download_model( local_model_dir=local_model_dir.as_posix(), **kwargs, ) loader = kwargs["loader"] kwargs_ = dict() if "nn_model_file" in kwargs.keys(): nn_model_file = (local_model_dir / kwargs["nn_model_file"]).as_posix() kwargs_["nn_model_file"] = nn_model_file if "encoder_model_file" in kwargs.keys(): encoder_model_file = (local_model_dir / kwargs["encoder_model_file"]).as_posix() kwargs_["encoder_model_file"] = encoder_model_file if "decoder_model_file" in kwargs.keys(): decoder_model_file = (local_model_dir / kwargs["decoder_model_file"]).as_posix() kwargs_["decoder_model_file"] = decoder_model_file if "joiner_model_file" in kwargs.keys(): joiner_model_file = (local_model_dir / kwargs["joiner_model_file"]).as_posix() kwargs_["joiner_model_file"] = joiner_model_file if "tokens_file" in kwargs.keys(): tokens_file = (local_model_dir / kwargs["tokens_file"]).as_posix() kwargs_["tokens_file"] = tokens_file if "normalize_samples" in kwargs.keys(): kwargs_["normalize_samples"] = kwargs["normalize_samples"] if loader == "load_sherpa_offline_recognizer": recognizer = load_sherpa_offline_recognizer( decoding_method=decoding_method, num_active_paths=num_active_paths, **kwargs_ ) elif loader == "load_sherpa_offline_recognizer_from_paraformer": recognizer = load_sherpa_offline_recognizer_from_paraformer( decoding_method=decoding_method, **kwargs_ ) elif loader == "load_sherpa_offline_recognizer_from_transducer": recognizer = load_sherpa_offline_recognizer_from_transducer( decoding_method=decoding_method, **kwargs_ ) else: raise NotImplementedError("loader not support: {}".format(loader)) return recognizer if __name__ == "__main__": pass