added combine_masks
Browse files
app.py
CHANGED
@@ -12,7 +12,7 @@ import argparse
|
|
12 |
# Load configuration and models
|
13 |
config = OmegaConf.load("config/inference_config.yaml")
|
14 |
sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
15 |
-
"
|
16 |
)
|
17 |
clipaway = CLIPAway(
|
18 |
sd_pipe=sd_pipeline,
|
@@ -26,11 +26,19 @@ clipaway = CLIPAway(
|
|
26 |
)
|
27 |
|
28 |
def dilate_mask(mask, kernel_size=5, iterations=5):
|
29 |
-
mask = mask.convert("L")
|
30 |
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
31 |
mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
|
32 |
return Image.fromarray(mask)
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
def remove_obj(image, uploaded_mask, seed):
|
35 |
image_pil, sketched_mask = image["image"], image["mask"]
|
36 |
mask = dilate_mask(combine_masks(uploaded_mask, sketched_mask))
|
|
|
12 |
# Load configuration and models
|
13 |
config = OmegaConf.load("config/inference_config.yaml")
|
14 |
sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
15 |
+
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float32
|
16 |
)
|
17 |
clipaway = CLIPAway(
|
18 |
sd_pipe=sd_pipeline,
|
|
|
26 |
)
|
27 |
|
28 |
def dilate_mask(mask, kernel_size=5, iterations=5):
|
29 |
+
mask = mask.convert("L")
|
30 |
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
31 |
mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
|
32 |
return Image.fromarray(mask)
|
33 |
|
34 |
+
def combine_masks(uploaded_mask, sketched_mask):
|
35 |
+
if uploaded_mask is not None:
|
36 |
+
return uploaded_mask
|
37 |
+
elif sketched_mask is not None:
|
38 |
+
return sketched_mask
|
39 |
+
else:
|
40 |
+
raise ValueError("Please provide a mask")
|
41 |
+
|
42 |
def remove_obj(image, uploaded_mask, seed):
|
43 |
image_pil, sketched_mask = image["image"], image["mask"]
|
44 |
mask = dilate_mask(combine_masks(uploaded_mask, sketched_mask))
|