File size: 4,831 Bytes
ce3552b
 
 
 
3f6a115
527bf99
 
ce3552b
 
c590a10
ce3552b
 
7120442
ce3552b
 
7120442
ce3552b
 
527bf99
 
 
 
 
fc8546c
ce3552b
 
 
 
 
 
 
 
d42756a
ce3552b
 
 
527bf99
 
 
 
 
 
 
 
 
ce3552b
 
 
 
 
 
 
 
527bf99
ce3552b
527bf99
ce3552b
527bf99
ce3552b
 
 
79488ea
1467efe
79488ea
da9710b
fce9e32
527bf99
 
 
 
 
 
 
fce9e32
 
 
 
 
 
 
 
 
 
251a915
fce9e32
da9710b
fce9e32
2104e5b
 
 
527bf99
 
d42756a
a8d52ff
527bf99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
from region_control import MultiDiffusion, get_views, preprocess_mask
from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
MAX_COLORS = 12

sd = MultiDiffusion("cuda", "2.0")

def process_sketch(image, binary_matrixes):
  high_freq_colors, image = get_high_freq_colors(image)
  how_many_colors = len(high_freq_colors)
  im2arr = np.array(image) # im2arr.shape: height x width x channel
  im2arr = color_quantization(im2arr, high_freq_colors)
  
  colors_fixed = []
  for color in high_freq_colors:
    r, g, b = color[1]
    if any(c != 255 for c in (r, g, b)):
      binary_matrix = create_binary_matrix(im2arr, (r,g,b))
      binary_matrixes.append(binary_matrix)
      colors_fixed.append(gr.update(value=f'<div style="display:flex;align-items: center;justify-content: center"><img width="20%" style="margin-right: 1em" src="file/{binary_matrix}" /><div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div></div>'))
  visibilities = []
  colors = []
  for n in range(MAX_COLORS):
    visibilities.append(gr.update(visible=False))
    colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>'))
  for n in range(how_many_colors-1):
    visibilities[n] = gr.update(visible=True)
    colors[n] = colors_fixed[n]
  return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]

def process_generation(binary_matrixes, master_prompt, *prompts):
    clipped_prompts = prompts[:len(binary_matrixes)]
    prompts = [master_prompt] + list(clipped_prompts)
    neg_prompts = [""] * len(prompts)
    fg_masks = torch.cat([preprocess_mask(mask_path, 512 // 8, 512 // 8, "cuda") for mask_path in binary_matrixes])
    bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
    bg_mask[bg_mask < 0] = 0
    masks = torch.cat([bg_mask, fg_masks])
    print(masks.size())
    image = sd.generate(masks, prompts, neg_prompts, 512, 512, 50, bootstrapping=20)
    return(image)

css = '''
#color-bg{display:flex;justify-content: center;align-items: center;}
.color-bg-item{width: 100%; height: 32px}
#main_button{width:100%}
'''
def update_css(aspect):
  if(aspect=='Square'):
    return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
  elif(aspect == 'Horizontal'):
   return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)]
  elif(aspect=='Vertical'):
    return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]

with gr.Blocks(css=css) as demo:
  binary_matrixes = gr.State([])
  gr.Markdown('''## Control your Stable Diffusion generation with Sketches
  This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io).
  ''')
  with gr.Row():
    with gr.Box(elem_id="main-image"):
      #with gr.Row():
      image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512,512), brush_radius=45)
      #image_horizontal = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(768,512), visible=False, brush_radius=45)
      #image_vertical = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512, 768), visible=False, brush_radius=45)
      #with gr.Row():
      #    aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio")
      button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True)
      
      prompts = []
      colors = []
      color_row = [None] * MAX_COLORS
      with gr.Column(visible=False) as post_sketch:
        general_prompt = gr.Textbox(label="General Prompt")
        for n in range(MAX_COLORS):
          with gr.Row(visible=False) as color_row[n]:
            with gr.Box(elem_id="color-bg"):
              colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>'))
            prompts.append(gr.Textbox(label="Prompt for this mask"))
        final_run_btn = gr.Button("Generate!")
    
    out_image = gr.Image(label="Result")
  gr.Markdown('''
  ![Examples](https://multidiffusion.github.io/pics/tight.jpg)
  ''')
  #css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
  #aspect.change(update_css, inputs=aspect, outputs=[image, image_horizontal, image_vertical])
  button_run.click(process_sketch, inputs=[image, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
  final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image)
demo.launch(debug=True)