Spaces:
Runtime error
Runtime error
Update dtype and example
Browse files
app.py
CHANGED
@@ -1,12 +1,27 @@
|
|
1 |
import gradio as gr
|
2 |
import jax
|
3 |
-
|
|
|
4 |
from flax.jax_utils import replicate
|
5 |
from flax.training.common_utils import shard
|
6 |
|
|
|
|
|
7 |
pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
|
8 |
"bguisard/stable-diffusion-nano-2-1",
|
|
|
9 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
|
@@ -51,7 +66,17 @@ app = gr.Interface(
|
|
51 |
"Stable Diffusion Nano allows for fast prototyping of diffusion models, "
|
52 |
"enabling quick experimentation with easily available hardware."
|
53 |
),
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
)
|
56 |
|
57 |
app.launch()
|
|
|
|
1 |
import gradio as gr
|
2 |
import jax
|
3 |
+
import jax.numpy as jnp
|
4 |
+
from diffusers import FlaxPNDMScheduler, FlaxStableDiffusionPipeline
|
5 |
from flax.jax_utils import replicate
|
6 |
from flax.training.common_utils import shard
|
7 |
|
8 |
+
DTYPE = jnp.bfloat16
|
9 |
+
|
10 |
pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
|
11 |
"bguisard/stable-diffusion-nano-2-1",
|
12 |
+
dtype=DTYPE,
|
13 |
)
|
14 |
+
if DTYPE != jnp.float32:
|
15 |
+
# There is a known issue with schedulers when loading from a pre trained
|
16 |
+
# pipeline. We need the schedulers to always use float32.
|
17 |
+
# See: https://github.com/huggingface/diffusers/issues/2155
|
18 |
+
scheduler, scheduler_params = FlaxPNDMScheduler.from_pretrained(
|
19 |
+
pretrained_model_name_or_path="bguisard/stable-diffusion-nano-2-1",
|
20 |
+
subfolder="scheduler",
|
21 |
+
dtype=jnp.float32,
|
22 |
+
)
|
23 |
+
pipeline_params["scheduler"] = scheduler_params
|
24 |
+
pipeline.scheduler = scheduler
|
25 |
|
26 |
|
27 |
def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
|
|
|
66 |
"Stable Diffusion Nano allows for fast prototyping of diffusion models, "
|
67 |
"enabling quick experimentation with easily available hardware."
|
68 |
),
|
69 |
+
# Some examples were copied from hf.co/spaces/stabilityai/stable-diffusion
|
70 |
+
examples=[
|
71 |
+
# ["A watercolor painting of a bird", 30, 0],
|
72 |
+
[
|
73 |
+
"A small cabin on top of a snowy mountain in the style of Disney, artstation",
|
74 |
+
25,
|
75 |
+
3129302,
|
76 |
+
],
|
77 |
+
# ["A mecha robot in a favela in expressionist style", 30, 827198341273],
|
78 |
+
],
|
79 |
)
|
80 |
|
81 |
app.launch()
|
82 |
+
|