import asyncio import os import uuid from fastapi import APIRouter, Depends, HTTPException, Path, WebSocket, WebSocketDisconnect, Query from firebase_admin import auth from firebase_admin.exceptions import FirebaseError from requests import Session from realtime_ai_character.audio.speech_to_text import (SpeechToText, get_speech_to_text) from realtime_ai_character.audio.text_to_speech import (TextToSpeech, get_text_to_speech) from realtime_ai_character.character_catalog.catalog_manager import ( CatalogManager, get_catalog_manager) from realtime_ai_character.database.connection import get_db from realtime_ai_character.llm import (AsyncCallbackAudioHandler, AsyncCallbackTextHandler, get_llm, LLM) from realtime_ai_character.logger import get_logger from realtime_ai_character.models.interaction import Interaction from realtime_ai_character.utils import (ConversationHistory, build_history, get_connection_manager) logger = get_logger(__name__) router = APIRouter() manager = get_connection_manager() GREETING_TXT = 'Hi, my friend, what brings you here today?' async def get_current_user(token: str): """Heler function for auth with Firebase.""" if not token: return "" try: decoded_token = auth.verify_id_token(token) except FirebaseError as e: logger.info(f'Receveid invalid token: {token} with error {e}') raise HTTPException(status_code=401, detail="Invalid authentication credentials") return decoded_token['uid'] @router.websocket("/ws/{client_id}") async def websocket_endpoint(websocket: WebSocket, client_id: int = Path(...), api_key: str = Query(None), llm_model: str = Query(default=os.getenv( 'LLM_MODEL_USE', 'gpt-3.5-turbo-16k')), token: str = Query(None), db: Session = Depends(get_db), catalog_manager=Depends(get_catalog_manager), speech_to_text=Depends(get_speech_to_text), text_to_speech=Depends(get_text_to_speech)): # Default user_id to client_id. If auth is enabled and token is provided, use # the user_id from the token. user_id = str(client_id) if os.getenv('USE_AUTH', ''): # Do not allow anonymous users to use non-GPT3.5 model. if not token and llm_model != 'gpt-3.5-turbo-16k': await websocket.close(code=1008, reason="Unauthorized") return try: user_id = await get_current_user(token) except HTTPException: await websocket.close(code=1008, reason="Unauthorized") return llm = get_llm(model=llm_model) await manager.connect(websocket) try: main_task = asyncio.create_task( handle_receive(websocket, client_id, db, llm, catalog_manager, speech_to_text, text_to_speech)) await asyncio.gather(main_task) except WebSocketDisconnect: await manager.disconnect(websocket) await manager.broadcast_message(f"User #{user_id} left the chat") async def handle_receive(websocket: WebSocket, client_id: int, db: Session, llm: LLM, catalog_manager: CatalogManager, speech_to_text: SpeechToText, text_to_speech: TextToSpeech): try: conversation_history = ConversationHistory() # TODO: clean up client_id once migration is done. user_id = str(client_id) session_id = str(uuid.uuid4().hex) # 0. Receive client platform info (web, mobile, terminal) data = await websocket.receive() if data['type'] != 'websocket.receive': raise WebSocketDisconnect('disconnected') platform = data['text'] logger.info(f"User #{user_id}:{platform} connected to server with " f"session_id {session_id}") # 1. User selected a character character = None character_list = list(catalog_manager.characters.keys()) user_input_template = 'Context:{context}\n User:{query}' while not character: character_message = "\n".join([ f"{i+1} - {character}" for i, character in enumerate(character_list) ]) await manager.send_message( message= f"Select your character by entering the corresponding number:\n" f"{character_message}\n", websocket=websocket) data = await websocket.receive() if data['type'] != 'websocket.receive': raise WebSocketDisconnect('disconnected') if not character and 'text' in data: selection = int(data['text']) if selection > len(character_list) or selection < 1: await manager.send_message( message= f"Invalid selection. Select your character [" f"{', '.join(catalog_manager.characters.keys())}]\n", websocket=websocket) continue character = catalog_manager.get_character( character_list[selection - 1]) conversation_history.system_prompt = character.llm_system_prompt user_input_template = character.llm_user_prompt logger.info( f"User #{user_id} selected character: {character.name}") tts_event = asyncio.Event() tts_task = None previous_transcript = None token_buffer = [] # Greet the user await manager.send_message(message=GREETING_TXT, websocket=websocket) tts_task = asyncio.create_task( text_to_speech.stream( text=GREETING_TXT, websocket=websocket, tts_event=tts_event, characater_name=character.name, first_sentence=True, )) # Send end of the greeting so the client knows when to start listening await manager.send_message(message='[end]\n', websocket=websocket) async def on_new_token(token): return await manager.send_message(message=token, websocket=websocket) async def stop_audio(): if tts_task and not tts_task.done(): tts_event.set() tts_task.cancel() if previous_transcript: conversation_history.user.append(previous_transcript) conversation_history.ai.append(' '.join(token_buffer)) token_buffer.clear() try: await tts_task except asyncio.CancelledError: pass tts_event.clear() while True: data = await websocket.receive() if data['type'] != 'websocket.receive': raise WebSocketDisconnect('disconnected') # handle text message if 'text' in data: msg_data = data['text'] # 0. itermidiate transcript starts with [&] if msg_data.startswith('[&]'): logger.info(f'intermediate transcript: {msg_data}') if not os.getenv('EXPERIMENT_CONVERSATION_UTTERANCE', ''): continue asyncio.create_task(stop_audio()) asyncio.create_task( llm.achat_utterances( history=build_history(conversation_history), user_input=msg_data, callback=AsyncCallbackTextHandler( on_new_token, []), audioCallback=AsyncCallbackAudioHandler( text_to_speech, websocket, tts_event, character.name))) continue # 1. Send message to LLM print('response = await llm.achat, user_input', msg_data) response = await llm.achat( history=build_history(conversation_history), user_input=msg_data, user_input_template=user_input_template, callback=AsyncCallbackTextHandler(on_new_token, token_buffer), audioCallback=AsyncCallbackAudioHandler( text_to_speech, websocket, tts_event, character.name), character=character) # 2. Send response to client await manager.send_message(message='[end]\n', websocket=websocket) # 3. Update conversation history conversation_history.user.append(msg_data) conversation_history.ai.append(response) token_buffer.clear() # 4. Persist interaction in the database Interaction(client_id=client_id, user_id=user_id, session_id=session_id, client_message_unicode=msg_data, server_message_unicode=response, platform=platform, action_type='text').save(db) # handle binary message(audio) elif 'bytes' in data: binary_data = data['bytes'] # 1. Transcribe audio transcript: str = speech_to_text.transcribe( binary_data, platform=platform, prompt=character.name).strip() # ignore audio that picks up background noise if (not transcript or len(transcript) < 2): continue # 2. Send transcript to client await manager.send_message( message=f'[+]You said: {transcript}', websocket=websocket) # 3. stop the previous audio stream, if new transcript is received await stop_audio() previous_transcript = transcript async def tts_task_done_call_back(response): # Send response to client, [=] indicates the response is done await manager.send_message(message='[=]', websocket=websocket) # Update conversation history conversation_history.user.append(transcript) conversation_history.ai.append(response) token_buffer.clear() # Persist interaction in the database Interaction(client_id=client_id, user_id=user_id, session_id=session_id, client_message_unicode=transcript, server_message_unicode=response, platform=platform, action_type='audio').save(db) # 4. Send message to LLM tts_task = asyncio.create_task( llm.achat(history=build_history(conversation_history), user_input=transcript, user_input_template=user_input_template, callback=AsyncCallbackTextHandler( on_new_token, token_buffer, tts_task_done_call_back), audioCallback=AsyncCallbackAudioHandler( text_to_speech, websocket, tts_event, character.name), character=character)) except WebSocketDisconnect: logger.info(f"User #{user_id} closed the connection") await manager.disconnect(websocket) return