pycui's picture
Add RealChar deployment for HuggingFace (V0)
babeaf6
raw
history blame contribute delete
No virus
2.53 kB
from dataclasses import field
from typing import List
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
from pydantic.dataclasses import dataclass
from starlette.websockets import WebSocket, WebSocketState
@dataclass
class Character:
name: str
llm_system_prompt: str
llm_user_prompt: str
@dataclass
class ConversationHistory:
system_prompt: str = ''
user: list[str] = field(default_factory=list)
ai: list[str] = field(default_factory=list)
def __iter__(self):
yield self.system_prompt
for user_message, ai_message in zip(self.user, self.ai):
yield user_message
yield ai_message
def build_history(conversation_history: ConversationHistory) -> List[BaseMessage]:
history = []
for i, message in enumerate(conversation_history):
if i == 0:
history.append(SystemMessage(content=message))
elif i % 2 == 0:
history.append(AIMessage(content=message))
else:
history.append(HumanMessage(content=message))
return history
class Singleton:
_instances = {}
@classmethod
def get_instance(cls, *args, **kwargs):
""" Static access method. """
if cls not in cls._instances:
cls._instances[cls] = cls(*args, **kwargs)
return cls._instances[cls]
@classmethod
def initialize(cls, *args, **kwargs):
""" Static access method. """
if cls not in cls._instances:
cls._instances[cls] = cls(*args, **kwargs)
class ConnectionManager(Singleton):
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
async def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
print(f"Client #{id(websocket)} left the chat")
# await self.broadcast_message(f"Client #{id(websocket)} left the chat")
async def send_message(self, message: str, websocket: WebSocket):
if websocket.application_state == WebSocketState.CONNECTED:
await websocket.send_text(message)
async def broadcast_message(self, message: str):
for connection in self.active_connections:
if connection.application_state == WebSocketState.CONNECTED:
await connection.send_text(message)
def get_connection_manager():
return ConnectionManager.get_instance()