diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..04298d77cd9726b8143d50277579a28eee000b25 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,8 @@ +FROM ghcr.io/collabora/whisperfusion-base:latest as base + +WORKDIR /root +COPY scripts/setup-whisperfusion.sh scripts/run-whisperfusion.sh scratch-space/models /root/ +RUN ./setup-whisperfusion.sh + +CMD ./run-whisperfusion.sh + diff --git a/README (2).md b/README (2).md new file mode 100644 index 0000000000000000000000000000000000000000..d89385a5dbc570a3e6908e4fc0e5e501bff2ac7c --- /dev/null +++ b/README (2).md @@ -0,0 +1,61 @@ +# WhisperFusion + +
Using state-of-the-art natural language processing techniques, we implemented WhisperFusion a techonlogy demo that combines live transcriptions, LLM's and text-to-speech, in a low-latency pipeline. For more details about the demo, checkout https://github.com/collabora/WhisperFusion.
+" + data["segments"][0].text + "
", "transcription-" + available_transcription_elements); + new_transcription_element_state = false; + } + document.getElementById("transcription-" + available_transcription_elements).innerHTML = "" + data["segments"][0].text + "
"; + + if (data["eos"] == true) { + new_transcription_element_state = true; + } + } else if ("llm_output" in data) { + new_transcription_element("Phi-2", "Phi.svg"); + new_text_element("" + data["llm_output"][0] + "
", "llm-" + available_transcription_elements); + } + + window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' }); + } +} + +function new_transcription_element(speaker_name, speaker_avatar) { + var avatar_container = document.createElement("div"); + avatar_container.className = "avatar-container"; + + var avatar_img = document.createElement("div"); + avatar_img.innerHTML = ""; + + var avatar_name = document.createElement("div"); + avatar_name.className = "avatar-name"; + avatar_name.innerHTML = speaker_name; + + var dummy_element = document.createElement("div"); + + avatar_container.appendChild(avatar_img); + avatar_container.appendChild(avatar_name); + avatar_container.appendChild(dummy_element); + + document.getElementById("main-wrapper").appendChild(avatar_container); +} + +function new_text_element(text, id) { + var text_container = document.createElement("div"); + text_container.className = "text-container"; + text_container.style.maxWidth = "500px"; + + var text_element = document.createElement("div"); + text_element.id = id; + text_element.innerHTML = "" + text + "
"; + + var dummy_element = document.createElement("div"); + + text_container.appendChild(text_element); + text_container.appendChild(dummy_element); + + document.getElementById("main-wrapper").appendChild(text_container); +} + +function new_transcription_time_element(time) { + var text_container = document.createElement("div"); + text_container.className = "transcription-timing-container"; + text_container.style.maxWidth = "500px"; + + var text_element = document.createElement("div"); + text_element.innerHTML = "WhisperLive - Transcription time: " + time + "ms"; + + var dummy_element = document.createElement("div"); + + text_container.appendChild(text_element); + text_container.appendChild(dummy_element); + + document.getElementById("main-wrapper").appendChild(text_container); +} + +function new_llm_time_element(time) { + var text_container = document.createElement("div"); + text_container.className = "llm-timing-container"; + text_container.style.maxWidth = "500px"; + + var first_response_text_element = document.createElement("div"); + first_response_text_element.innerHTML = "Phi-2 first response time: " + time + "ms"; + + var complete_response_text_element = document.createElement("div"); + complete_response_text_element.innerHTML = "Phi-2 complete response time: " + time + "ms"; + + var dummy_element = document.createElement("div"); + + text_container.appendChild(first_response_text_element); + text_container.appendChild(complete_response_text_element); + text_container.appendChild(dummy_element); + + document.getElementById("main-wrapper").appendChild(text_container); +} + +function new_whisper_speech_audio_element(id, duration) { + var audio_container = document.createElement("div"); + audio_container.className = "whisperspeech-audio-container"; + audio_container.style.maxWidth = "500px"; + + var audio_div_element = document.createElement("div"); + var audio_element = document.createElement("audio"); + audio_element.style.paddingTop = "20px"; + + if (duration > 10) + duration = 10; + audio_element.src = "static/" + duration + ".mp3"; + + audio_element.id = id; + audio_element.onplay = function() { + console.log(this.id) + var id = this.id.split("-")[1] - 1; + + if (audio_source) { + audio_source.disconnect(); + } + + audio_source = audioContext_tts.createBufferSource(); + audio_source.buffer = audio_sources[id]; + audio_source.connect(audioContext_tts.destination); + audio_source.start() + }; + audio_element.onpause = function() { + this.currentTime = 0; + console.log(this.id) + var id = this.id.split("-")[1] - 1; + if (audio_source) { + audio_source.stop(); + } + }; + audio_element.controls = true; + + audio_div_element.appendChild(audio_element); + + var dummy_element_a = document.createElement("div"); + var dummy_element_b = document.createElement("div"); + + audio_container.appendChild(dummy_element_a); + audio_container.appendChild(audio_div_element); + audio_container.appendChild(dummy_element_b); + + document.getElementById("main-wrapper").appendChild(audio_container); +} + +function new_whisper_speech_time_element(time) { + var text_container = document.createElement("div"); + text_container.className = "whisperspeech-timing-container"; + text_container.style.maxWidth = "500px"; + + var text_element = document.createElement("div"); + text_element.innerHTML = "WhisperSpeech response time: " + time + "ms"; + + var dummy_element = document.createElement("div"); + + text_container.appendChild(text_element); + text_container.appendChild(dummy_element); + + document.getElementById("main-wrapper").appendChild(text_container); +} + +document.addEventListener('DOMContentLoaded', function() { + const queryString = window.location.search; + const urlParams = new URLSearchParams(queryString); + if (urlParams.has('name')) { + you_name = urlParams.get('name') + } + }, false); \ No newline at end of file diff --git a/examples/chatbot/html/static/0.mp3 b/examples/chatbot/html/static/0.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..787be87653c143603c96d4327290dedd054b09e1 Binary files /dev/null and b/examples/chatbot/html/static/0.mp3 differ diff --git a/examples/chatbot/html/static/1.mp3 b/examples/chatbot/html/static/1.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..d00f3c41dc160269d38ea6e34fdd8794e6c93008 Binary files /dev/null and b/examples/chatbot/html/static/1.mp3 differ diff --git a/examples/chatbot/html/static/10.mp3 b/examples/chatbot/html/static/10.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..2249cae9a3510a0cb2229a5e733ebc24d3a4acb0 Binary files /dev/null and b/examples/chatbot/html/static/10.mp3 differ diff --git a/examples/chatbot/html/static/2.mp3 b/examples/chatbot/html/static/2.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..0766cc95eb0ed79639f0cea68d765c1f1d196586 Binary files /dev/null and b/examples/chatbot/html/static/2.mp3 differ diff --git a/examples/chatbot/html/static/3.mp3 b/examples/chatbot/html/static/3.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..2060864853559ed9f90ae278a6a1c8ab9300b5c9 Binary files /dev/null and b/examples/chatbot/html/static/3.mp3 differ diff --git a/examples/chatbot/html/static/4.mp3 b/examples/chatbot/html/static/4.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..00c976f0b64804e597f07a015ace360b8073a371 Binary files /dev/null and b/examples/chatbot/html/static/4.mp3 differ diff --git a/examples/chatbot/html/static/5.mp3 b/examples/chatbot/html/static/5.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..cad47f7684b09997ac3096857c4b187634865551 Binary files /dev/null and b/examples/chatbot/html/static/5.mp3 differ diff --git a/examples/chatbot/html/static/6.mp3 b/examples/chatbot/html/static/6.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..43d837c2f6b2850d176f19aa3bb78acb694ccfdd Binary files /dev/null and b/examples/chatbot/html/static/6.mp3 differ diff --git a/examples/chatbot/html/static/7.mp3 b/examples/chatbot/html/static/7.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..e47e161af954f24813333fd81ec54d4c8286396a Binary files /dev/null and b/examples/chatbot/html/static/7.mp3 differ diff --git a/examples/chatbot/html/static/8.mp3 b/examples/chatbot/html/static/8.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..38fb7b8a00f0c209d798056e8bddc8b0945043b1 Binary files /dev/null and b/examples/chatbot/html/static/8.mp3 differ diff --git a/examples/chatbot/html/static/9.mp3 b/examples/chatbot/html/static/9.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..2249cae9a3510a0cb2229a5e733ebc24d3a4acb0 Binary files /dev/null and b/examples/chatbot/html/static/9.mp3 differ diff --git a/llm_service.py b/llm_service.py new file mode 100644 index 0000000000000000000000000000000000000000..3db1fa317261fae82c6144239f3f7048a3477ff9 --- /dev/null +++ b/llm_service.py @@ -0,0 +1,354 @@ +import time +import json +from pathlib import Path +from typing import Optional + +import logging +logging.basicConfig(level = logging.INFO) + +import numpy as np +import torch +from transformers import AutoTokenizer +import re + +import tensorrt_llm +from tensorrt_llm.logger import logger +from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner + +if PYTHON_BINDINGS: + from tensorrt_llm.runtime import ModelRunnerCpp + + +def read_model_name(engine_dir: str): + engine_version = tensorrt_llm.runtime.engine.get_engine_version(engine_dir) + + with open(Path(engine_dir) / "config.json", 'r') as f: + config = json.load(f) + + if engine_version is None: + return config['builder_config']['name'] + + return config['pretrained_config']['architecture'] + + +def throttle_generator(generator, stream_interval): + for i, out in enumerate(generator): + if not i % stream_interval: + yield out + + if i % stream_interval: + yield out + + +def load_tokenizer(tokenizer_dir: Optional[str] = None, + vocab_file: Optional[str] = None, + model_name: str = 'gpt', + tokenizer_type: Optional[str] = None): + if vocab_file is None: + use_fast = True + if tokenizer_type is not None and tokenizer_type == "llama": + use_fast = False + # Should set both padding_side and truncation_side to be 'left' + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, + legacy=False, + padding_side='left', + truncation_side='left', + trust_remote_code=True, + tokenizer_type=tokenizer_type, + use_fast=use_fast) + else: + # For gpt-next, directly load from tokenizer.model + assert model_name == 'gpt' + tokenizer = T5Tokenizer(vocab_file=vocab_file, + padding_side='left', + truncation_side='left') + + if model_name == 'qwen': + with open(Path(tokenizer_dir) / "generation_config.json") as f: + gen_config = json.load(f) + chat_format = gen_config['chat_format'] + if chat_format == 'raw': + pad_id = gen_config['pad_token_id'] + end_id = gen_config['eos_token_id'] + elif chat_format == 'chatml': + pad_id = tokenizer.im_end_id + end_id = tokenizer.im_end_id + else: + raise Exception(f"unknown chat format: {chat_format}") + elif model_name == 'glm_10b': + pad_id = tokenizer.pad_token_id + end_id = tokenizer.eop_token_id + else: + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + pad_id = tokenizer.pad_token_id + end_id = tokenizer.eos_token_id + + return tokenizer, pad_id, end_id + + +class TensorRTLLMEngine: + def __init__(self): + pass + + def initialize_model(self, engine_dir, tokenizer_dir): + self.log_level = 'error' + self.runtime_rank = tensorrt_llm.mpi_rank() + logger.set_level(self.log_level) + model_name = read_model_name(engine_dir) + self.tokenizer, self.pad_id, self.end_id = load_tokenizer( + tokenizer_dir=tokenizer_dir, + vocab_file=None, + model_name=model_name, + tokenizer_type=None, + ) + self.prompt_template = None + self.runner_cls = ModelRunner + self.runner_kwargs = dict(engine_dir=engine_dir, + lora_dir=None, + rank=self.runtime_rank, + debug_mode=False, + lora_ckpt_source='hf') + self.runner = self.runner_cls.from_dir(**self.runner_kwargs) + self.last_prompt = None + self.last_output = None + + def parse_input( + self, + input_text=None, + add_special_tokens=True, + max_input_length=923, + pad_id=None, + ): + if self.pad_id is None: + self.pad_id = self.tokenizer.pad_token_id + + batch_input_ids = [] + for curr_text in input_text: + if self.prompt_template is not None: + curr_text = self.prompt_template.format(input_text=curr_text) + input_ids = self.tokenizer.encode( + curr_text, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=max_input_length + ) + batch_input_ids.append(input_ids) + + batch_input_ids = [ + torch.tensor(x, dtype=torch.int32) for x in batch_input_ids + ] + return batch_input_ids + + def decode_tokens( + self, + output_ids, + input_lengths, + sequence_lengths, + transcription_queue + ): + batch_size, num_beams, _ = output_ids.size() + for batch_idx in range(batch_size): + if transcription_queue.qsize() != 0: + return None + + inputs = output_ids[batch_idx][0][:input_lengths[batch_idx]].tolist() + input_text = self.tokenizer.decode(inputs) + output = [] + for beam in range(num_beams): + if transcription_queue.qsize() != 0: + return None + + output_begin = input_lengths[batch_idx] + output_end = sequence_lengths[batch_idx][beam] + outputs = output_ids[batch_idx][beam][ + output_begin:output_end].tolist() + output_text = self.tokenizer.decode(outputs) + output.append(output_text) + return output + + def format_prompt_qa(self, prompt, conversation_history): + formatted_prompt = "" + for user_prompt, llm_response in conversation_history: + formatted_prompt += f"Instruct: {user_prompt}\nOutput:{llm_response}\n" + return f"{formatted_prompt}Instruct: {prompt}\nOutput:" + + def format_prompt_chat(self, prompt, conversation_history): + formatted_prompt = "" + for user_prompt, llm_response in conversation_history: + formatted_prompt += f"Alice: {user_prompt}\nBob:{llm_response}\n" + return f"{formatted_prompt}Alice: {prompt}\nBob:" + + def format_prompt_chatml(self, prompt, conversation_history, system_prompt=""): + formatted_prompt = ("<|im_start|>system\n" + system_prompt + "<|im_end|>\n") + for user_prompt, llm_response in conversation_history: + formatted_prompt += f"<|im_start|>user\n{user_prompt}<|im_end|>\n" + formatted_prompt += f"<|im_start|>assistant\n{llm_response}<|im_end|>\n" + formatted_prompt += f"<|im_start|>user\n{prompt}<|im_end|>\n" + return formatted_prompt + + def run( + self, + model_path, + tokenizer_path, + transcription_queue=None, + llm_queue=None, + audio_queue=None, + input_text=None, + max_output_len=50, + max_attention_window_size=4096, + num_beams=1, + streaming=False, + streaming_interval=4, + debug=False, + ): + self.initialize_model( + model_path, + tokenizer_path, + ) + + logging.info("[LLM INFO:] Loaded LLM TensorRT Engine.") + + conversation_history = {} + + while True: + + # Get the last transcription output from the queue + transcription_output = transcription_queue.get() + if transcription_queue.qsize() != 0: + continue + + if transcription_output["uid"] not in conversation_history: + conversation_history[transcription_output["uid"]] = [] + + prompt = transcription_output['prompt'].strip() + + # if prompt is same but EOS is True, we need that to send outputs to websockets + if self.last_prompt == prompt: + if self.last_output is not None and transcription_output["eos"]: + self.eos = transcription_output["eos"] + llm_queue.put({ + "uid": transcription_output["uid"], + "llm_output": self.last_output, + "eos": self.eos, + "latency": self.infer_time + }) + audio_queue.put({"llm_output": self.last_output, "eos": self.eos}) + conversation_history[transcription_output["uid"]].append( + (transcription_output['prompt'].strip(), self.last_output[0].strip()) + ) + continue + + # input_text=[self.format_prompt_qa(prompt, conversation_history[transcription_output["uid"]])] + input_text=[self.format_prompt_chatml(prompt, conversation_history[transcription_output["uid"]], system_prompt="You are Dolphin, a helpful AI assistant")] + + self.eos = transcription_output["eos"] + + batch_input_ids = self.parse_input( + input_text=input_text, + add_special_tokens=True, + max_input_length=923, + pad_id=None, + ) + + input_lengths = [x.size(0) for x in batch_input_ids] + + logging.info(f"[LLM INFO:] Running LLM Inference with WhisperLive prompt: {prompt}, eos: {self.eos}") + start = time.time() + with torch.no_grad(): + outputs = self.runner.generate( + batch_input_ids, + max_new_tokens=max_output_len, + max_attention_window_size=max_attention_window_size, + end_id=self.end_id, + pad_id=self.pad_id, + temperature=1.0, + top_k=1, + top_p=0.0, + num_beams=num_beams, + length_penalty=1.0, + repetition_penalty=1.0, + stop_words_list=None, + bad_words_list=None, + lora_uids=None, + prompt_table_path=None, + prompt_tasks=None, + streaming=streaming, + output_sequence_lengths=True, + return_dict=True) + torch.cuda.synchronize() + if streaming: + for curr_outputs in throttle_generator(outputs, streaming_interval): + output_ids = curr_outputs['output_ids'] + sequence_lengths = curr_outputs['sequence_lengths'] + output = self.decode_tokens( + output_ids, + input_lengths, + sequence_lengths, + transcription_queue + ) + + if output is None: + break + + # Interrupted by transcription queue + if output is None: + continue + else: + output_ids = outputs['output_ids'] + sequence_lengths = outputs['sequence_lengths'] + context_logits = None + generation_logits = None + if self.runner.gather_context_logits: + context_logits = outputs['context_logits'] + if self.runner.gather_generation_logits: + generation_logits = outputs['generation_logits'] + output = self.decode_tokens( + output_ids, + input_lengths, + sequence_lengths, + transcription_queue + ) + self.infer_time = time.time() - start + + # if self.eos: + if output is not None: + output[0] = clean_llm_output(output[0]) + self.last_output = output + self.last_prompt = prompt + llm_queue.put({ + "uid": transcription_output["uid"], + "llm_output": output, + "eos": self.eos, + "latency": self.infer_time + }) + audio_queue.put({"llm_output": output, "eos": self.eos}) + logging.info(f"[LLM INFO:] Output: {output[0]}\nLLM inference done in {self.infer_time} ms\n\n") + + if self.eos: + conversation_history[transcription_output["uid"]].append( + (transcription_output['prompt'].strip(), output[0].strip()) + ) + self.last_prompt = None + self.last_output = None + +def clean_llm_output(output): + output = output.replace("\n\nDolphin\n\n", "") + output = output.replace("\nDolphin\n\n", "") + output = output.replace("Dolphin: ", "") + output = output.replace("Assistant: ", "") + + if not output.endswith('.') and not output.endswith('?') and not output.endswith('!'): + last_punct = output.rfind('.') + last_q = output.rfind('?') + if last_q > last_punct: + last_punct = last_q + + last_ex = output.rfind('!') + if last_ex > last_punct: + last_punct = last_ex + + if last_punct > 0: + output = output[:last_punct+1] + + return output diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a5732cb0d173809eb05aa16050a7fcf27382a3 --- /dev/null +++ b/main.py @@ -0,0 +1,114 @@ +import multiprocessing +import argparse +import threading +import ssl +import time +import sys +import functools + +from multiprocessing import Process, Manager, Value, Queue + +from whisper_live.trt_server import TranscriptionServer +from llm_service import TensorRTLLMEngine +from tts_service import WhisperSpeechTTS + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--whisper_tensorrt_path', + type=str, + default="/root/TensorRT-LLM/examples/whisper/whisper_small_en", + help='Whisper TensorRT model path') + parser.add_argument('--mistral', + action="store_true", + help='Mistral') + parser.add_argument('--mistral_tensorrt_path', + type=str, + default=None, + help='Mistral TensorRT model path') + parser.add_argument('--mistral_tokenizer_path', + type=str, + default="teknium/OpenHermes-2.5-Mistral-7B", + help='Mistral TensorRT model path') + parser.add_argument('--phi', + action="store_true", + help='Phi') + parser.add_argument('--phi_tensorrt_path', + type=str, + default="/root/TensorRT-LLM/examples/phi/phi_engine", + help='Phi TensorRT model path') + parser.add_argument('--phi_tokenizer_path', + type=str, + default="/root/TensorRT-LLM/examples/phi/phi-2", + help='Phi Tokenizer path') + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + if not args.whisper_tensorrt_path: + raise ValueError("Please provide whisper_tensorrt_path to run the pipeline.") + import sys + sys.exit(0) + + if args.mistral: + if not args.mistral_tensorrt_path or not args.mistral_tokenizer_path: + raise ValueError("Please provide mistral_tensorrt_path and mistral_tokenizer_path to run the pipeline.") + import sys + sys.exit(0) + + if args.phi: + if not args.phi_tensorrt_path or not args.phi_tokenizer_path: + raise ValueError("Please provide phi_tensorrt_path and phi_tokenizer_path to run the pipeline.") + import sys + sys.exit(0) + + multiprocessing.set_start_method('spawn') + + lock = multiprocessing.Lock() + + manager = Manager() + shared_output = manager.list() + + transcription_queue = Queue() + llm_queue = Queue() + audio_queue = Queue() + + + whisper_server = TranscriptionServer() + whisper_process = multiprocessing.Process( + target=whisper_server.run, + args=( + "0.0.0.0", + 6006, + transcription_queue, + llm_queue, + args.whisper_tensorrt_path + ) + ) + whisper_process.start() + + llm_provider = TensorRTLLMEngine() + # llm_provider = MistralTensorRTLLMProvider() + llm_process = multiprocessing.Process( + target=llm_provider.run, + args=( + # args.mistral_tensorrt_path, + # args.mistral_tokenizer_path, + args.phi_tensorrt_path, + args.phi_tokenizer_path, + transcription_queue, + llm_queue, + audio_queue, + ) + ) + llm_process.start() + + # audio process + tts_runner = WhisperSpeechTTS() + tts_process = multiprocessing.Process(target=tts_runner.run, args=("0.0.0.0", 8888, audio_queue)) + tts_process.start() + + llm_process.join() + whisper_process.join() + tts_process.join() diff --git a/publish.sh b/publish.sh new file mode 100644 index 0000000000000000000000000000000000000000..fb3334dbde064ecae59ad05ba84db0c6f2d0be5a --- /dev/null +++ b/publish.sh @@ -0,0 +1,4 @@ +#!/bin/bash -e + +docker push ghcr.io/collabora/whisperfusion-base:latest +docker push ghcr.io/collabora/whisperfusion:latest diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5901f7cec6e1d8c522d4c2b42373428fcaa35279 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +faster-whisper==0.9.0 +websockets +onnxruntime==1.16.0 +ffmpeg-python +scipy +websocket-client +tiktoken==0.3.3 +kaldialign +braceexpand +openai-whisper +whisperspeech +soundfile diff --git a/run-whisperfusion.sh b/run-whisperfusion.sh new file mode 100644 index 0000000000000000000000000000000000000000..86f0c64d166af5695e74220bddc71636996c1502 --- /dev/null +++ b/run-whisperfusion.sh @@ -0,0 +1,16 @@ +#!/bin/bash -e + +test -f /etc/shinit_v2 && source /etc/shinit_v2 + +cd WhisperFusion +if [ "$1" != "mistral" ]; then + exec python3 main.py --phi \ + --whisper_tensorrt_path /root/whisper_small_en \ + --phi_tensorrt_path /root/dolphin-2_6-phi-2 \ + --phi_tokenizer_path /root/dolphin-2_6-phi-2/tokenizer +else + exec python3 main.py --mistral \ + --whisper_tensorrt_path /root/models/whisper_small_en \ + --mistral_tensorrt_path /root/models/mistral \ + --mistral_tokenizer_path teknium/OpenHermes-2.5-Mistral-7B +fi diff --git a/run_client.py b/run_client.py new file mode 100644 index 0000000000000000000000000000000000000000..b114ffbad6b1a6258b6793bbbd0d56aca0d5c794 --- /dev/null +++ b/run_client.py @@ -0,0 +1,7 @@ +from whisper_live.client import TranscriptionClient + +if __name__ == "__main__": + client = TranscriptionClient( + "0.0.0.0", "6006", is_multilingual=False, lang="en", translate=False + ) + client() # uses microphone audio \ No newline at end of file diff --git a/run_faster_whisper_server.py b/run_faster_whisper_server.py new file mode 100644 index 0000000000000000000000000000000000000000..281fe5ded8bc152d0cf0a2c3f3a1cf19a85147a6 --- /dev/null +++ b/run_faster_whisper_server.py @@ -0,0 +1,5 @@ +from whisper_live.server import TranscriptionServer + +if __name__ == "__main__": + server = TranscriptionServer() + server.run("0.0.0.0", 6006) \ No newline at end of file diff --git a/run_trt_server.py b/run_trt_server.py new file mode 100644 index 0000000000000000000000000000000000000000..24bfc49445c3ad77e68dbb0543888c95e03a6a13 --- /dev/null +++ b/run_trt_server.py @@ -0,0 +1,5 @@ +from whisper_live.trt_server import TranscriptionServer + +if __name__ == "__main__": + server = TranscriptionServer() + server.run("0.0.0.0", 6006) \ No newline at end of file diff --git a/setup-whisperfusion.sh b/setup-whisperfusion.sh new file mode 100644 index 0000000000000000000000000000000000000000..154d123ddf0ad397189154c13be9db7f6fc31e1d --- /dev/null +++ b/setup-whisperfusion.sh @@ -0,0 +1,27 @@ +#!/bin/bash -e + +## Clone this repo and install requirements +[ -d "WhisperFusion" ] || git clone https://github.com/collabora/WhisperFusion.git + +cd WhisperFusion +apt update +apt install ffmpeg portaudio19-dev -y + +## Install torchaudio matching the PyTorch from the base image +pip install --extra-index-url https://download.pytorch.org/whl/cu121 torchaudio + +## Install all the other dependencies normally +pip install -r requirements.txt + +## force update huggingface_hub (tokenizers 0.14.1 spuriously require and ancient <=0.18 version) +pip install -U huggingface_hub + +huggingface-cli download collabora/whisperspeech t2s-small-en+pl.model s2a-q4-tiny-en+pl.model +huggingface-cli download charactr/vocos-encodec-24khz + +mkdir -p /root/.cache/torch/hub/checkpoints/ +curl -L -o /root/.cache/torch/hub/checkpoints/encodec_24khz-d7cc33bc.th https://dl.fbaipublicfiles.com/encodec/v0/encodec_24khz-d7cc33bc.th +mkdir -p /root/.cache/whisper-live/ +curl -L -o /root/.cache/whisper-live/silero_vad.onnx https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx + +python -c 'from transformers.utils.hub import move_cache; move_cache()' diff --git a/setup.sh b/setup.sh new file mode 100644 index 0000000000000000000000000000000000000000..82648560c7db44568d2acf3b8d2ef7335cfbc586 --- /dev/null +++ b/setup.sh @@ -0,0 +1,6 @@ +#!/bin/bash -e + +./setup-whisper.sh +#./setup-mistral.sh +./setup-phi-2.sh +./setup-whisperfusion.sh diff --git a/tts_service.py b/tts_service.py new file mode 100644 index 0000000000000000000000000000000000000000..436b7ba327b1c7b7c23b2b9bcc821d22cf06d7a3 --- /dev/null +++ b/tts_service.py @@ -0,0 +1,68 @@ +import functools +import time +import logging +logging.basicConfig(level = logging.INFO) + +from websockets.sync.server import serve +from whisperspeech.pipeline import Pipeline + +class WhisperSpeechTTS: + def __init__(self): + pass + + def initialize_model(self): + self.pipe = Pipeline(s2a_ref='collabora/whisperspeech:s2a-q4-tiny-en+pl.model', torch_compile=True) + self.last_llm_response = None + + def run(self, host, port, audio_queue=None): + # initialize and warmup model + self.initialize_model() + for i in range(3): self.pipe.generate("Hello, I am warming up.") + + with serve( + functools.partial(self.start_whisperspeech_tts, audio_queue=audio_queue), + host, port + ) as server: + server.serve_forever() + + def start_whisperspeech_tts(self, websocket, audio_queue=None): + self.eos = False + self.output_audio = None + + while True: + llm_response = audio_queue.get() + if audio_queue.qsize() != 0: + continue + + # check if this websocket exists + try: + websocket.ping() + except Exception as e: + del websocket + audio_queue.put(llm_response) + break + + llm_output = llm_response["llm_output"][0] + self.eos = llm_response["eos"] + + def should_abort(): + if not audio_queue.empty(): raise TimeoutError() + + # only process if the output updated + if self.last_llm_response != llm_output.strip(): + try: + start = time.time() + audio = self.pipe.generate(llm_output.strip(), step_callback=should_abort) + inference_time = time.time() - start + logging.info(f"[WhisperSpeech INFO:] TTS inference done in {inference_time} ms.\n\n") + self.output_audio = audio.cpu().numpy() + self.last_llm_response = llm_output.strip() + except TimeoutError: + pass + + if self.eos and self.output_audio is not None: + try: + websocket.send(self.output_audio.tobytes()) + except Exception as e: + logging.error(f"[WhisperSpeech ERROR:] Audio error: {e}") + diff --git a/whisper_live/__init__.py b/whisper_live/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/whisper_live/client.py b/whisper_live/client.py new file mode 100644 index 0000000000000000000000000000000000000000..cee166897d7d227f745643ed4ac8dc2e95e5dd72 --- /dev/null +++ b/whisper_live/client.py @@ -0,0 +1,574 @@ +import os +import wave + +import numpy as np +import scipy +import ffmpeg +import pyaudio +import threading +import textwrap +import json +import websocket +import uuid +import time + + +def resample(file: str, sr: int = 16000): + """ + # https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/audio.py#L22 + Open an audio file and read as mono waveform, resampling as necessary, + save the resampled audio + + Args: + file (str): The audio file to open + sr (int): The sample rate to resample the audio if necessary + + Returns: + resampled_file (str): The resampled audio file + """ + try: + # This launches a subprocess to decode audio while down-mixing and resampling as necessary. + # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. + out, _ = ( + ffmpeg.input(file, threads=0) + .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) + .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) + ) + except ffmpeg.Error as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + np_buffer = np.frombuffer(out, dtype=np.int16) + + resampled_file = f"{file.split('.')[0]}_resampled.wav" + scipy.io.wavfile.write(resampled_file, sr, np_buffer.astype(np.int16)) + return resampled_file + + +class Client: + """ + Handles audio recording, streaming, and communication with a server using WebSocket. + """ + INSTANCES = {} + + def __init__( + self, host=None, port=None, is_multilingual=False, lang=None, translate=False, model_size="small" + ): + """ + Initializes a Client instance for audio recording and streaming to a server. + + If host and port are not provided, the WebSocket connection will not be established. + When translate is True, the task will be set to "translate" instead of "transcribe". + he audio recording starts immediately upon initialization. + + Args: + host (str): The hostname or IP address of the server. + port (int): The port number for the WebSocket server. + is_multilingual (bool, optional): Specifies if multilingual transcription is enabled. Default is False. + lang (str, optional): The selected language for transcription when multilingual is disabled. Default is None. + translate (bool, optional): Specifies if the task is translation. Default is False. + """ + self.chunk = 1024 * 3 + self.format = pyaudio.paInt16 + self.channels = 1 + self.rate = 16000 + self.record_seconds = 60000 + self.recording = False + self.multilingual = False + self.language = None + self.task = "transcribe" + self.uid = str(uuid.uuid4()) + self.waiting = False + self.last_response_recieved = None + self.disconnect_if_no_response_for = 15 + self.multilingual = is_multilingual + self.language = lang + self.model_size = model_size + self.server_error = False + if translate: + self.task = "translate" + + self.timestamp_offset = 0.0 + self.audio_bytes = None + self.p = pyaudio.PyAudio() + self.stream = self.p.open( + format=self.format, + channels=self.channels, + rate=self.rate, + input=True, + frames_per_buffer=self.chunk, + ) + + if host is not None and port is not None: + socket_url = f"ws://{host}:{port}" + self.client_socket = websocket.WebSocketApp( + socket_url, + on_open=lambda ws: self.on_open(ws), + on_message=lambda ws, message: self.on_message(ws, message), + on_error=lambda ws, error: self.on_error(ws, error), + on_close=lambda ws, close_status_code, close_msg: self.on_close( + ws, close_status_code, close_msg + ), + ) + else: + print("[ERROR]: No host or port specified.") + return + + Client.INSTANCES[self.uid] = self + + # start websocket client in a thread + self.ws_thread = threading.Thread(target=self.client_socket.run_forever) + self.ws_thread.setDaemon(True) + self.ws_thread.start() + + self.frames = b"" + print("[INFO]: * recording") + + # TTS audio websocket client + socket_url = f"ws://{host}:8888" + self.tts_client_socket = websocket.WebSocketApp( + socket_url, + on_open=lambda ws: self.on_open_tts(ws), + on_message=lambda ws, message: self.on_message_tts(ws, message), + on_error=lambda ws, error: self.on_error_tts(ws, error), + on_close=lambda ws, close_status_code, close_msg: self.on_close_tts( + ws, close_status_code, close_msg + ), + ) + + self.tts_ws_thread = threading.Thread(target=self.tts_client_socket.run_forever) + self.tts_ws_thread.setDaemon(True) + self.tts_ws_thread.start() + + def on_message(self, ws, message): + """ + Callback function called when a message is received from the server. + + It updates various attributes of the client based on the received message, including + recording status, language detection, and server messages. If a disconnect message + is received, it sets the recording status to False. + + Args: + ws (websocket.WebSocketApp): The WebSocket client instance. + message (str): The received message from the server. + + """ + self.last_response_recieved = time.time() + message = json.loads(message) + + if self.uid != message.get("uid"): + print("[ERROR]: invalid client uid") + return + + if "status" in message.keys(): + if message["status"] == "WAIT": + self.waiting = True + print( + f"[INFO]:Server is full. Estimated wait time {round(message['message'])} minutes." + ) + elif message["status"] == "ERROR": + print(f"Message from Server: {message['message']}") + self.server_error = True + return + + if "message" in message.keys() and message["message"] == "DISCONNECT": + print("[INFO]: Server overtime disconnected.") + self.recording = False + + if "message" in message.keys() and message["message"] == "SERVER_READY": + self.recording = True + return + + if "language" in message.keys(): + self.language = message.get("language") + lang_prob = message.get("language_prob") + print( + f"[INFO]: Server detected language {self.language} with probability {lang_prob}" + ) + return + + if "llm_output" in message.keys(): + print("LLM output: ") + for item in message["llm_output"]: + print(item) + + + if "segments" not in message.keys(): + return + + message = message["segments"] + text = [] + print(message) + if len(message): + for seg in message: + if text and text[-1] == seg["text"]: + # already got it + continue + text.append(seg["text"]) + # keep only last 3 + if len(text) > 3: + text = text[-3:] + wrapper = textwrap.TextWrapper(width=60) + word_list = wrapper.wrap(text="".join(text)) + # Print each line. + # if os.name == "nt": + # os.system("cls") + # else: + # os.system("clear") + for element in word_list: + print(element) + + def on_error(self, ws, error): + print(error) + + def on_close(self, ws, close_status_code, close_msg): + print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}") + + def on_open(self, ws): + """ + Callback function called when the WebSocket connection is successfully opened. + + Sends an initial configuration message to the server, including client UID, multilingual mode, + language selection, and task type. + + Args: + ws (websocket.WebSocketApp): The WebSocket client instance. + + """ + print(self.multilingual, self.language, self.task) + + print("[INFO]: Opened connection") + ws.send( + json.dumps( + { + "uid": self.uid, + "multilingual": self.multilingual, + "language": self.language, + "task": self.task, + "model_size": self.model_size, + } + ) + ) + + def on_open_tts(self): + pass + + def on_message_tts(self, ws, message): + # print(message) + print(type(message)) + self.write_audio_frames_to_file(message.tobytes(), "tts_out.wav", rate=24000) + pass + + def on_error_tts(self, ws, error): + print(error) + + def on_close_tts(self, ws, close_status_code, close_msg): + print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}") + + @staticmethod + def bytes_to_float_array(audio_bytes): + """ + Convert audio data from bytes to a NumPy float array. + + It assumes that the audio data is in 16-bit PCM format. The audio data is normalized to + have values between -1 and 1. + + Args: + audio_bytes (bytes): Audio data in bytes. + + Returns: + np.ndarray: A NumPy array containing the audio data as float values normalized between -1 and 1. + """ + raw_data = np.frombuffer(buffer=audio_bytes, dtype=np.int16) + return raw_data.astype(np.float32) / 32768.0 + + def send_packet_to_server(self, message): + """ + Send an audio packet to the server using WebSocket. + + Args: + message (bytes): The audio data packet in bytes to be sent to the server. + + """ + try: + self.client_socket.send(message, websocket.ABNF.OPCODE_BINARY) + except Exception as e: + print(e) + + def play_file(self, filename): + """ + Play an audio file and send it to the server for processing. + + Reads an audio file, plays it through the audio output, and simultaneously sends + the audio data to the server for processing. It uses PyAudio to create an audio + stream for playback. The audio data is read from the file in chunks, converted to + floating-point format, and sent to the server using WebSocket communication. + This method is typically used when you want to process pre-recorded audio and send it + to the server in real-time. + + Args: + filename (str): The path to the audio file to be played and sent to the server. + """ + + # read audio and create pyaudio stream + with wave.open(filename, "rb") as wavfile: + self.stream = self.p.open( + format=self.p.get_format_from_width(wavfile.getsampwidth()), + channels=wavfile.getnchannels(), + rate=wavfile.getframerate(), + input=True, + output=True, + frames_per_buffer=self.chunk, + ) + try: + while self.recording: + data = wavfile.readframes(self.chunk) + if data == b"": + break + + audio_array = self.bytes_to_float_array(data) + self.send_packet_to_server(audio_array.tobytes()) + self.stream.write(data) + + wavfile.close() + + assert self.last_response_recieved + while time.time() - self.last_response_recieved < self.disconnect_if_no_response_for: + continue + self.stream.close() + self.close_websocket() + + except KeyboardInterrupt: + wavfile.close() + self.stream.stop_stream() + self.stream.close() + self.p.terminate() + self.close_websocket() + print("[INFO]: Keyboard interrupt.") + + def close_websocket(self): + """ + Close the WebSocket connection and join the WebSocket thread. + + First attempts to close the WebSocket connection using `self.client_socket.close()`. After + closing the connection, it joins the WebSocket thread to ensure proper termination. + + """ + try: + self.client_socket.close() + except Exception as e: + print("[ERROR]: Error closing WebSocket:", e) + + try: + self.ws_thread.join() + except Exception as e: + print("[ERROR:] Error joining WebSocket thread:", e) + + def get_client_socket(self): + """ + Get the WebSocket client socket instance. + + Returns: + WebSocketApp: The WebSocket client socket instance currently in use by the client. + """ + return self.client_socket + + def write_audio_frames_to_file(self, frames, file_name, rate=None): + """ + Write audio frames to a WAV file. + + The WAV file is created or overwritten with the specified name. The audio frames should be + in the correct format and match the specified channel, sample width, and sample rate. + + Args: + frames (bytes): The audio frames to be written to the file. + file_name (str): The name of the WAV file to which the frames will be written. + + """ + with wave.open(file_name, "wb") as wavfile: + wavfile: wave.Wave_write + wavfile.setnchannels(self.channels) + wavfile.setsampwidth(2) + wavfile.setframerate(self.rate if rate is None else rate) + wavfile.writeframes(frames) + + def process_hls_stream(self, hls_url): + """ + Connect to an HLS source, process the audio stream, and send it for transcription. + + Args: + hls_url (str): The URL of the HLS stream source. + """ + print("[INFO]: Connecting to HLS stream...") + process = None # Initialize process to None + + try: + # Connecting to the HLS stream using ffmpeg-python + process = ( + ffmpeg + .input(hls_url, threads=0) + .output('-', format='s16le', acodec='pcm_s16le', ac=1, ar=self.rate) + .run_async(pipe_stdout=True, pipe_stderr=True) + ) + + # Process the stream + while True: + in_bytes = process.stdout.read(self.chunk * 2) # 2 bytes per sample + if not in_bytes: + break + audio_array = self.bytes_to_float_array(in_bytes) + self.send_packet_to_server(audio_array.tobytes()) + + except Exception as e: + print(f"[ERROR]: Failed to connect to HLS stream: {e}") + finally: + if process: + process.kill() + + print("[INFO]: HLS stream processing finished.") + + + def record(self, out_file="output_recording.wav"): + """ + Record audio data from the input stream and save it to a WAV file. + + Continuously records audio data from the input stream, sends it to the server via a WebSocket + connection, and simultaneously saves it to multiple WAV files in chunks. It stops recording when + the `RECORD_SECONDS` duration is reached or when the `RECORDING` flag is set to `False`. + + Audio data is saved in chunks to the "chunks" directory. Each chunk is saved as a separate WAV file. + The recording will continue until the specified duration is reached or until the `RECORDING` flag is set to `False`. + The recording process can be interrupted by sending a KeyboardInterrupt (e.g., pressing Ctrl+C). After recording, + the method combines all the saved audio chunks into the specified `out_file`. + + Args: + out_file (str, optional): The name of the output WAV file to save the entire recording. Default is "output_recording.wav". + + """ + n_audio_file = 0 + if not os.path.exists("chunks"): + os.makedirs("chunks", exist_ok=True) + try: + for _ in range(0, int(self.rate / self.chunk * self.record_seconds)): + if not self.recording: + break + data = self.stream.read(self.chunk) + self.frames += data + + audio_array = Client.bytes_to_float_array(data) + + self.send_packet_to_server(audio_array.tobytes()) + + # save frames if more than a minute + if len(self.frames) > 60 * self.rate: + t = threading.Thread( + target=self.write_audio_frames_to_file, + args=( + self.frames[:], + f"chunks/{n_audio_file}.wav", + ), + ) + t.start() + n_audio_file += 1 + self.frames = b"" + + except KeyboardInterrupt: + if len(self.frames): + self.write_audio_frames_to_file( + self.frames[:], f"chunks/{n_audio_file}.wav" + ) + n_audio_file += 1 + self.stream.stop_stream() + self.stream.close() + self.p.terminate() + self.close_websocket() + + self.write_output_recording(n_audio_file, out_file) + + def write_output_recording(self, n_audio_file, out_file): + """ + Combine and save recorded audio chunks into a single WAV file. + + The individual audio chunk files are expected to be located in the "chunks" directory. Reads each chunk + file, appends its audio data to the final recording, and then deletes the chunk file. After combining + and saving, the final recording is stored in the specified `out_file`. + + + Args: + n_audio_file (int): The number of audio chunk files to combine. + out_file (str): The name of the output WAV file to save the final recording. + + """ + input_files = [ + f"chunks/{i}.wav" + for i in range(n_audio_file) + if os.path.exists(f"chunks/{i}.wav") + ] + with wave.open(out_file, "wb") as wavfile: + wavfile: wave.Wave_write + wavfile.setnchannels(self.channels) + wavfile.setsampwidth(2) + wavfile.setframerate(self.rate) + for in_file in input_files: + with wave.open(in_file, "rb") as wav_in: + while True: + data = wav_in.readframes(self.chunk) + if data == b"": + break + wavfile.writeframes(data) + # remove this file + os.remove(in_file) + wavfile.close() + + +class TranscriptionClient: + """ + Client for handling audio transcription tasks via a WebSocket connection. + + Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used + to send audio data for transcription to a server and receive transcribed text segments. + + Args: + host (str): The hostname or IP address of the server. + port (int): The port number to connect to on the server. + is_multilingual (bool, optional): Indicates whether the transcription should support multiple languages (default is False). + lang (str, optional): The primary language for transcription (used if `is_multilingual` is False). Default is None, which defaults to English ('en'). + translate (bool, optional): Indicates whether translation tasks are required (default is False). + + Attributes: + client (Client): An instance of the underlying Client class responsible for handling the WebSocket connection. + + Example: + To create a TranscriptionClient and start transcription on microphone audio: + ```python + transcription_client = TranscriptionClient(host="localhost", port=9090, is_multilingual=True) + transcription_client() + ``` + """ + def __init__(self, host, port, is_multilingual=False, lang=None, translate=False, model_size="small"): + self.client = Client(host, port, is_multilingual, lang, translate, model_size) + + def __call__(self, audio=None, hls_url=None): + """ + Start the transcription process. + + Initiates the transcription process by connecting to the server via a WebSocket. It waits for the server + to be ready to receive audio data and then sends audio for transcription. If an audio file is provided, it + will be played and streamed to the server; otherwise, it will perform live recording. + + Args: + audio (str, optional): Path to an audio file for transcription. Default is None, which triggers live recording. + + """ + print("[INFO]: Waiting for server ready ...") + while not self.client.recording: + if self.client.waiting or self.client.server_error: + self.client.close_websocket() + return + + print("[INFO]: Server Ready!") + if hls_url is not None: + self.client.process_hls_stream(hls_url) + elif audio is not None: + resampled_file = resample(audio) + self.client.play_file(resampled_file) + else: + self.client.record() diff --git a/whisper_live/server.py b/whisper_live/server.py new file mode 100644 index 0000000000000000000000000000000000000000..0ec4346bf4acae6c81cea12a54d019db78e152b2 --- /dev/null +++ b/whisper_live/server.py @@ -0,0 +1,498 @@ +import websockets +import time +import threading +import json +import textwrap + +import logging +logging.basicConfig(level = logging.INFO) + +from websockets.sync.server import serve + +import torch +import numpy as np +import time +from whisper_live.transcriber import WhisperModel + + +class TranscriptionServer: + """ + Represents a transcription server that handles incoming audio from clients. + + Attributes: + RATE (int): The audio sampling rate (constant) set to 16000. + vad_model (torch.Module): The voice activity detection model. + vad_threshold (float): The voice activity detection threshold. + clients (dict): A dictionary to store connected clients. + websockets (dict): A dictionary to store WebSocket connections. + clients_start_time (dict): A dictionary to track client start times. + max_clients (int): Maximum allowed connected clients. + max_connection_time (int): Maximum allowed connection time in seconds. + """ + + RATE = 16000 + + def __init__(self): + # voice activity detection model + + self.clients = {} + self.websockets = {} + self.clients_start_time = {} + self.max_clients = 4 + self.max_connection_time = 600 + + def get_wait_time(self): + """ + Calculate and return the estimated wait time for clients. + + Returns: + float: The estimated wait time in minutes. + """ + wait_time = None + + for k, v in self.clients_start_time.items(): + current_client_time_remaining = self.max_connection_time - (time.time() - v) + + if wait_time is None or current_client_time_remaining < wait_time: + wait_time = current_client_time_remaining + + return wait_time / 60 + + def recv_audio(self, websocket): + """ + Receive audio chunks from a client in an infinite loop. + + Continuously receives audio frames from a connected client + over a WebSocket connection. It processes the audio frames using a + voice activity detection (VAD) model to determine if they contain speech + or not. If the audio frame contains speech, it is added to the client's + audio data for ASR. + If the maximum number of clients is reached, the method sends a + "WAIT" status to the client, indicating that they should wait + until a slot is available. + If a client's connection exceeds the maximum allowed time, it will + be disconnected, and the client's resources will be cleaned up. + + Args: + websocket (WebSocket): The WebSocket connection for the client. + + Raises: + Exception: If there is an error during the audio frame processing. + """ + logging.info("New client connected") + options = websocket.recv() + options = json.loads(options) + + if len(self.clients) >= self.max_clients: + logging.warning("Client Queue Full. Asking client to wait ...") + wait_time = self.get_wait_time() + response = { + "uid": options["uid"], + "status": "WAIT", + "message": wait_time, + } + websocket.send(json.dumps(response)) + websocket.close() + del websocket + return + + client = ServeClient( + websocket, + multilingual=options["multilingual"], + language=options["language"], + task=options["task"], + client_uid=options["uid"] + ) + + self.clients[websocket] = client + self.clients_start_time[websocket] = time.time() + + while True: + try: + frame_data = websocket.recv() + frame_np = np.frombuffer(frame_data, dtype=np.float32) + + self.clients[websocket].add_frames(frame_np) + + elapsed_time = time.time() - self.clients_start_time[websocket] + if elapsed_time >= self.max_connection_time: + self.clients[websocket].disconnect() + logging.warning(f"{self.clients[websocket]} Client disconnected due to overtime.") + self.clients[websocket].cleanup() + self.clients.pop(websocket) + self.clients_start_time.pop(websocket) + websocket.close() + del websocket + break + + except Exception as e: + logging.error(e) + self.clients[websocket].cleanup() + self.clients.pop(websocket) + self.clients_start_time.pop(websocket) + logging.info("Connection Closed.") + logging.info(self.clients) + del websocket + break + + def run(self, host, port=9090): + """ + Run the transcription server. + + Args: + host (str): The host address to bind the server. + port (int): The port number to bind the server. + """ + with serve(self.recv_audio, host, port) as server: + server.serve_forever() + + +class ServeClient: + """ + Attributes: + RATE (int): The audio sampling rate (constant) set to 16000. + SERVER_READY (str): A constant message indicating that the server is ready. + DISCONNECT (str): A constant message indicating that the client should disconnect. + client_uid (str): A unique identifier for the client. + data (bytes): Accumulated audio data. + frames (bytes): Accumulated audio frames. + language (str): The language for transcription. + task (str): The task type, e.g., "transcribe." + transcriber (WhisperModel): The Whisper model for speech-to-text. + timestamp_offset (float): The offset in audio timestamps. + frames_np (numpy.ndarray): NumPy array to store audio frames. + frames_offset (float): The offset in audio frames. + text (list): List of transcribed text segments. + current_out (str): The current incomplete transcription. + prev_out (str): The previous incomplete transcription. + t_start (float): Timestamp for the start of transcription. + exit (bool): A flag to exit the transcription thread. + same_output_threshold (int): Threshold for consecutive same output segments. + show_prev_out_thresh (int): Threshold for showing previous output segments. + add_pause_thresh (int): Threshold for adding a pause (blank) segment. + transcript (list): List of transcribed segments. + send_last_n_segments (int): Number of last segments to send to the client. + wrapper (textwrap.TextWrapper): Text wrapper for formatting text. + pick_previous_segments (int): Number of previous segments to include in the output. + websocket: The WebSocket connection for the client. + """ + RATE = 16000 + SERVER_READY = "SERVER_READY" + DISCONNECT = "DISCONNECT" + + def __init__(self, websocket, task="transcribe", device=None, multilingual=False, language=None, client_uid=None): + """ + Initialize a ServeClient instance. + The Whisper model is initialized based on the client's language and device availability. + The transcription thread is started upon initialization. A "SERVER_READY" message is sent + to the client to indicate that the server is ready. + + Args: + websocket (WebSocket): The WebSocket connection for the client. + task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe". + device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None. + multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False. + language (str, optional): The language for transcription. Defaults to None. + client_uid (str, optional): A unique identifier for the client. Defaults to None. + + """ + self.client_uid = client_uid + self.data = b"" + self.frames = b"" + self.language = language if multilingual else "en" + self.task = task + device = "cuda" if torch.cuda.is_available() else "cpu" + self.transcriber = WhisperModel( + "small" if multilingual else "small.en", + device=device, + compute_type="int8" if device=="cpu" else "float16", + local_files_only=False, + ) + + self.timestamp_offset = 0.0 + self.frames_np = None + self.frames_offset = 0.0 + self.text = [] + self.current_out = '' + self.prev_out = '' + self.t_start=None + self.exit = False + self.same_output_threshold = 0 + self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds + self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds + self.transcript = [] + self.send_last_n_segments = 10 + + # text formatting + self.wrapper = textwrap.TextWrapper(width=50) + self.pick_previous_segments = 2 + + # threading + self.websocket = websocket + self.trans_thread = threading.Thread(target=self.speech_to_text) + self.trans_thread.start() + self.websocket.send( + json.dumps( + { + "uid": self.client_uid, + "message": self.SERVER_READY + } + ) + ) + + def fill_output(self, output): + """ + Format the current incomplete transcription output by combining it with previous complete segments. + The resulting transcription is wrapped into two lines, each containing a maximum of 50 characters. + + It ensures that the combined transcription fits within two lines, with a maximum of 50 characters per line. + Segments are concatenated in the order they exist in the list of previous segments, with the most + recent complete segment first and older segments prepended as needed to maintain the character limit. + If a 3-second pause is detected in the previous segments, any text preceding it is discarded to ensure + the transcription starts with the most recent complete content. The resulting transcription is returned + as a single string. + + Args: + output(str): The current incomplete transcription segment. + + Returns: + str: A formatted transcription wrapped in two lines. + """ + text = '' + pick_prev = min(len(self.text), self.pick_previous_segments) + for seg in self.text[-pick_prev:]: + # discard everything before a 3 second pause + if seg == '': + text = '' + else: + text += seg + wrapped = "".join(text + output) + return wrapped + + def add_frames(self, frame_np): + """ + Add audio frames to the ongoing audio stream buffer. + + This method is responsible for maintaining the audio stream buffer, allowing the continuous addition + of audio frames as they are received. It also ensures that the buffer does not exceed a specified size + to prevent excessive memory usage. + + If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds + of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided + audio frame. The audio stream buffer is used for real-time processing of audio data for transcription. + + Args: + frame_np (numpy.ndarray): The audio frame data as a NumPy array. + + """ + if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE: + self.frames_offset += 30.0 + self.frames_np = self.frames_np[int(30*self.RATE):] + if self.frames_np is None: + self.frames_np = frame_np.copy() + else: + self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0) + + def speech_to_text(self): + """ + Process an audio stream in an infinite loop, continuously transcribing the speech. + + This method continuously receives audio frames, performs real-time transcription, and sends + transcribed segments to the client via a WebSocket connection. + + If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction. + It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments + are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech + (no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if + there is no speech for a specified duration to indicate a pause. + + Raises: + Exception: If there is an issue with audio processing or WebSocket communication. + + """ + while True: + if self.exit: + logging.info("Exiting speech to text thread") + break + + if self.frames_np is None: + continue + + # clip audio if the current chunk exceeds 30 seconds, this basically implies that + # no valid segment for the last 30 seconds from whisper + if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE: + duration = self.frames_np.shape[0] / self.RATE + self.timestamp_offset = self.frames_offset + duration - 5 + + samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE) + input_bytes = self.frames_np[int(samples_take):].copy() + duration = input_bytes.shape[0] / self.RATE + if duration<1.0: + continue + try: + input_sample = input_bytes.copy() + + # whisper transcribe with prompt + result, info = self.transcriber.transcribe( + input_sample, + initial_prompt=None, + language=self.language, + task=self.task, + vad_filter=True, + vad_parameters={"threshold": 0.5} + ) + + if self.language is None: + if info.language_probability > 0.5: + self.language = info.language + logging.info(f"Detected language {self.language} with probability {info.language_probability}") + self.websocket.send(json.dumps( + {"uid": self.client_uid, "language": self.language, "language_prob": info.language_probability})) + else: + # detect language again + continue + + if len(result): + self.t_start = None + last_segment = self.update_segments(result, duration) + if len(self.transcript) < self.send_last_n_segments: + segments = self.transcript + else: + segments = self.transcript[-self.send_last_n_segments:] + if last_segment is not None: + segments = segments + [last_segment] + else: + # show previous output if there is pause i.e. no output from whisper + segments = [] + if self.t_start is None: self.t_start = time.time() + if time.time() - self.t_start < self.show_prev_out_thresh: + if len(self.transcript) < self.send_last_n_segments: + segments = self.transcript + else: + segments = self.transcript[-self.send_last_n_segments:] + + # add a blank if there is no speech for 3 seconds + if len(self.text) and self.text[-1] != '': + if time.time() - self.t_start > self.add_pause_thresh: + self.text.append('') + + try: + self.websocket.send( + json.dumps({ + "uid": self.client_uid, + "segments": segments + }) + ) + except Exception as e: + logging.error(f"[ERROR]: {e}") + + except Exception as e: + logging.error(f"[ERROR]: {e}") + time.sleep(0.01) + + def update_segments(self, segments, duration): + """ + Processes the segments from whisper. Appends all the segments to the list + except for the last segment assuming that it is incomplete. + + Updates the ongoing transcript with transcribed segments, including their start and end times. + Complete segments are appended to the transcript in chronological order. Incomplete segments + (assumed to be the last one) are processed to identify repeated content. If the same incomplete + segment is seen multiple times, it updates the offset and appends the segment to the transcript. + A threshold is used to detect repeated content and ensure it is only included once in the transcript. + The timestamp offset is updated based on the duration of processed segments. The method returns the + last processed segment, allowing it to be sent to the client for real-time updates. + + Args: + segments(dict) : dictionary of segments as returned by whisper + duration(float): duration of the current chunk + + Returns: + dict or None: The last processed segment with its start time, end time, and transcribed text. + Returns None if there are no valid segments to process. + """ + offset = None + self.current_out = '' + last_segment = None + # process complete segments + if len(segments) > 1: + for i, s in enumerate(segments[:-1]): + text_ = s.text + self.text.append(text_) + start, end = self.timestamp_offset + s.start, self.timestamp_offset + min(duration, s.end) + self.transcript.append( + { + 'start': start, + 'end': end, + 'text': text_ + } + ) + + offset = min(duration, s.end) + + self.current_out += segments[-1].text + last_segment = { + 'start': self.timestamp_offset + segments[-1].start, + 'end': self.timestamp_offset + min(duration, segments[-1].end), + 'text': self.current_out + } + + # if same incomplete segment is seen multiple times then update the offset + # and append the segment to the list + if self.current_out.strip() == self.prev_out.strip() and self.current_out != '': + self.same_output_threshold += 1 + else: + self.same_output_threshold = 0 + + if self.same_output_threshold > 5: + if not len(self.text) or self.text[-1].strip().lower()!=self.current_out.strip().lower(): + self.text.append(self.current_out) + self.transcript.append( + { + 'start': self.timestamp_offset, + 'end': self.timestamp_offset + duration, + 'text': self.current_out + } + ) + self.current_out = '' + offset = duration + self.same_output_threshold = 0 + last_segment = None + else: + self.prev_out = self.current_out + + # update offset + if offset is not None: + self.timestamp_offset += offset + + return last_segment + + def disconnect(self): + """ + Notify the client of disconnection and send a disconnect message. + + This method sends a disconnect message to the client via the WebSocket connection to notify them + that the transcription service is disconnecting gracefully. + + """ + self.websocket.send( + json.dumps( + { + "uid": self.client_uid, + "message": self.DISCONNECT + } + ) + ) + + def cleanup(self): + """ + Perform cleanup tasks before exiting the transcription service. + + This method performs necessary cleanup tasks, including stopping the transcription thread, marking + the exit flag to indicate the transcription thread should exit gracefully, and destroying resources + associated with the transcription process. + + """ + logging.info("Cleaning up.") + self.exit = True + self.transcriber.destroy() diff --git a/whisper_live/transcriber.py b/whisper_live/transcriber.py new file mode 100644 index 0000000000000000000000000000000000000000..a275878bc1b37672ead3476101eb8e5c47e599a6 --- /dev/null +++ b/whisper_live/transcriber.py @@ -0,0 +1,1023 @@ +# original https://github.com/guillaumekln/faster-whisper/blob/master/faster_whisper/transcribe.py + +import itertools +import logging +import os +import zlib + +from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union + +import ctranslate2 +import numpy as np +import tokenizers + +from faster_whisper.audio import decode_audio +from faster_whisper.feature_extractor import FeatureExtractor +from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer +from faster_whisper.utils import download_model, format_timestamp, get_logger +from faster_whisper.vad import ( + SpeechTimestampsMap, + VadOptions, + collect_chunks, + get_speech_timestamps, +) + + +class Word(NamedTuple): + start: float + end: float + word: str + probability: float + + +class Segment(NamedTuple): + id: int + seek: int + start: float + end: float + text: str + tokens: List[int] + temperature: float + avg_logprob: float + compression_ratio: float + no_speech_prob: float + words: Optional[List[Word]] + + +class TranscriptionOptions(NamedTuple): + beam_size: int + best_of: int + patience: float + length_penalty: float + repetition_penalty: float + no_repeat_ngram_size: int + log_prob_threshold: Optional[float] + no_speech_threshold: Optional[float] + compression_ratio_threshold: Optional[float] + condition_on_previous_text: bool + prompt_reset_on_temperature: float + temperatures: List[float] + initial_prompt: Optional[Union[str, Iterable[int]]] + prefix: Optional[str] + suppress_blank: bool + suppress_tokens: Optional[List[int]] + without_timestamps: bool + max_initial_timestamp: float + word_timestamps: bool + prepend_punctuations: str + append_punctuations: str + + +class TranscriptionInfo(NamedTuple): + language: str + language_probability: float + duration: float + duration_after_vad: float + all_language_probs: Optional[List[Tuple[str, float]]] + transcription_options: TranscriptionOptions + vad_options: VadOptions + + +class WhisperModel: + def __init__( + self, + model_size_or_path: str, + device: str = "auto", + device_index: Union[int, List[int]] = 0, + compute_type: str = "default", + cpu_threads: int = 0, + num_workers: int = 1, + download_root: Optional[str] = None, + local_files_only: bool = False, + ): + """Initializes the Whisper model. + + Args: + model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en, + small, small.en, medium, medium.en, large-v1, large-v2, or large), a path to a converted + model directory, or a CTranslate2-converted Whisper model ID from the Hugging Face Hub. + When a size or a model ID is configured, the converted model is downloaded + from the Hugging Face Hub. + device: Device to use for computation ("cpu", "cuda", "auto"). + device_index: Device ID to use. + The model can also be loaded on multiple GPUs by passing a list of IDs + (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel + when transcribe() is called from multiple Python threads (see also num_workers). + compute_type: Type to use for computation. + See https://opennmt.net/CTranslate2/quantization.html. + cpu_threads: Number of threads to use when running on CPU (4 by default). + A non zero value overrides the OMP_NUM_THREADS environment variable. + num_workers: When transcribe() is called from multiple Python threads, + having multiple workers enables true parallelism when running the model + (concurrent calls to self.model.generate() will run in parallel). + This can improve the global throughput at the cost of increased memory usage. + download_root: Directory where the models should be saved. If not set, the models + are saved in the standard Hugging Face cache directory. + local_files_only: If True, avoid downloading the file and return the path to the + local cached file if it exists. + """ + self.logger = get_logger() + + if os.path.isdir(model_size_or_path): + model_path = model_size_or_path + else: + model_path = download_model( + model_size_or_path, + local_files_only=local_files_only, + cache_dir=download_root, + ) + + self.model = ctranslate2.models.Whisper( + model_path, + device=device, + device_index=device_index, + compute_type=compute_type, + intra_threads=cpu_threads, + inter_threads=num_workers, + ) + + tokenizer_file = os.path.join(model_path, "tokenizer.json") + if os.path.isfile(tokenizer_file): + self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file) + else: + self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained( + "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") + ) + + self.feature_extractor = FeatureExtractor() + self.num_samples_per_token = self.feature_extractor.hop_length * 2 + self.frames_per_second = ( + self.feature_extractor.sampling_rate // self.feature_extractor.hop_length + ) + self.tokens_per_second = ( + self.feature_extractor.sampling_rate // self.num_samples_per_token + ) + self.input_stride = 2 + self.time_precision = 0.02 + self.max_length = 448 + + @property + def supported_languages(self) -> List[str]: + """The languages supported by the model.""" + return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"] + + def transcribe( + self, + audio: Union[str, BinaryIO, np.ndarray], + language: Optional[str] = None, + task: str = "transcribe", + beam_size: int = 5, + best_of: int = 5, + patience: float = 1, + length_penalty: float = 1, + repetition_penalty: float = 1, + no_repeat_ngram_size: int = 0, + temperature: Union[float, List[float], Tuple[float, ...]] = [ + 0.0, + 0.2, + 0.4, + 0.6, + 0.8, + 1.0, + ], + compression_ratio_threshold: Optional[float] = 2.4, + log_prob_threshold: Optional[float] = -1.0, + no_speech_threshold: Optional[float] = 0.6, + condition_on_previous_text: bool = True, + prompt_reset_on_temperature: float = 0.5, + initial_prompt: Optional[Union[str, Iterable[int]]] = None, + prefix: Optional[str] = None, + suppress_blank: bool = True, + suppress_tokens: Optional[List[int]] = [-1], + without_timestamps: bool = False, + max_initial_timestamp: float = 1.0, + word_timestamps: bool = False, + prepend_punctuations: str = "\"'“¿([{-", + append_punctuations: str = "\"'.。,,!!??::”)]}、", + vad_filter: bool = False, + vad_parameters: Optional[Union[dict, VadOptions]] = None, + ) -> Tuple[Iterable[Segment], TranscriptionInfo]: + """Transcribes an input file. + + Arguments: + audio: Path to the input file (or a file-like object), or the audio waveform. + language: The language spoken in the audio. It should be a language code such + as "en" or "fr". If not set, the language will be detected in the first 30 seconds + of audio. + task: Task to execute (transcribe or translate). + beam_size: Beam size to use for decoding. + best_of: Number of candidates when sampling with non-zero temperature. + patience: Beam search patience factor. + length_penalty: Exponential length penalty constant. + repetition_penalty: Penalty applied to the score of previously generated tokens + (set > 1 to penalize). + no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). + temperature: Temperature for sampling. It can be a tuple of temperatures, + which will be successively used upon failures according to either + `compression_ratio_threshold` or `log_prob_threshold`. + compression_ratio_threshold: If the gzip compression ratio is above this value, + treat as failed. + log_prob_threshold: If the average log probability over sampled tokens is + below this value, treat as failed. + no_speech_threshold: If the no_speech probability is higher than this value AND + the average log probability over sampled tokens is below `log_prob_threshold`, + consider the segment as silent. + condition_on_previous_text: If True, the previous output of the model is provided + as a prompt for the next window; disabling may make the text inconsistent across + windows, but the model becomes less prone to getting stuck in a failure loop, + such as repetition looping or timestamps going out of sync. + prompt_reset_on_temperature: Resets prompt if temperature is above this value. + Arg has effect only if condition_on_previous_text is True. + initial_prompt: Optional text string or iterable of token ids to provide as a + prompt for the first window. + prefix: Optional text to provide as a prefix for the first window. + suppress_blank: Suppress blank outputs at the beginning of the sampling. + suppress_tokens: List of token IDs to suppress. -1 will suppress a default set + of symbols as defined in the model config.json file. + without_timestamps: Only sample text tokens. + max_initial_timestamp: The initial timestamp cannot be later than this. + word_timestamps: Extract word-level timestamps using the cross-attention pattern + and dynamic time warping, and include the timestamps for each word in each segment. + prepend_punctuations: If word_timestamps is True, merge these punctuation symbols + with the next word + append_punctuations: If word_timestamps is True, merge these punctuation symbols + with the previous word + vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio + without speech. This step is using the Silero VAD model + https://github.com/snakers4/silero-vad. + vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available + parameters and default values in the class `VadOptions`). + + Returns: + A tuple with: + + - a generator over transcribed segments + - an instance of TranscriptionInfo + """ + sampling_rate = self.feature_extractor.sampling_rate + + if not isinstance(audio, np.ndarray): + audio = decode_audio(audio, sampling_rate=sampling_rate) + + duration = audio.shape[0] / sampling_rate + duration_after_vad = duration + + self.logger.info( + "Processing audio with duration %s", format_timestamp(duration) + ) + + if vad_filter: + if vad_parameters is None: + vad_parameters = VadOptions() + elif isinstance(vad_parameters, dict): + vad_parameters = VadOptions(**vad_parameters) + speech_chunks = get_speech_timestamps(audio, vad_parameters) + audio = collect_chunks(audio, speech_chunks) + duration_after_vad = audio.shape[0] / sampling_rate + + self.logger.info( + "VAD filter removed %s of audio", + format_timestamp(duration - duration_after_vad), + ) + + if self.logger.isEnabledFor(logging.DEBUG): + self.logger.debug( + "VAD filter kept the following audio segments: %s", + ", ".join( + "[%s -> %s]" + % ( + format_timestamp(chunk["start"] / sampling_rate), + format_timestamp(chunk["end"] / sampling_rate), + ) + for chunk in speech_chunks + ), + ) + + else: + speech_chunks = None + + features = self.feature_extractor(audio) + + encoder_output = None + all_language_probs = None + + if language is None: + if not self.model.is_multilingual: + language = "en" + language_probability = 1 + else: + segment = features[:, : self.feature_extractor.nb_max_frames] + encoder_output = self.encode(segment) + # results is a list of tuple[str, float] with language names and + # probabilities. + results = self.model.detect_language(encoder_output)[0] + # Parse language names to strip out markers + all_language_probs = [(token[2:-2], prob) for (token, prob) in results] + # Get top language token and probability + language, language_probability = all_language_probs[0] + + self.logger.info( + "Detected language '%s' with probability %.2f", + language, + language_probability, + ) + else: + if not self.model.is_multilingual and language != "en": + self.logger.warning( + "The current model is English-only but the language parameter is set to '%s'; " + "using 'en' instead." % language + ) + language = "en" + + language_probability = 1 + + tokenizer = Tokenizer( + self.hf_tokenizer, + self.model.is_multilingual, + task=task, + language=language, + ) + + options = TranscriptionOptions( + beam_size=beam_size, + best_of=best_of, + patience=patience, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + log_prob_threshold=log_prob_threshold, + no_speech_threshold=no_speech_threshold, + compression_ratio_threshold=compression_ratio_threshold, + condition_on_previous_text=condition_on_previous_text, + prompt_reset_on_temperature=prompt_reset_on_temperature, + temperatures=( + temperature if isinstance(temperature, (list, tuple)) else [temperature] + ), + initial_prompt=initial_prompt, + prefix=prefix, + suppress_blank=suppress_blank, + suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens), + without_timestamps=without_timestamps, + max_initial_timestamp=max_initial_timestamp, + word_timestamps=word_timestamps, + prepend_punctuations=prepend_punctuations, + append_punctuations=append_punctuations, + ) + + segments = self.generate_segments(features, tokenizer, options, encoder_output) + + if speech_chunks: + segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate) + + info = TranscriptionInfo( + language=language, + language_probability=language_probability, + duration=duration, + duration_after_vad=duration_after_vad, + transcription_options=options, + vad_options=vad_parameters, + all_language_probs=all_language_probs, + ) + + return segments, info + + def generate_segments( + self, + features: np.ndarray, + tokenizer: Tokenizer, + options: TranscriptionOptions, + encoder_output: Optional[ctranslate2.StorageView] = None, + ) -> Iterable[Segment]: + content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames + idx = 0 + seek = 0 + all_tokens = [] + prompt_reset_since = 0 + + if options.initial_prompt is not None: + if isinstance(options.initial_prompt, str): + initial_prompt = " " + options.initial_prompt.strip() + initial_prompt_tokens = tokenizer.encode(initial_prompt) + all_tokens.extend(initial_prompt_tokens) + else: + all_tokens.extend(options.initial_prompt) + + last_speech_timestamp = 0.0 + all_segments = [] + while seek < content_frames: + time_offset = seek * self.feature_extractor.time_per_frame + segment = features[:, seek : seek + self.feature_extractor.nb_max_frames] + segment_size = min( + self.feature_extractor.nb_max_frames, content_frames - seek + ) + segment_duration = segment_size * self.feature_extractor.time_per_frame + + if self.logger.isEnabledFor(logging.DEBUG): + self.logger.debug( + "Processing segment at %s", format_timestamp(time_offset) + ) + + previous_tokens = all_tokens[prompt_reset_since:] + prompt = self.get_prompt( + tokenizer, + previous_tokens, + without_timestamps=options.without_timestamps, + prefix=options.prefix if seek == 0 else None, + ) + + if seek > 0 or encoder_output is None: + encoder_output = self.encode(segment) + + ( + result, + avg_logprob, + temperature, + compression_ratio, + ) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options) + + if options.no_speech_threshold is not None: + # no voice activity check + should_skip = result.no_speech_prob > options.no_speech_threshold + + if ( + options.log_prob_threshold is not None + and avg_logprob > options.log_prob_threshold + ): + # don't skip if the logprob is high enough, despite the no_speech_prob + should_skip = False + + if should_skip: + self.logger.debug( + "No speech threshold is met (%f > %f)", + result.no_speech_prob, + options.no_speech_threshold, + ) + + # fast-forward to the next segment boundary + seek += segment_size + continue + + tokens = result.sequences_ids[0] + + previous_seek = seek + current_segments = [] + + single_timestamp_ending = ( + len(tokens) >= 2 + and tokens[-2] < tokenizer.timestamp_begin + and tokens[-1] >= tokenizer.timestamp_begin + ) + + consecutive_timestamps = [ + i + for i in range(len(tokens)) + if i > 0 + and tokens[i] >= tokenizer.timestamp_begin + and tokens[i - 1] >= tokenizer.timestamp_begin + ] + + if len(consecutive_timestamps) > 0: + slices = list(consecutive_timestamps) + if single_timestamp_ending: + slices.append(len(tokens)) + + last_slice = 0 + for current_slice in slices: + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_position = ( + sliced_tokens[0] - tokenizer.timestamp_begin + ) + end_timestamp_position = ( + sliced_tokens[-1] - tokenizer.timestamp_begin + ) + start_time = ( + time_offset + start_timestamp_position * self.time_precision + ) + end_time = ( + time_offset + end_timestamp_position * self.time_precision + ) + + current_segments.append( + dict( + seek=seek, + start=start_time, + end=end_time, + tokens=sliced_tokens, + ) + ) + last_slice = current_slice + + if single_timestamp_ending: + # single timestamp at the end means no speech after the last timestamp. + seek += segment_size + else: + # otherwise, ignore the unfinished segment and seek to the last timestamp + last_timestamp_position = ( + tokens[last_slice - 1] - tokenizer.timestamp_begin + ) + seek += last_timestamp_position * self.input_stride + + else: + duration = segment_duration + timestamps = [ + token for token in tokens if token >= tokenizer.timestamp_begin + ] + if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin: + last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin + duration = last_timestamp_position * self.time_precision + + current_segments.append( + dict( + seek=seek, + start=time_offset, + end=time_offset + duration, + tokens=tokens, + ) + ) + + seek += segment_size + + if options.word_timestamps: + self.add_word_timestamps( + current_segments, + tokenizer, + encoder_output, + segment_size, + options.prepend_punctuations, + options.append_punctuations, + last_speech_timestamp=last_speech_timestamp, + ) + + word_end_timestamps = [ + w["end"] for s in current_segments for w in s["words"] + ] + if len(word_end_timestamps) > 0: + last_speech_timestamp = word_end_timestamps[-1] + if not single_timestamp_ending and len(word_end_timestamps) > 0: + seek_shift = round( + (word_end_timestamps[-1] - time_offset) * self.frames_per_second + ) + + if seek_shift > 0: + seek = previous_seek + seek_shift + + for segment in current_segments: + tokens = segment["tokens"] + text = tokenizer.decode(tokens) + + if segment["start"] == segment["end"] or not text.strip(): + continue + + all_tokens.extend(tokens) + idx += 1 + + all_segments.append(Segment( + id=idx, + seek=seek, + start=segment["start"], + end=segment["end"], + text=text, + tokens=tokens, + temperature=temperature, + avg_logprob=avg_logprob, + compression_ratio=compression_ratio, + no_speech_prob=result.no_speech_prob, + words=( + [Word(**word) for word in segment["words"]] + if options.word_timestamps + else None + ), + )) + + if ( + not options.condition_on_previous_text + or temperature > options.prompt_reset_on_temperature + ): + if options.condition_on_previous_text: + self.logger.debug( + "Reset prompt. prompt_reset_on_temperature threshold is met %f > %f", + temperature, + options.prompt_reset_on_temperature, + ) + + prompt_reset_since = len(all_tokens) + return all_segments + + def encode(self, features: np.ndarray) -> ctranslate2.StorageView: + # When the model is running on multiple GPUs, the encoder output should be moved + # to the CPU since we don't know which GPU will handle the next job. + to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 + + features = np.expand_dims(features, 0) + features = get_ctranslate2_storage(features) + + return self.model.encode(features, to_cpu=to_cpu) + + def generate_with_fallback( + self, + encoder_output: ctranslate2.StorageView, + prompt: List[int], + tokenizer: Tokenizer, + options: TranscriptionOptions, + ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]: + decode_result = None + all_results = [] + below_cr_threshold_results = [] + + max_initial_timestamp_index = int( + round(options.max_initial_timestamp / self.time_precision) + ) + + for temperature in options.temperatures: + if temperature > 0: + kwargs = { + "beam_size": 1, + "num_hypotheses": options.best_of, + "sampling_topk": 0, + "sampling_temperature": temperature, + } + else: + kwargs = { + "beam_size": options.beam_size, + "patience": options.patience, + } + + result = self.model.generate( + encoder_output, + [prompt], + length_penalty=options.length_penalty, + repetition_penalty=options.repetition_penalty, + no_repeat_ngram_size=options.no_repeat_ngram_size, + max_length=self.max_length, + return_scores=True, + return_no_speech_prob=True, + suppress_blank=options.suppress_blank, + suppress_tokens=options.suppress_tokens, + max_initial_timestamp_index=max_initial_timestamp_index, + **kwargs, + )[0] + + tokens = result.sequences_ids[0] + + # Recover the average log prob from the returned score. + seq_len = len(tokens) + cum_logprob = result.scores[0] * (seq_len**options.length_penalty) + avg_logprob = cum_logprob / (seq_len + 1) + + text = tokenizer.decode(tokens).strip() + compression_ratio = get_compression_ratio(text) + + decode_result = ( + result, + avg_logprob, + temperature, + compression_ratio, + ) + all_results.append(decode_result) + + needs_fallback = False + + if options.compression_ratio_threshold is not None: + if compression_ratio > options.compression_ratio_threshold: + needs_fallback = True # too repetitive + + self.logger.debug( + "Compression ratio threshold is not met with temperature %.1f (%f > %f)", + temperature, + compression_ratio, + options.compression_ratio_threshold, + ) + else: + below_cr_threshold_results.append(decode_result) + + if ( + options.log_prob_threshold is not None + and avg_logprob < options.log_prob_threshold + ): + needs_fallback = True # average log probability is too low + + self.logger.debug( + "Log probability threshold is not met with temperature %.1f (%f < %f)", + temperature, + avg_logprob, + options.log_prob_threshold, + ) + + if ( + options.no_speech_threshold is not None + and result.no_speech_prob > options.no_speech_threshold + ): + needs_fallback = False # silence + + if not needs_fallback: + break + else: + # all failed, select the result with the highest average log probability + decode_result = max( + below_cr_threshold_results or all_results, key=lambda x: x[1] + ) + + return decode_result + + def get_prompt( + self, + tokenizer: Tokenizer, + previous_tokens: List[int], + without_timestamps: bool = False, + prefix: Optional[str] = None, + ) -> List[int]: + prompt = [] + + if previous_tokens: + prompt.append(tokenizer.sot_prev) + prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) + + prompt.extend(tokenizer.sot_sequence) + + if without_timestamps: + prompt.append(tokenizer.no_timestamps) + + if prefix: + prefix_tokens = tokenizer.encode(" " + prefix.strip()) + if len(prefix_tokens) >= self.max_length // 2: + prefix_tokens = prefix_tokens[: self.max_length // 2 - 1] + if not without_timestamps: + prompt.append(tokenizer.timestamp_begin) + prompt.extend(prefix_tokens) + + return prompt + + def add_word_timestamps( + self, + segments: List[dict], + tokenizer: Tokenizer, + encoder_output: ctranslate2.StorageView, + num_frames: int, + prepend_punctuations: str, + append_punctuations: str, + last_speech_timestamp: float, + ) -> None: + if len(segments) == 0: + return + + text_tokens_per_segment = [ + [token for token in segment["tokens"] if token < tokenizer.eot] + for segment in segments + ] + + text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) + alignment = self.find_alignment( + tokenizer, text_tokens, encoder_output, num_frames + ) + word_durations = np.array([word["end"] - word["start"] for word in alignment]) + word_durations = word_durations[word_durations.nonzero()] + median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 + max_duration = median_duration * 2 + + # hack: truncate long words at sentence boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(word_durations) > 0: + sentence_end_marks = ".。!!??" + # ensure words at sentence boundaries + # are not longer than twice the median word duration. + for i in range(1, len(alignment)): + if alignment[i]["end"] - alignment[i]["start"] > max_duration: + if alignment[i]["word"] in sentence_end_marks: + alignment[i]["end"] = alignment[i]["start"] + max_duration + elif alignment[i - 1]["word"] in sentence_end_marks: + alignment[i]["start"] = alignment[i]["end"] - max_duration + + merge_punctuations(alignment, prepend_punctuations, append_punctuations) + + time_offset = ( + segments[0]["seek"] + * self.feature_extractor.hop_length + / self.feature_extractor.sampling_rate + ) + + word_index = 0 + + for segment, text_tokens in zip(segments, text_tokens_per_segment): + saved_tokens = 0 + words = [] + + while word_index < len(alignment) and saved_tokens < len(text_tokens): + timing = alignment[word_index] + + if timing["word"]: + words.append( + dict( + word=timing["word"], + start=round(time_offset + timing["start"], 2), + end=round(time_offset + timing["end"], 2), + probability=timing["probability"], + ) + ) + + saved_tokens += len(timing["tokens"]) + word_index += 1 + + # hack: truncate long words at segment boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(words) > 0: + # ensure the first and second word after a pause is not longer than + # twice the median word duration. + if words[0]["end"] - last_speech_timestamp > median_duration * 4 and ( + words[0]["end"] - words[0]["start"] > max_duration + or ( + len(words) > 1 + and words[1]["end"] - words[0]["start"] > max_duration * 2 + ) + ): + if ( + len(words) > 1 + and words[1]["end"] - words[1]["start"] > max_duration + ): + boundary = max( + words[1]["end"] / 2, words[1]["end"] - max_duration + ) + words[0]["end"] = words[1]["start"] = boundary + words[0]["start"] = max(0, words[0]["end"] - max_duration) + + # prefer the segment-level start timestamp if the first word is too long. + if ( + segment["start"] < words[0]["end"] + and segment["start"] - 0.5 > words[0]["start"] + ): + words[0]["start"] = max( + 0, min(words[0]["end"] - median_duration, segment["start"]) + ) + else: + segment["start"] = words[0]["start"] + + # prefer the segment-level end timestamp if the last word is too long. + if ( + segment["end"] > words[-1]["start"] + and segment["end"] + 0.5 < words[-1]["end"] + ): + words[-1]["end"] = max( + words[-1]["start"] + median_duration, segment["end"] + ) + else: + segment["end"] = words[-1]["end"] + + last_speech_timestamp = segment["end"] + + segment["words"] = words + + def find_alignment( + self, + tokenizer: Tokenizer, + text_tokens: List[int], + encoder_output: ctranslate2.StorageView, + num_frames: int, + median_filter_width: int = 7, + ) -> List[dict]: + if len(text_tokens) == 0: + return [] + + result = self.model.align( + encoder_output, + tokenizer.sot_sequence, + [text_tokens], + num_frames, + median_filter_width=median_filter_width, + )[0] + + text_token_probs = result.text_token_probs + + alignments = result.alignments + text_indices = np.array([pair[0] for pair in alignments]) + time_indices = np.array([pair[1] for pair in alignments]) + + words, word_tokens = tokenizer.split_to_word_tokens( + text_tokens + [tokenizer.eot] + ) + word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) + if len(word_boundaries) <= 1: + return [] + + jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) + jump_times = time_indices[jumps] / self.tokens_per_second + start_times = jump_times[word_boundaries[:-1]] + end_times = jump_times[word_boundaries[1:]] + word_probabilities = [ + np.mean(text_token_probs[i:j]) + for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) + ] + + return [ + dict( + word=word, tokens=tokens, start=start, end=end, probability=probability + ) + for word, tokens, start, end, probability in zip( + words, word_tokens, start_times, end_times, word_probabilities + ) + ] + + def destroy(self): + del self.model + + +def restore_speech_timestamps( + segments: Iterable[Segment], + speech_chunks: List[dict], + sampling_rate: int, +) -> Iterable[Segment]: + ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate) + + for segment in segments: + if segment.words: + words = [] + for word in segment.words: + # Ensure the word start and end times are resolved to the same chunk. + middle = (word.start + word.end) / 2 + chunk_index = ts_map.get_chunk_index(middle) + word = word._replace( + start=ts_map.get_original_time(word.start, chunk_index), + end=ts_map.get_original_time(word.end, chunk_index), + ) + words.append(word) + + segment = segment._replace( + start=words[0].start, + end=words[-1].end, + words=words, + ) + + else: + segment = segment._replace( + start=ts_map.get_original_time(segment.start), + end=ts_map.get_original_time(segment.end), + ) + + return segments + + +def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView: + segment = np.ascontiguousarray(segment) + segment = ctranslate2.StorageView.from_array(segment) + return segment + + +def get_compression_ratio(text: str) -> float: + text_bytes = text.encode("utf-8") + return len(text_bytes) / len(zlib.compress(text_bytes)) + + +def get_suppressed_tokens( + tokenizer: Tokenizer, + suppress_tokens: Optional[List[int]], +) -> Optional[List[int]]: + if not suppress_tokens or -1 in suppress_tokens: + return suppress_tokens + + suppress_tokens = list(suppress_tokens) + + # Ensure the following special tokens are suppressed when the user does + # not use the default set (-1). + suppress_tokens.extend( + [ + tokenizer.transcribe, + tokenizer.translate, + tokenizer.sot, + tokenizer.sot_prev, + tokenizer.sot_lm, + ] + ) + + return sorted(set(suppress_tokens)) + + +def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None: + # merge prepended punctuations + i = len(alignment) - 2 + j = len(alignment) - 1 + while i >= 0: + previous = alignment[i] + following = alignment[j] + if previous["word"].startswith(" ") and previous["word"].strip() in prepended: + # prepend it to the following word + following["word"] = previous["word"] + following["word"] + following["tokens"] = previous["tokens"] + following["tokens"] + previous["word"] = "" + previous["tokens"] = [] + else: + j = i + i -= 1 + + # merge appended punctuations + i = 0 + j = 1 + while j < len(alignment): + previous = alignment[i] + following = alignment[j] + if not previous["word"].endswith(" ") and following["word"] in appended: + # append it to the previous word + previous["word"] = previous["word"] + following["word"] + previous["tokens"] = previous["tokens"] + following["tokens"] + following["word"] = "" + following["tokens"] = [] + else: + i = j + j += 1 diff --git a/whisper_live/trt_server.py b/whisper_live/trt_server.py new file mode 100644 index 0000000000000000000000000000000000000000..4cc22d7283191fb9b0dbe8019e7e303e67713acc --- /dev/null +++ b/whisper_live/trt_server.py @@ -0,0 +1,435 @@ +import websockets +import time +import threading +import json +import textwrap + +import logging +logging.basicConfig(level = logging.INFO) + +from websockets.sync.server import serve + +import torch +import numpy as np +import queue + +from whisper_live.vad import VoiceActivityDetection +from whisper_live.trt_transcriber import WhisperTRTLLM + + +from scipy.io.wavfile import write +import functools + +save_counter = 0 +def save_wav(normalized_float32): + global save_counter + scaled_int16 = (normalized_float32 * 32768).astype(np.int16) + write(f"outputs/output{save_counter}.wav", 16000, scaled_int16) + save_counter += 1 + + + +class TranscriptionServer: + """ + Represents a transcription server that handles incoming audio from clients. + + Attributes: + RATE (int): The audio sampling rate (constant) set to 16000. + vad_model (torch.Module): The voice activity detection model. + vad_threshold (float): The voice activity detection threshold. + clients (dict): A dictionary to store connected clients. + websockets (dict): A dictionary to store WebSocket connections. + clients_start_time (dict): A dictionary to track client start times. + max_clients (int): Maximum allowed connected clients. + max_connection_time (int): Maximum allowed connection time in seconds. + """ + + RATE = 16000 + + def __init__(self): + # voice activity detection model + + self.clients = {} + self.websockets = {} + self.clients_start_time = {} + self.max_clients = 4 + self.max_connection_time = 600 + self.transcriber = None + + def get_wait_time(self): + """ + Calculate and return the estimated wait time for clients. + + Returns: + float: The estimated wait time in minutes. + """ + wait_time = None + + for k, v in self.clients_start_time.items(): + current_client_time_remaining = self.max_connection_time - (time.time() - v) + + if wait_time is None or current_client_time_remaining < wait_time: + wait_time = current_client_time_remaining + + return wait_time / 60 + + def recv_audio(self, websocket, transcription_queue=None, llm_queue=None, whisper_tensorrt_path=None): + """ + Receive audio chunks from a client in an infinite loop. + + Continuously receives audio frames from a connected client + over a WebSocket connection. It processes the audio frames using a + voice activity detection (VAD) model to determine if they contain speech + or not. If the audio frame contains speech, it is added to the client's + audio data for ASR. + If the maximum number of clients is reached, the method sends a + "WAIT" status to the client, indicating that they should wait + until a slot is available. + If a client's connection exceeds the maximum allowed time, it will + be disconnected, and the client's resources will be cleaned up. + + Args: + websocket (WebSocket): The WebSocket connection for the client. + + Raises: + Exception: If there is an error during the audio frame processing. + """ + self.vad_model = VoiceActivityDetection() + self.vad_threshold = 0.5 + + logging.info("[Whisper INFO:] New client connected") + options = websocket.recv() + options = json.loads(options) + + if len(self.clients) >= self.max_clients: + logging.warning("Client Queue Full. Asking client to wait ...") + wait_time = self.get_wait_time() + response = { + "uid": options["uid"], + "status": "WAIT", + "message": wait_time, + } + websocket.send(json.dumps(response)) + websocket.close() + del websocket + return + + if self.transcriber is None: + self.transcriber = WhisperTRTLLM(whisper_tensorrt_path, assets_dir="assets", device="cuda") + + client = ServeClient( + websocket, + multilingual=options["multilingual"], + language=options["language"], + task=options["task"], + client_uid=options["uid"], + transcription_queue=transcription_queue, + llm_queue=llm_queue, + transcriber=self.transcriber + ) + + self.clients[websocket] = client + self.clients_start_time[websocket] = time.time() + no_voice_activity_chunks = 0 + print() + while True: + try: + frame_data = websocket.recv() + frame_np = np.frombuffer(frame_data, dtype=np.float32) + + # VAD + try: + speech_prob = self.vad_model(torch.from_numpy(frame_np.copy()), self.RATE).item() + if speech_prob < self.vad_threshold: + no_voice_activity_chunks += 1 + if no_voice_activity_chunks > 3: + if not self.clients[websocket].eos: + self.clients[websocket].set_eos(True) + time.sleep(0.1) # EOS stop receiving frames for a 100ms(to send output to LLM.) + continue + no_voice_activity_chunks = 0 + self.clients[websocket].set_eos(False) + + except Exception as e: + logging.error(e) + return + self.clients[websocket].add_frames(frame_np) + + elapsed_time = time.time() - self.clients_start_time[websocket] + if elapsed_time >= self.max_connection_time: + self.clients[websocket].disconnect() + logging.warning(f"{self.clients[websocket]} Client disconnected due to overtime.") + self.clients[websocket].cleanup() + self.clients.pop(websocket) + self.clients_start_time.pop(websocket) + websocket.close() + del websocket + break + + except Exception as e: + logging.error(e) + self.clients[websocket].cleanup() + self.clients.pop(websocket) + self.clients_start_time.pop(websocket) + logging.info("[Whisper INFO:] Connection Closed.") + del websocket + break + + def run(self, host, port=9090, transcription_queue=None, llm_queue=None, whisper_tensorrt_path=None): + """ + Run the transcription server. + + Args: + host (str): The host address to bind the server. + port (int): The port number to bind the server. + """ + with serve( + functools.partial( + self.recv_audio, + transcription_queue=transcription_queue, + llm_queue=llm_queue, + whisper_tensorrt_path=whisper_tensorrt_path + ), + host, + port + ) as server: + server.serve_forever() + + +class ServeClient: + """ + Attributes: + RATE (int): The audio sampling rate (constant) set to 16000. + SERVER_READY (str): A constant message indicating that the server is ready. + DISCONNECT (str): A constant message indicating that the client should disconnect. + client_uid (str): A unique identifier for the client. + data (bytes): Accumulated audio data. + frames (bytes): Accumulated audio frames. + language (str): The language for transcription. + task (str): The task type, e.g., "transcribe." + transcriber (WhisperModel): The Whisper model for speech-to-text. + timestamp_offset (float): The offset in audio timestamps. + frames_np (numpy.ndarray): NumPy array to store audio frames. + frames_offset (float): The offset in audio frames. + exit (bool): A flag to exit the transcription thread. + transcript (list): List of transcribed segments. + websocket: The WebSocket connection for the client. + """ + RATE = 16000 + SERVER_READY = "SERVER_READY" + DISCONNECT = "DISCONNECT" + + def __init__( + self, + websocket, + task="transcribe", + device=None, + multilingual=False, + language=None, + client_uid=None, + transcription_queue=None, + llm_queue=None, + transcriber=None + ): + """ + Initialize a ServeClient instance. + The Whisper model is initialized based on the client's language and device availability. + The transcription thread is started upon initialization. A "SERVER_READY" message is sent + to the client to indicate that the server is ready. + + Args: + websocket (WebSocket): The WebSocket connection for the client. + task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe". + device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None. + multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False. + language (str, optional): The language for transcription. Defaults to None. + client_uid (str, optional): A unique identifier for the client. Defaults to None. + + """ + if transcriber is None: + raise ValueError("Transcriber is None.") + self.transcriber = transcriber + self.client_uid = client_uid + self.transcription_queue = transcription_queue + self.llm_queue = llm_queue + self.data = b"" + self.frames = b"" + self.task = task + self.last_prompt = None + + self.timestamp_offset = 0.0 + self.frames_np = None + self.frames_offset = 0.0 + self.exit = False + self.transcript = [] + self.prompt = None + self.segment_inference_time = [] + + # threading + self.websocket = websocket + self.lock = threading.Lock() + self.eos = False + self.trans_thread = threading.Thread(target=self.speech_to_text) + self.trans_thread.start() + + self.websocket.send( + json.dumps( + { + "uid": self.client_uid, + "message": self.SERVER_READY + } + ) + ) + + def set_eos(self, eos): + self.lock.acquire() + self.eos = eos + self.lock.release() + + def add_frames(self, frame_np): + """ + Add audio frames to the ongoing audio stream buffer. + + This method is responsible for maintaining the audio stream buffer, allowing the continuous addition + of audio frames as they are received. It also ensures that the buffer does not exceed a specified size + to prevent excessive memory usage. + + If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds + of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided + audio frame. The audio stream buffer is used for real-time processing of audio data for transcription. + + Args: + frame_np (numpy.ndarray): The audio frame data as a NumPy array. + + """ + self.lock.acquire() + if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE: + self.frames_offset += 30.0 + self.frames_np = self.frames_np[int(30*self.RATE):] + if self.frames_np is None: + self.frames_np = frame_np.copy() + else: + self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0) + self.lock.release() + + def speech_to_text(self): + """ + Process an audio stream in an infinite loop, continuously transcribing the speech. + + This method continuously receives audio frames, performs real-time transcription, and sends + transcribed segments to the client via a WebSocket connection. + + If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction. + It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments + are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech + (no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if + there is no speech for a specified duration to indicate a pause. + + Raises: + Exception: If there is an issue with audio processing or WebSocket communication. + + """ + while True: + # send the LLM outputs + try: + llm_response = None + if self.llm_queue is not None: + while not self.llm_queue.empty(): + llm_response = self.llm_queue.get() + + if llm_response: + eos = llm_response["eos"] + if eos: + self.websocket.send(json.dumps(llm_response)) + except queue.Empty: + pass + + if self.exit: + logging.info("[Whisper INFO:] Exiting speech to text thread") + break + + if self.frames_np is None: + time.sleep(0.02) # wait for any audio to arrive + continue + + # clip audio if the current chunk exceeds 30 seconds, this basically implies that + # no valid segment for the last 30 seconds from whisper + if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE: + duration = self.frames_np.shape[0] / self.RATE + self.timestamp_offset = self.frames_offset + duration - 5 + + samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE) + input_bytes = self.frames_np[int(samples_take):].copy() + duration = input_bytes.shape[0] / self.RATE + if duration<0.4: + time.sleep(0.01) # 5ms sleep to wait for some voice active audio to arrive + continue + + try: + input_sample = input_bytes.copy() + start = time.time() + mel, duration = self.transcriber.log_mel_spectrogram(input_sample) + last_segment = self.transcriber.transcribe(mel) + infer_time = time.time() - start + self.segment_inference_time.append(infer_time) + + segments = [] + if len(last_segment): + segments.append({"text": last_segment}) + try: + self.prompt = ' '.join(segment['text'] for segment in segments) + if self.last_prompt != self.prompt: + self.websocket.send( + json.dumps({ + "uid": self.client_uid, + "segments": segments, + "eos": self.eos, + "latency": infer_time + }) + ) + + self.transcription_queue.put({"uid": self.client_uid, "prompt": self.prompt, "eos": self.eos}) + if self.eos: + self.timestamp_offset += duration + logging.info(f"[Whisper INFO]: {self.prompt}, eos: {self.eos}") + logging.info( + f"[Whisper INFO]: Average inference time {sum(self.segment_inference_time) / len(self.segment_inference_time)}\n\n") + self.segment_inference_time = [] + + + + except Exception as e: + logging.error(f"[ERROR]: {e}") + + except Exception as e: + logging.error(f"[ERROR]: {e}") + + def disconnect(self): + """ + Notify the client of disconnection and send a disconnect message. + + This method sends a disconnect message to the client via the WebSocket connection to notify them + that the transcription service is disconnecting gracefully. + + """ + self.websocket.send( + json.dumps( + { + "uid": self.client_uid, + "message": self.DISCONNECT + } + ) + ) + + def cleanup(self): + """ + Perform cleanup tasks before exiting the transcription service. + + This method performs necessary cleanup tasks, including stopping the transcription thread, marking + the exit flag to indicate the transcription thread should exit gracefully, and destroying resources + associated with the transcription process. + + """ + logging.info("Cleaning up.") + self.exit = True + # self.transcriber.destroy() diff --git a/whisper_live/trt_transcriber.py b/whisper_live/trt_transcriber.py new file mode 100644 index 0000000000000000000000000000000000000000..bdad5e0a13c8cf3b2d301f29093173f765c261d2 --- /dev/null +++ b/whisper_live/trt_transcriber.py @@ -0,0 +1,347 @@ +import argparse +import json +import re +import time +from collections import OrderedDict +from pathlib import Path +from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union + +import torch +import numpy as np +from whisper.tokenizer import get_tokenizer +from whisper_live.whisper_utils import (mel_filters, store_transcripts, + write_error_stats, load_audio_wav_format, + pad_or_trim) + +import tensorrt_llm +import tensorrt_llm.logger as logger +from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt, + trt_dtype_to_torch) +from tensorrt_llm.runtime import ModelConfig, SamplingConfig +from tensorrt_llm.runtime.session import Session, TensorInfo + + +SAMPLE_RATE = 16000 +N_FFT = 400 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk + + +class WhisperEncoding: + + def __init__(self, engine_dir): + self.session = self.get_session(engine_dir) + + def get_session(self, engine_dir): + config_path = engine_dir / 'encoder_config.json' + with open(config_path, 'r') as f: + config = json.load(f) + + use_gpt_attention_plugin = config['plugin_config'][ + 'gpt_attention_plugin'] + dtype = config['builder_config']['precision'] + n_mels = config['builder_config']['n_mels'] + num_languages = config['builder_config']['num_languages'] + + self.dtype = dtype + self.n_mels = n_mels + self.num_languages = num_languages + + serialize_path = engine_dir / f'whisper_encoder_{self.dtype}_tp1_rank0.engine' + + with open(serialize_path, 'rb') as f: + session = Session.from_serialized_engine(f.read()) + + return session + + def get_audio_features(self, mel): + inputs = OrderedDict() + output_list = [] + + inputs.update({'x': mel}) + output_list.append( + TensorInfo('x', str_dtype_to_trt(self.dtype), mel.shape)) + + output_info = (self.session).infer_shapes(output_list) + + logger.debug(f'output info {output_info}') + outputs = { + t.name: torch.empty(tuple(t.shape), + dtype=trt_dtype_to_torch(t.dtype), + device='cuda') + for t in output_info + } + stream = torch.cuda.current_stream() + ok = self.session.run(inputs=inputs, + outputs=outputs, + stream=stream.cuda_stream) + assert ok, 'Engine execution failed' + stream.synchronize() + audio_features = outputs['output'] + return audio_features + + +class WhisperDecoding: + + def __init__(self, engine_dir, runtime_mapping, debug_mode=False): + + self.decoder_config = self.get_config(engine_dir) + self.decoder_generation_session = self.get_session( + engine_dir, runtime_mapping, debug_mode) + + def get_config(self, engine_dir): + config_path = engine_dir / 'decoder_config.json' + with open(config_path, 'r') as f: + config = json.load(f) + decoder_config = OrderedDict() + decoder_config.update(config['plugin_config']) + decoder_config.update(config['builder_config']) + return decoder_config + + def get_session(self, engine_dir, runtime_mapping, debug_mode=False): + dtype = self.decoder_config['precision'] + serialize_path = engine_dir / f'whisper_decoder_{dtype}_tp1_rank0.engine' + with open(serialize_path, "rb") as f: + decoder_engine_buffer = f.read() + + decoder_model_config = ModelConfig( + num_heads=self.decoder_config['num_heads'], + num_kv_heads=self.decoder_config['num_heads'], + hidden_size=self.decoder_config['hidden_size'], + vocab_size=self.decoder_config['vocab_size'], + num_layers=self.decoder_config['num_layers'], + gpt_attention_plugin=self.decoder_config['gpt_attention_plugin'], + remove_input_padding=self.decoder_config['remove_input_padding'], + cross_attention=self.decoder_config['cross_attention'], + has_position_embedding=self. + decoder_config['has_position_embedding'], + has_token_type_embedding=self. + decoder_config['has_token_type_embedding'], + ) + decoder_generation_session = tensorrt_llm.runtime.GenerationSession( + decoder_model_config, + decoder_engine_buffer, + runtime_mapping, + debug_mode=debug_mode) + + return decoder_generation_session + + def generate(self, + decoder_input_ids, + encoder_outputs, + eot_id, + max_new_tokens=40, + num_beams=1): + encoder_input_lengths = torch.tensor( + [encoder_outputs.shape[1] for x in range(encoder_outputs.shape[0])], + dtype=torch.int32, + device='cuda') + + decoder_input_lengths = torch.tensor([ + decoder_input_ids.shape[-1] + for _ in range(decoder_input_ids.shape[0]) + ], + dtype=torch.int32, + device='cuda') + decoder_max_input_length = torch.max(decoder_input_lengths).item() + + # generation config + sampling_config = SamplingConfig(end_id=eot_id, + pad_id=eot_id, + num_beams=num_beams) + self.decoder_generation_session.setup( + decoder_input_lengths.size(0), + decoder_max_input_length, + max_new_tokens, + beam_width=num_beams, + encoder_max_input_length=encoder_outputs.shape[1]) + + torch.cuda.synchronize() + + decoder_input_ids = decoder_input_ids.type(torch.int32).cuda() + output_ids = self.decoder_generation_session.decode( + decoder_input_ids, + decoder_input_lengths, + sampling_config, + encoder_output=encoder_outputs, + encoder_input_lengths=encoder_input_lengths, + ) + torch.cuda.synchronize() + + # get the list of int from output_ids tensor + output_ids = output_ids.cpu().numpy().tolist() + return output_ids + + +class WhisperTRTLLM(object): + + def __init__( + self, + engine_dir, + debug_mode=False, + assets_dir=None, + device=None + ): + world_size = 1 + runtime_rank = tensorrt_llm.mpi_rank() + runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank) + torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node) + engine_dir = Path(engine_dir) + + self.encoder = WhisperEncoding(engine_dir) + self.decoder = WhisperDecoding(engine_dir, + runtime_mapping, + debug_mode=False) + self.n_mels = self.encoder.n_mels + # self.tokenizer = get_tokenizer(num_languages=self.encoder.num_languages, + # tokenizer_dir=assets_dir) + self.device = device + self.tokenizer = get_tokenizer( + False, + # num_languages=self.encoder.num_languages, + language="en", + task="transcribe", + ) + self.filters = mel_filters(self.device, self.encoder.n_mels, assets_dir) + + def log_mel_spectrogram( + self, + audio: Union[str, np.ndarray, torch.Tensor], + padding: int = 0, + return_duration = True + ): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 and 128 are supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + + Returns + ------- + torch.Tensor, shape = (80 or 128, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + if audio.endswith('.wav'): + audio, _ = load_audio_wav_format(audio) + else: + audio = load_audio(audio) + assert isinstance(audio, + np.ndarray), f"Unsupported audio type: {type(audio)}" + duration = audio.shape[-1] / SAMPLE_RATE + audio = pad_or_trim(audio, N_SAMPLES) + audio = audio.astype(np.float32) + audio = torch.from_numpy(audio) + + if self.device is not None: + audio = audio.to(self.device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, + N_FFT, + HOP_LENGTH, + window=window, + return_complex=True) + magnitudes = stft[..., :-1].abs()**2 + + + mel_spec = self.filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + if return_duration: + return log_spec, duration + else: + return log_spec + + + def process_batch( + self, + mel, + text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + num_beams=1): + prompt_id = self.tokenizer.encode( + text_prefix, allowed_special=set(self.tokenizer.special_tokens.keys())) + + prompt_id = torch.tensor(prompt_id) + batch_size = mel.shape[0] + decoder_input_ids = prompt_id.repeat(batch_size, 1) + + encoder_output = self.encoder.get_audio_features(mel) + output_ids = self.decoder.generate(decoder_input_ids, + encoder_output, + self.tokenizer.eot, + max_new_tokens=96, + num_beams=num_beams) + texts = [] + for i in range(len(output_ids)): + text = self.tokenizer.decode(output_ids[i][0]).strip() + texts.append(text) + return texts + + def transcribe( + self, + mel, + text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + dtype='float16', + batch_size=1, + num_beams=1, + ): + mel = mel.type(str_dtype_to_torch(dtype)) + mel = mel.unsqueeze(0) + predictions = self.process_batch(mel, text_prefix, num_beams) + prediction = predictions[0] + + # remove all special tokens in the prediction + prediction = re.sub(r'<\|.*?\|>', '', prediction) + return prediction.strip() + + +def decode_wav_file( + model, + mel, + text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + dtype='float16', + batch_size=1, + num_beams=1, + normalizer=None, + mel_filters_dir=None): + + mel = mel.type(str_dtype_to_torch(dtype)) + mel = mel.unsqueeze(0) + # repeat the mel spectrogram to match the batch size + mel = mel.repeat(batch_size, 1, 1) + predictions = model.process_batch(mel, text_prefix, num_beams) + prediction = predictions[0] + + # remove all special tokens in the prediction + prediction = re.sub(r'<\|.*?\|>', '', prediction) + if normalizer: + prediction = normalizer(prediction) + + return prediction.strip() + + +if __name__=="__main__": + tensorrt_llm.logger.set_level("error") + model = WhisperTRTLLM("/root/TensorRT-LLM/examples/whisper/whisper_small_en", False, "../assets", device="cuda") + mel, total_duration = model.log_mel_spectrogram( + "../assets/1221-135766-0002.wav", + ) + results = model.transcribe(mel) + print(results, total_duration) \ No newline at end of file diff --git a/whisper_live/vad.py b/whisper_live/vad.py new file mode 100644 index 0000000000000000000000000000000000000000..ec93333f88876744dd004d3679c906f4191814a4 --- /dev/null +++ b/whisper_live/vad.py @@ -0,0 +1,118 @@ +# original: https://github.com/snakers4/silero-vad/blob/master/utils_vad.py + +import os +import subprocess +import torch +import numpy as np +import onnxruntime + + +class VoiceActivityDetection(): + + def __init__(self, force_onnx_cpu=True): + print("downloading ONNX model...") + path = self.download() + print("loading session") + + opts = onnxruntime.SessionOptions() + opts.log_severity_level = 3 + + opts.inter_op_num_threads = 1 + opts.intra_op_num_threads = 1 + + print("loading onnx model") + if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers(): + self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts) + else: + self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts) + + print("reset states") + self.reset_states() + self.sample_rates = [8000, 16000] + + def _validate_input(self, x, sr: int): + if x.dim() == 1: + x = x.unsqueeze(0) + if x.dim() > 2: + raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}") + + if sr != 16000 and (sr % 16000 == 0): + step = sr // 16000 + x = x[:,::step] + sr = 16000 + + if sr not in self.sample_rates: + raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)") + + if sr / x.shape[1] > 31.25: + raise ValueError("Input audio chunk is too short") + + return x, sr + + def reset_states(self, batch_size=1): + self._h = np.zeros((2, batch_size, 64)).astype('float32') + self._c = np.zeros((2, batch_size, 64)).astype('float32') + self._last_sr = 0 + self._last_batch_size = 0 + + def __call__(self, x, sr: int): + + x, sr = self._validate_input(x, sr) + batch_size = x.shape[0] + + if not self._last_batch_size: + self.reset_states(batch_size) + if (self._last_sr) and (self._last_sr != sr): + self.reset_states(batch_size) + if (self._last_batch_size) and (self._last_batch_size != batch_size): + self.reset_states(batch_size) + + if sr in [8000, 16000]: + ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr, dtype='int64')} + ort_outs = self.session.run(None, ort_inputs) + out, self._h, self._c = ort_outs + else: + raise ValueError() + + self._last_sr = sr + self._last_batch_size = batch_size + + out = torch.tensor(out) + return out + + def audio_forward(self, x, sr: int, num_samples: int = 512): + outs = [] + x, sr = self._validate_input(x, sr) + + if x.shape[1] % num_samples: + pad_num = num_samples - (x.shape[1] % num_samples) + x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0) + + self.reset_states(x.shape[0]) + for i in range(0, x.shape[1], num_samples): + wavs_batch = x[:, i:i+num_samples] + out_chunk = self.__call__(wavs_batch, sr) + outs.append(out_chunk) + + stacked = torch.cat(outs, dim=1) + return stacked.cpu() + + @staticmethod + def download(model_url="https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx"): + target_dir = os.path.expanduser("~/.cache/whisper-live/") + + # Ensure the target directory exists + os.makedirs(target_dir, exist_ok=True) + + # Define the target file path + model_filename = os.path.join(target_dir, "silero_vad.onnx") + + # Check if the model file already exists + if not os.path.exists(model_filename): + # If it doesn't exist, download the model using wget + print("Downloading VAD ONNX model...") + try: + subprocess.run(["wget", "-O", model_filename, model_url], check=True) + except subprocess.CalledProcessError: + print("Failed to download the model using wget.") + return model_filename diff --git a/whisper_live/whisper_utils.py b/whisper_live/whisper_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6c30d3576f888dfd4935fdc8d1f2ef97940b50cd --- /dev/null +++ b/whisper_live/whisper_utils.py @@ -0,0 +1,365 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from collections import defaultdict +from functools import lru_cache +from pathlib import Path +from subprocess import CalledProcessError, run +from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union + +import kaldialign +import numpy as np +import soundfile +import torch +import torch.nn.functional as F + +Pathlike = Union[str, Path] + +SAMPLE_RATE = 16000 +N_FFT = 400 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk + + +def load_audio(file: str, sr: int = SAMPLE_RATE): + """ + Open an audio file and read as mono waveform, resampling as necessary + + Parameters + ---------- + file: str + The audio file to open + + sr: int + The sample rate to resample the audio if necessary + + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + + # This launches a subprocess to decode audio while down-mixing + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + # fmt: off + cmd = [ + "ffmpeg", "-nostdin", "-threads", "0", "-i", file, "-f", "s16le", "-ac", + "1", "-acodec", "pcm_s16le", "-ar", + str(sr), "-" + ] + # fmt: on + try: + out = run(cmd, capture_output=True, check=True).stdout + except CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + + +def load_audio_wav_format(wav_path): + # make sure audio in .wav format + assert wav_path.endswith( + '.wav'), f"Only support .wav format, but got {wav_path}" + waveform, sample_rate = soundfile.read(wav_path) + assert sample_rate == 16000, f"Only support 16k sample rate, but got {sample_rate}" + return waveform, sample_rate + + +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select(dim=axis, + index=torch.arange(length, + device=array.device)) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = F.pad(array, + [pad for sizes in pad_widths[::-1] for pad in sizes]) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array + + +@lru_cache(maxsize=None) +def mel_filters(device, + n_mels: int, + mel_filters_dir: str = None) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" + if mel_filters_dir is None: + mel_filters_path = os.path.join(os.path.dirname(__file__), "assets", + "mel_filters.npz") + else: + mel_filters_path = os.path.join(mel_filters_dir, "mel_filters.npz") + with np.load(mel_filters_path) as f: + return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + + +def log_mel_spectrogram( + audio: Union[str, np.ndarray, torch.Tensor], + n_mels: int, + padding: int = 0, + device: Optional[Union[str, torch.device]] = None, + return_duration: bool = False, + mel_filters_dir: str = None, +): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 and 128 are supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + + Returns + ------- + torch.Tensor, shape = (80 or 128, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + if audio.endswith('.wav'): + audio, _ = load_audio_wav_format(audio) + else: + audio = load_audio(audio) + assert isinstance(audio, + np.ndarray), f"Unsupported audio type: {type(audio)}" + duration = audio.shape[-1] / SAMPLE_RATE + audio = pad_or_trim(audio, N_SAMPLES) + audio = audio.astype(np.float32) + audio = torch.from_numpy(audio) + + if device is not None: + audio = audio.to(device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, + N_FFT, + HOP_LENGTH, + window=window, + return_complex=True) + magnitudes = stft[..., :-1].abs()**2 + + filters = mel_filters(audio.device, n_mels, mel_filters_dir) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + if return_duration: + return log_spec, duration + else: + return log_spec + + +def store_transcripts(filename: Pathlike, texts: Iterable[Tuple[str, str, + str]]) -> None: + """Save predicted results and reference transcripts to a file. + https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py + Args: + filename: + File to save the results to. + texts: + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. + Returns: + Return None. + """ + with open(filename, "w") as f: + for cut_id, ref, hyp in texts: + print(f"{cut_id}:\tref={ref}", file=f) + print(f"{cut_id}:\thyp={hyp}", file=f) + + +def write_error_stats( + f: TextIO, + test_set_name: str, + results: List[Tuple[str, str]], + enable_log: bool = True, +) -> float: + """Write statistics based on predicted results and reference transcripts. + https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py + It will write the following to the given file: + + - WER + - number of insertions, deletions, substitutions, corrects and total + reference words. For example:: + + Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 + reference words (2337 correct) + + - The difference between the reference transcript and predicted result. + An instance is given below:: + + THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES + + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). + + Another example is:: + + FOR THE FIRST DAY (SIR->*) I THINK + + The reference word `SIR` is missing in the predicted + results (a deletion error). + results: + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. + Returns: + Return None. + """ + subs: Dict[Tuple[str, str], int] = defaultdict(int) + ins: Dict[str, int] = defaultdict(int) + dels: Dict[str, int] = defaultdict(int) + + # `words` stores counts per word, as follows: + # corr, ref_sub, hyp_sub, ins, dels + words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) + num_corr = 0 + ERR = "*" + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + for ref_word, hyp_word in ali: + if ref_word == ERR: + ins[hyp_word] += 1 + words[hyp_word][3] += 1 + elif hyp_word == ERR: + dels[ref_word] += 1 + words[ref_word][4] += 1 + elif hyp_word != ref_word: + subs[(ref_word, hyp_word)] += 1 + words[ref_word][1] += 1 + words[hyp_word][2] += 1 + else: + words[ref_word][0] += 1 + num_corr += 1 + ref_len = sum([len(r) for _, r, _ in results]) + sub_errs = sum(subs.values()) + ins_errs = sum(ins.values()) + del_errs = sum(dels.values()) + tot_errs = sub_errs + ins_errs + del_errs + tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + + if enable_log: + logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]") + + print(f"%WER = {tot_err_rate}", file=f) + print( + f"Errors: {ins_errs} insertions, {del_errs} deletions, " + f"{sub_errs} substitutions, over {ref_len} reference " + f"words ({num_corr} correct)", + file=f, + ) + print( + "Search below for sections starting with PER-UTT DETAILS:, " + "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", + file=f, + ) + + print("", file=f) + print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + combine_successive_errors = True + if combine_successive_errors: + ali = [[[x], [y]] for x, y in ali] + for i in range(len(ali) - 1): + if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: + ali[i + 1][0] = ali[i][0] + ali[i + 1][0] + ali[i + 1][1] = ali[i][1] + ali[i + 1][1] + ali[i] = [[], []] + ali = [[ + list(filter(lambda a: a != ERR, x)), + list(filter(lambda a: a != ERR, y)), + ] for x, y in ali] + ali = list(filter(lambda x: x != [[], []], ali)) + ali = [[ + ERR if x == [] else " ".join(x), + ERR if y == [] else " ".join(y), + ] for x, y in ali] + + print( + f"{cut_id}:\t" + " ".join((ref_word if ref_word == hyp_word else + f"({ref_word}->{hyp_word})" + for ref_word, hyp_word in ali)), + file=f, + ) + + print("", file=f) + print("SUBSTITUTIONS: count ref -> hyp", file=f) + + for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], + reverse=True): + print(f"{count} {ref} -> {hyp}", file=f) + + print("", file=f) + print("DELETIONS: count ref", file=f) + for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): + print(f"{count} {ref}", file=f) + + print("", file=f) + print("INSERTIONS: count hyp", file=f) + for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): + print(f"{count} {hyp}", file=f) + + print("", file=f) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", + file=f) + for _, word, counts in sorted([(sum(v[1:]), k, v) + for k, v in words.items()], + reverse=True): + (corr, ref_sub, hyp_sub, ins, dels) = counts + tot_errs = ref_sub + hyp_sub + ins + dels + ref_count = corr + ref_sub + dels + hyp_count = corr + hyp_sub + ins + + print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) + return float(tot_err_rate) diff --git a/whisperfusion b/whisperfusion new file mode 160000 index 0000000000000000000000000000000000000000..1de4c740954848883f911e6c97e1db105b999b82 --- /dev/null +++ b/whisperfusion @@ -0,0 +1 @@ +Subproject commit 1de4c740954848883f911e6c97e1db105b999b82