#!/usr/bin/python3 # -*- coding: utf-8 -*- from enum import Enum from functools import lru_cache import os import huggingface_hub import sherpa class EnumDecodingMethod(Enum): greedy_search = "greedy_search" modified_beam_search = "modified_beam_search" class EnumRecognizerType(Enum): sherpa_offline_recognizer = "sherpa.OfflineRecognizer" sherpa_online_recognizer = "sherpa.OnlineRecognizer" sherpa_onnx_offline_recognizer = "sherpa_onnx.OfflineRecognizer" sherpa_onnx_online_recognizer = "sherpa_onnx.OnlineRecognizer" model_map = { "Chinese": [ { "repo_id": "csukuangfj/wenet-chinese-model", "nn_model_file": "final.zip", "tokens_file": "units.txt", "sub_folder": ".", "recognizer_type": EnumRecognizerType.sherpa_offline_recognizer.value, } ] } def download_model(repo_id: str, nn_model_file: str, tokens_file: str, sub_folder: str, local_model_dir: str, ): nn_model_file = huggingface_hub.hf_hub_download( repo_id=repo_id, filename=nn_model_file, subfolder=sub_folder, local_dir=local_model_dir, ) tokens_file = huggingface_hub.hf_hub_download( repo_id=repo_id, filename=tokens_file, subfolder=sub_folder, local_dir=local_model_dir, ) return nn_model_file, tokens_file def load_sherpa_offline_recognizer(nn_model_file: str, tokens_file: str, sample_rate: int = 16000, num_active_paths: int = 2, decoding_method: str = "greedy_search", num_mel_bins: int = 80, frame_dither: int = 0, ): feat_config = sherpa.FeatureConfig(normalize_samples=False) feat_config.fbank_opts.frame_opts.samp_freq = sample_rate feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins feat_config.fbank_opts.frame_opts.dither = frame_dither config = sherpa.OfflineRecognizerConfig( nn_model=nn_model_file, tokens=tokens_file, use_gpu=False, feat_config=feat_config, decoding_method=decoding_method, num_active_paths=num_active_paths, ) recognizer = sherpa.OfflineRecognizer(config) return recognizer def load_recognizer(repo_id: str, nn_model_file: str, tokens_file: str, sub_folder: str, local_model_dir: str, recognizer_type: str, decoding_method: str = "greedy_search", num_active_paths: int = 4, ): if not os.path.exists(local_model_dir): download_model( repo_id=repo_id, nn_model_file=nn_model_file, tokens_file=tokens_file, sub_folder=sub_folder, local_model_dir=local_model_dir, ) if recognizer_type == EnumRecognizerType.sherpa_offline_recognizer.value: recognizer = load_sherpa_offline_recognizer( nn_model_file=nn_model_file, tokens_file=tokens_file, decoding_method=decoding_method, num_active_paths=num_active_paths, ) else: raise NotImplementedError("recognizer_type not support: {}".format(recognizer_type)) return recognizer if __name__ == "__main__": pass