[Major] Add list generation
Browse files
app.py
CHANGED
@@ -282,22 +282,138 @@ def generate(
|
|
282 |
|
283 |
return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask]
|
284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
def reset():
|
286 |
return [100, "Randomize Seed", 1372, "Fix CFG", 7.5, 1.5, None, None, None, None, None, None, None, "Close Image Video", 10]
|
287 |
|
|
|
288 |
def get_example():
|
289 |
return [
|
290 |
-
["example_images/dufu.png", "black and white suit", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
291 |
-
["example_images/girl.jpeg", "reflective sunglasses", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
292 |
-
["example_images/
|
293 |
-
["example_images/
|
294 |
-
["example_images/
|
295 |
-
["example_images/
|
296 |
-
["example_images/
|
297 |
-
["example_images/girl.jpeg", "
|
298 |
-
["example_images/
|
299 |
-
["example_images/girl.jpeg", "
|
300 |
-
["example_images/
|
|
|
|
|
301 |
]
|
302 |
|
303 |
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
@@ -325,7 +441,14 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
325 |
with gr.Row():
|
326 |
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
|
327 |
with gr.Row():
|
328 |
-
instruction = gr.Textbox(lines=1, label="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
with gr.Row():
|
330 |
steps = gr.Number(value=100, precision=0, label="Steps", interactive=True)
|
331 |
randomize_seed = gr.Radio(
|
@@ -347,17 +470,13 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
347 |
)
|
348 |
text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True)
|
349 |
image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True)
|
350 |
-
with gr.Row():
|
351 |
-
reset_button = gr.Button("Reset")
|
352 |
-
generate_button = gr.Button("Generate")
|
353 |
with gr.Column(scale=1, min_width=100):
|
354 |
with gr.Column():
|
355 |
mix_image = gr.Image(label=f"Mix Image", type="pil", interactive=False)
|
356 |
with gr.Column():
|
357 |
edited_mask = gr.Image(label=f"Output Mask", type="pil", interactive=False)
|
358 |
|
359 |
-
|
360 |
-
with gr.Accordion('More outputs', open=False):
|
361 |
with gr.Row():
|
362 |
weather_close_video = gr.Radio(
|
363 |
["Show Image Video", "Close Image Video"],
|
@@ -374,15 +493,11 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
374 |
original_image = gr.Image(label=f"Original Image", type="pil", interactive=False)
|
375 |
edited_image = gr.Image(label=f"Output Image", type="pil", interactive=False)
|
376 |
mix_result_with_red_mask = gr.Image(label=f"Mix Image With Red Mask", type="pil", interactive=False)
|
377 |
-
|
378 |
|
379 |
with gr.Row():
|
380 |
gr.Examples(
|
381 |
examples=get_example(),
|
382 |
-
|
383 |
-
inputs=[input_image, instruction, steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale],
|
384 |
-
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask],
|
385 |
-
cache_examples=False,
|
386 |
)
|
387 |
|
388 |
generate_button.click(
|
@@ -401,6 +516,24 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
401 |
],
|
402 |
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask],
|
403 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
reset_button.click(
|
405 |
fn=reset,
|
406 |
inputs=[],
|
|
|
282 |
|
283 |
return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask]
|
284 |
|
285 |
+
@spaces.GPU(duration=30)
|
286 |
+
def generate_list(
|
287 |
+
input_image: Image.Image,
|
288 |
+
generate_list: str,
|
289 |
+
steps: int,
|
290 |
+
randomize_seed: bool,
|
291 |
+
seed: int,
|
292 |
+
randomize_cfg: bool,
|
293 |
+
text_cfg_scale: float,
|
294 |
+
image_cfg_scale: float,
|
295 |
+
weather_close_video: bool,
|
296 |
+
decode_image_batch: int
|
297 |
+
):
|
298 |
+
generate_list = generate_list.split('\n')
|
299 |
+
# Remove the empty element
|
300 |
+
generate_list = [element for element in generate_list if element]
|
301 |
+
|
302 |
+
seed = random.randint(0, 100000) if randomize_seed else seed
|
303 |
+
text_cfg_scale = round(random.uniform(6.0, 9.0), ndigits=2) if randomize_cfg else text_cfg_scale
|
304 |
+
image_cfg_scale = round(random.uniform(1.2, 1.8), ndigits=2) if randomize_cfg else image_cfg_scale
|
305 |
+
|
306 |
+
width, height = input_image.size
|
307 |
+
factor = args.resolution / max(width, height)
|
308 |
+
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
|
309 |
+
width = int((width * factor) // 64) * 64
|
310 |
+
height = int((height * factor) // 64) * 64
|
311 |
+
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
|
312 |
+
|
313 |
+
if len(generate_list) == 0:
|
314 |
+
return [input_image, seed]
|
315 |
+
|
316 |
+
model.cuda()
|
317 |
+
image_video = [np.array(input_image).astype(np.uint8)]
|
318 |
+
generate_index = 0
|
319 |
+
input_image_copy = input_image.convert("RGB")
|
320 |
+
while generate_index < len(generate_list):
|
321 |
+
print(f'generate_index: {str(generate_index)}')
|
322 |
+
instruction = generate_list[generate_index]
|
323 |
+
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
324 |
+
cond = {}
|
325 |
+
input_image_torch = 2 * torch.tensor(np.array(input_image_copy.copy())).float() / 255 - 1
|
326 |
+
input_image_torch = rearrange(input_image_torch, "h w c -> 1 c h w").to(model.device)
|
327 |
+
cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)]
|
328 |
+
cond["c_concat"] = [model.encode_first_stage(input_image_torch).mode().to(model.device)]
|
329 |
+
|
330 |
+
uncond = {}
|
331 |
+
uncond["c_crossattn"] = [null_token.to(model.device)]
|
332 |
+
uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
|
333 |
+
|
334 |
+
sigmas = model_wrap.get_sigmas(steps).to(model.device)
|
335 |
+
|
336 |
+
extra_args = {
|
337 |
+
"cond": cond,
|
338 |
+
"uncond": uncond,
|
339 |
+
"text_cfg_scale": text_cfg_scale,
|
340 |
+
"image_cfg_scale": image_cfg_scale,
|
341 |
+
}
|
342 |
+
torch.manual_seed(seed)
|
343 |
+
z_0 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0]
|
344 |
+
z_1 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0]
|
345 |
+
|
346 |
+
z_0, z_1, _, _ = sample_euler_ancestral(model_wrap_cfg, z_0, z_1, sigmas, height, width, extra_args=extra_args)
|
347 |
+
|
348 |
+
x_0 = model.decode_first_stage(z_0)
|
349 |
+
|
350 |
+
x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False)
|
351 |
+
x_1 = torch.where(x_1 > 0, 1, -1) # Thresholding step
|
352 |
+
|
353 |
+
if torch.sum(x_1).item()/x_1.numel() < -0.99:
|
354 |
+
seed += 1
|
355 |
+
continue
|
356 |
+
else:
|
357 |
+
generate_index += 1
|
358 |
+
|
359 |
+
x_0 = torch.clamp((x_0 + 1.0) / 2.0, min=0.0, max=1.0)
|
360 |
+
x_1 = torch.clamp((x_1 + 1.0) / 2.0, min=0.0, max=1.0)
|
361 |
+
x_0 = 255.0 * rearrange(x_0, "1 c h w -> h w c")
|
362 |
+
x_1 = 255.0 * rearrange(x_1, "1 c h w -> h w c")
|
363 |
+
x_1 = torch.cat([x_1, x_1, x_1], dim=-1)
|
364 |
+
edited_image = Image.fromarray(x_0.type(torch.uint8).cpu().numpy())
|
365 |
+
edited_mask = Image.fromarray(x_1.type(torch.uint8).cpu().numpy())
|
366 |
+
|
367 |
+
# 对edited_mask做膨胀
|
368 |
+
edited_mask_copy = edited_mask.copy()
|
369 |
+
kernel = np.ones((3, 3), np.uint8)
|
370 |
+
edited_mask = cv2.dilate(np.array(edited_mask), kernel, iterations=3)
|
371 |
+
edited_mask = Image.fromarray(edited_mask)
|
372 |
+
|
373 |
+
m_img = edited_mask.filter(ImageFilter.GaussianBlur(radius=3))
|
374 |
+
m_img = np.asarray(m_img).astype('float') / 255.0
|
375 |
+
img_np = np.asarray(input_image_copy).astype('float') / 255.0
|
376 |
+
ours_np = np.asarray(edited_image).astype('float') / 255.0
|
377 |
+
|
378 |
+
mix_image_np = m_img * ours_np + (1 - m_img) * img_np
|
379 |
+
|
380 |
+
image_video.append((mix_image_np * 255).astype(np.uint8))
|
381 |
+
mix_image = Image.fromarray((mix_image_np * 255).astype(np.uint8)).convert('RGB')
|
382 |
+
input_image_copy = mix_image
|
383 |
+
|
384 |
+
mix_result_with_red_mask = None
|
385 |
+
mask_video_path = None
|
386 |
+
edited_mask_copy = None
|
387 |
+
|
388 |
+
image_video_path = "image.mp4"
|
389 |
+
fps = 2
|
390 |
+
with imageio.get_writer(image_video_path, fps=fps) as video:
|
391 |
+
for image in image_video:
|
392 |
+
video.append_data(image)
|
393 |
+
|
394 |
+
|
395 |
+
return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image, mix_result_with_red_mask]
|
396 |
+
|
397 |
+
|
398 |
def reset():
|
399 |
return [100, "Randomize Seed", 1372, "Fix CFG", 7.5, 1.5, None, None, None, None, None, None, None, "Close Image Video", 10]
|
400 |
|
401 |
+
|
402 |
def get_example():
|
403 |
return [
|
404 |
+
["example_images/dufu.png", "", "black and white suit\nsunglasses\nblue medical mask\nyellow schoolbag\nred bow tie\nbrown high-top hat", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
405 |
+
["example_images/girl.jpeg", "reflective sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
406 |
+
["example_images/dufu.png", "black and white suit", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
407 |
+
["example_images/girl.jpeg", "reflective sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
408 |
+
["example_images/road_sign.png", "stop sign", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
409 |
+
["example_images/dufu.png", "blue medical mask", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
410 |
+
["example_images/people_standing.png", "dark green pleated skirt", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
411 |
+
["example_images/girl.jpeg", "shiny golden crown", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
412 |
+
["example_images/dufu.png", "sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
413 |
+
["example_images/girl.jpeg", "diamond necklace", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
414 |
+
["example_images/iron_man.jpg", "sunglasses", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
415 |
+
["example_images/girl.jpeg", "the queen's crown", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
416 |
+
["example_images/girl.jpeg", "gorgeous yellow gown", "", 100, "Fix Seed", 1372, "Fix CFG", 7.5, 1.5],
|
417 |
]
|
418 |
|
419 |
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
|
441 |
with gr.Row():
|
442 |
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
|
443 |
with gr.Row():
|
444 |
+
instruction = gr.Textbox(lines=1, label="Single object description", interactive=True)
|
445 |
+
with gr.Row():
|
446 |
+
reset_button = gr.Button("Reset")
|
447 |
+
generate_button = gr.Button("Generate")
|
448 |
+
with gr.Row():
|
449 |
+
list_input = gr.Textbox(label="Input List", placeholder="Enter one item per line", lines=10)
|
450 |
+
with gr.Row():
|
451 |
+
list_generate_button = gr.Button("List Generate")
|
452 |
with gr.Row():
|
453 |
steps = gr.Number(value=100, precision=0, label="Steps", interactive=True)
|
454 |
randomize_seed = gr.Radio(
|
|
|
470 |
)
|
471 |
text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True)
|
472 |
image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True)
|
|
|
|
|
|
|
473 |
with gr.Column(scale=1, min_width=100):
|
474 |
with gr.Column():
|
475 |
mix_image = gr.Image(label=f"Mix Image", type="pil", interactive=False)
|
476 |
with gr.Column():
|
477 |
edited_mask = gr.Image(label=f"Output Mask", type="pil", interactive=False)
|
478 |
|
479 |
+
with gr.Accordion('Click to see more (includes generation process per object for list generation and per step for single generation)', open=False):
|
|
|
480 |
with gr.Row():
|
481 |
weather_close_video = gr.Radio(
|
482 |
["Show Image Video", "Close Image Video"],
|
|
|
493 |
original_image = gr.Image(label=f"Original Image", type="pil", interactive=False)
|
494 |
edited_image = gr.Image(label=f"Output Image", type="pil", interactive=False)
|
495 |
mix_result_with_red_mask = gr.Image(label=f"Mix Image With Red Mask", type="pil", interactive=False)
|
|
|
496 |
|
497 |
with gr.Row():
|
498 |
gr.Examples(
|
499 |
examples=get_example(),
|
500 |
+
inputs=[input_image, instruction, list_input, steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale],
|
|
|
|
|
|
|
501 |
)
|
502 |
|
503 |
generate_button.click(
|
|
|
516 |
],
|
517 |
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask],
|
518 |
)
|
519 |
+
|
520 |
+
list_generate_button.click(
|
521 |
+
fn=generate_list,
|
522 |
+
inputs=[
|
523 |
+
input_image,
|
524 |
+
list_input,
|
525 |
+
steps,
|
526 |
+
randomize_seed,
|
527 |
+
seed,
|
528 |
+
randomize_cfg,
|
529 |
+
text_cfg_scale,
|
530 |
+
image_cfg_scale,
|
531 |
+
weather_close_video,
|
532 |
+
decode_image_batch
|
533 |
+
],
|
534 |
+
outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask, mask_video, image_video, original_image, mix_result_with_red_mask],
|
535 |
+
)
|
536 |
+
|
537 |
reset_button.click(
|
538 |
fn=reset,
|
539 |
inputs=[],
|