File size: 4,705 Bytes
8a096e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import torch
import torch.nn as nn
import re
import math
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
def build_vision_tower():
vision_tower = 'openai/clip-vit-large-patch14-336'
return CLIPVisionTower(vision_tower)
class CLIPVisionTowerHD(nn.Module):
def __init__(self, config, vision_select_layer=-2):
super().__init__()
self.is_loaded = False
# self.vision_tower_name = vision_tower
self.vis_config = config
self.select_layer = vision_select_layer
self.select_feature = 'patch'
self.load_model()
def load_model(self):
# self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel(CLIPVisionConfig(**self.vis_config))
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def resize_pos(self):
print ('Dummy Resized')
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
def forward(self, images, glb_GN, sub_GN):
if not self.is_loaded:
self.load_model()
assert type(images) is list
shapes = []
input_imgs = []
for img in images:
_, C, H, W = img.shape
shapes.append([H//336, W//336])
sub_img = img.reshape(1,3,H//336,336,W//336,336).permute(0,2,4,1,3,5).reshape(-1,3,336,336).contiguous()
glb_img = torch.nn.functional.interpolate(img.float(), size=(336,336), mode='bicubic',).to(sub_img.dtype)
input_imgs.append(glb_img)
input_imgs.append(sub_img)
input_imgs = torch.cat(input_imgs, dim=0)
image_forward_outs = self.vision_tower(input_imgs.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(input_imgs.dtype) ### B*?, N, C
_, N, C = image_features.shape
H = int(math.sqrt(N))
assert N == 24 ** 2
output_imgs = []
output_len = []
for [h, w] in shapes:
B_ = h*w
glb_img = image_features[:1] ### 1, N, C
glb_img = glb_img.reshape(1,H,H,C).reshape(1,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(1,H//2,H//2,4*C).contiguous()
temp_glb_GN = sub_GN.repeat(1, H//2, 1, 1)
glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1,-1,4*C)
sub_img = image_features[1:1+B_] ### ?, N, C
sub_img = sub_img.reshape(B_,H,H,C).reshape(B_,H//2,2,H//2,2,C).contiguous().permute(0,1,3,2,4,5).reshape(B_,-1,4*C).contiguous()
sub_img = sub_img.reshape(1, h, w, 12, 12, -1).permute(0,1,3,2,4,5).reshape(1,h*12,w*12,4*C)
temp_sub_GN = sub_GN.repeat(1, h*12, 1, 1)
sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1,-1,4*C)
output_imgs.append(torch.cat([glb_img, glb_GN, sub_img], dim=1))
temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
assert temp_len == output_imgs[-1].shape[1]
output_len.append(temp_len)
image_features = image_features[1+h*w:]
new_output_imgs = []
max_len = max(output_len)
for img_feat in output_imgs:
if img_feat.shape[1] < max_len:
pad_feat = torch.zeros(1, (max_len-img_feat.shape[1]), img_feat.shape[2]).to(img_feat.device)
img_feat_padding = torch.cat([img_feat, pad_feat], dim=1)
new_output_imgs.append(img_feat_padding)
else:
new_output_imgs.append(img_feat)
output_imgs = torch.cat(new_output_imgs, dim=0)
return output_imgs, output_len
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def num_features(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
|