{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/barrel/aai/.venv/lib/python3.10/site-packages/pyannote/audio/core/io.py:43: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call.\n",
" torchaudio.set_audio_backend(\"soundfile\")\n"
]
}
],
"source": [
"import gradio as gr\n",
"import numpy as np\n",
"import torch\n",
"import torchaudio\n",
"from silero_vad import get_speech_timestamps, load_silero_vad\n",
"import whisperx\n",
"import openai\n",
"import asyncio\n",
"import edge_tts\n",
"import gc\n",
"import logging\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-09-23 13:50:24,408 - INFO - Using device: cuda\n",
"2024-09-23 13:50:24,660 - INFO - Loaded Silero VAD model\n",
"Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../.cache/torch/whisperx-vad-segmentation.bin`\n",
"2024-09-23 13:50:24,994 - INFO - Loaded WhisperX model\n",
"2024-09-23 13:50:24,994 - INFO - Set OpenAI API key\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"No language specified, language will be first be detected for each audio file (increases inference time).\n",
"Model was trained with pyannote.audio 0.0.1, yours is 3.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.\n",
"Model was trained with torch 1.10.0+cu102, yours is 2.3.1+cu121. Bad things might happen unless you revert torch to 1.x.\n"
]
}
],
"source": [
"# Configure logging\n",
"logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n",
"\n",
"# Load Silero VAD model\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"logging.info(f'Using device: {device}')\n",
"vad_model = load_silero_vad().to(device) # Ensure the model is on the correct device\n",
"logging.info('Loaded Silero VAD model')\n",
"\n",
"# Load WhisperX model\n",
"whisper_model = whisperx.load_model(\"tiny\", device, compute_type=\"float16\")\n",
"logging.info('Loaded WhisperX model')\n",
"\n",
"openai.api_key = \"\"\n",
"logging.info('Set OpenAI API key')\n",
"\n",
"# TTS Voice\n",
"TTS_VOICE = \"en-GB-SoniaNeural\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchaudio\n",
"import logging\n",
"\n",
"def check_vad(audio_data, sample_rate):\n",
" logging.info('Checking voice activity')\n",
" # Resample to 16000 Hz if necessary\n",
" target_sample_rate = 16000\n",
" if sample_rate != target_sample_rate:\n",
" resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)\n",
" audio_tensor = resampler(torch.from_numpy(audio_data))\n",
" else:\n",
" audio_tensor = torch.from_numpy(audio_data)\n",
" audio_tensor = audio_tensor.to(device)\n",
"\n",
" # Log audio data details\n",
" logging.info(f'Audio tensor shape: {audio_tensor.shape}, dtype: {audio_tensor.dtype}, device: {audio_tensor.device}')\n",
"\n",
" # Get speech timestamps with optimized parameters\n",
" speech_timestamps = get_speech_timestamps(\n",
" audio=audio_tensor,\n",
" model=vad_model,\n",
" sampling_rate=target_sample_rate,\n",
" min_speech_duration_ms=250,\n",
" min_silence_duration_ms=80,\n",
" speech_pad_ms=30\n",
" )\n",
" logging.info(f'Found {len(speech_timestamps)} speech timestamps')\n",
" return len(speech_timestamps) > 0"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def transcript(audio_data, sample_rate):\n",
" logging.info('Transcribing audio')\n",
" # Resample to 16000 Hz if necessary\n",
" target_sample_rate = 16000\n",
" if sample_rate != target_sample_rate:\n",
" resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)\n",
" audio_data = resampler(torch.from_numpy(audio_data)).numpy()\n",
" else:\n",
" audio_data = audio_data\n",
"\n",
" # Transcribe\n",
" batch_size = 16 # Adjust as needed\n",
" result = whisper_model.transcribe(audio_data, batch_size=batch_size)\n",
" text = result['segments'][0]['text']\n",
" logging.info(f'Transcription result: {text}')\n",
" # Clear GPU memory\n",
" del result\n",
" gc.collect()\n",
" if device == 'cuda':\n",
" torch.cuda.empty_cache()\n",
" return text"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from openai import OpenAI\n",
"\n",
"openai_client = OpenAI(api_key='')\n",
"\n",
"def llm(text):\n",
" logging.info('Getting response from OpenAI API')\n",
" response = openai_client.chat.completions.create(\n",
" model=\"gpt-4o\", # Updated to a more recent model\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You respond to the following transcript from the conversation that you are having with the user.\"},\n",
" {\"role\": \"user\", \"content\": text} \n",
" ],\n",
" stream=True,\n",
" temperature=0.7, # Optional: Adjust as needed\n",
" top_p=0.9, # Optional: Adjust as needed\n",
" )\n",
" for chunk in response:\n",
" yield chunk.choices[0].delta.content"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def tts_streaming(text_stream):\n",
" logging.info('Performing TTS')\n",
" buffer = \"\"\n",
" punctuation = {'.', '!', '?'}\n",
" for text_chunk in text_stream:\n",
" if text_chunk is not None:\n",
" buffer += text_chunk\n",
" # Check for sentence completion\n",
" sentences = []\n",
" start = 0\n",
" for i, char in enumerate(buffer):\n",
" if (char in punctuation):\n",
" sentences.append(buffer[start:i+1].strip())\n",
" start = i+1\n",
" buffer = buffer[start:]\n",
"\n",
" for sentence in sentences:\n",
" if sentence:\n",
" communicate = edge_tts.Communicate(sentence, TTS_VOICE)\n",
" for chunk in communicate.stream_sync():\n",
" if chunk[\"type\"] == \"audio\":\n",
" yield chunk[\"data\"]\n",
" # Process any remaining text\n",
" if buffer.strip():\n",
" communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE)\n",
" for chunk in communicate.stream_sync():\n",
" if chunk[\"type\"] == \"audio\":\n",
" yield chunk[\"data\"]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# load audio to numpy array\n",
"def load_audio(audio_path):\n",
" audio_data, sample_rate = torchaudio.load(audio_path)\n",
" audio_data = audio_data[0].numpy()\n",
" if audio_data.ndim > 1:\n",
" audio_data = np.mean(audio_data, axis=1)\n",
" return audio_data, sample_rate"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Testing the pipeline\n",
"\n",
"# 1. Load audio\n",
"audio_path = 'audio.mp3'\n",
"audio_data, sample_rate = load_audio(audio_path)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-09-23 13:50:49,248 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,253 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,494 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,495 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,498 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,506 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,507 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,511 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,518 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,519 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,523 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,531 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,532 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,535 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,543 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,543 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,546 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,557 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,558 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,561 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,569 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,570 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,573 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,581 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,582 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,585 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,593 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,593 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,595 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,604 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,605 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,607 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,616 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,617 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,619 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,628 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,629 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,632 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,640 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,641 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,644 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,651 - INFO - Found 0 speech timestamps\n",
"2024-09-23 13:50:49,652 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,654 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,665 - INFO - Found 0 speech timestamps\n",
"2024-09-23 13:50:49,665 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,669 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,678 - INFO - Found 0 speech timestamps\n",
"2024-09-23 13:50:49,678 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,681 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,690 - INFO - Found 0 speech timestamps\n",
"2024-09-23 13:50:49,691 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,693 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,703 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,704 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,707 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,718 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,719 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,722 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,731 - INFO - Found 0 speech timestamps\n",
"2024-09-23 13:50:49,732 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,734 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,743 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,744 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,746 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,759 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,760 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,762 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,773 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,773 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,776 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,784 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,785 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,789 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,798 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,799 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,801 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,810 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,810 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,813 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,821 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,822 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,824 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,833 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,834 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,836 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,844 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,845 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,847 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,856 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,857 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,860 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,871 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,872 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,875 - INFO - Audio tensor shape: torch.Size([8000]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,883 - INFO - Found 1 speech timestamps\n",
"2024-09-23 13:50:49,884 - INFO - Checking voice activity\n",
"2024-09-23 13:50:49,887 - INFO - Audio tensor shape: torch.Size([644]), dtype: torch.float32, device: cuda:0\n",
"2024-09-23 13:50:49,889 - INFO - Found 0 speech timestamps\n"
]
}
],
"source": [
"chunk_size = 500 # ms\n",
"chunk_size_samples = int(sample_rate * chunk_size / 1000)\n",
"chunks = [audio_data[i:i + chunk_size_samples] for i in range(0, len(audio_data), chunk_size_samples)]\n",
"\n",
"# 2. Check voice activity\n",
"voice_activity = [check_vad(chunk, sample_rate) for chunk in chunks]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-09-23 13:50:50,691 - INFO - Transcribing audio\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warning: audio is shorter than 30s, language detection may be inaccurate.\n",
"Detected language: en (0.99) in first 30s of audio...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-09-23 13:50:51,041 - INFO - Transcription result: What's this the reporter tried to make a hit piece about Wu Kong is not happy. I wonder why? What a shock. Well wait a second. Should we get to the bottom of this?\n"
]
}
],
"source": [
"text = transcript(audio_data, sample_rate)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"text = llm(text)\n",
"tts_audio = tts_streaming(text)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-09-23 13:50:53,979 - INFO - Performing TTS\n",
"2024-09-23 13:50:53,980 - INFO - Getting response from OpenAI API\n",
"2024-09-23 13:50:54,236 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import Audio\n",
"from pydub import AudioSegment\n",
"from io import BytesIO\n",
"import base64\n",
"\n",
"# Combine audio chunk bytes\n",
"audio_bytes = b''.join(tts_audio)\n",
"\n",
"# Play audio\n",
"audio_segment = AudioSegment.from_file(BytesIO(audio_bytes), format=\"raw\", frame_rate=16000, channels=1, sample_width=2)\n",
"\n",
"Audio(audio_bytes, rate=16000)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"np_audio = np.frombuffer(audio_bytes, dtype=np.int16)\n",
"\n",
"# export audio with numpy\n",
"np_audio.tofile(\"output.wav\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# function to process audio input\n",
"def process_audio_old(audio, state):\n",
" \"\"\"\n",
" Flow:\n",
" 1. Sleep for 0.5 seconds to allow the audio buffer to accumulate\n",
" 2. Check for voice activity\n",
" 3. If voice activity is detected and mode is \"idle\":\n",
" - Set mode to \"listening\"\n",
" 4. If voice activity is detected and mode is \"speaking\":\n",
" - Stop the llm and tts tasks\n",
" - Set mode to \"listening\"\n",
" 5. If voice activity is detected and mode is \"listening\":\n",
" - If there's previous_no_vad_audio, add it to chunk_queue\n",
" - Start accumulating audio chunks in chunk_queue\n",
" - If the length of chunk_queue is greater than 3 seconds\n",
" - Get the first 2 seconds of audio from chunk_queue\n",
" - Run transcription on the first 2 seconds\n",
" - Store the transcription in the state\n",
" - Remove the first 2 seconds of audio from chunk_queue\n",
" 6. If voice activity is not detected:\n",
" - If mode is \"listening\" and there's audio in chunk_queue\n",
" - Add the chunk to chunk_queue\n",
" - Set mode to \"processing\"\n",
" - Run transcription on the leftover audio in chunk_queue\n",
" - Store the transcription in the state\n",
" - Set the mode to \"processing\"\n",
" - If mode is \"processing\"\n",
" - Check if there's any leftover audio in chunk_queue\n",
" - If there is, run transcription on the leftover audio\n",
" - Store the transcription in the state\n",
" - Start LLM and TTS in the background\n",
" - Set mode to \"responding\"\n",
" - If mode is \"responding\"\n",
" - Get the audio byte chunks from TTS\n",
" - Output the full audio\n",
" - Set mode to \"idle\"\n",
" - If mode is \"idle\"\n",
" - do nothing\n",
" \n",
" Ex: Gradio Streaming Audio Example:\n",
" import gradio as gr\n",
" import numpy as np\n",
" import time\n",
"\n",
" def add_to_stream(audio, instream):\n",
" time.sleep(1)\n",
" if audio is None:\n",
" return gr.update(), instream\n",
" if instream is None:\n",
" ret = audio\n",
" else:\n",
" ret = (audio[0], np.concatenate((instream[1], audio[1])))\n",
" return ret, ret\n",
"\n",
"\n",
" with gr.Blocks() as demo:\n",
" inp = gr.Audio(source=\"microphone\")\n",
" out = gr.Audio()\n",
" stream = gr.State()\n",
" clear = gr.Button(\"Clear\")\n",
"\n",
" inp.stream(add_to_stream, [inp, stream], [out, stream])\n",
" clear.click(lambda: [None, None, None], None, [inp, out, stream])\n",
"\n",
"\n",
" if __name__ == \"__main__\":\n",
" demo.launch()\n",
" \"\"\"\n",
" \"\"\"old code:\n",
" time.sleep(0.5)\n",
" if audio is None:\n",
" return None, state\n",
"\n",
" sample_rate, audio_data = audio\n",
" audio_data = np.array(audio_data, dtype=np.float32)\n",
"\n",
" # Convert to mono if stereo\n",
" if audio_data.ndim > 1:\n",
" audio_data = np.mean(audio_data, axis=1)\n",
"\n",
" # Check for voice activity\n",
" vad_result = check_vad(audio_data, sample_rate)\n",
" if vad_result:\n",
" logging.info('Voice activity detected')\n",
" # Voice activity detected\n",
" if state.get(\"previous_audio_chunk\") is not None:\n",
" state[\"audio_buffer\"].append(state[\"previous_audio_chunk\"])\n",
" state[\"audio_buffer\"].append(audio_data)\n",
" state[\"is_speaking\"] = True\n",
" state[\"previous_audio_chunk\"] = audio_data\n",
"\n",
" # Update total speaking time\n",
" chunk_duration = len(audio_data) / sample_rate\n",
" state[\"total_speaking_time\"] += chunk_duration\n",
"\n",
" # Start transcription after 3 seconds\n",
" if state[\"total_speaking_time\"] >= 3.0 and not state[\"transcription_started\"]:\n",
" logging.info('Starting transcription')\n",
" # Start transcribing the first 2 seconds\n",
" accumulated_audio = np.concatenate(state[\"audio_buffer\"])\n",
" first_two_seconds_samples = int(2.0 * sample_rate)\n",
" first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]\n",
"\n",
" # Transcribe asynchronously\n",
" transcribed_text = transcript(first_two_seconds_audio, sample_rate)\n",
" state[\"transcription\"] += transcribed_text\n",
" state[\"transcription_started\"] = True\n",
"\n",
" # Start LLM and TTS in the background\n",
" state[\"llm_task\"] = llm_and_tts(state[\"transcription\"], state)\n",
" else:\n",
" if state[\"is_speaking\"]:\n",
" logging.info('Voice activity ended')\n",
" # Voice activity just ended\n",
" # Process the accumulated audio\n",
" full_audio = np.concatenate(state[\"audio_buffer\"])\n",
" # Reset the state\n",
" state[\"audio_buffer\"] = []\n",
" state[\"is_speaking\"] = False\n",
" state[\"total_speaking_time\"] = 0.0\n",
" state[\"transcription_started\"] = False\n",
"\n",
" # Transcribe the remaining audio\n",
" transcribed_text = transcript(full_audio, sample_rate)\n",
" state[\"transcription\"] += transcribed_text\n",
"\n",
" # Start LLM and TTS if not already started\n",
" if not state.get(\"llm_task\"):\n",
" state[\"llm_task\"] = llm_and_tts(state[\"transcription\"], state)\n",
"\n",
" # Check if there's audio to output\n",
" if state.get(\"tts_audio_chunks\"):\n",
" logging.info('Outputting audio')\n",
" # Collect audio chunks\n",
" audio_chunks = state[\"tts_audio_chunks\"]\n",
" state[\"tts_audio_chunks\"] = []\n",
" response_audio = b\"\".join(audio_chunks)\n",
" np_response_audio = np.frombuffer(response_audio, dtype=np.int16)\n",
" return (sample_rate, np_response_audio), state\n",
"\n",
" # Collect the last chunk if it exists\n",
" if state.get(\"previous_audio_chunk\") is not None:\n",
" state[\"audio_buffer\"].append(state[\"previous_audio_chunk\"])\n",
"\n",
" return None, state\n",
" \"\"\"\n",
" ...\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Function to process audio input\n",
"def process_audio_chunk(audio, state):\n",
" if audio is None:\n",
" return None, state\n",
" if state is None:\n",
" state = {\n",
" 'mode': 'idle',\n",
" 'chunk_queue': [],\n",
" 'transcription': '',\n",
" 'previous_no_vad_audio': None,\n",
" 'tts_audio_chunks': [],\n",
" 'llm_task': None,\n",
" 'instream': None,\n",
" }\n",
"\n",
" sample_rate, audio_data = audio\n",
" audio_data = np.array(audio_data, dtype=np.float32)\n",
"\n",
" # Convert to mono if stereo\n",
" if audio_data.ndim > 1:\n",
" audio_data = np.mean(audio_data, axis=1)\n",
"\n",
" mode = state['mode']\n",
" chunk_queue = state['chunk_queue']\n",
" transcription = state['transcription']\n",
" previous_no_vad_audio = state['previous_no_vad_audio']\n",
" tts_audio_chunks = state['tts_audio_chunks']\n",
" llm_task = state['llm_task']\n",
" instream = state['instream']\n",
"\n",
" # Check for voice activity\n",
" vad_result = check_vad(audio_data, sample_rate)\n",
"\n",
" if vad_result:\n",
" logging.info(f'Voice activity detected in mode: {mode}')\n",
" if mode == 'idle':\n",
" mode = 'listening'\n",
" elif mode == 'speaking':\n",
" # Stop llm and tts tasks\n",
" if llm_task and llm_task.is_alive():\n",
" # Implement task cancellation logic if possible\n",
" logging.info('Stopping LLM and TTS tasks')\n",
" # Since we cannot kill threads directly, we need to handle this in the tasks\n",
" state['stop_signal'] = True\n",
" llm_task.join()\n",
" mode = 'listening'\n",
" \n",
" if vad_result:\n",
" if mode == 'listening':\n",
" if previous_no_vad_audio is not None:\n",
" chunk_queue.append(previous_no_vad_audio)\n",
" previous_no_vad_audio = None\n",
" # Accumulate audio chunks\n",
" chunk_queue.append(audio_data)\n",
" # Calculate the length of chunk_queue in seconds\n",
" total_samples = sum(len(chunk) for chunk in chunk_queue)\n",
" total_duration = total_samples / sample_rate\n",
" if total_duration > 3.0:\n",
" # Get the first 2 seconds of audio\n",
" first_two_seconds_samples = int(2.0 * sample_rate)\n",
" accumulated_audio = np.concatenate(chunk_queue)\n",
" first_two_seconds_audio = accumulated_audio[:first_two_seconds_samples]\n",
" # Run transcription on the first 2 seconds\n",
" transcribed_text = transcript(first_two_seconds_audio, sample_rate)\n",
" transcription += transcribed_text\n",
" # Remove the first 2 seconds from chunk_queue\n",
" remaining_audio = accumulated_audio[first_two_seconds_samples:]\n",
" chunk_queue = [remaining_audio] if len(remaining_audio) > 0 else []\n",
" elif mode == 'speaking':\n",
" # Continue accumulating audio chunks\n",
" chunk_queue.append(audio_data)\n",
" else:\n",
" logging.info(f'No voice activity detected in mode: {mode}')\n",
" if mode == 'listening' and chunk_queue:\n",
" # Add the chunk to chunk_queue\n",
" chunk_queue.append(audio_data)\n",
" # Run transcription on leftover audio in chunk_queue\n",
" accumulated_audio = np.concatenate(chunk_queue)\n",
" transcribed_text = transcript(accumulated_audio, sample_rate)\n",
" transcription += transcribed_text\n",
" # Clear chunk_queue\n",
" chunk_queue = []\n",
" mode = 'processing'\n",
" # Start LLM and TTS in the background\n",
" if not llm_task or not llm_task.is_alive():\n",
" state['stop_signal'] = False\n",
" llm_task = threading.Thread(target=llm_and_tts, args=(transcription, state))\n",
" llm_task.start()\n",
" elif mode == 'processing':\n",
" # Wait for LLM and TTS to finish\n",
" if llm_task and not llm_task.is_alive():\n",
" mode = 'responding'\n",
" elif mode == 'responding':\n",
" # Get the audio byte chunks from TTS\n",
" if tts_audio_chunks:\n",
" logging.info('Outputting audio response')\n",
" # Collect audio chunks\n",
" response_audio = b\"\".join(tts_audio_chunks)\n",
" np_response_audio = np.frombuffer(response_audio, dtype=np.int16)\n",
" \n",
" if instream is None:\n",
" instream = np_response_audio\n",
" else:\n",
" instream = np.concatenate((instream, np_response_audio))\n",
" \n",
" # Clear tts_audio_chunks\n",
" tts_audio_chunks.clear()\n",
" # Reset transcription for next interaction\n",
" transcription = ''\n",
" # Set mode to \"idle\"\n",
" mode = 'idle'\n",
" \n",
" # Update state\n",
" state.update({\n",
" 'mode': mode,\n",
" 'chunk_queue': chunk_queue,\n",
" 'transcription': transcription,\n",
" 'previous_no_vad_audio': previous_no_vad_audio,\n",
" 'tts_audio_chunks': tts_audio_chunks,\n",
" 'llm_task': None,\n",
" 'instream': instream\n",
" })\n",
" return (sample_rate, instream), state\n",
" elif mode == 'idle':\n",
" # Do nothing\n",
" pass\n",
" else:\n",
" # Store the audio when no VAD is detected\n",
" previous_no_vad_audio = audio_data\n",
"\n",
" # Update state\n",
" state.update({\n",
" 'mode': mode,\n",
" 'chunk_queue': chunk_queue,\n",
" 'transcription': transcription,\n",
" 'previous_no_vad_audio': previous_no_vad_audio,\n",
" 'tts_audio_chunks': tts_audio_chunks,\n",
" 'llm_task': llm_task,\n",
" 'instream': instream\n",
" })\n",
"\n",
" return None, state\n",
"\n",
"# Initialize the state\n",
"initial_state = {\n",
" 'mode': 'idle',\n",
" 'chunk_queue': [],\n",
" 'transcription': '',\n",
" 'previous_no_vad_audio': None,\n",
" 'tts_audio_chunks': [],\n",
" 'llm_task': None,\n",
" 'instream': None,\n",
"}\n",
"\n",
"# Create Gradio interface\n",
"with gr.Blocks() as demo:\n",
" gr.Markdown(\"## Voice-Activated Transcription and Response System\")\n",
" audio_input = gr.Audio(sources=\"microphone\", type=\"numpy\", streaming=True)\n",
" state = gr.State(initial_state)\n",
" audio_output = gr.Audio(label=\"Response Audio\", autoplay=True)\n",
" audio_input.stream(process_audio, [audio_input, state], [audio_output, state])\n",
"\n",
"if __name__ == \"__main__\":\n",
" logging.info('Launching Gradio interface')\n",
" demo.launch()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}