import multiprocessing import argparse import threading import ssl import time import sys import functools from multiprocessing import Process, Manager, Value, Queue from whisper_live.trt_server import TranscriptionServer from llm_service import TensorRTLLMEngine from tts_service import WhisperSpeechTTS def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--whisper_tensorrt_path', type=str, default="/root/TensorRT-LLM/examples/whisper/whisper_small_en", help='Whisper TensorRT model path') parser.add_argument('--mistral', action="store_true", help='Mistral') parser.add_argument('--mistral_tensorrt_path', type=str, default=None, help='Mistral TensorRT model path') parser.add_argument('--mistral_tokenizer_path', type=str, default="teknium/OpenHermes-2.5-Mistral-7B", help='Mistral TensorRT model path') parser.add_argument('--phi', action="store_true", help='Phi') parser.add_argument('--phi_tensorrt_path', type=str, default="/root/TensorRT-LLM/examples/phi/phi_engine", help='Phi TensorRT model path') parser.add_argument('--phi_tokenizer_path', type=str, default="/root/TensorRT-LLM/examples/phi/phi-2", help='Phi Tokenizer path') return parser.parse_args() if __name__ == "__main__": args = parse_arguments() if not args.whisper_tensorrt_path: raise ValueError("Please provide whisper_tensorrt_path to run the pipeline.") import sys sys.exit(0) if args.mistral: if not args.mistral_tensorrt_path or not args.mistral_tokenizer_path: raise ValueError("Please provide mistral_tensorrt_path and mistral_tokenizer_path to run the pipeline.") import sys sys.exit(0) if args.phi: if not args.phi_tensorrt_path or not args.phi_tokenizer_path: raise ValueError("Please provide phi_tensorrt_path and phi_tokenizer_path to run the pipeline.") import sys sys.exit(0) multiprocessing.set_start_method('spawn') lock = multiprocessing.Lock() manager = Manager() shared_output = manager.list() transcription_queue = Queue() llm_queue = Queue() audio_queue = Queue() whisper_server = TranscriptionServer() whisper_process = multiprocessing.Process( target=whisper_server.run, args=( "0.0.0.0", 6006, transcription_queue, llm_queue, args.whisper_tensorrt_path ) ) whisper_process.start() llm_provider = TensorRTLLMEngine() # llm_provider = MistralTensorRTLLMProvider() llm_process = multiprocessing.Process( target=llm_provider.run, args=( # args.mistral_tensorrt_path, # args.mistral_tokenizer_path, args.phi_tensorrt_path, args.phi_tokenizer_path, transcription_queue, llm_queue, audio_queue, ) ) llm_process.start() # audio process tts_runner = WhisperSpeechTTS() tts_process = multiprocessing.Process(target=tts_runner.run, args=("0.0.0.0", 8888, audio_queue)) tts_process.start() llm_process.join() whisper_process.join() tts_process.join()