import base64 import functools import io import logging import random import beartype import einops.layers.torch import numpy as np import requests import torch from jaxtyping import Int, Integer, UInt8, jaxtyped from PIL import Image, ImageDraw from torch import Tensor from torchvision.transforms import v2 logger = logging.getLogger("data.py") R2_URL = "https://pub-129e98faed1048af94c4d4119ea47be7.r2.dev" @beartype.beartype @functools.lru_cache(maxsize=512) def get_img(i: int) -> Image.Image: fpath = f"/images/ADE_val_{i + 1:08}.jpg" url = R2_URL + fpath logger.info("Getting image from '%s'.", url) return Image.open(requests.get(url, stream=True).raw) @beartype.beartype @functools.lru_cache(maxsize=512) def get_seg(i: int) -> Image.Image: fpath = f"/annotations/ADE_val_{i + 1:08}.png" url = R2_URL + fpath logger.info("Getting annotations from '%s'.", url) return Image.open(requests.get(url, stream=True).raw) @jaxtyped(typechecker=beartype.beartype) def make_colors() -> UInt8[np.ndarray, "n 3"]: values = (0, 51, 102, 153, 204, 255) colors = [] for r in values: for g in values: for b in values: colors.append((r, g, b)) # Fixed seed random.Random(42).shuffle(colors) colors = np.array(colors, dtype=np.uint8) # Fixed colors. Must be synced with Segmentation.elm. colors[2] = np.array([201, 249, 255], dtype=np.uint8) colors[2] = np.array([201, 249, 255], dtype=np.uint8) colors[4] = np.array([151, 204, 4], dtype=np.uint8) colors[13] = np.array([104, 139, 88], dtype=np.uint8) colors[16] = np.array([54, 48, 32], dtype=np.uint8) colors[21] = np.array([120, 202, 210], dtype=np.uint8) # water colors[26] = np.array([45, 125, 210], dtype=np.uint8) colors[29] = np.array([116, 142, 84], dtype=np.uint8) colors[46] = np.array([238, 185, 2], dtype=np.uint8) colors[52] = np.array([88, 91, 86], dtype=np.uint8) colors[60] = np.array([72, 99, 156], dtype=np.uint8) # river colors[72] = np.array([76, 46, 5], dtype=np.uint8) colors[94] = np.array([12, 15, 10], dtype=np.uint8) return colors colors = make_colors() resize_transform = v2.Compose([ v2.Resize((512, 512), interpolation=v2.InterpolationMode.NEAREST), v2.CenterCrop((448, 448)), ]) @beartype.beartype def to_sized(img_raw: Image.Image) -> Image.Image: return resize_transform(img_raw) u8_transform = v2.Compose([ v2.ToImage(), einops.layers.torch.Rearrange("() width height -> width height"), ]) @beartype.beartype def to_u8(seg_raw: Image.Image) -> UInt8[Tensor, "width height"]: return u8_transform(seg_raw) @jaxtyped(typechecker=beartype.beartype) def upsample( x_WH: Int[Tensor, "width_ps height_ps"], ) -> UInt8[Tensor, "width_px height_px"]: return ( torch.nn.functional.interpolate( x_WH.view((1, 1, 16, 16)).float(), scale_factor=28, ) .view((448, 448)) .type(torch.uint8) ) @jaxtyped(typechecker=beartype.beartype) def u8_to_overlay( map: Integer[Tensor, "width_ps height_ps"], img: Image.Image, *, opacity: float = 0.5, ) -> Image.Image: iw_np, ih_np = map.shape iw_px, ih_px = img.size pw_px, ph_px = iw_px // iw_np, ih_px // ih_np # Create a transparent overlay overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) # Using semi-transparent red (255, 0, 0, alpha) for p, i in enumerate(map.view(-1).tolist()): x_np, y_np = p % iw_np, p // ih_np draw.rectangle( [ (x_np * pw_px, y_np * ph_px), (x_np * pw_px + pw_px, y_np * ph_px + ph_px), ], fill=(*colors[i - 1], int(opacity * 256)), ) # Composite the original image and the overlay return Image.alpha_composite(img.convert("RGBA"), overlay) @jaxtyped(typechecker=beartype.beartype) def u8_to_img(map: UInt8[Tensor, "width height"]) -> Image.Image: map = map.cpu().numpy() width, height = map.shape colored = np.zeros((width, height, 3), dtype=np.uint8) for i, color in enumerate(colors): colored[map == i + 1, :] = color return Image.fromarray(colored) @jaxtyped(typechecker=beartype.beartype) def to_classes(map: Integer[Tensor, "width height"]) -> list[int]: # Integer is any signed or unsigned int: https://docs.kidger.site/jaxtyping/api/array/#dtype return list(set(map.view(-1).tolist())) @beartype.beartype def img_to_base64(img: Image.Image) -> str: buf = io.BytesIO() img.save(buf, format="webp", lossless=True) b64 = base64.b64encode(buf.getvalue()) s64 = b64.decode("utf8") return "data:image/webp;base64," + s64