Spaces:
Sleeping
Sleeping
import json | |
import time | |
from copy import deepcopy | |
from multi_turn_eval.multi_turn_utils import ( | |
STATELESS_CLASSES, | |
execute_multi_turn_func_call, | |
is_empty_execute_response, | |
) | |
from constant import ( | |
DEFAULT_USER_PROMPT_FOR_ADDITIONAL_FUNCTION_FC, | |
DEFAULT_USER_PROMPT_FOR_ADDITIONAL_FUNCTION_PROMPTING, | |
MAXIMUM_STEP_LIMIT, | |
) | |
from model_style import ModelStyle | |
from overrides import final | |
class BaseHandler: | |
model_name: str | |
model_style: ModelStyle | |
def __init__(self, model_name, temperature) -> None: | |
self.model_name = model_name | |
# Replace the slash with underscore to avoid creating subdirectories | |
# Replace the dash and dot with underscore for valid variable name | |
self.model_name_underline_replaced = ( | |
model_name.replace("/", "_").replace("-", "_").replace(".", "_") | |
) | |
self.temperature = temperature | |
self.is_fc_model = False # Whether the model is a function calling model | |
def inference(self, test_entry: dict, include_input_log: bool=False, include_state_log: bool=False): | |
# This method is used to retrive model response for each model. | |
return self.inference_multi_turn_FC(test_entry, include_input_log, include_state_log) | |
def inference_multi_turn_FC( | |
self, test_entry: dict, include_input_log: bool, include_state_log: bool | |
): | |
initial_config: dict = test_entry["initial_config"] | |
involved_classes: list = test_entry["involved_classes"] | |
test_entry_id: str = test_entry["id"] | |
test_category: str = test_entry_id.rsplit("_", 1)[0] | |
# This is only for the miss function category | |
# A mapping from turn index to function to holdout | |
holdout_function: dict[int, list] = test_entry.get("missed_function", {}) | |
total_input_token_count: list[list[float]] = [] | |
total_output_token_count: list[list[float]] = [] | |
total_latency: list[list[float]] = [] | |
all_model_response: list[list] = ( | |
[] | |
) # The model response that will be used for later evaluation | |
all_inference_log: list[list[dict]] = ( | |
[] | |
) # The debugging log for human to understand | |
force_quit = False # Whether the model has been forced to quit. If True, this whole entry will be failed. | |
# Execute no function call, but just to get a reference to all the instances to get the initial state for logging purpose | |
if include_state_log: | |
_, involved_instances = execute_multi_turn_func_call( | |
[], | |
initial_config, | |
involved_classes, | |
self.model_name_underline_replaced, | |
test_entry_id, | |
long_context=( | |
"long_context" in test_category or "composite" in test_category | |
), | |
is_evaL_run=False, | |
) | |
state_log = [] | |
for class_name, class_instance in involved_instances.items(): | |
if class_name in STATELESS_CLASSES: | |
continue | |
class_instance = deepcopy(class_instance) # Avoid modification in future turns | |
state_log.append( | |
{ | |
"role": "state_info", | |
"class_name": class_name, | |
"content": { | |
key: value | |
for key, value in vars(class_instance).items() | |
if not key.startswith("_") | |
}, | |
} | |
) | |
all_inference_log.append(state_log) | |
inference_data: dict = {} | |
inference_data = self._pre_query_processing_FC(inference_data, test_entry) | |
inference_data = self._compile_tools(inference_data, test_entry) | |
all_multi_turn_messages: list[list[dict]] = test_entry["question"] | |
for turn_idx, current_turn_message in enumerate(all_multi_turn_messages): | |
current_turn_message: list[dict] | |
if str(turn_idx) in holdout_function: | |
test_entry["function"].extend(holdout_function[str(turn_idx)]) | |
# Since we have added new functions, we need to recompile the tools | |
inference_data = self._compile_tools(inference_data, test_entry) | |
assert ( | |
len(current_turn_message) == 0 | |
), "Holdout turn should not have user message." | |
current_turn_message = [ | |
{ | |
"role": "user", | |
"content": DEFAULT_USER_PROMPT_FOR_ADDITIONAL_FUNCTION_FC, | |
} | |
] | |
if turn_idx == 0: | |
inference_data = self.add_first_turn_message_FC( | |
inference_data, [current_turn_message] | |
) | |
else: | |
assert isinstance(current_turn_message, list), "Current turn message is not a list" | |
inference_data = self._add_next_turn_user_message_FC( | |
inference_data, current_turn_message | |
) | |
current_turn_response = [] | |
current_turn_inference_log: list[dict] = {"begin_of_turn_query": current_turn_message} | |
current_turn_input_token_count: list[float] = [] | |
current_turn_output_token_count: list[float] = [] | |
current_turn_latency: list[float] = [] | |
involved_instances = None | |
count = 0 | |
while True: | |
print("-" * 100) | |
print( | |
f"ID: {test_entry_id.replace('multi_turn_', '')}, Turn: {turn_idx}, Step: {count}" | |
) | |
current_step_inference_log: list[dict] = [] | |
# Add to the current_turn_inference_log at beginning of each step so that we don't need to bother dealing with the break statements | |
current_turn_inference_log[f"step_{count}"] = current_step_inference_log | |
start_time = time.time() | |
api_response = self._query_FC(inference_data) | |
query_latency = time.time() - start_time | |
# This part of logging is disabled by default because it is too verbose and will make the result file extremely large | |
# It is only useful to see if the inference pipeline is working as expected (eg, does it convert all the inputs correctly) | |
if include_input_log: | |
current_step_inference_log.append( | |
{ | |
"role": "handler_log", | |
"content": inference_data.get("inference_input_log", ""), | |
} | |
) | |
# Try parsing the model response | |
model_response_data = self._parse_query_response_FC(api_response) | |
model_responses = model_response_data["model_responses"] | |
# Add the assistant message to the chat history | |
inference_data = self._add_assistant_message_FC( | |
inference_data, model_response_data | |
) | |
# Process the metadata | |
current_turn_input_token_count.append(model_response_data["input_token"]) | |
current_turn_output_token_count.append(model_response_data["output_token"]) | |
current_turn_latency.append(query_latency) | |
current_turn_response.append(model_responses) | |
current_step_inference_log.append( | |
{"role": "assistant", "content": model_responses} | |
) | |
# Try decoding the model response | |
try: | |
decoded_model_responses = self.decode_execute(model_responses) | |
current_step_inference_log.append( | |
{ | |
"role": "handler_log", | |
"content": "Successfully decoded model response.", | |
"model_response_decoded": decoded_model_responses, | |
} | |
) | |
if is_empty_execute_response(decoded_model_responses): | |
print("Empty response from the model. Proceed to next turn.") | |
current_step_inference_log.append( | |
{ | |
"role": "handler_log", | |
"content": f"Empty response from the model. Proceed to next turn.", | |
"model_response_decoded": decoded_model_responses, | |
} | |
) | |
break | |
except Exception as e: | |
print("Failed to decode the model response. Proceed to next turn.") | |
current_step_inference_log.append( | |
{ | |
"role": "handler_log", | |
"content": f"Error decoding the model response. Proceed to next turn.", | |
"error": str(e), | |
} | |
) | |
yield ("summary", model_responses, None, self.model_name) | |
break | |
# Obtain the execution results | |
execution_results, involved_instances = execute_multi_turn_func_call( | |
decoded_model_responses, | |
initial_config, | |
involved_classes, | |
self.model_name_underline_replaced, | |
test_entry_id, | |
long_context=( | |
"long_context" in test_category or "composite" in test_category | |
), | |
is_evaL_run=False, | |
) | |
# Add the execution results to the chat history for the next turn | |
inference_data = self._add_execution_results_FC( | |
inference_data, execution_results, model_response_data | |
) | |
for execution_result in execution_results: | |
current_step_inference_log.append( | |
{ | |
"role": "tool", | |
"content": execution_result, | |
} | |
) | |
execution_results = deepcopy(execution_results) | |
for i in range(len(execution_results)): | |
if "error" in execution_results[i]: | |
execution_results[i] = execution_results[i].replace("error", "error❗️") | |
yield ("regular", decoded_model_responses, execution_results, self.model_name) | |
count += 1 | |
# Force quit after too many steps | |
if count > MAXIMUM_STEP_LIMIT: | |
force_quit = True | |
current_step_inference_log.append( | |
{ | |
"role": "handler_log", | |
"content": f"Model has been forced to quit after {MAXIMUM_STEP_LIMIT} steps.", | |
} | |
) | |
break | |
# Add to the total list | |
all_model_response.append(current_turn_response) | |
all_inference_log.append(current_turn_inference_log) | |
total_input_token_count.append(current_turn_input_token_count) | |
total_output_token_count.append(current_turn_output_token_count) | |
total_latency.append(current_turn_latency) | |
if include_state_log: | |
state_log = [] | |
for class_name, class_instance in involved_instances.items(): | |
if class_name in STATELESS_CLASSES: | |
continue | |
class_instance = deepcopy(class_instance) # Avoid modification in future turns | |
state_log.append( | |
{ | |
"role": "state_info", | |
"class_name": class_name, | |
"content": { | |
key: value | |
for key, value in vars(class_instance).items() | |
if not key.startswith("_") | |
}, | |
} | |
) | |
all_inference_log.append(state_log) | |
if force_quit: | |
break | |
metadata = { | |
"input_token_count": total_input_token_count, | |
"output_token_count": total_output_token_count, | |
"latency": total_latency, | |
"inference_log": all_inference_log, | |
} | |
yield ("final", current_turn_response, inference_data, involved_instances) | |
def decode_ast(self, result, language="Python"): | |
# This method takes raw model output and convert it to standard AST checker input. | |
raise NotImplementedError | |
def decode_execute(self, result): | |
# This method takes raw model output and convert it to standard execute checker input. | |
raise NotImplementedError | |
#### FC methods #### | |
def _query_FC(self, inference_data: dict): | |
""" | |
Call the model API in FC mode to get the response. | |
Return the response object that can be used to feed into the decode method. | |
""" | |
raise NotImplementedError | |
def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict: | |
""" | |
Preprocess the testset entry before sending it to the model. | |
This includes transforming the input user message into the format expected by the model, and any other necessary preprocessing steps. | |
The inference_data dict is updated in place and returned. | |
""" | |
raise NotImplementedError | |
def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict: | |
""" | |
Compile the tools from the test entry and add them to the inference data. | |
This method is used to prepare the tools for the model query in FC mode. | |
The inference_data dict is updated in place and returned. | |
""" | |
raise NotImplementedError | |
def _parse_query_response_FC(self, api_response: any) -> dict: | |
""" | |
Parses the raw response from the model API to extract the result, input token count, and output token count. | |
Args: | |
api_response (any): The raw response from the model API. | |
Returns: | |
A dict containing the following elements: | |
- model_responses (any): The parsed result that can be directly used as input to the decode method. | |
- input_token (int): The number of tokens used in the input to the model. | |
- output_token (int): The number of tokens generated by the model as output. | |
- tool_call_ids (list[str]): The IDs of the tool calls that are generated by the model. Optional. | |
- Any other metadata that is specific to the model. | |
""" | |
raise NotImplementedError | |
def add_first_turn_message_FC( | |
self, inference_data: dict, first_turn_message: list[dict] | |
) -> dict: | |
""" | |
Add the first turn message to the chat history. | |
""" | |
raise NotImplementedError | |
def _add_next_turn_user_message_FC( | |
self, inference_data: dict, user_message: list[dict] | |
) -> dict: | |
""" | |
[Only for multi-turn] | |
Add next turn user message to the chat history for query. | |
user_message is a list of 1 element, which is the user message. | |
""" | |
raise NotImplementedError | |
def _add_assistant_message_FC( | |
self, inference_data: dict, model_response_data: dict | |
) -> dict: | |
""" | |
Add assistant message to the chat history. | |
""" | |
raise NotImplementedError | |
def _add_execution_results_FC( | |
self, inference_data: dict, execution_results: list[str], model_response_data: dict | |
) -> dict: | |
""" | |
Add the execution results to the chat history to prepare for the next turn of query. | |
Some models may need to add additional information to the chat history, such as tool call IDs. | |
""" | |
raise NotImplementedError | |
#### Prompting methods #### | |
def _query_prompting(self, inference_data: dict): | |
""" | |
Call the model API in prompting mode to get the response. | |
Return the response object that can be used to feed into the decode method. | |
""" | |
raise NotImplementedError | |
def _pre_query_processing_prompting(self, test_entry: dict) -> dict: | |
""" | |
Preprocess the testset entry before sending it to the model. | |
Returns a dict that contains all the necessary information for the query method. | |
`tools` and `message` must be included in the returned dict. | |
Things like `system_prompt` and `chat_history` are optional, specific to the model. | |
""" | |
raise NotImplementedError | |
def _parse_query_response_prompting(self, api_response: any) -> dict: | |
""" | |
Parses the raw response from the model API to extract the result, input token count, and output token count. | |
Args: | |
api_response (any): The raw response from the model API. | |
Returns: | |
A dict containing the following elements: | |
- model_responses (any): The parsed result that can be directly used as input to the decode method. | |
- input_token (int): The number of tokens used in the input to the model. | |
- output_token (int): The number of tokens generated by the model as output. | |
- tool_call_ids (list[str]): The IDs of the tool calls that are generated by the model. Optional. | |
- Any other metadata that is specific to the model. | |
""" | |
raise NotImplementedError | |
def add_first_turn_message_prompting( | |
self, inference_data: dict, first_turn_message: list[dict] | |
) -> dict: | |
""" | |
Add the first turn message to the chat history. | |
""" | |
raise NotImplementedError | |
def _add_next_turn_user_message_prompting( | |
self, inference_data: dict, user_message: list[dict] | |
) -> dict: | |
""" | |
[Only for multi-turn] | |
Add next turn user message to the chat history for query. | |
user_message is a list of 1 element, which is the user message. | |
""" | |
raise NotImplementedError | |
def _add_assistant_message_prompting( | |
self, inference_data: dict, model_response_data: dict | |
) -> dict: | |
""" | |
Add assistant message to the chat history. | |
""" | |
raise NotImplementedError | |
def _add_execution_results_prompting( | |
self, inference_data: dict, execution_results: list[str], model_response_data: dict | |
) -> dict: | |
""" | |
Add the execution results to the chat history to prepare for the next turn of query. | |
Some models may need to add additional information to the chat history, such as tool call IDs. | |
""" | |
raise NotImplementedError | |