bklg's picture
Upload 114 files
a153c95
raw
history blame
10.1 kB
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from typing import Any, Dict, List, Tuple
from .segment_anything.utils.transforms import ResizeLongestSide
from .segment_anything.build_sam import sam_model_registry
from .decoder import build_decoder
from . import constants
from einops import rearrange
from .segment_anything.modeling.prompt_engineering import prompt_engineering, get_prompt_templates
from .clip import load as load_clip
import clip
class RegionSpot(nn.Module):
TEXT_FEATS_MAP = {
'coco': 'text_feats_coco',
'objects365': 'text_feats_objects365',
'v3det': 'text_feats_v3det',
'lvis': 'text_feats_lvis',
'openimages': 'text_feats_openimages'
}
def __init__(self, sam_checkpoint='./sam_checkpoints/sam_vit_b_01ec64.pth',
clip_type='CLIP_400M_Large', is_training=True, custom_vocabulary=None, image_size=224):
super().__init__()
self.sam = sam_model_registry['vit_b'](checkpoint=sam_checkpoint)
self._freeze_module(self.sam)
self.clip_model, self.text_dim, self.clip_dim = self._load_clip_model(clip_type, image_size)
self.clip_model.eval()
self._freeze_module(self.clip_model)
self.logit_scale = self.clip_model.logit_scale.exp()
self.to_clip = nn.Linear(256, self.clip_dim)
self.ln_clip = nn.LayerNorm(self.clip_dim, elementwise_affine=False)
self.projector = nn.Linear(self.clip_dim, self.text_dim)
self.decoder = build_decoder(d_model=self.clip_dim)
# Dynamically set attributes based on the datasets in the map
if is_training:
datasets_to_load = ['objects365', 'v3det', 'openimages']
for dataset in datasets_to_load:
setattr(self, self.TEXT_FEATS_MAP[dataset], self.get_text_feat(dataset))
else:
dataset_name = 'custom' if custom_vocabulary else 'lvis'
# custom_vocabulary += ["background"]
self.text_feats = self.get_text_feat(dataset_name, custom_class=custom_vocabulary)
def _add_text_vocab(custom_vocabulary):
dataset_name = 'custom'
setattr(self, self.TEXT_FEATS_MAP['openimages'],custom_class = custom_vocabulary)
@staticmethod
def _freeze_module(module):
for param in module.parameters():
param.requires_grad = False
def _load_clip_model(self, clip_type, image_size):
clip_model_map = {
'CLIP_400M': ("ViT-B/16", 512, 768),
'CLIP_400M_Large': ("ViT-L/14", 768, 1024),
'CLIP_400M_Large_336': ("ViT-L/14@336px", 768, 1024)
}
model_type, text_dim, clip_dim = clip_model_map[clip_type]
clip_model, _ = load_clip(model_type, image_size=image_size)
return clip_model, text_dim, clip_dim
@torch.no_grad()
def get_text_feat(self, dataset_name: str, custom_class=None) -> torch.Tensor:
dataset_map = {
'coco': constants.COCO_INSTANCE_CLASSES,
'objects365': constants.OBJECTS365V1,
'v3det': constants.V3DET,
'lvis': constants.LVIS_CATEGORIES,
'openimages': constants.OPENIMAGE,
'custom': custom_class
}
# Error handling for custom dataset without custom classes provided
if dataset_name == 'custom' and custom_class is None:
raise ValueError("For custom datasets, you must provide the 'custom_class' parameter.")
class_names = dataset_map.get(dataset_name, [])
def clean_class_name(clss: str) -> str:
"""Clean class names for prompt templates."""
return clss.replace('-other', '').replace('-merged', '').replace('-stuff', '')
def extract_mean_emb(text: str) -> torch.Tensor:
"""Extract mean embeddings from text using the clip model."""
tokens = clip.tokenize(text).cuda()
if len(tokens) > 10000:
split_idx = len(tokens) // 2
text_features = torch.cat([
self.clip_model.encode_text(tokens[:split_idx]),
self.clip_model.encode_text(tokens[split_idx:])],
dim=0)
else:
text_features = self.clip_model.encode_text(tokens)
return torch.mean(text_features, 0, keepdims=True)[0]
templates = get_prompt_templates()
clss_embeddings = []
for clss in class_names:
txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates]
# txts = [clss]
clss_embeddings.append(extract_mean_emb(txts))
text_emb = torch.stack(clss_embeddings, dim=0)
text_emb /= text_emb.norm(dim=-1, keepdim=True)
return text_emb
def sigmoid_focal_loss(self, inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, reduction=True):
"""Compute the sigmoid focal loss."""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
loss = (alpha * targets + (1 - alpha) * (1 - targets)) * loss
return loss.mean(1).sum() / num_boxes
def get_logits(self, region_features, text_features, logit_scale):
"""Compute logits for region and text features."""
region_features = region_features / (region_features.norm(dim=-1, keepdim=True) + 1e-7)
logits_per_image = logit_scale * region_features @ text_features.unsqueeze(0).transpose(1, 2)
logits_per_text = logit_scale * text_features.unsqueeze(0) @ region_features.transpose(1, 2)
return logits_per_image, logits_per_text
def ce_loss(self, region_features, label, logit_scale, dataset_name, focal_alpha=0.25):
"""Compute the cross-entropy loss."""
b, n_box, d = region_features.shape
text_feats = getattr(self, self.TEXT_FEATS_MAP[dataset_name])
logits_per_image, _ = self.get_logits(region_features, text_feats, logit_scale)
target_classes_onehot = torch.zeros(logits_per_image.shape, dtype=logits_per_image.dtype, device=logits_per_image.device)
label = label.long()
target_classes_onehot.scatter_(2, label.unsqueeze(-1), 1)
loss_ce = self.sigmoid_focal_loss(logits_per_image, target_classes_onehot, n_box, alpha=focal_alpha, gamma=2) * logits_per_image.shape[1]
return loss_ce
def forward_train(self, batched_input: List[Dict[str, Any]]) -> List[Dict[str, torch.Tensor]]:
"""Training forward pass."""
resized_image = torch.stack([x["resized_image"] for x in batched_input], dim=0)
with torch.no_grad():
clip_feat = self.clip_model.encode_image_featuremap(resized_image).detach()
masks_token = torch.stack([x["mask_tokens"] for x in batched_input], dim=0).squeeze(2)
dataset_name = batched_input[0]["dataset_name"]
masks_token = self.to_clip(masks_token)
semantic_token = self.projector(self.decoder(masks_token, clip_feat))
label = torch.stack([x["label"] for x in batched_input], dim=0)
return self.ce_loss(semantic_token, label, self.logit_scale, dataset_name)
def forward_eval(self, batched_input: List[Dict[str, Any]], multimask_output=False) -> List[Dict[str, torch.Tensor]]:
"""Inference forward pass."""
sam_output = self.sam(batched_input, multimask_output=multimask_output)
masks_token = torch.stack([x["masks_token"] for x in sam_output], dim=0).squeeze(2)
pred_mask = torch.stack([x["masks"] for x in sam_output], dim=0)
resized_image = torch.stack([x["resized_image"] for x in batched_input], dim=0)
with torch.no_grad():
self.decoder.eval()
clip_feat = self.clip_model.encode_image_featuremap(resized_image).detach()
masks_token = self.to_clip(masks_token)
semantic_token = self.projector(self.decoder(masks_token, clip_feat))
logits_per_image, _ = self.get_logits(semantic_token, self.text_feats, self.logit_scale)
return logits_per_image, pred_mask
def forward_inference(self, clip_feat, masks_token, resized_image,) -> List[Dict[str, torch.Tensor]]:
"""Inference forward pass."""
# if masks_token.shape
masks_token = masks_token[None,:]
if masks_token.shape[2] == 1:
masks_token = masks_token.squeeze(2)
else:
masks_token = masks_token.permute(2, 1, 0, 3).squeeze(2)
clip_feat = clip_feat.repeat(3, 1, 1)
with torch.no_grad():
self.decoder.eval()
masks_token = self.to_clip(masks_token)
semantic_token = self.projector(self.decoder(masks_token, clip_feat))
logits_per_image, _ = self.get_logits(semantic_token, self.text_feats, self.logit_scale)
if logits_per_image.shape[0] == 3:
logits_per_image = logits_per_image.permute(1, 0, 2)
return logits_per_image
def build_regionspot_model(clip_type='CLIP_400M_Large', is_training=True, pretrain_ckpt=None, image_size=224, custom_vocabulary=None):
model = RegionSpot(clip_type=clip_type, is_training=is_training, image_size=image_size, custom_vocabulary=custom_vocabulary)
if pretrain_ckpt:
checkpoint = torch.load(pretrain_ckpt, map_location='cpu')['model']
# Remove the 'model.' prefix
new_checkpoint = {}
for key in checkpoint.keys():
if key.startswith('model.'):
new_key = key[len('model.'):]
new_checkpoint[new_key] = checkpoint[key]
else:
new_checkpoint[key] = checkpoint[key]
# Load the modified state dict
msg = model.load_state_dict(new_checkpoint, strict=False)
else:
msg= 'training stage'
return model, msg