Spaces:
Runtime error
Runtime error
import gradio as gr | |
from langchain_core.messages import HumanMessage | |
from langgraph.graph import StateGraph, MessagesState, START, END | |
from langgraph.checkpoint.memory import MemorySaver | |
from langgraph.prebuilt import ToolNode | |
from langchain_huggingface import ChatHuggingFace | |
from typing import Literal | |
import os | |
# Hugging Face Endpoint setup | |
HF_TOKEN = os.getenv("HF_TOKEN") # Ensure your API key is in the environment | |
llm = ChatHuggingFace( | |
repo_id="mistralai/Mistral-7B-Instruct-v0.3", | |
huggingfacehub_api_token=HF_TOKEN.strip(), | |
temperature=0.7, | |
max_new_tokens=150, | |
) | |
# Define tools for travel planning | |
def search_destination(query: str): | |
"""Search for travel destinations based on preferences.""" | |
if "beach" in query.lower(): | |
return "Destinations: Bali, Maldives, Goa" | |
elif "adventure" in query.lower(): | |
return "Destinations: Patagonia, Kilimanjaro, Swiss Alps" | |
return "Popular Destinations: Paris, New York, Tokyo" | |
def fetch_flights(destination: str): | |
"""Fetch flight options to the destination.""" | |
return f"Flights available to {destination}: Option 1 - $500, Option 2 - $700." | |
def fetch_hotels(destination: str): | |
"""Fetch hotel recommendations for the destination.""" | |
return f"Hotels in {destination}: Hotel A ($200/night), Hotel B ($150/night)." | |
tools = [search_destination, fetch_flights, fetch_hotels] | |
# Bind tools to the Hugging Face model | |
llm = llm.bind_tools(tools) | |
# Define the function to determine the next step | |
def should_continue(state: MessagesState) -> Literal["tools", END]: | |
messages = state['messages'] | |
last_message = messages[-1] | |
if last_message.tool_calls: | |
return "tools" | |
return END | |
# Define the function to call the LLM | |
def call_model(state: MessagesState): | |
messages = state['messages'] | |
response = llm.invoke(messages) | |
return {"messages": [response]} | |
# Create the graph | |
workflow = StateGraph(MessagesState) | |
# Define nodes | |
workflow.add_node("agent", call_model) | |
workflow.add_node("tools", ToolNode(tools)) | |
# Define edges | |
workflow.add_edge(START, "agent") | |
workflow.add_conditional_edges("agent", should_continue) | |
workflow.add_edge("tools", "agent") | |
# Initialize memory for persistence | |
checkpointer = MemorySaver() | |
# Compile the graph | |
app = workflow.compile(checkpointer=checkpointer) | |
# Function to process user input and generate response | |
def chat(user_input, history): | |
# Prepare the initial state with the user's message | |
initial_state = {"messages": [HumanMessage(content=user_input)]} | |
# Invoke the workflow | |
final_state = app.invoke(initial_state) | |
# Extract the response content | |
response = final_state["messages"][-1].content | |
# Append the user input and response to history | |
history.append((user_input, response)) | |
return history, history | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
with gr.Column(): | |
user_input = gr.Textbox( | |
show_label=False, | |
placeholder="Enter your message...", | |
).style(container=False) | |
with gr.Column(): | |
submit_btn = gr.Button("Send") | |
# Set up the event handler for the submit button | |
submit_btn.click(chat, [user_input, chatbot], [chatbot, chatbot]) | |
user_input.submit(chat, [user_input, chatbot], [chatbot, chatbot]) | |
# Launch the Gradio app | |
demo.launch() | |