radames commited on
Commit
83718c8
·
1 Parent(s): 30970e0

use compel for prompt encoding

Browse files
Files changed (2) hide show
  1. app.py +109 -12
  2. requirements.txt +2 -1
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from gradio_imageslider import ImageSlider
3
  import torch
4
  from diffusers import DiffusionPipeline, AutoencoderKL
 
5
  from PIL import Image
6
  from torchvision import transforms
7
  import tempfile
@@ -30,7 +31,12 @@ pipe = DiffusionPipeline.from_pretrained(
30
  use_safetensors=True,
31
  vae=vae,
32
  )
33
-
 
 
 
 
 
34
  pipe = pipe.to(device)
35
 
36
 
@@ -70,6 +76,11 @@ def predict(
70
  prompt,
71
  negative_prompt,
72
  seed,
 
 
 
 
 
73
  scale=2,
74
  progress=gr.Progress(track_tqdm=True),
75
  ):
@@ -77,11 +88,14 @@ def predict(
77
  raise gr.Error("Please upload an image.")
78
  padded_image = pad_image(input_image).resize((1024, 1024)).convert("RGB")
79
  image_lr = load_and_process_image(padded_image).to(device)
 
80
  generator = torch.manual_seed(seed)
81
  last_time = time.time()
82
  images = pipe(
83
- prompt,
84
- negative_prompt=negative_prompt,
 
 
85
  image_lr=image_lr,
86
  width=1024 * scale,
87
  height=1024 * scale,
@@ -89,11 +103,11 @@ def predict(
89
  stride=64,
90
  generator=generator,
91
  num_inference_steps=40,
92
- guidance_scale=8.5,
93
- cosine_scale_1=3,
94
- cosine_scale_2=1,
95
- cosine_scale_3=1,
96
- sigma=0.8,
97
  multi_decoder=1024 * scale > 2048,
98
  show_image=False,
99
  lowvram=LOW_MEMORY,
@@ -145,13 +159,48 @@ GPU Time Comparison: T4: ~276s - A10G: ~113.6s A100: ~43.5s RTX 4090: ~48.1s
145
  label="Negative Prompt",
146
  value="blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
147
  )
 
 
 
 
 
 
 
148
  scale = gr.Slider(
149
  minimum=1,
150
  maximum=5,
151
  value=2,
152
  step=1,
153
  label="x Scale",
154
- interactive=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  )
156
  seed = gr.Slider(
157
  minimum=0,
@@ -165,8 +214,19 @@ GPU Time Comparison: T4: ~276s - A10G: ~113.6s A100: ~43.5s RTX 4090: ~48.1s
165
  with gr.Column(scale=2):
166
  image_slider = ImageSlider(position=0.5)
167
  files = gr.Files()
168
- # inputs = [image_input, prompt, negative_prompt, seed, scale]
169
- inputs = [image_input, prompt, negative_prompt, seed]
 
 
 
 
 
 
 
 
 
 
 
170
  outputs = [image_slider, files]
171
  btn.click(predict, inputs=inputs, outputs=outputs, concurrency_limit=1)
172
  gr.Examples(
@@ -177,6 +237,12 @@ GPU Time Comparison: T4: ~276s - A10G: ~113.6s A100: ~43.5s RTX 4090: ~48.1s
177
  "photography of lara croft 8k high definition award winning",
178
  "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
179
  5436236241,
 
 
 
 
 
 
180
  2,
181
  ],
182
  [
@@ -184,6 +250,12 @@ GPU Time Comparison: T4: ~276s - A10G: ~113.6s A100: ~43.5s RTX 4090: ~48.1s
184
  "photo of tesla cybertruck futuristic car 8k high definition on a sand dune in mars, future",
185
  "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
186
  383472451451,
 
 
 
 
 
 
187
  2,
188
  ],
189
  [
@@ -191,6 +263,7 @@ GPU Time Comparison: T4: ~276s - A10G: ~113.6s A100: ~43.5s RTX 4090: ~48.1s
191
  "a photorealistic painting of Jesus Christ, 4k high definition",
192
  "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
193
  13317204146129588000,
 
194
  2,
195
  ],
196
  [
@@ -198,6 +271,12 @@ GPU Time Comparison: T4: ~276s - A10G: ~113.6s A100: ~43.5s RTX 4090: ~48.1s
198
  "A crowded stadium with enthusiastic fans watching a daytime sporting event, the stands filled with colorful attire and the sun casting a warm glow",
199
  "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
200
  5623124123512,
 
 
 
 
 
 
201
  2,
202
  ],
203
  [
@@ -205,12 +284,30 @@ GPU Time Comparison: T4: ~276s - A10G: ~113.6s A100: ~43.5s RTX 4090: ~48.1s
205
  "a large red flower on a black background 4k high definition",
206
  "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
207
  23123412341234,
 
 
 
 
 
 
208
  2,
209
  ],
 
 
 
 
 
 
 
 
 
 
 
 
210
  ],
211
  inputs=inputs,
212
  outputs=outputs,
213
- cache_examples=True,
214
  )
215
 
216
 
 
2
  from gradio_imageslider import ImageSlider
3
  import torch
4
  from diffusers import DiffusionPipeline, AutoencoderKL
5
+ from compel import Compel, ReturnedEmbeddingsType
6
  from PIL import Image
7
  from torchvision import transforms
8
  import tempfile
 
31
  use_safetensors=True,
32
  vae=vae,
33
  )
34
+ compel = Compel(
35
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
36
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
37
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
38
+ requires_pooled=[False, True],
39
+ )
40
  pipe = pipe.to(device)
41
 
42
 
 
76
  prompt,
77
  negative_prompt,
78
  seed,
79
+ guidance_scale=8.5,
80
+ cosine_scale_1=3,
81
+ cosine_scale_2=1,
82
+ cosine_scale_3=1,
83
+ sigma=0.8,
84
  scale=2,
85
  progress=gr.Progress(track_tqdm=True),
86
  ):
 
88
  raise gr.Error("Please upload an image.")
89
  padded_image = pad_image(input_image).resize((1024, 1024)).convert("RGB")
90
  image_lr = load_and_process_image(padded_image).to(device)
91
+ conditioning, pooled = compel([prompt, negative_prompt])
92
  generator = torch.manual_seed(seed)
93
  last_time = time.time()
94
  images = pipe(
95
+ prompt_embeds=conditioning[0:1],
96
+ pooled_prompt_embeds=pooled[0:1],
97
+ negative_prompt_embeds=conditioning[1:2],
98
+ negative_pooled_prompt_embeds=pooled[1:2],
99
  image_lr=image_lr,
100
  width=1024 * scale,
101
  height=1024 * scale,
 
103
  stride=64,
104
  generator=generator,
105
  num_inference_steps=40,
106
+ guidance_scale=guidance_scale,
107
+ cosine_scale_1=cosine_scale_1,
108
+ cosine_scale_2=cosine_scale_2,
109
+ cosine_scale_3=cosine_scale_3,
110
+ sigma=sigma,
111
  multi_decoder=1024 * scale > 2048,
112
  show_image=False,
113
  lowvram=LOW_MEMORY,
 
159
  label="Negative Prompt",
160
  value="blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
161
  )
162
+ guidance_scale = gr.Slider(
163
+ minimum=0,
164
+ maximum=50,
165
+ value=8.5,
166
+ step=0.001,
167
+ label="Guidance Scale",
168
+ )
169
  scale = gr.Slider(
170
  minimum=1,
171
  maximum=5,
172
  value=2,
173
  step=1,
174
  label="x Scale",
175
+ interactive=True,
176
+ )
177
+ cosine_scale_1 = gr.Slider(
178
+ minimum=0,
179
+ maximum=5,
180
+ value=3,
181
+ step=0.01,
182
+ label="Cosine Scale 1",
183
+ )
184
+ cosine_scale_2 = gr.Slider(
185
+ minimum=0,
186
+ maximum=5,
187
+ value=1,
188
+ step=0.01,
189
+ label="Cosine Scale 2",
190
+ )
191
+ cosine_scale_3 = gr.Slider(
192
+ minimum=0,
193
+ maximum=5,
194
+ value=1,
195
+ step=0.01,
196
+ label="Cosine Scale 3",
197
+ )
198
+ sigma = gr.Slider(
199
+ minimum=0,
200
+ maximum=1,
201
+ value=0.8,
202
+ step=0.01,
203
+ label="Sigma",
204
  )
205
  seed = gr.Slider(
206
  minimum=0,
 
214
  with gr.Column(scale=2):
215
  image_slider = ImageSlider(position=0.5)
216
  files = gr.Files()
217
+ inputs = [
218
+ image_input,
219
+ prompt,
220
+ negative_prompt,
221
+ seed,
222
+ guidance_scale,
223
+ cosine_scale_1,
224
+ cosine_scale_2,
225
+ cosine_scale_3,
226
+ sigma,
227
+ scale,
228
+ ]
229
+ # inputs = [image_input, prompt, negative_prompt, seed]
230
  outputs = [image_slider, files]
231
  btn.click(predict, inputs=inputs, outputs=outputs, concurrency_limit=1)
232
  gr.Examples(
 
237
  "photography of lara croft 8k high definition award winning",
238
  "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
239
  5436236241,
240
+ 8.5,
241
+ 3,
242
+ 1,
243
+ 1,
244
+ 1,
245
+ 0.8,
246
  2,
247
  ],
248
  [
 
250
  "photo of tesla cybertruck futuristic car 8k high definition on a sand dune in mars, future",
251
  "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
252
  383472451451,
253
+ 8.5,
254
+ 3,
255
+ 1,
256
+ 1,
257
+ 1,
258
+ 0.8,
259
  2,
260
  ],
261
  [
 
263
  "a photorealistic painting of Jesus Christ, 4k high definition",
264
  "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
265
  13317204146129588000,
266
+ 8.5,
267
  2,
268
  ],
269
  [
 
271
  "A crowded stadium with enthusiastic fans watching a daytime sporting event, the stands filled with colorful attire and the sun casting a warm glow",
272
  "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
273
  5623124123512,
274
+ 8.5,
275
+ 3,
276
+ 1,
277
+ 1,
278
+ 1,
279
+ 0.8,
280
  2,
281
  ],
282
  [
 
284
  "a large red flower on a black background 4k high definition",
285
  "blurry, ugly, duplicate, poorly drawn, deformed, mosaic",
286
  23123412341234,
287
+ 8.5,
288
+ 3,
289
+ 1,
290
+ 1,
291
+ 1,
292
+ 0.8,
293
  2,
294
  ],
295
+ [
296
+ "./examples/huggingface.jpg",
297
+ "photo realistic huggingface human+++ emoji costume, round, yellow, skin+++ texture+++",
298
+ "blurry, ugly, duplicate, poorly drawn, deformed, mosaic, emoji cartoon, drawing, pixelated",
299
+ 5532144938416372000,
300
+ 20.0,
301
+ 4.64,
302
+ 1,
303
+ 1,
304
+ 0.49,
305
+ 3,
306
+ ],
307
  ],
308
  inputs=inputs,
309
  outputs=outputs,
310
+ cache_examples=False,
311
  )
312
 
313
 
requirements.txt CHANGED
@@ -10,4 +10,5 @@ accelerate
10
  invisible-watermark
11
  huggingface-hub
12
  hf-transfer
13
- gradio_imageslider==0.0.16
 
 
10
  invisible-watermark
11
  huggingface-hub
12
  hf-transfer
13
+ gradio_imageslider==0.0.16
14
+ compel