Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
import numpy as np | |
from typing import Any, Dict, List, Tuple | |
from .segment_anything.utils.transforms import ResizeLongestSide | |
from .segment_anything.build_sam import sam_model_registry | |
from .decoder import build_decoder | |
from . import constants | |
from einops import rearrange | |
from .segment_anything.modeling.prompt_engineering import prompt_engineering, get_prompt_templates | |
from .clip import load as load_clip | |
import clip | |
class RegionSpot(nn.Module): | |
TEXT_FEATS_MAP = { | |
'coco': 'text_feats_coco', | |
'objects365': 'text_feats_objects365', | |
'v3det': 'text_feats_v3det', | |
'lvis': 'text_feats_lvis', | |
'openimages': 'text_feats_openimages' | |
} | |
def __init__(self, sam_checkpoint='./sam_checkpoints/sam_vit_b_01ec64.pth', | |
clip_type='CLIP_400M_Large', is_training=True, custom_vocabulary=None, image_size=224): | |
super().__init__() | |
self.sam = sam_model_registry['vit_b'](checkpoint=sam_checkpoint) | |
self._freeze_module(self.sam) | |
self.clip_model, self.text_dim, self.clip_dim = self._load_clip_model(clip_type, image_size) | |
self.clip_model.eval() | |
self._freeze_module(self.clip_model) | |
self.logit_scale = self.clip_model.logit_scale.exp() | |
self.to_clip = nn.Linear(256, self.clip_dim) | |
self.ln_clip = nn.LayerNorm(self.clip_dim, elementwise_affine=False) | |
self.projector = nn.Linear(self.clip_dim, self.text_dim) | |
self.decoder = build_decoder(d_model=self.clip_dim) | |
# Dynamically set attributes based on the datasets in the map | |
if is_training: | |
datasets_to_load = ['objects365', 'v3det', 'openimages'] | |
for dataset in datasets_to_load: | |
setattr(self, self.TEXT_FEATS_MAP[dataset], self.get_text_feat(dataset)) | |
else: | |
dataset_name = 'custom' if custom_vocabulary else 'lvis' | |
# custom_vocabulary += ["background"] | |
self.text_feats = self.get_text_feat(dataset_name, custom_class=custom_vocabulary) | |
def _add_text_vocab(custom_vocabulary): | |
dataset_name = 'custom' | |
setattr(self, self.TEXT_FEATS_MAP['openimages'],custom_class = custom_vocabulary) | |
def _freeze_module(module): | |
for param in module.parameters(): | |
param.requires_grad = False | |
def _load_clip_model(self, clip_type, image_size): | |
clip_model_map = { | |
'CLIP_400M': ("ViT-B/16", 512, 768), | |
'CLIP_400M_Large': ("ViT-L/14", 768, 1024), | |
'CLIP_400M_Large_336': ("ViT-L/14@336px", 768, 1024) | |
} | |
model_type, text_dim, clip_dim = clip_model_map[clip_type] | |
clip_model, _ = load_clip(model_type, image_size=image_size) | |
return clip_model, text_dim, clip_dim | |
def get_text_feat(self, dataset_name: str, custom_class=None) -> torch.Tensor: | |
dataset_map = { | |
'coco': constants.COCO_INSTANCE_CLASSES, | |
'objects365': constants.OBJECTS365V1, | |
'v3det': constants.V3DET, | |
'lvis': constants.LVIS_CATEGORIES, | |
'openimages': constants.OPENIMAGE, | |
'custom': custom_class | |
} | |
# Error handling for custom dataset without custom classes provided | |
if dataset_name == 'custom' and custom_class is None: | |
raise ValueError("For custom datasets, you must provide the 'custom_class' parameter.") | |
class_names = dataset_map.get(dataset_name, []) | |
def clean_class_name(clss: str) -> str: | |
"""Clean class names for prompt templates.""" | |
return clss.replace('-other', '').replace('-merged', '').replace('-stuff', '') | |
def extract_mean_emb(text: str) -> torch.Tensor: | |
"""Extract mean embeddings from text using the clip model.""" | |
tokens = clip.tokenize(text).cuda() | |
if len(tokens) > 10000: | |
split_idx = len(tokens) // 2 | |
text_features = torch.cat([ | |
self.clip_model.encode_text(tokens[:split_idx]), | |
self.clip_model.encode_text(tokens[split_idx:])], | |
dim=0) | |
else: | |
text_features = self.clip_model.encode_text(tokens) | |
return torch.mean(text_features, 0, keepdims=True)[0] | |
templates = get_prompt_templates() | |
clss_embeddings = [] | |
for clss in class_names: | |
txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates] | |
# txts = [clss] | |
clss_embeddings.append(extract_mean_emb(txts)) | |
text_emb = torch.stack(clss_embeddings, dim=0) | |
text_emb /= text_emb.norm(dim=-1, keepdim=True) | |
return text_emb | |
def sigmoid_focal_loss(self, inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, reduction=True): | |
"""Compute the sigmoid focal loss.""" | |
prob = inputs.sigmoid() | |
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") | |
p_t = prob * targets + (1 - prob) * (1 - targets) | |
loss = ce_loss * ((1 - p_t) ** gamma) | |
if alpha >= 0: | |
loss = (alpha * targets + (1 - alpha) * (1 - targets)) * loss | |
return loss.mean(1).sum() / num_boxes | |
def get_logits(self, region_features, text_features, logit_scale): | |
"""Compute logits for region and text features.""" | |
region_features = region_features / (region_features.norm(dim=-1, keepdim=True) + 1e-7) | |
logits_per_image = logit_scale * region_features @ text_features.unsqueeze(0).transpose(1, 2) | |
logits_per_text = logit_scale * text_features.unsqueeze(0) @ region_features.transpose(1, 2) | |
return logits_per_image, logits_per_text | |
def ce_loss(self, region_features, label, logit_scale, dataset_name, focal_alpha=0.25): | |
"""Compute the cross-entropy loss.""" | |
b, n_box, d = region_features.shape | |
text_feats = getattr(self, self.TEXT_FEATS_MAP[dataset_name]) | |
logits_per_image, _ = self.get_logits(region_features, text_feats, logit_scale) | |
target_classes_onehot = torch.zeros(logits_per_image.shape, dtype=logits_per_image.dtype, device=logits_per_image.device) | |
label = label.long() | |
target_classes_onehot.scatter_(2, label.unsqueeze(-1), 1) | |
loss_ce = self.sigmoid_focal_loss(logits_per_image, target_classes_onehot, n_box, alpha=focal_alpha, gamma=2) * logits_per_image.shape[1] | |
return loss_ce | |
def forward_train(self, batched_input: List[Dict[str, Any]]) -> List[Dict[str, torch.Tensor]]: | |
"""Training forward pass.""" | |
resized_image = torch.stack([x["resized_image"] for x in batched_input], dim=0) | |
with torch.no_grad(): | |
clip_feat = self.clip_model.encode_image_featuremap(resized_image).detach() | |
masks_token = torch.stack([x["mask_tokens"] for x in batched_input], dim=0).squeeze(2) | |
dataset_name = batched_input[0]["dataset_name"] | |
masks_token = self.to_clip(masks_token) | |
semantic_token = self.projector(self.decoder(masks_token, clip_feat)) | |
label = torch.stack([x["label"] for x in batched_input], dim=0) | |
return self.ce_loss(semantic_token, label, self.logit_scale, dataset_name) | |
def forward_eval(self, batched_input: List[Dict[str, Any]], multimask_output=False) -> List[Dict[str, torch.Tensor]]: | |
"""Inference forward pass.""" | |
sam_output = self.sam(batched_input, multimask_output=multimask_output) | |
masks_token = torch.stack([x["masks_token"] for x in sam_output], dim=0).squeeze(2) | |
pred_mask = torch.stack([x["masks"] for x in sam_output], dim=0) | |
resized_image = torch.stack([x["resized_image"] for x in batched_input], dim=0) | |
with torch.no_grad(): | |
self.decoder.eval() | |
clip_feat = self.clip_model.encode_image_featuremap(resized_image).detach() | |
masks_token = self.to_clip(masks_token) | |
semantic_token = self.projector(self.decoder(masks_token, clip_feat)) | |
logits_per_image, _ = self.get_logits(semantic_token, self.text_feats, self.logit_scale) | |
return logits_per_image, pred_mask | |
def forward_inference(self, clip_feat, masks_token, resized_image,) -> List[Dict[str, torch.Tensor]]: | |
"""Inference forward pass.""" | |
# if masks_token.shape | |
masks_token = masks_token[None,:] | |
if masks_token.shape[2] == 1: | |
masks_token = masks_token.squeeze(2) | |
else: | |
masks_token = masks_token.permute(2, 1, 0, 3).squeeze(2) | |
clip_feat = clip_feat.repeat(3, 1, 1) | |
with torch.no_grad(): | |
self.decoder.eval() | |
masks_token = self.to_clip(masks_token) | |
semantic_token = self.projector(self.decoder(masks_token, clip_feat)) | |
logits_per_image, _ = self.get_logits(semantic_token, self.text_feats, self.logit_scale) | |
if logits_per_image.shape[0] == 3: | |
logits_per_image = logits_per_image.permute(1, 0, 2) | |
return logits_per_image | |
def build_regionspot_model(clip_type='CLIP_400M_Large', is_training=True, pretrain_ckpt=None, image_size=224, custom_vocabulary=None): | |
model = RegionSpot(clip_type=clip_type, is_training=is_training, image_size=image_size, custom_vocabulary=custom_vocabulary) | |
if pretrain_ckpt: | |
checkpoint = torch.load(pretrain_ckpt, map_location='cpu')['model'] | |
# Remove the 'model.' prefix | |
new_checkpoint = {} | |
for key in checkpoint.keys(): | |
if key.startswith('model.'): | |
new_key = key[len('model.'):] | |
new_checkpoint[new_key] = checkpoint[key] | |
else: | |
new_checkpoint[key] = checkpoint[key] | |
# Load the modified state dict | |
msg = model.load_state_dict(new_checkpoint, strict=False) | |
else: | |
msg= 'training stage' | |
return model, msg | |