Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
import os | |
import numexpr | |
from groq import Groq | |
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam | |
MODEL = "llama3-groq-8b-8192-tool-use-preview" | |
client = Groq(api_key=os.environ["GROQ_API_KEY"]) | |
def evaluate_math_expression(expression: str): | |
return json.dumps(numexpr.evaluate(expression).tolist()) | |
calculator_tool: ChatCompletionToolParam = { | |
"type": "function", | |
"function": { | |
"name": "evaluate_math_expression", | |
"description": "Calculator tool: use this for evaluating numeric expressions with Python. Ensure the expression is valid Python syntax (e.g., use '**' for exponentiation, not '^').", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"expression": { | |
"type": "string", | |
"description": "The mathematical expression to evaluate. Must be valid Python syntax.", | |
}, | |
}, | |
"required": ["expression"], | |
}, | |
}, | |
} | |
tools = [calculator_tool] | |
def call_function(tool_call, available_functions): | |
function_name = tool_call.function.name | |
if function_name not in available_functions: | |
return { | |
"tool_call_id": tool_call.id, | |
"role": "tool", | |
"content": f"Function {function_name} does not exist.", | |
} | |
function_to_call = available_functions[function_name] | |
function_args = json.loads(tool_call.function.arguments) | |
function_response = function_to_call(**function_args) | |
return { | |
"tool_call_id": tool_call.id, | |
"role": "tool", | |
"name": function_name, | |
"content": json.dumps(function_response), | |
} | |
def get_model_response(messages, inner_messages, message, system_message): | |
messages_for_model = [] | |
for msg in messages: | |
native_messages = msg.get("metadata", {}).get("native_messages", [msg]) | |
if isinstance(native_messages, list): | |
messages_for_model.extend(native_messages) | |
else: | |
messages_for_model.append(native_messages) | |
messages_for_model.insert( | |
0, | |
{ | |
"role": "system", | |
"content": system_message, | |
}, | |
) | |
messages_for_model.append( | |
{ | |
"role": "user", | |
"content": message, | |
} | |
) | |
messages_for_model.extend(inner_messages) | |
try: | |
return client.chat.completions.create( | |
model=MODEL, | |
messages=messages_for_model, | |
tools=tools, | |
temperature=0.5, | |
top_p=0.65, | |
max_tokens=4096, | |
) | |
except Exception as e: | |
print(f"An error occurred while getting model response: {str(e)}") | |
print(messages_for_model) | |
return None | |
def respond(message, history, system_message): | |
inner_history = [] | |
available_functions = { | |
"evaluate_math_expression": evaluate_math_expression, | |
} | |
assistant_content = "" | |
assistant_native_message_list = [] | |
while True: | |
response_message = ( | |
get_model_response(history, inner_history, message, system_message) | |
.choices[0] | |
.message | |
) | |
if not response_message.tool_calls and response_message.content is not None: | |
break | |
if response_message.tool_calls is not None: | |
assistant_native_message_list.append(response_message) | |
inner_history.append(response_message) | |
assistant_content += ( | |
"```json\n" | |
+ json.dumps( | |
[ | |
tool_call.model_dump() | |
for tool_call in response_message.tool_calls | |
], | |
indent=2, | |
) | |
+ "\n```\n" | |
) | |
assistant_message = { | |
"role": "assistant", | |
"content": assistant_content, | |
"metadata": {"native_messages": assistant_native_message_list}, | |
} | |
yield assistant_message | |
for tool_call in response_message.tool_calls: | |
function_response = call_function(tool_call, available_functions) | |
assistant_content += ( | |
"```json\n" | |
+ json.dumps( | |
{ | |
"name": tool_call.function.name, | |
"arguments": json.loads(tool_call.function.arguments), | |
"response": json.loads(function_response["content"]), | |
}, | |
indent=2, | |
) | |
+ "\n```\n" | |
) | |
native_tool_message = { | |
"tool_call_id": tool_call.id, | |
"role": "tool", | |
"content": function_response["content"], | |
} | |
assistant_native_message_list.append( | |
native_tool_message | |
) | |
tool_message = { | |
"role": "assistant", | |
"content": assistant_content, | |
"metadata": {"native_messages": assistant_native_message_list}, | |
} | |
yield tool_message | |
inner_history.append(native_tool_message) | |
assistant_content += response_message.content | |
assistant_native_message_list.append(response_message) | |
final_message = { | |
"role": "assistant", | |
"content": assistant_content, | |
"metadata": {"native_messages": assistant_native_message_list}, | |
} | |
yield final_message | |
system_prompt = "You are a friendly Chatbot with access to a calculator. Don't mention that we are using functions defined in Python." | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox( | |
value=system_prompt, | |
label="System message", | |
), | |
], | |
type="messages", | |
examples=[ | |
["What is 42 to the power of 2?", system_prompt], | |
["If I have 3 apples and multiply them by 7, how many do I have?", system_prompt], | |
], | |
title="Groq Tool Use Chat", | |
description="This chatbot uses the `llama3-groq-8b-8192-tool-use-preview` LLM with tool use capabilities, including a calculator function.", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |