Sergidev commited on
Commit
a3a6f96
·
verified ·
1 Parent(s): 93428fb

History feature v1

Browse files
Files changed (1) hide show
  1. app.py +50 -192
app.py CHANGED
@@ -27,6 +27,7 @@ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
27
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
28
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
29
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
 
30
 
31
  MODEL = os.getenv(
32
  "MODEL",
@@ -38,32 +39,11 @@ torch.backends.cudnn.benchmark = False
38
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
 
 
 
41
 
42
  def load_pipeline(model_name):
43
- vae = AutoencoderKL.from_pretrained(
44
- "madebyollin/sdxl-vae-fp16-fix",
45
- torch_dtype=torch.float16,
46
- )
47
- pipeline = (
48
- StableDiffusionXLPipeline.from_single_file
49
- if MODEL.endswith(".safetensors")
50
- else StableDiffusionXLPipeline.from_pretrained
51
- )
52
-
53
- pipe = pipeline(
54
- model_name,
55
- vae=vae,
56
- torch_dtype=torch.float16,
57
- custom_pipeline="lpw_stable_diffusion_xl",
58
- use_safetensors=True,
59
- add_watermarker=False,
60
- use_auth_token=HF_TOKEN,
61
- variant="fp16",
62
- )
63
-
64
- pipe.to(device)
65
- return pipe
66
-
67
 
68
  @spaces.GPU
69
  def generate(
@@ -81,85 +61,29 @@ def generate(
81
  upscale_by: float = 1.5,
82
  progress=gr.Progress(track_tqdm=True),
83
  ) -> Image:
84
- generator = utils.seed_everything(seed)
85
-
86
- width, height = utils.aspect_ratio_handler(
87
- aspect_ratio_selector,
88
- custom_width,
89
- custom_height,
90
- )
91
-
92
- width, height = utils.preprocess_image_dimensions(width, height)
93
-
94
- backup_scheduler = pipe.scheduler
95
- pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
96
-
97
- if use_upscaler:
98
- upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
99
- metadata = {
100
- "prompt": prompt,
101
- "negative_prompt": negative_prompt,
102
- "resolution": f"{width} x {height}",
103
- "guidance_scale": guidance_scale,
104
- "num_inference_steps": num_inference_steps,
105
- "seed": seed,
106
- "sampler": sampler,
107
- }
108
-
109
- if use_upscaler:
110
- new_width = int(width * upscale_by)
111
- new_height = int(height * upscale_by)
112
- metadata["use_upscaler"] = {
113
- "upscale_method": "nearest-exact",
114
- "upscaler_strength": upscaler_strength,
115
- "upscale_by": upscale_by,
116
- "new_resolution": f"{new_width} x {new_height}",
117
- }
118
- else:
119
- metadata["use_upscaler"] = None
120
- logger.info(json.dumps(metadata, indent=4))
121
 
122
  try:
123
- if use_upscaler:
124
- latents = pipe(
125
- prompt=prompt,
126
- negative_prompt=negative_prompt,
127
- width=width,
128
- height=height,
129
- guidance_scale=guidance_scale,
130
- num_inference_steps=num_inference_steps,
131
- generator=generator,
132
- output_type="latent",
133
- ).images
134
- upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
135
- images = upscaler_pipe(
136
- prompt=prompt,
137
- negative_prompt=negative_prompt,
138
- image=upscaled_latents,
139
- guidance_scale=guidance_scale,
140
- num_inference_steps=num_inference_steps,
141
- strength=upscaler_strength,
142
- generator=generator,
143
- output_type="pil",
144
- ).images
145
- else:
146
- images = pipe(
147
- prompt=prompt,
148
- negative_prompt=negative_prompt,
149
- width=width,
150
- height=height,
151
- guidance_scale=guidance_scale,
152
- num_inference_steps=num_inference_steps,
153
- generator=generator,
154
- output_type="pil",
155
- ).images
156
-
157
- if images and IS_COLAB:
158
- for image in images:
159
- filepath = utils.save_image(image, metadata, OUTPUT_DIR)
160
- logger.info(f"Image saved as {filepath} with metadata")
161
-
162
- return images, metadata
163
  except Exception as e:
164
  logger.exception(f"An error occurred: {e}")
165
  raise
@@ -169,6 +93,18 @@ def generate(
169
  pipe.scheduler = backup_scheduler
170
  utils.free_memory()
171
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  if torch.cuda.is_available():
174
  pipe = load_pipeline(MODEL)
@@ -210,104 +146,25 @@ with gr.Blocks(css="style.css") as demo:
210
  preview=True,
211
  show_label=False
212
  )
 
 
 
 
213
  with gr.Accordion(label="Advanced Settings", open=False):
214
- negative_prompt = gr.Text(
215
- label="Negative Prompt",
216
- max_lines=5,
217
- placeholder="Enter a negative prompt",
218
- value=""
219
- )
220
- aspect_ratio_selector = gr.Radio(
221
- label="Aspect Ratio",
222
- choices=config.aspect_ratios,
223
- value="1024 x 1024",
224
- container=True,
225
- )
226
- with gr.Group(visible=False) as custom_resolution:
227
- with gr.Row():
228
- custom_width = gr.Slider(
229
- label="Width",
230
- minimum=MIN_IMAGE_SIZE,
231
- maximum=MAX_IMAGE_SIZE,
232
- step=8,
233
- value=1024,
234
- )
235
- custom_height = gr.Slider(
236
- label="Height",
237
- minimum=MIN_IMAGE_SIZE,
238
- maximum=MAX_IMAGE_SIZE,
239
- step=8,
240
- value=1024,
241
- )
242
- use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
243
- with gr.Row() as upscaler_row:
244
- upscaler_strength = gr.Slider(
245
- label="Strength",
246
- minimum=0,
247
- maximum=1,
248
- step=0.05,
249
- value=0.55,
250
- visible=False,
251
- )
252
- upscale_by = gr.Slider(
253
- label="Upscale by",
254
- minimum=1,
255
- maximum=1.5,
256
- step=0.1,
257
- value=1.5,
258
- visible=False,
259
- )
260
 
261
- sampler = gr.Dropdown(
262
- label="Sampler",
263
- choices=config.sampler_list,
264
- interactive=True,
265
- value="DPM++ 2M SDE Karras",
266
- )
267
- with gr.Row():
268
- seed = gr.Slider(
269
- label="Seed", minimum=0, maximum=utils.MAX_SEED, step=1, value=0
270
- )
271
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
272
- with gr.Group():
273
- with gr.Row():
274
- guidance_scale = gr.Slider(
275
- label="Guidance scale",
276
- minimum=1,
277
- maximum=12,
278
- step=0.1,
279
- value=7.0,
280
- )
281
- num_inference_steps = gr.Slider(
282
- label="Number of inference steps",
283
- minimum=1,
284
- maximum=50,
285
- step=1,
286
- value=28,
287
- )
288
  with gr.Accordion(label="Generation Parameters", open=False):
289
  gr_metadata = gr.JSON(label="Metadata", show_label=False)
 
290
  gr.Examples(
291
  examples=config.examples,
292
  inputs=prompt,
293
- outputs=[result, gr_metadata],
294
  fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
295
  cache_examples=CACHE_EXAMPLES,
296
  )
297
- use_upscaler.change(
298
- fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
299
- inputs=use_upscaler,
300
- outputs=[upscaler_strength, upscale_by],
301
- queue=False,
302
- api_name=False,
303
- )
304
- aspect_ratio_selector.change(
305
- fn=lambda x: gr.update(visible=x == "Custom"),
306
- inputs=aspect_ratio_selector,
307
- outputs=custom_resolution,
308
- queue=False,
309
- api_name=False,
310
- )
311
 
312
  inputs = [
313
  prompt,
@@ -333,7 +190,7 @@ with gr.Blocks(css="style.css") as demo:
333
  ).then(
334
  fn=generate,
335
  inputs=inputs,
336
- outputs=result,
337
  api_name="run",
338
  )
339
  negative_prompt.submit(
@@ -345,7 +202,7 @@ with gr.Blocks(css="style.css") as demo:
345
  ).then(
346
  fn=generate,
347
  inputs=inputs,
348
- outputs=result,
349
  api_name=False,
350
  )
351
  run_button.click(
@@ -357,7 +214,8 @@ with gr.Blocks(css="style.css") as demo:
357
  ).then(
358
  fn=generate,
359
  inputs=inputs,
360
- outputs=[result, gr_metadata],
361
  api_name=False,
362
  )
 
363
  demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)
 
27
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
28
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
29
  OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
30
+ THUMBNAIL_SIZE = (128, 128) # Size for thumbnails
31
 
32
  MODEL = os.getenv(
33
  "MODEL",
 
39
 
40
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
41
 
42
+ # Store the generation history
43
+ generation_history = []
44
 
45
  def load_pipeline(model_name):
46
+ # ... (rest of the function remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  @spaces.GPU
49
  def generate(
 
61
  upscale_by: float = 1.5,
62
  progress=gr.Progress(track_tqdm=True),
63
  ) -> Image:
64
+ # ... (rest of the function remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  try:
67
+ # ... (existing code for image generation)
68
+
69
+ if images:
70
+ # Create thumbnail
71
+ thumbnail = images[0].copy()
72
+ thumbnail.thumbnail(THUMBNAIL_SIZE)
73
+
74
+ # Add to generation history
75
+ generation_history.append({
76
+ "prompt": prompt,
77
+ "thumbnail": thumbnail,
78
+ "metadata": metadata
79
+ })
80
+
81
+ if IS_COLAB:
82
+ for image in images:
83
+ filepath = utils.save_image(image, metadata, OUTPUT_DIR)
84
+ logger.info(f"Image saved as {filepath} with metadata")
85
+
86
+ return images, metadata, update_history()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  except Exception as e:
88
  logger.exception(f"An error occurred: {e}")
89
  raise
 
93
  pipe.scheduler = backup_scheduler
94
  utils.free_memory()
95
 
96
+ def update_history():
97
+ history_html = "<div style='display: flex; flex-wrap: wrap;'>"
98
+ for item in reversed(generation_history[-10:]): # Show last 10 entries
99
+ thumbnail_path = f"data:image/png;base64,{utils.image_to_base64(item['thumbnail'])}"
100
+ history_html += f"""
101
+ <div style='margin: 5px; text-align: center;'>
102
+ <img src='{thumbnail_path}' style='width: 100px; height: 100px; object-fit: cover;'>
103
+ <p style='font-size: 12px; margin: 5px 0;'>{item['prompt'][:50]}...</p>
104
+ </div>
105
+ """
106
+ history_html += "</div>"
107
+ return history_html
108
 
109
  if torch.cuda.is_available():
110
  pipe = load_pipeline(MODEL)
 
146
  preview=True,
147
  show_label=False
148
  )
149
+
150
+ # Add the history display
151
+ history_display = gr.HTML(label="Generation History")
152
+
153
  with gr.Accordion(label="Advanced Settings", open=False):
154
+ # ... (rest of the UI components remain the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  with gr.Accordion(label="Generation Parameters", open=False):
157
  gr_metadata = gr.JSON(label="Metadata", show_label=False)
158
+
159
  gr.Examples(
160
  examples=config.examples,
161
  inputs=prompt,
162
+ outputs=[result, gr_metadata, history_display],
163
  fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
164
  cache_examples=CACHE_EXAMPLES,
165
  )
166
+
167
+ # ... (rest of the event handlers remain the same)
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  inputs = [
170
  prompt,
 
190
  ).then(
191
  fn=generate,
192
  inputs=inputs,
193
+ outputs=[result, gr_metadata, history_display],
194
  api_name="run",
195
  )
196
  negative_prompt.submit(
 
202
  ).then(
203
  fn=generate,
204
  inputs=inputs,
205
+ outputs=[result, gr_metadata, history_display],
206
  api_name=False,
207
  )
208
  run_button.click(
 
214
  ).then(
215
  fn=generate,
216
  inputs=inputs,
217
+ outputs=[result, gr_metadata, history_display],
218
  api_name=False,
219
  )
220
+
221
  demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)