pratikshahp commited on
Commit
dfde201
·
verified ·
1 Parent(s): d35bebd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -67
app.py CHANGED
@@ -1,12 +1,15 @@
1
  import os
 
 
2
  from typing import Annotated
3
  from typing_extensions import TypedDict
 
 
4
  from langgraph.graph import StateGraph, START, END
5
  from langgraph.graph.message import add_messages
6
- from langchain_huggingface import HuggingFaceEndpoint
7
  from dotenv import load_dotenv
8
  import logging
9
- import gradio as gr
10
 
11
  # Initialize logging
12
  logging.basicConfig(level=logging.INFO)
@@ -15,7 +18,7 @@ logging.basicConfig(level=logging.INFO)
15
  load_dotenv()
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
 
18
- # Initialize Hugging Face endpoint
19
  llm = HuggingFaceEndpoint(
20
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
21
  huggingfacehub_api_token=HF_TOKEN.strip(),
@@ -23,6 +26,10 @@ llm = HuggingFaceEndpoint(
23
  max_new_tokens=200
24
  )
25
 
 
 
 
 
26
  # Define the state structure
27
  class State(TypedDict):
28
  messages: Annotated[list, add_messages]
@@ -33,80 +40,129 @@ graph_builder = StateGraph(State)
33
  # Define the chatbot function
34
  def chatbot(state: State):
35
  try:
36
- logging.info(f"Input Messages: {state['messages']}")
37
- response = llm.invoke(state["messages"])
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  logging.info(f"LLM Response: {response}")
39
- return {"messages": [response]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  except Exception as e:
41
  logging.error(f"Error: {str(e)}")
42
- return {"messages": [f"Error: {str(e)}"]}
43
 
44
- # Add nodes and edges to the state graph
45
- graph_builder.add_node("chatbot", chatbot)
46
- graph_builder.add_edge(START, "chatbot")
47
- graph_builder.add_edge("chatbot", END)
 
48
 
49
- # Compile the state graph
50
- graph = graph_builder.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Function to stream updates from the graph
53
- def stream_graph_updates(user_input: str):
54
- """
55
- Stream updates from the graph based on user input and return the assistant's reply.
56
- """
57
- assistant_reply = ""
58
- for event in graph.stream({"messages": [("user", user_input)]}):
59
- for value in event.values():
60
- if isinstance(value["messages"][-1], dict):
61
- # If it's a dict, extract 'content'
62
- assistant_reply = value["messages"][-1].get("content", "")
63
- elif isinstance(value["messages"][-1], str):
64
- # If it's a string, use it directly
65
- assistant_reply = value["messages"][-1]
66
- return assistant_reply
67
-
68
- # Gradio chatbot function using the streaming updates
69
- def gradio_chatbot(user_message: str):
70
- """
71
- Handle Gradio user input, process through the graph, and return only the assistant's reply.
72
- """
73
- try:
74
- return stream_graph_updates(user_message)
75
- except Exception as e:
76
- logging.error(f"Error in Gradio chatbot: {str(e)}")
77
- return f"Error: {str(e)}"
78
 
79
- # Terminal-based chat loop
80
- def terminal_chat():
81
  """
82
- Chat loop for terminal interaction.
 
83
  """
84
- while True:
85
- try:
86
- user_input = input("User: ").strip()
87
- # Process input through the chatbot
88
- response = stream_graph_updates(user_input)
89
- print(f"Assistant: {response}")
90
- except Exception as e:
91
- print(f"Error: {str(e)}")
92
- break
93
-
94
- # Create Gradio interface
95
- interface = gr.Interface(
96
- fn=gradio_chatbot,
97
- inputs=gr.Textbox(placeholder="Enter your message", label="Your Message"),
98
- outputs=gr.Textbox(label="Assistant's Reply"),
99
- title="Chatbot",
100
- description="Interactive chatbot using a state graph and Hugging Face Endpoint."
 
101
  )
 
 
 
102
 
103
- if __name__ == "__main__":
104
- # Launch terminal chat and Gradio interface in separate threads
105
- import threading
 
 
 
 
 
 
 
 
 
106
 
107
- # Terminal chat thread
108
- terminal_thread = threading.Thread(target=terminal_chat, daemon=True)
109
- terminal_thread.start()
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- # Launch Gradio interface
112
- interface.launch(share=True)
 
 
1
  import os
2
+ import gradio as gr
3
+ import json
4
  from typing import Annotated
5
  from typing_extensions import TypedDict
6
+ from langchain_huggingface import HuggingFaceEndpoint
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
  from langgraph.graph import StateGraph, START, END
9
  from langgraph.graph.message import add_messages
10
+ from langchain_core.messages import ToolMessage
11
  from dotenv import load_dotenv
12
  import logging
 
13
 
14
  # Initialize logging
15
  logging.basicConfig(level=logging.INFO)
 
18
  load_dotenv()
19
  HF_TOKEN = os.getenv("HF_TOKEN")
20
 
21
+ # Initialize the HuggingFace model
22
  llm = HuggingFaceEndpoint(
23
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
24
  huggingfacehub_api_token=HF_TOKEN.strip(),
 
26
  max_new_tokens=200
27
  )
28
 
29
+ # Initialize Tavily Search tool
30
+ tool = TavilySearchResults(max_results=2)
31
+ tools = [tool]
32
+
33
  # Define the state structure
34
  class State(TypedDict):
35
  messages: Annotated[list, add_messages]
 
40
  # Define the chatbot function
41
  def chatbot(state: State):
42
  try:
43
+ # Get the last message and ensure it's a string
44
+ input_message = state["messages"][-1] if state["messages"] else ""
45
+
46
+ # Ensure that input_message is a string (check the type)
47
+ if isinstance(input_message, str):
48
+ query = input_message # If it's already a string, use it directly
49
+ elif hasattr(input_message, 'content') and isinstance(input_message.content, str):
50
+ query = input_message.content # Extract the content if it's a HumanMessage object
51
+ else:
52
+ raise ValueError("Input message is not in the correct format")
53
+
54
+ logging.info(f"Input Message: {query}")
55
+
56
+ # Invoke the LLM for a response
57
+ response = llm.invoke([query])
58
  logging.info(f"LLM Response: {response}")
59
+
60
+ # Now, invoke Tavily Search and get the results
61
+ search_results = tool.invoke({"query": query})
62
+
63
+ # Extract URLs from search results
64
+ urls = [result.get("url", "No URL found") for result in search_results]
65
+
66
+ # Prepare the result to include URL information
67
+ result_with_url = {
68
+ "role": "assistant", # Set the role to 'assistant'
69
+ "content": response, # Set the response as content
70
+ "urls": urls # Include the URLs of the search results
71
+ }
72
+
73
+ return {"messages": state["messages"] + [result_with_url]}
74
+
75
  except Exception as e:
76
  logging.error(f"Error: {str(e)}")
77
+ return {"messages": state["messages"] + [f"Error: {str(e)}"]}
78
 
79
+ # Add tool node to the graph
80
+ class BasicToolNode:
81
+ """A node that runs the tools requested in the last AIMessage."""
82
+ def __init__(self, tools: list) -> None:
83
+ self.tools_by_name = {tool.name: tool for tool in tools}
84
 
85
+ def __call__(self, inputs: dict):
86
+ if messages := inputs.get("messages", []):
87
+ message = messages[-1]
88
+ else:
89
+ raise ValueError("No message found in input")
90
+
91
+ outputs = []
92
+ for tool_call in message.tool_calls:
93
+ tool_result = self.tools_by_name[tool_call["name"]].invoke(
94
+ tool_call["args"]
95
+ )
96
+ outputs.append(
97
+ ToolMessage(
98
+ content=json.dumps(tool_result),
99
+ name=tool_call["name"],
100
+ tool_call_id=tool_call["id"],
101
+ )
102
+ )
103
+ return {"messages": outputs}
104
 
105
+ # Add tool node to the graph
106
+ tool_node = BasicToolNode(tools=tools)
107
+ graph_builder.add_node("tools", tool_node)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ # Define the conditional routing function
110
+ def route_tools(state: State):
111
  """
112
+ Route to the ToolNode if the last message has tool calls.
113
+ Otherwise, route to the end.
114
  """
115
+ if isinstance(state, list):
116
+ ai_message = state[-1]
117
+ elif messages := state.get("messages", []):
118
+ ai_message = messages[-1]
119
+ else:
120
+ raise ValueError(f"No messages found in input state to tool_edge: {state}")
121
+
122
+ if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
123
+ return "tools"
124
+
125
+ return END
126
+
127
+ # Add nodes and conditional edges to the state graph
128
+ graph_builder.add_node("chatbot", chatbot)
129
+ graph_builder.add_conditional_edges(
130
+ "chatbot",
131
+ route_tools,
132
+ {"tools": "tools", END: END}
133
  )
134
+ graph_builder.add_edge("tools", "chatbot")
135
+ graph_builder.add_edge(START, "chatbot")
136
+ graph = graph_builder.compile()
137
 
138
+ # Gradio interface
139
+ def chat_interface(input_text, state):
140
+ # Prepare state if not provided
141
+ if state is None:
142
+ state = {"messages": []}
143
+
144
+ # Append user input to state
145
+ state["messages"].append(input_text)
146
+
147
+ # Process state through the graph
148
+ updated_state = graph.invoke(state)
149
+ return updated_state["messages"][-1], updated_state
150
 
151
+ # Create Gradio app
152
+ with gr.Blocks() as demo:
153
+ gr.Markdown("### Chatbot with Tavily Search Integration")
154
+ chat_state = gr.State({"messages": []})
155
+
156
+ with gr.Row():
157
+ with gr.Column():
158
+ user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...", lines=2)
159
+ submit_button = gr.Button("Submit")
160
+
161
+ with gr.Column():
162
+ chatbot_output = gr.Textbox(label="Chatbot Response", interactive=False, lines=4)
163
+
164
+ submit_button.click(chat_interface, inputs=[user_input, chat_state], outputs=[chatbot_output, chat_state])
165
 
166
+ # Launch the Gradio app
167
+ if __name__ == "__main__":
168
+ demo.launch()