LegalAlly / src /buildgraph.py
Rohil Bansal
changed recursion limit and search kw
df844ea
from src.graph import *
from langgraph.graph import END, StateGraph, START
import sys
from langgraph.checkpoint.memory import MemorySaver
from langgraph.errors import GraphRecursionError
memory = MemorySaver()
try:
print("Initializing StateGraph...")
workflow = StateGraph(GraphState)
print("Adding nodes to the graph...")
workflow.add_node("understand_intent", understand_intent)
# workflow.add_node("intent_aware_response", intent_aware_response)
workflow.add_node("greeting", greeting)
workflow.add_node("off_topic", off_topic)
workflow.add_node("route_question", route_question)
workflow.add_node("retrieve", retrieve)
workflow.add_node("web_search", web_search)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("grade_generation", grade_generation_v_documents_and_question)
print("Nodes added successfully.")
print("Building graph edges...")
workflow.add_edge(START, "understand_intent")
workflow.add_conditional_edges(
"understand_intent",
intent_aware_response,
{
"off_topic": "off_topic",
"greeting": "greeting",
"route_question": "route_question",
}
)
workflow.add_edge("greeting", END)
workflow.add_edge("off_topic", END)
workflow.add_conditional_edges(
"route_question",
lambda x: x["route_question"],
{
"web_search": "web_search",
"vectorstore": "retrieve",
}
)
workflow.add_edge(
"retrieve", "grade_documents",
)
workflow.add_edge(
"generate", "grade_generation",
)
workflow.add_conditional_edges(
"grade_generation",
lambda x: x["grade_generation"],
{
"not supported": "generate",
"useful": END,
"not useful": "transform_query",
}
)
workflow.add_edge("transform_query", "route_question")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("web_search", "generate")
print("Graph edges built successfully.")
print("Compiling the workflow...")
app = workflow.compile(checkpointer=memory)
print("Workflow compiled successfully.")
try:
from IPython import get_ipython
from IPython.display import Image, display
# Check if we're in an IPython environment
if get_ipython() is not None:
print("Attempting to display graph visualization...")
graph_image = app.get_graph().draw_mermaid_png()
display(Image(graph_image))
print("Graph visualization displayed successfully.")
else:
print("Not running in IPython environment. Saving graph as JPG...")
import os
from PIL import Image
import io
graph_image = app.get_graph().draw_mermaid_png()
img = Image.open(io.BytesIO(graph_image))
img = img.convert('RGB')
# Create a 'graphs' directory if it doesn't exist
if not os.path.exists('graphs'):
os.makedirs('graphs')
img.save('graphs/workflow_graph.jpg', 'JPEG')
print("Graph saved as 'graphs/workflow_graph.jpg'")
except ImportError as e:
print(f"Required libraries not available. Graph visualization skipped. Error: {e}")
except Exception as e:
print(f"Error handling graph visualization: {e}")
print("Graph visualization skipped.")
except GraphRecursionError:
print("Graph recursion limit reached during compilation.")
# Handle the error as needed
except Exception as e:
print(f"Error building the graph: {e}")
sys.exit(1)
def run_workflow(user_input, config):
try:
print(f"Running workflow for question: {user_input}")
# Ensure user_input is a string, not a dict
if isinstance(user_input, dict):
print("user_input is a dict, extracting content")
user_input = user_input.get('content', str(user_input))
print(f"Processed user_input: {user_input}")
# Initialize input_state with required fields
input_state = {
"messages": [{"role": "user", "content": user_input}]
}
print(f"Initial input state: {input_state}")
use_web_search = False
final_output = None
try:
print("Starting graph execution")
for output in app.stream(input_state, config):
# print(f"Graph output: {output}")
for key, value in output.items():
print(f"Node '{key}'")
if key in ["grade_generation", "off_topic", "greeting", "web_search"]:
print(f"Setting final_output from node '{key}'")
final_output = value
print("Graph execution completed")
except GraphRecursionError:
print("Graph recursion limit reached, switching to web search")
use_web_search = True
if use_web_search:
print("Executing web search fallback")
web_search_result = web_search(input_state)
print(f"Web search result: {web_search_result}")
generate_result = generate(web_search_result)
print(f"Generate result: {generate_result}")
final_output = generate_result
print(f"Final output before processing: {final_output}")
if final_output is None:
print("No final output generated")
return {"generation": "I'm sorry, I couldn't generate a response. Could you please rephrase your question?"}
elif isinstance(final_output, dict) and "generation" in final_output:
print("Final output is a dict with 'generation' key")
return {"generation": str(final_output["generation"])}
elif isinstance(final_output, str):
print("Final output is a string")
return {"generation": final_output}
else:
print(f"Unexpected final output type: {type(final_output)}")
return {"generation": str(final_output)}
except Exception as e:
print(f"Error running the workflow: {e}")
print("Full traceback:")
import traceback
traceback.print_exc()
return {"generation": "I encountered an error while processing your question. Please try again."}
if __name__ == "__main__":
config = {"configurable": {"thread_id": "1"}}
while True:
question = input("Enter your question (or 'quit' to exit): ")
if question.lower() == 'quit':
break
result = run_workflow(question, config)
print("Chatbot:", result["generation"])