Spaces:
Sleeping
Sleeping
File size: 4,612 Bytes
bdafe83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import hashlib
import time
from dataclasses import dataclass
from typing import List, Union
from uuid import uuid1
# Preserved roles
SYSTEM_NAME = "System"
MODERATOR_NAME = "Moderator"
def _hash(input: str):
"""
Helper function that generates a SHA256 hash of a given input string.
Parameters:
input (str): The input string to be hashed.
Returns:
str: The SHA256 hash of the input string.
"""
hex_dig = hashlib.sha256(input.encode()).hexdigest()
return hex_dig
@dataclass
class Message:
"""
Represents a message in the chatArena environment.
Attributes:
agent_name (str): Name of the agent who sent the message.
content (str): Content of the message.
turn (int): The turn at which the message was sent.
timestamp (int): Wall time at which the message was sent. Defaults to current time in nanoseconds.
visible_to (Union[str, List[str]]): The receivers of the message. Can be a single agent, multiple agents, or 'all'. Defaults to 'all'.
msg_type (str): Type of the message, e.g., 'text'. Defaults to 'text'.
logged (bool): Whether the message is logged in the database. Defaults to False.
"""
agent_name: str
content: str
turn: int
timestamp: int = time.time_ns()
visible_to: Union[str, List[str]] = "all"
msg_type: str = "text"
logged: bool = False # Whether the message is logged in the database
@property
def msg_hash(self):
# Generate a unique message id given the content, timestamp and role
return _hash(
f"agent: {self.agent_name}\ncontent: {self.content}\ntimestamp: {str(self.timestamp)}\nturn: {self.turn}\nmsg_type: {self.msg_type}"
)
class MessagePool:
"""
A pool to manage the messages in the chatArena environment.
The pool is essentially a list of messages, and it allows a unified treatment of the visibility of the messages.
It supports two configurations for step definition: multiple players can act in the same turn (like in rock-paper-scissors).
Agents can only see the messages that 1) were sent before the current turn, and 2) are visible to the current role.
"""
def __init__(self):
"""Initialize the MessagePool with a unique conversation ID."""
self.conversation_id = str(uuid1())
self._messages: List[
Message
] = []
self._last_message_idx = 0
def reset(self):
"""Clear the message pool."""
self._messages = []
def append_message(self, message: Message):
"""
Append a message to the pool.
Parameters:
message (Message): The message to be added to the pool.
"""
self._messages.append(message)
def print(self):
"""Print all the messages in the pool."""
for message in self._messages:
print(f"[{message.agent_name}->{message.visible_to}]: {message.content}")
@property
def last_turn(self):
"""
Get the turn of the last message in the pool.
Returns:
int: The turn of the last message.
"""
if len(self._messages) == 0:
return 0
else:
return self._messages[-1].turn
@property
def last_message(self):
"""
Get the last message in the pool.
Returns:
Message: The last message.
"""
if len(self._messages) == 0:
return None
else:
return self._messages[-1]
def get_all_messages(self) -> List[Message]:
"""
Get all the messages in the pool.
Returns:
List[Message]: A list of all messages.
"""
return self._messages
def get_visible_messages(self, agent_name, turn: int) -> List[Message]:
"""
Get all the messages that are visible to a given agent before a specified turn.
Parameters:
agent_name (str): The name of the agent.
turn (int): The specified turn.
Returns:
List[Message]: A list of visible messages.
"""
# Get the messages before the current turn
prev_messages = [message for message in self._messages if message.turn < turn]
visible_messages = []
for message in prev_messages:
if (
message.visible_to == "all"
or agent_name in message.visible_to
or agent_name == "Moderator"
):
visible_messages.append(message)
return visible_messages
|