LegalAlly / src /buildgraph.py
Rohil Bansal
working...
d8143c9
raw
history blame
6.49 kB
from src.graph import *
from langgraph.graph import END, StateGraph, START
import sys
from langgraph.checkpoint.memory import MemorySaver
from langgraph.errors import GraphRecursionError
memory = MemorySaver()
try:
print("Initializing StateGraph...")
workflow = StateGraph(GraphState)
print("Adding nodes to the graph...")
workflow.add_node("understand_intent", understand_intent)
# workflow.add_node("intent_aware_response", intent_aware_response)
workflow.add_node("greeting", greeting)
workflow.add_node("off_topic", off_topic)
workflow.add_node("route_question", route_question)
workflow.add_node("retrieve", retrieve)
workflow.add_node("web_search", web_search)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("grade_generation", grade_generation_v_documents_and_question)
print("Nodes added successfully.")
print("Building graph edges...")
workflow.add_edge(START, "understand_intent")
workflow.add_conditional_edges(
"understand_intent",
intent_aware_response,
{
"off_topic": "off_topic",
"greeting": "greeting",
"route_question": "route_question",
}
)
workflow.add_edge("greeting", END)
workflow.add_edge("off_topic", END)
workflow.add_conditional_edges(
"route_question",
lambda x: x["route_question"],
{
"web_search": "web_search",
"vectorstore": "retrieve",
}
)
workflow.add_conditional_edges(
"retrieve",
check_recursion_limit,
{
"web_search": "web_search",
"continue": "grade_documents",
}
)
workflow.add_conditional_edges(
"generate",
check_recursion_limit,
{
"web_search": "web_search",
"continue": "grade_generation",
}
)
workflow.add_conditional_edges(
"grade_generation",
lambda x: x["grade_generation"],
{
"not supported": "generate",
"useful": END,
"not useful": "transform_query",
}
)
workflow.add_edge("transform_query", "route_question")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("web_search", "generate")
print("Graph edges built successfully.")
print("Compiling the workflow...")
app = workflow.compile(checkpointer=memory)
print("Workflow compiled successfully.")
try:
from IPython import get_ipython
from IPython.display import Image, display
# Check if we're in an IPython environment
if get_ipython() is not None:
print("Attempting to display graph visualization...")
graph_image = app.get_graph().draw_mermaid_png()
display(Image(graph_image))
print("Graph visualization displayed successfully.")
else:
print("Not running in IPython environment. Saving graph as JPG...")
import os
from PIL import Image
import io
graph_image = app.get_graph().draw_mermaid_png()
img = Image.open(io.BytesIO(graph_image))
img = img.convert('RGB')
# Create a 'graphs' directory if it doesn't exist
if not os.path.exists('graphs'):
os.makedirs('graphs')
img.save('graphs/workflow_graph.jpg', 'JPEG')
print("Graph saved as 'graphs/workflow_graph.jpg'")
except ImportError as e:
print(f"Required libraries not available. Graph visualization skipped. Error: {e}")
except Exception as e:
print(f"Error handling graph visualization: {e}")
print("Graph visualization skipped.")
except GraphRecursionError:
print("Graph recursion limit reached during compilation.")
# Handle the error as needed
except Exception as e:
print(f"Error building the graph: {e}")
sys.exit(1)
def run_workflow(question, config):
try:
print(f"Running workflow for question: {question}")
# Retrieve the previous state from memory
previous_state = memory.get(config)
# Initialize the input state
input_state = {
"question": question,
"chat_history": previous_state.get("chat_history", []) if previous_state else []
}
final_output = None
use_web_search = False
try:
for output in app.stream(input_state, config):
for key, value in output.items():
print(f"Node '{key}'")
if key in ["grade_generation", "off_topic", "greeting", "web_search"]:
final_output = value
except GraphRecursionError:
print("Graph recursion limit reached, switching to web search")
use_web_search = True
if use_web_search:
# Force the use of web_search
web_search_result = web_search(input_state)
generate_result = generate(web_search_result)
final_output = generate_result
if final_output is None:
return {"generation": "I'm sorry, I couldn't generate a response. Could you please rephrase your question?"}
elif isinstance(final_output, dict) and "generation" in final_output:
return {"generation": str(final_output["generation"])}
elif isinstance(final_output, str):
return {"generation": final_output}
else:
return {"generation": str(final_output)}
except Exception as e:
print(f"Error running the workflow: {e}")
import traceback
traceback.print_exc()
return {"generation": "I encountered an error while processing your question. Please try again."}
if __name__ == "__main__":
config = {"configurable": {"thread_id": "test_thread"}}
while True:
question = input("Enter your question (or 'quit' to exit): ")
if question.lower() == 'quit':
break
result = run_workflow(question, config)
print("Chatbot:", result["generation"])