File size: 4,145 Bytes
8cc439d
c16c201
f4f8da1
 
 
c16c201
f4f8da1
 
 
 
 
 
 
 
 
 
f3555c8
f4f8da1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e338a46
 
f4f8da1
 
 
 
e338a46
 
f4f8da1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e338a46
 
f4f8da1
 
 
 
 
e338a46
 
f4f8da1
 
5d460ce
 
f3555c8
 
 
 
 
8cc439d
f3555c8
 
 
 
8cc439d
 
 
 
f3555c8
8cc439d
f3555c8
 
8cc439d
 
f3555c8
 
 
 
 
8cc439d
f3555c8
 
f4f8da1
f3555c8
c16c201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import math
from typing import Union, Tuple

import torch

import torch.nn.functional as F
import torchvision.transforms.functional as tvf

# import torchvision.transforms as tvtfms
# # import operator as op
from PIL import Image

# # from torch import nn
# # from timm import create_model


def crop(image: Union[Image.Image, torch.tensor], size: Tuple[int, int]) -> Image:
    """
    Takes a `PIL.Image` and crops it `size` unless one
    dimension is larger than the actual image. Padding
    must be performed afterwards if so.

    Args:
        image (`PIL.Image`):
            An image to perform cropping on
        size (`tuple` of integers):
            A size to crop to, should be in the form
            of (width, height)

    Returns:
        An augmented `PIL.Image`
    """
    top = (image.size[-2] - size[0]) // 2
    left = (image.size[-1] - size[1]) // 2

    top = max(top, 0)
    left = max(left, 0)

    height = min(top + size[0], image.size[-2])
    width = min(left + size[1], image.size[-1])
    return image.crop((top, left, height, width))


def pad(image, size: Tuple[int, int]) -> Image:
    """
    Takes a `PIL.Image` and pads it to `size` with
    zeros.

    Args:
        image (`PIL.Image`):
            An image to perform padding on
        size (`tuple` of integers):
            A size to pad to, should be in the form
            of (width, height)

    Returns:
        An augmented `PIL.Image`
    """
    top = (image.size[-2] - size[0]) // 2
    left = (image.size[-1] - size[1]) // 2

    pad_top = max(-top, 0)
    pad_left = max(-left, 0)

    height, width = (
        max(size[1] - image.size[-2] + top, 0),
        max(size[0] - image.size[-1] + left, 0),
    )
    return tvf.pad(image, [pad_top, pad_left, height, width], padding_mode="constant")


def resized_crop_pad(
    image: Union[Image.Image, torch.tensor],
    size: Tuple[int, int],
    extra_crop_ratio: float = 0.14,
) -> Image:
    """
    Takes a `PIL.Image`, resize it according to the
    `extra_crop_ratio`, and then crops and pads
    it to `size`.

    Args:
        image (`PIL.Image`):
            An image to perform padding on
        size (`tuple` of integers):
            A size to crop and pad to, should be in the form
            of (width, height)
        extra_crop_ratio (float):
            The ratio of size at the edge cropped out. Default 0.14
    """

    maximum_space = max(size[0], size[1])
    extra_space = maximum_space * extra_crop_ratio
    extra_space = math.ceil(extra_space / 8) * 8
    extended_size = (size[0] + extra_space, size[1] + extra_space)
    resized_image = image.resize(extended_size, resample=Image.Resampling.BILINEAR)

    if extended_size != size:
        resized_image = pad(crop(resized_image, size), size)

    return resized_image


def gpu_crop(batch: torch.tensor, size: Tuple[int, int]):
    """
    Crops each image in `batch` to a particular `size`.

    Args:
        batch (array of `torch.Tensor`):
            A batch of images, should be of shape `NxCxWxH`
        size (`tuple` of integers):
            A size to pad to, should be in the form
            of (width, height)

    Returns:
        A batch of cropped images
    """
    # Split into multiple lines for clarity
    affine_matrix = torch.eye(3, device=batch.device).float()
    affine_matrix = affine_matrix.unsqueeze(0)
    affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)
    affine_matrix = affine_matrix.contiguous()[:, :2]

    coords = F.affine_grid(affine_matrix, batch.shape[:2] + size, align_corners=True)

    top_range, bottom_range = coords.min(), coords.max()
    zoom = 1 / (bottom_range - top_range).item() * 2

    resizing_limit = (
        min(batch.shape[-2] / coords.shape[-2], batch.shape[-1] / coords.shape[-1]) / 2
    )

    if resizing_limit > 1 and resizing_limit > zoom:
        batch = F.interpolate(
            batch,
            scale_factor=1 / resizing_limit,
            mode="area",
            recompute_scale_factor=True,
        )
    return F.grid_sample(
        batch, coords, mode="bilinear", padding_mode="reflection", align_corners=True
    )