Spaces:
Sleeping
Sleeping
from src.graph import * | |
from langgraph.graph import END, StateGraph, START | |
import sys | |
from langgraph.checkpoint.memory import MemorySaver | |
from langgraph.errors import GraphRecursionError | |
memory = MemorySaver() | |
try: | |
print("Initializing StateGraph...") | |
workflow = StateGraph(GraphState) | |
print("Adding nodes to the graph...") | |
workflow.add_node("understand_intent", understand_intent) | |
# workflow.add_node("intent_aware_response", intent_aware_response) | |
workflow.add_node("greeting", greeting) | |
workflow.add_node("off_topic", off_topic) | |
workflow.add_node("route_question", route_question) | |
workflow.add_node("retrieve", retrieve) | |
workflow.add_node("web_search", web_search) | |
workflow.add_node("grade_documents", grade_documents) | |
workflow.add_node("generate", generate) | |
workflow.add_node("transform_query", transform_query) | |
workflow.add_node("grade_generation", grade_generation_v_documents_and_question) | |
print("Nodes added successfully.") | |
print("Building graph edges...") | |
workflow.add_edge(START, "understand_intent") | |
workflow.add_conditional_edges( | |
"understand_intent", | |
intent_aware_response, | |
{ | |
"off_topic": "off_topic", | |
"greeting": "greeting", | |
"route_question": "route_question", | |
} | |
) | |
workflow.add_edge("greeting", END) | |
workflow.add_edge("off_topic", END) | |
workflow.add_conditional_edges( | |
"route_question", | |
lambda x: x["route_question"], | |
{ | |
"web_search": "web_search", | |
"vectorstore": "retrieve", | |
} | |
) | |
workflow.add_conditional_edges( | |
"retrieve", | |
check_recursion_limit, | |
{ | |
"web_search": "web_search", | |
"continue": "grade_documents", | |
} | |
) | |
workflow.add_conditional_edges( | |
"generate", | |
check_recursion_limit, | |
{ | |
"web_search": "web_search", | |
"continue": "grade_generation", | |
} | |
) | |
workflow.add_conditional_edges( | |
"grade_generation", | |
lambda x: x["grade_generation"], | |
{ | |
"not supported": "generate", | |
"useful": END, | |
"not useful": "transform_query", | |
} | |
) | |
workflow.add_edge("transform_query", "route_question") | |
workflow.add_conditional_edges( | |
"grade_documents", | |
decide_to_generate, | |
{ | |
"transform_query": "transform_query", | |
"generate": "generate", | |
}, | |
) | |
workflow.add_edge("web_search", "generate") | |
print("Graph edges built successfully.") | |
print("Compiling the workflow...") | |
app = workflow.compile(checkpointer=memory) | |
print("Workflow compiled successfully.") | |
try: | |
from IPython import get_ipython | |
from IPython.display import Image, display | |
# Check if we're in an IPython environment | |
if get_ipython() is not None: | |
print("Attempting to display graph visualization...") | |
graph_image = app.get_graph().draw_mermaid_png() | |
display(Image(graph_image)) | |
print("Graph visualization displayed successfully.") | |
else: | |
print("Not running in IPython environment. Saving graph as JPG...") | |
import os | |
from PIL import Image | |
import io | |
graph_image = app.get_graph().draw_mermaid_png() | |
img = Image.open(io.BytesIO(graph_image)) | |
img = img.convert('RGB') | |
# Create a 'graphs' directory if it doesn't exist | |
if not os.path.exists('graphs'): | |
os.makedirs('graphs') | |
img.save('graphs/workflow_graph.jpg', 'JPEG') | |
print("Graph saved as 'graphs/workflow_graph.jpg'") | |
except ImportError as e: | |
print(f"Required libraries not available. Graph visualization skipped. Error: {e}") | |
except Exception as e: | |
print(f"Error handling graph visualization: {e}") | |
print("Graph visualization skipped.") | |
except GraphRecursionError: | |
print("Graph recursion limit reached during compilation.") | |
# Handle the error as needed | |
except Exception as e: | |
print(f"Error building the graph: {e}") | |
sys.exit(1) | |
def run_workflow(question, config): | |
try: | |
print(f"Running workflow for question: {question}") | |
# Retrieve the previous state from memory | |
previous_state = memory.get(config) | |
# Initialize the input state | |
input_state = { | |
"question": question, | |
"chat_history": previous_state.get("chat_history", []) if previous_state else [] | |
} | |
final_output = None | |
use_web_search = False | |
try: | |
for output in app.stream(input_state, config): | |
for key, value in output.items(): | |
print(f"Node '{key}'") | |
if key in ["grade_generation", "off_topic", "greeting", "web_search"]: | |
final_output = value | |
except GraphRecursionError: | |
print("Graph recursion limit reached, switching to web search") | |
use_web_search = True | |
if use_web_search: | |
# Force the use of web_search | |
web_search_result = web_search(input_state) | |
generate_result = generate(web_search_result) | |
final_output = generate_result | |
if final_output is None: | |
return {"generation": "I'm sorry, I couldn't generate a response. Could you please rephrase your question?"} | |
elif isinstance(final_output, dict) and "generation" in final_output: | |
return {"generation": str(final_output["generation"])} | |
elif isinstance(final_output, str): | |
return {"generation": final_output} | |
else: | |
return {"generation": str(final_output)} | |
except Exception as e: | |
print(f"Error running the workflow: {e}") | |
import traceback | |
traceback.print_exc() | |
return {"generation": "I encountered an error while processing your question. Please try again."} | |
if __name__ == "__main__": | |
config = {"configurable": {"thread_id": "test_thread"}} | |
while True: | |
question = input("Enter your question (or 'quit' to exit): ") | |
if question.lower() == 'quit': | |
break | |
result = run_workflow(question, config) | |
print("Chatbot:", result["generation"]) |