mfidabel commited on
Commit
169ec0c
1 Parent(s): 8489717

Updated Layout

Browse files
Files changed (1) hide show
  1. app.py +49 -13
app.py CHANGED
@@ -4,15 +4,11 @@ from PIL import Image
4
  from flax.jax_utils import replicate
5
  from flax.training.common_utils import shard
6
  from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
 
7
  import jax.numpy as jnp
8
  import numpy as np
9
 
10
 
11
- title = "🧨 ControlNet on Segment Anything 🤗"
12
- description = "This is a demo on ControlNet based on Segment Anything"
13
-
14
- examples = [["a modern main room of a house", "low quality", "condition_image_1.png", 50, 4]]
15
-
16
  controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
17
  "mfidabel/controlnet-segment-anything", dtype=jnp.float32
18
  )
@@ -25,13 +21,18 @@ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
25
  params["controlnet"] = controlnet_params
26
  p_params = replicate(params)
27
 
 
 
 
 
 
28
 
29
  # Inference Function
30
- def infer(prompts, negative_prompts, image, num_inference_steps, seed):
31
  rng = jax.random.PRNGKey(int(seed))
32
  num_inference_steps = int(num_inference_steps)
33
  image = Image.fromarray(image, mode="RGB")
34
- num_samples = jax.device_count()
35
  p_rng = jax.random.split(rng, jax.device_count())
36
 
37
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
@@ -59,10 +60,45 @@ def infer(prompts, negative_prompts, image, num_inference_steps, seed):
59
  del output
60
 
61
  return final_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- gr.Interface(fn = infer,
64
- inputs = ["text", "text", "image", "number", "number"],
65
- outputs = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto", preview=True),
66
- title = title,
67
- description = description,
68
- examples = examples).launch()
 
4
  from flax.jax_utils import replicate
5
  from flax.training.common_utils import shard
6
  from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
7
+ from diffusers.utils import load_image
8
  import jax.numpy as jnp
9
  import numpy as np
10
 
11
 
 
 
 
 
 
12
  controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
13
  "mfidabel/controlnet-segment-anything", dtype=jnp.float32
14
  )
 
21
  params["controlnet"] = controlnet_params
22
  p_params = replicate(params)
23
 
24
+ # Description
25
+ title = "# 🧨 ControlNet on Segment Anything 🤗"
26
+ description = "This is a demo on ControlNet based on Segment Anything"
27
+
28
+ examples = [["a modern main room of a house", "low quality", "condition_image_1.png", 50, 4, 4]]
29
 
30
  # Inference Function
31
+ def infer(prompts, negative_prompts, image, num_inference_steps, seed, num_samples):
32
  rng = jax.random.PRNGKey(int(seed))
33
  num_inference_steps = int(num_inference_steps)
34
  image = Image.fromarray(image, mode="RGB")
35
+ num_samples = max(jax.device_count(), int(num_samples))
36
  p_rng = jax.random.split(rng, jax.device_count())
37
 
38
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
 
60
  del output
61
 
62
  return final_image
63
+
64
+ with gr.Blocks(css="h1 { text-align: center }") as demo:
65
+ # Title
66
+ gr.Markdown(title)
67
+ # Description
68
+ gr.Markdown(description)
69
+
70
+ # Images
71
+ with gr.Row(variant="panel"):
72
+ cond_img = gr.Image(label="Input")\
73
+ .style(height=400)
74
+ output = gr.Gallery(label="Generated images")\
75
+ .style(height=400, rows=[2], columns=[2])
76
+
77
+ # Submit & Clear
78
+ with gr.Row():
79
+ with gr.Column():
80
+ prompt = gr.Textbox(lines=1, label="Prompt")
81
+ negative_prompt = gr.Textbox(lines=1, label="Negative Prompt")
82
+
83
+ with gr.Column():
84
+ with gr.Accordion("Advanced options", open=False):
85
+ num_steps = gr.Slider(10, 60, 50, step=1, label="Steps")
86
+ seed = gr.Slider(0, 1024, 0, step=1, label="Seed")
87
+ num_samples = gr.Slider(1, 4, 4, step=1, label="Nº Samples")
88
+
89
+ submit = gr.Button("Submit")
90
+
91
+ # Examples
92
+ gr.Examples(examples=examples,
93
+ inputs=[prompt, negative_prompt, cond_img, num_steps, seed, num_samples],
94
+ outputs=output,
95
+ fn=infer,
96
+ cache_examples=True)
97
+
98
+
99
+ submit.click(infer,
100
+ inputs=[prompt, negative_prompt, cond_img, num_steps, seed, num_samples],
101
+ outputs = output)
102
 
103
+ demo.queue()
104
+ demo.launch()