Spaces:
Sleeping
Sleeping
import io | |
import os | |
# os.system("pip uninstall -y gradio") | |
# os.system("pip install gradio==3.41.0") | |
import torch | |
import json | |
import base64 | |
import random | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
from pathlib import Path | |
from PIL import Image | |
from plots import get_pre_define_colors | |
from utils.load_model import load_xclip | |
from utils.predict import xclip_pred | |
DEVICE = "cpu" | |
XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE) | |
XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json" | |
XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r")) | |
PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt') | |
IMAGES_FOLDER = "data/images" | |
XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r")) | |
CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt') | |
CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r')) | |
CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()} | |
# correct_predictions = [k for k, v in XCLIP_RESULTS.items() if v['prediction']] | |
# get the intersection of sachit and xclip (revised) | |
# INTERSECTION = [] | |
# IMAGE_RES = 400 * 400 # minimum resolution | |
# TOTAL_SAMPLES = 20 | |
# for file_name in XCLIP_RESULTS: | |
# image = Image.open(os.path.join(IMAGES_FOLDER, 'org', file_name)).convert('RGB') | |
# w, h = image.size | |
# if w * h < IMAGE_RES: | |
# continue | |
# else: | |
# INTERSECTION.append(file_name) | |
# IMAGE_FILE_LIST = random.sample(INTERSECTION, TOTAL_SAMPLES) | |
IMAGE_FILE_LIST = json.load(open("data/jsons/file_list.json", "r")) | |
# IMAGE_FILE_LIST = IMAGE_FILE_LIST[:19] | |
# IMAGE_FILE_LIST.append('Eastern_Bluebird.jpg') | |
IMAGE_GALLERY = [Image.open(os.path.join(IMAGES_FOLDER, 'org', file_name)).convert('RGB') for file_name in IMAGE_FILE_LIST] | |
ORG_PART_ORDER = ['back', 'beak', 'belly', 'breast', 'crown', 'forehead', 'eyes', 'legs', 'wings', 'nape', 'tail', 'throat'] | |
ORDERED_PARTS = ['crown', 'forehead', 'nape', 'eyes', 'beak', 'throat', 'breast', 'belly', 'back', 'wings', 'legs', 'tail'] | |
COLORS = get_pre_define_colors(12, cmap_set=['Set2', 'tab10']) | |
SACHIT_COLOR = "#ADD8E6" | |
# CUB_BOXES = json.load(open("data/jsons/cub_boxes_owlvit_large.json", "r")) | |
VISIBILITY_DICT = json.load(open("data/jsons/cub_vis_dict_binary.json", 'r')) | |
VISIBILITY_DICT['Eastern_Bluebird.jpg'] = dict(zip(ORDERED_PARTS, [True]*12)) | |
# --- Image related functions --- | |
def img_to_base64(img): | |
img_pil = Image.fromarray(img) if isinstance(img, np.ndarray) else img | |
buffered = io.BytesIO() | |
img_pil.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()) | |
return img_str.decode() | |
def create_blank_image(width=500, height=500, color=(255, 255, 255)): | |
"""Create a blank image of the given size and color.""" | |
return np.array(Image.new("RGB", (width, height), color)) | |
# Convert RGB colors to hex | |
def rgb_to_hex(rgb): | |
return f"#{''.join(f'{x:02x}' for x in rgb)}" | |
def load_part_images(file_name: str) -> dict: | |
part_images = {} | |
# start_time = time.time() | |
for part_name in ORDERED_PARTS: | |
base_name = Path(file_name).stem | |
part_image_path = os.path.join(IMAGES_FOLDER, "boxes", f"{base_name}_{part_name}.jpg") | |
if not Path(part_image_path).exists(): | |
continue | |
image = np.array(Image.open(part_image_path)) | |
part_images[part_name] = img_to_base64(image) | |
# print(f"Time cost to load 12 images: {time.time() - start_time}") | |
# This takes less than 0.01 seconds. So the loading time is not the bottleneck. | |
return part_images | |
def generate_xclip_explanations(result_dict:dict, visibility: dict, part_mask: dict = dict(zip(ORDERED_PARTS, [1]*12))): | |
""" | |
The result_dict needs three keys: 'descriptions', 'pred_scores', 'file_name' | |
descriptions: {part_name1: desc_1, part_name2: desc_2, ...} | |
pred_scores: {part_name1: score_1, part_name2: score_2, ...} | |
file_name: str | |
""" | |
descriptions = result_dict['descriptions'] | |
image_name = result_dict['file_name'] | |
part_images = PART_IMAGES_DICT[image_name] | |
MAX_LENGTH = 50 | |
exp_length = 400 | |
fontsize = 15 | |
# Start the SVG inside a div | |
svg_parts = [f'<div style="width: {exp_length}px; height: 450px; background-color: white;">', | |
"<svg width=\"100%\" height=\"100%\">"] | |
# Add a row for each visible bird part | |
y_offset = 0 | |
for part in ORDERED_PARTS: | |
if visibility[part] and part_mask[part]: | |
# Calculate the length of the bar (scaled to fit within the SVG) | |
part_score = max(result_dict['pred_scores'][part], 0) | |
bar_length = part_score * exp_length | |
# Modify the overlay image's opacity on mouseover and mouseout | |
mouseover_action1 = f"document.getElementById('overlayImage').src = 'data:image/jpeg;base64,{part_images[part]}'; document.getElementById('overlayImage').style.opacity = 1;" | |
mouseout_action1 = "document.getElementById('overlayImage').style.opacity = 0;" | |
combined_mouseover = f"javascript: {mouseover_action1};" | |
combined_mouseout = f"javascript: {mouseout_action1};" | |
# Add the description | |
num_lines = len(descriptions[part]) // MAX_LENGTH + 1 | |
for line in range(num_lines): | |
desc_line = descriptions[part][line*MAX_LENGTH:(line+1)*MAX_LENGTH] | |
y_offset += fontsize | |
svg_parts.append(f""" | |
<text x="0" y="{y_offset}" font-size="{fontsize}" | |
onmouseover="{combined_mouseover}" | |
onmouseout="{combined_mouseout}"> | |
{desc_line} | |
</text> | |
""") | |
# Add the bars | |
svg_parts.append(f""" | |
<rect x="0" y="{y_offset +3}" width="{bar_length}" height="{fontsize*0.7}" fill="{PART_COLORS[part]}" | |
onmouseover="{combined_mouseover}" | |
onmouseout="{combined_mouseout}"> | |
</rect> | |
""") | |
# Add the scores | |
svg_parts.append(f'<text x="{exp_length - 50}" y="{y_offset+fontsize+3}" font-size="{fontsize}" fill="{PART_COLORS[part]}">{part_score:.2f}</text>') | |
y_offset += fontsize + 3 | |
svg_parts.extend(("</svg>", "</div>")) | |
# Join everything into a single string | |
html = "".join(svg_parts) | |
return html | |
def generate_sachit_explanations(result_dict:dict): | |
descriptions = result_dict['descriptions'] | |
scores = result_dict['scores'] | |
MAX_LENGTH = 50 | |
exp_length = 400 | |
fontsize = 15 | |
descriptions = zip(scores, descriptions) | |
descriptions = sorted(descriptions, key=lambda x: x[0], reverse=True) | |
# Start the SVG inside a div | |
svg_parts = [f'<div style="width: {exp_length}px; height: 450px; background-color: white;">', | |
"<svg width=\"100%\" height=\"100%\">"] | |
# Add a row for each visible bird part | |
y_offset = 0 | |
for score, desc in descriptions: | |
# Calculate the length of the bar (scaled to fit within the SVG) | |
part_score = max(score, 0) | |
bar_length = part_score * exp_length | |
# Split the description into two lines if it's too long | |
num_lines = len(desc) // MAX_LENGTH + 1 | |
for line in range(num_lines): | |
desc_line = desc[line*MAX_LENGTH:(line+1)*MAX_LENGTH] | |
y_offset += fontsize | |
svg_parts.append(f""" | |
<text x="0" y="{y_offset}" font-size="{fontsize}" fill="black"> | |
{desc_line} | |
</text> | |
""") | |
# Add the bar | |
svg_parts.append(f""" | |
<rect x="0" y="{y_offset+3}" width="{bar_length}" height="{fontsize*0.7}" fill="{SACHIT_COLOR}"> | |
</rect> | |
""") | |
# Add the score | |
svg_parts.append(f'<text x="{exp_length - 50}" y="{y_offset+fontsize+3}" font-size="fontsize" fill="{SACHIT_COLOR}">{part_score:.2f}</text>') # Added fill color | |
y_offset += fontsize + 3 | |
svg_parts.extend(("</svg>", "</div>")) | |
# Join everything into a single string | |
html = "".join(svg_parts) | |
return html | |
# --- Constants created by the functions above --- | |
BLANK_OVERLAY = img_to_base64(create_blank_image()) | |
PART_COLORS = {part: rgb_to_hex(COLORS[i]) for i, part in enumerate(ORDERED_PARTS)} | |
blank_image = np.array(Image.open('data/images/final.png').convert('RGB')) | |
PART_IMAGES_DICT = {file_name: load_part_images(file_name) for file_name in IMAGE_FILE_LIST} | |
# --- Gradio Functions --- | |
def update_selected_image(event: gr.SelectData): | |
image_height = 400 | |
index = event.index | |
image_name = IMAGE_FILE_LIST[index] | |
current_image.state = image_name | |
org_image = Image.open(os.path.join(IMAGES_FOLDER, 'org', image_name)).convert('RGB') | |
img_base64 = f""" | |
<div style="position: relative; height: {image_height}px; display: inline-block;"> | |
<img id="birdImage" src="data:image/jpeg;base64,{img_to_base64(org_image)}" style="height: {image_height}px; width: auto;"> | |
<img id="overlayImage" src="data:image/jpeg;base64,{BLANK_OVERLAY}" style="position:absolute; top:0; left:0; width:auto; height: {image_height}px; opacity: 0;"> | |
</div> | |
""" | |
gt_label = XCLIP_RESULTS[image_name]['ground_truth'] | |
gt_class.state = gt_label | |
# --- for initial value only --- | |
out_dict = xclip_pred(new_desc=None, | |
new_part_mask=None, | |
new_class=None, | |
org_desc=XCLIP_DESC_PATH, | |
image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), | |
model=XCLIP, | |
owlvit_processor=OWLVIT_PRECESSOR, | |
device=DEVICE, | |
image_name=current_image.state, | |
cub_embeds=CUB_DESC_EMBEDS, | |
cub_idx2name=CUB_IDX2NAME, | |
descriptors=XCLIP_DESC) | |
xclip_label = out_dict['pred_class'] | |
clip_pred_scores = out_dict['pred_score'] | |
xclip_part_scores = out_dict['pred_desc_scores'] | |
result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state} | |
xclip_exp = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask=dict(zip(ORDERED_PARTS, [1]*12))) | |
# --- end of intial value --- | |
xclip_color = "green" if xclip_label.strip() == gt_label.strip() else "red" | |
xclip_pred_markdown = f""" | |
### <span style='color:{xclip_color}'>XCLIP: {xclip_label} {clip_pred_scores:.4f}</span> | |
""" | |
gt_label = f""" | |
## {gt_label} | |
""" | |
current_predicted_class.state = xclip_label | |
# Populate the textbox with current descriptions | |
custom_class_name = "class name: custom" | |
descs = XCLIP_DESC[xclip_label] | |
descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)} | |
descs = {k: descs[k] for k in ORDERED_PARTS} | |
custom_text = [custom_class_name] + list(descs.values()) | |
descriptions = ";\n".join(custom_text) | |
textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False) | |
# modified_exp = gr.HTML().update(value="", visible=True) | |
return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox | |
def on_edit_button_click_xclip(): | |
empty_exp = gr.HTML.update(visible=False) | |
# Populate the textbox with current descriptions | |
descs = XCLIP_DESC[current_predicted_class.state] | |
descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)} | |
descs = {k: descs[k] for k in ORDERED_PARTS} | |
custom_text = ["class name: custom"] + list(descs.values()) | |
descriptions = ";\n".join(custom_text) | |
textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False) | |
return textbox, empty_exp | |
def convert_input_text_to_xclip_format(textbox_input: str): | |
# Split the descriptions by newline to get individual descriptions for each part | |
descriptions_list = textbox_input.split(";\n") | |
# the first line should be "class name: xxx" | |
class_name_line = descriptions_list[0] | |
new_class_name = class_name_line.split(":")[1].strip() | |
descriptions_list = descriptions_list[1:] | |
# construct descripion dict with part name as key | |
descriptions_dict = {} | |
for desc in descriptions_list: | |
if desc.strip() == "": | |
continue | |
part_name, _ = desc.split(":") | |
descriptions_dict[part_name.strip()] = desc | |
# fill with empty string if the part is not in the descriptions | |
part_mask = {} | |
for part in ORDERED_PARTS: | |
if part not in descriptions_dict: | |
descriptions_dict[part] = "" | |
part_mask[part] = 0 | |
else: | |
part_mask[part] = 1 | |
return descriptions_dict, part_mask, new_class_name | |
def on_predict_button_click_xclip(textbox_input: str): | |
descriptions_dict, part_mask, new_class_name = convert_input_text_to_xclip_format(textbox_input) | |
# Get the new predictions and explanations | |
out_dict = xclip_pred(new_desc=descriptions_dict, | |
new_part_mask=part_mask, | |
new_class=new_class_name, | |
org_desc=XCLIP_DESC_PATH, | |
image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), | |
model=XCLIP, | |
owlvit_processor=OWLVIT_PRECESSOR, | |
device=DEVICE, | |
image_name=current_image.state, | |
cub_embeds=CUB_DESC_EMBEDS, | |
cub_idx2name=CUB_IDX2NAME, | |
descriptors=XCLIP_DESC) | |
xclip_label = out_dict['pred_class'] | |
xclip_pred_score = out_dict['pred_score'] | |
xclip_part_scores = out_dict['pred_desc_scores'] | |
custom_label = out_dict['modified_class'] | |
custom_pred_score = out_dict['modified_score'] | |
custom_part_scores = out_dict['modified_desc_scores'] | |
# construct a result dict to generate xclip explanations | |
result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state} | |
xclip_explanation = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask) | |
modified_result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["modified_descriptions"])), 'pred_scores': custom_part_scores, 'file_name': current_image.state} | |
modified_explanation = generate_xclip_explanations(modified_result_dict, VISIBILITY_DICT[current_image.state], part_mask) | |
xclip_color = "green" if xclip_label.strip() == gt_class.state.strip() else "red" | |
xclip_pred_markdown = f""" | |
### <span style='color:{xclip_color}'> {xclip_label} {xclip_pred_score:.4f}</span> | |
""" | |
custom_color = "green" if custom_label.strip() == gt_class.state.strip() else "red" | |
custom_pred_markdown = f""" | |
### <span style='color:{custom_color}'> {custom_label} {custom_pred_score:.4f}</span> | |
""" | |
textbox = gr.Textbox.update(visible=False) | |
# return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_explanation | |
modified_exp = gr.HTML().update(value=modified_explanation, visible=True) | |
return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp | |
custom_css = """ | |
html, body { | |
margin: 0; | |
padding: 0; | |
} | |
#container { | |
position: relative; | |
width: 400px; | |
height: 400px; | |
border: 1px solid #000; | |
margin: 0 auto; /* This will center the container horizontally */ | |
} | |
#canvas { | |
position: absolute; | |
top: 0; | |
left: 0; | |
width: 100%; | |
height: 100%; | |
object-fit: cover; | |
} | |
""" | |
# Define the Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PEEB") as demo: | |
current_image = gr.State("") | |
current_predicted_class = gr.State("") | |
gt_class = gr.State("") | |
with gr.Column(): | |
title_text = gr.Markdown("# PEEB - demo") | |
gr.Markdown( | |
""" | |
- In this demo a demo for PEEB paper (NAACL finding 2024). | |
- paper: https://arxiv.org/abs/2403.05297 | |
- code: https://github.com/anguyen8/peeb/tree/inspect_ddp | |
""" | |
) | |
# display the gallery of images | |
with gr.Column(): | |
gr.Markdown("## Select an image to start!") | |
image_gallery = gr.Gallery(value=IMAGE_GALLERY, label=None, preview=False, allow_preview=False, columns=10, height=250) | |
gr.Markdown("### Custom descritions: \n The first row should be **class name: {some name};**, where you can name your descriptions. \n For the remianing descriptions, please use **;** to separate the descriptions for each part, and use the format **{part name}: {descriptions}**. \n Note that you can delete a part completely, in such cases, all descriptions will remove the corresponding part.") | |
with gr.Row(): | |
with gr.Column(): | |
image_label = gr.Markdown("### Class Name") | |
org_image = gr.HTML() | |
with gr.Column(): | |
with gr.Row(): | |
# xclip_predict_button = gr.Button(label="Predict", value="Predict") | |
xclip_predict_button = gr.Button(value="Predict") | |
xclip_pred_label = gr.Markdown("### XCLIP:") | |
xclip_explanation = gr.HTML() | |
with gr.Column(): | |
# xclip_edit_button = gr.Button(label="Edit", value="Reset Descriptions") | |
xclip_edit_button = gr.Button(value="Reset Descriptions") | |
custom_pred_label = gr.Markdown( | |
"### Custom Descritpions:" | |
) | |
xclip_textbox = gr.Textbox(lines=12, placeholder="Edit the descriptions here", visible=False) | |
# ai_explanation = gr.Image(type="numpy", visible=True, show_label=False, height=500) | |
custom_explanation = gr.HTML() | |
gr.HTML("<br>") | |
image_gallery.select(update_selected_image, inputs=None, outputs=[image_label, org_image, xclip_pred_label, xclip_explanation, current_image, xclip_textbox]) | |
xclip_edit_button.click(on_edit_button_click_xclip, inputs=[], outputs=[xclip_textbox, custom_explanation]) | |
xclip_predict_button.click(on_predict_button_click_xclip, inputs=[xclip_textbox], outputs=[xclip_textbox, xclip_pred_label, xclip_explanation, custom_pred_label, custom_explanation]) | |
demo.launch(server_port=5000, share=True) | |
# demo.launch() |