Spaces:
Runtime error
Runtime error
Contrebande Labs
commited on
Commit
•
e0cb68e
1
Parent(s):
06f2eaf
sync with working jax inference code from main repo
Browse files
app.py
CHANGED
@@ -16,6 +16,7 @@ from diffusers import (
|
|
16 |
|
17 |
from transformers import ByT5Tokenizer, FlaxT5ForConditionalGeneration
|
18 |
|
|
|
19 |
def get_inference_lambda(seed):
|
20 |
|
21 |
tokenizer = ByT5Tokenizer()
|
@@ -51,7 +52,7 @@ def get_inference_lambda(seed):
|
|
51 |
"trained_betas": None,
|
52 |
}
|
53 |
)
|
54 |
-
timesteps =
|
55 |
guidance_scale = jnp.array([7.5], dtype=jnp.float32)
|
56 |
|
57 |
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
|
@@ -68,7 +69,13 @@ def get_inference_lambda(seed):
|
|
68 |
|
69 |
image_width = image_height = 256
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
def __tokenize_prompt(prompt: str):
|
74 |
|
@@ -78,15 +85,11 @@ def get_inference_lambda(seed):
|
|
78 |
padding="max_length",
|
79 |
truncation=True,
|
80 |
return_tensors="jax",
|
81 |
-
).input_ids
|
82 |
|
83 |
-
def __convert_image(
|
84 |
-
|
85 |
-
return
|
86 |
-
# return [
|
87 |
-
# Image.fromarray(image)
|
88 |
-
# for image in (np.asarray(vae_output) * 255).round().astype(np.uint8)
|
89 |
-
# ]
|
90 |
|
91 |
def __predict_image(tokenized_prompt: jnp.array):
|
92 |
|
@@ -99,14 +102,6 @@ def get_inference_lambda(seed):
|
|
99 |
context = jnp.concatenate(
|
100 |
[negative_prompt_text_encoder_hidden_states, text_encoder_hidden_states]
|
101 |
)
|
102 |
-
jax.debug.print("got text encoding...")
|
103 |
-
|
104 |
-
latent_shape = (
|
105 |
-
tokenized_prompt.shape[0],
|
106 |
-
unet.in_channels,
|
107 |
-
image_width // vae_scale_factor,
|
108 |
-
image_height // vae_scale_factor,
|
109 |
-
)
|
110 |
|
111 |
def ___timestep(step, step_args):
|
112 |
|
@@ -148,15 +143,12 @@ def get_inference_lambda(seed):
|
|
148 |
scheduler_state, guided_unet_prediction_sample, t, latents
|
149 |
).to_tuple()
|
150 |
|
151 |
-
jax.debug.print("did one step...")
|
152 |
-
|
153 |
return latents, scheduler_state
|
154 |
|
155 |
# initialize scheduler state
|
156 |
initial_scheduler_state = scheduler.set_timesteps(
|
157 |
scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape
|
158 |
)
|
159 |
-
jax.debug.print("initialized scheduler state...")
|
160 |
|
161 |
# initialize latents
|
162 |
initial_latents = (
|
@@ -165,49 +157,33 @@ def get_inference_lambda(seed):
|
|
165 |
)
|
166 |
* initial_scheduler_state.init_noise_sigma
|
167 |
)
|
168 |
-
jax.debug.print("initialized latents...")
|
169 |
|
170 |
final_latents, _ = jax.lax.fori_loop(
|
171 |
0, timesteps, ___timestep, (initial_latents, initial_scheduler_state)
|
172 |
)
|
173 |
-
jax.debug.print("got final latents...")
|
174 |
-
|
175 |
-
# scale and decode the image latents with vae
|
176 |
-
image = (
|
177 |
-
(
|
178 |
-
vae.apply(
|
179 |
-
{"params": vae_params},
|
180 |
-
1 / vae.config.scaling_factor * final_latents,
|
181 |
-
method=vae.decode,
|
182 |
-
).sample
|
183 |
-
/ 2
|
184 |
-
+ 0.5
|
185 |
-
)
|
186 |
-
.clip(0, 1)
|
187 |
-
.transpose(0, 2, 3, 1)
|
188 |
-
)
|
189 |
-
jax.debug.print("got vae processed image output...")
|
190 |
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
-
|
195 |
|
196 |
return lambda prompt: __convert_image(
|
197 |
-
|
198 |
)
|
199 |
|
200 |
|
201 |
generate_image_for_prompt = get_inference_lambda(87)
|
202 |
|
203 |
-
print(f"JAX devices: {jax.devices()}")
|
204 |
-
print(f"JAX device type: {jax.devices()[0].device_kind}")
|
205 |
-
|
206 |
-
def infer_charred(prompt):
|
207 |
-
# your inference function for charr stable difusion control
|
208 |
-
generate_image_for_prompt(prompt)
|
209 |
-
return None
|
210 |
-
|
211 |
|
212 |
with gr.Blocks(theme="gradio/soft") as demo:
|
213 |
|
@@ -239,10 +215,12 @@ with gr.Blocks(theme="gradio/soft") as demo:
|
|
239 |
submit_btn = gr.Button(value="Submit")
|
240 |
charred_inputs = [prompt_input_charr]
|
241 |
submit_btn.click(
|
242 |
-
fn=
|
|
|
|
|
243 |
)
|
244 |
# examples = [["postage stamp from california", "low quality", "charr_output.png", "charr_output.png" ]]
|
245 |
# gr.Examples(fn = infer_sd, inputs = ["text", "text", "image", "image"], examples=examples, cache_examples=True)
|
246 |
|
247 |
demo.queue(concurrency_count=1)
|
248 |
-
demo.launch(debug=True, show_error=True
|
|
|
16 |
|
17 |
from transformers import ByT5Tokenizer, FlaxT5ForConditionalGeneration
|
18 |
|
19 |
+
|
20 |
def get_inference_lambda(seed):
|
21 |
|
22 |
tokenizer = ByT5Tokenizer()
|
|
|
52 |
"trained_betas": None,
|
53 |
}
|
54 |
)
|
55 |
+
timesteps = 20
|
56 |
guidance_scale = jnp.array([7.5], dtype=jnp.float32)
|
57 |
|
58 |
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
|
|
|
69 |
|
70 |
image_width = image_height = 256
|
71 |
|
72 |
+
# Generating latent shape
|
73 |
+
latent_shape = (
|
74 |
+
negative_prompt_text_encoder_hidden_states.shape[0],
|
75 |
+
unet.in_channels,
|
76 |
+
image_width // vae_scale_factor,
|
77 |
+
image_height // vae_scale_factor,
|
78 |
+
)
|
79 |
|
80 |
def __tokenize_prompt(prompt: str):
|
81 |
|
|
|
85 |
padding="max_length",
|
86 |
truncation=True,
|
87 |
return_tensors="jax",
|
88 |
+
).input_ids
|
89 |
|
90 |
+
def __convert_image(image):
|
91 |
+
# create PIL image from JAX tensor converted to numpy
|
92 |
+
return Image.fromarray(np.asarray(image), mode="RGB")
|
|
|
|
|
|
|
|
|
93 |
|
94 |
def __predict_image(tokenized_prompt: jnp.array):
|
95 |
|
|
|
102 |
context = jnp.concatenate(
|
103 |
[negative_prompt_text_encoder_hidden_states, text_encoder_hidden_states]
|
104 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
def ___timestep(step, step_args):
|
107 |
|
|
|
143 |
scheduler_state, guided_unet_prediction_sample, t, latents
|
144 |
).to_tuple()
|
145 |
|
|
|
|
|
146 |
return latents, scheduler_state
|
147 |
|
148 |
# initialize scheduler state
|
149 |
initial_scheduler_state = scheduler.set_timesteps(
|
150 |
scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape
|
151 |
)
|
|
|
152 |
|
153 |
# initialize latents
|
154 |
initial_latents = (
|
|
|
157 |
)
|
158 |
* initial_scheduler_state.init_noise_sigma
|
159 |
)
|
|
|
160 |
|
161 |
final_latents, _ = jax.lax.fori_loop(
|
162 |
0, timesteps, ___timestep, (initial_latents, initial_scheduler_state)
|
163 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
+
vae_output = vae.apply(
|
166 |
+
{"params": vae_params},
|
167 |
+
1 / vae.config.scaling_factor * final_latents,
|
168 |
+
method=vae.decode,
|
169 |
+
).sample
|
170 |
+
|
171 |
+
# return 8 bit RGB image (width, height, rgb)
|
172 |
+
return (
|
173 |
+
((vae_output / 2 + 0.5).transpose(0, 2, 3, 1).clip(0, 1) * 255)
|
174 |
+
.round()
|
175 |
+
.astype(jnp.uint8)[0]
|
176 |
+
)
|
177 |
|
178 |
+
jax_jit_compiled_predict_image = jax.jit(__predict_image)
|
179 |
|
180 |
return lambda prompt: __convert_image(
|
181 |
+
jax_jit_compiled_predict_image(__tokenize_prompt(prompt))
|
182 |
)
|
183 |
|
184 |
|
185 |
generate_image_for_prompt = get_inference_lambda(87)
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
with gr.Blocks(theme="gradio/soft") as demo:
|
189 |
|
|
|
215 |
submit_btn = gr.Button(value="Submit")
|
216 |
charred_inputs = [prompt_input_charr]
|
217 |
submit_btn.click(
|
218 |
+
fn=generate_image_for_prompt,
|
219 |
+
inputs=charred_inputs,
|
220 |
+
outputs=[charred_output],
|
221 |
)
|
222 |
# examples = [["postage stamp from california", "low quality", "charr_output.png", "charr_output.png" ]]
|
223 |
# gr.Examples(fn = infer_sd, inputs = ["text", "text", "image", "image"], examples=examples, cache_examples=True)
|
224 |
|
225 |
demo.queue(concurrency_count=1)
|
226 |
+
demo.launch(debug=True, show_error=True)
|