Spaces:
Running
Running
import os | |
from dotenv import load_dotenv | |
from swarm_models import OpenAIChat | |
from swarms.structs.agent import Agent | |
from swarms.structs.groupchat import GroupChat, expertise_based | |
def setup_test_agents(): | |
model = OpenAIChat( | |
openai_api_key=os.getenv("OPENAI_API_KEY"), | |
model_name="gpt-4", | |
temperature=0.1, | |
) | |
return [ | |
Agent( | |
agent_name="Agent1", | |
system_prompt="You only respond with 'A'", | |
llm=model, | |
), | |
Agent( | |
agent_name="Agent2", | |
system_prompt="You only respond with 'B'", | |
llm=model, | |
), | |
Agent( | |
agent_name="Agent3", | |
system_prompt="You only respond with 'C'", | |
llm=model, | |
), | |
] | |
def test_round_robin_speaking(): | |
chat = GroupChat(agents=setup_test_agents()) | |
history = chat.run("Say your letter") | |
# Verify agents speak in order | |
responses = [ | |
r.message for t in history.turns for r in t.responses | |
] | |
assert responses == ["A", "B", "C"] * (len(history.turns)) | |
def test_concurrent_processing(): | |
chat = GroupChat(agents=setup_test_agents()) | |
tasks = ["Task1", "Task2", "Task3"] | |
histories = chat.concurrent_run(tasks) | |
assert len(histories) == len(tasks) | |
for history in histories: | |
assert history.total_messages > 0 | |
def test_expertise_based_speaking(): | |
agents = setup_test_agents() | |
chat = GroupChat(agents=agents, speaker_fn=expertise_based) | |
# Test each agent's expertise trigger | |
for agent in agents: | |
history = chat.run(f"Trigger {agent.system_prompt}") | |
first_response = history.turns[0].responses[0] | |
assert first_response.agent_name == agent.agent_name | |
def test_max_turns_limit(): | |
max_turns = 3 | |
chat = GroupChat(agents=setup_test_agents(), max_turns=max_turns) | |
history = chat.run("Test message") | |
assert len(history.turns) == max_turns | |
def test_error_handling(): | |
broken_agent = Agent( | |
agent_name="BrokenAgent", | |
system_prompt="You raise errors", | |
llm=None, | |
) | |
chat = GroupChat(agents=[broken_agent]) | |
history = chat.run("Trigger error") | |
assert "Error" in history.turns[0].responses[0].message | |
def test_conversation_context(): | |
agents = setup_test_agents() | |
complex_prompt = "Previous message refers to A. Now trigger B. Finally discuss C." | |
chat = GroupChat(agents=agents, speaker_fn=expertise_based) | |
history = chat.run(complex_prompt) | |
responses = [ | |
r.agent_name for t in history.turns for r in t.responses | |
] | |
assert all(agent.agent_name in responses for agent in agents) | |
def test_large_agent_group(): | |
large_group = setup_test_agents() * 5 # 15 agents | |
chat = GroupChat(agents=large_group) | |
history = chat.run("Test scaling") | |
assert history.total_messages > len(large_group) | |
def test_long_conversations(): | |
chat = GroupChat(agents=setup_test_agents(), max_turns=50) | |
history = chat.run("Long conversation test") | |
assert len(history.turns) == 50 | |
assert history.total_messages > 100 | |
def test_stress_batched_runs(): | |
chat = GroupChat(agents=setup_test_agents()) | |
tasks = ["Task"] * 100 | |
histories = chat.batched_run(tasks) | |
assert len(histories) == len(tasks) | |
total_messages = sum(h.total_messages for h in histories) | |
assert total_messages > len(tasks) * 3 | |
if __name__ == "__main__": | |
load_dotenv() | |
functions = [ | |
test_round_robin_speaking, | |
test_concurrent_processing, | |
test_expertise_based_speaking, | |
test_max_turns_limit, | |
test_error_handling, | |
test_conversation_context, | |
test_large_agent_group, | |
test_long_conversations, | |
test_stress_batched_runs, | |
] | |
for func in functions: | |
try: | |
print(f"Running {func.__name__}...") | |
func() | |
print("β Passed") | |
except Exception as e: | |
print(f"β Failed: {str(e)}") | |