Plonk / utils /image_processing.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
raw
history blame
1.86 kB
import torch
import torch.nn.functional as F
import torchvision
def remap_image_torch(image):
image_torch = ((image + 1) / 2.0) * 255.0
image_torch = torch.clip(image_torch, 0, 255).to(torch.uint8)
return image_torch
class CenterCrop(torch.nn.Module):
"""Crops the given image at the center. Allows to crop to the maximum possible size.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
ratio (str): Desired output ratio of the crop that will do the maximum possible crop with the given ratio.
"""
def __init__(self, size=None, ratio="1:1"):
super().__init__()
self.size = size
self.ratio = ratio
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
PIL Image or Tensor: Cropped image.
"""
if self.size is None:
if isinstance(img, torch.Tensor):
h, w = img.shape[-2:]
else:
w, h = img.size
ratio = self.ratio.split(":")
ratio = float(ratio[0]) / float(ratio[1])
ratioed_w = int(h * ratio)
ratioed_h = int(w / ratio)
if w >= h:
if ratioed_h <= h:
size = (ratioed_h, w)
else:
size = (h, ratioed_w)
else:
if ratioed_w <= w:
size = (h, ratioed_w)
else:
size = (ratioed_h, w)
else:
size = self.size
return torchvision.transforms.functional.center_crop(img, size)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"