BlockDetail
fix
984cccb
raw
history blame
No virus
8.24 kB
import gradio as gr
from diffusers import ControlNetModel, EulerAncestralDiscreteScheduler
import torch
import numpy as np
from PIL import Image, ImageFilter
from extension import CustomStableDiffusionControlNetPipeline
import spaces
negative_prompt = ""
device = torch.device('cuda')
pipe = None
@spaces.GPU
def load():
global pipe
controlnet = ControlNetModel.from_pretrained("BlockDetail/PartialSketchControlNet", torch_dtype=torch.float16).to(device)
pipe = CustomStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=controlnet, torch_dtype=torch.float16
).to(device)
pipe.safety_checker = None
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
threshold = 250
curr_num_samples = 2
all_gens = []
num_images = 5
with gr.Blocks() as demo:
start_state = []
with gr.Row():
with gr.Column():
with gr.Row():
gr.Textbox(label="Stroke Type", value="To sketch Blocking strokes, change brush color to green. To sketch Detail strokes, change brush color to black."),
dilation_strength = gr.Slider(7, 117, value=65, step=2, label="Dilation Strength"),
canvas = gr.Sketchpad(image_mode="RGBA", crop_size="1:1", label="Sketch", sources=(), brush = gr.Brush(colors=["#00FF00", "#000000"], default_size = 2, color_mode="fixed"))
prompt_box = gr.Textbox(label="Prompt")
with gr.Row():
btn = gr.Button("Generate")
btn2 = gr.Button("Reset")
with gr.Column():
num_samples = gr.Slider(1, 5, value=2, step=1, label="Num Samples to Generate"),
with gr.Tab("Renoised Images"):
gallery0 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512)
with gr.Tab("Renoised Overlay"):
gallery1 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512)
with gr.Tab("Pre-Renoise Images"):
gallery2 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512)
with gr.Tab("Pre-Renoise Overlay"):
gallery3 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512)
for k in range(num_images):
start_state.append([None, None])
sketch_states = gr.State(start_state)
checkbox_state = gr.State(True)
@spaces.GPU
def sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps, dilation):
global curr_num_samples
global pipe
generator = torch.Generator(device="cuda:0")
generator.manual_seed(seed)
negative_prompt = ""
guidance_scale = 7
controlnet_conditioning_scale = 1.0
images = pipe([prompt]*curr_num_samples, [curr_sketch_image.convert("RGB").point( lambda p: 256 if p > 128 else 0)]*curr_num_samples, guidance_scale=guidance_scale, controlnet_conditioning_scale = controlnet_conditioning_scale, negative_prompt = [negative_prompt] * curr_num_samples, num_inference_steps=num_steps, generator=generator, key_image=None, neg_mask=None).images
# run blended renoising if blocking strokes are provided
if dilation_mask is not None:
new_images = pipe.collage([prompt] * curr_num_samples, images, [dilation_mask] * curr_num_samples, num_inference_steps=50, strength=0.8)["images"]
else:
new_images = images
return images, new_images
def run_sketching(prompt, curr_sketch, sketch_states, dilation, contour_dilation=11):
seed = sketch_states[k][1]
if seed is None:
seed = np.random.randint(1000)
sketch_states[k][1] = seed
curr_sketch_image = Image.fromarray(curr_sketch["composite"])
curr_sketch = np.array(curr_sketch_image.resize((512, 512), resample=0))
curr_sketch[:, :, 0][curr_sketch[:, :, -1] == 0] = 255
curr_sketch[:, :, 2][curr_sketch[:, :, -1] == 0] = 255
curr_sketch[:, :, 1][curr_sketch[:, :, -1] == 0] = 255
curr_sketch_image = Image.fromarray(curr_sketch[:, :, 0]).resize((512, 512))
curr_construction_image = Image.fromarray(255 - curr_sketch[:, :, 1] + curr_sketch[:, :, 0])
if np.sum(255 - np.array(curr_construction_image)) == 0:
curr_construction_image = None
curr_detail_image = Image.fromarray(curr_sketch[:, :, 1]).resize((512, 512))
if curr_construction_image is not None:
dilation_mask = Image.fromarray(255 - np.array(curr_construction_image)).filter(ImageFilter.MaxFilter(dilation))
dilation_mask = dilation_mask.point( lambda p: 256 if p > 0 else 25).filter(ImageFilter.GaussianBlur(radius = 5))
neg_dilation_mask = Image.fromarray(255 - np.array(curr_detail_image)).filter(ImageFilter.MaxFilter(contour_dilation))
neg_dilation_mask = np.array(neg_dilation_mask.point( lambda p: 256 if p > 0 else 0))
dilation_mask = np.array(dilation_mask)
dilation_mask[neg_dilation_mask > 0] = 25
dilation_mask = Image.fromarray(dilation_mask).filter(ImageFilter.GaussianBlur(radius = 5))
else:
dilation_mask = None
images, new_images = sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps = 40, dilation = dilation)
save_sketch = np.array(Image.fromarray(curr_sketch).convert("RGBA"))
save_sketch[:, :, 3][save_sketch[:, :, 0] > 128] = 0
overlays = []
for i in images:
background = i.copy()
background.putalpha(80)
background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background)
overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).resize((512, 512)).convert("RGBA"))
overlays.append(overlay.convert("RGB"))
new_overlays = []
for i in new_images:
background = i.copy()
background.putalpha(80)
background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background)
overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).resize((512, 512)).convert("RGBA"))
new_overlays.append(overlay.convert("RGB"))
global all_gens
all_gens = new_images
return new_images, new_overlays, images, overlays
def reset(sketch_states):
for k in range(len(sketch_states)):
sketch_states[k] = [None, None]
return None, sketch_states
# def change_color(stroke_type):
# if stroke_type == "Blocking":
# color = "#00FF00"
# else:
# color = "#000000"
# return gr.Sketchpad(sources = (), width=512, brush = gr.Brush(colors=[color], default_size = 2, color_mode="fixed"), height=512)
def change_background(option):
global all_gens
if option == "None" or len(all_gens) == 0:
return None
elif option == "Sample 0":
image_overlay = all_gens[0].copy()
elif option == "Sample 1":
image_overlay = all_gens[0].copy()
else:
return None
image_overlay.putalpha(80)
return image_overlay
def change_num_samples(change):
global curr_num_samples
curr_num_samples = change
return None
btn.click(run_sketching, [prompt_box, canvas, sketch_states, dilation_strength[0]], [gallery0, gallery1, gallery2, gallery3])
btn2.click(reset, sketch_states, [canvas, sketch_states])
# stroke_type[0].change(change_color, [stroke_type[0]], canvas)
num_samples[0].change(change_num_samples, [num_samples[0]], None)
load()
demo.launch(share = True, debug = True)