Samuel Stevens commited on
Commit
e508563
·
1 Parent(s): dc20bdb

cleaning code up

Browse files
Files changed (6) hide show
  1. .gitignore +8 -0
  2. README.md +26 -0
  3. app.py +75 -66
  4. constants.py +776 -0
  5. data.py +187 -0
  6. pyproject.toml +1 -0
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .venv/
3
+ .hypothesis/
4
+ .aider*
5
+ .env
6
+ .DS_Store
7
+ .coverage
8
+ saev.egg-info/
README.md CHANGED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SAEs for Semantic Segmentation
3
+ emoji: 🐨
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.9.1
8
+ python_version: 3.12.8
9
+ app_file: app.py
10
+ pinned: false
11
+ license: mit
12
+ short_description: Interpret semantic segmentation models using SAEs.
13
+ ---
14
+
15
+
16
+ I used [s5cmd](https://github.com/peak/s5cmd) to upload ADE20K to Cloudflare R2.
17
+
18
+ ```sh
19
+ # in images/
20
+ s5cmd --credentials-file ~/.local/etc/cloudflare/r2-credentials --endpoint-url https://6391ae4399fb354a41cab96372935a6e.r2.cloudflarestorage.com \
21
+ cp validation/ s3://saev-ade20k/images/
22
+
23
+ # in annotations/
24
+ s5cmd --credentials-file ~/.local/etc/cloudflare/r2-credentials --endpoint-url https://6391ae4399fb354a41cab96372935a6e.r2.cloudflarestorage.com \
25
+ cp validation/ s3://saev-ade20k/annotations/```
26
+ ```
app.py CHANGED
@@ -1,23 +1,28 @@
 
 
 
 
1
  import os.path
 
2
  import typing
3
- import functools
4
 
5
  import beartype
6
  import einops
7
  import einops.layers.torch
8
  import gradio as gr
 
 
 
 
9
  import torch
10
  from jaxtyping import Float, Int, UInt8, jaxtyped
11
  from PIL import Image
12
  from torch import Tensor
13
 
14
- import saev.activations
15
- import saev.config
16
- import saev.nn
17
- import saev.visuals
18
 
19
- from .. import training
20
- from . import data
21
 
22
  ####################
23
  # Global Constants #
@@ -30,9 +35,6 @@ DEBUG = False
30
  max_frequency = 1e-2
31
  """Maximum frequency. Any feature that fires more than this is ignored."""
32
 
33
- ckpt = "oebd6e6i"
34
- """Which SAE checkpoint to use."""
35
-
36
  n_sae_latents = 3
37
  """Number of SAE latents to show."""
38
 
@@ -51,14 +53,8 @@ CROP_SIZE = (448, 448)
51
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
  """Hardware accelerator, if any."""
53
 
54
- ####################
55
- # Helper Functions #
56
- ####################
57
-
58
-
59
- @beartype.beartype
60
- def load_tensor(path: str) -> Tensor:
61
- return torch.load(path, weights_only=True, map_location="cpu")
62
 
63
 
64
  ##########
@@ -67,43 +63,49 @@ def load_tensor(path: str) -> Tensor:
67
 
68
 
69
  @functools.cache
70
- def load_vit(
71
- model_cfg: modeling.Config,
72
- ) -> tuple[
73
- activations.WrappedVisionTransformer,
74
- typing.Callable,
75
- float,
76
- Float[Tensor, " d_vit"],
77
- ]:
78
  vit = (
79
- saev.activations.WrappedVisionTransformer(model_cfg.wrapped_cfg)
 
 
 
 
 
 
 
80
  .to(DEVICE)
81
  .eval()
82
  )
83
- vit_transform = saev.activations.make_img_transform(
84
- model_cfg.vit_family, model_cfg.vit_ckpt
85
- )
86
- logger.info("Loaded ViT: %s.", model_cfg.key)
87
 
88
- try:
89
- # Normalizing constants
90
- acts_dataset = saev.activations.Dataset(model_cfg.acts_cfg)
91
- logger.info("Loaded dataset norms: %s.", model_cfg.key)
92
- except RuntimeError as err:
93
- logger.warning("Error loading ViT: %s", err)
94
- return None, None, None, None
95
 
96
- return vit, vit_transform, acts_dataset.scalar.item(), acts_dataset.act_mean
97
 
 
 
 
 
 
 
 
 
 
98
 
99
- sae_ckpt_fpath = f"/home/stevens.994/projects/saev/checkpoints/{ckpt}/sae.pt"
100
- sae = saev.nn.load(sae_ckpt_fpath)
101
- sae.to(device).eval()
102
 
 
 
 
 
 
 
 
103
 
104
- head_ckpt_fpath = "/home/stevens.994/projects/saev/checkpoints/contrib/semseg/lr_0_001__wd_0_001/model_step8000.pt"
105
- head = training.load(head_ckpt_fpath)
106
- head = head.to(device).eval()
 
 
107
 
108
 
109
  class RestOfDinoV2(torch.nn.Module):
@@ -136,17 +138,18 @@ rest_of_vit = rest_of_vit.to(device)
136
  ####################
137
 
138
 
139
- ckpt_data_root = (
140
- f"/research/nfs_su_809/workspace/stevens.994/saev/features/{ckpt}/sort_by_patch"
141
- )
142
 
143
- top_img_i = load_tensor(os.path.join(ckpt_data_root, "top_img_i.pt"))
144
- top_values = load_tensor(os.path.join(ckpt_data_root, "top_values.pt"))
145
- sparsity = load_tensor(os.path.join(ckpt_data_root, "sparsity.pt"))
146
 
 
 
 
147
 
148
- mask = torch.ones((sae.cfg.d_sae), dtype=bool)
149
- mask = mask & (sparsity < max_frequency)
 
150
 
151
 
152
  ############
@@ -229,9 +232,13 @@ class SaeActivation(typing.TypedDict):
229
 
230
  @beartype.beartype
231
  def get_image(image_i: int) -> tuple[str, str, int]:
232
- img_sized, labels_sized = data.get_sample(image_i)
 
 
 
 
233
 
234
- return data.pil_to_base64(img_sized), data.pil_to_base64(labels_sized), image_i
235
 
236
 
237
  @beartype.beartype
@@ -243,27 +250,29 @@ def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]
243
  if not patches:
244
  return []
245
 
246
- vit, vit_transform, scalar, mean = load_vit(model_cfg)
247
- if vit is None:
248
- logger.warning("Skipping ViT '%s'", model_name)
249
- return []
250
- sae = load_sae(model_cfg)
251
 
252
- mean = mean.to(DEVICE)
253
- x = vit_transform(img_p)[None, ...].to(DEVICE)
 
254
 
255
  _, vit_acts_BLPD = vit(x)
256
- vit_acts_PD = (vit_acts_BLPD[0, 0, 1:].to(DEVICE).clamp(-1e-5, 1e5) - mean) / scalar
 
 
 
257
 
258
  _, f_x_PS, _ = sae(vit_acts_PD)
259
  # Ignore [CLS] token and get just the requested latents.
260
  acts_SP = einops.rearrange(f_x_PS, "patches n_latents -> n_latents patches")
261
- logger.info("Got SAE activations for '%s'.", model_name)
262
- top_img_i, top_values = load_tensors(model_cfg)
263
- logger.info("Loaded top SAE activations for '%s'.", model_name)
264
 
265
  breakpoint()
266
 
 
 
 
267
  vit_acts_MD = torch.stack([
268
  acts_dataset[image_i * acts_dataset.metadata.n_patches_per_img + i]["act"]
269
  for i in patches
 
1
+ import functools
2
+ import io
3
+ import json
4
+ import logging
5
  import os.path
6
+ import pathlib
7
  import typing
 
8
 
9
  import beartype
10
  import einops
11
  import einops.layers.torch
12
  import gradio as gr
13
+ import saev.activations
14
+ import saev.config
15
+ import saev.nn
16
+ import saev.visuals
17
  import torch
18
  from jaxtyping import Float, Int, UInt8, jaxtyped
19
  from PIL import Image
20
  from torch import Tensor
21
 
22
+ import constants
23
+ import data
 
 
24
 
25
+ logger = logging.getLogger("app.py")
 
26
 
27
  ####################
28
  # Global Constants #
 
35
  max_frequency = 1e-2
36
  """Maximum frequency. Any feature that fires more than this is ignored."""
37
 
 
 
 
38
  n_sae_latents = 3
39
  """Number of SAE latents to show."""
40
 
 
53
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
  """Hardware accelerator, if any."""
55
 
56
+ CWD = pathlib.Path(".")
57
+ """Current working directory."""
 
 
 
 
 
 
58
 
59
 
60
  ##########
 
63
 
64
 
65
  @functools.cache
66
+ def load_vit() -> tuple[saev.activations.WrappedVisionTransformer, typing.Callable]:
 
 
 
 
 
 
 
67
  vit = (
68
+ saev.activations.WrappedVisionTransformer(
69
+ saev.config.Activations(
70
+ model_family="dinov2",
71
+ model_ckpt="dinov2_vitb14_reg",
72
+ layers=[-2],
73
+ n_patches_per_img=256,
74
+ )
75
+ )
76
  .to(DEVICE)
77
  .eval()
78
  )
79
+ vit_transform = saev.activations.make_img_transform("dinov2", "dinov2_vitb14_reg")
80
+ logger.info("Loaded ViT.")
 
 
81
 
82
+ return vit, vit_transform
 
 
 
 
 
 
83
 
 
84
 
85
+ @functools.cache
86
+ def load_sae() -> saev.nn.SparseAutoencoder:
87
+ """
88
+ Loads a sparse autoencoder from disk.
89
+ """
90
+ sae_ckpt_fpath = CWD / "assets" / "sae.pt"
91
+ sae = saev.nn.load(str(sae_ckpt_fpath))
92
+ sae.to(device).eval()
93
+ return sae
94
 
 
 
 
95
 
96
+ @functools.cache
97
+ def load_clf() -> torch.nn.Module:
98
+ # /home/stevens.994/projects/saev/checkpoints/contrib/semseg/lr_0_001__wd_0_001/model_step8000.pt
99
+ head_ckpt_fpath = CWD / "assets" / "clf.pt"
100
+ with open(head_ckpt_fpath, "rb") as fd:
101
+ kwargs = json.loads(fd.readline().decode())
102
+ buffer = io.BytesIO(fd.read())
103
 
104
+ model = torch.nn.Linear(**kwargs)
105
+ state_dict = torch.load(buffer, weights_only=True, map_location=device)
106
+ model.load_state_dict(state_dict)
107
+ model = model.to(device).eval()
108
+ return model
109
 
110
 
111
  class RestOfDinoV2(torch.nn.Module):
 
138
  ####################
139
 
140
 
141
+ @beartype.beartype
142
+ def load_tensor(path: str | pathlib.Path) -> Tensor:
143
+ return torch.load(path, weights_only=True, map_location="cpu")
144
 
 
 
 
145
 
146
+ top_img_i = load_tensor(CWD / "assets" / "top_img_i.pt")
147
+ top_values = load_tensor(CWD / "assets" / "top_values_uint8.pt")
148
+ sparsity = load_tensor(CWD / "assets" / "sparsity.pt")
149
 
150
+
151
+ # mask = torch.ones((sae.cfg.d_sae), dtype=bool)
152
+ # mask = mask & (sparsity < max_frequency)
153
 
154
 
155
  ############
 
232
 
233
  @beartype.beartype
234
  def get_image(image_i: int) -> tuple[str, str, int]:
235
+ sample = data.get_sample(image_i)
236
+ img_sized = data.to_sized(sample["image"])
237
+ seg_sized = data.to_sized(sample["segmentation"])
238
+ seg_u8_sized = data.to_u8(seg_sized)
239
+ seg_img_sized = data.u8_to_img(seg_u8_sized)
240
 
241
+ return data.img_to_base64(img_sized), data.img_to_base64(seg_img_sized), image_i
242
 
243
 
244
  @beartype.beartype
 
250
  if not patches:
251
  return []
252
 
253
+ vit, vit_transform = load_vit()
254
+ sae = load_sae()
 
 
 
255
 
256
+ sample = data.get_sample(image_i)
257
+
258
+ x = vit_transform(sample["image"])[None, ...].to(DEVICE)
259
 
260
  _, vit_acts_BLPD = vit(x)
261
+ vit_acts_PD = (
262
+ vit_acts_BLPD[0, 0, 1:].to(DEVICE).clamp(-1e-5, 1e5)
263
+ - (constants.DINOV2_IMAGENET1K_MEAN).to(DEVICE)
264
+ ) / constants.DINOV2_IMAGENET1K_SCALAR
265
 
266
  _, f_x_PS, _ = sae(vit_acts_PD)
267
  # Ignore [CLS] token and get just the requested latents.
268
  acts_SP = einops.rearrange(f_x_PS, "patches n_latents -> n_latents patches")
269
+ logger.info("Got SAE activations.")
 
 
270
 
271
  breakpoint()
272
 
273
+ top_img_i, top_values = load_tensors(model_cfg)
274
+ logger.info("Loaded top SAE activations for '%s'.", model_name)
275
+
276
  vit_acts_MD = torch.stack([
277
  acts_dataset[image_i * acts_dataset.metadata.n_patches_per_img + i]["act"]
278
  for i in patches
constants.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ DINOV2_IMAGENET1K_SCALAR = 2.0181241035461426
5
+
6
+
7
+ DINOV2_IMAGENET1K_MEAN = torch.tensor([
8
+ 0.1450997292995453,
9
+ -1.0630134344100952,
10
+ -0.3518574833869934,
11
+ -0.38624095916748047,
12
+ -0.4866980314254761,
13
+ -0.28983384370803833,
14
+ 0.9997676014900208,
15
+ -1.231179118156433,
16
+ -0.7889889478683472,
17
+ -0.4450306296348572,
18
+ -0.09231726080179214,
19
+ 0.13243812322616577,
20
+ 0.09571082890033722,
21
+ -0.29342857003211975,
22
+ 0.05933428555727005,
23
+ -0.21923032402992249,
24
+ 0.08959043025970459,
25
+ -0.6981018781661987,
26
+ 0.4853704869747162,
27
+ -0.29948222637176514,
28
+ 0.3107207119464874,
29
+ -0.3812718093395233,
30
+ -0.5013473033905029,
31
+ 2.88395094871521,
32
+ -0.5611682534217834,
33
+ -0.3514024615287781,
34
+ 0.025546086952090263,
35
+ -0.24438244104385376,
36
+ -0.23365195095539093,
37
+ -0.2533780336380005,
38
+ 0.4445696473121643,
39
+ 1.1176759004592896,
40
+ -0.4188934564590454,
41
+ 0.09051182866096497,
42
+ -0.04133417829871178,
43
+ -0.008052834309637547,
44
+ -0.5118610858917236,
45
+ 0.22084011137485504,
46
+ -0.7333402633666992,
47
+ 0.8644523620605469,
48
+ -0.43727627396583557,
49
+ -0.22333095967769623,
50
+ -1.5415295362472534,
51
+ -0.24187016487121582,
52
+ -0.33239033818244934,
53
+ -1.2828021049499512,
54
+ -0.21485395729541779,
55
+ 0.6667488813400269,
56
+ -0.25890952348709106,
57
+ -0.8630414009094238,
58
+ 1.5059994459152222,
59
+ -0.00952776987105608,
60
+ 0.18695995211601257,
61
+ 0.0200128685683012,
62
+ -0.221832275390625,
63
+ 1.2800148725509644,
64
+ -0.1416555792093277,
65
+ 0.61446613073349,
66
+ 0.053658585995435715,
67
+ -0.08877403289079666,
68
+ 1.0190010070800781,
69
+ -0.308927446603775,
70
+ -0.3903353214263916,
71
+ -0.35504740476608276,
72
+ -0.7907304763793945,
73
+ -0.18439480662345886,
74
+ -0.1797204464673996,
75
+ 0.8199827075004578,
76
+ -0.1736353039741516,
77
+ -0.16373644769191742,
78
+ 0.7541728019714355,
79
+ -0.3236996829509735,
80
+ 0.8245170712471008,
81
+ 0.3411649167537689,
82
+ -0.21873517334461212,
83
+ -0.7620954513549805,
84
+ -0.10635858029127121,
85
+ -0.592278003692627,
86
+ 0.8314691781997681,
87
+ -0.2021609991788864,
88
+ -0.24301563203334808,
89
+ -0.03504444658756256,
90
+ -0.061244938522577286,
91
+ -0.36000630259513855,
92
+ -0.38578882813453674,
93
+ -1.2314008474349976,
94
+ -0.3416382968425751,
95
+ 0.5925644636154175,
96
+ 0.32259607315063477,
97
+ 0.13169726729393005,
98
+ -0.131134033203125,
99
+ 0.05763484537601471,
100
+ -0.7130515575408936,
101
+ -0.5685354471206665,
102
+ 0.04428980499505997,
103
+ 0.9245452880859375,
104
+ 0.37724241614341736,
105
+ -0.4426809549331665,
106
+ 0.5091503262519836,
107
+ -0.08006338775157928,
108
+ -0.18945513665676117,
109
+ -0.770736575126648,
110
+ -0.3588047921657562,
111
+ 0.04727765917778015,
112
+ -0.16137081384658813,
113
+ -0.021555813029408455,
114
+ 0.6381930708885193,
115
+ 0.30161890387535095,
116
+ -0.0710706040263176,
117
+ -0.13884945213794708,
118
+ -0.22726555168628693,
119
+ -0.6134527921676636,
120
+ 0.2969088852405548,
121
+ -0.2334780991077423,
122
+ -0.46334928274154663,
123
+ -0.3058214485645294,
124
+ 0.5196799039840698,
125
+ 0.6341780424118042,
126
+ 0.12271945178508759,
127
+ -1.0072089433670044,
128
+ -0.1198473796248436,
129
+ -0.24667270481586456,
130
+ -0.19228138029575348,
131
+ -0.3955901861190796,
132
+ -0.19902971386909485,
133
+ 0.7407659292221069,
134
+ 2.3908257484436035,
135
+ 0.02820657566189766,
136
+ 0.07064329087734222,
137
+ -0.2637694776058197,
138
+ 0.2560977339744568,
139
+ 0.3973558247089386,
140
+ -0.17345857620239258,
141
+ -0.9541534185409546,
142
+ -0.21434728801250458,
143
+ 0.41178393363952637,
144
+ -0.008175228722393513,
145
+ 0.5115303993225098,
146
+ -0.9667210578918457,
147
+ 1.6499103307724,
148
+ -1.8320564031600952,
149
+ 1.1143667697906494,
150
+ 0.24006624519824982,
151
+ -0.02112947776913643,
152
+ -0.4952388405799866,
153
+ 1.1000680923461914,
154
+ -0.4901401102542877,
155
+ 0.22758258879184723,
156
+ -0.6699370741844177,
157
+ 0.6926363706588745,
158
+ -0.5719613432884216,
159
+ 0.008403707295656204,
160
+ 2.0220773220062256,
161
+ -0.1789812445640564,
162
+ -0.8777256011962891,
163
+ 0.3709064722061157,
164
+ -0.2629733681678772,
165
+ 0.08407248556613922,
166
+ -0.27063870429992676,
167
+ 0.09993340820074081,
168
+ -0.3755860924720764,
169
+ 0.07000888139009476,
170
+ 0.3775370419025421,
171
+ 0.5653945207595825,
172
+ -0.11404427886009216,
173
+ -0.06088113784790039,
174
+ -0.0898045226931572,
175
+ 0.19868576526641846,
176
+ 0.14287644624710083,
177
+ -0.669394314289093,
178
+ -0.07882463932037354,
179
+ -0.12379930168390274,
180
+ -0.010277876630425453,
181
+ -0.5625343918800354,
182
+ -0.6508009433746338,
183
+ 0.06929764896631241,
184
+ -2.0470166206359863,
185
+ 1.0193544626235962,
186
+ -0.9747569561004639,
187
+ -0.25624850392341614,
188
+ -0.04412469267845154,
189
+ -0.01941649615764618,
190
+ 0.04781557247042656,
191
+ -0.2561051845550537,
192
+ -0.09596704691648483,
193
+ -1.0529744625091553,
194
+ -0.32774603366851807,
195
+ -0.1931363344192505,
196
+ -0.36885082721710205,
197
+ -0.9351740479469299,
198
+ -0.47905397415161133,
199
+ -0.678762674331665,
200
+ 2.336048126220703,
201
+ 0.26323413848876953,
202
+ -0.36512619256973267,
203
+ -0.3650853633880615,
204
+ -0.8287989497184753,
205
+ 0.5866581201553345,
206
+ -0.420742005109787,
207
+ 0.008546118624508381,
208
+ -0.7811568975448608,
209
+ -0.34993329644203186,
210
+ -0.373068243265152,
211
+ 0.028424998745322227,
212
+ -0.537581205368042,
213
+ -0.15937983989715576,
214
+ -0.5638740062713623,
215
+ -0.4413940906524658,
216
+ -0.05887509509921074,
217
+ -0.12291032075881958,
218
+ -0.26565149426460266,
219
+ -0.23059803247451782,
220
+ -0.2925986349582672,
221
+ 0.04849022254347801,
222
+ -0.4770037531852722,
223
+ 0.040383752435445786,
224
+ -0.8186637759208679,
225
+ -0.062463242560625076,
226
+ -0.3251510262489319,
227
+ -0.4319412112236023,
228
+ -0.34569647908210754,
229
+ 0.9713658690452576,
230
+ -0.25668394565582275,
231
+ -0.37531179189682007,
232
+ 0.5259386301040649,
233
+ -0.06112021207809448,
234
+ 0.06980857998132706,
235
+ -0.38363778591156006,
236
+ -0.1948518007993698,
237
+ -0.7897586822509766,
238
+ -0.600932776927948,
239
+ -0.4269576072692871,
240
+ -0.32002967596054077,
241
+ 0.08897170424461365,
242
+ -0.3079395294189453,
243
+ -0.05779555067420006,
244
+ -0.782086968421936,
245
+ 1.9608103036880493,
246
+ 0.1145739033818245,
247
+ 0.06164107844233513,
248
+ -0.3024725317955017,
249
+ -0.6308553218841553,
250
+ -0.7640243172645569,
251
+ -4.433685302734375,
252
+ -0.31690648198127747,
253
+ -0.019084235653281212,
254
+ -0.09761863201856613,
255
+ -0.029514605179429054,
256
+ -0.5096182823181152,
257
+ 1.112805962562561,
258
+ -0.3302820324897766,
259
+ -0.23730400204658508,
260
+ 0.044646695256233215,
261
+ -0.805400013923645,
262
+ -7.766678333282471,
263
+ -0.2016162872314453,
264
+ -0.5018128752708435,
265
+ 0.6819560527801514,
266
+ -0.2735823392868042,
267
+ -2.2288968563079834,
268
+ -0.36170846223831177,
269
+ -0.7745882868766785,
270
+ 0.4644778370857239,
271
+ 0.2525951564311981,
272
+ -0.22642317414283752,
273
+ -0.5394997596740723,
274
+ -0.5064775347709656,
275
+ -0.5716705918312073,
276
+ 0.19713695347309113,
277
+ -0.5411649942398071,
278
+ -0.17092496156692505,
279
+ 0.45778003334999084,
280
+ 0.6894896030426025,
281
+ -0.21671152114868164,
282
+ -0.9160588383674622,
283
+ -0.10307890176773071,
284
+ 0.11703722178936005,
285
+ -0.7433905601501465,
286
+ -1.5170584917068481,
287
+ 2.163774013519287,
288
+ -1.542649507522583,
289
+ -0.1601075381040573,
290
+ -0.5249155163764954,
291
+ 0.44509291648864746,
292
+ -0.5261067152023315,
293
+ -0.02273540571331978,
294
+ -0.28311043977737427,
295
+ 0.9144242405891418,
296
+ 0.43954336643218994,
297
+ -0.2469814419746399,
298
+ 0.18752114474773407,
299
+ -0.6066163778305054,
300
+ -0.14480441808700562,
301
+ -0.3546217679977417,
302
+ -0.11870954185724258,
303
+ -0.09891107678413391,
304
+ -0.377458781003952,
305
+ 0.33304381370544434,
306
+ -0.156569704413414,
307
+ -0.9730328321456909,
308
+ -0.5034677386283875,
309
+ 0.042613230645656586,
310
+ 0.08271210640668869,
311
+ -0.2368200123310089,
312
+ -0.07397157698869705,
313
+ 0.011974042281508446,
314
+ -0.2115129977464676,
315
+ -0.3752884566783905,
316
+ -0.24985794723033905,
317
+ -0.25223013758659363,
318
+ 1.8311675786972046,
319
+ -0.1650543361902237,
320
+ -0.031050190329551697,
321
+ 0.10702164471149445,
322
+ 0.8963613510131836,
323
+ -0.9483885169029236,
324
+ -0.8156309723854065,
325
+ -1.7132004499435425,
326
+ 0.08163392543792725,
327
+ 0.4886241555213928,
328
+ -0.016470594331622124,
329
+ -0.37671732902526855,
330
+ -0.025105634704232216,
331
+ -0.2695018947124481,
332
+ -0.8450148701667786,
333
+ -0.9802296757698059,
334
+ -0.21868866682052612,
335
+ -0.5872927308082581,
336
+ 1.019242763519287,
337
+ 0.01872517168521881,
338
+ 0.5087792873382568,
339
+ 0.06771136820316315,
340
+ 1.4142885208129883,
341
+ 0.13146139681339264,
342
+ -0.36489933729171753,
343
+ 0.37572142481803894,
344
+ -0.3490581810474396,
345
+ -0.13830198347568512,
346
+ -1.8019393682479858,
347
+ 1.5129766464233398,
348
+ 0.07059808075428009,
349
+ 1.7206473350524902,
350
+ 0.02890164405107498,
351
+ 0.3628808557987213,
352
+ 0.3914141058921814,
353
+ 0.4993101954460144,
354
+ 0.3969678580760956,
355
+ -0.058554816991090775,
356
+ -0.3434300422668457,
357
+ -0.4157616198062897,
358
+ -0.7624511122703552,
359
+ -0.3997197449207306,
360
+ 1.4573990106582642,
361
+ -0.3363801836967468,
362
+ -0.46490129828453064,
363
+ -0.7445303797721863,
364
+ -0.3460237979888916,
365
+ -0.6315308809280396,
366
+ 0.8536337018013,
367
+ -0.08939796686172485,
368
+ -0.21093742549419403,
369
+ -0.08742645382881165,
370
+ -0.020040960982441902,
371
+ 0.09354449808597565,
372
+ -0.809800386428833,
373
+ -0.0018062496092170477,
374
+ -1.0083088874816895,
375
+ 0.3428219258785248,
376
+ 0.012708818539977074,
377
+ -0.3535612225532532,
378
+ 1.9481208324432373,
379
+ 0.013826621696352959,
380
+ -0.026771225035190582,
381
+ 0.18734635412693024,
382
+ 0.9365230798721313,
383
+ 1.247671025339514e-05,
384
+ -0.4420109987258911,
385
+ 0.10769690573215485,
386
+ -0.6858118176460266,
387
+ -0.24754805862903595,
388
+ 1.0027467012405396,
389
+ -0.26436665654182434,
390
+ -0.33883318305015564,
391
+ 0.38209766149520874,
392
+ 0.479579895734787,
393
+ -0.5910238027572632,
394
+ 0.1890297830104828,
395
+ -0.29854580760002136,
396
+ -0.5636696219444275,
397
+ -0.504091739654541,
398
+ -0.32814571261405945,
399
+ -0.748496949672699,
400
+ -0.3217906653881073,
401
+ -0.12439341843128204,
402
+ -0.3949342668056488,
403
+ 0.09739203751087189,
404
+ -0.4254276752471924,
405
+ 0.8690429329872131,
406
+ -0.26380032300949097,
407
+ -1.2738139629364014,
408
+ -0.12694764137268066,
409
+ -0.7331164479255676,
410
+ 0.11337947845458984,
411
+ -0.7573927640914917,
412
+ -0.41507089138031006,
413
+ -0.18960340321063995,
414
+ 1.2390563488006592,
415
+ -0.10859012603759766,
416
+ -0.021934548392891884,
417
+ -0.05041227489709854,
418
+ -0.055214136838912964,
419
+ 0.20024456083774567,
420
+ -0.2689618766307831,
421
+ -0.3135489821434021,
422
+ -0.07520166784524918,
423
+ -0.5906742811203003,
424
+ 0.2828388512134552,
425
+ 0.05117213353514671,
426
+ 1.4600849151611328,
427
+ -0.1967628449201584,
428
+ 0.011182722635567188,
429
+ 0.028878701850771904,
430
+ -0.12146933376789093,
431
+ 0.6056286096572876,
432
+ 0.22920559346675873,
433
+ -0.008979334495961666,
434
+ -0.2874019742012024,
435
+ -0.4887332320213318,
436
+ 0.8754663467407227,
437
+ -0.05393843352794647,
438
+ -0.2956174910068512,
439
+ -0.18953847885131836,
440
+ -0.19063766300678253,
441
+ -0.8141281008720398,
442
+ 0.11052622646093369,
443
+ -0.020359158515930176,
444
+ -0.1262499988079071,
445
+ -1.7762614488601685,
446
+ -0.4864279627799988,
447
+ -0.8644945621490479,
448
+ 0.1278448849916458,
449
+ 1.1127605438232422,
450
+ -0.595068097114563,
451
+ -0.06630692631006241,
452
+ 1.5608118772506714,
453
+ -0.9473971724510193,
454
+ -0.1827543079853058,
455
+ -0.25564679503440857,
456
+ -0.4378860294818878,
457
+ -0.8285927176475525,
458
+ -1.1397618055343628,
459
+ -0.06226593255996704,
460
+ -0.09025824069976807,
461
+ -0.518083393573761,
462
+ -0.893482506275177,
463
+ 0.5022943615913391,
464
+ -0.5922176837921143,
465
+ 0.2571451961994171,
466
+ 0.25571396946907043,
467
+ 0.832092821598053,
468
+ -0.061823680996894836,
469
+ -0.08963754773139954,
470
+ -0.42173218727111816,
471
+ -0.4375287890434265,
472
+ -0.43921560049057007,
473
+ 0.5626742243766785,
474
+ -0.011294233612716198,
475
+ 0.626301646232605,
476
+ -0.28029197454452515,
477
+ 0.15464802086353302,
478
+ -0.7071759700775146,
479
+ -0.0337684191763401,
480
+ -0.20901329815387726,
481
+ -0.29788798093795776,
482
+ 0.6644192934036255,
483
+ -0.049459852278232574,
484
+ 0.039552830159664154,
485
+ -0.2790898084640503,
486
+ 0.3250356614589691,
487
+ -0.12668772041797638,
488
+ -0.46142634749412537,
489
+ -0.35542988777160645,
490
+ -1.1817448139190674,
491
+ 0.007615066133439541,
492
+ -0.43865758180618286,
493
+ -0.16142761707305908,
494
+ -0.37852972745895386,
495
+ -0.582589328289032,
496
+ 0.4371003210544586,
497
+ -0.2603273391723633,
498
+ -0.03284638375043869,
499
+ 0.8895729184150696,
500
+ -0.025997856631875038,
501
+ 0.5761443376541138,
502
+ -0.28437164425849915,
503
+ -0.11191761493682861,
504
+ -0.07794637233018875,
505
+ 0.02127309888601303,
506
+ -0.10069284588098526,
507
+ -0.2177346795797348,
508
+ -1.029278039932251,
509
+ -0.5014596581459045,
510
+ -0.5774326920509338,
511
+ -0.2856050431728363,
512
+ -0.24715296924114227,
513
+ 0.1243511438369751,
514
+ 0.042631667107343674,
515
+ -0.846584677696228,
516
+ -0.7308683395385742,
517
+ -0.09307371079921722,
518
+ -0.35250845551490784,
519
+ 0.12801845371723175,
520
+ -0.5423708558082581,
521
+ -0.22422067821025848,
522
+ 1.574460744857788,
523
+ -0.27640238404273987,
524
+ -0.37266722321510315,
525
+ -0.12533603608608246,
526
+ 0.3177711069583893,
527
+ -0.4530303478240967,
528
+ 0.24940718710422516,
529
+ -0.1272897720336914,
530
+ 0.6882254481315613,
531
+ -0.2153051793575287,
532
+ -0.6189695000648499,
533
+ -0.38704702258110046,
534
+ -0.14360225200653076,
535
+ -0.08159925043582916,
536
+ 0.4714410603046417,
537
+ -0.16035029292106628,
538
+ 0.005880486220121384,
539
+ -0.5742312669754028,
540
+ -0.33733850717544556,
541
+ -0.39702731370925903,
542
+ -0.14614750444889069,
543
+ -0.06936132907867432,
544
+ 0.2528288662433624,
545
+ -0.25900882482528687,
546
+ 0.45907658338546753,
547
+ -0.20694994926452637,
548
+ 0.4083366394042969,
549
+ -0.9925484657287598,
550
+ -0.17098328471183777,
551
+ 0.3215583860874176,
552
+ -0.33823585510253906,
553
+ -0.07112737745046616,
554
+ -0.05322866141796112,
555
+ 0.19237284362316132,
556
+ -0.6257429122924805,
557
+ 0.23328493535518646,
558
+ -0.17247024178504944,
559
+ -0.3362499177455902,
560
+ -0.17041970789432526,
561
+ -0.014526017010211945,
562
+ -0.12138030678033829,
563
+ 0.0698552280664444,
564
+ -0.609315037727356,
565
+ 0.8142863512039185,
566
+ -2.295081615447998,
567
+ -0.07903101295232773,
568
+ -0.48268306255340576,
569
+ -0.2097805291414261,
570
+ -0.4481655955314636,
571
+ -1.059373378753662,
572
+ 0.17675237357616425,
573
+ -0.5335419774055481,
574
+ 0.7713444232940674,
575
+ 0.6341530084609985,
576
+ 1.1411781311035156,
577
+ -0.18365903198719025,
578
+ -0.4029919505119324,
579
+ -0.34328755736351013,
580
+ -1.1935101747512817,
581
+ -0.4249494671821594,
582
+ 0.10720300674438477,
583
+ -0.13509584963321686,
584
+ -0.610278844833374,
585
+ -0.1007867231965065,
586
+ -0.13094481825828552,
587
+ 0.3319343030452728,
588
+ -0.22466504573822021,
589
+ -0.33384865522384644,
590
+ -0.3001727759838104,
591
+ -0.48621413111686707,
592
+ 0.10271137952804565,
593
+ -0.3953743577003479,
594
+ -0.3412061631679535,
595
+ -1.3808176517486572,
596
+ -0.3035687804222107,
597
+ 0.27737119793891907,
598
+ -0.10266303271055222,
599
+ -0.472690224647522,
600
+ 0.03376518189907074,
601
+ -0.2053908109664917,
602
+ -0.46477705240249634,
603
+ -0.0046875146217644215,
604
+ 0.8462978005409241,
605
+ -0.7554765343666077,
606
+ -0.9736349582672119,
607
+ -0.14118513464927673,
608
+ -0.2665828466415405,
609
+ -0.9371470212936401,
610
+ -0.007497116923332214,
611
+ 0.6816821098327637,
612
+ 0.20980679988861084,
613
+ -0.5602611303329468,
614
+ -0.7874919176101685,
615
+ -0.01479698158800602,
616
+ -0.45345690846443176,
617
+ -0.12117742747068405,
618
+ -0.5790822505950928,
619
+ -0.27737149596214294,
620
+ 0.08818025887012482,
621
+ -0.25239622592926025,
622
+ 1.1271374225616455,
623
+ 0.0044799973256886005,
624
+ 0.2183203548192978,
625
+ -2.0634095668792725,
626
+ -0.007129574194550514,
627
+ 0.32677894830703735,
628
+ 0.019878007471561432,
629
+ 0.060301825404167175,
630
+ -0.6844122409820557,
631
+ 0.35185739398002625,
632
+ -0.0028550554998219013,
633
+ -0.5629953145980835,
634
+ 0.06621643155813217,
635
+ -0.043473124504089355,
636
+ -0.3398932218551636,
637
+ -0.1782192587852478,
638
+ -0.24575252830982208,
639
+ -0.20299431681632996,
640
+ -0.3652290999889374,
641
+ -0.9888001680374146,
642
+ -0.30628740787506104,
643
+ 0.6184420585632324,
644
+ -0.33409008383750916,
645
+ 0.20486755669116974,
646
+ -0.8251897692680359,
647
+ -0.08471876382827759,
648
+ -0.5613390803337097,
649
+ 0.057765014469623566,
650
+ 0.5359746813774109,
651
+ -0.7063419818878174,
652
+ 0.28122395277023315,
653
+ -0.004502696450799704,
654
+ -0.6543170213699341,
655
+ 0.04663177207112312,
656
+ -0.05775964632630348,
657
+ -6.37779594399035e-05,
658
+ 0.46121329069137573,
659
+ -0.004464420489966869,
660
+ 1.4332563877105713,
661
+ 0.20597098767757416,
662
+ -0.17879879474639893,
663
+ 0.4316228926181793,
664
+ -1.2352955341339111,
665
+ -0.19363455474376678,
666
+ -0.32174810767173767,
667
+ -0.23037514090538025,
668
+ 0.17044368386268616,
669
+ 0.13070613145828247,
670
+ 1.2171069383621216,
671
+ -1.171966314315796,
672
+ 0.04596511274576187,
673
+ -0.1690378040075302,
674
+ -0.030221890658140182,
675
+ 0.3216114342212677,
676
+ -0.08577033132314682,
677
+ -0.26656001806259155,
678
+ -0.4321160316467285,
679
+ -0.22010475397109985,
680
+ -0.6187731623649597,
681
+ -0.4711909890174866,
682
+ -0.3499036431312561,
683
+ 0.13558903336524963,
684
+ -0.2124641239643097,
685
+ -0.28327351808547974,
686
+ 0.12788993120193481,
687
+ -1.3083688020706177,
688
+ -0.0332779586315155,
689
+ -0.4718656837940216,
690
+ 1.031941533088684,
691
+ -0.07811620831489563,
692
+ -0.5331435799598694,
693
+ -0.2602376341819763,
694
+ -0.8461449146270752,
695
+ 0.18593788146972656,
696
+ 0.5763140320777893,
697
+ -0.45714831352233887,
698
+ -0.1056162416934967,
699
+ 0.2665534019470215,
700
+ -0.4580163061618805,
701
+ -0.25224190950393677,
702
+ -0.2334505170583725,
703
+ -0.6723064184188843,
704
+ 0.12331533432006836,
705
+ 0.054681699723005295,
706
+ -0.14116793870925903,
707
+ -0.10254379361867905,
708
+ 2.0082550048828125,
709
+ -1.4980225563049316,
710
+ 0.00379346776753664,
711
+ -0.8470208644866943,
712
+ 0.06866040825843811,
713
+ -0.3133383095264435,
714
+ -0.20381635427474976,
715
+ -0.03295162320137024,
716
+ 1.1624072790145874,
717
+ -1.2590479850769043,
718
+ -0.5051106810569763,
719
+ -0.5310556292533875,
720
+ 0.11350126564502716,
721
+ -0.5141156315803528,
722
+ 1.0333826541900635,
723
+ -0.5528491735458374,
724
+ -0.6508246064186096,
725
+ -1.0594176054000854,
726
+ -0.03546600416302681,
727
+ -0.0008655009442009032,
728
+ 0.06422116607427597,
729
+ -0.5845358371734619,
730
+ -0.049052149057388306,
731
+ -0.578079104423523,
732
+ -0.46709108352661133,
733
+ -0.6544204354286194,
734
+ -0.13105393946170807,
735
+ -0.12359122931957245,
736
+ 0.19125737249851227,
737
+ -0.9108084440231323,
738
+ -0.24640944600105286,
739
+ -0.5813102126121521,
740
+ -0.2342103123664856,
741
+ 0.645296573638916,
742
+ 0.4200597405433655,
743
+ 1.030412197113037,
744
+ 0.026015933603048325,
745
+ 0.03929654508829117,
746
+ -0.18394766747951508,
747
+ -0.2946997582912445,
748
+ 0.029773380607366562,
749
+ -1.1292797327041626,
750
+ -0.3272054195404053,
751
+ -0.19441728293895721,
752
+ -0.8372487425804138,
753
+ 0.5765964984893799,
754
+ -0.28797629475593567,
755
+ -0.6211466789245605,
756
+ 0.09933445602655411,
757
+ -0.5617806911468506,
758
+ 1.163861870765686,
759
+ 0.1421220600605011,
760
+ -0.790323793888092,
761
+ -0.4003753960132599,
762
+ -0.6941299438476562,
763
+ -0.5033494830131531,
764
+ -0.2234964221715927,
765
+ -0.12398113310337067,
766
+ -0.26237404346466064,
767
+ -0.4991702139377594,
768
+ -0.7963886260986328,
769
+ -0.012063371017575264,
770
+ -1.1415417194366455,
771
+ 0.40668150782585144,
772
+ 0.33048388361930847,
773
+ 1.3195141553878784,
774
+ -0.0008099540136754513,
775
+ -0.06793856620788574,
776
+ ])
data.py CHANGED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import dataclasses
3
+ import functools
4
+ import io
5
+ import logging
6
+ import os.path
7
+ import random
8
+
9
+ import beartype
10
+ import einops.layers.torch
11
+ import numpy as np
12
+ import torchvision.datasets.folder
13
+ from jaxtyping import UInt8, jaxtyped
14
+ from PIL import Image
15
+ from torch import Tensor
16
+ from torchvision.transforms import v2
17
+
18
+ logger = logging.getLogger("data.py")
19
+
20
+
21
+ @beartype.beartype
22
+ class Ade20k:
23
+ @beartype.beartype
24
+ @dataclasses.dataclass(frozen=True)
25
+ class Sample:
26
+ img_path: str
27
+ seg_path: str
28
+ label: str
29
+ target: int
30
+
31
+ samples: list[Sample]
32
+
33
+ def __init__(self, root: str, split: str):
34
+ self.logger = logging.getLogger("ade20k")
35
+ self.root = root
36
+ self.split = split
37
+ self.img_dir = os.path.join(root, "images")
38
+ self.seg_dir = os.path.join(root, "annotations")
39
+
40
+ # Check that we have the right path.
41
+ for subdir in ("images", "annotations"):
42
+ if not os.path.isdir(os.path.join(root, subdir)):
43
+ # Something is missing.
44
+ if os.path.realpath(root).endswith(subdir):
45
+ self.logger.warning(
46
+ "The ADE20K root should contain 'images/' and 'annotations/' directories."
47
+ )
48
+ raise ValueError(f"Can't find path '{os.path.join(root, subdir)}'.")
49
+
50
+ _, split_mapping = torchvision.datasets.folder.find_classes(self.img_dir)
51
+ split_lookup: dict[int, str] = {
52
+ value: key for key, value in split_mapping.items()
53
+ }
54
+ self.loader = torchvision.datasets.folder.default_loader
55
+
56
+ err_msg = f"Split '{split}' not in '{set(split_lookup.values())}'."
57
+ assert split in set(split_lookup.values()), err_msg
58
+
59
+ # Load all the image paths.
60
+ imgs: list[str] = [
61
+ path
62
+ for path, s in torchvision.datasets.folder.make_dataset(
63
+ self.img_dir,
64
+ split_mapping,
65
+ extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
66
+ )
67
+ if split_lookup[s] == split
68
+ ]
69
+
70
+ segs: list[str] = [
71
+ path
72
+ for path, s in torchvision.datasets.folder.make_dataset(
73
+ self.seg_dir,
74
+ split_mapping,
75
+ extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
76
+ )
77
+ if split_lookup[s] == split
78
+ ]
79
+
80
+ # Load all the targets, classes and mappings
81
+ with open(os.path.join(root, "sceneCategories.txt")) as fd:
82
+ img_labels: list[str] = [line.split()[1] for line in fd.readlines()]
83
+
84
+ label_set = sorted(set(img_labels))
85
+ label_to_idx = {label: i for i, label in enumerate(label_set)}
86
+
87
+ self.samples = [
88
+ self.Sample(img_path, seg_path, label, label_to_idx[label])
89
+ for img_path, seg_path, label in zip(imgs, segs, img_labels)
90
+ ]
91
+
92
+ def __getitem__(self, index: int) -> dict[str, object]:
93
+ # Convert to dict.
94
+ sample = dataclasses.asdict(self.samples[index])
95
+
96
+ sample["image"] = self.loader(sample.pop("img_path"))
97
+ sample["segmentation"] = Image.open(sample.pop("seg_path")).convert("L")
98
+ sample["index"] = index
99
+
100
+ return sample
101
+
102
+ def __len__(self) -> int:
103
+ return len(self.samples)
104
+
105
+
106
+ @functools.cache
107
+ def get_dataset() -> Ade20k:
108
+ return Ade20k(
109
+ root="/research/nfs_su_809/workspace/stevens.994/datasets/ade20k/",
110
+ split="validation",
111
+ )
112
+
113
+
114
+ @beartype.beartype
115
+ def get_sample(i: int) -> dict[str, object]:
116
+ dataset = get_dataset()
117
+ return dataset[i]
118
+
119
+
120
+ @jaxtyped(typechecker=beartype.beartype)
121
+ def make_colors() -> UInt8[np.ndarray, "n 3"]:
122
+ values = (0, 51, 102, 153, 204, 255)
123
+ colors = []
124
+ for r in values:
125
+ for g in values:
126
+ for b in values:
127
+ colors.append((r, g, b))
128
+ # Fixed seed
129
+ random.Random(42).shuffle(colors)
130
+ colors = np.array(colors, dtype=np.uint8)
131
+
132
+ # Fixed colors for example 3122
133
+ colors[2] = np.array([201, 249, 255], dtype=np.uint8)
134
+ colors[4] = np.array([151, 204, 4], dtype=np.uint8)
135
+ colors[13] = np.array([104, 139, 88], dtype=np.uint8)
136
+ colors[16] = np.array([54, 48, 32], dtype=np.uint8)
137
+ colors[26] = np.array([45, 125, 210], dtype=np.uint8)
138
+ colors[46] = np.array([238, 185, 2], dtype=np.uint8)
139
+ colors[52] = np.array([88, 91, 86], dtype=np.uint8)
140
+ colors[72] = np.array([76, 46, 5], dtype=np.uint8)
141
+ colors[94] = np.array([12, 15, 10], dtype=np.uint8)
142
+
143
+ return colors
144
+
145
+
146
+ colors = make_colors()
147
+
148
+ resize_transform = v2.Compose([
149
+ v2.Resize((512, 512), interpolation=v2.InterpolationMode.NEAREST),
150
+ v2.CenterCrop((448, 448)),
151
+ ])
152
+
153
+
154
+ @beartype.beartype
155
+ def to_sized(img_raw: Image.Image) -> Image.Image:
156
+ return resize_transform(img_raw)
157
+
158
+
159
+ u8_transform = v2.Compose([
160
+ v2.ToImage(),
161
+ einops.layers.torch.Rearrange("() width height -> width height"),
162
+ ])
163
+
164
+
165
+ @beartype.beartype
166
+ def to_u8(seg_raw: Image.Image) -> UInt8[Tensor, "width height"]:
167
+ return u8_transform(seg_raw)
168
+
169
+
170
+ @jaxtyped(typechecker=beartype.beartype)
171
+ def u8_to_img(map: UInt8[Tensor, "width height"]) -> Image.Image:
172
+ map = map.cpu().numpy()
173
+ width, height = map.shape
174
+ colored = np.zeros((width, height, 3), dtype=np.uint8)
175
+ for i, color in enumerate(colors):
176
+ colored[map == i + 1, :] = color
177
+
178
+ return Image.fromarray(colored)
179
+
180
+
181
+ @beartype.beartype
182
+ def img_to_base64(img: Image.Image) -> str:
183
+ buf = io.BytesIO()
184
+ img.save(buf, format="webp")
185
+ b64 = base64.b64encode(buf.getvalue())
186
+ s64 = b64.decode("utf8")
187
+ return "data:image/webp;base64," + s64
pyproject.toml CHANGED
@@ -9,6 +9,7 @@ dependencies = [
9
  "einops>=0.8.0",
10
  "gradio>=5.3.0",
11
  "numpy>=2.2.2",
 
12
  "torch>=2.6.0",
13
  "torchvision>=0.21.0",
14
  ]
 
9
  "einops>=0.8.0",
10
  "gradio>=5.3.0",
11
  "numpy>=2.2.2",
12
+ "saev",
13
  "torch>=2.6.0",
14
  "torchvision>=0.21.0",
15
  ]