''' from diffusers import utils from diffusers.utils import deprecation_utils from diffusers.models import cross_attention utils.deprecate = lambda *arg, **kwargs: None deprecation_utils.deprecate = lambda *arg, **kwargs: None cross_attention.deprecate = lambda *arg, **kwargs: None ''' import os import sys ''' MAIN_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.insert(0, MAIN_DIR) os.chdir(MAIN_DIR) ''' import gradio as gr import numpy as np import torch import random from annotator.util import resize_image, HWC3 from annotator.canny import CannyDetector from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.pipelines import DiffusionPipeline from diffusers.schedulers import DPMSolverMultistepScheduler #from models import ControlLoRA, ControlLoRACrossAttnProcessor apply_canny = CannyDetector() device = 'cuda' if torch.cuda.is_available() else 'cpu' ''' pipeline = DiffusionPipeline.from_pretrained( 'IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1', safety_checker=None ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(device) unet: UNet2DConditionModel = pipeline.unet #ckpt_path = "ckpts/sd-diffusiondb-canny-model-control-lora-zh" ckpt_path = "svjack/canny-control-lora-zh" control_lora = ControlLoRA.from_pretrained(ckpt_path) control_lora = control_lora.to(device) # load control lora attention processors lora_attn_procs = {} lora_layers_list = list([list(layer_list) for layer_list in control_lora.lora_layers]) n_ch = len(unet.config.block_out_channels) control_ids = [i for i in range(n_ch)] for name in pipeline.unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): control_id = control_ids[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) control_id = list(reversed(control_ids))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) control_id = control_ids[block_id] lora_layers = lora_layers_list[control_id] if len(lora_layers) != 0: lora_layer: ControlLoRACrossAttnProcessor = lora_layers.pop(0) lora_attn_procs[name] = lora_layer unet.set_attn_processor(lora_attn_procs) ''' from diffusers import ( AutoencoderKL, ControlNetModel, DDPMScheduler, StableDiffusionControlNetPipeline, UNet2DConditionModel, UniPCMultistepScheduler, ) import torch from diffusers.utils import load_image controlnet_model_name_or_path = "svjack/ControlNet-Canny-Zh" controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path) base_model_path = "IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1" pipe = StableDiffusionControlNetPipeline.from_pretrained( base_model_path, controlnet=controlnet, #torch_dtype=torch.float16 ) # speed up diffusion process with faster scheduler and memory optimization pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) #pipe.enable_model_cpu_offload() if device == "cuda": pipe = pipe.to("cuda") def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold): from PIL import Image with torch.no_grad(): img = resize_image(HWC3(input_image), image_resolution) H, W, C = img.shape detected_map = apply_canny(img, low_threshold, high_threshold) detected_map = HWC3(detected_map) ''' print(type(detected_map)) return [detected_map] control = torch.from_numpy(detected_map[...,::-1].copy().transpose([2,0,1])).float().to(device)[None] / 127.5 - 1 _ = control_lora(control).control_states if seed == -1: seed = random.randint(0, 65535) ''' if seed == -1: seed = random.randint(0, 65535) control_image = Image.fromarray(detected_map) # run inference generator = torch.Generator(device=device).manual_seed(seed) images = [] for i in range(num_samples): ''' _ = control_lora(control).control_states image = pipeline( prompt + ', ' + a_prompt, negative_prompt=n_prompt, num_inference_steps=sample_steps, guidance_scale=scale, eta=eta, generator=generator, height=H, width=W).images[0] ''' image = pipe( prompt + ', ' + a_prompt, negative_prompt=n_prompt, num_inference_steps=sample_steps, guidance_scale=scale, eta=eta, image = control_image, generator=generator, height=H, width=W).images[0] images.append(np.asarray(image)) results = images return [255 - detected_map] + results block = gr.Blocks().queue() with block: with gr.Row(): gr.Markdown("## Control Stable Diffusion with Canny Edge Maps") #gr.Markdown("This _example_ was **drive** from

[https://github.com/svjack/ControlLoRA-Chinese](https://github.com/svjack/ControlLoRA-Chinese)

\n") with gr.Row(): with gr.Column(): input_image = gr.Image(source='upload', type="numpy", value = "love_in_rose.png") prompt = gr.Textbox(label="Prompt", value = "沙滩上的俊俏美男子") run_button = gr.Button(label="Run") with gr.Accordion("Advanced options", open=False): num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256) low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1) high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1) sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) eta = gr.Number(label="eta", value=0.0) a_prompt = gr.Textbox(label="Added Prompt", value='详细的模拟混合媒体拼贴画,帆布质地的当代艺术风格,朋克艺术,逼真主义,感性的身体,表现主义,极简主义。杰作,完美的组成,逼真的美丽的脸') n_prompt = gr.Textbox(label="Negative Prompt", value='低质量,模糊,混乱') with gr.Column(): result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold] run_button.click(fn=process, inputs=ips, outputs=[result_gallery], show_progress = True) block.launch(server_name='0.0.0.0') #### block.launch(server_name='172.16.202.228', share=True)