Rahatara commited on
Commit
a0cfb1c
·
verified ·
1 Parent(s): 85963c6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from io import BytesIO
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ from PIL import ImageColor
6
+ import json
7
+ from google import genai
8
+ from google.genai import types
9
+
10
+ # Initialize Google Gemini client
11
+ client = genai.Client(api_key=os.environ['GEM_API_KEY'])
12
+ model_name = "gemini-2.0-flash-exp"
13
+
14
+ bounding_box_system_instructions = """
15
+ Return bounding boxes as a JSON array with labels. Never return masks or code fencing. Limit to 25 objects.
16
+ If an object is present multiple times, name them according to their unique characteristic (colors, size, position, unique characteristics, etc..).
17
+ """
18
+
19
+ additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]
20
+
21
+ def parse_json(json_output):
22
+ """
23
+ Parse JSON output from the Gemini model.
24
+ """
25
+ lines = json_output.splitlines()
26
+ for i, line in enumerate(lines):
27
+ if line == "```json":
28
+ json_output = "\n".join(lines[i+1:]) # Remove everything before "```json"
29
+ json_output = json_output.split("```")[0] # Remove everything after the closing "```"
30
+ break
31
+ return json_output
32
+
33
+ def plot_bounding_boxes(im, bounding_boxes):
34
+ """
35
+ Plots bounding boxes on an image with labels.
36
+ """
37
+ im = im.copy()
38
+ width, height = im.size
39
+ draw = ImageDraw.Draw(im)
40
+ colors = [
41
+ 'red', 'green', 'blue', 'yellow', 'orange', 'pink', 'purple', 'cyan',
42
+ 'lime', 'magenta', 'violet', 'gold', 'silver'
43
+ ] + additional_colors
44
+
45
+ try:
46
+ # Use a default font if NotoSansCJK is not available
47
+ try:
48
+ font = ImageFont.load_default()
49
+ except OSError:
50
+ print("NotoSansCJK-Regular.ttc not found. Using default font.")
51
+ font = ImageFont.load_default()
52
+
53
+ bounding_boxes_json = json.loads(bounding_boxes)
54
+ for i, bounding_box in enumerate(bounding_boxes_json):
55
+ color = colors[i % len(colors)]
56
+ abs_y1 = int(bounding_box["box_2d"][0] / 1000 * height)
57
+ abs_x1 = int(bounding_box["box_2d"][1] / 1000 * width)
58
+ abs_y2 = int(bounding_box["box_2d"][2] / 1000 * height)
59
+ abs_x2 = int(bounding_box["box_2d"][3] / 1000 * width)
60
+
61
+ if abs_x1 > abs_x2:
62
+ abs_x1, abs_x2 = abs_x2, abs_x1
63
+
64
+ if abs_y1 > abs_y2:
65
+ abs_y1, abs_y2 = abs_y2, abs_y1
66
+
67
+ # Draw bounding box and label
68
+ draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4)
69
+ if "label" in bounding_box:
70
+ draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font)
71
+ except Exception as e:
72
+ print(f"Error drawing bounding boxes: {e}")
73
+
74
+ return im
75
+
76
+ def predict_bounding_boxes(image, prompt):
77
+ """
78
+ Process the image and prompt through Gemini and draw bounding boxes.
79
+ """
80
+ try:
81
+ # Resize the image for input
82
+ image = image.resize((1024, int(1024 * image.height / image.width)))
83
+ buffered = BytesIO()
84
+ image.save(buffered, format="JPEG")
85
+ image_bytes = buffered.getvalue()
86
+
87
+ # Make API request to Gemini
88
+ response = client.models.generate_content(
89
+ model=model_name,
90
+ contents=[prompt, image],
91
+ config=types.GenerateContentConfig(
92
+ system_instruction=bounding_box_system_instructions,
93
+ temperature=0.5,
94
+ safety_settings=[
95
+ types.SafetySetting(
96
+ category="HARM_CATEGORY_DANGEROUS_CONTENT",
97
+ threshold="BLOCK_ONLY_HIGH",
98
+ )
99
+ ],
100
+ )
101
+ )
102
+
103
+ print("Gemini response:", response.text)
104
+
105
+ # Parse and plot bounding boxes
106
+ bounding_boxes = parse_json(response.text)
107
+ if not bounding_boxes:
108
+ raise ValueError("No bounding boxes returned.")
109
+
110
+ result_image = plot_bounding_boxes(image, bounding_boxes)
111
+ return result_image
112
+ except Exception as e:
113
+ print(f"Error during processing: {e}")
114
+ return image, f"Error: {e}"
115
+
116
+ def gradio_interface():
117
+ """
118
+ Gradio app interface for bounding box generation with example pairs.
119
+ """
120
+ # Example image + prompt pairs
121
+ examples = [
122
+ ["cookies.jpg", "Detect the cookies and label their types."],
123
+ ["messed_room.jpg", "Find the unorganized item and suggest action in label in the image to fix them."],
124
+ ["yoga.jpg", "Show the different yoga poses and name them."],
125
+ ["zoom_face.png", "Label the tired faces in the image."]
126
+ ]
127
+
128
+ with gr.Blocks(gr.themes.Glass(secondary_hue= "rose")) as demo:
129
+ gr.Markdown("# Gemini Bounding Box Generator")
130
+
131
+ with gr.Row():
132
+ with gr.Column():
133
+ gr.Markdown("### Input Section")
134
+ input_image = gr.Image(type="pil", label="Input Image")
135
+ input_prompt = gr.Textbox(lines=2, label="Input Prompt", placeholder="Describe what to detect.")
136
+ submit_btn = gr.Button("Generate")
137
+
138
+ with gr.Column():
139
+ gr.Markdown("### Output Section")
140
+ output_image = gr.Image(type="pil", label="Output Image")
141
+ #output_json = gr.Textbox(label="Bounding Boxes JSON")
142
+
143
+ gr.Markdown("### Examples")
144
+ gr.Examples(
145
+ examples=examples,
146
+ inputs=[input_image, input_prompt],
147
+ label="Example Images with Prompts"
148
+ )
149
+
150
+ # Event to generate bounding boxes
151
+ submit_btn.click(
152
+ predict_bounding_boxes,
153
+ inputs=[input_image, input_prompt],
154
+ outputs=[output_image]
155
+ )
156
+
157
+ return demo
158
+
159
+
160
+
161
+ if __name__ == "__main__":
162
+ app = gradio_interface()
163
+ app.launch()