Samuel Stevens commited on
Commit
dc20bdb
·
0 Parent(s):

initial commit

Browse files
Files changed (6) hide show
  1. .python-version +1 -0
  2. README.md +0 -0
  3. app.py +512 -0
  4. data.py +0 -0
  5. justfile +9 -0
  6. pyproject.toml +20 -0
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
README.md ADDED
File without changes
app.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import typing
3
+ import functools
4
+
5
+ import beartype
6
+ import einops
7
+ import einops.layers.torch
8
+ import gradio as gr
9
+ import torch
10
+ from jaxtyping import Float, Int, UInt8, jaxtyped
11
+ from PIL import Image
12
+ from torch import Tensor
13
+
14
+ import saev.activations
15
+ import saev.config
16
+ import saev.nn
17
+ import saev.visuals
18
+
19
+ from .. import training
20
+ from . import data
21
+
22
+ ####################
23
+ # Global Constants #
24
+ ####################
25
+
26
+
27
+ DEBUG = False
28
+ """Whether we are debugging."""
29
+
30
+ max_frequency = 1e-2
31
+ """Maximum frequency. Any feature that fires more than this is ignored."""
32
+
33
+ ckpt = "oebd6e6i"
34
+ """Which SAE checkpoint to use."""
35
+
36
+ n_sae_latents = 3
37
+ """Number of SAE latents to show."""
38
+
39
+ n_sae_examples = 4
40
+ """Number of SAE examples per latent to show."""
41
+
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ """Hardware accelerator, if any."""
44
+
45
+ RESIZE_SIZE = 512
46
+ """Resize shorter size to this size in pixels."""
47
+
48
+ CROP_SIZE = (448, 448)
49
+ """Crop size in pixels."""
50
+
51
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ """Hardware accelerator, if any."""
53
+
54
+ ####################
55
+ # Helper Functions #
56
+ ####################
57
+
58
+
59
+ @beartype.beartype
60
+ def load_tensor(path: str) -> Tensor:
61
+ return torch.load(path, weights_only=True, map_location="cpu")
62
+
63
+
64
+ ##########
65
+ # Models #
66
+ ##########
67
+
68
+
69
+ @functools.cache
70
+ def load_vit(
71
+ model_cfg: modeling.Config,
72
+ ) -> tuple[
73
+ activations.WrappedVisionTransformer,
74
+ typing.Callable,
75
+ float,
76
+ Float[Tensor, " d_vit"],
77
+ ]:
78
+ vit = (
79
+ saev.activations.WrappedVisionTransformer(model_cfg.wrapped_cfg)
80
+ .to(DEVICE)
81
+ .eval()
82
+ )
83
+ vit_transform = saev.activations.make_img_transform(
84
+ model_cfg.vit_family, model_cfg.vit_ckpt
85
+ )
86
+ logger.info("Loaded ViT: %s.", model_cfg.key)
87
+
88
+ try:
89
+ # Normalizing constants
90
+ acts_dataset = saev.activations.Dataset(model_cfg.acts_cfg)
91
+ logger.info("Loaded dataset norms: %s.", model_cfg.key)
92
+ except RuntimeError as err:
93
+ logger.warning("Error loading ViT: %s", err)
94
+ return None, None, None, None
95
+
96
+ return vit, vit_transform, acts_dataset.scalar.item(), acts_dataset.act_mean
97
+
98
+
99
+ sae_ckpt_fpath = f"/home/stevens.994/projects/saev/checkpoints/{ckpt}/sae.pt"
100
+ sae = saev.nn.load(sae_ckpt_fpath)
101
+ sae.to(device).eval()
102
+
103
+
104
+ head_ckpt_fpath = "/home/stevens.994/projects/saev/checkpoints/contrib/semseg/lr_0_001__wd_0_001/model_step8000.pt"
105
+ head = training.load(head_ckpt_fpath)
106
+ head = head.to(device).eval()
107
+
108
+
109
+ class RestOfDinoV2(torch.nn.Module):
110
+ def __init__(self, *, n_end_layers: int):
111
+ super().__init__()
112
+ self.vit = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg")
113
+ self.n_end_layers = n_end_layers
114
+
115
+ def forward_start(self, x: Float[Tensor, "batch channels width height"]):
116
+ x_BPD = self.vit.prepare_tokens_with_masks(x)
117
+ for blk in self.vit.blocks[: -self.n_end_layers]:
118
+ x_BPD = blk(x_BPD)
119
+
120
+ return x_BPD
121
+
122
+ def forward_end(self, x_BPD: Float[Tensor, "batch n_patches dim"]):
123
+ for blk in self.vit.blocks[-self.n_end_layers :]:
124
+ x_BPD = blk(x_BPD)
125
+
126
+ x_BPD = self.vit.norm(x_BPD)
127
+ return x_BPD[:, self.vit.num_register_tokens + 1 :]
128
+
129
+
130
+ rest_of_vit = RestOfDinoV2(n_end_layers=1)
131
+ rest_of_vit = rest_of_vit.to(device)
132
+
133
+
134
+ ####################
135
+ # Global Variables #
136
+ ####################
137
+
138
+
139
+ ckpt_data_root = (
140
+ f"/research/nfs_su_809/workspace/stevens.994/saev/features/{ckpt}/sort_by_patch"
141
+ )
142
+
143
+ top_img_i = load_tensor(os.path.join(ckpt_data_root, "top_img_i.pt"))
144
+ top_values = load_tensor(os.path.join(ckpt_data_root, "top_values.pt"))
145
+ sparsity = load_tensor(os.path.join(ckpt_data_root, "sparsity.pt"))
146
+
147
+
148
+ mask = torch.ones((sae.cfg.d_sae), dtype=bool)
149
+ mask = mask & (sparsity < max_frequency)
150
+
151
+
152
+ ############
153
+ # Datasets #
154
+ ############
155
+
156
+
157
+ # in1k_dataset = saev.activations.get_dataset(
158
+ # saev.config.ImagenetDataset(),
159
+ # img_transform=v2.Compose([
160
+ # v2.Resize(size=(512, 512)),
161
+ # v2.CenterCrop(size=(448, 448)),
162
+ # ]),
163
+ # )
164
+
165
+
166
+ # acts_dataset = saev.activations.Dataset(
167
+ # saev.config.DataLoad(
168
+ # shard_root="/local/scratch/stevens.994/cache/saev/a1f842330bb568b2fb05c15d4fa4252fb7f5204837335000d9fd420f120cd03e",
169
+ # scale_mean=not DEBUG,
170
+ # scale_norm=not DEBUG,
171
+ # layer=-2,
172
+ # )
173
+ # )
174
+
175
+
176
+ # vit_dataset = saev.activations.Ade20k(
177
+ # saev.config.Ade20kDataset(
178
+ # root="/research/nfs_su_809/workspace/stevens.994/datasets/ade20k/"
179
+ # ),
180
+ # img_transform=v2.Compose([
181
+ # v2.Resize(size=(256, 256)),
182
+ # v2.CenterCrop(size=(224, 224)),
183
+ # v2.ToImage(),
184
+ # v2.ToDtype(torch.float32, scale=True),
185
+ # v2.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
186
+ # ]),
187
+ # )
188
+
189
+
190
+ #######################
191
+ # Inference Functions #
192
+ #######################
193
+
194
+
195
+ @beartype.beartype
196
+ class Example(typing.TypedDict):
197
+ """Represents an example image and its associated label.
198
+
199
+ Used to store examples of SAE latent activations for visualization.
200
+ """
201
+
202
+ orig_url: str
203
+ """The URL or path to access the original example image."""
204
+ highlighted_url: str
205
+ """The URL or path to access the SAE-highlighted image."""
206
+ index: int
207
+ """Dataset index."""
208
+
209
+
210
+ @beartype.beartype
211
+ class SaeActivation(typing.TypedDict):
212
+ """Represents the activation pattern of a single SAE latent across patches.
213
+
214
+ This captures how strongly a particular SAE latent fires on different patches of an input image.
215
+ """
216
+
217
+ latent: int
218
+ """The index of the SAE latent being measured."""
219
+
220
+ highlighted_url: str
221
+ """The image with the colormaps applied."""
222
+
223
+ activations: list[float]
224
+ """The activation values of this latent across different patches. Each value represents how strongly this latent fired on a particular patch."""
225
+
226
+ examples: list[Example]
227
+ """Top examples for this latent."""
228
+
229
+
230
+ @beartype.beartype
231
+ def get_image(image_i: int) -> tuple[str, str, int]:
232
+ img_sized, labels_sized = data.get_sample(image_i)
233
+
234
+ return data.pil_to_base64(img_sized), data.pil_to_base64(labels_sized), image_i
235
+
236
+
237
+ @beartype.beartype
238
+ @torch.inference_mode
239
+ def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]:
240
+ """
241
+ Given a particular cell, returns some highlighted images showing what feature fires most on this cell.
242
+ """
243
+ if not patches:
244
+ return []
245
+
246
+ vit, vit_transform, scalar, mean = load_vit(model_cfg)
247
+ if vit is None:
248
+ logger.warning("Skipping ViT '%s'", model_name)
249
+ return []
250
+ sae = load_sae(model_cfg)
251
+
252
+ mean = mean.to(DEVICE)
253
+ x = vit_transform(img_p)[None, ...].to(DEVICE)
254
+
255
+ _, vit_acts_BLPD = vit(x)
256
+ vit_acts_PD = (vit_acts_BLPD[0, 0, 1:].to(DEVICE).clamp(-1e-5, 1e5) - mean) / scalar
257
+
258
+ _, f_x_PS, _ = sae(vit_acts_PD)
259
+ # Ignore [CLS] token and get just the requested latents.
260
+ acts_SP = einops.rearrange(f_x_PS, "patches n_latents -> n_latents patches")
261
+ logger.info("Got SAE activations for '%s'.", model_name)
262
+ top_img_i, top_values = load_tensors(model_cfg)
263
+ logger.info("Loaded top SAE activations for '%s'.", model_name)
264
+
265
+ breakpoint()
266
+
267
+ vit_acts_MD = torch.stack([
268
+ acts_dataset[image_i * acts_dataset.metadata.n_patches_per_img + i]["act"]
269
+ for i in patches
270
+ ]).to(device)
271
+
272
+ _, f_x_MS, _ = sae(vit_acts_MD)
273
+ f_x_S = f_x_MS.sum(axis=0)
274
+
275
+ latents = torch.argsort(f_x_S, descending=True).cpu()
276
+ latents = latents[mask[latents]][:n_sae_latents].tolist()
277
+
278
+ images = []
279
+ for latent in latents:
280
+ elems, seen_i_im = [], set()
281
+ for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]):
282
+ if i_im in seen_i_im:
283
+ continue
284
+
285
+ example = in1k_dataset[i_im]
286
+ elems.append(
287
+ saev.visuals.GridElement(example["image"], example["label"], values_p)
288
+ )
289
+ seen_i_im.add(i_im)
290
+
291
+ # How to scale values.
292
+ upper = None
293
+ if top_values[latent].numel() > 0:
294
+ upper = top_values[latent].max().item()
295
+
296
+ latent_images = [make_img(elem, upper=upper) for elem in elems[:n_sae_examples]]
297
+
298
+ while len(latent_images) < n_sae_examples:
299
+ latent_images += [None]
300
+
301
+ images.extend(latent_images)
302
+
303
+ return images + latents
304
+
305
+
306
+ @torch.inference_mode
307
+ def get_true_labels(image_i: int) -> Image.Image:
308
+ seg = human_dataset[image_i]["segmentation"]
309
+ image = seg_to_img(seg)
310
+ return image
311
+
312
+
313
+ @torch.inference_mode
314
+ def get_pred_labels(i: int) -> list[Image.Image | list[int]]:
315
+ sample = vit_dataset[i]
316
+ x = sample["image"][None, ...].to(device)
317
+ x_BPD = rest_of_vit.forward_start(x)
318
+ x_BPD = rest_of_vit.forward_end(x_BPD)
319
+
320
+ x_WHD = einops.rearrange(x_BPD, "() (w h) dim -> w h dim", w=16, h=16)
321
+
322
+ logits_WHC = head(x_WHD)
323
+
324
+ pred_WH = logits_WHC.argmax(axis=-1)
325
+ preds = einops.rearrange(pred_WH, "w h -> (w h)").tolist()
326
+ return [seg_to_img(upsample(pred_WH)), preds]
327
+
328
+
329
+ @beartype.beartype
330
+ def unscaled(x: float, max_obs: float) -> float:
331
+ """Scale from [-10, 10] to [10 * -max_obs, 10 * max_obs]."""
332
+ return map_range(x, (-10.0, 10.0), (-10.0 * max_obs, 10.0 * max_obs))
333
+
334
+
335
+ @beartype.beartype
336
+ def map_range(
337
+ x: float,
338
+ domain: tuple[float | int, float | int],
339
+ range: tuple[float | int, float | int],
340
+ ):
341
+ a, b = domain
342
+ c, d = range
343
+ if not (a <= x <= b):
344
+ raise ValueError(f"x={x:.3f} must be in {[a, b]}.")
345
+ return c + (x - a) * (d - c) / (b - a)
346
+
347
+
348
+ @torch.inference_mode
349
+ def get_modified_labels(
350
+ i: int,
351
+ latent1: int,
352
+ latent2: int,
353
+ latent3: int,
354
+ value1: float,
355
+ value2: float,
356
+ value3: float,
357
+ ) -> list[Image.Image | list[int]]:
358
+ sample = vit_dataset[i]
359
+ x = sample["image"][None, ...].to(device)
360
+ x_BPD = rest_of_vit.forward_start(x)
361
+
362
+ x_hat_BPD, f_x_BPS, _ = sae(x_BPD)
363
+
364
+ err_BPD = x_BPD - x_hat_BPD
365
+
366
+ values = torch.tensor(
367
+ [
368
+ unscaled(float(value), top_values[latent].max().item())
369
+ for value, latent in [
370
+ (value1, latent1),
371
+ (value2, latent2),
372
+ (value3, latent3),
373
+ ]
374
+ ],
375
+ device=device,
376
+ )
377
+ f_x_BPS[..., torch.tensor([latent1, latent2, latent3], device=device)] = values
378
+
379
+ # Reproduce the SAE forward pass after f_x
380
+ modified_x_hat_BPD = (
381
+ einops.einsum(
382
+ f_x_BPS,
383
+ sae.W_dec,
384
+ "batch patches d_sae, d_sae d_vit -> batch patches d_vit",
385
+ )
386
+ + sae.b_dec
387
+ )
388
+ modified_BPD = err_BPD + modified_x_hat_BPD
389
+
390
+ modified_BPD = rest_of_vit.forward_end(modified_BPD)
391
+
392
+ logits_BPC = head(modified_BPD)
393
+ pred_P = logits_BPC[0].argmax(axis=-1)
394
+ pred_WH = einops.rearrange(pred_P, "(w h) -> w h", w=16, h=16)
395
+ return seg_to_img(upsample(pred_WH)), pred_P.tolist()
396
+
397
+
398
+ @jaxtyped(typechecker=beartype.beartype)
399
+ @torch.inference_mode
400
+ def upsample(
401
+ x_WH: Int[Tensor, "width_ps height_ps"],
402
+ ) -> UInt8[Tensor, "width_px height_px"]:
403
+ return (
404
+ torch.nn.functional.interpolate(
405
+ x_WH.view((1, 1, 16, 16)).float(),
406
+ scale_factor=28,
407
+ )
408
+ .view((448, 448))
409
+ .type(torch.uint8)
410
+ )
411
+
412
+
413
+ @beartype.beartype
414
+ def make_img(
415
+ elem: saev.visuals.GridElement, *, upper: float | None = None
416
+ ) -> Image.Image:
417
+ # Resize to 256x256 and crop to 224x224
418
+ resize_size_px = (512, 512)
419
+ resize_w_px, resize_h_px = resize_size_px
420
+ crop_size_px = (448, 448)
421
+ crop_w_px, crop_h_px = crop_size_px
422
+ crop_coords_px = (
423
+ (resize_w_px - crop_w_px) // 2,
424
+ (resize_h_px - crop_h_px) // 2,
425
+ (resize_w_px + crop_w_px) // 2,
426
+ (resize_h_px + crop_h_px) // 2,
427
+ )
428
+
429
+ img = elem.img.resize(resize_size_px).crop(crop_coords_px)
430
+ img = saev.imaging.add_highlights(
431
+ img, elem.patches.numpy(), upper=upper, opacity=0.5
432
+ )
433
+ return img
434
+
435
+
436
+ with gr.Blocks() as demo:
437
+ image_number = gr.Number(label="Validation Example")
438
+
439
+ input_image_base64 = gr.Text(label="Image in Base64")
440
+ true_labels_base64 = gr.Text(label="Labels in Base64")
441
+
442
+ get_input_image_btn = gr.Button(value="Get Input Image")
443
+ get_input_image_btn.click(
444
+ get_image,
445
+ inputs=[image_number],
446
+ outputs=[input_image_base64, true_labels_base64, image_number],
447
+ api_name="get-image",
448
+ )
449
+
450
+ # input_image = gr.Image(
451
+ # label="Input Image",
452
+ # sources=["upload", "clipboard"],
453
+ # type="pil",
454
+ # interactive=True,
455
+ # )
456
+ # patch_numbers = gr.CheckboxGroup(label="Image Patch", choices=list(range(256)))
457
+ # top_latent_numbers = gr.CheckboxGroup(label="Top Latents")
458
+ # top_latent_numbers = [
459
+ # gr.Number(label="Top Latents #{j+1}") for j in range(n_sae_latents)
460
+ # ]
461
+ # sae_example_images = [
462
+ # gr.Image(label=f"Latent #{j}, Example #{i + 1}", format="png")
463
+ # for i in range(n_sae_examples)
464
+ # for j in range(n_sae_latents)
465
+ # ]
466
+
467
+ patches_json = gr.JSON(label="Patches", value=[])
468
+ activations_json = gr.JSON(label="Activations", value=[])
469
+
470
+ get_sae_activations_btn = gr.Button(value="Get SAE Activations")
471
+ get_sae_activations_btn.click(
472
+ get_sae_activations,
473
+ inputs=[image_number, patches_json],
474
+ outputs=[activations_json],
475
+ api_name="get-sae-examples",
476
+ )
477
+ # semseg_image = gr.Image(label="Semantic Segmentaions", format="png")
478
+ # semseg_colors = gr.CheckboxGroup(
479
+ # label="Sem Seg Colors", choices=list(range(1, 151))
480
+ # )
481
+
482
+ # get_pred_labels_btn = gr.Button(value="Get Pred. Labels")
483
+ # get_pred_labels_btn.click(
484
+ # get_pred_labels,
485
+ # inputs=[image_number],
486
+ # outputs=[semseg_image, semseg_colors],
487
+ # api_name="get-pred-labels",
488
+ # )
489
+
490
+ # get_true_labels_btn = gr.Button(value="Get True Label")
491
+ # get_true_labels_btn.click(
492
+ # get_true_labels,
493
+ # inputs=[image_number],
494
+ # outputs=semseg_image,
495
+ # api_name="get-true-labels",
496
+ # )
497
+
498
+ # latent_numbers = [gr.Number(label=f"Latent {i + 1}") for i in range(3)]
499
+ # value_sliders = [
500
+ # gr.Slider(label=f"Value {i + 1}", minimum=-10, maximum=10) for i in range(3)
501
+ # ]
502
+
503
+ # get_modified_labels_btn = gr.Button(value="Get Modified Label")
504
+ # get_modified_labels_btn.click(
505
+ # get_modified_labels,
506
+ # inputs=[image_number] + latent_numbers + value_sliders,
507
+ # outputs=[semseg_image, semseg_colors],
508
+ # api_name="get-modified-labels",
509
+ # )
510
+
511
+ if __name__ == "__main__":
512
+ demo.launch()
data.py ADDED
File without changes
justfile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ build: lint
2
+ uv pip compile pyproject.toml > requirements.txt
3
+
4
+ lint: fmt
5
+ git ls-files "*.py" --cached --others --exclude-standard | xargs uv run ruff check
6
+
7
+ fmt:
8
+ git ls-files "*.py" --cached --others --exclude-standard | xargs uv run isort
9
+ git ls-files "*.py" --cached --others --exclude-standard | xargs uv run ruff format --preview
pyproject.toml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "saev-semantic-segmentation"
3
+ version = "0.1.0"
4
+ description = "Gradio app space for semantic segmentation with SAEs"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "beartype>=0.19.0",
9
+ "einops>=0.8.0",
10
+ "gradio>=5.3.0",
11
+ "numpy>=2.2.2",
12
+ "torch>=2.6.0",
13
+ "torchvision>=0.21.0",
14
+ ]
15
+
16
+ [tool.ruff.lint]
17
+ ignore = ["F722"]
18
+
19
+ [tool.uv.sources]
20
+ saev = { git = "https://github.com/samuelstevens/saev" }