import os import io import torch import json import base64 import gradio as gr import numpy as np from pathlib import Path from PIL import Image from plots import get_pre_define_colors from utils.load_model import load_xclip from utils.predict import xclip_pred #! Huggingface does not allow load model to main process, so we need to load the model when needed, it may not help in improve the speed of the app. DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"Not at Huggingface demo, load model to main process.") XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE) print(f"Device: {DEVICE}") XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json" XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r")) IMAGES_FOLDER = "data/images" # XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r")) IMAGE2GT = json.load(open("data/jsons/image2gt.json", 'r')) CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt') CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r')) CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()} IMAGE_FILE_LIST = json.load(open("data/jsons/file_list.json", "r")) IMAGE_GALLERY = [Image.open(os.path.join(IMAGES_FOLDER, 'org', file_name)).convert('RGB') for file_name in IMAGE_FILE_LIST] ORG_PART_ORDER = ['back', 'beak', 'belly', 'breast', 'crown', 'forehead', 'eyes', 'legs', 'wings', 'nape', 'tail', 'throat'] ORDERED_PARTS = ['crown', 'forehead', 'nape', 'eyes', 'beak', 'throat', 'breast', 'belly', 'back', 'wings', 'legs', 'tail'] COLORS = get_pre_define_colors(12, cmap_set=['Set2', 'tab10']) SACHIT_COLOR = "#ADD8E6" # CUB_BOXES = json.load(open("data/jsons/cub_boxes_owlvit_large.json", "r")) VISIBILITY_DICT = json.load(open("data/jsons/cub_vis_dict_binary.json", 'r')) VISIBILITY_DICT['Eastern_Bluebird.jpg'] = dict(zip(ORDERED_PARTS, [True]*12)) # --- Image related functions --- def img_to_base64(img): img_pil = Image.fromarray(img) if isinstance(img, np.ndarray) else img buffered = io.BytesIO() img_pil.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()) return img_str.decode() def create_blank_image(width=500, height=500, color=(255, 255, 255)): """Create a blank image of the given size and color.""" return np.array(Image.new("RGB", (width, height), color)) # Convert RGB colors to hex def rgb_to_hex(rgb): return f"#{''.join(f'{x:02x}' for x in rgb)}" def load_part_images(file_name: str) -> dict: part_images = {} # start_time = time.time() for part_name in ORDERED_PARTS: base_name = Path(file_name).stem part_image_path = os.path.join(IMAGES_FOLDER, "boxes", f"{base_name}_{part_name}.jpg") if not Path(part_image_path).exists(): continue image = np.array(Image.open(part_image_path)) part_images[part_name] = img_to_base64(image) # print(f"Time cost to load 12 images: {time.time() - start_time}") # This takes less than 0.01 seconds. So the loading time is not the bottleneck. return part_images def generate_xclip_explanations(result_dict:dict, visibility: dict, part_mask: dict = dict(zip(ORDERED_PARTS, [1]*12))): """ The result_dict needs three keys: 'descriptions', 'pred_scores', 'file_name' descriptions: {part_name1: desc_1, part_name2: desc_2, ...} pred_scores: {part_name1: score_1, part_name2: score_2, ...} file_name: str """ descriptions = result_dict['descriptions'] image_name = result_dict['file_name'] part_images = PART_IMAGES_DICT[image_name] MAX_LENGTH = 50 exp_length = 400 fontsize = 15 # Start the SVG inside a div svg_parts = [f'
', ""] # Add a row for each visible bird part y_offset = 0 for part in ORDERED_PARTS: if visibility[part] and part_mask[part]: # Calculate the length of the bar (scaled to fit within the SVG) part_score = max(result_dict['pred_scores'][part], 0) bar_length = part_score * exp_length # Modify the overlay image's opacity on mouseover and mouseout mouseover_action1 = f"document.getElementById('overlayImage').src = 'data:image/jpeg;base64,{part_images[part]}'; document.getElementById('overlayImage').style.opacity = 1;" mouseout_action1 = "document.getElementById('overlayImage').style.opacity = 0;" combined_mouseover = f"javascript: {mouseover_action1};" combined_mouseout = f"javascript: {mouseout_action1};" # Add the description num_lines = len(descriptions[part]) // MAX_LENGTH + 1 for line in range(num_lines): desc_line = descriptions[part][line*MAX_LENGTH:(line+1)*MAX_LENGTH] y_offset += fontsize svg_parts.append(f""" {desc_line} """) # Add the bars svg_parts.append(f""" """) # Add the scores svg_parts.append(f'{part_score:.2f}') y_offset += fontsize + 3 svg_parts.extend(("", "
")) # Join everything into a single string html = "".join(svg_parts) return html def generate_sachit_explanations(result_dict:dict): descriptions = result_dict['descriptions'] scores = result_dict['scores'] MAX_LENGTH = 50 exp_length = 400 fontsize = 15 descriptions = zip(scores, descriptions) descriptions = sorted(descriptions, key=lambda x: x[0], reverse=True) # Start the SVG inside a div svg_parts = [f'
', ""] # Add a row for each visible bird part y_offset = 0 for score, desc in descriptions: # Calculate the length of the bar (scaled to fit within the SVG) part_score = max(score, 0) bar_length = part_score * exp_length # Split the description into two lines if it's too long num_lines = len(desc) // MAX_LENGTH + 1 for line in range(num_lines): desc_line = desc[line*MAX_LENGTH:(line+1)*MAX_LENGTH] y_offset += fontsize svg_parts.append(f""" {desc_line} """) # Add the bar svg_parts.append(f""" """) # Add the score svg_parts.append(f'{part_score:.2f}') # Added fill color y_offset += fontsize + 3 svg_parts.extend(("", "
")) # Join everything into a single string html = "".join(svg_parts) return html # --- Constants created by the functions above --- BLANK_OVERLAY = img_to_base64(create_blank_image()) PART_COLORS = {part: rgb_to_hex(COLORS[i]) for i, part in enumerate(ORDERED_PARTS)} blank_image = np.array(Image.open('data/images/final.png').convert('RGB')) PART_IMAGES_DICT = {file_name: load_part_images(file_name) for file_name in IMAGE_FILE_LIST} # --- Gradio Functions --- def update_selected_image(event: gr.SelectData): image_height = 400 index = event.index image_name = IMAGE_FILE_LIST[index] current_image.state = image_name org_image = Image.open(os.path.join(IMAGES_FOLDER, 'org', image_name)).convert('RGB') img_base64 = f"""
""" gt_label = IMAGE2GT[image_name] gt_class.state = gt_label # --- for initial value only --- out_dict = xclip_pred(new_desc=None, new_part_mask=None, new_class=None, org_desc=XCLIP_DESC_PATH, image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), model=XCLIP, owlvit_processor=OWLVIT_PRECESSOR, device=DEVICE, image_name=current_image.state, cub_embeds=CUB_DESC_EMBEDS, cub_idx2name=CUB_IDX2NAME, descriptors=XCLIP_DESC) xclip_label = out_dict['pred_class'] clip_pred_scores = out_dict['pred_score'] xclip_part_scores = out_dict['pred_desc_scores'] result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state} xclip_exp = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask=dict(zip(ORDERED_PARTS, [1]*12))) # --- end of intial value --- xclip_color = "green" if xclip_label.strip() == gt_label.strip() else "red" xclip_pred_markdown = f""" ### {xclip_label}     {clip_pred_scores:.4f} """ gt_label = f""" ## {gt_label} """ current_predicted_class.state = xclip_label # Populate the textbox with current descriptions custom_class_name = "class name: custom" descs = XCLIP_DESC[xclip_label] descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)} descs = {k: descs[k] for k in ORDERED_PARTS} custom_text = [custom_class_name] + list(descs.values()) descriptions = ";\n".join(custom_text) # textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False) textbox = gr.Textbox(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False) # modified_exp = gr.HTML().update(value="", visible=True) return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox def on_edit_button_click_xclip(): # empty_exp = gr.HTML.update(visible=False) empty_exp = gr.HTML(visible=False) # Populate the textbox with current descriptions descs = XCLIP_DESC[current_predicted_class.state] descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)} descs = {k: descs[k] for k in ORDERED_PARTS} custom_text = ["class name: custom"] + list(descs.values()) descriptions = ";\n".join(custom_text) # textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False) textbox = gr.Textbox(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False) return textbox, empty_exp def convert_input_text_to_xclip_format(textbox_input: str): # Split the descriptions by newline to get individual descriptions for each part descriptions_list = textbox_input.split(";\n") # the first line should be "class name: xxx" class_name_line = descriptions_list[0] new_class_name = class_name_line.split(":")[1].strip() descriptions_list = descriptions_list[1:] # construct descripion dict with part name as key descriptions_dict = {} for desc in descriptions_list: if desc.strip() == "": continue part_name, _ = desc.split(":") descriptions_dict[part_name.strip()] = desc # fill with empty string if the part is not in the descriptions part_mask = {} for part in ORDERED_PARTS: if part not in descriptions_dict: descriptions_dict[part] = "" part_mask[part] = 0 else: part_mask[part] = 1 return descriptions_dict, part_mask, new_class_name def on_predict_button_click_xclip(textbox_input: str): descriptions_dict, part_mask, new_class_name = convert_input_text_to_xclip_format(textbox_input) # Get the new predictions and explanations out_dict = xclip_pred(new_desc=descriptions_dict, new_part_mask=part_mask, new_class=new_class_name, org_desc=XCLIP_DESC_PATH, image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), model=XCLIP, owlvit_processor=OWLVIT_PRECESSOR, device=DEVICE, image_name=current_image.state, cub_embeds=CUB_DESC_EMBEDS, cub_idx2name=CUB_IDX2NAME, descriptors=XCLIP_DESC) xclip_label = out_dict['pred_class'] xclip_pred_score = out_dict['pred_score'] xclip_part_scores = out_dict['pred_desc_scores'] custom_label = out_dict['modified_class'] custom_pred_score = out_dict['modified_score'] custom_part_scores = out_dict['modified_desc_scores'] # construct a result dict to generate xclip explanations result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state} xclip_explanation = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask) modified_result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["modified_descriptions"])), 'pred_scores': custom_part_scores, 'file_name': current_image.state} modified_explanation = generate_xclip_explanations(modified_result_dict, VISIBILITY_DICT[current_image.state], part_mask) xclip_color = "green" if xclip_label.strip() == gt_class.state.strip() else "red" xclip_pred_markdown = f""" ### {xclip_label}     {xclip_pred_score:.4f} """ custom_color = "green" if custom_label.strip() == gt_class.state.strip() else "red" custom_pred_markdown = f""" ### {custom_label}     {custom_pred_score:.4f} """ # textbox = gr.Textbox.update(visible=False) textbox = gr.Textbox(visible=False) # return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_explanation # modified_exp = gr.HTML().update(value=modified_explanation, visible=True) modified_exp = gr.HTML(value=modified_explanation, visible=True) return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp custom_css = """ html, body { margin: 0; padding: 0; } #container { position: relative; width: 400px; height: 400px; border: 1px solid #000; margin: 0 auto; /* This will center the container horizontally */ } #canvas { position: absolute; top: 0; left: 0; width: 100%; height: 100%; object-fit: cover; } """ # Define the Gradio interface with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PEEB") as demo: current_image = gr.State("") current_predicted_class = gr.State("") gt_class = gr.State("") with gr.Column(): title_text = gr.Markdown("# Demo | A classifier with Part-based Explainable and Editable Bottleneck (PEEB)") gr.Markdown("PEEB is an image classifier, here for birds, pre-trained on Bird-11K and finetuned on CUB-200 (see our [NAACL 2024 paper](https://arxiv.org/abs/2403.05297) and [code](https://github.com/anguyen8/peeb/tree/inspect_ddp)).\n This **interactive** demo shows how to run PEEB on an existing image and how to **edit** a class' textual description to directly change the classifier to detect one new bird species (without any re-training).") gr.Markdown( """ ### Steps: 1. **Select an image**. Then, PEEB will show its grounded explanations and the top-1 predicted label with associated `softmax` confidence score. 2. **Hover mouse over text descriptors** to see the corresponding region used to match to each text descriptor. 3. **Edit the text under [Extra class]()** which correspond to one extra, new class (i.e. 200+1 = `201`). Further editing will overwrite this class' descriptors. 4. **Click on [Predict]()** to see the grounded explanations and the top-1 label for the newly modified CUB-201 classifier. """ ) # display the gallery of images with gr.Column(): gr.Markdown("## Select an image to start!") image_gallery = gr.Gallery(value=IMAGE_GALLERY, label=None, preview=False, allow_preview=False, columns=10, height=250) gr.Markdown("### Extra-class descriptors: \n The first row should be `class name: {some name};`, the name of your 201th class. \n For the 12 part descriptors, please use `;` to separate the descriptions for each part, and use the format `{part name}: {descriptions}`.") gr.Markdown("**Note:** you can delete a row for any given part (e.g. `nape`) and that part will be removed from all 201 classes in the classifier. For example, you can edit PEEB into a classifier that only identifies birds using 5 parts by deleting all rows corresponding to the other 7 parts.") with gr.Row(): with gr.Column(): image_label = gr.Markdown("### Class Name") org_image = gr.HTML() with gr.Column(): with gr.Row(): # xclip_predict_button = gr.Button(label="Predict", value="Predict") xclip_predict_button = gr.Button(value="Predict") xclip_pred_label = gr.Markdown("### Top-1 class:") xclip_explanation = gr.HTML() with gr.Column(): # xclip_edit_button = gr.Button(label="Edit", value="Reset Extra-class descriptors") xclip_edit_button = gr.Button(value="Reset Descriptions") custom_pred_label = gr.Markdown( "### Extra class:" ) xclip_textbox = gr.Textbox(lines=12, placeholder="Edit the descriptions here", visible=False) # ai_explanation = gr.Image(type="numpy", visible=True, show_label=False, height=500) custom_explanation = gr.HTML() gr.HTML("
") image_gallery.select(update_selected_image, inputs=None, outputs=[image_label, org_image, xclip_pred_label, xclip_explanation, current_image, xclip_textbox]) xclip_edit_button.click(on_edit_button_click_xclip, inputs=[], outputs=[xclip_textbox, custom_explanation]) xclip_predict_button.click(on_predict_button_click_xclip, inputs=[xclip_textbox], outputs=[xclip_textbox, xclip_pred_label, xclip_explanation, custom_pred_label, custom_explanation]) demo.launch()