Spaces:
Sleeping
Sleeping
File size: 13,553 Bytes
4d1746c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 |
from multi_turn_utils import (
execute_multi_turn_func_call,
is_empty_execute_response,
)
#### Main functions ####
def multi_turn_checker(
multi_turn_model_result_list_decoded: list[list[list[str]]],
multi_turn_ground_truth_list: list[list[str]],
test_entry: dict,
test_category: str,
model_name: str,
) -> dict:
"""
The main function that checks the correctness of the model's function call execution.
"""
initial_config: dict = test_entry["initial_config"]
involved_classes: list = test_entry["involved_classes"]
test_entry_id: str = test_entry["id"]
test_category: str = test_entry_id.rsplit("_", 1)[0]
execution_results: list[dict] = []
all_turn_model_execution_results: list[str] = []
# First execute all the function calls
for turn_index, single_turn_ground_truth_list in enumerate(
multi_turn_ground_truth_list
):
single_turn_model_response_list = multi_turn_model_result_list_decoded[turn_index]
# Note that we combine all the sub-step results into a single list, for easier comparison
single_turn_model_execution_results = []
single_turn_model_execution_results_uncombined = []
single_turn_ground_truth_execution_results = []
model_instances = {} # Will be overwritten in the for loop
single_step_model_execution_results = [] # Will be overwritten in the for loop
for single_step_model_response in single_turn_model_response_list:
single_step_model_execution_results, model_instances = (
execute_multi_turn_func_call(
func_call_list=single_step_model_response,
initial_config=initial_config,
involved_classes=involved_classes,
model_name=model_name,
test_entry_id=test_entry_id,
long_context=(
"long_context" in test_category or "composite" in test_category
),
is_evaL_run=True,
)
)
single_turn_model_execution_results.extend(single_step_model_execution_results)
single_turn_model_execution_results_uncombined.append(single_step_model_execution_results)
# Execute the ground truth function calls
single_turn_ground_truth_execution_results, ground_truth_instances = (
execute_multi_turn_func_call(
func_call_list=single_turn_ground_truth_list,
initial_config=initial_config,
involved_classes=involved_classes,
model_name=model_name + "_ground_truth",
test_entry_id=test_entry_id,
long_context=(
"long_context" in test_category or "composite" in test_category
),
is_evaL_run=True,
)
)
all_turn_model_execution_results.extend(single_turn_model_execution_results)
execution_results.append(
{
"model": single_turn_model_execution_results_uncombined,
"ground_truth": single_turn_ground_truth_execution_results,
}
)
# If the ground truth list is not empty, then the model response list should not be empty
if len(single_turn_ground_truth_list) > 0:
if not single_turn_model_response_list or is_empty_execute_response(
single_turn_model_response_list
):
return {
"valid": False,
"error_message": f"Model response list is empty for turn {turn_index}",
"error_type": "multi_turn:empty_turn_model_response",
"details": {
"execution_result": execution_results,
},
}
# If the ground truth list is empty, this is the turn where the model should eventually fail to achieve the user request.
# The actual check for irrelevance is done in the multi_turn_irrelevance_checker function
# Note: If the model outputs any function call in this turn, we will still execute it so that the state check at the next turn is accurate.
if not single_turn_ground_truth_list:
continue
## Check after each turn ##
assert len(model_instances) == len(
ground_truth_instances
), f"Model instances and ground truth instances do not match in length for turn {turn_index}. Model instances: {len(model_instances)}, Ground truth instances: {len(ground_truth_instances)}"
assert set(model_instances.keys()) == set(ground_truth_instances.keys())
# Check the state of the instances
state_check_result = state_checker(model_instances, ground_truth_instances)
if not state_check_result["valid"]:
state_check_result["execution_result"] = execution_results
return state_check_result
# Check the response of the function calls
# We use the all_turn_model_execution_results to accomodate the situation where the model invokes a function in a previous turn, and thus don't need to invoke it again in the current turn.
response_check_result = response_checker(
all_turn_model_execution_results,
single_turn_ground_truth_execution_results,
turn_index,
)
if not response_check_result["valid"]:
return response_check_result
# # Check the method invoke order
# method_invoke_order_check_result = method_invoke_order_checker(
# model_instances, ground_truth_instances
# )
# if not method_invoke_order_check_result["valid"]:
# return method_invoke_order_check_result
return {"valid": True}
def multi_turn_irrelevance_checker(
multi_turn_model_result_list_decoded: list[list[list[str]]],
multi_turn_ground_truth_list: list[list[str]],
) -> dict:
"""
Check if the model's output are irrelevant when it should be.
It should be empty when the ground truth is a empty list for that turn.
"""
for turn_index, single_turn_ground_truth_list in enumerate(
multi_turn_ground_truth_list
):
single_turn_model_response_list = multi_turn_model_result_list_decoded[turn_index]
if len(single_turn_ground_truth_list) == 0:
if is_empty_execute_response(single_turn_model_response_list):
continue
else:
return {
"valid": False,
"error_message": f"Model outputs valid function calls when it should not for turn {turn_index}.",
"error_type": "multi_turn:irrelevance_error:decoder_success",
"details": {
"model response decoded": single_turn_model_response_list,
},
}
return {"valid": True}
#### Sub-Chekcers ####
def state_checker(model_instances: dict, ground_truth_instances: dict):
"""
Checks if, after executing the function calls, the model_instance has the same state (defined by the attributes) as the ground_truth_instance.
It checks if every instance in the model_instances has the same attributes as their corresponding instance (of the same class) from ground_truth_instances.
"""
for class_name, ground_truth_instance in ground_truth_instances.items():
model_instance = model_instances[class_name]
valid, differences = _compare_instances(model_instance, ground_truth_instance)
if not valid:
model_instance_attributes = {
key: value
for key, value in vars(model_instance).items()
if not key.startswith("_")
}
ground_truth_instance_attributes = {
key: value
for key, value in vars(ground_truth_instance).items()
if not key.startswith("_")
}
# Format the error message for better readability
return {
"valid": False,
"error_message": f"Model instance for {class_name} does not match the state with ground truth instance.",
"error_type": "multi_turn:instance_state_mismatch",
"details": {
"differences": differences,
"model_instance_state": model_instance_attributes,
"ground_truth_instance_state": ground_truth_instance_attributes,
},
}
return {"valid": True}
def response_checker(
model_response_list: list, ground_truth_response_list: list, turn_index: int
):
"""
Checks if the model_response is a subsequence of the ground_truth_response.
Each list contains the response of the function calls executed in that single turn.
"""
# We don't need to enforce the order of the responses, because many entries have parallel operations, and so the model can execute them in any order.
is_subsequence, missing_items = _is_subsequence_unordered(
ground_truth_response_list, model_response_list
)
if not is_subsequence:
return {
"valid": False,
"error_message": f"Model response execution results so far does not contain all the ground truth response execution results for turn {turn_index}.",
"error_type": "multi_turn:execution_response_mismatch",
"details": {
"missing_items": missing_items,
"model_response (including all previous turns)": model_response_list,
"ground_truth_response (only the current turn)": ground_truth_response_list,
},
}
return {"valid": True}
def method_invoke_order_checker(model_instances: dict, ground_truth_instances: dict):
"""
Checks if the model_instance called the same order of methods as the ground_truth_instance.
model_instance can call additional methods, but not skip any method that the ground_truth_instance called.
Note: Currently, this functions only checks for the method names and not the arguments.
"""
for class_name, ground_truth_instance in ground_truth_instances.items():
model_instance = model_instances[class_name]
# The get_method_called method is added by the LoggingMeta metaclass automatically
model_invoke_order = model_instance.get_method_called()
ground_truth_invoke_order = ground_truth_instance.get_method_called()
# Extract the method names
model_invoke_order = [method_call["method"] for method_call in model_invoke_order]
ground_truth_invoke_order = [
method_call["method"] for method_call in ground_truth_invoke_order
]
is_subsequence, missing_items = _is_subsequence(
ground_truth_invoke_order, model_invoke_order
)
if not is_subsequence:
return {
"valid": False,
"error_message": f"Model instance for {class_name} does not match the method invoke order with ground truth instance. Missing items: {missing_items}",
"error_type": "multi_turn:method_invoke_order_mismatch",
}
return {"valid": True}
#### Helper functions ####
def _compare_instances(model_obect, ground_truth_object):
"""
Checks if the model_object has the same attributes as the ground_truth_object. They are instances of the same class.
"""
assert type(model_obect) == type(
ground_truth_object
), "Objects are not of the same type."
differences = {}
valid = True
for attr_name in vars(ground_truth_object):
# We don't check for private attributes
if attr_name.startswith("_"):
continue
model_attr = getattr(model_obect, attr_name)
ground_truth_attr = getattr(ground_truth_object, attr_name)
if model_attr != ground_truth_attr:
valid = False
differences[attr_name] = {"model": model_attr, "ground_truth": ground_truth_attr}
return valid, differences
def _is_subsequence(list1, list2) -> tuple[bool, list]:
"""
Checks if list1 is a subsequence of list2, i.e., all elements of list1 are present in list2 in the same order.
Also returns the elements of list1 that are not present in list2.
"""
# Convert list2 to an iterator to ensure that the elements are consumed only once.
iter_list2 = iter(list2)
return all(item in iter_list2 for item in list1), [
item for item in list1 if item not in list2
]
def _is_subsequence_unordered(list1, list2) -> tuple[bool, list]:
"""
Checks if all elements of list1 are present in list2, regardless of order.
Also returns the elements of list1 that are not present in list2.
"""
# Copy list2 to avoid modifying the original list during checks
list2_copy = list2[:]
# Check each item in list1 to see if it exists in list2_copy
missing_elements = []
for item in list1:
try:
# Attempt to remove one occurrence of `item` from list2_copy to handle duplicates
list2_copy.remove(item)
except ValueError:
# If item is not found, add it to missing_elements
missing_elements.append(item)
# If there are missing elements, list1 is not a subsequence of list2
is_subsequence = len(missing_elements) == 0
return is_subsequence, missing_elements
|