File size: 7,914 Bytes
a324479
 
 
 
 
 
42d64c8
a324479
 
e6915e1
a324479
 
 
 
 
 
 
 
 
 
 
 
 
 
42d64c8
 
169ec0c
 
892096a
 
6f417f5
892096a
6f417f5
ddb9f2a
42d64c8
 
 
ddb9f2a
892096a
 
6f417f5
 
 
 
 
 
 
 
 
 
ddb9f2a
6f417f5
 
 
 
 
 
 
 
e6915e1
892096a
6f417f5
 
 
 
 
169ec0c
6f417f5
a324479
 
892096a
e6915e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a324479
 
d2262bb
e6915e1
 
 
 
 
fef68a2
e6915e1
 
 
 
169ec0c
6f417f5
e6915e1
 
 
 
 
 
 
 
 
6f417f5
e6915e1
 
 
6f417f5
 
169ec0c
 
 
892096a
e6915e1
892096a
e6915e1
169ec0c
 
 
 
e6915e1
 
169ec0c
 
 
 
892096a
169ec0c
 
892096a
 
169ec0c
6f417f5
 
169ec0c
 
 
 
a324479
169ec0c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import gradio as gr
import jax
from PIL import Image
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
import jax.profiler
import jax.numpy as jnp
import numpy as np
import gc


controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    "mfidabel/controlnet-segment-anything", dtype=jnp.float32
)

pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
)

# Add ControlNet params and Replicate
params["controlnet"] = controlnet_params
p_params = replicate(params)

jax.profiler.save_device_memory_profile("memory.prof")

# Description
title = "# 🧨 ControlNet on Segment Anything 🤗"
description = """This is a demo on 🧨 ControlNet based on Meta's [Segment Anything Model](https://segment-anything.com/).

                Upload a Segment Anything Segmentation Map, write a prompt, and generate images 🤗 This demo is still a Work in Progress, so don't expect it to work well for now !! 

                ⌛️ It takes about 30~ seconds to generate 4 samples, to get faster results, don't forget to reduce the Nº Samples to 1.

                You can obtain the Segmentation Map of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mfidabel/JAX_SPRINT_2023/blob/main/Segment_Anything_JAX_SPRINT.ipynb)

                
                A huge thanks goes out to @Google Cloud, for providing us with powerful TPUs that enabled us to train this model; and to the @HuggingFace Team for organizing the sprint.
              """

about = """


        # 👨‍💻 About the model

        This model is based on the [ControlNet Model](https://huggingface.co/blog/controlnet), which allow us to generate Images using some sort of condition image. For this model, we selected the segmentation maps produced by Meta's new segmentation model called [Segment Anything Model](https://github.com/facebookresearch/segment-anything) as the condition image. We then trained the model to generate images based on the structure of the segmentation maps and the text prompts given.


        # 💾 About the dataset

        For the training, we generated a segmented dataset based on the [COYO-700M](https://huggingface.co/datasets/kakaobrain/coyo-700m) dataset. The dataset provided us with the images, and the text prompts. For the segmented images, we used [Segment Anything Model](https://github.com/facebookresearch/segment-anything). We then created 8k samples to train our model on, which isn't a lot, but as a team, we have been very busy with many other responsibilities and time constraints, which made it challenging to dedicate a lot of time to generating a larger dataset.  Despite the constraints we faced, we have still managed to achieve some nice results 🙌

        You can check the generated datasets below ⬇️
        - [sam-coyo-2k](https://huggingface.co/datasets/mfidabel/sam-coyo-2k)
        - [sam-coyo-2.5k](https://huggingface.co/datasets/mfidabel/sam-coyo-2.5k)
        - [sam-coyo-3k](https://huggingface.co/datasets/mfidabel/sam-coyo-3k)

"""

examples = [["contemporary living room of a house", "low quality", "examples/condition_image_1.png"],
            ["new york buildings,  Vincent Van Gogh starry night ", "low quality, monochrome", "examples/condition_image_2.png"],
            ["contemporary living room,  high quality, 4k, realistic", "low quality, monochrome, low res", "examples/condition_image_3.png"],
            ["internal stairs of a japanese house", "low quality,  low res,  people, kids", "examples/condition_image_4.png"],
            ["a photo of a girl taking notes", "low quality,  low res,  painting", "examples/condition_image_5.png"],
            ["painting of an hot air ballon flying over a valley, The Great Wave off Kanagawa style, blue and white colors", "low quality,  low res", "examples/condition_image_6.png"],
            ["painting of families enjoying the sunset,  The Garden of Earthly Delights style,  joyful", "low quality,  low res", "examples/condition_image_7.png"]]

css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"

# Inference Function
def infer(prompts, negative_prompts, image, num_inference_steps = 50, seed = 4, num_samples = 4):
    try:
        rng = jax.random.PRNGKey(int(seed))
        num_inference_steps = int(num_inference_steps)
        image = Image.fromarray(image, mode="RGB")
        num_samples = max(jax.device_count(), int(num_samples))
        p_rng = jax.random.split(rng, jax.device_count())
        
        prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
        negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
        processed_image = pipe.prepare_image_inputs([image] * num_samples)
        
        prompt_ids = shard(prompt_ids)
        negative_prompt_ids = shard(negative_prompt_ids)
        processed_image = shard(processed_image)
        
        output = pipe(
            prompt_ids=prompt_ids,
            image=processed_image,
            params=p_params,
            prng_seed=p_rng,
            num_inference_steps=num_inference_steps,
            neg_prompt_ids=negative_prompt_ids,
            jit=True,
        ).images

        del negative_prompt_ids
        del processed_image
        del prompt_ids

        output = output.reshape((num_samples,) + output.shape[-3:])
        final_image = [np.array(x*255, dtype=np.uint8) for x in output]
        print(output.shape)
        del output
        
    except Exception as e:
        print("Error: " + str(e))
        final_image = [np.zeros((512, 512, 3), dtype=np.uint8)] * num_samples
    finally:
        gc.collect()
        return final_image
    

default_example = examples[5]

cond_img = gr.Image(label="Input", shape=(512, 512), value=default_example[2])\
                    .style(height=200)

output = gr.Gallery(label="Generated images")\
                    .style(height=200, rows=[2], columns=[2], object_fit="contain")

prompt = gr.Textbox(lines=1, label="Prompt", value=default_example[0])
negative_prompt = gr.Textbox(lines=1, label="Negative Prompt", value=default_example[1])


with gr.Blocks(css=css) as demo:
    with gr.Row():
        with gr.Column():
            # Title
            gr.Markdown(title)
            # Description
            gr.Markdown(description)

        with gr.Column():
            # Examples
            gr.Markdown("Try some of the examples below ⬇️")
            gr.Examples(examples=examples,
                    inputs=[prompt, negative_prompt, cond_img],
                    outputs=output,
                    fn=infer,
                    examples_per_page=4)

    # Images
    with gr.Row(variant="panel"):
        with gr.Column(scale=2):
            cond_img.render()
        with gr.Column(scale=1):
            output.render()
        
    # Submit & Clear
    with gr.Row():
        with gr.Column():
            prompt.render()
            negative_prompt.render()

        with gr.Column():
            with gr.Accordion("Advanced options", open=False):
                num_steps = gr.Slider(10, 60, 50, step=1, label="Steps")
                seed = gr.Slider(0, 1024, 4, step=1, label="Seed")
                num_samples = gr.Slider(1, 4, 4, step=1, label="Nº Samples")
                
            submit = gr.Button("Generate")
            # TODO: Download Button

    with gr.Row():
        gr.Markdown(about, elem_classes="about")
    
    submit.click(infer, 
                 inputs=[prompt, negative_prompt, cond_img, num_steps, seed, num_samples],
                 outputs = output)
    
demo.queue()
demo.launch()