Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import json | |
from typing import Annotated | |
from typing_extensions import TypedDict | |
from langchain_huggingface import HuggingFaceEndpoint | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langgraph.graph import StateGraph, START, END | |
from langgraph.graph.message import add_messages | |
from langchain_core.messages import ToolMessage | |
from dotenv import load_dotenv | |
import logging | |
# Initialize logging | |
logging.basicConfig(level=logging.INFO) | |
# Load environment variables | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# Initialize the HuggingFace model | |
llm = HuggingFaceEndpoint( | |
repo_id="mistralai/Mistral-7B-Instruct-v0.3", | |
huggingfacehub_api_token=HF_TOKEN.strip(), | |
temperature=0.7, | |
max_new_tokens=200 | |
) | |
# Initialize Tavily Search tool | |
tool = TavilySearchResults(max_results=2) | |
tools = [tool] | |
# Define the state structure | |
class State(TypedDict): | |
messages: Annotated[list, add_messages] | |
# Create a state graph builder | |
graph_builder = StateGraph(State) | |
# Define the chatbot function | |
def chatbot(state: State): | |
try: | |
# Get the last message and ensure it's a string | |
input_message = state["messages"][-1] if state["messages"] else "" | |
# Ensure that input_message is a string (check the type) | |
if isinstance(input_message, str): | |
query = input_message # If it's already a string, use it directly | |
elif hasattr(input_message, 'content') and isinstance(input_message.content, str): | |
query = input_message.content # Extract the content if it's a HumanMessage object | |
else: | |
raise ValueError("Input message is not in the correct format") | |
logging.info(f"Input Message: {query}") | |
# Invoke the LLM for a response | |
response = llm.invoke([query]) | |
logging.info(f"LLM Response: {response}") | |
# Now, invoke Tavily Search and get the results | |
search_results = tool.invoke({"query": query}) | |
# Extract URLs from search results | |
urls = [result.get("url", "No URL found") for result in search_results] | |
# Prepare the result to include URL information | |
result_with_url = { | |
"role": "assistant", # Set the role to 'assistant' | |
"content": response, # Set the response as content | |
"urls": urls # Include the URLs of the search results | |
} | |
return {"messages": state["messages"] + [result_with_url]} | |
except Exception as e: | |
logging.error(f"Error: {str(e)}") | |
return {"messages": state["messages"] + [f"Error: {str(e)}"]} | |
# Add tool node to the graph | |
class BasicToolNode: | |
"""A node that runs the tools requested in the last AIMessage.""" | |
def __init__(self, tools: list) -> None: | |
self.tools_by_name = {tool.name: tool for tool in tools} | |
def __call__(self, inputs: dict): | |
if messages := inputs.get("messages", []): | |
message = messages[-1] | |
else: | |
raise ValueError("No message found in input") | |
outputs = [] | |
for tool_call in message.tool_calls: | |
tool_result = self.tools_by_name[tool_call["name"]].invoke( | |
tool_call["args"] | |
) | |
outputs.append( | |
ToolMessage( | |
content=json.dumps(tool_result), | |
name=tool_call["name"], | |
tool_call_id=tool_call["id"], | |
) | |
) | |
return {"messages": outputs} | |
# Add tool node to the graph | |
tool_node = BasicToolNode(tools=tools) | |
graph_builder.add_node("tools", tool_node) | |
# Define the conditional routing function | |
def route_tools(state: State): | |
""" | |
Route to the ToolNode if the last message has tool calls. | |
Otherwise, route to the end. | |
""" | |
if isinstance(state, list): | |
ai_message = state[-1] | |
elif messages := state.get("messages", []): | |
ai_message = messages[-1] | |
else: | |
raise ValueError(f"No messages found in input state to tool_edge: {state}") | |
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: | |
return "tools" | |
return END | |
# Add nodes and conditional edges to the state graph | |
graph_builder.add_node("chatbot", chatbot) | |
graph_builder.add_conditional_edges( | |
"chatbot", | |
route_tools, | |
{"tools": "tools", END: END} | |
) | |
graph_builder.add_edge("tools", "chatbot") | |
graph_builder.add_edge(START, "chatbot") | |
graph = graph_builder.compile() | |
# Gradio interface | |
def chat_interface(input_text, state): | |
# Prepare state if not provided | |
if state is None: | |
state = {"messages": []} | |
# Append user input to state | |
state["messages"].append(input_text) | |
# Process state through the graph | |
updated_state = graph.invoke(state) | |
return updated_state["messages"][-1], updated_state | |
# Create Gradio app | |
with gr.Blocks() as demo: | |
gr.Markdown("### Chatbot with Tavily Search Integration") | |
chat_state = gr.State({"messages": []}) | |
with gr.Row(): | |
with gr.Column(): | |
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...", lines=2) | |
submit_button = gr.Button("Submit") | |
with gr.Column(): | |
chatbot_output = gr.Textbox(label="Chatbot Response", interactive=False, lines=4) | |
submit_button.click(chat_interface, inputs=[user_input, chat_state], outputs=[chatbot_output, chat_state]) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch() | |