dual_window / multi_turn_eval /multi_turn_utils.py
HuanzhiMao's picture
Fixing Dual Window Bug and Adding Single Model Demo (#1)
0157229 verified
raw
history blame
5.84 kB
import importlib
import inspect
import json
import re
import copy
CLASS_FILE_PATH_MAPPING = {
"GorillaFileSystem": "multi_turn_eval.func_source_code.gorilla_file_system",
"MathAPI": "multi_turn_eval.func_source_code.math_api",
"MessageAPI": "multi_turn_eval.func_source_code.message_api",
"TwitterAPI": "multi_turn_eval.func_source_code.posting_api",
"TicketAPI": "multi_turn_eval.func_source_code.ticket_api",
"TradingBot": "multi_turn_eval.func_source_code.trading_bot",
"TravelAPI": "multi_turn_eval.func_source_code.travel_booking",
"VehicleControlAPI": "multi_turn_eval.func_source_code.vehicle_control",
}
# These classes are stateless and do not require any initial configuration
STATELESS_CLASSES = [
"MathAPI",
]
def execute_multi_turn_func_call(
func_call_list: list[str], # a list of strings of func calls
initial_config: dict,
involved_classes: list,
model_name: str,
test_entry_id: str,
long_context: bool = False,
is_evaL_run: bool = False,
) -> tuple[list[str], dict]:
"""
TODO: Add docstring
"""
if is_evaL_run:
model_name += "_eval"
class_method_name_mapping = {}
involved_instances = {}
for class_name in involved_classes:
module_name = CLASS_FILE_PATH_MAPPING[class_name]
# TODO: Handler the model name issue from handler more elegantly
instance_name = (
f"{model_name.replace('-', '_').replace('.', '_').replace('/', '_')}_{str(test_entry_id).replace('-', '_')}_{class_name.lower()}_instance"
)
if instance_name not in globals():
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
class_instance = class_()
if class_name not in STATELESS_CLASSES:
class_initial_config = initial_config.get(class_name, {})
# Deep copy the initial configuration to avoid mutation issues
class_instance._load_scenario(
copy.deepcopy(class_initial_config), long_context=long_context
)
globals()[instance_name] = class_instance
# This happens in subsequent turns
else:
class_instance = globals()[instance_name]
involved_instances[class_name] = class_instance
# Retrieve all method names and map them to the instance
for method_name, method in inspect.getmembers(
class_instance, predicate=inspect.ismethod
):
# Skip private methods
if method_name.startswith("_"):
continue
class_method_name_mapping[method_name] = instance_name
execution_results = []
for func_call in func_call_list:
# Add the instance name to the method calls
func_call = _process_method_calls(func_call, class_method_name_mapping)
# Evaluate the function call
try:
# We need to make a copy here because otherwise the `eval(func_call)` would error.
func_call_copy = func_call
# Before calling `eval`, we need to make sure that the function call is safe
# We do so by checking if the function is `kill` or `exit`, etc.
# Extract the function name first
if "(" in func_call_copy:
func_call_copy = func_call_copy.split("(")[0]
# Situation where the function call is a method call
if "." in func_call_copy:
func_call_copy = func_call_copy.split(".")[1]
if func_call_copy in ["kill", "exit", "quit", "remove", "unlink", "popen", "Popen", "run"]:
raise Exception(f"Function call {func_call_copy} is not allowed.")
func_call_result = eval(func_call)
if type(func_call_result) == str:
pass
elif type(func_call_result) == dict:
# Some function returns a object instance, which is not serializable
try:
func_call_result = json.dumps(func_call_result)
except:
func_call_result = str(func_call_result)
else:
func_call_result = str(func_call_result)
execution_results.append(func_call_result)
except Exception as e:
execution_results.append(f"Error during execution: {str(e)}")
return execution_results, involved_instances
def is_empty_execute_response(input_list: list):
if len(input_list) == 0:
return True
if len(input_list) == 1 and len(input_list[0]) == 0:
return True
return False
def _process_method_calls(function_call_string: str, instance_mapping: dict) -> str:
"""
Prepends the instance name to the function name for each of the function name represented in the string, you will
also be provided with the mapping of method name to instance name.
Example input:
```
f(x = g((1, 2), h(3)), y = (4), z = (5, 6))
```
Example return:
```
a.f(x=a.g((1, 2), a.h(3)), y=(4), z=(5, 6))
```
Args:
function_call_string (str): The function call string to parse.
class_mapping (dict): A dictionary mapping method names to instance names.
Returns:
str: The parsed function call string with instance names prepended to method names.
"""
def replace_function(match):
func_name = match.group(1)
if func_name in instance_mapping:
return f"{instance_mapping[func_name]}.{func_name}"
return func_name
# Regular expression to match function names
pattern = r"\b([a-zA-Z_]\w*)\s*(?=\()"
# Replace function names with their class-prepended versions
processed_string = re.sub(pattern, replace_function, function_call_string)
return processed_string