Feature Extraction
Transformers
PyTorch
bbsnet
custom_code
thinh-huynh-re commited on
Commit
855518d
·
1 Parent(s): 8ab3009

Upload processor

Browse files
image_processor_bbsnet.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as transforms
6
+ from PIL.Image import Image
7
+ from torch import Tensor
8
+ from transformers.image_processing_utils import BaseImageProcessor
9
+
10
+
11
+ from transformers import VideoMAEImageProcessor, ViTImageProcessor
12
+
13
+ INPUT_IMAGE_SIZE = (352, 352)
14
+
15
+ rgb_transform = transforms.Compose(
16
+ [
17
+ transforms.Resize(INPUT_IMAGE_SIZE),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
20
+ ]
21
+ )
22
+ gt_transform = transforms.ToTensor()
23
+ depth_transform = transforms.Compose(
24
+ [transforms.Resize(INPUT_IMAGE_SIZE), transforms.ToTensor()]
25
+ )
26
+
27
+ # See VideoMAEImageProcessor, ViTImageProcessor for more examples
28
+ class BBSNetImageProcessor(BaseImageProcessor):
29
+ model_input_names = ["bbsnet_preprocessor"]
30
+
31
+ def __init__(self, testsize: Optional[int] = 352, **kwargs) -> None:
32
+ super().__init__(**kwargs)
33
+ self.testsize = testsize
34
+
35
+ def preprocess(
36
+ self,
37
+ inputs: Dict[str, Image], # {'rgb': ..., 'gt': ..., 'depth': ...}
38
+ **kwargs
39
+ ) -> Dict[str, Tensor]:
40
+ rs = dict()
41
+ if "rgb" in inputs:
42
+ rs["rgb"] = rgb_transform(inputs["rgb"]).unsqueeze(0)
43
+ if "gt" in inputs:
44
+ rs["gt"] = gt_transform(inputs["gt"]).unsqueeze(0)
45
+ if "depth" in inputs:
46
+ rs["depth"] = depth_transform(inputs["depth"]).unsqueeze(0)
47
+ return rs
48
+
49
+ def postprocess(
50
+ self, logits: Tensor, size: Tuple[int, int], **kwargs
51
+ ) -> np.ndarray:
52
+ logits: Tensor = F.upsample(
53
+ logits, size=size, mode="bilinear", align_corners=False
54
+ )
55
+ res: np.ndarray = logits.sigmoid().squeeze().data.cpu().numpy()
56
+ res = (res - res.min()) / (res.max() - res.min() + 1e-8)
57
+ return res
preprocessor_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processor_bbsnet.BBSNetImageProcessor"
4
+ },
5
+ "image_processor_type": "BBSNetImageProcessor",
6
+ "testsize": 352
7
+ }