pratikshahp's picture
Update app.py
dfde201 verified
import os
import gradio as gr
import json
from typing import Annotated
from typing_extensions import TypedDict
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_core.messages import ToolMessage
from dotenv import load_dotenv
import logging
# Initialize logging
logging.basicConfig(level=logging.INFO)
# Load environment variables
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
# Initialize the HuggingFace model
llm = HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
huggingfacehub_api_token=HF_TOKEN.strip(),
temperature=0.7,
max_new_tokens=200
)
# Initialize Tavily Search tool
tool = TavilySearchResults(max_results=2)
tools = [tool]
# Define the state structure
class State(TypedDict):
messages: Annotated[list, add_messages]
# Create a state graph builder
graph_builder = StateGraph(State)
# Define the chatbot function
def chatbot(state: State):
try:
# Get the last message and ensure it's a string
input_message = state["messages"][-1] if state["messages"] else ""
# Ensure that input_message is a string (check the type)
if isinstance(input_message, str):
query = input_message # If it's already a string, use it directly
elif hasattr(input_message, 'content') and isinstance(input_message.content, str):
query = input_message.content # Extract the content if it's a HumanMessage object
else:
raise ValueError("Input message is not in the correct format")
logging.info(f"Input Message: {query}")
# Invoke the LLM for a response
response = llm.invoke([query])
logging.info(f"LLM Response: {response}")
# Now, invoke Tavily Search and get the results
search_results = tool.invoke({"query": query})
# Extract URLs from search results
urls = [result.get("url", "No URL found") for result in search_results]
# Prepare the result to include URL information
result_with_url = {
"role": "assistant", # Set the role to 'assistant'
"content": response, # Set the response as content
"urls": urls # Include the URLs of the search results
}
return {"messages": state["messages"] + [result_with_url]}
except Exception as e:
logging.error(f"Error: {str(e)}")
return {"messages": state["messages"] + [f"Error: {str(e)}"]}
# Add tool node to the graph
class BasicToolNode:
"""A node that runs the tools requested in the last AIMessage."""
def __init__(self, tools: list) -> None:
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, inputs: dict):
if messages := inputs.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")
outputs = []
for tool_call in message.tool_calls:
tool_result = self.tools_by_name[tool_call["name"]].invoke(
tool_call["args"]
)
outputs.append(
ToolMessage(
content=json.dumps(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
# Add tool node to the graph
tool_node = BasicToolNode(tools=tools)
graph_builder.add_node("tools", tool_node)
# Define the conditional routing function
def route_tools(state: State):
"""
Route to the ToolNode if the last message has tool calls.
Otherwise, route to the end.
"""
if isinstance(state, list):
ai_message = state[-1]
elif messages := state.get("messages", []):
ai_message = messages[-1]
else:
raise ValueError(f"No messages found in input state to tool_edge: {state}")
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
return "tools"
return END
# Add nodes and conditional edges to the state graph
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_conditional_edges(
"chatbot",
route_tools,
{"tools": "tools", END: END}
)
graph_builder.add_edge("tools", "chatbot")
graph_builder.add_edge(START, "chatbot")
graph = graph_builder.compile()
# Gradio interface
def chat_interface(input_text, state):
# Prepare state if not provided
if state is None:
state = {"messages": []}
# Append user input to state
state["messages"].append(input_text)
# Process state through the graph
updated_state = graph.invoke(state)
return updated_state["messages"][-1], updated_state
# Create Gradio app
with gr.Blocks() as demo:
gr.Markdown("### Chatbot with Tavily Search Integration")
chat_state = gr.State({"messages": []})
with gr.Row():
with gr.Column():
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...", lines=2)
submit_button = gr.Button("Submit")
with gr.Column():
chatbot_output = gr.Textbox(label="Chatbot Response", interactive=False, lines=4)
submit_button.click(chat_interface, inputs=[user_input, chat_state], outputs=[chatbot_output, chat_state])
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()