Samuel Stevens
commited on
Commit
·
c0b4385
1
Parent(s):
29d1b06
Updates based on lab feedback
Browse files- app.py +140 -89
- requirements.txt +5 -5
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import functools
|
2 |
import io
|
3 |
import json
|
@@ -6,11 +7,11 @@ import math
|
|
6 |
import os
|
7 |
import pathlib
|
8 |
import random
|
|
|
9 |
|
10 |
import beartype
|
11 |
import einops.layers.torch
|
12 |
import gradio as gr
|
13 |
-
import matplotlib
|
14 |
import numpy as np
|
15 |
import open_clip
|
16 |
import requests
|
@@ -36,7 +37,7 @@ DEBUG = False
|
|
36 |
n_sae_latents = 5
|
37 |
"""Number of SAE latents to show."""
|
38 |
|
39 |
-
|
40 |
"""Number of SAE examples per latent to show."""
|
41 |
|
42 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -52,14 +53,44 @@ max_frequency = 1e-1
|
|
52 |
"""Maximum frequency. Any feature that fires more than this is ignored."""
|
53 |
|
54 |
CWD = pathlib.Path(__file__).parent
|
|
|
|
|
55 |
|
56 |
r2_url = "https://pub-289086e849214430853bc87bd8964988.r2.dev/"
|
57 |
|
58 |
-
colormap = matplotlib.colormaps.get_cmap("plasma")
|
59 |
|
60 |
logger.info("Set global constants.")
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
###########
|
64 |
# Helpers #
|
65 |
###########
|
@@ -102,26 +133,31 @@ def get_dataset_img(i: int) -> Image.Image:
|
|
102 |
|
103 |
|
104 |
@beartype.beartype
|
105 |
-
def
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
(
|
118 |
-
(
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
|
127 |
##########
|
@@ -209,7 +245,7 @@ logger.info("Loaded SAE.")
|
|
209 |
############
|
210 |
|
211 |
human_transform = transforms.Compose([
|
212 |
-
transforms.Resize(
|
213 |
transforms.CenterCrop((448, 448)),
|
214 |
transforms.ToTensor(),
|
215 |
einops.layers.torch.Rearrange("channels width height -> width height channels"),
|
@@ -226,7 +262,7 @@ with open(CWD / "data" / "image_fpaths.json") as fd:
|
|
226 |
|
227 |
|
228 |
with open(CWD / "data" / "image_labels.json") as fd:
|
229 |
-
|
230 |
|
231 |
|
232 |
logger.info("Loaded all datasets.")
|
@@ -256,40 +292,41 @@ mask = mask & (sparsity < max_frequency)
|
|
256 |
|
257 |
|
258 |
@beartype.beartype
|
259 |
-
def
|
260 |
-
|
261 |
-
|
262 |
-
return
|
263 |
-
Image.fromarray((
|
264 |
-
|
265 |
-
|
266 |
|
267 |
|
268 |
@beartype.beartype
|
269 |
-
def
|
270 |
-
indices = [i for i, tgt in enumerate(
|
271 |
i = random.choice(indices)
|
272 |
|
273 |
-
|
274 |
-
|
275 |
-
return
|
|
|
|
|
|
|
276 |
|
277 |
|
278 |
@torch.inference_mode
|
279 |
-
def
|
280 |
-
image_i: int, patches: list[int]
|
281 |
-
) -> list[None | Image.Image | int]:
|
282 |
"""
|
283 |
Given a particular cell, returns some highlighted images showing what feature fires most on this cell.
|
284 |
"""
|
285 |
if not patches:
|
286 |
-
return [
|
287 |
|
288 |
logger.info("Getting SAE examples for patches %s.", patches)
|
289 |
|
290 |
-
img = get_dataset_img(
|
291 |
-
|
292 |
-
x_BPD = split_vit.forward_start(
|
293 |
# Need to add 1 to account for [CLS] token.
|
294 |
vit_acts_MD = x_BPD[0, [p + 1 for p in patches]].to(device)
|
295 |
|
@@ -299,15 +336,19 @@ def get_sae_examples(
|
|
299 |
latents = torch.argsort(f_x_S, descending=True).cpu()
|
300 |
latents = latents[mask[latents]][:n_sae_latents].tolist()
|
301 |
|
302 |
-
|
303 |
for latent in latents:
|
304 |
-
|
305 |
for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]):
|
306 |
if i_im in seen_i_im:
|
307 |
continue
|
308 |
|
309 |
example_img = get_dataset_img(i_im)
|
310 |
-
|
|
|
|
|
|
|
|
|
311 |
seen_i_im.add(i_im)
|
312 |
|
313 |
# How to scale values.
|
@@ -315,17 +356,24 @@ def get_sae_examples(
|
|
315 |
if top_values[latent].numel() > 0:
|
316 |
upper = top_values[latent].max().item()
|
317 |
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
|
326 |
-
|
327 |
|
328 |
-
return
|
329 |
|
330 |
|
331 |
@torch.inference_mode
|
@@ -434,7 +482,8 @@ def add_highlights(
|
|
434 |
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
435 |
draw = ImageDraw.Draw(overlay)
|
436 |
|
437 |
-
colors = (
|
|
|
438 |
|
439 |
for p, (val, color) in enumerate(zip(patches, colors)):
|
440 |
assert upper is not None
|
@@ -458,43 +507,45 @@ def add_highlights(
|
|
458 |
|
459 |
|
460 |
with gr.Blocks() as demo:
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
)
|
|
|
|
|
|
|
|
|
|
|
471 |
get_random_class_image_btn = gr.Button(value="Get Random Class Image")
|
472 |
-
|
473 |
-
|
474 |
-
inputs=[
|
475 |
-
outputs=[
|
476 |
-
api_name="get-random-class-
|
477 |
)
|
478 |
|
479 |
patch_numbers = gr.CheckboxGroup(
|
480 |
label="Image Patch", choices=list(range(n_patches_per_img))
|
481 |
)
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
gr.Image(label=f"Latent #{j}, Example #{i + 1}")
|
489 |
-
for i in range(n_sae_examples)
|
490 |
-
for j in range(n_sae_latents)
|
491 |
-
]
|
492 |
-
get_sae_examples_btn = gr.Button(value="Get SAE Examples")
|
493 |
-
get_sae_examples_btn.click(
|
494 |
-
get_sae_examples,
|
495 |
-
inputs=[image_number, patch_numbers],
|
496 |
-
outputs=sae_example_images + top_latent_numbers,
|
497 |
-
api_name="get-sae-examples",
|
498 |
concurrency_limit=16,
|
499 |
)
|
500 |
|
@@ -502,7 +553,7 @@ with gr.Blocks() as demo:
|
|
502 |
get_pred_dist_btn = gr.Button(value="Get Pred. Distribution")
|
503 |
get_pred_dist_btn.click(
|
504 |
get_pred_dist,
|
505 |
-
inputs=[
|
506 |
outputs=[pred_dist],
|
507 |
api_name="get-preds",
|
508 |
)
|
@@ -514,7 +565,7 @@ with gr.Blocks() as demo:
|
|
514 |
get_modified_dist_btn = gr.Button(value="Get Modified Label")
|
515 |
get_modified_dist_btn.click(
|
516 |
get_modified_dist,
|
517 |
-
inputs=[
|
518 |
outputs=[pred_dist],
|
519 |
api_name="get-modified",
|
520 |
concurrency_limit=16,
|
|
|
1 |
+
import base64
|
2 |
import functools
|
3 |
import io
|
4 |
import json
|
|
|
7 |
import os
|
8 |
import pathlib
|
9 |
import random
|
10 |
+
import typing
|
11 |
|
12 |
import beartype
|
13 |
import einops.layers.torch
|
14 |
import gradio as gr
|
|
|
15 |
import numpy as np
|
16 |
import open_clip
|
17 |
import requests
|
|
|
37 |
n_sae_latents = 5
|
38 |
"""Number of SAE latents to show."""
|
39 |
|
40 |
+
n_latent_examples = 4
|
41 |
"""Number of SAE examples per latent to show."""
|
42 |
|
43 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
53 |
"""Maximum frequency. Any feature that fires more than this is ignored."""
|
54 |
|
55 |
CWD = pathlib.Path(__file__).parent
|
56 |
+
"""Current working directory."""
|
57 |
+
|
58 |
|
59 |
r2_url = "https://pub-289086e849214430853bc87bd8964988.r2.dev/"
|
60 |
|
|
|
61 |
|
62 |
logger.info("Set global constants.")
|
63 |
|
64 |
|
65 |
+
@beartype.beartype
|
66 |
+
class Example(typing.TypedDict):
|
67 |
+
"""Represents an example image and its associated label.
|
68 |
+
|
69 |
+
Used to store examples of SAE latent activations for visualization.
|
70 |
+
"""
|
71 |
+
|
72 |
+
orig_url: str
|
73 |
+
"""The URL or path to access the original example image."""
|
74 |
+
highlighted_url: typing.NotRequired[str]
|
75 |
+
"""The URL or path to access the SAE-highlighted image."""
|
76 |
+
target: int
|
77 |
+
"""Class ID."""
|
78 |
+
|
79 |
+
|
80 |
+
@beartype.beartype
|
81 |
+
class SaeLatent(typing.TypedDict):
|
82 |
+
"""Represents a single SAE latent."""
|
83 |
+
|
84 |
+
latent: int
|
85 |
+
"""The index of the SAE latent being measured."""
|
86 |
+
|
87 |
+
highlighted_url: str
|
88 |
+
"""The image with the colormaps applied."""
|
89 |
+
|
90 |
+
examples: list[Example]
|
91 |
+
"""Top examples for this latent."""
|
92 |
+
|
93 |
+
|
94 |
###########
|
95 |
# Helpers #
|
96 |
###########
|
|
|
133 |
|
134 |
|
135 |
@beartype.beartype
|
136 |
+
def to_sized(img: Image.Image) -> Image.Image:
|
137 |
+
# Copied from contrib/classification/transforms.py:for_webapp()
|
138 |
+
w, h = img.size
|
139 |
+
if w > h:
|
140 |
+
resize_w = int(w * 512 / h)
|
141 |
+
resize_px = (resize_w, 512)
|
142 |
+
|
143 |
+
margin_x = (resize_w - 448) // 2
|
144 |
+
crop_px = (margin_x, 32, 448 + margin_x, 480)
|
145 |
+
else:
|
146 |
+
resize_h = int(h * 512 / w)
|
147 |
+
resize_px = (512, resize_h)
|
148 |
+
margin_y = (resize_h - 448) // 2
|
149 |
+
crop_px = (32, margin_y, 480, 448 + margin_y)
|
150 |
+
|
151 |
+
return img.resize(resize_px, resample=Image.Resampling.BICUBIC).crop(crop_px)
|
152 |
+
|
153 |
+
|
154 |
+
@beartype.beartype
|
155 |
+
def img_to_base64(img: Image.Image) -> str:
|
156 |
+
buf = io.BytesIO()
|
157 |
+
img.save(buf, format="webp", lossless=True)
|
158 |
+
b64 = base64.b64encode(buf.getvalue())
|
159 |
+
s64 = b64.decode("utf8")
|
160 |
+
return "data:image/webp;base64," + s64
|
161 |
|
162 |
|
163 |
##########
|
|
|
245 |
############
|
246 |
|
247 |
human_transform = transforms.Compose([
|
248 |
+
transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
|
249 |
transforms.CenterCrop((448, 448)),
|
250 |
transforms.ToTensor(),
|
251 |
einops.layers.torch.Rearrange("channels width height -> width height channels"),
|
|
|
262 |
|
263 |
|
264 |
with open(CWD / "data" / "image_labels.json") as fd:
|
265 |
+
img_labels = json.load(fd)
|
266 |
|
267 |
|
268 |
logger.info("Loaded all datasets.")
|
|
|
292 |
|
293 |
|
294 |
@beartype.beartype
|
295 |
+
def get_img(img_i: int) -> Example:
|
296 |
+
img = get_dataset_img(img_i)
|
297 |
+
img = human_transform(img)
|
298 |
+
return {
|
299 |
+
"orig_url": img_to_base64(Image.fromarray((img * 255).to(torch.uint8).numpy())),
|
300 |
+
"target": img_labels[img_i],
|
301 |
+
}
|
302 |
|
303 |
|
304 |
@beartype.beartype
|
305 |
+
def get_random_class_img(cls: int) -> Example:
|
306 |
+
indices = [i for i, tgt in enumerate(img_labels) if tgt == cls]
|
307 |
i = random.choice(indices)
|
308 |
|
309 |
+
img = get_dataset_img(i)
|
310 |
+
img = human_transform(img)
|
311 |
+
return {
|
312 |
+
"orig_url": img_to_base64(Image.fromarray((img * 255).to(torch.uint8).numpy())),
|
313 |
+
"target": cls,
|
314 |
+
}
|
315 |
|
316 |
|
317 |
@torch.inference_mode
|
318 |
+
def get_sae_latents(img_i: int, patches: list[int]) -> list[SaeLatent]:
|
|
|
|
|
319 |
"""
|
320 |
Given a particular cell, returns some highlighted images showing what feature fires most on this cell.
|
321 |
"""
|
322 |
if not patches:
|
323 |
+
return []
|
324 |
|
325 |
logger.info("Getting SAE examples for patches %s.", patches)
|
326 |
|
327 |
+
img = get_dataset_img(img_i)
|
328 |
+
x_BCWH = vit_transform(img)[None, ...].to(device)
|
329 |
+
x_BPD = split_vit.forward_start(x_BCWH)
|
330 |
# Need to add 1 to account for [CLS] token.
|
331 |
vit_acts_MD = x_BPD[0, [p + 1 for p in patches]].to(device)
|
332 |
|
|
|
336 |
latents = torch.argsort(f_x_S, descending=True).cpu()
|
337 |
latents = latents[mask[latents]][:n_sae_latents].tolist()
|
338 |
|
339 |
+
sae_latents = []
|
340 |
for latent in latents:
|
341 |
+
intermediates, seen_i_im = [], set()
|
342 |
for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]):
|
343 |
if i_im in seen_i_im:
|
344 |
continue
|
345 |
|
346 |
example_img = get_dataset_img(i_im)
|
347 |
+
intermediates.append({
|
348 |
+
"img": example_img,
|
349 |
+
"patches": values_p,
|
350 |
+
"target": img_labels[i_im],
|
351 |
+
})
|
352 |
seen_i_im.add(i_im)
|
353 |
|
354 |
# How to scale values.
|
|
|
356 |
if top_values[latent].numel() > 0:
|
357 |
upper = top_values[latent].max().item()
|
358 |
|
359 |
+
examples = []
|
360 |
+
for intermediate in intermediates[:n_latent_examples]:
|
361 |
+
img_sized = to_sized(intermediate["img"])
|
362 |
+
examples.append({
|
363 |
+
"orig_url": img_to_base64(img_sized),
|
364 |
+
"highlighted_url": img_to_base64(
|
365 |
+
add_highlights(
|
366 |
+
img_sized,
|
367 |
+
intermediate["patches"].to(float).numpy(),
|
368 |
+
upper=upper,
|
369 |
+
)
|
370 |
+
),
|
371 |
+
"target": intermediate["target"],
|
372 |
+
})
|
373 |
|
374 |
+
sae_latents.append({"latent": latent, "examples": examples})
|
375 |
|
376 |
+
return sae_latents
|
377 |
|
378 |
|
379 |
@torch.inference_mode
|
|
|
482 |
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
483 |
draw = ImageDraw.Draw(overlay)
|
484 |
|
485 |
+
colors = np.zeros((len(patches), 3), dtype=np.uint8)
|
486 |
+
colors[:, 0] = ((patches / (upper + 1e-9)) * 255).astype(np.uint8)
|
487 |
|
488 |
for p, (val, color) in enumerate(zip(patches, colors)):
|
489 |
assert upper is not None
|
|
|
507 |
|
508 |
|
509 |
with gr.Blocks() as demo:
|
510 |
+
###########
|
511 |
+
# get-img #
|
512 |
+
###########
|
513 |
+
|
514 |
+
# Inputs
|
515 |
+
number = gr.Number(label="Number", precision=0)
|
516 |
+
|
517 |
+
# Outputs
|
518 |
+
json_out = gr.JSON(label="get_img_out", value={})
|
519 |
+
|
520 |
+
get_img_btn = gr.Button(value="Get Input Image")
|
521 |
+
get_img_btn.click(
|
522 |
+
get_img,
|
523 |
+
inputs=[number],
|
524 |
+
outputs=[json_out],
|
525 |
+
api_name="get-img",
|
526 |
)
|
527 |
+
|
528 |
+
########################
|
529 |
+
# get-random-class-img #
|
530 |
+
########################
|
531 |
+
|
532 |
get_random_class_image_btn = gr.Button(value="Get Random Class Image")
|
533 |
+
get_img_btn.click(
|
534 |
+
get_random_class_img,
|
535 |
+
inputs=[number],
|
536 |
+
outputs=[json_out],
|
537 |
+
api_name="get-random-class-img",
|
538 |
)
|
539 |
|
540 |
patch_numbers = gr.CheckboxGroup(
|
541 |
label="Image Patch", choices=list(range(n_patches_per_img))
|
542 |
)
|
543 |
+
get_sae_latents_btn = gr.Button(value="Get SAE Examples")
|
544 |
+
get_sae_latents_btn.click(
|
545 |
+
get_sae_latents,
|
546 |
+
inputs=[number, patch_numbers],
|
547 |
+
outputs=json_out,
|
548 |
+
api_name="get-sae-latents",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
549 |
concurrency_limit=16,
|
550 |
)
|
551 |
|
|
|
553 |
get_pred_dist_btn = gr.Button(value="Get Pred. Distribution")
|
554 |
get_pred_dist_btn.click(
|
555 |
get_pred_dist,
|
556 |
+
inputs=[number],
|
557 |
outputs=[pred_dist],
|
558 |
api_name="get-preds",
|
559 |
)
|
|
|
565 |
get_modified_dist_btn = gr.Button(value="Get Modified Label")
|
566 |
get_modified_dist_btn.click(
|
567 |
get_modified_dist,
|
568 |
+
inputs=[number, patch_numbers] + latent_numbers + value_sliders,
|
569 |
outputs=[pred_dist],
|
570 |
api_name="get-modified",
|
571 |
concurrency_limit=16,
|
requirements.txt
CHANGED
@@ -47,7 +47,7 @@ contourpy==1.3.1
|
|
47 |
# via matplotlib
|
48 |
cycler==0.12.1
|
49 |
# via matplotlib
|
50 |
-
datasets==3.3.
|
51 |
# via saev
|
52 |
dill==0.3.8
|
53 |
# via
|
@@ -131,7 +131,7 @@ jsonschema-specifications==2024.10.1
|
|
131 |
# via jsonschema
|
132 |
kiwisolver==1.4.8
|
133 |
# via matplotlib
|
134 |
-
marimo==0.11.
|
135 |
# via saev
|
136 |
markdown==3.7
|
137 |
# via
|
@@ -155,7 +155,7 @@ multidict==6.1.0
|
|
155 |
# yarl
|
156 |
multiprocess==0.70.16
|
157 |
# via datasets
|
158 |
-
narwhals==1.
|
159 |
# via
|
160 |
# altair
|
161 |
# marimo
|
@@ -297,7 +297,7 @@ ruff==0.9.6
|
|
297 |
# via
|
298 |
# gradio
|
299 |
# marimo
|
300 |
-
saev @ git+https://github.com/samuelstevens/saev@
|
301 |
# via saev-image-classification (pyproject.toml)
|
302 |
safehttpx==0.1.6
|
303 |
# via gradio
|
@@ -352,7 +352,7 @@ tqdm==4.67.1
|
|
352 |
# saev
|
353 |
triton==3.2.0
|
354 |
# via torch
|
355 |
-
typeguard==4.4.
|
356 |
# via tyro
|
357 |
typer==0.15.1
|
358 |
# via gradio
|
|
|
47 |
# via matplotlib
|
48 |
cycler==0.12.1
|
49 |
# via matplotlib
|
50 |
+
datasets==3.3.1
|
51 |
# via saev
|
52 |
dill==0.3.8
|
53 |
# via
|
|
|
131 |
# via jsonschema
|
132 |
kiwisolver==1.4.8
|
133 |
# via matplotlib
|
134 |
+
marimo==0.11.6
|
135 |
# via saev
|
136 |
markdown==3.7
|
137 |
# via
|
|
|
155 |
# yarl
|
156 |
multiprocess==0.70.16
|
157 |
# via datasets
|
158 |
+
narwhals==1.27.1
|
159 |
# via
|
160 |
# altair
|
161 |
# marimo
|
|
|
297 |
# via
|
298 |
# gradio
|
299 |
# marimo
|
300 |
+
saev @ git+https://github.com/samuelstevens/saev@298cabdb6b771c76b402d0fdddab6907d1941d7a
|
301 |
# via saev-image-classification (pyproject.toml)
|
302 |
safehttpx==0.1.6
|
303 |
# via gradio
|
|
|
352 |
# saev
|
353 |
triton==3.2.0
|
354 |
# via torch
|
355 |
+
typeguard==4.4.2
|
356 |
# via tyro
|
357 |
typer==0.15.1
|
358 |
# via gradio
|