import os import gradio as gr import networkx as nx import matplotlib.pyplot as plt from langchain_experimental.graph_transformers import LLMGraphTransformer from langchain.chains import GraphQAChain from langchain_core.documents import Document from langchain_community.graphs.networkx_graph import NetworkxEntityGraph from langchain_core.prompts import ChatPromptTemplate from langchain_groq import ChatGroq import pandas as pd from gradio_client import Client import numpy as np from PIL import Image as PILImage import base64 from io import BytesIO # Set the base directory BASE_DIR = os.getcwd() GROQ_API_KEY = os.environ.get('GROQ_API_KEY') # Set up LLM and Flux client llm = ChatGroq(temperature=0, model_name='llama-3.1-8b-instant', groq_api_key=groq_api_key) flux_client = Client("black-forest-labs/Flux.1-schnell") def create_graph(text): documents = [Document(page_content=text)] llm_transformer_filtered = LLMGraphTransformer(llm=llm) graph_documents_filtered = llm_transformer_filtered.convert_to_graph_documents(documents) graph = NetworkxEntityGraph() for node in graph_documents_filtered[0].nodes: graph.add_node(node.id) for edge in graph_documents_filtered[0].relationships: graph._graph.add_edge( edge.source.id, edge.target.id, relation=edge.type ) return graph, graph_documents_filtered def visualize_graph(graph): plt.figure(figsize=(12, 8)) pos = nx.spring_layout(graph._graph) nx.draw(graph._graph, pos, with_labels=True, node_color='lightblue', node_size=500, font_size=8, font_weight='bold') edge_labels = nx.get_edge_attributes(graph._graph, 'relation') nx.draw_networkx_edge_labels(graph._graph, pos, edge_labels=edge_labels, font_size=6) plt.title("Graph Visualization") plt.axis('off') # Save the plot as an image file graph_viz_path = os.path.join(BASE_DIR, 'graph_visualization.png') plt.savefig(graph_viz_path) plt.close() return graph_viz_path def generate_image(prompt): try: print(f"Generating image with prompt: {prompt}") result = flux_client.predict( prompt=prompt, seed=0, randomize_seed=True, width=1024, height=1024, num_inference_steps=4, api_name="/infer" ) if isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], str): img_str = result[0] img_str += '=' * (-len(img_str) % 4) img_data = base64.b64decode(img_str) image = PILImage.open(BytesIO(img_data)) elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], np.ndarray): image = PILImage.fromarray((result[0] * 255).astype(np.uint8)) elif isinstance(result, PILImage.Image): image = result else: raise ValueError(f"Unexpected result format from flux_client.predict: {type(result)}") image_path = os.path.join(BASE_DIR, 'generated_image.png') image.save(image_path) print(f"Image saved to: {image_path}") return image_path except Exception as e: print(f"Error in generate_image: {str(e)}") import traceback traceback.print_exc() return None def process_text(text, question): try: print("Creating graph...") graph, graph_documents_filtered = create_graph(text) print("Setting up GraphQAChain...") graph_rag = GraphQAChain.from_llm( llm=llm, graph=graph, verbose=True ) print("Running question through GraphQAChain...") answer = graph_rag.run(question) print(f"Answer: {answer}") print("Visualizing graph...") graph_viz_path = visualize_graph(graph) print(f"Graph visualization saved to: {graph_viz_path}") print("Generating summary...") summary_prompt = f"Summarize the following text in one sentence: {text}" summary = llm.invoke(summary_prompt).content print(f"Summary: {summary}") print("Generating image...") image_path = generate_image(summary) if image_path and os.path.exists(image_path): print(f"Generated image saved to: {image_path}") else: print("Failed to generate or save image") return answer, graph_viz_path, summary, image_path except Exception as e: print(f"An error occurred in process_text: {str(e)}") import traceback traceback.print_exc() return str(e), None, str(e), None def ui_function(text, question): answer, graph_viz_path, summary, image_path = process_text(text, question) if isinstance(answer, str) and answer.startswith("An error occurred"): return answer, None, answer, None return answer, graph_viz_path, summary, image_path # Create Gradio interface iface = gr.Interface( fn=ui_function, inputs=[ gr.Textbox(label="Input Text"), gr.Textbox(label="Question") ], outputs=[ gr.Textbox(label="Answer"), gr.Image(label="Graph Visualization", type="filepath"), gr.Textbox(label="Summary"), gr.Image(label="Generated Image", type="filepath") ], title="GraphRAG and Image Generation UI", description="Enter text to create a graph, ask a question, and generate a relevant image." ) iface.launch()