Spaces:
Runtime error
Runtime error
import pytest | |
from unittest.mock import Mock, patch | |
from swarms.agents.agents import AgentNodeInitializer, AgentNode, agent # replace with actual import | |
# For initializing AgentNodeInitializer in multiple tests | |
def mock_agent_node_initializer(): | |
with patch('swarms.agents.agents.ChatOpenAI') as mock_llm, \ | |
patch('swarms.agents.agents.AutoGPT') as mock_agent: | |
initializer = AgentNodeInitializer(model_type='openai', model_id='test', openai_api_key='test_key', temperature=0.5) | |
initializer.llm = mock_llm | |
initializer.tools = [Mock(spec=BaseTool)] | |
initializer.vectorstore = Mock() | |
initializer.agent = mock_agent | |
return initializer | |
# Test initialize_llm method of AgentNodeInitializer class | |
def test_agent_node_initializer_initialize_llm(model_type, mock_agent_node_initializer): | |
with patch('swarms.agents.agents.ChatOpenAI') as mock_openai, \ | |
patch('swarms.agents.agents.HuggingFaceLLM') as mock_huggingface: | |
if model_type == 'invalid': | |
with pytest.raises(ValueError): | |
mock_agent_node_initializer.initialize_llm(model_type, 'model_id', 'openai_api_key', 0.5) | |
else: | |
mock_agent_node_initializer.initialize_llm(model_type, 'model_id', 'openai_api_key', 0.5) | |
if model_type == 'openai': | |
mock_openai.assert_called_once() | |
elif model_type == 'huggingface': | |
mock_huggingface.assert_called_once() | |
# Test add_tool method of AgentNodeInitializer class | |
def test_agent_node_initializer_add_tool(mock_agent_node_initializer): | |
with patch('swarms.agents.agents.BaseTool') as mock_base_tool: | |
mock_agent_node_initializer.add_tool(mock_base_tool) | |
assert mock_base_tool in mock_agent_node_initializer.tools | |
# Test run method of AgentNodeInitializer class | |
def test_agent_node_initializer_run(prompt, mock_agent_node_initializer): | |
if prompt == '': | |
with pytest.raises(ValueError): | |
mock_agent_node_initializer.run(prompt) | |
else: | |
assert mock_agent_node_initializer.run(prompt) == "Task completed by AgentNode" | |
# For initializing AgentNode in multiple tests | |
def mock_agent_node(): | |
with patch('swarms.agents.agents.ChatOpenAI') as mock_llm, \ | |
patch('swarms.agents.agents.AgentNodeInitializer') as mock_agent_node_initializer: | |
mock_agent_node = AgentNode('test_key') | |
mock_agent_node.llm_class = mock_llm | |
mock_agent_node.vectorstore = Mock() | |
mock_agent_node_initializer.llm = mock_llm | |
return mock_agent_node | |
# Test initialize_llm method of AgentNode class | |
def test_agent_node_initialize_llm(llm_class, mock_agent_node): | |
with patch('swarms.agents.agents.ChatOpenAI') as mock_openai, \ | |
patch('swarms.agents.agents.HuggingFaceLLM') as mock_huggingface: | |
mock_agent_node.initialize_llm(llm_class) | |
if llm_class == 'openai': | |
mock_openai.assert_called_once() | |
elif llm_class == 'huggingface': | |
mock_huggingface.assert_called_once() | |
# Test initialize_tools method of AgentNode class | |
def test_agent_node_initialize_tools(mock_agent_node): | |
with patch('swarms.agents.agents.DuckDuckGoSearchRun') as mock_ddg, \ | |
patch('swarms.agents.agents.WriteFileTool') as mock_write_file, \ | |
patch('swarms.agents.agents.ReadFileTool') as mock_read_file, \ | |
patch('swarms.agents.agents.process_csv') as mock_process_csv, \ | |
patch('swarms.agents.agents.WebpageQATool') as mock_webpage_qa: | |
mock_agent_node.initialize_tools('openai') | |
assert mock_ddg.called | |
assert mock_write_file.called | |
assert mock_read_file.called | |
assert mock_process_csv.called | |
assert mock_webpage_qa.called | |
# Test create_agent method of AgentNode class | |
def test_agent_node_create_agent(mock_agent_node): | |
with patch.object(mock_agent_node, 'initialize_llm'), \ | |
patch.object(mock_agent_node, 'initialize_tools'), \ | |
patch.object(mock_agent_node, 'initialize_vectorstore'), \ | |
patch('swarms.agents.agents.AgentNodeInitializer') as mock_agent_node_initializer: | |
mock_agent_node.create_agent() | |
mock_agent_node_initializer.assert_called_once() | |
mock_agent_node_initializer.return_value.create_agent.assert_called_once() | |
# Test agent function | |
def test_agent(openai_api_key, objective): | |
if openai_api_key == '' or objective == '': | |
with pytest.raises(ValueError): | |
agent(openai_api_key, objective) | |
else: | |
with patch('swarms.agents.agents.AgentNodeInitializer') as mock_agent_node_initializer: | |
mock_agent_node = mock_agent_node_initializer.return_value.create_agent.return_value | |
mock_agent_node.run.return_value = 'Agent output' | |
result = agent(openai_api_key, objective) | |
assert result == 'Agent output' | |