File size: 7,631 Bytes
6eca12e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3308ae3
6eca12e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3308ae3
6eca12e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os, sys
from typing import List

import torch
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.controlnet import MultiControlNetModel
from diffusers.models.embeddings import ImageProjection
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from PIL import Image
from typing import Any, Callable, Dict, List, Optional, Union
from copy import deepcopy
import time
sys.path.append(os.path.dirname(__file__))
from resampler import Resampler
from diffusers import DiffusionPipeline
import numpy as np
# sys.path.append(os.environ['LYRASD_WORKDIR'] + "/tests/utils")
from .tools import get_mem_use

class ImageProjModel(torch.nn.Module):
    """Projection Model"""
    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()
        
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)
        
    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens


class LyraIPAdapter:
    def __init__(
            self,
            sd_pipe,
            sdxl,
            device,
            ip_ckpt=None,
            ip_plus=False,
            image_encoder_path=None,
            num_ip_tokens=4,
            ip_projection_dim=None,
        ):
        self.pipe = sd_pipe
        self.device = device
        self.ip_ckpt = ip_ckpt
        self.num_ip_tokens = num_ip_tokens
        self.ip_projection_dim = ip_projection_dim
        self.sdxl = sdxl
        self.ip_plus = ip_plus
        self.cross_attention_dim = 2048
        # self.pipe = sd_pipe.to(self.device)
        # self.set_ip_adapter()
        
        if image_encoder_path:
            self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(self.device, dtype=torch.float16)
            self.clip_image_processor = CLIPImageProcessor()
            self.projection_dim = self.image_encoder.config.projection_dim

        # image proj model
        if self.ip_ckpt:
            if self.ip_plus:
                proj_heads = 20 if self.sdxl else 12
                self.image_proj_model = self.init_proj_plus(proj_heads, self.num_ip_tokens)
            else:
                self.image_proj_model = self.init_proj(self.ip_projection_dim, self.num_ip_tokens)

        self.load_ip_adapter()
        
    def init_proj_diffuser(self, state_dict):
        # diffusers加载版本
        clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
        cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4

        image_proj_model = ImageProjection(
            cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
        ).to(dtype=self.dtype, device=self.device)
        return image_proj_model

    # init_proj / init_proj_plus 是 facein里实现的
    def init_proj(self, projection_dim, num_tokens):
        image_proj_model = ImageProjModel(
            cross_attention_dim=self.cross_attention_dim,
            clip_embeddings_dim=projection_dim,
            clip_extra_context_tokens=num_tokens,
        ).to(self.device, dtype=torch.float16)
        return image_proj_model


    def init_proj_plus(self, heads, num_tokens):
        image_proj_model = Resampler(
            dim=1280,
            depth=4,
            dim_head=64,
            heads=heads,
            num_queries=num_tokens,
            embedding_dim=self.image_encoder.config.hidden_size,
            output_dim=self.cross_attention_dim,
            ff_mult=4,
        ).to(self.device, dtype=torch.float16)
        return image_proj_model
        
    def load_ip_adapter(self):
        unet = self.pipe.unet

        def parse_ckpt_path(ckpt):
            ll = ckpt.split("/")
            weight_name = ll[-1]
            subfolder = ll[-2]
            pretrained_path = "/".join(ll[:-2])
            return pretrained_path, subfolder, weight_name

        if self.ip_ckpt:
            state_dict = torch.load(self.ip_ckpt, map_location="cpu")
            self.image_proj_model.load_state_dict(state_dict["image_proj"])
            pretrained_path, subfolder, weight_name = parse_ckpt_path(self.ip_ckpt)
            dir_ipadapter = os.path.join(pretrained_path, "lyra_tran", subfolder, '.'.join(weight_name.split(".")[:-1]))
            unet.load_ip_adapter(dir_ipadapter, "", 1, "fp16")
        
    @torch.inference_mode()
    def get_image_embeds(self, image=None):
        image_prompt_embeds, uncond_image_prompt_embeds = None, None

        if image is not None:
            if not isinstance(image, list):
                image = [image]
            clip_image = self.clip_image_processor(images=image, return_tensors="pt").pixel_values
            clip_image = clip_image.to(self.device, dtype=torch.float16)
            if self.ip_plus:
                clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
                uncond_clip_image_embeds = self.image_encoder(
                    torch.zeros_like(clip_image), output_hidden_states=True
                ).hidden_states[-2]
            else:
                clip_image_embeds = self.image_encoder(clip_image).image_embeds
                uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds)
            clip_image_prompt_embeds = self.image_proj_model(clip_image_embeds)
            uncond_clip_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
            image_prompt_embeds = clip_image_prompt_embeds
            uncond_image_prompt_embeds = uncond_clip_image_prompt_embeds

        return image_prompt_embeds, uncond_image_prompt_embeds
    
    @torch.inference_mode()
    def get_image_embeds_lyrasd(self, image=None, ip_image_embeds=None, batch_size = 1, ip_scale=1.0, do_classifier_free_guidance=True):
        dict_tensor = {}

        if self.ip_ckpt and ip_scale>0:
            if ip_image_embeds is not None:
                dict_tensor["ip_hidden_states"] = ip_image_embeds
            elif image is not None:
                if not isinstance(image, list):
                    image = [image]
                clip_image = self.clip_image_processor(images=image, return_tensors="pt").pixel_values
                clip_image = clip_image.to(self.device, dtype=torch.float16)
                if self.ip_plus:
                    clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
                    uncond_clip_image_embeds = self.image_encoder(
                        torch.zeros_like(clip_image), output_hidden_states=True
                    ).hidden_states[-2]
                else:
                    clip_image_embeds = self.image_encoder(clip_image).image_embeds
                    uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds)

                if do_classifier_free_guidance:
                    clip_image_embeds = torch.cat([uncond_clip_image_embeds, clip_image_embeds])
                ip_image_embeds = self.image_proj_model(clip_image_embeds)
                dict_tensor["ip_hidden_states"] = ip_image_embeds
        return dict_tensor