Spaces:
Sleeping
Sleeping
# ------------------------------------------------------------------------------ | |
# CLIP-DINOiser | |
# author: Monika Wysoczanska, Warsaw University of Technology | |
# ------------------------------------------------------------------------------ | |
# Modified from OpenMMLab https://github.com/chongzhou96/MaskCLIP | |
# ------------------------------------------------------------------------------ | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmseg.ops import resize | |
from typing import Any, List | |
from torch import Tensor | |
from mmcv.utils import print_log | |
from mmseg.utils import get_root_logger | |
from open_clip import get_tokenizer, create_model_from_pretrained | |
from models.builder import MODELS | |
from .vit import VisionTransformer | |
import torchvision.transforms as T | |
from .utils.embed import AdaptivePadding | |
from .utils.prompt_templates import imagenet_templates | |
OPENAI_NORMALIZE = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
def make_vision_transformer(backbone_cfg): | |
model = VisionTransformer(**backbone_cfg) | |
model.init_weights() | |
return model | |
class MaskClip(nn.Module): | |
def __init__( | |
self, | |
backbone, | |
decode_head, | |
clip_model, | |
class_names | |
): | |
super(MaskClip, self).__init__() | |
self.decode_head = eval(decode_head.get('type'))(clip_model, class_names, **decode_head) | |
self.backbone = make_vision_transformer(backbone) | |
self.clip_T = OPENAI_NORMALIZE | |
self.to_PIL = T.ToPILImage() | |
self.patch_size = backbone.get('patch_size') | |
self.padding = AdaptivePadding(self.patch_size, self.patch_size) | |
def extract_feat(self, inputs: Tensor) -> List[Tensor]: | |
"""Extract features from images.""" | |
x = self.backbone(inputs) | |
return x | |
def forward(self, inputs: Tensor, return_feat=False) -> Tensor: | |
"""Encode images with backbone and decode into a semantic segmentation | |
map of the same size as input.""" | |
inputs = self.clip_T(inputs) | |
x = self.extract_feat(inputs) | |
seg_logits, feats, k = self.decode_head(x, return_feat) | |
if return_feat: | |
return seg_logits, feats, k | |
return seg_logits | |
class MaskClipHead(nn.Module): | |
def __init__(self, clip_model, class_names, visual_projs_path=None, in_index=-1, in_channels=3, norm_cfg=None, channels=0, | |
text_channels=512, attn_pooling=False, align_corners=False, model_prefix='hf-hub:laion', use_templates=False, **kwargs): | |
super(MaskClipHead, self).__init__() | |
self.text_channels = text_channels | |
self.visual_projs_path = visual_projs_path | |
self.clip_model = clip_model | |
self.class_names = class_names | |
self.in_channels = in_channels | |
self.in_index = in_index # from base decode head default | |
self._init_inputs(in_channels, in_index, None) | |
self.channels = channels | |
self.norm_cfg = norm_cfg | |
self.align_corners = align_corners | |
self.use_templates = use_templates | |
self.proj = nn.Conv2d(self.in_channels, text_channels, 1, bias=False) | |
self.load_visual_projs() | |
self.attn_pooling = attn_pooling | |
self.tokenizer = get_tokenizer(f'{model_prefix}/{clip_model}') | |
self.hf_modelname = f'{model_prefix}/{clip_model}' | |
model, _ = create_model_from_pretrained(f'{model_prefix}/{clip_model}') | |
model.eval() | |
self.register_buffer("class_embeddings", self._get_class_embeddings(model, class_names)) | |
def update_vocab(self, class_names): | |
model, _ = create_model_from_pretrained(self.hf_modelname) | |
model.eval() | |
self.class_embeddings = self._get_class_embeddings(model, class_names) | |
def _embed_label(self, text_model: torch.nn.Module, label: str) -> torch.Tensor: | |
""" | |
Encode label name into a single vector | |
""" | |
if self.use_templates: | |
templates = imagenet_templates | |
else: | |
templates = ['a photo of an {}' if label.startswith('aeiou') else 'a photo of a {}'] | |
all_prompts = [self.tokenizer(template.format(label)) for template in templates] | |
out = text_model.encode_text(torch.cat(all_prompts)) | |
out /= out.norm(dim=-1, keepdim=True) | |
out = out.mean(dim=0) | |
return out | |
def _get_class_embeddings(self, text_model: torch.nn.Module, class_names: List[str]): | |
aug_embeddings = torch.stack([self._embed_label(text_model, label) for label in class_names]) | |
# normalize vector | |
aug_embeddings = aug_embeddings / aug_embeddings.norm(dim=-1, keepdim=True) | |
return aug_embeddings.squeeze(1) | |
def load_visual_projs(self): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
loaded = torch.load(self.visual_projs_path, map_location=device) | |
attrs = ['proj'] | |
for attr in attrs: | |
current_attr = getattr(self, attr) | |
state_dict = loaded[attr] | |
for key in state_dict: | |
if 'weight' in key: | |
state_dict[key] = state_dict[key][:, :, None, None] | |
current_attr.load_state_dict(state_dict) | |
print_log(f'Loaded proj weights from {self.visual_projs_path}', logger=get_root_logger()) | |
def forward(self, inputs, return_feat=False): | |
x = self._transform_inputs(inputs) | |
q, k, v, cls_token = None, None, None, None | |
if isinstance(x, list) and len(x) == 4: | |
x, q, k, v = x | |
if isinstance(x, list) and len(x) == 2: | |
x, cls_token = x | |
if v is not None: | |
feat = self.proj(v) | |
else: | |
feat = self.proj(x) | |
output = self.cls_seg(feat) | |
if return_feat: | |
return output, feat, k | |
return output | |
def _init_inputs(self, in_channels, in_index, input_transform): | |
"""Check and initialize input transforms. | |
The in_channels, in_index and input_transform must match. | |
Specifically, when input_transform is None, only single feature map | |
will be selected. So in_channels and in_index must be of type int. | |
When input_transform | |
Args: | |
in_channels (int|Sequence[int]): Input channels. | |
in_index (int|Sequence[int]): Input feature index. | |
input_transform (str|None): Transformation type of input features. | |
Options: 'resize_concat', 'multiple_select', None. | |
'resize_concat': Multiple feature maps will be resize to the | |
same size as first one and than concat together. | |
Usually used in FCN head of HRNet. | |
'multiple_select': Multiple feature maps will be bundle into | |
a list and passed into decode head. | |
None: Only one select feature map is allowed. | |
""" | |
if input_transform is not None: | |
assert input_transform in ['resize_concat', 'multiple_select'] | |
self.input_transform = input_transform | |
self.in_index = in_index | |
if input_transform is not None: | |
assert isinstance(in_channels, (list, tuple)) | |
assert isinstance(in_index, (list, tuple)) | |
assert len(in_channels) == len(in_index) | |
if input_transform == 'resize_concat': | |
self.in_channels = sum(in_channels) | |
else: | |
self.in_channels = in_channels | |
else: | |
assert isinstance(in_channels, int) | |
assert isinstance(in_index, int) | |
self.in_channels = in_channels | |
def cls_seg(self, feat): | |
feat = feat / feat.norm(dim=1, keepdim=True) | |
output = F.conv2d(feat, self.class_embeddings[:, :, None, None]) | |
output = F.softmax(output * 100, dim=1) # softmax of similarities with temp scaling | |
return output | |
def _transform_inputs(self, inputs): | |
"""Transform inputs for decoder. | |
Args: | |
inputs (list[Tensor]): List of multi-level img features. | |
Returns: | |
Tensor: The transformed inputs | |
""" | |
if self.input_transform == 'resize_concat': | |
inputs = [inputs[i] for i in self.in_index] | |
upsampled_inputs = [ | |
resize( | |
input=x, | |
size=inputs[0].shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) for x in inputs | |
] | |
inputs = torch.cat(upsampled_inputs, dim=1) | |
elif self.input_transform == 'multiple_select': | |
inputs = [inputs[i] for i in self.in_index] | |
else: | |
inputs = inputs[self.in_index] | |
return inputs | |