File size: 22,551 Bytes
cca9b7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 |
import logging
import random
import torch
from torch.cuda.amp import autocast as autocast
from torchvision import models
import torch.nn as nn
from medomni.common.registry import registry
from medomni.models.blip2 import Blip2Base, disabled_train
from medomni.models.modeling_llama import LlamaForCausalLM
from transformers import LlamaTokenizer
from transformers import SwinModel
import torch.nn.functional as F
import math
from einops import rearrange, repeat
from einops_exts import rearrange_many
import open_clip
import segmentation_models_pytorch as smp
from medomni.models.UNet import UNet3d
from huggingface_hub import PyTorchModelHubMixin
import ipdb
from peft import (
get_peft_model,
LoraConfig,
PrefixTuningConfig,
PromptEncoderConfig,
PromptTuningConfig,
TaskType,
)
class GroupNorm(nn.GroupNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
def replace_batchnorm_2d(model):
for name, module in reversed(model._modules.items()):
if len(list(module.children())) > 0:
model._modules[name] = replace_batchnorm_2d(module)
if isinstance(module, nn.BatchNorm2d):
model._modules[name] = GroupNorm(num_groups=16, num_channels=module.num_features)
return model
def dice_loss(input, target):
input = torch.sigmoid(input)
smooth = 1.0
iflat = input.view(-1)
tflat = target.view(-1)
intersection = (iflat * tflat).sum()
return ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))
class FocalLoss(nn.Module):
def __init__(self, gamma):
super().__init__()
self.gamma = gamma
def forward(self, input, target):
if not (target.size() == input.size()):
raise ValueError("Target size ({}) must be the same as input size ({})"
.format(target.size(), input.size()))
max_val = (-input).clamp(min=0)
loss = input - input * target + max_val + \
((-max_val).exp() + (-input - max_val).exp()).log()
invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))
loss = (invprobs * self.gamma).exp() * loss
return loss.mean()
class MixedLoss(nn.Module):
def __init__(self, alpha, gamma):
super().__init__()
self.alpha = alpha
self.focal = FocalLoss(gamma)
def forward(self, input, target):
loss = self.alpha*self.focal(input, target) - torch.log(dice_loss(input, target))
return loss.mean()
def trans_seg(sample_num, bsz):
labels = torch.zeros((bsz, 10))
c_bsz = 0
for num1 in sample_num:
num2 = num1.split('-')
for num3 in num2:
if num3 != 'n/a':
c4 = 0
for num in num3.split(','):
labels[c_bsz, c4] = float(num)
c4 += 1
c_bsz += 1
return labels
def trans_det(sample_num, bsz):
labels = torch.zeros((bsz, 4))
c_bsz = 0
for num1 in sample_num:
num2 = num1.split(';')
for num3 in num2:
if num3 != 'n/a':
c4 = 0
for num in num3.split(','):
labels[c_bsz, c4] = float(num)
c4 += 1
c_bsz += 1
return labels
def trans_keypoint(sample_num, bsz):
labels = torch.zeros((bsz, 2))
c_bsz = 0
for num1 in sample_num:
num2 = num1.split(';')
for num3 in num2:
if num3 != 'n/a':
c4 = 0
for num in num3.split(','):
labels[c_bsz, c4] = float(num)
c4 += 1
c_bsz += 1
return labels
@registry.register_model("medomni")
class MedOmni(Blip2Base, PyTorchModelHubMixin):
PRETRAINED_MODEL_CONFIG_DICT = {
"medomni": "configs/models/medomni.yaml",
}
def __init__(
self,
config,
):
super().__init__()
freeze_vit=True
llama_model=config['llama_model']
max_txt_len=config['max_txt_len']
low_resource=False # use 8 bit and put vit in cpu / have not been tested
end_sym=config['end_sym']
# self.tokenizer = self.init_tokenizer()
self.low_resource = low_resource
print('Loading VIT')
self.visual_encoder_2d = SwinModel.from_pretrained('microsoft/swin-base-patch4-window7-224')
self.visual_encoder_3d = UNet3d(in_channels=1, n_classes=1, n_channels=32)
self.ln_vision_2d = LayerNorm(1024)
self.ln_vision_3d = LayerNorm(256)
if freeze_vit:
for name, param in self.visual_encoder_2d.named_parameters():
param.requires_grad = False
self.visual_encoder_2d = self.visual_encoder_2d.eval()
self.visual_encoder_2d.train = disabled_train
for name, param in self.ln_vision_2d.named_parameters():
param.requires_grad = False
self.ln_vision_2d = self.ln_vision_2d.eval()
self.ln_vision_2d.train = disabled_train
for name, param in self.visual_encoder_3d.named_parameters():
param.requires_grad = False
self.visual_encoder_3d = self.visual_encoder_3d.eval()
self.visual_encoder_3d.train = disabled_train
for name, param in self.ln_vision_3d.named_parameters():
param.requires_grad = False
self.ln_vision_3d = self.ln_vision_3d.eval()
self.ln_vision_3d.train = disabled_train
logging.info("freeze vision encoder")
print('Loading VIT Done')
print('Loading LLAMA')
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, legacy=False, use_fast=False)
special_token = {}
special_token["additional_special_tokens"] = ['<ImageHere>']
self.llama_tokenizer.add_special_tokens(
special_token
)
self.llama_tokenizer.add_tokens("<DET>")
self.llama_tokenizer.add_tokens("<2DSEG>")
self.llama_tokenizer.add_tokens("<3DSEG>")
# self.llama_tokenizer.add_tokens("<2DPOINT>")
self.llama_tokenizer.add_tokens("<N/A>")
self.det_token_idx = self.llama_tokenizer("<DET>", add_special_tokens=False).input_ids[0]
self.seg_token_idx_2d = self.llama_tokenizer("<2DSEG>", add_special_tokens=False).input_ids[0]
self.seg_token_idx_3d = self.llama_tokenizer("<3DSEG>", add_special_tokens=False).input_ids[0]
# self.point_token_idx_2d = self.llama_tokenizer("<2DPOINT>", add_special_tokens=False).input_ids[0]
self.na_token_idx = self.llama_tokenizer("<N/A>", add_special_tokens=False).input_ids[0]
self.llama_tokenizer.pad_token = 0
if self.low_resource:
self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model,
torch_dtype=torch.bfloat16,
load_in_8bit=True,
device_map="auto"
)
else:
self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model,
torch_dtype=torch.bfloat16,
)
self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
self.embed_tokens = self.llama_model.get_input_embeddings()
self.embed_states = self.llama_model.get_output_embeddings() # cannot remove
# ---LoRA---
class CastOutputToFloat(nn.Sequential):
def forward(self, x): return super().forward(x).to(torch.bfloat16)
self.llama_model.lm_head = CastOutputToFloat(self.llama_model.lm_head)
# ---LoRA---
print("Setup PEFT")
peft_config = LoraConfig(
task_type="CAUSAL_LM", inference_mode=False,
r=16,
lora_alpha=16, lora_dropout=0.1,
target_modules=['q_proj', 'v_proj']
) # 8 32 hyz 9.21
self.llama_model = get_peft_model(self.llama_model, peft_config)
self.llama_proj_2d = nn.Linear(1024, self.llama_model.config.hidden_size)
self.llama_proj_3d = nn.Linear(256, self.llama_model.config.hidden_size)
# # Detection
text_det = nn.Sequential(
LayerNorm(self.llama_model.config.hidden_size),
nn.Linear(self.llama_model.config.hidden_size, 256),
nn.ReLU(inplace=True),
LayerNorm(256),
nn.Linear(256, 4),
)
self.text_det = text_det
self.det_loss = torch.nn.SmoothL1Loss()
# # Keypoint
# text_point = nn.Sequential(
# LayerNorm(self.llama_model.config.hidden_size),
# nn.Linear(self.llama_model.config.hidden_size, 256),
# nn.ReLU(inplace=True),
# LayerNorm(256),
# nn.Linear(256, 2),
# )
# self.text_point = text_point
# self.keypoint_loss = torch.nn.SmoothL1Loss()
# Segmentation
self.model_seg_2d = smp.Unet(encoder_name="resnet18", encoder_weights="imagenet", in_channels=3, classes=1)
self.model_seg_2d = replace_batchnorm_2d(self.model_seg_2d) # GN is much better than BN
text2seg_2d = nn.Sequential(
LayerNorm(self.llama_model.config.hidden_size),
nn.Linear(self.llama_model.config.hidden_size, 512),
)
self.text2seg_2d = text2seg_2d
self.text2seg_2d_ln = LayerNorm(512)
self.text2seg_2d_gn = GroupNorm(16, 512)
text2seg_3d = nn.Sequential(
LayerNorm(self.llama_model.config.hidden_size),
nn.Linear(self.llama_model.config.hidden_size, 256),
)
self.text2seg_3d = text2seg_3d
self.text2seg_3d_ln = LayerNorm(256)
self.text2seg_3d_gn = GroupNorm(16, 256)
self.seg_loss = MixedLoss(10.0, 2.0)
self.max_txt_len = max_txt_len
self.end_sym = end_sym
self.prompt_list = []
def vit_to_cpu(self):
self.ln_vision.to("cpu")
self.ln_vision.float()
self.visual_encoder.to("cpu")
self.visual_encoder.float()
def encode_img(self, image, modals, task_types=[]):
B,S,_,_,_ = image.shape
device = image.device
image_embeds_list = None
if self.low_resource:
self.vit_to_cpu()
image = image.to("cpu")
with self.maybe_autocast():
if 'ct' in modals:
image_embeds_list = self.visual_encoder_3d(image, encoder_only=True)
image_embeds_list = [_.to(device) for _ in image_embeds_list]
image_embeds = image_embeds_list[-1].detach()
image_embeds = F.adaptive_avg_pool3d(image_embeds, (1, 3, 3)).view(B, image_embeds.shape[1], -1).permute(0, 2, 1)
inputs_llama = self.llama_proj_3d(self.ln_vision_3d(image_embeds))
inputs_llama = rearrange(inputs_llama, "(b s) c d -> b s c d", b=B, s=S).to(torch.bfloat16)
atts_llama = torch.ones(inputs_llama.size()[:-2], dtype=torch.long).to(image.device)
else:
image = rearrange(image, "b s c h w -> (b s) c h w")
image_embeds = self.visual_encoder_2d(image)['last_hidden_state'].to(device)
image_embeds_unp = image_embeds.permute(0, 2, 1).view(B*S,-1,7,7)
image_embeds_unp = F.adaptive_avg_pool2d(image_embeds_unp, (3, 3))
image_embeds = image_embeds_unp.view(B*S, -1, 9).permute(0, 2, 1)
inputs_llama = self.llama_proj_2d(self.ln_vision_2d(image_embeds))
if 'segmentation' not in task_types:
inputs_llama = rearrange(inputs_llama, "(b s) c d -> b s c d", b=B, s=S).to(torch.bfloat16)
atts_llama = torch.ones(inputs_llama.size()[:-2], dtype=torch.long).to(image.device)
else:
inputs_llama = rearrange(inputs_llama, "(b s) c d -> b s c d", b=B, s=S).to(torch.bfloat16).detach() # add detach() for segmentation tasks
atts_llama = torch.ones(inputs_llama.size()[:-2], dtype=torch.long).to(image.device).detach()
return inputs_llama, atts_llama, image_embeds_list
def prompt_concat(self, img_embeds, atts_img, prompt):
if prompt:
batch_size = img_embeds.shape[0]
p_after_embeds = self.embed_tokens(prompt.input_ids).expand(batch_size, -1, -1)
wrapped_img_embeds = torch.cat([img_embeds, p_after_embeds], dim=1)
wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
return wrapped_img_embeds, wrapped_atts_img
else:
return img_embeds, atts_img
def prompt_wrap(self, img_embeds, atts_img, prompt_list, num_imgs, seg=None):
bsz = img_embeds.shape[0]
if prompt_list:
img_idx = ([], [])
for i in range(len(num_imgs)):
for j in range(num_imgs[i]):
img_idx[0].append(i)
img_idx[1].append(j)
prompt_tokens = self.llama_tokenizer(prompt_list, return_tensors="pt", padding="longest", truncation=True, max_length=256).to(img_embeds.device)
idx = (prompt_tokens.input_ids == 32000).nonzero(as_tuple=True)
prompt_tokens.input_ids[idx] = 123 # avoid memory issue
p_embeds = self.embed_tokens(prompt_tokens.input_ids).expand(bsz, -1, -1)
if seg is None:
p_embeds[idx] = rearrange(img_embeds[img_idx], "b c d -> (b c) d").to(torch.bfloat16)
else:
p_embeds[idx] = rearrange(img_embeds[img_idx], "b c d -> (b c) d").to(torch.bfloat16).detach()
return p_embeds, atts_img
else:
return img_embeds, atts_img
def forward(self, samples):
image = samples["image"]
bsz = image.shape[0]
img_embeds, atts_img, img_embeds_list = self.encode_img(image, samples['modal'], samples['task_type'])
prefix_list = []
tag_list = [[] for _ in range(bsz)]
placeholder = ['<ImageHere>'] * 9 # 9 = the number of visual tokens
for j in range(bsz):
num = samples['num_imgs'][j]
prefix = '' # Can add some prompt, such as 'You will be given an image, please describe everything you see'
for i in range(num):
prefix += '<img' + str(i) + '>' + ''.join(x for x in placeholder) + '</img' + str(i) + '>'
tag_list[j].append('<img' + str(i) + '>')
prefix_list.append('###Human:' + prefix)
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prefix_list, samples['num_imgs'], seg = None if 'segmentation' not in samples['task_type'] else 'yes')
self.llama_tokenizer.padding_side = "right"
prompt = [t for t in samples['question']]
for i in range(len(prompt)):
tags = ''
for tag in tag_list[i]:
if tag not in prompt[i]:
tags += tag
prompt[i] = prompt[i].replace('_*_', tags)
if 'detection' in samples['task_type'] or 'keypoint' in samples['task_type']:
sample_ans = [ans.split('|||')[0] for ans in samples['answer']]
sample_num = [ans.split('|||')[1] for ans in samples['answer']]
else:
sample_ans = samples['answer']
text = ['###Assistant: ' + str(t) + self.end_sym for t in sample_ans]
prompt_tokens = self.llama_tokenizer(
prompt,
return_tensors="pt",
padding='longest',
truncation=True,
max_length=256,
add_special_tokens=False
).to(image.device)
img_embeds, atts_img = self.prompt_concat(img_embeds, atts_img, prompt_tokens)
to_regress_tokens = self.llama_tokenizer(
text,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_txt_len,
add_special_tokens=False
).to(image.device)
targets = to_regress_tokens.input_ids.masked_fill(
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
)
empty_targets = (
torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
)
targets = torch.cat([empty_targets, targets], dim=1)
batch_size = img_embeds.shape[0]
bos = torch.ones([batch_size, 1],
dtype=to_regress_tokens.input_ids.dtype,
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
bos_embeds = self.embed_tokens(bos)
atts_bos = atts_img[:, :1]
to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
with self.maybe_autocast():
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
output_hidden_states=True,
)
loss = outputs.loss
if 'detection' in samples['task_type']:
with self.maybe_autocast():
hidden_states = outputs.hidden_states[-1]
token_mask = targets == self.det_token_idx
target_states = hidden_states[token_mask]
with self.maybe_autocast():
det_states = self.text_det(target_states)
labels = trans_det(sample_num, det_states.shape[0])
labels = labels.to(targets.device)
det_loss = self.det_loss(det_states, labels)
loss += det_loss * 1e2
if 'keypoint' in samples['task_type']:
with self.maybe_autocast():
hidden_states = outputs.hidden_states[-1]
token_mask = targets == self.point_token_idx_2d
target_states = hidden_states[token_mask]
with self.maybe_autocast():
point_states = self.text_point(target_states)
labels = trans_keypoint(sample_num, point_states.shape[0])
labels = labels.to(targets.device)
keypoint_loss = self.keypoint_loss(point_states, labels)
loss += keypoint_loss * 1e2
if 'segmentation' in samples['task_type']:
if 'ct' in samples['modal']:
masks = samples['answer_img']
with self.maybe_autocast():
img_embeds_list = self.visual_encoder_3d(image, encoder_only=True)
img_embeds_list = [_.to(targets.device) for _ in img_embeds_list]
hidden_states = outputs.hidden_states[-1]
token_mask = targets == self.seg_token_idx_3d
target_states = hidden_states[token_mask]
seg_states = self.text2seg_3d(target_states)
last_feats = img_embeds_list[-1]
last_feats = last_feats + seg_states.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
last_feats = self.text2seg_3d_gn(last_feats)
img_embeds_list[-1] = last_feats
seg_preds = self.visual_encoder_3d(encoder_only=False, x_=img_embeds_list)
loss += self.seg_loss(seg_preds, masks.float()) # +
else:
masks = samples['answer_img']
with self.maybe_autocast():
feats = self.model_seg_2d.encoder(image[:,0])
last_feats = feats[-1]
hidden_states = outputs.hidden_states[-1]
token_mask = targets == self.seg_token_idx_2d
target_states = hidden_states[token_mask]
seg_states = self.text2seg_2d(target_states)
last_feats = last_feats+seg_states.unsqueeze(-1).unsqueeze(-1)
last_feats = self.text2seg_2d_gn(last_feats)
feats[-1] = last_feats
seg_feats = self.model_seg_2d.decoder(*feats)
seg_preds = self.model_seg_2d.segmentation_head(seg_feats)
loss += self.seg_loss(seg_preds, masks.float())
return {"loss": loss, "modal": samples['modal'][0], "task_type": samples['task_type'][0]}
@classmethod
def from_config(cls, cfg, finetune=False):
# llama_model = cfg.get("llama_model")
# freeze_vit = cfg.get("freeze_vit", True)
# low_resource = cfg.get("low_resource", False)
# max_txt_len = cfg.get("max_txt_len", 32)
# end_sym = cfg.get("end_sym", '\n')
# ipdb.set_trace()
# model = cls(
# freeze_vit=freeze_vit,
# llama_model=llama_model,
# max_txt_len=max_txt_len,
# low_resource=low_resource,
# end_sym=end_sym
# )
model = cls(cfg)
# load checkpoint
ckpt_path = cfg.get("ckpt", "")
if ckpt_path:
print("Load Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cpu")
if finetune:
current_model_dict = model.state_dict()
weights = ckpt['model']
new_state_dict = {}
for k in list(current_model_dict.keys()):
if k in list(weights.keys()):
if weights[k].size() == current_model_dict[k].size():
new_state_dict[k] = weights[k]
else:
new_state_dict[k] = current_model_dict[k]
else:
print(k)
new_state_dict[k] = current_model_dict[k]
msg = model.load_state_dict(new_state_dict, strict=False)
else:
msg = model.load_state_dict(ckpt['model'], strict=False)
return model
|