import io |
import json |
import logging |
import math |
import os |
import pathlib |
import random |
import beartype |
import einops.layers.torch |
import gradio as gr |
import numpy as np |
import open_clip |
import requests |
import saev.nn |
import torch |
from jaxtyping import Float, jaxtyped |
from PIL import Image, ImageDraw |
from torch import Tensor |
from torchvision.transforms import v2 |
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" |
logging.basicConfig(level=logging.INFO, format=log_format) |
logger = logging.getLogger("app.py") |
DEBUG = True |
"""Whether we are debugging.""" |
n_sae_latents = 3 |
"""Number of SAE latents to show.""" |
n_sae_examples = 4 |
"""Number of SAE examples per latent to show.""" |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
"""Hardware accelerator, if any.""" |
vit_ckpt = "ViT-B-16/openai" |
"""CLIP checkpoint.""" |
n_patches_per_img: int = 196 |
"""Number of patches per image in vit_ckpt.""" |
max_frequency = 1e-2 |
"""Maximum frequency. Any feature that fires more than this is ignored.""" |
CWD = pathlib.Path(__file__).parent |
r2_url = "https://pub-289086e849214430853bc87bd8964988.r2.dev/" |
logger.info("Set global constants.") |
@beartype.beartype |
def get_cache_dir() -> str: |
""" |
Get cache directory from environment variables, defaulting to the current working directory (.) |
Returns: |
A path to a cache directory (might not exist yet). |
""" |
cache_dir = "" |
for var in ("HF_HOME", "HF_HUB_CACHE"): |
cache_dir = cache_dir or os.environ.get(var, "") |
return cache_dir or "." |
@beartype.beartype |
def load_model(fpath: str | pathlib.Path, *, device: str = "cpu") -> torch.nn.Module: |
""" |
Loads a linear layer from disk. |
""" |
with open(fpath, "rb") as fd: |
kwargs = json.loads(fd.readline().decode()) |
buffer = io.BytesIO(fd.read()) |
model = torch.nn.Linear(**kwargs) |
state_dict = torch.load(buffer, weights_only=True, map_location=device) |
model.load_state_dict(state_dict) |
model = model.to(device) |
return model |
@beartype.beartype |
def get_dataset_img(i: int) -> Image.Image: |
return Image.open(requests.get(r2_url + image_fpaths[i], stream=True).raw) |
@beartype.beartype |
def make_img( |
img: Image.Image, patches: Float[Tensor, ""], *, upper: float | None = None |
) -> Image.Image: |
resize_size_px = (512, 512) |
resize_w_px, resize_h_px = resize_size_px |
crop_size_px = (448, 448) |
crop_w_px, crop_h_px = crop_size_px |
crop_coords_px = ( |
(resize_w_px - crop_w_px) // 2, |
(resize_h_px - crop_h_px) // 2, |
(resize_w_px + crop_w_px) // 2, |
(resize_h_px + crop_h_px) // 2, |
) |
img = img.resize(resize_size_px).crop(crop_coords_px) |
img = add_highlights(img, patches.numpy(), upper=upper, opacity=0.5) |
return img |
@jaxtyped(typechecker=beartype.beartype) |
class SplitClip(torch.nn.Module): |
def __init__(self, *, n_end_layers: int): |
super().__init__() |
if vit_ckpt.startswith("hf-hub:"): |
clip, _ = open_clip.create_model_from_pretrained( |
vit_ckpt, cache_dir=get_cache_dir() |
) |
else: |
arch, ckpt = vit_ckpt.split("/") |
clip, _ = open_clip.create_model_from_pretrained( |
arch, pretrained=ckpt, cache_dir=get_cache_dir() |
) |
model = clip.visual |
model.proj = None |
model.output_tokens = True |
self.vit = model.eval() |
assert not isinstance(self.vit, open_clip.timm_model.TimmModel) |
self.n_end_layers = n_end_layers |
@staticmethod |
def _expand_token(token, batch_size: int): |
return token.view(1, 1, -1).expand(batch_size, -1, -1) |
def forward_start(self, x: Float[Tensor, "batch channels width height"]): |
x = self.vit.conv1(x) |
x = x.reshape(x.shape[0], x.shape[1], -1) |
x = x.permute(0, 2, 1) |
x = torch.cat( |
[self._expand_token(self.vit.class_embedding, x.shape[0]).to(x.dtype), x], |
dim=1, |
) |
x = x + self.vit.positional_embedding.to(x.dtype) |
x = self.vit.patch_dropout(x) |
x = self.vit.ln_pre(x) |
for r in self.vit.transformer.resblocks[: -self.n_end_layers]: |
x = r(x) |
return x |
def forward_end(self, x: Float[Tensor, "batch n_patches dim"]): |
for r in self.vit.transformer.resblocks[-self.n_end_layers :]: |
x = r(x) |
x = self.vit.ln_post(x) |
pooled, _ = self.vit._global_pool(x) |
if self.vit.proj is not None: |
pooled = pooled @ self.vit.proj |
return pooled |
split_vit = SplitClip(n_end_layers=1) |
split_vit = split_vit.to(device) |
logger.info("Initialized CLIP ViT.") |
clf_ckpt_fpath = CWD / "ckpts" / "clf.pt" |
clf = load_model(clf_ckpt_fpath) |
clf = clf.to(device).eval() |
logger.info("Loaded linear classifier.") |
sae_ckpt_fpath = CWD / "ckpts" / "sae.pt" |
sae = saev.nn.load(sae_ckpt_fpath.as_posix()) |
sae.to(device).eval() |
logger.info("Loaded SAE.") |
human_transform = v2.Compose([ |
v2.Resize((512, 512), interpolation=v2.InterpolationMode.NEAREST), |
v2.CenterCrop((448, 448)), |
v2.ToImage(), |
einops.layers.torch.Rearrange("channels width height -> width height channels"), |
]) |
arch, ckpt = vit_ckpt.split("/") |
_, vit_transform = open_clip.create_model_from_pretrained( |
arch, pretrained=ckpt, cache_dir=get_cache_dir() |
) |
with open(CWD / "data" / "image_fpaths.json") as fd: |
image_fpaths = json.load(fd) |
with open(CWD / "data" / "image_labels.json") as fd: |
image_labels = json.load(fd) |
logger.info("Loaded all datasets.") |
@beartype.beartype |
def load_tensor(path: str | pathlib.Path) -> Tensor: |
return torch.load(path, weights_only=True, map_location="cpu") |
top_img_i = load_tensor(CWD / "data" / "top_img_i.pt") |
top_values = load_tensor(CWD / "data" / "top_values.pt") |
sparsity = load_tensor(CWD / "data" / "sparsity.pt") |
mask = torch.ones((sae.cfg.d_sae), dtype=bool) |
mask = mask & (sparsity < max_frequency) |
@beartype.beartype |
def get_image(image_i: int) -> list[Image.Image | int]: |
image = get_dataset_img(image_i) |
image = human_transform(image) |
return [Image.fromarray(image.numpy()), image_labels[image_i]] |
@beartype.beartype |
def get_random_class_image(cls: int) -> Image.Image: |
indices = [i for i, tgt in enumerate(image_labels) if tgt == cls] |
i = random.choice(indices) |
image = get_dataset_img(i) |
image = human_transform(image) |
return Image.fromarray(image.numpy()) |
@torch.inference_mode |
def get_sae_examples( |
image_i: int, patches: list[int] |
) -> list[None | Image.Image | int]: |
""" |
Given a particular cell, returns some highlighted images showing what feature fires most on this cell. |
""" |
if not patches: |
return [None] * 12 + [-1] * 3 |
img = get_dataset_img(image_i) |
x = vit_transform(img)[None, ...].to(device) |
x_BPD = split_vit.forward_start(x) |
vit_acts_MD = x_BPD[0, patches].to(device) |
_, f_x_MS, _ = sae(vit_acts_MD) |
f_x_S = f_x_MS.sum(axis=0) |
latents = torch.argsort(f_x_S, descending=True).cpu() |
latents = latents[mask[latents]][:n_sae_latents].tolist() |
images = [] |
for latent in latents: |
img_patch_pairs, seen_i_im = [], set() |
for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]): |
if i_im in seen_i_im: |
continue |
example = None |
img_patch_pairs.append((example["image"], values_p)) |
seen_i_im.add(i_im) |
upper = None |
if top_values[latent].numel() > 0: |
upper = top_values[latent].max().item() |
latent_images = [ |
make_img(img, patches, upper=upper) |
for img, patches in img_patch_pairs[:n_sae_examples] |
] |
while len(latent_images) < n_sae_examples: |
latent_images += [None] |
images.extend(latent_images) |
return images + latents |
@torch.inference_mode |
def get_pred_dist(i: int) -> dict[int, float]: |
img = get_dataset_img(i) |
x = vit_transform(img)[None, ...].to(device) |
x_BPD = split_vit.forward_start(x) |
x_BD = split_vit.forward_end(x_BPD) |
logits_BC = clf(x_BD) |
probs = torch.nn.functional.softmax(logits_BC[0], dim=0).cpu().tolist() |
return {i: prob for i, prob in enumerate(probs)} |
@torch.inference_mode |
def get_modified_dist( |
image_i: int, |
patches: list[int], |
latent1: int, |
latent2: int, |
latent3: int, |
value1: float, |
value2: float, |
value3: float, |
) -> dict[int, float]: |
img = get_dataset_img(image_i) |
x = vit_transform(img)[None, ...].to(device) |
x_BPD = split_vit.forward_start(x) |
cls_B1D, x_BPD = x_BPD[:, :1, :], x_BPD[:, 1:, :] |
x_hat_BPD, f_x_BPS, _ = sae(x_BPD) |
err_BPD = x_BPD - x_hat_BPD |
values = torch.tensor( |
[ |
unscaled(float(value), top_values[latent].max().item()) |
for value, latent in [ |
(value1, latent1), |
(value2, latent2), |
(value3, latent3), |
] |
], |
device=device, |
) |
patches = torch.tensor(patches, device=device) |
latents = torch.tensor([latent1, latent2, latent3], device=device) |
f_x_BPS[:, patches[:, None], latents[None, :]] = values |
modified_x_hat_BPD = ( |
einops.einsum( |
f_x_BPS, |
sae.W_dec, |
"batch patches d_sae, d_sae d_vit -> batch patches d_vit", |
) |
+ sae.b_dec |
) |
modified_BPD = torch.cat([cls_B1D, err_BPD + modified_x_hat_BPD], axis=1) |
modified_BD = split_vit.forward_end(modified_BPD) |
logits_BC = clf(modified_BD) |
probs = torch.nn.functional.softmax(logits_BC[0], dim=0).cpu().tolist() |
return {i: prob for i, prob in enumerate(probs)} |
@beartype.beartype |
def unscaled(x: float, max_obs: float) -> float: |
"""Scale from [-20, 20] to [20 * -max_obs, 20 * max_obs].""" |
return map_range(x, (-20.0, 20.0), (-20.0 * max_obs, 20.0 * max_obs)) |
@beartype.beartype |
def map_range( |
x: float, |
domain: tuple[float | int, float | int], |
range: tuple[float | int, float | int], |
): |
a, b = domain |
c, d = range |
if not (a <= x <= b): |
raise ValueError(f"x={x:.3f} must be in {[a, b]}.") |
return c + (x - a) * (d - c) / (b - a) |
@jaxtyped(typechecker=beartype.beartype) |
def add_highlights( |
img: Image.Image, |
patches: Float[np.ndarray, " n_patches"], |
*, |
upper: float | None = None, |
opacity: float = 0.9, |
) -> Image.Image: |
if not len(patches): |
return img |
iw_np, ih_np = int(math.sqrt(len(patches))), int(math.sqrt(len(patches))) |
iw_px, ih_px = img.size |
pw_px, ph_px = iw_px // iw_np, ih_px // ih_np |
assert iw_np * ih_np == len(patches) |
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) |
draw = ImageDraw.Draw(overlay) |
for p, val in enumerate(patches): |
assert upper is not None |
val /= upper + 1e-9 |
x_np, y_np = p % iw_np, p // ih_np |
draw.rectangle( |
[ |
(x_np * pw_px, y_np * ph_px), |
(x_np * pw_px + pw_px, y_np * ph_px + ph_px), |
], |
fill=(int(val * 256), 0, 0, int(opacity * val * 256)), |
) |
return Image.alpha_composite(img.convert("RGBA"), overlay) |
with gr.Blocks() as demo: |
image_number = gr.Number(label="Test Example", precision=0) |
class_number = gr.Number(label="Test Class", precision=0) |
input_image = gr.Image(label="Input Image") |
get_input_image_btn = gr.Button(value="Get Input Image") |
get_input_image_btn.click( |
get_image, |
inputs=[image_number], |
outputs=[input_image, class_number], |
api_name="get-image", |
) |
get_random_class_image_btn = gr.Button(value="Get Random Class Image") |
get_input_image_btn.click( |
get_random_class_image, |
inputs=[image_number], |
outputs=[input_image], |
api_name="get-random-class-image", |
) |
patch_numbers = gr.CheckboxGroup( |
label="Image Patch", choices=list(range(n_patches_per_img)) |
) |
top_latent_numbers = gr.CheckboxGroup(label="Top Latents") |
top_latent_numbers = [ |
gr.Number(label=f"Top Latents #{j + 1}", precision=0) |
for j in range(n_sae_latents) |
] |
sae_example_images = [ |
gr.Image(label=f"Latent #{j}, Example #{i + 1}") |
for i in range(n_sae_examples) |
for j in range(n_sae_latents) |
] |
get_sae_examples_btn = gr.Button(value="Get SAE Examples") |
get_sae_examples_btn.click( |
get_sae_examples, |
inputs=[image_number, patch_numbers], |
outputs=sae_example_images + top_latent_numbers, |
api_name="get-sae-examples", |
) |
pred_dist = gr.Label(label="Pred. Dist.") |
get_pred_dist_btn = gr.Button(value="Get Pred. Distribution") |
get_pred_dist_btn.click( |
get_pred_dist, |
inputs=[image_number], |
outputs=[pred_dist], |
api_name="get-preds", |
) |
latent_numbers = [gr.Number(label=f"Latent {i + 1}", precision=0) for i in range(3)] |
value_sliders = [ |
gr.Slider(label=f"Value {i + 1}", minimum=-10, maximum=10) for i in range(3) |
] |
get_modified_dist_btn = gr.Button(value="Get Modified Label") |
get_modified_dist_btn.click( |
get_modified_dist, |
inputs=[image_number, patch_numbers] + latent_numbers + value_sliders, |
outputs=[pred_dist], |
api_name="get-modified", |
) |
if __name__ == "__main__": |
demo.launch() |