from openai import OpenAI import os import networkx as nx from dotenv import load_dotenv from constants import DOCUMENTS from tqdm import tqdm from cdlib import algorithms import matplotlib.pyplot as plt load_dotenv(".env.example") client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) def draw_graph(graph): pos = nx.spring_layout(graph) # Position the nodes plt.figure(figsize=(12, 8)) nx.draw( graph, pos, with_labels=True, node_color="skyblue", edge_color="gray", node_size=1500, font_size=10, font_weight="bold", ) edge_labels = nx.get_edge_attributes(graph, "label") nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=8) plt.title("Graph Visualization of Extracted Entities and Relationships") plt.savefig("graph.png") plt.show() # Source texts -> chunks def get_chunks(documents, chunk_size=1000, overlap_size=200): chunks = [] for doc in documents: for i in range(0, len(doc), chunk_size - overlap_size): chunks.append(doc[i : i + chunk_size]) return chunks # print(get_chunks(DOCUMENTS)) # Chunks -> Element instances def extract_elements(chunks): elements = [] for index, chunk in enumerate(chunks): print(f"Processing chunk {index + 1}/{len(chunks)}") response = client.chat.completions.create( model="gpt-4o", messages=[ { "role": "system", "content": "Extract entities and relationships from the following text.", }, {"role": "user", "content": chunk}, ], ) print(response.choices[0].message.content) entities_and_relations = response.choices[0].message.content elements.append(entities_and_relations) return elements # print(extract_elements(get_chunks(DOCUMENTS))) # Element instances -> Element summaries def summarize_elements(elements): summaries = [] for index, element in enumerate(elements): print(f"Summarizing element {index + 1}/{len(elements)}") response = client.chat.completions.create( model="gpt-4o", messages=[ { "role": "system", "content": 'Summarize the following entities and relationships in a structured format. Use "->" to represent relationships, after the "Relationships:" word.', }, {"role": "user", "content": element}, ], ) print("Element summary:", response.choices[0].message.content) summary = response.choices[0].message.content summaries.append(summary) return summaries # print(summarize_elements(extract_elements(get_chunks(DOCUMENTS)))) # Element summaries -> Graph communities def build_graph(summaries): G = nx.Graph() for index, summary in enumerate(summaries): print(f"Summary index {index + 1}/{len(summaries)}") lines = summary.split("\n") entities_section = False relationships_section = False entities = [] for line in tqdm(lines): if line.startswith("### Entities:") or line.startswith("**Entities:**"): entities_section = True relationships_section = False continue elif line.startswith("### Relationships:") or line.startswith( "**Relationships:**" ): entities_section = False relationships_section = True continue if entities_section and line.strip(): if line[0].isdigit() and line[1] == ".": line = line.split(".", 1)[1].strip() entity = line.strip() entity = entity.replace("**", "") entities.append(entity) G.add_node(entity) elif relationships_section and line.strip(): parts = line.split("->") if len(parts) == 2: source = parts[0].strip() target = parts[-1].strip() relation = " -> ".join(parts[1:-1]).strip() G.add_edge(source, target, label=relation) return G # Graph communities -> Graph summaries def detect_communities(graph): communities = [] index = 0 for component in nx.connected_components(graph): print( f"Component index {index} of {len(list(nx.connected_components(graph)))}:" ) subgraph = graph.subgraph(component) if len(subgraph.nodes) > 1: try: sub_communities = algorithms.leiden(subgraph) for community in sub_communities.communities: communities.append(list(community)) except Exception as e: print(f"Error processing community {index}: {e}") else: communities.append(list(subgraph.nodes)) index += 1 print("Communities from detect_communities:", communities) return communities # summarize communities def summarize_communities(communities, graph): community_summaries = [] for index, community in enumerate(communities): print(f"Summarize Community index {index+1}/{len(communities)}:") subgraph = graph.subgraph(community) nodes = list(subgraph.nodes) edges = list(subgraph.edges(data=True)) description = "Entities: " + ", ".join(nodes) + "\nRelationships: " relationships = [] for edge in edges: source, target, data = edge relation = data.get("label", "") relationships.append(f"{source} -> {data['label']} -> {target}") description += ", ".join(relationships) response = client.chat.completions.create( model="gpt-4o", messages=[ { "role": "system", "content": "Summarize the following community of entities and relationships.", }, {"role": "user", "content": description}, ], ) summary = response.choices[0].message.content.strip() community_summaries.append(summary) return community_summaries # Community Summaries → Community Answers → Global Answer def generate_answer(community_summaries, query): intermediate_answers = [] for index, summary in enumerate(community_summaries): print(f"Answering community {index+1}/{len(community_summaries)}:") response = client.chat.completions.create( model="gpt-4o", messages=[ { "role": "system", "content": "Answer the following query based on the provided summary.", }, {"role": "user", "content": f"Query: {query} Summary: {summary}"}, ], ) print("Intermediate answer:", response.choices[0].message.content) intermediate_answers.append(response.choices[0].message.content) final_response = client.chat.completions.create( model="gpt-4o", messages=[ { "role": "system", "content": "Combine these answers into a final, concise response.", }, { "role": "user", "content": f"Intermediate answers: {intermediate_answers}", }, ], ) final_answer = final_response.choices[0].message.content return final_answer def graphrag_pipeline(documents, query): chunks = get_chunks(documents) elements = extract_elements(chunks) summaries = summarize_elements(elements) graph = build_graph(summaries) num_entities = graph.number_of_nodes() print(f"Number of entities in the graph: {num_entities}") draw_graph(graph) communities = detect_communities(graph) print(communities) community_summaries = summarize_communities(communities, graph) final_answer = generate_answer(community_summaries, query) return final_answer query = "What factors in these articles can impact medical inflation in the UK in the short term?" # "What are the main themes in these documents?" print("Query:", query) answer = graphrag_pipeline(DOCUMENTS, query) print("Answer:", answer)