File size: 4,081 Bytes
fd1c028
092fcaa
 
 
fd1c028
 
 
092fcaa
 
 
fd1c028
 
092fcaa
 
 
 
 
 
 
 
 
fd1c028
092fcaa
 
 
 
 
 
 
 
 
 
fd1c028
 
 
 
 
 
 
 
092fcaa
fd1c028
 
 
092fcaa
 
 
 
 
 
 
 
 
 
 
5855e29
092fcaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd1c028
 
092fcaa
 
 
 
 
 
 
 
 
 
 
 
 
fd1c028
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers.utils import load_image
from transformers import DPTImageProcessor, DPTForDepthEstimation
import torch
import mediapy
import sa_handler
import pipeline_calls



# init models

depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")

controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-depth-sdxl-1.0",
    variant="fp16",
    use_safetensors=True,
    torch_dtype=torch.float16,
).to("cuda")
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    vae=vae,
    variant="fp16",
    use_safetensors=True,
    torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_model_cpu_offload()

sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
                                      share_layer_norm=False,
                                      share_attention=True,
                                      adain_queries=True,
                                      adain_keys=True,
                                      adain_values=False,
                                     )
handler = sa_handler.Handler(pipeline)
handler.register(sa_args, )


# get depth maps
def get_depth_maps(image):
    image = load_image(image) #("./example_image/train.png")
    depth_image1 = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)
    #depth_image2 = load_image("./example_image/sun.png").resize((1024, 1024))
    #mediapy.show_images([depth_image1, depth_image2])
    return depth_image1 #[depth_image1, depth_image2]



# run ControlNet depth with StyleAligned
def style_aligned_controlnet(reference_prompt, target_prompt, image):
    #reference_prompt = "a poster in flat design style"
    #target_prompts = [target_prompts] #["a train in flat design style", "the sun in flat design style"]
    controlnet_conditioning_scale = 0.8
    num_images_per_prompt = 1 # adjust according to VRAM size
    depth_map = get_depth_maps(image)
    latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
    #for deph_map, target_prompt in zip((depth_image1, depth_image2), target_prompts):
    latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
    images = pipeline_calls.controlnet_call(pipeline, [reference_prompt, target_prompt],
                                            image=deph_map,
                                            num_inference_steps=50,
                                            controlnet_conditioning_scale=controlnet_conditioning_scale,
                                            num_images_per_prompt=num_images_per_prompt,
                                           latents=latents)
    print(f"images -{images}")
    return images[0]  

#mediapy.show_images([images[0], deph_map] +  images[1:], titles=["reference", "depth"] + [f'result {i}' for i in range(1, len(images))])



with gr.Blocks() as demo:
    with gr.Row(variant='panel'):
      with gr.Group():
        gr.Markdown("### <center>Reference Prompt and Image</center>")
        ref_prompt = gr.Textbox(label="Enter a Prompt describing the reference image", placeholder='a photo of <object> in <style name> style')
        depth_map = gr.Image(label="Upload the image to get Depth Map", )
      with gr.Group():
        gr.Markdown("### <center>Prompt for generation and generated Image</center>")
        prompt = gr.Textbox(label="Enter a Prompt", placeholder='a photo of <object> in <style name> style')
        output = gr.Image(label="Style-Aligned ControlNet",type='pil')
    btn = gr.Button("Generate", size='sm')
      
    btn.click(fn=greet, inputs=[ref_prompt, prompt, depth_map], outputs=output, api_name="style_aligned_controlnet")


demo.launch()