HoneyTian commited on
Commit
0a78294
1 Parent(s): 5353e72
examples/add_punctuation/add_punctuation.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+
5
+ import sherpa_onnx
6
+
7
+ from project_settings import project_path
8
+
9
+
10
+ def get_args():
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument(
13
+ "--model_file",
14
+ default=(project_path / "pretrained_models/huggingface/csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx").as_posix(),
15
+ type=str
16
+ )
17
+ parser.add_argument(
18
+ "--text",
19
+ default="i'm a google virtual assistant recording this call for the person you're trying to reach before i try to connect you can ask what you're calling about",
20
+ type=str
21
+ )
22
+ args = parser.parse_args()
23
+ return args
24
+
25
+
26
+ def main():
27
+ args = get_args()
28
+
29
+ config = sherpa_onnx.OfflinePunctuationConfig(
30
+ model=sherpa_onnx.OfflinePunctuationModelConfig(
31
+ ct_transformer=args.model_file
32
+ ),
33
+ )
34
+
35
+ punctuation_model = sherpa_onnx.OfflinePunctuation(config)
36
+
37
+ text = punctuation_model.add_punctuation(args.text)
38
+ print("text: {}".format(text))
39
+ return
40
+
41
+
42
+ if __name__ == '__main__':
43
+ main()
examples/add_punctuation/download_model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import sys
7
+
8
+ pwd = os.path.abspath(os.path.dirname(__file__))
9
+ sys.path.append(os.path.join(pwd, "../../"))
10
+
11
+ import huggingface_hub
12
+
13
+ from project_settings import project_path
14
+
15
+
16
+ def get_args():
17
+ parser = argparse.ArgumentParser()
18
+
19
+ parser.add_argument(
20
+ "--repo_id",
21
+ default="csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12",
22
+ type=str
23
+ )
24
+ parser.add_argument("--model_filename", default="model.onnx", type=str)
25
+ parser.add_argument("--model_sub_folder", default=".", type=str)
26
+
27
+ parser.add_argument(
28
+ "--pretrained_model_dir",
29
+ default=(project_path / "pretrained_models").as_posix(),
30
+ type=str
31
+ )
32
+ args = parser.parse_args()
33
+ return args
34
+
35
+
36
+ def main():
37
+ args = get_args()
38
+
39
+ pretrained_model_dir = Path(args.pretrained_model_dir)
40
+ pretrained_model_dir.mkdir(exist_ok=True)
41
+
42
+ repo_id: Path = Path(args.repo_id)
43
+ local_model_dir = pretrained_model_dir / "huggingface" / repo_id
44
+ local_model_dir.mkdir(parents=True, exist_ok=True)
45
+
46
+ print("download model")
47
+ model_filename = huggingface_hub.hf_hub_download(
48
+ repo_id=args.repo_id,
49
+ filename=args.model_filename,
50
+ subfolder=args.model_sub_folder,
51
+ local_dir=local_model_dir.as_posix(),
52
+ )
53
+ print(model_filename)
54
+ return
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
examples/gradio_client/{predict.py → asr.py} RENAMED
File without changes
examples/gradio_client/whisper_large_v3.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+
5
+ from gradio_client import Client, file
6
+
7
+ from project_settings import project_path
8
+
9
+
10
+ def get_args():
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument(
13
+ "--filename",
14
+ default=(project_path / "data/test_wavs/paraformer-zh/si_chuan_hua.wav").as_posix(),
15
+ type=str
16
+ )
17
+ args = parser.parse_args()
18
+ return args
19
+
20
+
21
+ def main():
22
+ args = get_args()
23
+
24
+ filename = args.filename
25
+
26
+ client = Client("hf-audio/whisper-large-v3")
27
+ result = client.predict(
28
+ inputs=file(filename),
29
+ task="transcribe",
30
+ api_name="/predict"
31
+ )
32
+ print(result)
33
+ return
34
+
35
+
36
+ if __name__ == '__main__':
37
+ main()
examples/wenet/toolbox_download.py DELETED
File without changes
main.py CHANGED
@@ -148,10 +148,20 @@ def process(
148
  filename=out_filename.as_posix(),
149
  )
150
 
 
 
 
 
 
 
 
 
 
 
 
151
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
152
  end = time.time()
153
 
154
- # statistics
155
  metadata = torchaudio.info(out_filename.as_posix())
156
  duration = metadata.num_frames / 16000
157
  rtf = (end - start) / duration
 
148
  filename=out_filename.as_posix(),
149
  )
150
 
151
+ # load_punctuation_model
152
+ if add_punctuation == "Yes":
153
+ local_model_dir = pretrained_model_dir / "huggingface" / md5_encrypt("csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12")
154
+ punctuation_model = nn_models.load_punctuation_model(
155
+ local_model_dir=local_model_dir,
156
+ nn_model_file="model.onnx",
157
+ nn_model_file_sub_folder=".",
158
+ )
159
+ text = punctuation_model.add_punctuation(text)
160
+
161
+ # statistics
162
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
163
  end = time.time()
164
 
 
165
  metadata = torchaudio.info(out_filename.as_posix())
166
  duration = metadata.num_frames / 16000
167
  rtf = (end - start) / duration
toolbox/k2_sherpa/nn_models.py CHANGED
@@ -764,7 +764,7 @@ def load_sherpa_onnx_online_recognizer_from_paraformer(encoder_model_file: str,
764
  def load_recognizer(local_model_dir: Path,
765
  decoding_method: str = "greedy_search",
766
  num_active_paths: int = 4,
767
- **kwargs
768
  ):
769
  if not local_model_dir.exists():
770
  download_model(
@@ -839,5 +839,29 @@ def load_recognizer(local_model_dir: Path,
839
  return recognizer
840
 
841
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842
  if __name__ == "__main__":
843
  pass
 
764
  def load_recognizer(local_model_dir: Path,
765
  decoding_method: str = "greedy_search",
766
  num_active_paths: int = 4,
767
+ **kwargs,
768
  ):
769
  if not local_model_dir.exists():
770
  download_model(
 
839
  return recognizer
840
 
841
 
842
+ def load_punctuation_model(local_model_dir: Path,
843
+ nn_model_file: str,
844
+ nn_model_file_sub_folder: str,
845
+ ):
846
+ if not local_model_dir.exists():
847
+ download_model(
848
+ local_model_dir=local_model_dir.as_posix(),
849
+ nn_model_file=nn_model_file,
850
+ nn_model_file_sub_folder=nn_model_file_sub_folder,
851
+ )
852
+
853
+ nn_model_file = (local_model_dir / nn_model_file_sub_folder / nn_model_file).as_posix()
854
+
855
+ config = sherpa_onnx.OfflinePunctuationConfig(
856
+ model=sherpa_onnx.OfflinePunctuationModelConfig(
857
+ ct_transformer=nn_model_file
858
+ ),
859
+ )
860
+
861
+ punctuation_model = sherpa_onnx.OfflinePunctuation(config)
862
+
863
+ return punctuation_model
864
+
865
+
866
  if __name__ == "__main__":
867
  pass