File size: 7,374 Bytes
36d9761 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import torch
import torchvision
import torch.nn.functional as F
def attn_cosine_sim(x, eps=1e-08):
x = x[0] # TEMP: getting rid of redundant dimension, TBF
norm1 = x.norm(dim=2, keepdim=True)
factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps)
sim_matrix = (x @ x.permute(0, 2, 1)) / factor
return sim_matrix
class VitExtractor:
BLOCK_KEY = 'block'
ATTN_KEY = 'attn'
PATCH_IMD_KEY = 'patch_imd'
QKV_KEY = 'qkv'
KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY]
def __init__(self, model_name, device):
# pdb.set_trace()
self.model = torch.hub.load('facebookresearch/dino:main', model_name).to(device)
self.model.eval()
self.model_name = model_name
self.hook_handlers = []
self.layers_dict = {}
self.outputs_dict = {}
for key in VitExtractor.KEY_LIST:
self.layers_dict[key] = []
self.outputs_dict[key] = []
self._init_hooks_data()
def _init_hooks_data(self):
self.layers_dict[VitExtractor.BLOCK_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
self.layers_dict[VitExtractor.ATTN_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
self.layers_dict[VitExtractor.QKV_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
for key in VitExtractor.KEY_LIST:
# self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else []
self.outputs_dict[key] = []
def _register_hooks(self, **kwargs):
for block_idx, block in enumerate(self.model.blocks):
if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]:
self.hook_handlers.append(block.register_forward_hook(self._get_block_hook()))
if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]:
self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook()))
if block_idx in self.layers_dict[VitExtractor.QKV_KEY]:
self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook()))
if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]:
self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook()))
def _clear_hooks(self):
for handler in self.hook_handlers:
handler.remove()
self.hook_handlers = []
def _get_block_hook(self):
def _get_block_output(model, input, output):
self.outputs_dict[VitExtractor.BLOCK_KEY].append(output)
return _get_block_output
def _get_attn_hook(self):
def _get_attn_output(model, inp, output):
self.outputs_dict[VitExtractor.ATTN_KEY].append(output)
return _get_attn_output
def _get_qkv_hook(self):
def _get_qkv_output(model, inp, output):
self.outputs_dict[VitExtractor.QKV_KEY].append(output)
return _get_qkv_output
# TODO: CHECK ATTN OUTPUT TUPLE
def _get_patch_imd_hook(self):
def _get_attn_output(model, inp, output):
self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0])
return _get_attn_output
def get_feature_from_input(self, input_img): # List([B, N, D])
self._register_hooks()
self.model(input_img)
feature = self.outputs_dict[VitExtractor.BLOCK_KEY]
self._clear_hooks()
self._init_hooks_data()
return feature
def get_qkv_feature_from_input(self, input_img):
self._register_hooks()
self.model(input_img)
feature = self.outputs_dict[VitExtractor.QKV_KEY]
self._clear_hooks()
self._init_hooks_data()
return feature
def get_attn_feature_from_input(self, input_img):
self._register_hooks()
self.model(input_img)
feature = self.outputs_dict[VitExtractor.ATTN_KEY]
self._clear_hooks()
self._init_hooks_data()
return feature
def get_patch_size(self):
return 8 if "8" in self.model_name else 16
def get_width_patch_num(self, input_img_shape):
b, c, h, w = input_img_shape
patch_size = self.get_patch_size()
return w // patch_size
def get_height_patch_num(self, input_img_shape):
b, c, h, w = input_img_shape
patch_size = self.get_patch_size()
return h // patch_size
def get_patch_num(self, input_img_shape):
patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape))
return patch_num
def get_head_num(self):
if "dino" in self.model_name:
return 6 if "s" in self.model_name else 12
return 6 if "small" in self.model_name else 12
def get_embedding_dim(self):
if "dino" in self.model_name:
return 384 if "s" in self.model_name else 768
return 384 if "small" in self.model_name else 768
def get_queries_from_qkv(self, qkv, input_img_shape):
patch_num = self.get_patch_num(input_img_shape)
head_num = self.get_head_num()
embedding_dim = self.get_embedding_dim()
q = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[0]
return q
def get_keys_from_qkv(self, qkv, input_img_shape):
patch_num = self.get_patch_num(input_img_shape)
head_num = self.get_head_num()
embedding_dim = self.get_embedding_dim()
k = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[1]
return k
def get_values_from_qkv(self, qkv, input_img_shape):
patch_num = self.get_patch_num(input_img_shape)
head_num = self.get_head_num()
embedding_dim = self.get_embedding_dim()
v = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[2]
return v
def get_keys_from_input(self, input_img, layer_num):
qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num]
keys = self.get_keys_from_qkv(qkv_features, input_img.shape)
return keys
def get_keys_self_sim_from_input(self, input_img, layer_num):
keys = self.get_keys_from_input(input_img, layer_num=layer_num)
h, t, d = keys.shape
concatenated_keys = keys.transpose(0, 1).reshape(t, h * d)
ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...])
return ssim_map
class DinoStructureLoss:
def __init__(self, ):
self.extractor = VitExtractor(model_name="dino_vitb8", device="cuda")
self.preprocess = torchvision.transforms.Compose([
torchvision.transforms.Resize(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
def calculate_global_ssim_loss(self, outputs, inputs):
loss = 0.0
for a, b in zip(inputs, outputs): # avoid memory limitations
with torch.no_grad():
target_keys_self_sim = self.extractor.get_keys_self_sim_from_input(a.unsqueeze(0), layer_num=11)
keys_ssim = self.extractor.get_keys_self_sim_from_input(b.unsqueeze(0), layer_num=11)
loss += F.mse_loss(keys_ssim, target_keys_self_sim)
return loss
|