Slep commited on
Commit
ba56501
1 Parent(s): 11f6a98

Upload processor

Browse files
Files changed (2) hide show
  1. preprocessor_config.json +31 -0
  2. processor.py +91 -0
preprocessor_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "processor.CondViTProcessor",
4
+ "AutoProcessor": "processor.CondViTProcessor"
5
+ },
6
+ "bkg_color": 255,
7
+ "categories": [
8
+ "Bags",
9
+ "Feet",
10
+ "Hands",
11
+ "Head",
12
+ "Lower Body",
13
+ "Neck",
14
+ "Outwear",
15
+ "Upper Body",
16
+ "Waist",
17
+ "Whole Body"
18
+ ],
19
+ "image_mean": [
20
+ 0.48145466,
21
+ 0.4578275,
22
+ 0.40821073
23
+ ],
24
+ "image_processor_type": "CondViTProcessor",
25
+ "image_std": [
26
+ 0.26862954,
27
+ 0.26130258,
28
+ 0.27577711
29
+ ],
30
+ "input_resolution": 224
31
+ }
processor.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.image_processing_utils import ImageProcessingMixin, BatchFeature
2
+
3
+ from torchvision.transforms import transforms as tf
4
+ import torchvision.transforms.functional as F
5
+ from PIL import Image
6
+ import torch
7
+
8
+
9
+ class CondViTProcessor(ImageProcessingMixin):
10
+ def __init__(
11
+ self,
12
+ bkg_color=255,
13
+ input_resolution=224,
14
+ image_mean=(0.48145466, 0.4578275, 0.40821073),
15
+ image_std=(0.26862954, 0.26130258, 0.27577711),
16
+ categories=[
17
+ "Bags",
18
+ "Feet",
19
+ "Hands",
20
+ "Head",
21
+ "Lower Body",
22
+ "Neck",
23
+ "Outwear",
24
+ "Upper Body",
25
+ "Waist",
26
+ "Whole Body",
27
+ ],
28
+ **kwargs,
29
+ ):
30
+ super().__init__(**kwargs)
31
+
32
+ self.bkg_color = bkg_color
33
+ self.input_resolution = input_resolution
34
+ self.image_mean = image_mean
35
+ self.image_std = image_std
36
+
37
+ self.categories = categories
38
+
39
+ def square_pad(self, image):
40
+ max_wh = max(image.size)
41
+ p_left, p_top = [(max_wh - s) // 2 for s in image.size]
42
+ p_right, p_bottom = [
43
+ max_wh - (s + pad) for s, pad in zip(image.size, [p_left, p_top])
44
+ ]
45
+ padding = (p_left, p_top, p_right, p_bottom)
46
+ return F.pad(image, padding, self.bkg_color, "constant")
47
+
48
+ def process_img(self, image):
49
+ img = self.square_pad(image)
50
+ img = F.resize(img, self.input_resolution)
51
+ img = F.to_tensor(img)
52
+ img = F.normalize(img, self.image_mean, self.image_std)
53
+ return img
54
+
55
+ def process_cat(self, cat):
56
+ if cat is not None:
57
+ cat = torch.tensor(self.categories.index(cat), dtype=int)
58
+ return cat
59
+
60
+ def __call__(self, images, categories=None):
61
+ """
62
+ Parameters
63
+ ----------
64
+ images : Union[Image.Image, List[Image.Image]]
65
+ Image or list of images to process
66
+ categories : Optional[Union[str, List[str]]]
67
+ Category or list of categories to process
68
+
69
+ Returns
70
+ -------
71
+ BatchFeature
72
+ pixel_values : torch.Tensor
73
+ Processed image tensor (B C H W)
74
+ category : torch.Tensor
75
+ Categories indices (B)
76
+ """
77
+ use_cats = categories is not None
78
+
79
+ # Single Image + Single category
80
+ if isinstance(images, Image.Image):
81
+ images = [images]
82
+ if use_cats:
83
+ categories = [categories]
84
+
85
+ data = {}
86
+ data["pixel_values"] = torch.stack([self.process_img(img) for img in images])
87
+
88
+ if use_cats:
89
+ data["category"] = torch.stack([self.process_cat(c) for c in categories])
90
+
91
+ return BatchFeature(data=data)