File size: 14,630 Bytes
cd1e8dc
628e6c3
cd1e8dc
 
 
 
 
 
 
 
 
 
628e6c3
35f97ba
628e6c3
426fb9c
822b647
426fb9c
 
 
1195790
39db156
 
1195790
 
 
ba29a7c
39db156
 
 
ba29a7c
39db156
 
a30c911
39db156
 
 
 
 
822b647
 
 
39db156
822b647
39db156
a30c911
ba29a7c
 
663f236
39db156
 
 
 
663f236
ba29a7c
 
60c2a6b
a5f0d0c
39db156
 
 
 
60c2a6b
1d71dff
60c2a6b
 
39db156
 
 
 
 
60c2a6b
cd1e8dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a95bff
cd1e8dc
 
 
 
 
a5e9129
 
 
cd1e8dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35f97ba
 
be5cb04
f2819ff
ce09356
 
ccf1a03
cd1e8dc
 
 
 
 
 
 
 
4598830
cd1e8dc
 
 
4598830
 
 
cd1e8dc
 
 
e57f9cc
cd1e8dc
 
9702a1f
cd1e8dc
 
a5e9129
 
cd1e8dc
 
9702a1f
a5e9129
cd1e8dc
 
 
 
7b66f42
cd1e8dc
 
 
 
 
 
 
 
 
788a013
dc72f49
788a013
 
 
 
cd1e8dc
788a013
 
 
 
cd1e8dc
 
68696f0
71f4cfe
 
5461399
cd1e8dc
 
 
 
 
 
 
71f4cfe
 
cd1e8dc
 
 
71f4cfe
 
 
783c45d
cd1e8dc
 
 
783c45d
cd1e8dc
 
 
a5e9129
cd1e8dc
628e6c3
 
1195790
628e6c3
 
 
 
 
 
cd1e8dc
 
628e6c3
 
 
 
 
cd1e8dc
628e6c3
 
 
 
 
 
 
 
 
 
 
 
cd1e8dc
 
 
 
 
 
628e6c3
 
 
 
 
 
 
 
3f2581f
628e6c3
14811bd
4598830
 
2d0240c
628e6c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5f0d0c
 
 
 
 
ba29a7c
 
 
a30c911
bf28d41
f7a5714
e5488f2
f7a5714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14811bd
f7a5714
14811bd
f7a5714
 
ba29a7c
60c2a6b
 
 
 
 
 
39db156
 
 
ba29a7c
 
 
628e6c3
 
 
cd1e8dc
628e6c3
cd1e8dc
 
 
628e6c3
 
4598830
628e6c3
cd1e8dc
628e6c3
 
 
 
 
 
88bd8ea
 
628e6c3
 
a5e9129
 
 
628e6c3
a5e9129
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
# This file is adapted from https://huggingface.co/spaces/diffusers/controlnet-canny/blob/main/app.py
# The original license file is LICENSE.ControlNet in this repo.
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
from flax.training.common_utils import shard
from flax.jax_utils import replicate    
from diffusers.utils import load_image
import jax.numpy as jnp
import jax
import cv2
from PIL import Image
import numpy as np
import gradio as gr
import os


if gr.__version__ != "3.28.3": #doesn't work...
    os.system("pip uninstall -y gradio")
    os.system("pip install gradio==3.28.3")

title_description = """
# UCDR-Net
## Unlimited Controlled Domain Randomization Network for Bridging the Sim2Real Gap in Robotics

"""

description = """
While existing ControlNet and public diffusion models are predominantly geared towards high-resolution images (512x512 or above) and intricate artistic detail generation, there's an untapped potential of these models in Automatic Data Augmentation (ADA). 
By harnessing the inherent variance in prompt-conditioned generated images, we can significantly boost the visual diversity of training samples for computer vision pipelines. 
This is particularly relevant in the field of robotics, where deep learning is increasingly playing a pivotal role in training policies for robotic manipulation from images.

In this HuggingFace sprint, we present UCDR-Net (Unlimited Controlled Domain Randomization Network), a novel CannyEdge mini-ControlNet trained on Stable Diffusion 1.5 with mixed datasets. 
Our model generates photorealistic and varied renderings from simplistic robotic simulation images, enabling real-time data augmentation for robotic vision training.

We specifically designed UCDR-Net to be fast and composition preserving, with an emphasis on lower resolution images (128x128) for online data augmentation in typical preprocessing pipelines. 
Our choice of Canny Edge version of ControlNet ensures shape and structure preservation in the image, which is crucial for visuomotor policy learning.

We trained ControlNet from scratch using only 128x128 images, preprocessing the training datasets and extracting Canny Edge maps. 
We then trained four Control-Nets with different mixtures of 2 datasets (Coyo-700M and Bridge Data) and showcased the results.
* [Coyo-700M](https://github.com/kakaobrain/coyo-dataset)
* [Bridge](https://sites.google.com/view/bridgedata)

Model Description and Training Process: Please refer to the readme file attached to the model repository.

Model Repository: [ControlNet repo](https://huggingface.co/Baptlem/baptlem-controlnet)

"""

traj_description = """ 
To demonstrate UCDR-Net's capabilities, we generated a trajectory of our simulated robotic environment and presented the resulting videos for each model. 
We batched the frames for each video and performed independent inference for each frame, which explains the "wobbling" effect.
Prompt used for every video: "A robotic arm with a gripper and a small cube on a table, super realistic, industrial background"

"""


perfo_description = """
Our model has been benchmarked on a node of 4 Titan RTX 24Go GPUs, achieving an impressive 13 FPS image generation rate!
The Table on the right shows the performances of our models running on different nodes.
To make the benchmark, we loaded one of our model on every GPUs of the node. We then retrieve an episode of our simulation.
For every frame of the episode, we preprocess the image (resize, canny, …) and process the Canny image on the GPUs.
We repeated this procedure for different Batch Size (BS).

We can see that the greater the BS the greater the FPS. By increazing the BS, we take advantage of the parallelization of the GPUs.
"""

conclusion_description = """
UCDR-Net stands as a natural development in bridging the Sim2Real gap in robotics by providing real-time data augmentation for training visual policies. 
We are excited to share our work with the HuggingFace community and contribute to the advancement of robotic vision training techniques.

"""

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

def load_controlnet(controlnet_version):
    controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
        "Baptlem/baptlem-controlnet",
        subfolder=controlnet_version,
        from_flax=True,
        dtype=jnp.float32,
    )
    return controlnet, controlnet_params


def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"):
    controlnet, controlnet_params = load_controlnet(controlnet_version)

    scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(
        sb_path,
        subfolder="scheduler"
    )
    
    pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
        sb_path,
        controlnet=controlnet, 
        revision="flax", 
        dtype=jnp.bfloat16
    )
        
    pipe.scheduler = scheduler
    params["controlnet"] = controlnet_params
    params["scheduler"] = scheduler_params
    return pipe, params  

    

controlnet_path = "Baptlem/baptlem-controlnet"
controlnet_version = "coyo-500k"

# Constants
low_threshold = 100
high_threshold = 200

print(os.path.abspath('.'))
print(os.listdir("."))
print("Gradio version:", gr.__version__)
# pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_model_cpu_offload()
# pipe.enable_attention_slicing()
print("Loaded models...")
def pipe_inference(
    image,
    prompt,
    is_canny=False,
    num_samples=4,
    resolution=128,
    num_inference_steps=50,
    guidance_scale=7.5,
    model="coyo-500k",
    seed=0,
    negative_prompt="",
    ):
    print("Loading pipe")
    pipe, params = load_sb_pipe(model)
        
    if not isinstance(image, np.ndarray):
        image = np.array(image) 

    processed_image = resize_image(image, resolution) #-> PIL
        
    if not is_canny:
        resized_image, processed_image = preprocess_canny(processed_image, resolution)

    rng = create_key(seed)
    rng = jax.random.split(rng, jax.device_count())

    prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples)
    negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples)
    processed_image = pipe.prepare_image_inputs([processed_image] * num_samples)
        
    p_params = replicate(params)
    prompt_ids = shard(prompt_ids)
    negative_prompt_ids = shard(negative_prompt_ids)
    processed_image = shard(processed_image)
    print("Inference...")
    output = pipe(
        prompt_ids=prompt_ids,
        image=processed_image,
        params=p_params,
        prng_seed=rng,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        neg_prompt_ids=negative_prompt_ids,
        jit=True,
    ).images
    print("Finished inference...")
    # all_outputs = []
    # all_outputs.append(image)
    # if not is_canny:
    #     all_outputs.append(resized_image)
        
    # for image in output.images:
    #     all_outputs.append(image)

    all_outputs = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
    return all_outputs

def resize_image(image, resolution):  
    if not isinstance(image, np.ndarray):
        image = np.array(image) 
    h, w = image.shape[:2]
    ratio = w/h
    if ratio > 1 :
        resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST)
    elif ratio < 1 :
        resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST)
    else:
        resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST)
    
    return Image.fromarray(resized_image)
    
    
def preprocess_canny(image, resolution=128):
    if not isinstance(image, np.ndarray):
        image = np.array(image) 
        
    processed_image = cv2.Canny(image, low_threshold, high_threshold)
    processed_image = processed_image[:, :, None]
    processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2)

    resized_image = Image.fromarray(image)
    processed_image = Image.fromarray(processed_image)
    return resized_image, processed_image


def create_demo(process, max_images=12, default_num_images=4):
    with gr.Blocks() as demo:
        with gr.Row():
            gr.Markdown(title_description)
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(source='upload', type='numpy')
                prompt = gr.Textbox(label='Prompt')
                run_button = gr.Button(label='Run')
                with gr.Accordion('Advanced options', open=False):
                    is_canny = gr.Checkbox(
                        label='Is canny', value=False)
                    num_samples = gr.Slider(label='Images',
                                            minimum=1,
                                            maximum=max_images,
                                            value=default_num_images,
                                            step=1)
                    """
                    canny_low_threshold = gr.Slider(
                        label='Canny low threshold',
                        minimum=1,
                        maximum=255,
                        value=100,
                        step=1)
                    canny_high_threshold = gr.Slider(
                        label='Canny high threshold',
                        minimum=1,
                        maximum=255,
                        value=200,
                        step=1)
                    """
                    resolution = gr.Slider(label='Resolution',
                                          minimum=128,
                                          maximum=128,
                                          value=128,
                                          step=1)
                    num_steps = gr.Slider(label='Steps',
                                          minimum=1,
                                          maximum=100,
                                          value=20,
                                          step=1)
                    guidance_scale = gr.Slider(label='Guidance Scale',
                                               minimum=0.1,
                                               maximum=30.0,
                                               value=7.5,
                                               step=0.1)
                    model = gr.Dropdown(choices=["coyo-500k", "bridge-2M", "coyo1M-bridge2M", "coyo2M-bridge325k"],
                                        value="coyo-500k",
                                        label="Model used for inference", 
                                        info="Find every models at https://huggingface.co/Baptlem/baptlem-controlnet")
                    seed = gr.Slider(label='Seed',
                                     minimum=-1,
                                     maximum=2147483647,
                                     step=1,
                                     randomize=True)
                    n_prompt = gr.Textbox(
                        label='Negative Prompt',
                        value=
                        'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
                    )
            with gr.Column():
                result = gr.Gallery(label='Output',
                                    show_label=False,
                                    elem_id='gallery').style(grid=2,
                                                             height='auto')

        with gr.Row():
            gr.Video("./trajectory_hf/trajectory_coyo2M-bridge325k_64.avi",
                        format="avi",
                        interactive=False)
        
        with gr.Row():
            gr.Markdown(description)

        with gr.Row():
            with gr.Column():
                gr.Markdown(traj_description)
            with gr.Column():
                gr.Video("./trajectory_hf/trajectory.avi",
                        format="avi",
                        interactive=False)

        with gr.Row():
            with gr.Column():
                gr.Markdown("Trajectory processed with coyo-500k model :")
            with gr.Column():
                gr.Video("./trajectory_hf/trajectory_coyo-500k.avi",
                        format="avi",
                        interactive=False)

        with gr.Row():
            with gr.Column():
                gr.Markdown("Trajectory processed with bridge-2M model :")
            with gr.Column():
                gr.Video("./trajectory_hf/trajectory_bridge-2M.avi",
                        format="avi",
                        interactive=False)

        with gr.Row():
            with gr.Column():
                gr.Markdown("Trajectory processed with coyo1M-bridge2M model :")
            with gr.Column():
                gr.Video("./trajectory_hf/trajectory_coyo1M-bridge2M.avi",
                        format="avi",
                        interactive=False)

        with gr.Row():
            with gr.Column():
                gr.Markdown("Trajectory processed with coyo2M-bridge325k model :")
            with gr.Column():
                gr.Video("./trajectory_hf/trajectory_coyo2M-bridge325k.avi",
                        format="avi",
                        interactive=False)
        
        with gr.Row():
            with gr.Column():
                gr.Markdown(perfo_description)
            with gr.Column():
                gr.Image("./perfo_rtx.png",
                        interactive=False)

        with gr.Row():
            gr.Markdown(conclusion_description)
        
        
        
        inputs = [
            input_image,
            prompt,
            is_canny,
            num_samples,
            resolution,
            #canny_low_threshold,
            #canny_high_threshold,
            num_steps,
            guidance_scale,
            model,
            seed,
            n_prompt,
        ]
        prompt.submit(fn=process, inputs=inputs, outputs=result)
        run_button.click(fn=process,
                         inputs=inputs,
                         outputs=result,
                         api_name='canny')
    
    return demo

if __name__ == '__main__':

    pipe_inference
    demo = create_demo(pipe_inference)
    demo.queue().launch()
    # gr.Interface(create_demo).launch()