bedead commited on
Commit
6467917
·
verified ·
1 Parent(s): 225113d

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -5,10 +5,10 @@ colorFrom: green
5
  colorTo: red
6
  sdk: gradio
7
  python_version: 3.10.13
8
- sdk_version: 3.48.0
9
  app_file: app.py
10
  pinned: false
11
  license: mit
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
5
  colorTo: red
6
  sdk: gradio
7
  python_version: 3.10.13
8
+ sdk_version: 4.36.1
9
  app_file: app.py
10
  pinned: false
11
  license: mit
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,5 +1,5 @@
1
- import spaces
2
  import gradio as gr
 
3
  import torch
4
  from omegaconf import OmegaConf
5
  from PIL import Image
@@ -9,10 +9,18 @@ import cv2
9
  import numpy as np
10
  import argparse
11
 
 
 
 
 
 
 
 
 
12
  # Load configuration and models
13
- config = OmegaConf.load("config/inference_config.yaml")
14
  sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
15
- "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float32
16
  )
17
  clipaway = CLIPAway(
18
  sd_pipe=sd_pipeline,
@@ -21,7 +29,7 @@ clipaway = CLIPAway(
21
  alpha_clip_path=config.alpha_clip_ckpt_pth,
22
  config=config,
23
  alpha_clip_id=config.alpha_clip_id,
24
- device="cpu",
25
  num_tokens=4
26
  )
27
 
@@ -43,7 +51,7 @@ def remove_obj(image, uploaded_mask, seed):
43
  image_pil, sketched_mask = image["image"], image["mask"]
44
  mask = dilate_mask(combine_masks(uploaded_mask, sketched_mask))
45
  seed = int(seed)
46
- latents = torch.randn((1, 4, 64, 64), generator=torch.Generator().manual_seed(seed)).to("cpu")
47
  final_image = clipaway.generate(
48
  prompt=[""], scale=1, seed=seed,
49
  pil_image=[image_pil], alpha=[mask], strength=1, latents=latents
@@ -52,29 +60,29 @@ def remove_obj(image, uploaded_mask, seed):
52
 
53
  # Define example data
54
  examples = [
55
- ["gradio_examples/images/1.jpg", "gradio_examples/masks/1.png", 42],
56
- ["gradio_examples/images/2.jpg", "gradio_examples/masks/2.png", 42],
57
- ["gradio_examples/images/3.jpg", "gradio_examples/masks/3.png", 464],
58
  ]
59
 
 
60
  with gr.Blocks() as demo:
61
  gr.Markdown("<h1 style='text-align:center'>CLIPAway: Harmonizing Focused Embeddings for Removing Objects via Diffusion Models</h1>")
62
  gr.Markdown("""
63
  <div style='display:flex; justify-content:center; align-items:center;'>
64
- <a href='https://arxiv.org/abs/2406.09368' style="margin-right:10px;">Paper</a> |
65
  <a href='https://yigitekin.github.io/CLIPAway/' style="margin:10px;">Project Website</a> |
66
- <a href='https://github.com/YigitEkin/CLIPAway' style="margin-left:10px;">GitHub</a>
67
  </div>
68
  """)
69
  gr.Markdown("""
70
  This application allows you to remove objects from images using the CLIPAway method with diffusion models.
71
  To use this tool:
72
- 1. Upload an image. (NOTE: We expect a 512x512 image, if you upload a different size, it will be resized to 512x512 which can affect the results.)
73
- 2. Upload a pre-defined mask if you have one. (If you don't have a mask, and want to sketch one,
74
- we have provided a gradio demo in our github repository. <br/> Unfortunately, we cannot provide it here due to the compatibility issues with zerogpu.)
75
- 3. Set the seed for reproducibility (default is 42).
76
- 4. Click 'Remove Object' to process the image.
77
- 5. The result will be displayed on the right side.
78
  Note: The mask should be a binary image where the object to be removed is white and the background is black.
79
  """)
80
 
@@ -89,10 +97,10 @@ with gr.Blocks() as demo:
89
 
90
  process_button.click(
91
  fn=remove_obj,
92
- inputs=[image_input, seed_input],
93
  outputs=result_image
94
  )
95
 
 
96
 
97
-
98
- demo.launch(share=True)
 
 
1
  import gradio as gr
2
+ import sys
3
  import torch
4
  from omegaconf import OmegaConf
5
  from PIL import Image
 
9
  import numpy as np
10
  import argparse
11
 
12
+ # Parse command line arguments
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--config", type=str, default="config/inference_config.yaml", help="Path to the config file")
15
+ parser.add_argument("--share", action="store_true", help="Share the interface if provided")
16
+ args = parser.parse_args()
17
+
18
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+
20
  # Load configuration and models
21
+ config = OmegaConf.load(args.config)
22
  sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
23
+ "runwayml/stable-diffusion-inpainting", safety_checker=None, torch_dtype=torch.float32
24
  )
25
  clipaway = CLIPAway(
26
  sd_pipe=sd_pipeline,
 
29
  alpha_clip_path=config.alpha_clip_ckpt_pth,
30
  config=config,
31
  alpha_clip_id=config.alpha_clip_id,
32
+ device=device,
33
  num_tokens=4
34
  )
35
 
 
51
  image_pil, sketched_mask = image["image"], image["mask"]
52
  mask = dilate_mask(combine_masks(uploaded_mask, sketched_mask))
53
  seed = int(seed)
54
+ latents = torch.randn((1, 4, 64, 64), generator=torch.Generator().manual_seed(seed)).to(device)
55
  final_image = clipaway.generate(
56
  prompt=[""], scale=1, seed=seed,
57
  pil_image=[image_pil], alpha=[mask], strength=1, latents=latents
 
60
 
61
  # Define example data
62
  examples = [
63
+ ["assets/gradio_examples/images/1.jpg", "assets/gradio_examples/masks/1.png", 42],
64
+ ["assets/gradio_examples/images/2.jpg", "assets/gradio_examples/masks/2.png", 42],
65
+ ["assets/gradio_examples/images/3.jpg", "assets/gradio_examples/masks/3.png", 2024],
66
  ]
67
 
68
+ # Define the Gradio interface
69
  with gr.Blocks() as demo:
70
  gr.Markdown("<h1 style='text-align:center'>CLIPAway: Harmonizing Focused Embeddings for Removing Objects via Diffusion Models</h1>")
71
  gr.Markdown("""
72
  <div style='display:flex; justify-content:center; align-items:center;'>
73
+ <a href='https://arxiv.org/abs/2406.09368' style="margin:10px;">Paper</a> |
74
  <a href='https://yigitekin.github.io/CLIPAway/' style="margin:10px;">Project Website</a> |
75
+ <a href='https://github.com/YigitEkin/CLIPAway' style="margin:10px;">GitHub</a>
76
  </div>
77
  """)
78
  gr.Markdown("""
79
  This application allows you to remove objects from images using the CLIPAway method with diffusion models.
80
  To use this tool:
81
+ 1. Upload an image.
82
+ 2. Either Sketch a mask over the object you want to remove or upload a pre-defined mask if you have one.
83
+ 4. Set the seed for reproducibility (default is 42).
84
+ 5. Click 'Remove Object' to process the image.
85
+ 6. The result will be displayed on the right side.
 
86
  Note: The mask should be a binary image where the object to be removed is white and the background is black.
87
  """)
88
 
 
97
 
98
  process_button.click(
99
  fn=remove_obj,
100
+ inputs=[image_input, uploaded_mask, seed_input],
101
  outputs=result_image
102
  )
103
 
104
+ # Launch the interface with caching
105
 
106
+ demo.launch(share=True)
 
model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (194 Bytes). View file
 
model/__pycache__/attention_processor.cpython-310.pyc ADDED
Binary file (3.94 kB). View file
 
model/__pycache__/clip_away.cpython-310.pyc ADDED
Binary file (8.85 kB). View file
 
model/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.91 kB). View file
 
model/clip_away.py CHANGED
@@ -38,8 +38,8 @@ class ImageProjModel(torch.nn.Module):
38
  return clip_extra_context_tokens
39
 
40
  class CLIPAway:
41
- def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, alpha_clip_path, config, alpha_clip_id="ViT-L/14", device="cuda", num_tokens=4):
42
- super().__init__()
43
  self.device = device
44
  self.ipadapter_image_encoder_path = image_encoder_path
45
  self.ipadapter_ckpt = ip_ckpt
 
38
  return clip_extra_context_tokens
39
 
40
  class CLIPAway:
41
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, alpha_clip_path, config, device, alpha_clip_id="ViT-L/14", num_tokens=4):
42
+ super().__init__()
43
  self.device = device
44
  self.ipadapter_image_encoder_path = image_encoder_path
45
  self.ipadapter_ckpt = ip_ckpt
requirements.txt CHANGED
@@ -15,5 +15,4 @@ transformers==4.39.3
15
  git+https://github.com/openai/CLIP.git
16
  git+https://github.com/tencent-ailab/IP-Adapter.git
17
  git+https://github.com/SunzeY/AlphaCLIP.git
18
- loralib
19
- gradio==4.44.1
 
15
  git+https://github.com/openai/CLIP.git
16
  git+https://github.com/tencent-ailab/IP-Adapter.git
17
  git+https://github.com/SunzeY/AlphaCLIP.git
18
+ loralib