ariG23498 commited on
Commit
3d3d52b
1 Parent(s): 0ad1f77

get better masks

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -112,8 +112,10 @@ def run_segmentation(image, object_to_segment, device):
112
  outputs.pred_masks.cpu(),
113
  inputs["original_sizes"].cpu(),
114
  inputs["reshaped_input_sizes"].cpu(),
115
- )
116
-
 
 
117
  delete_model(seg_model)
118
  return masks
119
 
@@ -133,10 +135,10 @@ def run_inpainting(image, replaced_caption, masks, device):
133
  output = pipeline(
134
  prompt=prompt,
135
  image=image,
136
- mask_image=Image.fromarray(masks[0][0][0, :, :].numpy()),
137
  negative_prompt=negative_prompt,
138
  guidance_scale=7.5,
139
- strength=0.6,
140
  ).images[0]
141
 
142
  delete_model(pipeline)
@@ -174,7 +176,7 @@ def run_open_gen_fill(image, edit_prompt):
174
  replace_with,
175
  caption,
176
  replaced_caption,
177
- Image.fromarray(masks[0][0][0, :, :].numpy()),
178
  output,
179
  )
180
 
 
112
  outputs.pred_masks.cpu(),
113
  inputs["original_sizes"].cpu(),
114
  inputs["reshaped_input_sizes"].cpu(),
115
+ )[0]
116
+ # Merge the masks
117
+ masks = torch.max(masks[:, 0, ...], dim=0, keepdim=False).values
118
+
119
  delete_model(seg_model)
120
  return masks
121
 
 
135
  output = pipeline(
136
  prompt=prompt,
137
  image=image,
138
+ mask_image=Image.fromarray(masks.numpy()),
139
  negative_prompt=negative_prompt,
140
  guidance_scale=7.5,
141
+ strength=1.0,
142
  ).images[0]
143
 
144
  delete_model(pipeline)
 
176
  replace_with,
177
  caption,
178
  replaced_caption,
179
+ Image.fromarray(masks.numpy()),
180
  output,
181
  )
182