bedead commited on
Commit
2c45810
·
verified ·
1 Parent(s): bd7a154

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -31,15 +31,21 @@ def dilate_mask(mask, kernel_size=5, iterations=5):
31
  mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
32
  return Image.fromarray(mask)
33
 
34
- @spaces.GPU
35
- def remove_obj(image, uploaded_mask, seed):
36
- image_pil = image.resize((512, 512), Image.LANCZOS)
 
 
 
37
  mask = dilate_mask(uploaded_mask)
38
  seed = int(seed)
39
- latents = torch.randn((1, 4, 64, 64), generator=torch.Generator().manual_seed(seed)).to("cuda")
40
  final_image = clipaway.generate(
41
  prompt=[""], scale=1, seed=seed,
42
- pil_image=[image_pil], alpha=[mask], strength=1, latents=latents
 
 
 
43
  )[0]
44
  return final_image
45
 
@@ -73,8 +79,7 @@ with gr.Blocks(theme="gradio/monochrome") as demo:
73
 
74
  with gr.Row():
75
  with gr.Column():
76
- image_input = gr.Image(label="Upload Image and Sketch Mask", type="pil", tool="sketch")
77
- uploaded_mask = gr.Image(label="Upload Mask", type="pil", image_mode="L")
78
  seed_input = gr.Number(value=42, label="Seed")
79
  process_button = gr.Button("Remove Object")
80
  with gr.Column():
@@ -82,7 +87,7 @@ with gr.Blocks(theme="gradio/monochrome") as demo:
82
 
83
  process_button.click(
84
  fn=remove_obj,
85
- inputs=[image_input, uploaded_mask, seed_input],
86
  outputs=result_image
87
  )
88
 
 
31
  mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
32
  return Image.fromarray(mask)
33
 
34
+ def remove_obj(image, seed):
35
+ alpha_channel = img["layers"][0][:, :, 3]
36
+ mask = np.where(alpha_channel == 0, 0, 255).astype(np.uint8)
37
+ uploaded_mask = Image.fromarray(mask)
38
+ background = Image.fromarray(img["background"])
39
+
40
  mask = dilate_mask(uploaded_mask)
41
  seed = int(seed)
42
+ latents = torch.randn((1, 4, 64, 64), generator=torch.Generator().manual_seed(seed)).to("cpu")
43
  final_image = clipaway.generate(
44
  prompt=[""], scale=1, seed=seed,
45
+ pil_image=[background],
46
+ alpha=[mask],
47
+ strength=1,
48
+ latents=latents
49
  )[0]
50
  return final_image
51
 
 
79
 
80
  with gr.Row():
81
  with gr.Column():
82
+ image_input = gr.ImageMask(label="Upload Image and Sketch Mask", type="pil")
 
83
  seed_input = gr.Number(value=42, label="Seed")
84
  process_button = gr.Button("Remove Object")
85
  with gr.Column():
 
87
 
88
  process_button.click(
89
  fn=remove_obj,
90
+ inputs=[image_input, seed_input],
91
  outputs=result_image
92
  )
93