|
import functools |
|
import io |
|
import json |
|
import logging |
|
import math |
|
import os |
|
import pathlib |
|
import random |
|
|
|
import beartype |
|
import einops.layers.torch |
|
import gradio as gr |
|
import numpy as np |
|
import open_clip |
|
import requests |
|
import saev.nn |
|
import torch |
|
from jaxtyping import Float, jaxtyped |
|
from PIL import Image, ImageDraw |
|
from torch import Tensor |
|
from torchvision.transforms import v2 |
|
|
|
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" |
|
logging.basicConfig(level=logging.INFO, format=log_format) |
|
logger = logging.getLogger("app.py") |
|
|
|
|
|
|
|
|
|
|
|
|
|
DEBUG = True |
|
"""Whether we are debugging.""" |
|
|
|
n_sae_latents = 3 |
|
"""Number of SAE latents to show.""" |
|
|
|
n_sae_examples = 4 |
|
"""Number of SAE examples per latent to show.""" |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
"""Hardware accelerator, if any.""" |
|
|
|
vit_ckpt = "ViT-B-16/openai" |
|
"""CLIP checkpoint.""" |
|
|
|
n_patches_per_img: int = 196 |
|
"""Number of patches per image in vit_ckpt.""" |
|
|
|
max_frequency = 1e-1 |
|
"""Maximum frequency. Any feature that fires more than this is ignored.""" |
|
|
|
CWD = pathlib.Path(__file__).parent |
|
|
|
r2_url = "https://pub-289086e849214430853bc87bd8964988.r2.dev/" |
|
|
|
|
|
logger.info("Set global constants.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@beartype.beartype |
|
def get_cache_dir() -> str: |
|
""" |
|
Get cache directory from environment variables, defaulting to the current working directory (.) |
|
|
|
Returns: |
|
A path to a cache directory (might not exist yet). |
|
""" |
|
cache_dir = "" |
|
for var in ("HF_HOME", "HF_HUB_CACHE"): |
|
cache_dir = cache_dir or os.environ.get(var, "") |
|
return cache_dir or "." |
|
|
|
|
|
@beartype.beartype |
|
def load_model(fpath: str | pathlib.Path, *, device: str = "cpu") -> torch.nn.Module: |
|
""" |
|
Loads a linear layer from disk. |
|
""" |
|
with open(fpath, "rb") as fd: |
|
kwargs = json.loads(fd.readline().decode()) |
|
buffer = io.BytesIO(fd.read()) |
|
|
|
model = torch.nn.Linear(**kwargs) |
|
state_dict = torch.load(buffer, weights_only=True, map_location=device) |
|
model.load_state_dict(state_dict) |
|
model = model.to(device) |
|
return model |
|
|
|
|
|
@beartype.beartype |
|
@functools.lru_cache(maxsize=512) |
|
def get_dataset_img(i: int) -> Image.Image: |
|
return Image.open(requests.get(r2_url + image_fpaths[i], stream=True).raw) |
|
|
|
|
|
@beartype.beartype |
|
def make_img( |
|
img: Image.Image, |
|
patches: Float[Tensor, " n_patches"], |
|
*, |
|
upper: int | None = None, |
|
) -> Image.Image: |
|
|
|
resize_size_px = (512, 512) |
|
resize_w_px, resize_h_px = resize_size_px |
|
crop_size_px = (448, 448) |
|
crop_w_px, crop_h_px = crop_size_px |
|
crop_coords_px = ( |
|
(resize_w_px - crop_w_px) // 2, |
|
(resize_h_px - crop_h_px) // 2, |
|
(resize_w_px + crop_w_px) // 2, |
|
(resize_h_px + crop_h_px) // 2, |
|
) |
|
|
|
img = img.resize(resize_size_px).crop(crop_coords_px) |
|
img = add_highlights(img, patches.numpy(), upper=upper, opacity=0.5) |
|
return img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@jaxtyped(typechecker=beartype.beartype) |
|
class SplitClip(torch.nn.Module): |
|
def __init__(self, *, n_end_layers: int): |
|
super().__init__() |
|
|
|
if vit_ckpt.startswith("hf-hub:"): |
|
clip, _ = open_clip.create_model_from_pretrained( |
|
vit_ckpt, cache_dir=get_cache_dir() |
|
) |
|
else: |
|
arch, ckpt = vit_ckpt.split("/") |
|
clip, _ = open_clip.create_model_from_pretrained( |
|
arch, pretrained=ckpt, cache_dir=get_cache_dir() |
|
) |
|
model = clip.visual |
|
model.proj = None |
|
model.output_tokens = True |
|
self.vit = model.eval() |
|
assert not isinstance(self.vit, open_clip.timm_model.TimmModel) |
|
|
|
self.n_end_layers = n_end_layers |
|
|
|
@staticmethod |
|
def _expand_token(token, batch_size: int): |
|
return token.view(1, 1, -1).expand(batch_size, -1, -1) |
|
|
|
def forward_start(self, x: Float[Tensor, "batch channels width height"]): |
|
x = self.vit.conv1(x) |
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
x = x.permute(0, 2, 1) |
|
|
|
|
|
x = torch.cat( |
|
[self._expand_token(self.vit.class_embedding, x.shape[0]).to(x.dtype), x], |
|
dim=1, |
|
) |
|
|
|
x = x + self.vit.positional_embedding.to(x.dtype) |
|
|
|
x = self.vit.patch_dropout(x) |
|
x = self.vit.ln_pre(x) |
|
for r in self.vit.transformer.resblocks[: -self.n_end_layers]: |
|
x = r(x) |
|
return x |
|
|
|
def forward_end(self, x: Float[Tensor, "batch n_patches dim"]): |
|
for r in self.vit.transformer.resblocks[-self.n_end_layers :]: |
|
x = r(x) |
|
|
|
x = self.vit.ln_post(x) |
|
pooled, _ = self.vit._global_pool(x) |
|
if self.vit.proj is not None: |
|
pooled = pooled @ self.vit.proj |
|
|
|
return pooled |
|
|
|
|
|
|
|
split_vit = SplitClip(n_end_layers=1) |
|
split_vit = split_vit.to(device) |
|
logger.info("Initialized CLIP ViT.") |
|
|
|
|
|
clf_ckpt_fpath = CWD / "ckpts" / "clf.pt" |
|
clf = load_model(clf_ckpt_fpath) |
|
clf = clf.to(device).eval() |
|
logger.info("Loaded linear classifier.") |
|
|
|
|
|
sae_ckpt_fpath = CWD / "ckpts" / "sae.pt" |
|
sae = saev.nn.load(sae_ckpt_fpath.as_posix()) |
|
sae.to(device).eval() |
|
logger.info("Loaded SAE.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
human_transform = v2.Compose([ |
|
v2.Resize((512, 512), interpolation=v2.InterpolationMode.NEAREST), |
|
v2.CenterCrop((448, 448)), |
|
v2.ToImage(), |
|
einops.layers.torch.Rearrange("channels width height -> width height channels"), |
|
]) |
|
|
|
arch, ckpt = vit_ckpt.split("/") |
|
_, vit_transform = open_clip.create_model_from_pretrained( |
|
arch, pretrained=ckpt, cache_dir=get_cache_dir() |
|
) |
|
|
|
|
|
with open(CWD / "data" / "image_fpaths.json") as fd: |
|
image_fpaths = json.load(fd) |
|
|
|
|
|
with open(CWD / "data" / "image_labels.json") as fd: |
|
image_labels = json.load(fd) |
|
|
|
|
|
logger.info("Loaded all datasets.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
@beartype.beartype |
|
def load_tensor(path: str | pathlib.Path) -> Tensor: |
|
return torch.load(path, weights_only=True, map_location="cpu") |
|
|
|
|
|
top_img_i = load_tensor(CWD / "data" / "top_img_i.pt") |
|
top_values = load_tensor(CWD / "data" / "top_values_uint8.pt") |
|
sparsity = load_tensor(CWD / "data" / "sparsity.pt") |
|
|
|
|
|
mask = torch.ones((sae.cfg.d_sae), dtype=bool) |
|
mask = mask & (sparsity < max_frequency) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@beartype.beartype |
|
def get_image(image_i: int) -> list[Image.Image | int]: |
|
image = get_dataset_img(image_i) |
|
image = human_transform(image) |
|
return [Image.fromarray(image.numpy()), image_labels[image_i]] |
|
|
|
|
|
@beartype.beartype |
|
def get_random_class_image(cls: int) -> Image.Image: |
|
indices = [i for i, tgt in enumerate(image_labels) if tgt == cls] |
|
i = random.choice(indices) |
|
|
|
image = get_dataset_img(i) |
|
image = human_transform(image) |
|
return Image.fromarray(image.numpy()) |
|
|
|
|
|
@torch.inference_mode |
|
def get_sae_examples( |
|
image_i: int, patches: list[int] |
|
) -> list[None | Image.Image | int]: |
|
""" |
|
Given a particular cell, returns some highlighted images showing what feature fires most on this cell. |
|
""" |
|
if not patches: |
|
return [None] * 12 + [-1] * 3 |
|
|
|
logger.info("Getting SAE examples for patches %s.", patches) |
|
|
|
img = get_dataset_img(image_i) |
|
x = vit_transform(img)[None, ...].to(device) |
|
x_BPD = split_vit.forward_start(x) |
|
|
|
vit_acts_MD = x_BPD[0, [p + 1 for p in patches]].to(device) |
|
|
|
_, f_x_MS, _ = sae(vit_acts_MD) |
|
f_x_S = f_x_MS.sum(axis=0) |
|
|
|
latents = torch.argsort(f_x_S, descending=True).cpu() |
|
latents = latents[mask[latents]][:n_sae_latents].tolist() |
|
|
|
images = [] |
|
for latent in latents: |
|
img_patch_pairs, seen_i_im = [], set() |
|
for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]): |
|
if i_im in seen_i_im: |
|
continue |
|
|
|
example_img = get_dataset_img(i_im) |
|
img_patch_pairs.append((example_img, values_p)) |
|
seen_i_im.add(i_im) |
|
|
|
|
|
upper = None |
|
if top_values[latent].numel() > 0: |
|
upper = top_values[latent].max().item() |
|
|
|
latent_images = [ |
|
make_img(img, patches.to(float), upper=upper) |
|
for img, patches in img_patch_pairs[:n_sae_examples] |
|
] |
|
|
|
while len(latent_images) < n_sae_examples: |
|
latent_images += [None] |
|
|
|
images.extend(latent_images) |
|
|
|
return images + latents |
|
|
|
|
|
@torch.inference_mode |
|
def get_pred_dist(i: int) -> dict[int, float]: |
|
img = get_dataset_img(i) |
|
x = vit_transform(img)[None, ...].to(device) |
|
x_BPD = split_vit.forward_start(x) |
|
x_BD = split_vit.forward_end(x_BPD) |
|
|
|
logits_BC = clf(x_BD) |
|
|
|
probs = torch.nn.functional.softmax(logits_BC[0], dim=0).cpu().tolist() |
|
return {i: prob for i, prob in enumerate(probs)} |
|
|
|
|
|
@torch.inference_mode |
|
def get_modified_dist( |
|
image_i: int, |
|
patches: list[int], |
|
latent1: int, |
|
latent2: int, |
|
latent3: int, |
|
value1: float, |
|
value2: float, |
|
value3: float, |
|
) -> dict[int, float]: |
|
img = get_dataset_img(image_i) |
|
x = vit_transform(img)[None, ...].to(device) |
|
x_BPD = split_vit.forward_start(x) |
|
|
|
cls_B1D, x_BPD = x_BPD[:, :1, :], x_BPD[:, 1:, :] |
|
|
|
x_hat_BPD, f_x_BPS, _ = sae(x_BPD) |
|
err_BPD = x_BPD - x_hat_BPD |
|
|
|
values = torch.tensor( |
|
[ |
|
unscaled(value, top_values[latent].max().item()) |
|
for value, latent in [ |
|
(value1, latent1), |
|
(value2, latent2), |
|
(value3, latent3), |
|
] |
|
], |
|
device=device, |
|
) |
|
patches = torch.tensor(patches, device=device) |
|
latents = torch.tensor([latent1, latent2, latent3], device=device) |
|
f_x_BPS[:, patches[:, None], latents[None, :]] = values |
|
|
|
|
|
modified_x_hat_BPD = ( |
|
einops.einsum( |
|
f_x_BPS, |
|
sae.W_dec, |
|
"batch patches d_sae, d_sae d_vit -> batch patches d_vit", |
|
) |
|
+ sae.b_dec |
|
) |
|
|
|
modified_BPD = torch.cat([cls_B1D, err_BPD + modified_x_hat_BPD], axis=1) |
|
|
|
modified_BD = split_vit.forward_end(modified_BPD) |
|
logits_BC = clf(modified_BD) |
|
|
|
probs = torch.nn.functional.softmax(logits_BC[0], dim=0).cpu().tolist() |
|
return {i: prob for i, prob in enumerate(probs)} |
|
|
|
|
|
@beartype.beartype |
|
def unscaled(x: float | int, max_obs: float | int) -> float: |
|
"""Scale from [-20, 20] to [20 * -max_obs, 20 * max_obs].""" |
|
return map_range(x, (-20.0, 20.0), (-20.0 * max_obs, 20.0 * max_obs)) |
|
|
|
|
|
@beartype.beartype |
|
def map_range( |
|
x: float | int, |
|
domain: tuple[float | int, float | int], |
|
range: tuple[float | int, float | int], |
|
): |
|
a, b = domain |
|
c, d = range |
|
if not (a <= x <= b): |
|
raise ValueError(f"x={x:.3f} must be in {[a, b]}.") |
|
return c + (x - a) * (d - c) / (b - a) |
|
|
|
|
|
@jaxtyped(typechecker=beartype.beartype) |
|
def add_highlights( |
|
img: Image.Image, |
|
patches: Float[np.ndarray, " n_patches"], |
|
*, |
|
upper: int | None = None, |
|
opacity: float = 0.9, |
|
) -> Image.Image: |
|
if not len(patches): |
|
return img |
|
|
|
iw_np, ih_np = int(math.sqrt(len(patches))), int(math.sqrt(len(patches))) |
|
iw_px, ih_px = img.size |
|
pw_px, ph_px = iw_px // iw_np, ih_px // ih_np |
|
assert iw_np * ih_np == len(patches) |
|
|
|
|
|
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) |
|
draw = ImageDraw.Draw(overlay) |
|
|
|
|
|
for p, val in enumerate(patches): |
|
assert upper is not None |
|
val /= upper + 1e-9 |
|
x_np, y_np = p % iw_np, p // ih_np |
|
draw.rectangle( |
|
[ |
|
(x_np * pw_px, y_np * ph_px), |
|
(x_np * pw_px + pw_px, y_np * ph_px + ph_px), |
|
], |
|
fill=(int(val * 256), 0, 0, int(opacity * val * 256)), |
|
) |
|
|
|
|
|
return Image.alpha_composite(img.convert("RGBA"), overlay) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
image_number = gr.Number(label="Test Example", precision=0) |
|
class_number = gr.Number(label="Test Class", precision=0) |
|
input_image = gr.Image(label="Input Image") |
|
get_input_image_btn = gr.Button(value="Get Input Image") |
|
get_input_image_btn.click( |
|
get_image, |
|
inputs=[image_number], |
|
outputs=[input_image, class_number], |
|
api_name="get-image", |
|
) |
|
get_random_class_image_btn = gr.Button(value="Get Random Class Image") |
|
get_input_image_btn.click( |
|
get_random_class_image, |
|
inputs=[image_number], |
|
outputs=[input_image], |
|
api_name="get-random-class-image", |
|
) |
|
|
|
patch_numbers = gr.CheckboxGroup( |
|
label="Image Patch", choices=list(range(n_patches_per_img)) |
|
) |
|
top_latent_numbers = gr.CheckboxGroup(label="Top Latents") |
|
top_latent_numbers = [ |
|
gr.Number(label=f"Top Latents #{j + 1}", precision=0) |
|
for j in range(n_sae_latents) |
|
] |
|
sae_example_images = [ |
|
gr.Image(label=f"Latent #{j}, Example #{i + 1}") |
|
for i in range(n_sae_examples) |
|
for j in range(n_sae_latents) |
|
] |
|
get_sae_examples_btn = gr.Button(value="Get SAE Examples") |
|
get_sae_examples_btn.click( |
|
get_sae_examples, |
|
inputs=[image_number, patch_numbers], |
|
outputs=sae_example_images + top_latent_numbers, |
|
api_name="get-sae-examples", |
|
concurrency_limit=16, |
|
) |
|
|
|
pred_dist = gr.Label(label="Pred. Dist.") |
|
get_pred_dist_btn = gr.Button(value="Get Pred. Distribution") |
|
get_pred_dist_btn.click( |
|
get_pred_dist, |
|
inputs=[image_number], |
|
outputs=[pred_dist], |
|
api_name="get-preds", |
|
) |
|
|
|
latent_numbers = [gr.Number(label=f"Latent {i + 1}", precision=0) for i in range(3)] |
|
value_sliders = [ |
|
gr.Slider(label=f"Value {i + 1}", minimum=-10, maximum=10) for i in range(3) |
|
] |
|
get_modified_dist_btn = gr.Button(value="Get Modified Label") |
|
get_modified_dist_btn.click( |
|
get_modified_dist, |
|
inputs=[image_number, patch_numbers] + latent_numbers + value_sliders, |
|
outputs=[pred_dist], |
|
api_name="get-modified", |
|
concurrency_limit=16, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|