Spaces:
Build error
Build error
File size: 5,245 Bytes
8fbac9e 2145817 7aa3400 8fbac9e 2145817 83700f0 2145817 ac86774 8fbac9e 504c2ad 8fbac9e 9eae628 8fbac9e 5cbb0f4 58c2805 7aa3400 5cbb0f4 ac86774 8fbac9e 19d656f 2145817 8fbac9e 2145817 8fbac9e 19d656f 7aa3400 8fbac9e 2145817 f6a8ecc 2145817 8fbac9e a9023d8 8fbac9e ac86774 504c2ad 9eae628 504c2ad ac86774 2145817 ac86774 8fbac9e 7aa3400 8fbac9e 2145817 8fbac9e 9f9014d 7593627 2145817 83700f0 8fbac9e a9568f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import gradio as gr
import numpy as np
from CLIP.clip import ClipWrapper, saliency_configs
from time import time
from matplotlib import pyplot as plt
import io
from PIL import Image, ImageDraw, ImageFont
import matplotlib
matplotlib.use("Agg")
tag = """
<script async src="https://www.googletagmanager.com/gtag/js?id=G-T5BQ1GP083"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'G-T5BQ1GP083');
</script>"""
def plot_to_png(fig):
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
img = np.array(Image.open(buf)).astype(np.uint8)
return img
def add_text_to_image(
image: np.ndarray,
text,
position,
color="rgb(255, 255, 255)",
fontsize=60,
):
image = Image.fromarray(image)
draw = ImageDraw.Draw(image)
draw.text(
position,
text,
fill=color,
font=ImageFont.truetype(
"/usr/share/fonts/truetype/lato/Lato-Medium.ttf", fontsize
),
)
return np.array(image)
def generate_relevancy(
img: np.array, labels: str, prompt: str, saliency_config: str, subtract_mean: bool
):
labels = labels.split(",")
if len(labels) > 32:
labels = labels[:32]
prompts = [prompt]
resize_shape = np.array(img.shape[:2])
resize_shape = tuple(
((resize_shape / resize_shape.max()) * 224 * 4).astype(int).tolist()
)
img = np.asarray(Image.fromarray(img).resize(resize_shape))
assert img.dtype == np.uint8
h, w, c = img.shape
start = time()
try:
grads = ClipWrapper.get_clip_saliency(
img=img,
text_labels=np.array(labels),
prompts=prompts,
**saliency_configs[saliency_config](h),
)[0]
except Exception as e:
print(e)
return (
[img],
tag,
)
print("inference took", float(time() - start))
if subtract_mean:
grads -= grads.mean(axis=0)
grads = grads.cpu().numpy()
vmin = 0.002
cmap = plt.get_cmap("jet")
vmax = 0.008
returns = []
for label_grad, label in zip(grads, labels):
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
ax.axis("off")
ax.imshow(img)
grad = np.clip((label_grad - vmin) / (vmax - vmin), a_min=0.0, a_max=1.0)
colored_grad = cmap(grad)
grad = 1 - grad
colored_grad[..., -1] = grad * 0.7
colored_grad = add_text_to_image(
(colored_grad * 255).astype(np.uint8), text=label, position=(0, 0)
)
colored_grad = colored_grad.astype(float) / 255
ax.imshow(colored_grad)
plt.tight_layout(pad=0)
returns.append(plot_to_png(fig))
plt.close(fig)
return (
returns,
tag,
)
iface = gr.Interface(
title="Semantic Abstraction Multi-scale Relevancy Extractor",
description="""A CPU-only demo of [Semantic Abstraction](https://semantic-abstraction.cs.columbia.edu/)'s Multi-Scale Relevancy Extractor. To run GPU inference locally, use the [official codebase release](https://github.com/columbia-ai-robotics/semantic-abstraction).
This relevancy extractor builds heavily on [Chefer et al.'s codebase](https://github.com/hila-chefer/Transformer-MM-Explainability) and [CLIP on Wheels' codebase](https://cow.cs.columbia.edu/).""",
fn=generate_relevancy,
cache_examples=True,
inputs=[
gr.Image(type="numpy", label="Image"),
gr.Textbox(label="Labels (comma separated without spaces in between)"),
gr.Textbox(
label="Prompt. (Make sure to include '{}' in the prompt like examples below)"
),
gr.Dropdown(
value="ours",
choices=["ours", "ours_fast", "chefer_et_al"],
label="Relevancy Configuration",
),
gr.Checkbox(value=True, label="subtract mean"),
],
outputs=[
gr.Gallery(label="Relevancy Maps", type="numpy"),
gr.HTML(value=tag),
],
examples=[
[
"https://semantic-abstraction.cs.columbia.edu/downloads/gameroom.png",
"basketball jersey,nintendo switch,television,ping pong table,vase,fireplace,abstract painting of a vespa,carpet,wall",
"a photograph of a {} in a home.",
"ours_fast",
True,
],
[
"https://semantic-abstraction.cs.columbia.edu/downloads/livingroom.png",
"monopoly boardgame set,door knob,sofa,coffee table,plant,carpet,wall",
"a photograph of a {} in a home.",
"ours_fast",
True,
],
[
"https://semantic-abstraction.cs.columbia.edu/downloads/fireplace.png",
"fireplace,beige armchair,candle,large indoor plant in a pot,forest painting,cheetah-patterned pillow,floor,carpet,wall",
"a photograph of a {} in a home.",
"ours_fast",
True,
],
[
"https://semantic-abstraction.cs.columbia.edu/downloads/walle.png",
"WALL-E,a fire extinguisher",
"a 3D render of {}.",
"ours_fast",
True,
],
],
)
iface.launch()
|