Spaces:
Sleeping
Sleeping
File size: 7,096 Bytes
e21f690 897452e e21f690 f85d15d e21f690 897452e e21f690 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
'''
from diffusers import utils
from diffusers.utils import deprecation_utils
from diffusers.models import cross_attention
utils.deprecate = lambda *arg, **kwargs: None
deprecation_utils.deprecate = lambda *arg, **kwargs: None
cross_attention.deprecate = lambda *arg, **kwargs: None
'''
import os
import sys
'''
MAIN_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, MAIN_DIR)
os.chdir(MAIN_DIR)
'''
import gradio as gr
import numpy as np
import torch
import random
from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines import DiffusionPipeline
from diffusers.schedulers import DPMSolverMultistepScheduler
#from models import ControlLoRA, ControlLoRACrossAttnProcessor
apply_canny = CannyDetector()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
'''
pipeline = DiffusionPipeline.from_pretrained(
'IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1', safety_checker=None
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(device)
unet: UNet2DConditionModel = pipeline.unet
#ckpt_path = "ckpts/sd-diffusiondb-canny-model-control-lora-zh"
ckpt_path = "svjack/canny-control-lora-zh"
control_lora = ControlLoRA.from_pretrained(ckpt_path)
control_lora = control_lora.to(device)
# load control lora attention processors
lora_attn_procs = {}
lora_layers_list = list([list(layer_list) for layer_list in control_lora.lora_layers])
n_ch = len(unet.config.block_out_channels)
control_ids = [i for i in range(n_ch)]
for name in pipeline.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
control_id = control_ids[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
control_id = list(reversed(control_ids))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
control_id = control_ids[block_id]
lora_layers = lora_layers_list[control_id]
if len(lora_layers) != 0:
lora_layer: ControlLoRACrossAttnProcessor = lora_layers.pop(0)
lora_attn_procs[name] = lora_layer
unet.set_attn_processor(lora_attn_procs)
'''
from diffusers import (
AutoencoderKL,
ControlNetModel,
DDPMScheduler,
StableDiffusionControlNetPipeline,
UNet2DConditionModel,
UniPCMultistepScheduler,
)
import torch
from diffusers.utils import load_image
controlnet_model_name_or_path = "svjack/ControlNet-Canny-Zh"
controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path)
base_model_path = "IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1"
pipe = StableDiffusionControlNetPipeline.from_pretrained(
base_model_path, controlnet=controlnet,
#torch_dtype=torch.float16
)
# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
#pipe.enable_model_cpu_offload()
if device == "cuda":
pipe = pipe.to("cuda")
pipe.safety_checker = None
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold):
from PIL import Image
with torch.no_grad():
img = resize_image(HWC3(input_image), image_resolution)
H, W, C = img.shape
detected_map = apply_canny(img, low_threshold, high_threshold)
detected_map = HWC3(detected_map)
'''
print(type(detected_map))
return [detected_map]
control = torch.from_numpy(detected_map[...,::-1].copy().transpose([2,0,1])).float().to(device)[None] / 127.5 - 1
_ = control_lora(control).control_states
if seed == -1:
seed = random.randint(0, 65535)
'''
if seed == -1:
seed = random.randint(0, 65535)
control_image = Image.fromarray(detected_map)
# run inference
generator = torch.Generator(device=device).manual_seed(seed)
images = []
for i in range(num_samples):
'''
_ = control_lora(control).control_states
image = pipeline(
prompt + ', ' + a_prompt, negative_prompt=n_prompt,
num_inference_steps=sample_steps, guidance_scale=scale, eta=eta,
generator=generator, height=H, width=W).images[0]
'''
image = pipe(
prompt + ', ' + a_prompt, negative_prompt=n_prompt,
num_inference_steps=sample_steps, guidance_scale=scale, eta=eta,
image = control_image,
generator=generator, height=H, width=W).images[0]
images.append(np.asarray(image))
results = images
return [255 - detected_map] + results
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## Control Stable Diffusion with Canny Edge Maps")
#gr.Markdown("This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/ControlLoRA-Chinese](https://github.com/svjack/ControlLoRA-Chinese)</h4></b>\n")
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', type="numpy", value = "hate_dog.png")
prompt = gr.Textbox(label="Prompt", value = "可爱的狗宝宝")
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
eta = gr.Number(label="eta", value=0.0)
a_prompt = gr.Textbox(label="Added Prompt", value='')
n_prompt = gr.Textbox(label="Negative Prompt",
value='低质量,模糊,混乱')
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery], show_progress = True)
block.launch(server_name='0.0.0.0')
#### block.launch(server_name='172.16.202.228', share=True)
|