import io | |
from PIL import Image | |
import numpy as np | |
import torch | |
class ImageEncoder: | |
def encode_torch(self, img: torch.Tensor, quality=95): | |
if img.ndim == 2: | |
img = ( | |
img[None] | |
.repeat_interleave(3, dim=0) | |
.permute(1, 2, 0) | |
.contiguous() | |
.clamp(0, 255) | |
.type(torch.uint8) | |
) | |
elif img.ndim == 3: | |
if img.shape[0] == 3: | |
img = img.permute(1, 2, 0).contiguous().clamp(0, 255).type(torch.uint8) | |
elif img.shape[2] == 3: | |
img = img.contiguous().clamp(0, 255).type(torch.uint8) | |
else: | |
raise ValueError(f"Unsupported image shape: {img.shape}") | |
else: | |
raise ValueError(f"Unsupported image num dims: {img.ndim}") | |
img = img.cpu().numpy().astype(np.uint8) | |
im = Image.fromarray(img) | |
iob = io.BytesIO() | |
im.save(iob, format="JPEG", quality=quality) | |
iob.seek(0) | |
return iob | |