|
import base64 |
|
import functools |
|
import io |
|
import logging |
|
import random |
|
|
|
import beartype |
|
import einops.layers.torch |
|
import numpy as np |
|
import requests |
|
import torch |
|
from jaxtyping import Int, Integer, UInt8, jaxtyped |
|
from PIL import Image, ImageDraw |
|
from torch import Tensor |
|
from torchvision.transforms import v2 |
|
|
|
logger = logging.getLogger("data.py") |
|
|
|
R2_URL = "https://pub-129e98faed1048af94c4d4119ea47be7.r2.dev" |
|
|
|
|
|
@beartype.beartype |
|
@functools.lru_cache(maxsize=512) |
|
def get_img(i: int) -> Image.Image: |
|
fpath = f"/images/ADE_val_{i + 1:08}.jpg" |
|
url = R2_URL + fpath |
|
logger.info("Getting image from '%s'.", url) |
|
return Image.open(requests.get(url, stream=True).raw) |
|
|
|
|
|
@beartype.beartype |
|
@functools.lru_cache(maxsize=512) |
|
def get_seg(i: int) -> Image.Image: |
|
fpath = f"/annotations/ADE_val_{i + 1:08}.png" |
|
url = R2_URL + fpath |
|
logger.info("Getting annotations from '%s'.", url) |
|
return Image.open(requests.get(url, stream=True).raw) |
|
|
|
|
|
@jaxtyped(typechecker=beartype.beartype) |
|
def make_colors() -> UInt8[np.ndarray, "n 3"]: |
|
values = (0, 51, 102, 153, 204, 255) |
|
colors = [] |
|
for r in values: |
|
for g in values: |
|
for b in values: |
|
colors.append((r, g, b)) |
|
|
|
random.Random(42).shuffle(colors) |
|
colors = np.array(colors, dtype=np.uint8) |
|
|
|
|
|
colors[2] = np.array([201, 249, 255], dtype=np.uint8) |
|
colors[2] = np.array([201, 249, 255], dtype=np.uint8) |
|
colors[4] = np.array([151, 204, 4], dtype=np.uint8) |
|
colors[13] = np.array([104, 139, 88], dtype=np.uint8) |
|
colors[16] = np.array([54, 48, 32], dtype=np.uint8) |
|
colors[21] = np.array([120, 202, 210], dtype=np.uint8) |
|
colors[26] = np.array([45, 125, 210], dtype=np.uint8) |
|
colors[29] = np.array([116, 142, 84], dtype=np.uint8) |
|
colors[46] = np.array([238, 185, 2], dtype=np.uint8) |
|
colors[52] = np.array([88, 91, 86], dtype=np.uint8) |
|
colors[60] = np.array([72, 99, 156], dtype=np.uint8) |
|
colors[72] = np.array([76, 46, 5], dtype=np.uint8) |
|
colors[94] = np.array([12, 15, 10], dtype=np.uint8) |
|
|
|
return colors |
|
|
|
|
|
colors = make_colors() |
|
|
|
resize_transform = v2.Compose([ |
|
v2.Resize((512, 512), interpolation=v2.InterpolationMode.NEAREST), |
|
v2.CenterCrop((448, 448)), |
|
]) |
|
|
|
|
|
@beartype.beartype |
|
def to_sized(img_raw: Image.Image) -> Image.Image: |
|
return resize_transform(img_raw) |
|
|
|
|
|
u8_transform = v2.Compose([ |
|
v2.ToImage(), |
|
einops.layers.torch.Rearrange("() width height -> width height"), |
|
]) |
|
|
|
|
|
@beartype.beartype |
|
def to_u8(seg_raw: Image.Image) -> UInt8[Tensor, "width height"]: |
|
return u8_transform(seg_raw) |
|
|
|
|
|
@jaxtyped(typechecker=beartype.beartype) |
|
def upsample( |
|
x_WH: Int[Tensor, "width_ps height_ps"], |
|
) -> UInt8[Tensor, "width_px height_px"]: |
|
return ( |
|
torch.nn.functional.interpolate( |
|
x_WH.view((1, 1, 16, 16)).float(), |
|
scale_factor=28, |
|
) |
|
.view((448, 448)) |
|
.type(torch.uint8) |
|
) |
|
|
|
|
|
@jaxtyped(typechecker=beartype.beartype) |
|
def u8_to_overlay( |
|
map: Integer[Tensor, "width_ps height_ps"], |
|
img: Image.Image, |
|
*, |
|
opacity: float = 0.5, |
|
) -> Image.Image: |
|
iw_np, ih_np = map.shape |
|
iw_px, ih_px = img.size |
|
pw_px, ph_px = iw_px // iw_np, ih_px // ih_np |
|
|
|
|
|
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) |
|
draw = ImageDraw.Draw(overlay) |
|
|
|
|
|
for p, i in enumerate(map.view(-1).tolist()): |
|
x_np, y_np = p % iw_np, p // ih_np |
|
draw.rectangle( |
|
[ |
|
(x_np * pw_px, y_np * ph_px), |
|
(x_np * pw_px + pw_px, y_np * ph_px + ph_px), |
|
], |
|
fill=(*colors[i - 1], int(opacity * 256)), |
|
) |
|
|
|
|
|
return Image.alpha_composite(img.convert("RGBA"), overlay) |
|
|
|
|
|
@jaxtyped(typechecker=beartype.beartype) |
|
def u8_to_img(map: UInt8[Tensor, "width height"]) -> Image.Image: |
|
map = map.cpu().numpy() |
|
width, height = map.shape |
|
colored = np.zeros((width, height, 3), dtype=np.uint8) |
|
for i, color in enumerate(colors): |
|
colored[map == i + 1, :] = color |
|
|
|
return Image.fromarray(colored) |
|
|
|
|
|
@jaxtyped(typechecker=beartype.beartype) |
|
def to_classes(map: Integer[Tensor, "width height"]) -> list[int]: |
|
|
|
return list(set(map.view(-1).tolist())) |
|
|
|
|
|
@beartype.beartype |
|
def img_to_base64(img: Image.Image) -> str: |
|
buf = io.BytesIO() |
|
img.save(buf, format="webp", lossless=True) |
|
b64 = base64.b64encode(buf.getvalue()) |
|
s64 = b64.decode("utf8") |
|
return "data:image/webp;base64," + s64 |
|
|