aie4-final / graph.py
angry-meow's picture
updated graph stucture
06dc1b7
raw
history blame
5.47 kB
from typing import Dict, List, TypedDict, Annotated, Sequence
from langgraph.graph import Graph, StateGraph, END
from langgraph.prebuilt import ToolExecutor
from langchain.schema import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain_community.tools.tavily_search import TavilySearchResults
import models
import prompts
from helper_functions import format_docs
from operator import itemgetter
# Define the state structure
class State(TypedDict):
messages: Sequence[str]
topic: str
research_data: Dict[str, str]
team_members: List[str]
draft_posts: Sequence[str]
final_post: str
research_members = ["Qdrant_researcher", "Web_researcher"]
# Research Agent Pieces
qdrant_research_chain = (
{"context": itemgetter("topic") | models.compression_retriever, "topic": itemgetter("topic")}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"response": prompts.research_query_prompt | models.gpt4o_mini | StrOutputParser(), "context": itemgetter("context")}
)
# Web Search Agent Pieces
tavily_tool = TavilySearchResults(max_results=3)
query_chain = ( prompts.search_query_prompt | models.gpt4o_mini | StrOutputParser() )
tavily_simple = ({"tav_results": tavily_tool} | prompts.tavily_prompt | models.gpt4o_mini | StrOutputParser())
tavily_chain = (
{"query": query_chain} | tavily_simple
)
def query_qdrant(state: State) -> State:
# Extract the last message as the input
topic = state["topic"]
# Run the chain
result = qdrant_research_chain.invoke({"topic": topic})
# Update the state with the research results
state["research_data"]["qdrant_results"] = result
return state
def web_search(state: State) -> State:
# Extract the last message as the topic
topic = state["topic"]
# Get the Qdrant results from the state
qdrant_results = state["research_data"].get("qdrant_results", "No previous results available.")
# Run the web search chain
result = tavily_chain.invoke({
"topic": topic,
"qdrant_results": qdrant_results
})
# Update the state with the web search results
state["research_data"]["web_search_results"] = result
return state
def research_supervisor(state):
# Implement research supervision logic
return state
def post_creation(state):
# Implement post creation logic
return state
def copy_editing(state):
# Implement copy editing logic
return state
def voice_editing(state):
# Implement voice editing logic
return state
def post_review(state):
# Implement post review logic
return state
def writing_supervisor(state):
# Implement writing supervision logic
return state
def overall_supervisor(state):
# Implement overall supervision logic
return state
# Create the research team graph
research_graph = StateGraph(State)
research_graph.add_node("query_qdrant", query_qdrant)
research_graph.add_node("web_search", web_search)
research_graph.add_node("research_supervisor", research_supervisor)
research_graph.add_edge("query_qdrant", "research_supervisor")
research_graph.add_edge("web_search", "research_supervisor")
research_graph.add_conditional_edges(
"research_supervisor",
lambda x: x["next"],
{"query_qdrant": "query_qdrant", "web_search": "web_search", "FINISH": END},
)
#research_graph.add_edge("research_supervisor", END)
research_graph.set_entry_point("research_supervisor")
research_graph_comp = research_graph.compile()
# Create the writing team graph
writing_graph = StateGraph(State)
writing_graph.add_node("post_creation", post_creation)
writing_graph.add_node("copy_editing", copy_editing)
writing_graph.add_node("voice_editing", voice_editing)
writing_graph.add_node("post_review", post_review)
writing_graph.add_node("writing_supervisor", writing_supervisor)
writing_graph.add_edge("post_creation", "writing_supervisor")
writing_graph.add_edge("copy_editing", "writing_supervisor")
writing_graph.add_edge("voice_editing", "writing_supervisor")
writing_graph.add_edge("post_review", "writing_supervisor")
writing_graph.add_conditional_edges(
"writing_supervisor",
lambda x: x["next"],
{"post_creation": "post_creation",
"copy_editing": "copy_editing",
"voice_editing": "voice_editing",
"post_review": "post_review",
"FINISH": END},
)
#writing_graph.add_edge("writing_supervisor", END)
writing_graph.set_entry_point("writing_supervisor")
writing_graph_comp = research_graph.compile()
# Create the overall graph
overall_graph = StateGraph(State)
# Add the research and writing team graphs as nodes
overall_graph.add_node("research_team", research_graph)
overall_graph.add_node("writing_team", writing_graph)
# Add the overall supervisor node
overall_graph.add_node("overall_supervisor", overall_supervisor)
overall_graph.set_entry_point("overall_supervisor")
# Connect the nodes
overall_graph.add_edge("research_team", "overall_supervisor")
overall_graph.add_edge("writing_team", "overall_supervisor")
overall_graph.add_conditional_edges(
"overall_supervisor",
lambda x: x["next"],
{"research_team": "research_team",
"writing_team": "writing_team",
"FINISH": END},
)
# Compile the graph
app = overall_graph.compile()