Sergidev commited on
Commit
23daa6d
·
verified ·
1 Parent(s): 0388d1b

History v1

Browse files
Files changed (1) hide show
  1. app.py +166 -165
app.py CHANGED
@@ -12,13 +12,17 @@ from PIL import Image, PngImagePlugin
12
  from datetime import datetime
13
  from diffusers.models import AutoencoderKL
14
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
 
 
 
15
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
  DESCRIPTION = "PonyDiffusion V6 XL"
20
  if not torch.cuda.is_available():
21
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
 
22
  IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
23
  HF_TOKEN = os.getenv("HF_TOKEN")
24
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
@@ -27,7 +31,6 @@ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
27
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
28
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
29
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
30
-
31
  MODEL = os.getenv(
32
  "MODEL",
33
  "https://huggingface.co/AstraliteHeart/pony-diffusion-v6/blob/main/v6.safetensors",
@@ -38,6 +41,8 @@ torch.backends.cudnn.benchmark = False
38
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
 
 
 
41
 
42
  def load_pipeline(model_name):
43
  vae = AutoencoderKL.from_pretrained(
@@ -49,7 +54,6 @@ def load_pipeline(model_name):
49
  if MODEL.endswith(".safetensors")
50
  else StableDiffusionXLPipeline.from_pretrained
51
  )
52
-
53
  pipe = pipeline(
54
  model_name,
55
  vae=vae,
@@ -60,11 +64,9 @@ def load_pipeline(model_name):
60
  use_auth_token=HF_TOKEN,
61
  variant="fp16",
62
  )
63
-
64
  pipe.to(device)
65
  return pipe
66
 
67
-
68
  @spaces.GPU
69
  def generate(
70
  prompt: str,
@@ -82,20 +84,16 @@ def generate(
82
  progress=gr.Progress(track_tqdm=True),
83
  ) -> Image:
84
  generator = utils.seed_everything(seed)
85
-
86
  width, height = utils.aspect_ratio_handler(
87
- aspect_ratio_selector,
88
- custom_width,
89
- custom_height,
90
  )
91
-
92
  width, height = utils.preprocess_image_dimensions(width, height)
93
-
94
  backup_scheduler = pipe.scheduler
95
  pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
96
 
97
  if use_upscaler:
98
  upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
 
99
  metadata = {
100
  "prompt": prompt,
101
  "negative_prompt": negative_prompt,
@@ -117,6 +115,7 @@ def generate(
117
  }
118
  else:
119
  metadata["use_upscaler"] = None
 
120
  logger.info(json.dumps(metadata, indent=4))
121
 
122
  try:
@@ -154,12 +153,34 @@ def generate(
154
  output_type="pil",
155
  ).images
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if images and IS_COLAB:
158
  for image in images:
159
  filepath = utils.save_image(image, metadata, OUTPUT_DIR)
160
  logger.info(f"Image saved as {filepath} with metadata")
161
 
162
- return images, metadata
 
163
  except Exception as e:
164
  logger.exception(f"An error occurred: {e}")
165
  raise
@@ -169,7 +190,6 @@ def generate(
169
  pipe.scheduler = backup_scheduler
170
  utils.free_memory()
171
 
172
-
173
  if torch.cuda.is_available():
174
  pipe = load_pipeline(MODEL)
175
  logger.info("Loaded on Device!")
@@ -178,52 +198,32 @@ else:
178
 
179
  with gr.Blocks(css="style.css") as demo:
180
  title = gr.HTML(
181
- f"""<h1><span>{DESCRIPTION}</span></h1>""",
182
- elem_id="title",
183
- )
184
- gr.Markdown(
185
- f"""Gradio demo for ([Pony Diffusion V6]https://civitai.com/models/257749/pony-diffusion-v6-xl/)""",
186
- elem_id="subtitle",
187
  )
188
- gr.DuplicateButton(
189
- value="Duplicate Space for private use",
190
- elem_id="duplicate-button",
191
- visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
192
- )
193
- with gr.Group():
194
- with gr.Row():
195
- prompt = gr.Text(
196
  label="Prompt",
197
  show_label=False,
198
- max_lines=5,
199
  placeholder="Enter your prompt",
200
- container=False,
201
  )
202
- run_button = gr.Button(
203
- "Generate",
204
- variant="primary",
205
- scale=0
 
206
  )
207
- result = gr.Gallery(
208
- label="Result",
209
- columns=1,
210
- preview=True,
211
- show_label=False
212
- )
213
- with gr.Accordion(label="Advanced Settings", open=False):
214
- negative_prompt = gr.Text(
215
- label="Negative Prompt",
216
- max_lines=5,
217
- placeholder="Enter a negative prompt",
218
- value=""
219
- )
220
- aspect_ratio_selector = gr.Radio(
221
- label="Aspect Ratio",
222
- choices=config.aspect_ratios,
223
- value="1024 x 1024",
224
- container=True,
225
- )
226
- with gr.Group(visible=False) as custom_resolution:
227
  with gr.Row():
228
  custom_width = gr.Slider(
229
  label="Width",
@@ -239,125 +239,126 @@ with gr.Blocks(css="style.css") as demo:
239
  step=8,
240
  value=1024,
241
  )
242
- use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
243
- with gr.Row() as upscaler_row:
244
- upscaler_strength = gr.Slider(
245
- label="Strength",
246
- minimum=0,
247
- maximum=1,
248
- step=0.05,
249
- value=0.55,
250
- visible=False,
251
- )
252
- upscale_by = gr.Slider(
253
- label="Upscale by",
254
- minimum=1,
255
- maximum=1.5,
256
- step=0.1,
257
- value=1.5,
258
- visible=False,
259
- )
260
 
261
- sampler = gr.Dropdown(
262
- label="Sampler",
263
- choices=config.sampler_list,
264
- interactive=True,
265
- value="DPM++ 2M SDE Karras",
266
- )
267
- with gr.Row():
268
- seed = gr.Slider(
269
- label="Seed", minimum=0, maximum=utils.MAX_SEED, step=1, value=0
270
- )
271
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
272
- with gr.Group():
273
  with gr.Row():
274
  guidance_scale = gr.Slider(
275
- label="Guidance scale",
276
- minimum=1,
277
- maximum=12,
278
- step=0.1,
279
- value=7.0,
280
  )
281
  num_inference_steps = gr.Slider(
282
- label="Number of inference steps",
283
  minimum=1,
284
- maximum=50,
285
  step=1,
286
- value=28,
287
  )
288
- with gr.Accordion(label="Generation Parameters", open=False):
289
- gr_metadata = gr.JSON(label="Metadata", show_label=False)
290
- gr.Examples(
291
- examples=config.examples,
292
- inputs=prompt,
293
- outputs=[result, gr_metadata],
294
- fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
295
- cache_examples=CACHE_EXAMPLES,
296
- )
297
- use_upscaler.change(
298
- fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
299
- inputs=use_upscaler,
300
- outputs=[upscaler_strength, upscale_by],
301
- queue=False,
302
- api_name=False,
303
- )
304
- aspect_ratio_selector.change(
305
- fn=lambda x: gr.update(visible=x == "Custom"),
306
- inputs=aspect_ratio_selector,
307
- outputs=custom_resolution,
308
- queue=False,
309
- api_name=False,
310
- )
311
 
312
- inputs = [
313
- prompt,
314
- negative_prompt,
315
- seed,
316
- custom_width,
317
- custom_height,
318
- guidance_scale,
319
- num_inference_steps,
320
- sampler,
321
- aspect_ratio_selector,
322
- use_upscaler,
323
- upscaler_strength,
324
- upscale_by,
325
- ]
326
-
327
- prompt.submit(
328
- fn=utils.randomize_seed_fn,
329
- inputs=[seed, randomize_seed],
330
- outputs=seed,
331
- queue=False,
332
- api_name=False,
333
- ).then(
334
- fn=generate,
335
- inputs=inputs,
336
- outputs=result,
337
- api_name="run",
338
- )
339
- negative_prompt.submit(
340
- fn=utils.randomize_seed_fn,
341
- inputs=[seed, randomize_seed],
342
- outputs=seed,
343
- queue=False,
344
- api_name=False,
345
- ).then(
346
- fn=generate,
347
- inputs=inputs,
348
- outputs=result,
349
- api_name=False,
350
- )
351
- run_button.click(
352
- fn=utils.randomize_seed_fn,
353
- inputs=[seed, randomize_seed],
354
- outputs=seed,
355
- queue=False,
356
- api_name=False,
357
- ).then(
358
- fn=generate,
359
- inputs=inputs,
360
- outputs=[result, gr_metadata],
361
- api_name=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  )
363
- demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from datetime import datetime
13
  from diffusers.models import AutoencoderKL
14
  from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
15
+ from collections import deque
16
+ import base64
17
+ from io import BytesIO
18
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
  DESCRIPTION = "PonyDiffusion V6 XL"
23
  if not torch.cuda.is_available():
24
+ DESCRIPTION += "\n\nRunning on CPU 🥶 This demo does not work on CPU."
25
+
26
  IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
27
  HF_TOKEN = os.getenv("HF_TOKEN")
28
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
 
31
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
32
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
33
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
 
34
  MODEL = os.getenv(
35
  "MODEL",
36
  "https://huggingface.co/AstraliteHeart/pony-diffusion-v6/blob/main/v6.safetensors",
 
41
 
42
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
43
 
44
+ MAX_HISTORY_SIZE = 10
45
+ image_history = deque(maxlen=MAX_HISTORY_SIZE)
46
 
47
  def load_pipeline(model_name):
48
  vae = AutoencoderKL.from_pretrained(
 
54
  if MODEL.endswith(".safetensors")
55
  else StableDiffusionXLPipeline.from_pretrained
56
  )
 
57
  pipe = pipeline(
58
  model_name,
59
  vae=vae,
 
64
  use_auth_token=HF_TOKEN,
65
  variant="fp16",
66
  )
 
67
  pipe.to(device)
68
  return pipe
69
 
 
70
  @spaces.GPU
71
  def generate(
72
  prompt: str,
 
84
  progress=gr.Progress(track_tqdm=True),
85
  ) -> Image:
86
  generator = utils.seed_everything(seed)
 
87
  width, height = utils.aspect_ratio_handler(
88
+ aspect_ratio_selector, custom_width, custom_height,
 
 
89
  )
 
90
  width, height = utils.preprocess_image_dimensions(width, height)
 
91
  backup_scheduler = pipe.scheduler
92
  pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
93
 
94
  if use_upscaler:
95
  upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
96
+
97
  metadata = {
98
  "prompt": prompt,
99
  "negative_prompt": negative_prompt,
 
115
  }
116
  else:
117
  metadata["use_upscaler"] = None
118
+
119
  logger.info(json.dumps(metadata, indent=4))
120
 
121
  try:
 
153
  output_type="pil",
154
  ).images
155
 
156
+ if images:
157
+ for image in images:
158
+ # Create thumbnail
159
+ thumbnail = image.copy()
160
+ thumbnail.thumbnail((256, 256))
161
+
162
+ # Convert thumbnail to base64
163
+ buffered = BytesIO()
164
+ thumbnail.save(buffered, format="PNG")
165
+ img_str = base64.b64encode(buffered.getvalue()).decode()
166
+
167
+ # Add to history
168
+ image_history.appendleft({
169
+ "thumbnail": f"data:image/png;base64,{img_str}",
170
+ "prompt": prompt,
171
+ "negative_prompt": negative_prompt,
172
+ "seed": seed,
173
+ "width": width,
174
+ "height": height,
175
+ })
176
+
177
  if images and IS_COLAB:
178
  for image in images:
179
  filepath = utils.save_image(image, metadata, OUTPUT_DIR)
180
  logger.info(f"Image saved as {filepath} with metadata")
181
 
182
+ return images, metadata, list(image_history)
183
+
184
  except Exception as e:
185
  logger.exception(f"An error occurred: {e}")
186
  raise
 
190
  pipe.scheduler = backup_scheduler
191
  utils.free_memory()
192
 
 
193
  if torch.cuda.is_available():
194
  pipe = load_pipeline(MODEL)
195
  logger.info("Loaded on Device!")
 
198
 
199
  with gr.Blocks(css="style.css") as demo:
200
  title = gr.HTML(
201
+ f"""<h1>{DESCRIPTION}</h1>"""
 
 
 
 
 
202
  )
203
+
204
+ with gr.Row():
205
+ with gr.Column(scale=2):
206
+ prompt = gr.Textbox(
 
 
 
 
207
  label="Prompt",
208
  show_label=False,
209
+ max_lines=2,
210
  placeholder="Enter your prompt",
 
211
  )
212
+ negative_prompt = gr.Textbox(
213
+ label="Negative Prompt",
214
+ show_label=False,
215
+ max_lines=2,
216
+ placeholder="Enter a negative prompt",
217
  )
218
+
219
+ with gr.Row():
220
+ seed = gr.Number(
221
+ label="Seed",
222
+ value=0,
223
+ precision=0,
224
+ )
225
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
226
+
 
 
 
 
 
 
 
 
 
 
 
227
  with gr.Row():
228
  custom_width = gr.Slider(
229
  label="Width",
 
239
  step=8,
240
  value=1024,
241
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  with gr.Row():
244
  guidance_scale = gr.Slider(
245
+ label="Guidance Scale", minimum=0, maximum=20, step=0.1, value=7
 
 
 
 
246
  )
247
  num_inference_steps = gr.Slider(
248
+ label="Num Inference Steps",
249
  minimum=1,
250
+ maximum=100,
251
  step=1,
252
+ value=30,
253
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
+ with gr.Row():
256
+ sampler = gr.Dropdown(
257
+ label="Sampler",
258
+ choices=[
259
+ "DPM++ 2M SDE Karras",
260
+ "DPM++ 2M SDE",
261
+ "Euler a",
262
+ "Euler",
263
+ "DPM++ 2M Karras",
264
+ "DPM++ 2M",
265
+ "LMS Karras",
266
+ "Heun",
267
+ "DPM++ SDE Karras",
268
+ "DPM++ SDE",
269
+ "DPM2 Karras",
270
+ "DPM2",
271
+ "DPM2 a Karras",
272
+ "DPM2 a",
273
+ "LMS",
274
+ "DDIM",
275
+ "PLMS",
276
+ ],
277
+ value="DPM++ 2M SDE Karras",
278
+ )
279
+ aspect_ratio_selector = gr.Dropdown(
280
+ label="Aspect Ratio",
281
+ choices=[
282
+ "1024 x 1024",
283
+ "1152 x 896",
284
+ "896 x 1152",
285
+ "1216 x 832",
286
+ "832 x 1216",
287
+ "1344 x 768",
288
+ "768 x 1344",
289
+ "1536 x 640",
290
+ "640 x 1536",
291
+ ],
292
+ value="1024 x 1024",
293
+ )
294
+
295
+ with gr.Row():
296
+ use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
297
+ upscaler_strength = gr.Slider(
298
+ label="Upscaler Strength",
299
+ minimum=0,
300
+ maximum=1,
301
+ step=0.05,
302
+ value=0.55,
303
+ )
304
+ upscale_by = gr.Slider(
305
+ label="Upscale By",
306
+ minimum=1,
307
+ maximum=4,
308
+ step=0.1,
309
+ value=1.5,
310
+ )
311
+
312
+ with gr.Column(scale=1):
313
+ output_image = gr.Image(label="Generated Image")
314
+ output_text = gr.JSON(label="Generation Info")
315
+
316
+ with gr.Row():
317
+ generate_button = gr.Button("Generate")
318
+
319
+ # Add the history component
320
+ history = gr.HTML(label="Generation History")
321
+
322
+ # Update the generate_button click event
323
+ generate_button.click(
324
+ generate,
325
+ inputs=[
326
+ prompt,
327
+ negative_prompt,
328
+ seed,
329
+ custom_width,
330
+ custom_height,
331
+ guidance_scale,
332
+ num_inference_steps,
333
+ sampler,
334
+ aspect_ratio_selector,
335
+ use_upscaler,
336
+ upscaler_strength,
337
+ upscale_by,
338
+ ],
339
+ outputs=[output_image, output_text, history],
340
  )
341
+
342
+ # Add a function to update the history display
343
+ def update_history(history_data):
344
+ html = "<div class='history-container'>"
345
+ for item in history_data:
346
+ html += f"""
347
+ <div class='history-item'>
348
+ <img src='{item['thumbnail']}' alt='Generated Image'>
349
+ <div class='history-info'>
350
+ <p><strong>Prompt:</strong> {item['prompt']}</p>
351
+ <p><strong>Negative Prompt:</strong> {item['negative_prompt']}</p>
352
+ <p><strong>Seed:</strong> {item['seed']}</p>
353
+ <p><strong>Size:</strong> {item['width']}x{item['height']}</p>
354
+ </div>
355
+ </div>
356
+ """
357
+ html += "</div>"
358
+ return html
359
+
360
+ # Connect the update_history function to the history component
361
+ history.change(update_history, inputs=[history], outputs=[history])
362
+
363
+ demo.queue(concurrency_count=1, max_size=20)
364
+ demo.launch(debug=True)