|
import logging |
|
import random |
|
|
|
import torch |
|
from torch.cuda.amp import autocast as autocast |
|
from torchvision import models |
|
import torch.nn as nn |
|
|
|
from medomni.common.registry import registry |
|
from medomni.models.blip2 import Blip2Base, disabled_train |
|
from medomni.models.modeling_llama import LlamaForCausalLM |
|
from transformers import LlamaTokenizer |
|
from transformers import SwinModel |
|
import torch.nn.functional as F |
|
import math |
|
from einops import rearrange, repeat |
|
from einops_exts import rearrange_many |
|
import open_clip |
|
import segmentation_models_pytorch as smp |
|
from medomni.models.UNet import UNet3d |
|
from huggingface_hub import PyTorchModelHubMixin |
|
import ipdb |
|
from peft import ( |
|
get_peft_model, |
|
LoraConfig, |
|
PrefixTuningConfig, |
|
PromptEncoderConfig, |
|
PromptTuningConfig, |
|
TaskType, |
|
) |
|
|
|
class GroupNorm(nn.GroupNorm): |
|
"""Subclass torch's LayerNorm to handle fp16.""" |
|
|
|
def forward(self, x: torch.Tensor): |
|
orig_type = x.dtype |
|
ret = super().forward(x.type(torch.float32)) |
|
return ret.type(orig_type) |
|
|
|
class LayerNorm(nn.LayerNorm): |
|
"""Subclass torch's LayerNorm to handle fp16.""" |
|
|
|
def forward(self, x: torch.Tensor): |
|
orig_type = x.dtype |
|
ret = super().forward(x.type(torch.float32)) |
|
return ret.type(orig_type) |
|
|
|
def replace_batchnorm_2d(model): |
|
for name, module in reversed(model._modules.items()): |
|
if len(list(module.children())) > 0: |
|
model._modules[name] = replace_batchnorm_2d(module) |
|
|
|
if isinstance(module, nn.BatchNorm2d): |
|
model._modules[name] = GroupNorm(num_groups=16, num_channels=module.num_features) |
|
return model |
|
|
|
def dice_loss(input, target): |
|
input = torch.sigmoid(input) |
|
smooth = 1.0 |
|
iflat = input.view(-1) |
|
tflat = target.view(-1) |
|
intersection = (iflat * tflat).sum() |
|
return ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)) |
|
|
|
class FocalLoss(nn.Module): |
|
def __init__(self, gamma): |
|
super().__init__() |
|
self.gamma = gamma |
|
|
|
def forward(self, input, target): |
|
if not (target.size() == input.size()): |
|
raise ValueError("Target size ({}) must be the same as input size ({})" |
|
.format(target.size(), input.size())) |
|
max_val = (-input).clamp(min=0) |
|
loss = input - input * target + max_val + \ |
|
((-max_val).exp() + (-input - max_val).exp()).log() |
|
invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0)) |
|
loss = (invprobs * self.gamma).exp() * loss |
|
return loss.mean() |
|
|
|
class MixedLoss(nn.Module): |
|
def __init__(self, alpha, gamma): |
|
super().__init__() |
|
self.alpha = alpha |
|
self.focal = FocalLoss(gamma) |
|
|
|
def forward(self, input, target): |
|
loss = self.alpha*self.focal(input, target) - torch.log(dice_loss(input, target)) |
|
return loss.mean() |
|
|
|
def trans_seg(sample_num, bsz): |
|
labels = torch.zeros((bsz, 10)) |
|
c_bsz = 0 |
|
for num1 in sample_num: |
|
num2 = num1.split('-') |
|
for num3 in num2: |
|
if num3 != 'n/a': |
|
c4 = 0 |
|
for num in num3.split(','): |
|
labels[c_bsz, c4] = float(num) |
|
c4 += 1 |
|
c_bsz += 1 |
|
return labels |
|
|
|
def trans_det(sample_num, bsz): |
|
labels = torch.zeros((bsz, 4)) |
|
c_bsz = 0 |
|
for num1 in sample_num: |
|
num2 = num1.split(';') |
|
for num3 in num2: |
|
if num3 != 'n/a': |
|
c4 = 0 |
|
for num in num3.split(','): |
|
labels[c_bsz, c4] = float(num) |
|
c4 += 1 |
|
c_bsz += 1 |
|
return labels |
|
|
|
def trans_keypoint(sample_num, bsz): |
|
labels = torch.zeros((bsz, 2)) |
|
c_bsz = 0 |
|
for num1 in sample_num: |
|
num2 = num1.split(';') |
|
for num3 in num2: |
|
if num3 != 'n/a': |
|
c4 = 0 |
|
for num in num3.split(','): |
|
labels[c_bsz, c4] = float(num) |
|
c4 += 1 |
|
c_bsz += 1 |
|
return labels |
|
|
|
@registry.register_model("medomni") |
|
class MedOmni(Blip2Base, PyTorchModelHubMixin): |
|
PRETRAINED_MODEL_CONFIG_DICT = { |
|
"medomni": "configs/models/medomni.yaml", |
|
} |
|
def __init__( |
|
self, |
|
config, |
|
): |
|
super().__init__() |
|
freeze_vit=True |
|
llama_model=config['llama_model'] |
|
max_txt_len=config['max_txt_len'] |
|
low_resource=False |
|
end_sym=config['end_sym'] |
|
|
|
self.low_resource = low_resource |
|
|
|
print('Loading VIT') |
|
self.visual_encoder_2d = SwinModel.from_pretrained('microsoft/swin-base-patch4-window7-224') |
|
self.visual_encoder_3d = UNet3d(in_channels=1, n_classes=1, n_channels=32) |
|
self.ln_vision_2d = LayerNorm(1024) |
|
self.ln_vision_3d = LayerNorm(256) |
|
|
|
if freeze_vit: |
|
for name, param in self.visual_encoder_2d.named_parameters(): |
|
param.requires_grad = False |
|
self.visual_encoder_2d = self.visual_encoder_2d.eval() |
|
self.visual_encoder_2d.train = disabled_train |
|
for name, param in self.ln_vision_2d.named_parameters(): |
|
param.requires_grad = False |
|
self.ln_vision_2d = self.ln_vision_2d.eval() |
|
self.ln_vision_2d.train = disabled_train |
|
for name, param in self.visual_encoder_3d.named_parameters(): |
|
param.requires_grad = False |
|
self.visual_encoder_3d = self.visual_encoder_3d.eval() |
|
self.visual_encoder_3d.train = disabled_train |
|
for name, param in self.ln_vision_3d.named_parameters(): |
|
param.requires_grad = False |
|
self.ln_vision_3d = self.ln_vision_3d.eval() |
|
self.ln_vision_3d.train = disabled_train |
|
logging.info("freeze vision encoder") |
|
print('Loading VIT Done') |
|
|
|
print('Loading LLAMA') |
|
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, legacy=False, use_fast=False) |
|
special_token = {} |
|
special_token["additional_special_tokens"] = ['<ImageHere>'] |
|
self.llama_tokenizer.add_special_tokens( |
|
special_token |
|
) |
|
self.llama_tokenizer.add_tokens("<DET>") |
|
self.llama_tokenizer.add_tokens("<2DSEG>") |
|
self.llama_tokenizer.add_tokens("<3DSEG>") |
|
|
|
self.llama_tokenizer.add_tokens("<N/A>") |
|
self.det_token_idx = self.llama_tokenizer("<DET>", add_special_tokens=False).input_ids[0] |
|
self.seg_token_idx_2d = self.llama_tokenizer("<2DSEG>", add_special_tokens=False).input_ids[0] |
|
self.seg_token_idx_3d = self.llama_tokenizer("<3DSEG>", add_special_tokens=False).input_ids[0] |
|
|
|
self.na_token_idx = self.llama_tokenizer("<N/A>", add_special_tokens=False).input_ids[0] |
|
self.llama_tokenizer.pad_token = 0 |
|
|
|
if self.low_resource: |
|
self.llama_model = LlamaForCausalLM.from_pretrained( |
|
llama_model, |
|
torch_dtype=torch.bfloat16, |
|
load_in_8bit=True, |
|
device_map="auto" |
|
) |
|
else: |
|
self.llama_model = LlamaForCausalLM.from_pretrained( |
|
llama_model, |
|
torch_dtype=torch.bfloat16, |
|
) |
|
|
|
self.llama_model.resize_token_embeddings(len(self.llama_tokenizer)) |
|
self.embed_tokens = self.llama_model.get_input_embeddings() |
|
self.embed_states = self.llama_model.get_output_embeddings() |
|
|
|
class CastOutputToFloat(nn.Sequential): |
|
def forward(self, x): return super().forward(x).to(torch.bfloat16) |
|
self.llama_model.lm_head = CastOutputToFloat(self.llama_model.lm_head) |
|
|
|
|
|
print("Setup PEFT") |
|
peft_config = LoraConfig( |
|
task_type="CAUSAL_LM", inference_mode=False, |
|
r=16, |
|
lora_alpha=16, lora_dropout=0.1, |
|
target_modules=['q_proj', 'v_proj'] |
|
) |
|
self.llama_model = get_peft_model(self.llama_model, peft_config) |
|
self.llama_proj_2d = nn.Linear(1024, self.llama_model.config.hidden_size) |
|
self.llama_proj_3d = nn.Linear(256, self.llama_model.config.hidden_size) |
|
|
|
|
|
text_det = nn.Sequential( |
|
LayerNorm(self.llama_model.config.hidden_size), |
|
nn.Linear(self.llama_model.config.hidden_size, 256), |
|
nn.ReLU(inplace=True), |
|
LayerNorm(256), |
|
nn.Linear(256, 4), |
|
) |
|
self.text_det = text_det |
|
self.det_loss = torch.nn.SmoothL1Loss() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.model_seg_2d = smp.Unet(encoder_name="resnet18", encoder_weights="imagenet", in_channels=3, classes=1) |
|
self.model_seg_2d = replace_batchnorm_2d(self.model_seg_2d) |
|
|
|
text2seg_2d = nn.Sequential( |
|
LayerNorm(self.llama_model.config.hidden_size), |
|
nn.Linear(self.llama_model.config.hidden_size, 512), |
|
) |
|
self.text2seg_2d = text2seg_2d |
|
self.text2seg_2d_ln = LayerNorm(512) |
|
self.text2seg_2d_gn = GroupNorm(16, 512) |
|
text2seg_3d = nn.Sequential( |
|
LayerNorm(self.llama_model.config.hidden_size), |
|
nn.Linear(self.llama_model.config.hidden_size, 256), |
|
) |
|
self.text2seg_3d = text2seg_3d |
|
self.text2seg_3d_ln = LayerNorm(256) |
|
self.text2seg_3d_gn = GroupNorm(16, 256) |
|
self.seg_loss = MixedLoss(10.0, 2.0) |
|
|
|
self.max_txt_len = max_txt_len |
|
self.end_sym = end_sym |
|
self.prompt_list = [] |
|
|
|
def vit_to_cpu(self): |
|
self.ln_vision.to("cpu") |
|
self.ln_vision.float() |
|
self.visual_encoder.to("cpu") |
|
self.visual_encoder.float() |
|
|
|
def encode_img(self, image, modals, task_types=[]): |
|
B,S,_,_,_ = image.shape |
|
device = image.device |
|
image_embeds_list = None |
|
if self.low_resource: |
|
self.vit_to_cpu() |
|
image = image.to("cpu") |
|
|
|
with self.maybe_autocast(): |
|
if 'ct' in modals: |
|
image_embeds_list = self.visual_encoder_3d(image, encoder_only=True) |
|
image_embeds_list = [_.to(device) for _ in image_embeds_list] |
|
image_embeds = image_embeds_list[-1].detach() |
|
image_embeds = F.adaptive_avg_pool3d(image_embeds, (1, 3, 3)).view(B, image_embeds.shape[1], -1).permute(0, 2, 1) |
|
inputs_llama = self.llama_proj_3d(self.ln_vision_3d(image_embeds)) |
|
inputs_llama = rearrange(inputs_llama, "(b s) c d -> b s c d", b=B, s=S).to(torch.bfloat16) |
|
atts_llama = torch.ones(inputs_llama.size()[:-2], dtype=torch.long).to(image.device) |
|
else: |
|
image = rearrange(image, "b s c h w -> (b s) c h w") |
|
image_embeds = self.visual_encoder_2d(image)['last_hidden_state'].to(device) |
|
image_embeds_unp = image_embeds.permute(0, 2, 1).view(B*S,-1,7,7) |
|
image_embeds_unp = F.adaptive_avg_pool2d(image_embeds_unp, (3, 3)) |
|
image_embeds = image_embeds_unp.view(B*S, -1, 9).permute(0, 2, 1) |
|
inputs_llama = self.llama_proj_2d(self.ln_vision_2d(image_embeds)) |
|
if 'segmentation' not in task_types: |
|
inputs_llama = rearrange(inputs_llama, "(b s) c d -> b s c d", b=B, s=S).to(torch.bfloat16) |
|
atts_llama = torch.ones(inputs_llama.size()[:-2], dtype=torch.long).to(image.device) |
|
else: |
|
inputs_llama = rearrange(inputs_llama, "(b s) c d -> b s c d", b=B, s=S).to(torch.bfloat16).detach() |
|
atts_llama = torch.ones(inputs_llama.size()[:-2], dtype=torch.long).to(image.device).detach() |
|
|
|
return inputs_llama, atts_llama, image_embeds_list |
|
|
|
def prompt_concat(self, img_embeds, atts_img, prompt): |
|
if prompt: |
|
batch_size = img_embeds.shape[0] |
|
p_after_embeds = self.embed_tokens(prompt.input_ids).expand(batch_size, -1, -1) |
|
wrapped_img_embeds = torch.cat([img_embeds, p_after_embeds], dim=1) |
|
wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1]) |
|
return wrapped_img_embeds, wrapped_atts_img |
|
else: |
|
return img_embeds, atts_img |
|
|
|
def prompt_wrap(self, img_embeds, atts_img, prompt_list, num_imgs, seg=None): |
|
bsz = img_embeds.shape[0] |
|
if prompt_list: |
|
img_idx = ([], []) |
|
for i in range(len(num_imgs)): |
|
for j in range(num_imgs[i]): |
|
img_idx[0].append(i) |
|
img_idx[1].append(j) |
|
prompt_tokens = self.llama_tokenizer(prompt_list, return_tensors="pt", padding="longest", truncation=True, max_length=256).to(img_embeds.device) |
|
idx = (prompt_tokens.input_ids == 32000).nonzero(as_tuple=True) |
|
prompt_tokens.input_ids[idx] = 123 |
|
p_embeds = self.embed_tokens(prompt_tokens.input_ids).expand(bsz, -1, -1) |
|
if seg is None: |
|
p_embeds[idx] = rearrange(img_embeds[img_idx], "b c d -> (b c) d").to(torch.bfloat16) |
|
else: |
|
p_embeds[idx] = rearrange(img_embeds[img_idx], "b c d -> (b c) d").to(torch.bfloat16).detach() |
|
return p_embeds, atts_img |
|
else: |
|
return img_embeds, atts_img |
|
|
|
def forward(self, samples): |
|
image = samples["image"] |
|
bsz = image.shape[0] |
|
img_embeds, atts_img, img_embeds_list = self.encode_img(image, samples['modal'], samples['task_type']) |
|
prefix_list = [] |
|
tag_list = [[] for _ in range(bsz)] |
|
placeholder = ['<ImageHere>'] * 9 |
|
for j in range(bsz): |
|
num = samples['num_imgs'][j] |
|
prefix = '' |
|
for i in range(num): |
|
prefix += '<img' + str(i) + '>' + ''.join(x for x in placeholder) + '</img' + str(i) + '>' |
|
tag_list[j].append('<img' + str(i) + '>') |
|
prefix_list.append('###Human:' + prefix) |
|
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prefix_list, samples['num_imgs'], seg = None if 'segmentation' not in samples['task_type'] else 'yes') |
|
self.llama_tokenizer.padding_side = "right" |
|
|
|
prompt = [t for t in samples['question']] |
|
for i in range(len(prompt)): |
|
tags = '' |
|
for tag in tag_list[i]: |
|
if tag not in prompt[i]: |
|
tags += tag |
|
prompt[i] = prompt[i].replace('_*_', tags) |
|
|
|
if 'detection' in samples['task_type'] or 'keypoint' in samples['task_type']: |
|
sample_ans = [ans.split('|||')[0] for ans in samples['answer']] |
|
sample_num = [ans.split('|||')[1] for ans in samples['answer']] |
|
else: |
|
sample_ans = samples['answer'] |
|
text = ['###Assistant: ' + str(t) + self.end_sym for t in sample_ans] |
|
|
|
prompt_tokens = self.llama_tokenizer( |
|
prompt, |
|
return_tensors="pt", |
|
padding='longest', |
|
truncation=True, |
|
max_length=256, |
|
add_special_tokens=False |
|
).to(image.device) |
|
|
|
img_embeds, atts_img = self.prompt_concat(img_embeds, atts_img, prompt_tokens) |
|
|
|
to_regress_tokens = self.llama_tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding="longest", |
|
truncation=True, |
|
max_length=self.max_txt_len, |
|
add_special_tokens=False |
|
).to(image.device) |
|
|
|
targets = to_regress_tokens.input_ids.masked_fill( |
|
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 |
|
) |
|
|
|
empty_targets = ( |
|
torch.ones([atts_img.shape[0], atts_img.shape[1]+1], |
|
dtype=torch.long).to(image.device).fill_(-100) |
|
) |
|
targets = torch.cat([empty_targets, targets], dim=1) |
|
|
|
batch_size = img_embeds.shape[0] |
|
bos = torch.ones([batch_size, 1], |
|
dtype=to_regress_tokens.input_ids.dtype, |
|
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id |
|
|
|
bos_embeds = self.embed_tokens(bos) |
|
atts_bos = atts_img[:, :1] |
|
|
|
to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids) |
|
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1) |
|
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) |
|
with self.maybe_autocast(): |
|
outputs = self.llama_model( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
return_dict=True, |
|
labels=targets, |
|
output_hidden_states=True, |
|
) |
|
loss = outputs.loss |
|
|
|
if 'detection' in samples['task_type']: |
|
with self.maybe_autocast(): |
|
hidden_states = outputs.hidden_states[-1] |
|
token_mask = targets == self.det_token_idx |
|
target_states = hidden_states[token_mask] |
|
with self.maybe_autocast(): |
|
det_states = self.text_det(target_states) |
|
labels = trans_det(sample_num, det_states.shape[0]) |
|
labels = labels.to(targets.device) |
|
det_loss = self.det_loss(det_states, labels) |
|
loss += det_loss * 1e2 |
|
|
|
if 'keypoint' in samples['task_type']: |
|
with self.maybe_autocast(): |
|
hidden_states = outputs.hidden_states[-1] |
|
token_mask = targets == self.point_token_idx_2d |
|
target_states = hidden_states[token_mask] |
|
with self.maybe_autocast(): |
|
point_states = self.text_point(target_states) |
|
labels = trans_keypoint(sample_num, point_states.shape[0]) |
|
labels = labels.to(targets.device) |
|
keypoint_loss = self.keypoint_loss(point_states, labels) |
|
loss += keypoint_loss * 1e2 |
|
|
|
if 'segmentation' in samples['task_type']: |
|
if 'ct' in samples['modal']: |
|
masks = samples['answer_img'] |
|
with self.maybe_autocast(): |
|
img_embeds_list = self.visual_encoder_3d(image, encoder_only=True) |
|
img_embeds_list = [_.to(targets.device) for _ in img_embeds_list] |
|
hidden_states = outputs.hidden_states[-1] |
|
token_mask = targets == self.seg_token_idx_3d |
|
target_states = hidden_states[token_mask] |
|
seg_states = self.text2seg_3d(target_states) |
|
last_feats = img_embeds_list[-1] |
|
last_feats = last_feats + seg_states.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
last_feats = self.text2seg_3d_gn(last_feats) |
|
img_embeds_list[-1] = last_feats |
|
seg_preds = self.visual_encoder_3d(encoder_only=False, x_=img_embeds_list) |
|
loss += self.seg_loss(seg_preds, masks.float()) |
|
else: |
|
masks = samples['answer_img'] |
|
with self.maybe_autocast(): |
|
feats = self.model_seg_2d.encoder(image[:,0]) |
|
last_feats = feats[-1] |
|
hidden_states = outputs.hidden_states[-1] |
|
token_mask = targets == self.seg_token_idx_2d |
|
target_states = hidden_states[token_mask] |
|
seg_states = self.text2seg_2d(target_states) |
|
last_feats = last_feats+seg_states.unsqueeze(-1).unsqueeze(-1) |
|
last_feats = self.text2seg_2d_gn(last_feats) |
|
feats[-1] = last_feats |
|
seg_feats = self.model_seg_2d.decoder(*feats) |
|
seg_preds = self.model_seg_2d.segmentation_head(seg_feats) |
|
loss += self.seg_loss(seg_preds, masks.float()) |
|
|
|
return {"loss": loss, "modal": samples['modal'][0], "task_type": samples['task_type'][0]} |
|
|
|
@classmethod |
|
def from_config(cls, cfg, finetune=False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = cls(cfg) |
|
|
|
|
|
ckpt_path = cfg.get("ckpt", "") |
|
if ckpt_path: |
|
print("Load Checkpoint: {}".format(ckpt_path)) |
|
ckpt = torch.load(ckpt_path, map_location="cpu") |
|
if finetune: |
|
current_model_dict = model.state_dict() |
|
weights = ckpt['model'] |
|
new_state_dict = {} |
|
for k in list(current_model_dict.keys()): |
|
if k in list(weights.keys()): |
|
if weights[k].size() == current_model_dict[k].size(): |
|
new_state_dict[k] = weights[k] |
|
else: |
|
new_state_dict[k] = current_model_dict[k] |
|
else: |
|
print(k) |
|
new_state_dict[k] = current_model_dict[k] |
|
msg = model.load_state_dict(new_state_dict, strict=False) |
|
else: |
|
msg = model.load_state_dict(ckpt['model'], strict=False) |
|
|
|
return model |
|
|