MoMA_demo / model_lib /modules.py
Kunpeng Song
bg
ef3a17c
import os
from PIL import Image
import torch
import torch.nn as nn
from typing import List, Optional
import torch.utils.checkpoint
from torchvision.transforms import ToPILImage
from model_lib.moMA_generator import MoMA_generator
from transformers.activations import ACT2FN
from huggingface_hub import hf_hub_download
from dataset_lib.dataset_eval_MoMA import Dataset_evaluate_MoMA
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path
from llava.constants import IMAGE_TOKEN_INDEX
def add_function(model):
def my_llava_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
):
(_,position_ids,attention_mask,_,inputs_embeds,_) = self.prepare_inputs_labels_for_multimodal(input_ids,position_ids,attention_mask,None,None,images)
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=inputs_embeds,
use_cache=True,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
return outputs[0]
model.my_llava_forward = my_llava_forward
class LlamaMLP_mapping(nn.Module):
def __init__(self, hidden_size,hidden_size_out):
super().__init__()
self.hidden_size, self.hidden_size_out = hidden_size,hidden_size_out
self.gate_proj = nn.Linear(self.hidden_size, self.hidden_size_out, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.hidden_size_out, bias=False)
self.down_proj = nn.Linear(self.hidden_size_out, self.hidden_size_out, bias=False)
self.act_fn = ACT2FN["silu"]
self.act_fn_output = ACT2FN["tanh"]
self.init_linear()
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def init_linear(self):
torch.nn.init.xavier_normal_(self.gate_proj.weight)
self.gate_proj.weight.data=self.gate_proj.weight.data/4.0
torch.nn.init.xavier_normal_(self.up_proj.weight)
self.up_proj.weight.data=self.up_proj.weight.data/4.0
torch.nn.init.xavier_normal_(self.down_proj.weight)
self.down_proj.weight.data=self.down_proj.weight.data/4.0
class MoMA_main_modal(nn.Module):
def __init__(self,args):
super().__init__()
self.args = args
self.device = args.device
self.moMA_generator = MoMA_generator(self.device,args)
self.unet = self.moMA_generator.pipe.unet
self.vae = self.moMA_generator.pipe.vae
print('Loading MoMA: its Multi-modal LLM...')
model_name = get_model_name_from_path(args.model_path)
self.tokenizer_llava, self.model_llava, self.image_processor_llava, self.context_len_llava = load_pretrained_model(args.model_path, None, model_name, load_8bit=self.args.load_8bit, load_4bit=self.args.load_4bit, device=args.device)
add_function(self.model_llava)
self.mapping = LlamaMLP_mapping(4096,1024).to(self.device, dtype=torch.float16)
self.load_saved_components()
self.freeze_modules()
def load_saved_components(self):
if not os.path.exists(self.args.load_attn_adapters):
print('Loading Attentions and LLM mappings...')
hf_hub_download(repo_id=self.args.model_path, filename="attn_adapters_projectors.th",local_dir='/'.join(self.args.load_attn_adapters.split('/')[:-1]))
#load attention adapters and self cross attentions
state_dict = torch.load(self.args.load_attn_adapters, map_location="cpu")
self.moMA_generator.image_proj_model.load_state_dict(state_dict["projectors"])
attn_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
attn_layers.load_state_dict(state_dict["self_cross_attentions"],strict=False)
#load LLM projectors
self.load_state_dict(state_dict['llm_mapping'],strict=False)
def freeze_modules(self):
all_modules = [self.moMA_generator.pipe.vae,self.moMA_generator.pipe.text_encoder,self.unet,self.model_llava,self.mapping]
for module in all_modules:
module.train = False
module.requires_grad_(False)
def forward_MLLM(self,batch):
llava_processeds,subjects,prompts = batch['llava_processed'].half().to(self.device),batch['label'],batch['text']
input_ids,attention_masks,position_ids = [],[],[]
for subject,prompt in zip(subjects,prompts):
prompt_construct = f"USER: <image>\n A photo of a {subject}. Describe a new image of the same {subject} in: {prompt}. ASSISTANT: *"
input_id = tokenizer_image_token(prompt_construct, self.tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
attention_mask = torch.ones(input_id.shape, dtype=torch.long, device=self.device)
position_id = torch.tensor(list(range(input_id.shape[-1])), device=self.device)
position_ids += [position_id]
attention_masks += [attention_mask[0]]
input_ids += [input_id[0]]
input_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in input_ids],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
position_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in position_ids],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
attention_masks = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in attention_masks],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
output = self.model_llava.my_llava_forward(self.model_llava,input_ids=input_ids,attention_mask=attention_masks,position_ids=position_ids,images=llava_processeds)
output = self.mapping(output)
return output[:,-1,:]
def reset(self):
self.moMA_generator.reset_all()
def generate_images(self, rgb_path, subject, prompt, strength=1.0, num=1, seed=0):
batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject,self)
self.moMA_generator.set_selfAttn_strength(strength)
with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):
with torch.no_grad():
### key steps
llava_emb = self.forward_MLLM(batch).clone().detach()
img,mask = self.moMA_generator.generate_with_MoMA(batch,llava_emb=llava_emb,seed=seed,device=self.args.device)
self.reset()
result = ToPILImage()(img[0])
return result