fffiloni multimodalart HF staff commited on
Commit
2d87298
1 Parent(s): b959ad7

ZeroGPU support (#2)

Browse files

- Support ZeroGPU (aca4f93b09d9e8b8f7406e9c8579c1943d09ebe3)


Co-authored-by: Apolinário from multimodal AI art <multimodalart@users.noreply.huggingface.co>

.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ doc/cn_example.jpg filter=lfs diff=lfs merge=lfs -text
37
+ doc/md_example.jpg filter=lfs diff=lfs merge=lfs -text
38
+ example_image/camel.png filter=lfs diff=lfs merge=lfs -text
39
+ example_image/train.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md CHANGED
@@ -3,7 +3,7 @@ title: StyleAligned Transfer
3
  emoji: 🐠
4
  colorFrom: blue
5
  colorTo: pink
6
- sdk: docker
7
  pinned: false
8
  ---
9
 
 
3
  emoji: 🐠
4
  colorFrom: blue
5
  colorTo: pink
6
+ sdk: gradio
7
  pinned: false
8
  ---
9
 
app.py CHANGED
@@ -6,6 +6,7 @@ import math
6
  from diffusers.utils import load_image
7
  import inversion
8
  import numpy as np
 
9
 
10
  # init models
11
 
@@ -22,6 +23,7 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
22
  pipeline.enable_model_cpu_offload()
23
  pipeline.enable_vae_slicing()
24
 
 
25
  def run(ref_path, ref_style, ref_prompt, prompt1, prompt2, prompt3):
26
  # DDIM inversion
27
  src_style = f"{ref_style}"
@@ -41,8 +43,6 @@ def run(ref_path, ref_style, ref_prompt, prompt1, prompt2, prompt3):
41
  prompts = [
42
  src_prompt,
43
  prompt1,
44
- prompt2,
45
- prompt3
46
  ]
47
 
48
  # some parameters you can adjust to control fidelity to reference
@@ -105,8 +105,8 @@ with gr.Blocks(css=css) as demo:
105
  with gr.Group():
106
  results = gr.Gallery()
107
  prompt1 = gr.Textbox(label="Prompt1", value="A man working on a laptop")
108
- prompt2 = gr.Textbox(label="Prompt2", value="A man eating pizza")
109
- prompt3 = gr.Textbox(label="Prompt3", value="A woman playing on saxophone")
110
  run_button = gr.Button("Submit")
111
 
112
 
 
6
  from diffusers.utils import load_image
7
  import inversion
8
  import numpy as np
9
+ import spaces
10
 
11
  # init models
12
 
 
23
  pipeline.enable_model_cpu_offload()
24
  pipeline.enable_vae_slicing()
25
 
26
+ @spaces.GPU(duration=120)
27
  def run(ref_path, ref_style, ref_prompt, prompt1, prompt2, prompt3):
28
  # DDIM inversion
29
  src_style = f"{ref_style}"
 
43
  prompts = [
44
  src_prompt,
45
  prompt1,
 
 
46
  ]
47
 
48
  # some parameters you can adjust to control fidelity to reference
 
105
  with gr.Group():
106
  results = gr.Gallery()
107
  prompt1 = gr.Textbox(label="Prompt1", value="A man working on a laptop")
108
+ prompt2 = gr.Textbox(label="Prompt2", value="A man eating pizza", visible=False)
109
+ prompt3 = gr.Textbox(label="Prompt3", value="A woman playing on saxophone", visible=False)
110
  run_button = gr.Button("Submit")
111
 
112
 
contributing.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to Contribute
2
+
3
+ We'd love to accept your patches and contributions to this project. There are
4
+ just a few small guidelines you need to follow.
5
+
6
+ ## Contributor License Agreement
7
+
8
+ Contributions to this project must be accompanied by a Contributor License
9
+ Agreement. You (or your employer) retain the copyright to your contribution;
10
+ this simply gives us permission to use and redistribute your contributions as
11
+ part of the project. Head over to <https://cla.developers.google.com/> to see
12
+ your current agreements on file or to sign a new one.
13
+
14
+ You generally only need to submit a CLA once, so if you've already submitted one
15
+ (even if it was for a different project), you probably don't need to do it
16
+ again.
17
+
18
+ ## Code Reviews
19
+
20
+ All submissions, including submissions by project members, require review. We
21
+ use GitHub pull requests for this purpose. Consult
22
+ [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23
+ information on using pull requests.
24
+
25
+ ## Community Guidelines
26
+
27
+ This project follows [Google's Open Source Community
28
+ Guidelines](https://opensource.google/conduct/).
demo_stylealigned_controlnet.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
3
+ from diffusers.utils import load_image
4
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
5
+ import torch
6
+ import sa_handler
7
+ import pipeline_calls
8
+
9
+
10
+
11
+ # Initialize models
12
+ depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
13
+ feature_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
14
+
15
+ controlnet = ControlNetModel.from_pretrained(
16
+ "diffusers/controlnet-depth-sdxl-1.0",
17
+ variant="fp16",
18
+ use_safetensors=True,
19
+ torch_dtype=torch.float16,
20
+ ).to("cuda")
21
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
22
+ pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
23
+ "stabilityai/stable-diffusion-xl-base-1.0",
24
+ controlnet=controlnet,
25
+ vae=vae,
26
+ variant="fp16",
27
+ use_safetensors=True,
28
+ torch_dtype=torch.float16,
29
+ ).to("cuda")
30
+ # Configure pipeline for CPU offloading and VAE slicing
31
+ pipeline.enable_model_cpu_offload()
32
+ pipeline.enable_vae_slicing()
33
+
34
+ # Initialize style-aligned handler
35
+ sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
36
+ share_layer_norm=False,
37
+ share_attention=True,
38
+ adain_queries=True,
39
+ adain_keys=True,
40
+ adain_values=False,
41
+ )
42
+ handler = sa_handler.Handler(pipeline)
43
+ handler.register(sa_args, )
44
+
45
+
46
+ # Function to run ControlNet depth with StyleAligned
47
+ def style_aligned_controlnet(ref_style_prompt, depth_map, ref_image, img_generation_prompt, seed):
48
+ try:
49
+ if depth_map == True:
50
+ image = load_image(ref_image)
51
+ depth_image = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)
52
+ else:
53
+ depth_image = load_image(ref_image).resize((1024, 1024))
54
+ controlnet_conditioning_scale = 0.8
55
+ gen = None if seed is None else torch.manual_seed(int(seed))
56
+ num_images_per_prompt = 3 # adjust according to VRAM size
57
+ latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128, generator=gen).to(pipeline.unet.dtype)
58
+
59
+ images = pipeline_calls.controlnet_call(pipeline, [ref_style_prompt, img_generation_prompt],
60
+ image=depth_image,
61
+ num_inference_steps=50,
62
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
63
+ num_images_per_prompt=num_images_per_prompt,
64
+ latents=latents)
65
+ return [images[0], depth_image] + images[1:], gr.Image(value=images[0], visible=True)
66
+ except Exception as e:
67
+ raise gr.Error(f"Error in generating images:{e}")
68
+
69
+ # Create a Gradio UI
70
+ with gr.Blocks() as demo:
71
+ gr.HTML('<h1 style="text-align: center;">ControlNet with StyleAligned</h1>')
72
+ with gr.Row():
73
+
74
+ with gr.Column(variant='panel'):
75
+ # Textbox for reference style prompt
76
+ ref_style_prompt = gr.Textbox(
77
+ label='Reference style prompt',
78
+ info="Enter a Prompt to generate the reference image", placeholder='a poster in <style name> style'
79
+ )
80
+ with gr.Row(variant='panel'):
81
+ # Checkbox for using controller depth-map
82
+ depth_map = gr.Checkbox(label='Depth-map',)
83
+ seed = gr.Number(value=1234, label="Seed", precision=0, step=1, scale=3,
84
+ info="Enter a seed of a previous reference image "
85
+ "or leave empty for a random generation.")
86
+ # Image display for the generated reference style image
87
+ ref_style_image = gr.Image(visible=False, label='Reference style image', scale=1)
88
+
89
+
90
+ with gr.Column(variant='panel'):
91
+ # Image upload option for uploading a reference image for controlnet
92
+ ref_image = gr.Image(label="Upload the reference image",
93
+ type='filepath' )
94
+ # Textbox for ControlNet prompt
95
+ img_generation_prompt = gr.Textbox(
96
+ label='Generation Prompt',
97
+ info="Enter a Prompt to generate images using ControlNet and StyleAligned",
98
+ )
99
+
100
+ # Button to trigger image generation
101
+ btn = gr.Button("Generate", size='sm')
102
+ # Gallery to display generated images
103
+ gallery = gr.Gallery(label="Style-Aligned ControlNet - Generated images",
104
+ elem_id="gallery",
105
+ columns=5,
106
+ rows=1,
107
+ object_fit="contain",
108
+ height="auto",
109
+ )
110
+
111
+ btn.click(fn=style_aligned_controlnet,
112
+ inputs=[ref_style_prompt, depth_map, ref_image, img_generation_prompt, seed],
113
+ outputs=[gallery, ref_style_image],
114
+ api_name="style_aligned_controlnet")
115
+
116
+
117
+ # Example inputs for the Gradio interface
118
+ gr.Examples(
119
+ examples=[
120
+ ['A couple sitting a wooden bench, in colorful clay animation, claymation style.', True,
121
+ 'example_image/train.png', 'A train in colorful clay animation, claymation style.',],
122
+ ['A couple sitting a wooden bench, in colorful clay animation, claymation style.', False,
123
+ 'example_image/sun.png', 'Sun in colorful clay animation, claymation style.',],
124
+ ['A poster in a papercut art style.', False,
125
+ 'example_image/A.png', 'Letter A in a papercut art style.', None],
126
+ ['A bull in a low-poly, colorful origami style.', True, 'example_image/whale.png',
127
+ 'A whale in a low-poly, colorful origami style.', None],
128
+ ['An image in ancient egyptian art style, hieroglyphics style.', True, 'example_image/camel.png',
129
+ 'A camel in a painterly, digital illustration style.',],
130
+ ['An image in ancient egyptian art style, hieroglyphics style.', True, 'example_image/whale.png',
131
+ 'A whale in ancient egyptian art style, hieroglyphics style.',],
132
+ ],
133
+ inputs=[ref_style_prompt, depth_map, ref_image, img_generation_prompt,],
134
+ outputs=[gallery, ref_style_image],
135
+ fn=style_aligned_controlnet,
136
+ )
137
+
138
+ # Launch the Gradio demo
139
+ demo.launch()
demo_stylealigned_multidiffusion.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler
4
+ import sa_handler
5
+ import pipeline_calls
6
+
7
+
8
+ # init models
9
+ model_ckpt = "stabilityai/stable-diffusion-2-base"
10
+ scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
11
+ pipeline = StableDiffusionPanoramaPipeline.from_pretrained(
12
+ model_ckpt, scheduler=scheduler, torch_dtype=torch.float16
13
+ ).to("cuda")
14
+ # Configure the pipeline for CPU offloading and VAE slicing
15
+ pipeline.enable_model_cpu_offload()
16
+ pipeline.enable_vae_slicing()
17
+ sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,
18
+ share_layer_norm=True,
19
+ share_attention=True,
20
+ adain_queries=True,
21
+ adain_keys=True,
22
+ adain_values=False,
23
+ )
24
+ # Initialize the style-aligned handler
25
+ handler = sa_handler.Handler(pipeline)
26
+ handler.register(sa_args)
27
+
28
+
29
+ # Define the function to run MultiDiffusion with StyleAligned
30
+ def style_aligned_multidiff(ref_style_prompt, img_generation_prompt, seed):
31
+ try:
32
+ view_batch_size = 25 # adjust according to VRAM size
33
+ gen = None if seed is None else torch.manual_seed(int(seed))
34
+ reference_latent = torch.randn(1, 4, 64, 64, generator=gen)
35
+ images = pipeline_calls.panorama_call(pipeline,
36
+ [ref_style_prompt, img_generation_prompt],
37
+ reference_latent=reference_latent,
38
+ view_batch_size=view_batch_size)
39
+
40
+ return images, gr.Image(value=images[0], visible=True)
41
+ except Exception as e:
42
+ raise gr.Error(f"Error in generating images:{e}")
43
+
44
+ # Create a Gradio UI
45
+ with gr.Blocks() as demo:
46
+ gr.HTML('<h1 style="text-align: center;">MultiDiffusion with StyleAligned </h1>')
47
+ with gr.Row():
48
+ with gr.Column(variant='panel'):
49
+ # Textbox for reference style prompt
50
+ ref_style_prompt = gr.Textbox(
51
+ label='Reference style prompt',
52
+ info='Enter a Prompt to generate the reference image',
53
+ placeholder='A poster in a papercut art style.'
54
+ )
55
+ seed = gr.Number(value=1234, label="Seed", precision=0, step=1,
56
+ info="Enter a seed of a previous reference image "
57
+ "or leave empty for a random generation.")
58
+ # Image display for the reference style image
59
+ ref_style_image = gr.Image(visible=False, label='Reference style image')
60
+
61
+
62
+ with gr.Column(variant='panel'):
63
+ # Textbox for prompt for MultiDiffusion panoramas
64
+ img_generation_prompt = gr.Textbox(
65
+ label='MultiDiffusion Prompt',
66
+ info='Enter a Prompt to generate panoramic images using Style-aligned combined with MultiDiffusion',
67
+ placeholder= 'A village in a papercut art style.'
68
+ )
69
+
70
+ # Button to trigger image generation
71
+ btn = gr.Button('Style Aligned MultiDiffusion - Generate', size='sm')
72
+ # Gallery to display generated style image and the panorama
73
+ gallery = gr.Gallery(label='StyleAligned MultiDiffusion - generated images',
74
+ elem_id='gallery',
75
+ columns=5,
76
+ rows=1,
77
+ object_fit='contain',
78
+ height='auto',
79
+ allow_preview=True,
80
+ preview=True,
81
+ )
82
+ # Button click event
83
+ btn.click(fn=style_aligned_multidiff,
84
+ inputs=[ref_style_prompt, img_generation_prompt, seed],
85
+ outputs=[gallery, ref_style_image,],
86
+ api_name='style_aligned_multidiffusion')
87
+
88
+ # Example inputs for the Gradio demo
89
+ gr.Examples(
90
+ examples=[
91
+ ['A poster in a papercut art style.', 'A village in a papercut art style.'],
92
+ ['A poster in a papercut art style.', 'Futuristic cityscape in a papercut art style.'],
93
+ ['A poster in a papercut art style.', 'A jungle in a papercut art style.'],
94
+ ['A poster in a flat design style.', 'Giraffes in a flat design style.'],
95
+ ['A poster in a flat design style.', 'Houses in a flat design style.'],
96
+ ['A poster in a flat design style.', 'Mountains in a flat design style.'],
97
+ ],
98
+ inputs=[ref_style_prompt, img_generation_prompt],
99
+ outputs=[gallery, ref_style_image],
100
+ fn=style_aligned_multidiff,
101
+ )
102
+
103
+ # Launch the Gradio demo
104
+ demo.launch()
demo_stylealigned_sdxl.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from diffusers import StableDiffusionXLPipeline, DDIMScheduler
3
+ import torch
4
+ import sa_handler
5
+
6
+ # init models
7
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
8
+ set_alpha_to_one=False)
9
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
10
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True,
11
+ scheduler=scheduler
12
+ ).to("cuda")
13
+ # Configure the pipeline for CPU offloading and VAE slicing#pipeline.enable_sequential_cpu_offload()
14
+ pipeline.enable_model_cpu_offload()
15
+ pipeline.enable_vae_slicing()
16
+ # Initialize the style-aligned handler
17
+ handler = sa_handler.Handler(pipeline)
18
+ sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
19
+ share_layer_norm=False,
20
+ share_attention=True,
21
+ adain_queries=True,
22
+ adain_keys=True,
23
+ adain_values=False,
24
+ )
25
+
26
+ handler.register(sa_args, )
27
+
28
+ # Define the function to generate style-aligned images
29
+ def style_aligned_sdxl(initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4,
30
+ initial_prompt5, style_prompt, seed):
31
+ try:
32
+ # Combine the style prompt with each initial prompt
33
+ gen = None if seed is None else torch.manual_seed(int(seed))
34
+ sets_of_prompts = [prompt + " in the style of " + style_prompt for prompt in [initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5] if prompt]
35
+ # Generate images using the pipeline
36
+ images = pipeline(sets_of_prompts, generator=gen).images
37
+ return images
38
+ except Exception as e:
39
+ raise gr.Error(f"Error in generating images: {e}")
40
+
41
+ with gr.Blocks() as demo:
42
+ gr.HTML('<h1 style="text-align: center;">StyleAligned SDXL</h1>')
43
+ with gr.Group():
44
+ with gr.Column():
45
+ with gr.Accordion(label='Enter upto 5 different initial prompts', open=True):
46
+ with gr.Row(variant='panel'):
47
+ # Textboxes for initial prompts
48
+ initial_prompt1 = gr.Textbox(label='Initial prompt 1', value='', show_label=False, container=False, placeholder='a toy train')
49
+ initial_prompt2 = gr.Textbox(label='Initial prompt 2', value='', show_label=False, container=False, placeholder='a toy airplane')
50
+ initial_prompt3 = gr.Textbox(label='Initial prompt 3', value='', show_label=False, container=False, placeholder='a toy bicycle')
51
+ initial_prompt4 = gr.Textbox(label='Initial prompt 4', value='', show_label=False, container=False, placeholder='a toy car')
52
+ initial_prompt5 = gr.Textbox(label='Initial prompt 5', value='', show_label=False, container=False, placeholder='a toy boat')
53
+ with gr.Row():
54
+ # Textbox for the style prompt
55
+ style_prompt = gr.Textbox(label="Enter a style prompt", placeholder='macro photo, 3d game asset', scale=3)
56
+ seed = gr.Number(value=1234, label="Seed", precision=0, step=1, scale=1,
57
+ info="Enter a seed of a previous run "
58
+ "or leave empty for a random generation.")
59
+ # Button to generate images
60
+ btn = gr.Button("Generate a set of Style-aligned SDXL images",)
61
+ # Display the generated images
62
+ output = gr.Gallery(label="Style aligned text-to-image on SDXL ", elem_id="gallery",columns=5, rows=1,
63
+ object_fit="contain", height="auto",)
64
+
65
+ # Button click event
66
+ btn.click(fn=style_aligned_sdxl,
67
+ inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5,
68
+ style_prompt, seed],
69
+ outputs=output,
70
+ api_name="style_aligned_sdxl")
71
+
72
+ # Providing Example inputs for the demo
73
+ gr.Examples(examples=[
74
+ ["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "macro photo. 3d game asset."],
75
+ ["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "BW logo. high contrast."],
76
+ ["a cat", "a dog", "a bear", "a man on a bicycle", "a girl working on laptop", "minimal origami."],
77
+ ["a firewoman", "a Gardner", "a scientist", "a policewoman", "a saxophone player", "made of claymation, stop motion animation."],
78
+ ["a firewoman", "a Gardner", "a scientist", "a policewoman", "a saxophone player", "sketch, character sheet."],
79
+ ],
80
+ inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt],
81
+ outputs=[output],
82
+ fn=style_aligned_sdxl)
83
+
84
+ # Launch the Gradio demo
85
+ demo.launch()
doc/cn_example.jpg ADDED

Git LFS Details

  • SHA256: 76f94e53ddca8389ba142bcb644127a4d8b8b2b13cc8b21f3959434989968dea
  • Pointer size: 132 Bytes
  • Size of remote file: 1.08 MB
doc/md_example.jpg ADDED

Git LFS Details

  • SHA256: 2457056b024e2a1cacac71fe0332e688efb46dd535ace9df51fc326dcf10131a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.62 MB
doc/sa_example.jpg ADDED
doc/sa_transfer_example.jpeg ADDED
example_image/A.png ADDED
example_image/camel.png ADDED

Git LFS Details

  • SHA256: 7b8bf33c6f7837c8de8105aa6943007e63c8590afa2ff1db3e2ca47317597034
  • Pointer size: 132 Bytes
  • Size of remote file: 1.1 MB
example_image/medieval-bed.jpeg ADDED
example_image/sun.png ADDED
example_image/train.png ADDED

Git LFS Details

  • SHA256: 4f6c557bfb56274d3f99f07145c42e3a2380e66849c3b8654efad202c0d09a68
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
example_image/whale.png ADDED
inversion.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+ from typing import Callable
18
+ from diffusers import StableDiffusionXLPipeline
19
+ import torch
20
+ from tqdm import tqdm
21
+ import numpy as np
22
+
23
+
24
+ T = torch.Tensor
25
+ TN = T | None
26
+ InversionCallback = Callable[[StableDiffusionXLPipeline, int, T, dict[str, T]], dict[str, T]]
27
+
28
+
29
+ def _get_text_embeddings(prompt: str, tokenizer, text_encoder, device):
30
+ # Tokenize text and get embeddings
31
+ text_inputs = tokenizer(prompt, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt')
32
+ text_input_ids = text_inputs.input_ids
33
+
34
+ with torch.no_grad():
35
+ prompt_embeds = text_encoder(
36
+ text_input_ids.to(device),
37
+ output_hidden_states=True,
38
+ )
39
+
40
+ pooled_prompt_embeds = prompt_embeds[0]
41
+ prompt_embeds = prompt_embeds.hidden_states[-2]
42
+ if prompt == '':
43
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
44
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
45
+ return negative_prompt_embeds, negative_pooled_prompt_embeds
46
+ return prompt_embeds, pooled_prompt_embeds
47
+
48
+
49
+ def _encode_text_sdxl(model: StableDiffusionXLPipeline, prompt: str) -> tuple[dict[str, T], T]:
50
+ device = model._execution_device
51
+ prompt_embeds, pooled_prompt_embeds, = _get_text_embeddings(prompt, model.tokenizer, model.text_encoder, device)
52
+ prompt_embeds_2, pooled_prompt_embeds2, = _get_text_embeddings( prompt, model.tokenizer_2, model.text_encoder_2, device)
53
+ prompt_embeds = torch.cat((prompt_embeds, prompt_embeds_2), dim=-1)
54
+ text_encoder_projection_dim = model.text_encoder_2.config.projection_dim
55
+ add_time_ids = model._get_add_time_ids((1024, 1024), (0, 0), (1024, 1024), torch.float16,
56
+ text_encoder_projection_dim).to(device)
57
+ added_cond_kwargs = {"text_embeds": pooled_prompt_embeds2, "time_ids": add_time_ids}
58
+ return added_cond_kwargs, prompt_embeds
59
+
60
+
61
+ def _encode_text_sdxl_with_negative(model: StableDiffusionXLPipeline, prompt: str) -> tuple[dict[str, T], T]:
62
+ added_cond_kwargs, prompt_embeds = _encode_text_sdxl(model, prompt)
63
+ added_cond_kwargs_uncond, prompt_embeds_uncond = _encode_text_sdxl(model, "")
64
+ prompt_embeds = torch.cat((prompt_embeds_uncond, prompt_embeds, ))
65
+ added_cond_kwargs = {"text_embeds": torch.cat((added_cond_kwargs_uncond["text_embeds"], added_cond_kwargs["text_embeds"])),
66
+ "time_ids": torch.cat((added_cond_kwargs_uncond["time_ids"], added_cond_kwargs["time_ids"])),}
67
+ return added_cond_kwargs, prompt_embeds
68
+
69
+
70
+ def _encode_image(model: StableDiffusionXLPipeline, image: np.ndarray) -> T:
71
+ model.vae.to(dtype=torch.float32)
72
+ image = torch.from_numpy(image).float() / 255.
73
+ image = (image * 2 - 1).permute(2, 0, 1).unsqueeze(0)
74
+ latent = model.vae.encode(image.to(model.vae.device))['latent_dist'].mean * model.vae.config.scaling_factor
75
+ model.vae.to(dtype=torch.float16)
76
+ return latent
77
+
78
+
79
+ def _next_step(model: StableDiffusionXLPipeline, model_output: T, timestep: int, sample: T) -> T:
80
+ timestep, next_timestep = min(timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep
81
+ alpha_prod_t = model.scheduler.alphas_cumprod[int(timestep)] if timestep >= 0 else model.scheduler.final_alpha_cumprod
82
+ alpha_prod_t_next = model.scheduler.alphas_cumprod[int(next_timestep)]
83
+ beta_prod_t = 1 - alpha_prod_t
84
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
85
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
86
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
87
+ return next_sample
88
+
89
+
90
+ def _get_noise_pred(model: StableDiffusionXLPipeline, latent: T, t: T, context: T, guidance_scale: float, added_cond_kwargs: dict[str, T]):
91
+ latents_input = torch.cat([latent] * 2)
92
+ noise_pred = model.unet(latents_input, t, encoder_hidden_states=context, added_cond_kwargs=added_cond_kwargs)["sample"]
93
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
94
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
95
+ # latents = next_step(model, noise_pred, t, latent)
96
+ return noise_pred
97
+
98
+
99
+ def _ddim_loop(model: StableDiffusionXLPipeline, z0, prompt, guidance_scale) -> T:
100
+ all_latent = [z0]
101
+ added_cond_kwargs, text_embedding = _encode_text_sdxl_with_negative(model, prompt)
102
+ latent = z0.clone().detach().half()
103
+ for i in tqdm(range(model.scheduler.num_inference_steps)):
104
+ t = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1]
105
+ noise_pred = _get_noise_pred(model, latent, t, text_embedding, guidance_scale, added_cond_kwargs)
106
+ latent = _next_step(model, noise_pred, t, latent)
107
+ all_latent.append(latent)
108
+ return torch.cat(all_latent).flip(0)
109
+
110
+
111
+ def make_inversion_callback(zts, offset: int = 0) -> [T, InversionCallback]:
112
+
113
+ def callback_on_step_end(pipeline: StableDiffusionXLPipeline, i: int, t: T, callback_kwargs: dict[str, T]) -> dict[str, T]:
114
+ latents = callback_kwargs['latents']
115
+ latents[0] = zts[max(offset + 1, i + 1)].to(latents.device, latents.dtype)
116
+ return {'latents': latents}
117
+ return zts[offset], callback_on_step_end
118
+
119
+
120
+ @torch.no_grad()
121
+ def ddim_inversion(model: StableDiffusionXLPipeline, x0: np.ndarray, prompt: str, num_inference_steps: int, guidance_scale,) -> T:
122
+ z0 = _encode_image(model, x0)
123
+ model.scheduler.set_timesteps(num_inference_steps, device=z0.device)
124
+ zs = _ddim_loop(model, z0, prompt, guidance_scale)
125
+ return zs
pipeline_calls.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+ from typing import Any
18
+ import torch
19
+ import numpy as np
20
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
21
+ from diffusers.image_processor import PipelineImageInput
22
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
23
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
24
+ from diffusers import StableDiffusionPanoramaPipeline
25
+ from PIL import Image
26
+ import copy
27
+
28
+ T = torch.Tensor
29
+ TN = T | None
30
+
31
+
32
+ def get_depth_map(image: Image, feature_processor: DPTImageProcessor, depth_estimator: DPTForDepthEstimation) -> Image:
33
+ image = feature_processor(images=image, return_tensors="pt").pixel_values.to("cuda")
34
+ with torch.no_grad(), torch.autocast("cuda"):
35
+ depth_map = depth_estimator(image).predicted_depth
36
+
37
+ depth_map = torch.nn.functional.interpolate(
38
+ depth_map.unsqueeze(1),
39
+ size=(1024, 1024),
40
+ mode="bicubic",
41
+ align_corners=False,
42
+ )
43
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
44
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
45
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
46
+ image = torch.cat([depth_map] * 3, dim=1)
47
+
48
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
49
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
50
+ return image
51
+
52
+
53
+ def concat_zero_control(control_reisduel: T) -> T:
54
+ b = control_reisduel.shape[0] // 2
55
+ zerso_reisduel = torch.zeros_like(control_reisduel[0:1])
56
+ return torch.cat((zerso_reisduel, control_reisduel[:b], zerso_reisduel, control_reisduel[b::]))
57
+
58
+
59
+ @torch.no_grad()
60
+ def controlnet_call(
61
+ pipeline: StableDiffusionXLControlNetPipeline,
62
+ prompt: str | list[str] = None,
63
+ prompt_2: str | list[str] | None = None,
64
+ image: PipelineImageInput = None,
65
+ height: int | None = None,
66
+ width: int | None = None,
67
+ num_inference_steps: int = 50,
68
+ guidance_scale: float = 5.0,
69
+ negative_prompt: str | list[str] | None = None,
70
+ negative_prompt_2: str | list[str] | None = None,
71
+ num_images_per_prompt: int = 1,
72
+ eta: float = 0.0,
73
+ generator: torch.Generator | None = None,
74
+ latents: TN = None,
75
+ prompt_embeds: TN = None,
76
+ negative_prompt_embeds: TN = None,
77
+ pooled_prompt_embeds: TN = None,
78
+ negative_pooled_prompt_embeds: TN = None,
79
+ cross_attention_kwargs: dict[str, Any] | None = None,
80
+ controlnet_conditioning_scale: float | list[float] = 1.0,
81
+ control_guidance_start: float | list[float] = 0.0,
82
+ control_guidance_end: float | list[float] = 1.0,
83
+ original_size: tuple[int, int] = None,
84
+ crops_coords_top_left: tuple[int, int] = (0, 0),
85
+ target_size: tuple[int, int] | None = None,
86
+ negative_original_size: tuple[int, int] | None = None,
87
+ negative_crops_coords_top_left: tuple[int, int] = (0, 0),
88
+ negative_target_size:tuple[int, int] | None = None,
89
+ clip_skip: int | None = None,
90
+ ) -> list[Image]:
91
+ controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet
92
+
93
+ # align format for control guidance
94
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
95
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
96
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
97
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
98
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
99
+ mult = 1
100
+ control_guidance_start, control_guidance_end = (
101
+ mult * [control_guidance_start],
102
+ mult * [control_guidance_end],
103
+ )
104
+
105
+ # 1. Check inputs. Raise error if not correct
106
+ pipeline.check_inputs(
107
+ prompt,
108
+ prompt_2,
109
+ image,
110
+ 1,
111
+ negative_prompt,
112
+ negative_prompt_2,
113
+ prompt_embeds,
114
+ negative_prompt_embeds,
115
+ pooled_prompt_embeds,
116
+ negative_pooled_prompt_embeds,
117
+ controlnet_conditioning_scale,
118
+ control_guidance_start,
119
+ control_guidance_end,
120
+ )
121
+
122
+ pipeline._guidance_scale = guidance_scale
123
+
124
+ # 2. Define call parameters
125
+ if prompt is not None and isinstance(prompt, str):
126
+ batch_size = 1
127
+ elif prompt is not None and isinstance(prompt, list):
128
+ batch_size = len(prompt)
129
+ else:
130
+ batch_size = prompt_embeds.shape[0]
131
+
132
+ device = pipeline._execution_device
133
+
134
+ # 3. Encode input prompt
135
+ text_encoder_lora_scale = (
136
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
137
+ )
138
+ (
139
+ prompt_embeds,
140
+ negative_prompt_embeds,
141
+ pooled_prompt_embeds,
142
+ negative_pooled_prompt_embeds,
143
+ ) = pipeline.encode_prompt(
144
+ prompt,
145
+ prompt_2,
146
+ device,
147
+ 1,
148
+ True,
149
+ negative_prompt,
150
+ negative_prompt_2,
151
+ prompt_embeds=prompt_embeds,
152
+ negative_prompt_embeds=negative_prompt_embeds,
153
+ pooled_prompt_embeds=pooled_prompt_embeds,
154
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
155
+ lora_scale=text_encoder_lora_scale,
156
+ clip_skip=clip_skip,
157
+ )
158
+
159
+ # 4. Prepare image
160
+ if isinstance(controlnet, ControlNetModel):
161
+ image = pipeline.prepare_image(
162
+ image=image,
163
+ width=width,
164
+ height=height,
165
+ batch_size=1,
166
+ num_images_per_prompt=1,
167
+ device=device,
168
+ dtype=controlnet.dtype,
169
+ do_classifier_free_guidance=True,
170
+ guess_mode=False,
171
+ )
172
+ height, width = image.shape[-2:]
173
+ image = torch.stack([image[0]] * num_images_per_prompt + [image[1]] * num_images_per_prompt)
174
+ else:
175
+ assert False
176
+ # 5. Prepare timesteps
177
+ pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
178
+ timesteps = pipeline.scheduler.timesteps
179
+
180
+ # 6. Prepare latent variables
181
+ num_channels_latents = pipeline.unet.config.in_channels
182
+ latents = pipeline.prepare_latents(
183
+ 1 + num_images_per_prompt,
184
+ num_channels_latents,
185
+ height,
186
+ width,
187
+ prompt_embeds.dtype,
188
+ device,
189
+ generator,
190
+ latents,
191
+ )
192
+
193
+ # 6.5 Optionally get Guidance Scale Embedding
194
+ timestep_cond = None
195
+
196
+ # 7. Prepare extra step kwargs.
197
+ extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
198
+
199
+ # 7.1 Create tensor stating which controlnets to keep
200
+ controlnet_keep = []
201
+ for i in range(len(timesteps)):
202
+ keeps = [
203
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
204
+ for s, e in zip(control_guidance_start, control_guidance_end)
205
+ ]
206
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
207
+
208
+ # 7.2 Prepare added time ids & embeddings
209
+ if isinstance(image, list):
210
+ original_size = original_size or image[0].shape[-2:]
211
+ else:
212
+ original_size = original_size or image.shape[-2:]
213
+ target_size = target_size or (height, width)
214
+
215
+ add_text_embeds = pooled_prompt_embeds
216
+ if pipeline.text_encoder_2 is None:
217
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
218
+ else:
219
+ text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim
220
+
221
+ add_time_ids = pipeline._get_add_time_ids(
222
+ original_size,
223
+ crops_coords_top_left,
224
+ target_size,
225
+ dtype=prompt_embeds.dtype,
226
+ text_encoder_projection_dim=text_encoder_projection_dim,
227
+ )
228
+
229
+ if negative_original_size is not None and negative_target_size is not None:
230
+ negative_add_time_ids = pipeline._get_add_time_ids(
231
+ negative_original_size,
232
+ negative_crops_coords_top_left,
233
+ negative_target_size,
234
+ dtype=prompt_embeds.dtype,
235
+ text_encoder_projection_dim=text_encoder_projection_dim,
236
+ )
237
+ else:
238
+ negative_add_time_ids = add_time_ids
239
+
240
+ prompt_embeds = torch.stack([prompt_embeds[0]] + [prompt_embeds[1]] * num_images_per_prompt)
241
+ negative_prompt_embeds = torch.stack([negative_prompt_embeds[0]] + [negative_prompt_embeds[1]] * num_images_per_prompt)
242
+ negative_pooled_prompt_embeds = torch.stack([negative_pooled_prompt_embeds[0]] + [negative_pooled_prompt_embeds[1]] * num_images_per_prompt)
243
+ add_text_embeds = torch.stack([add_text_embeds[0]] + [add_text_embeds[1]] * num_images_per_prompt)
244
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
245
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
246
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
247
+
248
+ prompt_embeds = prompt_embeds.to(device)
249
+ add_text_embeds = add_text_embeds.to(device)
250
+ add_time_ids = add_time_ids.to(device).repeat(1 + num_images_per_prompt, 1)
251
+ batch_size = num_images_per_prompt + 1
252
+ # 8. Denoising loop
253
+ num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
254
+ is_unet_compiled = is_compiled_module(pipeline.unet)
255
+ is_controlnet_compiled = is_compiled_module(pipeline.controlnet)
256
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
257
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
258
+ controlnet_prompt_embeds = torch.cat((prompt_embeds[1:batch_size], prompt_embeds[1:batch_size]))
259
+ controlnet_added_cond_kwargs = {key: torch.cat((item[1:batch_size,], item[1:batch_size])) for key, item in added_cond_kwargs.items()}
260
+ with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
261
+ for i, t in enumerate(timesteps):
262
+ # Relevant thread:
263
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
264
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
265
+ torch._inductor.cudagraph_mark_step_begin()
266
+ # expand the latents if we are doing classifier free guidance
267
+ latent_model_input = torch.cat([latents] * 2)
268
+ latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
269
+
270
+ # controlnet(s) inference
271
+ control_model_input = torch.cat((latent_model_input[1:batch_size,], latent_model_input[batch_size+1:]))
272
+
273
+ if isinstance(controlnet_keep[i], list):
274
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
275
+ else:
276
+ controlnet_cond_scale = controlnet_conditioning_scale
277
+ if isinstance(controlnet_cond_scale, list):
278
+ controlnet_cond_scale = controlnet_cond_scale[0]
279
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
280
+ if cond_scale > 0:
281
+ down_block_res_samples, mid_block_res_sample = pipeline.controlnet(
282
+ control_model_input,
283
+ t,
284
+ encoder_hidden_states=controlnet_prompt_embeds,
285
+ controlnet_cond=image,
286
+ conditioning_scale=cond_scale,
287
+ guess_mode=False,
288
+ added_cond_kwargs=controlnet_added_cond_kwargs,
289
+ return_dict=False,
290
+ )
291
+
292
+ mid_block_res_sample = concat_zero_control(mid_block_res_sample)
293
+ down_block_res_samples = [concat_zero_control(down_block_res_sample) for down_block_res_sample in down_block_res_samples]
294
+ else:
295
+ mid_block_res_sample = down_block_res_samples = None
296
+ # predict the noise residual
297
+ noise_pred = pipeline.unet(
298
+ latent_model_input,
299
+ t,
300
+ encoder_hidden_states=prompt_embeds,
301
+ timestep_cond=timestep_cond,
302
+ cross_attention_kwargs=cross_attention_kwargs,
303
+ down_block_additional_residuals=down_block_res_samples,
304
+ mid_block_additional_residual=mid_block_res_sample,
305
+ added_cond_kwargs=added_cond_kwargs,
306
+ return_dict=False,
307
+ )[0]
308
+
309
+ # perform guidance
310
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
311
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
312
+
313
+ # compute the previous noisy sample x_t -> x_t-1
314
+ latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
315
+
316
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
317
+ progress_bar.update()
318
+
319
+ # manually for max memory savings
320
+ if pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast:
321
+ pipeline.upcast_vae()
322
+ latents = latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype)
323
+
324
+ # make sure the VAE is in float32 mode, as it overflows in float16
325
+ needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast
326
+
327
+ if needs_upcasting:
328
+ pipeline.upcast_vae()
329
+ latents = latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype)
330
+
331
+ image = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor, return_dict=False)[0]
332
+
333
+ # cast back to fp16 if needed
334
+ if needs_upcasting:
335
+ pipeline.vae.to(dtype=torch.float16)
336
+
337
+ if pipeline.watermark is not None:
338
+ image = pipeline.watermark.apply_watermark(image)
339
+
340
+ image = pipeline.image_processor.postprocess(image, output_type='pil')
341
+
342
+ # Offload all models
343
+ pipeline.maybe_free_model_hooks()
344
+ return image
345
+
346
+
347
+ @torch.no_grad()
348
+ def panorama_call(
349
+ pipeline: StableDiffusionPanoramaPipeline,
350
+ prompt: list[str],
351
+ height: int | None = 512,
352
+ width: int | None = 2048,
353
+ num_inference_steps: int = 50,
354
+ guidance_scale: float = 7.5,
355
+ view_batch_size: int = 1,
356
+ negative_prompt: str | list[str] | None = None,
357
+ num_images_per_prompt: int | None = 1,
358
+ eta: float = 0.0,
359
+ generator: torch.Generator | None = None,
360
+ reference_latent: TN = None,
361
+ latents: TN = None,
362
+ prompt_embeds: TN = None,
363
+ negative_prompt_embeds: TN = None,
364
+ cross_attention_kwargs: dict[str, Any] | None = None,
365
+ circular_padding: bool = False,
366
+ clip_skip: int | None = None,
367
+ stride=8
368
+ ) -> list[Image]:
369
+ # 0. Default height and width to unet
370
+ height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
371
+ width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
372
+
373
+ # 1. Check inputs. Raise error if not correct
374
+ pipeline.check_inputs(
375
+ prompt, height, width, 1, negative_prompt, prompt_embeds, negative_prompt_embeds
376
+ )
377
+
378
+ device = pipeline._execution_device
379
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
380
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
381
+ # corresponds to doing no classifier free guidance.
382
+ do_classifier_free_guidance = guidance_scale > 1.0
383
+
384
+ # 3. Encode input prompt
385
+ text_encoder_lora_scale = (
386
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
387
+ )
388
+ prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
389
+ prompt,
390
+ device,
391
+ num_images_per_prompt,
392
+ do_classifier_free_guidance,
393
+ negative_prompt,
394
+ prompt_embeds=prompt_embeds,
395
+ negative_prompt_embeds=negative_prompt_embeds,
396
+ lora_scale=text_encoder_lora_scale,
397
+ clip_skip=clip_skip,
398
+ )
399
+ # For classifier free guidance, we need to do two forward passes.
400
+ # Here we concatenate the unconditional and text embeddings into a single batch
401
+ # to avoid doing two forward passes
402
+
403
+ # 4. Prepare timesteps
404
+ pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
405
+ timesteps = pipeline.scheduler.timesteps
406
+
407
+ # 5. Prepare latent variables
408
+ num_channels_latents = pipeline.unet.config.in_channels
409
+ latents = pipeline.prepare_latents(
410
+ 1,
411
+ num_channels_latents,
412
+ height,
413
+ width,
414
+ prompt_embeds.dtype,
415
+ device,
416
+ generator,
417
+ latents,
418
+ )
419
+ if reference_latent is None:
420
+ reference_latent = torch.randn(1, 4, pipeline.unet.config.sample_size, pipeline.unet.config.sample_size,
421
+ generator=generator)
422
+ reference_latent = reference_latent.to(device=device, dtype=pipeline.unet.dtype)
423
+ # 6. Define panorama grid and initialize views for synthesis.
424
+ # prepare batch grid
425
+ views = pipeline.get_views(height, width, circular_padding=circular_padding, stride=stride)
426
+ views_batch = [views[i: i + view_batch_size] for i in range(0, len(views), view_batch_size)]
427
+ views_scheduler_status = [copy.deepcopy(pipeline.scheduler.__dict__)] * len(views_batch)
428
+ count = torch.zeros_like(latents)
429
+ value = torch.zeros_like(latents)
430
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
431
+ extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
432
+
433
+ # 8. Denoising loop
434
+ # Each denoising step also includes refinement of the latents with respect to the
435
+ # views.
436
+ num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
437
+
438
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds[:1],
439
+ *[negative_prompt_embeds[1:]] * view_batch_size]
440
+ )
441
+ prompt_embeds = torch.cat([prompt_embeds[:1],
442
+ *[prompt_embeds[1:]] * view_batch_size]
443
+ )
444
+
445
+ with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
446
+ for i, t in enumerate(timesteps):
447
+ count.zero_()
448
+ value.zero_()
449
+
450
+ # generate views
451
+ # Here, we iterate through different spatial crops of the latents and denoise them. These
452
+ # denoised (latent) crops are then averaged to produce the final latent
453
+ # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
454
+ # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
455
+ # Batch views denoise
456
+ for j, batch_view in enumerate(views_batch):
457
+ vb_size = len(batch_view)
458
+ # get the latents corresponding to the current view coordinates
459
+ if circular_padding:
460
+ latents_for_view = []
461
+ for h_start, h_end, w_start, w_end in batch_view:
462
+ if w_end > latents.shape[3]:
463
+ # Add circular horizontal padding
464
+ latent_view = torch.cat(
465
+ (
466
+ latents[:, :, h_start:h_end, w_start:],
467
+ latents[:, :, h_start:h_end, : w_end - latents.shape[3]],
468
+ ),
469
+ dim=-1,
470
+ )
471
+ else:
472
+ latent_view = latents[:, :, h_start:h_end, w_start:w_end]
473
+ latents_for_view.append(latent_view)
474
+ latents_for_view = torch.cat(latents_for_view)
475
+ else:
476
+ latents_for_view = torch.cat(
477
+ [
478
+ latents[:, :, h_start:h_end, w_start:w_end]
479
+ for h_start, h_end, w_start, w_end in batch_view
480
+ ]
481
+ )
482
+ # rematch block's scheduler status
483
+ pipeline.scheduler.__dict__.update(views_scheduler_status[j])
484
+
485
+ # expand the latents if we are doing classifier free guidance
486
+ latent_reference_plus_view = torch.cat((reference_latent, latents_for_view))
487
+ latent_model_input = latent_reference_plus_view.repeat(2, 1, 1, 1)
488
+ prompt_embeds_input = torch.cat([negative_prompt_embeds[: 1 + vb_size],
489
+ prompt_embeds[: 1 + vb_size]]
490
+ )
491
+ latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
492
+ # predict the noise residual
493
+ # return
494
+ noise_pred = pipeline.unet(
495
+ latent_model_input,
496
+ t,
497
+ encoder_hidden_states=prompt_embeds_input,
498
+ cross_attention_kwargs=cross_attention_kwargs,
499
+ ).sample
500
+
501
+ # perform guidance
502
+
503
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
504
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
505
+ # compute the previous noisy sample x_t -> x_t-1
506
+ latent_reference_plus_view = pipeline.scheduler.step(
507
+ noise_pred, t, latent_reference_plus_view, **extra_step_kwargs
508
+ ).prev_sample
509
+ if j == len(views_batch) - 1:
510
+ reference_latent = latent_reference_plus_view[:1]
511
+ latents_denoised_batch = latent_reference_plus_view[1:]
512
+ # save views scheduler status after sample
513
+ views_scheduler_status[j] = copy.deepcopy(pipeline.scheduler.__dict__)
514
+
515
+ # extract value from batch
516
+ for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
517
+ latents_denoised_batch.chunk(vb_size), batch_view
518
+ ):
519
+ if circular_padding and w_end > latents.shape[3]:
520
+ # Case for circular padding
521
+ value[:, :, h_start:h_end, w_start:] += latents_view_denoised[
522
+ :, :, h_start:h_end, : latents.shape[3] - w_start
523
+ ]
524
+ value[:, :, h_start:h_end, : w_end - latents.shape[3]] += latents_view_denoised[
525
+ :, :, h_start:h_end,
526
+ latents.shape[3] - w_start:
527
+ ]
528
+ count[:, :, h_start:h_end, w_start:] += 1
529
+ count[:, :, h_start:h_end, : w_end - latents.shape[3]] += 1
530
+ else:
531
+ value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
532
+ count[:, :, h_start:h_end, w_start:w_end] += 1
533
+
534
+ # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
535
+ latents = torch.where(count > 0, value / count, value)
536
+
537
+ # call the callback, if provided
538
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
539
+ progress_bar.update()
540
+
541
+ if circular_padding:
542
+ image = pipeline.decode_latents_with_padding(latents)
543
+ else:
544
+ image = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor, return_dict=False)[0]
545
+ reference_image = pipeline.vae.decode(reference_latent / pipeline.vae.config.scaling_factor, return_dict=False)[0]
546
+ # image, has_nsfw_concept = pipeline.run_safety_checker(image, device, prompt_embeds.dtype)
547
+ # reference_image, _ = pipeline.run_safety_checker(reference_image, device, prompt_embeds.dtype)
548
+
549
+ image = pipeline.image_processor.postprocess(image, output_type='pil', do_denormalize=[True])
550
+ reference_image = pipeline.image_processor.postprocess(reference_image, output_type='pil', do_denormalize=[True])
551
+ pipeline.maybe_free_model_hooks()
552
+ return reference_image + image
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ diffusers
2
+ transformers
3
+ accelerate
4
+ mediapy
5
+ ipywidgets
6
+ einops
sa_handler.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+
18
+ from dataclasses import dataclass
19
+ from diffusers import StableDiffusionXLPipeline
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.nn import functional as nnf
23
+ from diffusers.models import attention_processor
24
+ import einops
25
+
26
+ T = torch.Tensor
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class StyleAlignedArgs:
31
+ share_group_norm: bool = True
32
+ share_layer_norm: bool = True,
33
+ share_attention: bool = True
34
+ adain_queries: bool = True
35
+ adain_keys: bool = True
36
+ adain_values: bool = False
37
+ full_attention_share: bool = False
38
+ shared_score_scale: float = 1.
39
+ shared_score_shift: float = 0.
40
+ only_self_level: float = 0.
41
+
42
+
43
+ def expand_first(feat: T, scale=1.,) -> T:
44
+ b = feat.shape[0]
45
+ feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
46
+ if scale == 1:
47
+ feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
48
+ else:
49
+ feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
50
+ feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
51
+ return feat_style.reshape(*feat.shape)
52
+
53
+
54
+ def concat_first(feat: T, dim=2, scale=1.) -> T:
55
+ feat_style = expand_first(feat, scale=scale)
56
+ return torch.cat((feat, feat_style), dim=dim)
57
+
58
+
59
+ def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]:
60
+ feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
61
+ feat_mean = feat.mean(dim=-2, keepdims=True)
62
+ return feat_mean, feat_std
63
+
64
+
65
+ def adain(feat: T) -> T:
66
+ feat_mean, feat_std = calc_mean_std(feat)
67
+ feat_style_mean = expand_first(feat_mean)
68
+ feat_style_std = expand_first(feat_std)
69
+ feat = (feat - feat_mean) / feat_std
70
+ feat = feat * feat_style_std + feat_style_mean
71
+ return feat
72
+
73
+
74
+ class DefaultAttentionProcessor(nn.Module):
75
+
76
+ def __init__(self):
77
+ super().__init__()
78
+ self.processor = attention_processor.AttnProcessor2_0()
79
+
80
+ def __call__(self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None,
81
+ attention_mask=None, **kwargs):
82
+ return self.processor(attn, hidden_states, encoder_hidden_states, attention_mask)
83
+
84
+
85
+ class SharedAttentionProcessor(DefaultAttentionProcessor):
86
+
87
+ def shifted_scaled_dot_product_attention(self, attn: attention_processor.Attention, query: T, key: T, value: T) -> T:
88
+ logits = torch.einsum('bhqd,bhkd->bhqk', query, key) * attn.scale
89
+ logits[:, :, :, query.shape[2]:] += self.shared_score_shift
90
+ probs = logits.softmax(-1)
91
+ return torch.einsum('bhqk,bhkd->bhqd', probs, value)
92
+
93
+ def shared_call(
94
+ self,
95
+ attn: attention_processor.Attention,
96
+ hidden_states,
97
+ encoder_hidden_states=None,
98
+ attention_mask=None,
99
+ **kwargs
100
+ ):
101
+
102
+ residual = hidden_states
103
+ input_ndim = hidden_states.ndim
104
+ if input_ndim == 4:
105
+ batch_size, channel, height, width = hidden_states.shape
106
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
107
+ batch_size, sequence_length, _ = (
108
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
109
+ )
110
+
111
+ if attention_mask is not None:
112
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
113
+ # scaled_dot_product_attention expects attention_mask shape to be
114
+ # (batch, heads, source_length, target_length)
115
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
116
+
117
+ if attn.group_norm is not None:
118
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
119
+
120
+ query = attn.to_q(hidden_states)
121
+ key = attn.to_k(hidden_states)
122
+ value = attn.to_v(hidden_states)
123
+ inner_dim = key.shape[-1]
124
+ head_dim = inner_dim // attn.heads
125
+
126
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
127
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
128
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
129
+ # if self.step >= self.start_inject:
130
+ if self.adain_queries:
131
+ query = adain(query)
132
+ if self.adain_keys:
133
+ key = adain(key)
134
+ if self.adain_values:
135
+ value = adain(value)
136
+ if self.share_attention:
137
+ key = concat_first(key, -2, scale=self.shared_score_scale)
138
+ value = concat_first(value, -2)
139
+ if self.shared_score_shift != 0:
140
+ hidden_states = self.shifted_scaled_dot_product_attention(attn, query, key, value,)
141
+ else:
142
+ hidden_states = nnf.scaled_dot_product_attention(
143
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
144
+ )
145
+ else:
146
+ hidden_states = nnf.scaled_dot_product_attention(
147
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
148
+ )
149
+ # hidden_states = adain(hidden_states)
150
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
151
+ hidden_states = hidden_states.to(query.dtype)
152
+
153
+ # linear proj
154
+ hidden_states = attn.to_out[0](hidden_states)
155
+ # dropout
156
+ hidden_states = attn.to_out[1](hidden_states)
157
+
158
+ if input_ndim == 4:
159
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
160
+
161
+ if attn.residual_connection:
162
+ hidden_states = hidden_states + residual
163
+
164
+ hidden_states = hidden_states / attn.rescale_output_factor
165
+ return hidden_states
166
+
167
+ def __call__(self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None,
168
+ attention_mask=None, **kwargs):
169
+ if self.full_attention_share:
170
+ b, n, d = hidden_states.shape
171
+ hidden_states = einops.rearrange(hidden_states, '(k b) n d -> k (b n) d', k=2)
172
+ hidden_states = super().__call__(attn, hidden_states, encoder_hidden_states=encoder_hidden_states,
173
+ attention_mask=attention_mask, **kwargs)
174
+ hidden_states = einops.rearrange(hidden_states, 'k (b n) d -> (k b) n d', n=n)
175
+ else:
176
+ hidden_states = self.shared_call(attn, hidden_states, hidden_states, attention_mask, **kwargs)
177
+
178
+ return hidden_states
179
+
180
+ def __init__(self, style_aligned_args: StyleAlignedArgs):
181
+ super().__init__()
182
+ self.share_attention = style_aligned_args.share_attention
183
+ self.adain_queries = style_aligned_args.adain_queries
184
+ self.adain_keys = style_aligned_args.adain_keys
185
+ self.adain_values = style_aligned_args.adain_values
186
+ self.full_attention_share = style_aligned_args.full_attention_share
187
+ self.shared_score_scale = style_aligned_args.shared_score_scale
188
+ self.shared_score_shift = style_aligned_args.shared_score_shift
189
+
190
+
191
+ def _get_switch_vec(total_num_layers, level):
192
+ if level == 0:
193
+ return torch.zeros(total_num_layers, dtype=torch.bool)
194
+ if level == 1:
195
+ return torch.ones(total_num_layers, dtype=torch.bool)
196
+ to_flip = level > .5
197
+ if to_flip:
198
+ level = 1 - level
199
+ num_switch = int(level * total_num_layers)
200
+ vec = torch.arange(total_num_layers)
201
+ vec = vec % (total_num_layers // num_switch)
202
+ vec = vec == 0
203
+ if to_flip:
204
+ vec = ~vec
205
+ return vec
206
+
207
+
208
+ def init_attention_processors(pipeline: StableDiffusionXLPipeline, style_aligned_args: StyleAlignedArgs | None = None):
209
+ attn_procs = {}
210
+ unet = pipeline.unet
211
+ number_of_self, number_of_cross = 0, 0
212
+ num_self_layers = len([name for name in unet.attn_processors.keys() if 'attn1' in name])
213
+ if style_aligned_args is None:
214
+ only_self_vec = _get_switch_vec(num_self_layers, 1)
215
+ else:
216
+ only_self_vec = _get_switch_vec(num_self_layers, style_aligned_args.only_self_level)
217
+ for i, name in enumerate(unet.attn_processors.keys()):
218
+ is_self_attention = 'attn1' in name
219
+ if is_self_attention:
220
+ number_of_self += 1
221
+ if style_aligned_args is None or only_self_vec[i // 2]:
222
+ attn_procs[name] = DefaultAttentionProcessor()
223
+ else:
224
+ attn_procs[name] = SharedAttentionProcessor(style_aligned_args)
225
+ else:
226
+ number_of_cross += 1
227
+ attn_procs[name] = DefaultAttentionProcessor()
228
+
229
+ unet.set_attn_processor(attn_procs)
230
+
231
+
232
+ def register_shared_norm(pipeline: StableDiffusionXLPipeline,
233
+ share_group_norm: bool = True,
234
+ share_layer_norm: bool = True, ):
235
+ def register_norm_forward(norm_layer: nn.GroupNorm | nn.LayerNorm) -> nn.GroupNorm | nn.LayerNorm:
236
+ if not hasattr(norm_layer, 'orig_forward'):
237
+ setattr(norm_layer, 'orig_forward', norm_layer.forward)
238
+ orig_forward = norm_layer.orig_forward
239
+
240
+ def forward_(hidden_states: T) -> T:
241
+ n = hidden_states.shape[-2]
242
+ hidden_states = concat_first(hidden_states, dim=-2)
243
+ hidden_states = orig_forward(hidden_states)
244
+ return hidden_states[..., :n, :]
245
+
246
+ norm_layer.forward = forward_
247
+ return norm_layer
248
+
249
+ def get_norm_layers(pipeline_, norm_layers_: dict[str, list[nn.GroupNorm | nn.LayerNorm]]):
250
+ if isinstance(pipeline_, nn.LayerNorm) and share_layer_norm:
251
+ norm_layers_['layer'].append(pipeline_)
252
+ if isinstance(pipeline_, nn.GroupNorm) and share_group_norm:
253
+ norm_layers_['group'].append(pipeline_)
254
+ else:
255
+ for layer in pipeline_.children():
256
+ get_norm_layers(layer, norm_layers_)
257
+
258
+ norm_layers = {'group': [], 'layer': []}
259
+ get_norm_layers(pipeline.unet, norm_layers)
260
+ return [register_norm_forward(layer) for layer in norm_layers['group']] + [register_norm_forward(layer) for layer in
261
+ norm_layers['layer']]
262
+
263
+
264
+ class Handler:
265
+
266
+ def register(self, style_aligned_args: StyleAlignedArgs, ):
267
+ self.norm_layers = register_shared_norm(self.pipeline, style_aligned_args.share_group_norm,
268
+ style_aligned_args.share_layer_norm)
269
+ init_attention_processors(self.pipeline, style_aligned_args)
270
+
271
+ def remove(self):
272
+ for layer in self.norm_layers:
273
+ layer.forward = layer.orig_forward
274
+ self.norm_layers = []
275
+ init_attention_processors(self.pipeline, None)
276
+
277
+ def __init__(self, pipeline: StableDiffusionXLPipeline):
278
+ self.pipeline = pipeline
279
+ self.norm_layers = []
style_aligned_sd1.ipynb ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "a885cf5d-c525-4f5b-a8e4-f67d2f699909",
6
+ "metadata": {
7
+ "pycharm": {
8
+ "name": "#%% md\n"
9
+ }
10
+ },
11
+ "source": [
12
+ "## Copyright 2023 Google LLC"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "d891d022-8979-40d4-848f-ecb84c17f12c",
19
+ "metadata": {
20
+ "jp-MarkdownHeadingCollapsed": true,
21
+ "pycharm": {
22
+ "name": "#%%\n"
23
+ }
24
+ },
25
+ "outputs": [],
26
+ "source": [
27
+ "# Copyright 2023 Google LLC\n",
28
+ "#\n",
29
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
30
+ "# you may not use this file except in compliance with the License.\n",
31
+ "# You may obtain a copy of the License at\n",
32
+ "#\n",
33
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
34
+ "#\n",
35
+ "# Unless required by applicable law or agreed to in writing, software\n",
36
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
37
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
38
+ "# See the License for the specific language governing permissions and\n",
39
+ "# limitations under the License."
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "markdown",
44
+ "id": "540d8642-c203-471c-a66d-0d43aabb0706",
45
+ "metadata": {
46
+ "pycharm": {
47
+ "name": "#%% md\n"
48
+ }
49
+ },
50
+ "source": [
51
+ "# StyleAligned over SD1.4"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "id": "23d54ea7-f7ab-4548-9b10-ece87216dc18",
58
+ "metadata": {
59
+ "pycharm": {
60
+ "name": "#%%\n"
61
+ }
62
+ },
63
+ "outputs": [],
64
+ "source": [
65
+ "from diffusers import DDIMScheduler,StableDiffusionPipeline\n",
66
+ "import torch\n",
67
+ "import mediapy\n",
68
+ "import sa_handler\n",
69
+ "import math"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "id": "522b14e7-9768-4eaa-8433-bf88acb244c4",
76
+ "metadata": {
77
+ "pycharm": {
78
+ "name": "#%%\n"
79
+ }
80
+ },
81
+ "outputs": [],
82
+ "source": [
83
+ "scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False,\n",
84
+ " set_alpha_to_one=False)\n",
85
+ "pipeline = StableDiffusionPipeline.from_pretrained(\n",
86
+ " \"CompVis/stable-diffusion-v1-4\",\n",
87
+ " revision=\"fp16\",\n",
88
+ " scheduler=scheduler\n",
89
+ ")\n",
90
+ "pipeline = pipeline.to(\"cuda\")\n",
91
+ "\n",
92
+ "handler = sa_handler.Handler(pipeline)\n",
93
+ "sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,\n",
94
+ " share_layer_norm=True,\n",
95
+ " share_attention=True,\n",
96
+ " adain_queries=True,\n",
97
+ " adain_keys=True,\n",
98
+ " adain_values=False,\n",
99
+ " )\n",
100
+ "\n",
101
+ "handler.register(sa_args, )"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "id": "5db98c81-8b72-4fc7-8cd0-65eda17198e3",
108
+ "metadata": {
109
+ "pycharm": {
110
+ "name": "#%%\n"
111
+ }
112
+ },
113
+ "outputs": [],
114
+ "source": [
115
+ "# run StyleAligned\n",
116
+ "\n",
117
+ "sets_of_prompts = [\n",
118
+ " \"a toy train. macro photo. 3d game asset\",\n",
119
+ " \"a toy airplane. macro photo. 3d game asset\",\n",
120
+ " \"a toy bicycle. macro photo. 3d game asset\",\n",
121
+ " \"a toy car. macro photo. 3d game asset\",\n",
122
+ " \"a toy boat. macro photo. 3d game asset\",\n",
123
+ "]\n",
124
+ "images = pipeline(sets_of_prompts, generator=None).images\n",
125
+ "mediapy.show_images(images)"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "id": "afbe3876-22d9-4735-89b9-d5b5c46aea5c",
132
+ "metadata": {
133
+ "pycharm": {
134
+ "name": "#%%\n"
135
+ }
136
+ },
137
+ "outputs": [],
138
+ "source": []
139
+ }
140
+ ],
141
+ "metadata": {
142
+ "kernelspec": {
143
+ "display_name": "Python 3 (ipykernel)",
144
+ "language": "python",
145
+ "name": "python3"
146
+ },
147
+ "language_info": {
148
+ "codemirror_mode": {
149
+ "name": "ipython",
150
+ "version": 3
151
+ },
152
+ "file_extension": ".py",
153
+ "mimetype": "text/x-python",
154
+ "name": "python",
155
+ "nbconvert_exporter": "python",
156
+ "pygments_lexer": "ipython3",
157
+ "version": "3.11.5"
158
+ }
159
+ },
160
+ "nbformat": 4,
161
+ "nbformat_minor": 5
162
+ }
style_aligned_sdxl.ipynb ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "a885cf5d-c525-4f5b-a8e4-f67d2f699909",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Copyright 2023 Google LLC"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "d891d022-8979-40d4-848f-ecb84c17f12c",
15
+ "metadata": {
16
+ "jp-MarkdownHeadingCollapsed": true
17
+ },
18
+ "outputs": [],
19
+ "source": [
20
+ "# Copyright 2023 Google LLC\n",
21
+ "#\n",
22
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
23
+ "# you may not use this file except in compliance with the License.\n",
24
+ "# You may obtain a copy of the License at\n",
25
+ "#\n",
26
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
27
+ "#\n",
28
+ "# Unless required by applicable law or agreed to in writing, software\n",
29
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
30
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
31
+ "# See the License for the specific language governing permissions and\n",
32
+ "# limitations under the License."
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "id": "540d8642-c203-471c-a66d-0d43aabb0706",
38
+ "metadata": {},
39
+ "source": [
40
+ "# StyleAligned over SDXL"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "id": "23d54ea7-f7ab-4548-9b10-ece87216dc18",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "from diffusers import StableDiffusionXLPipeline, DDIMScheduler\n",
51
+ "import torch\n",
52
+ "import mediapy\n",
53
+ "import sa_handler"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "id": "c2f6f1e6-445f-47bc-b9db-0301caeb7490",
60
+ "metadata": {
61
+ "pycharm": {
62
+ "name": "#%%\n"
63
+ }
64
+ },
65
+ "outputs": [],
66
+ "source": [
67
+ "# init models\n",
68
+ "\n",
69
+ "scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False,\n",
70
+ " set_alpha_to_one=False)\n",
71
+ "pipeline = StableDiffusionXLPipeline.from_pretrained(\n",
72
+ " \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True,\n",
73
+ " scheduler=scheduler\n",
74
+ ").to(\"cuda\")\n",
75
+ "\n",
76
+ "handler = sa_handler.Handler(pipeline)\n",
77
+ "sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,\n",
78
+ " share_layer_norm=False,\n",
79
+ " share_attention=True,\n",
80
+ " adain_queries=True,\n",
81
+ " adain_keys=True,\n",
82
+ " adain_values=False,\n",
83
+ " )\n",
84
+ "\n",
85
+ "handler.register(sa_args, )"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "id": "5cca9256-0ce0-45c3-9cba-68c7eff1452f",
92
+ "metadata": {
93
+ "pycharm": {
94
+ "name": "#%%\n"
95
+ }
96
+ },
97
+ "outputs": [],
98
+ "source": [
99
+ "# run StyleAligned\n",
100
+ "\n",
101
+ "sets_of_prompts = [\n",
102
+ " \"a toy train. macro photo. 3d game asset\",\n",
103
+ " \"a toy airplane. macro photo. 3d game asset\",\n",
104
+ " \"a toy bicycle. macro photo. 3d game asset\",\n",
105
+ " \"a toy car. macro photo. 3d game asset\",\n",
106
+ " \"a toy boat. macro photo. 3d game asset\",\n",
107
+ "]\n",
108
+ "images = pipeline(sets_of_prompts,).images\n",
109
+ "mediapy.show_images(images)"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": null,
115
+ "id": "d819ad6d-0c19-411f-ba97-199909f64805",
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": []
119
+ }
120
+ ],
121
+ "metadata": {
122
+ "kernelspec": {
123
+ "display_name": "Python 3 (ipykernel)",
124
+ "language": "python",
125
+ "name": "python3"
126
+ },
127
+ "language_info": {
128
+ "codemirror_mode": {
129
+ "name": "ipython",
130
+ "version": 3
131
+ },
132
+ "file_extension": ".py",
133
+ "mimetype": "text/x-python",
134
+ "name": "python",
135
+ "nbconvert_exporter": "python",
136
+ "pygments_lexer": "ipython3",
137
+ "version": "3.11.5"
138
+ }
139
+ },
140
+ "nbformat": 4,
141
+ "nbformat_minor": 5
142
+ }
style_aligned_transfer_sdxl.ipynb ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "a885cf5d-c525-4f5b-a8e4-f67d2f699909",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Copyright 2023 Google LLC"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "d891d022-8979-40d4-848f-ecb84c17f12c",
15
+ "metadata": {
16
+ "jp-MarkdownHeadingCollapsed": true
17
+ },
18
+ "outputs": [],
19
+ "source": [
20
+ "# Copyright 2023 Google LLC\n",
21
+ "#\n",
22
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
23
+ "# you may not use this file except in compliance with the License.\n",
24
+ "# You may obtain a copy of the License at\n",
25
+ "#\n",
26
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
27
+ "#\n",
28
+ "# Unless required by applicable law or agreed to in writing, software\n",
29
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
30
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
31
+ "# See the License for the specific language governing permissions and\n",
32
+ "# limitations under the License."
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "id": "540d8642-c203-471c-a66d-0d43aabb0706",
38
+ "metadata": {},
39
+ "source": [
40
+ "# StyleAligned over SDXL from input image"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "id": "483d0cf9",
46
+ "metadata": {},
47
+ "source": [
48
+ "#### Model Load "
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "id": "23d54ea7-f7ab-4548-9b10-ece87216dc18",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "from diffusers import StableDiffusionXLPipeline, DDIMScheduler\n",
59
+ "import torch\n",
60
+ "import mediapy\n",
61
+ "import sa_handler\n",
62
+ "import math\n",
63
+ "\n",
64
+ "\n",
65
+ "scheduler = DDIMScheduler(\n",
66
+ " beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\",\n",
67
+ " clip_sample=False, set_alpha_to_one=False)\n",
68
+ "\n",
69
+ "pipeline = StableDiffusionXLPipeline.from_pretrained(\n",
70
+ " \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\",\n",
71
+ " use_safetensors=True,\n",
72
+ " scheduler=scheduler\n",
73
+ ").to(\"cuda\")"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "markdown",
78
+ "id": "c09b1a68",
79
+ "metadata": {
80
+ "pycharm": {
81
+ "name": "#%% md\n"
82
+ }
83
+ },
84
+ "source": [
85
+ "#### Ref image load and inversion"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "id": "f4717854",
92
+ "metadata": {
93
+ "pycharm": {
94
+ "name": "#%%\n"
95
+ }
96
+ },
97
+ "outputs": [],
98
+ "source": [
99
+ "# DDIM inversion\n",
100
+ "\n",
101
+ "from diffusers.utils import load_image\n",
102
+ "import inversion\n",
103
+ "import numpy as np\n",
104
+ "\n",
105
+ "src_style = \"medieval painting\"\n",
106
+ "src_prompt = f'Man laying in a bed, {src_style}.'\n",
107
+ "image_path = './example_image/medieval-bed.jpeg'\n",
108
+ "\n",
109
+ "num_inference_steps = 50\n",
110
+ "x0 = np.array(load_image(image_path).resize((1024, 1024)))\n",
111
+ "zts = inversion.ddim_inversion(pipeline, x0, src_prompt, num_inference_steps, 2)\n",
112
+ "mediapy.show_image(x0, title=\"innput reference image\", height=256)"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "id": "1751c4fe",
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "prompts = [\n",
123
+ " src_prompt,\n",
124
+ " \"A man working on a laptop\",\n",
125
+ " \"A man eats pizza\",\n",
126
+ " \"A woman playig on saxophone\",\n",
127
+ "]\n",
128
+ "\n",
129
+ "# some parameters you can adjust to control fidelity to reference\n",
130
+ "shared_score_shift = np.log(2) # higher value induces higher fidelity, set 0 for no shift\n",
131
+ "shared_score_scale = 1.0 # higher value induces higher, set 1 for no rescale\n",
132
+ "\n",
133
+ "# for very famouse images consider supressing attention to refference, here is a configuration example:\n",
134
+ "# shared_score_shift = np.log(1)\n",
135
+ "# shared_score_scale = 0.5\n",
136
+ "\n",
137
+ "for i in range(1, len(prompts)):\n",
138
+ " prompts[i] = f'{prompts[i]}, {src_style}.'\n",
139
+ "\n",
140
+ "handler = sa_handler.Handler(pipeline)\n",
141
+ "sa_args = sa_handler.StyleAlignedArgs(\n",
142
+ " share_group_norm=True, share_layer_norm=True, share_attention=True,\n",
143
+ " adain_queries=True, adain_keys=True, adain_values=False,\n",
144
+ " shared_score_shift=shared_score_shift, shared_score_scale=shared_score_scale,)\n",
145
+ "handler.register(sa_args)\n",
146
+ "\n",
147
+ "zT, inversion_callback = inversion.make_inversion_callback(zts, offset=5)\n",
148
+ "\n",
149
+ "g_cpu = torch.Generator(device='cpu')\n",
150
+ "g_cpu.manual_seed(10)\n",
151
+ "\n",
152
+ "latents = torch.randn(len(prompts), 4, 128, 128, device='cpu', generator=g_cpu,\n",
153
+ " dtype=pipeline.unet.dtype,).to('cuda:0')\n",
154
+ "latents[0] = zT\n",
155
+ "\n",
156
+ "images_a = pipeline(prompts, latents=latents,\n",
157
+ " callback_on_step_end=inversion_callback,\n",
158
+ " num_inference_steps=num_inference_steps, guidance_scale=10.0).images\n",
159
+ "\n",
160
+ "handler.remove()\n",
161
+ "mediapy.show_images(images_a, titles=[p[:-(len(src_style) + 3)] for p in prompts])"
162
+ ]
163
+ }
164
+ ],
165
+ "metadata": {
166
+ "kernelspec": {
167
+ "display_name": "Python 3",
168
+ "language": "python",
169
+ "name": "python3"
170
+ },
171
+ "language_info": {
172
+ "codemirror_mode": {
173
+ "name": "ipython",
174
+ "version": 3
175
+ },
176
+ "file_extension": ".py",
177
+ "mimetype": "text/x-python",
178
+ "name": "python",
179
+ "nbconvert_exporter": "python",
180
+ "pygments_lexer": "ipython3",
181
+ "version": "3.10.13"
182
+ }
183
+ },
184
+ "nbformat": 4,
185
+ "nbformat_minor": 5
186
+ }
style_aligned_w_controlnet.ipynb ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "f86ede39-8d9f-4da9-bc12-955f2fddd484",
6
+ "metadata": {
7
+ "pycharm": {
8
+ "name": "#%% md\n"
9
+ }
10
+ },
11
+ "source": [
12
+ "## Copyright 2023 Google LLC"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "3f3cbf47-a52b-48b1-9bd3-3435f92f2174",
19
+ "metadata": {
20
+ "pycharm": {
21
+ "name": "#%%\n"
22
+ }
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "# Copyright 2023 Google LLC\n",
27
+ "#\n",
28
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
29
+ "# you may not use this file except in compliance with the License.\n",
30
+ "# You may obtain a copy of the License at\n",
31
+ "#\n",
32
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
33
+ "#\n",
34
+ "# Unless required by applicable law or agreed to in writing, software\n",
35
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
36
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
37
+ "# See the License for the specific language governing permissions and\n",
38
+ "# limitations under the License."
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "id": "22de629b-581f-4335-9e7b-f73221d8dbcb",
44
+ "metadata": {
45
+ "pycharm": {
46
+ "name": "#%% md\n"
47
+ }
48
+ },
49
+ "source": [
50
+ "# ControlNet depth with StyleAligned over SDXL"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "id": "486b7ebb-c483-4bf0-ace8-f8092c2d1f23",
57
+ "metadata": {
58
+ "pycharm": {
59
+ "name": "#%%\n"
60
+ }
61
+ },
62
+ "outputs": [],
63
+ "source": [
64
+ "from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL\n",
65
+ "from diffusers.utils import load_image\n",
66
+ "from transformers import DPTImageProcessor, DPTForDepthEstimation\n",
67
+ "import torch\n",
68
+ "import mediapy\n",
69
+ "import sa_handler\n",
70
+ "import pipeline_calls"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "2a7e85e7-b5cf-45b2-946a-5ba1e4923586",
77
+ "metadata": {
78
+ "pycharm": {
79
+ "name": "#%%\n"
80
+ }
81
+ },
82
+ "outputs": [],
83
+ "source": [
84
+ "# init models\n",
85
+ "\n",
86
+ "depth_estimator = DPTForDepthEstimation.from_pretrained(\"Intel/dpt-hybrid-midas\").to(\"cuda\")\n",
87
+ "feature_processor = DPTImageProcessor.from_pretrained(\"Intel/dpt-hybrid-midas\")\n",
88
+ "\n",
89
+ "controlnet = ControlNetModel.from_pretrained(\n",
90
+ " \"diffusers/controlnet-depth-sdxl-1.0\",\n",
91
+ " variant=\"fp16\",\n",
92
+ " use_safetensors=True,\n",
93
+ " torch_dtype=torch.float16,\n",
94
+ ").to(\"cuda\")\n",
95
+ "vae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16).to(\"cuda\")\n",
96
+ "pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(\n",
97
+ " \"stabilityai/stable-diffusion-xl-base-1.0\",\n",
98
+ " controlnet=controlnet,\n",
99
+ " vae=vae,\n",
100
+ " variant=\"fp16\",\n",
101
+ " use_safetensors=True,\n",
102
+ " torch_dtype=torch.float16,\n",
103
+ ").to(\"cuda\")\n",
104
+ "pipeline.enable_model_cpu_offload()\n",
105
+ "\n",
106
+ "sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,\n",
107
+ " share_layer_norm=False,\n",
108
+ " share_attention=True,\n",
109
+ " adain_queries=True,\n",
110
+ " adain_keys=True,\n",
111
+ " adain_values=False,\n",
112
+ " )\n",
113
+ "handler = sa_handler.Handler(pipeline)\n",
114
+ "handler.register(sa_args, )"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "id": "94ca26b4-9061-4012-9400-8d97ef212d87",
121
+ "metadata": {
122
+ "pycharm": {
123
+ "name": "#%%\n"
124
+ }
125
+ },
126
+ "outputs": [],
127
+ "source": [
128
+ "# get depth maps\n",
129
+ "\n",
130
+ "image = load_image(\"./example_image/train.png\")\n",
131
+ "depth_image1 = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)\n",
132
+ "depth_image2 = load_image(\"./example_image/sun.png\").resize((1024, 1024))\n",
133
+ "mediapy.show_images([depth_image1, depth_image2])"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "id": "c8f56fe4-559f-49ff-a2d8-460dcfeb56a0",
140
+ "metadata": {
141
+ "pycharm": {
142
+ "name": "#%%\n"
143
+ }
144
+ },
145
+ "outputs": [],
146
+ "source": [
147
+ "# run ControlNet depth with StyleAligned\n",
148
+ "\n",
149
+ "reference_prompt = \"a poster in flat design style\"\n",
150
+ "target_prompts = [\"a train in flat design style\", \"the sun in flat design style\"]\n",
151
+ "controlnet_conditioning_scale = 0.8\n",
152
+ "num_images_per_prompt = 3 # adjust according to VRAM size\n",
153
+ "latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)\n",
154
+ "for deph_map, target_prompt in zip((depth_image1, depth_image2), target_prompts):\n",
155
+ " latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)\n",
156
+ " images = pipeline_calls.controlnet_call(pipeline, [reference_prompt, target_prompt],\n",
157
+ " image=deph_map,\n",
158
+ " num_inference_steps=50,\n",
159
+ " controlnet_conditioning_scale=controlnet_conditioning_scale,\n",
160
+ " num_images_per_prompt=num_images_per_prompt,\n",
161
+ " latents=latents)\n",
162
+ " \n",
163
+ " mediapy.show_images([images[0], deph_map] + images[1:], titles=[\"reference\", \"depth\"] + [f'result {i}' for i in range(1, len(images))])\n"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "id": "437ba4bd-6243-486b-8ba5-3b7cd661d53a",
170
+ "metadata": {
171
+ "pycharm": {
172
+ "name": "#%%\n"
173
+ }
174
+ },
175
+ "outputs": [],
176
+ "source": []
177
+ }
178
+ ],
179
+ "metadata": {
180
+ "kernelspec": {
181
+ "display_name": "Python 3 (ipykernel)",
182
+ "language": "python",
183
+ "name": "python3"
184
+ },
185
+ "language_info": {
186
+ "codemirror_mode": {
187
+ "name": "ipython",
188
+ "version": 3
189
+ },
190
+ "file_extension": ".py",
191
+ "mimetype": "text/x-python",
192
+ "name": "python",
193
+ "nbconvert_exporter": "python",
194
+ "pygments_lexer": "ipython3",
195
+ "version": "3.11.5"
196
+ }
197
+ },
198
+ "nbformat": 4,
199
+ "nbformat_minor": 5
200
+ }
style_aligned_w_multidiffusion.ipynb ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "50fa980f-1bae-40c1-a1f3-f5f89bef60d3",
6
+ "metadata": {
7
+ "pycharm": {
8
+ "name": "#%% md\n"
9
+ }
10
+ },
11
+ "source": [
12
+ "## Copyright 2023 Google LLC"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "5da5f038-057f-4475-a783-95660f98238c",
19
+ "metadata": {
20
+ "pycharm": {
21
+ "name": "#%%\n"
22
+ }
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "# Copyright 2023 Google LLC\n",
27
+ "#\n",
28
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
29
+ "# you may not use this file except in compliance with the License.\n",
30
+ "# You may obtain a copy of the License at\n",
31
+ "#\n",
32
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
33
+ "#\n",
34
+ "# Unless required by applicable law or agreed to in writing, software\n",
35
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
36
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
37
+ "# See the License for the specific language governing permissions and\n",
38
+ "# limitations under the License."
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "id": "c3a7c069-c441-4204-a905-59cbd9edc13a",
44
+ "metadata": {
45
+ "pycharm": {
46
+ "name": "#%% md\n"
47
+ }
48
+ },
49
+ "source": [
50
+ "# MultiDiffusion with StyleAligned over SD v2"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "id": "14178de7-d4c8-4881-ac1d-ff84bae57c6f",
57
+ "metadata": {
58
+ "pycharm": {
59
+ "name": "#%%\n"
60
+ }
61
+ },
62
+ "outputs": [],
63
+ "source": [
64
+ "import torch\n",
65
+ "from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler\n",
66
+ "import mediapy\n",
67
+ "import sa_handler\n",
68
+ "import pipeline_calls"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "id": "738cee0e-4d6e-4875-b4df-eadff6e27e7f",
75
+ "metadata": {
76
+ "pycharm": {
77
+ "name": "#%%\n"
78
+ }
79
+ },
80
+ "outputs": [],
81
+ "source": [
82
+ "# init models\n",
83
+ "model_ckpt = \"stabilityai/stable-diffusion-2-base\"\n",
84
+ "scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder=\"scheduler\")\n",
85
+ "pipeline = StableDiffusionPanoramaPipeline.from_pretrained(\n",
86
+ " model_ckpt, scheduler=scheduler, torch_dtype=torch.float16\n",
87
+ ").to(\"cuda\")\n",
88
+ "\n",
89
+ "sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,\n",
90
+ " share_layer_norm=True,\n",
91
+ " share_attention=True,\n",
92
+ " adain_queries=True,\n",
93
+ " adain_keys=True,\n",
94
+ " adain_values=False,\n",
95
+ " )\n",
96
+ "handler = sa_handler.Handler(pipeline)\n",
97
+ "handler.register(sa_args)"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "id": "ea61e789-2814-4820-8ae7-234c3c6640a0",
104
+ "metadata": {
105
+ "pycharm": {
106
+ "name": "#%%\n"
107
+ }
108
+ },
109
+ "outputs": [],
110
+ "source": [
111
+ "# run MultiDiffusion with StyleAligned\n",
112
+ "\n",
113
+ "reference_prompt = \"a beautiful papercut art design\"\n",
114
+ "target_prompts = [\"mountains in a beautiful papercut art design\", \"giraffes in a beautiful papercut art design\"]\n",
115
+ "view_batch_size = 25 # adjust according to VRAM size\n",
116
+ "reference_latent = torch.randn(1, 4, 64, 64,)\n",
117
+ "for target_prompt in target_prompts:\n",
118
+ " images = pipeline_calls.panorama_call(pipeline, [reference_prompt, target_prompt], reference_latent=reference_latent, view_batch_size=view_batch_size)\n",
119
+ " mediapy.show_images(images, titles=[\"reference\", \"result\"])"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": null,
125
+ "id": "791a9b28-f0ce-4fd0-9f3c-594281c2ae56",
126
+ "metadata": {
127
+ "pycharm": {
128
+ "name": "#%%\n"
129
+ }
130
+ },
131
+ "outputs": [],
132
+ "source": []
133
+ }
134
+ ],
135
+ "metadata": {
136
+ "kernelspec": {
137
+ "display_name": "Python 3 (ipykernel)",
138
+ "language": "python",
139
+ "name": "python3"
140
+ },
141
+ "language_info": {
142
+ "codemirror_mode": {
143
+ "name": "ipython",
144
+ "version": 3
145
+ },
146
+ "file_extension": ".py",
147
+ "mimetype": "text/x-python",
148
+ "name": "python",
149
+ "nbconvert_exporter": "python",
150
+ "pygments_lexer": "ipython3",
151
+ "version": "3.11.5"
152
+ }
153
+ },
154
+ "nbformat": 4,
155
+ "nbformat_minor": 5
156
+ }