Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from diffusers import ControlNetModel, EulerAncestralDiscreteScheduler | |
import torch | |
import numpy as np | |
from PIL import Image, ImageFilter | |
from extension import CustomStableDiffusionControlNetPipeline | |
import spaces | |
negative_prompt = "" | |
device = torch.device('cuda') | |
controlnet = ControlNetModel.from_pretrained("BlockDetail/PartialSketchControlNet", torch_dtype=torch.float16).to(device) | |
pipe = CustomStableDiffusionControlNetPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
controlnet=controlnet, torch_dtype=torch.float16 | |
).to(device) | |
pipe.safety_checker = None | |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
threshold = 250 | |
curr_num_samples = 2 | |
all_gens = [] | |
num_images = 5 | |
with gr.Blocks() as demo: | |
start_state = [] | |
with gr.Row(): | |
with gr.Column(): | |
gr.Textbox(label=None, value="We introduce a novel sketch-to-image tool that aligns with the iterative refinement process of artists. Our tool lets users sketch blocking strokes to coarsely represent the placement and form of objects and detail strokes to refine their shape and silhouettes.") | |
with gr.Row(): | |
gr.Textbox(label="Stroke Type", value="To sketch Blocking strokes, change brush color to green. To sketch Detail strokes, change brush color to black."), | |
dilation_strength = gr.Slider(7, 117, value=65, step=2, label="Dilation Strength"), | |
canvas = gr.Sketchpad(image_mode="RGBA", crop_size="1:1", label="Canvas", sources=(), brush = gr.Brush(colors=["#00FF00", "#000000"], default_size = 2, color_mode="fixed")) | |
prompt_box = gr.Textbox(label="Prompt") | |
with gr.Row(): | |
btn = gr.Button("Generate") | |
btn2 = gr.Button("Reset") | |
with gr.Column(): | |
num_samples = gr.Slider(1, 5, value=2, step=1, label="Num Samples to Generate"), | |
with gr.Tab("Renoised Images"): | |
gallery0 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512) | |
with gr.Tab("Renoised Overlay"): | |
gallery1 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512) | |
with gr.Tab("Pre-Renoise Images"): | |
gallery2 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512) | |
with gr.Tab("Pre-Renoise Overlay"): | |
gallery3 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512) | |
for k in range(num_images): | |
start_state.append([None, None]) | |
sketch_states = gr.State(start_state) | |
checkbox_state = gr.State(True) | |
def sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps, dilation): | |
global curr_num_samples | |
global pipe | |
generator = torch.Generator(device="cuda:0") | |
generator.manual_seed(seed) | |
negative_prompt = "" | |
guidance_scale = 7 | |
controlnet_conditioning_scale = 1.0 | |
images = pipe([prompt]*curr_num_samples, [curr_sketch_image.convert("RGB").point( lambda p: 256 if p > 128 else 0)]*curr_num_samples, guidance_scale=guidance_scale, controlnet_conditioning_scale = controlnet_conditioning_scale, negative_prompt = [negative_prompt] * curr_num_samples, num_inference_steps=num_steps, generator=generator, key_image=None, neg_mask=None).images | |
# run blended renoising if blocking strokes are provided | |
if dilation_mask is not None: | |
new_images = pipe.collage([prompt] * curr_num_samples, images, [dilation_mask] * curr_num_samples, num_inference_steps=50, strength=0.8)["images"] | |
else: | |
new_images = images | |
return images, new_images | |
def run_sketching(prompt, curr_sketch, sketch_states, dilation, contour_dilation=11): | |
seed = sketch_states[k][1] | |
if seed is None: | |
seed = np.random.randint(1000) | |
sketch_states[k][1] = seed | |
curr_sketch_image = Image.fromarray(curr_sketch["composite"]) | |
curr_sketch = np.array(curr_sketch_image.resize((512, 512), resample=0)) | |
curr_sketch[:, :, 0][curr_sketch[:, :, -1] == 0] = 255 | |
curr_sketch[:, :, 2][curr_sketch[:, :, -1] == 0] = 255 | |
curr_sketch[:, :, 1][curr_sketch[:, :, -1] == 0] = 255 | |
curr_sketch_image = Image.fromarray(curr_sketch[:, :, 0]).resize((512, 512)) | |
curr_construction_image = Image.fromarray(255 - curr_sketch[:, :, 1] + curr_sketch[:, :, 0]) | |
if np.sum(255 - np.array(curr_construction_image)) == 0: | |
curr_construction_image = None | |
curr_detail_image = Image.fromarray(curr_sketch[:, :, 1]).resize((512, 512)) | |
if curr_construction_image is not None: | |
dilation_mask = Image.fromarray(255 - np.array(curr_construction_image)).filter(ImageFilter.MaxFilter(dilation)) | |
dilation_mask = dilation_mask.point( lambda p: 256 if p > 0 else 25).filter(ImageFilter.GaussianBlur(radius = 5)) | |
neg_dilation_mask = Image.fromarray(255 - np.array(curr_detail_image)).filter(ImageFilter.MaxFilter(contour_dilation)) | |
neg_dilation_mask = np.array(neg_dilation_mask.point( lambda p: 256 if p > 0 else 0)) | |
dilation_mask = np.array(dilation_mask) | |
dilation_mask[neg_dilation_mask > 0] = 25 | |
dilation_mask = Image.fromarray(dilation_mask).filter(ImageFilter.GaussianBlur(radius = 5)) | |
else: | |
dilation_mask = None | |
images, new_images = sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps = 40, dilation = dilation) | |
save_sketch = np.array(Image.fromarray(curr_sketch).convert("RGBA")) | |
save_sketch[:, :, 3][save_sketch[:, :, 0] > 128] = 0 | |
overlays = [] | |
for i in images: | |
background = i.copy() | |
background.putalpha(80) | |
background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background) | |
overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).resize((512, 512)).convert("RGBA")) | |
overlays.append(overlay.convert("RGB")) | |
new_overlays = [] | |
for i in new_images: | |
background = i.copy() | |
background.putalpha(80) | |
background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background) | |
overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).resize((512, 512)).convert("RGBA")) | |
new_overlays.append(overlay.convert("RGB")) | |
global all_gens | |
all_gens = new_images | |
return new_images, new_overlays, images, overlays | |
def reset(sketch_states): | |
for k in range(len(sketch_states)): | |
sketch_states[k] = [None, None] | |
return None, sketch_states | |
# def change_color(stroke_type): | |
# if stroke_type == "Blocking": | |
# color = "#00FF00" | |
# else: | |
# color = "#000000" | |
# return gr.Sketchpad(sources = (), width=512, brush = gr.Brush(colors=[color], default_size = 2, color_mode="fixed"), height=512) | |
def change_background(option): | |
global all_gens | |
if option == "None" or len(all_gens) == 0: | |
return None | |
elif option == "Sample 0": | |
image_overlay = all_gens[0].copy() | |
elif option == "Sample 1": | |
image_overlay = all_gens[0].copy() | |
else: | |
return None | |
image_overlay.putalpha(80) | |
return image_overlay | |
def change_num_samples(change): | |
global curr_num_samples | |
curr_num_samples = change | |
return None | |
btn.click(run_sketching, [prompt_box, canvas, sketch_states, dilation_strength[0]], [gallery0, gallery1, gallery2, gallery3]) | |
btn2.click(reset, sketch_states, [canvas, sketch_states]) | |
# stroke_type[0].change(change_color, [stroke_type[0]], canvas) | |
num_samples[0].change(change_num_samples, [num_samples[0]], None) | |
demo.launch(share = True, debug = True) |