Spaces:
Sleeping
Sleeping
from src.graph import * | |
from langgraph.graph import END, StateGraph, START | |
import sys | |
from langgraph.checkpoint.memory import MemorySaver | |
from langgraph.errors import GraphRecursionError | |
memory = MemorySaver() | |
try: | |
print("Initializing StateGraph...") | |
workflow = StateGraph(GraphState) | |
print("Adding nodes to the graph...") | |
workflow.add_node("understand_intent", understand_intent) | |
# workflow.add_node("intent_aware_response", intent_aware_response) | |
workflow.add_node("greeting", greeting) | |
workflow.add_node("off_topic", off_topic) | |
workflow.add_node("route_question", route_question) | |
workflow.add_node("retrieve", retrieve) | |
workflow.add_node("web_search", web_search) | |
workflow.add_node("grade_documents", grade_documents) | |
workflow.add_node("generate", generate) | |
workflow.add_node("transform_query", transform_query) | |
workflow.add_node("grade_generation", grade_generation_v_documents_and_question) | |
print("Nodes added successfully.") | |
print("Building graph edges...") | |
workflow.add_edge(START, "understand_intent") | |
workflow.add_conditional_edges( | |
"understand_intent", | |
intent_aware_response, | |
{ | |
"off_topic": "off_topic", | |
"greeting": "greeting", | |
"route_question": "route_question", | |
} | |
) | |
workflow.add_edge("greeting", END) | |
workflow.add_edge("off_topic", END) | |
workflow.add_conditional_edges( | |
"route_question", | |
lambda x: x["route_question"], | |
{ | |
"web_search": "web_search", | |
"vectorstore": "retrieve", | |
} | |
) | |
workflow.add_edge( | |
"retrieve", "grade_documents", | |
) | |
workflow.add_edge( | |
"generate", "grade_generation", | |
) | |
workflow.add_conditional_edges( | |
"grade_generation", | |
lambda x: x["grade_generation"], | |
{ | |
"not supported": "generate", | |
"useful": END, | |
"not useful": "transform_query", | |
} | |
) | |
workflow.add_edge("transform_query", "route_question") | |
workflow.add_conditional_edges( | |
"grade_documents", | |
decide_to_generate, | |
{ | |
"transform_query": "transform_query", | |
"generate": "generate", | |
}, | |
) | |
workflow.add_edge("web_search", "generate") | |
print("Graph edges built successfully.") | |
print("Compiling the workflow...") | |
app = workflow.compile(checkpointer=memory) | |
print("Workflow compiled successfully.") | |
try: | |
from IPython import get_ipython | |
from IPython.display import Image, display | |
# Check if we're in an IPython environment | |
if get_ipython() is not None: | |
print("Attempting to display graph visualization...") | |
graph_image = app.get_graph().draw_mermaid_png() | |
display(Image(graph_image)) | |
print("Graph visualization displayed successfully.") | |
else: | |
print("Not running in IPython environment. Saving graph as JPG...") | |
import os | |
from PIL import Image | |
import io | |
graph_image = app.get_graph().draw_mermaid_png() | |
img = Image.open(io.BytesIO(graph_image)) | |
img = img.convert('RGB') | |
# Create a 'graphs' directory if it doesn't exist | |
if not os.path.exists('graphs'): | |
os.makedirs('graphs') | |
img.save('graphs/workflow_graph.jpg', 'JPEG') | |
print("Graph saved as 'graphs/workflow_graph.jpg'") | |
except ImportError as e: | |
print(f"Required libraries not available. Graph visualization skipped. Error: {e}") | |
except Exception as e: | |
print(f"Error handling graph visualization: {e}") | |
print("Graph visualization skipped.") | |
except GraphRecursionError: | |
print("Graph recursion limit reached during compilation.") | |
# Handle the error as needed | |
except Exception as e: | |
print(f"Error building the graph: {e}") | |
sys.exit(1) | |
def run_workflow(user_input, config): | |
try: | |
print(f"Running workflow for question: {user_input}") | |
# Ensure user_input is a string, not a dict | |
if isinstance(user_input, dict): | |
print("user_input is a dict, extracting content") | |
user_input = user_input.get('content', str(user_input)) | |
print(f"Processed user_input: {user_input}") | |
# Initialize input_state with required fields | |
input_state = { | |
"messages": [{"role": "user", "content": user_input}] | |
} | |
print(f"Initial input state: {input_state}") | |
use_web_search = False | |
final_output = None | |
try: | |
print("Starting graph execution") | |
for output in app.stream(input_state, config): | |
# print(f"Graph output: {output}") | |
for key, value in output.items(): | |
print(f"Node '{key}'") | |
if key in ["grade_generation", "off_topic", "greeting", "web_search"]: | |
print(f"Setting final_output from node '{key}'") | |
final_output = value | |
print("Graph execution completed") | |
except GraphRecursionError: | |
print("Graph recursion limit reached, switching to web search") | |
use_web_search = True | |
if use_web_search: | |
print("Executing web search fallback") | |
web_search_result = web_search(input_state) | |
print(f"Web search result: {web_search_result}") | |
generate_result = generate(web_search_result) | |
print(f"Generate result: {generate_result}") | |
final_output = generate_result | |
print(f"Final output before processing: {final_output}") | |
if final_output is None: | |
print("No final output generated") | |
return {"generation": "I'm sorry, I couldn't generate a response. Could you please rephrase your question?"} | |
elif isinstance(final_output, dict) and "generation" in final_output: | |
print("Final output is a dict with 'generation' key") | |
return {"generation": str(final_output["generation"])} | |
elif isinstance(final_output, str): | |
print("Final output is a string") | |
return {"generation": final_output} | |
else: | |
print(f"Unexpected final output type: {type(final_output)}") | |
return {"generation": str(final_output)} | |
except Exception as e: | |
print(f"Error running the workflow: {e}") | |
print("Full traceback:") | |
import traceback | |
traceback.print_exc() | |
return {"generation": "I encountered an error while processing your question. Please try again."} | |
if __name__ == "__main__": | |
config = {"configurable": {"thread_id": "1"}} | |
while True: | |
question = input("Enter your question (or 'quit' to exit): ") | |
if question.lower() == 'quit': | |
break | |
result = run_workflow(question, config) | |
print("Chatbot:", result["generation"]) |