File size: 3,411 Bytes
267129f
 
97867c9
 
 
b6ed368
97867c9
 
8f767c1
97867c9
 
8f767c1
b6ed368
8f767c1
 
 
 
 
 
97867c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267129f
 
 
 
97867c9
267129f
 
97867c9
267129f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode
from langchain_huggingface import ChatHuggingFace
from typing import Literal
import os

# Hugging Face Endpoint setup
HF_TOKEN = os.getenv("HF_TOKEN")  # Ensure your API key is in the environment

llm = ChatHuggingFace(
    repo_id="mistralai/Mistral-7B-Instruct-v0.3",
    huggingfacehub_api_token=HF_TOKEN.strip(),
    temperature=0.7,
    max_new_tokens=150,
)

# Define tools for travel planning
def search_destination(query: str):
    """Search for travel destinations based on preferences."""
    if "beach" in query.lower():
        return "Destinations: Bali, Maldives, Goa"
    elif "adventure" in query.lower():
        return "Destinations: Patagonia, Kilimanjaro, Swiss Alps"
    return "Popular Destinations: Paris, New York, Tokyo"

def fetch_flights(destination: str):
    """Fetch flight options to the destination."""
    return f"Flights available to {destination}: Option 1 - $500, Option 2 - $700."

def fetch_hotels(destination: str):
    """Fetch hotel recommendations for the destination."""
    return f"Hotels in {destination}: Hotel A ($200/night), Hotel B ($150/night)."

tools = [search_destination, fetch_flights, fetch_hotels]

# Bind tools to the Hugging Face model
llm = llm.bind_tools(tools)

# Define the function to determine the next step
def should_continue(state: MessagesState) -> Literal["tools", END]:
    messages = state['messages']
    last_message = messages[-1]
    if last_message.tool_calls:
        return "tools"
    return END

# Define the function to call the LLM
def call_model(state: MessagesState):
    messages = state['messages']
    response = llm.invoke(messages)
    return {"messages": [response]}

# Create the graph
workflow = StateGraph(MessagesState)

# Define nodes
workflow.add_node("agent", call_model)
workflow.add_node("tools", ToolNode(tools))

# Define edges
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", should_continue)
workflow.add_edge("tools", "agent")

# Initialize memory for persistence
checkpointer = MemorySaver()

# Compile the graph
app = workflow.compile(checkpointer=checkpointer)

# Function to process user input and generate response
def chat(user_input, history):
    # Prepare the initial state with the user's message
    initial_state = {"messages": [HumanMessage(content=user_input)]}

    # Invoke the workflow
    final_state = app.invoke(initial_state)

    # Extract the response content
    response = final_state["messages"][-1].content

    # Append the user input and response to history
    history.append((user_input, response))

    return history, history

# Create the Gradio interface
with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    with gr.Row():
        with gr.Column():
            user_input = gr.Textbox(
                show_label=False,
                placeholder="Enter your message...",
            ).style(container=False)
        with gr.Column():
            submit_btn = gr.Button("Send")

    # Set up the event handler for the submit button
    submit_btn.click(chat, [user_input, chatbot], [chatbot, chatbot])
    user_input.submit(chat, [user_input, chatbot], [chatbot, chatbot])

# Launch the Gradio app
demo.launch()