import gradio as gr import numpy as np import torch import jax import jax.numpy as jnp from diffusers import StableDiffusionInpaintPipeline from flax.jax_utils import replicate from flax.training.common_utils import shard from PIL import Image from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator from diffusers import ( UniPCMultistepScheduler, FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, ) import colorsys sam_checkpoint = "sam_vit_h_4b8939.pth" model_type = "vit_h" device = "cpu" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) predictor = SamPredictor(sam) mask_generator = SamAutomaticMaskGenerator(sam) controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( "mfidabel/controlnet-segment-anything", dtype=jnp.float32 ) pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16, ) params["controlnet"] = controlnet_params p_params = replicate(params) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe = pipe.to(device) with gr.Blocks() as demo: gr.Markdown("# WildSynth: Synthetic Wildlife Data Generation") gr.Markdown( """ We have trained a JAX ControlNet model with To try the demo, upload an image and select object(s) you want to inpaint. Write a prompt & a negative prompt to control the inpainting. Click on the "Submit" button to inpaint the selected object(s). Check "Background" to inpaint the background instead of the selected object(s). If the demo is slow, clone the space to your own HF account and run on a GPU. """ ) with gr.Row(): input_img = gr.Image(label="Input") mask_img = gr.Image(label="Mask", interactive=False) output_img = gr.Image(label="Output", interactive=False) with gr.Row(): prompt_text = gr.Textbox(lines=1, label="Prompt") negative_prompt_text = gr.Textbox(lines=1, label="Negative Prompt") with gr.Row(): submit = gr.Button("Submit") clear = gr.Button("Clear") def generate_mask(image, evt: gr.SelectData): predictor.set_image(image) input_point = np.array([120, 21]) input_label = np.ones(input_point.shape[0]) mask, _, _ = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=False, ) # clear torch cache torch.cuda.empty_cache() mask = Image.fromarray(mask[0, :, :]) segs = mask_generator.generate(image) boolean_masks = [s["segmentation"] for s in segs] finseg = np.zeros( (boolean_masks[0].shape[0], boolean_masks[0].shape[1], 3), dtype=np.uint8 ) # Loop over the boolean masks and assign a unique color to each class for class_id, boolean_mask in enumerate(boolean_masks): hue = class_id * 1.0 / len(boolean_masks) rgb = tuple(int(i * 255) for i in colorsys.hsv_to_rgb(hue, 1, 1)) rgb_mask = np.zeros( (boolean_mask.shape[0], boolean_mask.shape[1], 3), dtype=np.uint8 ) rgb_mask[:, :, 0] = boolean_mask * rgb[0] rgb_mask[:, :, 1] = boolean_mask * rgb[1] rgb_mask[:, :, 2] = boolean_mask * rgb[2] finseg += rgb_mask torch.cuda.empty_cache() return mask, finseg def infer( image, prompts, negative_prompts, num_inference_steps=50, seed=4, num_samples=4 ): try: rng = jax.random.PRNGKey(int(seed)) num_inference_steps = int(num_inference_steps) image = Image.fromarray(image, mode="RGB") num_samples = max(jax.device_count(), int(num_samples)) p_rng = jax.random.split(rng, jax.device_count()) prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) negative_prompt_ids = pipe.prepare_text_inputs( [negative_prompts] * num_samples ) processed_image = pipe.prepare_image_inputs([image] * num_samples) prompt_ids = shard(prompt_ids) negative_prompt_ids = shard(negative_prompt_ids) processed_image = shard(processed_image) output = pipe( prompt_ids=prompt_ids, image=processed_image, params=p_params, prng_seed=p_rng, num_inference_steps=num_inference_steps, neg_prompt_ids=negative_prompt_ids, jit=True, ).images del negative_prompt_ids del processed_image del prompt_ids output = output.reshape((num_samples,) + output.shape[-3:]) final_image = [np.array(x * 255, dtype=np.uint8) for x in output] print(output.shape) del output except Exception as e: print("Error: " + str(e)) final_image = [np.zeros((512, 512, 3), dtype=np.uint8)] * num_samples finally: gc.collect() return final_image def _clear(sel_pix, img, mask, seg, out, prompt, neg_prompt, bg): img = None mask = None seg = None out = None prompt = "" neg_prompt = "" bg = False return img, mask, seg, out, prompt, neg_prompt, bg input_img.change( generate_mask, inputs=[input_img], outputs=[mask_img], ) submit.click( infer, inputs=[mask_img, prompt_text, negative_prompt_text], outputs=[output_img], ) clear.click( _clear, inputs=[ input_img, mask_img, output_img, prompt_text, negative_prompt_text, ], outputs=[ input_img, mask_img, output_img, prompt_text, negative_prompt_text, ], ) if __name__ == "__main__": demo.queue() demo.launch()