Spaces:
Build error
Build error
import os | |
import queue | |
import asyncio | |
import concurrent.futures | |
import functools | |
import io | |
import sys | |
import random | |
from threading import Thread | |
import time | |
from dotenv import load_dotenv | |
import pyaudio | |
import speech_recognition as sr | |
import websockets | |
from aioconsole import ainput # for async input | |
from pydub import AudioSegment | |
from simpleaudio import WaveObject | |
load_dotenv() | |
executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) | |
web2_initial_message = True | |
CHUNK = 1024 | |
FORMAT = pyaudio.paInt16 | |
CHANNELS = 1 | |
RATE = 44100 | |
class AudioPlayer: | |
def __init__(self): | |
self.play_thread = None | |
self.stop_flag = False | |
self.queue = queue.Queue() | |
def play_audio(self): | |
while not self.stop_flag or not self.queue.empty(): | |
try: | |
wav_data = self.queue.get_nowait() | |
except queue.Empty: | |
continue | |
wave_obj = WaveObject.from_wave_file(wav_data) | |
play_obj = wave_obj.play() | |
while play_obj.is_playing() and not self.stop_flag: | |
time.sleep(0.1) | |
if self.stop_flag: | |
play_obj.stop() | |
def start_playing(self, wav_data): | |
self.stop_flag = False | |
self.queue.put(wav_data) | |
if self.play_thread is None or not self.play_thread.is_alive(): | |
self.play_thread = Thread(target=self.play_audio) | |
self.play_thread.start() | |
def stop_playing(self): | |
if self.play_thread and self.play_thread.is_alive(): | |
self.stop_flag = True | |
self.play_thread.join() | |
self.play_thread = None | |
def add_to_queue(self, wav_data): | |
self.queue.put(wav_data) | |
audio_player = AudioPlayer() | |
def get_input_device_id(): | |
p = pyaudio.PyAudio() | |
devices = [(i, p.get_device_info_by_index(i)['name']) | |
for i in range(p.get_device_count()) | |
if p.get_device_info_by_index(i).get('maxInputChannels')] | |
print('Available devices:') | |
for id, name in devices: | |
print(f"Device id {id} - {name}") | |
return int(input('Please select device id: ')) | |
async def handle_audio(websocket, device_id): | |
with sr.Microphone(device_index=device_id, sample_rate=RATE) as source: | |
recognizer = sr.Recognizer() | |
print('Source sample rate: ', source.SAMPLE_RATE) | |
print('Source width: ', source.SAMPLE_WIDTH) | |
print('Adjusting for ambient noise...Wait for 2 seconds') | |
recognizer.energy_threshold = 5000 | |
recognizer.dynamic_energy_ratio = 6 | |
recognizer.dynamic_energy_adjustment_damping = 0.85 | |
recognizer.non_speaking_duration = 0.5 | |
recognizer.pause_threshold = 0.8 | |
recognizer.phrase_threshold = 0.5 | |
recognizer.adjust_for_ambient_noise(source, duration=2) | |
listen_func = functools.partial( | |
recognizer.listen, source, phrase_time_limit=30) | |
print('Okay, start talking!') | |
while True: | |
print('[*]', end="") # indicate that we are listening | |
audio = await asyncio.get_event_loop().run_in_executor(executor, listen_func) | |
await websocket.send(audio.frame_data) | |
print('[-]', end="") # indicate that we are done listening | |
await asyncio.sleep(2) | |
async def handle_text(websocket): | |
print('You: ', end="", flush=False) | |
while True: | |
message = await ainput() | |
await websocket.send(message) | |
initial_message = True | |
async def receive_message(websocket, websocket2): | |
web1_init_message = await websocket.recv() | |
print('web1_init_message: ', web1_init_message) | |
web2_init_message = await websocket2.recv() | |
print('web1_init_message: ', web2_init_message) | |
message_to_websocket1 = "Suppose I'm Steve Jobs now. What question do you have for me?" | |
await websocket.send(message_to_websocket1) | |
web1_message = '' | |
while True: | |
try: | |
message = await websocket.recv() | |
print('here') | |
except websockets.exceptions.ConnectionClosedError as e: | |
print("Connection closed unexpectedly: ", e) | |
break | |
except Exception as e: | |
print("An error occurred: ", e) | |
break | |
if isinstance(message, str): | |
if message == '[end]\n': | |
if not web1_message: | |
continue | |
# remove everything before '> ' in the message | |
message_to_websocket2 = web1_message[web1_message.find('> ') + 2:] | |
# print('message_to_websocket2: ', message_to_websocket2) | |
await websocket2.send(message_to_websocket2) | |
web2_message = '' | |
j = 0 | |
while True: | |
j += 1 | |
try: | |
message = await websocket2.recv() | |
except websockets.exceptions.ConnectionClosedError as e: | |
print("Connection closed unexpectedly: ", e) | |
break | |
except Exception as e: | |
print("An error occurred: ", e) | |
break | |
if isinstance(message, str): | |
if message == '[end]\n': | |
# print('\nWebsocket2: ', end="", flush=False) | |
if not web2_message: | |
# print('skip') | |
continue | |
# remove everything before '> ' in the message | |
print(web2_message) | |
message_from_websocket2 = web2_message[web2_message.find('> ') + 2:] | |
await websocket.send(message_from_websocket2) | |
break | |
elif message.startswith('[+]'): | |
# stop playing audio | |
audio_player.stop_playing() | |
# indicate the transcription is done | |
# print(f"\nnWebsocket2: {message}", end="\n", flush=False) | |
elif message.startswith('[=]'): | |
# indicate the response is done | |
# print(f"nWebsocket2: {web2_message}", end="\n", flush=False) | |
pass | |
else: | |
# print('\nmessage++\n') | |
web2_message += message | |
elif isinstance(message, bytes): | |
global web2_initial_message | |
if web2_initial_message: | |
web2_initial_message = False | |
continue | |
audio_data = io.BytesIO(message) | |
audio = AudioSegment.from_mp3(audio_data) | |
wav_data = io.BytesIO() | |
audio.export(wav_data, format="wav") | |
# Start playing audio | |
audio_player.start_playing(wav_data) | |
elif message.startswith('[+]'): | |
# stop playing audio | |
audio_player.stop_playing() | |
# indicate the transcription is done | |
print(f"\n{message}", end="\n", flush=False) | |
elif message.startswith('[=]'): | |
# indicate the response is done | |
print(f"{message}", end="\n", flush=False) | |
else: | |
web1_message += message | |
print(f"{message}", end="", flush=False) | |
elif isinstance(message, bytes): | |
audio_data = io.BytesIO(message) | |
audio = AudioSegment.from_mp3(audio_data) | |
wav_data = io.BytesIO() | |
audio.export(wav_data, format="wav") | |
# Start playing audio | |
audio_player.start_playing(wav_data) | |
else: | |
print("Unexpected message") | |
break | |
def select_model(): | |
llm_model_selection = input( | |
'1: gpt-3.5-turbo-16k \n' | |
'2: gpt-4 \n' | |
'3: claude-2 \n' | |
'Select llm model:') | |
if llm_model_selection == '1': | |
llm_model = 'gpt-3.5-turbo-16k' | |
elif llm_model_selection == '2': | |
llm_model = 'gpt-4' | |
elif llm_model_selection == '3': | |
llm_model = 'claude-2' | |
return llm_model | |
async def start_client(client_id, url): | |
api_key = os.getenv('AUTH_API_KEY') | |
llm_model = select_model() | |
uri = f"ws://{url}/ws/{client_id}?api_key={api_key}&llm_model={llm_model}" | |
async with websockets.connect(uri) as websocket: | |
uri2 = f"ws://{url}/ws/9999999?api_key={api_key}&llm_model={llm_model}" | |
# send client platform info | |
async with websockets.connect(uri2) as websocket2: | |
await websocket.send('terminal') | |
await websocket2.send('terminal') | |
print(f"Client #{client_id} connected to websocket1") | |
print(f"Client 9999999 connected to websocket2") | |
welcome_message = await websocket.recv() | |
welcome_message2 = await websocket2.recv() | |
print(f"{welcome_message}") | |
character = input('Select character: ') | |
await websocket.send(character) | |
await websocket2.send('6') | |
mode = input('Select mode (1: audio, 2: text): ') | |
if mode.lower() == '1': | |
device_id = get_input_device_id() | |
send_task = asyncio.create_task(handle_audio(websocket, device_id)) | |
else: | |
send_task = asyncio.create_task(handle_text(websocket)) | |
receive_task = asyncio.create_task(receive_message(websocket, websocket2)) | |
await asyncio.gather(receive_task, send_task) | |
async def main(url): | |
client_id = random.randint(0, 1000000) | |
task = asyncio.create_task(start_client(client_id, url)) | |
try: | |
await task | |
except KeyboardInterrupt: | |
task.cancel() | |
await asyncio.wait_for(task, timeout=None) | |
print("Client stopped by user") | |
if __name__ == "__main__": | |
url = sys.argv[1] if len(sys.argv) > 1 else 'localhost:8000' | |
asyncio.run(main(url)) | |