Safetensors
FLUX.1-dev-fp8-flumina / turbojpeg_imgs.py
aredden's picture
initial commit
d9aea20
raw
history blame
3.73 kB
import numpy as np
import torch
from turbojpeg import (
TurboJPEG,
TJPF_GRAY,
TJFLAG_PROGRESSIVE,
TJFLAG_FASTUPSAMPLE,
TJFLAG_FASTDCT,
TJPF_RGB,
TJPF_BGR,
TJSAMP_GRAY,
TJSAMP_411,
TJSAMP_420,
TJSAMP_422,
TJSAMP_444,
TJSAMP_440,
TJSAMP_441,
)
class Subsampling:
S411 = TJSAMP_411
S420 = TJSAMP_420
S422 = TJSAMP_422
S444 = TJSAMP_444
S440 = TJSAMP_440
S441 = TJSAMP_441
GRAY = TJSAMP_GRAY
class Flags:
PROGRESSIVE = TJFLAG_PROGRESSIVE
FASTUPSAMPLE = TJFLAG_FASTUPSAMPLE
FASTDCT = TJFLAG_FASTDCT
class PixelFormat:
GRAY = TJPF_GRAY
RGB = TJPF_RGB
BGR = TJPF_BGR
class TurboImage:
def __init__(self):
self.tj = TurboJPEG()
self.flags = Flags.PROGRESSIVE
self.subsampling_gray = Subsampling.GRAY
self.pixel_format_gray = PixelFormat.GRAY
self.subsampling_rgb = Subsampling.S420
self.pixel_format_rgb = PixelFormat.RGB
def set_subsampling_gray(self, subsampling):
self.subsampling_gray = subsampling
def set_subsampling_rgb(self, subsampling):
self.subsampling_rgb = subsampling
def set_pixel_format_gray(self, pixel_format):
self.pixel_format_gray = pixel_format
def set_pixel_format_rgb(self, pixel_format):
self.pixel_format_rgb = pixel_format
def set_flags(self, flags):
self.flags = flags
def encode(
self,
img,
subsampling,
pixel_format,
quality=90,
):
return self.tj.encode(
img,
quality=quality,
flags=self.flags,
pixel_format=pixel_format,
jpeg_subsample=subsampling,
)
@torch.inference_mode()
def encode_torch(self, img: torch.Tensor, quality=90):
if img.ndim == 2:
subsampling = self.subsampling_gray
pixel_format = self.pixel_format_gray
img = img.clamp(0, 255).cpu().contiguous().numpy().astype(np.uint8)
elif img.ndim == 3:
subsampling = self.subsampling_rgb
pixel_format = self.pixel_format_rgb
if img.shape[0] == 3:
img = (
img.permute(1, 2, 0)
.clamp(0, 255)
.cpu()
.contiguous()
.numpy()
.astype(np.uint8)
)
elif img.shape[2] == 3:
img = img.clamp(0, 255).cpu().contiguous().numpy().astype(np.uint8)
else:
raise ValueError(f"Unsupported image shape: {img.shape}")
else:
raise ValueError(f"Unsupported image num dims: {img.ndim}")
return self.encode(
img,
quality=quality,
subsampling=subsampling,
pixel_format=pixel_format,
)
def encode_numpy(self, img: np.ndarray, quality=90):
if img.ndim == 2:
subsampling = self.subsampling_gray
pixel_format = self.pixel_format_gray
elif img.ndim == 3:
if img.shape[0] == 3:
img = np.ascontiguousarray(img.transpose(1, 2, 0))
elif img.shape[2] == 3:
img = np.ascontiguousarray(img)
else:
raise ValueError(f"Unsupported image shape: {img.shape}")
subsampling = self.subsampling_rgb
pixel_format = self.pixel_format_rgb
else:
raise ValueError(f"Unsupported image num dims: {img.ndim}")
img = img.clip(0, 255).astype(np.uint8)
return self.encode(
img, quality=quality, subsampling=subsampling, pixel_format=pixel_format
)