Spaces:
Paused
Paused
Commit
·
5d7959c
1
Parent(s):
795e5f6
add support for phi
Browse files
main.py
CHANGED
@@ -16,8 +16,11 @@ def parse_arguments():
|
|
16 |
parser = argparse.ArgumentParser()
|
17 |
parser.add_argument('--whisper_tensorrt_path',
|
18 |
type=str,
|
19 |
-
default=
|
20 |
help='Whisper TensorRT model path')
|
|
|
|
|
|
|
21 |
parser.add_argument('--mistral_tensorrt_path',
|
22 |
type=str,
|
23 |
default=None,
|
@@ -26,6 +29,17 @@ def parse_arguments():
|
|
26 |
type=str,
|
27 |
default="teknium/OpenHermes-2.5-Mistral-7B",
|
28 |
help='Mistral TensorRT model path')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
return parser.parse_args()
|
30 |
|
31 |
|
@@ -36,10 +50,17 @@ if __name__ == "__main__":
|
|
36 |
import sys
|
37 |
sys.exit(0)
|
38 |
|
39 |
-
if
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
multiprocessing.set_start_method('spawn')
|
45 |
|
@@ -70,8 +91,10 @@ if __name__ == "__main__":
|
|
70 |
llm_process = multiprocessing.Process(
|
71 |
target=llm_provider.run,
|
72 |
args=(
|
73 |
-
args.mistral_tensorrt_path,
|
74 |
-
args.mistral_tokenizer_path,
|
|
|
|
|
75 |
transcription_queue,
|
76 |
llm_queue,
|
77 |
)
|
|
|
16 |
parser = argparse.ArgumentParser()
|
17 |
parser.add_argument('--whisper_tensorrt_path',
|
18 |
type=str,
|
19 |
+
default="/root/TensorRT-LLM/examples/whisper/whisper_small_en",
|
20 |
help='Whisper TensorRT model path')
|
21 |
+
parser.add_argument('--mistral',
|
22 |
+
action="store_true",
|
23 |
+
help='Mistral')
|
24 |
parser.add_argument('--mistral_tensorrt_path',
|
25 |
type=str,
|
26 |
default=None,
|
|
|
29 |
type=str,
|
30 |
default="teknium/OpenHermes-2.5-Mistral-7B",
|
31 |
help='Mistral TensorRT model path')
|
32 |
+
parser.add_argument('--phi',
|
33 |
+
action="store_true",
|
34 |
+
help='Phi')
|
35 |
+
parser.add_argument('--phi_tensorrt_path',
|
36 |
+
type=str,
|
37 |
+
default="/root/TensorRT-LLM/examples/phi/phi_engine",
|
38 |
+
help='Phi TensorRT model path')
|
39 |
+
parser.add_argument('--phi_tokenizer_path',
|
40 |
+
type=str,
|
41 |
+
default="/root/TensorRT-LLM/examples/phi/phi-2",
|
42 |
+
help='Phi Tokenizer path')
|
43 |
return parser.parse_args()
|
44 |
|
45 |
|
|
|
50 |
import sys
|
51 |
sys.exit(0)
|
52 |
|
53 |
+
if args.mistral:
|
54 |
+
if not args.mistral_tensorrt_path or not args.mistral_tokenizer_path:
|
55 |
+
raise ValueError("Please provide mistral_tensorrt_path and mistral_tokenizer_path to run the pipeline.")
|
56 |
+
import sys
|
57 |
+
sys.exit(0)
|
58 |
+
|
59 |
+
if args.phi:
|
60 |
+
if not args.phi_tensorrt_path or not args.phi_tokenizer_path:
|
61 |
+
raise ValueError("Please provide phi_tensorrt_path and phi_tokenizer_path to run the pipeline.")
|
62 |
+
import sys
|
63 |
+
sys.exit(0)
|
64 |
|
65 |
multiprocessing.set_start_method('spawn')
|
66 |
|
|
|
91 |
llm_process = multiprocessing.Process(
|
92 |
target=llm_provider.run,
|
93 |
args=(
|
94 |
+
# args.mistral_tensorrt_path,
|
95 |
+
# args.mistral_tokenizer_path,
|
96 |
+
args.phi_tensorrt_path,
|
97 |
+
args.phi_tokenizer_path,
|
98 |
transcription_queue,
|
99 |
llm_queue,
|
100 |
)
|