hideosnes commited on
Commit
2fdc658
1 Parent(s): 1fd342e

Create ip_adapter.py

Browse files
Files changed (1) hide show
  1. ip_adapter/ip_adapter.py +461 -0
ip_adapter/ip_adapter.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ from diffusers.pipelines.controlnet import MultiControlNetModel
7
+ from PIL import Image
8
+ from safetensors import safe_open
9
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
+
11
+ from .utils import is_torch2_available, get_generator
12
+
13
+ if is_torch2_available():
14
+ from .attention_processor import (
15
+ AttnProcessor2_0 as AttnProcessor,
16
+ )
17
+ from .attention_processor import (
18
+ CNAttnProcessor2_0 as CNAttnProcessor,
19
+ )
20
+ from .attention_processor import (
21
+ IPAttnProcessor2_0 as IPAttnProcessor,
22
+ )
23
+ else:
24
+ from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
25
+ from .resampler import Resampler
26
+
27
+
28
+ class ImageProjModel(torch.nn.Module):
29
+ """Projection Model"""
30
+
31
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
32
+ super().__init__()
33
+
34
+ self.generator = None
35
+ self.cross_attention_dim = cross_attention_dim
36
+ self.clip_extra_context_tokens = clip_extra_context_tokens
37
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
38
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
39
+
40
+ def forward(self, image_embeds):
41
+ embeds = image_embeds
42
+ clip_extra_context_tokens = self.proj(embeds).reshape(
43
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
44
+ )
45
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
46
+ return clip_extra_context_tokens
47
+
48
+
49
+ class MLPProjModel(torch.nn.Module):
50
+ """SD model with image prompt"""
51
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
52
+ super().__init__()
53
+
54
+ self.proj = torch.nn.Sequential(
55
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
56
+ torch.nn.GELU(),
57
+ torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
58
+ torch.nn.LayerNorm(cross_attention_dim)
59
+ )
60
+
61
+ def forward(self, image_embeds):
62
+ clip_extra_context_tokens = self.proj(image_embeds)
63
+ return clip_extra_context_tokens
64
+
65
+
66
+ class IPAdapter:
67
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, target_blocks=["block"]):
68
+ self.device = device
69
+ self.image_encoder_path = image_encoder_path
70
+ self.ip_ckpt = ip_ckpt
71
+ self.num_tokens = num_tokens
72
+ self.target_blocks = target_blocks
73
+
74
+ self.pipe = sd_pipe.to(self.device)
75
+ self.set_ip_adapter()
76
+
77
+ # load image encoder
78
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
79
+ self.device, dtype=torch.float16
80
+ )
81
+ self.clip_image_processor = CLIPImageProcessor()
82
+ # image proj model
83
+ self.image_proj_model = self.init_proj()
84
+
85
+ self.load_ip_adapter()
86
+
87
+
88
+ def init_proj(self):
89
+ image_proj_model = ImageProjModel(
90
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
91
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
92
+ clip_extra_context_tokens=self.num_tokens,
93
+ ).to(self.device, dtype=torch.float16)
94
+ return image_proj_model
95
+
96
+ def set_ip_adapter(self):
97
+ unet = self.pipe.unet
98
+ attn_procs = {}
99
+ for name in unet.attn_processors.keys():
100
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
101
+ if name.startswith("mid_block"):
102
+ hidden_size = unet.config.block_out_channels[-1]
103
+ elif name.startswith("up_blocks"):
104
+ block_id = int(name[len("up_blocks.")])
105
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
106
+ elif name.startswith("down_blocks"):
107
+ block_id = int(name[len("down_blocks.")])
108
+ hidden_size = unet.config.block_out_channels[block_id]
109
+ if cross_attention_dim is None:
110
+ attn_procs[name] = AttnProcessor()
111
+ else:
112
+ selected = False
113
+ for block_name in self.target_blocks:
114
+ if block_name in name:
115
+ selected = True
116
+ break
117
+ if selected:
118
+ attn_procs[name] = IPAttnProcessor(
119
+ hidden_size=hidden_size,
120
+ cross_attention_dim=cross_attention_dim,
121
+ scale=1.0,
122
+ num_tokens=self.num_tokens,
123
+ ).to(self.device, dtype=torch.float16)
124
+ else:
125
+ attn_procs[name] = IPAttnProcessor(
126
+ hidden_size=hidden_size,
127
+ cross_attention_dim=cross_attention_dim,
128
+ scale=1.0,
129
+ num_tokens=self.num_tokens,
130
+ skip=True
131
+ ).to(self.device, dtype=torch.float16)
132
+ unet.set_attn_processor(attn_procs)
133
+ if hasattr(self.pipe, "controlnet"):
134
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
135
+ for controlnet in self.pipe.controlnet.nets:
136
+ controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
137
+ else:
138
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
139
+
140
+ def load_ip_adapter(self):
141
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
142
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
143
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
144
+ for key in f.keys():
145
+ if key.startswith("image_proj."):
146
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
147
+ elif key.startswith("ip_adapter."):
148
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
149
+ else:
150
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
151
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
152
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
153
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
154
+
155
+ @torch.inference_mode()
156
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
157
+ if pil_image is not None:
158
+ if isinstance(pil_image, Image.Image):
159
+ pil_image = [pil_image]
160
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
161
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
162
+ else:
163
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
164
+
165
+ if content_prompt_embeds is not None:
166
+ clip_image_embeds = clip_image_embeds - content_prompt_embeds
167
+
168
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
169
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
170
+ return image_prompt_embeds, uncond_image_prompt_embeds
171
+
172
+ def set_scale(self, scale):
173
+ for attn_processor in self.pipe.unet.attn_processors.values():
174
+ if isinstance(attn_processor, IPAttnProcessor):
175
+ attn_processor.scale = scale
176
+
177
+ def generate(
178
+ self,
179
+ pil_image=None,
180
+ clip_image_embeds=None,
181
+ prompt=None,
182
+ negative_prompt=None,
183
+ scale=1.0,
184
+ num_samples=4,
185
+ seed=None,
186
+ guidance_scale=7.5,
187
+ num_inference_steps=30,
188
+ neg_content_emb=None,
189
+ **kwargs,
190
+ ):
191
+ self.set_scale(scale)
192
+
193
+ if pil_image is not None:
194
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
195
+ else:
196
+ num_prompts = clip_image_embeds.size(0)
197
+
198
+ if prompt is None:
199
+ prompt = "best quality, high quality"
200
+ if negative_prompt is None:
201
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
202
+
203
+ if not isinstance(prompt, List):
204
+ prompt = [prompt] * num_prompts
205
+ if not isinstance(negative_prompt, List):
206
+ negative_prompt = [negative_prompt] * num_prompts
207
+
208
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
209
+ pil_image=pil_image, clip_image_embeds=clip_image_embeds, content_prompt_embeds=neg_content_emb
210
+ )
211
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
212
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
213
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
214
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
215
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
216
+
217
+ with torch.inference_mode():
218
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
219
+ prompt,
220
+ device=self.device,
221
+ num_images_per_prompt=num_samples,
222
+ do_classifier_free_guidance=True,
223
+ negative_prompt=negative_prompt,
224
+ )
225
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
226
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
227
+
228
+ generator = get_generator(seed, self.device)
229
+
230
+ images = self.pipe(
231
+ prompt_embeds=prompt_embeds,
232
+ negative_prompt_embeds=negative_prompt_embeds,
233
+ guidance_scale=guidance_scale,
234
+ num_inference_steps=num_inference_steps,
235
+ generator=generator,
236
+ **kwargs,
237
+ ).images
238
+
239
+ return images
240
+
241
+
242
+ class IPAdapterXL(IPAdapter):
243
+ """SDXL"""
244
+
245
+ def generate(
246
+ self,
247
+ pil_image,
248
+ prompt=None,
249
+ negative_prompt=None,
250
+ scale=1.0,
251
+ num_samples=4,
252
+ seed=None,
253
+ num_inference_steps=30,
254
+ neg_content_emb=None,
255
+ neg_content_prompt=None,
256
+ neg_content_scale=1.0,
257
+ **kwargs,
258
+ ):
259
+ self.set_scale(scale)
260
+
261
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
262
+
263
+ if prompt is None:
264
+ prompt = "best quality, high quality"
265
+ if negative_prompt is None:
266
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
267
+
268
+ if not isinstance(prompt, List):
269
+ prompt = [prompt] * num_prompts
270
+ if not isinstance(negative_prompt, List):
271
+ negative_prompt = [negative_prompt] * num_prompts
272
+
273
+ if neg_content_emb is None:
274
+ if neg_content_prompt is not None:
275
+ with torch.inference_mode():
276
+ (
277
+ prompt_embeds_, # torch.Size([1, 77, 2048])
278
+ negative_prompt_embeds_,
279
+ pooled_prompt_embeds_, # torch.Size([1, 1280])
280
+ negative_pooled_prompt_embeds_,
281
+ ) = self.pipe.encode_prompt(
282
+ neg_content_prompt,
283
+ num_images_per_prompt=num_samples,
284
+ do_classifier_free_guidance=True,
285
+ negative_prompt=negative_prompt,
286
+ )
287
+ pooled_prompt_embeds_ *= neg_content_scale
288
+ else:
289
+ pooled_prompt_embeds_ = neg_content_emb
290
+ else:
291
+ pooled_prompt_embeds_ = None
292
+
293
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image, content_prompt_embeds=pooled_prompt_embeds_)
294
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
295
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
296
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
297
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
298
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
299
+
300
+ with torch.inference_mode():
301
+ (
302
+ prompt_embeds,
303
+ negative_prompt_embeds,
304
+ pooled_prompt_embeds,
305
+ negative_pooled_prompt_embeds,
306
+ ) = self.pipe.encode_prompt(
307
+ prompt,
308
+ num_images_per_prompt=num_samples,
309
+ do_classifier_free_guidance=True,
310
+ negative_prompt=negative_prompt,
311
+ )
312
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
313
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
314
+
315
+ self.generator = get_generator(seed, self.device)
316
+
317
+ images = self.pipe(
318
+ prompt_embeds=prompt_embeds,
319
+ negative_prompt_embeds=negative_prompt_embeds,
320
+ pooled_prompt_embeds=pooled_prompt_embeds,
321
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
322
+ num_inference_steps=num_inference_steps,
323
+ generator=self.generator,
324
+ **kwargs,
325
+ ).images
326
+
327
+ return images
328
+
329
+
330
+ class IPAdapterPlus(IPAdapter):
331
+ """IP-Adapter with fine-grained features"""
332
+
333
+ def init_proj(self):
334
+ image_proj_model = Resampler(
335
+ dim=self.pipe.unet.config.cross_attention_dim,
336
+ depth=4,
337
+ dim_head=64,
338
+ heads=12,
339
+ num_queries=self.num_tokens,
340
+ embedding_dim=self.image_encoder.config.hidden_size,
341
+ output_dim=self.pipe.unet.config.cross_attention_dim,
342
+ ff_mult=4,
343
+ ).to(self.device, dtype=torch.float16)
344
+ return image_proj_model
345
+
346
+ @torch.inference_mode()
347
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
348
+ if isinstance(pil_image, Image.Image):
349
+ pil_image = [pil_image]
350
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
351
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
352
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
353
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
354
+ uncond_clip_image_embeds = self.image_encoder(
355
+ torch.zeros_like(clip_image), output_hidden_states=True
356
+ ).hidden_states[-2]
357
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
358
+ return image_prompt_embeds, uncond_image_prompt_embeds
359
+
360
+
361
+ class IPAdapterFull(IPAdapterPlus):
362
+ """IP-Adapter with full features"""
363
+
364
+ def init_proj(self):
365
+ image_proj_model = MLPProjModel(
366
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
367
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
368
+ ).to(self.device, dtype=torch.float16)
369
+ return image_proj_model
370
+
371
+
372
+ class IPAdapterPlusXL(IPAdapter):
373
+ """SDXL"""
374
+
375
+ def init_proj(self):
376
+ image_proj_model = Resampler(
377
+ dim=1280,
378
+ depth=4,
379
+ dim_head=64,
380
+ heads=20,
381
+ num_queries=self.num_tokens,
382
+ embedding_dim=self.image_encoder.config.hidden_size,
383
+ output_dim=self.pipe.unet.config.cross_attention_dim,
384
+ ff_mult=4,
385
+ ).to(self.device, dtype=torch.float16)
386
+ return image_proj_model
387
+
388
+ @torch.inference_mode()
389
+ def get_image_embeds(self, pil_image):
390
+ if isinstance(pil_image, Image.Image):
391
+ pil_image = [pil_image]
392
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
393
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
394
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
395
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
396
+ uncond_clip_image_embeds = self.image_encoder(
397
+ torch.zeros_like(clip_image), output_hidden_states=True
398
+ ).hidden_states[-2]
399
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
400
+ return image_prompt_embeds, uncond_image_prompt_embeds
401
+
402
+ def generate(
403
+ self,
404
+ pil_image,
405
+ prompt=None,
406
+ negative_prompt=None,
407
+ scale=1.0,
408
+ num_samples=4,
409
+ seed=None,
410
+ num_inference_steps=30,
411
+ **kwargs,
412
+ ):
413
+ self.set_scale(scale)
414
+
415
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
416
+
417
+ if prompt is None:
418
+ prompt = "best quality, high quality"
419
+ if negative_prompt is None:
420
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
421
+
422
+ if not isinstance(prompt, List):
423
+ prompt = [prompt] * num_prompts
424
+ if not isinstance(negative_prompt, List):
425
+ negative_prompt = [negative_prompt] * num_prompts
426
+
427
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
428
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
429
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
430
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
431
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
432
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
433
+
434
+ with torch.inference_mode():
435
+ (
436
+ prompt_embeds,
437
+ negative_prompt_embeds,
438
+ pooled_prompt_embeds,
439
+ negative_pooled_prompt_embeds,
440
+ ) = self.pipe.encode_prompt(
441
+ prompt,
442
+ num_images_per_prompt=num_samples,
443
+ do_classifier_free_guidance=True,
444
+ negative_prompt=negative_prompt,
445
+ )
446
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
447
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
448
+
449
+ generator = get_generator(seed, self.device)
450
+
451
+ images = self.pipe(
452
+ prompt_embeds=prompt_embeds,
453
+ negative_prompt_embeds=negative_prompt_embeds,
454
+ pooled_prompt_embeds=pooled_prompt_embeds,
455
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
456
+ num_inference_steps=num_inference_steps,
457
+ generator=generator,
458
+ **kwargs,
459
+ ).images
460
+
461
+ return images