Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
import torchvision | |
def remap_image_torch(image): | |
image_torch = ((image + 1) / 2.0) * 255.0 | |
image_torch = torch.clip(image_torch, 0, 255).to(torch.uint8) | |
return image_torch | |
class CenterCrop(torch.nn.Module): | |
"""Crops the given image at the center. Allows to crop to the maximum possible size. | |
Args: | |
size (sequence or int): Desired output size of the crop. If size is an | |
int instead of sequence like (h, w), a square crop (size, size) is | |
made. | |
ratio (str): Desired output ratio of the crop that will do the maximum possible crop with the given ratio. | |
""" | |
def __init__(self, size=None, ratio="1:1"): | |
super().__init__() | |
self.size = size | |
self.ratio = ratio | |
def forward(self, img): | |
""" | |
Args: | |
img (PIL Image or Tensor): Image to be cropped. | |
Returns: | |
PIL Image or Tensor: Cropped image. | |
""" | |
if self.size is None: | |
if isinstance(img, torch.Tensor): | |
h, w = img.shape[-2:] | |
else: | |
w, h = img.size | |
ratio = self.ratio.split(":") | |
ratio = float(ratio[0]) / float(ratio[1]) | |
ratioed_w = int(h * ratio) | |
ratioed_h = int(w / ratio) | |
if w >= h: | |
if ratioed_h <= h: | |
size = (ratioed_h, w) | |
else: | |
size = (h, ratioed_w) | |
else: | |
if ratioed_w <= w: | |
size = (h, ratioed_w) | |
else: | |
size = (ratioed_h, w) | |
else: | |
size = self.size | |
return torchvision.transforms.functional.center_crop(img, size) | |
def __repr__(self) -> str: | |
return f"{self.__class__.__name__}(size={self.size})" | |