from src.graph import * from langgraph.graph import END, StateGraph, START import sys from langgraph.checkpoint.memory import MemorySaver from langgraph.errors import GraphRecursionError memory = MemorySaver() try: print("Initializing StateGraph...") workflow = StateGraph(GraphState) print("Adding nodes to the graph...") workflow.add_node("understand_intent", understand_intent) # workflow.add_node("intent_aware_response", intent_aware_response) workflow.add_node("greeting", greeting) workflow.add_node("off_topic", off_topic) workflow.add_node("route_question", route_question) workflow.add_node("retrieve", retrieve) workflow.add_node("web_search", web_search) workflow.add_node("grade_documents", grade_documents) workflow.add_node("generate", generate) workflow.add_node("transform_query", transform_query) workflow.add_node("grade_generation", grade_generation_v_documents_and_question) print("Nodes added successfully.") print("Building graph edges...") workflow.add_edge(START, "understand_intent") workflow.add_conditional_edges( "understand_intent", intent_aware_response, { "off_topic": "off_topic", "greeting": "greeting", "route_question": "route_question", } ) workflow.add_edge("greeting", END) workflow.add_edge("off_topic", END) workflow.add_conditional_edges( "route_question", lambda x: x["route_question"], { "web_search": "web_search", "vectorstore": "retrieve", } ) workflow.add_edge( "retrieve", "grade_documents", ) workflow.add_edge( "generate", "grade_generation", ) workflow.add_conditional_edges( "grade_generation", lambda x: x["grade_generation"], { "not supported": "generate", "useful": END, "not useful": "transform_query", } ) workflow.add_edge("transform_query", "route_question") workflow.add_conditional_edges( "grade_documents", decide_to_generate, { "transform_query": "transform_query", "generate": "generate", }, ) workflow.add_edge("web_search", "generate") print("Graph edges built successfully.") print("Compiling the workflow...") app = workflow.compile(checkpointer=memory) print("Workflow compiled successfully.") try: from IPython import get_ipython from IPython.display import Image, display # Check if we're in an IPython environment if get_ipython() is not None: print("Attempting to display graph visualization...") graph_image = app.get_graph().draw_mermaid_png() display(Image(graph_image)) print("Graph visualization displayed successfully.") else: print("Not running in IPython environment. Saving graph as JPG...") import os from PIL import Image import io graph_image = app.get_graph().draw_mermaid_png() img = Image.open(io.BytesIO(graph_image)) img = img.convert('RGB') # Create a 'graphs' directory if it doesn't exist if not os.path.exists('graphs'): os.makedirs('graphs') img.save('graphs/workflow_graph.jpg', 'JPEG') print("Graph saved as 'graphs/workflow_graph.jpg'") except ImportError as e: print(f"Required libraries not available. Graph visualization skipped. Error: {e}") except Exception as e: print(f"Error handling graph visualization: {e}") print("Graph visualization skipped.") except GraphRecursionError: print("Graph recursion limit reached during compilation.") # Handle the error as needed except Exception as e: print(f"Error building the graph: {e}") sys.exit(1) def run_workflow(user_input, config): try: print(f"Running workflow for question: {user_input}") # Ensure user_input is a string, not a dict if isinstance(user_input, dict): print("user_input is a dict, extracting content") user_input = user_input.get('content', str(user_input)) print(f"Processed user_input: {user_input}") # Initialize input_state with required fields input_state = { "messages": [{"role": "user", "content": user_input}] } print(f"Initial input state: {input_state}") use_web_search = False final_output = None try: print("Starting graph execution") for output in app.stream(input_state, config): # print(f"Graph output: {output}") for key, value in output.items(): print(f"Node '{key}'") if key in ["grade_generation", "off_topic", "greeting", "web_search"]: print(f"Setting final_output from node '{key}'") final_output = value print("Graph execution completed") except GraphRecursionError: print("Graph recursion limit reached, switching to web search") use_web_search = True if use_web_search: print("Executing web search fallback") web_search_result = web_search(input_state) print(f"Web search result: {web_search_result}") generate_result = generate(web_search_result) print(f"Generate result: {generate_result}") final_output = generate_result print(f"Final output before processing: {final_output}") if final_output is None: print("No final output generated") return {"generation": "I'm sorry, I couldn't generate a response. Could you please rephrase your question?"} elif isinstance(final_output, dict) and "generation" in final_output: print("Final output is a dict with 'generation' key") return {"generation": str(final_output["generation"])} elif isinstance(final_output, str): print("Final output is a string") return {"generation": final_output} else: print(f"Unexpected final output type: {type(final_output)}") return {"generation": str(final_output)} except Exception as e: print(f"Error running the workflow: {e}") print("Full traceback:") import traceback traceback.print_exc() return {"generation": "I encountered an error while processing your question. Please try again."} if __name__ == "__main__": config = {"configurable": {"thread_id": "1"}} while True: question = input("Enter your question (or 'quit' to exit): ") if question.lower() == 'quit': break result = run_workflow(question, config) print("Chatbot:", result["generation"])