Spaces:
Paused
Paused
from typing import Dict, List, TypedDict, Annotated, Sequence | |
from langgraph.graph import Graph, StateGraph, END | |
from langgraph.prebuilt import ToolExecutor | |
from langchain.schema import StrOutputParser | |
from langchain.schema.runnable import RunnablePassthrough | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
import models | |
import prompts | |
from helper_functions import format_docs | |
from operator import itemgetter | |
# Define the state structure | |
class State(TypedDict): | |
messages: Sequence[str] | |
topic: str | |
research_data: Dict[str, str] | |
team_members: List[str] | |
draft_posts: Sequence[str] | |
final_post: str | |
research_members = ["Qdrant_researcher", "Web_researcher"] | |
# Research Agent Pieces | |
qdrant_research_chain = ( | |
{"context": itemgetter("topic") | models.compression_retriever, "topic": itemgetter("topic")} | |
| RunnablePassthrough.assign(context=itemgetter("context")) | |
| {"response": prompts.research_query_prompt | models.gpt4o_mini | StrOutputParser(), "context": itemgetter("context")} | |
) | |
# Web Search Agent Pieces | |
tavily_tool = TavilySearchResults(max_results=3) | |
query_chain = ( prompts.search_query_prompt | models.gpt4o_mini | StrOutputParser() ) | |
tavily_simple = ({"tav_results": tavily_tool} | prompts.tavily_prompt | models.gpt4o_mini | StrOutputParser()) | |
tavily_chain = ( | |
{"query": query_chain} | tavily_simple | |
) | |
def query_qdrant(state: State) -> State: | |
# Extract the last message as the input | |
topic = state["topic"] | |
# Run the chain | |
result = qdrant_research_chain.invoke({"topic": topic}) | |
# Update the state with the research results | |
state["research_data"]["qdrant_results"] = result | |
return state | |
def web_search(state: State) -> State: | |
# Extract the last message as the topic | |
topic = state["topic"] | |
# Get the Qdrant results from the state | |
qdrant_results = state["research_data"].get("qdrant_results", "No previous results available.") | |
# Run the web search chain | |
result = tavily_chain.invoke({ | |
"topic": topic, | |
"qdrant_results": qdrant_results | |
}) | |
# Update the state with the web search results | |
state["research_data"]["web_search_results"] = result | |
return state | |
def research_supervisor(state): | |
# Implement research supervision logic | |
return state | |
def post_creation(state): | |
# Implement post creation logic | |
return state | |
def copy_editing(state): | |
# Implement copy editing logic | |
return state | |
def voice_editing(state): | |
# Implement voice editing logic | |
return state | |
def post_review(state): | |
# Implement post review logic | |
return state | |
def writing_supervisor(state): | |
# Implement writing supervision logic | |
return state | |
def overall_supervisor(state): | |
# Implement overall supervision logic | |
return state | |
# Create the research team graph | |
research_graph = StateGraph(State) | |
research_graph.add_node("query_qdrant", query_qdrant) | |
research_graph.add_node("web_search", web_search) | |
research_graph.add_node("research_supervisor", research_supervisor) | |
research_graph.add_edge("query_qdrant", "research_supervisor") | |
research_graph.add_edge("web_search", "research_supervisor") | |
research_graph.add_conditional_edges( | |
"research_supervisor", | |
lambda x: x["next"], | |
{"query_qdrant": "query_qdrant", "web_search": "web_search", "FINISH": END}, | |
) | |
#research_graph.add_edge("research_supervisor", END) | |
research_graph.set_entry_point("research_supervisor") | |
research_graph_comp = research_graph.compile() | |
# Create the writing team graph | |
writing_graph = StateGraph(State) | |
writing_graph.add_node("post_creation", post_creation) | |
writing_graph.add_node("copy_editing", copy_editing) | |
writing_graph.add_node("voice_editing", voice_editing) | |
writing_graph.add_node("post_review", post_review) | |
writing_graph.add_node("writing_supervisor", writing_supervisor) | |
writing_graph.add_edge("post_creation", "writing_supervisor") | |
writing_graph.add_edge("copy_editing", "writing_supervisor") | |
writing_graph.add_edge("voice_editing", "writing_supervisor") | |
writing_graph.add_edge("post_review", "writing_supervisor") | |
writing_graph.add_conditional_edges( | |
"writing_supervisor", | |
lambda x: x["next"], | |
{"post_creation": "post_creation", | |
"copy_editing": "copy_editing", | |
"voice_editing": "voice_editing", | |
"post_review": "post_review", | |
"FINISH": END}, | |
) | |
#writing_graph.add_edge("writing_supervisor", END) | |
writing_graph.set_entry_point("writing_supervisor") | |
writing_graph_comp = research_graph.compile() | |
# Create the overall graph | |
overall_graph = StateGraph(State) | |
# Add the research and writing team graphs as nodes | |
overall_graph.add_node("research_team", research_graph) | |
overall_graph.add_node("writing_team", writing_graph) | |
# Add the overall supervisor node | |
overall_graph.add_node("overall_supervisor", overall_supervisor) | |
overall_graph.set_entry_point("overall_supervisor") | |
# Connect the nodes | |
overall_graph.add_edge("research_team", "overall_supervisor") | |
overall_graph.add_edge("writing_team", "overall_supervisor") | |
overall_graph.add_conditional_edges( | |
"overall_supervisor", | |
lambda x: x["next"], | |
{"research_team": "research_team", | |
"writing_team": "writing_team", | |
"FINISH": END}, | |
) | |
# Compile the graph | |
app = overall_graph.compile() |