#!/usr/bin/python3 # -*- coding: utf-8 -*- from enum import Enum from functools import lru_cache import os from pathlib import Path import huggingface_hub import sherpa import sherpa_onnx class EnumDecodingMethod(Enum): greedy_search = "greedy_search" modified_beam_search = "modified_beam_search" model_map = { "Chinese": [ { "repo_id": "csukuangfj/wenet-chinese-model", "nn_model_file": "final.zip", "tokens_file": "units.txt", "sub_folder": ".", "loader": "load_sherpa_offline_recognizer", }, { "repo_id": "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28", "nn_model_file": "model.int8.onnx", "tokens_file": "tokens.txt", "sub_folder": ".", "loader": "load_sherpa_offline_recognizer_from_paraformer", } ] } def download_model(repo_id: str, nn_model_file: str, tokens_file: str, sub_folder: str, local_model_dir: str, ): nn_model_file = huggingface_hub.hf_hub_download( repo_id=repo_id, filename=nn_model_file, subfolder=sub_folder, local_dir=local_model_dir, ) tokens_file = huggingface_hub.hf_hub_download( repo_id=repo_id, filename=tokens_file, subfolder=sub_folder, local_dir=local_model_dir, ) return nn_model_file, tokens_file def load_sherpa_offline_recognizer(nn_model_file: str, tokens_file: str, sample_rate: int = 16000, num_active_paths: int = 2, decoding_method: str = "greedy_search", num_mel_bins: int = 80, frame_dither: int = 0, ): feat_config = sherpa.FeatureConfig(normalize_samples=False) feat_config.fbank_opts.frame_opts.samp_freq = sample_rate feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins feat_config.fbank_opts.frame_opts.dither = frame_dither config = sherpa.OfflineRecognizerConfig( nn_model=nn_model_file, tokens=tokens_file, use_gpu=False, feat_config=feat_config, decoding_method=decoding_method, num_active_paths=num_active_paths, ) recognizer = sherpa.OfflineRecognizer(config) return recognizer def load_sherpa_offline_recognizer_from_paraformer(nn_model_file: str, tokens_file: str, sample_rate: int = 16000, decoding_method: str = "greedy_search", feature_dim: int = 80, num_threads: int = 2, ): recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( paraformer=nn_model_file, tokens=tokens_file, num_threads=num_threads, sample_rate=sample_rate, feature_dim=feature_dim, decoding_method=decoding_method, debug=False, ) return recognizer def load_recognizer(repo_id: str, nn_model_file: str, tokens_file: str, sub_folder: str, local_model_dir: Path, loader: str, decoding_method: str = "greedy_search", num_active_paths: int = 4, ): if not os.path.exists(local_model_dir): download_model( repo_id=repo_id, nn_model_file=nn_model_file, tokens_file=tokens_file, sub_folder=sub_folder, local_model_dir=local_model_dir.as_posix(), ) nn_model_file = (local_model_dir / nn_model_file).as_posix() tokens_file = (local_model_dir / tokens_file).as_posix() if loader == "load_sherpa_offline_recognizer": recognizer = load_sherpa_offline_recognizer( nn_model_file=nn_model_file, tokens_file=tokens_file, decoding_method=decoding_method, num_active_paths=num_active_paths, ) elif loader == "load_sherpa_offline_recognizer_from_paraformer": recognizer = load_sherpa_offline_recognizer_from_paraformer( nn_model_file=nn_model_file, tokens_file=tokens_file, decoding_method=decoding_method, ) else: raise NotImplementedError("loader not support: {}".format(loader)) return recognizer if __name__ == "__main__": pass