|
import torch
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
|
|
|
|
def remap_image_torch(image):
|
|
image_torch = ((image + 1) / 2.0) * 255.0
|
|
image_torch = torch.clip(image_torch, 0, 255).to(torch.uint8)
|
|
return image_torch
|
|
|
|
|
|
class CenterCrop(torch.nn.Module):
|
|
"""Crops the given image at the center. Allows to crop to the maximum possible size.
|
|
Args:
|
|
size (sequence or int): Desired output size of the crop. If size is an
|
|
int instead of sequence like (h, w), a square crop (size, size) is
|
|
made.
|
|
ratio (str): Desired output ratio of the crop that will do the maximum possible crop with the given ratio.
|
|
"""
|
|
|
|
def __init__(self, size=None, ratio="1:1"):
|
|
super().__init__()
|
|
self.size = size
|
|
self.ratio = ratio
|
|
|
|
def forward(self, img):
|
|
"""
|
|
Args:
|
|
img (PIL Image or Tensor): Image to be cropped.
|
|
|
|
Returns:
|
|
PIL Image or Tensor: Cropped image.
|
|
"""
|
|
if self.size is None:
|
|
if isinstance(img, torch.Tensor):
|
|
h, w = img.shape[-2:]
|
|
else:
|
|
w, h = img.size
|
|
ratio = self.ratio.split(":")
|
|
ratio = float(ratio[0]) / float(ratio[1])
|
|
ratioed_w = int(h * ratio)
|
|
ratioed_h = int(w / ratio)
|
|
if w >= h:
|
|
if ratioed_h <= h:
|
|
size = (ratioed_h, w)
|
|
else:
|
|
size = (h, ratioed_w)
|
|
else:
|
|
if ratioed_w <= w:
|
|
size = (h, ratioed_w)
|
|
else:
|
|
size = (ratioed_h, w)
|
|
else:
|
|
size = self.size
|
|
return torchvision.transforms.functional.center_crop(img, size)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(size={self.size})"
|
|
|