|
import torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
from typing import Union, List
|
|
|
|
|
|
def pil2tensor(image: Union[Image.Image, List[Image.Image]]) -> torch.Tensor:
|
|
if isinstance(image, list):
|
|
return torch.cat([pil2tensor(img) for img in image], dim=0)
|
|
|
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
|
|
|
|
|
def np2tensor(img_np: Union[np.ndarray, List[np.ndarray]]) -> torch.Tensor:
|
|
if isinstance(img_np, list):
|
|
return torch.cat([np2tensor(img) for img in img_np], dim=0)
|
|
|
|
return torch.from_numpy(img_np.astype(np.float32) / 255.0).unsqueeze(0)
|
|
|
|
|
|
def tensor2np(tensor: torch.Tensor):
|
|
if len(tensor.shape) == 3:
|
|
return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8)
|
|
else:
|
|
return [np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor]
|
|
|
|
def tensor2pil(image: torch.Tensor) -> List[Image.Image]:
|
|
batch_count = image.size(0) if len(image.shape) > 3 else 1
|
|
if batch_count > 1:
|
|
out = []
|
|
for i in range(batch_count):
|
|
out.extend(tensor2pil(image[i]))
|
|
return out
|
|
|
|
return [
|
|
Image.fromarray(
|
|
np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
|
|
)
|
|
] |