import functools import io import json import logging import math import os import pathlib import random import beartype import einops.layers.torch import gradio as gr import numpy as np import open_clip import requests import saev.nn import torch from jaxtyping import Float, jaxtyped from PIL import Image, ImageDraw from torch import Tensor from torchvision.transforms import v2 log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" logging.basicConfig(level=logging.INFO, format=log_format) logger = logging.getLogger("app.py") #################### # Global Constants # #################### DEBUG = True """Whether we are debugging.""" n_sae_latents = 3 """Number of SAE latents to show.""" n_sae_examples = 4 """Number of SAE examples per latent to show.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") """Hardware accelerator, if any.""" vit_ckpt = "ViT-B-16/openai" """CLIP checkpoint.""" n_patches_per_img: int = 196 """Number of patches per image in vit_ckpt.""" max_frequency = 1e-1 """Maximum frequency. Any feature that fires more than this is ignored.""" CWD = pathlib.Path(__file__).parent r2_url = "https://pub-289086e849214430853bc87bd8964988.r2.dev/" logger.info("Set global constants.") ########### # Helpers # ########### @beartype.beartype def get_cache_dir() -> str: """ Get cache directory from environment variables, defaulting to the current working directory (.) Returns: A path to a cache directory (might not exist yet). """ cache_dir = "" for var in ("HF_HOME", "HF_HUB_CACHE"): cache_dir = cache_dir or os.environ.get(var, "") return cache_dir or "." @beartype.beartype def load_model(fpath: str | pathlib.Path, *, device: str = "cpu") -> torch.nn.Module: """ Loads a linear layer from disk. """ with open(fpath, "rb") as fd: kwargs = json.loads(fd.readline().decode()) buffer = io.BytesIO(fd.read()) model = torch.nn.Linear(**kwargs) state_dict = torch.load(buffer, weights_only=True, map_location=device) model.load_state_dict(state_dict) model = model.to(device) return model @beartype.beartype @functools.lru_cache(maxsize=512) def get_dataset_img(i: int) -> Image.Image: return Image.open(requests.get(r2_url + image_fpaths[i], stream=True).raw) @beartype.beartype def make_img( img: Image.Image, patches: Float[Tensor, " n_patches"], *, upper: float | None = None, ) -> Image.Image: # Resize to 256x256 and crop to 224x224 resize_size_px = (512, 512) resize_w_px, resize_h_px = resize_size_px crop_size_px = (448, 448) crop_w_px, crop_h_px = crop_size_px crop_coords_px = ( (resize_w_px - crop_w_px) // 2, (resize_h_px - crop_h_px) // 2, (resize_w_px + crop_w_px) // 2, (resize_h_px + crop_h_px) // 2, ) img = img.resize(resize_size_px).crop(crop_coords_px) img = add_highlights(img, patches.numpy(), upper=upper, opacity=0.5) return img ########## # Models # ########## @jaxtyped(typechecker=beartype.beartype) class SplitClip(torch.nn.Module): def __init__(self, *, n_end_layers: int): super().__init__() if vit_ckpt.startswith("hf-hub:"): clip, _ = open_clip.create_model_from_pretrained( vit_ckpt, cache_dir=get_cache_dir() ) else: arch, ckpt = vit_ckpt.split("/") clip, _ = open_clip.create_model_from_pretrained( arch, pretrained=ckpt, cache_dir=get_cache_dir() ) model = clip.visual model.proj = None model.output_tokens = True # type: ignore self.vit = model.eval() assert not isinstance(self.vit, open_clip.timm_model.TimmModel) self.n_end_layers = n_end_layers @staticmethod def _expand_token(token, batch_size: int): return token.view(1, 1, -1).expand(batch_size, -1, -1) def forward_start(self, x: Float[Tensor, "batch channels width height"]): x = self.vit.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] # class embeddings and positional embeddings x = torch.cat( [self._expand_token(self.vit.class_embedding, x.shape[0]).to(x.dtype), x], dim=1, ) # shape = [*, grid ** 2 + 1, width] x = x + self.vit.positional_embedding.to(x.dtype) x = self.vit.patch_dropout(x) x = self.vit.ln_pre(x) for r in self.vit.transformer.resblocks[: -self.n_end_layers]: x = r(x) return x def forward_end(self, x: Float[Tensor, "batch n_patches dim"]): for r in self.vit.transformer.resblocks[-self.n_end_layers :]: x = r(x) x = self.vit.ln_post(x) pooled, _ = self.vit._global_pool(x) if self.vit.proj is not None: pooled = pooled @ self.vit.proj return pooled # ViT split_vit = SplitClip(n_end_layers=1) split_vit = split_vit.to(device) logger.info("Initialized CLIP ViT.") # Linear classifier clf_ckpt_fpath = CWD / "ckpts" / "clf.pt" clf = load_model(clf_ckpt_fpath) clf = clf.to(device).eval() logger.info("Loaded linear classifier.") # SAE sae_ckpt_fpath = CWD / "ckpts" / "sae.pt" sae = saev.nn.load(sae_ckpt_fpath.as_posix()) sae.to(device).eval() logger.info("Loaded SAE.") ############ # Datasets # ############ human_transform = v2.Compose([ v2.Resize((512, 512), interpolation=v2.InterpolationMode.NEAREST), v2.CenterCrop((448, 448)), v2.ToImage(), einops.layers.torch.Rearrange("channels width height -> width height channels"), ]) arch, ckpt = vit_ckpt.split("/") _, vit_transform = open_clip.create_model_from_pretrained( arch, pretrained=ckpt, cache_dir=get_cache_dir() ) with open(CWD / "data" / "image_fpaths.json") as fd: image_fpaths = json.load(fd) with open(CWD / "data" / "image_labels.json") as fd: image_labels = json.load(fd) logger.info("Loaded all datasets.") ############# # Variables # ############# @beartype.beartype def load_tensor(path: str | pathlib.Path) -> Tensor: return torch.load(path, weights_only=True, map_location="cpu") top_img_i = load_tensor(CWD / "data" / "top_img_i.pt") top_values = load_tensor(CWD / "data" / "top_values.pt") sparsity = load_tensor(CWD / "data" / "sparsity.pt") mask = torch.ones((sae.cfg.d_sae), dtype=bool) mask = mask & (sparsity < max_frequency) ############# # Inference # ############# @beartype.beartype def get_image(image_i: int) -> list[Image.Image | int]: image = get_dataset_img(image_i) image = human_transform(image) return [Image.fromarray(image.numpy()), image_labels[image_i]] @beartype.beartype def get_random_class_image(cls: int) -> Image.Image: indices = [i for i, tgt in enumerate(image_labels) if tgt == cls] i = random.choice(indices) image = get_dataset_img(i) image = human_transform(image) return Image.fromarray(image.numpy()) @torch.inference_mode def get_sae_examples( image_i: int, patches: list[int] ) -> list[None | Image.Image | int]: """ Given a particular cell, returns some highlighted images showing what feature fires most on this cell. """ if not patches: return [None] * 12 + [-1] * 3 logger.info("Getting SAE examples for patches %s.", patches) img = get_dataset_img(image_i) x = vit_transform(img)[None, ...].to(device) x_BPD = split_vit.forward_start(x) # Need to add 1 to account for [CLS] token. vit_acts_MD = x_BPD[0, [p + 1 for p in patches]].to(device) _, f_x_MS, _ = sae(vit_acts_MD) f_x_S = f_x_MS.sum(axis=0) latents = torch.argsort(f_x_S, descending=True).cpu() latents = latents[mask[latents]][:n_sae_latents].tolist() images = [] for latent in latents: img_patch_pairs, seen_i_im = [], set() for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]): if i_im in seen_i_im: continue example_img = get_dataset_img(i_im) img_patch_pairs.append((example_img, values_p)) seen_i_im.add(i_im) # How to scale values. upper = None if top_values[latent].numel() > 0: upper = top_values[latent].max().item() latent_images = [ make_img(img, patches, upper=upper) for img, patches in img_patch_pairs[:n_sae_examples] ] while len(latent_images) < n_sae_examples: latent_images += [None] images.extend(latent_images) return images + latents @torch.inference_mode def get_pred_dist(i: int) -> dict[int, float]: img = get_dataset_img(i) x = vit_transform(img)[None, ...].to(device) x_BPD = split_vit.forward_start(x) x_BD = split_vit.forward_end(x_BPD) logits_BC = clf(x_BD) probs = torch.nn.functional.softmax(logits_BC[0], dim=0).cpu().tolist() return {i: prob for i, prob in enumerate(probs)} @torch.inference_mode def get_modified_dist( image_i: int, patches: list[int], latent1: int, latent2: int, latent3: int, value1: float, value2: float, value3: float, ) -> dict[int, float]: img = get_dataset_img(image_i) x = vit_transform(img)[None, ...].to(device) x_BPD = split_vit.forward_start(x) cls_B1D, x_BPD = x_BPD[:, :1, :], x_BPD[:, 1:, :] x_hat_BPD, f_x_BPS, _ = sae(x_BPD) err_BPD = x_BPD - x_hat_BPD values = torch.tensor( [ unscaled(float(value), top_values[latent].max().item()) for value, latent in [ (value1, latent1), (value2, latent2), (value3, latent3), ] ], device=device, ) patches = torch.tensor(patches, device=device) latents = torch.tensor([latent1, latent2, latent3], device=device) f_x_BPS[:, patches[:, None], latents[None, :]] = values # Reproduce the SAE forward pass after f_x modified_x_hat_BPD = ( einops.einsum( f_x_BPS, sae.W_dec, "batch patches d_sae, d_sae d_vit -> batch patches d_vit", ) + sae.b_dec ) modified_BPD = torch.cat([cls_B1D, err_BPD + modified_x_hat_BPD], axis=1) modified_BD = split_vit.forward_end(modified_BPD) logits_BC = clf(modified_BD) probs = torch.nn.functional.softmax(logits_BC[0], dim=0).cpu().tolist() return {i: prob for i, prob in enumerate(probs)} @beartype.beartype def unscaled(x: float, max_obs: float) -> float: """Scale from [-20, 20] to [20 * -max_obs, 20 * max_obs].""" return map_range(x, (-20.0, 20.0), (-20.0 * max_obs, 20.0 * max_obs)) @beartype.beartype def map_range( x: float, domain: tuple[float | int, float | int], range: tuple[float | int, float | int], ): a, b = domain c, d = range if not (a <= x <= b): raise ValueError(f"x={x:.3f} must be in {[a, b]}.") return c + (x - a) * (d - c) / (b - a) @jaxtyped(typechecker=beartype.beartype) def add_highlights( img: Image.Image, patches: Float[np.ndarray, " n_patches"], *, upper: float | None = None, opacity: float = 0.9, ) -> Image.Image: if not len(patches): return img iw_np, ih_np = int(math.sqrt(len(patches))), int(math.sqrt(len(patches))) iw_px, ih_px = img.size pw_px, ph_px = iw_px // iw_np, ih_px // ih_np assert iw_np * ih_np == len(patches) # Create a transparent overlay overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) # Using semi-transparent red (255, 0, 0, alpha) for p, val in enumerate(patches): assert upper is not None val /= upper + 1e-9 x_np, y_np = p % iw_np, p // ih_np draw.rectangle( [ (x_np * pw_px, y_np * ph_px), (x_np * pw_px + pw_px, y_np * ph_px + ph_px), ], fill=(int(val * 256), 0, 0, int(opacity * val * 256)), ) # Composite the original image and the overlay return Image.alpha_composite(img.convert("RGBA"), overlay) ############# # Interface # ############# with gr.Blocks() as demo: image_number = gr.Number(label="Test Example", precision=0) class_number = gr.Number(label="Test Class", precision=0) input_image = gr.Image(label="Input Image") get_input_image_btn = gr.Button(value="Get Input Image") get_input_image_btn.click( get_image, inputs=[image_number], outputs=[input_image, class_number], api_name="get-image", ) get_random_class_image_btn = gr.Button(value="Get Random Class Image") get_input_image_btn.click( get_random_class_image, inputs=[image_number], outputs=[input_image], api_name="get-random-class-image", ) patch_numbers = gr.CheckboxGroup( label="Image Patch", choices=list(range(n_patches_per_img)) ) top_latent_numbers = gr.CheckboxGroup(label="Top Latents") top_latent_numbers = [ gr.Number(label=f"Top Latents #{j + 1}", precision=0) for j in range(n_sae_latents) ] sae_example_images = [ gr.Image(label=f"Latent #{j}, Example #{i + 1}") for i in range(n_sae_examples) for j in range(n_sae_latents) ] get_sae_examples_btn = gr.Button(value="Get SAE Examples") get_sae_examples_btn.click( get_sae_examples, inputs=[image_number, patch_numbers], outputs=sae_example_images + top_latent_numbers, api_name="get-sae-examples", concurrency_limit=16, ) pred_dist = gr.Label(label="Pred. Dist.") get_pred_dist_btn = gr.Button(value="Get Pred. Distribution") get_pred_dist_btn.click( get_pred_dist, inputs=[image_number], outputs=[pred_dist], api_name="get-preds", ) latent_numbers = [gr.Number(label=f"Latent {i + 1}", precision=0) for i in range(3)] value_sliders = [ gr.Slider(label=f"Value {i + 1}", minimum=-10, maximum=10) for i in range(3) ] get_modified_dist_btn = gr.Button(value="Get Modified Label") get_modified_dist_btn.click( get_modified_dist, inputs=[image_number, patch_numbers] + latent_numbers + value_sliders, outputs=[pred_dist], api_name="get-modified", concurrency_limit=16, ) if __name__ == "__main__": demo.launch()