multimodalart's picture
Update app.py
94ca5b6
raw
history blame
No virus
5.88 kB
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
import base64
import requests
from io import BytesIO
from region_control import MultiDiffusion, get_views, preprocess_mask
from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix
MAX_COLORS = 12
sd = MultiDiffusion("cuda", "2.0")
canvas_html = "<div id='canvas-root'></div>"
load_js = """
async () => {
const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js"
fetch(url)
.then(res => res.text())
.then(text => {
const script = document.createElement('script');
script.type = "module"
script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' }));
document.head.appendChild(script);
});
}
"""
get_js_colors = """
async (canvasData) => {
const canvasEl = document.getElementById("canvas-root");
return [canvasEl._data]
}
"""
set_canvas_size ="""
async (aspect) => {
if(aspect ==='square'){
_updateCanvas(512,512)
}
if(aspect ==='horizontal'){
_updateCanvas(768,512)
}
if(aspect ==='vertical'){
_updateCanvas(512,768)
}
}
"""
def process_sketch(canvas_data, binary_matrixes):
base64_img = canvas_data['image']
image_data = base64.b64decode(base64_img.split(',')[1])
image = Image.open(BytesIO(image_data))
im2arr = np.array(image)
colors = [tuple(int(color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)) for color in canvas_data['colors']]
colors_fixed = []
for color in colors:
r, g, b = color
if any(c != 255 for c in (r, g, b)):
binary_matrix = create_binary_matrix(im2arr, (r,g,b))
binary_matrixes.append(binary_matrix)
colors_fixed.append(gr.update(value=f'<div style="display:flex;align-items: center;justify-content: center"><img width="20%" style="margin-right: 1em" src="file/{binary_matrix}" /><div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div></div>'))
visibilities = []
colors = []
for n in range(MAX_COLORS):
visibilities.append(gr.update(visible=False))
colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>'))
for n in range(len(colors)-1):
visibilities[n] = gr.update(visible=True)
colors[n] = colors_fixed[n]
return [gr.update(visible=True), binary_matrixes, *visibilities, *colors]
def process_generation(binary_matrixes, master_prompt, *prompts):
clipped_prompts = prompts[:len(binary_matrixes)]
prompts = [master_prompt] + list(clipped_prompts)
neg_prompts = [""] * len(prompts)
fg_masks = torch.cat([preprocess_mask(mask_path, 512 // 8, 512 // 8, "cuda") for mask_path in binary_matrixes])
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
bg_mask[bg_mask < 0] = 0
masks = torch.cat([bg_mask, fg_masks])
print(masks.size())
image = sd.generate(masks, prompts, neg_prompts, 512, 512, 50, bootstrapping=20)
return(image)
css = '''
#color-bg{display:flex;justify-content: center;align-items: center;}
.color-bg-item{width: 100%; height: 32px}
#main_button{width:100%}
'''
def update_css(aspect):
if(aspect=='Square'):
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
elif(aspect == 'Horizontal'):
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)]
elif(aspect=='Vertical'):
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
with gr.Blocks(css=css) as demo:
binary_matrixes = gr.State([])
gr.Markdown('''## Control your Stable Diffusion generation with Sketches
This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io).
''')
with gr.Row():
with gr.Box(elem_id="main-image"):
#with gr.Row():
canvas_data = gr.JSON(value={}, visible=False)
canvas = gr.HTML(canvas_html)
#image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512,512), brush_radius=45)
#image_horizontal = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(768,512), visible=False, brush_radius=45)
#image_vertical = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512, 768), visible=False, brush_radius=45)
#with gr.Row():
# aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio")
button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True)
prompts = []
colors = []
color_row = [None] * MAX_COLORS
with gr.Column(visible=False) as post_sketch:
general_prompt = gr.Textbox(label="General Prompt")
for n in range(MAX_COLORS):
with gr.Row(visible=False) as color_row[n]:
with gr.Box(elem_id="color-bg"):
colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>'))
prompts.append(gr.Textbox(label="Prompt for this mask"))
final_run_btn = gr.Button("Generate!")
out_image = gr.Image(label="Result")
gr.Markdown('''
![Examples](https://multidiffusion.github.io/pics/tight.jpg)
''')
#css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>")
#aspect.change(update_css, inputs=aspect, outputs=[image, image_horizontal, image_vertical])
button_run.click(process_sketch, inputs=[canvas_data, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors])
final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image)
demo.load(None, None, None, _js=load_js)
demo.launch(debug=True)