mcrlab_gemini2_test / flowchart.py
Rahatara's picture
Rename app.py to flowchart.py
85963c6 verified
import os
from PIL import Image, ImageDraw, ImageFont
import json
import gradio as gr
from google import genai
from google.genai import types
# Initialize Google Gemini client
client = genai.Client(api_key=os.environ['GOOGLE_API_KEY'])
model_name = "gemini-2.0-flash-exp"
# Function to parse JSON output from Gemini
def parse_json(json_output):
"""
Parse JSON output from the Gemini model.
"""
try:
lines = json_output.splitlines()
for i, line in enumerate(lines):
if line == "```json":
json_output = "\n".join(lines[i + 1:]) # Remove everything before "```json"
json_output = json_output.split("```")[0] # Remove everything after the closing "```"
break
return json.loads(json_output)
except Exception as e:
print(f"Error parsing JSON: {e}")
return {}
# Function to draw a flowchart
def draw_flowchart(image, flowchart_json):
"""
Draws a flowchart on the given image based on JSON input.
"""
im = image.copy()
draw = ImageDraw.Draw(im)
# Load default font
try:
font = ImageFont.load_default()
except Exception as e:
print(f"Error loading font: {e}")
return im
shapes = flowchart_json.get("shapes", [])
connections = flowchart_json.get("connections", [])
# Draw shapes
for shape in shapes:
x, y, w, h = shape["x"], shape["y"], shape["width"], shape["height"]
shape_type = shape.get("type", "rectangle").lower()
label = shape.get("label", "")
color = shape.get("color", "white")
# Draw the shape
if shape_type == "rectangle":
draw.rectangle([x, y, x + w, y + h], fill=color, outline="black", width=3)
elif shape_type == "ellipse":
draw.ellipse([x, y, x + w, y + h], fill=color, outline="black", width=3)
elif shape_type == "diamond":
points = [
(x + w // 2, y), # Top
(x + w, y + h // 2), # Right
(x + w // 2, y + h), # Bottom
(x, y + h // 2) # Left
]
draw.polygon(points, fill=color, outline="black")
# Calculate text position using getbbox
bbox = font.getbbox(label)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
text_x = x + (w - text_w) // 2
text_y = y + (h - text_h) // 2
# Add the label
draw.text((text_x, text_y), label, fill="black", font=font)
# Draw connections
for conn in connections:
from_shape = next(s for s in shapes if s["id"] == conn["from"])
to_shape = next(s for s in shapes if s["id"] == conn["to"])
x1, y1 = from_shape["x"] + from_shape["width"] // 2, from_shape["y"] + from_shape["height"]
x2, y2 = to_shape["x"] + to_shape["width"] // 2, to_shape["y"]
# Draw the line
draw.line([x1, y1, x2, y2], fill="black", width=2)
# Add arrowhead for arrows
if conn.get("type", "arrow") == "arrow":
arrow_size = 10
draw.polygon([(x2, y2 - arrow_size), (x2, y2 + arrow_size), (x2 + arrow_size, y2)], fill="black")
return im
# Function to draw a flowchart
# Function to draw a flowchart
def olddraw_flowchart(image, flowchart_json):
"""
Draws a flowchart on the given image based on JSON input.
"""
im = image.copy()
draw = ImageDraw.Draw(im)
# Load default font
try:
font = ImageFont.load_default()
except Exception as e:
print(f"Error loading font: {e}")
return im
shapes = flowchart_json.get("shapes", [])
connections = flowchart_json.get("connections", [])
# Draw shapes
for shape in shapes:
x, y, w, h = shape["x"], shape["y"], shape["width"], shape["height"]
shape_type = shape.get("type", "rectangle").lower()
label = shape.get("label", "")
color = shape.get("color", "white")
# Draw the shape
if shape_type == "rectangle":
draw.rectangle([x, y, x + w, y + h], fill=color, outline="black", width=3)
elif shape_type == "ellipse":
draw.ellipse([x, y, x + w, y + h], fill=color, outline="black", width=3)
elif shape_type == "diamond":
points = [
(x + w // 2, y), # Top
(x + w, y + h // 2), # Right
(x + w // 2, y + h), # Bottom
(x, y + h // 2) # Left
]
draw.polygon(points, fill=color, outline="black")
# Calculate text position
text_w, text_h = font.getsize(label)
text_x = x + (w - text_w) // 2
text_y = y + (h - text_h) // 2
# Add the label
draw.text((text_x, text_y), label, fill="black", font=font)
# Draw connections
for conn in connections:
from_shape = next(s for s in shapes if s["id"] == conn["from"])
to_shape = next(s for s in shapes if s["id"] == conn["to"])
x1, y1 = from_shape["x"] + from_shape["width"] // 2, from_shape["y"] + from_shape["height"]
x2, y2 = to_shape["x"] + to_shape["width"] // 2, to_shape["y"]
# Draw the line
draw.line([x1, y1, x2, y2], fill="black", width=2)
# Add arrowhead for arrows
if conn.get("type", "arrow") == "arrow":
arrow_size = 10
draw.polygon([(x2, y2 - arrow_size), (x2, y2 + arrow_size), (x2 + arrow_size, y2)], fill="black")
return im
# Function to generate flowchart JSON via Gemini
def generate_flowchart(prompt):
"""
Use Google Gemini to generate JSON for a flowchart.
"""
try:
response = client.models.generate_content(
model=model_name,
contents=[prompt],
config=types.GenerateContentConfig(
system_instruction="""
Return a JSON structure describing a flowchart.
Use formal flowchart conventions with shapes like rectangles, ellipses, and diamonds.
Each shape should have attributes: id, label, x, y, width, height, type (e.g., 'rectangle', 'ellipse', 'diamond'), and color.
Also include connections with attributes: from (id), to (id), and type (e.g., 'arrow').
""",
temperature=0.5,
)
)
print("Gemini Response:", response.text)
return parse_json(response.text)
except Exception as e:
print(f"Error generating flowchart JSON: {e}")
return {}
# Function to predict the flowchart
def predict_flowchart(prompt):
"""
Generate a flowchart image based on the user's prompt.
"""
try:
# Generate the flowchart JSON
flowchart_json = generate_flowchart(prompt)
if not flowchart_json:
raise ValueError("Could not generate flowchart JSON.")
# Create a blank image to draw on
image = Image.new("RGB", (1000, 800), "white")
result_image = draw_flowchart(image, flowchart_json)
return result_image
except Exception as e:
print(f"Error during processing: {e}")
# Return a blank image in case of an error
error_image = Image.new("RGB", (1000, 800), "white")
draw = ImageDraw.Draw(error_image)
draw.text((50, 50), f"Error: {str(e)}", fill="red")
return error_image
# Define the Gradio interface for flowcharts
def gradio_interface_flowcharts():
"""
Gradio app interface for flowchart generation.
"""
with gr.Blocks(gr.themes.Glass(secondary_hue="blue")) as demo:
gr.Markdown("# Flowchart Generator with Gemini")
with gr.Row():
with gr.Column():
gr.Markdown("### Input Section")
input_prompt = gr.Textbox(lines=2, label="Input Prompt", placeholder="Describe the flowchart process.")
submit_btn = gr.Button("Generate Flowchart")
with gr.Column():
gr.Markdown("### Output Section")
output_image = gr.Image(type="pil", label="Output Flowchart")
# Event to generate flowcharts
submit_btn.click(
predict_flowchart,
inputs=[input_prompt],
outputs=[output_image]
)
return demo
# Run the app
if __name__ == "__main__":
demo = gradio_interface_flowcharts()
demo.launch()