|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import ast |
|
import difflib |
|
from collections.abc import Mapping |
|
from typing import Any, Callable, Dict |
|
|
|
|
|
class InterpretorError(ValueError): |
|
""" |
|
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported |
|
operations. |
|
""" |
|
|
|
pass |
|
|
|
|
|
def evaluate(code: str, tools: Dict[str, Callable], state=None, chat_mode=False): |
|
""" |
|
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set |
|
of functions. |
|
|
|
This function will recurse through the nodes of the tree provided. |
|
|
|
Args: |
|
code (`str`): |
|
The code to evaluate. |
|
tools (`Dict[str, Callable]`): |
|
The functions that may be called during the evaluation. Any call to another function will fail with an |
|
`InterpretorError`. |
|
state (`Dict[str, Any]`): |
|
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be |
|
updated by this function to contain all variables as they are evaluated. |
|
chat_mode (`bool`, *optional*, defaults to `False`): |
|
Whether or not the function is called from `Agent.chat`. |
|
""" |
|
try: |
|
expression = ast.parse(code) |
|
except SyntaxError as e: |
|
print("The code generated by the agent is not valid.\n", e) |
|
return |
|
if state is None: |
|
state = {} |
|
result = None |
|
for idx, node in enumerate(expression.body): |
|
try: |
|
line_result = evaluate_ast(node, state, tools) |
|
except InterpretorError as e: |
|
msg = f"Evaluation of the code stopped at line {idx} before the end because of the following error" |
|
if chat_mode: |
|
msg += ( |
|
f". Copy paste the following error message and send it back to the agent:\nI get an error: '{e}'" |
|
) |
|
else: |
|
msg += f":\n{e}" |
|
print(msg) |
|
break |
|
if line_result is not None: |
|
result = line_result |
|
|
|
return result |
|
|
|
|
|
def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Callable]): |
|
""" |
|
Evaluate an absract syntax tree using the content of the variables stored in a state and only evaluating a given |
|
set of functions. |
|
|
|
This function will recurse trough the nodes of the tree provided. |
|
|
|
Args: |
|
expression (`ast.AST`): |
|
The code to evaluate, as an abastract syntax tree. |
|
state (`Dict[str, Any]`): |
|
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation |
|
encounters assignements. |
|
tools (`Dict[str, Callable]`): |
|
The functions that may be called during the evaluation. Any call to another function will fail with an |
|
`InterpretorError`. |
|
""" |
|
if isinstance(expression, ast.Assign): |
|
|
|
|
|
return evaluate_assign(expression, state, tools) |
|
elif isinstance(expression, ast.Call): |
|
|
|
return evaluate_call(expression, state, tools) |
|
elif isinstance(expression, ast.Constant): |
|
|
|
return expression.value |
|
elif isinstance(expression, ast.Dict): |
|
|
|
keys = [evaluate_ast(k, state, tools) for k in expression.keys] |
|
values = [evaluate_ast(v, state, tools) for v in expression.values] |
|
return dict(zip(keys, values)) |
|
elif isinstance(expression, ast.Expr): |
|
|
|
return evaluate_ast(expression.value, state, tools) |
|
elif isinstance(expression, ast.For): |
|
|
|
return evaluate_for(expression, state, tools) |
|
elif isinstance(expression, ast.FormattedValue): |
|
|
|
return evaluate_ast(expression.value, state, tools) |
|
elif isinstance(expression, ast.If): |
|
|
|
return evaluate_if(expression, state, tools) |
|
elif hasattr(ast, "Index") and isinstance(expression, ast.Index): |
|
return evaluate_ast(expression.value, state, tools) |
|
elif isinstance(expression, ast.JoinedStr): |
|
return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values]) |
|
elif isinstance(expression, ast.List): |
|
|
|
return [evaluate_ast(elt, state, tools) for elt in expression.elts] |
|
elif isinstance(expression, ast.Name): |
|
|
|
return evaluate_name(expression, state, tools) |
|
elif isinstance(expression, ast.Subscript): |
|
|
|
return evaluate_subscript(expression, state, tools) |
|
else: |
|
|
|
raise InterpretorError(f"{expression.__class__.__name__} is not supported.") |
|
|
|
|
|
def evaluate_assign(assign, state, tools): |
|
var_names = assign.targets |
|
result = evaluate_ast(assign.value, state, tools) |
|
|
|
if len(var_names) == 1: |
|
state[var_names[0].id] = result |
|
else: |
|
if len(result) != len(var_names): |
|
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.") |
|
for var_name, r in zip(var_names, result): |
|
state[var_name.id] = r |
|
return result |
|
|
|
|
|
def evaluate_call(call, state, tools): |
|
if not isinstance(call.func, ast.Name): |
|
raise InterpretorError( |
|
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func} of " |
|
f"type {type(call.func)}." |
|
) |
|
func_name = call.func.id |
|
if func_name not in tools: |
|
raise InterpretorError( |
|
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func.id})." |
|
) |
|
|
|
func = tools[func_name] |
|
|
|
args = [evaluate_ast(arg, state, tools) for arg in call.args] |
|
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords} |
|
return func(*args, **kwargs) |
|
|
|
|
|
def evaluate_subscript(subscript, state, tools): |
|
index = evaluate_ast(subscript.slice, state, tools) |
|
value = evaluate_ast(subscript.value, state, tools) |
|
if isinstance(value, (list, tuple)): |
|
return value[int(index)] |
|
if index in value: |
|
return value[index] |
|
if isinstance(index, str) and isinstance(value, Mapping): |
|
close_matches = difflib.get_close_matches(index, list(value.keys())) |
|
if len(close_matches) > 0: |
|
return value[close_matches[0]] |
|
|
|
raise InterpretorError(f"Could not index {value} with '{index}'.") |
|
|
|
|
|
def evaluate_name(name, state, tools): |
|
if name.id in state: |
|
return state[name.id] |
|
close_matches = difflib.get_close_matches(name.id, list(state.keys())) |
|
if len(close_matches) > 0: |
|
return state[close_matches[0]] |
|
raise InterpretorError(f"The variable `{name.id}` is not defined.") |
|
|
|
|
|
def evaluate_condition(condition, state, tools): |
|
if len(condition.ops) > 1: |
|
raise InterpretorError("Cannot evaluate conditions with multiple operators") |
|
|
|
left = evaluate_ast(condition.left, state, tools) |
|
comparator = condition.ops[0] |
|
right = evaluate_ast(condition.comparators[0], state, tools) |
|
|
|
if isinstance(comparator, ast.Eq): |
|
return left == right |
|
elif isinstance(comparator, ast.NotEq): |
|
return left != right |
|
elif isinstance(comparator, ast.Lt): |
|
return left < right |
|
elif isinstance(comparator, ast.LtE): |
|
return left <= right |
|
elif isinstance(comparator, ast.Gt): |
|
return left > right |
|
elif isinstance(comparator, ast.GtE): |
|
return left >= right |
|
elif isinstance(comparator, ast.Is): |
|
return left is right |
|
elif isinstance(comparator, ast.IsNot): |
|
return left is not right |
|
elif isinstance(comparator, ast.In): |
|
return left in right |
|
elif isinstance(comparator, ast.NotIn): |
|
return left not in right |
|
else: |
|
raise InterpretorError(f"Operator not supported: {comparator}") |
|
|
|
|
|
def evaluate_if(if_statement, state, tools): |
|
result = None |
|
if evaluate_condition(if_statement.test, state, tools): |
|
for line in if_statement.body: |
|
line_result = evaluate_ast(line, state, tools) |
|
if line_result is not None: |
|
result = line_result |
|
else: |
|
for line in if_statement.orelse: |
|
line_result = evaluate_ast(line, state, tools) |
|
if line_result is not None: |
|
result = line_result |
|
return result |
|
|
|
|
|
def evaluate_for(for_loop, state, tools): |
|
result = None |
|
iterator = evaluate_ast(for_loop.iter, state, tools) |
|
for counter in iterator: |
|
state[for_loop.target.id] = counter |
|
for expression in for_loop.body: |
|
line_result = evaluate_ast(expression, state, tools) |
|
if line_result is not None: |
|
result = line_result |
|
return result |
|
|