import gradio as gr import jax import numpy as np import jax.numpy as jnp from flax.jax_utils import replicate from flax.training.common_utils import shard from PIL import Image from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel import cv2 # load control net and stable diffusion v1-5 controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( "Nahrawy/controlnet-VIDIT-FAID", dtype=jnp.bfloat16, revision="615ba4a457b95a0eba813bcc8caf842c03a4f7bd" ) pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16 ) def create_key(seed=0): return jax.random.PRNGKey(seed) def process_mask(image): mask = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) mask = cv2.resize(mask,(512,512)) return mask def infer(prompts, negative_prompts, image): params["controlnet"] = controlnet_params num_samples = 1 #jax.device_count() rng = create_key(0) rng = jax.random.split(rng, jax.device_count()) im = process_mask(image) mask = Image.fromarray(im) prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) processed_image = pipe.prepare_image_inputs([mask] * num_samples) p_params = replicate(params) prompt_ids = shard(prompt_ids) negative_prompt_ids = shard(negative_prompt_ids) processed_image = shard(processed_image) print(processed_image[0].shape) output = pipe( prompt_ids=prompt_ids, image=processed_image, params=p_params, prng_seed=rng, num_inference_steps=50, neg_prompt_ids=negative_prompt_ids, jit=True, ).images output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) return output_images e_images = ['0.png', '1.png' '2.png'] e_prompts = ['a dog in the middle of the road, shadow on the ground,light direction north-east', 'a skyscraper in the middle of an intersection, shadow on the ground, light direction east', 'a red rural house, light temperature 5500, shadow on the ground, light direction south-west'] e_negative_prompts = ['monochromatic, unrealistic, bad looking, full of glitches'*3] examples = [] for image, prompt, negative_prompt in zip(e_images, e_prompts, e_negative_prompts): examples.append([prompt, negative_prompt, image]) with gr.Blocks() as demo: gr.Markdown(title) prompts = gr.Textbox(label='prompts') negative_prompts = gr.Textbox(label='negative_prompts') with gr.Row(): with gr.Column(): in_image = gr.Image(label="Depth Map Conditioning") with gr.Column(): out_image = gr.Gallery(label="Generated Image") with gr.Row(): btn = gr.Button("Run") gr.Examples(examples=examples, inputs=[prompts,negative_prompts, in_image], outputs=out_image, cache_examples=True) btn.click(fn=infer, inputs=[prompts,negative_prompts, in_image] , outputs=out_image) demo.launch()