6Morpheus6 commited on
Commit
a2643c3
1 Parent(s): f07b0b1

gradio UI fix

Browse files

- removed gr.Tabs
- improved window resizing
- fixed CSS
- added download element
- added progress bar

Files changed (1) hide show
  1. app.py +67 -40
app.py CHANGED
@@ -3,6 +3,7 @@ sys.path.append('./')
3
 
4
 
5
  import os
 
6
  import cv2
7
  import torch
8
  import random
@@ -161,7 +162,9 @@ def create_image(image_pil,
161
  seed,
162
  target="Load only style blocks",
163
  neg_content_prompt=None,
164
- neg_content_scale=0):
 
 
165
 
166
  if target =="Load original IP-Adapter":
167
  # target_blocks=["blocks"] for original IP-Adapter
@@ -211,13 +214,27 @@ def create_image(image_pil,
211
  image=canny_map,
212
  controlnet_conditioning_scale=float(control_scale),
213
  )
214
- return images
 
 
 
 
 
 
215
 
216
  def pil_to_cv2(image_pil):
217
  image_np = np.array(image_pil)
218
  image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
219
  return image_cv2
220
 
 
 
 
 
 
 
 
 
221
  # Description
222
  title = r"""
223
  <h1 align="center">InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation</h1>
@@ -258,53 +275,58 @@ If our work is helpful for your research or applications, please cite us via:
258
  If you have any questions, please feel free to open an issue or directly reach us out at <b>haofanwang.ai@gmail.com</b>.
259
  """
260
 
261
- block = gr.Blocks(css="footer {visibility: hidden}").queue(max_size=10, api_open=False)
 
 
 
 
 
262
  with block:
263
 
264
  # description
265
  gr.Markdown(title)
266
  #gr.Markdown(description)
267
 
268
- with gr.Tabs():
269
- with gr.Row():
270
- with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
- with gr.Row():
273
- with gr.Column():
274
- image_pil = gr.Image(label="Style Image", type='pil')
275
 
276
- target = gr.Radio(["Load only style blocks", "Load style+layout block", "Load original IP-Adapter"],
277
- value="Load only style blocks",
278
- label="Style mode")
279
 
280
- prompt = gr.Textbox(label="Prompt",
281
- value="a cat, masterpiece, best quality, high quality")
282
-
283
- scale = gr.Slider(minimum=0,maximum=2.0, step=0.01,value=1.0, label="Scale")
284
-
285
- with gr.Accordion(open=False, label="Advanced Options"):
286
-
287
- with gr.Column():
288
- src_image_pil = gr.Image(label="Source Image (optional)", type='pil')
289
- control_scale = gr.Slider(minimum=0,maximum=1.0, step=0.01,value=0.5, label="Controlnet conditioning scale")
290
-
291
- n_prompt = gr.Textbox(label="Neg Prompt", value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
292
-
293
- neg_content_prompt = gr.Textbox(label="Neg Content Prompt", value="")
294
- neg_content_scale = gr.Slider(minimum=0, maximum=1.0, step=0.01,value=0.5, label="Neg Content Scale")
295
-
296
- guidance_scale = gr.Slider(minimum=1,maximum=15.0, step=0.01,value=5.0, label="guidance scale")
297
- num_samples= gr.Slider(minimum=1,maximum=4.0, step=1.0,value=1.0, label="num samples")
298
- num_inference_steps = gr.Slider(minimum=5,maximum=50.0, step=1.0,value=20, label="num inference steps")
299
- seed = gr.Slider(minimum=-1000000,maximum=1000000,value=1, step=1, label="Seed Value")
300
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
301
-
302
- #generate_button = gr.Button("Generate Image")
303
 
304
- with gr.Column():
305
- generated_image = gr.Gallery(label="Generated Image")
306
- generate_button = gr.Button("Generate Image")
307
- gr.Markdown(description)
 
 
 
308
 
309
  generate_button.click(
310
  fn=randomize_seed_fn,
@@ -327,7 +349,12 @@ with block:
327
  target,
328
  neg_content_prompt,
329
  neg_content_scale],
330
- outputs=[generated_image])
 
 
 
 
 
331
 
332
  gr.Examples(
333
  examples=get_example(),
 
3
 
4
 
5
  import os
6
+ import gc
7
  import cv2
8
  import torch
9
  import random
 
162
  seed,
163
  target="Load only style blocks",
164
  neg_content_prompt=None,
165
+ neg_content_scale=0,
166
+ progress=gr.Progress(track_tqdm=True)
167
+ ):
168
 
169
  if target =="Load original IP-Adapter":
170
  # target_blocks=["blocks"] for original IP-Adapter
 
214
  image=canny_map,
215
  controlnet_conditioning_scale=float(control_scale),
216
  )
217
+
218
+ gradio_temp_dir = os.environ['GRADIO_TEMP_DIR']
219
+ temp_file_path = os.path.join(gradio_temp_dir, "image.png")
220
+ images[0].save(temp_file_path, format="PNG")
221
+ print(f"Image saved in: {temp_file_path}")
222
+
223
+ return images, temp_file_path
224
 
225
  def pil_to_cv2(image_pil):
226
  image_np = np.array(image_pil)
227
  image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
228
  return image_cv2
229
 
230
+ def clear_cache(device="cuda"):
231
+ gc.collect()
232
+ if device == 'mps':
233
+ torch.mps.empty_cache()
234
+ elif device == 'cuda':
235
+ torch.cuda.empty_cache()
236
+ print(f"{device} cache cleared!")
237
+
238
  # Description
239
  title = r"""
240
  <h1 align="center">InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation</h1>
 
275
  If you have any questions, please feel free to open an issue or directly reach us out at <b>haofanwang.ai@gmail.com</b>.
276
  """
277
 
278
+ css = """
279
+ footer { visibility: hidden }
280
+ #row-height { height: 65px !important }
281
+ """
282
+
283
+ block = gr.Blocks(css=css).queue(max_size=10, api_open=False)
284
  with block:
285
 
286
  # description
287
  gr.Markdown(title)
288
  #gr.Markdown(description)
289
 
290
+ with gr.Row(equal_height=True):
291
+ with gr.Column():
292
+
293
+ with gr.Row():
294
+ with gr.Column():
295
+ image_pil = gr.Image(label="Style Image", type='pil')
296
+
297
+ target = gr.Radio(["Load only style blocks", "Load style+layout block", "Load original IP-Adapter"],
298
+ value="Load only style blocks",
299
+ label="Style mode")
300
+
301
+ prompt = gr.Textbox(label="Prompt",
302
+ value="a cat, masterpiece, best quality, high quality")
303
+
304
+ scale = gr.Slider(minimum=0,maximum=2.0, step=0.01,value=1.0, label="Scale")
305
+
306
+ with gr.Accordion(open=False, label="Advanced Options"):
307
 
308
+ with gr.Column():
309
+ src_image_pil = gr.Image(label="Source Image (optional)", type='pil')
310
+ control_scale = gr.Slider(minimum=0,maximum=1.0, step=0.01,value=0.5, label="Controlnet conditioning scale")
311
 
312
+ n_prompt = gr.Textbox(label="Neg Prompt", value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
 
 
313
 
314
+ neg_content_prompt = gr.Textbox(label="Neg Content Prompt", value="")
315
+ neg_content_scale = gr.Slider(minimum=0, maximum=1.0, step=0.01,value=0.5, label="Neg Content Scale")
316
+
317
+ guidance_scale = gr.Slider(minimum=1,maximum=15.0, step=0.01,value=5.0, label="guidance scale")
318
+ num_samples= gr.Slider(minimum=1,maximum=4.0, step=1.0,value=1.0, label="num samples")
319
+ num_inference_steps = gr.Slider(minimum=5,maximum=50.0, step=1.0,value=20, label="num inference steps")
320
+ seed = gr.Slider(minimum=-1000000,maximum=1000000,value=1, step=1, label="Seed Value")
321
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
+ #generate_button = gr.Button("Generate Image")
324
+
325
+ with gr.Column():
326
+ generated_image = gr.Gallery(label="Generated Image", scale=0.3)
327
+ download_image = gr.File(label="Download Image", elem_id="row-height", scale=0)
328
+ generate_button = gr.Button("Generate Image", min_width=2000, scale=0)
329
+ gr.Markdown(description)
330
 
331
  generate_button.click(
332
  fn=randomize_seed_fn,
 
349
  target,
350
  neg_content_prompt,
351
  neg_content_scale],
352
+ outputs=[generated_image, download_image]
353
+ ).then(
354
+ fn=clear_cache,
355
+ inputs=[],
356
+ outputs=None
357
+ )
358
 
359
  gr.Examples(
360
  examples=get_example(),