from typing import Dict, List, TypedDict, Annotated, Sequence from langgraph.graph import Graph, StateGraph, END from langgraph.prebuilt import ToolExecutor from langchain.schema import StrOutputParser from langchain.schema.runnable import RunnablePassthrough from langchain_community.tools.tavily_search import TavilySearchResults import models import prompts from helper_functions import format_docs from operator import itemgetter # Define the state structure class State(TypedDict): messages: Sequence[str] topic: str research_data: Dict[str, str] team_members: List[str] draft_posts: Sequence[str] final_post: str research_members = ["Qdrant_researcher", "Web_researcher"] # Research Agent Pieces qdrant_research_chain = ( {"context": itemgetter("topic") | models.compression_retriever, "topic": itemgetter("topic")} | RunnablePassthrough.assign(context=itemgetter("context")) | {"response": prompts.research_query_prompt | models.gpt4o_mini | StrOutputParser(), "context": itemgetter("context")} ) # Web Search Agent Pieces tavily_tool = TavilySearchResults(max_results=3) query_chain = ( prompts.search_query_prompt | models.gpt4o_mini | StrOutputParser() ) tavily_simple = ({"tav_results": tavily_tool} | prompts.tavily_prompt | models.gpt4o_mini | StrOutputParser()) tavily_chain = ( {"query": query_chain} | tavily_simple ) def query_qdrant(state: State) -> State: # Extract the last message as the input topic = state["topic"] # Run the chain result = qdrant_research_chain.invoke({"topic": topic}) # Update the state with the research results state["research_data"]["qdrant_results"] = result return state def web_search(state: State) -> State: # Extract the last message as the topic topic = state["topic"] # Get the Qdrant results from the state qdrant_results = state["research_data"].get("qdrant_results", "No previous results available.") # Run the web search chain result = tavily_chain.invoke({ "topic": topic, "qdrant_results": qdrant_results }) # Update the state with the web search results state["research_data"]["web_search_results"] = result return state def research_supervisor(state): # Implement research supervision logic return state def post_creation(state): # Implement post creation logic return state def copy_editing(state): # Implement copy editing logic return state def voice_editing(state): # Implement voice editing logic return state def post_review(state): # Implement post review logic return state def writing_supervisor(state): # Implement writing supervision logic return state def overall_supervisor(state): # Implement overall supervision logic return state # Create the research team graph research_graph = StateGraph(State) research_graph.add_node("query_qdrant", query_qdrant) research_graph.add_node("web_search", web_search) research_graph.add_node("research_supervisor", research_supervisor) research_graph.add_edge("query_qdrant", "research_supervisor") research_graph.add_edge("web_search", "research_supervisor") research_graph.add_conditional_edges( "research_supervisor", lambda x: x["next"], {"query_qdrant": "query_qdrant", "web_search": "web_search", "FINISH": END}, ) #research_graph.add_edge("research_supervisor", END) research_graph.set_entry_point("research_supervisor") research_graph_comp = research_graph.compile() # Create the writing team graph writing_graph = StateGraph(State) writing_graph.add_node("post_creation", post_creation) writing_graph.add_node("copy_editing", copy_editing) writing_graph.add_node("voice_editing", voice_editing) writing_graph.add_node("post_review", post_review) writing_graph.add_node("writing_supervisor", writing_supervisor) writing_graph.add_edge("post_creation", "writing_supervisor") writing_graph.add_edge("copy_editing", "writing_supervisor") writing_graph.add_edge("voice_editing", "writing_supervisor") writing_graph.add_edge("post_review", "writing_supervisor") writing_graph.add_conditional_edges( "writing_supervisor", lambda x: x["next"], {"post_creation": "post_creation", "copy_editing": "copy_editing", "voice_editing": "voice_editing", "post_review": "post_review", "FINISH": END}, ) #writing_graph.add_edge("writing_supervisor", END) writing_graph.set_entry_point("writing_supervisor") writing_graph_comp = research_graph.compile() # Create the overall graph overall_graph = StateGraph(State) # Add the research and writing team graphs as nodes overall_graph.add_node("research_team", research_graph) overall_graph.add_node("writing_team", writing_graph) # Add the overall supervisor node overall_graph.add_node("overall_supervisor", overall_supervisor) overall_graph.set_entry_point("overall_supervisor") # Connect the nodes overall_graph.add_edge("research_team", "overall_supervisor") overall_graph.add_edge("writing_team", "overall_supervisor") overall_graph.add_conditional_edges( "overall_supervisor", lambda x: x["next"], {"research_team": "research_team", "writing_team": "writing_team", "FINISH": END}, ) # Compile the graph app = overall_graph.compile()