import fastapi import numpy as np import torch import torchaudio from silero_vad import get_speech_timestamps, load_silero_vad import whisperx import edge_tts import gc import logging import time from openai import OpenAI import threading # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Configure FastAPI app = fastapi.FastAPI() # Load Silero VAD model device = 'cuda' if torch.cuda.is_available() else 'cpu' logging.info(f'Using device: {device}') vad_model = load_silero_vad().to(device) # Ensure the model is on the correct device logging.info('Loaded Silero VAD model') # Load WhisperX model whisper_model = whisperx.load_model("tiny", device, compute_type="float16") logging.info('Loaded WhisperX model') OPENAI_API_KEY = "" # os.getenv("OPENAI_API_KEY") if not OPENAI_API_KEY: logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.") raise ValueError("OpenAI API key not found.") # Initialize OpenAI client openai_client = OpenAI(api_key=OPENAI_API_KEY) logging.info('Initialized OpenAI client') # TTS Voice TTS_VOICE = "en-GB-SoniaNeural" # Function to check voice activity using Silero VAD def check_vad(audio_data, sample_rate): logging.info('Checking voice activity') # Resample to 16000 Hz if necessary target_sample_rate = 16000 if sample_rate != target_sample_rate: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) audio_tensor = resampler(torch.from_numpy(audio_data)) else: audio_tensor = torch.from_numpy(audio_data) audio_tensor = audio_tensor.to(device) # Log audio data details logging.info(f'Audio tensor shape: {audio_tensor.shape}, dtype: {audio_tensor.dtype}, device: {audio_tensor.device}') # Get speech timestamps speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate) logging.info(f'Found {len(speech_timestamps)} speech timestamps') return len(speech_timestamps) > 0 # Function to transcribe audio using WhisperX def transcript(audio_data, sample_rate): logging.info('Transcribing audio') # Resample to 16000 Hz if necessary target_sample_rate = 16000 if sample_rate != target_sample_rate: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) audio_data = resampler(torch.from_numpy(audio_data)).numpy() else: audio_data = audio_data # Transcribe batch_size = 16 # Adjust as needed result = whisper_model.transcribe(audio_data, batch_size=batch_size) text = result["segments"][0]["text"] if len(result["segments"]) > 0 else "" logging.info(f'Transcription result: {text}') # Clear GPU memory del result gc.collect() if device == 'cuda': torch.cuda.empty_cache() return text # Function to get streaming response from OpenAI API def llm(text): logging.info('Getting response from OpenAI API') response = openai_client.chat.completions.create( model="gpt-4o", # Updated to a more recent model messages=[ {"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."}, {"role": "user", "content": text} ], stream=True, temperature=0.7, # Optional: Adjust as needed top_p=0.9, # Optional: Adjust as needed ) for chunk in response: yield chunk.choices[0].delta.content # Function to perform TTS per sentence using Edge-TTS def tts_streaming(text_stream): logging.info('Performing TTS') buffer = "" punctuation = {'.', '!', '?'} for text_chunk in text_stream: if text_chunk is not None: buffer += text_chunk # Check for sentence completion sentences = [] start = 0 for i, char in enumerate(buffer): if (char in punctuation): sentences.append(buffer[start:i+1].strip()) start = i+1 buffer = buffer[start:] for sentence in sentences: if sentence: communicate = edge_tts.Communicate(sentence, TTS_VOICE) for chunk in communicate.stream_sync(): if chunk["type"] == "audio": yield chunk["data"] # Process any remaining text if buffer.strip(): communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE) for chunk in communicate.stream_sync(): if chunk["type"] == "audio": yield chunk["data"] # Function to handle LLM and TTS def llm_and_tts(transcribed_text, state): logging.info('Handling LLM and TTS') # Get streaming response from LLM for text_chunk in llm(transcribed_text): if state.get('stop_signal'): logging.info('LLM and TTS task stopped') break # Get audio data from TTS for audio_chunk in tts_streaming([text_chunk]): if state.get('stop_signal'): logging.info('LLM and TTS task stopped during TTS') break yield np.frombuffer(audio_chunk, dtype=np.int16) state = { 'mode': 'idle', 'chunk_queue': [], 'transcription': '', 'in_transcription': False, 'previous_no_vad_audio': [], 'llm_task': None, 'instream': None, 'stop_signal': False, 'args': { 'sample_rate': 16000, 'chunk_size': 0.5, # seconds 'transcript_chunk_size': 2, # seconds } } def transcript_loop(): while True: if len(state['chunk_queue']) > 0: accumulated_audio = np.concatenate(state['chunk_queue']) total_samples = sum(len(chunk) for chunk in state['chunk_queue']) total_duration = total_samples / state['sample_rate'] # Run transcription on the first 2 seconds if len > 3 seconds if total_duration > 3.0 and state['in_transcription'] == True: first_two_seconds_samples = int(2.0 * state['sample_rate']) first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples] transcribed_text = transcript(first_two_seconds_audio, state['sample_rate']) state['transcription'] += transcribed_text remaining_audio = accumulated_audio[first_two_seconds_samples:] state['chunk_queue'] = [remaining_audio] else: # Run transcription on the accumulated audio transcribed_text = transcript(accumulated_audio, state['sample_rate']) state['transcription'] += transcribed_text state['chunk_queue'] = [] state['in_transcription'] = False else: time.sleep(0.1) if len(state['chunk_queue']) == 0 and state['mode'] == any(['idle', 'processing']): state['in_transcription'] = False break def process_audio(audio_chunk): # returns output audio sample_rate, audio_data = audio_chunk audio_data = np.array(audio_data, dtype=np.float32) # convert to mono if necessary if audio_data.ndim > 1: audio_data = np.mean(audio_data, axis=1) mode = state['mode'] chunk_queue = state['chunk_queue'] transcription = state['transcription'] in_transcription = state['in_transcription'] previous_no_vad_audio = state['previous_no_vad_audio'] llm_task = state['llm_task'] instream = state['instream'] stop_signal = state['stop_signal'] args = state['args'] args['sample_rate'] = sample_rate # check for voice activity vad = check_vad(audio_data, sample_rate) if vad: logging.info(f'Voice activity detected in mode: {mode}') if mode == 'idle': mode = 'listening' elif mode == 'speaking': # Stop llm and tts tasks if llm_task and llm_task.is_alive(): # Implement task cancellation logic if possible logging.info('Stopping LLM and TTS tasks') # Since we cannot kill threads directly, we need to handle this in the tasks stop_signal = True llm_task.join() mode = 'listening' if mode == 'listening': if previous_no_vad_audio is not None: chunk_queue.append(previous_no_vad_audio) previous_no_vad_audio = None # Accumulate audio chunks chunk_queue.append(audio_data) # Start transcription thread if not already running if not in_transcription: in_transcription = True transcription_task = threading.Thread(target=transcript_loop, args=(chunk_queue, sample_rate)) transcription_task.start() elif mode == 'speaking': # Continue accumulating audio chunks chunk_queue.append(audio_data) else: logging.info(f'No voice activity detected in mode: {mode}') if mode == 'listening': # Add the last chunk to queue chunk_queue.append(audio_data) # Change mode to processing mode = 'processing' # Wait for transcription to complete while in_transcription: time.sleep(0.1) # Check if transcription is complete if len(chunk_queue) == 0: # Start LLM and TTS tasks if not llm_task or not llm_task.is_alive(): stop_signal = False llm_task = threading.Thread(target=llm_and_tts, args=(transcription, state)) llm_task.start() if mode == 'processing': # Wait for LLM and TTS tasks to start yielding audio if llm_task and llm_task.is_alive(): mode = 'responding' if mode == 'responding': for audio_chunk in llm_task: if instream is None: instream = audio_chunk else: instream = np.concatenate((instream, audio_chunk)) # Send audio to output stream yield instream # Cleanup llm_task = None transcription = '' mode = 'idle' # Updaate state state['mode'] = mode state['chunk_queue'] = chunk_queue state['transcription'] = transcription state['in_transcription'] = in_transcription state['previous_no_vad_audio'] = previous_no_vad_audio state['llm_task'] = llm_task state['instream'] = instream state['stop_signal'] = stop_signal state['args'] = args # Store previous audio chunk with no voice activity previous_no_vad_audio = audio_data # Update state state['mode'] = mode state['chunk_queue'] = chunk_queue state['transcription'] = transcription state['in_transcription'] = in_transcription state['previous_no_vad_audio'] = previous_no_vad_audio state['llm_task'] = llm_task state['instream'] = instream state['stop_signal'] = stop_signal state['args'] = args @app.websocket('/ws') def websocket_endpoint(websocket: fastapi.WebSocket): logging.info('WebSocket connection established') try: while True: time.sleep(state['args']['chunk_size']) audio_chunk = websocket.receive_bytes() if audio_chunk is None: break for audio_data in process_audio(audio_chunk): websocket.send_bytes(audio_data.tobytes()) except Exception as e: logging.error(f'WebSocket error: {e}') finally: logging.info('WebSocket connection closed') websocket.close() @app.get('/') def index(): return fastapi.FileResponse('index.html')