File size: 5,882 Bytes
ce3552b
 
 
 
3f6a115
94ca5b6
 
 
527bf99
 
ce3552b
 
c590a10
ce3552b
94ca5b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce3552b
94ca5b6
ce3552b
94ca5b6
 
527bf99
 
 
fc8546c
ce3552b
 
 
 
 
94ca5b6
ce3552b
 
d42756a
ce3552b
 
 
527bf99
 
 
 
 
 
 
 
 
ce3552b
 
 
 
 
 
 
 
527bf99
ce3552b
527bf99
ce3552b
527bf99
ce3552b
 
 
79488ea
1467efe
79488ea
da9710b
fce9e32
527bf99
94ca5b6
 
 
 
527bf99
 
 
 
 
fce9e32
 
 
 
 
 
 
 
 
 
251a915
fce9e32
da9710b
fce9e32
2104e5b
 
 
527bf99
 
94ca5b6
a8d52ff
94ca5b6
527bf99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
import base64
import requests 
from io import BytesIO
from region_control import MultiDiffusion, get_views, preprocess_mask
from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
MAX_COLORS = 12

sd = MultiDiffusion("cuda", "2.0")

canvas_html = "<div id='canvas-root'></div>"
load_js = """
async () => {
const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js"
fetch(url)
  .then(res => res.text())
  .then(text => {
    const script = document.createElement('script');
    script.type = "module"
    script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' }));
    document.head.appendChild(script);
  });
}
"""

get_js_colors = """
async (canvasData) => {
  const canvasEl = document.getElementById("canvas-root");
  return [canvasEl._data]
}
"""

set_canvas_size ="""
async (aspect) => {
  if(aspect ==='square'){
    _updateCanvas(512,512)
  }
  if(aspect ==='horizontal'){
    _updateCanvas(768,512)
  }
  if(aspect ==='vertical'){
    _updateCanvas(512,768)
  }
}
"""

def process_sketch(canvas_data, binary_matrixes):
  base64_img = canvas_data['image']
  image_data = base64.b64decode(base64_img.split(',')[1])
  image = Image.open(BytesIO(image_data))
  im2arr = np.array(image)
  
  colors = [tuple(int(color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)) for color in canvas_data['colors']] 
  colors_fixed = []
  for color in colors:
    r, g, b = color
    if any(c != 255 for c in (r, g, b)):
      binary_matrix = create_binary_matrix(im2arr, (r,g,b))
      binary_matrixes.append(binary_matrix)
      colors_fixed.append(gr.update(value=f'<div style="display:flex;align-items: center;justify-content: center"><img width="20%" style="margin-right: 1em" src="file/{binary_matrix}" /><div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div></div>'))
  visibilities = []
  colors = []
  for n in range(MAX_COLORS):
    visibilities.append(gr.update(visible=False))
    colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>'))
  for n in range(len(colors)-1):
    visibilities[n] = gr.update(visible=True)
    colors[n] = colors_fixed[n]
  return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]

def process_generation(binary_matrixes, master_prompt, *prompts):
    clipped_prompts = prompts[:len(binary_matrixes)]
    prompts = [master_prompt] + list(clipped_prompts)
    neg_prompts = [""] * len(prompts)
    fg_masks = torch.cat([preprocess_mask(mask_path, 512 // 8, 512 // 8, "cuda") for mask_path in binary_matrixes])
    bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
    bg_mask[bg_mask < 0] = 0
    masks = torch.cat([bg_mask, fg_masks])
    print(masks.size())
    image = sd.generate(masks, prompts, neg_prompts, 512, 512, 50, bootstrapping=20)
    return(image)

css = '''
#color-bg{display:flex;justify-content: center;align-items: center;}
.color-bg-item{width: 100%; height: 32px}
#main_button{width:100%}
'''
def update_css(aspect):
  if(aspect=='Square'):
    return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
  elif(aspect == 'Horizontal'):
   return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)]
  elif(aspect=='Vertical'):
    return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]

with gr.Blocks(css=css) as demo:
  binary_matrixes = gr.State([])
  gr.Markdown('''## Control your Stable Diffusion generation with Sketches
  This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io).
  ''')
  with gr.Row():
    with gr.Box(elem_id="main-image"):
      #with gr.Row():
      canvas_data = gr.JSON(value={}, visible=False)
      canvas = gr.HTML(canvas_html)
      #image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512,512), brush_radius=45)
      
      #image_horizontal = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(768,512), visible=False, brush_radius=45)
      #image_vertical = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512, 768), visible=False, brush_radius=45)
      #with gr.Row():
      #    aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio")
      button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True)
      
      prompts = []
      colors = []
      color_row = [None] * MAX_COLORS
      with gr.Column(visible=False) as post_sketch:
        general_prompt = gr.Textbox(label="General Prompt")
        for n in range(MAX_COLORS):
          with gr.Row(visible=False) as color_row[n]:
            with gr.Box(elem_id="color-bg"):
              colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>'))
            prompts.append(gr.Textbox(label="Prompt for this mask"))
        final_run_btn = gr.Button("Generate!")
    
    out_image = gr.Image(label="Result")
  gr.Markdown('''
  ![Examples](https://multidiffusion.github.io/pics/tight.jpg)
  ''')
  #css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
  #aspect.change(update_css, inputs=aspect, outputs=[image, image_horizontal, image_vertical])
  button_run.click(process_sketch, inputs=[canvas_data, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
  final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image)
  demo.load(None, None, None, _js=load_js)
demo.launch(debug=True)