Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,19 +1,15 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
from
|
4 |
-
from
|
5 |
from langchain_huggingface import HuggingFaceEndpoint
|
6 |
-
from
|
7 |
-
|
8 |
-
|
9 |
-
load_dotenv()
|
10 |
|
11 |
-
#
|
12 |
-
HF_TOKEN = os.getenv("HF_TOKEN")
|
13 |
-
if not HF_TOKEN:
|
14 |
-
raise ValueError("Hugging Face API token not found. Set HF_TOKEN in your environment variables.")
|
15 |
|
16 |
-
# Define the Hugging Face model endpoint
|
17 |
llm = HuggingFaceEndpoint(
|
18 |
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
|
19 |
huggingfacehub_api_token=HF_TOKEN.strip(),
|
@@ -21,85 +17,72 @@ llm = HuggingFaceEndpoint(
|
|
21 |
max_new_tokens=150,
|
22 |
)
|
23 |
|
24 |
-
# Define
|
25 |
-
|
26 |
-
def
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
def
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
)
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
)
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
)
|
69 |
-
|
70 |
-
# Create the agent
|
71 |
-
planner_agent = AgentExecutor.from_agent_and_tools(
|
72 |
-
llm=llm,
|
73 |
-
tools=tools,
|
74 |
-
verbose=True
|
75 |
-
)
|
76 |
-
|
77 |
-
# Define the Gradio function
|
78 |
-
def plan_trip(user_input):
|
79 |
-
"""
|
80 |
-
Processes user input through the planner agent and returns the output.
|
81 |
-
"""
|
82 |
-
try:
|
83 |
-
output = planner_agent.run(user_input)
|
84 |
-
return output
|
85 |
-
except Exception as e:
|
86 |
-
return f"Error: {e}"
|
87 |
-
|
88 |
-
# Create the Gradio interface
|
89 |
-
with gr.Blocks() as travel_planner_app:
|
90 |
-
gr.Markdown("## AI Travel Planner")
|
91 |
-
gr.Markdown("Enter your trip details below, and the AI will help you plan your travel itinerary!")
|
92 |
-
|
93 |
-
with gr.Row():
|
94 |
-
user_input = gr.Textbox(
|
95 |
-
label="Trip Details",
|
96 |
-
placeholder="E.g., Plan a trip to Paris from May 10 to May 15 for a family of 4."
|
97 |
-
)
|
98 |
-
submit_button = gr.Button("Plan Trip")
|
99 |
-
output = gr.Textbox(label="Travel Plan")
|
100 |
-
|
101 |
-
submit_button.click(plan_trip, inputs=user_input, outputs=output)
|
102 |
-
|
103 |
-
# Launch the Gradio app
|
104 |
-
if __name__ == "__main__":
|
105 |
-
travel_planner_app.launch()
|
|
|
1 |
+
from langgraph.graph import StateGraph, MessagesState, START, END
|
2 |
+
from langgraph.checkpoint.memory import MemorySaver
|
3 |
+
from langgraph.prebuilt import ToolNode
|
4 |
+
from langchain_core.messages import HumanMessage
|
5 |
from langchain_huggingface import HuggingFaceEndpoint
|
6 |
+
from langchain_core.tools import tool
|
7 |
+
from typing import Literal
|
8 |
+
import os
|
|
|
9 |
|
10 |
+
# Hugging Face Endpoint setup
|
11 |
+
HF_TOKEN = os.getenv("HF_TOKEN") # Ensure your API key is in the environment
|
|
|
|
|
12 |
|
|
|
13 |
llm = HuggingFaceEndpoint(
|
14 |
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
|
15 |
huggingfacehub_api_token=HF_TOKEN.strip(),
|
|
|
17 |
max_new_tokens=150,
|
18 |
)
|
19 |
|
20 |
+
# Define tools for travel planning
|
21 |
+
@tool
|
22 |
+
def search_destination(query: str):
|
23 |
+
"""Search for travel destinations based on preferences."""
|
24 |
+
if "beach" in query.lower():
|
25 |
+
return "Destinations: Bali, Maldives, Goa"
|
26 |
+
elif "adventure" in query.lower():
|
27 |
+
return "Destinations: Patagonia, Kilimanjaro, Swiss Alps"
|
28 |
+
return "Popular Destinations: Paris, New York, Tokyo"
|
29 |
+
|
30 |
+
@tool
|
31 |
+
def fetch_flights(destination: str):
|
32 |
+
"""Fetch flight options to the destination."""
|
33 |
+
return f"Flights available to {destination}: Option 1 - $500, Option 2 - $700."
|
34 |
+
|
35 |
+
@tool
|
36 |
+
def fetch_hotels(destination: str):
|
37 |
+
"""Fetch hotel recommendations for the destination."""
|
38 |
+
return f"Hotels in {destination}: Hotel A ($200/night), Hotel B ($150/night)."
|
39 |
+
|
40 |
+
tools = [search_destination, fetch_flights, fetch_hotels]
|
41 |
+
|
42 |
+
# Bind tools to the Hugging Face model
|
43 |
+
llm = llm.bind_tools(tools)
|
44 |
+
|
45 |
+
# Define the function to determine the next step
|
46 |
+
def should_continue(state: MessagesState) -> Literal["tools", END]:
|
47 |
+
messages = state['messages']
|
48 |
+
last_message = messages[-1]
|
49 |
+
if last_message.tool_calls:
|
50 |
+
return "tools"
|
51 |
+
return END
|
52 |
+
|
53 |
+
# Define the function to call the LLM
|
54 |
+
def call_model(state: MessagesState):
|
55 |
+
messages = state['messages']
|
56 |
+
response = llm.invoke(messages)
|
57 |
+
return {"messages": [response]}
|
58 |
+
|
59 |
+
# Create the graph
|
60 |
+
workflow = StateGraph(MessagesState)
|
61 |
+
|
62 |
+
# Define nodes
|
63 |
+
workflow.add_node("agent", call_model)
|
64 |
+
workflow.add_node("tools", ToolNode(tools))
|
65 |
+
|
66 |
+
# Define edges
|
67 |
+
workflow.add_edge(START, "agent")
|
68 |
+
workflow.add_conditional_edges("agent", should_continue)
|
69 |
+
workflow.add_edge("tools", "agent")
|
70 |
+
|
71 |
+
# Initialize memory for persistence
|
72 |
+
checkpointer = MemorySaver()
|
73 |
+
|
74 |
+
# Compile the graph
|
75 |
+
app = workflow.compile(checkpointer=checkpointer)
|
76 |
+
|
77 |
+
# Use the graph
|
78 |
+
initial_state = {"messages": [HumanMessage(content="I want a beach vacation.")]}
|
79 |
+
|
80 |
+
final_state = app.invoke(initial_state)
|
81 |
+
print(final_state["messages"][-1].content)
|
82 |
+
|
83 |
+
# Continue the conversation
|
84 |
+
next_state = app.invoke(
|
85 |
+
{"messages": [HumanMessage(content="Tell me about flights to Bali.")]},
|
86 |
+
config={"configurable": {"thread_id": 123}} # Reuse thread_id to maintain context
|
87 |
)
|
88 |
+
print(next_state["messages"][-1].content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|