Spaces:
Sleeping
Sleeping
File size: 7,815 Bytes
907070c 2ad3800 907070c 7504db8 907070c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import gradio as gr
from diffusers import ControlNetModel, EulerAncestralDiscreteScheduler
import torch
import numpy as np
from PIL import Image, ImageFilter
from extension import CustomStableDiffusionControlNetPipeline
negative_prompt = ""
device = torch.device('cuda')
controlnet = ControlNetModel.from_pretrained("BlockDetail/PartialSketchControlNet", torch_dtype=torch.float16).to(device)
pipe = CustomStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=controlnet, torch_dtype=torch.float16
).to(device)
pipe.safety_checker = None
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
threshold = 250
curr_num_samples = 2
all_gens = []
num_images = 5
with gr.Blocks() as demo:
start_state = []
with gr.Row():
with gr.Column():
with gr.Row():
stroke_type = gr.Radio(["Blocking", "Detail"], value="Detail", label="Stroke Type"),
dilation_strength = gr.Slider(7, 117, value=65, step=2, label="Dilation Strength"),
canvas = gr.Image(source="canvas", shape=(512, 512), tool="color-sketch",
min_width=512, brush_radius = 2).style(width=512, height=512)
prompt_box = gr.Textbox(width="50vw", label="Prompt")
with gr.Row():
btn = gr.Button("Generate").style(width=100, height=80)
btn2 = gr.Button("Reset").style(width=100, height=80)
with gr.Column():
num_samples = gr.Slider(1, 5, value=2, step=1, label="Num Samples to Generate"),
with gr.Tab("Renoised Images"):
gallery0 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height="auto", preview=True, interactive=False).style(width=512, height=512)
with gr.Tab("Renoised Overlay"):
gallery1 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height="auto", preview=True, interactive=False).style(width=512, height=512)
with gr.Tab("Pre-Renoise Images"):
gallery2 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height="auto", preview=True, interactive=False).style(width=512, height=512)
with gr.Tab("Pre-Renoise Overlay"):
gallery3 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height="auto", preview=True, interactive=False).style(width=512, height=512)
for k in range(num_images):
start_state.append([None, None])
sketch_states = gr.State(start_state)
checkbox_state = gr.State(True)
def sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps, dilation):
global curr_num_samples
generator = torch.Generator(device="cuda:0")
generator.manual_seed(seed)
negative_prompt = ""
guidance_scale = 7
controlnet_conditioning_scale = 1.0
images = pipe([prompt]*curr_num_samples, [curr_sketch_image.convert("RGB").point( lambda p: 256 if p > 128 else 0)]*curr_num_samples, guidance_scale=guidance_scale, controlnet_conditioning_scale = controlnet_conditioning_scale, negative_prompt = [negative_prompt] * curr_num_samples, num_inference_steps=num_steps, generator=generator, key_image=None, neg_mask=None).images
# run blended renoising if blocking strokes are provided
if dilation_mask is not None:
new_images = pipe.collage([prompt] * curr_num_samples, images, [dilation_mask] * curr_num_samples, num_inference_steps=50, strength=0.8)["images"]
else:
new_images = images
return images, new_images
def run_sketching(prompt, curr_sketch, sketch_states, dilation, contour_dilation=11):
seed = sketch_states[k][1]
if seed is None:
seed = np.random.randint(1000)
sketch_states[k][1] = seed
curr_sketch_image = Image.fromarray(curr_sketch[:, :, 0]).resize((512, 512))
curr_construction_image = Image.fromarray(255 - curr_sketch[:, :, 2] + curr_sketch[:, :, 0])
if np.sum(255 - np.array(curr_construction_image)) == 0:
curr_construction_image = None
curr_detail_image = Image.fromarray(curr_sketch[:, :, 2]).resize((512, 512))
if curr_construction_image is not None:
dilation_mask = Image.fromarray(255 - np.array(curr_construction_image)).filter(ImageFilter.MaxFilter(dilation))
dilation_mask = dilation_mask.point( lambda p: 256 if p > 0 else 25).filter(ImageFilter.GaussianBlur(radius = 5))
neg_dilation_mask = Image.fromarray(255 - np.array(curr_detail_image)).filter(ImageFilter.MaxFilter(contour_dilation))
neg_dilation_mask = np.array(neg_dilation_mask.point( lambda p: 256 if p > 0 else 0))
dilation_mask = np.array(dilation_mask)
dilation_mask[neg_dilation_mask > 0] = 25
dilation_mask = Image.fromarray(dilation_mask).filter(ImageFilter.GaussianBlur(radius = 5))
else:
dilation_mask = None
images, new_images = sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps = 40, dilation = dilation)
save_sketch = np.array(Image.fromarray(curr_sketch).convert("RGBA"))
save_sketch[:, :, 3][save_sketch[:, :, 0] > 128] = 0
overlays = []
for i in images:
background = i.copy()
background.putalpha(80)
background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background)
overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).convert("RGBA"))
overlays.append(overlay.convert("RGB"))
new_overlays = []
for i in new_images:
background = i.copy()
background.putalpha(80)
background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background)
overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).convert("RGBA"))
new_overlays.append(overlay.convert("RGB"))
global all_gens
all_gens = new_images
return new_images, new_overlays, images, overlays
def reset(sketch_states):
for k in range(len(sketch_states)):
sketch_states[k] = [None, None]
return None, sketch_states
def change_color(stroke_type):
if stroke_type == "Blocking":
color = "#0000FF"
else:
color = "#000000"
return gr.Image(source="canvas", shape=(512, 512), tool="color-sketch",
min_width=512, brush_radius = 2, brush_color=color).style(width=400, height=400)
def change_background(option):
global all_gens
if option == "None" or len(all_gens) == 0:
return None
elif option == "Sample 0":
image_overlay = all_gens[0].copy()
elif option == "Sample 1":
image_overlay = all_gens[0].copy()
else:
return None
image_overlay.putalpha(80)
return image_overlay
def change_num_samples(change):
global curr_num_samples
curr_num_samples = change
return None
btn.click(run_sketching, [prompt_box, canvas, sketch_states, dilation_strength[0]], [gallery0, gallery1, gallery2, gallery3])
btn2.click(reset, sketch_states, [canvas, sketch_states])
stroke_type[0].change(change_color, [stroke_type[0]], canvas)
num_samples[0].change(change_num_samples, [num_samples[0]], None)
demo.launch(share = True, debug = True) |