GraphRAG / app.py
girishwangikar's picture
Update app.py
fc5da88 verified
raw
history blame
5.47 kB
import os
import gradio as gr
import networkx as nx
import matplotlib.pyplot as plt
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain.chains import GraphQAChain
from langchain_core.documents import Document
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
import pandas as pd
from gradio_client import Client
import numpy as np
from PIL import Image as PILImage
import base64
from io import BytesIO
# Set the base directory
BASE_DIR = os.getcwd()
GROQ_API_KEY = os.environ.get('GROQ_API_KEY')
# Set up LLM and Flux client
llm = ChatGroq(temperature=0, model_name='llama-3.1-8b-instant', groq_api_key=GROQ_API_KEY)
flux_client = Client("black-forest-labs/Flux.1-schnell")
def create_graph(text):
documents = [Document(page_content=text)]
llm_transformer_filtered = LLMGraphTransformer(llm=llm)
graph_documents_filtered = llm_transformer_filtered.convert_to_graph_documents(documents)
graph = NetworkxEntityGraph()
for node in graph_documents_filtered[0].nodes:
graph.add_node(node.id)
for edge in graph_documents_filtered[0].relationships:
graph._graph.add_edge(
edge.source.id,
edge.target.id,
relation=edge.type
)
return graph, graph_documents_filtered
def visualize_graph(graph):
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(graph._graph)
nx.draw(graph._graph, pos, with_labels=True, node_color='lightblue', node_size=500, font_size=8, font_weight='bold')
edge_labels = nx.get_edge_attributes(graph._graph, 'relation')
nx.draw_networkx_edge_labels(graph._graph, pos, edge_labels=edge_labels, font_size=6)
plt.title("Graph Visualization")
plt.axis('off')
# Save the plot as an image file
graph_viz_path = os.path.join(BASE_DIR, 'graph_visualization.png')
plt.savefig(graph_viz_path)
plt.close()
return graph_viz_path
def generate_image(prompt):
try:
print(f"Generating image with prompt: {prompt}")
result = flux_client.predict(
prompt=prompt,
seed=0,
randomize_seed=True,
width=1024,
height=1024,
num_inference_steps=4,
api_name="/infer"
)
if isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], str):
img_str = result[0]
img_str += '=' * (-len(img_str) % 4)
img_data = base64.b64decode(img_str)
image = PILImage.open(BytesIO(img_data))
elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], np.ndarray):
image = PILImage.fromarray((result[0] * 255).astype(np.uint8))
elif isinstance(result, PILImage.Image):
image = result
else:
raise ValueError(f"Unexpected result format from flux_client.predict: {type(result)}")
image_path = os.path.join(BASE_DIR, 'generated_image.png')
image.save(image_path)
print(f"Image saved to: {image_path}")
return image_path
except Exception as e:
print(f"Error in generate_image: {str(e)}")
import traceback
traceback.print_exc()
return None
def process_text(text, question):
try:
print("Creating graph...")
graph, graph_documents_filtered = create_graph(text)
print("Setting up GraphQAChain...")
graph_rag = GraphQAChain.from_llm(
llm=llm,
graph=graph,
verbose=True
)
print("Running question through GraphQAChain...")
answer = graph_rag.run(question)
print(f"Answer: {answer}")
print("Visualizing graph...")
graph_viz_path = visualize_graph(graph)
print(f"Graph visualization saved to: {graph_viz_path}")
print("Generating summary...")
summary_prompt = f"Summarize the following text in one sentence: {text}"
summary = llm.invoke(summary_prompt).content
print(f"Summary: {summary}")
print("Generating image...")
image_path = generate_image(summary)
if image_path and os.path.exists(image_path):
print(f"Generated image saved to: {image_path}")
else:
print("Failed to generate or save image")
return answer, graph_viz_path, summary, image_path
except Exception as e:
print(f"An error occurred in process_text: {str(e)}")
import traceback
traceback.print_exc()
return str(e), None, str(e), None
def ui_function(text, question):
answer, graph_viz_path, summary, image_path = process_text(text, question)
if isinstance(answer, str) and answer.startswith("An error occurred"):
return answer, None, answer, None
return answer, graph_viz_path, summary, image_path
# Create Gradio interface
iface = gr.Interface(
fn=ui_function,
inputs=[
gr.Textbox(label="Input Text"),
gr.Textbox(label="Question")
],
outputs=[
gr.Textbox(label="Answer"),
gr.Image(label="Graph Visualization", type="filepath"),
gr.Textbox(label="Summary"),
gr.Image(label="Generated Image", type="filepath")
],
title="GraphRAG and Image Generation UI",
description="Enter text to create a graph, ask a question, and generate a relevant image."
)
iface.launch()