doevent commited on
Commit
d1e5ec4
1 Parent(s): 98ab872

Upload models/blip_nlvr.py

Browse files
Files changed (1) hide show
  1. models/blip_nlvr.py +103 -0
models/blip_nlvr.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.med import BertConfig
2
+ from models.nlvr_encoder import BertModel
3
+ from models.vit import interpolate_pos_embed
4
+ from models.blip import create_vit, init_tokenizer, is_url
5
+
6
+ from timm.models.hub import download_cached_file
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ from transformers import BertTokenizer
12
+ import numpy as np
13
+
14
+ class BLIP_NLVR(nn.Module):
15
+ def __init__(self,
16
+ med_config = 'configs/med_config.json',
17
+ image_size = 480,
18
+ vit = 'base',
19
+ vit_grad_ckpt = False,
20
+ vit_ckpt_layer = 0,
21
+ ):
22
+ """
23
+ Args:
24
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
25
+ image_size (int): input image size
26
+ vit (str): model size of vision transformer
27
+ """
28
+ super().__init__()
29
+
30
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
31
+ self.tokenizer = init_tokenizer()
32
+ med_config = BertConfig.from_json_file(med_config)
33
+ med_config.encoder_width = vision_width
34
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
35
+
36
+ self.cls_head = nn.Sequential(
37
+ nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
38
+ nn.ReLU(),
39
+ nn.Linear(self.text_encoder.config.hidden_size, 2)
40
+ )
41
+
42
+ def forward(self, image, text, targets, train=True):
43
+
44
+ image_embeds = self.visual_encoder(image)
45
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
46
+ image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
47
+
48
+ text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
49
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
50
+
51
+ output = self.text_encoder(text.input_ids,
52
+ attention_mask = text.attention_mask,
53
+ encoder_hidden_states = [image0_embeds,image1_embeds],
54
+ encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
55
+ image_atts[image0_embeds.size(0):]],
56
+ return_dict = True,
57
+ )
58
+ hidden_state = output.last_hidden_state[:,0,:]
59
+ prediction = self.cls_head(hidden_state)
60
+
61
+ if train:
62
+ loss = F.cross_entropy(prediction, targets)
63
+ return loss
64
+ else:
65
+ return prediction
66
+
67
+ def blip_nlvr(pretrained='',**kwargs):
68
+ model = BLIP_NLVR(**kwargs)
69
+ if pretrained:
70
+ model,msg = load_checkpoint(model,pretrained)
71
+ print("missing keys:")
72
+ print(msg.missing_keys)
73
+ return model
74
+
75
+
76
+ def load_checkpoint(model,url_or_filename):
77
+ if is_url(url_or_filename):
78
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
79
+ checkpoint = torch.load(cached_file, map_location='cpu')
80
+ elif os.path.isfile(url_or_filename):
81
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
82
+ else:
83
+ raise RuntimeError('checkpoint url or path is invalid')
84
+ state_dict = checkpoint['model']
85
+
86
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
87
+
88
+ for key in list(state_dict.keys()):
89
+ if 'crossattention.self.' in key:
90
+ new_key0 = key.replace('self','self0')
91
+ new_key1 = key.replace('self','self1')
92
+ state_dict[new_key0] = state_dict[key]
93
+ state_dict[new_key1] = state_dict[key]
94
+ elif 'crossattention.output.dense.' in key:
95
+ new_key0 = key.replace('dense','dense0')
96
+ new_key1 = key.replace('dense','dense1')
97
+ state_dict[new_key0] = state_dict[key]
98
+ state_dict[new_key1] = state_dict[key]
99
+
100
+ msg = model.load_state_dict(state_dict,strict=False)
101
+ print('load checkpoint from %s'%url_or_filename)
102
+ return model,msg
103
+