yizhangliu commited on
Commit
43859c3
·
1 Parent(s): c69b167

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -14
app.py CHANGED
@@ -104,6 +104,21 @@ def preprocess_mask(mask):
104
  mask = torch.from_numpy(mask)
105
  return mask
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  model = None
108
  def model_process(image, mask):
109
  global model
@@ -112,19 +127,23 @@ def model_process(image, mask):
112
  # RGB
113
  # origin_image_bytes = input["image"].read()
114
 
115
-
116
  print(f'liuyz_2_here_')
117
-
118
- # image, alpha_channel = load_img(origin_image_bytes)
 
 
 
 
119
  # Origin image shape: (512, 512, 3)
120
- alpha_channel = None
 
121
  original_shape = image.shape
122
  interpolation = cv2.INTER_CUBIC
123
 
124
  # form = request.form
125
- print(f'liuyz_3_here_', original_shape)
126
 
127
- size_limit = 512 # : Union[int, str] = form.get("sizeLimit", "1080")
128
  if size_limit == "Original":
129
  size_limit = max(image.shape)
130
  else:
@@ -167,9 +186,8 @@ def model_process(image, mask):
167
  print(f"Resized image shape: {image.shape} / {image[250][250]}")
168
 
169
  #mask, _ = load_img(input["mask"].read(), gray=True)
170
- mask_image = Image.fromarray(mask).convert("L")
171
- mask_image.save(f'./mask_image.png')
172
- mask = np.array(mask_image)
173
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
174
  print(f"mask image shape: {mask.shape} / {type(mask)} / {mask[250][250]}")
175
 
@@ -238,11 +256,7 @@ def predict(dict):
238
  print(f'liuyz_3_', image.convert("RGB").resize((512, 512)).shape)
239
  # mask = dict["mask"] # .convert("RGB") #.resize((512, 512))
240
  '''
241
- print(f'size__', dict["image"].size)
242
- image = Image.fromarray(dict["image"])
243
- mask = np.array(Image.fromarray(dict["mask"]).convert("L"))
244
- print(f'mask___1 = {mask.shape}')
245
-
246
  output = model_process(dict["image"], dict["mask"])
247
  # output = mask #output.images[0]
248
  # output = pipe(prompt = prompt, image=init_image, mask_image=mask,guidance_scale=7.5)
 
104
  mask = torch.from_numpy(mask)
105
  return mask
106
 
107
+ def load_img(nparr, gray: bool = False):
108
+ # alpha_channel = None
109
+ # nparr = np.frombuffer(img_bytes, np.uint8)
110
+ if gray:
111
+ np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
112
+ else:
113
+ np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
114
+ if len(np_img.shape) == 3 and np_img.shape[2] == 4:
115
+ alpha_channel = np_img[:, :, -1]
116
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
117
+ else:
118
+ np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
119
+
120
+ return np_img, alpha_channel
121
+
122
  model = None
123
  def model_process(image, mask):
124
  global model
 
127
  # RGB
128
  # origin_image_bytes = input["image"].read()
129
 
 
130
  print(f'liuyz_2_here_')
131
+ image_pil = Image.fromarray(image)
132
+ mask_pil = Image.fromarray(mask).convert("L")
133
+ print(f'image_pil_ = {image_pil.shape}')
134
+ print(f'mask_pil_ = {mask_pil.shape}')
135
+
136
+ image, alpha_channel = load_img(image)
137
  # Origin image shape: (512, 512, 3)
138
+
139
+ # alpha_channel = None
140
  original_shape = image.shape
141
  interpolation = cv2.INTER_CUBIC
142
 
143
  # form = request.form
144
+ print(f'liuyz_3_here_', original_shape, alpha_channel)
145
 
146
+ size_limit = image_pil.shape[1] # : Union[int, str] = form.get("sizeLimit", "1080")
147
  if size_limit == "Original":
148
  size_limit = max(image.shape)
149
  else:
 
186
  print(f"Resized image shape: {image.shape} / {image[250][250]}")
187
 
188
  #mask, _ = load_img(input["mask"].read(), gray=True)
189
+ mask_pil.save(f'./mask_pil.png')
190
+ mask = np.array(mask_pil)
 
191
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
192
  print(f"mask image shape: {mask.shape} / {type(mask)} / {mask[250][250]}")
193
 
 
256
  print(f'liuyz_3_', image.convert("RGB").resize((512, 512)).shape)
257
  # mask = dict["mask"] # .convert("RGB") #.resize((512, 512))
258
  '''
259
+
 
 
 
 
260
  output = model_process(dict["image"], dict["mask"])
261
  # output = mask #output.images[0]
262
  # output = pipe(prompt = prompt, image=init_image, mask_image=mask,guidance_scale=7.5)