Lisandro commited on
Commit
5c6edfb
·
1 Parent(s): d69d39b

feat: Add FLUX.1 [schnell] model and interface with Gradio

Browse files

The commit adds the FLUX.1 [schnell] model and integrates it with Gradio for easy generation of images based on user prompts. The code includes the necessary imports, configuration settings, and UI components. It also includes advanced settings for seed, randomizing seed, width, height, and number of inference steps. The commit also includes examples of prompts for generating images.

The commit also includes an update to the requirements.txt file to add the necessary dependencies for Gradio and image loading.

This commit builds upon the previous commits that created the empty app.py file and made the initial commit to the repository.

Files changed (2) hide show
  1. app.py +122 -0
  2. requirements.txt +2 -0
app.py CHANGED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import spaces
5
+ import torch
6
+ from diffusers import DiffusionPipeline
7
+
8
+ dtype = torch.bfloat16
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="refs/pr/1").to(device)
12
+
13
+ MAX_SEED = np.iinfo(np.int32).max
14
+ MAX_IMAGE_SIZE = 2048
15
+
16
+ @spaces.GPU()
17
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
18
+ if randomize_seed:
19
+ seed = random.randint(0, MAX_SEED)
20
+ generator = torch.Generator().manual_seed(seed)
21
+ image = pipe(
22
+ prompt = prompt,
23
+ width = width,
24
+ height = height,
25
+ num_inference_steps = num_inference_steps,
26
+ generator = generator,
27
+ guidance_scale=0.0
28
+ ).images[0]
29
+ return image, seed
30
+
31
+ examples = [
32
+ "a tiny astronaut hatching from an egg on the moon",
33
+ "a cat holding a sign that says hello world",
34
+ "an anime illustration of a wiener schnitzel",
35
+ ]
36
+
37
+ css="""
38
+ #col-container {
39
+ margin: 0 auto;
40
+ max-width: 520px;
41
+ }
42
+ """
43
+
44
+ with gr.Blocks(css=css) as demo:
45
+
46
+ with gr.Column(elem_id="col-container"):
47
+ gr.Markdown(f"""# FLUX.1 [schnell]
48
+ 12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
49
+ [[blog](https://blackforestlabs.ai/2024/07/31/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)]
50
+ """)
51
+
52
+ with gr.Row():
53
+
54
+ prompt = gr.Text(
55
+ label="Prompt",
56
+ show_label=False,
57
+ max_lines=1,
58
+ placeholder="Enter your prompt",
59
+ container=False,
60
+ )
61
+
62
+ run_button = gr.Button("Run", scale=0)
63
+
64
+ result = gr.Image(label="Result", show_label=False)
65
+
66
+ with gr.Accordion("Advanced Settings", open=False):
67
+
68
+ seed = gr.Slider(
69
+ label="Seed",
70
+ minimum=0,
71
+ maximum=MAX_SEED,
72
+ step=1,
73
+ value=0,
74
+ )
75
+
76
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
77
+
78
+ with gr.Row():
79
+
80
+ width = gr.Slider(
81
+ label="Width",
82
+ minimum=256,
83
+ maximum=MAX_IMAGE_SIZE,
84
+ step=32,
85
+ value=1024,
86
+ )
87
+
88
+ height = gr.Slider(
89
+ label="Height",
90
+ minimum=256,
91
+ maximum=MAX_IMAGE_SIZE,
92
+ step=32,
93
+ value=1024,
94
+ )
95
+
96
+ with gr.Row():
97
+
98
+
99
+ num_inference_steps = gr.Slider(
100
+ label="Number of inference steps",
101
+ minimum=1,
102
+ maximum=50,
103
+ step=1,
104
+ value=4,
105
+ )
106
+
107
+ gr.Examples(
108
+ examples = examples,
109
+ fn = infer,
110
+ inputs = [prompt],
111
+ outputs = [result, seed],
112
+ cache_examples="lazy"
113
+ )
114
+
115
+ gr.on(
116
+ triggers=[run_button.click, prompt.submit],
117
+ fn = infer,
118
+ inputs = [prompt, seed, randomize_seed, width, height, num_inference_steps],
119
+ outputs = [result, seed]
120
+ )
121
+
122
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio-imageslider
2
+ loadimg