import sys import os from contextlib import contextmanager from langchain.schema import Document from langgraph.graph import END, StateGraph from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod from typing_extensions import TypedDict from typing import List, Dict import operator from typing import Annotated from IPython.display import display, HTML, Image from .chains.answer_chitchat import make_chitchat_node from .chains.answer_ai_impact import make_ai_impact_node from .chains.query_transformation import make_query_transform_node from .chains.translation import make_translation_node from .chains.intent_categorization import make_intent_categorization_node from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriever_node from .chains.answer_rag import make_rag_node from .chains.graph_retriever import make_graph_retriever_node from .chains.chitchat_categorization import make_chitchat_intent_categorization_node # from .chains.set_defaults import set_defaults class GraphState(TypedDict): """ Represents the state of our graph. """ user_input : str language : str intent : str search_graphs_chitchat : bool query: str questions_list : List[dict] handled_questions_index : Annotated[list[int], operator.add] n_questions : int answer: str audience: str = "experts" sources_input: List[str] = ["IPCC","IPBES"] relevant_content_sources_selection: List[str] = ["Figures (IPCC/IPBES)"] sources_auto: bool = True min_year: int = 1960 max_year: int = None documents: Annotated[List[Document], operator.add] related_contents : Annotated[List[Document], operator.add] recommended_content : List[Document] search_only : bool = False reports : List[str] = [] def dummy(state): return def search(state): #TODO return def answer_search(state):#TODO return def route_intent(state): intent = state["intent"] if intent in ["chitchat","esg"]: return "answer_chitchat" # elif intent == "ai_impact": # return "answer_ai_impact" else: # Search route return "answer_climate" def chitchat_route_intent(state): intent = state["search_graphs_chitchat"] if intent is True: return "retrieve_graphs_chitchat" elif intent is False: return END def route_translation(state): if state["language"].lower() == "english": return "transform_query" else: return "transform_query" # return "translate_query" #TODO : add translation def route_based_on_relevant_docs(state,threshold_docs=0.2): docs = [x for x in state["documents"] if x.metadata["reranking_score"] > threshold_docs] print("Route : ", ["answer_rag" if len(docs) > 0 else "answer_rag_no_docs"]) if len(docs) > 0: return "answer_rag" else: return "answer_rag_no_docs" def route_continue_retrieve_documents(state): index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"] questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx) # if questions_ipx_finished and state["search_only"]: # return END if questions_ipx_finished: return "end_retrieve_IPx_documents" else: return "retrieve_documents" # if state["n_questions"]["IPx"] == len(state["handled_questions_index"]) and state["search_only"] : # return END # elif state["n_questions"]["IPx"] == len(state["handled_questions_index"]): # return "answer_search" # else : # return "retrieve_documents" def route_continue_retrieve_local_documents(state): index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"] questions_poc_finished = all(elem in state["handled_questions_index"] for elem in index_question_poc) # if questions_poc_finished and state["search_only"]: # return END if questions_poc_finished or ("POC region" not in state["relevant_content_sources_selection"]): return "end_retrieve_local_documents" else: return "retrieve_local_data" # if state["n_questions"]["POC"] == len(state["handled_questions_index"]) and state["search_only"] : # return END # elif state["n_questions"]["POC"] == len(state["handled_questions_index"]): # return "answer_search" # else : # return "retrieve_local_data" # if len(state["remaining_questions"]) == 0 and state["search_only"] : # return END # elif len(state["remaining_questions"]) > 0: # return "retrieve_documents" # else: # return "answer_search" def route_retrieve_documents(state): sources_to_retrieve = [] if "Graphs (OurWorldInData)" in state["relevant_content_sources_selection"] : sources_to_retrieve.append("retrieve_graphs") if sources_to_retrieve == []: return END return sources_to_retrieve def make_id_dict(values): return {k:k for k in values} def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2): workflow = StateGraph(GraphState) # Define the node functions categorize_intent = make_intent_categorization_node(llm) transform_query = make_query_transform_node(llm) translate_query = make_translation_node(llm) answer_chitchat = make_chitchat_node(llm) answer_ai_impact = make_ai_impact_node(llm) retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm) retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker) # retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm) answer_rag = make_rag_node(llm, with_docs=True) answer_rag_no_docs = make_rag_node(llm, with_docs=False) chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm) # Define the nodes # workflow.add_node("set_defaults", set_defaults) workflow.add_node("categorize_intent", categorize_intent) workflow.add_node("answer_climate", dummy) workflow.add_node("answer_search", answer_search) workflow.add_node("transform_query", transform_query) workflow.add_node("translate_query", translate_query) workflow.add_node("answer_chitchat", answer_chitchat) workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent) workflow.add_node("retrieve_graphs", retrieve_graphs) # workflow.add_node("retrieve_local_data", retrieve_local_data) workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs) workflow.add_node("retrieve_documents", retrieve_documents) workflow.add_node("answer_rag", answer_rag) workflow.add_node("answer_rag_no_docs", answer_rag_no_docs) # Entry point workflow.set_entry_point("categorize_intent") # CONDITIONAL EDGES workflow.add_conditional_edges( "categorize_intent", route_intent, make_id_dict(["answer_chitchat","answer_climate"]) ) workflow.add_conditional_edges( "chitchat_categorize_intent", chitchat_route_intent, make_id_dict(["retrieve_graphs_chitchat", END]) ) workflow.add_conditional_edges( "answer_climate", route_translation, make_id_dict(["translate_query","transform_query"]) ) workflow.add_conditional_edges( "answer_search", lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs), make_id_dict(["answer_rag","answer_rag_no_docs"]) ) workflow.add_conditional_edges( "transform_query", route_retrieve_documents, make_id_dict(["retrieve_graphs", END]) ) # Define the edges workflow.add_edge("translate_query", "transform_query") workflow.add_edge("transform_query", "retrieve_documents") #TODO put back # workflow.add_edge("transform_query", "retrieve_local_data") # workflow.add_edge("transform_query", END) # TODO remove workflow.add_edge("retrieve_graphs", END) workflow.add_edge("answer_rag", END) workflow.add_edge("answer_rag_no_docs", END) workflow.add_edge("answer_chitchat", "chitchat_categorize_intent") workflow.add_edge("retrieve_graphs_chitchat", END) # workflow.add_edge("retrieve_local_data", "answer_search") workflow.add_edge("retrieve_documents", "answer_search") # Compile app = workflow.compile() return app def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2): workflow = StateGraph(GraphState) # Define the node functions categorize_intent = make_intent_categorization_node(llm) transform_query = make_query_transform_node(llm) translate_query = make_translation_node(llm) answer_chitchat = make_chitchat_node(llm) answer_ai_impact = make_ai_impact_node(llm) retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm) retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker) retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm) answer_rag = make_rag_node(llm, with_docs=True) answer_rag_no_docs = make_rag_node(llm, with_docs=False) chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm) # Define the nodes # workflow.add_node("set_defaults", set_defaults) workflow.add_node("categorize_intent", categorize_intent) workflow.add_node("answer_climate", dummy) workflow.add_node("answer_search", answer_search) # workflow.add_node("end_retrieve_local_documents", dummy) # workflow.add_node("end_retrieve_IPx_documents", dummy) workflow.add_node("transform_query", transform_query) workflow.add_node("translate_query", translate_query) workflow.add_node("answer_chitchat", answer_chitchat) workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent) workflow.add_node("retrieve_graphs", retrieve_graphs) workflow.add_node("retrieve_local_data", retrieve_local_data) workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs) workflow.add_node("retrieve_documents", retrieve_documents) workflow.add_node("answer_rag", answer_rag) workflow.add_node("answer_rag_no_docs", answer_rag_no_docs) # Entry point workflow.set_entry_point("categorize_intent") # CONDITIONAL EDGES workflow.add_conditional_edges( "categorize_intent", route_intent, make_id_dict(["answer_chitchat","answer_climate"]) ) workflow.add_conditional_edges( "chitchat_categorize_intent", chitchat_route_intent, make_id_dict(["retrieve_graphs_chitchat", END]) ) workflow.add_conditional_edges( "answer_climate", route_translation, make_id_dict(["translate_query","transform_query"]) ) workflow.add_conditional_edges( "answer_search", lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs), make_id_dict(["answer_rag","answer_rag_no_docs"]) ) workflow.add_conditional_edges( "transform_query", route_retrieve_documents, make_id_dict(["retrieve_graphs", END]) ) # Define the edges workflow.add_edge("translate_query", "transform_query") workflow.add_edge("transform_query", "retrieve_documents") #TODO put back workflow.add_edge("transform_query", "retrieve_local_data") # workflow.add_edge("transform_query", END) # TODO remove workflow.add_edge("retrieve_graphs", END) workflow.add_edge("answer_rag", END) workflow.add_edge("answer_rag_no_docs", END) workflow.add_edge("answer_chitchat", "chitchat_categorize_intent") workflow.add_edge("retrieve_graphs_chitchat", END) workflow.add_edge("retrieve_local_data", "answer_search") workflow.add_edge("retrieve_documents", "answer_search") # Compile app = workflow.compile() return app def display_graph(app): display( Image( app.get_graph(xray = True).draw_mermaid_png( draw_method=MermaidDrawMethod.API, ) ) )