jcarnero commited on
Commit
f4f8da1
·
1 Parent(s): 5d460ce

CenterCropPad transform working

Browse files
Files changed (1) hide show
  1. deployment/transforms.py +75 -3
deployment/transforms.py CHANGED
@@ -1,6 +1,78 @@
1
- from typing import Literal
2
- import torchvision.transforms as tvtfms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  def CenterCropPad(size: tuple[Literal[460], Literal[460]]):
6
- return tvtfms.CenterCrop(size)
 
 
 
 
 
1
+ from typing import Literal, Union, Tuple
2
+
3
+ import torch
4
+
5
+ # # import torch.nn.functional as F
6
+ import torchvision.transforms.functional as tvf
7
+
8
+ # import torchvision.transforms as tvtfms
9
+ # # import operator as op
10
+ from PIL import Image
11
+
12
+ # # from torch import nn
13
+ # # from timm import create_model
14
+
15
+
16
+ def center_crop(
17
+ image: Union[Image.Image, torch.tensor], size: Tuple[int, int]
18
+ ) -> Image:
19
+ """
20
+ Takes a `PIL.Image` and crops it `size` unless one
21
+ dimension is larger than the actual image. Padding
22
+ must be performed afterwards if so.
23
+
24
+ Args:
25
+ image (`PIL.Image`):
26
+ An image to perform cropping on
27
+ size (`tuple` of integers):
28
+ A size to crop to, should be in the form
29
+ of (width, height)
30
+
31
+ Returns:
32
+ An augmented `PIL.Image`
33
+ """
34
+ top = (image.shape[-1] - size[0]) // 2
35
+ left = (image.shape[-2] - size[1]) // 2
36
+
37
+ top = max(top, 0)
38
+ left = max(left, 0)
39
+
40
+ height = min(top + size[0], image.shape[-1])
41
+ width = min(left + size[1], image.shape[-2])
42
+ return image.crop((top, left, height, width))
43
+
44
+
45
+ def pad(image, size: Tuple[int, int]) -> Image:
46
+ """
47
+ Takes a `PIL.Image` and pads it to `size` with
48
+ zeros.
49
+
50
+ Args:
51
+ image (`PIL.Image`):
52
+ An image to perform padding on
53
+ size (`tuple` of integers):
54
+ A size to pad to, should be in the form
55
+ of (width, height)
56
+
57
+ Returns:
58
+ An augmented `PIL.Image`
59
+ """
60
+ top = (image.shape[-1] - size[0]) // 2
61
+ left = (image.shape[-2] - size[1]) // 2
62
+
63
+ pad_top = max(-top, 0)
64
+ pad_left = max(-left, 0)
65
+
66
+ height, width = (
67
+ max(size[1] - image.shape[-1] + top, 0),
68
+ max(size[0] - image.shape[-2] + left, 0),
69
+ )
70
+ return tvf.pad(image, [pad_top, pad_left, height, width], padding_mode="constant")
71
 
72
 
73
  def CenterCropPad(size: tuple[Literal[460], Literal[460]]):
74
+ # return tvtfms.CenterCrop(size)
75
+ def _crop_pad(img):
76
+ return pad(center_crop(img, size), size)
77
+
78
+ return _crop_pad