File size: 3,729 Bytes
d9aea20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import numpy as np
import torch
from turbojpeg import (
TurboJPEG,
TJPF_GRAY,
TJFLAG_PROGRESSIVE,
TJFLAG_FASTUPSAMPLE,
TJFLAG_FASTDCT,
TJPF_RGB,
TJPF_BGR,
TJSAMP_GRAY,
TJSAMP_411,
TJSAMP_420,
TJSAMP_422,
TJSAMP_444,
TJSAMP_440,
TJSAMP_441,
)
class Subsampling:
S411 = TJSAMP_411
S420 = TJSAMP_420
S422 = TJSAMP_422
S444 = TJSAMP_444
S440 = TJSAMP_440
S441 = TJSAMP_441
GRAY = TJSAMP_GRAY
class Flags:
PROGRESSIVE = TJFLAG_PROGRESSIVE
FASTUPSAMPLE = TJFLAG_FASTUPSAMPLE
FASTDCT = TJFLAG_FASTDCT
class PixelFormat:
GRAY = TJPF_GRAY
RGB = TJPF_RGB
BGR = TJPF_BGR
class TurboImage:
def __init__(self):
self.tj = TurboJPEG()
self.flags = Flags.PROGRESSIVE
self.subsampling_gray = Subsampling.GRAY
self.pixel_format_gray = PixelFormat.GRAY
self.subsampling_rgb = Subsampling.S420
self.pixel_format_rgb = PixelFormat.RGB
def set_subsampling_gray(self, subsampling):
self.subsampling_gray = subsampling
def set_subsampling_rgb(self, subsampling):
self.subsampling_rgb = subsampling
def set_pixel_format_gray(self, pixel_format):
self.pixel_format_gray = pixel_format
def set_pixel_format_rgb(self, pixel_format):
self.pixel_format_rgb = pixel_format
def set_flags(self, flags):
self.flags = flags
def encode(
self,
img,
subsampling,
pixel_format,
quality=90,
):
return self.tj.encode(
img,
quality=quality,
flags=self.flags,
pixel_format=pixel_format,
jpeg_subsample=subsampling,
)
@torch.inference_mode()
def encode_torch(self, img: torch.Tensor, quality=90):
if img.ndim == 2:
subsampling = self.subsampling_gray
pixel_format = self.pixel_format_gray
img = img.clamp(0, 255).cpu().contiguous().numpy().astype(np.uint8)
elif img.ndim == 3:
subsampling = self.subsampling_rgb
pixel_format = self.pixel_format_rgb
if img.shape[0] == 3:
img = (
img.permute(1, 2, 0)
.clamp(0, 255)
.cpu()
.contiguous()
.numpy()
.astype(np.uint8)
)
elif img.shape[2] == 3:
img = img.clamp(0, 255).cpu().contiguous().numpy().astype(np.uint8)
else:
raise ValueError(f"Unsupported image shape: {img.shape}")
else:
raise ValueError(f"Unsupported image num dims: {img.ndim}")
return self.encode(
img,
quality=quality,
subsampling=subsampling,
pixel_format=pixel_format,
)
def encode_numpy(self, img: np.ndarray, quality=90):
if img.ndim == 2:
subsampling = self.subsampling_gray
pixel_format = self.pixel_format_gray
elif img.ndim == 3:
if img.shape[0] == 3:
img = np.ascontiguousarray(img.transpose(1, 2, 0))
elif img.shape[2] == 3:
img = np.ascontiguousarray(img)
else:
raise ValueError(f"Unsupported image shape: {img.shape}")
subsampling = self.subsampling_rgb
pixel_format = self.pixel_format_rgb
else:
raise ValueError(f"Unsupported image num dims: {img.ndim}")
img = img.clip(0, 255).astype(np.uint8)
return self.encode(
img, quality=quality, subsampling=subsampling, pixel_format=pixel_format
)
|