Spaces:
Sleeping
Sleeping
from unittest import TestCase, mock | |
from lagent.actions import ActionExecutor | |
from lagent.actions.llm_qa import LLMQA | |
from lagent.actions.serper_search import SerperSearch | |
from lagent.agents.rewoo import ReWOO, ReWOOProtocol | |
from lagent.schema import ActionReturn, ActionStatusCode | |
class TestReWOO(TestCase): | |
def test_normal_chat(self, mock_parse_worker_func, mock_qa_func, | |
mock_search_func): | |
mock_model = mock.Mock() | |
mock_model.generate_from_template.return_value = 'LLM response' | |
mock_parse_worker_func.return_value = (['Thought1', 'Thought2' | |
], ['LLMQA', 'SerperSearch'], | |
['abc', 'abc']) | |
search_return = ActionReturn(args=None) | |
search_return.state = ActionStatusCode.SUCCESS | |
search_return.result = dict(text='search_return') | |
mock_search_func.return_value = search_return | |
qa_return = ActionReturn(args=None) | |
qa_return.state = ActionStatusCode.SUCCESS | |
qa_return.result = dict(text='qa_return') | |
mock_qa_func.return_value = qa_return | |
chatbot = ReWOO( | |
llm=mock_model, | |
action_executor=ActionExecutor(actions=[ | |
LLMQA(mock_model), | |
SerperSearch(api_key=''), | |
])) | |
agent_return = chatbot.chat('abc') | |
self.assertEqual(agent_return.response, 'LLM response') | |
def test_parse_worker(self): | |
prompt = ReWOOProtocol() | |
message = """ | |
Plan: a. | |
#E1 = tool1["a"] | |
#E2 = tool2["b"] | |
""" | |
try: | |
thoughts, actions, actions_input = prompt.parse_worker(message) | |
except Exception as e: | |
self.assertEqual( | |
'Each Plan should only correspond to only ONE action', str(e)) | |
else: | |
self.assertFalse( | |
True, 'it should raise exception when the format is incorrect') | |
message = """ | |
Plan: a. | |
#E1 = tool1("a") | |
Plan: b. | |
#E2 = tool2["b"] | |
""" | |
try: | |
thoughts, actions, actions_input = prompt.parse_worker(message) | |
except Exception as e: | |
self.assertIsInstance(e, BaseException) | |
else: | |
self.assertFalse( | |
True, 'it should raise exception when the format is incorrect') | |
message = """ | |
Plan: a. | |
#E1 = tool1["a"] | |
Plan: b. | |
#E2 = tool2["b"] | |
""" | |
try: | |
thoughts, actions, actions_input = prompt.parse_worker(message) | |
except Exception: | |
self.assertFalse( | |
True, | |
'it should not raise exception when the format is correct') | |
self.assertEqual(thoughts, ['a.', 'b.']) | |
self.assertEqual(actions, ['tool1', 'tool2']) | |
self.assertEqual(actions_input, ['"a"', '"b"']) | |