Johannes commited on
Commit
eea614c
1 Parent(s): a4c0ed2

initial changes

Browse files
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: SyntheticDataSAM
3
- emoji: 🐢
4
- colorFrom: gray
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.28.0
8
  app_file: app.py
 
1
  ---
2
+ title: ControlNet+SAM WildSynth
3
+ emoji: 🦬
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.28.0
8
  app_file: app.py
__pycache__/controlnet_inpaint.cpython-310.pyc ADDED
Binary file (36.1 kB). View file
 
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from diffusers import StableDiffusionInpaintPipeline
7
+ from flax.jax_utils import replicate
8
+ from flax.training.common_utils import shard
9
+ from PIL import Image
10
+ from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
11
+ from diffusers import (
12
+ UniPCMultistepScheduler,
13
+ FlaxStableDiffusionControlNetPipeline,
14
+ FlaxControlNetModel,
15
+ )
16
+
17
+ import colorsys
18
+
19
+ sam_checkpoint = "sam_vit_h_4b8939.pth"
20
+ model_type = "vit_h"
21
+ device = "cpu"
22
+
23
+
24
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
25
+ sam.to(device=device)
26
+ predictor = SamPredictor(sam)
27
+ mask_generator = SamAutomaticMaskGenerator(sam)
28
+
29
+
30
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
31
+ "mfidabel/controlnet-segment-anything", dtype=jnp.float32
32
+ )
33
+
34
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
35
+ "runwayml/stable-diffusion-v1-5",
36
+ controlnet=controlnet,
37
+ revision="flax",
38
+ dtype=jnp.bfloat16,
39
+ )
40
+
41
+ params["controlnet"] = controlnet_params
42
+ p_params = replicate(params)
43
+
44
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
45
+ pipe = pipe.to(device)
46
+
47
+
48
+ with gr.Blocks() as demo:
49
+ gr.Markdown("# WildSynth: Synthetic Wildlife Data Generation")
50
+ gr.Markdown(
51
+ """
52
+ We have trained a JAX ControlNet model with
53
+ To try the demo, upload an image and select object(s) you want to inpaint.
54
+ Write a prompt & a negative prompt to control the inpainting.
55
+ Click on the "Submit" button to inpaint the selected object(s).
56
+ Check "Background" to inpaint the background instead of the selected object(s).
57
+
58
+ If the demo is slow, clone the space to your own HF account and run on a GPU.
59
+ """
60
+ )
61
+ with gr.Row():
62
+ input_img = gr.Image(label="Input")
63
+ mask_img = gr.Image(label="Mask", interactive=False)
64
+ output_img = gr.Image(label="Output", interactive=False)
65
+
66
+ with gr.Row():
67
+ prompt_text = gr.Textbox(lines=1, label="Prompt")
68
+ negative_prompt_text = gr.Textbox(lines=1, label="Negative Prompt")
69
+
70
+ with gr.Row():
71
+ submit = gr.Button("Submit")
72
+ clear = gr.Button("Clear")
73
+
74
+ def generate_mask(image, evt: gr.SelectData):
75
+ predictor.set_image(image)
76
+ input_point = np.array([120, 21])
77
+ input_label = np.ones(input_point.shape[0])
78
+ mask, _, _ = predictor.predict(
79
+ point_coords=input_point,
80
+ point_labels=input_label,
81
+ multimask_output=False,
82
+ )
83
+
84
+ # clear torch cache
85
+ torch.cuda.empty_cache()
86
+ mask = Image.fromarray(mask[0, :, :])
87
+ segs = mask_generator.generate(image)
88
+ boolean_masks = [s["segmentation"] for s in segs]
89
+ finseg = np.zeros(
90
+ (boolean_masks[0].shape[0], boolean_masks[0].shape[1], 3), dtype=np.uint8
91
+ )
92
+ # Loop over the boolean masks and assign a unique color to each class
93
+ for class_id, boolean_mask in enumerate(boolean_masks):
94
+ hue = class_id * 1.0 / len(boolean_masks)
95
+ rgb = tuple(int(i * 255) for i in colorsys.hsv_to_rgb(hue, 1, 1))
96
+ rgb_mask = np.zeros(
97
+ (boolean_mask.shape[0], boolean_mask.shape[1], 3), dtype=np.uint8
98
+ )
99
+ rgb_mask[:, :, 0] = boolean_mask * rgb[0]
100
+ rgb_mask[:, :, 1] = boolean_mask * rgb[1]
101
+ rgb_mask[:, :, 2] = boolean_mask * rgb[2]
102
+ finseg += rgb_mask
103
+
104
+ torch.cuda.empty_cache()
105
+
106
+ return mask, finseg
107
+
108
+ def infer(
109
+ image, prompts, negative_prompts, num_inference_steps=50, seed=4, num_samples=4
110
+ ):
111
+ try:
112
+ rng = jax.random.PRNGKey(int(seed))
113
+ num_inference_steps = int(num_inference_steps)
114
+ image = Image.fromarray(image, mode="RGB")
115
+ num_samples = max(jax.device_count(), int(num_samples))
116
+ p_rng = jax.random.split(rng, jax.device_count())
117
+
118
+ prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
119
+ negative_prompt_ids = pipe.prepare_text_inputs(
120
+ [negative_prompts] * num_samples
121
+ )
122
+ processed_image = pipe.prepare_image_inputs([image] * num_samples)
123
+
124
+ prompt_ids = shard(prompt_ids)
125
+ negative_prompt_ids = shard(negative_prompt_ids)
126
+ processed_image = shard(processed_image)
127
+
128
+ output = pipe(
129
+ prompt_ids=prompt_ids,
130
+ image=processed_image,
131
+ params=p_params,
132
+ prng_seed=p_rng,
133
+ num_inference_steps=num_inference_steps,
134
+ neg_prompt_ids=negative_prompt_ids,
135
+ jit=True,
136
+ ).images
137
+
138
+ del negative_prompt_ids
139
+ del processed_image
140
+ del prompt_ids
141
+
142
+ output = output.reshape((num_samples,) + output.shape[-3:])
143
+ final_image = [np.array(x * 255, dtype=np.uint8) for x in output]
144
+ print(output.shape)
145
+ del output
146
+
147
+ except Exception as e:
148
+ print("Error: " + str(e))
149
+ final_image = [np.zeros((512, 512, 3), dtype=np.uint8)] * num_samples
150
+ finally:
151
+ gc.collect()
152
+ return final_image
153
+
154
+ def _clear(sel_pix, img, mask, seg, out, prompt, neg_prompt, bg):
155
+ img = None
156
+ mask = None
157
+ seg = None
158
+ out = None
159
+ prompt = ""
160
+ neg_prompt = ""
161
+ bg = False
162
+ return img, mask, seg, out, prompt, neg_prompt, bg
163
+
164
+ input_img.change(
165
+ generate_mask,
166
+ inputs=[input_img],
167
+ outputs=[mask_img],
168
+ )
169
+ submit.click(
170
+ infer,
171
+ inputs=[mask_img, prompt_text, negative_prompt_text],
172
+ outputs=[output_img],
173
+ )
174
+ clear.click(
175
+ _clear,
176
+ inputs=[
177
+ input_img,
178
+ mask_img,
179
+ output_img,
180
+ prompt_text,
181
+ negative_prompt_text,
182
+ ],
183
+ outputs=[
184
+ input_img,
185
+ mask_img,
186
+ output_img,
187
+ prompt_text,
188
+ negative_prompt_text,
189
+ ],
190
+ )
191
+
192
+ if __name__ == "__main__":
193
+ demo.queue()
194
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ git+https://github.com/facebookresearch/segment-anything.git
4
+ transformers
5
+ flax
6
+ jax[cuda11_pip]
7
+ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
8
+ jaxlib
9
+ git+https://github.com/huggingface/diffusers@main
10
+ opencv-python
sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879