linxy97 commited on
Commit
266eeb8
1 Parent(s): 694e47c

Upload MyPipe.py

Browse files
Files changed (1) hide show
  1. MyPipe.py +67 -63
MyPipe.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import torch, os
3
  import torch.nn.functional as F
4
  from torchvision.transforms.functional import normalize
@@ -7,70 +6,75 @@ from transformers import Pipeline
7
  from skimage import io
8
  from PIL import Image
9
 
 
10
  class RMBGPipe(Pipeline):
11
- def __init__(self,**kwargs):
12
- Pipeline.__init__(self,**kwargs)
13
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- self.model.to(self.device)
15
- self.model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- def _sanitize_parameters(self, **kwargs):
18
- # parse parameters
19
- preprocess_kwargs = {}
20
- postprocess_kwargs = {}
21
- if "model_input_size" in kwargs :
22
- preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
23
- if "out_name" in kwargs:
24
- postprocess_kwargs["out_name"] = kwargs["out_name"]
25
- return preprocess_kwargs, {}, postprocess_kwargs
26
 
27
- def preprocess(self,im_path:str,model_input_size: list=[1024,1024]):
28
- # preprocess the input
29
- orig_im = io.imread(im_path)
30
- orig_im_size = orig_im.shape[0:2]
31
- image = self.preprocess_image(orig_im, model_input_size).to(self.device)
32
- inputs = {
33
- "image":image,
34
- "orig_im_size":orig_im_size,
35
- "im_path" : im_path
36
- }
37
- return inputs
 
 
38
 
39
- def _forward(self,inputs):
40
- result = self.model(inputs.pop("image"))
41
- inputs["result"] = result
42
- return inputs
43
- def postprocess(self,inputs,out_name = ""):
44
- result = inputs.pop("result")
45
- orig_im_size = inputs.pop("orig_im_size")
46
- im_path = inputs.pop("im_path")
47
- result_image = self.postprocess_image(result[0][0], orig_im_size)
48
- if out_name != "" :
49
- # if out_name is specified we save the image using that name
50
- pil_im = Image.fromarray(result_image)
51
- no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
52
- orig_image = Image.open(im_path)
53
- no_bg_image.paste(orig_image, mask=pil_im)
54
- no_bg_image.save(out_name)
55
- else :
56
- return result_image
57
 
58
- # utilities functions
59
- def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
60
- # same as utilities.py with minor modification
61
- if len(im.shape) < 3:
62
- im = im[:, :, np.newaxis]
63
- # orig_im_size=im.shape[0:2]
64
- im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
65
- im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
66
- image = torch.divide(im_tensor,255.0)
67
- image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
68
- return image
69
- def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
70
- result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
71
- ma = torch.max(result)
72
- mi = torch.min(result)
73
- result = (result-mi)/(ma-mi)
74
- im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
75
- im_array = np.squeeze(im_array)
76
- return im_array
 
 
1
  import torch, os
2
  import torch.nn.functional as F
3
  from torchvision.transforms.functional import normalize
 
6
  from skimage import io
7
  from PIL import Image
8
 
9
+
10
  class RMBGPipe(Pipeline):
11
+ def __init__(self, **kwargs):
12
+ Pipeline.__init__(self, **kwargs)
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ self.model.to(self.device)
15
+ self.model.eval()
16
+
17
+ def _sanitize_parameters(self, **kwargs):
18
+ # parse parameters
19
+ preprocess_kwargs = {}
20
+ postprocess_kwargs = {}
21
+ if "model_input_size" in kwargs:
22
+ preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
23
+ if "out_name" in kwargs:
24
+ postprocess_kwargs["out_name"] = kwargs["out_name"]
25
+ return preprocess_kwargs, {}, postprocess_kwargs
26
+
27
+ def preprocess(self, orig_im: Image, model_input_size: list = [1024, 1024]):
28
+ # preprocess the input
29
+ orig_im_size = orig_im.shape[0:2]
30
+ image = self.preprocess_image(orig_im, model_input_size).to(self.device)
31
+ inputs = {
32
+ "orig_im": orig_im,
33
+ "image": image,
34
+ "orig_im_size": orig_im_size,
35
+ }
36
+ return inputs
37
 
38
+ def _forward(self, inputs):
39
+ result = self.model(inputs.pop("image"))
40
+ inputs["result"] = result
41
+ return inputs
 
 
 
 
 
42
 
43
+ def postprocess(self, inputs, out_name=""):
44
+ result = inputs.pop("result")
45
+ orig_im_size = inputs.pop("orig_im_size")
46
+ orig_image = inputs.pop("orig_image")
47
+ result_image = self.postprocess_image(result[0][0], orig_im_size)
48
+ if out_name != "":
49
+ # if out_name is specified we save the image using that name
50
+ pil_im = Image.fromarray(result_image)
51
+ no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
52
+ no_bg_image.paste(orig_image, mask=pil_im)
53
+ no_bg_image.save(out_name)
54
+ else:
55
+ return result_image
56
 
57
+ # utilities functions
58
+ def preprocess_image(
59
+ self, im: np.ndarray, model_input_size: list = [1024, 1024]
60
+ ) -> torch.Tensor:
61
+ # same as utilities.py with minor modification
62
+ if len(im.shape) < 3:
63
+ im = im[:, :, np.newaxis]
64
+ # orig_im_size=im.shape[0:2]
65
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
66
+ im_tensor = F.interpolate(
67
+ torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
68
+ ).type(torch.uint8)
69
+ image = torch.divide(im_tensor, 255.0)
70
+ image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
71
+ return image
 
 
 
72
 
73
+ def postprocess_image(self, result: torch.Tensor, im_size: list) -> np.ndarray:
74
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0)
75
+ ma = torch.max(result)
76
+ mi = torch.min(result)
77
+ result = (result - mi) / (ma - mi)
78
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
79
+ im_array = np.squeeze(im_array)
80
+ return im_array