Spaces:
Sleeping
Sleeping
import hashlib | |
import time | |
from dataclasses import dataclass | |
from typing import List, Union | |
from uuid import uuid1 | |
# Preserved roles | |
SYSTEM_NAME = "System" | |
MODERATOR_NAME = "Moderator" | |
def _hash(input: str): | |
""" | |
Helper function that generates a SHA256 hash of a given input string. | |
Parameters: | |
input (str): The input string to be hashed. | |
Returns: | |
str: The SHA256 hash of the input string. | |
""" | |
hex_dig = hashlib.sha256(input.encode()).hexdigest() | |
return hex_dig | |
class Message: | |
""" | |
Represents a message in the chatArena environment. | |
Attributes: | |
agent_name (str): Name of the agent who sent the message. | |
content (str): Content of the message. | |
turn (int): The turn at which the message was sent. | |
timestamp (int): Wall time at which the message was sent. Defaults to current time in nanoseconds. | |
visible_to (Union[str, List[str]]): The receivers of the message. Can be a single agent, multiple agents, or 'all'. Defaults to 'all'. | |
msg_type (str): Type of the message, e.g., 'text'. Defaults to 'text'. | |
logged (bool): Whether the message is logged in the database. Defaults to False. | |
""" | |
agent_name: str | |
content: str | |
turn: int | |
timestamp: int = time.time_ns() | |
visible_to: Union[str, List[str]] = "all" | |
msg_type: str = "text" | |
logged: bool = False # Whether the message is logged in the database | |
def msg_hash(self): | |
# Generate a unique message id given the content, timestamp and role | |
return _hash( | |
f"agent: {self.agent_name}\ncontent: {self.content}\ntimestamp: {str(self.timestamp)}\nturn: {self.turn}\nmsg_type: {self.msg_type}" | |
) | |
class MessagePool: | |
""" | |
A pool to manage the messages in the chatArena environment. | |
The pool is essentially a list of messages, and it allows a unified treatment of the visibility of the messages. | |
It supports two configurations for step definition: multiple players can act in the same turn (like in rock-paper-scissors). | |
Agents can only see the messages that 1) were sent before the current turn, and 2) are visible to the current role. | |
""" | |
def __init__(self): | |
"""Initialize the MessagePool with a unique conversation ID.""" | |
self.conversation_id = str(uuid1()) | |
self._messages: List[ | |
Message | |
] = [] | |
self._last_message_idx = 0 | |
def reset(self): | |
"""Clear the message pool.""" | |
self._messages = [] | |
def append_message(self, message: Message): | |
""" | |
Append a message to the pool. | |
Parameters: | |
message (Message): The message to be added to the pool. | |
""" | |
self._messages.append(message) | |
def print(self): | |
"""Print all the messages in the pool.""" | |
for message in self._messages: | |
print(f"[{message.agent_name}->{message.visible_to}]: {message.content}") | |
def last_turn(self): | |
""" | |
Get the turn of the last message in the pool. | |
Returns: | |
int: The turn of the last message. | |
""" | |
if len(self._messages) == 0: | |
return 0 | |
else: | |
return self._messages[-1].turn | |
def last_message(self): | |
""" | |
Get the last message in the pool. | |
Returns: | |
Message: The last message. | |
""" | |
if len(self._messages) == 0: | |
return None | |
else: | |
return self._messages[-1] | |
def get_all_messages(self) -> List[Message]: | |
""" | |
Get all the messages in the pool. | |
Returns: | |
List[Message]: A list of all messages. | |
""" | |
return self._messages | |
def get_visible_messages(self, agent_name, turn: int) -> List[Message]: | |
""" | |
Get all the messages that are visible to a given agent before a specified turn. | |
Parameters: | |
agent_name (str): The name of the agent. | |
turn (int): The specified turn. | |
Returns: | |
List[Message]: A list of visible messages. | |
""" | |
# Get the messages before the current turn | |
prev_messages = [message for message in self._messages if message.turn < turn] | |
visible_messages = [] | |
for message in prev_messages: | |
if ( | |
message.visible_to == "all" | |
or agent_name in message.visible_to | |
or agent_name == "Moderator" | |
): | |
visible_messages.append(message) | |
return visible_messages | |