Spaces:
Build error
Build error
import gradio as gr | |
import numpy as np | |
from CLIP.clip import ClipWrapper, saliency_configs | |
from time import time | |
from matplotlib import pyplot as plt | |
import io | |
from PIL import Image, ImageDraw, ImageFont | |
def plot_to_png(fig): | |
buf = io.BytesIO() | |
plt.savefig(buf, format="png") | |
buf.seek(0) | |
img = np.array(Image.open(buf)).astype(np.uint8) | |
return img | |
def add_text_to_image( | |
image: np.ndarray, | |
text, | |
position, | |
color="rgb(255, 255, 255)", | |
fontsize=60, | |
): | |
image = Image.fromarray(image) | |
draw = ImageDraw.Draw(image) | |
draw.text( | |
position, | |
text, | |
fill=color, | |
font=ImageFont.truetype( | |
"/usr/share/fonts/truetype/lato/Lato-Medium.ttf", fontsize | |
), | |
) | |
return np.array(image) | |
def generate_relevancy( | |
img: np.array, labels: str, prompt: str, saliency_config: str, subtract_mean: bool | |
): | |
labels = labels.split(",") | |
prompts = [prompt] | |
img = np.asarray(Image.fromarray(img).resize((244 * 4, 244 * 4))) | |
assert img.dtype == np.uint8 | |
h, w, c = img.shape | |
start = time() | |
grads = ClipWrapper.get_clip_saliency( | |
img=img, | |
text_labels=np.array(labels), | |
prompts=prompts, | |
**saliency_configs[saliency_config](h), | |
)[0] | |
print("inference took", float(time() - start)) | |
if subtract_mean: | |
grads -= grads.mean(axis=0) | |
grads = grads.cpu().numpy() | |
vmin = 0.002 | |
cmap = plt.get_cmap("jet") | |
vmax = 0.008 | |
returns = [] | |
for label_grad, label in zip(grads, labels): | |
fig, ax = plt.subplots(1, 1, figsize=(4, 4)) | |
ax.axis("off") | |
ax.imshow(img) | |
grad = np.clip((label_grad - vmin) / (vmax - vmin), a_min=0.0, a_max=1.0) | |
colored_grad = cmap(grad) | |
grad = 1 - grad | |
colored_grad[..., -1] = grad * 0.7 | |
colored_grad = add_text_to_image( | |
(colored_grad * 255).astype(np.uint8), text=label, position=(0, 0) | |
) | |
colored_grad = colored_grad.astype(float) / 255 | |
ax.imshow(colored_grad) | |
plt.tight_layout(pad=0) | |
returns.append(plot_to_png(fig)) | |
plt.close(fig) | |
return returns | |
iface = gr.Interface( | |
title="Semantic Abstraction Multi-scale Relevancy Extractor", | |
description="""A demo of [Semantic Abstraction](https://semantic-abstraction.cs.columbia.edu/)'s Multi-Scale Relevancy Extractor. To run GPU inference locally, use the [official codebase release](https://github.com/columbia-ai-robotics/semantic-abstraction). | |
This relevancy extractor builds heavily on [Chefer et al.'s codebase](https://github.com/hila-chefer/Transformer-MM-Explainability) and [CLIP on Wheels' codebase](https://cow.cs.columbia.edu/).""", | |
fn=generate_relevancy, | |
cache_examples=True, | |
inputs=[ | |
gr.Image(type="numpy", label="Image"), | |
gr.Textbox(label="Labels (comma separated)"), | |
gr.Textbox(label="Prompt"), | |
gr.Dropdown( | |
value="ours", | |
choices=["ours", "ours_fast", "chefer_et_al"], | |
label="Relevancy Configuration", | |
), | |
gr.Checkbox(value=True, label="subtract mean"), | |
], | |
outputs=gr.Gallery(label="Relevancy Maps", type="numpy"), | |
examples=[ | |
[ | |
"https://semantic-abstraction.cs.columbia.edu/downloads/gameroom.png", | |
"basketball jersey,nintendo switch,television,ping pong table,vase,fireplace,abstract painting of a vespa,carpet,wall", | |
"a photograph of a {} in a home.", | |
"ours_fast", | |
True, | |
], | |
[ | |
"https://semantic-abstraction.cs.columbia.edu/downloads/livingroom.png", | |
"monopoly boardgame set,door knob,sofa,coffee table,plant,carpet,wall", | |
"a photograph of a {} in a home.", | |
"ours_fast", | |
True, | |
], | |
[ | |
"https://semantic-abstraction.cs.columbia.edu/downloads/fireplace.png", | |
"fireplace,beige armchair,candle,large indoor plant in a pot,forest painting,cheetah-patterned pillow,floor,carpet,wall", | |
"a photograph of a {} in a home.", | |
"ours_fast", | |
True, | |
], | |
[ | |
"https://semantic-abstraction.cs.columbia.edu/downloads/walle.png", | |
"WALL-E,a fire extinguisher", | |
"a 3D render of {}.", | |
"ours_fast", | |
True, | |
], | |
], | |
) | |
# iface.launch(share=True) | |
iface.launch() | |