Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import torch | |
import base64 | |
import requests | |
from io import BytesIO | |
from region_control import MultiDiffusion, get_views, preprocess_mask | |
from sketch_helper import get_high_freq_colors, color_quantization, create_binary_matrix | |
MAX_COLORS = 12 | |
sd = MultiDiffusion("cuda", "2.0") | |
canvas_html = "<div id='canvas-root'></div>" | |
load_js = """ | |
async () => { | |
const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js" | |
fetch(url) | |
.then(res => res.text()) | |
.then(text => { | |
const script = document.createElement('script'); | |
script.type = "module" | |
script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' })); | |
document.head.appendChild(script); | |
}); | |
} | |
""" | |
get_js_colors = """ | |
async (canvasData) => { | |
const canvasEl = document.getElementById("canvas-root"); | |
return [canvasEl._data] | |
} | |
""" | |
set_canvas_size =""" | |
async (aspect) => { | |
if(aspect ==='square'){ | |
_updateCanvas(512,512) | |
} | |
if(aspect ==='horizontal'){ | |
_updateCanvas(768,512) | |
} | |
if(aspect ==='vertical'){ | |
_updateCanvas(512,768) | |
} | |
} | |
""" | |
def process_sketch(canvas_data, binary_matrixes): | |
base64_img = canvas_data['image'] | |
image_data = base64.b64decode(base64_img.split(',')[1]) | |
image = Image.open(BytesIO(image_data)) | |
im2arr = np.array(image) | |
colors = [tuple(int(color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)) for color in canvas_data['colors']] | |
colors_fixed = [] | |
for color in colors: | |
r, g, b = color | |
if any(c != 255 for c in (r, g, b)): | |
binary_matrix = create_binary_matrix(im2arr, (r,g,b)) | |
binary_matrixes.append(binary_matrix) | |
colors_fixed.append(gr.update(value=f'<div style="display:flex;align-items: center;justify-content: center"><img width="20%" style="margin-right: 1em" src="file/{binary_matrix}" /><div class="color-bg-item" style="background-color: rgb({r},{g},{b})"></div></div>')) | |
visibilities = [] | |
colors = [] | |
for n in range(MAX_COLORS): | |
visibilities.append(gr.update(visible=False)) | |
colors.append(gr.update(value=f'<div class="color-bg-item" style="background-color: black"></div>')) | |
for n in range(len(colors)-1): | |
visibilities[n] = gr.update(visible=True) | |
colors[n] = colors_fixed[n] | |
return [gr.update(visible=True), binary_matrixes, *visibilities, *colors] | |
def process_generation(binary_matrixes, master_prompt, *prompts): | |
clipped_prompts = prompts[:len(binary_matrixes)] | |
prompts = [master_prompt] + list(clipped_prompts) | |
neg_prompts = [""] * len(prompts) | |
fg_masks = torch.cat([preprocess_mask(mask_path, 512 // 8, 512 // 8, "cuda") for mask_path in binary_matrixes]) | |
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True) | |
bg_mask[bg_mask < 0] = 0 | |
masks = torch.cat([bg_mask, fg_masks]) | |
print(masks.size()) | |
image = sd.generate(masks, prompts, neg_prompts, 512, 512, 50, bootstrapping=20) | |
return(image) | |
css = ''' | |
#color-bg{display:flex;justify-content: center;align-items: center;} | |
.color-bg-item{width: 100%; height: 32px} | |
#main_button{width:100%} | |
.isPopup.svelte-160vdtq { | |
top: -342px !important; | |
z-index: 10001 !important; | |
left: -25px !important; | |
} | |
<style> | |
''' | |
def update_css(aspect): | |
if(aspect=='Square'): | |
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)] | |
elif(aspect == 'Horizontal'): | |
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)] | |
elif(aspect=='Vertical'): | |
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] | |
with gr.Blocks(css=css) as demo: | |
binary_matrixes = gr.State([]) | |
gr.Markdown('''## Control your Stable Diffusion generation with Sketches | |
This Space demonstrates MultiDiffusion region-based generation using Stable Diffusion model. To get started, draw your masks and type your prompts. More details in the [project page](https://multidiffusion.github.io). | |
''') | |
with gr.Row(): | |
with gr.Box(elem_id="main-image"): | |
#with gr.Row(): | |
canvas_data = gr.JSON(value={}, visible=False) | |
canvas = gr.HTML(canvas_html) | |
#image = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512,512), brush_radius=45) | |
#image_horizontal = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(768,512), visible=False, brush_radius=45) | |
#image_vertical = gr.Image(interactive=True, tool="color-sketch", source="canvas", type="pil", shape=(512, 768), visible=False, brush_radius=45) | |
#with gr.Row(): | |
# aspect = gr.Radio(["Square", "Horizontal", "Vertical"], value="Square", label="Aspect Ratio") | |
button_run = gr.Button("I've finished my sketch",elem_id="main_button", interactive=True) | |
prompts = [] | |
colors = [] | |
color_row = [None] * MAX_COLORS | |
with gr.Column(visible=False) as post_sketch: | |
general_prompt = gr.Textbox(label="General Prompt") | |
for n in range(MAX_COLORS): | |
with gr.Row(visible=False) as color_row[n]: | |
with gr.Box(elem_id="color-bg"): | |
colors.append(gr.HTML('<div class="color-bg-item" style="background-color: black"></div>')) | |
prompts.append(gr.Textbox(label="Prompt for this mask")) | |
final_run_btn = gr.Button("Generate!") | |
out_image = gr.Image(label="Result") | |
gr.Markdown(''' | |
![Examples](https://multidiffusion.github.io/pics/tight.jpg) | |
''') | |
#css_height = gr.HTML("<style>#main-image{width: 512px} .fixed-height{height: 512px !important}</style>") | |
#aspect.change(update_css, inputs=aspect, outputs=[image, image_horizontal, image_vertical]) | |
button_run.click(process_sketch, inputs=[canvas_data, binary_matrixes], outputs=[post_sketch, binary_matrixes, *color_row, *colors], _js=get_js_colors) | |
final_run_btn.click(process_generation, inputs=[binary_matrixes, general_prompt, *prompts], outputs=out_image) | |
demo.load(None, None, None, _js=load_js) | |
demo.launch(debug=True) |