OriLib commited on
Commit
ce305f6
·
verified ·
1 Parent(s): 28f8f41

Update MyPipe.py

Browse files

added support to pil image as input

Files changed (1) hide show
  1. MyPipe.py +11 -10
MyPipe.py CHANGED
@@ -3,6 +3,7 @@ import torch.nn.functional as F
3
  from torchvision.transforms.functional import normalize
4
  import numpy as np
5
  from transformers import Pipeline
 
6
  from skimage import io
7
  from PIL import Image
8
 
@@ -23,34 +24,35 @@ class RMBGPipe(Pipeline):
23
  postprocess_kwargs["return_mask"] = kwargs["return_mask"]
24
  return preprocess_kwargs, {}, postprocess_kwargs
25
 
26
- def preprocess(self,im_path:str,model_input_size: list=[1024,1024]):
27
  # preprocess the input
28
- orig_im = io.imread(im_path)
 
29
  orig_im_size = orig_im.shape[0:2]
30
- image = self.preprocess_image(orig_im, model_input_size).to(self.device)
31
  inputs = {
32
- "image":image,
33
  "orig_im_size":orig_im_size,
34
- "im_path" : im_path
35
  }
36
  return inputs
37
 
38
  def _forward(self,inputs):
39
- result = self.model(inputs.pop("image"))
40
  inputs["result"] = result
41
  return inputs
42
 
43
  def postprocess(self,inputs,return_mask:bool=False ):
44
  result = inputs.pop("result")
45
  orig_im_size = inputs.pop("orig_im_size")
46
- im_path = inputs.pop("im_path")
47
  result_image = self.postprocess_image(result[0][0], orig_im_size)
48
  pil_im = Image.fromarray(result_image)
49
  if return_mask ==True :
50
  return pil_im
51
  no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
52
- orig_image = Image.fromarray(io.imread(im_path))
53
- no_bg_image.paste(orig_image, mask=pil_im)
54
  return no_bg_image
55
 
56
  # utilities functions
@@ -58,7 +60,6 @@ class RMBGPipe(Pipeline):
58
  # same as utilities.py with minor modification
59
  if len(im.shape) < 3:
60
  im = im[:, :, np.newaxis]
61
- # orig_im_size=im.shape[0:2]
62
  im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
63
  im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
64
  image = torch.divide(im_tensor,255.0)
 
3
  from torchvision.transforms.functional import normalize
4
  import numpy as np
5
  from transformers import Pipeline
6
+ from transformers.image_utils import load_image
7
  from skimage import io
8
  from PIL import Image
9
 
 
24
  postprocess_kwargs["return_mask"] = kwargs["return_mask"]
25
  return preprocess_kwargs, {}, postprocess_kwargs
26
 
27
+ def preprocess(self,input_image,model_input_size: list=[1024,1024]):
28
  # preprocess the input
29
+ orig_im = load_image(input_image)
30
+ orig_im = np.array(orig_im)
31
  orig_im_size = orig_im.shape[0:2]
32
+ preprocessed_image = self.preprocess_image(orig_im, model_input_size).to(self.device)
33
  inputs = {
34
+ "preprocessed_image":preprocessed_image,
35
  "orig_im_size":orig_im_size,
36
+ "input_image" : input_image
37
  }
38
  return inputs
39
 
40
  def _forward(self,inputs):
41
+ result = self.model(inputs.pop("preprocessed_image"))
42
  inputs["result"] = result
43
  return inputs
44
 
45
  def postprocess(self,inputs,return_mask:bool=False ):
46
  result = inputs.pop("result")
47
  orig_im_size = inputs.pop("orig_im_size")
48
+ input_image = inputs.pop("input_image")
49
  result_image = self.postprocess_image(result[0][0], orig_im_size)
50
  pil_im = Image.fromarray(result_image)
51
  if return_mask ==True :
52
  return pil_im
53
  no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
54
+ input_image = load_image(input_image)
55
+ no_bg_image.paste(input_image, mask=pil_im)
56
  return no_bg_image
57
 
58
  # utilities functions
 
60
  # same as utilities.py with minor modification
61
  if len(im.shape) < 3:
62
  im = im[:, :, np.newaxis]
 
63
  im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
64
  im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
65
  image = torch.divide(im_tensor,255.0)