File size: 6,350 Bytes
d8b4b59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40fc2b7
d8b4b59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1161442
d8b4b59
1161442
 
 
d8b4b59
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
from langchain.agents import tool
from typing import Literal
import json
from PIL import Image

from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
from langgraph.graph import END, MessagesState

from render_mermaid import render_mermaid
from langchain_community.document_loaders import GithubFileLoader

# from langchain_ollama import ChatOllama
from prompts import *
from constants import file_extensions
from __init__ import llm, llm_structured


class GraphState(MessagesState):
    working_knowledge: str
    all_files: list[str]
    remaining_files: list[str]
    explored_files: list[str]
    explored_summaries: str
    document_summaries_store: dict
    documents: list
    final_graph: Image.Image


def load_github_codebase(repo: str, branch: str):
    loader = GithubFileLoader(
        repo=repo,  # the repo name
        branch=branch,  # the branch name
        github_api_url="https://api.github.com",
        file_filter=lambda file_path: file_path.endswith(tuple(file_extensions)),
        # file_filter=lambda filepath: True,
        encoding="utf-8",
    )
    documents = loader.load()
    return documents


def get_file_content_summary(file_path: str, state: GraphState):
    """Returns the functional summary of a file. Please note that the file_path should not be null.

    Args:
        file_path: The path of the file for which the summary is required."""

    summary = check_summary_in_store(file_path, state)
    if summary:
        return summary
    for document in state["documents"]:
        if document.metadata["path"] == file_path:
            doc_content = document.page_content
            break
    # print(content)
    summary = llm.invoke(
        [SystemMessage(content=summarizer_prompt), HumanMessage(content=doc_content)]
    ).content
    summary = json.dumps({"FilePath": file_path, "Summary": summary})
    save_summary_in_store(file_path, summary, state)
    return summary


def explore_file(state: GraphState):
    file_path = state["remaining_files"].pop()

    summary_dict = json.loads(get_file_content_summary(file_path, state))
    if summary_dict["FilePath"] in state["explored_files"]:
        return state
    knowledge_str = f"""* File Path: {summary_dict['FilePath']}\n\tSummary: {summary_dict['Summary']}\n\n"""
    state["explored_summaries"] += knowledge_str
    state["explored_files"].append(file_path)
    return state


@tool
def generate_final_mermaid_code():
    """Generate the final mermaid code for the codebase once all the files are explored and the working knowledge is complete."""
    return "generate_mermaid_code"


def check_summary_in_store(file_path: str, state: GraphState):
    if file_path in state["document_summaries_store"]:
        return state["document_summaries_store"][file_path]
    return None


def save_summary_in_store(file_path: str, summary: str, state: GraphState):
    state["document_summaries_store"][file_path] = summary


def get_all_filesnames_in_codebase(state: GraphState):
    """Get a list of all files (as filepaths) in the codebase."""
    filenames = []
    for document in state["documents"]:
        filenames.append(document.metadata["path"])

    return {
        "all_files": filenames,
        "explored_files": [],
        "remaining_files": filenames,
        "explored_summaries": "",
        "document_summaries_store": {},
    }


def parse_plan(state: GraphState):
    """Parse the plan and return the next action."""
    if "File Exploration Plan" in state["working_knowledge"]:
        plan_working = state["working_knowledge"].split("File Exploration Plan")[1]
    else:
        plan_working = state["working_knowledge"]
    response = llm_structured.invoke(plan_parser.format(plan_list=plan_working))[
        "plan_list"
    ]
    if len(response) > 25:
        response = response[:25]
    # response = eval(llm.invoke(plan_parser.format(plan_list=plan_working)).content)
    return {"remaining_files": response}


def router(state: GraphState):
    """Route the conversation to the appropriate node based on the current state of the conversation."""
    if state["remaining_files"] != []:
        return "explore_file"
    else:
        return "generate_mermaid_code"


def get_plan_for_codebase(state: GraphState):
    new_state = get_all_filesnames_in_codebase(state)
    planner_content = "# File Structure\n" + str(new_state["all_files"])
    plan = llm.invoke(
        [SystemMessage(content=planner_prompt), HumanMessage(content=planner_content)]
    )

    knowledge_str = f"""# Plan\n{plan.content}"""
    new_state["working_knowledge"] = knowledge_str
    # print(new_state)
    return new_state


def final_mermaid_code_generation(state: GraphState):
    final_graph_content = (
        "# Disjoint Codebase Understanding\n"
        + state["working_knowledge"]
        + "\n\n# Completed Explorations\n"
        + state["explored_summaries"]
    )
    response = llm.invoke(
        [
            SystemMessage(content=final_graph_prompt),
            HumanMessage(content=final_graph_content),
        ]
    )
    return {"messages": [response]}


import time


def extract_mermaid_and_generate_graph(state: GraphState):
    mermaid_code = state["messages"][-1].content
    if "mermaid" in mermaid_code:
        mermaid_code = mermaid_code.split("mermaid")[-1]
    response = llm.invoke(
        [SystemMessage(content=mermaid_extracter), HumanMessage(content=mermaid_code)]
    ).content
    response = response.split("```mermaid")[-1].split("```")[0]
    # Save the mermaid code in a file with the current timestamp
    # print(response)
    file_name = f"mermaid/{int(time.time())}.png"
    # render_mermaid(response, file_name)

    # # Read image to return as output
    # img = Image.open(file_name)
    return {"messages": [AIMessage(response)], "final_graph": None}


def need_to_update_working_knowledge(state: GraphState):
    messages = state["messages"]
    last_message = messages[-1]
    # prev_to_last_message = messages[-2]
    # If the last call is a tool message, we need to update the working knowledge
    if last_message.content == "generate_mermaid_code":
        return "generate_mermaid_code"
    if isinstance(last_message, ToolMessage):
        return "tools_knowledge_update"
    # Otherwise, we continue with the agent
    return "agent"