Peijie commited on
Commit
711211a
1 Parent(s): c0d5a1b

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +7 -0
  2. README.md +3 -0
  3. app.py +400 -0
  4. data/image_embeddings/American_Goldfinch_0123_32505.jpg.pt +3 -0
  5. data/image_embeddings/Black_Tern_0101_144331.jpg.pt +3 -0
  6. data/image_embeddings/Brandt_Cormorant_0040_23144.jpg.pt +3 -0
  7. data/image_embeddings/Brown_Thrasher_0014_155421.jpg.pt +3 -0
  8. data/image_embeddings/Carolina_Wren_0060_186296.jpg.pt +3 -0
  9. data/image_embeddings/Cedar_Waxwing_0075_179114.jpg.pt +3 -0
  10. data/image_embeddings/Clark_Nutcracker_0126_85134.jpg.pt +3 -0
  11. data/image_embeddings/Gray_Catbird_0071_20974.jpg.pt +3 -0
  12. data/image_embeddings/Heermann_Gull_0097_45783.jpg.pt +3 -0
  13. data/image_embeddings/House_Wren_0137_187273.jpg.pt +3 -0
  14. data/image_embeddings/Ivory_Gull_0004_49019.jpg.pt +3 -0
  15. data/image_embeddings/Northern_Waterthrush_0038_177027.jpg.pt +3 -0
  16. data/image_embeddings/Pine_Warbler_0113_172456.jpg.pt +3 -0
  17. data/image_embeddings/Red_Headed_Woodpecker_0032_182815.jpg.pt +3 -0
  18. data/image_embeddings/Rufous_Hummingbird_0076_59563.jpg.pt +3 -0
  19. data/image_embeddings/Sage_Thrasher_0062_796462.jpg.pt +3 -0
  20. data/image_embeddings/Vesper_Sparrow_0030_125663.jpg.pt +3 -0
  21. data/image_embeddings/Western_Grebe_0064_36613.jpg.pt +3 -0
  22. data/image_embeddings/White_Eyed_Vireo_0046_158849.jpg.pt +3 -0
  23. data/image_embeddings/Winter_Wren_0048_189683.jpg.pt +3 -0
  24. data/images/boxes/American_Goldfinch_0123_32505_all.jpg +0 -0
  25. data/images/boxes/American_Goldfinch_0123_32505_back.jpg +0 -0
  26. data/images/boxes/American_Goldfinch_0123_32505_beak.jpg +0 -0
  27. data/images/boxes/American_Goldfinch_0123_32505_belly.jpg +0 -0
  28. data/images/boxes/American_Goldfinch_0123_32505_breast.jpg +0 -0
  29. data/images/boxes/American_Goldfinch_0123_32505_crown.jpg +0 -0
  30. data/images/boxes/American_Goldfinch_0123_32505_eyes.jpg +0 -0
  31. data/images/boxes/American_Goldfinch_0123_32505_forehead.jpg +0 -0
  32. data/images/boxes/American_Goldfinch_0123_32505_legs.jpg +0 -0
  33. data/images/boxes/American_Goldfinch_0123_32505_nape.jpg +0 -0
  34. data/images/boxes/American_Goldfinch_0123_32505_tail.jpg +0 -0
  35. data/images/boxes/American_Goldfinch_0123_32505_throat.jpg +0 -0
  36. data/images/boxes/American_Goldfinch_0123_32505_visible.jpg +0 -0
  37. data/images/boxes/American_Goldfinch_0123_32505_wings.jpg +0 -0
  38. data/images/boxes/Black_Tern_0101_144331_all.jpg +0 -0
  39. data/images/boxes/Black_Tern_0101_144331_back.jpg +0 -0
  40. data/images/boxes/Black_Tern_0101_144331_beak.jpg +0 -0
  41. data/images/boxes/Black_Tern_0101_144331_belly.jpg +0 -0
  42. data/images/boxes/Black_Tern_0101_144331_breast.jpg +0 -0
  43. data/images/boxes/Black_Tern_0101_144331_crown.jpg +0 -0
  44. data/images/boxes/Black_Tern_0101_144331_eyes.jpg +0 -0
  45. data/images/boxes/Black_Tern_0101_144331_forehead.jpg +0 -0
  46. data/images/boxes/Black_Tern_0101_144331_legs.jpg +0 -0
  47. data/images/boxes/Black_Tern_0101_144331_nape.jpg +0 -0
  48. data/images/boxes/Black_Tern_0101_144331_tail.jpg +0 -0
  49. data/images/boxes/Black_Tern_0101_144331_throat.jpg +0 -0
  50. data/images/boxes/Black_Tern_0101_144331_visible.jpg +0 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ temp*
2
+
3
+
4
+ # python temp files
5
+ __pycache__
6
+ *.pyc
7
+ .vscode
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
app.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+
4
+ import json
5
+ import base64
6
+ import random
7
+ import numpy as np
8
+ import pandas as pd
9
+ import gradio as gr
10
+ from pathlib import Path
11
+ from PIL import Image
12
+
13
+ from plots import get_pre_define_colors
14
+ from utils.load_model import load_xclip
15
+ from utils.predict import xclip_pred
16
+
17
+
18
+ DEVICE = "cpu"
19
+ XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)
20
+ XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json"
21
+ XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r"))
22
+ PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt')
23
+ IMAGES_FOLDER = "data/images"
24
+ # XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
25
+ # correct_predictions = [k for k, v in XCLIP_RESULTS.items() if v['prediction']]
26
+
27
+ # get the intersection of sachit and xclip (revised)
28
+ # INTERSECTION = []
29
+ # IMAGE_RES = 400 * 400 # minimum resolution
30
+ # TOTAL_SAMPLES = 20
31
+ # for file_name in XCLIP_RESULTS:
32
+ # image = Image.open(os.path.join(IMAGES_FOLDER, 'org', file_name)).convert('RGB')
33
+ # w, h = image.size
34
+ # if w * h < IMAGE_RES:
35
+ # continue
36
+ # else:
37
+ # INTERSECTION.append(file_name)
38
+
39
+ # IMAGE_FILE_LIST = random.sample(INTERSECTION, TOTAL_SAMPLES)
40
+ IMAGE_FILE_LIST = json.load(open("data/jsons/file_list.json", "r"))
41
+ # IMAGE_FILE_LIST = IMAGE_FILE_LIST[:19]
42
+ # IMAGE_FILE_LIST.append('Eastern_Bluebird.jpg')
43
+ IMAGE_GALLERY = [Image.open(os.path.join(IMAGES_FOLDER, 'org', file_name)).convert('RGB') for file_name in IMAGE_FILE_LIST]
44
+
45
+ ORG_PART_ORDER = ['back', 'beak', 'belly', 'breast', 'crown', 'forehead', 'eyes', 'legs', 'wings', 'nape', 'tail', 'throat']
46
+ ORDERED_PARTS = ['crown', 'forehead', 'nape', 'eyes', 'beak', 'throat', 'breast', 'belly', 'back', 'wings', 'legs', 'tail']
47
+ COLORS = get_pre_define_colors(12, cmap_set=['Set2', 'tab10'])
48
+ SACHIT_COLOR = "#ADD8E6"
49
+ # CUB_BOXES = json.load(open("data/jsons/cub_boxes_owlvit_large.json", "r"))
50
+ VISIBILITY_DICT = json.load(open("data/jsons/cub_vis_dict_binary.json", 'r'))
51
+ VISIBILITY_DICT['Eastern_Bluebird.jpg'] = dict(zip(ORDERED_PARTS, [True]*12))
52
+
53
+ # --- Image related functions ---
54
+ def img_to_base64(img):
55
+ img_pil = Image.fromarray(img) if isinstance(img, np.ndarray) else img
56
+ buffered = io.BytesIO()
57
+ img_pil.save(buffered, format="JPEG")
58
+ img_str = base64.b64encode(buffered.getvalue())
59
+ return img_str.decode()
60
+
61
+ def create_blank_image(width=500, height=500, color=(255, 255, 255)):
62
+ """Create a blank image of the given size and color."""
63
+ return np.array(Image.new("RGB", (width, height), color))
64
+
65
+ # Convert RGB colors to hex
66
+ def rgb_to_hex(rgb):
67
+ return f"#{''.join(f'{x:02x}' for x in rgb)}"
68
+
69
+ def load_part_images(file_name: str) -> dict:
70
+ part_images = {}
71
+ # start_time = time.time()
72
+ for part_name in ORDERED_PARTS:
73
+ base_name = Path(file_name).stem
74
+ part_image_path = os.path.join(IMAGES_FOLDER, "boxes", f"{base_name}_{part_name}.jpg")
75
+ if not Path(part_image_path).exists():
76
+ continue
77
+ image = np.array(Image.open(part_image_path))
78
+ part_images[part_name] = img_to_base64(image)
79
+ # print(f"Time cost to load 12 images: {time.time() - start_time}")
80
+ # This takes less than 0.01 seconds. So the loading time is not the bottleneck.
81
+ return part_images
82
+
83
+ def generate_xclip_explanations(result_dict:dict, visibility: dict, part_mask: dict = dict(zip(ORDERED_PARTS, [1]*12))):
84
+ """
85
+ The result_dict needs three keys: 'descriptions', 'pred_scores', 'file_name'
86
+ descriptions: {part_name1: desc_1, part_name2: desc_2, ...}
87
+ pred_scores: {part_name1: score_1, part_name2: score_2, ...}
88
+ file_name: str
89
+ """
90
+
91
+ descriptions = result_dict['descriptions']
92
+ image_name = result_dict['file_name']
93
+ part_images = PART_IMAGES_DICT[image_name]
94
+ MAX_LENGTH = 50
95
+ exp_length = 400
96
+ fontsize = 15
97
+
98
+ # Start the SVG inside a div
99
+ svg_parts = [f'<div style="width: {exp_length}px; height: 450px; background-color: white;">',
100
+ "<svg width=\"100%\" height=\"100%\">"]
101
+
102
+ # Add a row for each visible bird part
103
+ y_offset = 0
104
+ for part in ORDERED_PARTS:
105
+ if visibility[part] and part_mask[part]:
106
+ # Calculate the length of the bar (scaled to fit within the SVG)
107
+ part_score = max(result_dict['pred_scores'][part], 0)
108
+ bar_length = part_score * exp_length
109
+
110
+ # Modify the overlay image's opacity on mouseover and mouseout
111
+ mouseover_action1 = f"document.getElementById('overlayImage').src = 'data:image/jpeg;base64,{part_images[part]}'; document.getElementById('overlayImage').style.opacity = 1;"
112
+ mouseout_action1 = "document.getElementById('overlayImage').style.opacity = 0;"
113
+
114
+ combined_mouseover = f"javascript: {mouseover_action1};"
115
+ combined_mouseout = f"javascript: {mouseout_action1};"
116
+
117
+ # Add the description
118
+ num_lines = len(descriptions[part]) // MAX_LENGTH + 1
119
+ for line in range(num_lines):
120
+ desc_line = descriptions[part][line*MAX_LENGTH:(line+1)*MAX_LENGTH]
121
+ y_offset += fontsize
122
+ svg_parts.append(f"""
123
+ <text x="0" y="{y_offset}" font-size="{fontsize}"
124
+ onmouseover="{combined_mouseover}"
125
+ onmouseout="{combined_mouseout}">
126
+ {desc_line}
127
+ </text>
128
+ """)
129
+
130
+ # Add the bars
131
+ svg_parts.append(f"""
132
+ <rect x="0" y="{y_offset +3}" width="{bar_length}" height="{fontsize*0.7}" fill="{PART_COLORS[part]}"
133
+ onmouseover="{combined_mouseover}"
134
+ onmouseout="{combined_mouseout}">
135
+ </rect>
136
+ """)
137
+ # Add the scores
138
+ svg_parts.append(f'<text x="{exp_length - 50}" y="{y_offset+fontsize+3}" font-size="{fontsize}" fill="{PART_COLORS[part]}">{part_score:.2f}</text>')
139
+
140
+ y_offset += fontsize + 3
141
+ svg_parts.extend(("</svg>", "</div>"))
142
+ # Join everything into a single string
143
+ html = "".join(svg_parts)
144
+
145
+
146
+ return html
147
+
148
+
149
+
150
+ def generate_sachit_explanations(result_dict:dict):
151
+ descriptions = result_dict['descriptions']
152
+ scores = result_dict['scores']
153
+ MAX_LENGTH = 50
154
+ exp_length = 400
155
+ fontsize = 15
156
+
157
+ descriptions = zip(scores, descriptions)
158
+ descriptions = sorted(descriptions, key=lambda x: x[0], reverse=True)
159
+
160
+ # Start the SVG inside a div
161
+ svg_parts = [f'<div style="width: {exp_length}px; height: 450px; background-color: white;">',
162
+ "<svg width=\"100%\" height=\"100%\">"]
163
+
164
+ # Add a row for each visible bird part
165
+ y_offset = 0
166
+ for score, desc in descriptions:
167
+
168
+ # Calculate the length of the bar (scaled to fit within the SVG)
169
+ part_score = max(score, 0)
170
+ bar_length = part_score * exp_length
171
+
172
+ # Split the description into two lines if it's too long
173
+ num_lines = len(desc) // MAX_LENGTH + 1
174
+ for line in range(num_lines):
175
+ desc_line = desc[line*MAX_LENGTH:(line+1)*MAX_LENGTH]
176
+ y_offset += fontsize
177
+ svg_parts.append(f"""
178
+ <text x="0" y="{y_offset}" font-size="{fontsize}" fill="black">
179
+ {desc_line}
180
+ </text>
181
+ """)
182
+
183
+ # Add the bar
184
+ svg_parts.append(f"""
185
+ <rect x="0" y="{y_offset+3}" width="{bar_length}" height="{fontsize*0.7}" fill="{SACHIT_COLOR}">
186
+ </rect>
187
+ """)
188
+
189
+ # Add the score
190
+ svg_parts.append(f'<text x="{exp_length - 50}" y="{y_offset+fontsize+3}" font-size="fontsize" fill="{SACHIT_COLOR}">{part_score:.2f}</text>') # Added fill color
191
+
192
+ y_offset += fontsize + 3
193
+
194
+
195
+ svg_parts.extend(("</svg>", "</div>"))
196
+ # Join everything into a single string
197
+ html = "".join(svg_parts)
198
+
199
+
200
+ return html
201
+
202
+ # --- Constants created by the functions above ---
203
+ BLANK_OVERLAY = img_to_base64(create_blank_image())
204
+ PART_COLORS = {part: rgb_to_hex(COLORS[i]) for i, part in enumerate(ORDERED_PARTS)}
205
+ blank_image = np.array(Image.open('data/images/final.png').convert('RGB'))
206
+ PART_IMAGES_DICT = {file_name: load_part_images(file_name) for file_name in IMAGE_FILE_LIST}
207
+
208
+ # --- Gradio Functions ---
209
+ def update_selected_image(event: gr.SelectData):
210
+ image_height = 400
211
+ index = event.index
212
+
213
+ image_name = IMAGE_FILE_LIST[index]
214
+ current_image.state = image_name
215
+ org_image = Image.open(os.path.join(IMAGES_FOLDER, 'org', image_name)).convert('RGB')
216
+ img_base64 = f"""
217
+ <div style="position: relative; height: {image_height}px; display: inline-block;">
218
+ <img id="birdImage" src="data:image/jpeg;base64,{img_to_base64(org_image)}" style="height: {image_height}px; width: auto;">
219
+ <img id="overlayImage" src="data:image/jpeg;base64,{BLANK_OVERLAY}" style="position:absolute; top:0; left:0; width:auto; height: {image_height}px; opacity: 0;">
220
+ </div>
221
+ """
222
+ gt_label = XCLIP_RESULTS[image_name]['ground_truth']
223
+ gt_class.state = gt_label
224
+
225
+ # --- for initial value only ---
226
+ out_dict = xclip_pred(new_desc=None, new_part_mask=None, new_class=None, org_desc=XCLIP_DESC_PATH, image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), model=XCLIP, owlvit_processor=OWLVIT_PRECESSOR, device=DEVICE, image_name=current_image.state)
227
+ xclip_label = out_dict['pred_class']
228
+ clip_pred_scores = out_dict['pred_score']
229
+ xclip_part_scores = out_dict['pred_desc_scores']
230
+ result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state}
231
+ xclip_exp = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask=dict(zip(ORDERED_PARTS, [1]*12)))
232
+ # --- end of intial value ---
233
+
234
+ xclip_color = "green" if xclip_label.strip() == gt_label.strip() else "red"
235
+ xclip_pred_markdown = f"""
236
+ ### <span style='color:{xclip_color}'>XCLIP: {xclip_label} &nbsp;&nbsp;&nbsp; {clip_pred_scores:.4f}</span>
237
+ """
238
+
239
+ gt_label = f"""
240
+ ## {gt_label}
241
+ """
242
+ current_predicted_class.state = xclip_label
243
+
244
+ # Populate the textbox with current descriptions
245
+ custom_class_name = "class name: custom"
246
+ descs = XCLIP_DESC[xclip_label]
247
+ descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)}
248
+ descs = {k: descs[k] for k in ORDERED_PARTS}
249
+ custom_text = [custom_class_name] + list(descs.values())
250
+ descriptions = ";\n".join(custom_text)
251
+ textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
252
+ # modified_exp = gr.HTML().update(value="", visible=True)
253
+ return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox
254
+
255
+ def on_edit_button_click_xclip():
256
+ empty_exp = gr.HTML.update(visible=False)
257
+
258
+ # Populate the textbox with current descriptions
259
+ descs = XCLIP_DESC[current_predicted_class.state]
260
+ descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)}
261
+ descs = {k: descs[k] for k in ORDERED_PARTS}
262
+ custom_text = ["class name: custom"] + list(descs.values())
263
+ descriptions = ";\n".join(custom_text)
264
+ textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False)
265
+
266
+ return textbox, empty_exp
267
+
268
+ def convert_input_text_to_xclip_format(textbox_input: str):
269
+
270
+ # Split the descriptions by newline to get individual descriptions for each part
271
+ descriptions_list = textbox_input.split(";\n")
272
+ # the first line should be "class name: xxx"
273
+ class_name_line = descriptions_list[0]
274
+ new_class_name = class_name_line.split(":")[1].strip()
275
+
276
+ descriptions_list = descriptions_list[1:]
277
+
278
+ # construct descripion dict with part name as key
279
+ descriptions_dict = {}
280
+ for desc in descriptions_list:
281
+ if desc.strip() == "":
282
+ continue
283
+ part_name, _ = desc.split(":")
284
+ descriptions_dict[part_name.strip()] = desc
285
+ # fill with empty string if the part is not in the descriptions
286
+ part_mask = {}
287
+ for part in ORDERED_PARTS:
288
+ if part not in descriptions_dict:
289
+ descriptions_dict[part] = ""
290
+ part_mask[part] = 0
291
+ else:
292
+ part_mask[part] = 1
293
+ return descriptions_dict, part_mask, new_class_name
294
+
295
+ def on_predict_button_click_xclip(textbox_input: str):
296
+ descriptions_dict, part_mask, new_class_name = convert_input_text_to_xclip_format(textbox_input)
297
+
298
+ # Get the new predictions and explanations
299
+ out_dict = xclip_pred(new_desc=descriptions_dict, new_part_mask=part_mask, new_class=new_class_name, org_desc=XCLIP_DESC_PATH, image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), model=XCLIP, owlvit_processor=OWLVIT_PRECESSOR, device=DEVICE, image_name=current_image.state)
300
+ xclip_label = out_dict['pred_class']
301
+ xclip_pred_score = out_dict['pred_score']
302
+ xclip_part_scores = out_dict['pred_desc_scores']
303
+ custom_label = out_dict['modified_class']
304
+ custom_pred_score = out_dict['modified_score']
305
+ custom_part_scores = out_dict['modified_desc_scores']
306
+
307
+ # construct a result dict to generate xclip explanations
308
+ result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state}
309
+ xclip_explanation = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask)
310
+ modified_result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["modified_descriptions"])), 'pred_scores': custom_part_scores, 'file_name': current_image.state}
311
+ modified_explanation = generate_xclip_explanations(modified_result_dict, VISIBILITY_DICT[current_image.state], part_mask)
312
+
313
+ xclip_color = "green" if xclip_label.strip() == gt_class.state.strip() else "red"
314
+ xclip_pred_markdown = f"""
315
+ ### <span style='color:{xclip_color}'>XCLIP: {xclip_label} &nbsp;&nbsp;&nbsp; {xclip_pred_score:.4f}</span>
316
+ """
317
+ custom_color = "green" if custom_label.strip() == gt_class.state.strip() else "red"
318
+ custom_pred_markdown = f"""
319
+ ### <span style='color:{custom_color}'>XCLIP: {custom_label} &nbsp;&nbsp;&nbsp; {custom_pred_score:.4f}</span>
320
+ """
321
+ textbox = gr.Textbox.update(visible=False)
322
+ # return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_explanation
323
+
324
+ modified_exp = gr.HTML().update(value=modified_explanation, visible=True)
325
+ return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp
326
+
327
+
328
+ custom_css = """
329
+ html, body {
330
+ margin: 0;
331
+ padding: 0;
332
+ }
333
+
334
+ #container {
335
+ position: relative;
336
+ width: 400px;
337
+ height: 400px;
338
+ border: 1px solid #000;
339
+ margin: 0 auto; /* This will center the container horizontally */
340
+ }
341
+
342
+ #canvas {
343
+ position: absolute;
344
+ top: 0;
345
+ left: 0;
346
+ width: 100%;
347
+ height: 100%;
348
+ object-fit: cover;
349
+ }
350
+
351
+ """
352
+
353
+ # Define the Gradio interface
354
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PEEB") as demo:
355
+ current_image = gr.State("")
356
+ current_predicted_class = gr.State("")
357
+ gt_class = gr.State("")
358
+
359
+ with gr.Column():
360
+ title_text = gr.Markdown("# PEEB - demo")
361
+ gr.Markdown(
362
+ "- In this demo, you can edit the descriptions of a class and see how to model react to it."
363
+ )
364
+
365
+ # display the gallery of images
366
+ with gr.Column():
367
+
368
+ gr.Markdown("## Select an image to start!")
369
+ image_gallery = gr.Gallery(value=IMAGE_GALLERY, label=None, preview=False, allow_preview=False, columns=10, height=250)
370
+ gr.Markdown("### Custom descritions: \n The first row should be **class name: {some name};**, where you can name your descriptions. \n For the remianing descriptions, please use **;** to separate the descriptions for each part, and use the format **{part name}: {descriptions}**. \n Note that you can delete a part completely, in such cases, all descriptions will remove the corresponding part.")
371
+
372
+ with gr.Row():
373
+ with gr.Column():
374
+ image_label = gr.Markdown("### Class Name")
375
+ org_image = gr.HTML()
376
+
377
+ with gr.Column():
378
+ with gr.Row():
379
+ # xclip_predict_button = gr.Button(label="Predict", value="Predict")
380
+ xclip_predict_button = gr.Button(value="Predict")
381
+ xclip_pred_label = gr.Markdown("### XCLIP:")
382
+ xclip_explanation = gr.HTML()
383
+
384
+ with gr.Column():
385
+ # xclip_edit_button = gr.Button(label="Edit", value="Reset Descriptions")
386
+ xclip_edit_button = gr.Button(value="Reset Descriptions")
387
+ custom_pred_label = gr.Markdown(
388
+ "### Custom Descritpions:"
389
+ )
390
+ xclip_textbox = gr.Textbox(lines=12, placeholder="Edit the descriptions here", visible=False)
391
+ # ai_explanation = gr.Image(type="numpy", visible=True, show_label=False, height=500)
392
+ custom_explanation = gr.HTML()
393
+
394
+ gr.HTML("<br>")
395
+
396
+ image_gallery.select(update_selected_image, inputs=None, outputs=[image_label, org_image, xclip_pred_label, xclip_explanation, current_image, xclip_textbox])
397
+ xclip_edit_button.click(on_edit_button_click_xclip, inputs=[], outputs=[xclip_textbox, custom_explanation])
398
+ xclip_predict_button.click(on_predict_button_click_xclip, inputs=[xclip_textbox], outputs=[xclip_textbox, xclip_pred_label, xclip_explanation, custom_pred_label, custom_explanation])
399
+
400
+ demo.launch(server_port=5000, share=True)
data/image_embeddings/American_Goldfinch_0123_32505.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4405b6dfc87741cf87aa4887f77308aee46209877a7dcf29caacb4dae12459d5
3
+ size 1770910
data/image_embeddings/Black_Tern_0101_144331.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:218995c5e9d3256313ead069ff11c89a52ce616221880070d722f27c4227ffe2
3
+ size 1770875
data/image_embeddings/Brandt_Cormorant_0040_23144.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c493ed75f6dad68a1336ae3142deea98acb2eec30fbb5345aa1c545660eef4bb
3
+ size 1770900
data/image_embeddings/Brown_Thrasher_0014_155421.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c051c80027beeebfabab679b596f5a2b7536c016c2c966a5736b03a980b96a5
3
+ size 1770895
data/image_embeddings/Carolina_Wren_0060_186296.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0c34e05f759b6244ad50ca5529002e26a9370c9db07d22df91e476f827b7724
3
+ size 1770890
data/image_embeddings/Cedar_Waxwing_0075_179114.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d91e1fd22664d4dbad771f214ae943b60c26a0e52aeefc156eddbddde8cb0fb
3
+ size 1770890
data/image_embeddings/Clark_Nutcracker_0126_85134.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99e85d16d9b4b0d62e92926a7cefce6fbd5298daa1632df02d1d2bc1c812ccf4
3
+ size 1770900
data/image_embeddings/Gray_Catbird_0071_20974.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e02ea920306d2a41b2f0a46c3205691e1373d3a443714ba31c67bd46fa0baae8
3
+ size 1770880
data/image_embeddings/Heermann_Gull_0097_45783.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51ecf397a13ffc0ef481b029c7c54498dd9c0dda7db709f9335dba01faebdc65
3
+ size 1770885
data/image_embeddings/House_Wren_0137_187273.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fab5144fff8e0ff975f9064337dc032d39918bf777d149e02e4952a6ed10d8b
3
+ size 1770875
data/image_embeddings/Ivory_Gull_0004_49019.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:129b38324da3899caa7182fa0a251c81eba2a8ba8e71995139e269d479456e75
3
+ size 1770870
data/image_embeddings/Northern_Waterthrush_0038_177027.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bd735f0756b810b8c74628ca2285311411cb6fb14639277728a60260e64cda9
3
+ size 1770925
data/image_embeddings/Pine_Warbler_0113_172456.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c48503ff01eb8af79b86315ab9b6abe7d215c32ab37eb5acc54dd99b9877574
3
+ size 1770885
data/image_embeddings/Red_Headed_Woodpecker_0032_182815.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54ec8a9edf3bc0e5e21a989596469efec44815f9ac30a0cdbde4f5d1f1952619
3
+ size 1770930
data/image_embeddings/Rufous_Hummingbird_0076_59563.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d02487c6d3b10c2bc193547a3ad863b6b02710071e93b2a99e9be17931c9e785
3
+ size 1770910
data/image_embeddings/Sage_Thrasher_0062_796462.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:294ae723107b6cc26f467ef19018f7d0c27befe0ddbf46ea1432a4440cf538c7
3
+ size 1770890
data/image_embeddings/Vesper_Sparrow_0030_125663.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c3b58049302546a0f19e1a0da37d85ee3841d1f34674a6263b4972229539806
3
+ size 1770895
data/image_embeddings/Western_Grebe_0064_36613.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66a4a4c3d9e8c61c729eef180dca7c06dc19748be507798548bb629fb8283645
3
+ size 1770885
data/image_embeddings/White_Eyed_Vireo_0046_158849.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31f5601dd90778785d90da4b079faa4e8082da814b0edb75c46c27f7a59bb0c3
3
+ size 1770905
data/image_embeddings/Winter_Wren_0048_189683.jpg.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa44fb0827d907160d964837908b8d313bce096d02062be2ea7192e6c2903543
3
+ size 1770880
data/images/boxes/American_Goldfinch_0123_32505_all.jpg ADDED
data/images/boxes/American_Goldfinch_0123_32505_back.jpg ADDED
data/images/boxes/American_Goldfinch_0123_32505_beak.jpg ADDED
data/images/boxes/American_Goldfinch_0123_32505_belly.jpg ADDED
data/images/boxes/American_Goldfinch_0123_32505_breast.jpg ADDED
data/images/boxes/American_Goldfinch_0123_32505_crown.jpg ADDED
data/images/boxes/American_Goldfinch_0123_32505_eyes.jpg ADDED
data/images/boxes/American_Goldfinch_0123_32505_forehead.jpg ADDED
data/images/boxes/American_Goldfinch_0123_32505_legs.jpg ADDED
data/images/boxes/American_Goldfinch_0123_32505_nape.jpg ADDED
data/images/boxes/American_Goldfinch_0123_32505_tail.jpg ADDED
data/images/boxes/American_Goldfinch_0123_32505_throat.jpg ADDED
data/images/boxes/American_Goldfinch_0123_32505_visible.jpg ADDED
data/images/boxes/American_Goldfinch_0123_32505_wings.jpg ADDED
data/images/boxes/Black_Tern_0101_144331_all.jpg ADDED
data/images/boxes/Black_Tern_0101_144331_back.jpg ADDED
data/images/boxes/Black_Tern_0101_144331_beak.jpg ADDED
data/images/boxes/Black_Tern_0101_144331_belly.jpg ADDED
data/images/boxes/Black_Tern_0101_144331_breast.jpg ADDED
data/images/boxes/Black_Tern_0101_144331_crown.jpg ADDED
data/images/boxes/Black_Tern_0101_144331_eyes.jpg ADDED
data/images/boxes/Black_Tern_0101_144331_forehead.jpg ADDED
data/images/boxes/Black_Tern_0101_144331_legs.jpg ADDED
data/images/boxes/Black_Tern_0101_144331_nape.jpg ADDED
data/images/boxes/Black_Tern_0101_144331_tail.jpg ADDED
data/images/boxes/Black_Tern_0101_144331_throat.jpg ADDED
data/images/boxes/Black_Tern_0101_144331_visible.jpg ADDED