|
import json |
|
import logging |
|
from fastapi import APIRouter, Body |
|
from typing import List, Dict |
|
from pydantic import BaseModel |
|
|
|
try: |
|
from tool_gpu_checker import gpu_checker_get_message |
|
from tool_bpy_doc import bpy_doc_get_documentation |
|
from tool_find_related import find_related |
|
from tool_wiki_search import wiki_search |
|
except: |
|
from routers.tool_gpu_checker import gpu_checker_get_message |
|
from routers.tool_bpy_doc import bpy_doc_get_documentation |
|
from routers.tool_find_related import find_related |
|
from routers.tool_wiki_search import wiki_search |
|
|
|
|
|
class ToolCallFunction(BaseModel): |
|
name: str |
|
arguments: str |
|
|
|
|
|
class ToolCallInput(BaseModel): |
|
id: str |
|
type: str |
|
function: ToolCallFunction |
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
def process_tool_call(tool_call: ToolCallInput) -> Dict: |
|
output = {"tool_call_id": tool_call.id, "output": ""} |
|
function_name = tool_call.function.name |
|
|
|
try: |
|
function_args = json.loads(tool_call.function.arguments) |
|
if function_name == "get_bpy_api_info": |
|
output["output"] = bpy_doc_get_documentation( |
|
function_args.get("api", "")) |
|
elif function_name == "check_gpu": |
|
output["output"] = gpu_checker_get_message( |
|
function_args.get("gpu", "")) |
|
elif function_name == "find_related": |
|
output["output"] = find_related( |
|
function_args["repo"], function_args["number"]) |
|
elif function_name == "wiki_search": |
|
output["output"] = wiki_search(**function_args) |
|
except json.JSONDecodeError as e: |
|
error_message = f"Malformed JSON encountered at position {e.pos}: {e.msg}\n {tool_call.function.arguments}" |
|
output["output"] = error_message |
|
|
|
|
|
logging.error(f"JSONDecodeError in process_tool_call: {error_message}") |
|
|
|
return output |
|
|
|
|
|
@router.post("/function_call", response_model=List[Dict]) |
|
def function_call(tool_calls: List[ToolCallInput] = Body(..., description="List of tool calls in the request body")): |
|
""" |
|
Endpoint to process tool calls. |
|
Args: |
|
tool_calls (List[ToolCallInput]): List of tool calls. |
|
Returns: |
|
List[Dict]: List of tool outputs with tool_call_id and output. |
|
""" |
|
tool_outputs = [process_tool_call(tool_input) for tool_input in tool_calls] |
|
return tool_outputs |
|
|
|
|
|
if __name__ == "__main__": |
|
tool_calls_data = [ |
|
{ |
|
"id": "call_abc123", |
|
"type": "function", |
|
"function": { |
|
"name": "get_bpy_api_info", |
|
"arguments": "{\"api\":\"bpy.context.scene.world\"}" |
|
} |
|
}, |
|
{ |
|
"id": "call_abc456", |
|
"type": "function", |
|
"function": { |
|
"name": "check_gpu", |
|
"arguments": "{\"gpu\":\"Mesa Intel(R) Iris(R) Plus Graphics 640 (Kaby Lake GT3e) (KBL GT3) Intel 4.6 (Core Profile) Mesa 22.2.5\"}" |
|
} |
|
}, |
|
{ |
|
"id": "call_abc789", |
|
"type": "function", |
|
"function": { |
|
"name": "find_related", |
|
"arguments": "{\"repo\":\"blender\",\"number\":111434}" |
|
} |
|
}, |
|
{ |
|
"id": "call_abc101112", |
|
"type": "function", |
|
"function": { |
|
"name": "wiki_search", |
|
"arguments": "{\"query\":\"Set Snap Base\",\"groups\":[\"manual\"]}" |
|
} |
|
} |
|
] |
|
|
|
tool_calls = [ |
|
ToolCallInput(id=tc['id'], type=tc['type'], |
|
function=ToolCallFunction(**tc['function'])) |
|
for tc in tool_calls_data |
|
] |
|
|
|
test = function_call(tool_calls) |
|
print(test) |
|
|