import numpy as np import torch from turbojpeg import ( TurboJPEG, TJPF_GRAY, TJFLAG_PROGRESSIVE, TJFLAG_FASTUPSAMPLE, TJFLAG_FASTDCT, TJPF_RGB, TJPF_BGR, TJSAMP_GRAY, TJSAMP_411, TJSAMP_420, TJSAMP_422, TJSAMP_444, TJSAMP_440, TJSAMP_441, ) class Subsampling: S411 = TJSAMP_411 S420 = TJSAMP_420 S422 = TJSAMP_422 S444 = TJSAMP_444 S440 = TJSAMP_440 S441 = TJSAMP_441 GRAY = TJSAMP_GRAY class Flags: PROGRESSIVE = TJFLAG_PROGRESSIVE FASTUPSAMPLE = TJFLAG_FASTUPSAMPLE FASTDCT = TJFLAG_FASTDCT class PixelFormat: GRAY = TJPF_GRAY RGB = TJPF_RGB BGR = TJPF_BGR class TurboImage: def __init__(self): self.tj = TurboJPEG() self.flags = Flags.PROGRESSIVE self.subsampling_gray = Subsampling.GRAY self.pixel_format_gray = PixelFormat.GRAY self.subsampling_rgb = Subsampling.S420 self.pixel_format_rgb = PixelFormat.RGB def set_subsampling_gray(self, subsampling): self.subsampling_gray = subsampling def set_subsampling_rgb(self, subsampling): self.subsampling_rgb = subsampling def set_pixel_format_gray(self, pixel_format): self.pixel_format_gray = pixel_format def set_pixel_format_rgb(self, pixel_format): self.pixel_format_rgb = pixel_format def set_flags(self, flags): self.flags = flags def encode( self, img, subsampling, pixel_format, quality=90, ): return self.tj.encode( img, quality=quality, flags=self.flags, pixel_format=pixel_format, jpeg_subsample=subsampling, ) @torch.inference_mode() def encode_torch(self, img: torch.Tensor, quality=90): if img.ndim == 2: subsampling = self.subsampling_gray pixel_format = self.pixel_format_gray img = img.clamp(0, 255).cpu().contiguous().numpy().astype(np.uint8) elif img.ndim == 3: subsampling = self.subsampling_rgb pixel_format = self.pixel_format_rgb if img.shape[0] == 3: img = ( img.permute(1, 2, 0) .clamp(0, 255) .cpu() .contiguous() .numpy() .astype(np.uint8) ) elif img.shape[2] == 3: img = img.clamp(0, 255).cpu().contiguous().numpy().astype(np.uint8) else: raise ValueError(f"Unsupported image shape: {img.shape}") else: raise ValueError(f"Unsupported image num dims: {img.ndim}") return self.encode( img, quality=quality, subsampling=subsampling, pixel_format=pixel_format, ) def encode_numpy(self, img: np.ndarray, quality=90): if img.ndim == 2: subsampling = self.subsampling_gray pixel_format = self.pixel_format_gray elif img.ndim == 3: if img.shape[0] == 3: img = np.ascontiguousarray(img.transpose(1, 2, 0)) elif img.shape[2] == 3: img = np.ascontiguousarray(img) else: raise ValueError(f"Unsupported image shape: {img.shape}") subsampling = self.subsampling_rgb pixel_format = self.pixel_format_rgb else: raise ValueError(f"Unsupported image num dims: {img.ndim}") img = img.clip(0, 255).astype(np.uint8) return self.encode( img, quality=quality, subsampling=subsampling, pixel_format=pixel_format )