Samuel Stevens commited on
Commit
c0b4385
·
1 Parent(s): 29d1b06

Updates based on lab feedback

Browse files
Files changed (2) hide show
  1. app.py +140 -89
  2. requirements.txt +5 -5
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import functools
2
  import io
3
  import json
@@ -6,11 +7,11 @@ import math
6
  import os
7
  import pathlib
8
  import random
 
9
 
10
  import beartype
11
  import einops.layers.torch
12
  import gradio as gr
13
- import matplotlib
14
  import numpy as np
15
  import open_clip
16
  import requests
@@ -36,7 +37,7 @@ DEBUG = False
36
  n_sae_latents = 5
37
  """Number of SAE latents to show."""
38
 
39
- n_sae_examples = 4
40
  """Number of SAE examples per latent to show."""
41
 
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -52,14 +53,44 @@ max_frequency = 1e-1
52
  """Maximum frequency. Any feature that fires more than this is ignored."""
53
 
54
  CWD = pathlib.Path(__file__).parent
 
 
55
 
56
  r2_url = "https://pub-289086e849214430853bc87bd8964988.r2.dev/"
57
 
58
- colormap = matplotlib.colormaps.get_cmap("plasma")
59
 
60
  logger.info("Set global constants.")
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  ###########
64
  # Helpers #
65
  ###########
@@ -102,26 +133,31 @@ def get_dataset_img(i: int) -> Image.Image:
102
 
103
 
104
  @beartype.beartype
105
- def make_img(
106
- img: Image.Image,
107
- patches: Float[Tensor, " n_patches"],
108
- *,
109
- upper: int | None = None,
110
- ) -> Image.Image:
111
- # Resize to 256x256 and crop to 224x224
112
- resize_size_px = (512, 512)
113
- resize_w_px, resize_h_px = resize_size_px
114
- crop_size_px = (448, 448)
115
- crop_w_px, crop_h_px = crop_size_px
116
- crop_coords_px = (
117
- (resize_w_px - crop_w_px) // 2,
118
- (resize_h_px - crop_h_px) // 2,
119
- (resize_w_px + crop_w_px) // 2,
120
- (resize_h_px + crop_h_px) // 2,
121
- )
122
- img = img.resize(resize_size_px).crop(crop_coords_px)
123
- img = add_highlights(img, patches.numpy(), upper=upper, opacity=0.5)
124
- return img
 
 
 
 
 
125
 
126
 
127
  ##########
@@ -209,7 +245,7 @@ logger.info("Loaded SAE.")
209
  ############
210
 
211
  human_transform = transforms.Compose([
212
- transforms.Resize((448,), interpolation=transforms.InterpolationMode.BICUBIC),
213
  transforms.CenterCrop((448, 448)),
214
  transforms.ToTensor(),
215
  einops.layers.torch.Rearrange("channels width height -> width height channels"),
@@ -226,7 +262,7 @@ with open(CWD / "data" / "image_fpaths.json") as fd:
226
 
227
 
228
  with open(CWD / "data" / "image_labels.json") as fd:
229
- image_labels = json.load(fd)
230
 
231
 
232
  logger.info("Loaded all datasets.")
@@ -256,40 +292,41 @@ mask = mask & (sparsity < max_frequency)
256
 
257
 
258
  @beartype.beartype
259
- def get_image(image_i: int) -> list[Image.Image | int]:
260
- image = get_dataset_img(image_i)
261
- image = human_transform(image)
262
- return [
263
- Image.fromarray((image * 255).to(torch.uint8).numpy()),
264
- image_labels[image_i],
265
- ]
266
 
267
 
268
  @beartype.beartype
269
- def get_random_class_image(cls: int) -> Image.Image:
270
- indices = [i for i, tgt in enumerate(image_labels) if tgt == cls]
271
  i = random.choice(indices)
272
 
273
- image = get_dataset_img(i)
274
- image = human_transform(image)
275
- return Image.fromarray((image * 255).to(torch.uint8).numpy())
 
 
 
276
 
277
 
278
  @torch.inference_mode
279
- def get_sae_examples(
280
- image_i: int, patches: list[int]
281
- ) -> list[None | Image.Image | int]:
282
  """
283
  Given a particular cell, returns some highlighted images showing what feature fires most on this cell.
284
  """
285
  if not patches:
286
- return [None] * n_sae_latents * n_sae_examples + [-1] * n_sae_latents
287
 
288
  logger.info("Getting SAE examples for patches %s.", patches)
289
 
290
- img = get_dataset_img(image_i)
291
- x = vit_transform(img)[None, ...].to(device)
292
- x_BPD = split_vit.forward_start(x)
293
  # Need to add 1 to account for [CLS] token.
294
  vit_acts_MD = x_BPD[0, [p + 1 for p in patches]].to(device)
295
 
@@ -299,15 +336,19 @@ def get_sae_examples(
299
  latents = torch.argsort(f_x_S, descending=True).cpu()
300
  latents = latents[mask[latents]][:n_sae_latents].tolist()
301
 
302
- images = []
303
  for latent in latents:
304
- img_patch_pairs, seen_i_im = [], set()
305
  for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]):
306
  if i_im in seen_i_im:
307
  continue
308
 
309
  example_img = get_dataset_img(i_im)
310
- img_patch_pairs.append((example_img, values_p))
 
 
 
 
311
  seen_i_im.add(i_im)
312
 
313
  # How to scale values.
@@ -315,17 +356,24 @@ def get_sae_examples(
315
  if top_values[latent].numel() > 0:
316
  upper = top_values[latent].max().item()
317
 
318
- latent_images = [
319
- make_img(img, patches.to(float), upper=upper)
320
- for img, patches in img_patch_pairs[:n_sae_examples]
321
- ]
322
-
323
- while len(latent_images) < n_sae_examples:
324
- latent_images += [None]
 
 
 
 
 
 
 
325
 
326
- images.extend(latent_images)
327
 
328
- return images + latents
329
 
330
 
331
  @torch.inference_mode
@@ -434,7 +482,8 @@ def add_highlights(
434
  overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
435
  draw = ImageDraw.Draw(overlay)
436
 
437
- colors = (colormap(patches / (upper + 1e-9))[:, :3] * 255).astype(np.uint8)
 
438
 
439
  for p, (val, color) in enumerate(zip(patches, colors)):
440
  assert upper is not None
@@ -458,43 +507,45 @@ def add_highlights(
458
 
459
 
460
  with gr.Blocks() as demo:
461
- image_number = gr.Number(label="Test Example", precision=0)
462
- class_number = gr.Number(label="Test Class", precision=0)
463
- input_image = gr.Image(label="Input Image")
464
- get_input_image_btn = gr.Button(value="Get Input Image")
465
- get_input_image_btn.click(
466
- get_image,
467
- inputs=[image_number],
468
- outputs=[input_image, class_number],
469
- api_name="get-image",
 
 
 
 
 
 
 
470
  )
 
 
 
 
 
471
  get_random_class_image_btn = gr.Button(value="Get Random Class Image")
472
- get_input_image_btn.click(
473
- get_random_class_image,
474
- inputs=[image_number],
475
- outputs=[input_image],
476
- api_name="get-random-class-image",
477
  )
478
 
479
  patch_numbers = gr.CheckboxGroup(
480
  label="Image Patch", choices=list(range(n_patches_per_img))
481
  )
482
- top_latent_numbers = gr.CheckboxGroup(label="Top Latents")
483
- top_latent_numbers = [
484
- gr.Number(label=f"Top Latents #{j + 1}", precision=0)
485
- for j in range(n_sae_latents)
486
- ]
487
- sae_example_images = [
488
- gr.Image(label=f"Latent #{j}, Example #{i + 1}")
489
- for i in range(n_sae_examples)
490
- for j in range(n_sae_latents)
491
- ]
492
- get_sae_examples_btn = gr.Button(value="Get SAE Examples")
493
- get_sae_examples_btn.click(
494
- get_sae_examples,
495
- inputs=[image_number, patch_numbers],
496
- outputs=sae_example_images + top_latent_numbers,
497
- api_name="get-sae-examples",
498
  concurrency_limit=16,
499
  )
500
 
@@ -502,7 +553,7 @@ with gr.Blocks() as demo:
502
  get_pred_dist_btn = gr.Button(value="Get Pred. Distribution")
503
  get_pred_dist_btn.click(
504
  get_pred_dist,
505
- inputs=[image_number],
506
  outputs=[pred_dist],
507
  api_name="get-preds",
508
  )
@@ -514,7 +565,7 @@ with gr.Blocks() as demo:
514
  get_modified_dist_btn = gr.Button(value="Get Modified Label")
515
  get_modified_dist_btn.click(
516
  get_modified_dist,
517
- inputs=[image_number, patch_numbers] + latent_numbers + value_sliders,
518
  outputs=[pred_dist],
519
  api_name="get-modified",
520
  concurrency_limit=16,
 
1
+ import base64
2
  import functools
3
  import io
4
  import json
 
7
  import os
8
  import pathlib
9
  import random
10
+ import typing
11
 
12
  import beartype
13
  import einops.layers.torch
14
  import gradio as gr
 
15
  import numpy as np
16
  import open_clip
17
  import requests
 
37
  n_sae_latents = 5
38
  """Number of SAE latents to show."""
39
 
40
+ n_latent_examples = 4
41
  """Number of SAE examples per latent to show."""
42
 
43
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
53
  """Maximum frequency. Any feature that fires more than this is ignored."""
54
 
55
  CWD = pathlib.Path(__file__).parent
56
+ """Current working directory."""
57
+
58
 
59
  r2_url = "https://pub-289086e849214430853bc87bd8964988.r2.dev/"
60
 
 
61
 
62
  logger.info("Set global constants.")
63
 
64
 
65
+ @beartype.beartype
66
+ class Example(typing.TypedDict):
67
+ """Represents an example image and its associated label.
68
+
69
+ Used to store examples of SAE latent activations for visualization.
70
+ """
71
+
72
+ orig_url: str
73
+ """The URL or path to access the original example image."""
74
+ highlighted_url: typing.NotRequired[str]
75
+ """The URL or path to access the SAE-highlighted image."""
76
+ target: int
77
+ """Class ID."""
78
+
79
+
80
+ @beartype.beartype
81
+ class SaeLatent(typing.TypedDict):
82
+ """Represents a single SAE latent."""
83
+
84
+ latent: int
85
+ """The index of the SAE latent being measured."""
86
+
87
+ highlighted_url: str
88
+ """The image with the colormaps applied."""
89
+
90
+ examples: list[Example]
91
+ """Top examples for this latent."""
92
+
93
+
94
  ###########
95
  # Helpers #
96
  ###########
 
133
 
134
 
135
  @beartype.beartype
136
+ def to_sized(img: Image.Image) -> Image.Image:
137
+ # Copied from contrib/classification/transforms.py:for_webapp()
138
+ w, h = img.size
139
+ if w > h:
140
+ resize_w = int(w * 512 / h)
141
+ resize_px = (resize_w, 512)
142
+
143
+ margin_x = (resize_w - 448) // 2
144
+ crop_px = (margin_x, 32, 448 + margin_x, 480)
145
+ else:
146
+ resize_h = int(h * 512 / w)
147
+ resize_px = (512, resize_h)
148
+ margin_y = (resize_h - 448) // 2
149
+ crop_px = (32, margin_y, 480, 448 + margin_y)
150
+
151
+ return img.resize(resize_px, resample=Image.Resampling.BICUBIC).crop(crop_px)
152
+
153
+
154
+ @beartype.beartype
155
+ def img_to_base64(img: Image.Image) -> str:
156
+ buf = io.BytesIO()
157
+ img.save(buf, format="webp", lossless=True)
158
+ b64 = base64.b64encode(buf.getvalue())
159
+ s64 = b64.decode("utf8")
160
+ return "data:image/webp;base64," + s64
161
 
162
 
163
  ##########
 
245
  ############
246
 
247
  human_transform = transforms.Compose([
248
+ transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
249
  transforms.CenterCrop((448, 448)),
250
  transforms.ToTensor(),
251
  einops.layers.torch.Rearrange("channels width height -> width height channels"),
 
262
 
263
 
264
  with open(CWD / "data" / "image_labels.json") as fd:
265
+ img_labels = json.load(fd)
266
 
267
 
268
  logger.info("Loaded all datasets.")
 
292
 
293
 
294
  @beartype.beartype
295
+ def get_img(img_i: int) -> Example:
296
+ img = get_dataset_img(img_i)
297
+ img = human_transform(img)
298
+ return {
299
+ "orig_url": img_to_base64(Image.fromarray((img * 255).to(torch.uint8).numpy())),
300
+ "target": img_labels[img_i],
301
+ }
302
 
303
 
304
  @beartype.beartype
305
+ def get_random_class_img(cls: int) -> Example:
306
+ indices = [i for i, tgt in enumerate(img_labels) if tgt == cls]
307
  i = random.choice(indices)
308
 
309
+ img = get_dataset_img(i)
310
+ img = human_transform(img)
311
+ return {
312
+ "orig_url": img_to_base64(Image.fromarray((img * 255).to(torch.uint8).numpy())),
313
+ "target": cls,
314
+ }
315
 
316
 
317
  @torch.inference_mode
318
+ def get_sae_latents(img_i: int, patches: list[int]) -> list[SaeLatent]:
 
 
319
  """
320
  Given a particular cell, returns some highlighted images showing what feature fires most on this cell.
321
  """
322
  if not patches:
323
+ return []
324
 
325
  logger.info("Getting SAE examples for patches %s.", patches)
326
 
327
+ img = get_dataset_img(img_i)
328
+ x_BCWH = vit_transform(img)[None, ...].to(device)
329
+ x_BPD = split_vit.forward_start(x_BCWH)
330
  # Need to add 1 to account for [CLS] token.
331
  vit_acts_MD = x_BPD[0, [p + 1 for p in patches]].to(device)
332
 
 
336
  latents = torch.argsort(f_x_S, descending=True).cpu()
337
  latents = latents[mask[latents]][:n_sae_latents].tolist()
338
 
339
+ sae_latents = []
340
  for latent in latents:
341
+ intermediates, seen_i_im = [], set()
342
  for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]):
343
  if i_im in seen_i_im:
344
  continue
345
 
346
  example_img = get_dataset_img(i_im)
347
+ intermediates.append({
348
+ "img": example_img,
349
+ "patches": values_p,
350
+ "target": img_labels[i_im],
351
+ })
352
  seen_i_im.add(i_im)
353
 
354
  # How to scale values.
 
356
  if top_values[latent].numel() > 0:
357
  upper = top_values[latent].max().item()
358
 
359
+ examples = []
360
+ for intermediate in intermediates[:n_latent_examples]:
361
+ img_sized = to_sized(intermediate["img"])
362
+ examples.append({
363
+ "orig_url": img_to_base64(img_sized),
364
+ "highlighted_url": img_to_base64(
365
+ add_highlights(
366
+ img_sized,
367
+ intermediate["patches"].to(float).numpy(),
368
+ upper=upper,
369
+ )
370
+ ),
371
+ "target": intermediate["target"],
372
+ })
373
 
374
+ sae_latents.append({"latent": latent, "examples": examples})
375
 
376
+ return sae_latents
377
 
378
 
379
  @torch.inference_mode
 
482
  overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
483
  draw = ImageDraw.Draw(overlay)
484
 
485
+ colors = np.zeros((len(patches), 3), dtype=np.uint8)
486
+ colors[:, 0] = ((patches / (upper + 1e-9)) * 255).astype(np.uint8)
487
 
488
  for p, (val, color) in enumerate(zip(patches, colors)):
489
  assert upper is not None
 
507
 
508
 
509
  with gr.Blocks() as demo:
510
+ ###########
511
+ # get-img #
512
+ ###########
513
+
514
+ # Inputs
515
+ number = gr.Number(label="Number", precision=0)
516
+
517
+ # Outputs
518
+ json_out = gr.JSON(label="get_img_out", value={})
519
+
520
+ get_img_btn = gr.Button(value="Get Input Image")
521
+ get_img_btn.click(
522
+ get_img,
523
+ inputs=[number],
524
+ outputs=[json_out],
525
+ api_name="get-img",
526
  )
527
+
528
+ ########################
529
+ # get-random-class-img #
530
+ ########################
531
+
532
  get_random_class_image_btn = gr.Button(value="Get Random Class Image")
533
+ get_img_btn.click(
534
+ get_random_class_img,
535
+ inputs=[number],
536
+ outputs=[json_out],
537
+ api_name="get-random-class-img",
538
  )
539
 
540
  patch_numbers = gr.CheckboxGroup(
541
  label="Image Patch", choices=list(range(n_patches_per_img))
542
  )
543
+ get_sae_latents_btn = gr.Button(value="Get SAE Examples")
544
+ get_sae_latents_btn.click(
545
+ get_sae_latents,
546
+ inputs=[number, patch_numbers],
547
+ outputs=json_out,
548
+ api_name="get-sae-latents",
 
 
 
 
 
 
 
 
 
 
549
  concurrency_limit=16,
550
  )
551
 
 
553
  get_pred_dist_btn = gr.Button(value="Get Pred. Distribution")
554
  get_pred_dist_btn.click(
555
  get_pred_dist,
556
+ inputs=[number],
557
  outputs=[pred_dist],
558
  api_name="get-preds",
559
  )
 
565
  get_modified_dist_btn = gr.Button(value="Get Modified Label")
566
  get_modified_dist_btn.click(
567
  get_modified_dist,
568
+ inputs=[number, patch_numbers] + latent_numbers + value_sliders,
569
  outputs=[pred_dist],
570
  api_name="get-modified",
571
  concurrency_limit=16,
requirements.txt CHANGED
@@ -47,7 +47,7 @@ contourpy==1.3.1
47
  # via matplotlib
48
  cycler==0.12.1
49
  # via matplotlib
50
- datasets==3.3.0
51
  # via saev
52
  dill==0.3.8
53
  # via
@@ -131,7 +131,7 @@ jsonschema-specifications==2024.10.1
131
  # via jsonschema
132
  kiwisolver==1.4.8
133
  # via matplotlib
134
- marimo==0.11.5
135
  # via saev
136
  markdown==3.7
137
  # via
@@ -155,7 +155,7 @@ multidict==6.1.0
155
  # yarl
156
  multiprocess==0.70.16
157
  # via datasets
158
- narwhals==1.26.0
159
  # via
160
  # altair
161
  # marimo
@@ -297,7 +297,7 @@ ruff==0.9.6
297
  # via
298
  # gradio
299
  # marimo
300
- saev @ git+https://github.com/samuelstevens/saev@928cb62084e88118e792ff6fc8cc043ec250f0ff
301
  # via saev-image-classification (pyproject.toml)
302
  safehttpx==0.1.6
303
  # via gradio
@@ -352,7 +352,7 @@ tqdm==4.67.1
352
  # saev
353
  triton==3.2.0
354
  # via torch
355
- typeguard==4.4.1
356
  # via tyro
357
  typer==0.15.1
358
  # via gradio
 
47
  # via matplotlib
48
  cycler==0.12.1
49
  # via matplotlib
50
+ datasets==3.3.1
51
  # via saev
52
  dill==0.3.8
53
  # via
 
131
  # via jsonschema
132
  kiwisolver==1.4.8
133
  # via matplotlib
134
+ marimo==0.11.6
135
  # via saev
136
  markdown==3.7
137
  # via
 
155
  # yarl
156
  multiprocess==0.70.16
157
  # via datasets
158
+ narwhals==1.27.1
159
  # via
160
  # altair
161
  # marimo
 
297
  # via
298
  # gradio
299
  # marimo
300
+ saev @ git+https://github.com/samuelstevens/saev@298cabdb6b771c76b402d0fdddab6907d1941d7a
301
  # via saev-image-classification (pyproject.toml)
302
  safehttpx==0.1.6
303
  # via gradio
 
352
  # saev
353
  triton==3.2.0
354
  # via torch
355
+ typeguard==4.4.2
356
  # via tyro
357
  typer==0.15.1
358
  # via gradio