pratikshahp commited on
Commit
97867c9
·
verified ·
1 Parent(s): 30c95ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -94
app.py CHANGED
@@ -1,19 +1,15 @@
1
- import os
2
- import gradio as gr
3
- from langchain.agents import AgentExecutor, Tool
4
- from langchain.prompts import PromptTemplate
5
  from langchain_huggingface import HuggingFaceEndpoint
6
- from dotenv import load_dotenv # Optional for .env files
7
-
8
- # Load environment variables from .env file
9
- load_dotenv()
10
 
11
- # Get Hugging Face token from environment variable
12
- HF_TOKEN = os.getenv("HF_TOKEN")
13
- if not HF_TOKEN:
14
- raise ValueError("Hugging Face API token not found. Set HF_TOKEN in your environment variables.")
15
 
16
- # Define the Hugging Face model endpoint
17
  llm = HuggingFaceEndpoint(
18
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
19
  huggingfacehub_api_token=HF_TOKEN.strip(),
@@ -21,85 +17,72 @@ llm = HuggingFaceEndpoint(
21
  max_new_tokens=150,
22
  )
23
 
24
- # Define specialized agents
25
-
26
- def flight_booking(input):
27
- return f"Searching flights for: {input}"
28
-
29
- flight_tool = Tool(
30
- name="Flight Booking Tool",
31
- func=flight_booking,
32
- description="Finds and books airline flights"
33
- )
34
-
35
- def hotel_booking(input):
36
- return f"Searching hotels for: {input}"
37
-
38
- hotel_tool = Tool(
39
- name="Hotel Booking Tool",
40
- func=hotel_booking,
41
- description="Searches and books accommodations"
42
- )
43
-
44
- def transportation_booking(input):
45
- return f"Arranging transportation for: {input}"
46
-
47
- transport_tool = Tool(
48
- name="Transportation Booking Tool",
49
- func=transportation_booking,
50
- description="Handles rental cars, shuttles, or trains"
51
- )
52
-
53
- def activity_booking(input):
54
- return f"Finding activities for: {input}"
55
-
56
- activity_tool = Tool(
57
- name="Activity Booking Tool",
58
- func=activity_booking,
59
- description="Books activities, tours, and events"
60
- )
61
-
62
- tools = [flight_tool, hotel_tool, transport_tool, activity_tool]
63
-
64
- # Define a prompt for the planner agent
65
- planner_prompt = PromptTemplate(
66
- template="You are a travel planner. Your task is to coordinate tools to plan a trip based on user input: {input}",
67
- input_variables=["input"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  )
69
-
70
- # Create the agent
71
- planner_agent = AgentExecutor.from_agent_and_tools(
72
- llm=llm,
73
- tools=tools,
74
- verbose=True
75
- )
76
-
77
- # Define the Gradio function
78
- def plan_trip(user_input):
79
- """
80
- Processes user input through the planner agent and returns the output.
81
- """
82
- try:
83
- output = planner_agent.run(user_input)
84
- return output
85
- except Exception as e:
86
- return f"Error: {e}"
87
-
88
- # Create the Gradio interface
89
- with gr.Blocks() as travel_planner_app:
90
- gr.Markdown("## AI Travel Planner")
91
- gr.Markdown("Enter your trip details below, and the AI will help you plan your travel itinerary!")
92
-
93
- with gr.Row():
94
- user_input = gr.Textbox(
95
- label="Trip Details",
96
- placeholder="E.g., Plan a trip to Paris from May 10 to May 15 for a family of 4."
97
- )
98
- submit_button = gr.Button("Plan Trip")
99
- output = gr.Textbox(label="Travel Plan")
100
-
101
- submit_button.click(plan_trip, inputs=user_input, outputs=output)
102
-
103
- # Launch the Gradio app
104
- if __name__ == "__main__":
105
- travel_planner_app.launch()
 
1
+ from langgraph.graph import StateGraph, MessagesState, START, END
2
+ from langgraph.checkpoint.memory import MemorySaver
3
+ from langgraph.prebuilt import ToolNode
4
+ from langchain_core.messages import HumanMessage
5
  from langchain_huggingface import HuggingFaceEndpoint
6
+ from langchain_core.tools import tool
7
+ from typing import Literal
8
+ import os
 
9
 
10
+ # Hugging Face Endpoint setup
11
+ HF_TOKEN = os.getenv("HF_TOKEN") # Ensure your API key is in the environment
 
 
12
 
 
13
  llm = HuggingFaceEndpoint(
14
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
15
  huggingfacehub_api_token=HF_TOKEN.strip(),
 
17
  max_new_tokens=150,
18
  )
19
 
20
+ # Define tools for travel planning
21
+ @tool
22
+ def search_destination(query: str):
23
+ """Search for travel destinations based on preferences."""
24
+ if "beach" in query.lower():
25
+ return "Destinations: Bali, Maldives, Goa"
26
+ elif "adventure" in query.lower():
27
+ return "Destinations: Patagonia, Kilimanjaro, Swiss Alps"
28
+ return "Popular Destinations: Paris, New York, Tokyo"
29
+
30
+ @tool
31
+ def fetch_flights(destination: str):
32
+ """Fetch flight options to the destination."""
33
+ return f"Flights available to {destination}: Option 1 - $500, Option 2 - $700."
34
+
35
+ @tool
36
+ def fetch_hotels(destination: str):
37
+ """Fetch hotel recommendations for the destination."""
38
+ return f"Hotels in {destination}: Hotel A ($200/night), Hotel B ($150/night)."
39
+
40
+ tools = [search_destination, fetch_flights, fetch_hotels]
41
+
42
+ # Bind tools to the Hugging Face model
43
+ llm = llm.bind_tools(tools)
44
+
45
+ # Define the function to determine the next step
46
+ def should_continue(state: MessagesState) -> Literal["tools", END]:
47
+ messages = state['messages']
48
+ last_message = messages[-1]
49
+ if last_message.tool_calls:
50
+ return "tools"
51
+ return END
52
+
53
+ # Define the function to call the LLM
54
+ def call_model(state: MessagesState):
55
+ messages = state['messages']
56
+ response = llm.invoke(messages)
57
+ return {"messages": [response]}
58
+
59
+ # Create the graph
60
+ workflow = StateGraph(MessagesState)
61
+
62
+ # Define nodes
63
+ workflow.add_node("agent", call_model)
64
+ workflow.add_node("tools", ToolNode(tools))
65
+
66
+ # Define edges
67
+ workflow.add_edge(START, "agent")
68
+ workflow.add_conditional_edges("agent", should_continue)
69
+ workflow.add_edge("tools", "agent")
70
+
71
+ # Initialize memory for persistence
72
+ checkpointer = MemorySaver()
73
+
74
+ # Compile the graph
75
+ app = workflow.compile(checkpointer=checkpointer)
76
+
77
+ # Use the graph
78
+ initial_state = {"messages": [HumanMessage(content="I want a beach vacation.")]}
79
+
80
+ final_state = app.invoke(initial_state)
81
+ print(final_state["messages"][-1].content)
82
+
83
+ # Continue the conversation
84
+ next_state = app.invoke(
85
+ {"messages": [HumanMessage(content="Tell me about flights to Bali.")]},
86
+ config={"configurable": {"thread_id": 123}} # Reuse thread_id to maintain context
87
  )
88
+ print(next_state["messages"][-1].content)