import gradio as gr import networkx as nx import matplotlib.pyplot as plt from neo4j import GraphDatabase import io import base64 class Neo4jGraphVisualizer: def __init__(self, uri, username, password): """ Initialize the Neo4j graph database connection Args: uri (str): Neo4j database URI username (str): Neo4j username password (str): Neo4j password """ self.driver = GraphDatabase.driver(uri, auth=(username, password)) def fetch_graph_data(self): """ Fetch graph data from Neo4j database Returns: dict: A dictionary containing nodes and relationships """ with self.driver.session() as session: # Fetch all nodes with elementId instead of deprecated ID() nodes_result = session.run(""" MATCH (n) RETURN elementId(n) as id, labels(n) as labels, properties(n) as properties """) # Fetch all relationships using elementId relationships_result = session.run(""" MATCH (a)-[r]->(b) RETURN elementId(a) as source_id, elementId(b) as target_id, type(r) as relationship_type, properties(r) as relationship_properties """) # Process nodes nodes = [ { 'id': record['id'], 'label': record['labels'][0] if record['labels'] else 'Unknown', 'properties': dict(record['properties']) } for record in nodes_result ] # Process relationships relationships = [ { 'source': record['source_id'], 'target': record['target_id'], 'type': record['relationship_type'], 'properties': dict(record.get('relationship_properties', {})) } for record in relationships_result ] return {'nodes': nodes, 'relationships': relationships} def visualize_graph(self): """ Visualize the graph using NetworkX and Matplotlib Returns: str: Base64 encoded image of the graph """ try: # Fetch graph data graph_data = self.fetch_graph_data() # Create NetworkX graph G = nx.DiGraph() # Add nodes for node in graph_data['nodes']: # Use node's label or properties for display label = node.get('properties', {}).get('name', str(node['id'])) G.add_node(node['id'], label=label, properties=node['properties']) # Add edges for rel in graph_data['relationships']: G.add_edge(rel['source'], rel['target'], type=rel['type'], properties=rel['properties']) # Visualization plt.figure(figsize=(16, 12)) pos = nx.spring_layout(G, k=0.9, iterations=50) # Improved layout # Draw nodes with color and size based on properties node_sizes = [300 + len(str(G.nodes[node]['properties'])) * 10 for node in G.nodes()] node_colors = ['lightblue' if idx % 2 == 0 else 'lightgreen' for idx in range(len(G.nodes()))] nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes, alpha=0.8) # Draw edges nx.draw_networkx_edges(G, pos, edge_color='gray', arrows=True, width=1.5) # Draw labels nx.draw_networkx_labels(G, pos, labels={node: G.nodes[node]['label'] for node in G.nodes()}, font_size=8) plt.title("Neo4j Graph Visualization") plt.axis('off') # Save to buffer buffer = io.BytesIO() plt.savefig(buffer, format='png', dpi=300, bbox_inches='tight') buffer.seek(0) image_png = buffer.getvalue() buffer.close() plt.close() # Close the plot to free memory # Encode graphic = base64.b64encode(image_png).decode('utf-8') return f"data:image/png;base64,{graphic}" except Exception as e: print(f"Error in graph visualization: {e}") return f"Error visualizing graph: {e}" def close(self): """Close the Neo4j driver connection""" self.driver.close() def create_gradio_interface(uri, username, password): """ Create a Gradio interface for Neo4j graph visualization Args: uri (str): Neo4j database URI username (str): Neo4j username password (str): Neo4j password """ visualizer = Neo4jGraphVisualizer(uri, username, password) def visualize_graph(): try: graph_image = visualizer.visualize_graph() return graph_image except Exception as e: return f"Error: {str(e)}" # Create Gradio interface iface = gr.Interface( fn=visualize_graph, inputs=[], outputs=gr.Image(type="filepath"), title="Neo4j Graph Visualization", description="Visualize graph data from Neo4j database" ) return iface, visualizer # Configuration (replace with your actual Neo4j credentials) NEO4J_URI="neo4j+s://b96332bd.databases.neo4j.io" NEO4J_USERNAME="neo4j" NEO4J_PASSWORD="qviTdN6cw66AjIv6lu7kXcsN4keYPdXc2gAWuIoB8T4" AURA_INSTANCEID="b96332bd" AURA_INSTANCENAME="Instance01" def main(): # Create Gradio interface interface, visualizer = create_gradio_interface( NEO4J_URI, NEO4J_USERNAME, NEO4J_PASSWORD ) try: # Launch the interface interface.launch(server_name='0.0.0.0', server_port=7860) except Exception as e: print(f"Error launching interface: {e}") finally: # Ensure driver is closed visualizer.close() if __name__ == "__main__": main()