Leimingkun commited on
Commit
3b2b77a
·
1 Parent(s): eae0f7a

stylestudio

Browse files
app.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./")
3
+ import gradio as gr
4
+ import torch
5
+ from ip_adapter.utils import BLOCKS as BLOCKS
6
+ import numpy as np
7
+ import random
8
+ from diffusers import (
9
+ AutoencoderKL,
10
+ StableDiffusionXLPipeline,
11
+ )
12
+ from ip_adapter import StyleStudio_Adapter
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
16
+ import os
17
+ os.system("git lfs install")
18
+ os.system("git clone https://huggingface.co/h94/IP-Adapter")
19
+ os.system("mv IP-Adapter/sdxl_models sdxl_models")
20
+
21
+ from huggingface_hub import hf_hub_download
22
+
23
+ # hf_hub_download(repo_id="h94/IP-Adapter", filename="sdxl_models/image_encoder", local_dir="./sdxl_models/image_encoder")
24
+ hf_hub_download(repo_id="InstantX/CSGO", filename="csgo_4_32.bin", local_dir="./CSGO/")
25
+ os.system('rm -rf IP-Adapter/models')
26
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
27
+ image_encoder_path = "sdxl_models/image_encoder"
28
+ csgo_ckpt ='./CSGO/csgo_4_32.bin'
29
+ pretrained_vae_name_or_path ='madebyollin/sdxl-vae-fp16-fix'
30
+ weight_dtype = torch.float16
31
+
32
+ vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
33
+ pipe = StableDiffusionXLPipeline.from_pretrained(
34
+ base_model_path,
35
+ torch_dtype=torch.float16,
36
+ add_watermarker=False,
37
+ vae=vae
38
+ )
39
+ pipe.enable_vae_tiling()
40
+
41
+ target_style_blocks = BLOCKS['style']
42
+
43
+ csgo = StyleStudio_Adapter(
44
+ pipe, image_encoder_path, csgo_ckpt, device, num_style_tokens=32,
45
+ target_style_blocks=target_style_blocks,
46
+ controlnet_adapter=False,
47
+ style_model_resampler=True,
48
+
49
+ fuSAttn=True,
50
+ end_fusion=20,
51
+ adainIP=True,
52
+ )
53
+
54
+ MAX_SEED = np.iinfo(np.int32).max
55
+
56
+
57
+ def get_example():
58
+ case = [
59
+ [
60
+ './assets/style1.jpg',
61
+ "Text-Driven Style Synthesis",
62
+ "A red apple",
63
+ 7.0,
64
+ 42,
65
+ 20,
66
+ ],
67
+ ]
68
+ return case
69
+
70
+ def run_for_examples(style_image_pil, target, prompt, guidance_scale, seed, end_fusion):
71
+
72
+ return create_image(
73
+ style_image_pil=style_image_pil,
74
+ prompt=prompt,
75
+ guidance_scale=7.0,
76
+ num_inference_steps=50,
77
+ seed=42,
78
+ end_fusion=end_fusion,
79
+ use_SAttn=True,
80
+ crossModalAdaIN=True,
81
+ )
82
+
83
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
84
+ if randomize_seed:
85
+ seed = random.randint(0, MAX_SEED)
86
+ return seed
87
+
88
+ def create_image(
89
+ style_image_pil,
90
+ prompt,
91
+ guidance_scale,
92
+ num_inference_steps,
93
+ end_fusion,
94
+ crossModalAdaIN,
95
+ use_SAttn,
96
+ seed,
97
+ negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
98
+ ):
99
+
100
+ style_image = style_image_pil
101
+
102
+ generator = torch.Generator(device).manual_seed(seed)
103
+ init_latents = torch.randn((1, 4, 128, 128), generator=generator, device="cuda", dtype=torch.float16)
104
+ num_sample=1
105
+ if use_SAttn:
106
+ num_sample=2
107
+ init_latents = init_latents.repeat(num_sample, 1, 1, 1)
108
+ with torch.no_grad():
109
+ images = csgo.generate(pil_style_image=style_image,
110
+ prompt=prompt,
111
+ negative_prompt=negative_prompt,
112
+ height=1024,
113
+ width=1024,
114
+ guidance_scale=guidance_scale,
115
+ num_images_per_prompt=1,
116
+ num_samples=num_sample,
117
+ num_inference_steps=num_inference_steps,
118
+ end_fusion=end_fusion,
119
+ cross_modal_adain=crossModalAdaIN,
120
+ use_SAttn=use_SAttn,
121
+
122
+ generator=generator,
123
+ )
124
+
125
+ if use_SAttn:
126
+ return [images[1]]
127
+ else:
128
+ return [images[0]]
129
+
130
+ # Description
131
+ title = r"""
132
+ <h1 align="center">StyleStudio: Text-Driven Style Transfer with Selective Control of Style Elements</h1>
133
+ """
134
+
135
+ description = r"""
136
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/MingKunLei/StyleStudio' target='_blank'><b>StyleStudio: Text-Driven Style Transfer with Selective Control of Style Elements</b></a>.<br>
137
+ How to use:<br>
138
+ 1. Upload a style image.
139
+ 2. <b>Enter your desired prompt<b>.
140
+ 3. Click the <b>Submit</b> button to begin customization.
141
+ 4. Share your stylized photo with your friends and enjoy! 😊
142
+
143
+ Advanced usage:<br>
144
+ 1. Click advanced options.
145
+ 2. Choose different guidance and steps.
146
+ 3. Set the timing for the Teacher Model's participation
147
+ """
148
+
149
+ article = r"""
150
+ ---
151
+ 📝 **Tips**
152
+ As the value of end_fusion increases, the style gradually diminishes.
153
+ ---
154
+ 📝 **Citation**
155
+ <br>
156
+ If our work is helpful for your research or applications, please cite us via:
157
+ ```bibtex
158
+
159
+ ```
160
+ 📧 **Contact**
161
+ <br>
162
+ If you have any questions, please feel free to open an issue or directly reach us out at <b>leimingkun@westlake.edu.cn</b>.
163
+ """
164
+
165
+ block = gr.Blocks(css="footer {visibility: hidden}").queue(max_size=10, api_open=False)
166
+ with block:
167
+ gr.Markdown(title)
168
+ gr.Markdown(description)
169
+
170
+ with gr.Tabs():
171
+ with gr.Row():
172
+ with gr.Column():
173
+ with gr.Row():
174
+ with gr.Column():
175
+ style_image_pil = gr.Image(label="Style Image", type='pil')
176
+
177
+ target = gr.Radio(["Text-Driven Style Synthesis"],
178
+ value="Text-Driven Style Synthesis",
179
+ label="task")
180
+
181
+ prompt = gr.Textbox(label="Prompt",
182
+ value="A red apple")
183
+
184
+ neg_prompt = gr.Textbox(label="Negative Prompt",
185
+ value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
186
+
187
+ with gr.Accordion(open=True, label="Advanced Options"):
188
+
189
+ guidance_scale = gr.Slider(minimum=1, maximum=15.0, step=0.01, value=7.0, label="guidance scale")
190
+
191
+ num_inference_steps = gr.Slider(minimum=5, maximum=100.0, step=1.0, value=50,
192
+ label="num inference steps")
193
+
194
+ end_fusion = gr.Slider(minimum=0, maximum=num_inference_steps, step=1.0, value=20.0, label="end fusion")
195
+
196
+ seed = gr.Slider(minimum=-1000000, maximum=1000000, value=1, step=1, label="Seed Value")
197
+
198
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
199
+
200
+ crossModalAdaIN = gr.Checkbox(label="Cross Modal AdaIN", value=True)
201
+ use_SAttn = gr.Checkbox(label="Teacher Model", value=True)
202
+
203
+ generate_button = gr.Button("Generate Image")
204
+
205
+ with gr.Column():
206
+ generated_image = gr.Gallery(label="Generated Image")
207
+
208
+ generate_button.click(
209
+ fn=randomize_seed_fn,
210
+ inputs=[seed, randomize_seed],
211
+ outputs=seed,
212
+ queue=False,
213
+ api_name=False,
214
+ ).then(
215
+ fn=create_image,
216
+ inputs=[
217
+ style_image_pil,
218
+ prompt,
219
+ guidance_scale,
220
+ num_inference_steps,
221
+ end_fusion,
222
+ crossModalAdaIN,
223
+ use_SAttn,
224
+ seed,
225
+ neg_prompt,],
226
+ outputs=[generated_image])
227
+
228
+ gr.Examples(
229
+ examples=get_example(),
230
+ inputs=[style_image_pil, target, prompt, guidance_scale, seed, end_fusion],
231
+ fn=run_for_examples,
232
+ outputs=[generated_image],
233
+ cache_examples=True,
234
+ )
235
+
236
+ gr.Markdown(article)
237
+
238
+ block.launch()
assets/style1.jpg ADDED
assets/style2.jpg ADDED
ip_adapter/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull,IPAdapterXL_CS,IPAdapter_CS
2
+ from .ip_adapter import CSGO
3
+ from .ip_adapter import StyleStudio_Adapter, StyleStudio_Adapter_exp
4
+ from .ip_adapter import IPAdapterXL_cross_modal
5
+ __all__ = [
6
+ "IPAdapter",
7
+ "IPAdapterPlus",
8
+ "IPAdapterPlusXL",
9
+ "IPAdapterXL",
10
+ "CSGO",
11
+ "StyleStudio_Adapter",
12
+ "StyleStudio_Adapter_exp",
13
+ "IPAdapterXL_cross_modal",
14
+ "IPAdapterFull",
15
+ ]
ip_adapter/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (555 Bytes). View file
 
ip_adapter/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (545 Bytes). View file
 
ip_adapter/__pycache__/attention_processor.cpython-310.pyc ADDED
Binary file (27.2 kB). View file
 
ip_adapter/__pycache__/attention_processor.cpython-39.pyc ADDED
Binary file (27.7 kB). View file
 
ip_adapter/__pycache__/ip_adapter.cpython-310.pyc ADDED
Binary file (32.2 kB). View file
 
ip_adapter/__pycache__/ip_adapter.cpython-39.pyc ADDED
Binary file (33.3 kB). View file
 
ip_adapter/__pycache__/resampler.cpython-310.pyc ADDED
Binary file (4.25 kB). View file
 
ip_adapter/__pycache__/resampler.cpython-39.pyc ADDED
Binary file (4.21 kB). View file
 
ip_adapter/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.64 kB). View file
 
ip_adapter/__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.61 kB). View file
 
ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,1645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.fft as fft
7
+ import pdb
8
+
9
+
10
+ class AttnProcessor(nn.Module):
11
+ r"""
12
+ Default processor for performing attention-related computations.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ hidden_size=None,
18
+ cross_attention_dim=None,
19
+ save_in_unet='down',
20
+ atten_control=None,
21
+ ):
22
+ super().__init__()
23
+ self.atten_control = atten_control
24
+ self.save_in_unet = save_in_unet
25
+
26
+ def __call__(
27
+ self,
28
+ attn,
29
+ hidden_states,
30
+ encoder_hidden_states=None,
31
+ attention_mask=None,
32
+ temb=None,
33
+ ):
34
+ residual = hidden_states
35
+
36
+ if attn.spatial_norm is not None:
37
+ hidden_states = attn.spatial_norm(hidden_states, temb)
38
+
39
+ input_ndim = hidden_states.ndim
40
+
41
+ if input_ndim == 4:
42
+ batch_size, channel, height, width = hidden_states.shape
43
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
44
+
45
+ batch_size, sequence_length, _ = (
46
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
47
+ )
48
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
49
+
50
+ if attn.group_norm is not None:
51
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
52
+
53
+ query = attn.to_q(hidden_states)
54
+
55
+ if encoder_hidden_states is None:
56
+ encoder_hidden_states = hidden_states
57
+ elif attn.norm_cross:
58
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
59
+
60
+ key = attn.to_k(encoder_hidden_states)
61
+ value = attn.to_v(encoder_hidden_states)
62
+
63
+ query = attn.head_to_batch_dim(query)
64
+ key = attn.head_to_batch_dim(key)
65
+ value = attn.head_to_batch_dim(value)
66
+
67
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
68
+ hidden_states = torch.bmm(attention_probs, value)
69
+ hidden_states = attn.batch_to_head_dim(hidden_states)
70
+
71
+ # linear proj
72
+ hidden_states = attn.to_out[0](hidden_states)
73
+ # dropout
74
+ hidden_states = attn.to_out[1](hidden_states)
75
+
76
+ if input_ndim == 4:
77
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
78
+
79
+ if attn.residual_connection:
80
+ hidden_states = hidden_states + residual
81
+
82
+ hidden_states = hidden_states / attn.rescale_output_factor
83
+
84
+ return hidden_states
85
+
86
+
87
+ class IPAttnProcessor(nn.Module):
88
+ r"""
89
+ Attention processor for IP-Adapater.
90
+ Args:
91
+ hidden_size (`int`):
92
+ The hidden size of the attention layer.
93
+ cross_attention_dim (`int`):
94
+ The number of channels in the `encoder_hidden_states`.
95
+ scale (`float`, defaults to 1.0):
96
+ the weight scale of image prompt.
97
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
98
+ The context length of the image features.
99
+ """
100
+
101
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,save_in_unet='down', atten_control=None):
102
+ super().__init__()
103
+
104
+ self.hidden_size = hidden_size
105
+ self.cross_attention_dim = cross_attention_dim
106
+ self.scale = scale
107
+ self.num_tokens = num_tokens
108
+ self.skip = skip
109
+
110
+ self.atten_control = atten_control
111
+ self.save_in_unet = save_in_unet
112
+
113
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
114
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
115
+
116
+ def __call__(
117
+ self,
118
+ attn,
119
+ hidden_states,
120
+ encoder_hidden_states=None,
121
+ attention_mask=None,
122
+ temb=None,
123
+ ):
124
+ residual = hidden_states
125
+
126
+ if attn.spatial_norm is not None:
127
+ hidden_states = attn.spatial_norm(hidden_states, temb)
128
+
129
+ input_ndim = hidden_states.ndim
130
+
131
+ if input_ndim == 4:
132
+ batch_size, channel, height, width = hidden_states.shape
133
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
134
+
135
+ batch_size, sequence_length, _ = (
136
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
137
+ )
138
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
139
+
140
+ if attn.group_norm is not None:
141
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
142
+
143
+ query = attn.to_q(hidden_states)
144
+
145
+ if encoder_hidden_states is None:
146
+ encoder_hidden_states = hidden_states
147
+ else:
148
+ # get encoder_hidden_states, ip_hidden_states
149
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
150
+ encoder_hidden_states, ip_hidden_states = (
151
+ encoder_hidden_states[:, :end_pos, :],
152
+ encoder_hidden_states[:, end_pos:, :],
153
+ )
154
+ if attn.norm_cross:
155
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
156
+
157
+ key = attn.to_k(encoder_hidden_states)
158
+ value = attn.to_v(encoder_hidden_states)
159
+
160
+ query = attn.head_to_batch_dim(query)
161
+ key = attn.head_to_batch_dim(key)
162
+ value = attn.head_to_batch_dim(value)
163
+
164
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
165
+ hidden_states = torch.bmm(attention_probs, value)
166
+ hidden_states = attn.batch_to_head_dim(hidden_states)
167
+
168
+ if not self.skip:
169
+ # for ip-adapter
170
+ ip_key = self.to_k_ip(ip_hidden_states)
171
+ ip_value = self.to_v_ip(ip_hidden_states)
172
+
173
+ ip_key = attn.head_to_batch_dim(ip_key)
174
+ ip_value = attn.head_to_batch_dim(ip_value)
175
+
176
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
177
+ self.attn_map = ip_attention_probs
178
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
179
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
180
+
181
+ hidden_states = hidden_states + self.scale * ip_hidden_states
182
+
183
+ # linear proj
184
+ hidden_states = attn.to_out[0](hidden_states)
185
+ # dropout
186
+ hidden_states = attn.to_out[1](hidden_states)
187
+
188
+ if input_ndim == 4:
189
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
190
+
191
+ if attn.residual_connection:
192
+ hidden_states = hidden_states + residual
193
+
194
+ hidden_states = hidden_states / attn.rescale_output_factor
195
+
196
+ return hidden_states
197
+
198
+
199
+ class AttnProcessor2_0(torch.nn.Module):
200
+ r"""
201
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ hidden_size=None,
207
+ cross_attention_dim=None,
208
+ save_in_unet='down',
209
+ atten_control=None,
210
+ ):
211
+ super().__init__()
212
+ if not hasattr(F, "scaled_dot_product_attention"):
213
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
214
+ self.atten_control = atten_control
215
+ self.save_in_unet = save_in_unet
216
+
217
+ def __call__(
218
+ self,
219
+ attn,
220
+ hidden_states,
221
+ encoder_hidden_states=None,
222
+ attention_mask=None,
223
+ temb=None,
224
+ ):
225
+ residual = hidden_states
226
+
227
+ if attn.spatial_norm is not None:
228
+ hidden_states = attn.spatial_norm(hidden_states, temb)
229
+
230
+ input_ndim = hidden_states.ndim
231
+
232
+ if input_ndim == 4:
233
+ batch_size, channel, height, width = hidden_states.shape
234
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
235
+
236
+ batch_size, sequence_length, _ = (
237
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
238
+ )
239
+
240
+ if attention_mask is not None:
241
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
242
+ # scaled_dot_product_attention expects attention_mask shape to be
243
+ # (batch, heads, source_length, target_length)
244
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
245
+
246
+ if attn.group_norm is not None:
247
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
248
+
249
+ query = attn.to_q(hidden_states)
250
+
251
+ if encoder_hidden_states is None:
252
+ encoder_hidden_states = hidden_states
253
+ elif attn.norm_cross:
254
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
255
+
256
+ key = attn.to_k(encoder_hidden_states)
257
+ value = attn.to_v(encoder_hidden_states)
258
+
259
+ inner_dim = key.shape[-1]
260
+ head_dim = inner_dim // attn.heads
261
+
262
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
263
+
264
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
265
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
266
+
267
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
268
+ # TODO: add support for attn.scale when we move to Torch 2.1
269
+ hidden_states = F.scaled_dot_product_attention(
270
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
271
+ )
272
+
273
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
274
+ hidden_states = hidden_states.to(query.dtype)
275
+
276
+ # linear proj
277
+ hidden_states = attn.to_out[0](hidden_states)
278
+ # dropout
279
+ hidden_states = attn.to_out[1](hidden_states)
280
+
281
+ if input_ndim == 4:
282
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
283
+
284
+ if attn.residual_connection:
285
+ hidden_states = hidden_states + residual
286
+
287
+ hidden_states = hidden_states / attn.rescale_output_factor
288
+
289
+ return hidden_states
290
+
291
+
292
+ class IPAttnProcessor2_0(torch.nn.Module):
293
+ r"""
294
+ Attention processor for IP-Adapater for PyTorch 2.0.
295
+ Args:
296
+ hidden_size (`int`):
297
+ The hidden size of the attention layer.
298
+ cross_attention_dim (`int`):
299
+ The number of channels in the `encoder_hidden_states`.
300
+ scale (`float`, defaults to 1.0):
301
+ the weight scale of image prompt.
302
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
303
+ The context length of the image features.
304
+ """
305
+
306
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,save_in_unet='down', atten_control=None):
307
+ super().__init__()
308
+
309
+ if not hasattr(F, "scaled_dot_product_attention"):
310
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
311
+
312
+ self.hidden_size = hidden_size
313
+ self.cross_attention_dim = cross_attention_dim
314
+ self.scale = scale
315
+ self.num_tokens = num_tokens
316
+ self.skip = skip
317
+
318
+ self.atten_control = atten_control
319
+ self.save_in_unet = save_in_unet
320
+
321
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
322
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
323
+
324
+ def __call__(
325
+ self,
326
+ attn,
327
+ hidden_states,
328
+ encoder_hidden_states=None,
329
+ attention_mask=None,
330
+ temb=None,
331
+ ):
332
+ residual = hidden_states
333
+
334
+ if attn.spatial_norm is not None:
335
+ hidden_states = attn.spatial_norm(hidden_states, temb)
336
+
337
+ input_ndim = hidden_states.ndim
338
+
339
+ if input_ndim == 4:
340
+ batch_size, channel, height, width = hidden_states.shape
341
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
342
+
343
+ batch_size, sequence_length, _ = (
344
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
345
+ )
346
+
347
+ if attention_mask is not None:
348
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
349
+ # scaled_dot_product_attention expects attention_mask shape to be
350
+ # (batch, heads, source_length, target_length)
351
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
352
+
353
+ if attn.group_norm is not None:
354
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
355
+
356
+ query = attn.to_q(hidden_states)
357
+
358
+ if encoder_hidden_states is None:
359
+ encoder_hidden_states = hidden_states
360
+ else:
361
+ # get encoder_hidden_states, ip_hidden_states
362
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
363
+ encoder_hidden_states, ip_hidden_states = (
364
+ encoder_hidden_states[:, :end_pos, :],
365
+ encoder_hidden_states[:, end_pos:, :],
366
+ )
367
+ if attn.norm_cross:
368
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
369
+
370
+ key = attn.to_k(encoder_hidden_states)
371
+ value = attn.to_v(encoder_hidden_states)
372
+
373
+ inner_dim = key.shape[-1]
374
+ head_dim = inner_dim // attn.heads
375
+
376
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
377
+
378
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
379
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
380
+
381
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
382
+ # TODO: add support for attn.scale when we move to Torch 2.1
383
+ hidden_states = F.scaled_dot_product_attention(
384
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
385
+ )
386
+
387
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
388
+ hidden_states = hidden_states.to(query.dtype)
389
+
390
+ if not self.skip:
391
+ # for ip-adapter
392
+ ip_key = self.to_k_ip(ip_hidden_states)
393
+ ip_value = self.to_v_ip(ip_hidden_states)
394
+
395
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
396
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
397
+
398
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
399
+ # TODO: add support for attn.scale when we move to Torch 2.1
400
+ ip_hidden_states = F.scaled_dot_product_attention(
401
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
402
+ )
403
+ with torch.no_grad():
404
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
405
+ #print(self.attn_map.shape)
406
+
407
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
408
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
409
+
410
+ hidden_states = hidden_states + self.scale * ip_hidden_states
411
+
412
+ # linear proj
413
+ hidden_states = attn.to_out[0](hidden_states)
414
+ # dropout
415
+ hidden_states = attn.to_out[1](hidden_states)
416
+
417
+ if input_ndim == 4:
418
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
419
+
420
+ if attn.residual_connection:
421
+ hidden_states = hidden_states + residual
422
+
423
+ hidden_states = hidden_states / attn.rescale_output_factor
424
+
425
+ return hidden_states
426
+
427
+
428
+ class IP_CS_AttnProcessor2_0(torch.nn.Module):
429
+ r"""
430
+ Attention processor for IP-Adapater for PyTorch 2.0.
431
+ Args:
432
+ hidden_size (`int`):
433
+ The hidden size of the attention layer.
434
+ cross_attention_dim (`int`):
435
+ The number of channels in the `encoder_hidden_states`.
436
+ scale (`float`, defaults to 1.0):
437
+ the weight scale of image prompt.
438
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
439
+ The context length of the image features.
440
+ """
441
+
442
+ def __init__(self, hidden_size, cross_attention_dim=None, content_scale=1.0,style_scale=1.0, num_content_tokens=4,num_style_tokens=4,
443
+ skip=False,content=False, style=False):
444
+ super().__init__()
445
+
446
+ if not hasattr(F, "scaled_dot_product_attention"):
447
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
448
+
449
+ self.hidden_size = hidden_size
450
+ self.cross_attention_dim = cross_attention_dim
451
+ self.content_scale = content_scale
452
+ self.style_scale = style_scale
453
+ self.num_content_tokens = num_content_tokens
454
+ self.num_style_tokens = num_style_tokens
455
+ self.skip = skip
456
+
457
+ self.content = content
458
+ self.style = style
459
+
460
+ if self.content or self.style:
461
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
462
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
463
+ self.to_k_ip_content =None
464
+ self.to_v_ip_content =None
465
+
466
+ def set_content_ipa(self,content_scale=1.0):
467
+
468
+ self.to_k_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
469
+ self.to_v_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
470
+ self.content_scale=content_scale
471
+ self.content =True
472
+
473
+ def __call__(
474
+ self,
475
+ attn,
476
+ hidden_states,
477
+ encoder_hidden_states=None,
478
+ attention_mask=None,
479
+ temb=None,
480
+ ):
481
+ residual = hidden_states
482
+
483
+ if attn.spatial_norm is not None:
484
+ hidden_states = attn.spatial_norm(hidden_states, temb)
485
+
486
+ input_ndim = hidden_states.ndim
487
+
488
+ if input_ndim == 4:
489
+ batch_size, channel, height, width = hidden_states.shape
490
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
491
+
492
+ batch_size, sequence_length, _ = (
493
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
494
+ )
495
+
496
+ if attention_mask is not None:
497
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
498
+ # scaled_dot_product_attention expects attention_mask shape to be
499
+ # (batch, heads, source_length, target_length)
500
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
501
+
502
+ if attn.group_norm is not None:
503
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
504
+
505
+ query = attn.to_q(hidden_states)
506
+
507
+ if encoder_hidden_states is None:
508
+ encoder_hidden_states = hidden_states
509
+ else:
510
+ # get encoder_hidden_states, ip_hidden_states
511
+ end_pos = encoder_hidden_states.shape[1] - self.num_content_tokens-self.num_style_tokens
512
+ encoder_hidden_states, ip_content_hidden_states,ip_style_hidden_states = (
513
+ encoder_hidden_states[:, :end_pos, :],
514
+ encoder_hidden_states[:, end_pos:end_pos + self.num_content_tokens, :],
515
+ encoder_hidden_states[:, end_pos + self.num_content_tokens:, :],
516
+ )
517
+ if attn.norm_cross:
518
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
519
+
520
+ key = attn.to_k(encoder_hidden_states)
521
+ value = attn.to_v(encoder_hidden_states)
522
+
523
+ inner_dim = key.shape[-1]
524
+ head_dim = inner_dim // attn.heads
525
+
526
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
527
+
528
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
529
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
530
+
531
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
532
+ # TODO: add support for attn.scale when we move to Torch 2.1
533
+ hidden_states = F.scaled_dot_product_attention(
534
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
535
+ )
536
+
537
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
538
+ hidden_states = hidden_states.to(query.dtype)
539
+
540
+ if self.content is True:
541
+ exit()
542
+ if not self.skip and self.content is True:
543
+ # print('content#####################################################')
544
+ # for ip-content-adapter
545
+ if self.to_k_ip_content is None:
546
+
547
+ ip_content_key = self.to_k_ip(ip_content_hidden_states)
548
+ ip_content_value = self.to_v_ip(ip_content_hidden_states)
549
+ else:
550
+ ip_content_key = self.to_k_ip_content(ip_content_hidden_states)
551
+ ip_content_value = self.to_v_ip_content(ip_content_hidden_states)
552
+
553
+ ip_content_key = ip_content_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
554
+ ip_content_value = ip_content_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
555
+
556
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
557
+ # TODO: add support for attn.scale when we move to Torch 2.1
558
+ ip_content_hidden_states = F.scaled_dot_product_attention(
559
+ query, ip_content_key, ip_content_value, attn_mask=None, dropout_p=0.0, is_causal=False
560
+ )
561
+
562
+
563
+ ip_content_hidden_states = ip_content_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
564
+ ip_content_hidden_states = ip_content_hidden_states.to(query.dtype)
565
+
566
+ hidden_states = hidden_states + self.content_scale * ip_content_hidden_states
567
+
568
+ if not self.skip and self.style is True:
569
+ # for ip-style-adapter
570
+ ip_style_key = self.to_k_ip(ip_style_hidden_states)
571
+ ip_style_value = self.to_v_ip(ip_style_hidden_states)
572
+
573
+ ip_style_key = ip_style_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
574
+ ip_style_value = ip_style_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
575
+
576
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
577
+ # TODO: add support for attn.scale when we move to Torch 2.1
578
+ ip_style_hidden_states = F.scaled_dot_product_attention(
579
+ query, ip_style_key, ip_style_value, attn_mask=None, dropout_p=0.0, is_causal=False
580
+ )
581
+
582
+ ip_style_hidden_states = ip_style_hidden_states.transpose(1, 2).reshape(batch_size, -1,
583
+ attn.heads * head_dim)
584
+ ip_style_hidden_states = ip_style_hidden_states.to(query.dtype)
585
+
586
+ hidden_states = hidden_states + self.style_scale * ip_style_hidden_states
587
+
588
+ # linear proj
589
+ hidden_states = attn.to_out[0](hidden_states)
590
+ # dropout
591
+ hidden_states = attn.to_out[1](hidden_states)
592
+
593
+ if input_ndim == 4:
594
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
595
+
596
+ if attn.residual_connection:
597
+ hidden_states = hidden_states + residual
598
+
599
+ hidden_states = hidden_states / attn.rescale_output_factor
600
+
601
+ return hidden_states
602
+
603
+ ## for controlnet
604
+ class CNAttnProcessor:
605
+ r"""
606
+ Default processor for performing attention-related computations.
607
+ """
608
+
609
+ def __init__(self, num_tokens=4,save_in_unet='down',atten_control=None):
610
+ self.num_tokens = num_tokens
611
+ self.atten_control = atten_control
612
+ self.save_in_unet = save_in_unet
613
+
614
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
615
+ residual = hidden_states
616
+
617
+ if attn.spatial_norm is not None:
618
+ hidden_states = attn.spatial_norm(hidden_states, temb)
619
+
620
+ input_ndim = hidden_states.ndim
621
+
622
+ if input_ndim == 4:
623
+ batch_size, channel, height, width = hidden_states.shape
624
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
625
+
626
+ batch_size, sequence_length, _ = (
627
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
628
+ )
629
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
630
+
631
+ if attn.group_norm is not None:
632
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
633
+
634
+ query = attn.to_q(hidden_states)
635
+
636
+ if encoder_hidden_states is None:
637
+ encoder_hidden_states = hidden_states
638
+ else:
639
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
640
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
641
+ if attn.norm_cross:
642
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
643
+
644
+ key = attn.to_k(encoder_hidden_states)
645
+ value = attn.to_v(encoder_hidden_states)
646
+
647
+ query = attn.head_to_batch_dim(query)
648
+ key = attn.head_to_batch_dim(key)
649
+ value = attn.head_to_batch_dim(value)
650
+
651
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
652
+ hidden_states = torch.bmm(attention_probs, value)
653
+ hidden_states = attn.batch_to_head_dim(hidden_states)
654
+
655
+ # linear proj
656
+ hidden_states = attn.to_out[0](hidden_states)
657
+ # dropout
658
+ hidden_states = attn.to_out[1](hidden_states)
659
+
660
+ if input_ndim == 4:
661
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
662
+
663
+ if attn.residual_connection:
664
+ hidden_states = hidden_states + residual
665
+
666
+ hidden_states = hidden_states / attn.rescale_output_factor
667
+
668
+ return hidden_states
669
+
670
+
671
+ class CNAttnProcessor2_0:
672
+ r"""
673
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
674
+ """
675
+
676
+ def __init__(self, num_tokens=4, save_in_unet='down', atten_control=None):
677
+ if not hasattr(F, "scaled_dot_product_attention"):
678
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
679
+ self.num_tokens = num_tokens
680
+ self.atten_control = atten_control
681
+ self.save_in_unet = save_in_unet
682
+
683
+ def __call__(
684
+ self,
685
+ attn,
686
+ hidden_states,
687
+ encoder_hidden_states=None,
688
+ attention_mask=None,
689
+ temb=None,
690
+ ):
691
+ residual = hidden_states
692
+
693
+ if attn.spatial_norm is not None:
694
+ hidden_states = attn.spatial_norm(hidden_states, temb)
695
+
696
+ input_ndim = hidden_states.ndim
697
+
698
+ if input_ndim == 4:
699
+ batch_size, channel, height, width = hidden_states.shape
700
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
701
+
702
+ batch_size, sequence_length, _ = (
703
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
704
+ )
705
+
706
+ if attention_mask is not None:
707
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
708
+ # scaled_dot_product_attention expects attention_mask shape to be
709
+ # (batch, heads, source_length, target_length)
710
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
711
+
712
+ if attn.group_norm is not None:
713
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
714
+
715
+ query = attn.to_q(hidden_states)
716
+
717
+ if encoder_hidden_states is None:
718
+ encoder_hidden_states = hidden_states
719
+ else:
720
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
721
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
722
+ if attn.norm_cross:
723
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
724
+
725
+ key = attn.to_k(encoder_hidden_states)
726
+ value = attn.to_v(encoder_hidden_states)
727
+
728
+ inner_dim = key.shape[-1]
729
+ head_dim = inner_dim // attn.heads
730
+
731
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
732
+
733
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
734
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
735
+
736
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
737
+ # TODO: add support for attn.scale when we move to Torch 2.1
738
+ hidden_states = F.scaled_dot_product_attention(
739
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
740
+ )
741
+
742
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
743
+ hidden_states = hidden_states.to(query.dtype)
744
+
745
+ # linear proj
746
+ hidden_states = attn.to_out[0](hidden_states)
747
+ # dropout
748
+ hidden_states = attn.to_out[1](hidden_states)
749
+
750
+ if input_ndim == 4:
751
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
752
+
753
+ if attn.residual_connection:
754
+ hidden_states = hidden_states + residual
755
+
756
+ hidden_states = hidden_states / attn.rescale_output_factor
757
+
758
+ return hidden_states
759
+
760
+ class IP_FuAd_AttnProcessor2_0(torch.nn.Module):
761
+ r"""
762
+ Attention processor for IP-Adapater for PyTorch 2.0.
763
+ Args:
764
+ hidden_size (`int`):
765
+ The hidden size of the attention layer.
766
+ cross_attention_dim (`int`):
767
+ The number of channels in the `encoder_hidden_states`.
768
+ scale (`float`, defaults to 1.0):
769
+ the weight scale of image prompt.
770
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
771
+ The context length of the image features.
772
+ """
773
+
774
+ def __init__(self, hidden_size, cross_attention_dim=None, content_scale=1.0,style_scale=1.0, num_content_tokens=4,num_style_tokens=4,
775
+ skip=False,content=False, style=False, fuAttn=False, fuIPAttn=False, adainIP=False,
776
+ fuScale=0, end_fusion=0, attn_name=None):
777
+ super().__init__()
778
+
779
+ if not hasattr(F, "scaled_dot_product_attention"):
780
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
781
+
782
+ self.hidden_size = hidden_size
783
+ self.cross_attention_dim = cross_attention_dim
784
+ self.content_scale = content_scale
785
+ self.style_scale = style_scale
786
+ self.num_style_tokens = num_style_tokens
787
+ self.skip = skip
788
+
789
+ self.content = content
790
+ self.style = style
791
+
792
+ self.fuAttn = fuAttn
793
+ self.fuIPAttn = fuIPAttn
794
+ self.adainIP = adainIP
795
+ self.fuScale = fuScale
796
+ self.denoise_step = 0
797
+ self.end_fusion = end_fusion
798
+ self.name = attn_name
799
+
800
+ if self.content or self.style:
801
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
802
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
803
+ self.to_k_ip_content =None
804
+ self.to_v_ip_content =None
805
+
806
+ # def set_content_ipa(self,content_scale=1.0):
807
+
808
+ # self.to_k_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
809
+ # self.to_v_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
810
+ # self.content_scale=content_scale
811
+ # self.content =True
812
+
813
+ def reset_denoise_step(self):
814
+ if self.denoise_step == 50:
815
+ self.denoise_step = 0
816
+ # if "up_blocks.0.attentions.1.transformer_blocks.0.attn2" in self.name:
817
+ # print("attn2 reset successful")
818
+
819
+ def __call__(
820
+ self,
821
+ attn,
822
+ hidden_states,
823
+ encoder_hidden_states=None,
824
+ attention_mask=None,
825
+ temb=None,
826
+ ):
827
+ self.denoise_step += 1
828
+ residual = hidden_states
829
+
830
+ if attn.spatial_norm is not None:
831
+ hidden_states = attn.spatial_norm(hidden_states, temb)
832
+
833
+ input_ndim = hidden_states.ndim
834
+
835
+ if input_ndim == 4:
836
+ batch_size, channel, height, width = hidden_states.shape
837
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
838
+
839
+ batch_size, sequence_length, _ = (
840
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
841
+ )
842
+
843
+ if attention_mask is not None:
844
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
845
+ # scaled_dot_product_attention expects attention_mask shape to be
846
+ # (batch, heads, source_length, target_length)
847
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
848
+
849
+ if attn.group_norm is not None:
850
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
851
+
852
+ query = attn.to_q(hidden_states)
853
+
854
+ if encoder_hidden_states is None:
855
+ encoder_hidden_states = hidden_states
856
+ else:
857
+ # get encoder_hidden_states, ip_hidden_states
858
+ end_pos = encoder_hidden_states.shape[1] -self.num_style_tokens
859
+ encoder_hidden_states, ip_style_hidden_states = (
860
+ encoder_hidden_states[:, :end_pos, :],
861
+ encoder_hidden_states[:, end_pos:, :],
862
+ )
863
+ if attn.norm_cross:
864
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
865
+
866
+ key = attn.to_k(encoder_hidden_states)
867
+ value = attn.to_v(encoder_hidden_states)
868
+
869
+ inner_dim = key.shape[-1]
870
+ head_dim = inner_dim // attn.heads
871
+
872
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
873
+
874
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
875
+
876
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
877
+
878
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
879
+ # TODO: add support for attn.scale when we move to Torch 2.1
880
+ # # modified the attnMap of the Stylization Image
881
+
882
+ if self.fuAttn and self.denoise_step <= self.end_fusion:
883
+ assert query.shape[0] == 4
884
+ scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
885
+ text_attn_probs = (torch.matmul(query, key.transpose(-2, -1)) * scale_factor).softmax(dim=-1)
886
+ text_attn_probs[1] = self.fuScale*text_attn_probs[1] + (1-self.fuScale)*text_attn_probs[0]
887
+ text_attn_probs[3] = self.fuScale*text_attn_probs[3] + (1-self.fuScale)*text_attn_probs[2]
888
+ hidden_states = torch.matmul(text_attn_probs, value)
889
+ else:
890
+ hidden_states = F.scaled_dot_product_attention(
891
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
892
+ )
893
+
894
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
895
+ hidden_states = hidden_states.to(query.dtype)
896
+
897
+ raw_hidden_states = hidden_states
898
+
899
+ if not self.skip and self.style is True:
900
+
901
+ # for ip-style-adapter
902
+ ip_style_key = self.to_k_ip(ip_style_hidden_states)
903
+ ip_style_value = self.to_v_ip(ip_style_hidden_states)
904
+
905
+ ip_style_key = ip_style_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
906
+ ip_style_value = ip_style_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
907
+
908
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
909
+ # TODO: add support for attn.scale when we move to Torch 2.1
910
+ if self.fuIPAttn and self.denoise_step <= self.end_fusion:
911
+ assert query.shape[0] == 4
912
+ if "down" in self.name:
913
+ print("wrong! coding")
914
+ exit()
915
+ scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
916
+ ip_attn_probs = torch.matmul(query, ip_style_key.transpose(-2, -1)) * scale_factor
917
+ ip_attn_probs = F.softmax(ip_attn_probs, dim=-1)
918
+ ip_attn_probs[1] = self.fuScale*ip_attn_probs[1] + (1-self.fuScale)*ip_attn_probs[0]
919
+ ip_attn_probs[3] = self.fuScale*ip_attn_probs[3] + (1-self.fuScale)*ip_attn_probs[2]
920
+ ip_style_hidden_states = torch.matmul(ip_attn_probs, ip_style_value)
921
+ else:
922
+ ip_style_hidden_states = F.scaled_dot_product_attention(
923
+ query, ip_style_key, ip_style_value, attn_mask=None, dropout_p=0.0, is_causal=False
924
+ )
925
+
926
+ ip_style_hidden_states = ip_style_hidden_states.transpose(1, 2).reshape(batch_size, -1,
927
+ attn.heads * head_dim)
928
+ ip_style_hidden_states = ip_style_hidden_states.to(query.dtype)
929
+
930
+ if not self.adainIP:
931
+ hidden_states = hidden_states + self.style_scale * ip_style_hidden_states
932
+ else:
933
+ # print("adain")
934
+ def adain(content, style):
935
+ content_mean = content.mean(dim=1, keepdim=True)
936
+ content_std = content.std(dim=1, keepdim=True)
937
+ style_mean = style.mean(dim=1, keepdim=True)
938
+ style_std = style.std(dim=1, keepdim=True)
939
+ normalized_content = (content - content_mean) / content_std
940
+ stylized_content = normalized_content * style_std + style_mean
941
+ return stylized_content
942
+ hidden_states = adain(content=hidden_states, style=ip_style_hidden_states)
943
+
944
+ if hidden_states.shape[0] == 4:
945
+ hidden_states[0] = raw_hidden_states[0]
946
+ hidden_states[2] = raw_hidden_states[2]
947
+ # hidden_states = raw_hidden_states
948
+
949
+ # linear proj
950
+ hidden_states = attn.to_out[0](hidden_states)
951
+ # dropout
952
+ hidden_states = attn.to_out[1](hidden_states)
953
+
954
+ if input_ndim == 4:
955
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
956
+
957
+ if attn.residual_connection:
958
+ hidden_states = hidden_states + residual
959
+
960
+ hidden_states = hidden_states / attn.rescale_output_factor
961
+
962
+ self.reset_denoise_step()
963
+ return hidden_states
964
+
965
+ class IP_FuAd_AttnProcessor2_0_exp(torch.nn.Module):
966
+ r"""
967
+ Attention processor for IP-Adapater for PyTorch 2.0.
968
+ Args:
969
+ hidden_size (`int`):
970
+ The hidden size of the attention layer.
971
+ cross_attention_dim (`int`):
972
+ The number of channels in the `encoder_hidden_states`.
973
+ scale (`float`, defaults to 1.0):
974
+ the weight scale of image prompt.
975
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
976
+ The context length of the image features.
977
+ """
978
+
979
+ def __init__(self, hidden_size, cross_attention_dim=None, content_scale=1.0,style_scale=1.0, num_content_tokens=4,num_style_tokens=4,
980
+ skip=False,content=False, style=False, fuAttn=False, fuIPAttn=False, adainIP=False,
981
+ fuScale=0, end_fusion=0, attn_name=None, save_attn_map=False):
982
+ super().__init__()
983
+
984
+ if not hasattr(F, "scaled_dot_product_attention"):
985
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
986
+
987
+ self.hidden_size = hidden_size
988
+ self.cross_attention_dim = cross_attention_dim
989
+ self.content_scale = content_scale
990
+ self.style_scale = style_scale
991
+ self.num_style_tokens = num_style_tokens
992
+ self.skip = skip
993
+
994
+ self.content = content
995
+ self.style = style
996
+
997
+ self.fuAttn = fuAttn
998
+ self.fuIPAttn = fuIPAttn
999
+ self.adainIP = adainIP
1000
+ self.fuScale = fuScale
1001
+ self.denoise_step = 0
1002
+ self.end_fusion = end_fusion
1003
+ self.name = attn_name
1004
+
1005
+ self.save_attn_map = save_attn_map
1006
+
1007
+ if self.content or self.style:
1008
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1009
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1010
+ self.to_k_ip_content =None
1011
+ self.to_v_ip_content =None
1012
+
1013
+ # def set_content_ipa(self,content_scale=1.0):
1014
+
1015
+ # self.to_k_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
1016
+ # self.to_v_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
1017
+ # self.content_scale=content_scale
1018
+ # self.content =True
1019
+ def reset_denoise_step(self):
1020
+ if self.denoise_step == 50:
1021
+ self.denoise_step = 0
1022
+ # if "up_blocks.0.attentions.1.transformer_blocks.0.attn2" in self.name:
1023
+ # print("attn2 reset successful")
1024
+
1025
+ def __call__(
1026
+ self,
1027
+ attn,
1028
+ hidden_states,
1029
+ encoder_hidden_states=None,
1030
+ attention_mask=None,
1031
+ temb=None,
1032
+ ):
1033
+ self.denoise_step += 1
1034
+ residual = hidden_states
1035
+
1036
+ if attn.spatial_norm is not None:
1037
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1038
+
1039
+ input_ndim = hidden_states.ndim
1040
+
1041
+ if input_ndim == 4:
1042
+ batch_size, channel, height, width = hidden_states.shape
1043
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1044
+
1045
+ batch_size, sequence_length, _ = (
1046
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1047
+ )
1048
+
1049
+ if attention_mask is not None:
1050
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1051
+ # scaled_dot_product_attention expects attention_mask shape to be
1052
+ # (batch, heads, source_length, target_length)
1053
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1054
+
1055
+ if attn.group_norm is not None:
1056
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1057
+
1058
+ query = attn.to_q(hidden_states)
1059
+
1060
+ if encoder_hidden_states is None:
1061
+ encoder_hidden_states = hidden_states
1062
+ else:
1063
+ # get encoder_hidden_states, ip_hidden_states
1064
+ end_pos = encoder_hidden_states.shape[1] - self.num_content_tokens-self.num_style_tokens
1065
+ encoder_hidden_states, ip_style_hidden_states = (
1066
+ encoder_hidden_states[:, :end_pos, :],
1067
+ encoder_hidden_states[:, end_pos:, :],
1068
+ )
1069
+ if attn.norm_cross:
1070
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1071
+
1072
+ key = attn.to_k(encoder_hidden_states)
1073
+ value = attn.to_v(encoder_hidden_states)
1074
+
1075
+ ## attention map
1076
+ if self.save_attn_map:
1077
+ attention_probs = attn.get_attention_scores(attn.head_to_batch_dim(query), attn.head_to_batch_dim(value), attention_mask)
1078
+ if attention_probs is not None:
1079
+ if not hasattr(attn, "attn_map"):
1080
+ setattr(attn, "attn_map", {})
1081
+ setattr(attn, "inference_step", 0)
1082
+ else:
1083
+ attn.inference_step += 1
1084
+
1085
+ # # maybe we need to save all the timestep
1086
+ # if attn.inference_step in self.attn_map_save_steps:
1087
+ attn.attn_map[attn.inference_step] = attention_probs.clone().cpu().detach()
1088
+ # attn.attn_map[attn.inference_step] = attention_probs.detach()
1089
+ ## end of attention map
1090
+ else:
1091
+ print(f"{attn} didn't get the attention probs")
1092
+
1093
+ inner_dim = key.shape[-1]
1094
+ head_dim = inner_dim // attn.heads
1095
+
1096
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1097
+
1098
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1099
+
1100
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1101
+
1102
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1103
+ # TODO: add support for attn.scale when we move to Torch 2.1
1104
+ # # modified the attnMap of the Stylization Image
1105
+
1106
+ if self.fuAttn and self.denoise_step <= self.end_fusion:
1107
+ assert query.shape[0] == 4
1108
+ scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
1109
+ text_attn_probs = (torch.matmul(query, key.transpose(-2, -1)) * scale_factor).softmax(dim=-1)
1110
+ text_attn_probs[1] = self.fuScale*text_attn_probs[1] + (1-self.fuScale)*text_attn_probs[0]
1111
+ text_attn_probs[3] = self.fuScale*text_attn_probs[3] + (1-self.fuScale)*text_attn_probs[2]
1112
+ hidden_states = torch.matmul(text_attn_probs, value)
1113
+ else:
1114
+ hidden_states = F.scaled_dot_product_attention(
1115
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1116
+ )
1117
+
1118
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1119
+ hidden_states = hidden_states.to(query.dtype)
1120
+
1121
+ raw_hidden_states = hidden_states
1122
+
1123
+ if not self.skip and self.style is True:
1124
+
1125
+ # for ip-style-adapter
1126
+ ip_style_key = self.to_k_ip(ip_style_hidden_states)
1127
+ ip_style_value = self.to_v_ip(ip_style_hidden_states)
1128
+
1129
+ ip_style_key = ip_style_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1130
+ ip_style_value = ip_style_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1131
+
1132
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1133
+ # TODO: add support for attn.scale when we move to Torch 2.1
1134
+ if self.fuIPAttn and self.denoise_step <= self.end_fusion:
1135
+ assert query.shape[0] == 4
1136
+ if "down" in self.name:
1137
+ print("wrong! coding")
1138
+ exit()
1139
+ scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
1140
+ ip_attn_probs = torch.matmul(query, ip_style_key.transpose(-2, -1)) * scale_factor
1141
+ ip_attn_probs = F.softmax(ip_attn_probs, dim=-1)
1142
+ ip_attn_probs[1] = self.fuScale*ip_attn_probs[1] + (1-self.fuScale)*ip_attn_probs[0]
1143
+ ip_attn_probs[3] = self.fuScale*ip_attn_probs[3] + (1-self.fuScale)*ip_attn_probs[2]
1144
+ ip_style_hidden_states = torch.matmul(ip_attn_probs, ip_style_value)
1145
+ else:
1146
+ ip_style_hidden_states = F.scaled_dot_product_attention(
1147
+ query, ip_style_key, ip_style_value, attn_mask=None, dropout_p=0.0, is_causal=False
1148
+ )
1149
+
1150
+ ip_style_hidden_states = ip_style_hidden_states.transpose(1, 2).reshape(batch_size, -1,
1151
+ attn.heads * head_dim)
1152
+ ip_style_hidden_states = ip_style_hidden_states.to(query.dtype)
1153
+
1154
+ # if self.adainIP and self.denoise_step >= self.start_adain:
1155
+ if self.adainIP:
1156
+ # print("adain")
1157
+ # if self.denoise_step == 1 and "up_blocks.1.attentions.2.transformer_blocks.1" in self.name:
1158
+ # print("adain")
1159
+ def adain(content, style):
1160
+ content_mean = content.mean(dim=1, keepdim=True)
1161
+ content_std = content.std(dim=1, keepdim=True)
1162
+ print("exp code")
1163
+ pdb.set_trace()
1164
+ style_mean = style.mean(dim=1, keepdim=True)
1165
+ style_std = style.std(dim=1, keepdim=True)
1166
+ normalized_content = (content - content_mean) / content_std
1167
+ stylized_content = normalized_content * style_std + style_mean
1168
+ return stylized_content
1169
+ pdb.set_trace()
1170
+ hidden_states = adain(content=hidden_states, style=ip_style_hidden_states)
1171
+ else:
1172
+ hidden_states = hidden_states + self.style_scale * ip_style_hidden_states
1173
+
1174
+ if hidden_states.shape[0] == 4:
1175
+ hidden_states[0] = raw_hidden_states[0]
1176
+ hidden_states[2] = raw_hidden_states[2]
1177
+ # hidden_states = raw_hidden_states
1178
+
1179
+ # linear proj
1180
+ hidden_states = attn.to_out[0](hidden_states)
1181
+ # dropout
1182
+ hidden_states = attn.to_out[1](hidden_states)
1183
+
1184
+ if input_ndim == 4:
1185
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1186
+
1187
+ if attn.residual_connection:
1188
+ hidden_states = hidden_states + residual
1189
+
1190
+ hidden_states = hidden_states / attn.rescale_output_factor
1191
+
1192
+ self.reset_denoise_step()
1193
+ return hidden_states
1194
+
1195
+ class AttnProcessor2_0_hijack(torch.nn.Module):
1196
+ r"""
1197
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1198
+ """
1199
+
1200
+ def __init__(
1201
+ self,
1202
+ hidden_size=None,
1203
+ cross_attention_dim=None,
1204
+ save_in_unet='down',
1205
+ atten_control=None,
1206
+ fuSAttn=False,
1207
+ fuScale=0,
1208
+ end_fusion=0,
1209
+ attn_name=None,
1210
+ ):
1211
+ super().__init__()
1212
+ if not hasattr(F, "scaled_dot_product_attention"):
1213
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1214
+ self.atten_control = atten_control
1215
+ self.save_in_unet = save_in_unet
1216
+
1217
+ self.fuSAttn = fuSAttn
1218
+ self.fuScale = fuScale
1219
+ self.denoise_step = 0
1220
+ self.end_fusion = end_fusion
1221
+ self.name = attn_name
1222
+
1223
+ def reset_denoise_step(self):
1224
+ if self.denoise_step == 50:
1225
+ self.denoise_step = 0
1226
+ # if "up_blocks.0.attentions.1.transformer_blocks.0.attn1" in self.name:
1227
+ # print("attn1 reset successful")
1228
+
1229
+ def __call__(
1230
+ self,
1231
+ attn,
1232
+ hidden_states,
1233
+ encoder_hidden_states=None,
1234
+ attention_mask=None,
1235
+ temb=None,
1236
+ ):
1237
+ self.denoise_step += 1
1238
+ residual = hidden_states
1239
+
1240
+ if attn.spatial_norm is not None:
1241
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1242
+
1243
+ input_ndim = hidden_states.ndim
1244
+
1245
+ if input_ndim == 4:
1246
+ batch_size, channel, height, width = hidden_states.shape
1247
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1248
+
1249
+ batch_size, sequence_length, _ = (
1250
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1251
+ )
1252
+
1253
+ if attention_mask is not None:
1254
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1255
+ # scaled_dot_product_attention expects attention_mask shape to be
1256
+ # (batch, heads, source_length, target_length)
1257
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1258
+
1259
+ if attn.group_norm is not None:
1260
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1261
+
1262
+ query = attn.to_q(hidden_states)
1263
+
1264
+ if encoder_hidden_states is None:
1265
+ encoder_hidden_states = hidden_states
1266
+ elif attn.norm_cross:
1267
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1268
+
1269
+ key = attn.to_k(encoder_hidden_states)
1270
+ value = attn.to_v(encoder_hidden_states)
1271
+
1272
+ inner_dim = key.shape[-1]
1273
+ head_dim = inner_dim // attn.heads
1274
+
1275
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1276
+
1277
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1278
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1279
+
1280
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1281
+ # TODO: add support for attn.scale when we move to Torch 2.1
1282
+ if self.fuSAttn and self.denoise_step <= self.end_fusion:
1283
+ assert query.shape[0] == 4
1284
+ if "up_blocks.1.attentions.2.transformer_blocks.1" in self.name and self.denoise_step == self.end_fusion:
1285
+ print("now: ", self.denoise_step, "end now:", self.end_fusion, "scale: ", self.fuScale)
1286
+ # pdb.set_trace()
1287
+ scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
1288
+ attn_probs = (torch.matmul(query, key.transpose(-2, -1)) * scale_factor).softmax(dim=-1)
1289
+ attn_probs[1] = self.fuScale*attn_probs[1] + (1-self.fuScale)*attn_probs[0]
1290
+ attn_probs[3] = self.fuScale*attn_probs[3] + (1-self.fuScale)*attn_probs[2]
1291
+ hidden_states = torch.matmul(attn_probs, value)
1292
+ else:
1293
+ hidden_states = F.scaled_dot_product_attention(
1294
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1295
+ )
1296
+
1297
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1298
+ hidden_states = hidden_states.to(query.dtype)
1299
+
1300
+ # linear proj
1301
+ hidden_states = attn.to_out[0](hidden_states)
1302
+ # dropout
1303
+ hidden_states = attn.to_out[1](hidden_states)
1304
+
1305
+ if input_ndim == 4:
1306
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1307
+
1308
+ if attn.residual_connection:
1309
+ hidden_states = hidden_states + residual
1310
+
1311
+ hidden_states = hidden_states / attn.rescale_output_factor
1312
+
1313
+ if self.denoise_step == 50:
1314
+ self.reset_denoise_step()
1315
+ return hidden_states
1316
+
1317
+ class AttnProcessor2_0_exp(torch.nn.Module):
1318
+ r"""
1319
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1320
+ """
1321
+
1322
+ def __init__(
1323
+ self,
1324
+ hidden_size=None,
1325
+ cross_attention_dim=None,
1326
+ save_in_unet='down',
1327
+ atten_control=None,
1328
+ fuSAttn=False,
1329
+ fuScale=0,
1330
+ end_fusion=0,
1331
+ attn_name=None,
1332
+ ):
1333
+ super().__init__()
1334
+ if not hasattr(F, "scaled_dot_product_attention"):
1335
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1336
+ self.atten_control = atten_control
1337
+ self.save_in_unet = save_in_unet
1338
+
1339
+ self.fuSAttn = fuSAttn
1340
+ self.fuScale = fuScale
1341
+ self.denoise_step = 0
1342
+ self.end_fusion = end_fusion
1343
+ self.name = attn_name
1344
+
1345
+ def reset_denoise_step(self):
1346
+ if self.denoise_step == 50:
1347
+ self.denoise_step = 0
1348
+ # if "up_blocks.0.attentions.1.transformer_blocks.0.attn1" in self.name:
1349
+ # print("attn1 reset successful")
1350
+
1351
+ def __call__(
1352
+ self,
1353
+ attn,
1354
+ hidden_states,
1355
+ encoder_hidden_states=None,
1356
+ attention_mask=None,
1357
+ temb=None,
1358
+ ):
1359
+ self.denoise_step += 1
1360
+ residual = hidden_states
1361
+
1362
+ if attn.spatial_norm is not None:
1363
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1364
+
1365
+ input_ndim = hidden_states.ndim
1366
+
1367
+ if input_ndim == 4:
1368
+ batch_size, channel, height, width = hidden_states.shape
1369
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1370
+
1371
+ batch_size, sequence_length, _ = (
1372
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1373
+ )
1374
+
1375
+ if attention_mask is not None:
1376
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1377
+ # scaled_dot_product_attention expects attention_mask shape to be
1378
+ # (batch, heads, source_length, target_length)
1379
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1380
+
1381
+ if attn.group_norm is not None:
1382
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1383
+
1384
+ query = attn.to_q(hidden_states)
1385
+
1386
+ if encoder_hidden_states is None:
1387
+ encoder_hidden_states = hidden_states
1388
+ elif attn.norm_cross:
1389
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1390
+
1391
+ key = attn.to_k(encoder_hidden_states)
1392
+ value = attn.to_v(encoder_hidden_states)
1393
+
1394
+ inner_dim = key.shape[-1]
1395
+ head_dim = inner_dim // attn.heads
1396
+
1397
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1398
+
1399
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1400
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1401
+
1402
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1403
+ # TODO: add support for attn.scale when we move to Torch 2.1
1404
+ if self.fuSAttn and self.denoise_step <= self.end_fusion:
1405
+ assert query.shape[0] == 4
1406
+ if "up_blocks.1.attentions.2.transformer_blocks.1" in self.name and self.denoise_step == self.end_fusion:
1407
+ print("now: ", self.denoise_step, "end now:", self.end_fusion, "scale: ", self.fuScale)
1408
+ # pdb.set_trace()
1409
+ scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
1410
+ attn_probs = (torch.matmul(query, key.transpose(-2, -1)) * scale_factor).softmax(dim=-1)
1411
+
1412
+ attn_probs[1] = self.fuScale*attn_probs[1] + (1-self.fuScale)*attn_probs[0]
1413
+ attn_probs[3] = self.fuScale*attn_probs[3] + (1-self.fuScale)*attn_probs[2]
1414
+ print("exp code")
1415
+ pdb.set_trace()
1416
+ def adain(content, style):
1417
+ content_mean = content.mean(dim=1, keepdim=True)
1418
+ content_std = content.std(dim=1, keepdim=True)
1419
+ style_mean = style.mean(dim=1, keepdim=True)
1420
+ style_std = style.std(dim=1, keepdim=True)
1421
+ normalized_content = (content - content_mean) / content_std
1422
+ stylized_content = normalized_content * style_std + style_mean
1423
+ return stylized_content
1424
+ value[1] = adain(content=value[0], style=value[1])
1425
+ value[3] = adain(content=value[2], style=value[3])
1426
+ hidden_states = torch.matmul(attn_probs, value)
1427
+ else:
1428
+ hidden_states = F.scaled_dot_product_attention(
1429
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1430
+ )
1431
+
1432
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1433
+ hidden_states = hidden_states.to(query.dtype)
1434
+
1435
+ # linear proj
1436
+ hidden_states = attn.to_out[0](hidden_states)
1437
+ # dropout
1438
+ hidden_states = attn.to_out[1](hidden_states)
1439
+
1440
+ if input_ndim == 4:
1441
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1442
+
1443
+ if attn.residual_connection:
1444
+ hidden_states = hidden_states + residual
1445
+
1446
+ hidden_states = hidden_states / attn.rescale_output_factor
1447
+
1448
+ self.reset_denoise_step()
1449
+ return hidden_states
1450
+
1451
+ class IPAttnProcessor2_0_cross_modal(torch.nn.Module):
1452
+ r"""
1453
+ Attention processor for IP-Adapater for PyTorch 2.0.
1454
+ Args:
1455
+ hidden_size (`int`):
1456
+ The hidden size of the attention layer.
1457
+ cross_attention_dim (`int`):
1458
+ The number of channels in the `encoder_hidden_states`.
1459
+ scale (`float`, defaults to 1.0):
1460
+ the weight scale of image prompt.
1461
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
1462
+ The context length of the image features.
1463
+ """
1464
+
1465
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,
1466
+ fuAttn=False, fuIPAttn=False, adainIP=False, end_fusion=0, fuScale=0, attn_name=None):
1467
+ super().__init__()
1468
+
1469
+ if not hasattr(F, "scaled_dot_product_attention"):
1470
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1471
+
1472
+ self.hidden_size = hidden_size
1473
+ self.cross_attention_dim = cross_attention_dim
1474
+ self.scale = scale
1475
+ self.num_tokens = num_tokens
1476
+ self.skip = skip
1477
+
1478
+ self.fuAttn = fuAttn
1479
+ self.fuIPAttn = fuIPAttn
1480
+ self.adainIP = adainIP
1481
+ self.denoise_step = fuScale
1482
+ self.end_fusion = end_fusion
1483
+ self.fuScale = fuScale
1484
+ self.name = attn_name
1485
+
1486
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1487
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1488
+
1489
+ def reset_denoise_step(self):
1490
+ if self.denoise_step == 50:
1491
+ self.denoise_step = 0
1492
+ # if "up_blocks.0.attentions.1.transformer_blocks.0.attn2" in self.name:
1493
+ # print("attn2 reset successful")
1494
+
1495
+ def __call__(
1496
+ self,
1497
+ attn,
1498
+ hidden_states,
1499
+ encoder_hidden_states=None,
1500
+ attention_mask=None,
1501
+ temb=None,
1502
+ ):
1503
+ self.denoise_step += 1
1504
+ residual = hidden_states
1505
+
1506
+ if attn.spatial_norm is not None:
1507
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1508
+
1509
+ input_ndim = hidden_states.ndim
1510
+
1511
+ if input_ndim == 4:
1512
+ batch_size, channel, height, width = hidden_states.shape
1513
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1514
+
1515
+ batch_size, sequence_length, _ = (
1516
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1517
+ )
1518
+
1519
+ if attention_mask is not None:
1520
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1521
+ # scaled_dot_product_attention expects attention_mask shape to be
1522
+ # (batch, heads, source_length, target_length)
1523
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1524
+
1525
+ if attn.group_norm is not None:
1526
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1527
+
1528
+ query = attn.to_q(hidden_states)
1529
+
1530
+ if encoder_hidden_states is None:
1531
+ encoder_hidden_states = hidden_states
1532
+ else:
1533
+ # get encoder_hidden_states, ip_hidden_states
1534
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
1535
+ encoder_hidden_states, ip_hidden_states = (
1536
+ encoder_hidden_states[:, :end_pos, :],
1537
+ encoder_hidden_states[:, end_pos:, :],
1538
+ )
1539
+ if attn.norm_cross:
1540
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1541
+
1542
+ key = attn.to_k(encoder_hidden_states)
1543
+ value = attn.to_v(encoder_hidden_states)
1544
+
1545
+ inner_dim = key.shape[-1]
1546
+ head_dim = inner_dim // attn.heads
1547
+
1548
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1549
+
1550
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1551
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1552
+
1553
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1554
+ # TODO: add support for attn.scale when we move to Torch 2.1
1555
+ if self.fuAttn and self.denoise_step <= self.end_fusion:
1556
+ assert query.shape[0] == 4
1557
+ if "up_blocks.1.attentions.2.transformer_blocks.1" in self.name and self.denoise_step == self.end_fusion:
1558
+ print("fuAttn")
1559
+ print("now: ", self.denoise_step, "end now:", self.end_fusion, "scale: ", self.fuScale)
1560
+ scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
1561
+ text_attn_probs = (torch.matmul(query, key.transpose(-2, -1)) * scale_factor).softmax(dim=-1)
1562
+ text_attn_probs[1] = self.fuScale*text_attn_probs[1] + (1-self.fuScale)*text_attn_probs[0]
1563
+ text_attn_probs[3] = self.fuScale*text_attn_probs[3] + (1-self.fuScale)*text_attn_probs[2]
1564
+ hidden_states = torch.matmul(text_attn_probs, value)
1565
+ else:
1566
+ hidden_states = F.scaled_dot_product_attention(
1567
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1568
+ )
1569
+
1570
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1571
+ hidden_states = hidden_states.to(query.dtype)
1572
+
1573
+ raw_hidden_states = hidden_states
1574
+
1575
+ if not self.skip:
1576
+ # for ip-adapter
1577
+ ip_key = self.to_k_ip(ip_hidden_states)
1578
+ ip_value = self.to_v_ip(ip_hidden_states)
1579
+
1580
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1581
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1582
+
1583
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1584
+ # TODO: add support for attn.scale when we move to Torch 2.1
1585
+ if self.fuIPAttn and self.denoise_step <= self.end_fusion:
1586
+ assert query.shape[0] == 4
1587
+ print("fuIPAttn")
1588
+ if "down" in self.name:
1589
+ print("wrong! coding")
1590
+ exit()
1591
+ scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
1592
+ ip_attn_probs = torch.matmul(query, ip_key.transpose(-2, -1)) * scale_factor
1593
+ ip_attn_probs = F.softmax(ip_attn_probs, dim=-1)
1594
+ ip_attn_probs[1] = self.fuScale*ip_attn_probs[1] + (1-self.fuScale)*ip_attn_probs[0]
1595
+ ip_attn_probs[3] = self.fuScale*ip_attn_probs[3] + (1-self.fuScale)*ip_attn_probs[2]
1596
+ ip_hidden_states = torch.matmul(ip_attn_probs, ip_value)
1597
+ else:
1598
+ ip_hidden_states = F.scaled_dot_product_attention(
1599
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
1600
+ )
1601
+
1602
+ with torch.no_grad():
1603
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
1604
+ #print(self.attn_map.shape)
1605
+
1606
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1607
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
1608
+
1609
+ if self.adainIP:
1610
+ def adain(content, style):
1611
+ # 计算内容特征的均值和标准差
1612
+ content_mean = content.mean(dim=1, keepdim=True)
1613
+ content_std = content.std(dim=1, keepdim=True)
1614
+ # 计算风格特征的均值和标准差
1615
+ style_mean = style.mean(dim=1, keepdim=True)
1616
+ style_std = style.std(dim=1, keepdim=True)
1617
+ # 归一化内容特征并应用风格特征的均值和方差
1618
+ normalized_content = (content - content_mean) / content_std
1619
+ stylized_content = normalized_content * style_std + style_mean
1620
+ return stylized_content
1621
+ hidden_states = adain(content=hidden_states, style=ip_hidden_states)
1622
+ else:
1623
+ hidden_states = hidden_states + self.scale * ip_hidden_states
1624
+
1625
+ if hidden_states.shape[0] == 4:
1626
+ hidden_states[0] = raw_hidden_states[0]
1627
+ hidden_states[2] = raw_hidden_states[2]
1628
+
1629
+ # linear proj
1630
+ hidden_states = attn.to_out[0](hidden_states)
1631
+ # dropout
1632
+ hidden_states = attn.to_out[1](hidden_states)
1633
+
1634
+ if input_ndim == 4:
1635
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1636
+
1637
+ if attn.residual_connection:
1638
+ hidden_states = hidden_states + residual
1639
+
1640
+ hidden_states = hidden_states / attn.rescale_output_factor
1641
+
1642
+ if self.denoise_step == 50:
1643
+ self.reset_denoise_step()
1644
+
1645
+ return hidden_states
ip_adapter/ip_adapter.py ADDED
@@ -0,0 +1,1757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import os
3
+ from typing import List
4
+
5
+ import torch
6
+ from diffusers import StableDiffusionPipeline
7
+ from diffusers.pipelines.controlnet import MultiControlNetModel
8
+ from PIL import Image
9
+ from safetensors import safe_open
10
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
11
+ from torchvision import transforms
12
+ from .utils import is_torch2_available, get_generator
13
+
14
+ if is_torch2_available():
15
+ from .attention_processor import (
16
+ AttnProcessor2_0 as AttnProcessor,
17
+ )
18
+ from .attention_processor import (
19
+ CNAttnProcessor2_0 as CNAttnProcessor,
20
+ )
21
+ from .attention_processor import (
22
+ IPAttnProcessor2_0 as IPAttnProcessor,
23
+ )
24
+ from .attention_processor import IP_CS_AttnProcessor2_0 as IP_CS_AttnProcessor
25
+ from .attention_processor import IP_FuAd_AttnProcessor2_0 as IP_FuAd_AttnProcessor
26
+ from .attention_processor import IP_FuAd_AttnProcessor2_0_exp as IP_FuAd_AttnProcessor_exp
27
+ from .attention_processor import AttnProcessor2_0_exp as AttnProcessor_exp
28
+ from .attention_processor import AttnProcessor2_0_hijack as AttnProcessor_hijack
29
+ from .attention_processor import IPAttnProcessor2_0_cross_modal as IPAttnProcessor_cross_modal
30
+ else:
31
+ from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
32
+
33
+ from .resampler import Resampler
34
+
35
+ from transformers import AutoImageProcessor, AutoModel
36
+
37
+
38
+ class ImageProjModel(torch.nn.Module):
39
+ """Projection Model"""
40
+
41
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
42
+ super().__init__()
43
+
44
+ self.generator = None
45
+ self.cross_attention_dim = cross_attention_dim
46
+ self.clip_extra_context_tokens = clip_extra_context_tokens
47
+ # print(clip_embeddings_dim, self.clip_extra_context_tokens, cross_attention_dim)
48
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
49
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
50
+
51
+ def forward(self, image_embeds):
52
+ embeds = image_embeds
53
+ clip_extra_context_tokens = self.proj(embeds).reshape(
54
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
55
+ )
56
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
57
+ return clip_extra_context_tokens
58
+
59
+
60
+ class MLPProjModel(torch.nn.Module):
61
+ """SD model with image prompt"""
62
+
63
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
64
+ super().__init__()
65
+
66
+ self.proj = torch.nn.Sequential(
67
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
68
+ torch.nn.GELU(),
69
+ torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
70
+ torch.nn.LayerNorm(cross_attention_dim)
71
+ )
72
+
73
+ def forward(self, image_embeds):
74
+ clip_extra_context_tokens = self.proj(image_embeds)
75
+ return clip_extra_context_tokens
76
+
77
+
78
+ class IPAdapter:
79
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["block"]):
80
+ self.device = device
81
+ self.image_encoder_path = image_encoder_path
82
+ self.ip_ckpt = ip_ckpt
83
+ self.num_tokens = num_tokens
84
+ self.target_blocks = target_blocks
85
+
86
+ self.pipe = sd_pipe.to(self.device)
87
+ self.set_ip_adapter()
88
+
89
+ # load image encoder
90
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
91
+ self.device, dtype=torch.float16
92
+ )
93
+ self.clip_image_processor = CLIPImageProcessor()
94
+ # image proj model
95
+ self.image_proj_model = self.init_proj()
96
+
97
+ self.load_ip_adapter()
98
+
99
+ def init_proj(self):
100
+ image_proj_model = ImageProjModel(
101
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
102
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
103
+ clip_extra_context_tokens=self.num_tokens,
104
+ ).to(self.device, dtype=torch.float16)
105
+ return image_proj_model
106
+
107
+ def set_ip_adapter(self):
108
+ unet = self.pipe.unet
109
+ attn_procs = {}
110
+ for name in unet.attn_processors.keys():
111
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
112
+ if name.startswith("mid_block"):
113
+ hidden_size = unet.config.block_out_channels[-1]
114
+ elif name.startswith("up_blocks"):
115
+ block_id = int(name[len("up_blocks.")])
116
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
117
+ elif name.startswith("down_blocks"):
118
+ block_id = int(name[len("down_blocks.")])
119
+ hidden_size = unet.config.block_out_channels[block_id]
120
+ if cross_attention_dim is None:
121
+ attn_procs[name] = AttnProcessor()
122
+ else:
123
+ selected = False
124
+ for block_name in self.target_blocks:
125
+ if block_name in name:
126
+ selected = True
127
+ break
128
+ if selected:
129
+ attn_procs[name] = IPAttnProcessor(
130
+ hidden_size=hidden_size,
131
+ cross_attention_dim=cross_attention_dim,
132
+ scale=1.0,
133
+ num_tokens=self.num_tokens,
134
+ ).to(self.device, dtype=torch.float16)
135
+ else:
136
+ attn_procs[name] = IPAttnProcessor(
137
+ hidden_size=hidden_size,
138
+ cross_attention_dim=cross_attention_dim,
139
+ scale=1.0,
140
+ num_tokens=self.num_tokens,
141
+ skip=True
142
+ ).to(self.device, dtype=torch.float16)
143
+ unet.set_attn_processor(attn_procs)
144
+ if hasattr(self.pipe, "controlnet"):
145
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
146
+ for controlnet in self.pipe.controlnet.nets:
147
+ controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
148
+ else:
149
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
150
+
151
+ def load_ip_adapter(self):
152
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
153
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
154
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
155
+ for key in f.keys():
156
+ if key.startswith("image_proj."):
157
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
158
+ elif key.startswith("ip_adapter."):
159
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
160
+ else:
161
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
162
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
163
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
164
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
165
+
166
+ @torch.inference_mode()
167
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
168
+ if pil_image is not None:
169
+ if isinstance(pil_image, Image.Image):
170
+ pil_image = [pil_image]
171
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
172
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
173
+ else:
174
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
175
+
176
+ if content_prompt_embeds is not None:
177
+ clip_image_embeds = clip_image_embeds - content_prompt_embeds
178
+
179
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
180
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
181
+ return image_prompt_embeds, uncond_image_prompt_embeds
182
+
183
+ def set_scale(self, scale):
184
+ for attn_processor in self.pipe.unet.attn_processors.values():
185
+ if isinstance(attn_processor, IPAttnProcessor):
186
+ attn_processor.scale = scale
187
+
188
+ def generate(
189
+ self,
190
+ pil_image=None,
191
+ clip_image_embeds=None,
192
+ prompt=None,
193
+ negative_prompt=None,
194
+ scale=1.0,
195
+ num_samples=4,
196
+ seed=None,
197
+ guidance_scale=7.5,
198
+ num_inference_steps=30,
199
+ neg_content_emb=None,
200
+ **kwargs,
201
+ ):
202
+ self.set_scale(scale)
203
+
204
+ if pil_image is not None:
205
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
206
+ else:
207
+ num_prompts = clip_image_embeds.size(0)
208
+
209
+ if prompt is None:
210
+ prompt = "best quality, high quality"
211
+ if negative_prompt is None:
212
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
213
+
214
+ if not isinstance(prompt, List):
215
+ prompt = [prompt] * num_prompts
216
+ if not isinstance(negative_prompt, List):
217
+ negative_prompt = [negative_prompt] * num_prompts
218
+
219
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
220
+ pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=neg_content_emb
221
+ )
222
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
223
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
224
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
225
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
226
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
227
+
228
+ with torch.inference_mode():
229
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
230
+ prompt,
231
+ device=self.device,
232
+ num_images_per_prompt=num_samples,
233
+ do_classifier_free_guidance=True,
234
+ negative_prompt=negative_prompt,
235
+ )
236
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
237
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
238
+
239
+ generator = get_generator(seed, self.device)
240
+
241
+ images = self.pipe(
242
+ prompt_embeds=prompt_embeds,
243
+ negative_prompt_embeds=negative_prompt_embeds,
244
+ guidance_scale=guidance_scale,
245
+ num_inference_steps=num_inference_steps,
246
+ generator=generator,
247
+ **kwargs,
248
+ ).images
249
+
250
+ return images
251
+
252
+
253
+ class IPAdapter_CS:
254
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_content_tokens=4,
255
+ num_style_tokens=4,
256
+ target_content_blocks=["block"], target_style_blocks=["block"], content_image_encoder_path=None,
257
+ controlnet_adapter=False,
258
+ controlnet_target_content_blocks=None,
259
+ controlnet_target_style_blocks=None,
260
+ content_model_resampler=False,
261
+ style_model_resampler=False,
262
+ ):
263
+ self.device = device
264
+ self.image_encoder_path = image_encoder_path
265
+ self.ip_ckpt = ip_ckpt
266
+ self.num_content_tokens = num_content_tokens
267
+ self.num_style_tokens = num_style_tokens
268
+ self.content_target_blocks = target_content_blocks
269
+ self.style_target_blocks = target_style_blocks
270
+
271
+ self.content_model_resampler = content_model_resampler
272
+ self.style_model_resampler = style_model_resampler
273
+
274
+ self.controlnet_adapter = controlnet_adapter
275
+ self.controlnet_target_content_blocks = controlnet_target_content_blocks
276
+ self.controlnet_target_style_blocks = controlnet_target_style_blocks
277
+
278
+ self.pipe = sd_pipe.to(self.device)
279
+ self.set_ip_adapter()
280
+ self.content_image_encoder_path = content_image_encoder_path
281
+
282
+
283
+ # load image encoder
284
+ if content_image_encoder_path is not None:
285
+ self.content_image_encoder = AutoModel.from_pretrained(content_image_encoder_path).to(self.device,
286
+ dtype=torch.float16)
287
+ self.content_image_processor = AutoImageProcessor.from_pretrained(content_image_encoder_path)
288
+ else:
289
+ self.content_image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
290
+ self.device, dtype=torch.float16
291
+ )
292
+ self.content_image_processor = CLIPImageProcessor()
293
+ # model.requires_grad_(False)
294
+
295
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
296
+ self.device, dtype=torch.float16
297
+ )
298
+ # if self.use_CSD is not None:
299
+ # self.style_image_encoder = CSD_CLIP("vit_large", "default",self.use_CSD+"/ViT-L-14.pt")
300
+ # model_path = self.use_CSD+"/checkpoint.pth"
301
+ # checkpoint = torch.load(model_path, map_location="cpu")
302
+ # state_dict = convert_state_dict(checkpoint['model_state_dict'])
303
+ # self.style_image_encoder.load_state_dict(state_dict, strict=False)
304
+ #
305
+ # normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
306
+ # self.style_preprocess = transforms.Compose([
307
+ # transforms.Resize(size=224, interpolation=Func.InterpolationMode.BICUBIC),
308
+ # transforms.CenterCrop(224),
309
+ # transforms.ToTensor(),
310
+ # normalize,
311
+ # ])
312
+
313
+ self.clip_image_processor = CLIPImageProcessor()
314
+ # image proj model
315
+ self.content_image_proj_model = self.init_proj(self.num_content_tokens, content_or_style_='content',
316
+ model_resampler=self.content_model_resampler)
317
+ self.style_image_proj_model = self.init_proj(self.num_style_tokens, content_or_style_='style',
318
+ model_resampler=self.style_model_resampler)
319
+
320
+ self.load_ip_adapter()
321
+
322
+ def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False):
323
+
324
+ # print('@@@@',self.pipe.unet.config.cross_attention_dim,self.image_encoder.config.projection_dim)
325
+ if content_or_style_ == 'content' and self.content_image_encoder_path is not None:
326
+ image_proj_model = ImageProjModel(
327
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
328
+ clip_embeddings_dim=self.content_image_encoder.config.projection_dim,
329
+ clip_extra_context_tokens=num_tokens,
330
+ ).to(self.device, dtype=torch.float16)
331
+ return image_proj_model
332
+
333
+ image_proj_model = ImageProjModel(
334
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
335
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
336
+ clip_extra_context_tokens=num_tokens,
337
+ ).to(self.device, dtype=torch.float16)
338
+ return image_proj_model
339
+
340
+ def set_ip_adapter(self):
341
+ unet = self.pipe.unet
342
+ attn_procs = {}
343
+ for name in unet.attn_processors.keys():
344
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
345
+ if name.startswith("mid_block"):
346
+ hidden_size = unet.config.block_out_channels[-1]
347
+ elif name.startswith("up_blocks"):
348
+ block_id = int(name[len("up_blocks.")])
349
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
350
+ elif name.startswith("down_blocks"):
351
+ block_id = int(name[len("down_blocks.")])
352
+ hidden_size = unet.config.block_out_channels[block_id]
353
+ if cross_attention_dim is None:
354
+ attn_procs[name] = AttnProcessor()
355
+ else:
356
+ # layername_id += 1
357
+ selected = False
358
+ for block_name in self.style_target_blocks:
359
+ if block_name in name:
360
+ selected = True
361
+ # print(name)
362
+ attn_procs[name] = IP_CS_AttnProcessor(
363
+ hidden_size=hidden_size,
364
+ cross_attention_dim=cross_attention_dim,
365
+ style_scale=1.0,
366
+ style=True,
367
+ num_content_tokens=self.num_content_tokens,
368
+ num_style_tokens=self.num_style_tokens,
369
+ )
370
+ for block_name in self.content_target_blocks:
371
+ if block_name in name:
372
+ # selected = True
373
+ if selected is False:
374
+ attn_procs[name] = IP_CS_AttnProcessor(
375
+ hidden_size=hidden_size,
376
+ cross_attention_dim=cross_attention_dim,
377
+ content_scale=1.0,
378
+ content=True,
379
+ num_content_tokens=self.num_content_tokens,
380
+ num_style_tokens=self.num_style_tokens,
381
+ )
382
+ else:
383
+ attn_procs[name].set_content_ipa(content_scale=1.0)
384
+ # attn_procs[name].content=True
385
+
386
+ if selected is False:
387
+ attn_procs[name] = IP_CS_AttnProcessor(
388
+ hidden_size=hidden_size,
389
+ cross_attention_dim=cross_attention_dim,
390
+ num_content_tokens=self.num_content_tokens,
391
+ num_style_tokens=self.num_style_tokens,
392
+ skip=True,
393
+ )
394
+
395
+ attn_procs[name].to(self.device, dtype=torch.float16)
396
+ unet.set_attn_processor(attn_procs)
397
+ if hasattr(self.pipe, "controlnet"):
398
+ if self.controlnet_adapter is False:
399
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
400
+ for controlnet in self.pipe.controlnet.nets:
401
+ controlnet.set_attn_processor(CNAttnProcessor(
402
+ num_tokens=self.num_content_tokens + self.num_style_tokens))
403
+ else:
404
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(
405
+ num_tokens=self.num_content_tokens + self.num_style_tokens))
406
+
407
+ else:
408
+ controlnet_attn_procs = {}
409
+ controlnet_style_target_blocks = self.controlnet_target_style_blocks
410
+ controlnet_content_target_blocks = self.controlnet_target_content_blocks
411
+ for name in self.pipe.controlnet.attn_processors.keys():
412
+ # print(name)
413
+ cross_attention_dim = None if name.endswith(
414
+ "attn1.processor") else self.pipe.controlnet.config.cross_attention_dim
415
+ if name.startswith("mid_block"):
416
+ hidden_size = self.pipe.controlnet.config.block_out_channels[-1]
417
+ elif name.startswith("up_blocks"):
418
+ block_id = int(name[len("up_blocks.")])
419
+ hidden_size = list(reversed(self.pipe.controlnet.config.block_out_channels))[block_id]
420
+ elif name.startswith("down_blocks"):
421
+ block_id = int(name[len("down_blocks.")])
422
+ hidden_size = self.pipe.controlnet.config.block_out_channels[block_id]
423
+ if cross_attention_dim is None:
424
+ # layername_id += 1
425
+ controlnet_attn_procs[name] = AttnProcessor()
426
+
427
+ else:
428
+ # layername_id += 1
429
+ selected = False
430
+ for block_name in controlnet_style_target_blocks:
431
+ if block_name in name:
432
+ selected = True
433
+ # print(name)
434
+ controlnet_attn_procs[name] = IP_CS_AttnProcessor(
435
+ hidden_size=hidden_size,
436
+ cross_attention_dim=cross_attention_dim,
437
+ style_scale=1.0,
438
+ style=True,
439
+ num_content_tokens=self.num_content_tokens,
440
+ num_style_tokens=self.num_style_tokens,
441
+ )
442
+
443
+ for block_name in controlnet_content_target_blocks:
444
+ if block_name in name:
445
+ if selected is False:
446
+ controlnet_attn_procs[name] = IP_CS_AttnProcessor(
447
+ hidden_size=hidden_size,
448
+ cross_attention_dim=cross_attention_dim,
449
+ content_scale=1.0,
450
+ content=True,
451
+ num_content_tokens=self.num_content_tokens,
452
+ num_style_tokens=self.num_style_tokens,
453
+ )
454
+
455
+ selected = True
456
+ elif selected is True:
457
+ controlnet_attn_procs[name].set_content_ipa(content_scale=1.0)
458
+
459
+ # if args.content_image_encoder_type !='dinov2':
460
+ # weights = {
461
+ # "to_k_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_k_ip.weight"],
462
+ # "to_v_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_v_ip.weight"],
463
+ # }
464
+ # attn_procs[name].load_state_dict(weights)
465
+ if selected is False:
466
+ controlnet_attn_procs[name] = IP_CS_AttnProcessor(
467
+ hidden_size=hidden_size,
468
+ cross_attention_dim=cross_attention_dim,
469
+ num_content_tokens=self.num_content_tokens,
470
+ num_style_tokens=self.num_style_tokens,
471
+ skip=True,
472
+ )
473
+ controlnet_attn_procs[name].to(self.device, dtype=torch.float16)
474
+ # layer_name = name.split(".processor")[0]
475
+ # # print(state_dict["ip_adapter"].keys())
476
+ # weights = {
477
+ # "to_k_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_k_ip.weight"],
478
+ # "to_v_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_v_ip.weight"],
479
+ # }
480
+ # attn_procs[name].load_state_dict(weights)
481
+ self.pipe.controlnet.set_attn_processor(controlnet_attn_procs)
482
+
483
+ def load_ip_adapter(self):
484
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
485
+ state_dict = {"content_image_proj": {}, "style_image_proj": {}, "ip_adapter": {}}
486
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
487
+ for key in f.keys():
488
+ if key.startswith("content_image_proj."):
489
+ state_dict["content_image_proj"][key.replace("content_image_proj.", "")] = f.get_tensor(key)
490
+ elif key.startswith("style_image_proj."):
491
+ state_dict["style_image_proj"][key.replace("style_image_proj.", "")] = f.get_tensor(key)
492
+ elif key.startswith("ip_adapter."):
493
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
494
+ else:
495
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
496
+ self.content_image_proj_model.load_state_dict(state_dict["content_image_proj"])
497
+ self.style_image_proj_model.load_state_dict(state_dict["style_image_proj"])
498
+
499
+ if 'conv_in_unet_sd' in state_dict.keys():
500
+ self.pipe.unet.conv_in.load_state_dict(state_dict["conv_in_unet_sd"], strict=True)
501
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
502
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
503
+
504
+ if self.controlnet_adapter is True:
505
+ print('loading controlnet_adapter')
506
+ self.pipe.controlnet.load_state_dict(state_dict["controlnet_adapter_modules"], strict=False)
507
+
508
+ @torch.inference_mode()
509
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None,
510
+ content_or_style_=''):
511
+ # if pil_image is not None:
512
+ # if isinstance(pil_image, Image.Image):
513
+ # pil_image = [pil_image]
514
+ # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
515
+ # clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
516
+ # else:
517
+ # clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
518
+
519
+ # if content_prompt_embeds is not None:
520
+ # clip_image_embeds = clip_image_embeds - content_prompt_embeds
521
+
522
+ if content_or_style_ == 'content':
523
+ if pil_image is not None:
524
+ if isinstance(pil_image, Image.Image):
525
+ pil_image = [pil_image]
526
+ if self.content_image_proj_model is not None:
527
+ clip_image = self.content_image_processor(images=pil_image, return_tensors="pt").pixel_values
528
+ clip_image_embeds = self.content_image_encoder(
529
+ clip_image.to(self.device, dtype=torch.float16)).image_embeds
530
+ else:
531
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
532
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
533
+ else:
534
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
535
+
536
+ image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
537
+ uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
538
+ return image_prompt_embeds, uncond_image_prompt_embeds
539
+ if content_or_style_ == 'style':
540
+ if pil_image is not None:
541
+ if self.use_CSD is not None:
542
+ clip_image = self.style_preprocess(pil_image).unsqueeze(0).to(self.device, dtype=torch.float32)
543
+ clip_image_embeds = self.style_image_encoder(clip_image)
544
+ else:
545
+ if isinstance(pil_image, Image.Image):
546
+ pil_image = [pil_image]
547
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
548
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
549
+
550
+
551
+ else:
552
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
553
+ image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
554
+ uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
555
+ return image_prompt_embeds, uncond_image_prompt_embeds
556
+
557
+ def set_scale(self, content_scale, style_scale):
558
+ for attn_processor in self.pipe.unet.attn_processors.values():
559
+ if isinstance(attn_processor, IP_CS_AttnProcessor):
560
+ if attn_processor.content is True:
561
+ attn_processor.content_scale = content_scale
562
+
563
+ if attn_processor.style is True:
564
+ attn_processor.style_scale = style_scale
565
+ # print('style_scale:',style_scale)
566
+ if self.controlnet_adapter is not None:
567
+ for attn_processor in self.pipe.controlnet.attn_processors.values():
568
+
569
+ if isinstance(attn_processor, IP_CS_AttnProcessor):
570
+ if attn_processor.content is True:
571
+ attn_processor.content_scale = content_scale
572
+ # print(content_scale)
573
+
574
+ if attn_processor.style is True:
575
+ attn_processor.style_scale = style_scale
576
+
577
+ def generate(
578
+ self,
579
+ pil_content_image=None,
580
+ pil_style_image=None,
581
+ clip_content_image_embeds=None,
582
+ clip_style_image_embeds=None,
583
+ prompt=None,
584
+ negative_prompt=None,
585
+ content_scale=1.0,
586
+ style_scale=1.0,
587
+ num_samples=4,
588
+ seed=None,
589
+ guidance_scale=7.5,
590
+ num_inference_steps=30,
591
+ neg_content_emb=None,
592
+ **kwargs,
593
+ ):
594
+ self.set_scale(content_scale, style_scale)
595
+
596
+ if pil_content_image is not None:
597
+ num_prompts = 1 if isinstance(pil_content_image, Image.Image) else len(pil_content_image)
598
+ else:
599
+ num_prompts = clip_content_image_embeds.size(0)
600
+
601
+ if prompt is None:
602
+ prompt = "best quality, high quality"
603
+ if negative_prompt is None:
604
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
605
+
606
+ if not isinstance(prompt, List):
607
+ prompt = [prompt] * num_prompts
608
+ if not isinstance(negative_prompt, List):
609
+ negative_prompt = [negative_prompt] * num_prompts
610
+
611
+ content_image_prompt_embeds, uncond_content_image_prompt_embeds = self.get_image_embeds(
612
+ pil_image=pil_content_image, clip_image_embeds=clip_content_image_embeds
613
+ )
614
+ style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds(
615
+ pil_image=pil_style_image, clip_image_embeds=clip_style_image_embeds
616
+ )
617
+
618
+ bs_embed, seq_len, _ = content_image_prompt_embeds.shape
619
+ content_image_prompt_embeds = content_image_prompt_embeds.repeat(1, num_samples, 1)
620
+ content_image_prompt_embeds = content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
621
+ uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.repeat(1, num_samples, 1)
622
+ uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.view(bs_embed * num_samples, seq_len,
623
+ -1)
624
+
625
+ bs_style_embed, seq_style_len, _ = content_image_prompt_embeds.shape
626
+ style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1)
627
+ style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1)
628
+ uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1)
629
+ uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len,
630
+ -1)
631
+
632
+ with torch.inference_mode():
633
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
634
+ prompt,
635
+ device=self.device,
636
+ num_images_per_prompt=num_samples,
637
+ do_classifier_free_guidance=True,
638
+ negative_prompt=negative_prompt,
639
+ )
640
+ prompt_embeds = torch.cat([prompt_embeds_, content_image_prompt_embeds, style_image_prompt_embeds], dim=1)
641
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_,
642
+ uncond_content_image_prompt_embeds, uncond_style_image_prompt_embeds],
643
+ dim=1)
644
+
645
+ generator = get_generator(seed, self.device)
646
+
647
+ images = self.pipe(
648
+ prompt_embeds=prompt_embeds,
649
+ negative_prompt_embeds=negative_prompt_embeds,
650
+ guidance_scale=guidance_scale,
651
+ num_inference_steps=num_inference_steps,
652
+ generator=generator,
653
+ **kwargs,
654
+ ).images
655
+
656
+ return images
657
+
658
+
659
+ class IPAdapterXL_CS(IPAdapter_CS):
660
+ """SDXL"""
661
+
662
+ def generate(
663
+ self,
664
+ pil_content_image,
665
+ pil_style_image,
666
+ prompt=None,
667
+ negative_prompt=None,
668
+ content_scale=1.0,
669
+ style_scale=1.0,
670
+ num_samples=4,
671
+ seed=None,
672
+ content_image_embeds=None,
673
+ style_image_embeds=None,
674
+ num_inference_steps=30,
675
+ neg_content_emb=None,
676
+ neg_content_prompt=None,
677
+ neg_content_scale=1.0,
678
+
679
+ **kwargs,
680
+ ):
681
+ self.set_scale(content_scale, style_scale)
682
+
683
+ num_prompts = 1 if isinstance(pil_content_image, Image.Image) else len(pil_content_image)
684
+
685
+ if prompt is None:
686
+ prompt = "best quality, high quality"
687
+ if negative_prompt is None:
688
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
689
+
690
+ if not isinstance(prompt, List):
691
+ prompt = [prompt] * num_prompts
692
+ if not isinstance(negative_prompt, List):
693
+ negative_prompt = [negative_prompt] * num_prompts
694
+
695
+ content_image_prompt_embeds, uncond_content_image_prompt_embeds = self.get_image_embeds(pil_content_image,
696
+ content_image_embeds,
697
+ content_or_style_='content')
698
+
699
+
700
+
701
+ style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds(pil_style_image,
702
+ style_image_embeds,
703
+ content_or_style_='style')
704
+
705
+
706
+ bs_embed, seq_len, _ = content_image_prompt_embeds.shape
707
+
708
+ content_image_prompt_embeds = content_image_prompt_embeds.repeat(1, num_samples, 1)
709
+ content_image_prompt_embeds = content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
710
+
711
+ uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.repeat(1, num_samples, 1)
712
+ uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.view(bs_embed * num_samples, seq_len,
713
+ -1)
714
+ bs_style_embed, seq_style_len, _ = style_image_prompt_embeds.shape
715
+ style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1)
716
+ style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1)
717
+ uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1)
718
+ uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len,
719
+ -1)
720
+
721
+ with torch.inference_mode():
722
+ (
723
+ prompt_embeds,
724
+ negative_prompt_embeds,
725
+ pooled_prompt_embeds,
726
+ negative_pooled_prompt_embeds,
727
+ ) = self.pipe.encode_prompt(
728
+ prompt,
729
+ num_images_per_prompt=num_samples,
730
+ do_classifier_free_guidance=True,
731
+ negative_prompt=negative_prompt,
732
+ )
733
+ prompt_embeds = torch.cat([prompt_embeds, content_image_prompt_embeds, style_image_prompt_embeds], dim=1)
734
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds,
735
+ uncond_content_image_prompt_embeds, uncond_style_image_prompt_embeds],
736
+ dim=1)
737
+
738
+ # self.generator = get_generator(seed, self.device)
739
+ # latents = torch.randn((1, 4, 128, 128), generator=self.generator, device="cuda", dtype=torch.float16).to("cuda")
740
+ # latents = latents.repeat(2, 1, 1, 1)
741
+ # print(latents.shape)
742
+ images = self.pipe(
743
+ prompt_embeds=prompt_embeds,
744
+ negative_prompt_embeds=negative_prompt_embeds,
745
+ pooled_prompt_embeds=pooled_prompt_embeds,
746
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
747
+ num_inference_steps=num_inference_steps,
748
+ # generator=self.generator,
749
+ **kwargs,
750
+ ).images
751
+ return images
752
+
753
+
754
+ class CSGO(IPAdapterXL_CS):
755
+ """SDXL"""
756
+
757
+ def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False):
758
+ if content_or_style_ == 'content':
759
+ if model_resampler:
760
+ image_proj_model = Resampler(
761
+ dim=self.pipe.unet.config.cross_attention_dim,
762
+ depth=4,
763
+ dim_head=64,
764
+ heads=12,
765
+ num_queries=num_tokens,
766
+ embedding_dim=self.content_image_encoder.config.hidden_size,
767
+ output_dim=self.pipe.unet.config.cross_attention_dim,
768
+ ff_mult=4,
769
+ ).to(self.device, dtype=torch.float16)
770
+ else:
771
+ image_proj_model = ImageProjModel(
772
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
773
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
774
+ clip_extra_context_tokens=num_tokens,
775
+ ).to(self.device, dtype=torch.float16)
776
+ if content_or_style_ == 'style':
777
+ if model_resampler:
778
+ image_proj_model = Resampler(
779
+ dim=self.pipe.unet.config.cross_attention_dim,
780
+ depth=4,
781
+ dim_head=64,
782
+ heads=12,
783
+ num_queries=num_tokens,
784
+ embedding_dim=self.content_image_encoder.config.hidden_size,
785
+ output_dim=self.pipe.unet.config.cross_attention_dim,
786
+ ff_mult=4,
787
+ ).to(self.device, dtype=torch.float16)
788
+ else:
789
+ image_proj_model = ImageProjModel(
790
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
791
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
792
+ clip_extra_context_tokens=num_tokens,
793
+ ).to(self.device, dtype=torch.float16)
794
+ return image_proj_model
795
+
796
+ @torch.inference_mode()
797
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_or_style_=''):
798
+ if isinstance(pil_image, Image.Image):
799
+ pil_image = [pil_image]
800
+ if content_or_style_ == 'style':
801
+
802
+ if self.style_model_resampler:
803
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
804
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16),
805
+ output_hidden_states=True).hidden_states[-2]
806
+ image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
807
+ uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
808
+ else:
809
+
810
+
811
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
812
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
813
+ image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
814
+ uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
815
+ return image_prompt_embeds, uncond_image_prompt_embeds
816
+
817
+
818
+ else:
819
+
820
+ if self.content_image_encoder_path is not None:
821
+ clip_image = self.content_image_processor(images=pil_image, return_tensors="pt").pixel_values
822
+ outputs = self.content_image_encoder(clip_image.to(self.device, dtype=torch.float16),
823
+ output_hidden_states=True)
824
+ clip_image_embeds = outputs.last_hidden_state
825
+ image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
826
+
827
+ # uncond_clip_image_embeds = self.image_encoder(
828
+ # torch.zeros_like(clip_image), output_hidden_states=True
829
+ # ).last_hidden_state
830
+ uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
831
+ return image_prompt_embeds, uncond_image_prompt_embeds
832
+
833
+ else:
834
+ if self.content_model_resampler:
835
+
836
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
837
+
838
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
839
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
840
+ # clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
841
+ image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
842
+ # uncond_clip_image_embeds = self.image_encoder(
843
+ # torch.zeros_like(clip_image), output_hidden_states=True
844
+ # ).hidden_states[-2]
845
+ uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
846
+ else:
847
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
848
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
849
+ image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
850
+ uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
851
+
852
+ return image_prompt_embeds, uncond_image_prompt_embeds
853
+
854
+ # # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
855
+ # clip_image = clip_image.to(self.device, dtype=torch.float16)
856
+ # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
857
+ # image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
858
+ # uncond_clip_image_embeds = self.image_encoder(
859
+ # torch.zeros_like(clip_image), output_hidden_states=True
860
+ # ).hidden_states[-2]
861
+ # uncond_image_prompt_embeds = self.content_image_proj_model(uncond_clip_image_embeds)
862
+ # return image_prompt_embeds, uncond_image_prompt_embeds
863
+
864
+
865
+ class StyleStudio_Adapter(CSGO):
866
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device,
867
+ num_style_tokens=4,
868
+ target_style_blocks=["block"],
869
+ controlnet_adapter=False,
870
+ controlnet_target_content_blocks=None,
871
+ controlnet_target_style_blocks=None,
872
+ style_model_resampler=False,
873
+ fuAttn=False,
874
+ fuSAttn=False,
875
+ fuIPAttn=False,
876
+ fuScale=0,
877
+ adainIP=False,
878
+ end_fusion=0,
879
+ save_attn_map=False,
880
+ ):
881
+ self.fuAttn = fuAttn
882
+ self.fuSAttn = fuSAttn
883
+ self.fuIPAttn = fuIPAttn
884
+ self.adainIP = adainIP
885
+ self.fuScale = fuScale
886
+ # if self.adainIP:
887
+ # print("use the cross modal adain")
888
+ if self.fuSAttn:
889
+ print(f"hijack Self AttnMap in {end_fusion} steps", "fuScale is: ", fuScale)
890
+ if self.fuAttn:
891
+ print(f"hijack Cross AttnMap in {end_fusion} steps", "fuScale is: ", fuScale)
892
+ if self.fuIPAttn:
893
+ print(f"hijack IP AttnMap in {end_fusion} steps", "fuScale is: ", fuScale)
894
+ self.end_fusion = end_fusion
895
+ self.save_attn_map = save_attn_map
896
+
897
+ self.device = device
898
+ self.image_encoder_path = image_encoder_path
899
+ self.ip_ckpt = ip_ckpt
900
+ self.num_style_tokens = num_style_tokens
901
+ self.style_target_blocks = target_style_blocks
902
+
903
+ self.style_model_resampler = style_model_resampler
904
+
905
+ self.controlnet_adapter = controlnet_adapter
906
+ self.controlnet_target_content_blocks = controlnet_target_content_blocks
907
+ self.controlnet_target_style_blocks = controlnet_target_style_blocks
908
+
909
+ self.pipe = sd_pipe.to(self.device)
910
+ self.set_ip_adapter()
911
+
912
+
913
+ # load image encoder
914
+ # model.requires_grad_(False)
915
+
916
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
917
+ self.device, dtype=torch.float16
918
+ )
919
+
920
+ self.clip_image_processor = CLIPImageProcessor()
921
+ # image proj model
922
+ self.style_image_proj_model = self.init_proj(self.num_style_tokens, content_or_style_='style',
923
+ model_resampler=self.style_model_resampler)
924
+ self.load_ip_adapter()
925
+
926
+ def set_ip_adapter(self):
927
+ unet = self.pipe.unet
928
+ attn_procs = {}
929
+ for name in unet.attn_processors.keys():
930
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
931
+ if name.startswith("mid_block"):
932
+ hidden_size = unet.config.block_out_channels[-1]
933
+ elif name.startswith("up_blocks"):
934
+ block_id = int(name[len("up_blocks.")])
935
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
936
+ elif name.startswith("down_blocks"):
937
+ block_id = int(name[len("down_blocks.")])
938
+ hidden_size = unet.config.block_out_channels[block_id]
939
+ if cross_attention_dim is None:
940
+ attn_procs[name] = AttnProcessor_hijack(
941
+ fuSAttn=self.fuSAttn,
942
+ fuScale=self.fuScale,
943
+ end_fusion=self.end_fusion,
944
+ attn_name=name)
945
+ else:
946
+ # layername_id += 1
947
+ selected = False
948
+ for block_name in self.style_target_blocks:
949
+ if block_name in name:
950
+ selected = True
951
+ # print(name)
952
+ attn_procs[name] = IP_FuAd_AttnProcessor(
953
+ hidden_size=hidden_size,
954
+ cross_attention_dim=cross_attention_dim,
955
+ style_scale=1.0,
956
+ style=True,
957
+ num_style_tokens=self.num_style_tokens,
958
+ fuAttn=self.fuAttn,
959
+ fuIPAttn=self.fuIPAttn,
960
+ adainIP=self.adainIP,
961
+ fuScale=self.fuScale,
962
+ end_fusion=self.end_fusion,
963
+ attn_name=name,
964
+ )
965
+ if selected is False:
966
+ attn_procs[name] = IP_FuAd_AttnProcessor(
967
+ hidden_size=hidden_size,
968
+ cross_attention_dim=cross_attention_dim,
969
+ num_style_tokens=self.num_style_tokens,
970
+ skip=True,
971
+ fuAttn=self.fuAttn,
972
+
973
+ fuIPAttn=self.fuIPAttn,
974
+ adainIP=self.adainIP,
975
+ fuScale=self.fuScale,
976
+ end_fusion=self.end_fusion,
977
+ attn_name=name,
978
+ )
979
+
980
+ attn_procs[name].to(self.device, dtype=torch.float16)
981
+ unet.set_attn_processor(attn_procs)
982
+ if hasattr(self.pipe, "controlnet"):
983
+ if self.controlnet_adapter is False:
984
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
985
+ for controlnet in self.pipe.controlnet.nets:
986
+ controlnet.set_attn_processor(CNAttnProcessor(
987
+ num_tokens=self.num_content_tokens + self.num_style_tokens))
988
+ else:
989
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(
990
+ num_tokens=self.num_content_tokens + self.num_style_tokens))
991
+
992
+ def load_ip_adapter(self):
993
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
994
+ state_dict = {"content_image_proj": {}, "style_image_proj": {}, "ip_adapter": {}}
995
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
996
+ for key in f.keys():
997
+ if key.startswith("content_image_proj."):
998
+ state_dict["content_image_proj"][key.replace("content_image_proj.", "")] = f.get_tensor(key)
999
+ elif key.startswith("style_image_proj."):
1000
+ state_dict["style_image_proj"][key.replace("style_image_proj.", "")] = f.get_tensor(key)
1001
+ elif key.startswith("ip_adapter."):
1002
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
1003
+ else:
1004
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
1005
+ self.style_image_proj_model.load_state_dict(state_dict["style_image_proj"])
1006
+
1007
+ if 'conv_in_unet_sd' in state_dict.keys():
1008
+ self.pipe.unet.conv_in.load_state_dict(state_dict["conv_in_unet_sd"], strict=True)
1009
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
1010
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
1011
+
1012
+ def set_scale(self, style_scale):
1013
+ for attn_processor in self.pipe.unet.attn_processors.values():
1014
+ if isinstance(attn_processor, IP_FuAd_AttnProcessor):
1015
+ if attn_processor.style is True:
1016
+ attn_processor.style_scale = style_scale
1017
+ # print('style_scale:',style_scale)
1018
+
1019
+ def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False):
1020
+ if content_or_style_ == 'content':
1021
+ if model_resampler:
1022
+ image_proj_model = Resampler(
1023
+ dim=self.pipe.unet.config.cross_attention_dim,
1024
+ depth=4,
1025
+ dim_head=64,
1026
+ heads=12,
1027
+ num_queries=num_tokens,
1028
+ embedding_dim=self.image_encoder.config.hidden_size,
1029
+ output_dim=self.pipe.unet.config.cross_attention_dim,
1030
+ ff_mult=4,
1031
+ ).to(self.device, dtype=torch.float16)
1032
+ else:
1033
+ image_proj_model = ImageProjModel(
1034
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
1035
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
1036
+ clip_extra_context_tokens=num_tokens,
1037
+ ).to(self.device, dtype=torch.float16)
1038
+ if content_or_style_ == 'style':
1039
+ if model_resampler:
1040
+ image_proj_model = Resampler(
1041
+ dim=self.pipe.unet.config.cross_attention_dim,
1042
+ depth=4,
1043
+ dim_head=64,
1044
+ heads=12,
1045
+ num_queries=num_tokens,
1046
+ embedding_dim=self.image_encoder.config.hidden_size,
1047
+ output_dim=self.pipe.unet.config.cross_attention_dim,
1048
+ ff_mult=4,
1049
+ ).to(self.device, dtype=torch.float16)
1050
+ else:
1051
+ image_proj_model = ImageProjModel(
1052
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
1053
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
1054
+ clip_extra_context_tokens=num_tokens,
1055
+ ).to(self.device, dtype=torch.float16)
1056
+ return image_proj_model
1057
+
1058
+ @torch.inference_mode()
1059
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
1060
+ if isinstance(pil_image, Image.Image):
1061
+ pil_image = [pil_image]
1062
+ if self.style_model_resampler:
1063
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1064
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16),
1065
+ output_hidden_states=True).hidden_states[-2]
1066
+ image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
1067
+ uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
1068
+ else:
1069
+
1070
+
1071
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1072
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
1073
+ image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
1074
+ uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
1075
+ return image_prompt_embeds, uncond_image_prompt_embeds
1076
+
1077
+ @torch.inference_mode()
1078
+ def get_neg_image_embeds(self, pil_image=None, clip_image_embeds=None):
1079
+ if isinstance(pil_image, Image.Image):
1080
+ pil_image = [pil_image]
1081
+
1082
+ if self.style_model_resampler:
1083
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1084
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16),
1085
+ output_hidden_states=True).hidden_states[-2]
1086
+ neg_image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
1087
+ else:
1088
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1089
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
1090
+ neg_image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
1091
+ return neg_image_prompt_embeds
1092
+
1093
+ def set_endFusion(self, end_T):
1094
+ for attn_processor in self.pipe.unet.attn_processors.values():
1095
+ if isinstance(attn_processor, AttnProcessor_hijack):
1096
+ attn_processor.end_fusion = end_T
1097
+
1098
+ def set_SAttn(self, use_SAttn):
1099
+ for attn_processor in self.pipe.unet.attn_processors.values():
1100
+ if isinstance(attn_processor, AttnProcessor_hijack):
1101
+ attn_processor.fuSAttn = use_SAttn
1102
+
1103
+ def set_adain(self, use_CMA):
1104
+ for attn_processor in self.pipe.unet.attn_processors.values():
1105
+ if isinstance(attn_processor, IP_FuAd_AttnProcessor):
1106
+ attn_processor.adainIP = use_CMA
1107
+
1108
+ def generate(
1109
+ self,
1110
+ pil_style_image,
1111
+
1112
+ neg_pil_style_image=None,
1113
+
1114
+ prompt=None,
1115
+ negative_prompt=None,
1116
+ num_samples=2,
1117
+ style_image_embeds=None,
1118
+ num_inference_steps=30,
1119
+ end_fusion=20,
1120
+ cross_modal_adain=True,
1121
+ use_SAttn=True,
1122
+ **kwargs,
1123
+ ):
1124
+
1125
+ self.set_endFusion(end_T = end_fusion)
1126
+ self.set_adain(use_CMA=cross_modal_adain)
1127
+ self.set_SAttn(use_SAttn=use_SAttn)
1128
+
1129
+ # self.set_scale(style_scale=style_scale)
1130
+ num_prompts = 1 if isinstance(pil_style_image, Image.Image) else len(pil_style_image)
1131
+
1132
+ if prompt is None:
1133
+ prompt = "best quality, high quality"
1134
+ if negative_prompt is None:
1135
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
1136
+
1137
+ if not isinstance(prompt, List):
1138
+ prompt = [prompt] * num_prompts
1139
+ if not isinstance(negative_prompt, List):
1140
+ negative_prompt = [negative_prompt] * num_prompts
1141
+
1142
+ style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds(
1143
+ pil_style_image,
1144
+ style_image_embeds,
1145
+ )
1146
+
1147
+ if neg_pil_style_image is not None:
1148
+ print("using neg style image")
1149
+ neg_style_image_prompt_embeds = self.get_neg_image_embeds(neg_pil_style_image,
1150
+ style_image_embeds,)
1151
+ cos_sim_neg = F.cosine_similarity(style_image_prompt_embeds, neg_style_image_prompt_embeds.squeeze(0).unsqueeze(1), dim=-1)
1152
+ cos_sim_uncond = F.cosine_similarity(style_image_prompt_embeds, uncond_style_image_prompt_embeds.squeeze(0).unsqueeze(1), dim=-1)
1153
+ print(f"neg cos sim is: {cos_sim_neg.diagonal()}")
1154
+ print(f"uncond cos sim is: {cos_sim_uncond.diagonal()}")
1155
+ uncond_style_image_prompt_embeds = neg_style_image_prompt_embeds
1156
+
1157
+ bs_embed, seq_style_len, _ = style_image_prompt_embeds.shape
1158
+ style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1)
1159
+ style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1)
1160
+ uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1)
1161
+ uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len,
1162
+ -1)
1163
+
1164
+ with torch.inference_mode():
1165
+ (
1166
+ prompt_embeds,
1167
+ negative_prompt_embeds,
1168
+ pooled_prompt_embeds,
1169
+ negative_pooled_prompt_embeds,
1170
+ ) = self.pipe.encode_prompt(
1171
+ prompt,
1172
+ num_images_per_prompt=num_samples,
1173
+ do_classifier_free_guidance=True,
1174
+ negative_prompt=negative_prompt,
1175
+ )
1176
+ prompt_embeds = torch.cat([prompt_embeds, style_image_prompt_embeds], dim=1)
1177
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds,
1178
+ uncond_style_image_prompt_embeds],
1179
+ dim=1)
1180
+
1181
+ images = self.pipe(
1182
+ prompt_embeds=prompt_embeds,
1183
+ negative_prompt_embeds=negative_prompt_embeds,
1184
+ pooled_prompt_embeds=pooled_prompt_embeds,
1185
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1186
+ num_inference_steps=num_inference_steps,
1187
+ **kwargs,
1188
+ ).images
1189
+ return images
1190
+
1191
+ # StyleStudio_Adapter experiment code
1192
+ class StyleStudio_Adapter_exp(StyleStudio_Adapter):
1193
+ def set_ip_adapter(self):
1194
+ unet = self.pipe.unet
1195
+ attn_procs = {}
1196
+ for name in unet.attn_processors.keys():
1197
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
1198
+ if name.startswith("mid_block"):
1199
+ hidden_size = unet.config.block_out_channels[-1]
1200
+ elif name.startswith("up_blocks"):
1201
+ block_id = int(name[len("up_blocks.")])
1202
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
1203
+ elif name.startswith("down_blocks"):
1204
+ block_id = int(name[len("down_blocks.")])
1205
+ hidden_size = unet.config.block_out_channels[block_id]
1206
+ if cross_attention_dim is None:
1207
+ attn_procs[name] = AttnProcessor_exp(
1208
+ fuSAttn=self.fuSAttn,
1209
+ fuScale=self.fuScale,
1210
+ end_fusion=self.end_fusion,
1211
+ attn_name=name)
1212
+ else:
1213
+ # layername_id += 1
1214
+ selected = False
1215
+ for block_name in self.style_target_blocks:
1216
+ if block_name in name:
1217
+ selected = True
1218
+ # print(name)
1219
+ # 将所有的StyleBlock中的都改为FuAdAttn
1220
+ attn_procs[name] = IP_FuAd_AttnProcessor_exp(
1221
+ hidden_size=hidden_size,
1222
+ cross_attention_dim=cross_attention_dim,
1223
+ style_scale=1.0,
1224
+ style=True,
1225
+ num_content_tokens=self.num_content_tokens,
1226
+ num_style_tokens=self.num_style_tokens,
1227
+ fuAttn=self.fuAttn,
1228
+ fuIPAttn=self.fuIPAttn,
1229
+ adainIP=self.adainIP,
1230
+ fuScale=self.fuScale,
1231
+ end_fusion=self.end_fusion,
1232
+ attn_name=name,
1233
+ save_attn_map=self.save_attn_map,
1234
+ )
1235
+ # 没有CSGO中关于Content Control的需求 因此就将这个处理Content tokens Cross Attention 删除
1236
+ # 并且这里应该是CSGO代码中 有问题的部分 不论如何这里都会被之后的重置
1237
+ # 并且在CSGO的设计里Content Block和Style Block是没有子集的
1238
+ # selected False表明不是Style Block 关键是 Skip = True
1239
+ if selected is False:
1240
+ attn_procs[name] = IP_FuAd_AttnProcessor_exp(
1241
+ hidden_size=hidden_size,
1242
+ cross_attention_dim=cross_attention_dim,
1243
+ num_content_tokens=self.num_content_tokens,
1244
+ num_style_tokens=self.num_style_tokens,
1245
+ skip=True,
1246
+ fuAttn=self.fuAttn,
1247
+ fuIPAttn=self.fuIPAttn,
1248
+ adainIP=self.adainIP,
1249
+ fuScale=self.fuScale,
1250
+ end_fusion=self.end_fusion,
1251
+ attn_name=name,
1252
+ save_attn_map=self.save_attn_map,
1253
+ )
1254
+ # attn_procs[name] = IP_FuAd_AttnProcessor_exp(
1255
+ # hidden_size=hidden_size,
1256
+ # cross_attention_dim=cross_attention_dim,
1257
+ # num_content_tokens=self.num_content_tokens,
1258
+ # num_style_tokens=self.num_style_tokens,
1259
+ # skip=True,
1260
+ # fuAttn=self.fuAttn,
1261
+ # fuIPAttn=self.fuIPAttn,
1262
+ # )
1263
+
1264
+ attn_procs[name].to(self.device, dtype=torch.float16)
1265
+ unet.set_attn_processor(attn_procs)
1266
+ if hasattr(self.pipe, "controlnet"):
1267
+ if self.controlnet_adapter is False:
1268
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
1269
+ for controlnet in self.pipe.controlnet.nets:
1270
+ controlnet.set_attn_processor(CNAttnProcessor(
1271
+ num_tokens=self.num_content_tokens + self.num_style_tokens))
1272
+ else:
1273
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(
1274
+ num_tokens=self.num_content_tokens + self.num_style_tokens))
1275
+ # 因为我们的代码中没有controlnet需要将Style 注入 这并不是一个I2I的任务
1276
+ # 因此我们将原本CSGO中和ControlNet中注入Style的部分给删除了
1277
+
1278
+ class IPAdapterXL(IPAdapter):
1279
+ """SDXL"""
1280
+
1281
+ def generate(
1282
+ self,
1283
+ pil_image,
1284
+ prompt=None,
1285
+ negative_prompt=None,
1286
+ scale=1.0,
1287
+ num_samples=4,
1288
+ seed=None,
1289
+ num_inference_steps=30,
1290
+ neg_content_emb=None,
1291
+ neg_content_prompt=None,
1292
+ neg_content_scale=1.0,
1293
+ **kwargs,
1294
+ ):
1295
+ self.set_scale(scale)
1296
+
1297
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
1298
+
1299
+ if prompt is None:
1300
+ prompt = "best quality, high quality"
1301
+ if negative_prompt is None:
1302
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
1303
+
1304
+ if not isinstance(prompt, List):
1305
+ prompt = [prompt] * num_prompts
1306
+ if not isinstance(negative_prompt, List):
1307
+ negative_prompt = [negative_prompt] * num_prompts
1308
+
1309
+ if neg_content_emb is None:
1310
+ if neg_content_prompt is not None:
1311
+ with torch.inference_mode():
1312
+ (
1313
+ prompt_embeds_, # torch.Size([1, 77, 2048])
1314
+ negative_prompt_embeds_,
1315
+ pooled_prompt_embeds_, # torch.Size([1, 1280])
1316
+ negative_pooled_prompt_embeds_,
1317
+ ) = self.pipe.encode_prompt(
1318
+ neg_content_prompt,
1319
+ num_images_per_prompt=num_samples,
1320
+ do_classifier_free_guidance=True,
1321
+ negative_prompt=negative_prompt,
1322
+ )
1323
+ pooled_prompt_embeds_ *= neg_content_scale
1324
+ else:
1325
+ pooled_prompt_embeds_ = neg_content_emb
1326
+ else:
1327
+ pooled_prompt_embeds_ = None
1328
+
1329
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image,
1330
+ content_prompt_embeds=pooled_prompt_embeds_)
1331
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
1332
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
1333
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1334
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
1335
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1336
+
1337
+ with torch.inference_mode():
1338
+ (
1339
+ prompt_embeds,
1340
+ negative_prompt_embeds,
1341
+ pooled_prompt_embeds,
1342
+ negative_pooled_prompt_embeds,
1343
+ ) = self.pipe.encode_prompt(
1344
+ prompt,
1345
+ num_images_per_prompt=num_samples,
1346
+ do_classifier_free_guidance=True,
1347
+ negative_prompt=negative_prompt,
1348
+ )
1349
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
1350
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
1351
+
1352
+ self.generator = get_generator(seed, self.device)
1353
+
1354
+ images = self.pipe(
1355
+ prompt_embeds=prompt_embeds,
1356
+ negative_prompt_embeds=negative_prompt_embeds,
1357
+ pooled_prompt_embeds=pooled_prompt_embeds,
1358
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1359
+ num_inference_steps=num_inference_steps,
1360
+ generator=self.generator,
1361
+ **kwargs,
1362
+ ).images
1363
+
1364
+ return images
1365
+
1366
+
1367
+ class IPAdapterXL_cross_modal(IPAdapterXL):
1368
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4,
1369
+ target_blocks=["block"],
1370
+ fuAttn=False,
1371
+ fuSAttn=False,
1372
+ fuIPAttn=False,
1373
+ fuScale=0,
1374
+ adainIP=False,
1375
+ end_fusion=0,
1376
+ save_attn_map=False,):
1377
+ self.fuAttn = fuAttn
1378
+ self.fuSAttn = fuSAttn
1379
+ self.fuIPAttn = fuIPAttn
1380
+ self.adainIP = adainIP
1381
+ self.fuScale = fuScale
1382
+ if self.fuSAttn:
1383
+ print(f"hijack Self AttnMap in {end_fusion} steps", "fuScale is: ", fuScale)
1384
+ if self.fuAttn:
1385
+ print(f"hijack Cross AttnMap in {end_fusion} steps", "fuScale is: ", fuScale)
1386
+ if self.fuIPAttn:
1387
+ print(f"hijack IP AttnMap in {end_fusion} steps", "fuScale is: ", fuScale)
1388
+ self.end_fusion = end_fusion
1389
+ self.save_attn_map = save_attn_map
1390
+
1391
+ self.device = device
1392
+ self.image_encoder_path = image_encoder_path
1393
+ self.ip_ckpt = ip_ckpt
1394
+ self.num_tokens = num_tokens
1395
+ self.target_blocks = target_blocks
1396
+
1397
+ self.pipe = sd_pipe.to(self.device)
1398
+ self.set_ip_adapter()
1399
+
1400
+ # load image encoder
1401
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
1402
+ self.device, dtype=torch.float16
1403
+ )
1404
+ self.clip_image_processor = CLIPImageProcessor()
1405
+ # image proj model
1406
+ self.image_proj_model = self.init_proj()
1407
+
1408
+ self.load_ip_adapter()
1409
+
1410
+ def init_proj(self):
1411
+ image_proj_model = ImageProjModel(
1412
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
1413
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
1414
+ clip_extra_context_tokens=self.num_tokens,
1415
+ ).to(self.device, dtype=torch.float16)
1416
+ return image_proj_model
1417
+
1418
+ def set_ip_adapter(self):
1419
+ unet = self.pipe.unet
1420
+ attn_procs = {}
1421
+ for name in unet.attn_processors.keys():
1422
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
1423
+ if name.startswith("mid_block"):
1424
+ hidden_size = unet.config.block_out_channels[-1]
1425
+ elif name.startswith("up_blocks"):
1426
+ block_id = int(name[len("up_blocks.")])
1427
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
1428
+ elif name.startswith("down_blocks"):
1429
+ block_id = int(name[len("down_blocks.")])
1430
+ hidden_size = unet.config.block_out_channels[block_id]
1431
+ if cross_attention_dim is None:
1432
+ attn_procs[name] = AttnProcessor_hijack(
1433
+ fuSAttn=self.fuSAttn,
1434
+ fuScale=self.fuScale,
1435
+ end_fusion=self.end_fusion,
1436
+ attn_name=name) # Self Attention
1437
+ else: # Cross Attention
1438
+ selected = False
1439
+ for block_name in self.target_blocks:
1440
+ if block_name in name:
1441
+ selected = True
1442
+ break
1443
+ if selected:
1444
+ attn_procs[name] = IPAttnProcessor_cross_modal(
1445
+ hidden_size=hidden_size,
1446
+ cross_attention_dim=cross_attention_dim,
1447
+ scale=1.0,
1448
+ num_tokens=self.num_tokens,
1449
+ fuAttn=self.fuAttn,
1450
+ fuIPAttn=self.fuIPAttn,
1451
+ adainIP=self.adainIP,
1452
+ fuScale=self.fuScale,
1453
+ end_fusion=self.end_fusion,
1454
+ attn_name=name,
1455
+ ).to(self.device, dtype=torch.float16)
1456
+ else:
1457
+ attn_procs[name] = IPAttnProcessor_cross_modal(
1458
+ hidden_size=hidden_size,
1459
+ cross_attention_dim=cross_attention_dim,
1460
+ scale=1.0,
1461
+ num_tokens=self.num_tokens,
1462
+ skip=True,
1463
+ fuAttn=self.fuAttn,
1464
+ fuIPAttn=self.fuIPAttn,
1465
+ adainIP=self.adainIP,
1466
+ fuScale=self.fuScale,
1467
+ end_fusion=self.end_fusion,
1468
+ attn_name=name,
1469
+ ).to(self.device, dtype=torch.float16)
1470
+ unet.set_attn_processor(attn_procs)
1471
+ if hasattr(self.pipe, "controlnet"):
1472
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
1473
+ for controlnet in self.pipe.controlnet.nets:
1474
+ controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
1475
+ else:
1476
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
1477
+
1478
+ def load_ip_adapter(self):
1479
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
1480
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
1481
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
1482
+ for key in f.keys():
1483
+ if key.startswith("image_proj."):
1484
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
1485
+ elif key.startswith("ip_adapter."):
1486
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
1487
+ else:
1488
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
1489
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
1490
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
1491
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
1492
+
1493
+ @torch.inference_mode()
1494
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
1495
+ if pil_image is not None:
1496
+ if isinstance(pil_image, Image.Image):
1497
+ pil_image = [pil_image]
1498
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1499
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
1500
+ else:
1501
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
1502
+
1503
+ if content_prompt_embeds is not None:
1504
+ clip_image_embeds = clip_image_embeds - content_prompt_embeds
1505
+
1506
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
1507
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
1508
+ return image_prompt_embeds, uncond_image_prompt_embeds
1509
+
1510
+ def set_scale(self, scale):
1511
+ for attn_processor in self.pipe.unet.attn_processors.values():
1512
+ if isinstance(attn_processor, IPAttnProcessor_cross_modal):
1513
+ attn_processor.scale = scale
1514
+
1515
+ @torch.inference_mode()
1516
+ def get_neg_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
1517
+ if pil_image is not None:
1518
+ if isinstance(pil_image, Image.Image):
1519
+ pil_image = [pil_image]
1520
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1521
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
1522
+ else:
1523
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
1524
+
1525
+ if content_prompt_embeds is not None:
1526
+ clip_image_embeds = clip_image_embeds - content_prompt_embeds
1527
+
1528
+ neg_image_prompt_embeds = self.image_proj_model(clip_image_embeds)
1529
+ return neg_image_prompt_embeds
1530
+
1531
+ def generate(
1532
+ self,
1533
+ pil_image,
1534
+ neg_pil_image=None,
1535
+ prompt=None,
1536
+ negative_prompt=None,
1537
+ scale=1.0,
1538
+ num_samples=4,
1539
+ seed=None,
1540
+ num_inference_steps=30,
1541
+ neg_content_emb=None,
1542
+ neg_content_prompt=None,
1543
+ neg_content_scale=1.0,
1544
+ **kwargs,
1545
+ ):
1546
+ self.set_scale(scale)
1547
+
1548
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
1549
+
1550
+ if prompt is None:
1551
+ prompt = "best quality, high quality"
1552
+ if negative_prompt is None:
1553
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
1554
+
1555
+ if not isinstance(prompt, List):
1556
+ prompt = [prompt] * num_prompts
1557
+ if not isinstance(negative_prompt, List):
1558
+ negative_prompt = [negative_prompt] * num_prompts
1559
+
1560
+ if neg_content_emb is None:
1561
+ if neg_content_prompt is not None:
1562
+ with torch.inference_mode():
1563
+ (
1564
+ prompt_embeds_, # torch.Size([1, 77, 2048])
1565
+ negative_prompt_embeds_,
1566
+ pooled_prompt_embeds_, # torch.Size([1, 1280])
1567
+ negative_pooled_prompt_embeds_,
1568
+ ) = self.pipe.encode_prompt(
1569
+ neg_content_prompt,
1570
+ num_images_per_prompt=num_samples,
1571
+ do_classifier_free_guidance=True,
1572
+ negative_prompt=negative_prompt,
1573
+ )
1574
+ pooled_prompt_embeds_ *= neg_content_scale
1575
+ else:
1576
+ pooled_prompt_embeds_ = neg_content_emb
1577
+ else:
1578
+ pooled_prompt_embeds_ = None
1579
+
1580
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image, content_prompt_embeds=pooled_prompt_embeds_)
1581
+
1582
+ if neg_pil_image is not None:
1583
+ neg_image_prompt_embeds = self.get_neg_image_embeds(neg_pil_image)
1584
+ cos_sim_neg = F.cosine_similarity(image_prompt_embeds, neg_image_prompt_embeds.squeeze(0).unsqueeze(1), dim=-1)
1585
+ cos_sim_uncond = F.cosine_similarity(image_prompt_embeds, uncond_image_prompt_embeds.squeeze(0).unsqueeze(1), dim=-1)
1586
+ print(f"neg cos sim is: {cos_sim_neg.diagonal()}")
1587
+ print(f"uncond cos sim is: {cos_sim_uncond.diagonal()}")
1588
+ uncond_image_prompt_embeds = neg_image_prompt_embeds
1589
+
1590
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
1591
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
1592
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1593
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
1594
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1595
+
1596
+ with torch.inference_mode():
1597
+ (
1598
+ prompt_embeds,
1599
+ negative_prompt_embeds,
1600
+ pooled_prompt_embeds,
1601
+ negative_pooled_prompt_embeds,
1602
+ ) = self.pipe.encode_prompt(
1603
+ prompt,
1604
+ num_images_per_prompt=num_samples,
1605
+ do_classifier_free_guidance=True,
1606
+ negative_prompt=negative_prompt,
1607
+ )
1608
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
1609
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
1610
+
1611
+ # self.generator = get_generator(seed, self.device)
1612
+
1613
+ images = self.pipe(
1614
+ prompt_embeds=prompt_embeds,
1615
+ negative_prompt_embeds=negative_prompt_embeds,
1616
+ pooled_prompt_embeds=pooled_prompt_embeds,
1617
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1618
+ num_inference_steps=num_inference_steps,
1619
+ # generator=self.generator,
1620
+ **kwargs,
1621
+ ).images
1622
+
1623
+ return images
1624
+
1625
+
1626
+ class IPAdapterPlus(IPAdapter):
1627
+ """IP-Adapter with fine-grained features"""
1628
+
1629
+ def init_proj(self):
1630
+ image_proj_model = Resampler(
1631
+ dim=self.pipe.unet.config.cross_attention_dim,
1632
+ depth=4,
1633
+ dim_head=64,
1634
+ heads=12,
1635
+ num_queries=self.num_tokens,
1636
+ embedding_dim=self.image_encoder.config.hidden_size,
1637
+ output_dim=self.pipe.unet.config.cross_attention_dim,
1638
+ ff_mult=4,
1639
+ ).to(self.device, dtype=torch.float16)
1640
+ return image_proj_model
1641
+
1642
+ @torch.inference_mode()
1643
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
1644
+ if isinstance(pil_image, Image.Image):
1645
+ pil_image = [pil_image]
1646
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1647
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
1648
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
1649
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
1650
+ uncond_clip_image_embeds = self.image_encoder(
1651
+ torch.zeros_like(clip_image), output_hidden_states=True
1652
+ ).hidden_states[-2]
1653
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
1654
+ return image_prompt_embeds, uncond_image_prompt_embeds
1655
+
1656
+
1657
+ class IPAdapterFull(IPAdapterPlus):
1658
+ """IP-Adapter with full features"""
1659
+
1660
+ def init_proj(self):
1661
+ image_proj_model = MLPProjModel(
1662
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
1663
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
1664
+ ).to(self.device, dtype=torch.float16)
1665
+ return image_proj_model
1666
+
1667
+
1668
+ class IPAdapterPlusXL(IPAdapter):
1669
+ """SDXL"""
1670
+
1671
+ def init_proj(self):
1672
+ image_proj_model = Resampler(
1673
+ dim=1280,
1674
+ depth=4,
1675
+ dim_head=64,
1676
+ heads=20,
1677
+ num_queries=self.num_tokens,
1678
+ embedding_dim=self.image_encoder.config.hidden_size,
1679
+ output_dim=self.pipe.unet.config.cross_attention_dim,
1680
+ ff_mult=4,
1681
+ ).to(self.device, dtype=torch.float16)
1682
+ return image_proj_model
1683
+
1684
+ @torch.inference_mode()
1685
+ def get_image_embeds(self, pil_image):
1686
+ if isinstance(pil_image, Image.Image):
1687
+ pil_image = [pil_image]
1688
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1689
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
1690
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
1691
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
1692
+ uncond_clip_image_embeds = self.image_encoder(
1693
+ torch.zeros_like(clip_image), output_hidden_states=True
1694
+ ).hidden_states[-2]
1695
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
1696
+ return image_prompt_embeds, uncond_image_prompt_embeds
1697
+
1698
+ def generate(
1699
+ self,
1700
+ pil_image,
1701
+ prompt=None,
1702
+ negative_prompt=None,
1703
+ scale=1.0,
1704
+ num_samples=4,
1705
+ seed=None,
1706
+ num_inference_steps=30,
1707
+ **kwargs,
1708
+ ):
1709
+ self.set_scale(scale)
1710
+
1711
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
1712
+
1713
+ if prompt is None:
1714
+ prompt = "best quality, high quality"
1715
+ if negative_prompt is None:
1716
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
1717
+
1718
+ if not isinstance(prompt, List):
1719
+ prompt = [prompt] * num_prompts
1720
+ if not isinstance(negative_prompt, List):
1721
+ negative_prompt = [negative_prompt] * num_prompts
1722
+
1723
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
1724
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
1725
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
1726
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1727
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
1728
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1729
+
1730
+ with torch.inference_mode():
1731
+ (
1732
+ prompt_embeds,
1733
+ negative_prompt_embeds,
1734
+ pooled_prompt_embeds,
1735
+ negative_pooled_prompt_embeds,
1736
+ ) = self.pipe.encode_prompt(
1737
+ prompt,
1738
+ num_images_per_prompt=num_samples,
1739
+ do_classifier_free_guidance=True,
1740
+ negative_prompt=negative_prompt,
1741
+ )
1742
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
1743
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
1744
+
1745
+ generator = get_generator(seed, self.device)
1746
+
1747
+ images = self.pipe(
1748
+ prompt_embeds=prompt_embeds,
1749
+ negative_prompt_embeds=negative_prompt_embeds,
1750
+ pooled_prompt_embeds=pooled_prompt_embeds,
1751
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1752
+ num_inference_steps=num_inference_steps,
1753
+ generator=generator,
1754
+ **kwargs,
1755
+ ).images
1756
+
1757
+ return images
ip_adapter/resampler.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads)
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
+ out = weight @ v
75
+
76
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
+
78
+ return self.to_out(out)
79
+
80
+
81
+ class Resampler(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim=1024,
85
+ depth=8,
86
+ dim_head=64,
87
+ heads=16,
88
+ num_queries=8,
89
+ embedding_dim=768,
90
+ output_dim=1024,
91
+ ff_mult=4,
92
+ max_seq_len: int = 257, # CLIP tokens + CLS token
93
+ apply_pos_emb: bool = False,
94
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95
+ ):
96
+ super().__init__()
97
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98
+
99
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100
+
101
+ self.proj_in = nn.Linear(embedding_dim, dim)
102
+
103
+ self.proj_out = nn.Linear(dim, output_dim)
104
+ self.norm_out = nn.LayerNorm(output_dim)
105
+
106
+ self.to_latents_from_mean_pooled_seq = (
107
+ nn.Sequential(
108
+ nn.LayerNorm(dim),
109
+ nn.Linear(dim, dim * num_latents_mean_pooled),
110
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111
+ )
112
+ if num_latents_mean_pooled > 0
113
+ else None
114
+ )
115
+
116
+ self.layers = nn.ModuleList([])
117
+ for _ in range(depth):
118
+ self.layers.append(
119
+ nn.ModuleList(
120
+ [
121
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122
+ FeedForward(dim=dim, mult=ff_mult),
123
+ ]
124
+ )
125
+ )
126
+
127
+ def forward(self, x):
128
+ if self.pos_emb is not None:
129
+ n, device = x.shape[1], x.device
130
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
131
+ x = x + pos_emb
132
+
133
+ latents = self.latents.repeat(x.size(0), 1, 1)
134
+
135
+ x = self.proj_in(x)
136
+
137
+ if self.to_latents_from_mean_pooled_seq:
138
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
141
+
142
+ for attn, ff in self.layers:
143
+ latents = attn(x, latents) + latents
144
+ latents = ff(latents) + latents
145
+
146
+ latents = self.proj_out(latents)
147
+ return self.norm_out(latents)
148
+
149
+
150
+ def masked_mean(t, *, dim, mask=None):
151
+ if mask is None:
152
+ return t.mean(dim=dim)
153
+
154
+ denom = mask.sum(dim=dim, keepdim=True)
155
+ mask = rearrange(mask, "b n -> b n 1")
156
+ masked_t = t.masked_fill(~mask, 0.0)
157
+
158
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
ip_adapter/utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ BLOCKS = {
7
+ 'content': ['down_blocks'],
8
+ 'style': ["up_blocks"],
9
+
10
+ }
11
+
12
+ controlnet_BLOCKS = {
13
+ 'content': [],
14
+ 'style': ["down_blocks"],
15
+ }
16
+
17
+
18
+ def resize_width_height(width, height, min_short_side=512, max_long_side=1024):
19
+
20
+ if width < height:
21
+
22
+ if width < min_short_side:
23
+ scale_factor = min_short_side / width
24
+ new_width = min_short_side
25
+ new_height = int(height * scale_factor)
26
+ else:
27
+ new_width, new_height = width, height
28
+ else:
29
+
30
+ if height < min_short_side:
31
+ scale_factor = min_short_side / height
32
+ new_width = int(width * scale_factor)
33
+ new_height = min_short_side
34
+ else:
35
+ new_width, new_height = width, height
36
+
37
+ if max(new_width, new_height) > max_long_side:
38
+ scale_factor = max_long_side / max(new_width, new_height)
39
+ new_width = int(new_width * scale_factor)
40
+ new_height = int(new_height * scale_factor)
41
+ return new_width, new_height
42
+
43
+ def resize_content(content_image):
44
+ max_long_side = 1024
45
+ min_short_side = 1024
46
+
47
+ new_width, new_height = resize_width_height(content_image.size[0], content_image.size[1],
48
+ min_short_side=min_short_side, max_long_side=max_long_side)
49
+ height = new_height // 16 * 16
50
+ width = new_width // 16 * 16
51
+ content_image = content_image.resize((width, height))
52
+
53
+ return width,height,content_image
54
+
55
+ attn_maps = {}
56
+ def hook_fn(name):
57
+ def forward_hook(module, input, output):
58
+ if hasattr(module.processor, "attn_map"):
59
+ attn_maps[name] = module.processor.attn_map
60
+ del module.processor.attn_map
61
+
62
+ return forward_hook
63
+
64
+ def register_cross_attention_hook(unet):
65
+ for name, module in unet.named_modules():
66
+ if name.split('.')[-1].startswith('attn2'):
67
+ module.register_forward_hook(hook_fn(name))
68
+
69
+ return unet
70
+
71
+ def upscale(attn_map, target_size):
72
+ attn_map = torch.mean(attn_map, dim=0)
73
+ attn_map = attn_map.permute(1,0)
74
+ temp_size = None
75
+
76
+ for i in range(0,5):
77
+ scale = 2 ** i
78
+ if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
79
+ temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
80
+ break
81
+
82
+ assert temp_size is not None, "temp_size cannot is None"
83
+
84
+ attn_map = attn_map.view(attn_map.shape[0], *temp_size)
85
+
86
+ attn_map = F.interpolate(
87
+ attn_map.unsqueeze(0).to(dtype=torch.float32),
88
+ size=target_size,
89
+ mode='bilinear',
90
+ align_corners=False
91
+ )[0]
92
+
93
+ attn_map = torch.softmax(attn_map, dim=0)
94
+ return attn_map
95
+ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
96
+
97
+ idx = 0 if instance_or_negative else 1
98
+ net_attn_maps = []
99
+
100
+ for name, attn_map in attn_maps.items():
101
+ attn_map = attn_map.cpu() if detach else attn_map
102
+ attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
103
+ attn_map = upscale(attn_map, image_size)
104
+ net_attn_maps.append(attn_map)
105
+
106
+ net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
107
+
108
+ return net_attn_maps
109
+
110
+ def attnmaps2images(net_attn_maps):
111
+
112
+ #total_attn_scores = 0
113
+ images = []
114
+
115
+ for attn_map in net_attn_maps:
116
+ attn_map = attn_map.cpu().numpy()
117
+ #total_attn_scores += attn_map.mean().item()
118
+
119
+ normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
120
+ normalized_attn_map = normalized_attn_map.astype(np.uint8)
121
+ #print("norm: ", normalized_attn_map.shape)
122
+ image = Image.fromarray(normalized_attn_map)
123
+
124
+ #image = fix_save_attn_map(attn_map)
125
+ images.append(image)
126
+
127
+ #print(total_attn_scores)
128
+ return images
129
+ def is_torch2_available():
130
+ return hasattr(F, "scaled_dot_product_attention")
131
+
132
+ def get_generator(seed, device):
133
+
134
+ if seed is not None:
135
+ if isinstance(seed, list):
136
+ generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
137
+ else:
138
+ generator = torch.Generator(device).manual_seed(seed)
139
+ else:
140
+ generator = None
141
+
142
+ return generator
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.25.1
2
+ torch==2.0.1
3
+ torchaudio==2.0.2
4
+ torchvision==0.15.2
5
+ transformers==4.40.2
6
+ accelerate
7
+ safetensors
8
+ einops
9
+ spaces==0.19.4
10
+ omegaconf
11
+ peft
12
+ huggingface-hub==0.24.5
13
+ opencv-python
14
+ insightface
15
+ gradio
16
+ controlnet_aux
17
+ gdown
18
+ peft