ariG23498's picture
ariG23498 HF staff
device allocation
9199a4f
# ------------------------------------------------------------------------------
# CLIP-DINOiser
# author: Monika Wysoczanska, Warsaw University of Technology
# ------------------------------------------------------------------------------
# Modified from OpenMMLab https://github.com/chongzhou96/MaskCLIP
# ------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.ops import resize
from typing import Any, List
from torch import Tensor
from mmcv.utils import print_log
from mmseg.utils import get_root_logger
from open_clip import get_tokenizer, create_model_from_pretrained
from models.builder import MODELS
from .vit import VisionTransformer
import torchvision.transforms as T
from .utils.embed import AdaptivePadding
from .utils.prompt_templates import imagenet_templates
OPENAI_NORMALIZE = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
def make_vision_transformer(backbone_cfg):
model = VisionTransformer(**backbone_cfg)
model.init_weights()
return model
@MODELS.register_module()
class MaskClip(nn.Module):
def __init__(
self,
backbone,
decode_head,
clip_model,
class_names
):
super(MaskClip, self).__init__()
self.decode_head = eval(decode_head.get('type'))(clip_model, class_names, **decode_head)
self.backbone = make_vision_transformer(backbone)
self.clip_T = OPENAI_NORMALIZE
self.to_PIL = T.ToPILImage()
self.patch_size = backbone.get('patch_size')
self.padding = AdaptivePadding(self.patch_size, self.patch_size)
def extract_feat(self, inputs: Tensor) -> List[Tensor]:
"""Extract features from images."""
x = self.backbone(inputs)
return x
def forward(self, inputs: Tensor, return_feat=False) -> Tensor:
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""
inputs = self.clip_T(inputs)
x = self.extract_feat(inputs)
seg_logits, feats, k = self.decode_head(x, return_feat)
if return_feat:
return seg_logits, feats, k
return seg_logits
class MaskClipHead(nn.Module):
def __init__(self, clip_model, class_names, visual_projs_path=None, in_index=-1, in_channels=3, norm_cfg=None, channels=0,
text_channels=512, attn_pooling=False, align_corners=False, model_prefix='hf-hub:laion', use_templates=False, **kwargs):
super(MaskClipHead, self).__init__()
self.text_channels = text_channels
self.visual_projs_path = visual_projs_path
self.clip_model = clip_model
self.class_names = class_names
self.in_channels = in_channels
self.in_index = in_index # from base decode head default
self._init_inputs(in_channels, in_index, None)
self.channels = channels
self.norm_cfg = norm_cfg
self.align_corners = align_corners
self.use_templates = use_templates
self.proj = nn.Conv2d(self.in_channels, text_channels, 1, bias=False)
self.load_visual_projs()
self.attn_pooling = attn_pooling
self.tokenizer = get_tokenizer(f'{model_prefix}/{clip_model}')
self.hf_modelname = f'{model_prefix}/{clip_model}'
model, _ = create_model_from_pretrained(f'{model_prefix}/{clip_model}')
model.eval()
self.register_buffer("class_embeddings", self._get_class_embeddings(model, class_names))
@torch.no_grad()
def update_vocab(self, class_names):
model, _ = create_model_from_pretrained(self.hf_modelname)
model.eval()
self.class_embeddings = self._get_class_embeddings(model, class_names)
@torch.no_grad()
def _embed_label(self, text_model: torch.nn.Module, label: str) -> torch.Tensor:
"""
Encode label name into a single vector
"""
if self.use_templates:
templates = imagenet_templates
else:
templates = ['a photo of an {}' if label.startswith('aeiou') else 'a photo of a {}']
all_prompts = [self.tokenizer(template.format(label)) for template in templates]
out = text_model.encode_text(torch.cat(all_prompts))
out /= out.norm(dim=-1, keepdim=True)
out = out.mean(dim=0)
return out
def _get_class_embeddings(self, text_model: torch.nn.Module, class_names: List[str]):
aug_embeddings = torch.stack([self._embed_label(text_model, label) for label in class_names])
# normalize vector
aug_embeddings = aug_embeddings / aug_embeddings.norm(dim=-1, keepdim=True)
return aug_embeddings.squeeze(1)
def load_visual_projs(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
loaded = torch.load(self.visual_projs_path, map_location=device)
attrs = ['proj']
for attr in attrs:
current_attr = getattr(self, attr)
state_dict = loaded[attr]
for key in state_dict:
if 'weight' in key:
state_dict[key] = state_dict[key][:, :, None, None]
current_attr.load_state_dict(state_dict)
print_log(f'Loaded proj weights from {self.visual_projs_path}', logger=get_root_logger())
def forward(self, inputs, return_feat=False):
x = self._transform_inputs(inputs)
q, k, v, cls_token = None, None, None, None
if isinstance(x, list) and len(x) == 4:
x, q, k, v = x
if isinstance(x, list) and len(x) == 2:
x, cls_token = x
if v is not None:
feat = self.proj(v)
else:
feat = self.proj(x)
output = self.cls_seg(feat)
if return_feat:
return output, feat, k
return output
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def cls_seg(self, feat):
feat = feat / feat.norm(dim=1, keepdim=True)
output = F.conv2d(feat, self.class_embeddings[:, :, None, None])
output = F.softmax(output * 100, dim=1) # softmax of similarities with temp scaling
return output
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs