AIMS168 commited on
Commit
22d83cd
·
verified ·
1 Parent(s): fb09aef

Upload 5 files

Browse files
ip_adapter/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull,IPAdapterXL_CS,IPAdapter_CS
2
+ from .ip_adapter import CSGO
3
+ __all__ = [
4
+ "IPAdapter",
5
+ "IPAdapterPlus",
6
+ "IPAdapterPlusXL",
7
+ "IPAdapterXL",
8
+ "CSGO"
9
+ "IPAdapterFull",
10
+ ]
ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class AttnProcessor(nn.Module):
8
+ r"""
9
+ Default processor for performing attention-related computations.
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ hidden_size=None,
15
+ cross_attention_dim=None,
16
+ save_in_unet='down',
17
+ atten_control=None,
18
+ ):
19
+ super().__init__()
20
+ self.atten_control = atten_control
21
+ self.save_in_unet = save_in_unet
22
+
23
+ def __call__(
24
+ self,
25
+ attn,
26
+ hidden_states,
27
+ encoder_hidden_states=None,
28
+ attention_mask=None,
29
+ temb=None,
30
+ ):
31
+ residual = hidden_states
32
+
33
+ if attn.spatial_norm is not None:
34
+ hidden_states = attn.spatial_norm(hidden_states, temb)
35
+
36
+ input_ndim = hidden_states.ndim
37
+
38
+ if input_ndim == 4:
39
+ batch_size, channel, height, width = hidden_states.shape
40
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
41
+
42
+ batch_size, sequence_length, _ = (
43
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
44
+ )
45
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
46
+
47
+ if attn.group_norm is not None:
48
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
49
+
50
+ query = attn.to_q(hidden_states)
51
+
52
+ if encoder_hidden_states is None:
53
+ encoder_hidden_states = hidden_states
54
+ elif attn.norm_cross:
55
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
56
+
57
+ key = attn.to_k(encoder_hidden_states)
58
+ value = attn.to_v(encoder_hidden_states)
59
+
60
+ query = attn.head_to_batch_dim(query)
61
+ key = attn.head_to_batch_dim(key)
62
+ value = attn.head_to_batch_dim(value)
63
+
64
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
65
+ hidden_states = torch.bmm(attention_probs, value)
66
+ hidden_states = attn.batch_to_head_dim(hidden_states)
67
+
68
+ # linear proj
69
+ hidden_states = attn.to_out[0](hidden_states)
70
+ # dropout
71
+ hidden_states = attn.to_out[1](hidden_states)
72
+
73
+ if input_ndim == 4:
74
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
75
+
76
+ if attn.residual_connection:
77
+ hidden_states = hidden_states + residual
78
+
79
+ hidden_states = hidden_states / attn.rescale_output_factor
80
+
81
+ return hidden_states
82
+
83
+
84
+ class IPAttnProcessor(nn.Module):
85
+ r"""
86
+ Attention processor for IP-Adapater.
87
+ Args:
88
+ hidden_size (`int`):
89
+ The hidden size of the attention layer.
90
+ cross_attention_dim (`int`):
91
+ The number of channels in the `encoder_hidden_states`.
92
+ scale (`float`, defaults to 1.0):
93
+ the weight scale of image prompt.
94
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
95
+ The context length of the image features.
96
+ """
97
+
98
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,save_in_unet='down', atten_control=None):
99
+ super().__init__()
100
+
101
+ self.hidden_size = hidden_size
102
+ self.cross_attention_dim = cross_attention_dim
103
+ self.scale = scale
104
+ self.num_tokens = num_tokens
105
+ self.skip = skip
106
+
107
+ self.atten_control = atten_control
108
+ self.save_in_unet = save_in_unet
109
+
110
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
111
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
112
+
113
+ def __call__(
114
+ self,
115
+ attn,
116
+ hidden_states,
117
+ encoder_hidden_states=None,
118
+ attention_mask=None,
119
+ temb=None,
120
+ ):
121
+ residual = hidden_states
122
+
123
+ if attn.spatial_norm is not None:
124
+ hidden_states = attn.spatial_norm(hidden_states, temb)
125
+
126
+ input_ndim = hidden_states.ndim
127
+
128
+ if input_ndim == 4:
129
+ batch_size, channel, height, width = hidden_states.shape
130
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
131
+
132
+ batch_size, sequence_length, _ = (
133
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
134
+ )
135
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
136
+
137
+ if attn.group_norm is not None:
138
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
139
+
140
+ query = attn.to_q(hidden_states)
141
+
142
+ if encoder_hidden_states is None:
143
+ encoder_hidden_states = hidden_states
144
+ else:
145
+ # get encoder_hidden_states, ip_hidden_states
146
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
147
+ encoder_hidden_states, ip_hidden_states = (
148
+ encoder_hidden_states[:, :end_pos, :],
149
+ encoder_hidden_states[:, end_pos:, :],
150
+ )
151
+ if attn.norm_cross:
152
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
153
+
154
+ key = attn.to_k(encoder_hidden_states)
155
+ value = attn.to_v(encoder_hidden_states)
156
+
157
+ query = attn.head_to_batch_dim(query)
158
+ key = attn.head_to_batch_dim(key)
159
+ value = attn.head_to_batch_dim(value)
160
+
161
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
162
+ hidden_states = torch.bmm(attention_probs, value)
163
+ hidden_states = attn.batch_to_head_dim(hidden_states)
164
+
165
+ if not self.skip:
166
+ # for ip-adapter
167
+ ip_key = self.to_k_ip(ip_hidden_states)
168
+ ip_value = self.to_v_ip(ip_hidden_states)
169
+
170
+ ip_key = attn.head_to_batch_dim(ip_key)
171
+ ip_value = attn.head_to_batch_dim(ip_value)
172
+
173
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
174
+ self.attn_map = ip_attention_probs
175
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
176
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
177
+
178
+ hidden_states = hidden_states + self.scale * ip_hidden_states
179
+
180
+ # linear proj
181
+ hidden_states = attn.to_out[0](hidden_states)
182
+ # dropout
183
+ hidden_states = attn.to_out[1](hidden_states)
184
+
185
+ if input_ndim == 4:
186
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
187
+
188
+ if attn.residual_connection:
189
+ hidden_states = hidden_states + residual
190
+
191
+ hidden_states = hidden_states / attn.rescale_output_factor
192
+
193
+ return hidden_states
194
+
195
+
196
+ class AttnProcessor2_0(torch.nn.Module):
197
+ r"""
198
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ hidden_size=None,
204
+ cross_attention_dim=None,
205
+ save_in_unet='down',
206
+ atten_control=None,
207
+ ):
208
+ super().__init__()
209
+ if not hasattr(F, "scaled_dot_product_attention"):
210
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
211
+ self.atten_control = atten_control
212
+ self.save_in_unet = save_in_unet
213
+
214
+ def __call__(
215
+ self,
216
+ attn,
217
+ hidden_states,
218
+ encoder_hidden_states=None,
219
+ attention_mask=None,
220
+ temb=None,
221
+ ):
222
+ residual = hidden_states
223
+
224
+ if attn.spatial_norm is not None:
225
+ hidden_states = attn.spatial_norm(hidden_states, temb)
226
+
227
+ input_ndim = hidden_states.ndim
228
+
229
+ if input_ndim == 4:
230
+ batch_size, channel, height, width = hidden_states.shape
231
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
232
+
233
+ batch_size, sequence_length, _ = (
234
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
235
+ )
236
+
237
+ if attention_mask is not None:
238
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
239
+ # scaled_dot_product_attention expects attention_mask shape to be
240
+ # (batch, heads, source_length, target_length)
241
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
242
+
243
+ if attn.group_norm is not None:
244
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
245
+
246
+ query = attn.to_q(hidden_states)
247
+
248
+ if encoder_hidden_states is None:
249
+ encoder_hidden_states = hidden_states
250
+ elif attn.norm_cross:
251
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
252
+
253
+ key = attn.to_k(encoder_hidden_states)
254
+ value = attn.to_v(encoder_hidden_states)
255
+
256
+ inner_dim = key.shape[-1]
257
+ head_dim = inner_dim // attn.heads
258
+
259
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
260
+
261
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
262
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
263
+
264
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
265
+ # TODO: add support for attn.scale when we move to Torch 2.1
266
+ hidden_states = F.scaled_dot_product_attention(
267
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
268
+ )
269
+
270
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
271
+ hidden_states = hidden_states.to(query.dtype)
272
+
273
+ # linear proj
274
+ hidden_states = attn.to_out[0](hidden_states)
275
+ # dropout
276
+ hidden_states = attn.to_out[1](hidden_states)
277
+
278
+ if input_ndim == 4:
279
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
280
+
281
+ if attn.residual_connection:
282
+ hidden_states = hidden_states + residual
283
+
284
+ hidden_states = hidden_states / attn.rescale_output_factor
285
+
286
+ return hidden_states
287
+
288
+
289
+ class IPAttnProcessor2_0(torch.nn.Module):
290
+ r"""
291
+ Attention processor for IP-Adapater for PyTorch 2.0.
292
+ Args:
293
+ hidden_size (`int`):
294
+ The hidden size of the attention layer.
295
+ cross_attention_dim (`int`):
296
+ The number of channels in the `encoder_hidden_states`.
297
+ scale (`float`, defaults to 1.0):
298
+ the weight scale of image prompt.
299
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
300
+ The context length of the image features.
301
+ """
302
+
303
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,save_in_unet='down', atten_control=None):
304
+ super().__init__()
305
+
306
+ if not hasattr(F, "scaled_dot_product_attention"):
307
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
308
+
309
+ self.hidden_size = hidden_size
310
+ self.cross_attention_dim = cross_attention_dim
311
+ self.scale = scale
312
+ self.num_tokens = num_tokens
313
+ self.skip = skip
314
+
315
+ self.atten_control = atten_control
316
+ self.save_in_unet = save_in_unet
317
+
318
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
319
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
320
+
321
+ def __call__(
322
+ self,
323
+ attn,
324
+ hidden_states,
325
+ encoder_hidden_states=None,
326
+ attention_mask=None,
327
+ temb=None,
328
+ ):
329
+ residual = hidden_states
330
+
331
+ if attn.spatial_norm is not None:
332
+ hidden_states = attn.spatial_norm(hidden_states, temb)
333
+
334
+ input_ndim = hidden_states.ndim
335
+
336
+ if input_ndim == 4:
337
+ batch_size, channel, height, width = hidden_states.shape
338
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
339
+
340
+ batch_size, sequence_length, _ = (
341
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
342
+ )
343
+
344
+ if attention_mask is not None:
345
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
346
+ # scaled_dot_product_attention expects attention_mask shape to be
347
+ # (batch, heads, source_length, target_length)
348
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
349
+
350
+ if attn.group_norm is not None:
351
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
352
+
353
+ query = attn.to_q(hidden_states)
354
+
355
+ if encoder_hidden_states is None:
356
+ encoder_hidden_states = hidden_states
357
+ else:
358
+ # get encoder_hidden_states, ip_hidden_states
359
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
360
+ encoder_hidden_states, ip_hidden_states = (
361
+ encoder_hidden_states[:, :end_pos, :],
362
+ encoder_hidden_states[:, end_pos:, :],
363
+ )
364
+ if attn.norm_cross:
365
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
366
+
367
+ key = attn.to_k(encoder_hidden_states)
368
+ value = attn.to_v(encoder_hidden_states)
369
+
370
+ inner_dim = key.shape[-1]
371
+ head_dim = inner_dim // attn.heads
372
+
373
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
374
+
375
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
376
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
377
+
378
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
379
+ # TODO: add support for attn.scale when we move to Torch 2.1
380
+ hidden_states = F.scaled_dot_product_attention(
381
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
382
+ )
383
+
384
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
385
+ hidden_states = hidden_states.to(query.dtype)
386
+
387
+ if not self.skip:
388
+ # for ip-adapter
389
+ ip_key = self.to_k_ip(ip_hidden_states)
390
+ ip_value = self.to_v_ip(ip_hidden_states)
391
+
392
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
393
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
394
+
395
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
396
+ # TODO: add support for attn.scale when we move to Torch 2.1
397
+ ip_hidden_states = F.scaled_dot_product_attention(
398
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
399
+ )
400
+ with torch.no_grad():
401
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
402
+ #print(self.attn_map.shape)
403
+
404
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
405
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
406
+
407
+ hidden_states = hidden_states + self.scale * ip_hidden_states
408
+
409
+ # linear proj
410
+ hidden_states = attn.to_out[0](hidden_states)
411
+ # dropout
412
+ hidden_states = attn.to_out[1](hidden_states)
413
+
414
+ if input_ndim == 4:
415
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
416
+
417
+ if attn.residual_connection:
418
+ hidden_states = hidden_states + residual
419
+
420
+ hidden_states = hidden_states / attn.rescale_output_factor
421
+
422
+ return hidden_states
423
+
424
+
425
+ class IP_CS_AttnProcessor2_0(torch.nn.Module):
426
+ r"""
427
+ Attention processor for IP-Adapater for PyTorch 2.0.
428
+ Args:
429
+ hidden_size (`int`):
430
+ The hidden size of the attention layer.
431
+ cross_attention_dim (`int`):
432
+ The number of channels in the `encoder_hidden_states`.
433
+ scale (`float`, defaults to 1.0):
434
+ the weight scale of image prompt.
435
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
436
+ The context length of the image features.
437
+ """
438
+
439
+ def __init__(self, hidden_size, cross_attention_dim=None, content_scale=1.0,style_scale=1.0, num_content_tokens=4,num_style_tokens=4,
440
+ skip=False,content=False, style=False):
441
+ super().__init__()
442
+
443
+ if not hasattr(F, "scaled_dot_product_attention"):
444
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
445
+
446
+ self.hidden_size = hidden_size
447
+ self.cross_attention_dim = cross_attention_dim
448
+ self.content_scale = content_scale
449
+ self.style_scale = style_scale
450
+ self.num_content_tokens = num_content_tokens
451
+ self.num_style_tokens = num_style_tokens
452
+ self.skip = skip
453
+
454
+ self.content = content
455
+ self.style = style
456
+
457
+ if self.content or self.style:
458
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
459
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
460
+ self.to_k_ip_content =None
461
+ self.to_v_ip_content =None
462
+
463
+ def set_content_ipa(self,content_scale=1.0):
464
+
465
+ self.to_k_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
466
+ self.to_v_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
467
+ self.content_scale=content_scale
468
+ self.content =True
469
+
470
+ def __call__(
471
+ self,
472
+ attn,
473
+ hidden_states,
474
+ encoder_hidden_states=None,
475
+ attention_mask=None,
476
+ temb=None,
477
+ ):
478
+ residual = hidden_states
479
+
480
+ if attn.spatial_norm is not None:
481
+ hidden_states = attn.spatial_norm(hidden_states, temb)
482
+
483
+ input_ndim = hidden_states.ndim
484
+
485
+ if input_ndim == 4:
486
+ batch_size, channel, height, width = hidden_states.shape
487
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
488
+
489
+ batch_size, sequence_length, _ = (
490
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
491
+ )
492
+
493
+ if attention_mask is not None:
494
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
495
+ # scaled_dot_product_attention expects attention_mask shape to be
496
+ # (batch, heads, source_length, target_length)
497
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
498
+
499
+ if attn.group_norm is not None:
500
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
501
+
502
+ query = attn.to_q(hidden_states)
503
+
504
+ if encoder_hidden_states is None:
505
+ encoder_hidden_states = hidden_states
506
+ else:
507
+ # get encoder_hidden_states, ip_hidden_states
508
+ end_pos = encoder_hidden_states.shape[1] - self.num_content_tokens-self.num_style_tokens
509
+ encoder_hidden_states, ip_content_hidden_states,ip_style_hidden_states = (
510
+ encoder_hidden_states[:, :end_pos, :],
511
+ encoder_hidden_states[:, end_pos:end_pos + self.num_content_tokens, :],
512
+ encoder_hidden_states[:, end_pos + self.num_content_tokens:, :],
513
+ )
514
+ if attn.norm_cross:
515
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
516
+
517
+ key = attn.to_k(encoder_hidden_states)
518
+ value = attn.to_v(encoder_hidden_states)
519
+
520
+ inner_dim = key.shape[-1]
521
+ head_dim = inner_dim // attn.heads
522
+
523
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
524
+
525
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
526
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
527
+
528
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
529
+ # TODO: add support for attn.scale when we move to Torch 2.1
530
+ hidden_states = F.scaled_dot_product_attention(
531
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
532
+ )
533
+
534
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
535
+ hidden_states = hidden_states.to(query.dtype)
536
+
537
+ if not self.skip and self.content is True:
538
+ # print('content#####################################################')
539
+ # for ip-content-adapter
540
+ if self.to_k_ip_content is None:
541
+
542
+ ip_content_key = self.to_k_ip(ip_content_hidden_states)
543
+ ip_content_value = self.to_v_ip(ip_content_hidden_states)
544
+ else:
545
+ ip_content_key = self.to_k_ip_content(ip_content_hidden_states)
546
+ ip_content_value = self.to_v_ip_content(ip_content_hidden_states)
547
+
548
+ ip_content_key = ip_content_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
549
+ ip_content_value = ip_content_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
550
+
551
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
552
+ # TODO: add support for attn.scale when we move to Torch 2.1
553
+ ip_content_hidden_states = F.scaled_dot_product_attention(
554
+ query, ip_content_key, ip_content_value, attn_mask=None, dropout_p=0.0, is_causal=False
555
+ )
556
+
557
+
558
+ ip_content_hidden_states = ip_content_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
559
+ ip_content_hidden_states = ip_content_hidden_states.to(query.dtype)
560
+
561
+
562
+ hidden_states = hidden_states + self.content_scale * ip_content_hidden_states
563
+
564
+ if not self.skip and self.style is True:
565
+ # for ip-style-adapter
566
+ ip_style_key = self.to_k_ip(ip_style_hidden_states)
567
+ ip_style_value = self.to_v_ip(ip_style_hidden_states)
568
+
569
+ ip_style_key = ip_style_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
570
+ ip_style_value = ip_style_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
571
+
572
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
573
+ # TODO: add support for attn.scale when we move to Torch 2.1
574
+ ip_style_hidden_states = F.scaled_dot_product_attention(
575
+ query, ip_style_key, ip_style_value, attn_mask=None, dropout_p=0.0, is_causal=False
576
+ )
577
+
578
+ ip_style_hidden_states = ip_style_hidden_states.transpose(1, 2).reshape(batch_size, -1,
579
+ attn.heads * head_dim)
580
+ ip_style_hidden_states = ip_style_hidden_states.to(query.dtype)
581
+
582
+ hidden_states = hidden_states + self.style_scale * ip_style_hidden_states
583
+
584
+ # linear proj
585
+ hidden_states = attn.to_out[0](hidden_states)
586
+ # dropout
587
+ hidden_states = attn.to_out[1](hidden_states)
588
+
589
+ if input_ndim == 4:
590
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
591
+
592
+ if attn.residual_connection:
593
+ hidden_states = hidden_states + residual
594
+
595
+ hidden_states = hidden_states / attn.rescale_output_factor
596
+
597
+ return hidden_states
598
+
599
+ ## for controlnet
600
+ class CNAttnProcessor:
601
+ r"""
602
+ Default processor for performing attention-related computations.
603
+ """
604
+
605
+ def __init__(self, num_tokens=4,save_in_unet='down',atten_control=None):
606
+ self.num_tokens = num_tokens
607
+ self.atten_control = atten_control
608
+ self.save_in_unet = save_in_unet
609
+
610
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
611
+ residual = hidden_states
612
+
613
+ if attn.spatial_norm is not None:
614
+ hidden_states = attn.spatial_norm(hidden_states, temb)
615
+
616
+ input_ndim = hidden_states.ndim
617
+
618
+ if input_ndim == 4:
619
+ batch_size, channel, height, width = hidden_states.shape
620
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
621
+
622
+ batch_size, sequence_length, _ = (
623
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
624
+ )
625
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
626
+
627
+ if attn.group_norm is not None:
628
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
629
+
630
+ query = attn.to_q(hidden_states)
631
+
632
+ if encoder_hidden_states is None:
633
+ encoder_hidden_states = hidden_states
634
+ else:
635
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
636
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
637
+ if attn.norm_cross:
638
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
639
+
640
+ key = attn.to_k(encoder_hidden_states)
641
+ value = attn.to_v(encoder_hidden_states)
642
+
643
+ query = attn.head_to_batch_dim(query)
644
+ key = attn.head_to_batch_dim(key)
645
+ value = attn.head_to_batch_dim(value)
646
+
647
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
648
+ hidden_states = torch.bmm(attention_probs, value)
649
+ hidden_states = attn.batch_to_head_dim(hidden_states)
650
+
651
+ # linear proj
652
+ hidden_states = attn.to_out[0](hidden_states)
653
+ # dropout
654
+ hidden_states = attn.to_out[1](hidden_states)
655
+
656
+ if input_ndim == 4:
657
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
658
+
659
+ if attn.residual_connection:
660
+ hidden_states = hidden_states + residual
661
+
662
+ hidden_states = hidden_states / attn.rescale_output_factor
663
+
664
+ return hidden_states
665
+
666
+
667
+ class CNAttnProcessor2_0:
668
+ r"""
669
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
670
+ """
671
+
672
+ def __init__(self, num_tokens=4, save_in_unet='down', atten_control=None):
673
+ if not hasattr(F, "scaled_dot_product_attention"):
674
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
675
+ self.num_tokens = num_tokens
676
+ self.atten_control = atten_control
677
+ self.save_in_unet = save_in_unet
678
+
679
+ def __call__(
680
+ self,
681
+ attn,
682
+ hidden_states,
683
+ encoder_hidden_states=None,
684
+ attention_mask=None,
685
+ temb=None,
686
+ ):
687
+ residual = hidden_states
688
+
689
+ if attn.spatial_norm is not None:
690
+ hidden_states = attn.spatial_norm(hidden_states, temb)
691
+
692
+ input_ndim = hidden_states.ndim
693
+
694
+ if input_ndim == 4:
695
+ batch_size, channel, height, width = hidden_states.shape
696
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
697
+
698
+ batch_size, sequence_length, _ = (
699
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
700
+ )
701
+
702
+ if attention_mask is not None:
703
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
704
+ # scaled_dot_product_attention expects attention_mask shape to be
705
+ # (batch, heads, source_length, target_length)
706
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
707
+
708
+ if attn.group_norm is not None:
709
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
710
+
711
+ query = attn.to_q(hidden_states)
712
+
713
+ if encoder_hidden_states is None:
714
+ encoder_hidden_states = hidden_states
715
+ else:
716
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
717
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
718
+ if attn.norm_cross:
719
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
720
+
721
+ key = attn.to_k(encoder_hidden_states)
722
+ value = attn.to_v(encoder_hidden_states)
723
+
724
+ inner_dim = key.shape[-1]
725
+ head_dim = inner_dim // attn.heads
726
+
727
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
728
+
729
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
730
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
731
+
732
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
733
+ # TODO: add support for attn.scale when we move to Torch 2.1
734
+ hidden_states = F.scaled_dot_product_attention(
735
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
736
+ )
737
+
738
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
739
+ hidden_states = hidden_states.to(query.dtype)
740
+
741
+ # linear proj
742
+ hidden_states = attn.to_out[0](hidden_states)
743
+ # dropout
744
+ hidden_states = attn.to_out[1](hidden_states)
745
+
746
+ if input_ndim == 4:
747
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
748
+
749
+ if attn.residual_connection:
750
+ hidden_states = hidden_states + residual
751
+
752
+ hidden_states = hidden_states / attn.rescale_output_factor
753
+
754
+ return hidden_states
ip_adapter/ip_adapter.py ADDED
@@ -0,0 +1,1078 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ from diffusers.pipelines.controlnet import MultiControlNetModel
7
+ from PIL import Image
8
+ from safetensors import safe_open
9
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
+ from torchvision import transforms
11
+ from .utils import is_torch2_available, get_generator
12
+
13
+ # import torchvision.transforms.functional as Func
14
+
15
+ # from .clip_style_models import CSD_CLIP, convert_state_dict
16
+
17
+ if is_torch2_available():
18
+ from .attention_processor import (
19
+ AttnProcessor2_0 as AttnProcessor,
20
+ )
21
+ from .attention_processor import (
22
+ CNAttnProcessor2_0 as CNAttnProcessor,
23
+ )
24
+ from .attention_processor import (
25
+ IPAttnProcessor2_0 as IPAttnProcessor,
26
+ )
27
+ from .attention_processor import IP_CS_AttnProcessor2_0 as IP_CS_AttnProcessor
28
+ else:
29
+ from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
30
+ from .resampler import Resampler
31
+
32
+ from transformers import AutoImageProcessor, AutoModel
33
+
34
+
35
+ class ImageProjModel(torch.nn.Module):
36
+ """Projection Model"""
37
+
38
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
39
+ super().__init__()
40
+
41
+ self.generator = None
42
+ self.cross_attention_dim = cross_attention_dim
43
+ self.clip_extra_context_tokens = clip_extra_context_tokens
44
+ # print(clip_embeddings_dim, self.clip_extra_context_tokens, cross_attention_dim)
45
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
46
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
47
+
48
+ def forward(self, image_embeds):
49
+ embeds = image_embeds
50
+ clip_extra_context_tokens = self.proj(embeds).reshape(
51
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
52
+ )
53
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
54
+ return clip_extra_context_tokens
55
+
56
+
57
+ class MLPProjModel(torch.nn.Module):
58
+ """SD model with image prompt"""
59
+
60
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
61
+ super().__init__()
62
+
63
+ self.proj = torch.nn.Sequential(
64
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
65
+ torch.nn.GELU(),
66
+ torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
67
+ torch.nn.LayerNorm(cross_attention_dim)
68
+ )
69
+
70
+ def forward(self, image_embeds):
71
+ clip_extra_context_tokens = self.proj(image_embeds)
72
+ return clip_extra_context_tokens
73
+
74
+
75
+ class IPAdapter:
76
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["block"]):
77
+ self.device = device
78
+ self.image_encoder_path = image_encoder_path
79
+ self.ip_ckpt = ip_ckpt
80
+ self.num_tokens = num_tokens
81
+ self.target_blocks = target_blocks
82
+
83
+ self.pipe = sd_pipe.to(self.device)
84
+ self.set_ip_adapter()
85
+
86
+ # load image encoder
87
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
88
+ self.device, dtype=torch.float16
89
+ )
90
+ self.clip_image_processor = CLIPImageProcessor()
91
+ # image proj model
92
+ self.image_proj_model = self.init_proj()
93
+
94
+ self.load_ip_adapter()
95
+
96
+ def init_proj(self):
97
+ image_proj_model = ImageProjModel(
98
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
99
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
100
+ clip_extra_context_tokens=self.num_tokens,
101
+ ).to(self.device, dtype=torch.float16)
102
+ return image_proj_model
103
+
104
+ def set_ip_adapter(self):
105
+ unet = self.pipe.unet
106
+ attn_procs = {}
107
+ for name in unet.attn_processors.keys():
108
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
109
+ if name.startswith("mid_block"):
110
+ hidden_size = unet.config.block_out_channels[-1]
111
+ elif name.startswith("up_blocks"):
112
+ block_id = int(name[len("up_blocks.")])
113
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
114
+ elif name.startswith("down_blocks"):
115
+ block_id = int(name[len("down_blocks.")])
116
+ hidden_size = unet.config.block_out_channels[block_id]
117
+ if cross_attention_dim is None:
118
+ attn_procs[name] = AttnProcessor()
119
+ else:
120
+ selected = False
121
+ for block_name in self.target_blocks:
122
+ if block_name in name:
123
+ selected = True
124
+ break
125
+ if selected:
126
+ attn_procs[name] = IPAttnProcessor(
127
+ hidden_size=hidden_size,
128
+ cross_attention_dim=cross_attention_dim,
129
+ scale=1.0,
130
+ num_tokens=self.num_tokens,
131
+ ).to(self.device, dtype=torch.float16)
132
+ else:
133
+ attn_procs[name] = IPAttnProcessor(
134
+ hidden_size=hidden_size,
135
+ cross_attention_dim=cross_attention_dim,
136
+ scale=1.0,
137
+ num_tokens=self.num_tokens,
138
+ skip=True
139
+ ).to(self.device, dtype=torch.float16)
140
+ unet.set_attn_processor(attn_procs)
141
+ if hasattr(self.pipe, "controlnet"):
142
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
143
+ for controlnet in self.pipe.controlnet.nets:
144
+ controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
145
+ else:
146
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
147
+
148
+ def load_ip_adapter(self):
149
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
150
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
151
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
152
+ for key in f.keys():
153
+ if key.startswith("image_proj."):
154
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
155
+ elif key.startswith("ip_adapter."):
156
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
157
+ else:
158
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
159
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
160
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
161
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
162
+
163
+ @torch.inference_mode()
164
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
165
+ if pil_image is not None:
166
+ if isinstance(pil_image, Image.Image):
167
+ pil_image = [pil_image]
168
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
169
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
170
+ else:
171
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
172
+
173
+ if content_prompt_embeds is not None:
174
+ clip_image_embeds = clip_image_embeds - content_prompt_embeds
175
+
176
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
177
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
178
+ return image_prompt_embeds, uncond_image_prompt_embeds
179
+
180
+ def set_scale(self, scale):
181
+ for attn_processor in self.pipe.unet.attn_processors.values():
182
+ if isinstance(attn_processor, IPAttnProcessor):
183
+ attn_processor.scale = scale
184
+
185
+ def generate(
186
+ self,
187
+ pil_image=None,
188
+ clip_image_embeds=None,
189
+ prompt=None,
190
+ negative_prompt=None,
191
+ scale=1.0,
192
+ num_samples=4,
193
+ seed=None,
194
+ guidance_scale=7.5,
195
+ num_inference_steps=30,
196
+ neg_content_emb=None,
197
+ **kwargs,
198
+ ):
199
+ self.set_scale(scale)
200
+
201
+ if pil_image is not None:
202
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
203
+ else:
204
+ num_prompts = clip_image_embeds.size(0)
205
+
206
+ if prompt is None:
207
+ prompt = "best quality, high quality"
208
+ if negative_prompt is None:
209
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
210
+
211
+ if not isinstance(prompt, List):
212
+ prompt = [prompt] * num_prompts
213
+ if not isinstance(negative_prompt, List):
214
+ negative_prompt = [negative_prompt] * num_prompts
215
+
216
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
217
+ pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=neg_content_emb
218
+ )
219
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
220
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
221
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
222
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
223
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
224
+
225
+ with torch.inference_mode():
226
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
227
+ prompt,
228
+ device=self.device,
229
+ num_images_per_prompt=num_samples,
230
+ do_classifier_free_guidance=True,
231
+ negative_prompt=negative_prompt,
232
+ )
233
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
234
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
235
+
236
+ generator = get_generator(seed, self.device)
237
+
238
+ images = self.pipe(
239
+ prompt_embeds=prompt_embeds,
240
+ negative_prompt_embeds=negative_prompt_embeds,
241
+ guidance_scale=guidance_scale,
242
+ num_inference_steps=num_inference_steps,
243
+ generator=generator,
244
+ **kwargs,
245
+ ).images
246
+
247
+ return images
248
+
249
+
250
+ class IPAdapter_CS:
251
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_content_tokens=4,
252
+ num_style_tokens=4,
253
+ target_content_blocks=["block"], target_style_blocks=["block"], content_image_encoder_path=None,
254
+ controlnet_adapter=False,
255
+ controlnet_target_content_blocks=None,
256
+ controlnet_target_style_blocks=None,
257
+ content_model_resampler=False,
258
+ style_model_resampler=False,
259
+ ):
260
+ self.device = device
261
+ self.image_encoder_path = image_encoder_path
262
+ self.ip_ckpt = ip_ckpt
263
+ self.num_content_tokens = num_content_tokens
264
+ self.num_style_tokens = num_style_tokens
265
+ self.content_target_blocks = target_content_blocks
266
+ self.style_target_blocks = target_style_blocks
267
+
268
+ self.content_model_resampler = content_model_resampler
269
+ self.style_model_resampler = style_model_resampler
270
+
271
+ self.controlnet_adapter = controlnet_adapter
272
+ self.controlnet_target_content_blocks = controlnet_target_content_blocks
273
+ self.controlnet_target_style_blocks = controlnet_target_style_blocks
274
+
275
+ self.pipe = sd_pipe.to(self.device)
276
+ self.set_ip_adapter()
277
+ self.content_image_encoder_path = content_image_encoder_path
278
+
279
+
280
+ # load image encoder
281
+ if content_image_encoder_path is not None:
282
+ self.content_image_encoder = AutoModel.from_pretrained(content_image_encoder_path).to(self.device,
283
+ dtype=torch.float16)
284
+ self.content_image_processor = AutoImageProcessor.from_pretrained(content_image_encoder_path)
285
+ else:
286
+ self.content_image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
287
+ self.device, dtype=torch.float16
288
+ )
289
+ self.content_image_processor = CLIPImageProcessor()
290
+ # model.requires_grad_(False)
291
+
292
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
293
+ self.device, dtype=torch.float16
294
+ )
295
+ # if self.use_CSD is not None:
296
+ # self.style_image_encoder = CSD_CLIP("vit_large", "default",self.use_CSD+"/ViT-L-14.pt")
297
+ # model_path = self.use_CSD+"/checkpoint.pth"
298
+ # checkpoint = torch.load(model_path, map_location="cpu")
299
+ # state_dict = convert_state_dict(checkpoint['model_state_dict'])
300
+ # self.style_image_encoder.load_state_dict(state_dict, strict=False)
301
+ #
302
+ # normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
303
+ # self.style_preprocess = transforms.Compose([
304
+ # transforms.Resize(size=224, interpolation=Func.InterpolationMode.BICUBIC),
305
+ # transforms.CenterCrop(224),
306
+ # transforms.ToTensor(),
307
+ # normalize,
308
+ # ])
309
+
310
+ self.clip_image_processor = CLIPImageProcessor()
311
+ # image proj model
312
+ self.content_image_proj_model = self.init_proj(self.num_content_tokens, content_or_style_='content',
313
+ model_resampler=self.content_model_resampler)
314
+ self.style_image_proj_model = self.init_proj(self.num_style_tokens, content_or_style_='style',
315
+ model_resampler=self.style_model_resampler)
316
+
317
+ self.load_ip_adapter()
318
+
319
+ def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False):
320
+
321
+ # print('@@@@',self.pipe.unet.config.cross_attention_dim,self.image_encoder.config.projection_dim)
322
+ if content_or_style_ == 'content' and self.content_image_encoder_path is not None:
323
+ image_proj_model = ImageProjModel(
324
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
325
+ clip_embeddings_dim=self.content_image_encoder.config.projection_dim,
326
+ clip_extra_context_tokens=num_tokens,
327
+ ).to(self.device, dtype=torch.float16)
328
+ return image_proj_model
329
+
330
+ image_proj_model = ImageProjModel(
331
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
332
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
333
+ clip_extra_context_tokens=num_tokens,
334
+ ).to(self.device, dtype=torch.float16)
335
+ return image_proj_model
336
+
337
+ def set_ip_adapter(self):
338
+ unet = self.pipe.unet
339
+ attn_procs = {}
340
+ for name in unet.attn_processors.keys():
341
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
342
+ if name.startswith("mid_block"):
343
+ hidden_size = unet.config.block_out_channels[-1]
344
+ elif name.startswith("up_blocks"):
345
+ block_id = int(name[len("up_blocks.")])
346
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
347
+ elif name.startswith("down_blocks"):
348
+ block_id = int(name[len("down_blocks.")])
349
+ hidden_size = unet.config.block_out_channels[block_id]
350
+ if cross_attention_dim is None:
351
+ attn_procs[name] = AttnProcessor()
352
+ else:
353
+ # layername_id += 1
354
+ selected = False
355
+ for block_name in self.style_target_blocks:
356
+ if block_name in name:
357
+ selected = True
358
+ # print(name)
359
+ attn_procs[name] = IP_CS_AttnProcessor(
360
+ hidden_size=hidden_size,
361
+ cross_attention_dim=cross_attention_dim,
362
+ style_scale=1.0,
363
+ style=True,
364
+ num_content_tokens=self.num_content_tokens,
365
+ num_style_tokens=self.num_style_tokens,
366
+ )
367
+ for block_name in self.content_target_blocks:
368
+ if block_name in name:
369
+ # selected = True
370
+ if selected is False:
371
+ attn_procs[name] = IP_CS_AttnProcessor(
372
+ hidden_size=hidden_size,
373
+ cross_attention_dim=cross_attention_dim,
374
+ content_scale=1.0,
375
+ content=True,
376
+ num_content_tokens=self.num_content_tokens,
377
+ num_style_tokens=self.num_style_tokens,
378
+ )
379
+ else:
380
+ attn_procs[name].set_content_ipa(content_scale=1.0)
381
+ # attn_procs[name].content=True
382
+
383
+ if selected is False:
384
+ attn_procs[name] = IP_CS_AttnProcessor(
385
+ hidden_size=hidden_size,
386
+ cross_attention_dim=cross_attention_dim,
387
+ num_content_tokens=self.num_content_tokens,
388
+ num_style_tokens=self.num_style_tokens,
389
+ skip=True,
390
+ )
391
+
392
+ attn_procs[name].to(self.device, dtype=torch.float16)
393
+ unet.set_attn_processor(attn_procs)
394
+ if hasattr(self.pipe, "controlnet"):
395
+ if self.controlnet_adapter is False:
396
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
397
+ for controlnet in self.pipe.controlnet.nets:
398
+ controlnet.set_attn_processor(CNAttnProcessor(
399
+ num_tokens=self.num_content_tokens + self.num_style_tokens))
400
+ else:
401
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(
402
+ num_tokens=self.num_content_tokens + self.num_style_tokens))
403
+
404
+ else:
405
+ controlnet_attn_procs = {}
406
+ controlnet_style_target_blocks = self.controlnet_target_style_blocks
407
+ controlnet_content_target_blocks = self.controlnet_target_content_blocks
408
+ for name in self.pipe.controlnet.attn_processors.keys():
409
+ # print(name)
410
+ cross_attention_dim = None if name.endswith(
411
+ "attn1.processor") else self.pipe.controlnet.config.cross_attention_dim
412
+ if name.startswith("mid_block"):
413
+ hidden_size = self.pipe.controlnet.config.block_out_channels[-1]
414
+ elif name.startswith("up_blocks"):
415
+ block_id = int(name[len("up_blocks.")])
416
+ hidden_size = list(reversed(self.pipe.controlnet.config.block_out_channels))[block_id]
417
+ elif name.startswith("down_blocks"):
418
+ block_id = int(name[len("down_blocks.")])
419
+ hidden_size = self.pipe.controlnet.config.block_out_channels[block_id]
420
+ if cross_attention_dim is None:
421
+ # layername_id += 1
422
+ controlnet_attn_procs[name] = AttnProcessor()
423
+
424
+ else:
425
+ # layername_id += 1
426
+ selected = False
427
+ for block_name in controlnet_style_target_blocks:
428
+ if block_name in name:
429
+ selected = True
430
+ # print(name)
431
+ controlnet_attn_procs[name] = IP_CS_AttnProcessor(
432
+ hidden_size=hidden_size,
433
+ cross_attention_dim=cross_attention_dim,
434
+ style_scale=1.0,
435
+ style=True,
436
+ num_content_tokens=self.num_content_tokens,
437
+ num_style_tokens=self.num_style_tokens,
438
+ )
439
+
440
+ for block_name in controlnet_content_target_blocks:
441
+ if block_name in name:
442
+ if selected is False:
443
+ controlnet_attn_procs[name] = IP_CS_AttnProcessor(
444
+ hidden_size=hidden_size,
445
+ cross_attention_dim=cross_attention_dim,
446
+ content_scale=1.0,
447
+ content=True,
448
+ num_content_tokens=self.num_content_tokens,
449
+ num_style_tokens=self.num_style_tokens,
450
+ )
451
+
452
+ selected = True
453
+ elif selected is True:
454
+ controlnet_attn_procs[name].set_content_ipa(content_scale=1.0)
455
+
456
+ # if args.content_image_encoder_type !='dinov2':
457
+ # weights = {
458
+ # "to_k_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_k_ip.weight"],
459
+ # "to_v_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_v_ip.weight"],
460
+ # }
461
+ # attn_procs[name].load_state_dict(weights)
462
+ if selected is False:
463
+ controlnet_attn_procs[name] = IP_CS_AttnProcessor(
464
+ hidden_size=hidden_size,
465
+ cross_attention_dim=cross_attention_dim,
466
+ num_content_tokens=self.num_content_tokens,
467
+ num_style_tokens=self.num_style_tokens,
468
+ skip=True,
469
+ )
470
+ controlnet_attn_procs[name].to(self.device, dtype=torch.float16)
471
+ # layer_name = name.split(".processor")[0]
472
+ # # print(state_dict["ip_adapter"].keys())
473
+ # weights = {
474
+ # "to_k_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_k_ip.weight"],
475
+ # "to_v_ip.weight": state_dict["ip_adapter"][str(layername_id) + ".to_v_ip.weight"],
476
+ # }
477
+ # attn_procs[name].load_state_dict(weights)
478
+ self.pipe.controlnet.set_attn_processor(controlnet_attn_procs)
479
+
480
+ def load_ip_adapter(self):
481
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
482
+ state_dict = {"content_image_proj": {}, "style_image_proj": {}, "ip_adapter": {}}
483
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
484
+ for key in f.keys():
485
+ if key.startswith("content_image_proj."):
486
+ state_dict["content_image_proj"][key.replace("content_image_proj.", "")] = f.get_tensor(key)
487
+ elif key.startswith("style_image_proj."):
488
+ state_dict["style_image_proj"][key.replace("style_image_proj.", "")] = f.get_tensor(key)
489
+ elif key.startswith("ip_adapter."):
490
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
491
+ else:
492
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
493
+ self.content_image_proj_model.load_state_dict(state_dict["content_image_proj"])
494
+ self.style_image_proj_model.load_state_dict(state_dict["style_image_proj"])
495
+
496
+ if 'conv_in_unet_sd' in state_dict.keys():
497
+ self.pipe.unet.conv_in.load_state_dict(state_dict["conv_in_unet_sd"], strict=True)
498
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
499
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
500
+
501
+ if self.controlnet_adapter is True:
502
+ print('loading controlnet_adapter')
503
+ self.pipe.controlnet.load_state_dict(state_dict["controlnet_adapter_modules"], strict=False)
504
+
505
+ @torch.inference_mode()
506
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None,
507
+ content_or_style_=''):
508
+ # if pil_image is not None:
509
+ # if isinstance(pil_image, Image.Image):
510
+ # pil_image = [pil_image]
511
+ # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
512
+ # clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
513
+ # else:
514
+ # clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
515
+
516
+ # if content_prompt_embeds is not None:
517
+ # clip_image_embeds = clip_image_embeds - content_prompt_embeds
518
+
519
+ if content_or_style_ == 'content':
520
+ if pil_image is not None:
521
+ if isinstance(pil_image, Image.Image):
522
+ pil_image = [pil_image]
523
+ if self.content_image_proj_model is not None:
524
+ clip_image = self.content_image_processor(images=pil_image, return_tensors="pt").pixel_values
525
+ clip_image_embeds = self.content_image_encoder(
526
+ clip_image.to(self.device, dtype=torch.float16)).image_embeds
527
+ else:
528
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
529
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
530
+ else:
531
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
532
+
533
+ image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
534
+ uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
535
+ return image_prompt_embeds, uncond_image_prompt_embeds
536
+ if content_or_style_ == 'style':
537
+ if pil_image is not None:
538
+ if self.use_CSD is not None:
539
+ clip_image = self.style_preprocess(pil_image).unsqueeze(0).to(self.device, dtype=torch.float32)
540
+ clip_image_embeds = self.style_image_encoder(clip_image)
541
+ else:
542
+ if isinstance(pil_image, Image.Image):
543
+ pil_image = [pil_image]
544
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
545
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
546
+
547
+
548
+ else:
549
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
550
+ image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
551
+ uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
552
+ return image_prompt_embeds, uncond_image_prompt_embeds
553
+
554
+ def set_scale(self, content_scale, style_scale):
555
+ for attn_processor in self.pipe.unet.attn_processors.values():
556
+ if isinstance(attn_processor, IP_CS_AttnProcessor):
557
+ if attn_processor.content is True:
558
+ attn_processor.content_scale = content_scale
559
+
560
+ if attn_processor.style is True:
561
+ attn_processor.style_scale = style_scale
562
+ # print('style_scale:',style_scale)
563
+ if self.controlnet_adapter is not None:
564
+ for attn_processor in self.pipe.controlnet.attn_processors.values():
565
+
566
+ if isinstance(attn_processor, IP_CS_AttnProcessor):
567
+ if attn_processor.content is True:
568
+ attn_processor.content_scale = content_scale
569
+ # print(content_scale)
570
+
571
+ if attn_processor.style is True:
572
+ attn_processor.style_scale = style_scale
573
+
574
+ def generate(
575
+ self,
576
+ pil_content_image=None,
577
+ pil_style_image=None,
578
+ clip_content_image_embeds=None,
579
+ clip_style_image_embeds=None,
580
+ prompt=None,
581
+ negative_prompt=None,
582
+ content_scale=1.0,
583
+ style_scale=1.0,
584
+ num_samples=4,
585
+ seed=None,
586
+ guidance_scale=7.5,
587
+ num_inference_steps=30,
588
+ neg_content_emb=None,
589
+ **kwargs,
590
+ ):
591
+ self.set_scale(content_scale, style_scale)
592
+
593
+ if pil_content_image is not None:
594
+ num_prompts = 1 if isinstance(pil_content_image, Image.Image) else len(pil_content_image)
595
+ else:
596
+ num_prompts = clip_content_image_embeds.size(0)
597
+
598
+ if prompt is None:
599
+ prompt = "best quality, high quality"
600
+ if negative_prompt is None:
601
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
602
+
603
+ if not isinstance(prompt, List):
604
+ prompt = [prompt] * num_prompts
605
+ if not isinstance(negative_prompt, List):
606
+ negative_prompt = [negative_prompt] * num_prompts
607
+
608
+ content_image_prompt_embeds, uncond_content_image_prompt_embeds = self.get_image_embeds(
609
+ pil_image=pil_content_image, clip_image_embeds=clip_content_image_embeds
610
+ )
611
+ style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds(
612
+ pil_image=pil_style_image, clip_image_embeds=clip_style_image_embeds
613
+ )
614
+
615
+ bs_embed, seq_len, _ = content_image_prompt_embeds.shape
616
+ content_image_prompt_embeds = content_image_prompt_embeds.repeat(1, num_samples, 1)
617
+ content_image_prompt_embeds = content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
618
+ uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.repeat(1, num_samples, 1)
619
+ uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.view(bs_embed * num_samples, seq_len,
620
+ -1)
621
+
622
+ bs_style_embed, seq_style_len, _ = content_image_prompt_embeds.shape
623
+ style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1)
624
+ style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1)
625
+ uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1)
626
+ uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len,
627
+ -1)
628
+
629
+ with torch.inference_mode():
630
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
631
+ prompt,
632
+ device=self.device,
633
+ num_images_per_prompt=num_samples,
634
+ do_classifier_free_guidance=True,
635
+ negative_prompt=negative_prompt,
636
+ )
637
+ prompt_embeds = torch.cat([prompt_embeds_, content_image_prompt_embeds, style_image_prompt_embeds], dim=1)
638
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_,
639
+ uncond_content_image_prompt_embeds, uncond_style_image_prompt_embeds],
640
+ dim=1)
641
+
642
+ generator = get_generator(seed, self.device)
643
+
644
+ images = self.pipe(
645
+ prompt_embeds=prompt_embeds,
646
+ negative_prompt_embeds=negative_prompt_embeds,
647
+ guidance_scale=guidance_scale,
648
+ num_inference_steps=num_inference_steps,
649
+ generator=generator,
650
+ **kwargs,
651
+ ).images
652
+
653
+ return images
654
+
655
+
656
+ class IPAdapterXL_CS(IPAdapter_CS):
657
+ """SDXL"""
658
+
659
+ def generate(
660
+ self,
661
+ pil_content_image,
662
+ pil_style_image,
663
+ prompt=None,
664
+ negative_prompt=None,
665
+ content_scale=1.0,
666
+ style_scale=1.0,
667
+ num_samples=4,
668
+ seed=None,
669
+ content_image_embeds=None,
670
+ style_image_embeds=None,
671
+ num_inference_steps=30,
672
+ neg_content_emb=None,
673
+ neg_content_prompt=None,
674
+ neg_content_scale=1.0,
675
+ **kwargs,
676
+ ):
677
+ self.set_scale(content_scale, style_scale)
678
+
679
+ num_prompts = 1 if isinstance(pil_content_image, Image.Image) else len(pil_content_image)
680
+
681
+ if prompt is None:
682
+ prompt = "best quality, high quality"
683
+ if negative_prompt is None:
684
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
685
+
686
+ if not isinstance(prompt, List):
687
+ prompt = [prompt] * num_prompts
688
+ if not isinstance(negative_prompt, List):
689
+ negative_prompt = [negative_prompt] * num_prompts
690
+
691
+ content_image_prompt_embeds, uncond_content_image_prompt_embeds = self.get_image_embeds(pil_content_image,
692
+ content_image_embeds,
693
+ content_or_style_='content')
694
+
695
+
696
+
697
+ style_image_prompt_embeds, uncond_style_image_prompt_embeds = self.get_image_embeds(pil_style_image,
698
+ style_image_embeds,
699
+ content_or_style_='style')
700
+
701
+ bs_embed, seq_len, _ = content_image_prompt_embeds.shape
702
+
703
+ content_image_prompt_embeds = content_image_prompt_embeds.repeat(1, num_samples, 1)
704
+ content_image_prompt_embeds = content_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
705
+
706
+ uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.repeat(1, num_samples, 1)
707
+ uncond_content_image_prompt_embeds = uncond_content_image_prompt_embeds.view(bs_embed * num_samples, seq_len,
708
+ -1)
709
+ bs_style_embed, seq_style_len, _ = style_image_prompt_embeds.shape
710
+ style_image_prompt_embeds = style_image_prompt_embeds.repeat(1, num_samples, 1)
711
+ style_image_prompt_embeds = style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len, -1)
712
+ uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.repeat(1, num_samples, 1)
713
+ uncond_style_image_prompt_embeds = uncond_style_image_prompt_embeds.view(bs_embed * num_samples, seq_style_len,
714
+ -1)
715
+
716
+ with torch.inference_mode():
717
+ (
718
+ prompt_embeds,
719
+ negative_prompt_embeds,
720
+ pooled_prompt_embeds,
721
+ negative_pooled_prompt_embeds,
722
+ ) = self.pipe.encode_prompt(
723
+ prompt,
724
+ num_images_per_prompt=num_samples,
725
+ do_classifier_free_guidance=True,
726
+ negative_prompt=negative_prompt,
727
+ )
728
+ prompt_embeds = torch.cat([prompt_embeds, content_image_prompt_embeds, style_image_prompt_embeds], dim=1)
729
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds,
730
+ uncond_content_image_prompt_embeds, uncond_style_image_prompt_embeds],
731
+ dim=1)
732
+
733
+ self.generator = get_generator(seed, self.device)
734
+
735
+ images = self.pipe(
736
+ prompt_embeds=prompt_embeds,
737
+ negative_prompt_embeds=negative_prompt_embeds,
738
+ pooled_prompt_embeds=pooled_prompt_embeds,
739
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
740
+ num_inference_steps=num_inference_steps,
741
+ generator=self.generator,
742
+ **kwargs,
743
+ ).images
744
+ return images
745
+
746
+
747
+ class CSGO(IPAdapterXL_CS):
748
+ """SDXL"""
749
+
750
+ def init_proj(self, num_tokens, content_or_style_='content', model_resampler=False):
751
+ if content_or_style_ == 'content':
752
+ if model_resampler:
753
+ image_proj_model = Resampler(
754
+ dim=self.pipe.unet.config.cross_attention_dim,
755
+ depth=4,
756
+ dim_head=64,
757
+ heads=12,
758
+ num_queries=num_tokens,
759
+ embedding_dim=self.content_image_encoder.config.hidden_size,
760
+ output_dim=self.pipe.unet.config.cross_attention_dim,
761
+ ff_mult=4,
762
+ ).to(self.device, dtype=torch.float16)
763
+ else:
764
+ image_proj_model = ImageProjModel(
765
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
766
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
767
+ clip_extra_context_tokens=num_tokens,
768
+ ).to(self.device, dtype=torch.float16)
769
+ if content_or_style_ == 'style':
770
+ if model_resampler:
771
+ image_proj_model = Resampler(
772
+ dim=self.pipe.unet.config.cross_attention_dim,
773
+ depth=4,
774
+ dim_head=64,
775
+ heads=12,
776
+ num_queries=num_tokens,
777
+ embedding_dim=self.content_image_encoder.config.hidden_size,
778
+ output_dim=self.pipe.unet.config.cross_attention_dim,
779
+ ff_mult=4,
780
+ ).to(self.device, dtype=torch.float16)
781
+ else:
782
+ image_proj_model = ImageProjModel(
783
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
784
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
785
+ clip_extra_context_tokens=num_tokens,
786
+ ).to(self.device, dtype=torch.float16)
787
+ return image_proj_model
788
+
789
+ @torch.inference_mode()
790
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_or_style_=''):
791
+ if isinstance(pil_image, Image.Image):
792
+ pil_image = [pil_image]
793
+ if content_or_style_ == 'style':
794
+
795
+ if self.style_model_resampler:
796
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
797
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16),
798
+ output_hidden_states=True).hidden_states[-2]
799
+ image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
800
+ uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
801
+ else:
802
+
803
+
804
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
805
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
806
+ image_prompt_embeds = self.style_image_proj_model(clip_image_embeds)
807
+ uncond_image_prompt_embeds = self.style_image_proj_model(torch.zeros_like(clip_image_embeds))
808
+ return image_prompt_embeds, uncond_image_prompt_embeds
809
+
810
+
811
+ else:
812
+
813
+ if self.content_image_encoder_path is not None:
814
+ clip_image = self.content_image_processor(images=pil_image, return_tensors="pt").pixel_values
815
+ outputs = self.content_image_encoder(clip_image.to(self.device, dtype=torch.float16),
816
+ output_hidden_states=True)
817
+ clip_image_embeds = outputs.last_hidden_state
818
+ image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
819
+
820
+ # uncond_clip_image_embeds = self.image_encoder(
821
+ # torch.zeros_like(clip_image), output_hidden_states=True
822
+ # ).last_hidden_state
823
+ uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
824
+ return image_prompt_embeds, uncond_image_prompt_embeds
825
+
826
+ else:
827
+ if self.content_model_resampler:
828
+
829
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
830
+
831
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
832
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
833
+ # clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
834
+ image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
835
+ # uncond_clip_image_embeds = self.image_encoder(
836
+ # torch.zeros_like(clip_image), output_hidden_states=True
837
+ # ).hidden_states[-2]
838
+ uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
839
+ else:
840
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
841
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
842
+ image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
843
+ uncond_image_prompt_embeds = self.content_image_proj_model(torch.zeros_like(clip_image_embeds))
844
+
845
+ return image_prompt_embeds, uncond_image_prompt_embeds
846
+
847
+ # # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
848
+ # clip_image = clip_image.to(self.device, dtype=torch.float16)
849
+ # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
850
+ # image_prompt_embeds = self.content_image_proj_model(clip_image_embeds)
851
+ # uncond_clip_image_embeds = self.image_encoder(
852
+ # torch.zeros_like(clip_image), output_hidden_states=True
853
+ # ).hidden_states[-2]
854
+ # uncond_image_prompt_embeds = self.content_image_proj_model(uncond_clip_image_embeds)
855
+ # return image_prompt_embeds, uncond_image_prompt_embeds
856
+
857
+
858
+ class IPAdapterXL(IPAdapter):
859
+ """SDXL"""
860
+
861
+ def generate(
862
+ self,
863
+ pil_image,
864
+ prompt=None,
865
+ negative_prompt=None,
866
+ scale=1.0,
867
+ num_samples=4,
868
+ seed=None,
869
+ num_inference_steps=30,
870
+ neg_content_emb=None,
871
+ neg_content_prompt=None,
872
+ neg_content_scale=1.0,
873
+ **kwargs,
874
+ ):
875
+ self.set_scale(scale)
876
+
877
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
878
+
879
+ if prompt is None:
880
+ prompt = "best quality, high quality"
881
+ if negative_prompt is None:
882
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
883
+
884
+ if not isinstance(prompt, List):
885
+ prompt = [prompt] * num_prompts
886
+ if not isinstance(negative_prompt, List):
887
+ negative_prompt = [negative_prompt] * num_prompts
888
+
889
+ if neg_content_emb is None:
890
+ if neg_content_prompt is not None:
891
+ with torch.inference_mode():
892
+ (
893
+ prompt_embeds_, # torch.Size([1, 77, 2048])
894
+ negative_prompt_embeds_,
895
+ pooled_prompt_embeds_, # torch.Size([1, 1280])
896
+ negative_pooled_prompt_embeds_,
897
+ ) = self.pipe.encode_prompt(
898
+ neg_content_prompt,
899
+ num_images_per_prompt=num_samples,
900
+ do_classifier_free_guidance=True,
901
+ negative_prompt=negative_prompt,
902
+ )
903
+ pooled_prompt_embeds_ *= neg_content_scale
904
+ else:
905
+ pooled_prompt_embeds_ = neg_content_emb
906
+ else:
907
+ pooled_prompt_embeds_ = None
908
+
909
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image,
910
+ content_prompt_embeds=pooled_prompt_embeds_)
911
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
912
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
913
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
914
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
915
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
916
+
917
+ with torch.inference_mode():
918
+ (
919
+ prompt_embeds,
920
+ negative_prompt_embeds,
921
+ pooled_prompt_embeds,
922
+ negative_pooled_prompt_embeds,
923
+ ) = self.pipe.encode_prompt(
924
+ prompt,
925
+ num_images_per_prompt=num_samples,
926
+ do_classifier_free_guidance=True,
927
+ negative_prompt=negative_prompt,
928
+ )
929
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
930
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
931
+
932
+ self.generator = get_generator(seed, self.device)
933
+
934
+ images = self.pipe(
935
+ prompt_embeds=prompt_embeds,
936
+ negative_prompt_embeds=negative_prompt_embeds,
937
+ pooled_prompt_embeds=pooled_prompt_embeds,
938
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
939
+ num_inference_steps=num_inference_steps,
940
+ generator=self.generator,
941
+ **kwargs,
942
+ ).images
943
+
944
+ return images
945
+
946
+
947
+ class IPAdapterPlus(IPAdapter):
948
+ """IP-Adapter with fine-grained features"""
949
+
950
+ def init_proj(self):
951
+ image_proj_model = Resampler(
952
+ dim=self.pipe.unet.config.cross_attention_dim,
953
+ depth=4,
954
+ dim_head=64,
955
+ heads=12,
956
+ num_queries=self.num_tokens,
957
+ embedding_dim=self.image_encoder.config.hidden_size,
958
+ output_dim=self.pipe.unet.config.cross_attention_dim,
959
+ ff_mult=4,
960
+ ).to(self.device, dtype=torch.float16)
961
+ return image_proj_model
962
+
963
+ @torch.inference_mode()
964
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
965
+ if isinstance(pil_image, Image.Image):
966
+ pil_image = [pil_image]
967
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
968
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
969
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
970
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
971
+ uncond_clip_image_embeds = self.image_encoder(
972
+ torch.zeros_like(clip_image), output_hidden_states=True
973
+ ).hidden_states[-2]
974
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
975
+ return image_prompt_embeds, uncond_image_prompt_embeds
976
+
977
+
978
+ class IPAdapterFull(IPAdapterPlus):
979
+ """IP-Adapter with full features"""
980
+
981
+ def init_proj(self):
982
+ image_proj_model = MLPProjModel(
983
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
984
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
985
+ ).to(self.device, dtype=torch.float16)
986
+ return image_proj_model
987
+
988
+
989
+ class IPAdapterPlusXL(IPAdapter):
990
+ """SDXL"""
991
+
992
+ def init_proj(self):
993
+ image_proj_model = Resampler(
994
+ dim=1280,
995
+ depth=4,
996
+ dim_head=64,
997
+ heads=20,
998
+ num_queries=self.num_tokens,
999
+ embedding_dim=self.image_encoder.config.hidden_size,
1000
+ output_dim=self.pipe.unet.config.cross_attention_dim,
1001
+ ff_mult=4,
1002
+ ).to(self.device, dtype=torch.float16)
1003
+ return image_proj_model
1004
+
1005
+ @torch.inference_mode()
1006
+ def get_image_embeds(self, pil_image):
1007
+ if isinstance(pil_image, Image.Image):
1008
+ pil_image = [pil_image]
1009
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1010
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
1011
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
1012
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
1013
+ uncond_clip_image_embeds = self.image_encoder(
1014
+ torch.zeros_like(clip_image), output_hidden_states=True
1015
+ ).hidden_states[-2]
1016
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
1017
+ return image_prompt_embeds, uncond_image_prompt_embeds
1018
+
1019
+ def generate(
1020
+ self,
1021
+ pil_image,
1022
+ prompt=None,
1023
+ negative_prompt=None,
1024
+ scale=1.0,
1025
+ num_samples=4,
1026
+ seed=None,
1027
+ num_inference_steps=30,
1028
+ **kwargs,
1029
+ ):
1030
+ self.set_scale(scale)
1031
+
1032
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
1033
+
1034
+ if prompt is None:
1035
+ prompt = "best quality, high quality"
1036
+ if negative_prompt is None:
1037
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
1038
+
1039
+ if not isinstance(prompt, List):
1040
+ prompt = [prompt] * num_prompts
1041
+ if not isinstance(negative_prompt, List):
1042
+ negative_prompt = [negative_prompt] * num_prompts
1043
+
1044
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
1045
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
1046
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
1047
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1048
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
1049
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1050
+
1051
+ with torch.inference_mode():
1052
+ (
1053
+ prompt_embeds,
1054
+ negative_prompt_embeds,
1055
+ pooled_prompt_embeds,
1056
+ negative_pooled_prompt_embeds,
1057
+ ) = self.pipe.encode_prompt(
1058
+ prompt,
1059
+ num_images_per_prompt=num_samples,
1060
+ do_classifier_free_guidance=True,
1061
+ negative_prompt=negative_prompt,
1062
+ )
1063
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
1064
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
1065
+
1066
+ generator = get_generator(seed, self.device)
1067
+
1068
+ images = self.pipe(
1069
+ prompt_embeds=prompt_embeds,
1070
+ negative_prompt_embeds=negative_prompt_embeds,
1071
+ pooled_prompt_embeds=pooled_prompt_embeds,
1072
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1073
+ num_inference_steps=num_inference_steps,
1074
+ generator=generator,
1075
+ **kwargs,
1076
+ ).images
1077
+
1078
+ return images
ip_adapter/resampler.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads)
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
+ out = weight @ v
75
+
76
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
+
78
+ return self.to_out(out)
79
+
80
+
81
+ class Resampler(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim=1024,
85
+ depth=8,
86
+ dim_head=64,
87
+ heads=16,
88
+ num_queries=8,
89
+ embedding_dim=768,
90
+ output_dim=1024,
91
+ ff_mult=4,
92
+ max_seq_len: int = 257, # CLIP tokens + CLS token
93
+ apply_pos_emb: bool = False,
94
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95
+ ):
96
+ super().__init__()
97
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98
+
99
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100
+
101
+ self.proj_in = nn.Linear(embedding_dim, dim)
102
+
103
+ self.proj_out = nn.Linear(dim, output_dim)
104
+ self.norm_out = nn.LayerNorm(output_dim)
105
+
106
+ self.to_latents_from_mean_pooled_seq = (
107
+ nn.Sequential(
108
+ nn.LayerNorm(dim),
109
+ nn.Linear(dim, dim * num_latents_mean_pooled),
110
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111
+ )
112
+ if num_latents_mean_pooled > 0
113
+ else None
114
+ )
115
+
116
+ self.layers = nn.ModuleList([])
117
+ for _ in range(depth):
118
+ self.layers.append(
119
+ nn.ModuleList(
120
+ [
121
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122
+ FeedForward(dim=dim, mult=ff_mult),
123
+ ]
124
+ )
125
+ )
126
+
127
+ def forward(self, x):
128
+ if self.pos_emb is not None:
129
+ n, device = x.shape[1], x.device
130
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
131
+ x = x + pos_emb
132
+
133
+ latents = self.latents.repeat(x.size(0), 1, 1)
134
+
135
+ x = self.proj_in(x)
136
+
137
+ if self.to_latents_from_mean_pooled_seq:
138
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
141
+
142
+ for attn, ff in self.layers:
143
+ latents = attn(x, latents) + latents
144
+ latents = ff(latents) + latents
145
+
146
+ latents = self.proj_out(latents)
147
+ return self.norm_out(latents)
148
+
149
+
150
+ def masked_mean(t, *, dim, mask=None):
151
+ if mask is None:
152
+ return t.mean(dim=dim)
153
+
154
+ denom = mask.sum(dim=dim, keepdim=True)
155
+ mask = rearrange(mask, "b n -> b n 1")
156
+ masked_t = t.masked_fill(~mask, 0.0)
157
+
158
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
ip_adapter/utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ BLOCKS = {
7
+ 'content': ['down_blocks'],
8
+ 'style': ["up_blocks"],
9
+
10
+ }
11
+
12
+ controlnet_BLOCKS = {
13
+ 'content': [],
14
+ 'style': ["down_blocks"],
15
+ }
16
+
17
+
18
+ def resize_width_height(width, height, min_short_side=512, max_long_side=1024):
19
+
20
+ if width < height:
21
+
22
+ if width < min_short_side:
23
+ scale_factor = min_short_side / width
24
+ new_width = min_short_side
25
+ new_height = int(height * scale_factor)
26
+ else:
27
+ new_width, new_height = width, height
28
+ else:
29
+
30
+ if height < min_short_side:
31
+ scale_factor = min_short_side / height
32
+ new_width = int(width * scale_factor)
33
+ new_height = min_short_side
34
+ else:
35
+ new_width, new_height = width, height
36
+
37
+ if max(new_width, new_height) > max_long_side:
38
+ scale_factor = max_long_side / max(new_width, new_height)
39
+ new_width = int(new_width * scale_factor)
40
+ new_height = int(new_height * scale_factor)
41
+ return new_width, new_height
42
+
43
+ def resize_content(content_image):
44
+ max_long_side = 1024
45
+ min_short_side = 1024
46
+
47
+ new_width, new_height = resize_width_height(content_image.size[0], content_image.size[1],
48
+ min_short_side=min_short_side, max_long_side=max_long_side)
49
+ height = new_height // 16 * 16
50
+ width = new_width // 16 * 16
51
+ content_image = content_image.resize((width, height))
52
+
53
+ return width,height,content_image
54
+
55
+ attn_maps = {}
56
+ def hook_fn(name):
57
+ def forward_hook(module, input, output):
58
+ if hasattr(module.processor, "attn_map"):
59
+ attn_maps[name] = module.processor.attn_map
60
+ del module.processor.attn_map
61
+
62
+ return forward_hook
63
+
64
+ def register_cross_attention_hook(unet):
65
+ for name, module in unet.named_modules():
66
+ if name.split('.')[-1].startswith('attn2'):
67
+ module.register_forward_hook(hook_fn(name))
68
+
69
+ return unet
70
+
71
+ def upscale(attn_map, target_size):
72
+ attn_map = torch.mean(attn_map, dim=0)
73
+ attn_map = attn_map.permute(1,0)
74
+ temp_size = None
75
+
76
+ for i in range(0,5):
77
+ scale = 2 ** i
78
+ if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
79
+ temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
80
+ break
81
+
82
+ assert temp_size is not None, "temp_size cannot is None"
83
+
84
+ attn_map = attn_map.view(attn_map.shape[0], *temp_size)
85
+
86
+ attn_map = F.interpolate(
87
+ attn_map.unsqueeze(0).to(dtype=torch.float32),
88
+ size=target_size,
89
+ mode='bilinear',
90
+ align_corners=False
91
+ )[0]
92
+
93
+ attn_map = torch.softmax(attn_map, dim=0)
94
+ return attn_map
95
+ def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
96
+
97
+ idx = 0 if instance_or_negative else 1
98
+ net_attn_maps = []
99
+
100
+ for name, attn_map in attn_maps.items():
101
+ attn_map = attn_map.cpu() if detach else attn_map
102
+ attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
103
+ attn_map = upscale(attn_map, image_size)
104
+ net_attn_maps.append(attn_map)
105
+
106
+ net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
107
+
108
+ return net_attn_maps
109
+
110
+ def attnmaps2images(net_attn_maps):
111
+
112
+ #total_attn_scores = 0
113
+ images = []
114
+
115
+ for attn_map in net_attn_maps:
116
+ attn_map = attn_map.cpu().numpy()
117
+ #total_attn_scores += attn_map.mean().item()
118
+
119
+ normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
120
+ normalized_attn_map = normalized_attn_map.astype(np.uint8)
121
+ #print("norm: ", normalized_attn_map.shape)
122
+ image = Image.fromarray(normalized_attn_map)
123
+
124
+ #image = fix_save_attn_map(attn_map)
125
+ images.append(image)
126
+
127
+ #print(total_attn_scores)
128
+ return images
129
+ def is_torch2_available():
130
+ return hasattr(F, "scaled_dot_product_attention")
131
+
132
+ def get_generator(seed, device):
133
+
134
+ if seed is not None:
135
+ if isinstance(seed, list):
136
+ generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
137
+ else:
138
+ generator = torch.Generator(device).manual_seed(seed)
139
+ else:
140
+ generator = None
141
+
142
+ return generator