Spaces:
Sleeping
Sleeping
from multi_turn_utils import ( | |
execute_multi_turn_func_call, | |
is_empty_execute_response, | |
) | |
#### Main functions #### | |
def multi_turn_checker( | |
multi_turn_model_result_list_decoded: list[list[list[str]]], | |
multi_turn_ground_truth_list: list[list[str]], | |
test_entry: dict, | |
test_category: str, | |
model_name: str, | |
) -> dict: | |
""" | |
The main function that checks the correctness of the model's function call execution. | |
""" | |
initial_config: dict = test_entry["initial_config"] | |
involved_classes: list = test_entry["involved_classes"] | |
test_entry_id: str = test_entry["id"] | |
test_category: str = test_entry_id.rsplit("_", 1)[0] | |
execution_results: list[dict] = [] | |
all_turn_model_execution_results: list[str] = [] | |
# First execute all the function calls | |
for turn_index, single_turn_ground_truth_list in enumerate( | |
multi_turn_ground_truth_list | |
): | |
single_turn_model_response_list = multi_turn_model_result_list_decoded[turn_index] | |
# Note that we combine all the sub-step results into a single list, for easier comparison | |
single_turn_model_execution_results = [] | |
single_turn_model_execution_results_uncombined = [] | |
single_turn_ground_truth_execution_results = [] | |
model_instances = {} # Will be overwritten in the for loop | |
single_step_model_execution_results = [] # Will be overwritten in the for loop | |
for single_step_model_response in single_turn_model_response_list: | |
single_step_model_execution_results, model_instances = ( | |
execute_multi_turn_func_call( | |
func_call_list=single_step_model_response, | |
initial_config=initial_config, | |
involved_classes=involved_classes, | |
model_name=model_name, | |
test_entry_id=test_entry_id, | |
long_context=( | |
"long_context" in test_category or "composite" in test_category | |
), | |
is_evaL_run=True, | |
) | |
) | |
single_turn_model_execution_results.extend(single_step_model_execution_results) | |
single_turn_model_execution_results_uncombined.append(single_step_model_execution_results) | |
# Execute the ground truth function calls | |
single_turn_ground_truth_execution_results, ground_truth_instances = ( | |
execute_multi_turn_func_call( | |
func_call_list=single_turn_ground_truth_list, | |
initial_config=initial_config, | |
involved_classes=involved_classes, | |
model_name=model_name + "_ground_truth", | |
test_entry_id=test_entry_id, | |
long_context=( | |
"long_context" in test_category or "composite" in test_category | |
), | |
is_evaL_run=True, | |
) | |
) | |
all_turn_model_execution_results.extend(single_turn_model_execution_results) | |
execution_results.append( | |
{ | |
"model": single_turn_model_execution_results_uncombined, | |
"ground_truth": single_turn_ground_truth_execution_results, | |
} | |
) | |
# If the ground truth list is not empty, then the model response list should not be empty | |
if len(single_turn_ground_truth_list) > 0: | |
if not single_turn_model_response_list or is_empty_execute_response( | |
single_turn_model_response_list | |
): | |
return { | |
"valid": False, | |
"error_message": f"Model response list is empty for turn {turn_index}", | |
"error_type": "multi_turn:empty_turn_model_response", | |
"details": { | |
"execution_result": execution_results, | |
}, | |
} | |
# If the ground truth list is empty, this is the turn where the model should eventually fail to achieve the user request. | |
# The actual check for irrelevance is done in the multi_turn_irrelevance_checker function | |
# Note: If the model outputs any function call in this turn, we will still execute it so that the state check at the next turn is accurate. | |
if not single_turn_ground_truth_list: | |
continue | |
## Check after each turn ## | |
assert len(model_instances) == len( | |
ground_truth_instances | |
), f"Model instances and ground truth instances do not match in length for turn {turn_index}. Model instances: {len(model_instances)}, Ground truth instances: {len(ground_truth_instances)}" | |
assert set(model_instances.keys()) == set(ground_truth_instances.keys()) | |
# Check the state of the instances | |
state_check_result = state_checker(model_instances, ground_truth_instances) | |
if not state_check_result["valid"]: | |
state_check_result["execution_result"] = execution_results | |
return state_check_result | |
# Check the response of the function calls | |
# We use the all_turn_model_execution_results to accomodate the situation where the model invokes a function in a previous turn, and thus don't need to invoke it again in the current turn. | |
response_check_result = response_checker( | |
all_turn_model_execution_results, | |
single_turn_ground_truth_execution_results, | |
turn_index, | |
) | |
if not response_check_result["valid"]: | |
return response_check_result | |
# # Check the method invoke order | |
# method_invoke_order_check_result = method_invoke_order_checker( | |
# model_instances, ground_truth_instances | |
# ) | |
# if not method_invoke_order_check_result["valid"]: | |
# return method_invoke_order_check_result | |
return {"valid": True} | |
def multi_turn_irrelevance_checker( | |
multi_turn_model_result_list_decoded: list[list[list[str]]], | |
multi_turn_ground_truth_list: list[list[str]], | |
) -> dict: | |
""" | |
Check if the model's output are irrelevant when it should be. | |
It should be empty when the ground truth is a empty list for that turn. | |
""" | |
for turn_index, single_turn_ground_truth_list in enumerate( | |
multi_turn_ground_truth_list | |
): | |
single_turn_model_response_list = multi_turn_model_result_list_decoded[turn_index] | |
if len(single_turn_ground_truth_list) == 0: | |
if is_empty_execute_response(single_turn_model_response_list): | |
continue | |
else: | |
return { | |
"valid": False, | |
"error_message": f"Model outputs valid function calls when it should not for turn {turn_index}.", | |
"error_type": "multi_turn:irrelevance_error:decoder_success", | |
"details": { | |
"model response decoded": single_turn_model_response_list, | |
}, | |
} | |
return {"valid": True} | |
#### Sub-Chekcers #### | |
def state_checker(model_instances: dict, ground_truth_instances: dict): | |
""" | |
Checks if, after executing the function calls, the model_instance has the same state (defined by the attributes) as the ground_truth_instance. | |
It checks if every instance in the model_instances has the same attributes as their corresponding instance (of the same class) from ground_truth_instances. | |
""" | |
for class_name, ground_truth_instance in ground_truth_instances.items(): | |
model_instance = model_instances[class_name] | |
valid, differences = _compare_instances(model_instance, ground_truth_instance) | |
if not valid: | |
model_instance_attributes = { | |
key: value | |
for key, value in vars(model_instance).items() | |
if not key.startswith("_") | |
} | |
ground_truth_instance_attributes = { | |
key: value | |
for key, value in vars(ground_truth_instance).items() | |
if not key.startswith("_") | |
} | |
# Format the error message for better readability | |
return { | |
"valid": False, | |
"error_message": f"Model instance for {class_name} does not match the state with ground truth instance.", | |
"error_type": "multi_turn:instance_state_mismatch", | |
"details": { | |
"differences": differences, | |
"model_instance_state": model_instance_attributes, | |
"ground_truth_instance_state": ground_truth_instance_attributes, | |
}, | |
} | |
return {"valid": True} | |
def response_checker( | |
model_response_list: list, ground_truth_response_list: list, turn_index: int | |
): | |
""" | |
Checks if the model_response is a subsequence of the ground_truth_response. | |
Each list contains the response of the function calls executed in that single turn. | |
""" | |
# We don't need to enforce the order of the responses, because many entries have parallel operations, and so the model can execute them in any order. | |
is_subsequence, missing_items = _is_subsequence_unordered( | |
ground_truth_response_list, model_response_list | |
) | |
if not is_subsequence: | |
return { | |
"valid": False, | |
"error_message": f"Model response execution results so far does not contain all the ground truth response execution results for turn {turn_index}.", | |
"error_type": "multi_turn:execution_response_mismatch", | |
"details": { | |
"missing_items": missing_items, | |
"model_response (including all previous turns)": model_response_list, | |
"ground_truth_response (only the current turn)": ground_truth_response_list, | |
}, | |
} | |
return {"valid": True} | |
def method_invoke_order_checker(model_instances: dict, ground_truth_instances: dict): | |
""" | |
Checks if the model_instance called the same order of methods as the ground_truth_instance. | |
model_instance can call additional methods, but not skip any method that the ground_truth_instance called. | |
Note: Currently, this functions only checks for the method names and not the arguments. | |
""" | |
for class_name, ground_truth_instance in ground_truth_instances.items(): | |
model_instance = model_instances[class_name] | |
# The get_method_called method is added by the LoggingMeta metaclass automatically | |
model_invoke_order = model_instance.get_method_called() | |
ground_truth_invoke_order = ground_truth_instance.get_method_called() | |
# Extract the method names | |
model_invoke_order = [method_call["method"] for method_call in model_invoke_order] | |
ground_truth_invoke_order = [ | |
method_call["method"] for method_call in ground_truth_invoke_order | |
] | |
is_subsequence, missing_items = _is_subsequence( | |
ground_truth_invoke_order, model_invoke_order | |
) | |
if not is_subsequence: | |
return { | |
"valid": False, | |
"error_message": f"Model instance for {class_name} does not match the method invoke order with ground truth instance. Missing items: {missing_items}", | |
"error_type": "multi_turn:method_invoke_order_mismatch", | |
} | |
return {"valid": True} | |
#### Helper functions #### | |
def _compare_instances(model_obect, ground_truth_object): | |
""" | |
Checks if the model_object has the same attributes as the ground_truth_object. They are instances of the same class. | |
""" | |
assert type(model_obect) == type( | |
ground_truth_object | |
), "Objects are not of the same type." | |
differences = {} | |
valid = True | |
for attr_name in vars(ground_truth_object): | |
# We don't check for private attributes | |
if attr_name.startswith("_"): | |
continue | |
model_attr = getattr(model_obect, attr_name) | |
ground_truth_attr = getattr(ground_truth_object, attr_name) | |
if model_attr != ground_truth_attr: | |
valid = False | |
differences[attr_name] = {"model": model_attr, "ground_truth": ground_truth_attr} | |
return valid, differences | |
def _is_subsequence(list1, list2) -> tuple[bool, list]: | |
""" | |
Checks if list1 is a subsequence of list2, i.e., all elements of list1 are present in list2 in the same order. | |
Also returns the elements of list1 that are not present in list2. | |
""" | |
# Convert list2 to an iterator to ensure that the elements are consumed only once. | |
iter_list2 = iter(list2) | |
return all(item in iter_list2 for item in list1), [ | |
item for item in list1 if item not in list2 | |
] | |
def _is_subsequence_unordered(list1, list2) -> tuple[bool, list]: | |
""" | |
Checks if all elements of list1 are present in list2, regardless of order. | |
Also returns the elements of list1 that are not present in list2. | |
""" | |
# Copy list2 to avoid modifying the original list during checks | |
list2_copy = list2[:] | |
# Check each item in list1 to see if it exists in list2_copy | |
missing_elements = [] | |
for item in list1: | |
try: | |
# Attempt to remove one occurrence of `item` from list2_copy to handle duplicates | |
list2_copy.remove(item) | |
except ValueError: | |
# If item is not found, add it to missing_elements | |
missing_elements.append(item) | |
# If there are missing elements, list1 is not a subsequence of list2 | |
is_subsequence = len(missing_elements) == 0 | |
return is_subsequence, missing_elements | |