Spaces:
Sleeping
Sleeping
import json | |
import os | |
from anthropic import Anthropic | |
from anthropic.types import TextBlock, ToolUseBlock | |
from base_handler import BaseHandler | |
from constant import GORILLA_TO_OPENAPI | |
from model_style import ModelStyle | |
from utils import ( | |
ast_parse, | |
combine_consecutive_user_prompts, | |
convert_system_prompt_into_user_prompt, | |
convert_to_function_call, | |
convert_to_tool, | |
extract_system_prompt, | |
format_execution_results_prompting, | |
func_doc_language_specific_pre_processing, | |
system_prompt_pre_processing_chat_model, | |
) | |
class ClaudeHandler(BaseHandler): | |
def __init__(self, model_name, temperature) -> None: | |
super().__init__(model_name, temperature) | |
self.model_style = ModelStyle.Anthropic | |
self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) | |
def decode_ast(self, result, language="Python"): | |
if "FC" not in self.model_name: | |
func = result | |
if " " == func[0]: | |
func = func[1:] | |
if not func.startswith("["): | |
func = "[" + func | |
if not func.endswith("]"): | |
func = func + "]" | |
decode_output = ast_parse(func, language) | |
return decode_output | |
else: | |
decoded_output = [] | |
for invoked_function in result: | |
name = list(invoked_function.keys())[0] | |
params = json.loads(invoked_function[name]) | |
decoded_output.append({name: params}) | |
return decoded_output | |
def decode_execute(self, result): | |
if "FC" not in self.model_name: | |
func = result | |
if " " == func[0]: | |
func = func[1:] | |
if not func.startswith("["): | |
func = "[" + func | |
if not func.endswith("]"): | |
func = func + "]" | |
decode_output = ast_parse(func) | |
execution_list = [] | |
for function_call in decode_output: | |
for key, value in function_call.items(): | |
execution_list.append( | |
f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})" | |
) | |
return execution_list | |
else: | |
function_call = convert_to_function_call(result) | |
return function_call | |
#### FC methods #### | |
def _query_FC(self, inference_data: dict): | |
inference_data["inference_input_log"] = { | |
"message": repr(inference_data["message"]), | |
"tools": inference_data["tools"], | |
} | |
messages = inference_data["message"] | |
if inference_data["caching_enabled"]: | |
# Only add cache control to the last two user messages | |
# Remove previously set cache control flags from all user messages except the last two | |
count = 0 | |
for message in reversed(messages): | |
if message["role"] == "user": | |
if count < 2: | |
message["content"][0]["cache_control"] = {"type": "ephemeral"} | |
else: | |
if "cache_control" in message["content"][0]: | |
del message["content"][0]["cache_control"] | |
count += 1 | |
return self.client.beta.prompt_caching.messages.create( | |
model=self.model_name.strip("-FC"), | |
max_tokens=( | |
8192 if "claude-3-5" in self.model_name else 4096 | |
), # 3.5 Sonnet has a higher max token limit | |
tools=inference_data["tools"], | |
messages=messages, | |
) | |
def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict: | |
for round_idx in range(len(test_entry["question"])): | |
test_entry["question"][round_idx] = convert_system_prompt_into_user_prompt( | |
test_entry["question"][round_idx] | |
) | |
test_entry["question"][round_idx] = combine_consecutive_user_prompts( | |
test_entry["question"][round_idx] | |
) | |
inference_data["message"] = [] | |
test_entry_id: str = test_entry["id"] | |
test_category: str = test_entry_id.rsplit("_", 1)[0] | |
# caching enabled only for multi_turn category | |
inference_data["caching_enabled"] = ( | |
"claude-3-sonnet" not in self.model_name | |
) | |
return inference_data | |
def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict: | |
functions: list = test_entry["function"] | |
test_category: str = test_entry["id"].rsplit("_", 1)[0] | |
functions = func_doc_language_specific_pre_processing(functions, test_category) | |
tools = convert_to_tool(functions, GORILLA_TO_OPENAPI, self.model_style) | |
if inference_data["caching_enabled"]: | |
# First time compiling tools, so adding cache control flag to the last tool | |
if "tools" not in inference_data: | |
tools[-1]["cache_control"] = {"type": "ephemeral"} | |
# This is the situation where the tools are already compiled and we are adding more tools to the existing tools (in miss_func category) | |
# We add the cache control flag to the last tool in the previous existing tools and the last tool in the new tools to maximize cache hit | |
else: | |
existing_tool_len = len(inference_data["tools"]) | |
tools[existing_tool_len - 1]["cache_control"] = {"type": "ephemeral"} | |
tools[-1]["cache_control"] = {"type": "ephemeral"} | |
inference_data["tools"] = tools | |
return inference_data | |
def _parse_query_response_FC(self, api_response: any) -> dict: | |
text_outputs = [] | |
tool_call_outputs = [] | |
tool_call_ids = [] | |
for content in api_response.content: | |
if isinstance(content, TextBlock): | |
text_outputs.append(content.text) | |
elif isinstance(content, ToolUseBlock): | |
tool_call_outputs.append({content.name: json.dumps(content.input)}) | |
tool_call_ids.append(content.id) | |
model_responses = tool_call_outputs if tool_call_outputs else text_outputs | |
model_responses_message_for_chat_history = api_response.content | |
return { | |
"model_responses": model_responses, | |
"model_responses_message_for_chat_history": model_responses_message_for_chat_history, | |
"tool_call_ids": tool_call_ids, | |
"input_token": api_response.usage.input_tokens, | |
"output_token": api_response.usage.output_tokens, | |
} | |
def add_first_turn_message_FC( | |
self, inference_data: dict, first_turn_message: list[dict] | |
) -> dict: | |
for message in first_turn_message: | |
message["content"] = [{"type": "text", "text": message["content"]}] | |
inference_data["message"].extend(first_turn_message) | |
return inference_data | |
def _add_next_turn_user_message_FC( | |
self, inference_data: dict, user_message: list[dict] | |
) -> dict: | |
for message in user_message: | |
message["content"] = [{"type": "text", "text": message["content"]}] | |
inference_data["message"].extend(user_message) | |
return inference_data | |
def _add_assistant_message_FC( | |
self, inference_data: dict, model_response_data: dict | |
) -> dict: | |
inference_data["message"].append( | |
{ | |
"role": "assistant", | |
"content": model_response_data["model_responses_message_for_chat_history"], | |
} | |
) | |
return inference_data | |
def _add_execution_results_FC( | |
self, | |
inference_data: dict, | |
execution_results: list[str], | |
model_response_data: dict, | |
) -> dict: | |
# Claude don't use the tool role; it uses the user role to send the tool output | |
tool_message = { | |
"role": "user", | |
"content": [], | |
} | |
for execution_result, tool_call_id in zip( | |
execution_results, model_response_data["tool_call_ids"] | |
): | |
tool_message["content"].append( | |
{ | |
"type": "tool_result", | |
"content": execution_result, | |
"tool_use_id": tool_call_id, | |
} | |
) | |
inference_data["message"].append(tool_message) | |
return inference_data | |