ricklamers's picture
fix: hacky sessions.
60ed75d
raw
history blame
4.75 kB
import gradio as gr
import json
import os
import numexpr
import uuid
from groq import Groq
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
MODEL = "llama3-groq-8b-8192-tool-use-preview"
client = Groq(api_key=os.environ["GROQ_API_KEY"])
def evaluate_math_expression(expression: str):
return json.dumps(numexpr.evaluate(expression).tolist())
calculator_tool: ChatCompletionToolParam = {
"type": "function",
"function": {
"name": "evaluate_math_expression",
"description":
"Calculator tool: use this for evaluating numeric expressions with Python. Ensure the expression is valid Python syntax (e.g., use '**' for exponentiation, not '^').",
"parameters": {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "The mathematical expression to evaluate. Must be valid Python syntax.",
},
},
"required": ["expression"],
},
},
}
tools = [calculator_tool]
def call_function(tool_call, available_functions):
function_name = tool_call.function.name
if function_name not in available_functions:
return {
"tool_call_id": tool_call.id,
"role": "tool",
"content": f"Function {function_name} does not exist.",
}
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
function_response = function_to_call(**function_args)
return {
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": json.dumps(function_response),
}
def get_model_response(messages):
try:
return client.chat.completions.create(
model=MODEL,
messages=messages,
tools=tools,
temperature=0.5,
top_p=0.65,
max_tokens=4096,
)
except Exception as e:
print(f"An error occurred while getting model response: {str(e)}")
print(messages)
return None
conversation_state = {}
def respond(message, history, system_message):
if not history or not isinstance(history[0][0], str):
session_id = str(uuid.uuid4())
history.insert(0, (session_id, "Confirmed."))
else:
session_id = history[0][0]
if session_id not in conversation_state:
conversation_state[session_id] = []
if len(conversation_state[session_id]) == 0:
conversation_state[session_id].append({"role": "system", "content": system_message})
conversation_state[session_id].append({"role": "user", "content": message})
available_functions = {
"evaluate_math_expression": evaluate_math_expression,
}
function_calls = []
while True:
response = get_model_response(conversation_state[session_id])
response_message = response.choices[0].message
conversation_state[session_id].append(response_message)
if not response_message.tool_calls and response_message.content is not None:
break
if response_message.tool_calls is not None:
for tool_call in response_message.tool_calls:
function_call = {
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
}
function_calls.append(function_call)
function_response = call_function(tool_call, available_functions)
conversation_state[session_id].append(function_response)
function_calls.append({
"name": function_response["name"],
"result": json.loads(function_response["content"])
})
function_calls_md = "\n\n"
for i in range(0, len(function_calls), 2):
call = function_calls[i]
result = function_calls[i + 1] if i + 1 < len(function_calls) else None
function_calls_md += f"**Tool call:**\n```json\n{json.dumps({'name': call['name'], 'arguments': call['arguments'], 'result': result['result'] if result else None}, indent=2)}\n```\n"
return response_message.content + function_calls_md
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot with access to a calculator. Don't mention that we are using functions defined in Python.", label="System message"),
],
title="Groq Tool Use Chat",
description="This chatbot uses the `llama3-groq-8b-8192-tool-use-preview` LLM with tool use capabilities, including a calculator function.",
)
if __name__ == "__main__":
demo.launch()