Spaces:
Sleeping
Sleeping
File size: 8,803 Bytes
d2ff88f 9199a4f d2ff88f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
# ------------------------------------------------------------------------------
# CLIP-DINOiser
# author: Monika Wysoczanska, Warsaw University of Technology
# ------------------------------------------------------------------------------
# Modified from OpenMMLab https://github.com/chongzhou96/MaskCLIP
# ------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.ops import resize
from typing import Any, List
from torch import Tensor
from mmcv.utils import print_log
from mmseg.utils import get_root_logger
from open_clip import get_tokenizer, create_model_from_pretrained
from models.builder import MODELS
from .vit import VisionTransformer
import torchvision.transforms as T
from .utils.embed import AdaptivePadding
from .utils.prompt_templates import imagenet_templates
OPENAI_NORMALIZE = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
def make_vision_transformer(backbone_cfg):
model = VisionTransformer(**backbone_cfg)
model.init_weights()
return model
@MODELS.register_module()
class MaskClip(nn.Module):
def __init__(
self,
backbone,
decode_head,
clip_model,
class_names
):
super(MaskClip, self).__init__()
self.decode_head = eval(decode_head.get('type'))(clip_model, class_names, **decode_head)
self.backbone = make_vision_transformer(backbone)
self.clip_T = OPENAI_NORMALIZE
self.to_PIL = T.ToPILImage()
self.patch_size = backbone.get('patch_size')
self.padding = AdaptivePadding(self.patch_size, self.patch_size)
def extract_feat(self, inputs: Tensor) -> List[Tensor]:
"""Extract features from images."""
x = self.backbone(inputs)
return x
def forward(self, inputs: Tensor, return_feat=False) -> Tensor:
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""
inputs = self.clip_T(inputs)
x = self.extract_feat(inputs)
seg_logits, feats, k = self.decode_head(x, return_feat)
if return_feat:
return seg_logits, feats, k
return seg_logits
class MaskClipHead(nn.Module):
def __init__(self, clip_model, class_names, visual_projs_path=None, in_index=-1, in_channels=3, norm_cfg=None, channels=0,
text_channels=512, attn_pooling=False, align_corners=False, model_prefix='hf-hub:laion', use_templates=False, **kwargs):
super(MaskClipHead, self).__init__()
self.text_channels = text_channels
self.visual_projs_path = visual_projs_path
self.clip_model = clip_model
self.class_names = class_names
self.in_channels = in_channels
self.in_index = in_index # from base decode head default
self._init_inputs(in_channels, in_index, None)
self.channels = channels
self.norm_cfg = norm_cfg
self.align_corners = align_corners
self.use_templates = use_templates
self.proj = nn.Conv2d(self.in_channels, text_channels, 1, bias=False)
self.load_visual_projs()
self.attn_pooling = attn_pooling
self.tokenizer = get_tokenizer(f'{model_prefix}/{clip_model}')
self.hf_modelname = f'{model_prefix}/{clip_model}'
model, _ = create_model_from_pretrained(f'{model_prefix}/{clip_model}')
model.eval()
self.register_buffer("class_embeddings", self._get_class_embeddings(model, class_names))
@torch.no_grad()
def update_vocab(self, class_names):
model, _ = create_model_from_pretrained(self.hf_modelname)
model.eval()
self.class_embeddings = self._get_class_embeddings(model, class_names)
@torch.no_grad()
def _embed_label(self, text_model: torch.nn.Module, label: str) -> torch.Tensor:
"""
Encode label name into a single vector
"""
if self.use_templates:
templates = imagenet_templates
else:
templates = ['a photo of an {}' if label.startswith('aeiou') else 'a photo of a {}']
all_prompts = [self.tokenizer(template.format(label)) for template in templates]
out = text_model.encode_text(torch.cat(all_prompts))
out /= out.norm(dim=-1, keepdim=True)
out = out.mean(dim=0)
return out
def _get_class_embeddings(self, text_model: torch.nn.Module, class_names: List[str]):
aug_embeddings = torch.stack([self._embed_label(text_model, label) for label in class_names])
# normalize vector
aug_embeddings = aug_embeddings / aug_embeddings.norm(dim=-1, keepdim=True)
return aug_embeddings.squeeze(1)
def load_visual_projs(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
loaded = torch.load(self.visual_projs_path, map_location=device)
attrs = ['proj']
for attr in attrs:
current_attr = getattr(self, attr)
state_dict = loaded[attr]
for key in state_dict:
if 'weight' in key:
state_dict[key] = state_dict[key][:, :, None, None]
current_attr.load_state_dict(state_dict)
print_log(f'Loaded proj weights from {self.visual_projs_path}', logger=get_root_logger())
def forward(self, inputs, return_feat=False):
x = self._transform_inputs(inputs)
q, k, v, cls_token = None, None, None, None
if isinstance(x, list) and len(x) == 4:
x, q, k, v = x
if isinstance(x, list) and len(x) == 2:
x, cls_token = x
if v is not None:
feat = self.proj(v)
else:
feat = self.proj(x)
output = self.cls_seg(feat)
if return_feat:
return output, feat, k
return output
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def cls_seg(self, feat):
feat = feat / feat.norm(dim=1, keepdim=True)
output = F.conv2d(feat, self.class_embeddings[:, :, None, None])
output = F.softmax(output * 100, dim=1) # softmax of similarities with temp scaling
return output
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
|