Spaces:
Runtime error
Runtime error
feat: Add FLUX.1 [schnell] model and interface with Gradio
Browse filesThe commit adds the FLUX.1 [schnell] model and integrates it with Gradio for easy generation of images based on user prompts. The code includes the necessary imports, configuration settings, and UI components. It also includes advanced settings for seed, randomizing seed, width, height, and number of inference steps. The commit also includes examples of prompts for generating images.
The commit also includes an update to the requirements.txt file to add the necessary dependencies for Gradio and image loading.
This commit builds upon the previous commits that created the empty app.py file and made the initial commit to the repository.
- app.py +122 -0
- requirements.txt +2 -0
app.py
CHANGED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import spaces
|
5 |
+
import torch
|
6 |
+
from diffusers import DiffusionPipeline
|
7 |
+
|
8 |
+
dtype = torch.bfloat16
|
9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
|
11 |
+
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="refs/pr/1").to(device)
|
12 |
+
|
13 |
+
MAX_SEED = np.iinfo(np.int32).max
|
14 |
+
MAX_IMAGE_SIZE = 2048
|
15 |
+
|
16 |
+
@spaces.GPU()
|
17 |
+
def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
|
18 |
+
if randomize_seed:
|
19 |
+
seed = random.randint(0, MAX_SEED)
|
20 |
+
generator = torch.Generator().manual_seed(seed)
|
21 |
+
image = pipe(
|
22 |
+
prompt = prompt,
|
23 |
+
width = width,
|
24 |
+
height = height,
|
25 |
+
num_inference_steps = num_inference_steps,
|
26 |
+
generator = generator,
|
27 |
+
guidance_scale=0.0
|
28 |
+
).images[0]
|
29 |
+
return image, seed
|
30 |
+
|
31 |
+
examples = [
|
32 |
+
"a tiny astronaut hatching from an egg on the moon",
|
33 |
+
"a cat holding a sign that says hello world",
|
34 |
+
"an anime illustration of a wiener schnitzel",
|
35 |
+
]
|
36 |
+
|
37 |
+
css="""
|
38 |
+
#col-container {
|
39 |
+
margin: 0 auto;
|
40 |
+
max-width: 520px;
|
41 |
+
}
|
42 |
+
"""
|
43 |
+
|
44 |
+
with gr.Blocks(css=css) as demo:
|
45 |
+
|
46 |
+
with gr.Column(elem_id="col-container"):
|
47 |
+
gr.Markdown(f"""# FLUX.1 [schnell]
|
48 |
+
12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
|
49 |
+
[[blog](https://blackforestlabs.ai/2024/07/31/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)]
|
50 |
+
""")
|
51 |
+
|
52 |
+
with gr.Row():
|
53 |
+
|
54 |
+
prompt = gr.Text(
|
55 |
+
label="Prompt",
|
56 |
+
show_label=False,
|
57 |
+
max_lines=1,
|
58 |
+
placeholder="Enter your prompt",
|
59 |
+
container=False,
|
60 |
+
)
|
61 |
+
|
62 |
+
run_button = gr.Button("Run", scale=0)
|
63 |
+
|
64 |
+
result = gr.Image(label="Result", show_label=False)
|
65 |
+
|
66 |
+
with gr.Accordion("Advanced Settings", open=False):
|
67 |
+
|
68 |
+
seed = gr.Slider(
|
69 |
+
label="Seed",
|
70 |
+
minimum=0,
|
71 |
+
maximum=MAX_SEED,
|
72 |
+
step=1,
|
73 |
+
value=0,
|
74 |
+
)
|
75 |
+
|
76 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
77 |
+
|
78 |
+
with gr.Row():
|
79 |
+
|
80 |
+
width = gr.Slider(
|
81 |
+
label="Width",
|
82 |
+
minimum=256,
|
83 |
+
maximum=MAX_IMAGE_SIZE,
|
84 |
+
step=32,
|
85 |
+
value=1024,
|
86 |
+
)
|
87 |
+
|
88 |
+
height = gr.Slider(
|
89 |
+
label="Height",
|
90 |
+
minimum=256,
|
91 |
+
maximum=MAX_IMAGE_SIZE,
|
92 |
+
step=32,
|
93 |
+
value=1024,
|
94 |
+
)
|
95 |
+
|
96 |
+
with gr.Row():
|
97 |
+
|
98 |
+
|
99 |
+
num_inference_steps = gr.Slider(
|
100 |
+
label="Number of inference steps",
|
101 |
+
minimum=1,
|
102 |
+
maximum=50,
|
103 |
+
step=1,
|
104 |
+
value=4,
|
105 |
+
)
|
106 |
+
|
107 |
+
gr.Examples(
|
108 |
+
examples = examples,
|
109 |
+
fn = infer,
|
110 |
+
inputs = [prompt],
|
111 |
+
outputs = [result, seed],
|
112 |
+
cache_examples="lazy"
|
113 |
+
)
|
114 |
+
|
115 |
+
gr.on(
|
116 |
+
triggers=[run_button.click, prompt.submit],
|
117 |
+
fn = infer,
|
118 |
+
inputs = [prompt, seed, randomize_seed, width, height, num_inference_steps],
|
119 |
+
outputs = [result, seed]
|
120 |
+
)
|
121 |
+
|
122 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
gradio-imageslider
|
2 |
+
loadimg
|