import os import gradio as gr import json from typing import Annotated from typing_extensions import TypedDict from langchain_huggingface import HuggingFaceEndpoint from langchain_community.tools.tavily_search import TavilySearchResults from langgraph.graph import StateGraph, START, END from langgraph.graph.message import add_messages from langchain_core.messages import ToolMessage from dotenv import load_dotenv import logging # Initialize logging logging.basicConfig(level=logging.INFO) # Load environment variables load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") # Initialize the HuggingFace model llm = HuggingFaceEndpoint( repo_id="mistralai/Mistral-7B-Instruct-v0.3", huggingfacehub_api_token=HF_TOKEN.strip(), temperature=0.7, max_new_tokens=200 ) # Initialize Tavily Search tool tool = TavilySearchResults(max_results=2) tools = [tool] # Define the state structure class State(TypedDict): messages: Annotated[list, add_messages] # Create a state graph builder graph_builder = StateGraph(State) # Define the chatbot function def chatbot(state: State): try: # Get the last message and ensure it's a string input_message = state["messages"][-1] if state["messages"] else "" # Ensure that input_message is a string (check the type) if isinstance(input_message, str): query = input_message # If it's already a string, use it directly elif hasattr(input_message, 'content') and isinstance(input_message.content, str): query = input_message.content # Extract the content if it's a HumanMessage object else: raise ValueError("Input message is not in the correct format") logging.info(f"Input Message: {query}") # Invoke the LLM for a response response = llm.invoke([query]) logging.info(f"LLM Response: {response}") # Now, invoke Tavily Search and get the results search_results = tool.invoke({"query": query}) # Extract URLs from search results urls = [result.get("url", "No URL found") for result in search_results] # Prepare the result to include URL information result_with_url = { "role": "assistant", # Set the role to 'assistant' "content": response, # Set the response as content "urls": urls # Include the URLs of the search results } return {"messages": state["messages"] + [result_with_url]} except Exception as e: logging.error(f"Error: {str(e)}") return {"messages": state["messages"] + [f"Error: {str(e)}"]} # Add tool node to the graph class BasicToolNode: """A node that runs the tools requested in the last AIMessage.""" def __init__(self, tools: list) -> None: self.tools_by_name = {tool.name: tool for tool in tools} def __call__(self, inputs: dict): if messages := inputs.get("messages", []): message = messages[-1] else: raise ValueError("No message found in input") outputs = [] for tool_call in message.tool_calls: tool_result = self.tools_by_name[tool_call["name"]].invoke( tool_call["args"] ) outputs.append( ToolMessage( content=json.dumps(tool_result), name=tool_call["name"], tool_call_id=tool_call["id"], ) ) return {"messages": outputs} # Add tool node to the graph tool_node = BasicToolNode(tools=tools) graph_builder.add_node("tools", tool_node) # Define the conditional routing function def route_tools(state: State): """ Route to the ToolNode if the last message has tool calls. Otherwise, route to the end. """ if isinstance(state, list): ai_message = state[-1] elif messages := state.get("messages", []): ai_message = messages[-1] else: raise ValueError(f"No messages found in input state to tool_edge: {state}") if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: return "tools" return END # Add nodes and conditional edges to the state graph graph_builder.add_node("chatbot", chatbot) graph_builder.add_conditional_edges( "chatbot", route_tools, {"tools": "tools", END: END} ) graph_builder.add_edge("tools", "chatbot") graph_builder.add_edge(START, "chatbot") graph = graph_builder.compile() # Gradio interface def chat_interface(input_text, state): # Prepare state if not provided if state is None: state = {"messages": []} # Append user input to state state["messages"].append(input_text) # Process state through the graph updated_state = graph.invoke(state) return updated_state["messages"][-1], updated_state # Create Gradio app with gr.Blocks() as demo: gr.Markdown("### Chatbot with Tavily Search Integration") chat_state = gr.State({"messages": []}) with gr.Row(): with gr.Column(): user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...", lines=2) submit_button = gr.Button("Submit") with gr.Column(): chatbot_output = gr.Textbox(label="Chatbot Response", interactive=False, lines=4) submit_button.click(chat_interface, inputs=[user_input, chat_state], outputs=[chatbot_output, chat_state]) # Launch the Gradio app if __name__ == "__main__": demo.launch()