Yiqiao Jin
Initial Commit
bdafe83
raw
history blame
4.61 kB
import hashlib
import time
from dataclasses import dataclass
from typing import List, Union
from uuid import uuid1
# Preserved roles
SYSTEM_NAME = "System"
MODERATOR_NAME = "Moderator"
def _hash(input: str):
"""
Helper function that generates a SHA256 hash of a given input string.
Parameters:
input (str): The input string to be hashed.
Returns:
str: The SHA256 hash of the input string.
"""
hex_dig = hashlib.sha256(input.encode()).hexdigest()
return hex_dig
@dataclass
class Message:
"""
Represents a message in the chatArena environment.
Attributes:
agent_name (str): Name of the agent who sent the message.
content (str): Content of the message.
turn (int): The turn at which the message was sent.
timestamp (int): Wall time at which the message was sent. Defaults to current time in nanoseconds.
visible_to (Union[str, List[str]]): The receivers of the message. Can be a single agent, multiple agents, or 'all'. Defaults to 'all'.
msg_type (str): Type of the message, e.g., 'text'. Defaults to 'text'.
logged (bool): Whether the message is logged in the database. Defaults to False.
"""
agent_name: str
content: str
turn: int
timestamp: int = time.time_ns()
visible_to: Union[str, List[str]] = "all"
msg_type: str = "text"
logged: bool = False # Whether the message is logged in the database
@property
def msg_hash(self):
# Generate a unique message id given the content, timestamp and role
return _hash(
f"agent: {self.agent_name}\ncontent: {self.content}\ntimestamp: {str(self.timestamp)}\nturn: {self.turn}\nmsg_type: {self.msg_type}"
)
class MessagePool:
"""
A pool to manage the messages in the chatArena environment.
The pool is essentially a list of messages, and it allows a unified treatment of the visibility of the messages.
It supports two configurations for step definition: multiple players can act in the same turn (like in rock-paper-scissors).
Agents can only see the messages that 1) were sent before the current turn, and 2) are visible to the current role.
"""
def __init__(self):
"""Initialize the MessagePool with a unique conversation ID."""
self.conversation_id = str(uuid1())
self._messages: List[
Message
] = []
self._last_message_idx = 0
def reset(self):
"""Clear the message pool."""
self._messages = []
def append_message(self, message: Message):
"""
Append a message to the pool.
Parameters:
message (Message): The message to be added to the pool.
"""
self._messages.append(message)
def print(self):
"""Print all the messages in the pool."""
for message in self._messages:
print(f"[{message.agent_name}->{message.visible_to}]: {message.content}")
@property
def last_turn(self):
"""
Get the turn of the last message in the pool.
Returns:
int: The turn of the last message.
"""
if len(self._messages) == 0:
return 0
else:
return self._messages[-1].turn
@property
def last_message(self):
"""
Get the last message in the pool.
Returns:
Message: The last message.
"""
if len(self._messages) == 0:
return None
else:
return self._messages[-1]
def get_all_messages(self) -> List[Message]:
"""
Get all the messages in the pool.
Returns:
List[Message]: A list of all messages.
"""
return self._messages
def get_visible_messages(self, agent_name, turn: int) -> List[Message]:
"""
Get all the messages that are visible to a given agent before a specified turn.
Parameters:
agent_name (str): The name of the agent.
turn (int): The specified turn.
Returns:
List[Message]: A list of visible messages.
"""
# Get the messages before the current turn
prev_messages = [message for message in self._messages if message.turn < turn]
visible_messages = []
for message in prev_messages:
if (
message.visible_to == "all"
or agent_name in message.visible_to
or agent_name == "Moderator"
):
visible_messages.append(message)
return visible_messages