pratikshahp's picture
Update app.py
267129f verified
raw
history blame
3.41 kB
import gradio as gr
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode
from langchain_huggingface import ChatHuggingFace
from typing import Literal
import os
# Hugging Face Endpoint setup
HF_TOKEN = os.getenv("HF_TOKEN") # Ensure your API key is in the environment
llm = ChatHuggingFace(
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
huggingfacehub_api_token=HF_TOKEN.strip(),
temperature=0.7,
max_new_tokens=150,
)
# Define tools for travel planning
def search_destination(query: str):
"""Search for travel destinations based on preferences."""
if "beach" in query.lower():
return "Destinations: Bali, Maldives, Goa"
elif "adventure" in query.lower():
return "Destinations: Patagonia, Kilimanjaro, Swiss Alps"
return "Popular Destinations: Paris, New York, Tokyo"
def fetch_flights(destination: str):
"""Fetch flight options to the destination."""
return f"Flights available to {destination}: Option 1 - $500, Option 2 - $700."
def fetch_hotels(destination: str):
"""Fetch hotel recommendations for the destination."""
return f"Hotels in {destination}: Hotel A ($200/night), Hotel B ($150/night)."
tools = [search_destination, fetch_flights, fetch_hotels]
# Bind tools to the Hugging Face model
llm = llm.bind_tools(tools)
# Define the function to determine the next step
def should_continue(state: MessagesState) -> Literal["tools", END]:
messages = state['messages']
last_message = messages[-1]
if last_message.tool_calls:
return "tools"
return END
# Define the function to call the LLM
def call_model(state: MessagesState):
messages = state['messages']
response = llm.invoke(messages)
return {"messages": [response]}
# Create the graph
workflow = StateGraph(MessagesState)
# Define nodes
workflow.add_node("agent", call_model)
workflow.add_node("tools", ToolNode(tools))
# Define edges
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", should_continue)
workflow.add_edge("tools", "agent")
# Initialize memory for persistence
checkpointer = MemorySaver()
# Compile the graph
app = workflow.compile(checkpointer=checkpointer)
# Function to process user input and generate response
def chat(user_input, history):
# Prepare the initial state with the user's message
initial_state = {"messages": [HumanMessage(content=user_input)]}
# Invoke the workflow
final_state = app.invoke(initial_state)
# Extract the response content
response = final_state["messages"][-1].content
# Append the user input and response to history
history.append((user_input, response))
return history, history
# Create the Gradio interface
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column():
user_input = gr.Textbox(
show_label=False,
placeholder="Enter your message...",
).style(container=False)
with gr.Column():
submit_btn = gr.Button("Send")
# Set up the event handler for the submit button
submit_btn.click(chat, [user_input, chatbot], [chatbot, chatbot])
user_input.submit(chat, [user_input, chatbot], [chatbot, chatbot])
# Launch the Gradio app
demo.launch()