File size: 7,354 Bytes
8a4a948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338f71e
8a4a948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef3a17c
 
8a4a948
 
338f71e
8a4a948
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
from PIL import Image
import torch
import torch.nn as nn
from typing import List, Optional
import torch.utils.checkpoint
from torchvision.transforms import ToPILImage
from model_lib.moMA_generator import MoMA_generator
from transformers.activations import ACT2FN
from huggingface_hub import hf_hub_download

from dataset_lib.dataset_eval_MoMA import Dataset_evaluate_MoMA

from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path
from llava.constants import IMAGE_TOKEN_INDEX

def add_function(model):
    def my_llava_forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        return_dict: Optional[bool] = None,
    ):
        (_,position_ids,attention_mask,_,inputs_embeds,_) = self.prepare_inputs_labels_for_multimodal(input_ids,position_ids,attention_mask,None,None,images)
        
        outputs = self.model(
            input_ids=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=None,
            inputs_embeds=inputs_embeds,
            use_cache=True,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True,
        )
        return outputs[0]
    
    model.my_llava_forward = my_llava_forward


class LlamaMLP_mapping(nn.Module):
    def __init__(self, hidden_size,hidden_size_out):
        super().__init__()
        self.hidden_size, self.hidden_size_out = hidden_size,hidden_size_out
        self.gate_proj = nn.Linear(self.hidden_size, self.hidden_size_out, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.hidden_size_out, bias=False)
        self.down_proj = nn.Linear(self.hidden_size_out, self.hidden_size_out, bias=False)
        self.act_fn = ACT2FN["silu"]
        self.act_fn_output = ACT2FN["tanh"]
        self.init_linear()

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj

    def init_linear(self):
        torch.nn.init.xavier_normal_(self.gate_proj.weight) 
        self.gate_proj.weight.data=self.gate_proj.weight.data/4.0
        torch.nn.init.xavier_normal_(self.up_proj.weight) 
        self.up_proj.weight.data=self.up_proj.weight.data/4.0
        torch.nn.init.xavier_normal_(self.down_proj.weight) 
        self.down_proj.weight.data=self.down_proj.weight.data/4.0

class MoMA_main_modal(nn.Module):
    def __init__(self,args):
        super().__init__()
        self.args = args
        self.device = args.device

        self.moMA_generator = MoMA_generator(self.device,args)
        self.unet = self.moMA_generator.pipe.unet
        self.vae = self.moMA_generator.pipe.vae
        
        print('Loading MoMA: its Multi-modal LLM...')
        model_name = get_model_name_from_path(args.model_path)
        self.tokenizer_llava, self.model_llava, self.image_processor_llava, self.context_len_llava = load_pretrained_model(args.model_path, None, model_name, load_8bit=self.args.load_8bit, load_4bit=self.args.load_4bit, device=args.device)
        
        add_function(self.model_llava)

        self.mapping = LlamaMLP_mapping(4096,1024).to(self.device, dtype=torch.float16)
        self.load_saved_components()
        self.freeze_modules()

    def load_saved_components(self):
        if not os.path.exists(self.args.load_attn_adapters):
            print('Loading Attentions and LLM mappings...')
            hf_hub_download(repo_id=self.args.model_path, filename="attn_adapters_projectors.th",local_dir='/'.join(self.args.load_attn_adapters.split('/')[:-1]))

        #load attention adapters and self cross attentions
        state_dict = torch.load(self.args.load_attn_adapters, map_location="cpu")
        self.moMA_generator.image_proj_model.load_state_dict(state_dict["projectors"])
        attn_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
        attn_layers.load_state_dict(state_dict["self_cross_attentions"],strict=False)

        #load LLM projectors
        self.load_state_dict(state_dict['llm_mapping'],strict=False)

    def freeze_modules(self): 
        all_modules = [self.moMA_generator.pipe.vae,self.moMA_generator.pipe.text_encoder,self.unet,self.model_llava,self.mapping]
        for module in all_modules:
            module.train = False
            module.requires_grad_(False)

    def forward_MLLM(self,batch):
        llava_processeds,subjects,prompts = batch['llava_processed'].half().to(self.device),batch['label'],batch['text']
        
        input_ids,attention_masks,position_ids = [],[],[]
        for subject,prompt in zip(subjects,prompts):
            prompt_construct = f"USER: <image>\n A photo of a {subject}. Describe a new image of the same {subject} in: {prompt}. ASSISTANT: *" 
            input_id = tokenizer_image_token(prompt_construct, self.tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
            attention_mask = torch.ones(input_id.shape, dtype=torch.long, device=self.device)
            position_id = torch.tensor(list(range(input_id.shape[-1])), device=self.device)
            
            position_ids += [position_id]
            attention_masks += [attention_mask[0]]
            input_ids += [input_id[0]] 
        
        input_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1])  for i in input_ids],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1]) 
        position_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1])  for i in position_ids],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1]) 
        attention_masks = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1])  for i in attention_masks],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1]) 
        
        output = self.model_llava.my_llava_forward(self.model_llava,input_ids=input_ids,attention_mask=attention_masks,position_ids=position_ids,images=llava_processeds)
        output = self.mapping(output)
        return output[:,-1,:]

    def reset(self):
        self.moMA_generator.reset_all()

    def generate_images(self, rgb_path, subject, prompt, strength=1.0, num=1, seed=0):
        batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject,self)
        self.moMA_generator.set_selfAttn_strength(strength)
        
        with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):
            with torch.no_grad(): 
                ### key steps
                llava_emb = self.forward_MLLM(batch).clone().detach()
                img,mask = self.moMA_generator.generate_with_MoMA(batch,llava_emb=llava_emb,seed=seed,device=self.args.device)                            
                self.reset()
        
        result = ToPILImage()(img[0])
        return result