ysharma HF staff commited on
Commit
092fcaa
1 Parent(s): b3ca606

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -20
app.py CHANGED
@@ -1,18 +1,36 @@
1
  import gradio as gr
2
- from diffusers import StableDiffusionXLPipeline, DDIMScheduler
 
 
3
  import torch
4
  import mediapy
5
  import sa_handler
 
 
 
6
 
7
  # init models
8
- scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
9
- set_alpha_to_one=False)
10
- pipeline = StableDiffusionXLPipeline.from_pretrained(
11
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True,
12
- scheduler=scheduler
 
 
 
 
13
  ).to("cuda")
 
 
 
 
 
 
 
 
 
 
14
 
15
- handler = sa_handler.Handler(pipeline)
16
  sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
17
  share_layer_norm=False,
18
  share_attention=True,
@@ -20,10 +38,42 @@ sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
20
  adain_keys=True,
21
  adain_values=False,
22
  )
23
-
24
  handler.register(sa_args, )
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # run StyleAligned
28
  sets_of_prompts = [
29
  "a toy train. macro photo. 3d game asset",
@@ -33,20 +83,21 @@ sets_of_prompts = [
33
  "a toy boat. macro photo. 3d game asset",
34
  ]
35
 
36
- def style_aligned_sdxl(prompt):
37
- images = pipeline([prompts],).images
38
- #mediapy.show_images(images)
39
- print(images)
40
- return images
41
 
42
  with gr.Blocks() as demo:
43
- with gr.Group():
44
- with gr.Row():
45
- prompt = gr.Textbox(label="Prompt", scale=8)
46
- btn = gr.Button("Greet", scale=2)
47
- output = gr.Image(label="Style-Aligned SDXL")
48
-
49
- btn.click(fn=style_aligned_sdxl, inputs=prompt, outputs=output, api_name="style_aligned_sdxl")
 
 
 
 
 
 
50
 
51
  demo.launch()
52
 
 
1
  import gradio as gr
2
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
3
+ from diffusers.utils import load_image
4
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
5
  import torch
6
  import mediapy
7
  import sa_handler
8
+ import pipeline_calls
9
+
10
+
11
 
12
  # init models
13
+
14
+ depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
15
+ feature_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
16
+
17
+ controlnet = ControlNetModel.from_pretrained(
18
+ "diffusers/controlnet-depth-sdxl-1.0",
19
+ variant="fp16",
20
+ use_safetensors=True,
21
+ torch_dtype=torch.float16,
22
  ).to("cuda")
23
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
24
+ pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
25
+ "stabilityai/stable-diffusion-xl-base-1.0",
26
+ controlnet=controlnet,
27
+ vae=vae,
28
+ variant="fp16",
29
+ use_safetensors=True,
30
+ torch_dtype=torch.float16,
31
+ ).to("cuda")
32
+ pipeline.enable_model_cpu_offload()
33
 
 
34
  sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
35
  share_layer_norm=False,
36
  share_attention=True,
 
38
  adain_keys=True,
39
  adain_values=False,
40
  )
41
+ handler = sa_handler.Handler(pipeline)
42
  handler.register(sa_args, )
43
 
44
 
45
+ # get depth maps
46
+ def get_depth_maps(image):
47
+ image = load_image(image) #("./example_image/train.png")
48
+ depth_image1 = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)
49
+ #depth_image2 = load_image("./example_image/sun.png").resize((1024, 1024))
50
+ #mediapy.show_images([depth_image1, depth_image2])
51
+ return depth_image1 #[depth_image1, depth_image2]
52
+
53
+
54
+
55
+ # run ControlNet depth with StyleAligned
56
+ def style_aligned_controlnet(reference_prompt, target_prompt, image)
57
+ #reference_prompt = "a poster in flat design style"
58
+ #target_prompts = [target_prompts] #["a train in flat design style", "the sun in flat design style"]
59
+ controlnet_conditioning_scale = 0.8
60
+ num_images_per_prompt = 1 # adjust according to VRAM size
61
+ depth_map = get_depth_maps(image)
62
+ latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
63
+ #for deph_map, target_prompt in zip((depth_image1, depth_image2), target_prompts):
64
+ latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
65
+ images = pipeline_calls.controlnet_call(pipeline, [reference_prompt, target_prompt],
66
+ image=deph_map,
67
+ num_inference_steps=50,
68
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
69
+ num_images_per_prompt=num_images_per_prompt,
70
+ latents=latents)
71
+ print(f"images -{images}")
72
+ return images[0]
73
+
74
+ #mediapy.show_images([images[0], deph_map] + images[1:], titles=["reference", "depth"] + [f'result {i}' for i in range(1, len(images))])
75
+
76
+
77
  # run StyleAligned
78
  sets_of_prompts = [
79
  "a toy train. macro photo. 3d game asset",
 
83
  "a toy boat. macro photo. 3d game asset",
84
  ]
85
 
 
 
 
 
 
86
 
87
  with gr.Blocks() as demo:
88
+ with gr.Row(variant='panel'):
89
+ with gr.Group():
90
+ gr.Markdown("### <center>Reference Prompt and Image</center>")
91
+ ref_prompt = gr.Textbox(label="Enter a Prompt describing the reference image", placeholder='a photo of <object> in <style name> style')
92
+ depth_map = gr.Image(label="Upload the image to get Depth Map", )
93
+ with gr.Group():
94
+ gr.Markdown("### <center>Prompt for generation and generated Image</center>")
95
+ prompt = gr.Textbox(label="Enter a Prompt", placeholder='a photo of <object> in <style name> style')
96
+ output = gr.Image(label="Style-Aligned ControlNet",type='pil')
97
+ btn = gr.Button("Generate", size='sm')
98
+
99
+ btn.click(fn=greet, inputs=[ref_prompt, prompt, depth_map], outputs=output, api_name="style_aligned_controlnet")
100
+
101
 
102
  demo.launch()
103