bedead commited on
Commit
225113d
·
verified ·
1 Parent(s): ac9b118

added combine_masks

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -12,7 +12,7 @@ import argparse
12
  # Load configuration and models
13
  config = OmegaConf.load("config/inference_config.yaml")
14
  sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
15
- "botp/stable-diffusion-v1-5-inpainting", torch_dtype=torch.float32
16
  )
17
  clipaway = CLIPAway(
18
  sd_pipe=sd_pipeline,
@@ -26,11 +26,19 @@ clipaway = CLIPAway(
26
  )
27
 
28
  def dilate_mask(mask, kernel_size=5, iterations=5):
29
- mask = mask.convert("L").resize((512, 512), Image.NEAREST)
30
  kernel = np.ones((kernel_size, kernel_size), np.uint8)
31
  mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
32
  return Image.fromarray(mask)
33
 
 
 
 
 
 
 
 
 
34
  def remove_obj(image, uploaded_mask, seed):
35
  image_pil, sketched_mask = image["image"], image["mask"]
36
  mask = dilate_mask(combine_masks(uploaded_mask, sketched_mask))
 
12
  # Load configuration and models
13
  config = OmegaConf.load("config/inference_config.yaml")
14
  sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
15
+ "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float32
16
  )
17
  clipaway = CLIPAway(
18
  sd_pipe=sd_pipeline,
 
26
  )
27
 
28
  def dilate_mask(mask, kernel_size=5, iterations=5):
29
+ mask = mask.convert("L")
30
  kernel = np.ones((kernel_size, kernel_size), np.uint8)
31
  mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
32
  return Image.fromarray(mask)
33
 
34
+ def combine_masks(uploaded_mask, sketched_mask):
35
+ if uploaded_mask is not None:
36
+ return uploaded_mask
37
+ elif sketched_mask is not None:
38
+ return sketched_mask
39
+ else:
40
+ raise ValueError("Please provide a mask")
41
+
42
  def remove_obj(image, uploaded_mask, seed):
43
  image_pil, sketched_mask = image["image"], image["mask"]
44
  mask = dilate_mask(combine_masks(uploaded_mask, sketched_mask))