OriLib commited on
Commit
b0fb67c
·
verified ·
1 Parent(s): 524ae85

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +13 -26
README.md CHANGED
@@ -98,34 +98,25 @@ pip install -r requirements.txt
98
  ## Usage
99
 
100
  ```python
101
- import numpy as np
102
  from skimage import io
103
- import torch
104
- import torch.nn.functional as F
105
- from torchvision.transforms.functional import normalize
106
- from briarmbg import BriaRMBG
107
  from PIL import Image
 
 
108
 
109
- model_path = "./model.pth"
110
- im_path = "./example_input.jpg"
111
 
112
  net = BriaRMBG()
113
- if torch.cuda.is_available():
114
- net.load_state_dict(torch.load(model_path)).cuda()
115
- else:
116
- net.load_state_dict(torch.load(model_path,map_location="cpu"))
117
  net.eval()
118
 
119
  # prepare input
120
- model_input_size=[1024,1024]
121
- im = io.imread(im_path)
122
- if len(im.shape) < 3:
123
- im = im[:, :, np.newaxis]
124
- im_size=im.shape[0:2]
125
- im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
126
- im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
127
- image = torch.divide(im_tensor,255.0)
128
- image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
129
 
130
  if torch.cuda.is_available():
131
  image=image.cuda()
@@ -134,14 +125,10 @@ if torch.cuda.is_available():
134
  result=net(image)
135
 
136
  # post process
137
- result = torch.squeeze(F.interpolate(result[0][0], size=im_size, mode='bilinear') ,0)
138
- ma = torch.max(result)
139
- mi = torch.min(result)
140
- result = (result-mi)/(ma-mi)
141
 
142
  # save result
143
- im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
144
- pil_im = Image.fromarray(np.squeeze(im_array))
145
  no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
146
  orig_image = Image.open(im_path)
147
  no_bg_image.paste(orig_image, mask=pil_im)
 
98
  ## Usage
99
 
100
  ```python
 
101
  from skimage import io
102
+ import torch, os
 
 
 
103
  from PIL import Image
104
+ from briarmbg import BriaRMBG
105
+ from utilities import preprocess_image, postprocess_image
106
 
107
+ model_path = f"{os.path.dirname(__file__)}/model.pth"
108
+ im_path = f"{os.path.dirname(__file__)}/example_input.jpg"
109
 
110
  net = BriaRMBG()
111
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
112
+ net.load_state_dict(torch.load(model_path, map_location=device))
 
 
113
  net.eval()
114
 
115
  # prepare input
116
+ model_input_size = [1024,1024]
117
+ orig_im = io.imread(im_path)
118
+ orig_im_size = orig_im.shape[0:2]
119
+ image = preprocess_image(orig_im, model_input_size)
 
 
 
 
 
120
 
121
  if torch.cuda.is_available():
122
  image=image.cuda()
 
125
  result=net(image)
126
 
127
  # post process
128
+ result_image = postprocess_image(result[0][0], orig_im_size)
 
 
 
129
 
130
  # save result
131
+ pil_im = Image.fromarray(result_image)
 
132
  no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
133
  orig_image = Image.open(im_path)
134
  no_bg_image.paste(orig_image, mask=pil_im)