from llama_index.core import Document from llama_index.core import KnowledgeGraphIndex, ServiceContext, StorageContext from llama_index.llms.openai import OpenAI from llama_index.core.graph_stores import SimpleGraphStore from llama_index.core import SimpleDirectoryReader, load_index_from_storage from typing import List from dotenv import load_dotenv import os import json import networkx as nx from pyvis.network import Network from datetime import datetime from retrieve import get_latest_dir import html load_dotenv() llm = OpenAI( temperature=0.0, model="gpt-3.5-turbo", api_key=os.getenv("OPENAI_API_KEY") ) graph_store = SimpleGraphStore() storage_context = StorageContext.from_defaults(graph_store=graph_store) service_context = ServiceContext.from_defaults( llm=llm, chunk_size=2048, chunk_overlap=24 ) def create_document(input_dir: str) -> List[Document]: """ Create a document from the given directory. Args: input_dir (str): The input directory to read the documents from. Returns: List[Document]: The list of documents from the directory. """ reader = SimpleDirectoryReader( input_dir, exclude_hidden=True, required_exts=[".json"] ) products_document = [] for docs in reader.iter_data(): products_document.extend(docs) return products_document def kg_triplet_extract_fn(text) -> List[str]: """ Extract the triplets from the text. Args: text (str): The text to extract the triplets from. Returns: List[str]: The list of triplets extracted from the text. """ json_text = text.split("\n\n")[-1] product_spec = json.loads(json_text) triplets = [] product_name = product_spec["name"] del product_spec["name"] for key, value in product_spec.items(): triplets.append((product_name, key, value)) return triplets def generate_graph_visualization(kg_index): """ Generate a graph visualization from the KG index. Args: kg_index (KnowledgeGraphIndex): The Knowledge Graph index to generate the visualization from. Returns: str: The path to the generated graph visualization. """ output_directory = os.getenv("GRAPH_VIS_DIR", "graph_vis") # Generate a timestamp for the filename timestamp = datetime.now().strftime("%Y%m%d%H%M%S") graph_output_file = f"{timestamp}.html" graph_output_path = os.path.join(output_directory, graph_output_file) g = kg_index.get_networkx_graph(limit=20000) net = Network( notebook=False, cdn_resources="remote", height="800px", width="100%", select_menu=True, filter_menu=False, ) net.from_nx(g) net.force_atlas_2based(central_gravity=0.015, gravity=-31) net.save_graph(graph_output_path) print(f"Graph visualization saved to: {graph_output_path}") return graph_output_path def plot_subgraph(triplets): """ Plot a subgraph from the triplets. Args: triplets (str): The triplets to plot the subgraph from. Returns: str: The escaped HTML content to display the subgraph """ G = nx.DiGraph() for edge_str in eval(triplets): source, action, target = eval(edge_str) G.add_edge(source, target, label=action) net = Network(notebook=True, cdn_resources="remote", height="400px", width="100%") net.from_nx(G) net.force_atlas_2based(central_gravity=0.015, gravity=-31) html_content = net.generate_html() escaped_html = html.escape(html_content) return escaped_html def create_kg(max_features: int = 60): """ Create a Knowledge Graph from the given directory. Args: max_features (int): The maximum number of features to use for the KG. Returns: KnowledgeGraphIndex: The Knowledge Graph index. """ input_dir = os.getenv("PROD_SPEC_DIR", "prod_spec") product_documents = create_document(input_dir) kg_index = KnowledgeGraphIndex.from_documents( documents=product_documents, max_triplets_per_chunk=max_features, storage_context=storage_context, service_context=service_context, show_progress=True, include_embeddings=True, kg_triplet_extract_fn=kg_triplet_extract_fn, ) graphvis_path = generate_graph_visualization(kg_index) return kg_index, graphvis_path def persist_kg(kg_index: KnowledgeGraphIndex) -> str: """ Persist the KG index to storage. Args: kg_index (KnowledgeGraphIndex): The Knowledge Graph index to persist. Returns: str: The path to the persisted KG index. """ output_dir = os.getenv("GRAPH_DIR", "graphs") timestamp = datetime.now().strftime("%Y%m%d%H%M%S") kg_path = f"{output_dir}/{timestamp}" kg_index.storage_context.persist(kg_path) return kg_path def load_kg(kg_dir: str) -> KnowledgeGraphIndex: """ Load the KG index from the given directory. Args: kg_dir (str): The parent directory to load the KG index from. Returns: KnowledgeGraphIndex: The loaded Knowledge Graph index. """ kg_path = get_latest_dir(kg_dir) kg_index = load_index_from_storage( StorageContext.from_defaults(persist_dir=kg_path) ) return kg_index def query(kg_dir: str, query: str): """ Query the KG index for a given query. Args: kg_dir (str): The directory to load the KG index from. query (str): The query to ask the KG index. Returns: Response: The response from the KG index. """ kg_index = load_kg(kg_dir) query_engine = kg_index.as_query_engine( include_text=True, response_mode="refine", graph_store_query_depth=6, similarity_top_k=5, ) response = query_engine.query(query) return response def query_graph_qa(graph_rag_index, query, search_level): """ Query the Graph-RAG model for a given query. Args: graph_rag_index (KnowledgeGraphIndex): The Graph-RAG model index. query (str): The query to ask the Graph-RAG model. search_level (int): The max search level to use for the Graph-RAG model. Returns: tuple: The response, reference, and reference text from the Graph-RAG model. """ myretriever = graph_rag_index.as_retriever( include_text=True, similarity_top_k=search_level, ) query_engine = graph_rag_index.as_query_engine( sub_retrievers=[ myretriever, ], graph_store_query_depth=6, include_text=True, similarity_top_k=search_level, ) response = query_engine.query(query) nodes = myretriever.retrieve(query) reference = [] for _, value in response.metadata.items(): if isinstance(value, dict) and "kg_rel_texts" in value: reference = value["kg_rel_texts"] break reference_text = [] for node in nodes: reference_text.append(node.text) return response, reference, reference_text if __name__ == "__main__": kg_index, graphvis_path = create_kg() persist_kg(kg_index) kg_index = load_kg(os.getenv("GRAPH_DIR", "graphs")) generate_graph_visualization(kg_index) response = query( os.getenv("GRAPH_DIR", "graphs"), "Tell me the Built-in memory in Apple iPhone 15 Pro Max 256Gb Blue Titanium?", ) print(response) key = list(response.metadata)[-1] print(response.metadata[key])