# Import the necessary components from whisper_online.py import logging import os from typing import Optional import librosa import soundfile import uvicorn from fastapi import FastAPI, WebSocket from pydantic import BaseModel, ConfigDict from starlette.websockets import WebSocketDisconnect from libs.whisper_streaming.whisper_online import ( ASRBase, OnlineASRProcessor, VACOnlineASRProcessor, add_shared_args, asr_factory, set_logging, create_tokenizer, load_audio, load_audio_chunk, OpenaiApiASR, set_logging ) import argparse import sys import numpy as np import io import soundfile import wave import requests import argparse # from libs.whisper_streaming.whisper_online_server import online logger = logging.getLogger(__name__) SAMPLING_RATE = 16000 WARMUP_FILE = "mono16k.test_hebrew.wav" AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" app = FastAPI() args = argparse.ArgumentParser() add_shared_args(args) def drop_option_from_parser(parser, option_name): """ Reinitializes the parser and copies all options except the specified option. Args: parser (argparse.ArgumentParser): The original argument parser. option_name (str): The option string to drop (e.g., '--model'). Returns: argparse.ArgumentParser: A new parser without the specified option. """ # Create a new parser with the same description and other attributes new_parser = argparse.ArgumentParser( description=parser.description, epilog=parser.epilog, formatter_class=parser.formatter_class ) # Iterate through all the arguments of the original parser for action in parser._actions: if "-h" in action.option_strings: continue # Check if the option is not the one to drop if option_name not in action.option_strings : new_parser._add_action(action) return new_parser def convert_to_mono_16k(input_wav: str, output_wav: str) -> None: """ Converts any .wav file to mono 16 kHz. Args: input_wav (str): Path to the input .wav file. output_wav (str): Path to save the output .wav file with mono 16 kHz. """ # Step 1: Load the audio file with librosa audio_data, original_sr = librosa.load(input_wav, sr=None, mono=False) # Load at original sampling rate logger.info("Loaded audio with shape: %s, original sampling rate: %d" % (audio_data.shape, original_sr)) # Step 2: If the audio has multiple channels, average them to make it mono if audio_data.ndim > 1: audio_data = librosa.to_mono(audio_data) # Step 3: Resample the audio to 16 kHz resampled_audio = librosa.resample(audio_data, orig_sr=original_sr, target_sr=SAMPLING_RATE) # Step 4: Save the resampled audio as a .wav file in mono at 16 kHz sf.write(output_wav, resampled_audio, SAMPLING_RATE) logger.info(f"Converted audio saved to {output_wav}") def download_warmup_file(): # Download the audio file if not already present audio_file_path = "test_hebrew.wav" if not os.path.exists(WARMUP_FILE): if not os.path.exists(audio_file_path): response = requests.get(AUDIO_FILE_URL) with open(audio_file_path, 'wb') as f: f.write(response.content) convert_to_mono_16k(audio_file_path, WARMUP_FILE) class State(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) websocket: WebSocket asr: ASRBase online: OnlineASRProcessor min_limit: int is_first: bool = True last_end: Optional[float] = None async def receive_audio_chunk(state: State) -> Optional[np.ndarray]: # receive all audio that is available by this time # blocks operation if less than self.min_chunk seconds is available # unblocks if connection is closed or a chunk is available out = [] while sum(len(x) for x in out) < state.min_limit: raw_bytes = await state.websocket.receive_bytes() if not raw_bytes: break # print("received audio:",len(raw_bytes), "bytes", raw_bytes[:10]) sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1,endian="LITTLE",samplerate=SAMPLING_RATE, subtype="PCM_16",format="RAW") audio, _ = librosa.load(sf,sr=SAMPLING_RATE,dtype=np.float32) out.append(audio) if not out: return None flat_out = np.concatenate(out) if state.is_first and len(flat_out) < state.min_limit: return None state.is_first = False return flat_out def format_output_transcript(state, o) -> dict: # output format in stdout is like: # 0 1720 Takhle to je # - the first two words are: # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway # - the next words: segment transcript # This function differs from whisper_online.output_transcript in the following: # succeeding [beg,end] intervals are not overlapping because ELITR protocol (implemented in online-text-flow events) requires it. # Therefore, beg, is max of previous end and current beg outputed by Whisper. # Usually it differs negligibly, by appx 20 ms. if o[0] is not None: beg, end = o[0]*1000,o[1]*1000 if state.last_end is not None: beg = max(beg, state.last_end) state.last_end = end print("%1.0f %1.0f %s" % (beg,end,o[2]),flush=True,file=sys.stderr) return { "start": "%1.0f" % beg, "end": "%1.0f" % end, "text": "%s" % o[2], } else: logger.debug("No text in this segment") return None # Define WebSocket endpoint @app.websocket("/ws_transcribe_streaming") async def websocket_transcribe(websocket: WebSocket): logger.info("New WebSocket connection request received.") await websocket.accept() logger.info("WebSocket connection established successfully.") # initialize the ASR model logger.info("Loading whisper model...") asr, online = asr_factory(args) state = State( websocket=websocket, asr=asr, online=online, min_limit=args.min_chunk_size * SAMPLING_RATE, ) # warm up the ASR because the very first transcribe takes more time than the others. # Test results in https://github.com/ufal/whisper_streaming/pull/81 logger.info("Warming up the whisper model...") a = load_audio_chunk(WARMUP_FILE, 0, 1) asr.transcribe(a) logger.info("Whisper is warmed up.") try: while True: a = await receive_audio_chunk(state) if a is None: break state.online.insert_audio_chunk(a) o = online.process_iter() try: if result := format_output_transcript(state, o): await websocket.send_json(result) except BrokenPipeError: logger.info("broken pipe -- connection closed?") break except WebSocketDisconnect: logger.info("WebSocket connection closed by the client.") break except Exception as e: logger.error(f"Unexpected error during WebSocket transcription: {e}") await websocket.send_json({"error": str(e)}) finally: logger.info("Cleaning up and closing WebSocket connection.") def main(): global args args = drop_option_from_parser(args, '--model') args.add_argument('--model', type=str, help="Name size of the Whisper model to use. The model is automatically downloaded from the model hub if not present in model cache dir.") args.parse_args([ '--lan', 'he', '--model', 'ivrit-ai/faster-whisper-v2-d4', '--backend', 'faster-whisper', '--vad', # '--vac', '--buffer_trimming', 'segment', '--buffer_trimming_sec', '15', '--min_chunk_size', '1.0', '--vac_chunk_size', '0.04', '--start_at', '0.0', '--offline', '--comp_unaware', '--log_level', 'DEBUG' ]) uvicorn.run(app) if __name__ == "__main__": main()