Spaces:
Paused
Paused
File size: 8,493 Bytes
5885496 968fffb 5885496 968fffb 5885496 968fffb 5885496 968fffb 5885496 968fffb 5885496 968fffb 5885496 968fffb 5885496 968fffb 5885496 968fffb 5885496 968fffb 5885496 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
from typing import Callable, List, Optional, Tuple, Union
import json
import glob
import math
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from transformers import LlamaForCausalLM, CLIPVisionModel, BitsAndBytesConfig
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
from .llava.model.llava import LlavaLlamaForCausalLM
from .segment_anything import build_sam_vit_l, build_sam_vit_h
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
if 'mm_projector' in lora_module_names:
lora_module_names.remove('mm_projector')
return sorted(list(lora_module_names))
class LISA(nn.Module):
def __init__(self,
local_rank,
seg_token_idx,
tokenizer,
llm_version,
lora_r,
precision,
load_in_4bit=False,
load_in_8bit=False,
lora_target_modules=['q_proj', 'v_proj'],
lora_alpha=16,
lora_dropout=0.05,
vision_tower='openai/clip-vit-large-patch14',
mm_vision_select_layer=-2,
freeze_lm=True,
train_mask_decoder=True,
out_dim=256,
):
super().__init__()
self.tokenizer = tokenizer
self.image_token = tokenizer.cls_token_id
self.precision = precision
# LLaVA
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
if precision == "bf16":
self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.bfloat16, cache_dir=None, low_cpu_mem_usage=True)
elif precision == "fp16":
if load_in_4bit:
self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, load_in_4bit=True, cache_dir=None, low_cpu_mem_usage=True, device_map='auto',
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
)
elif load_in_8bit:
self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, load_in_8bit=True, cache_dir=None, low_cpu_mem_usage=True, device_map='auto')
else:
self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.half, cache_dir=None, low_cpu_mem_usage=True)
else:
self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.float32, cache_dir=None, low_cpu_mem_usage=True)
self.lm.enable_input_require_grads()
self.lm.gradient_checkpointing_enable()
self.lm.config.use_cache = False
model_vision_dict = self.lm.get_model().initialize_vision_modules(vision_tower=vision_tower, mm_vision_select_layer=mm_vision_select_layer, precision=precision)
vision_config = model_vision_dict['vision_config']
vision_tower = self.lm.get_model().vision_tower[0]
self.lm.model.config.eos_token_id = tokenizer.eos_token_id
self.lm.model.config.bos_token_id = tokenizer.bos_token_id
self.lm.model.config.pad_token_id = tokenizer.pad_token_id
if vision_tower.device.type == 'meta':
if precision == 'bf16':
vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).cuda(local_rank)
elif precision == 'fp16':
vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.half, low_cpu_mem_usage=True).cuda(local_rank)
else:
vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float32, low_cpu_mem_usage=True).cuda(local_rank)
self.lm.get_model().vision_tower[0] = vision_tower
else:
if precision == "bf16":
vision_tower.to(device='cuda', dtype=torch.bfloat16)
elif precision == "fp16":
vision_tower.to(device='cuda', dtype=torch.half)
else:
vision_tower.to(device='cuda', dtype=torch.float32)
self.lm.config.tune_mm_mlp_adapter = False
self.lm.config.freeze_mm_mlp_adapter = False
self.lm.config.mm_use_im_start_end = True
vision_config.use_im_start_end = True
self.lm.config.sep_image_conv_front = False
self.lm.initialize_vision_tokenizer(mm_use_im_start_end=True, tokenizer=tokenizer, num_new_tokens=num_new_tokens, device=local_rank, tune_mm_mlp_adapter=False)
if freeze_lm:
for n, param in self.lm.named_parameters():
param.requires_grad = False
self.llm_version = llm_version
self.seg_token_idx = seg_token_idx
self.lm.resize_token_embeddings(len(tokenizer))
for n, p in self.lm.named_parameters():
if any([x in n for x in ['lm_head', 'embed_tokens']]) and p.shape[0] == len(tokenizer):
p.requires_grad = True
# SAM
self.visual_model = build_sam_vit_h(None)
for param in self.visual_model.parameters():
param.requires_grad = False
if train_mask_decoder:
self.visual_model.mask_decoder.train()
for param in self.visual_model.mask_decoder.parameters():
param.requires_grad = True
# Projection layer
in_dim = self.lm.config.hidden_size
text_fc = [nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), nn.Linear(in_dim, out_dim), nn.Dropout(0.0)]
self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
def get_visual_embs(self, pixel_values: torch.FloatTensor):
image_embeddings = self.visual_model.image_encoder(pixel_values)
return image_embeddings
def evaluate(self, images_clip, images, input_ids, resize_list, original_size_list, max_new_tokens=32, tokenizer=None):
with torch.no_grad():
outputs = self.lm.generate(images=images_clip, input_ids=input_ids, max_new_tokens=max_new_tokens, num_beams=1, output_hidden_states=True, return_dict_in_generate=True)
output_hidden_states = outputs.hidden_states[-1]
output_ids = outputs.sequences
seg_token_mask = (output_ids[:, 1:] == self.seg_token_idx)
last_embedding = None
last_output_logit = None
hidden_states = []
assert len(self.text_hidden_fcs) == 1
hidden_states.append(self.text_hidden_fcs[0](output_hidden_states))
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
pred_embeddings = last_hidden_state[seg_token_mask]
seg_token_counts = seg_token_mask.int().sum(-1) #[bs, ]
seg_token_offset = seg_token_counts.cumsum(-1)
seg_token_offset = torch.cat([torch.zeros(1).long().cuda(), seg_token_offset], dim=0)
pred_embeddings_ = []
for i in range(len(seg_token_offset)-1):
start_i, end_i = seg_token_offset[i], seg_token_offset[i+1]
pred_embeddings_.append(pred_embeddings[start_i: end_i])
pred_embeddings = pred_embeddings_
image_embeddings = self.get_visual_embs(images)
multimask_output = False
pred_masks = []
for i in range(len(pred_embeddings)):
sparse_embeddings, dense_embeddings = self.visual_model.prompt_encoder(
points=None,
boxes=None,
masks=None,
text_embeds=pred_embeddings[i].unsqueeze(1),
)
sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
low_res_masks, iou_predictions = self.visual_model.mask_decoder(
image_embeddings=image_embeddings[i].unsqueeze(0),
image_pe=self.visual_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
pred_mask = self.visual_model.postprocess_masks(
low_res_masks,
input_size=resize_list[i],
original_size=original_size_list[i],
)
pred_masks.append(pred_mask[:, 0])
return output_ids, pred_masks
|