Spaces:
Runtime error
Runtime error
File size: 7,914 Bytes
a324479 42d64c8 a324479 e6915e1 a324479 42d64c8 169ec0c 892096a 6f417f5 892096a 6f417f5 ddb9f2a 42d64c8 ddb9f2a 892096a 6f417f5 ddb9f2a 6f417f5 e6915e1 892096a 6f417f5 169ec0c 6f417f5 a324479 892096a e6915e1 a324479 d2262bb e6915e1 fef68a2 e6915e1 169ec0c 6f417f5 e6915e1 6f417f5 e6915e1 6f417f5 169ec0c 892096a e6915e1 892096a e6915e1 169ec0c e6915e1 169ec0c 892096a 169ec0c 892096a 169ec0c 6f417f5 169ec0c a324479 169ec0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import gradio as gr
import jax
from PIL import Image
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
import jax.profiler
import jax.numpy as jnp
import numpy as np
import gc
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
"mfidabel/controlnet-segment-anything", dtype=jnp.float32
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
)
# Add ControlNet params and Replicate
params["controlnet"] = controlnet_params
p_params = replicate(params)
jax.profiler.save_device_memory_profile("memory.prof")
# Description
title = "# 🧨 ControlNet on Segment Anything 🤗"
description = """This is a demo on 🧨 ControlNet based on Meta's [Segment Anything Model](https://segment-anything.com/).
Upload a Segment Anything Segmentation Map, write a prompt, and generate images 🤗 This demo is still a Work in Progress, so don't expect it to work well for now !!
⌛️ It takes about 30~ seconds to generate 4 samples, to get faster results, don't forget to reduce the Nº Samples to 1.
You can obtain the Segmentation Map of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mfidabel/JAX_SPRINT_2023/blob/main/Segment_Anything_JAX_SPRINT.ipynb)
A huge thanks goes out to @Google Cloud, for providing us with powerful TPUs that enabled us to train this model; and to the @HuggingFace Team for organizing the sprint.
"""
about = """
# 👨💻 About the model
This model is based on the [ControlNet Model](https://huggingface.co/blog/controlnet), which allow us to generate Images using some sort of condition image. For this model, we selected the segmentation maps produced by Meta's new segmentation model called [Segment Anything Model](https://github.com/facebookresearch/segment-anything) as the condition image. We then trained the model to generate images based on the structure of the segmentation maps and the text prompts given.
# 💾 About the dataset
For the training, we generated a segmented dataset based on the [COYO-700M](https://huggingface.co/datasets/kakaobrain/coyo-700m) dataset. The dataset provided us with the images, and the text prompts. For the segmented images, we used [Segment Anything Model](https://github.com/facebookresearch/segment-anything). We then created 8k samples to train our model on, which isn't a lot, but as a team, we have been very busy with many other responsibilities and time constraints, which made it challenging to dedicate a lot of time to generating a larger dataset. Despite the constraints we faced, we have still managed to achieve some nice results 🙌
You can check the generated datasets below ⬇️
- [sam-coyo-2k](https://huggingface.co/datasets/mfidabel/sam-coyo-2k)
- [sam-coyo-2.5k](https://huggingface.co/datasets/mfidabel/sam-coyo-2.5k)
- [sam-coyo-3k](https://huggingface.co/datasets/mfidabel/sam-coyo-3k)
"""
examples = [["contemporary living room of a house", "low quality", "examples/condition_image_1.png"],
["new york buildings, Vincent Van Gogh starry night ", "low quality, monochrome", "examples/condition_image_2.png"],
["contemporary living room, high quality, 4k, realistic", "low quality, monochrome, low res", "examples/condition_image_3.png"],
["internal stairs of a japanese house", "low quality, low res, people, kids", "examples/condition_image_4.png"],
["a photo of a girl taking notes", "low quality, low res, painting", "examples/condition_image_5.png"],
["painting of an hot air ballon flying over a valley, The Great Wave off Kanagawa style, blue and white colors", "low quality, low res", "examples/condition_image_6.png"],
["painting of families enjoying the sunset, The Garden of Earthly Delights style, joyful", "low quality, low res", "examples/condition_image_7.png"]]
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
# Inference Function
def infer(prompts, negative_prompts, image, num_inference_steps = 50, seed = 4, num_samples = 4):
try:
rng = jax.random.PRNGKey(int(seed))
num_inference_steps = int(num_inference_steps)
image = Image.fromarray(image, mode="RGB")
num_samples = max(jax.device_count(), int(num_samples))
p_rng = jax.random.split(rng, jax.device_count())
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([image] * num_samples)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=p_rng,
num_inference_steps=num_inference_steps,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
del negative_prompt_ids
del processed_image
del prompt_ids
output = output.reshape((num_samples,) + output.shape[-3:])
final_image = [np.array(x*255, dtype=np.uint8) for x in output]
print(output.shape)
del output
except Exception as e:
print("Error: " + str(e))
final_image = [np.zeros((512, 512, 3), dtype=np.uint8)] * num_samples
finally:
gc.collect()
return final_image
default_example = examples[5]
cond_img = gr.Image(label="Input", shape=(512, 512), value=default_example[2])\
.style(height=200)
output = gr.Gallery(label="Generated images")\
.style(height=200, rows=[2], columns=[2], object_fit="contain")
prompt = gr.Textbox(lines=1, label="Prompt", value=default_example[0])
negative_prompt = gr.Textbox(lines=1, label="Negative Prompt", value=default_example[1])
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column():
# Title
gr.Markdown(title)
# Description
gr.Markdown(description)
with gr.Column():
# Examples
gr.Markdown("Try some of the examples below ⬇️")
gr.Examples(examples=examples,
inputs=[prompt, negative_prompt, cond_img],
outputs=output,
fn=infer,
examples_per_page=4)
# Images
with gr.Row(variant="panel"):
with gr.Column(scale=2):
cond_img.render()
with gr.Column(scale=1):
output.render()
# Submit & Clear
with gr.Row():
with gr.Column():
prompt.render()
negative_prompt.render()
with gr.Column():
with gr.Accordion("Advanced options", open=False):
num_steps = gr.Slider(10, 60, 50, step=1, label="Steps")
seed = gr.Slider(0, 1024, 4, step=1, label="Seed")
num_samples = gr.Slider(1, 4, 4, step=1, label="Nº Samples")
submit = gr.Button("Generate")
# TODO: Download Button
with gr.Row():
gr.Markdown(about, elem_classes="about")
submit.click(infer,
inputs=[prompt, negative_prompt, cond_img, num_steps, seed, num_samples],
outputs = output)
demo.queue()
demo.launch() |