Alberto Carmona commited on
Commit
cff97d1
1 Parent(s): ec81c7a

Merge the caption project with the text2image

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +67 -44
  2. captioning/__init__.py +0 -0
  3. captioning/data/__init__.py +0 -0
  4. captioning/data/dataloader.py +425 -0
  5. captioning/data/pth_loader.py +334 -0
  6. captioning/data/pth_loader_FineCapEval.py +334 -0
  7. captioning/models/AoAModel.py +228 -0
  8. captioning/models/AttEnsemble.py +90 -0
  9. captioning/models/AttModel.py +969 -0
  10. captioning/models/BertCapModel.py +104 -0
  11. captioning/models/CaptionModel.py +407 -0
  12. captioning/models/FCModel.py +204 -0
  13. captioning/models/M2Transformer.py +98 -0
  14. captioning/models/ShowTellModel.py +174 -0
  15. captioning/models/TransformerModel.py +363 -0
  16. captioning/models/__init__.py +73 -0
  17. captioning/models/cachedTransformer.py +420 -0
  18. captioning/models/utils.py +25 -0
  19. captioning/modules/loss_wrapper.py +127 -0
  20. captioning/modules/losses.py +218 -0
  21. captioning/utils/__init__.py +0 -0
  22. captioning/utils/clipscore.py +396 -0
  23. captioning/utils/config.py +153 -0
  24. captioning/utils/dist_utils.py +305 -0
  25. captioning/utils/div_utils.py +38 -0
  26. captioning/utils/eval_multi.py +218 -0
  27. captioning/utils/eval_utils.py +281 -0
  28. captioning/utils/misc.py +251 -0
  29. captioning/utils/opts.py +412 -0
  30. captioning/utils/resnet.py +71 -0
  31. captioning/utils/resnet_utils.py +27 -0
  32. captioning/utils/rewards.py +392 -0
  33. captioning/utils/utils.py +138 -0
  34. clip/__init__.py +1 -0
  35. clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  36. clip/clip.py +193 -0
  37. clip/model.py +437 -0
  38. clip/simple_tokenizer.py +132 -0
  39. configs/phase1/FineCapEval_clipRN50_mle.yml +60 -0
  40. configs/phase1/clipRN50_mle.yml +52 -0
  41. configs/phase1/transformer.yml +41 -0
  42. configs/phase2/FineCapEval_clipRN50_cider.yml +61 -0
  43. configs/phase2/FineCapEval_clipRN50_cider_clips.yml +65 -0
  44. configs/phase2/FineCapEval_clipRN50_clips.yml +64 -0
  45. configs/phase2/FineCapEval_clipRN50_clips_grammar.yml +64 -0
  46. configs/phase2/clipRN50_cider.yml +58 -0
  47. configs/phase2/clipRN50_cider_clips.yml +61 -0
  48. configs/phase2/clipRN50_clips.yml +58 -0
  49. configs/phase2/clipRN50_clips_grammar.yml +64 -0
  50. configs/phase2/transformer.yml +41 -0
app.py CHANGED
@@ -8,6 +8,11 @@ import captioning.utils.misc as utils
8
  import pytorch_lightning as pl
9
  import gradio as gr
10
 
 
 
 
 
 
11
 
12
  # Checkpoint class
13
  class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
@@ -47,7 +52,6 @@ opt.seq_length = seq_length
47
 
48
  opt.batch_size = 1
49
  opt.vocab = ix_to_word
50
- # opt.use_grammar = False
51
 
52
  model = models.setup(opt)
53
  del opt.vocab
@@ -111,55 +115,74 @@ clip_model.visual.attnpool.positional_embedding = pos_embed
111
 
112
 
113
  # End below
114
-
115
- def generate_image(img, steps=100, seed=42, guidance_scale=6.0):
116
-
117
- with torch.no_grad():
118
- image = preprocess(img)
119
- image = torch.tensor(np.stack([image])).to(device)
120
- image -= image_mean
121
- image /= image_std
122
-
123
- tmp_att, tmp_fc = clip_model.encode_image(image)
124
- tmp_att = tmp_att[0].permute(1, 2, 0)
125
- tmp_fc = tmp_fc[0]
126
-
127
- att_feat = tmp_att
128
- fc_feat = tmp_fc
129
-
 
 
 
 
 
 
 
 
130
 
131
- # Inference configurations
132
- eval_kwargs = {}
133
- eval_kwargs.update(vars(opt))
134
 
135
- verbose = eval_kwargs.get('verbose', True)
136
- verbose_beam = eval_kwargs.get('verbose_beam', 0)
137
- verbose_loss = eval_kwargs.get('verbose_loss', 1)
138
 
139
- # dataset = eval_kwargs.get('dataset', 'coco')
140
- beam_size = eval_kwargs.get('beam_size', 1)
141
- sample_n = eval_kwargs.get('sample_n', 1)
142
- remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
143
 
144
- with torch.no_grad():
145
- fc_feats = torch.zeros((1,0)).to(device)
146
- att_feats = att_feat.view(1, 196, 2048).float().to(device)
147
- att_masks = None
148
-
149
- # forward the model to also get generated samples for each image
150
- # Only leave one feature for each image, in case duplicate sample
151
- tmp_eval_kwargs = eval_kwargs.copy()
152
- tmp_eval_kwargs.update({'sample_n': 1})
153
- seq, seq_logprobs = model(
154
- fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
155
- seq = seq.data
156
-
157
- sents = utils.decode_sequence(model.vocab, seq)
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- return sents[0]
160
 
161
  gr.Interface(
162
- generate_image,
163
  inputs=[
164
  gr.Image(type="pil"),
165
  gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1),
@@ -168,4 +191,4 @@ gr.Interface(
168
  ],
169
  outputs=gr.Image(shape=[256,256], type="pil", elem_id="output_image"),
170
  css="#output_image{width: 256px}",
171
- ).launch()
 
8
  import pytorch_lightning as pl
9
  import gradio as gr
10
 
11
+ from diffusers import LDMTextToImagePipeline
12
+ # import PIL.Image
13
+ import random
14
+ import os
15
+
16
 
17
  # Checkpoint class
18
  class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
 
52
 
53
  opt.batch_size = 1
54
  opt.vocab = ix_to_word
 
55
 
56
  model = models.setup(opt)
57
  del opt.vocab
 
115
 
116
 
117
  # End below
118
+ print('Loading the model: CompVis/ldm-text2im-large-256')
119
+ ldm_pipeline = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
120
+
121
+ def generate_image_from_text(prompt, steps=100, seed=42, guidance_scale=6.0):
122
+ print('RUN: generate_image_from_text')
123
+ torch.cuda.empty_cache()
124
+ generator = torch.manual_seed(seed)
125
+ images = ldm_pipeline([prompt], generator=generator, num_inference_steps=steps, eta=0.3, guidance_scale=guidance_scale)["sample"]
126
+ return images[0]
127
+
128
+ def generate_text_from_image(img):
129
+ print('RUN: generate_text_from_image')
130
+ with torch.no_grad():
131
+ image = preprocess(img)
132
+ image = torch.tensor(np.stack([image])).to(device)
133
+ image -= image_mean
134
+ image /= image_std
135
+
136
+ tmp_att, tmp_fc = clip_model.encode_image(image)
137
+ tmp_att = tmp_att[0].permute(1, 2, 0)
138
+ tmp_fc = tmp_fc[0]
139
+
140
+ att_feat = tmp_att
141
+ fc_feat = tmp_fc
142
 
143
+ # Inference configurations
144
+ eval_kwargs = {}
145
+ eval_kwargs.update(vars(opt))
146
 
147
+ verbose = eval_kwargs.get('verbose', True)
148
+ verbose_beam = eval_kwargs.get('verbose_beam', 0)
149
+ verbose_loss = eval_kwargs.get('verbose_loss', 1)
150
 
151
+ # dataset = eval_kwargs.get('dataset', 'coco')
152
+ beam_size = eval_kwargs.get('beam_size', 1)
153
+ sample_n = eval_kwargs.get('sample_n', 1)
154
+ remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
155
 
156
+ with torch.no_grad():
157
+ fc_feats = torch.zeros((1,0)).to(device)
158
+ att_feats = att_feat.view(1, 196, 2048).float().to(device)
159
+ att_masks = None
160
+
161
+ # forward the model to also get generated samples for each image
162
+ # Only leave one feature for each image, in case duplicate sample
163
+ tmp_eval_kwargs = eval_kwargs.copy()
164
+ tmp_eval_kwargs.update({'sample_n': 1})
165
+ seq, seq_logprobs = model(
166
+ fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
167
+ seq = seq.data
168
+
169
+ sents = utils.decode_sequence(model.vocab, seq)
170
+
171
+ return sents[0]
172
+
173
+
174
+ def generate_drawing_from_image(img, steps=100, seed=42, guidance_scale=6.0):
175
+ print('RUN: generate_drawing_from_image')
176
+ caption = generate_text_from_image(img)
177
+ gen_image = generate_image_from_text(caption, steps=steps, seed=seed, guidance_scale=guidance_scale)
178
+ return gen_image
179
+
180
+
181
+ random_seed = random.randint(0, 2147483647)
182
 
 
183
 
184
  gr.Interface(
185
+ generate_drawing_from_image,
186
  inputs=[
187
  gr.Image(type="pil"),
188
  gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1),
 
191
  ],
192
  outputs=gr.Image(shape=[256,256], type="pil", elem_id="output_image"),
193
  css="#output_image{width: 256px}",
194
+ ).launch()
captioning/__init__.py ADDED
File without changes
captioning/data/__init__.py ADDED
File without changes
captioning/data/dataloader.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import json
6
+ import h5py
7
+ from lmdbdict import lmdbdict
8
+ from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
9
+ import os
10
+ import numpy as np
11
+ import numpy.random as npr
12
+ import random
13
+ from functools import partial
14
+
15
+ import torch
16
+ import torch.utils.data as data
17
+
18
+ import multiprocessing
19
+ import six
20
+
21
+ class HybridLoader:
22
+ """
23
+ If db_path is a director, then use normal file loading
24
+ If lmdb, then load from lmdb
25
+ The loading method depend on extention.
26
+
27
+ in_memory: if in_memory is True, we save all the features in memory
28
+ For individual np(y|z)s, we don't need to do that because the system will do this for us.
29
+ Should be useful for lmdb or h5.
30
+ (Copied this idea from vilbert)
31
+ """
32
+ def __init__(self, db_path, ext, in_memory=False):
33
+ self.db_path = db_path
34
+ self.ext = ext
35
+ if self.ext == '.npy':
36
+ self.loader = lambda x: np.load(six.BytesIO(x))
37
+ else:
38
+ def load_npz(x):
39
+ x = np.load(six.BytesIO(x))
40
+ return x['feat'] if 'feat' in x else x['z'] # normally it should be 'feat', but under cocotest_bu, the key is saved to be 'z' mistakenly.
41
+ self.loader = load_npz
42
+ if db_path.endswith('.lmdb'):
43
+ self.db_type = 'lmdb'
44
+ self.lmdb = lmdbdict(db_path, unsafe=True)
45
+ self.lmdb._key_dumps = DUMPS_FUNC['ascii']
46
+ self.lmdb._value_loads = LOADS_FUNC['identity']
47
+ elif db_path.endswith('.pth'): # Assume a key,value dictionary
48
+ self.db_type = 'pth'
49
+ self.feat_file = torch.load(db_path)
50
+ self.loader = lambda x: x
51
+ print('HybridLoader: ext is ignored')
52
+ elif db_path.endswith('h5'):
53
+ self.db_type = 'h5'
54
+ self.loader = lambda x: np.array(x).astype('float32')
55
+ else:
56
+ self.db_type = 'dir'
57
+
58
+ self.in_memory = in_memory
59
+ if self.in_memory:
60
+ self.features = {}
61
+
62
+ def get(self, key):
63
+
64
+ if self.in_memory and key in self.features:
65
+ # We save f_input because we want to save the
66
+ # compressed bytes to save memory
67
+ f_input = self.features[key]
68
+ elif self.db_type == 'lmdb':
69
+ f_input = self.lmdb[key]
70
+ elif self.db_type == 'pth':
71
+ f_input = self.feat_file[key]
72
+ elif self.db_type == 'h5':
73
+ f_input = h5py.File(self.db_path, 'r')[key]
74
+ else:
75
+ f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
76
+
77
+ if self.in_memory and key not in self.features:
78
+ self.features[key] = f_input
79
+
80
+ # load image
81
+ feat = self.loader(f_input)
82
+
83
+ return feat
84
+
85
+ class Dataset(data.Dataset):
86
+
87
+ def get_vocab_size(self):
88
+ return self.vocab_size
89
+
90
+ def get_vocab(self):
91
+ return self.ix_to_word
92
+
93
+ def get_seq_length(self):
94
+ return self.seq_length
95
+
96
+ def __init__(self, opt):
97
+ self.opt = opt
98
+ self.seq_per_img = opt.seq_per_img
99
+
100
+ # feature related options
101
+ self.use_fc = getattr(opt, 'use_fc', True)
102
+ self.use_att = getattr(opt, 'use_att', True)
103
+ self.use_box = getattr(opt, 'use_box', 0)
104
+ self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
105
+ self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
106
+
107
+ # load the json file which contains additional information about the dataset
108
+ print('DataLoader loading json file: ', opt.input_json)
109
+ self.info = json.load(open(self.opt.input_json))
110
+ if 'ix_to_word' in self.info:
111
+ self.ix_to_word = self.info['ix_to_word']
112
+ self.vocab_size = len(self.ix_to_word)
113
+ print('vocab size is ', self.vocab_size)
114
+
115
+ # open the hdf5 file
116
+ print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
117
+ """
118
+ Setting input_label_h5 to none is used when only doing generation.
119
+ For example, when you need to test on coco test set.
120
+ """
121
+ if self.opt.input_label_h5 != 'none':
122
+ self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
123
+ # load in the sequence data
124
+ seq_size = self.h5_label_file['labels'].shape
125
+ self.label = self.h5_label_file['labels'][:]
126
+ self.seq_length = seq_size[1]
127
+ print('max sequence length in data is', self.seq_length)
128
+ # load the pointers in full to RAM (should be small enough)
129
+ self.label_start_ix = self.h5_label_file['label_start_ix'][:]
130
+ self.label_end_ix = self.h5_label_file['label_end_ix'][:]
131
+ else:
132
+ self.seq_length = 1
133
+
134
+ self.data_in_memory = getattr(opt, 'data_in_memory', False)
135
+ self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
136
+ self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
137
+ self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
138
+
139
+ self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
140
+ print('read %d image features' %(self.num_images))
141
+
142
+ # separate out indexes for each of the provided splits
143
+ self.split_ix = {'train': [], 'val': [], 'test': []}
144
+ for ix in range(len(self.info['images'])):
145
+ img = self.info['images'][ix]
146
+ if not 'split' in img:
147
+ self.split_ix['train'].append(ix)
148
+ self.split_ix['val'].append(ix)
149
+ self.split_ix['test'].append(ix)
150
+ elif img['split'] == 'train':
151
+ self.split_ix['train'].append(ix)
152
+ elif img['split'] == 'val':
153
+ self.split_ix['val'].append(ix)
154
+ elif img['split'] == 'test':
155
+ self.split_ix['test'].append(ix)
156
+ elif opt.train_only == 0: # restval
157
+ self.split_ix['train'].append(ix)
158
+
159
+ print('assigned %d images to split train' %len(self.split_ix['train']))
160
+ print('assigned %d images to split val' %len(self.split_ix['val']))
161
+ print('assigned %d images to split test' %len(self.split_ix['test']))
162
+
163
+ def get_captions(self, ix, seq_per_img):
164
+ # fetch the sequence labels
165
+ ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
166
+ ix2 = self.label_end_ix[ix] - 1
167
+ ncap = ix2 - ix1 + 1 # number of captions available for this image
168
+ assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
169
+
170
+ if ncap < seq_per_img:
171
+ # we need to subsample (with replacement)
172
+ seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
173
+ for q in range(seq_per_img):
174
+ ixl = random.randint(ix1,ix2)
175
+ seq[q, :] = self.label[ixl, :self.seq_length]
176
+ else:
177
+ ixl = random.randint(ix1, ix2 - seq_per_img + 1)
178
+ seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
179
+
180
+ return seq
181
+
182
+ def collate_func(self, batch, split):
183
+ seq_per_img = self.seq_per_img
184
+
185
+ fc_batch = []
186
+ att_batch = []
187
+ label_batch = []
188
+
189
+ wrapped = False
190
+
191
+ infos = []
192
+ gts = []
193
+
194
+ for sample in batch:
195
+ # fetch image
196
+ tmp_fc, tmp_att, tmp_seq, \
197
+ ix, it_pos_now, tmp_wrapped = sample
198
+ if tmp_wrapped:
199
+ wrapped = True
200
+
201
+ fc_batch.append(tmp_fc)
202
+ att_batch.append(tmp_att)
203
+
204
+ tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
205
+ if hasattr(self, 'h5_label_file'):
206
+ # if there is ground truth
207
+ tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
208
+ label_batch.append(tmp_label)
209
+
210
+ # Used for reward evaluation
211
+ if hasattr(self, 'h5_label_file'):
212
+ # if there is ground truth
213
+ gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
214
+ else:
215
+ gts.append([])
216
+
217
+ # record associated info as well
218
+ info_dict = {}
219
+ info_dict['ix'] = ix
220
+ info_dict['id'] = self.info['images'][ix]['id']
221
+ info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
222
+ infos.append(info_dict)
223
+
224
+ # #sort by att_feat length
225
+ # fc_batch, att_batch, label_batch, gts, infos = \
226
+ # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
227
+ fc_batch, att_batch, label_batch, gts, infos = \
228
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
229
+ data = {}
230
+ data['fc_feats'] = np.stack(fc_batch)
231
+ # merge att_feats
232
+ max_att_len = max([_.shape[0] for _ in att_batch])
233
+ data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
234
+ for i in range(len(att_batch)):
235
+ data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
236
+ data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
237
+ for i in range(len(att_batch)):
238
+ data['att_masks'][i, :att_batch[i].shape[0]] = 1
239
+ # set att_masks to None if attention features have same length
240
+ if data['att_masks'].sum() == data['att_masks'].size:
241
+ data['att_masks'] = None
242
+
243
+ data['labels'] = np.vstack(label_batch)
244
+ # generate mask
245
+ nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
246
+ mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
247
+ for ix, row in enumerate(mask_batch):
248
+ row[:nonzeros[ix]] = 1
249
+ data['masks'] = mask_batch
250
+ data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
251
+ data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
252
+
253
+ data['gts'] = gts # all ground truth captions of each images
254
+ data['bounds'] = {'it_pos_now': it_pos_now, # the it_pos_now of the last sample
255
+ 'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
256
+ data['infos'] = infos
257
+
258
+ data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
259
+
260
+ return data
261
+
262
+ def __getitem__(self, index):
263
+ """This function returns a tuple that is further passed to collate_fn
264
+ """
265
+ ix, it_pos_now, wrapped = index #self.split_ix[index]
266
+ if self.use_att:
267
+ att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
268
+ # Reshape to K x C
269
+ att_feat = att_feat.reshape(-1, att_feat.shape[-1])
270
+ if self.norm_att_feat:
271
+ att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
272
+ if self.use_box:
273
+ box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
274
+ # devided by image width and height
275
+ x1,y1,x2,y2 = np.hsplit(box_feat, 4)
276
+ h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
277
+ box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
278
+ if self.norm_box_feat:
279
+ box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
280
+ att_feat = np.hstack([att_feat, box_feat])
281
+ # sort the features by the size of boxes
282
+ att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
283
+ else:
284
+ att_feat = np.zeros((0,0), dtype='float32')
285
+ if self.use_fc:
286
+ try:
287
+ fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
288
+ except:
289
+ # Use average of attention when there is no fc provided (For bottomup feature)
290
+ fc_feat = att_feat.mean(0)
291
+ else:
292
+ fc_feat = np.zeros((0), dtype='float32')
293
+ if hasattr(self, 'h5_label_file'):
294
+ seq = self.get_captions(ix, self.seq_per_img)
295
+ else:
296
+ seq = None
297
+ return (fc_feat,
298
+ att_feat, seq,
299
+ ix, it_pos_now, wrapped)
300
+
301
+ def __len__(self):
302
+ return len(self.info['images'])
303
+
304
+ class DataLoader:
305
+ def __init__(self, opt):
306
+ self.opt = opt
307
+ self.batch_size = self.opt.batch_size
308
+ self.dataset = Dataset(opt)
309
+
310
+ # Initialize loaders and iters
311
+ self.loaders, self.iters = {}, {}
312
+ for split in ['train', 'val', 'test']:
313
+ if split == 'train':
314
+ sampler = MySampler(self.dataset.split_ix[split], shuffle=True, wrap=True)
315
+ else:
316
+ sampler = MySampler(self.dataset.split_ix[split], shuffle=False, wrap=False)
317
+ self.loaders[split] = data.DataLoader(dataset=self.dataset,
318
+ batch_size=self.batch_size,
319
+ sampler=sampler,
320
+ pin_memory=True,
321
+ num_workers=4, # 4 is usually enough
322
+ collate_fn=partial(self.dataset.collate_func, split=split),
323
+ drop_last=False)
324
+ self.iters[split] = iter(self.loaders[split])
325
+
326
+ def get_batch(self, split):
327
+ try:
328
+ data = next(self.iters[split])
329
+ except StopIteration:
330
+ self.iters[split] = iter(self.loaders[split])
331
+ data = next(self.iters[split])
332
+ return data
333
+
334
+ def reset_iterator(self, split):
335
+ self.loaders[split].sampler._reset_iter()
336
+ self.iters[split] = iter(self.loaders[split])
337
+
338
+ def get_vocab_size(self):
339
+ return self.dataset.get_vocab_size()
340
+
341
+ @property
342
+ def vocab_size(self):
343
+ return self.get_vocab_size()
344
+
345
+ def get_vocab(self):
346
+ return self.dataset.get_vocab()
347
+
348
+ def get_seq_length(self):
349
+ return self.dataset.get_seq_length()
350
+
351
+ @property
352
+ def seq_length(self):
353
+ return self.get_seq_length()
354
+
355
+ def state_dict(self):
356
+ def get_prefetch_num(split):
357
+ if self.loaders[split].num_workers > 0:
358
+ return (self.iters[split]._send_idx - self.iters[split]._rcvd_idx) * self.batch_size
359
+ else:
360
+ return 0
361
+ return {split: loader.sampler.state_dict(get_prefetch_num(split)) \
362
+ for split, loader in self.loaders.items()}
363
+
364
+ def load_state_dict(self, state_dict=None):
365
+ if state_dict is None:
366
+ return
367
+ for split in self.loaders.keys():
368
+ self.loaders[split].sampler.load_state_dict(state_dict[split])
369
+
370
+
371
+ class MySampler(data.sampler.Sampler):
372
+ def __init__(self, index_list, shuffle, wrap):
373
+ self.index_list = index_list
374
+ self.shuffle = shuffle
375
+ self.wrap = wrap
376
+ # if wrap, there will be not stop iteration called
377
+ # wrap True used during training, and wrap False used during test.
378
+ self._reset_iter()
379
+
380
+ def __iter__(self):
381
+ return self
382
+
383
+ def __next__(self):
384
+ wrapped = False
385
+ if self.iter_counter == len(self._index_list):
386
+ self._reset_iter()
387
+ if self.wrap:
388
+ wrapped = True
389
+ else:
390
+ raise StopIteration()
391
+ if len(self._index_list) == 0: # overflow when 0 samples
392
+ return None
393
+ elem = (self._index_list[self.iter_counter], self.iter_counter+1, wrapped)
394
+ self.iter_counter += 1
395
+ return elem
396
+
397
+ def next(self):
398
+ return self.__next__()
399
+
400
+ def _reset_iter(self):
401
+ if self.shuffle:
402
+ rand_perm = npr.permutation(len(self.index_list))
403
+ self._index_list = [self.index_list[_] for _ in rand_perm]
404
+ else:
405
+ self._index_list = self.index_list
406
+
407
+ self.iter_counter = 0
408
+
409
+ def __len__(self):
410
+ return len(self.index_list)
411
+
412
+ def load_state_dict(self, state_dict=None):
413
+ if state_dict is None:
414
+ return
415
+ self._index_list = state_dict['index_list']
416
+ self.iter_counter = state_dict['iter_counter']
417
+
418
+ def state_dict(self, prefetched_num=None):
419
+ prefetched_num = prefetched_num or 0
420
+ return {
421
+ 'index_list': self._index_list,
422
+ 'iter_counter': self.iter_counter - prefetched_num
423
+ }
424
+
425
+
captioning/data/pth_loader.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import json
6
+ import h5py
7
+ from lmdbdict import lmdbdict
8
+ from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
9
+ import os
10
+ import numpy as np
11
+ import numpy.random as npr
12
+ import random
13
+
14
+ import torch
15
+ import torch.utils.data as data
16
+
17
+ import multiprocessing
18
+ import six
19
+
20
+ verbose = True
21
+ # import torch
22
+ # if torch.cuda.current_device() in [0, -1]:
23
+ if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
24
+ verbose = False
25
+
26
+ class HybridLoader:
27
+ """
28
+ If db_path is a director, then use normal file loading
29
+ If lmdb, then load from lmdb
30
+ The loading method depend on extention.
31
+
32
+ in_memory: if in_memory is True, we save all the features in memory
33
+ For individual np(y|z)s, we don't need to do that because the system will do this for us.
34
+ Should be useful for lmdb or h5.
35
+ (Copied this idea from vilbert)
36
+ """
37
+ def __init__(self, db_path, ext, in_memory=False):
38
+ self.db_path = db_path
39
+ self.ext = ext
40
+ if self.ext == '.npy':
41
+ self.loader = lambda x: np.load(six.BytesIO(x))
42
+ else:
43
+ self.loader = lambda x: np.load(six.BytesIO(x))['feat']
44
+ if db_path.endswith('.lmdb'):
45
+ self.db_type = 'lmdb'
46
+ self.lmdb = lmdbdict(db_path, unsafe=True)
47
+ self.lmdb._key_dumps = DUMPS_FUNC['ascii']
48
+ self.lmdb._value_loads = LOADS_FUNC['identity']
49
+ elif db_path.endswith('.pth'): # Assume a key,value dictionary
50
+ self.db_type = 'pth'
51
+ self.feat_file = torch.load(db_path)
52
+ self.loader = lambda x: x
53
+ print('HybridLoader: ext is ignored')
54
+ elif db_path.endswith('h5'):
55
+ self.db_type = 'h5'
56
+ self.loader = lambda x: np.array(x).astype('float32')
57
+ else:
58
+ self.db_type = 'dir'
59
+
60
+ self.in_memory = in_memory
61
+ if self.in_memory:
62
+ self.features = {}
63
+
64
+ def get(self, key):
65
+
66
+ if self.in_memory and key in self.features:
67
+ # We save f_input because we want to save the
68
+ # compressed bytes to save memory
69
+ f_input = self.features[key]
70
+ elif self.db_type == 'lmdb':
71
+ f_input = self.lmdb[key]
72
+ elif self.db_type == 'pth':
73
+ f_input = self.feat_file[key]
74
+ elif self.db_type == 'h5':
75
+ f_input = h5py.File(self.db_path, 'r')[key]
76
+ else:
77
+ f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
78
+
79
+ if self.in_memory and key not in self.features:
80
+ self.features[key] = f_input
81
+
82
+ # load image
83
+ feat = self.loader(f_input)
84
+
85
+ return feat
86
+
87
+ class CaptionDataset(data.Dataset):
88
+
89
+ def get_vocab_size(self):
90
+ return self.vocab_size
91
+
92
+ def get_vocab(self):
93
+ return self.ix_to_word
94
+
95
+ def get_seq_length(self):
96
+ return self.seq_length
97
+
98
+ def __init__(self, opt):
99
+ self.opt = opt
100
+ self.seq_per_img = opt.seq_per_img
101
+
102
+ # feature related options
103
+ self.use_fc = getattr(opt, 'use_fc', True)
104
+ self.use_att = getattr(opt, 'use_att', True)
105
+ self.use_box = getattr(opt, 'use_box', 0)
106
+ self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
107
+ self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
108
+
109
+ # load the json file which contains additional information about the dataset
110
+ if verbose:
111
+ print('DataLoader loading json file: ', opt.input_json)
112
+ self.info = json.load(open(self.opt.input_json))
113
+ if 'ix_to_word' in self.info:
114
+ self.ix_to_word = self.info['ix_to_word']
115
+ self.vocab_size = len(self.ix_to_word)
116
+ if verbose:
117
+ print('vocab size is ', self.vocab_size)
118
+
119
+ # open the hdf5 file
120
+ if verbose:
121
+ print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
122
+ """
123
+ Setting input_label_h5 to none is used when only doing generation.
124
+ For example, when you need to test on coco test set.
125
+ """
126
+ if self.opt.input_label_h5 != 'none':
127
+ self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
128
+ # load in the sequence data
129
+ seq_size = self.h5_label_file['labels'].shape
130
+ self.label = self.h5_label_file['labels'][:]
131
+ self.seq_length = seq_size[1]
132
+ if verbose:
133
+ print('max sequence length in data is', self.seq_length)
134
+ # load the pointers in full to RAM (should be small enough)
135
+ self.label_start_ix = self.h5_label_file['label_start_ix'][:]
136
+ self.label_end_ix = self.h5_label_file['label_end_ix'][:]
137
+ else:
138
+ self.seq_length = 1
139
+
140
+ self.data_in_memory = getattr(opt, 'data_in_memory', False)
141
+ self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
142
+ self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
143
+ self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
144
+
145
+ self.use_clipscore = getattr(opt, 'use_clipscore', False)
146
+ # if self.use_clipscore:
147
+ self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory)
148
+
149
+
150
+ self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
151
+ if verbose:
152
+ print('read %d image features' %(self.num_images))
153
+
154
+ # separate out indexes for each of the provided splits
155
+ self.split_ix = {'train': [], 'val': [], 'test': []}
156
+ for ix in range(len(self.info['images'])):
157
+ img = self.info['images'][ix]
158
+ if not 'split' in img:
159
+ self.split_ix['train'].append(ix)
160
+ self.split_ix['val'].append(ix)
161
+ self.split_ix['test'].append(ix)
162
+ elif img['split'] == 'train':
163
+ self.split_ix['train'].append(ix)
164
+ elif img['split'] == 'val':
165
+ self.split_ix['val'].append(ix)
166
+ elif img['split'] == 'test':
167
+ self.split_ix['test'].append(ix)
168
+ elif opt.train_only == 0: # restval
169
+ self.split_ix['train'].append(ix)
170
+
171
+ if verbose:
172
+ print('assigned %d images to split train' %len(self.split_ix['train']))
173
+ print('assigned %d images to split val' %len(self.split_ix['val']))
174
+ print('assigned %d images to split test' %len(self.split_ix['test']))
175
+
176
+ def get_captions(self, ix, seq_per_img):
177
+ # fetch the sequence labels
178
+ ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
179
+ ix2 = self.label_end_ix[ix] - 1
180
+ ncap = ix2 - ix1 + 1 # number of captions available for this image
181
+ assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
182
+
183
+ if ncap < seq_per_img:
184
+ # we need to subsample (with replacement)
185
+ seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
186
+ for q in range(seq_per_img):
187
+ ixl = random.randint(ix1,ix2)
188
+ seq[q, :] = self.label[ixl, :self.seq_length]
189
+ else:
190
+ ixl = random.randint(ix1, ix2 - seq_per_img + 1)
191
+ seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
192
+
193
+ return seq
194
+
195
+ def collate_func(self, batch):
196
+ seq_per_img = self.seq_per_img
197
+
198
+ fc_batch = []
199
+ att_batch = []
200
+ label_batch = []
201
+
202
+ clip_vis_feat_batch = []
203
+
204
+ wrapped = False
205
+
206
+ infos = []
207
+ gts = []
208
+
209
+ for sample in batch:
210
+ # fetch image
211
+ # if self.use_clipscore:
212
+ tmp_fc, tmp_att, tmp_seq, \
213
+ ix, tmp_clip_vis_feat = sample
214
+
215
+ clip_vis_feat_batch.append(tmp_clip_vis_feat)
216
+ # else:
217
+ # tmp_fc, tmp_att, tmp_seq, \
218
+ # ix = sample
219
+
220
+ fc_batch.append(tmp_fc)
221
+ att_batch.append(tmp_att)
222
+
223
+ tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
224
+ if hasattr(self, 'h5_label_file'):
225
+ # if there is ground truth
226
+ tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
227
+ label_batch.append(tmp_label)
228
+
229
+ # Used for reward evaluation
230
+ if hasattr(self, 'h5_label_file'):
231
+ # if there is ground truth
232
+ gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
233
+ else:
234
+ gts.append([])
235
+
236
+ # record associated info as well
237
+ info_dict = {}
238
+ info_dict['ix'] = ix
239
+ info_dict['id'] = self.info['images'][ix]['id']
240
+ info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
241
+ infos.append(info_dict)
242
+
243
+ # #sort by att_feat length
244
+ # fc_batch, att_batch, label_batch, gts, infos = \
245
+ # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
246
+ if self.use_clipscore:
247
+ fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \
248
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True))
249
+ else:
250
+ fc_batch, att_batch, label_batch, gts, infos = \
251
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
252
+ data = {}
253
+ data['fc_feats'] = np.stack(fc_batch)
254
+ # merge att_feats
255
+ max_att_len = max([_.shape[0] for _ in att_batch])
256
+ data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
257
+ for i in range(len(att_batch)):
258
+ data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
259
+ data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
260
+ for i in range(len(att_batch)):
261
+ data['att_masks'][i, :att_batch[i].shape[0]] = 1
262
+ # set att_masks to None if attention features have same length
263
+ if data['att_masks'].sum() == data['att_masks'].size:
264
+ data['att_masks'] = None
265
+
266
+ # if self.use_clipscore:
267
+ data['clip_vis_feats'] = np.stack(clip_vis_feat_batch)
268
+
269
+ data['labels'] = np.vstack(label_batch)
270
+ # generate mask
271
+ nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
272
+ mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
273
+ for ix, row in enumerate(mask_batch):
274
+ row[:nonzeros[ix]] = 1
275
+ data['masks'] = mask_batch
276
+ data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
277
+ data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
278
+
279
+ data['gts'] = gts # all ground truth captions of each images
280
+ data['infos'] = infos
281
+
282
+ data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
283
+
284
+ return data
285
+
286
+ def __getitem__(self, ix):
287
+ """This function returns a tuple that is further passed to collate_fn
288
+ """
289
+ if self.use_att:
290
+ att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
291
+ # Reshape to K x C
292
+ att_feat = att_feat.reshape(-1, att_feat.shape[-1])
293
+ if self.norm_att_feat:
294
+ att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
295
+ if self.use_box:
296
+ box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
297
+ # devided by image width and height
298
+ x1,y1,x2,y2 = np.hsplit(box_feat, 4)
299
+ h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
300
+ box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
301
+ if self.norm_box_feat:
302
+ box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
303
+ att_feat = np.hstack([att_feat, box_feat])
304
+ # sort the features by the size of boxes
305
+ att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
306
+ else:
307
+ att_feat = np.zeros((0,0), dtype='float32')
308
+ if self.use_fc:
309
+ try:
310
+ fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
311
+ except:
312
+ # Use average of attention when there is no fc provided (For bottomup feature)
313
+ fc_feat = att_feat.mean(0)
314
+ else:
315
+ fc_feat = np.zeros((0), dtype='float32')
316
+ if hasattr(self, 'h5_label_file'):
317
+ seq = self.get_captions(ix, self.seq_per_img)
318
+ else:
319
+ seq = None
320
+
321
+ # if self.use_clipscore:
322
+ clip_vis_feat = self.clipscore_loader.get(
323
+ str(self.info['images'][ix]['id']))
324
+
325
+ return (fc_feat,
326
+ att_feat, seq,
327
+ ix, clip_vis_feat)
328
+
329
+ # return (fc_feat,
330
+ # att_feat, seq,
331
+ # ix)
332
+
333
+ def __len__(self):
334
+ return len(self.info['images'])
captioning/data/pth_loader_FineCapEval.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import json
6
+ import h5py
7
+ from lmdbdict import lmdbdict
8
+ from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
9
+ import os
10
+ import numpy as np
11
+ import numpy.random as npr
12
+ import random
13
+
14
+ import torch
15
+ import torch.utils.data as data
16
+
17
+ import multiprocessing
18
+ import six
19
+
20
+ verbose = True
21
+ # import torch
22
+ # if torch.cuda.current_device() in [0, -1]:
23
+ if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
24
+ verbose = False
25
+
26
+ class HybridLoader:
27
+ """
28
+ If db_path is a director, then use normal file loading
29
+ If lmdb, then load from lmdb
30
+ The loading method depend on extention.
31
+
32
+ in_memory: if in_memory is True, we save all the features in memory
33
+ For individual np(y|z)s, we don't need to do that because the system will do this for us.
34
+ Should be useful for lmdb or h5.
35
+ (Copied this idea from vilbert)
36
+ """
37
+ def __init__(self, db_path, ext, in_memory=False):
38
+ self.db_path = db_path
39
+ self.ext = ext
40
+ if self.ext == '.npy':
41
+ self.loader = lambda x: np.load(six.BytesIO(x))
42
+ else:
43
+ self.loader = lambda x: np.load(six.BytesIO(x))['feat']
44
+ if db_path.endswith('.lmdb'):
45
+ self.db_type = 'lmdb'
46
+ self.lmdb = lmdbdict(db_path, unsafe=True)
47
+ self.lmdb._key_dumps = DUMPS_FUNC['ascii']
48
+ self.lmdb._value_loads = LOADS_FUNC['identity']
49
+ elif db_path.endswith('.pth'): # Assume a key,value dictionary
50
+ self.db_type = 'pth'
51
+ self.feat_file = torch.load(db_path)
52
+ self.loader = lambda x: x
53
+ print('HybridLoader: ext is ignored')
54
+ elif db_path.endswith('h5'):
55
+ self.db_type = 'h5'
56
+ self.loader = lambda x: np.array(x).astype('float32')
57
+ else:
58
+ self.db_type = 'dir'
59
+
60
+ self.in_memory = in_memory
61
+ if self.in_memory:
62
+ self.features = {}
63
+
64
+ def get(self, key):
65
+
66
+ if self.in_memory and key in self.features:
67
+ # We save f_input because we want to save the
68
+ # compressed bytes to save memory
69
+ f_input = self.features[key]
70
+ elif self.db_type == 'lmdb':
71
+ f_input = self.lmdb[key]
72
+ elif self.db_type == 'pth':
73
+ f_input = self.feat_file[key]
74
+ elif self.db_type == 'h5':
75
+ f_input = h5py.File(self.db_path, 'r')[key]
76
+ else:
77
+ f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
78
+
79
+ if self.in_memory and key not in self.features:
80
+ self.features[key] = f_input
81
+
82
+ # load image
83
+ feat = self.loader(f_input)
84
+
85
+ return feat
86
+
87
+ class CaptionDataset(data.Dataset):
88
+
89
+ def get_vocab_size(self):
90
+ return self.vocab_size
91
+
92
+ def get_vocab(self):
93
+ return self.ix_to_word
94
+
95
+ def get_seq_length(self):
96
+ return self.seq_length
97
+
98
+ def __init__(self, opt):
99
+ self.opt = opt
100
+ self.seq_per_img = opt.seq_per_img
101
+
102
+ # feature related options
103
+ self.use_fc = getattr(opt, 'use_fc', True)
104
+ self.use_att = getattr(opt, 'use_att', True)
105
+ self.use_box = getattr(opt, 'use_box', 0)
106
+ self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
107
+ self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
108
+
109
+ # load the json file which contains additional information about the dataset
110
+ if verbose:
111
+ print('DataLoader loading json file: ', opt.input_json)
112
+ self.info = json.load(open(self.opt.input_json))
113
+ if 'ix_to_word' in self.info:
114
+ self.ix_to_word = self.info['ix_to_word']
115
+ self.vocab_size = len(self.ix_to_word)
116
+ if verbose:
117
+ print('vocab size is ', self.vocab_size)
118
+
119
+ # open the hdf5 file
120
+ if verbose:
121
+ print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
122
+ """
123
+ Setting input_label_h5 to none is used when only doing generation.
124
+ For example, when you need to test on coco test set.
125
+ """
126
+ if self.opt.input_label_h5 != 'none':
127
+ self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
128
+ # load in the sequence data
129
+ seq_size = self.h5_label_file['labels'].shape
130
+ self.label = self.h5_label_file['labels'][:]
131
+ self.seq_length = seq_size[1]
132
+ if verbose:
133
+ print('max sequence length in data is', self.seq_length)
134
+ # load the pointers in full to RAM (should be small enough)
135
+ self.label_start_ix = self.h5_label_file['label_start_ix'][:]
136
+ self.label_end_ix = self.h5_label_file['label_end_ix'][:]
137
+ else:
138
+ self.seq_length = 1
139
+
140
+ self.data_in_memory = getattr(opt, 'data_in_memory', False)
141
+ self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
142
+ self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
143
+ self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
144
+
145
+ self.use_clipscore = getattr(opt, 'use_clipscore', False)
146
+ if self.use_clipscore:
147
+ self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory)
148
+
149
+
150
+ self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
151
+ if verbose:
152
+ print('read %d image features' %(self.num_images))
153
+
154
+ # separate out indexes for each of the provided splits
155
+ self.split_ix = {'train': [], 'val': [], 'test': []}
156
+ for ix in range(len(self.info['images'])):
157
+ img = self.info['images'][ix]
158
+ if not 'split' in img:
159
+ self.split_ix['train'].append(ix)
160
+ self.split_ix['val'].append(ix)
161
+ self.split_ix['test'].append(ix)
162
+ elif img['split'] == 'train':
163
+ self.split_ix['train'].append(ix)
164
+ elif img['split'] == 'val':
165
+ self.split_ix['val'].append(ix)
166
+ elif img['split'] == 'test':
167
+ self.split_ix['test'].append(ix)
168
+ elif opt.train_only == 0: # restval
169
+ self.split_ix['train'].append(ix)
170
+
171
+ if verbose:
172
+ print('assigned %d images to split train' %len(self.split_ix['train']))
173
+ print('assigned %d images to split val' %len(self.split_ix['val']))
174
+ print('assigned %d images to split test' %len(self.split_ix['test']))
175
+
176
+ def get_captions(self, ix, seq_per_img):
177
+ # fetch the sequence labels
178
+ ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
179
+ ix2 = self.label_end_ix[ix] - 1
180
+ ncap = ix2 - ix1 + 1 # number of captions available for this image
181
+ assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
182
+
183
+ if ncap < seq_per_img:
184
+ # we need to subsample (with replacement)
185
+ seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
186
+ for q in range(seq_per_img):
187
+ ixl = random.randint(ix1,ix2)
188
+ seq[q, :] = self.label[ixl, :self.seq_length]
189
+ else:
190
+ ixl = random.randint(ix1, ix2 - seq_per_img + 1)
191
+ seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
192
+
193
+ return seq
194
+
195
+ def collate_func(self, batch):
196
+ seq_per_img = self.seq_per_img
197
+
198
+ fc_batch = []
199
+ att_batch = []
200
+ label_batch = []
201
+
202
+ clip_vis_feat_batch = []
203
+
204
+ wrapped = False
205
+
206
+ infos = []
207
+ gts = []
208
+
209
+ for sample in batch:
210
+ # fetch image
211
+ if self.use_clipscore:
212
+ tmp_fc, tmp_att, tmp_seq, \
213
+ ix, tmp_clip_vis_feat = sample
214
+
215
+ clip_vis_feat_batch.append(tmp_clip_vis_feat)
216
+ else:
217
+ tmp_fc, tmp_att, tmp_seq, \
218
+ ix = sample
219
+
220
+ fc_batch.append(tmp_fc)
221
+ att_batch.append(tmp_att)
222
+
223
+ tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
224
+ if hasattr(self, 'h5_label_file'):
225
+ # if there is ground truth
226
+ tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
227
+ label_batch.append(tmp_label)
228
+
229
+ # Used for reward evaluation
230
+ if hasattr(self, 'h5_label_file'):
231
+ # if there is ground truth
232
+ gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
233
+ else:
234
+ gts.append([])
235
+
236
+ # record associated info as well
237
+ info_dict = {}
238
+ info_dict['ix'] = ix
239
+ info_dict['id'] = self.info['images'][ix]['id']
240
+ info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
241
+ infos.append(info_dict)
242
+
243
+ # #sort by att_feat length
244
+ # fc_batch, att_batch, label_batch, gts, infos = \
245
+ # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
246
+ if self.use_clipscore:
247
+ fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \
248
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True))
249
+ else:
250
+ fc_batch, att_batch, label_batch, gts, infos = \
251
+ zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
252
+ data = {}
253
+ data['fc_feats'] = np.stack(fc_batch)
254
+ # merge att_feats
255
+ max_att_len = max([_.shape[0] for _ in att_batch])
256
+ data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
257
+ for i in range(len(att_batch)):
258
+ data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
259
+ data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
260
+ for i in range(len(att_batch)):
261
+ data['att_masks'][i, :att_batch[i].shape[0]] = 1
262
+ # set att_masks to None if attention features have same length
263
+ if data['att_masks'].sum() == data['att_masks'].size:
264
+ data['att_masks'] = None
265
+
266
+ if self.use_clipscore:
267
+ data['clip_vis_feats'] = np.stack(clip_vis_feat_batch)
268
+
269
+ data['labels'] = np.vstack(label_batch)
270
+ # generate mask
271
+ nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
272
+ mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
273
+ for ix, row in enumerate(mask_batch):
274
+ row[:nonzeros[ix]] = 1
275
+ data['masks'] = mask_batch
276
+ data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
277
+ data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
278
+
279
+ data['gts'] = gts # all ground truth captions of each images
280
+ data['infos'] = infos
281
+
282
+ data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
283
+
284
+ return data
285
+
286
+ def __getitem__(self, ix):
287
+ """This function returns a tuple that is further passed to collate_fn
288
+ """
289
+ if self.use_att:
290
+ att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
291
+ # Reshape to K x C
292
+ att_feat = att_feat.reshape(-1, att_feat.shape[-1])
293
+ if self.norm_att_feat:
294
+ att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
295
+ if self.use_box:
296
+ box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
297
+ # devided by image width and height
298
+ x1,y1,x2,y2 = np.hsplit(box_feat, 4)
299
+ h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
300
+ box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
301
+ if self.norm_box_feat:
302
+ box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
303
+ att_feat = np.hstack([att_feat, box_feat])
304
+ # sort the features by the size of boxes
305
+ att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
306
+ else:
307
+ att_feat = np.zeros((0,0), dtype='float32')
308
+ if self.use_fc:
309
+ try:
310
+ fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
311
+ except:
312
+ # Use average of attention when there is no fc provided (For bottomup feature)
313
+ fc_feat = att_feat.mean(0)
314
+ else:
315
+ fc_feat = np.zeros((0), dtype='float32')
316
+ if hasattr(self, 'h5_label_file'):
317
+ seq = self.get_captions(ix, self.seq_per_img)
318
+ else:
319
+ seq = None
320
+
321
+ if self.use_clipscore:
322
+ clip_vis_feat = self.clipscore_loader.get(
323
+ str(self.info['images'][ix]['id']))
324
+
325
+ return (fc_feat,
326
+ att_feat, seq,
327
+ ix, clip_vis_feat)
328
+
329
+ return (fc_feat,
330
+ att_feat, seq,
331
+ ix)
332
+
333
+ def __len__(self):
334
+ return len(self.info['images'])
captioning/models/AoAModel.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation for paper 'Attention on Attention for Image Captioning'
2
+ # https://arxiv.org/abs/1908.06954
3
+
4
+ # RT: Code from original author's repo: https://github.com/husthuaan/AoANet/
5
+
6
+ from __future__ import absolute_import
7
+ from __future__ import division
8
+ from __future__ import print_function
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from .AttModel import pack_wrapper, AttModel, Attention
15
+ from .TransformerModel import LayerNorm, attention, clones, SublayerConnection, PositionwiseFeedForward
16
+
17
+ class MultiHeadedDotAttention(nn.Module):
18
+ def __init__(self, h, d_model, dropout=0.1, scale=1, project_k_v=1, use_output_layer=1, do_aoa=0, norm_q=0, dropout_aoa=0.3):
19
+ super(MultiHeadedDotAttention, self).__init__()
20
+ assert d_model * scale % h == 0
21
+ # We assume d_v always equals d_k
22
+ self.d_k = d_model * scale // h
23
+ self.h = h
24
+
25
+ # Do we need to do linear projections on K and V?
26
+ self.project_k_v = project_k_v
27
+
28
+ # normalize the query?
29
+ if norm_q:
30
+ self.norm = LayerNorm(d_model)
31
+ else:
32
+ self.norm = lambda x:x
33
+ self.linears = clones(nn.Linear(d_model, d_model * scale), 1 + 2 * project_k_v)
34
+
35
+ # output linear layer after the multi-head attention?
36
+ self.output_layer = nn.Linear(d_model * scale, d_model)
37
+
38
+ # apply aoa after attention?
39
+ self.use_aoa = do_aoa
40
+ if self.use_aoa:
41
+ self.aoa_layer = nn.Sequential(nn.Linear((1 + scale) * d_model, 2 * d_model), nn.GLU())
42
+ # dropout to the input of AoA layer
43
+ if dropout_aoa > 0:
44
+ self.dropout_aoa = nn.Dropout(p=dropout_aoa)
45
+ else:
46
+ self.dropout_aoa = lambda x:x
47
+
48
+ if self.use_aoa or not use_output_layer:
49
+ # AoA doesn't need the output linear layer
50
+ del self.output_layer
51
+ self.output_layer = lambda x:x
52
+
53
+ self.attn = None
54
+ self.dropout = nn.Dropout(p=dropout)
55
+
56
+ def forward(self, query, value, key, mask=None):
57
+ if mask is not None:
58
+ if len(mask.size()) == 2:
59
+ mask = mask.unsqueeze(-2)
60
+ # Same mask applied to all h heads.
61
+ mask = mask.unsqueeze(1)
62
+
63
+ single_query = 0
64
+ if len(query.size()) == 2:
65
+ single_query = 1
66
+ query = query.unsqueeze(1)
67
+
68
+ nbatches = query.size(0)
69
+
70
+ query = self.norm(query)
71
+
72
+ # Do all the linear projections in batch from d_model => h x d_k
73
+ if self.project_k_v == 0:
74
+ query_ = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
75
+ key_ = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
76
+ value_ = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
77
+ else:
78
+ query_, key_, value_ = \
79
+ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
80
+ for l, x in zip(self.linears, (query, key, value))]
81
+
82
+ # Apply attention on all the projected vectors in batch.
83
+ x, self.attn = attention(query_, key_, value_, mask=mask,
84
+ dropout=self.dropout)
85
+
86
+ # "Concat" using a view
87
+ x = x.transpose(1, 2).contiguous() \
88
+ .view(nbatches, -1, self.h * self.d_k)
89
+
90
+ if self.use_aoa:
91
+ # Apply AoA
92
+ x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1)))
93
+ x = self.output_layer(x)
94
+
95
+ if single_query:
96
+ query = query.squeeze(1)
97
+ x = x.squeeze(1)
98
+ return x
99
+
100
+ class AoA_Refiner_Layer(nn.Module):
101
+ def __init__(self, size, self_attn, feed_forward, dropout):
102
+ super(AoA_Refiner_Layer, self).__init__()
103
+ self.self_attn = self_attn
104
+ self.feed_forward = feed_forward
105
+ self.use_ff = 0
106
+ if self.feed_forward is not None:
107
+ self.use_ff = 1
108
+ self.sublayer = clones(SublayerConnection(size, dropout), 1+self.use_ff)
109
+ self.size = size
110
+
111
+ def forward(self, x, mask):
112
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
113
+ return self.sublayer[-1](x, self.feed_forward) if self.use_ff else x
114
+
115
+ class AoA_Refiner_Core(nn.Module):
116
+ def __init__(self, opt):
117
+ super(AoA_Refiner_Core, self).__init__()
118
+ attn = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=1, scale=opt.multi_head_scale, do_aoa=opt.refine_aoa, norm_q=0, dropout_aoa=getattr(opt, 'dropout_aoa', 0.3))
119
+ layer = AoA_Refiner_Layer(opt.rnn_size, attn, PositionwiseFeedForward(opt.rnn_size, 2048, 0.1) if opt.use_ff else None, 0.1)
120
+ self.layers = clones(layer, 6)
121
+ self.norm = LayerNorm(layer.size)
122
+
123
+ def forward(self, x, mask):
124
+ for layer in self.layers:
125
+ x = layer(x, mask)
126
+ return self.norm(x)
127
+
128
+ class AoA_Decoder_Core(nn.Module):
129
+ def __init__(self, opt):
130
+ super(AoA_Decoder_Core, self).__init__()
131
+ self.drop_prob_lm = opt.drop_prob_lm
132
+ self.d_model = opt.rnn_size
133
+ self.use_multi_head = opt.use_multi_head
134
+ self.multi_head_scale = opt.multi_head_scale
135
+ self.use_ctx_drop = getattr(opt, 'ctx_drop', 0)
136
+ self.out_res = getattr(opt, 'out_res', 0)
137
+ self.decoder_type = getattr(opt, 'decoder_type', 'AoA')
138
+ self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1
139
+ self.out_drop = nn.Dropout(self.drop_prob_lm)
140
+
141
+ if self.decoder_type == 'AoA':
142
+ # AoA layer
143
+ self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, 2 * opt.rnn_size), nn.GLU())
144
+ elif self.decoder_type == 'LSTM':
145
+ # LSTM layer
146
+ self.att2ctx = nn.LSTMCell(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size)
147
+ else:
148
+ # Base linear layer
149
+ self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size), nn.ReLU())
150
+
151
+ # if opt.use_multi_head == 1: # TODO, not implemented for now
152
+ # self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale)
153
+ if opt.use_multi_head == 2:
154
+ self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1)
155
+ else:
156
+ self.attention = Attention(opt)
157
+
158
+ if self.use_ctx_drop:
159
+ self.ctx_drop = nn.Dropout(self.drop_prob_lm)
160
+ else:
161
+ self.ctx_drop = lambda x :x
162
+
163
+ def forward(self, xt, mean_feats, att_feats, p_att_feats, state, att_masks=None):
164
+ # state[0][1] is the context vector at the last step
165
+ h_att, c_att = self.att_lstm(torch.cat([xt, mean_feats + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0]))
166
+
167
+ if self.use_multi_head == 2:
168
+ att = self.attention(h_att, p_att_feats.narrow(2, 0, self.multi_head_scale * self.d_model), p_att_feats.narrow(2, self.multi_head_scale * self.d_model, self.multi_head_scale * self.d_model), att_masks)
169
+ else:
170
+ att = self.attention(h_att, att_feats, p_att_feats, att_masks)
171
+
172
+ ctx_input = torch.cat([att, h_att], 1)
173
+ if self.decoder_type == 'LSTM':
174
+ output, c_logic = self.att2ctx(ctx_input, (state[0][1], state[1][1]))
175
+ state = (torch.stack((h_att, output)), torch.stack((c_att, c_logic)))
176
+ else:
177
+ output = self.att2ctx(ctx_input)
178
+ # save the context vector to state[0][1]
179
+ state = (torch.stack((h_att, output)), torch.stack((c_att, state[1][1])))
180
+
181
+ if self.out_res:
182
+ # add residual connection
183
+ output = output + h_att
184
+
185
+ output = self.out_drop(output)
186
+ return output, state
187
+
188
+ class AoAModel(AttModel):
189
+ def __init__(self, opt):
190
+ super(AoAModel, self).__init__(opt)
191
+ self.num_layers = 2
192
+ # mean pooling
193
+ self.use_mean_feats = getattr(opt, 'mean_feats', 1)
194
+ if opt.use_multi_head == 2:
195
+ del self.ctx2att
196
+ self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.multi_head_scale * opt.rnn_size)
197
+
198
+ if self.use_mean_feats:
199
+ del self.fc_embed
200
+ if opt.refine:
201
+ self.refiner = AoA_Refiner_Core(opt)
202
+ else:
203
+ self.refiner = lambda x,y : x
204
+ self.core = AoA_Decoder_Core(opt)
205
+
206
+ self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
207
+
208
+
209
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
210
+ att_feats, att_masks = self.clip_att(att_feats, att_masks)
211
+
212
+ # embed att feats
213
+ att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
214
+ att_feats = self.refiner(att_feats, att_masks)
215
+
216
+ if self.use_mean_feats:
217
+ # meaning pooling
218
+ if att_masks is None:
219
+ mean_feats = torch.mean(att_feats, dim=1)
220
+ else:
221
+ mean_feats = (torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1))
222
+ else:
223
+ mean_feats = self.fc_embed(fc_feats)
224
+
225
+ # Project the attention feats first to reduce memory and computation.
226
+ p_att_feats = self.ctx2att(att_feats)
227
+
228
+ return mean_feats, att_feats, p_att_feats, att_masks
captioning/models/AttEnsemble.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is the implementation for ensemble evaluation.
2
+
3
+ from __future__ import absolute_import
4
+ from __future__ import division
5
+ from __future__ import print_function
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.autograd import *
12
+
13
+ from .CaptionModel import CaptionModel
14
+ from .AttModel import pack_wrapper, AttModel
15
+
16
+ class AttEnsemble(AttModel):
17
+ def __init__(self, models, weights=None):
18
+ CaptionModel.__init__(self)
19
+ # super(AttEnsemble, self).__init__()
20
+
21
+ self.models = nn.ModuleList(models)
22
+ self.vocab_size = models[0].vocab_size
23
+ self.seq_length = models[0].seq_length
24
+ self.bad_endings_ix = models[0].bad_endings_ix
25
+ self.ss_prob = 0
26
+ weights = weights or [1.0] * len(self.models)
27
+ self.register_buffer('weights', torch.tensor(weights))
28
+
29
+ def init_hidden(self, batch_size):
30
+ state = [m.init_hidden(batch_size) for m in self.models]
31
+ return self.pack_state(state)
32
+
33
+ def pack_state(self, state):
34
+ self.state_lengths = [len(_) for _ in state]
35
+ return sum([list(_) for _ in state], [])
36
+
37
+ def unpack_state(self, state):
38
+ out = []
39
+ for l in self.state_lengths:
40
+ out.append(state[:l])
41
+ state = state[l:]
42
+ return out
43
+
44
+ def embed(self, it):
45
+ return [m.embed(it) for m in self.models]
46
+
47
+ def core(self, *args):
48
+ return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))])
49
+
50
+ def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state, output_logsoftmax=1):
51
+ # 'it' contains a word index
52
+ xt = self.embed(it)
53
+
54
+ state = self.unpack_state(state)
55
+ output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks)
56
+ logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mul(self.weights).div(self.weights.sum()).sum(-1).log()
57
+
58
+ return logprobs, self.pack_state(state)
59
+
60
+ def _prepare_feature(self, *args):
61
+ return tuple(zip(*[m._prepare_feature(*args) for m in self.models]))
62
+
63
+ def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
64
+ beam_size = opt.get('beam_size', 10)
65
+ batch_size = fc_feats.size(0)
66
+
67
+ fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
68
+
69
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
70
+ seq = torch.LongTensor(self.seq_length, batch_size).zero_()
71
+ seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1)
72
+ # lets process every image independently for now, for simplicity
73
+
74
+ self.done_beams = [[] for _ in range(batch_size)]
75
+ for k in range(batch_size):
76
+ state = self.init_hidden(beam_size)
77
+ tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)]
78
+ tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
79
+ tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
80
+ tmp_att_masks = [att_masks[i][k:k+1].expand(*((beam_size,)+att_masks[i].size()[1:])).contiguous() if att_masks[i] is not None else att_masks[i] for i,m in enumerate(self.models)]
81
+
82
+ it = fc_feats[0].data.new(beam_size).long().zero_()
83
+ logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
84
+
85
+ self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
86
+ seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
87
+ seqLogprobs[:, k] = self.done_beams[k][0]['logps']
88
+ # return the samples and their log likelihoods
89
+ return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
90
+ # return the samples and their log likelihoods
captioning/models/AttModel.py ADDED
@@ -0,0 +1,969 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains Att2in2, AdaAtt, AdaAttMO, UpDown model
2
+
3
+ # AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning
4
+ # https://arxiv.org/abs/1612.01887
5
+ # AdaAttMO is a modified version with maxout lstm
6
+
7
+ # Att2in is from Self-critical Sequence Training for Image Captioning
8
+ # https://arxiv.org/abs/1612.00563
9
+ # In this file we only have Att2in2, which is a slightly different version of att2in,
10
+ # in which the img feature embedding and word embedding is the same as what in adaatt.
11
+
12
+ # UpDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA
13
+ # https://arxiv.org/abs/1707.07998
14
+ # However, it may not be identical to the author's architecture.
15
+
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+ from __future__ import print_function
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from . import utils
25
+ from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
26
+
27
+ from .CaptionModel import CaptionModel
28
+
29
+ bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
30
+ bad_endings += ['the']
31
+
32
+ def sort_pack_padded_sequence(input, lengths):
33
+ sorted_lengths, indices = torch.sort(lengths, descending=True)
34
+ # tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
35
+ tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
36
+ inv_ix = indices.clone()
37
+ inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
38
+ return tmp, inv_ix
39
+
40
+ def pad_unsort_packed_sequence(input, inv_ix):
41
+ tmp, _ = pad_packed_sequence(input, batch_first=True)
42
+ tmp = tmp[inv_ix]
43
+ return tmp
44
+
45
+ def pack_wrapper(module, att_feats, att_masks):
46
+ if att_masks is not None:
47
+ packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
48
+ return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
49
+ else:
50
+ return module(att_feats)
51
+
52
+ class AttModel(CaptionModel):
53
+ def __init__(self, opt):
54
+ super(AttModel, self).__init__()
55
+ self.vocab_size = opt.vocab_size
56
+ self.input_encoding_size = opt.input_encoding_size
57
+ #self.rnn_type = opt.rnn_type
58
+ self.rnn_size = opt.rnn_size
59
+ self.num_layers = opt.num_layers
60
+ self.drop_prob_lm = opt.drop_prob_lm
61
+ self.seq_length = getattr(opt, 'max_length', 20) or opt.seq_length # maximum sample length
62
+ self.fc_feat_size = opt.fc_feat_size
63
+ self.att_feat_size = opt.att_feat_size
64
+ self.att_hid_size = opt.att_hid_size
65
+
66
+ self.bos_idx = getattr(opt, 'bos_idx', 0)
67
+ self.eos_idx = getattr(opt, 'eos_idx', 0)
68
+ self.pad_idx = getattr(opt, 'pad_idx', 0)
69
+
70
+ self.use_bn = getattr(opt, 'use_bn', 0)
71
+
72
+ self.ss_prob = 0.0 # Schedule sampling probability
73
+
74
+ self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size),
75
+ nn.ReLU(),
76
+ nn.Dropout(self.drop_prob_lm))
77
+ self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size),
78
+ nn.ReLU(),
79
+ nn.Dropout(self.drop_prob_lm))
80
+ self.att_embed = nn.Sequential(*(
81
+ ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
82
+ (nn.Linear(self.att_feat_size, self.rnn_size),
83
+ nn.ReLU(),
84
+ nn.Dropout(self.drop_prob_lm))+
85
+ ((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ())))
86
+
87
+ self.logit_layers = getattr(opt, 'logit_layers', 1)
88
+ if self.logit_layers == 1:
89
+ self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
90
+ else:
91
+ self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)]
92
+ self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)]))
93
+ self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size)
94
+
95
+ # For remove bad endding
96
+ self.vocab = opt.vocab
97
+ self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings]
98
+
99
+ def init_hidden(self, bsz):
100
+ weight = self.logit.weight \
101
+ if hasattr(self.logit, "weight") \
102
+ else self.logit[0].weight
103
+ return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
104
+ weight.new_zeros(self.num_layers, bsz, self.rnn_size))
105
+
106
+ def clip_att(self, att_feats, att_masks):
107
+ # Clip the length of att_masks and att_feats to the maximum length
108
+ if att_masks is not None:
109
+ max_len = att_masks.data.long().sum(1).max()
110
+ att_feats = att_feats[:, :max_len].contiguous()
111
+ att_masks = att_masks[:, :max_len].contiguous()
112
+ return att_feats, att_masks
113
+
114
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
115
+ att_feats, att_masks = self.clip_att(att_feats, att_masks)
116
+
117
+ # embed fc and att feats
118
+ fc_feats = self.fc_embed(fc_feats)
119
+ att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
120
+
121
+ # Project the attention feats first to reduce memory and computation comsumptions.
122
+ p_att_feats = self.ctx2att(att_feats)
123
+
124
+ return fc_feats, att_feats, p_att_feats, att_masks
125
+
126
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
127
+ batch_size = fc_feats.size(0)
128
+ if seq.ndim == 3: # B * seq_per_img * seq_len
129
+ seq = seq.reshape(-1, seq.shape[2])
130
+ seq_per_img = seq.shape[0] // batch_size
131
+ state = self.init_hidden(batch_size*seq_per_img)
132
+
133
+ outputs = fc_feats.new_zeros(batch_size*seq_per_img, seq.size(1), self.vocab_size+1)
134
+
135
+ # Prepare the features
136
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
137
+ # pp_att_feats is used for attention, we cache it in advance to reduce computation cost
138
+
139
+ if seq_per_img > 1:
140
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(seq_per_img,
141
+ [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
142
+ )
143
+
144
+ for i in range(seq.size(1)):
145
+ if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
146
+ sample_prob = fc_feats.new(batch_size*seq_per_img).uniform_(0, 1)
147
+ sample_mask = sample_prob < self.ss_prob
148
+ if sample_mask.sum() == 0:
149
+ it = seq[:, i].clone()
150
+ else:
151
+ sample_ind = sample_mask.nonzero().view(-1)
152
+ it = seq[:, i].data.clone()
153
+ prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1)
154
+ it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
155
+ else:
156
+ it = seq[:, i].clone()
157
+ # break if all the sequences end
158
+ if i >= 1 and seq[:, i].sum() == 0:
159
+ break
160
+
161
+ output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
162
+ outputs[:, i] = output
163
+
164
+ return outputs
165
+
166
+ def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1):
167
+ # 'it' contains a word index
168
+ xt = self.embed(it)
169
+
170
+ output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
171
+ if output_logsoftmax:
172
+ logprobs = F.log_softmax(self.logit(output), dim=1)
173
+ else:
174
+ logprobs = self.logit(output)
175
+
176
+ return logprobs, state
177
+
178
+ def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
179
+ beam_size = opt.get('beam_size', 10)
180
+ group_size = opt.get('group_size', 1)
181
+ sample_n = opt.get('sample_n', 10)
182
+ # when sample_n == beam_size then each beam is a sample.
183
+ assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
184
+ batch_size = fc_feats.size(0)
185
+
186
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
187
+
188
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
189
+ seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
190
+ seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
191
+ # lets process every image independently for now, for simplicity
192
+
193
+ self.done_beams = [[] for _ in range(batch_size)]
194
+ for k in range(batch_size):
195
+ state = self.init_hidden(beam_size)
196
+ tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks = utils.repeat_tensors(beam_size,
197
+ [p_fc_feats[k:k+1], p_att_feats[k:k+1], pp_att_feats[k:k+1], p_att_masks[k:k+1] if att_masks is not None else None]
198
+ )
199
+
200
+ for t in range(1):
201
+ if t == 0: # input <bos>
202
+ it = fc_feats.new_full([beam_size], self.bos_idx, dtype=torch.long)
203
+
204
+ logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
205
+
206
+ self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
207
+ if sample_n == beam_size:
208
+ for _n in range(sample_n):
209
+ seq[k*sample_n+_n, :] = self.done_beams[k][_n]['seq']
210
+ seqLogprobs[k*sample_n+_n, :] = self.done_beams[k][_n]['logps']
211
+ else:
212
+ seq[k, :] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
213
+ seqLogprobs[k, :] = self.done_beams[k][0]['logps']
214
+ # return the samples and their log likelihoods
215
+ return seq, seqLogprobs
216
+
217
+
218
+ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
219
+ beam_size = opt.get('beam_size', 10)
220
+ group_size = opt.get('group_size', 1)
221
+ sample_n = opt.get('sample_n', 10)
222
+ # when sample_n == beam_size then each beam is a sample.
223
+ assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
224
+ batch_size = fc_feats.size(0)
225
+
226
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
227
+
228
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
229
+ seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
230
+ seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
231
+ # lets process every image independently for now, for simplicity
232
+
233
+ self.done_beams = [[] for _ in range(batch_size)]
234
+
235
+ state = self.init_hidden(batch_size)
236
+
237
+ # first step, feed bos
238
+ it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
239
+ logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
240
+
241
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size,
242
+ [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
243
+ )
244
+ self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt)
245
+ for k in range(batch_size):
246
+ if sample_n == beam_size:
247
+ for _n in range(sample_n):
248
+ seq_len = self.done_beams[k][_n]['seq'].shape[0]
249
+ seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq']
250
+ seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps']
251
+ else:
252
+ seq_len = self.done_beams[k][0]['seq'].shape[0]
253
+ seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
254
+ seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
255
+ # return the samples and their log likelihoods
256
+ return seq, seqLogprobs
257
+
258
+ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
259
+
260
+ sample_method = opt.get('sample_method', 'greedy')
261
+ beam_size = opt.get('beam_size', 1)
262
+ temperature = opt.get('temperature', 1.0)
263
+ sample_n = int(opt.get('sample_n', 1))
264
+ group_size = opt.get('group_size', 1)
265
+ output_logsoftmax = opt.get('output_logsoftmax', 1)
266
+ decoding_constraint = opt.get('decoding_constraint', 0)
267
+ block_trigrams = opt.get('block_trigrams', 0)
268
+ remove_bad_endings = opt.get('remove_bad_endings', 0)
269
+ if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
270
+ return self._sample_beam(fc_feats, att_feats, att_masks, opt)
271
+ if group_size > 1:
272
+ return self._diverse_sample(fc_feats, att_feats, att_masks, opt)
273
+
274
+ batch_size = fc_feats.size(0)
275
+ state = self.init_hidden(batch_size*sample_n)
276
+
277
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
278
+
279
+ if sample_n > 1:
280
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n,
281
+ [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
282
+ )
283
+
284
+ trigrams = [] # will be a list of batch_size dictionaries
285
+
286
+ seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
287
+ seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
288
+ for t in range(self.seq_length + 1):
289
+ if t == 0: # input <bos>
290
+ it = fc_feats.new_full([batch_size*sample_n], self.bos_idx, dtype=torch.long)
291
+
292
+ logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, output_logsoftmax=output_logsoftmax)
293
+
294
+ if decoding_constraint and t > 0:
295
+ tmp = logprobs.new_zeros(logprobs.size())
296
+ tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
297
+ logprobs = logprobs + tmp
298
+
299
+ if remove_bad_endings and t > 0:
300
+ tmp = logprobs.new_zeros(logprobs.size())
301
+ prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
302
+ # Make it impossible to generate bad_endings
303
+ tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
304
+ logprobs = logprobs + tmp
305
+
306
+ # Mess with trigrams
307
+ # Copy from https://github.com/lukemelas/image-paragraph-captioning
308
+ if block_trigrams and t >= 3:
309
+ # Store trigram generated at last step
310
+ prev_two_batch = seq[:,t-3:t-1]
311
+ for i in range(batch_size): # = seq.size(0)
312
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
313
+ current = seq[i][t-1]
314
+ if t == 3: # initialize
315
+ trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
316
+ elif t > 3:
317
+ if prev_two in trigrams[i]: # add to list
318
+ trigrams[i][prev_two].append(current)
319
+ else: # create list
320
+ trigrams[i][prev_two] = [current]
321
+ # Block used trigrams at next step
322
+ prev_two_batch = seq[:,t-2:t]
323
+ mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size
324
+ for i in range(batch_size):
325
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
326
+ if prev_two in trigrams[i]:
327
+ for j in trigrams[i][prev_two]:
328
+ mask[i,j] += 1
329
+ # Apply mask to log probs
330
+ #logprobs = logprobs - (mask * 1e9)
331
+ alpha = 2.0 # = 4
332
+ logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
333
+
334
+ # sample the next word
335
+ if t == self.seq_length: # skip if we achieve maximum length
336
+ break
337
+ it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)
338
+
339
+ # stop when all finished
340
+ if t == 0:
341
+ unfinished = it != self.eos_idx
342
+ else:
343
+ it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0
344
+ logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs)
345
+ unfinished = unfinished & (it != self.eos_idx)
346
+ seq[:,t] = it
347
+ seqLogprobs[:,t] = logprobs
348
+ # quit loop if all sequences have finished
349
+ if unfinished.sum() == 0:
350
+ break
351
+
352
+ return seq, seqLogprobs
353
+
354
+ def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}):
355
+
356
+ sample_method = opt.get('sample_method', 'greedy')
357
+ beam_size = opt.get('beam_size', 1)
358
+ temperature = opt.get('temperature', 1.0)
359
+ group_size = opt.get('group_size', 1)
360
+ diversity_lambda = opt.get('diversity_lambda', 0.5)
361
+ decoding_constraint = opt.get('decoding_constraint', 0)
362
+ block_trigrams = opt.get('block_trigrams', 0)
363
+ remove_bad_endings = opt.get('remove_bad_endings', 0)
364
+
365
+ batch_size = fc_feats.size(0)
366
+ state = self.init_hidden(batch_size)
367
+
368
+ p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
369
+
370
+ trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries
371
+
372
+ seq_table = [fc_feats.new_full((batch_size, self.seq_length), self.pad_idx, dtype=torch.long) for _ in range(group_size)]
373
+ seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.seq_length) for _ in range(group_size)]
374
+ state_table = [self.init_hidden(batch_size) for _ in range(group_size)]
375
+
376
+ for tt in range(self.seq_length + group_size):
377
+ for divm in range(group_size):
378
+ t = tt - divm
379
+ seq = seq_table[divm]
380
+ seqLogprobs = seqLogprobs_table[divm]
381
+ trigrams = trigrams_table[divm]
382
+ if t >= 0 and t <= self.seq_length-1:
383
+ if t == 0: # input <bos>
384
+ it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
385
+ else:
386
+ it = seq[:, t-1] # changed
387
+
388
+ logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state_table[divm]) # changed
389
+ logprobs = F.log_softmax(logprobs / temperature, dim=-1)
390
+
391
+ # Add diversity
392
+ if divm > 0:
393
+ unaug_logprobs = logprobs.clone()
394
+ for prev_choice in range(divm):
395
+ prev_decisions = seq_table[prev_choice][:, t]
396
+ logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda
397
+
398
+ if decoding_constraint and t > 0:
399
+ tmp = logprobs.new_zeros(logprobs.size())
400
+ tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
401
+ logprobs = logprobs + tmp
402
+
403
+ if remove_bad_endings and t > 0:
404
+ tmp = logprobs.new_zeros(logprobs.size())
405
+ prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
406
+ # Impossible to generate remove_bad_endings
407
+ tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
408
+ logprobs = logprobs + tmp
409
+
410
+ # Mess with trigrams
411
+ if block_trigrams and t >= 3:
412
+ # Store trigram generated at last step
413
+ prev_two_batch = seq[:,t-3:t-1]
414
+ for i in range(batch_size): # = seq.size(0)
415
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
416
+ current = seq[i][t-1]
417
+ if t == 3: # initialize
418
+ trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
419
+ elif t > 3:
420
+ if prev_two in trigrams[i]: # add to list
421
+ trigrams[i][prev_two].append(current)
422
+ else: # create list
423
+ trigrams[i][prev_two] = [current]
424
+ # Block used trigrams at next step
425
+ prev_two_batch = seq[:,t-2:t]
426
+ mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
427
+ for i in range(batch_size):
428
+ prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
429
+ if prev_two in trigrams[i]:
430
+ for j in trigrams[i][prev_two]:
431
+ mask[i,j] += 1
432
+ # Apply mask to log probs
433
+ #logprobs = logprobs - (mask * 1e9)
434
+ alpha = 2.0 # = 4
435
+ logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
436
+
437
+ it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1)
438
+
439
+ # stop when all finished
440
+ if t == 0:
441
+ unfinished = it != self.eos_idx
442
+ else:
443
+ unfinished = (seq[:,t-1] != self.pad_idx) & (seq[:,t-1] != self.eos_idx)
444
+ it[~unfinished] = self.pad_idx
445
+ unfinished = unfinished & (it != self.eos_idx) # changed
446
+ seq[:,t] = it
447
+ seqLogprobs[:,t] = sampleLogprobs.view(-1)
448
+
449
+ return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, 1).reshape(batch_size * group_size, -1)
450
+
451
+ class AdaAtt_lstm(nn.Module):
452
+ def __init__(self, opt, use_maxout=True):
453
+ super(AdaAtt_lstm, self).__init__()
454
+ self.input_encoding_size = opt.input_encoding_size
455
+ #self.rnn_type = opt.rnn_type
456
+ self.rnn_size = opt.rnn_size
457
+ self.num_layers = opt.num_layers
458
+ self.drop_prob_lm = opt.drop_prob_lm
459
+ self.fc_feat_size = opt.fc_feat_size
460
+ self.att_feat_size = opt.att_feat_size
461
+ self.att_hid_size = opt.att_hid_size
462
+
463
+ self.use_maxout = use_maxout
464
+
465
+ # Build a LSTM
466
+ self.w2h = nn.Linear(self.input_encoding_size, (4+(use_maxout==True)) * self.rnn_size)
467
+ self.v2h = nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size)
468
+
469
+ self.i2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers - 1)])
470
+ self.h2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers)])
471
+
472
+ # Layers for getting the fake region
473
+ if self.num_layers == 1:
474
+ self.r_w2h = nn.Linear(self.input_encoding_size, self.rnn_size)
475
+ self.r_v2h = nn.Linear(self.rnn_size, self.rnn_size)
476
+ else:
477
+ self.r_i2h = nn.Linear(self.rnn_size, self.rnn_size)
478
+ self.r_h2h = nn.Linear(self.rnn_size, self.rnn_size)
479
+
480
+
481
+ def forward(self, xt, img_fc, state):
482
+
483
+ hs = []
484
+ cs = []
485
+ for L in range(self.num_layers):
486
+ # c,h from previous timesteps
487
+ prev_h = state[0][L]
488
+ prev_c = state[1][L]
489
+ # the input to this layer
490
+ if L == 0:
491
+ x = xt
492
+ i2h = self.w2h(x) + self.v2h(img_fc)
493
+ else:
494
+ x = hs[-1]
495
+ x = F.dropout(x, self.drop_prob_lm, self.training)
496
+ i2h = self.i2h[L-1](x)
497
+
498
+ all_input_sums = i2h+self.h2h[L](prev_h)
499
+
500
+ sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
501
+ sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
502
+ # decode the gates
503
+ in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
504
+ forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
505
+ out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
506
+ # decode the write inputs
507
+ if not self.use_maxout:
508
+ in_transform = torch.tanh(all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size))
509
+ else:
510
+ in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size)
511
+ in_transform = torch.max(\
512
+ in_transform.narrow(1, 0, self.rnn_size),
513
+ in_transform.narrow(1, self.rnn_size, self.rnn_size))
514
+ # perform the LSTM update
515
+ next_c = forget_gate * prev_c + in_gate * in_transform
516
+ # gated cells form the output
517
+ tanh_nex_c = torch.tanh(next_c)
518
+ next_h = out_gate * tanh_nex_c
519
+ if L == self.num_layers-1:
520
+ if L == 0:
521
+ i2h = self.r_w2h(x) + self.r_v2h(img_fc)
522
+ else:
523
+ i2h = self.r_i2h(x)
524
+ n5 = i2h+self.r_h2h(prev_h)
525
+ fake_region = torch.sigmoid(n5) * tanh_nex_c
526
+
527
+ cs.append(next_c)
528
+ hs.append(next_h)
529
+
530
+ # set up the decoder
531
+ top_h = hs[-1]
532
+ top_h = F.dropout(top_h, self.drop_prob_lm, self.training)
533
+ fake_region = F.dropout(fake_region, self.drop_prob_lm, self.training)
534
+
535
+ state = (torch.cat([_.unsqueeze(0) for _ in hs], 0),
536
+ torch.cat([_.unsqueeze(0) for _ in cs], 0))
537
+ return top_h, fake_region, state
538
+
539
+ class AdaAtt_attention(nn.Module):
540
+ def __init__(self, opt):
541
+ super(AdaAtt_attention, self).__init__()
542
+ self.input_encoding_size = opt.input_encoding_size
543
+ #self.rnn_type = opt.rnn_type
544
+ self.rnn_size = opt.rnn_size
545
+ self.drop_prob_lm = opt.drop_prob_lm
546
+ self.att_hid_size = opt.att_hid_size
547
+
548
+ # fake region embed
549
+ self.fr_linear = nn.Sequential(
550
+ nn.Linear(self.rnn_size, self.input_encoding_size),
551
+ nn.ReLU(),
552
+ nn.Dropout(self.drop_prob_lm))
553
+ self.fr_embed = nn.Linear(self.input_encoding_size, self.att_hid_size)
554
+
555
+ # h out embed
556
+ self.ho_linear = nn.Sequential(
557
+ nn.Linear(self.rnn_size, self.input_encoding_size),
558
+ nn.Tanh(),
559
+ nn.Dropout(self.drop_prob_lm))
560
+ self.ho_embed = nn.Linear(self.input_encoding_size, self.att_hid_size)
561
+
562
+ self.alpha_net = nn.Linear(self.att_hid_size, 1)
563
+ self.att2h = nn.Linear(self.rnn_size, self.rnn_size)
564
+
565
+ def forward(self, h_out, fake_region, conv_feat, conv_feat_embed, att_masks=None):
566
+
567
+ # View into three dimensions
568
+ att_size = conv_feat.numel() // conv_feat.size(0) // self.rnn_size
569
+ conv_feat = conv_feat.view(-1, att_size, self.rnn_size)
570
+ conv_feat_embed = conv_feat_embed.view(-1, att_size, self.att_hid_size)
571
+
572
+ # view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num
573
+ fake_region = self.fr_linear(fake_region)
574
+ fake_region_embed = self.fr_embed(fake_region)
575
+
576
+ h_out_linear = self.ho_linear(h_out)
577
+ h_out_embed = self.ho_embed(h_out_linear)
578
+
579
+ txt_replicate = h_out_embed.unsqueeze(1).expand(h_out_embed.size(0), att_size + 1, h_out_embed.size(1))
580
+
581
+ img_all = torch.cat([fake_region.view(-1,1,self.input_encoding_size), conv_feat], 1)
582
+ img_all_embed = torch.cat([fake_region_embed.view(-1,1,self.input_encoding_size), conv_feat_embed], 1)
583
+
584
+ hA = torch.tanh(img_all_embed + txt_replicate)
585
+ hA = F.dropout(hA,self.drop_prob_lm, self.training)
586
+
587
+ hAflat = self.alpha_net(hA.view(-1, self.att_hid_size))
588
+ PI = F.softmax(hAflat.view(-1, att_size + 1), dim=1)
589
+
590
+ if att_masks is not None:
591
+ att_masks = att_masks.view(-1, att_size)
592
+ PI = PI * torch.cat([att_masks[:,:1], att_masks], 1) # assume one one at the first time step.
593
+ PI = PI / PI.sum(1, keepdim=True)
594
+
595
+ visAtt = torch.bmm(PI.unsqueeze(1), img_all)
596
+ visAttdim = visAtt.squeeze(1)
597
+
598
+ atten_out = visAttdim + h_out_linear
599
+
600
+ h = torch.tanh(self.att2h(atten_out))
601
+ h = F.dropout(h, self.drop_prob_lm, self.training)
602
+ return h
603
+
604
+ class AdaAttCore(nn.Module):
605
+ def __init__(self, opt, use_maxout=False):
606
+ super(AdaAttCore, self).__init__()
607
+ self.lstm = AdaAtt_lstm(opt, use_maxout)
608
+ self.attention = AdaAtt_attention(opt)
609
+
610
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
611
+ h_out, p_out, state = self.lstm(xt, fc_feats, state)
612
+ atten_out = self.attention(h_out, p_out, att_feats, p_att_feats, att_masks)
613
+ return atten_out, state
614
+
615
+ class UpDownCore(nn.Module):
616
+ def __init__(self, opt, use_maxout=False):
617
+ super(UpDownCore, self).__init__()
618
+ self.drop_prob_lm = opt.drop_prob_lm
619
+
620
+ self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) # we, fc, h^2_t-1
621
+ self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v
622
+ self.attention = Attention(opt)
623
+
624
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
625
+ prev_h = state[0][-1]
626
+ att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1)
627
+
628
+ h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0]))
629
+
630
+ att = self.attention(h_att, att_feats, p_att_feats, att_masks)
631
+
632
+ lang_lstm_input = torch.cat([att, h_att], 1)
633
+ # lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ?????
634
+
635
+ h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1]))
636
+
637
+ output = F.dropout(h_lang, self.drop_prob_lm, self.training)
638
+ state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang]))
639
+
640
+ return output, state
641
+
642
+
643
+ ############################################################################
644
+ # Notice:
645
+ # StackAtt and DenseAtt are models that I randomly designed.
646
+ # They are not related to any paper.
647
+ ############################################################################
648
+
649
+ from .FCModel import LSTMCore
650
+ class StackAttCore(nn.Module):
651
+ def __init__(self, opt, use_maxout=False):
652
+ super(StackAttCore, self).__init__()
653
+ self.drop_prob_lm = opt.drop_prob_lm
654
+
655
+ # self.att0 = Attention(opt)
656
+ self.att1 = Attention(opt)
657
+ self.att2 = Attention(opt)
658
+
659
+ opt_input_encoding_size = opt.input_encoding_size
660
+ opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size
661
+ self.lstm0 = LSTMCore(opt) # att_feat + word_embedding
662
+ opt.input_encoding_size = opt.rnn_size * 2
663
+ self.lstm1 = LSTMCore(opt)
664
+ self.lstm2 = LSTMCore(opt)
665
+ opt.input_encoding_size = opt_input_encoding_size
666
+
667
+ # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size)
668
+ self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size)
669
+
670
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
671
+ # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks)
672
+ h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]])
673
+ att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks)
674
+ h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]])
675
+ att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks)
676
+ h_2, state_2 = self.lstm2(torch.cat([h_1,att_res_2],1), [state[0][2:3], state[1][2:3]])
677
+
678
+ return h_2, [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)]
679
+
680
+ class DenseAttCore(nn.Module):
681
+ def __init__(self, opt, use_maxout=False):
682
+ super(DenseAttCore, self).__init__()
683
+ self.drop_prob_lm = opt.drop_prob_lm
684
+
685
+ # self.att0 = Attention(opt)
686
+ self.att1 = Attention(opt)
687
+ self.att2 = Attention(opt)
688
+
689
+ opt_input_encoding_size = opt.input_encoding_size
690
+ opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size
691
+ self.lstm0 = LSTMCore(opt) # att_feat + word_embedding
692
+ opt.input_encoding_size = opt.rnn_size * 2
693
+ self.lstm1 = LSTMCore(opt)
694
+ self.lstm2 = LSTMCore(opt)
695
+ opt.input_encoding_size = opt_input_encoding_size
696
+
697
+ # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size)
698
+ self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size)
699
+
700
+ # fuse h_0 and h_1
701
+ self.fusion1 = nn.Sequential(nn.Linear(opt.rnn_size*2, opt.rnn_size),
702
+ nn.ReLU(),
703
+ nn.Dropout(opt.drop_prob_lm))
704
+ # fuse h_0, h_1 and h_2
705
+ self.fusion2 = nn.Sequential(nn.Linear(opt.rnn_size*3, opt.rnn_size),
706
+ nn.ReLU(),
707
+ nn.Dropout(opt.drop_prob_lm))
708
+
709
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
710
+ # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks)
711
+ h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]])
712
+ att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks)
713
+ h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]])
714
+ att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks)
715
+ h_2, state_2 = self.lstm2(torch.cat([self.fusion1(torch.cat([h_0, h_1], 1)),att_res_2],1), [state[0][2:3], state[1][2:3]])
716
+
717
+ return self.fusion2(torch.cat([h_0, h_1, h_2], 1)), [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)]
718
+
719
+ class Attention(nn.Module):
720
+ def __init__(self, opt):
721
+ super(Attention, self).__init__()
722
+ self.rnn_size = opt.rnn_size
723
+ self.att_hid_size = opt.att_hid_size
724
+
725
+ self.h2att = nn.Linear(self.rnn_size, self.att_hid_size)
726
+ self.alpha_net = nn.Linear(self.att_hid_size, 1)
727
+
728
+ def forward(self, h, att_feats, p_att_feats, att_masks=None):
729
+ # The p_att_feats here is already projected
730
+ att_size = att_feats.numel() // att_feats.size(0) // att_feats.size(-1)
731
+ att = p_att_feats.view(-1, att_size, self.att_hid_size)
732
+
733
+ att_h = self.h2att(h) # batch * att_hid_size
734
+ att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size
735
+ dot = att + att_h # batch * att_size * att_hid_size
736
+ dot = torch.tanh(dot) # batch * att_size * att_hid_size
737
+ dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size
738
+ dot = self.alpha_net(dot) # (batch * att_size) * 1
739
+ dot = dot.view(-1, att_size) # batch * att_size
740
+
741
+ weight = F.softmax(dot, dim=1) # batch * att_size
742
+ if att_masks is not None:
743
+ weight = weight * att_masks.view(-1, att_size).to(weight)
744
+ weight = weight / weight.sum(1, keepdim=True) # normalize to 1
745
+ att_feats_ = att_feats.view(-1, att_size, att_feats.size(-1)) # batch * att_size * att_feat_size
746
+ att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size
747
+
748
+ return att_res
749
+
750
+ class Att2in2Core(nn.Module):
751
+ def __init__(self, opt):
752
+ super(Att2in2Core, self).__init__()
753
+ self.input_encoding_size = opt.input_encoding_size
754
+ #self.rnn_type = opt.rnn_type
755
+ self.rnn_size = opt.rnn_size
756
+ #self.num_layers = opt.num_layers
757
+ self.drop_prob_lm = opt.drop_prob_lm
758
+ self.fc_feat_size = opt.fc_feat_size
759
+ self.att_feat_size = opt.att_feat_size
760
+ self.att_hid_size = opt.att_hid_size
761
+
762
+ # Build a LSTM
763
+ self.a2c = nn.Linear(self.rnn_size, 2 * self.rnn_size)
764
+ self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
765
+ self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
766
+ self.dropout = nn.Dropout(self.drop_prob_lm)
767
+
768
+ self.attention = Attention(opt)
769
+
770
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
771
+ att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks)
772
+
773
+ all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
774
+ sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
775
+ sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
776
+ in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
777
+ forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
778
+ out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
779
+
780
+ in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \
781
+ self.a2c(att_res)
782
+ in_transform = torch.max(\
783
+ in_transform.narrow(1, 0, self.rnn_size),
784
+ in_transform.narrow(1, self.rnn_size, self.rnn_size))
785
+ next_c = forget_gate * state[1][-1] + in_gate * in_transform
786
+ next_h = out_gate * torch.tanh(next_c)
787
+
788
+ output = self.dropout(next_h)
789
+ state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
790
+ return output, state
791
+
792
+ class Att2inCore(Att2in2Core):
793
+ def __init__(self, opt):
794
+ super(Att2inCore, self).__init__(opt)
795
+ del self.a2c
796
+ self.a2c = nn.Linear(self.att_feat_size, 2 * self.rnn_size)
797
+
798
+ """
799
+ Note this is my attempt to replicate att2all model in self-critical paper.
800
+ However, this is not a correct replication actually. Will fix it.
801
+ """
802
+ class Att2all2Core(nn.Module):
803
+ def __init__(self, opt):
804
+ super(Att2all2Core, self).__init__()
805
+ self.input_encoding_size = opt.input_encoding_size
806
+ #self.rnn_type = opt.rnn_type
807
+ self.rnn_size = opt.rnn_size
808
+ #self.num_layers = opt.num_layers
809
+ self.drop_prob_lm = opt.drop_prob_lm
810
+ self.fc_feat_size = opt.fc_feat_size
811
+ self.att_feat_size = opt.att_feat_size
812
+ self.att_hid_size = opt.att_hid_size
813
+
814
+ # Build a LSTM
815
+ self.a2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
816
+ self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
817
+ self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
818
+ self.dropout = nn.Dropout(self.drop_prob_lm)
819
+
820
+ self.attention = Attention(opt)
821
+
822
+ def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
823
+ att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks)
824
+
825
+ all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) + self.a2h(att_res)
826
+ sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
827
+ sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
828
+ in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
829
+ forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
830
+ out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
831
+
832
+ in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size)
833
+ in_transform = torch.max(\
834
+ in_transform.narrow(1, 0, self.rnn_size),
835
+ in_transform.narrow(1, self.rnn_size, self.rnn_size))
836
+ next_c = forget_gate * state[1][-1] + in_gate * in_transform
837
+ next_h = out_gate * torch.tanh(next_c)
838
+
839
+ output = self.dropout(next_h)
840
+ state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
841
+ return output, state
842
+
843
+ class AdaAttModel(AttModel):
844
+ def __init__(self, opt):
845
+ super(AdaAttModel, self).__init__(opt)
846
+ self.core = AdaAttCore(opt)
847
+
848
+ # AdaAtt with maxout lstm
849
+ class AdaAttMOModel(AttModel):
850
+ def __init__(self, opt):
851
+ super(AdaAttMOModel, self).__init__(opt)
852
+ self.core = AdaAttCore(opt, True)
853
+
854
+ class Att2in2Model(AttModel):
855
+ def __init__(self, opt):
856
+ super(Att2in2Model, self).__init__(opt)
857
+ self.core = Att2in2Core(opt)
858
+ delattr(self, 'fc_embed')
859
+ self.fc_embed = lambda x : x
860
+
861
+ class Att2all2Model(AttModel):
862
+ def __init__(self, opt):
863
+ super(Att2all2Model, self).__init__(opt)
864
+ self.core = Att2all2Core(opt)
865
+ delattr(self, 'fc_embed')
866
+ self.fc_embed = lambda x : x
867
+
868
+ class UpDownModel(AttModel):
869
+ def __init__(self, opt):
870
+ super(UpDownModel, self).__init__(opt)
871
+ self.num_layers = 2
872
+ self.core = UpDownCore(opt)
873
+
874
+ class StackAttModel(AttModel):
875
+ def __init__(self, opt):
876
+ super(StackAttModel, self).__init__(opt)
877
+ self.num_layers = 3
878
+ self.core = StackAttCore(opt)
879
+
880
+ class DenseAttModel(AttModel):
881
+ def __init__(self, opt):
882
+ super(DenseAttModel, self).__init__(opt)
883
+ self.num_layers = 3
884
+ self.core = DenseAttCore(opt)
885
+
886
+ class Att2inModel(AttModel):
887
+ def __init__(self, opt):
888
+ super(Att2inModel, self).__init__(opt)
889
+ del self.embed, self.fc_embed, self.att_embed
890
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
891
+ self.fc_embed = self.att_embed = lambda x: x
892
+ del self.ctx2att
893
+ self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size)
894
+ self.core = Att2inCore(opt)
895
+ self.init_weights()
896
+
897
+ def init_weights(self):
898
+ initrange = 0.1
899
+ self.embed.weight.data.uniform_(-initrange, initrange)
900
+ self.logit.bias.data.fill_(0)
901
+ self.logit.weight.data.uniform_(-initrange, initrange)
902
+
903
+
904
+ class NewFCModel(AttModel):
905
+ def __init__(self, opt):
906
+ super(NewFCModel, self).__init__(opt)
907
+ self.fc_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
908
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
909
+ self._core = LSTMCore(opt)
910
+ delattr(self, 'att_embed')
911
+ self.att_embed = lambda x : x
912
+ delattr(self, 'ctx2att')
913
+ self.ctx2att = lambda x: x
914
+
915
+ def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks):
916
+ # Step 0, feed the input image
917
+ # if (self.training and state[0].is_leaf) or \
918
+ # (not self.training and state[0].sum() == 0):
919
+ # _, state = self._core(fc_feats, state)
920
+ # three cases
921
+ # normal mle training
922
+ # Sample
923
+ # beam search (diverse beam search)
924
+ # fixed captioning module.
925
+ is_first_step = (state[0]==0).all(2).all(0) # size: B
926
+ if is_first_step.all():
927
+ _, state = self._core(fc_feats, state)
928
+ elif is_first_step.any():
929
+ # This is mostly for diverse beam search I think
930
+ new_state = [torch.zeros_like(_) for _ in state]
931
+ new_state[0][:, ~is_first_step] = state[0][:, ~is_first_step]
932
+ new_state[1][:, ~is_first_step] = state[1][:, ~is_first_step]
933
+ _, state = self._core(fc_feats, state)
934
+ new_state[0][:, is_first_step] = state[0][:, is_first_step]
935
+ new_state[1][:, is_first_step] = state[1][:, is_first_step]
936
+ state = new_state
937
+ # if (state[0]==0).all():
938
+ # # Let's forget about diverse beam search first
939
+ # _, state = self._core(fc_feats, state)
940
+ return self._core(xt, state)
941
+
942
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
943
+ fc_feats = self.fc_embed(fc_feats)
944
+
945
+ return fc_feats, att_feats, att_feats, att_masks
946
+
947
+
948
+ class LMModel(AttModel):
949
+ def __init__(self, opt):
950
+ super(LMModel, self).__init__(opt)
951
+ delattr(self, 'fc_embed')
952
+ self.fc_embed = lambda x: x.new_zeros(x.shape[0], self.input_encoding_size)
953
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
954
+ self._core = LSTMCore(opt)
955
+ delattr(self, 'att_embed')
956
+ self.att_embed = lambda x : x
957
+ delattr(self, 'ctx2att')
958
+ self.ctx2att = lambda x: x
959
+
960
+ def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks):
961
+ if (state[0]==0).all():
962
+ # Let's forget about diverse beam search first
963
+ _, state = self._core(fc_feats, state)
964
+ return self._core(xt, state)
965
+
966
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
967
+ fc_feats = self.fc_embed(fc_feats)
968
+
969
+ return fc_feats, None, None, None
captioning/models/BertCapModel.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BertCapModel is using huggingface transformer bert model as seq2seq model.
3
+
4
+ The result is not as goog as original transformer.
5
+ """
6
+
7
+ from __future__ import absolute_import
8
+ from __future__ import division
9
+ from __future__ import print_function
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ import copy
16
+ import math
17
+ import numpy as np
18
+
19
+ from .CaptionModel import CaptionModel
20
+ from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
21
+ try:
22
+ from transformers import BertModel, BertConfig
23
+ except:
24
+ print('Hugginface transformers not installed; please visit https://github.com/huggingface/transformers')
25
+ from .TransformerModel import subsequent_mask, TransformerModel, Generator
26
+
27
+ class EncoderDecoder(nn.Module):
28
+ """
29
+ A standard Encoder-Decoder architecture. Base for this and many
30
+ other models.
31
+ """
32
+ def __init__(self, encoder, decoder, generator):
33
+ super(EncoderDecoder, self).__init__()
34
+ self.encoder = encoder
35
+ self.decoder = decoder
36
+ self.generator = generator
37
+
38
+ def forward(self, src, tgt, src_mask, tgt_mask):
39
+ "Take in and process masked src and target sequences."
40
+ return self.decode(self.encode(src, src_mask), src_mask,
41
+ tgt, tgt_mask)
42
+
43
+ def encode(self, src, src_mask):
44
+ return self.encoder(inputs_embeds=src,
45
+ attention_mask=src_mask)[0]
46
+
47
+ def decode(self, memory, src_mask, tgt, tgt_mask):
48
+ return self.decoder(input_ids=tgt,
49
+ attention_mask=tgt_mask,
50
+ encoder_hidden_states=memory,
51
+ encoder_attention_mask=src_mask)[0]
52
+
53
+
54
+ class BertCapModel(TransformerModel):
55
+
56
+ def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
57
+ d_model=512, d_ff=2048, h=8, dropout=0.1):
58
+ "Helper: Construct a model from hyperparameters."
59
+ enc_config = BertConfig(vocab_size=1,
60
+ hidden_size=d_model,
61
+ num_hidden_layers=N_enc,
62
+ num_attention_heads=h,
63
+ intermediate_size=d_ff,
64
+ hidden_dropout_prob=dropout,
65
+ attention_probs_dropout_prob=dropout,
66
+ max_position_embeddings=1,
67
+ type_vocab_size=1)
68
+ dec_config = BertConfig(vocab_size=tgt_vocab,
69
+ hidden_size=d_model,
70
+ num_hidden_layers=N_dec,
71
+ num_attention_heads=h,
72
+ intermediate_size=d_ff,
73
+ hidden_dropout_prob=dropout,
74
+ attention_probs_dropout_prob=dropout,
75
+ max_position_embeddings=17,
76
+ type_vocab_size=1,
77
+ is_decoder=True)
78
+ encoder = BertModel(enc_config)
79
+ def return_embeds(*args, **kwargs):
80
+ return kwargs['inputs_embeds']
81
+ del encoder.embeddings; encoder.embeddings = return_embeds
82
+ decoder = BertModel(dec_config)
83
+ model = EncoderDecoder(
84
+ encoder,
85
+ decoder,
86
+ Generator(d_model, tgt_vocab))
87
+ return model
88
+
89
+ def __init__(self, opt):
90
+ super(BertCapModel, self).__init__(opt)
91
+
92
+ def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
93
+ """
94
+ state = [ys.unsqueeze(0)]
95
+ """
96
+ if len(state) == 0:
97
+ ys = it.unsqueeze(1)
98
+ else:
99
+ ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
100
+ out = self.model.decode(memory, mask,
101
+ ys,
102
+ subsequent_mask(ys.size(1))
103
+ .to(memory.device))
104
+ return out[:, -1], [ys.unsqueeze(0)]
captioning/models/CaptionModel.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains ShowAttendTell and AllImg model
2
+
3
+ # ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
4
+ # https://arxiv.org/abs/1502.03044
5
+
6
+ # AllImg is a model where
7
+ # img feature is concatenated with word embedding at every time step as the input of lstm
8
+ from __future__ import absolute_import
9
+ from __future__ import division
10
+ from __future__ import print_function
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.autograd import *
17
+ from ..utils import misc as utils
18
+ from . import utils as model_utils
19
+
20
+
21
+ class CaptionModel(nn.Module):
22
+ def __init__(self):
23
+ super(CaptionModel, self).__init__()
24
+
25
+ # implements beam search
26
+ # calls beam_step and returns the final set of beams
27
+ # augments log-probabilities with diversity terms when number of groups > 1
28
+
29
+ def forward(self, *args, **kwargs):
30
+ mode = kwargs.get('mode', 'forward')
31
+ if 'mode' in kwargs:
32
+ del kwargs['mode']
33
+ return getattr(self, '_'+mode)(*args, **kwargs)
34
+
35
+ def beam_search(self, init_state, init_logprobs, *args, **kwargs):
36
+
37
+ # function computes the similarity score to be augmented
38
+ def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
39
+ local_time = t - divm
40
+ unaug_logprobs = logprobs.clone()
41
+ batch_size = beam_seq_table[0].shape[0]
42
+
43
+ if divm > 0:
44
+ change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
45
+ for prev_choice in range(divm):
46
+ prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb
47
+ for prev_labels in range(bdash):
48
+ change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1))
49
+
50
+ if local_time == 0:
51
+ logprobs = logprobs - change * diversity_lambda
52
+ else:
53
+ logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda
54
+
55
+ return logprobs, unaug_logprobs
56
+
57
+
58
+ # does one step of classical beam search
59
+
60
+ def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
61
+ #INPUTS:
62
+ #logprobs: probabilities augmented after diversity N*bxV
63
+ #beam_size: obvious
64
+ #t : time instant
65
+ #beam_seq : tensor contanining the beams
66
+ #beam_seq_logprobs: tensor contanining the beam logprobs
67
+ #beam_logprobs_sum: tensor contanining joint logprobs
68
+ #OUPUTS:
69
+ #beam_seq : tensor containing the word indices of the decoded captions Nxbxl
70
+ #beam_seq_logprobs : log-probability of each decision made, NxbxlxV
71
+ #beam_logprobs_sum : joint log-probability of each beam Nxb
72
+
73
+ batch_size = beam_logprobs_sum.shape[0]
74
+ vocab_size = logprobs.shape[-1]
75
+ logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV
76
+ if t == 0:
77
+ assert logprobs.shape[1] == 1
78
+ beam_logprobs_sum = beam_logprobs_sum[:, :1]
79
+ candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV
80
+ ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True)
81
+ ys, ix = ys[:,:beam_size], ix[:,:beam_size]
82
+ beam_ix = ix // vocab_size # Nxb which beam
83
+ selected_ix = ix % vocab_size # Nxb # which world
84
+ state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams
85
+
86
+
87
+ if t > 0:
88
+ # gather according to beam_ix
89
+ assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all()
90
+ beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq))
91
+
92
+ beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(beam_seq_logprobs))
93
+
94
+ beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl
95
+ beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \
96
+ logprobs.reshape(batch_size, -1).gather(1, ix)
97
+ assert (beam_logprobs_sum == ys).all()
98
+ _tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size)
99
+ beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV
100
+ assert (_tmp_beam_logprobs == beam_logprobs).all()
101
+ beam_seq_logprobs = torch.cat([
102
+ beam_seq_logprobs,
103
+ beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2)
104
+
105
+ new_state = [None for _ in state]
106
+ for _ix in range(len(new_state)):
107
+ # copy over state in previous beam q to new beam at vix
108
+ new_state[_ix] = state[_ix][:, state_ix]
109
+ state = new_state
110
+ return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state
111
+
112
+ # Start diverse_beam_search
113
+ opt = kwargs['opt']
114
+ temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
115
+ beam_size = opt.get('beam_size', 10)
116
+ group_size = opt.get('group_size', 1)
117
+ diversity_lambda = opt.get('diversity_lambda', 0.5)
118
+ decoding_constraint = opt.get('decoding_constraint', 0)
119
+ remove_bad_endings = opt.get('remove_bad_endings', 0)
120
+ suppress_UNK = opt.get('suppress_UNK', 0)
121
+ length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
122
+ bdash = beam_size // group_size # beam per group
123
+
124
+ batch_size = init_logprobs.shape[0]
125
+ device = init_logprobs.device
126
+ # INITIALIZATIONS
127
+ beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)]
128
+ beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in range(group_size)]
129
+ beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)]
130
+
131
+ # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
132
+ done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)]
133
+ # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
134
+ # state_table = list(zip(*[_.reshape(-1, batch_size * bdash, group_size, *_.shape[2:]).chunk(group_size, 2) for _ in init_state]))
135
+ state_table = [[_.clone() for _ in init_state] for _ in range(group_size)]
136
+ # logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0))
137
+ logprobs_table = [init_logprobs.clone() for _ in range(group_size)]
138
+ # END INIT
139
+
140
+ # Chunk elements in the args
141
+ args = list(args)
142
+ args = model_utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x...
143
+ if self.__class__.__name__ == 'AttEnsemble':
144
+ args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
145
+ else:
146
+ args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
147
+
148
+ for t in range(self.seq_length + group_size - 1):
149
+ for divm in range(group_size):
150
+ if t >= divm and t <= self.seq_length + divm - 1:
151
+ # add diversity
152
+ logprobs = logprobs_table[divm]
153
+ # suppress previous word
154
+ if decoding_constraint and t-divm > 0:
155
+ logprobs.scatter_(1, beam_seq_table[divm][:, :, t-divm-1].reshape(-1, 1).to(device), float('-inf'))
156
+ if remove_bad_endings and t-divm > 0:
157
+ logprobs[torch.from_numpy(np.isin(beam_seq_table[divm][:, :, t-divm-1].cpu().numpy(), self.bad_endings_ix)).reshape(-1), 0] = float('-inf')
158
+ # suppress UNK tokens in the decoding
159
+ if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK':
160
+ logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000
161
+ # diversity is added here
162
+ # the function directly modifies the logprobs values and hence, we need to return
163
+ # the unaugmented ones for sorting the candidates in the end. # for historical
164
+ # reasons :-)
165
+ logprobs, unaug_logprobs = add_diversity(beam_seq_table,logprobs,t,divm,diversity_lambda,bdash)
166
+
167
+ # infer new beams
168
+ beam_seq_table[divm],\
169
+ beam_seq_logprobs_table[divm],\
170
+ beam_logprobs_sum_table[divm],\
171
+ state_table[divm] = beam_step(logprobs,
172
+ unaug_logprobs,
173
+ bdash,
174
+ t-divm,
175
+ beam_seq_table[divm],
176
+ beam_seq_logprobs_table[divm],
177
+ beam_logprobs_sum_table[divm],
178
+ state_table[divm])
179
+
180
+ # if time's up... or if end token is reached then copy beams
181
+ for b in range(batch_size):
182
+ is_end = beam_seq_table[divm][b, :, t-divm] == self.eos_idx
183
+ assert beam_seq_table[divm].shape[-1] == t-divm+1
184
+ if t == self.seq_length + divm - 1:
185
+ is_end.fill_(1)
186
+ for vix in range(bdash):
187
+ if is_end[vix]:
188
+ final_beam = {
189
+ 'seq': beam_seq_table[divm][b, vix].clone(),
190
+ 'logps': beam_seq_logprobs_table[divm][b, vix].clone(),
191
+ 'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(),
192
+ 'p': beam_logprobs_sum_table[divm][b, vix].item()
193
+ }
194
+ final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
195
+ done_beams_table[b][divm].append(final_beam)
196
+ beam_logprobs_sum_table[divm][b, is_end] -= 1000
197
+
198
+ # move the current group one step forward in time
199
+
200
+ it = beam_seq_table[divm][:, :, t-divm].reshape(-1).to(logprobs.device)
201
+ logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
202
+ logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
203
+
204
+ # all beams are sorted by their log-probabilities
205
+ done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] for b in range(batch_size)]
206
+ done_beams = [sum(_, []) for _ in done_beams_table]
207
+ return done_beams
208
+
209
+ def old_beam_search(self, init_state, init_logprobs, *args, **kwargs):
210
+
211
+ # function computes the similarity score to be augmented
212
+ def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
213
+ local_time = t - divm
214
+ unaug_logprobsf = logprobsf.clone()
215
+ for prev_choice in range(divm):
216
+ prev_decisions = beam_seq_table[prev_choice][local_time]
217
+ for sub_beam in range(bdash):
218
+ for prev_labels in range(bdash):
219
+ logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda
220
+ return unaug_logprobsf
221
+
222
+ # does one step of classical beam search
223
+
224
+ def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
225
+ #INPUTS:
226
+ #logprobsf: probabilities augmented after diversity
227
+ #beam_size: obvious
228
+ #t : time instant
229
+ #beam_seq : tensor contanining the beams
230
+ #beam_seq_logprobs: tensor contanining the beam logprobs
231
+ #beam_logprobs_sum: tensor contanining joint logprobs
232
+ #OUPUTS:
233
+ #beam_seq : tensor containing the word indices of the decoded captions
234
+ #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
235
+ #beam_logprobs_sum : joint log-probability of each beam
236
+
237
+ ys,ix = torch.sort(logprobsf,1,True)
238
+ candidates = []
239
+ cols = min(beam_size, ys.size(1))
240
+ rows = beam_size
241
+ if t == 0:
242
+ rows = 1
243
+ for c in range(cols): # for each column (word, essentially)
244
+ for q in range(rows): # for each beam expansion
245
+ #compute logprob of expanding beam q with word in (sorted) position c
246
+ local_logprob = ys[q,c].item()
247
+ candidate_logprob = beam_logprobs_sum[q] + local_logprob
248
+ # local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
249
+ candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':unaug_logprobsf[q]})
250
+ candidates = sorted(candidates, key=lambda x: -x['p'])
251
+
252
+ new_state = [_.clone() for _ in state]
253
+ #beam_seq_prev, beam_seq_logprobs_prev
254
+ if t >= 1:
255
+ #we''ll need these as reference when we fork beams around
256
+ beam_seq_prev = beam_seq[:t].clone()
257
+ beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
258
+ for vix in range(beam_size):
259
+ v = candidates[vix]
260
+ #fork beam index q into index vix
261
+ if t >= 1:
262
+ beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
263
+ beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
264
+ #rearrange recurrent states
265
+ for state_ix in range(len(new_state)):
266
+ # copy over state in previous beam q to new beam at vix
267
+ new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
268
+ #append new end terminal at the end of this beam
269
+ beam_seq[t, vix] = v['c'] # c'th word is the continuation
270
+ beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
271
+ beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
272
+ state = new_state
273
+ return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates
274
+
275
+ # Start diverse_beam_search
276
+ opt = kwargs['opt']
277
+ temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
278
+ beam_size = opt.get('beam_size', 10)
279
+ group_size = opt.get('group_size', 1)
280
+ diversity_lambda = opt.get('diversity_lambda', 0.5)
281
+ decoding_constraint = opt.get('decoding_constraint', 0)
282
+ remove_bad_endings = opt.get('remove_bad_endings', 0)
283
+ suppress_UNK = opt.get('suppress_UNK', 0)
284
+ length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
285
+ bdash = beam_size // group_size # beam per group
286
+
287
+ # INITIALIZATIONS
288
+ beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
289
+ beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash, self.vocab_size + 1).zero_() for _ in range(group_size)]
290
+ beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
291
+
292
+ # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
293
+ done_beams_table = [[] for _ in range(group_size)]
294
+ # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
295
+ state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))
296
+ logprobs_table = list(init_logprobs.chunk(group_size, 0))
297
+ # END INIT
298
+
299
+ # Chunk elements in the args
300
+ args = list(args)
301
+ if self.__class__.__name__ == 'AttEnsemble':
302
+ args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name
303
+ args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
304
+ else:
305
+ args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args]
306
+ args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
307
+
308
+ for t in range(self.seq_length + group_size - 1):
309
+ for divm in range(group_size):
310
+ if t >= divm and t <= self.seq_length + divm - 1:
311
+ # add diversity
312
+ logprobsf = logprobs_table[divm]
313
+ # suppress previous word
314
+ if decoding_constraint and t-divm > 0:
315
+ logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).to(logprobsf.device), float('-inf'))
316
+ if remove_bad_endings and t-divm > 0:
317
+ logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf')
318
+ # suppress UNK tokens in the decoding
319
+ if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1)-1)] == 'UNK':
320
+ logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000
321
+ # diversity is added here
322
+ # the function directly modifies the logprobsf values and hence, we need to return
323
+ # the unaugmented ones for sorting the candidates in the end. # for historical
324
+ # reasons :-)
325
+ unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash)
326
+
327
+ # infer new beams
328
+ beam_seq_table[divm],\
329
+ beam_seq_logprobs_table[divm],\
330
+ beam_logprobs_sum_table[divm],\
331
+ state_table[divm],\
332
+ candidates_divm = beam_step(logprobsf,
333
+ unaug_logprobsf,
334
+ bdash,
335
+ t-divm,
336
+ beam_seq_table[divm],
337
+ beam_seq_logprobs_table[divm],
338
+ beam_logprobs_sum_table[divm],
339
+ state_table[divm])
340
+
341
+ # if time's up... or if end token is reached then copy beams
342
+ for vix in range(bdash):
343
+ if beam_seq_table[divm][t-divm,vix] == self.eos_idx or t == self.seq_length + divm - 1:
344
+ final_beam = {
345
+ 'seq': beam_seq_table[divm][:, vix].clone(),
346
+ 'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
347
+ 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
348
+ 'p': beam_logprobs_sum_table[divm][vix].item()
349
+ }
350
+ final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
351
+ done_beams_table[divm].append(final_beam)
352
+ # don't continue beams from finished sequences
353
+ beam_logprobs_sum_table[divm][vix] = -1000
354
+
355
+ # move the current group one step forward in time
356
+
357
+ it = beam_seq_table[divm][t-divm].to(logprobsf.device)
358
+ logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
359
+ logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
360
+
361
+ # all beams are sorted by their log-probabilities
362
+ done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
363
+ done_beams = sum(done_beams_table, [])
364
+ return done_beams
365
+
366
+ def sample_next_word(self, logprobs, sample_method, temperature):
367
+ if sample_method == 'greedy':
368
+ sampleLogprobs, it = torch.max(logprobs.data, 1)
369
+ it = it.view(-1).long()
370
+ elif sample_method == 'gumbel': # gumbel softmax
371
+ # ref: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
372
+ def sample_gumbel(shape, eps=1e-20):
373
+ U = torch.rand(shape).to(logprobs.device)
374
+ return -torch.log(-torch.log(U + eps) + eps)
375
+ def gumbel_softmax_sample(logits, temperature):
376
+ y = logits + sample_gumbel(logits.size())
377
+ return F.log_softmax(y / temperature, dim=-1)
378
+ _logprobs = gumbel_softmax_sample(logprobs, temperature)
379
+ _, it = torch.max(_logprobs.data, 1)
380
+ sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
381
+ else:
382
+ logprobs = logprobs / temperature
383
+ if sample_method.startswith('top'): # topk sampling
384
+ top_num = float(sample_method[3:])
385
+ if 0 < top_num < 1:
386
+ # nucleus sampling from # The Curious Case of Neural Text Degeneration
387
+ probs = F.softmax(logprobs, dim=1)
388
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
389
+ _cumsum = sorted_probs.cumsum(1)
390
+ mask = _cumsum < top_num
391
+ mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
392
+ sorted_probs = sorted_probs * mask.to(sorted_probs)
393
+ sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
394
+ logprobs.scatter_(1, sorted_indices, sorted_probs.log())
395
+ else:
396
+ the_k = int(top_num)
397
+ tmp = torch.empty_like(logprobs).fill_(float('-inf'))
398
+ topk, indices = torch.topk(logprobs, the_k, dim=1)
399
+ tmp = tmp.scatter(1, indices, topk)
400
+ logprobs = tmp
401
+ it = torch.distributions.Categorical(logits=logprobs.detach()).sample()
402
+ sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
403
+ return it, sampleLogprobs
404
+
405
+
406
+ def decode_sequence(self, seq):
407
+ return utils.decode_sequence(self.vocab, seq)
captioning/models/FCModel.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.autograd import *
9
+ from . import utils
10
+
11
+ from .CaptionModel import CaptionModel
12
+
13
+ class LSTMCore(nn.Module):
14
+ def __init__(self, opt):
15
+ super(LSTMCore, self).__init__()
16
+ self.input_encoding_size = opt.input_encoding_size
17
+ self.rnn_size = opt.rnn_size
18
+ self.drop_prob_lm = opt.drop_prob_lm
19
+
20
+ # Build a LSTM
21
+ self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
22
+ self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
23
+ self.dropout = nn.Dropout(self.drop_prob_lm)
24
+
25
+ def forward(self, xt, state):
26
+
27
+ all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
28
+ sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
29
+ sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
30
+ in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
31
+ forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
32
+ out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
33
+
34
+ in_transform = torch.max(\
35
+ all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size),
36
+ all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size))
37
+ next_c = forget_gate * state[1][-1] + in_gate * in_transform
38
+ next_h = out_gate * torch.tanh(next_c)
39
+
40
+ output = self.dropout(next_h)
41
+ state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
42
+ return output, state
43
+
44
+ class FCModel(CaptionModel):
45
+ def __init__(self, opt):
46
+ super(FCModel, self).__init__()
47
+ self.vocab_size = opt.vocab_size
48
+ self.input_encoding_size = opt.input_encoding_size
49
+ self.rnn_type = opt.rnn_type
50
+ self.rnn_size = opt.rnn_size
51
+ self.num_layers = opt.num_layers
52
+ self.drop_prob_lm = opt.drop_prob_lm
53
+ self.seq_length = opt.seq_length
54
+ self.fc_feat_size = opt.fc_feat_size
55
+
56
+ self.ss_prob = 0.0 # Schedule sampling probability
57
+
58
+ self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
59
+ self.core = LSTMCore(opt)
60
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
61
+ self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
62
+
63
+ self.init_weights()
64
+
65
+ def init_weights(self):
66
+ initrange = 0.1
67
+ self.embed.weight.data.uniform_(-initrange, initrange)
68
+ self.logit.bias.data.fill_(0)
69
+ self.logit.weight.data.uniform_(-initrange, initrange)
70
+
71
+ def init_hidden(self, bsz):
72
+ weight = self.logit.weight
73
+ if self.rnn_type == 'lstm':
74
+ return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
75
+ weight.new_zeros(self.num_layers, bsz, self.rnn_size))
76
+ else:
77
+ return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
78
+
79
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
80
+ batch_size = fc_feats.size(0)
81
+ seq_per_img = seq.shape[0] // batch_size
82
+ state = self.init_hidden(batch_size*seq_per_img)
83
+ outputs = []
84
+
85
+ if seq_per_img > 1:
86
+ fc_feats = utils.repeat_tensors(seq_per_img, fc_feats)
87
+
88
+ for i in range(seq.size(1) + 1):
89
+ if i == 0:
90
+ xt = self.img_embed(fc_feats)
91
+ else:
92
+ if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
93
+ sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1)
94
+ sample_mask = sample_prob < self.ss_prob
95
+ if sample_mask.sum() == 0:
96
+ it = seq[:, i-1].clone()
97
+ else:
98
+ sample_ind = sample_mask.nonzero().view(-1)
99
+ it = seq[:, i-1].data.clone()
100
+ #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
101
+ #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
102
+ prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
103
+ it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
104
+ else:
105
+ it = seq[:, i-1].clone()
106
+ # break if all the sequences end
107
+ if i >= 2 and seq[:, i-1].sum() == 0:
108
+ break
109
+ xt = self.embed(it)
110
+
111
+ output, state = self.core(xt, state)
112
+ output = F.log_softmax(self.logit(output), dim=1)
113
+ outputs.append(output)
114
+
115
+ return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
116
+
117
+ def get_logprobs_state(self, it, state):
118
+ # 'it' is contains a word index
119
+ xt = self.embed(it)
120
+
121
+ output, state = self.core(xt, state)
122
+ logprobs = F.log_softmax(self.logit(output), dim=1)
123
+
124
+ return logprobs, state
125
+
126
+ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
127
+ beam_size = opt.get('beam_size', 10)
128
+ batch_size = fc_feats.size(0)
129
+
130
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
131
+ seq = torch.LongTensor(self.seq_length, batch_size).zero_()
132
+ seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1)
133
+ # lets process every image independently for now, for simplicity
134
+
135
+ self.done_beams = [[] for _ in range(batch_size)]
136
+ for k in range(batch_size):
137
+ state = self.init_hidden(beam_size)
138
+ for t in range(2):
139
+ if t == 0:
140
+ xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
141
+ elif t == 1: # input <bos>
142
+ it = fc_feats.data.new(beam_size).long().zero_()
143
+ xt = self.embed(it)
144
+
145
+ output, state = self.core(xt, state)
146
+ logprobs = F.log_softmax(self.logit(output), dim=1)
147
+
148
+ self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
149
+ seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
150
+ seqLogprobs[:, k] = self.done_beams[k][0]['logps']
151
+ # return the samples and their log likelihoods
152
+ return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
153
+
154
+ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
155
+ sample_method = opt.get('sample_method', 'greedy')
156
+ beam_size = opt.get('beam_size', 1)
157
+ temperature = opt.get('temperature', 1.0)
158
+ if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
159
+ return self._sample_beam(fc_feats, att_feats, opt)
160
+
161
+ batch_size = fc_feats.size(0)
162
+ state = self.init_hidden(batch_size)
163
+ seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
164
+ seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length, self.vocab_size + 1)
165
+ for t in range(self.seq_length + 2):
166
+ if t == 0:
167
+ xt = self.img_embed(fc_feats)
168
+ else:
169
+ if t == 1: # input <bos>
170
+ it = fc_feats.data.new(batch_size).long().zero_()
171
+ xt = self.embed(it)
172
+
173
+ output, state = self.core(xt, state)
174
+ logprobs = F.log_softmax(self.logit(output), dim=1)
175
+
176
+ # sample the next_word
177
+ if t == self.seq_length + 1: # skip if we achieve maximum length
178
+ break
179
+ if sample_method == 'greedy':
180
+ sampleLogprobs, it = torch.max(logprobs.data, 1)
181
+ it = it.view(-1).long()
182
+ else:
183
+ if temperature == 1.0:
184
+ prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
185
+ else:
186
+ # scale logprobs by temperature
187
+ prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
188
+ it = torch.multinomial(prob_prev, 1).to(logprobs.device)
189
+ sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
190
+ it = it.view(-1).long() # and flatten indices for downstream processing
191
+
192
+ if t >= 1:
193
+ # stop when all finished
194
+ if t == 1:
195
+ unfinished = it > 0
196
+ else:
197
+ unfinished = unfinished & (it > 0)
198
+ it = it * unfinished.type_as(it)
199
+ seq[:,t-1] = it #seq[t] the input of t+2 time step
200
+ seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
201
+ if unfinished.sum() == 0:
202
+ break
203
+
204
+ return seq, seqLogprobs
captioning/models/M2Transformer.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Instruction to use meshed_memory_transformer (https://arxiv.org/abs/1912.08226)
3
+
4
+ pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git
5
+
6
+ Note:
7
+ Currently m2transformer is not performing as well as original transformer. Not sure why? Still investigating.
8
+ """
9
+
10
+ from __future__ import absolute_import
11
+ from __future__ import division
12
+ from __future__ import print_function
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ import copy
19
+ import math
20
+ import numpy as np
21
+
22
+ from .CaptionModel import CaptionModel
23
+ from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
24
+
25
+ try:
26
+ from m2transformer.models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory
27
+ except:
28
+ print('meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`')
29
+ from .TransformerModel import subsequent_mask, TransformerModel
30
+
31
+
32
+ class M2TransformerModel(TransformerModel):
33
+
34
+ def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
35
+ d_model=512, d_ff=2048, h=8, dropout=0.1):
36
+ "Helper: Construct a model from hyperparameters."
37
+ encoder = MemoryAugmentedEncoder(N_enc, 0, attention_module=ScaledDotProductAttentionMemory,
38
+ attention_module_kwargs={'m': 40})
39
+ # Another implementation is to use MultiLevelEncoder + att_embed
40
+ decoder = MeshedDecoder(tgt_vocab, 54, N_dec, -1) # -1 is padding;
41
+ model = Transformer(0, encoder, decoder) # 0 is bos
42
+ return model
43
+
44
+ def __init__(self, opt):
45
+ super(M2TransformerModel, self).__init__(opt)
46
+ delattr(self, 'att_embed')
47
+ self.att_embed = lambda x: x # The visual embed is in the MAEncoder
48
+ # Notes: The dropout in MAEncoder is different from my att_embed, mine is 0.5?
49
+ # Also the attention mask seems wrong in MAEncoder too...intersting
50
+
51
+ def logit(self, x): # unsafe way
52
+ return x # M2transformer always output logsoftmax
53
+
54
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
55
+
56
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
57
+ memory, att_masks = self.model.encoder(att_feats)
58
+
59
+ return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
60
+
61
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
62
+ if seq.ndim == 3: # B * seq_per_img * seq_len
63
+ seq = seq.reshape(-1, seq.shape[2])
64
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
65
+
66
+ seq = seq.clone()
67
+ seq[~seq_mask.any(-2)] = -1 # Make padding to be -1 (my dataloader uses 0 as padding)
68
+ outputs = self.model(att_feats, seq)
69
+
70
+ return outputs
71
+
72
+ def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
73
+ """
74
+ state = [ys.unsqueeze(0)]
75
+ """
76
+ if len(state) == 0:
77
+ ys = it.unsqueeze(1)
78
+ else:
79
+ ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
80
+ out = self.model.decoder(ys, memory, mask)
81
+ return out[:, -1], [ys.unsqueeze(0)]
82
+
83
+ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
84
+ beam_size = opt.get('beam_size', 10)
85
+ group_size = opt.get('group_size', 1)
86
+ sample_n = opt.get('sample_n', 10)
87
+ assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
88
+
89
+ att_feats, _, __, ___ = self._prepare_feature_forward(att_feats, att_masks)
90
+ seq, logprobs, seqLogprobs = self.model.beam_search(att_feats, self.seq_length, 0,
91
+ beam_size, return_probs=True, out_size=beam_size)
92
+ seq = seq.reshape(-1, *seq.shape[2:])
93
+ seqLogprobs = seqLogprobs.reshape(-1, *seqLogprobs.shape[2:])
94
+
95
+ # if not (seqLogprobs.gather(-1, seq.unsqueeze(-1)).squeeze(-1) == logprobs.reshape(-1, logprobs.shape[-1])).all():
96
+ # import pudb;pu.db
97
+ # seqLogprobs = logprobs.reshape(-1, logprobs.shape[-1]).unsqueeze(-1).expand(-1,-1,seqLogprobs.shape[-1])
98
+ return seq, seqLogprobs
captioning/models/ShowTellModel.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.autograd import *
9
+ from . import utils
10
+
11
+ from .CaptionModel import CaptionModel
12
+
13
+ class ShowTellModel(CaptionModel):
14
+ def __init__(self, opt):
15
+ super(ShowTellModel, self).__init__()
16
+ self.vocab_size = opt.vocab_size
17
+ self.input_encoding_size = opt.input_encoding_size
18
+ self.rnn_type = opt.rnn_type
19
+ self.rnn_size = opt.rnn_size
20
+ self.num_layers = opt.num_layers
21
+ self.drop_prob_lm = opt.drop_prob_lm
22
+ self.seq_length = opt.seq_length
23
+ self.fc_feat_size = opt.fc_feat_size
24
+
25
+ self.ss_prob = 0.0 # Schedule sampling probability
26
+
27
+ self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
28
+ self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm)
29
+ self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
30
+ self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
31
+ self.dropout = nn.Dropout(self.drop_prob_lm)
32
+
33
+ self.init_weights()
34
+
35
+ def init_weights(self):
36
+ initrange = 0.1
37
+ self.embed.weight.data.uniform_(-initrange, initrange)
38
+ self.logit.bias.data.fill_(0)
39
+ self.logit.weight.data.uniform_(-initrange, initrange)
40
+
41
+ def init_hidden(self, bsz):
42
+ weight = self.logit.weight
43
+ if self.rnn_type == 'lstm':
44
+ return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
45
+ weight.new_zeros(self.num_layers, bsz, self.rnn_size))
46
+ else:
47
+ return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
48
+
49
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
50
+ batch_size = fc_feats.size(0)
51
+ seq_per_img = seq.shape[0] // batch_size
52
+ state = self.init_hidden(batch_size*seq_per_img)
53
+ outputs = []
54
+
55
+ if seq_per_img > 1:
56
+ fc_feats = utils.repeat_tensors(seq_per_img, fc_feats)
57
+
58
+ for i in range(seq.size(1) + 1):
59
+ if i == 0:
60
+ xt = self.img_embed(fc_feats)
61
+ else:
62
+ if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
63
+ sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1)
64
+ sample_mask = sample_prob < self.ss_prob
65
+ if sample_mask.sum() == 0:
66
+ it = seq[:, i-1].clone()
67
+ else:
68
+ sample_ind = sample_mask.nonzero().view(-1)
69
+ it = seq[:, i-1].data.clone()
70
+ #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
71
+ #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
72
+ prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
73
+ it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
74
+ else:
75
+ it = seq[:, i-1].clone()
76
+ # break if all the sequences end
77
+ if i >= 2 and seq[:, i-1].data.sum() == 0:
78
+ break
79
+ xt = self.embed(it)
80
+
81
+ output, state = self.core(xt.unsqueeze(0), state)
82
+ output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
83
+ outputs.append(output)
84
+
85
+ return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
86
+
87
+ def get_logprobs_state(self, it, state):
88
+ # 'it' contains a word index
89
+ xt = self.embed(it)
90
+
91
+ output, state = self.core(xt.unsqueeze(0), state)
92
+ logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
93
+
94
+ return logprobs, state
95
+
96
+ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
97
+ beam_size = opt.get('beam_size', 10)
98
+ batch_size = fc_feats.size(0)
99
+
100
+ assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
101
+ seq = torch.LongTensor(self.seq_length, batch_size).zero_()
102
+ seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
103
+ # lets process every image independently for now, for simplicity
104
+
105
+ self.done_beams = [[] for _ in range(batch_size)]
106
+ for k in range(batch_size):
107
+ state = self.init_hidden(beam_size)
108
+ for t in range(2):
109
+ if t == 0:
110
+ xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
111
+ elif t == 1: # input <bos>
112
+ it = fc_feats.data.new(beam_size).long().zero_()
113
+ xt = self.embed(it)
114
+
115
+ output, state = self.core(xt.unsqueeze(0), state)
116
+ logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
117
+
118
+ self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
119
+ seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
120
+ seqLogprobs[:, k] = self.done_beams[k][0]['logps']
121
+ # return the samples and their log likelihoods
122
+ return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
123
+
124
+ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
125
+ sample_method = opt.get('sample_method', 'greedy')
126
+ beam_size = opt.get('beam_size', 1)
127
+ temperature = opt.get('temperature', 1.0)
128
+ if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
129
+ return self.sample_beam(fc_feats, att_feats, opt)
130
+
131
+ batch_size = fc_feats.size(0)
132
+ state = self.init_hidden(batch_size)
133
+ seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
134
+ seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
135
+ for t in range(self.seq_length + 2):
136
+ if t == 0:
137
+ xt = self.img_embed(fc_feats)
138
+ else:
139
+ if t == 1: # input <bos>
140
+ it = fc_feats.data.new(batch_size).long().zero_()
141
+ xt = self.embed(it)
142
+
143
+ output, state = self.core(xt.unsqueeze(0), state)
144
+ logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
145
+
146
+ # sample the next word
147
+ if t == self.seq_length + 1: # skip if we achieve maximum length
148
+ break
149
+ if sample_method == 'greedy':
150
+ sampleLogprobs, it = torch.max(logprobs.data, 1)
151
+ it = it.view(-1).long()
152
+ else:
153
+ if temperature == 1.0:
154
+ prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
155
+ else:
156
+ # scale logprobs by temperature
157
+ prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
158
+ it = torch.multinomial(prob_prev, 1).to(logprobs.device)
159
+ sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
160
+ it = it.view(-1).long() # and flatten indices for downstream processing
161
+
162
+ if t >= 1:
163
+ # stop when all finished
164
+ if t == 1:
165
+ unfinished = it > 0
166
+ else:
167
+ unfinished = unfinished & (it > 0)
168
+ it = it * unfinished.type_as(it)
169
+ seq[:,t-1] = it #seq[t] the input of t+2 time step
170
+ seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
171
+ if unfinished.sum() == 0:
172
+ break
173
+
174
+ return seq, seqLogprobs
captioning/models/TransformerModel.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains Transformer network
2
+ # Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
3
+
4
+ # The cfg name correspondance:
5
+ # N=num_layers
6
+ # d_model=input_encoding_size
7
+ # d_ff=rnn_size
8
+ # h is always 8
9
+
10
+ from __future__ import absolute_import
11
+ from __future__ import division
12
+ from __future__ import print_function
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from . import utils
18
+
19
+ import copy
20
+ import math
21
+ import numpy as np
22
+
23
+ from .CaptionModel import CaptionModel
24
+ from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
25
+
26
+ class EncoderDecoder(nn.Module):
27
+ """
28
+ A standard Encoder-Decoder architecture. Base for this and many
29
+ other models.
30
+ """
31
+ def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
32
+ super(EncoderDecoder, self).__init__()
33
+ self.encoder = encoder
34
+ self.decoder = decoder
35
+ self.src_embed = src_embed
36
+ self.tgt_embed = tgt_embed
37
+ self.generator = generator
38
+
39
+ def forward(self, src, tgt, src_mask, tgt_mask):
40
+ "Take in and process masked src and target sequences."
41
+ return self.decode(self.encode(src, src_mask), src_mask,
42
+ tgt, tgt_mask)
43
+
44
+ def encode(self, src, src_mask):
45
+ return self.encoder(self.src_embed(src), src_mask)
46
+
47
+ def decode(self, memory, src_mask, tgt, tgt_mask):
48
+ return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
49
+
50
+ class Generator(nn.Module):
51
+ "Define standard linear + softmax generation step."
52
+ def __init__(self, d_model, vocab):
53
+ super(Generator, self).__init__()
54
+ self.proj = nn.Linear(d_model, vocab)
55
+
56
+ def forward(self, x):
57
+ return F.log_softmax(self.proj(x), dim=-1)
58
+
59
+ def clones(module, N):
60
+ "Produce N identical layers."
61
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
62
+
63
+ class Encoder(nn.Module):
64
+ "Core encoder is a stack of N layers"
65
+ def __init__(self, layer, N):
66
+ super(Encoder, self).__init__()
67
+ self.layers = clones(layer, N)
68
+ self.norm = LayerNorm(layer.size)
69
+
70
+ def forward(self, x, mask):
71
+ "Pass the input (and mask) through each layer in turn."
72
+ for layer in self.layers:
73
+ x = layer(x, mask)
74
+ return self.norm(x)
75
+
76
+ class LayerNorm(nn.Module):
77
+ "Construct a layernorm module (See citation for details)."
78
+ def __init__(self, features, eps=1e-6):
79
+ super(LayerNorm, self).__init__()
80
+ self.a_2 = nn.Parameter(torch.ones(features))
81
+ self.b_2 = nn.Parameter(torch.zeros(features))
82
+ self.eps = eps
83
+
84
+ def forward(self, x):
85
+ mean = x.mean(-1, keepdim=True)
86
+ std = x.std(-1, keepdim=True)
87
+ return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
88
+
89
+ class SublayerConnection(nn.Module):
90
+ """
91
+ A residual connection followed by a layer norm.
92
+ Note for code simplicity the norm is first as opposed to last.
93
+ """
94
+ def __init__(self, size, dropout):
95
+ super(SublayerConnection, self).__init__()
96
+ self.norm = LayerNorm(size)
97
+ self.dropout = nn.Dropout(dropout)
98
+
99
+ def forward(self, x, sublayer):
100
+ "Apply residual connection to any sublayer with the same size."
101
+ return x + self.dropout(sublayer(self.norm(x)))
102
+
103
+ class EncoderLayer(nn.Module):
104
+ "Encoder is made up of self-attn and feed forward (defined below)"
105
+ def __init__(self, size, self_attn, feed_forward, dropout):
106
+ super(EncoderLayer, self).__init__()
107
+ self.self_attn = self_attn
108
+ self.feed_forward = feed_forward
109
+ self.sublayer = clones(SublayerConnection(size, dropout), 2)
110
+ self.size = size
111
+
112
+ def forward(self, x, mask):
113
+ "Follow Figure 1 (left) for connections."
114
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
115
+ return self.sublayer[1](x, self.feed_forward)
116
+
117
+ class Decoder(nn.Module):
118
+ "Generic N layer decoder with masking."
119
+ def __init__(self, layer, N):
120
+ super(Decoder, self).__init__()
121
+ self.layers = clones(layer, N)
122
+ self.norm = LayerNorm(layer.size)
123
+
124
+ def forward(self, x, memory, src_mask, tgt_mask):
125
+ for layer in self.layers:
126
+ x = layer(x, memory, src_mask, tgt_mask)
127
+ return self.norm(x)
128
+
129
+ class DecoderLayer(nn.Module):
130
+ "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
131
+ def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
132
+ super(DecoderLayer, self).__init__()
133
+ self.size = size
134
+ self.self_attn = self_attn
135
+ self.src_attn = src_attn
136
+ self.feed_forward = feed_forward
137
+ self.sublayer = clones(SublayerConnection(size, dropout), 3)
138
+
139
+ def forward(self, x, memory, src_mask, tgt_mask):
140
+ "Follow Figure 1 (right) for connections."
141
+ m = memory
142
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
143
+ x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
144
+ return self.sublayer[2](x, self.feed_forward)
145
+
146
+ def subsequent_mask(size):
147
+ "Mask out subsequent positions."
148
+ attn_shape = (1, size, size)
149
+ subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
150
+ return torch.from_numpy(subsequent_mask) == 0
151
+
152
+ def attention(query, key, value, mask=None, dropout=None):
153
+ "Compute 'Scaled Dot Product Attention'"
154
+ d_k = query.size(-1)
155
+ scores = torch.matmul(query, key.transpose(-2, -1)) \
156
+ / math.sqrt(d_k)
157
+ if mask is not None:
158
+ scores = scores.masked_fill(mask == 0, float('-inf'))
159
+ p_attn = F.softmax(scores, dim = -1)
160
+ if dropout is not None:
161
+ p_attn = dropout(p_attn)
162
+ return torch.matmul(p_attn, value), p_attn
163
+
164
+ class MultiHeadedAttention(nn.Module):
165
+ def __init__(self, h, d_model, dropout=0.1):
166
+ "Take in model size and number of heads."
167
+ super(MultiHeadedAttention, self).__init__()
168
+ assert d_model % h == 0
169
+ # We assume d_v always equals d_k
170
+ self.d_k = d_model // h
171
+ self.h = h
172
+ self.linears = clones(nn.Linear(d_model, d_model), 4)
173
+ self.attn = None
174
+ self.dropout = nn.Dropout(p=dropout)
175
+
176
+ def forward(self, query, key, value, mask=None):
177
+ "Implements Figure 2"
178
+ if mask is not None:
179
+ # Same mask applied to all h heads.
180
+ mask = mask.unsqueeze(1)
181
+ nbatches = query.size(0)
182
+
183
+ # 1) Do all the linear projections in batch from d_model => h x d_k
184
+ query, key, value = \
185
+ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
186
+ for l, x in zip(self.linears, (query, key, value))]
187
+
188
+ # 2) Apply attention on all the projected vectors in batch.
189
+ x, self.attn = attention(query, key, value, mask=mask,
190
+ dropout=self.dropout)
191
+
192
+ # 3) "Concat" using a view and apply a final linear.
193
+ x = x.transpose(1, 2).contiguous() \
194
+ .view(nbatches, -1, self.h * self.d_k)
195
+ return self.linears[-1](x)
196
+
197
+ class PositionwiseFeedForward(nn.Module):
198
+ "Implements FFN equation."
199
+ def __init__(self, d_model, d_ff, dropout=0.1):
200
+ super(PositionwiseFeedForward, self).__init__()
201
+ self.w_1 = nn.Linear(d_model, d_ff)
202
+ self.w_2 = nn.Linear(d_ff, d_model)
203
+ self.dropout = nn.Dropout(dropout)
204
+
205
+ def forward(self, x):
206
+ return self.w_2(self.dropout(F.relu(self.w_1(x))))
207
+
208
+ class Embeddings(nn.Module):
209
+ def __init__(self, d_model, vocab):
210
+ super(Embeddings, self).__init__()
211
+ self.lut = nn.Embedding(vocab, d_model)
212
+ self.d_model = d_model
213
+
214
+ def forward(self, x):
215
+ return self.lut(x) * math.sqrt(self.d_model)
216
+
217
+ class PositionalEncoding(nn.Module):
218
+ "Implement the PE function."
219
+ def __init__(self, d_model, dropout, max_len=5000):
220
+ super(PositionalEncoding, self).__init__()
221
+ self.dropout = nn.Dropout(p=dropout)
222
+
223
+ # Compute the positional encodings once in log space.
224
+ pe = torch.zeros(max_len, d_model)
225
+ position = torch.arange(0, max_len).unsqueeze(1).float()
226
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
227
+ -(math.log(10000.0) / d_model))
228
+ pe[:, 0::2] = torch.sin(position * div_term)
229
+ pe[:, 1::2] = torch.cos(position * div_term)
230
+ pe = pe.unsqueeze(0)
231
+ self.register_buffer('pe', pe)
232
+
233
+ def forward(self, x):
234
+ x = x + self.pe[:, :x.size(1)]
235
+ return self.dropout(x)
236
+
237
+ class TransformerModel(AttModel):
238
+
239
+ def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
240
+ d_model=512, d_ff=2048, h=8, dropout=0.1):
241
+ "Helper: Construct a model from hyperparameters."
242
+ c = copy.deepcopy
243
+ attn = MultiHeadedAttention(h, d_model, dropout)
244
+ ff = PositionwiseFeedForward(d_model, d_ff, dropout)
245
+ position = PositionalEncoding(d_model, dropout)
246
+ model = EncoderDecoder(
247
+ Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc),
248
+ Decoder(DecoderLayer(d_model, c(attn), c(attn),
249
+ c(ff), dropout), N_dec),
250
+ lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
251
+ nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
252
+ Generator(d_model, tgt_vocab))
253
+
254
+ # This was important from their code.
255
+ # Initialize parameters with Glorot / fan_avg.
256
+ for p in model.parameters():
257
+ if p.dim() > 1:
258
+ nn.init.xavier_uniform_(p)
259
+ return model
260
+
261
+ def __init__(self, opt):
262
+ super(TransformerModel, self).__init__(opt)
263
+ self.opt = opt
264
+ # self.config = yaml.load(open(opt.config_file))
265
+
266
+ self.N_enc = getattr(opt, 'N_enc', opt.num_layers)
267
+ self.N_dec = getattr(opt, 'N_dec', opt.num_layers)
268
+ self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
269
+ self.d_ff = getattr(opt, 'd_ff', opt.rnn_size)
270
+ self.h = getattr(opt, 'num_att_heads', 8)
271
+ self.dropout = getattr(opt, 'dropout', 0.1)
272
+
273
+ delattr(self, 'att_embed')
274
+ self.att_embed = nn.Sequential(*(
275
+ ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
276
+ (nn.Linear(self.att_feat_size, self.d_model),
277
+ nn.ReLU(),
278
+ nn.Dropout(self.drop_prob_lm))+
279
+ ((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ())))
280
+
281
+ delattr(self, 'embed')
282
+ self.embed = lambda x : x
283
+ delattr(self, 'fc_embed')
284
+ self.fc_embed = lambda x : x
285
+ delattr(self, 'logit')
286
+ del self.ctx2att
287
+
288
+ tgt_vocab = self.vocab_size + 1
289
+
290
+
291
+ self.model = self.make_model(0, tgt_vocab,
292
+ N_enc=self.N_enc,
293
+ N_dec=self.N_dec,
294
+ d_model=self.d_model,
295
+ d_ff=self.d_ff,
296
+ h=self.h,
297
+ dropout=self.dropout)
298
+
299
+ def logit(self, x): # unsafe way
300
+ return self.model.generator.proj(x)
301
+
302
+ def init_hidden(self, bsz):
303
+ return []
304
+
305
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
306
+
307
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
308
+ memory = self.model.encode(att_feats, att_masks)
309
+
310
+ return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
311
+
312
+ def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
313
+ att_feats, att_masks = self.clip_att(att_feats, att_masks)
314
+
315
+ att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
316
+
317
+ if att_masks is None:
318
+ att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
319
+ att_masks = att_masks.unsqueeze(-2)
320
+
321
+ if seq is not None:
322
+ # crop the last one
323
+ # seq = seq[:,:-1]
324
+ seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx)
325
+ seq_mask[:,0] = 1 # bos
326
+
327
+ seq_mask = seq_mask.unsqueeze(-2)
328
+ seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
329
+
330
+ seq_per_img = seq.shape[0] // att_feats.shape[0]
331
+ if seq_per_img > 1:
332
+ att_feats, att_masks = utils.repeat_tensors(seq_per_img,
333
+ [att_feats, att_masks]
334
+ )
335
+ else:
336
+ seq_mask = None
337
+
338
+ return att_feats, seq, att_masks, seq_mask
339
+
340
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
341
+ if seq.ndim == 3: # B * seq_per_img * seq_len
342
+ seq = seq.reshape(-1, seq.shape[2])
343
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
344
+
345
+ out = self.model(att_feats, seq, att_masks, seq_mask)
346
+
347
+ outputs = self.model.generator(out)
348
+ return outputs
349
+ # return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
350
+
351
+ def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
352
+ """
353
+ state = [ys.unsqueeze(0)]
354
+ """
355
+ if len(state) == 0:
356
+ ys = it.unsqueeze(1)
357
+ else:
358
+ ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
359
+ out = self.model.decode(memory, mask,
360
+ ys,
361
+ subsequent_mask(ys.size(1))
362
+ .to(memory.device))
363
+ return out[:, -1], [ys.unsqueeze(0)]
captioning/models/__init__.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import os
6
+ import copy
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from .ShowTellModel import ShowTellModel
12
+ from .FCModel import FCModel
13
+ from .AttModel import *
14
+ from .TransformerModel import TransformerModel
15
+ from .cachedTransformer import TransformerModel as cachedTransformer
16
+ from .BertCapModel import BertCapModel
17
+ from .M2Transformer import M2TransformerModel
18
+ from .AoAModel import AoAModel
19
+
20
+ def setup(opt):
21
+ if opt.caption_model in ['fc', 'show_tell']:
22
+ print('Warning: %s model is mostly deprecated; many new features are not supported.' %opt.caption_model)
23
+ if opt.caption_model == 'fc':
24
+ print('Use newfc instead of fc')
25
+ if opt.caption_model == 'fc':
26
+ model = FCModel(opt)
27
+ elif opt.caption_model == 'language_model':
28
+ model = LMModel(opt)
29
+ elif opt.caption_model == 'newfc':
30
+ model = NewFCModel(opt)
31
+ elif opt.caption_model == 'show_tell':
32
+ model = ShowTellModel(opt)
33
+ # Att2in model in self-critical
34
+ elif opt.caption_model == 'att2in':
35
+ model = Att2inModel(opt)
36
+ # Att2in model with two-layer MLP img embedding and word embedding
37
+ elif opt.caption_model == 'att2in2':
38
+ model = Att2in2Model(opt)
39
+ elif opt.caption_model == 'att2all2':
40
+ print('Warning: this is not a correct implementation of the att2all model in the original paper.')
41
+ model = Att2all2Model(opt)
42
+ # Adaptive Attention model from Knowing when to look
43
+ elif opt.caption_model == 'adaatt':
44
+ model = AdaAttModel(opt)
45
+ # Adaptive Attention with maxout lstm
46
+ elif opt.caption_model == 'adaattmo':
47
+ model = AdaAttMOModel(opt)
48
+ # Top-down attention model
49
+ elif opt.caption_model in ['topdown', 'updown']:
50
+ model = UpDownModel(opt)
51
+ # StackAtt
52
+ elif opt.caption_model == 'stackatt':
53
+ model = StackAttModel(opt)
54
+ # DenseAtt
55
+ elif opt.caption_model == 'denseatt':
56
+ model = DenseAttModel(opt)
57
+ # Transformer
58
+ elif opt.caption_model == 'transformer':
59
+ if getattr(opt, 'cached_transformer', False):
60
+ model = cachedTransformer(opt)
61
+ else:
62
+ model = TransformerModel(opt)
63
+ # AoANet
64
+ elif opt.caption_model == 'aoa':
65
+ model = AoAModel(opt)
66
+ elif opt.caption_model == 'bert':
67
+ model = BertCapModel(opt)
68
+ elif opt.caption_model == 'm2transformer':
69
+ model = M2TransformerModel(opt)
70
+ else:
71
+ raise Exception("Caption model not supported: {}".format(opt.caption_model))
72
+
73
+ return model
captioning/models/cachedTransformer.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains Transformer network
2
+ # Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
3
+
4
+ # The cfg name correspondance:
5
+ # N=num_layers
6
+ # d_model=input_encoding_size
7
+ # d_ff=rnn_size
8
+ # h is always 8
9
+
10
+ from __future__ import absolute_import
11
+ from __future__ import division
12
+ from __future__ import print_function
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from . import utils
18
+
19
+ import copy
20
+ import math
21
+ import numpy as np
22
+
23
+ from .CaptionModel import CaptionModel
24
+ from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
25
+
26
+ class EncoderDecoder(nn.Module):
27
+ """
28
+ A standard Encoder-Decoder architecture. Base for this and many
29
+ other models.
30
+ """
31
+ def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
32
+ super(EncoderDecoder, self).__init__()
33
+ self.encoder = encoder
34
+ self.decoder = decoder
35
+ self.src_embed = src_embed
36
+ self.tgt_embed = tgt_embed
37
+ self.generator = generator
38
+
39
+ def forward(self, src, tgt, src_mask, tgt_mask):
40
+ "Take in and process masked src and target sequences."
41
+ return self.decode(self.encode(src, src_mask), src_mask,
42
+ tgt, tgt_mask)
43
+
44
+ def encode(self, src, src_mask):
45
+ return self.encoder(self.src_embed(src), src_mask)
46
+
47
+ def decode(self, memory, src_mask, tgt, tgt_mask, past=None):
48
+ return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask, past=past)
49
+
50
+ class Generator(nn.Module):
51
+ "Define standard linear + softmax generation step."
52
+ def __init__(self, d_model, vocab):
53
+ super(Generator, self).__init__()
54
+ self.proj = nn.Linear(d_model, vocab)
55
+
56
+ def forward(self, x):
57
+ return F.log_softmax(self.proj(x), dim=-1)
58
+
59
+ def clones(module, N):
60
+ "Produce N identical layers."
61
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
62
+
63
+ class Encoder(nn.Module):
64
+ "Core encoder is a stack of N layers"
65
+ def __init__(self, layer, N):
66
+ super(Encoder, self).__init__()
67
+ self.layers = clones(layer, N)
68
+ self.norm = LayerNorm(layer.size)
69
+
70
+ def forward(self, x, mask):
71
+ "Pass the input (and mask) through each layer in turn."
72
+ for layer in self.layers:
73
+ x = layer(x, mask)
74
+ return self.norm(x)
75
+
76
+ class LayerNorm(nn.Module):
77
+ "Construct a layernorm module (See citation for details)."
78
+ def __init__(self, features, eps=1e-6):
79
+ super(LayerNorm, self).__init__()
80
+ self.a_2 = nn.Parameter(torch.ones(features))
81
+ self.b_2 = nn.Parameter(torch.zeros(features))
82
+ self.eps = eps
83
+
84
+ def forward(self, x):
85
+ mean = x.mean(-1, keepdim=True)
86
+ std = x.std(-1, keepdim=True)
87
+ return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
88
+
89
+ class SublayerConnection(nn.Module):
90
+ """
91
+ A residual connection followed by a layer norm.
92
+ Note for code simplicity the norm is first as opposed to last.
93
+ """
94
+ def __init__(self, size, dropout):
95
+ super(SublayerConnection, self).__init__()
96
+ self.norm = LayerNorm(size)
97
+ self.dropout = nn.Dropout(dropout)
98
+
99
+ def forward(self, x, sublayer):
100
+ "Apply residual connection to any sublayer with the same size."
101
+ _x = sublayer(self.norm(x))
102
+ if type(_x) is tuple: # for multi-head attention that returns past
103
+ return x + self.dropout(_x[0]), _x[1]
104
+ return x + self.dropout(_x)
105
+
106
+ class EncoderLayer(nn.Module):
107
+ "Encoder is made up of self-attn and feed forward (defined below)"
108
+ def __init__(self, size, self_attn, feed_forward, dropout):
109
+ super(EncoderLayer, self).__init__()
110
+ self.self_attn = self_attn
111
+ self.feed_forward = feed_forward
112
+ self.sublayer = clones(SublayerConnection(size, dropout), 2)
113
+ self.size = size
114
+
115
+ def forward(self, x, mask):
116
+ "Follow Figure 1 (left) for connections."
117
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
118
+ return self.sublayer[1](x, self.feed_forward)
119
+
120
+ class Decoder(nn.Module):
121
+ "Generic N layer decoder with masking."
122
+ def __init__(self, layer, N):
123
+ super(Decoder, self).__init__()
124
+ self.layers = clones(layer, N)
125
+ self.norm = LayerNorm(layer.size)
126
+
127
+ def forward(self, x, memory, src_mask, tgt_mask, past=None):
128
+ if past is not None:
129
+ present = [[], []]
130
+ x = x[:, -1:]
131
+ tgt_mask = tgt_mask[:, -1:] if tgt_mask is not None else None
132
+ past = list(zip(past[0].split(2, dim=0), past[1].split(2, dim=0)))
133
+ else:
134
+ past = [None] * len(self.layers)
135
+ for i, (layer, layer_past) in enumerate(zip(self.layers, past)):
136
+ x = layer(x, memory, src_mask, tgt_mask,
137
+ layer_past)
138
+ if layer_past is not None:
139
+ present[0].append(x[1][0])
140
+ present[1].append(x[1][1])
141
+ x = x[0]
142
+ if past[0] is None:
143
+ return self.norm(x)
144
+ else:
145
+ return self.norm(x), [torch.cat(present[0], 0), torch.cat(present[1], 0)]
146
+
147
+
148
+ class DecoderLayer(nn.Module):
149
+ "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
150
+ def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
151
+ super(DecoderLayer, self).__init__()
152
+ self.size = size
153
+ self.self_attn = self_attn
154
+ self.src_attn = src_attn
155
+ self.feed_forward = feed_forward
156
+ self.sublayer = clones(SublayerConnection(size, dropout), 3)
157
+
158
+ def forward(self, x, memory, src_mask, tgt_mask, layer_past=None):
159
+ "Follow Figure 1 (right) for connections."
160
+ m = memory
161
+ if layer_past is None:
162
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
163
+ x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
164
+ return self.sublayer[2](x, self.feed_forward)
165
+ else:
166
+ present = [None, None]
167
+ x, present[0] = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, layer_past[0]))
168
+ x, present[1] = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask, layer_past[1]))
169
+ return self.sublayer[2](x, self.feed_forward), present
170
+
171
+ def subsequent_mask(size):
172
+ "Mask out subsequent positions."
173
+ attn_shape = (1, size, size)
174
+ subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
175
+ return torch.from_numpy(subsequent_mask) == 0
176
+
177
+ def attention(query, key, value, mask=None, dropout=None):
178
+ "Compute 'Scaled Dot Product Attention'"
179
+ d_k = query.size(-1)
180
+ scores = torch.matmul(query, key.transpose(-2, -1)) \
181
+ / math.sqrt(d_k)
182
+ if mask is not None:
183
+ scores = scores.masked_fill(mask == 0, float('-inf'))
184
+ p_attn = F.softmax(scores, dim = -1)
185
+ if dropout is not None:
186
+ p_attn = dropout(p_attn)
187
+ return torch.matmul(p_attn, value), p_attn
188
+
189
+ class MultiHeadedAttention(nn.Module):
190
+ def __init__(self, h, d_model, dropout=0.1):
191
+ "Take in model size and number of heads."
192
+ super(MultiHeadedAttention, self).__init__()
193
+ assert d_model % h == 0
194
+ # We assume d_v always equals d_k
195
+ self.d_k = d_model // h
196
+ self.h = h
197
+ self.linears = clones(nn.Linear(d_model, d_model), 4)
198
+ self.attn = None
199
+ self.dropout = nn.Dropout(p=dropout)
200
+
201
+ def forward(self, query, key, value, mask=None, layer_past=None):
202
+ "Implements Figure 2"
203
+ if mask is not None:
204
+ # Same mask applied to all h heads.
205
+ mask = mask.unsqueeze(1)
206
+ nbatches = query.size(0)
207
+
208
+ # The past works differently here. For self attn, the query and key be updated incrementailly
209
+ # For src_attn the past is fixed.
210
+
211
+ # For src_attn, when the layer past is ready
212
+ if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1: # suppose memory size always greater than 1
213
+ query = self.linears[0](query)
214
+ key, value = layer_past[0], layer_past[1]
215
+ present = torch.stack([key, value])
216
+ else:
217
+ # 1) Do all the linear projections in batch from d_model => h x d_k
218
+ query, key, value = \
219
+ [l(x) for l, x in zip(self.linears, (query, key, value))]
220
+
221
+ # self attn + past OR the first time step of src attn
222
+ if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1):
223
+ past_key, past_value = layer_past[0], layer_past[1]
224
+ key = torch.cat((past_key, key), dim=1)
225
+ value = torch.cat((past_value, value), dim=1)
226
+ present = torch.stack([key, value])
227
+
228
+ query, key, value = \
229
+ [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
230
+ for x in [query, key, value]]
231
+
232
+ # 2) Apply attention on all the projected vectors in batch.
233
+ x, self.attn = attention(query, key, value, mask=mask,
234
+ dropout=self.dropout)
235
+
236
+ # 3) "Concat" using a view and apply a final linear.
237
+ x = x.transpose(1, 2).contiguous() \
238
+ .view(nbatches, -1, self.h * self.d_k)
239
+ if layer_past is not None:
240
+ return self.linears[-1](x), present
241
+ else:
242
+ return self.linears[-1](x)
243
+
244
+ class PositionwiseFeedForward(nn.Module):
245
+ "Implements FFN equation."
246
+ def __init__(self, d_model, d_ff, dropout=0.1):
247
+ super(PositionwiseFeedForward, self).__init__()
248
+ self.w_1 = nn.Linear(d_model, d_ff)
249
+ self.w_2 = nn.Linear(d_ff, d_model)
250
+ self.dropout = nn.Dropout(dropout)
251
+
252
+ def forward(self, x):
253
+ return self.w_2(self.dropout(F.relu(self.w_1(x))))
254
+
255
+ class Embeddings(nn.Module):
256
+ def __init__(self, d_model, vocab):
257
+ super(Embeddings, self).__init__()
258
+ self.lut = nn.Embedding(vocab, d_model)
259
+ self.d_model = d_model
260
+
261
+ def forward(self, x):
262
+ return self.lut(x) * math.sqrt(self.d_model)
263
+
264
+ class PositionalEncoding(nn.Module):
265
+ "Implement the PE function."
266
+ def __init__(self, d_model, dropout, max_len=5000):
267
+ super(PositionalEncoding, self).__init__()
268
+ self.dropout = nn.Dropout(p=dropout)
269
+
270
+ # Compute the positional encodings once in log space.
271
+ pe = torch.zeros(max_len, d_model)
272
+ position = torch.arange(0, max_len).unsqueeze(1).float()
273
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
274
+ -(math.log(10000.0) / d_model))
275
+ pe[:, 0::2] = torch.sin(position * div_term)
276
+ pe[:, 1::2] = torch.cos(position * div_term)
277
+ pe = pe.unsqueeze(0)
278
+ self.register_buffer('pe', pe)
279
+
280
+ def forward(self, x):
281
+ x = x + self.pe[:, :x.size(1)]
282
+ return self.dropout(x)
283
+
284
+ class TransformerModel(AttModel):
285
+
286
+ def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
287
+ d_model=512, d_ff=2048, h=8, dropout=0.1):
288
+ "Helper: Construct a model from hyperparameters."
289
+ c = copy.deepcopy
290
+ attn = MultiHeadedAttention(h, d_model, dropout)
291
+ ff = PositionwiseFeedForward(d_model, d_ff, dropout)
292
+ position = PositionalEncoding(d_model, dropout)
293
+ model = EncoderDecoder(
294
+ Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc),
295
+ Decoder(DecoderLayer(d_model, c(attn), c(attn),
296
+ c(ff), dropout), N_dec),
297
+ lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
298
+ nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
299
+ Generator(d_model, tgt_vocab))
300
+
301
+ # This was important from their code.
302
+ # Initialize parameters with Glorot / fan_avg.
303
+ for p in model.parameters():
304
+ if p.dim() > 1:
305
+ nn.init.xavier_uniform_(p)
306
+ return model
307
+
308
+ def __init__(self, opt):
309
+ super(TransformerModel, self).__init__(opt)
310
+ self.opt = opt
311
+ # self.config = yaml.load(open(opt.config_file))
312
+
313
+ self.N_enc = getattr(opt, 'N_enc', opt.num_layers)
314
+ self.N_dec = getattr(opt, 'N_dec', opt.num_layers)
315
+ self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
316
+ self.d_ff = getattr(opt, 'd_ff', opt.rnn_size)
317
+ self.h = getattr(opt, 'num_att_heads', 8)
318
+ self.dropout = getattr(opt, 'dropout', 0.1)
319
+
320
+ delattr(self, 'att_embed')
321
+ self.att_embed = nn.Sequential(*(
322
+ ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
323
+ (nn.Linear(self.att_feat_size, self.d_model),
324
+ nn.ReLU(),
325
+ nn.Dropout(self.drop_prob_lm))+
326
+ ((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ())))
327
+
328
+ delattr(self, 'embed')
329
+ self.embed = lambda x : x
330
+ delattr(self, 'fc_embed')
331
+ self.fc_embed = lambda x : x
332
+ delattr(self, 'logit')
333
+ del self.ctx2att
334
+
335
+ tgt_vocab = self.vocab_size + 1
336
+
337
+
338
+ self.model = self.make_model(0, tgt_vocab,
339
+ N_enc=self.N_enc,
340
+ N_dec=self.N_dec,
341
+ d_model=self.d_model,
342
+ d_ff=self.d_ff,
343
+ h=self.h,
344
+ dropout=self.dropout)
345
+
346
+ def logit(self, x): # unsafe way
347
+ return self.model.generator.proj(x)
348
+
349
+ def init_hidden(self, bsz):
350
+ return []
351
+
352
+ def _prepare_feature(self, fc_feats, att_feats, att_masks):
353
+
354
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
355
+ memory = self.model.encode(att_feats, att_masks)
356
+
357
+ return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
358
+
359
+ def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
360
+ att_feats, att_masks = self.clip_att(att_feats, att_masks)
361
+
362
+ att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
363
+
364
+ if att_masks is None:
365
+ att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
366
+ att_masks = att_masks.unsqueeze(-2)
367
+
368
+ if seq is not None:
369
+ # crop the last one
370
+ # seq = seq[:,:-1]
371
+ seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx)
372
+ seq_mask[:,0] = 1 # bos
373
+
374
+ seq_mask = seq_mask.unsqueeze(-2)
375
+ seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
376
+
377
+ seq_per_img = seq.shape[0] // att_feats.shape[0]
378
+ if seq_per_img > 1:
379
+ att_feats, att_masks = utils.repeat_tensors(seq_per_img,
380
+ [att_feats, att_masks]
381
+ )
382
+ else:
383
+ seq_mask = None
384
+
385
+ return att_feats, seq, att_masks, seq_mask
386
+
387
+ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
388
+ if seq.ndim == 3: # B * seq_per_img * seq_len
389
+ seq = seq.reshape(-1, seq.shape[2])
390
+ att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
391
+
392
+ out = self.model(att_feats, seq, att_masks, seq_mask)
393
+
394
+ outputs = self.model.generator(out)
395
+ return outputs
396
+ # return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
397
+
398
+ def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
399
+ """
400
+ state is the precomputed key/value. N_dec x seq_len x d_model
401
+ Note: due to the layer norm, it's not equivalant to stateless,
402
+ but it seems behaving similar
403
+ """
404
+ # state is tokens + past
405
+ if len(state) == 0:
406
+ ys = it.unsqueeze(1)
407
+ # basically empty state, just to let it know to return past
408
+ # The second dim has to be batch_size, for beam search purpose
409
+ past = [fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model), # self
410
+ fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model)] # src
411
+ # 2 for self attn, 2 for src attn
412
+ else:
413
+ ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
414
+ past = state[1:]
415
+ out, past = self.model.decode(memory, mask,
416
+ ys, # We still feed the full past words, because we need it for position embedding to know the position id
417
+ subsequent_mask(ys.size(1))
418
+ .to(memory.device),
419
+ past=past)
420
+ return out[:, -1], [ys.unsqueeze(0)] + past
captioning/models/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def repeat_tensors(n, x):
4
+ """
5
+ For a tensor of size Bx..., we repeat it n times, and make it Bnx...
6
+ For collections, do nested repeat
7
+ """
8
+ if torch.is_tensor(x):
9
+ x = x.unsqueeze(1) # Bx1x...
10
+ x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx...
11
+ x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx...
12
+ elif type(x) is list or type(x) is tuple:
13
+ x = [repeat_tensors(n, _) for _ in x]
14
+ return x
15
+
16
+
17
+ def split_tensors(n, x):
18
+ if torch.is_tensor(x):
19
+ assert x.shape[0] % n == 0
20
+ x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
21
+ elif type(x) is list or type(x) is tuple:
22
+ x = [split_tensors(n, _) for _ in x]
23
+ elif x is None:
24
+ x = [None] * n
25
+ return x
captioning/modules/loss_wrapper.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import losses
3
+ from ..utils.rewards import init_scorer, get_self_critical_reward, get_self_critical_clipscore_reward
4
+ from ..utils.clipscore import CLIPScore
5
+ import numpy as np
6
+
7
+ class LossWrapper(torch.nn.Module):
8
+ def __init__(self, model, opt):
9
+ super(LossWrapper, self).__init__()
10
+ self.opt = opt
11
+ self.model = model
12
+ if opt.label_smoothing > 0:
13
+ self.crit = losses.LabelSmoothing(smoothing=opt.label_smoothing)
14
+ else:
15
+ self.crit = losses.LanguageModelCriterion()
16
+ self.rl_crit = losses.RewardCriterion()
17
+ self.struc_crit = losses.StructureLosses(opt)
18
+
19
+ self.clipscore_model = None
20
+ if self.opt.use_clipscore:
21
+ use_grammar = getattr(self.opt, 'use_grammar', False)
22
+ joint_out = getattr(self.opt, 'joint_out', False)
23
+ self.clipscore_model = CLIPScore(
24
+ mode=opt.clipscore_mode,
25
+ use_grammar=use_grammar,
26
+ joint_out=joint_out,
27
+ )
28
+ for p in self.clipscore_model.parameters():
29
+ p.requires_grad = False
30
+
31
+ if use_grammar:
32
+ state_dict = torch.load(self.opt.clip_load_path, map_location='cpu')
33
+ self.clipscore_model.load_state_dict(state_dict['state_dict'])
34
+
35
+ def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices,
36
+ sc_flag, struc_flag, clip_vis_feats=None):
37
+ opt = self.opt
38
+
39
+ out = {}
40
+ if struc_flag:
41
+ if opt.structure_loss_weight < 1:
42
+ lm_loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
43
+ else:
44
+ lm_loss = torch.tensor(0).type_as(fc_feats)
45
+ if opt.structure_loss_weight > 0:
46
+ gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
47
+ opt={'sample_method':opt.train_sample_method,
48
+ 'beam_size':opt.train_beam_size,
49
+ 'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\
50
+ or not 'margin' in opt.structure_loss_type,
51
+ 'sample_n': opt.train_sample_n},
52
+ mode='sample')
53
+ gts = [gts[_] for _ in gt_indices.tolist()]
54
+ struc_loss = self.struc_crit(sample_logprobs, gen_result, gts)
55
+ else:
56
+ struc_loss = {'loss': torch.tensor(0).type_as(fc_feats),
57
+ 'reward': torch.tensor(0).type_as(fc_feats)}
58
+ loss = (1-opt.structure_loss_weight) * lm_loss + opt.structure_loss_weight * struc_loss['loss']
59
+ out['lm_loss'] = lm_loss
60
+ out['struc_loss'] = struc_loss['loss']
61
+ out['reward'] = struc_loss['reward']
62
+ elif not sc_flag:
63
+ loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
64
+ else:
65
+ self.model.eval()
66
+ with torch.no_grad():
67
+ greedy_res, _ = self.model(fc_feats, att_feats, att_masks,
68
+ mode='sample',
69
+ opt={'sample_method': opt.sc_sample_method,
70
+ 'beam_size': opt.sc_beam_size})
71
+ self.model.train()
72
+ gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
73
+ opt={'sample_method':opt.train_sample_method,
74
+ 'beam_size':opt.train_beam_size,
75
+ 'sample_n': opt.train_sample_n},
76
+ mode='sample')
77
+ gts = [gts[_] for _ in gt_indices.tolist()]
78
+
79
+ if getattr(self.opt, 'use_multi_rewards', False):
80
+ assert self.opt.use_clipscore
81
+ clipscore_reward_normalized, clipscore_unnormalized_mean, grammar_rewards = get_self_critical_clipscore_reward(
82
+ greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab)
83
+
84
+ if self.opt.clipscore_mode == 'clip_s':
85
+ out['CLIP-S'] = clipscore_unnormalized_mean
86
+ elif self.opt.clipscore_mode == 'refclip_s':
87
+ out['RefCLIP-S'] = clipscore_unnormalized_mean
88
+
89
+ if getattr(self.opt, 'use_grammar', False):
90
+ out['grammar_reward'] = grammar_rewards.mean()
91
+
92
+ reward = clipscore_reward_normalized + grammar_rewards
93
+
94
+
95
+ else:
96
+ assert grammar_rewards is None
97
+
98
+ cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward(
99
+ greedy_res, gts, gen_result, self.opt)
100
+ out['CIDEr'] = cider_unnormalized_mean
101
+ if isinstance(cider_reward_normalized, np.ndarray):
102
+ cider_reward_normalized = torch.from_numpy(cider_reward_normalized).to(clipscore_reward_normalized.device)
103
+
104
+ reward = clipscore_reward_normalized + cider_reward_normalized
105
+ else:
106
+ if self.opt.use_clipscore:
107
+ clipscore_reward_normalized, clipscore_unnormalized_mean, _ = get_self_critical_clipscore_reward(
108
+ greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab)
109
+ if self.opt.clipscore_mode == 'clip_s':
110
+ out['CLIP-S'] = clipscore_unnormalized_mean
111
+ elif self.opt.clipscore_mode == 'refclip_s':
112
+ out['RefCLIP-S'] = clipscore_unnormalized_mean
113
+ reward = clipscore_reward_normalized
114
+ else:
115
+ cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward(
116
+ greedy_res, gts, gen_result, self.opt)
117
+ out['CIDEr'] = cider_unnormalized_mean
118
+ reward = cider_reward_normalized
119
+
120
+ if isinstance(reward, np.ndarray):
121
+ reward = torch.from_numpy(reward)
122
+ reward = reward.to(sample_logprobs)
123
+ loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
124
+ out['reward'] = reward[:,0].mean()
125
+ out['loss'] = loss
126
+ return out
127
+
captioning/modules/losses.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ..utils.rewards import get_scores, get_self_cider_scores
4
+
5
+ class RewardCriterion(nn.Module):
6
+ def __init__(self):
7
+ super(RewardCriterion, self).__init__()
8
+
9
+ def forward(self, input, seq, reward):
10
+ input = input.gather(2, seq.unsqueeze(2)).squeeze(2)
11
+
12
+ input = input.reshape(-1)
13
+ reward = reward.reshape(-1)
14
+ mask = (seq>0).to(input)
15
+ mask = torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1).reshape(-1)
16
+ output = - input * reward * mask
17
+ output = torch.sum(output) / torch.sum(mask)
18
+
19
+ return output
20
+
21
+ class StructureLosses(nn.Module):
22
+ """
23
+ This loss is inspired by Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018).
24
+ """
25
+ def __init__(self, opt):
26
+ super(StructureLosses, self).__init__()
27
+ self.opt = opt
28
+ self.loss_type = opt.structure_loss_type
29
+
30
+ def forward(self, input, seq, data_gts):
31
+ """
32
+ Input is either logits or log softmax
33
+ """
34
+ out = {}
35
+
36
+ batch_size = input.size(0)# batch_size = sample_size * seq_per_img
37
+ seq_per_img = batch_size // len(data_gts)
38
+
39
+ assert seq_per_img == self.opt.train_sample_n, seq_per_img
40
+
41
+ mask = (seq>0).to(input)
42
+ mask = torch.cat([mask.new_full((mask.size(0), 1), 1), mask[:, :-1]], 1)
43
+
44
+ scores = get_scores(data_gts, seq, self.opt)
45
+ scores = torch.from_numpy(scores).type_as(input).view(-1, seq_per_img)
46
+ out['reward'] = scores #.mean()
47
+ if self.opt.entropy_reward_weight > 0:
48
+ entropy = - (F.softmax(input, dim=2) * F.log_softmax(input, dim=2)).sum(2).data
49
+ entropy = (entropy * mask).sum(1) / mask.sum(1)
50
+ print('entropy', entropy.mean().item())
51
+ scores = scores + self.opt.entropy_reward_weight * entropy.view(-1, seq_per_img)
52
+ # rescale cost to [0,1]
53
+ costs = - scores
54
+ if self.loss_type == 'risk' or self.loss_type == 'softmax_margin':
55
+ costs = costs - costs.min(1, keepdim=True)[0]
56
+ costs = costs / costs.max(1, keepdim=True)[0]
57
+ # in principle
58
+ # Only risk need such rescale
59
+ # margin should be alright; Let's try.
60
+
61
+ # Gather input: BxTxD -> BxT
62
+ input = input.gather(2, seq.unsqueeze(2)).squeeze(2)
63
+
64
+ if self.loss_type == 'seqnll':
65
+ # input is logsoftmax
66
+ input = input * mask
67
+ input = input.sum(1) / mask.sum(1)
68
+ input = input.view(-1, seq_per_img)
69
+
70
+ target = costs.min(1)[1]
71
+ output = F.cross_entropy(input, target)
72
+ elif self.loss_type == 'risk':
73
+ # input is logsoftmax
74
+ input = input * mask
75
+ input = input.sum(1)
76
+ input = input.view(-1, seq_per_img)
77
+
78
+ output = (F.softmax(input.exp()) * costs).sum(1).mean()
79
+
80
+ # test
81
+ # avg_scores = input
82
+ # probs = F.softmax(avg_scores.exp_())
83
+ # loss = (probs * costs.type_as(probs)).sum() / input.size(0)
84
+ # print(output.item(), loss.item())
85
+
86
+ elif self.loss_type == 'max_margin':
87
+ # input is logits
88
+ input = input * mask
89
+ input = input.sum(1) / mask.sum(1)
90
+ input = input.view(-1, seq_per_img)
91
+ _, __ = costs.min(1, keepdim=True)
92
+ costs_star = _
93
+ input_star = input.gather(1, __)
94
+ output = F.relu(costs - costs_star - input_star + input).max(1)[0] / 2
95
+ output = output.mean()
96
+
97
+ # sanity test
98
+ # avg_scores = input + costs
99
+ # scores_with_high_target = avg_scores.clone()
100
+ # scores_with_high_target.scatter_(1, costs.min(1)[1].view(-1, 1), 1e10)
101
+
102
+ # target_and_offender_index = scores_with_high_target.sort(1, True)[1][:, 0:2]
103
+ # avg_scores = avg_scores.gather(1, target_and_offender_index)
104
+ # target_index = avg_scores.new_zeros(avg_scores.size(0), dtype=torch.long)
105
+ # loss = F.multi_margin_loss(avg_scores, target_index, size_average=True, margin=0)
106
+ # print(loss.item() * 2, output.item())
107
+
108
+ elif self.loss_type == 'multi_margin':
109
+ # input is logits
110
+ input = input * mask
111
+ input = input.sum(1) / mask.sum(1)
112
+ input = input.view(-1, seq_per_img)
113
+ _, __ = costs.min(1, keepdim=True)
114
+ costs_star = _
115
+ input_star = input.gather(1, __)
116
+ output = F.relu(costs - costs_star - input_star + input)
117
+ output = output.mean()
118
+
119
+ # sanity test
120
+ # avg_scores = input + costs
121
+ # loss = F.multi_margin_loss(avg_scores, costs.min(1)[1], margin=0)
122
+ # print(output, loss)
123
+
124
+ elif self.loss_type == 'softmax_margin':
125
+ # input is logsoftmax
126
+ input = input * mask
127
+ input = input.sum(1) / mask.sum(1)
128
+ input = input.view(-1, seq_per_img)
129
+
130
+ input = input + costs
131
+ target = costs.min(1)[1]
132
+ output = F.cross_entropy(input, target)
133
+
134
+ elif self.loss_type == 'real_softmax_margin':
135
+ # input is logits
136
+ # This is what originally defined in Kevin's paper
137
+ # The result should be equivalent to softmax_margin
138
+ input = input * mask
139
+ input = input.sum(1) / mask.sum(1)
140
+ input = input.view(-1, seq_per_img)
141
+
142
+ input = input + costs
143
+ target = costs.min(1)[1]
144
+ output = F.cross_entropy(input, target)
145
+
146
+ elif self.loss_type == 'new_self_critical':
147
+ """
148
+ A different self critical
149
+ Self critical uses greedy decoding score as baseline;
150
+ This setting uses the average score of the rest samples as baseline
151
+ (suppose c1...cn n samples, reward1 = score1 - 1/(n-1)(score2+..+scoren) )
152
+ """
153
+ baseline = (scores.sum(1, keepdim=True) - scores) / (scores.shape[1] - 1)
154
+ scores = scores - baseline
155
+ # self cider used as reward to promote diversity (not working that much in this way)
156
+ if getattr(self.opt, 'self_cider_reward_weight', 0) > 0:
157
+ _scores = get_self_cider_scores(data_gts, seq, self.opt)
158
+ _scores = torch.from_numpy(_scores).type_as(scores).view(-1, 1)
159
+ _scores = _scores.expand_as(scores - 1)
160
+ scores += self.opt.self_cider_reward_weight * _scores
161
+ output = - input * mask * scores.view(-1, 1)
162
+ output = torch.sum(output) / torch.sum(mask)
163
+
164
+ out['loss'] = output
165
+ return out
166
+
167
+ class LanguageModelCriterion(nn.Module):
168
+ def __init__(self):
169
+ super(LanguageModelCriterion, self).__init__()
170
+
171
+ def forward(self, input, target, mask):
172
+ if target.ndim == 3:
173
+ target = target.reshape(-1, target.shape[2])
174
+ mask = mask.reshape(-1, mask.shape[2])
175
+ # truncate to the same size
176
+ target = target[:, :input.size(1)]
177
+ mask = mask[:, :input.size(1)].to(input)
178
+
179
+ output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask
180
+ # Average over each token
181
+ output = torch.sum(output) / torch.sum(mask)
182
+
183
+ return output
184
+
185
+ class LabelSmoothing(nn.Module):
186
+ "Implement label smoothing."
187
+ def __init__(self, size=0, padding_idx=0, smoothing=0.0):
188
+ super(LabelSmoothing, self).__init__()
189
+ self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
190
+ # self.padding_idx = padding_idx
191
+ self.confidence = 1.0 - smoothing
192
+ self.smoothing = smoothing
193
+ # self.size = size
194
+ self.true_dist = None
195
+
196
+ def forward(self, input, target, mask):
197
+ if target.ndim == 3:
198
+ target = target.reshape(-1, target.shape[2])
199
+ mask = mask.reshape(-1, mask.shape[2])
200
+ # truncate to the same size
201
+ target = target[:, :input.size(1)]
202
+ mask = mask[:, :input.size(1)]
203
+
204
+ input = input.reshape(-1, input.size(-1))
205
+ target = target.reshape(-1)
206
+ mask = mask.reshape(-1).to(input)
207
+
208
+ # assert x.size(1) == self.size
209
+ self.size = input.size(1)
210
+ # true_dist = x.data.clone()
211
+ true_dist = input.data.clone()
212
+ # true_dist.fill_(self.smoothing / (self.size - 2))
213
+ true_dist.fill_(self.smoothing / (self.size - 1))
214
+ true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
215
+ # true_dist[:, self.padding_idx] = 0
216
+ # mask = torch.nonzero(target.data == self.padding_idx)
217
+ # self.true_dist = true_dist
218
+ return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum()
captioning/utils/__init__.py ADDED
File without changes
captioning/utils/clipscore.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPModel, CLIPTokenizer
2
+ import os
3
+ import json
4
+ import argparse
5
+ from random import shuffle, seed
6
+ import string
7
+ # non-standard dependencies:
8
+ import h5py
9
+ from six.moves import cPickle
10
+ import numpy as np
11
+ import torch
12
+ import torchvision.models as models
13
+ import skimage.io
14
+
15
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
16
+ from PIL import Image
17
+ from torch import nn
18
+
19
+
20
+ class CLIPScore(nn.Module):
21
+ def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s', use_grammar=False, joint_out=False):
22
+ super(CLIPScore, self).__init__()
23
+ # from transformers import CLIPModel, CLIPTokenizer
24
+ self.clip_model = CLIPModel.from_pretrained(
25
+ 'openai/clip-vit-base-patch32')
26
+ self.tokenizer = CLIPTokenizer.from_pretrained(
27
+ 'openai/clip-vit-base-patch32')
28
+
29
+ self.clip_model.eval()
30
+
31
+ self.clipscore_w = clipscore_w
32
+
33
+ self.image_transform = self._transform(image_size)
34
+
35
+ self.mode = mode
36
+ assert mode in ['clip_s', 'refclip_s']
37
+
38
+ self.use_grammar = use_grammar
39
+ self.joint_out = joint_out
40
+
41
+ if self.use_grammar and joint_out is False:
42
+ self.grammar_score_head = nn.Sequential(
43
+ nn.Linear(self.clip_model.text_embed_dim, self.clip_model.projection_dim, bias=False),
44
+ nn.ReLU(),
45
+ nn.Linear(self.clip_model.projection_dim, 2, bias=False)
46
+ )
47
+
48
+ def _transform(self, n_px):
49
+ return Compose([
50
+ Resize(n_px, interpolation=Image.BICUBIC),
51
+ CenterCrop(n_px),
52
+ lambda image: image.convert("RGB"),
53
+ ToTensor(),
54
+ Normalize((0.48145466, 0.4578275, 0.40821073),
55
+ (0.26862954, 0.26130258, 0.27577711)),
56
+ ])
57
+
58
+ def load_image(self, image_path):
59
+ image = Image.open(image_path)
60
+ return image
61
+
62
+ # @torch.no_grad()
63
+ def image_extract(self, image):
64
+ if isinstance(image, str):
65
+ image = self.load_image(image)
66
+ if not isinstance(image, torch.Tensor):
67
+ image = self.image_transform(image)
68
+
69
+ img_tensor = image.view(-1, 3, 224, 224)
70
+ device = next(self.clip_model.parameters()).device
71
+ img_tensor = img_tensor.to(device)
72
+
73
+ clip_model = self.clip_model
74
+
75
+ img_feat = clip_model.vision_model(img_tensor).pooler_output
76
+ img_feat = clip_model.visual_projection(img_feat)
77
+ img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
78
+
79
+ return img_feat
80
+
81
+ # @torch.no_grad()
82
+ def text_extract(self, text, prompt="A photo depicts", proj_norm=True):
83
+ if isinstance(text, str):
84
+ text_batch = [" ".join([prompt, text])]
85
+ elif isinstance(text, list):
86
+ text_batch = [" ".join([prompt, txt]) for txt in text]
87
+
88
+ if isinstance(text, tuple) and isinstance(text[0], torch.Tensor):
89
+ input_ids, attention_mask = text
90
+ else:
91
+ input_text = text_batch
92
+
93
+ tokenized = self.tokenizer(
94
+ input_text, return_tensors='pt', padding=True, truncation=True)
95
+
96
+ input_ids = tokenized.input_ids
97
+ attention_mask = tokenized.attention_mask
98
+
99
+ clip_model = self.clip_model
100
+ device = next(self.clip_model.parameters()).device
101
+ input_ids = input_ids.to(device)
102
+ attention_mask = attention_mask.to(device)
103
+
104
+ text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output
105
+
106
+ if proj_norm:
107
+ text_feat = clip_model.text_projection(text_feat)
108
+ text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
109
+
110
+ return text_feat
111
+
112
+ # @torch.no_grad()
113
+ def calc_clip_s(self, img_feat, text_feat):
114
+ return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1))
115
+
116
+ # @torch.no_grad()
117
+ def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None):
118
+
119
+ if clip_s is None:
120
+ clip_s = self.calc_clip_s(img_feat, text_feat)
121
+
122
+ B, dim = img_feat.size()
123
+
124
+ ref_text_feat = ref_text_feat.view(B, -1, dim)
125
+
126
+ K = ref_text_feat.size(1)
127
+
128
+ text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1)
129
+ assert ref_text_feat.size() == text_feat.size(
130
+ ), (ref_text_feat.size(), text_feat.size())
131
+
132
+ ref_score = self.calc_clip_s(text_feat, ref_text_feat)
133
+ if ref_text_mask is not None:
134
+ if not isinstance(ref_text_mask, torch.Tensor):
135
+ ref_text_mask = torch.tensor(
136
+ ref_text_mask, dtype=ref_score.dtype, device=ref_score.device)
137
+ ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K)
138
+
139
+ ref_score = ref_score.view(B, K).max(dim=1).values
140
+
141
+ assert clip_s.size() == (B,)
142
+ assert clip_s.size() == ref_score.size()
143
+
144
+ # harmonic mean
145
+ refclip_s = 2 / (1 / clip_s + 1 / ref_score)
146
+ return refclip_s
147
+
148
+ @torch.no_grad()
149
+ def forward(self,
150
+ images=None, text=None,
151
+ img_feat=None, text_feat=None,
152
+ ref_text=None, ref_text_feat=None, ref_text_mask=None,
153
+ prompt="A photo depicts",
154
+ mode=None):
155
+ if img_feat is None:
156
+ img_feat = self.image_extract(images)
157
+ img_feat = img_feat.view(-1, 512)
158
+
159
+ B = img_feat.size(0)
160
+
161
+ if text_feat is None:
162
+ text_feat = self.text_extract(text, prompt=prompt)
163
+ text_feat = text_feat.view(-1, 512)
164
+
165
+ if mode is None:
166
+ mode = self.mode
167
+ assert mode in ['clip_s', 'refclip_s']
168
+
169
+ if mode == 'clip_s':
170
+ clip_s = self.calc_clip_s(img_feat, text_feat)
171
+ return clip_s
172
+ elif mode == 'refclip_s':
173
+ if ref_text_feat is None:
174
+ ref_text_feat = self.text_extract(ref_text, prompt=prompt)
175
+ ref_text_feat = ref_text_feat.view(-1, 512)
176
+
177
+ refclip_s = self.calc_refclip_s(
178
+ img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask)
179
+ return refclip_s
180
+
181
+
182
+ def train_step(self,
183
+ images=None, text=None,
184
+ img_feat=None, text_feat=None,
185
+ neg_text=None, neg_text_feat=None,
186
+ # ref_text=None, ref_text_feat=None, ref_text_mask=None,
187
+ prompt="A photo depicts",
188
+ # return_loss=True,
189
+ **kwargs):
190
+
191
+ if img_feat is None:
192
+ img_feat = self.image_extract(images)
193
+ img_feat = img_feat.view(-1, 512)
194
+
195
+ B = img_feat.size(0)
196
+
197
+ if text_feat is None:
198
+ text_feat = self.text_extract(text, prompt=prompt, proj_norm=False)
199
+
200
+ text_cont_feat = self.clip_model.text_projection(text_feat)
201
+ text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
202
+ text_cont_feat = text_cont_feat.view(B, 512)
203
+
204
+ # cosine similarity as logits
205
+ logit_scale = self.clip_model.logit_scale.exp()
206
+ logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
207
+ # logits_per_image = logits_per_text.T
208
+
209
+ clip_loss = clip_loss_fn(logits_per_text)
210
+
211
+
212
+ # negative sampling
213
+ pos_text_feat = text_feat.view(B, 512)
214
+ neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512)
215
+
216
+ grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
217
+
218
+ # 2B, 1
219
+ grammar_text_logit = self.grammar_score_head(grammar_text_feat)
220
+ grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B)
221
+
222
+ grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels)
223
+
224
+ grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False)
225
+ grammar_pos_pred = grammar_pred[:B]
226
+ grammar_neg_pred = grammar_pred[B:]
227
+ # grammar_acc = (grammar_pred == grammar_labels).float().mean()
228
+
229
+ out = {
230
+ 'clip_loss': clip_loss,
231
+ 'grammar_loss': grammar_loss,
232
+ 'img_feat': img_feat,
233
+ 'text_feat': text_cont_feat,
234
+ 'neg_text_feat': neg_text_feat,
235
+ 'grammar_pos_pred': grammar_pos_pred,
236
+ 'grammar_neg_pred': grammar_neg_pred,
237
+ }
238
+
239
+ return out
240
+
241
+ # contrastive loss function, adapted from
242
+ # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
243
+ def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor:
244
+ neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim))
245
+ return -neg_ce.mean()
246
+
247
+
248
+ def clip_loss_fn(similarity: torch.Tensor) -> torch.Tensor:
249
+ caption_loss = contrastive_loss(similarity, dim=0)
250
+ image_loss = contrastive_loss(similarity, dim=1)
251
+ return (caption_loss + image_loss) / 2.0
252
+
253
+
254
+
255
+ # class CLIPScore(nn.Module):
256
+ # def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s'):
257
+ # super(CLIPScore, self).__init__()
258
+ # # from transformers import CLIPModel, CLIPTokenizer
259
+ # self.clip_model = CLIPModel.from_pretrained(
260
+ # 'openai/clip-vit-base-patch32')
261
+ # self.tokenizer = CLIPTokenizer.from_pretrained(
262
+ # 'openai/clip-vit-base-patch32')
263
+
264
+ # self.clip_model.eval()
265
+
266
+ # self.clipscore_w = clipscore_w
267
+
268
+ # self.image_transform = self._transform(image_size)
269
+
270
+ # self.mode = mode
271
+ # assert mode in ['clip_s', 'refclip_s']
272
+
273
+ # def _transform(self, n_px):
274
+ # return Compose([
275
+ # Resize(n_px, interpolation=Image.BICUBIC),
276
+ # CenterCrop(n_px),
277
+ # lambda image: image.convert("RGB"),
278
+ # ToTensor(),
279
+ # Normalize((0.48145466, 0.4578275, 0.40821073),
280
+ # (0.26862954, 0.26130258, 0.27577711)),
281
+ # ])
282
+
283
+ # def load_image(self, image_path):
284
+ # image = Image.open(image_path)
285
+ # return image
286
+
287
+ # @torch.no_grad()
288
+ # def image_extract(self, image):
289
+ # if isinstance(image, str):
290
+ # image = self.load_image(image)
291
+ # if not isinstance(image, torch.Tensor):
292
+ # image = self.image_transform(image)
293
+
294
+ # img_tensor = image.view(-1, 3, 224, 224)
295
+ # device = next(self.clip_model.parameters()).device
296
+ # img_tensor = img_tensor.to(device)
297
+
298
+ # clip_model = self.clip_model
299
+
300
+ # img_feat = clip_model.vision_model(img_tensor).pooler_output
301
+ # img_feat = clip_model.visual_projection(img_feat)
302
+ # img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
303
+
304
+ # return img_feat
305
+
306
+ # @torch.no_grad()
307
+ # def text_extract(self, text, prompt="A photo depicts"):
308
+ # if isinstance(text, str):
309
+ # text_batch = [" ".join([prompt, text])]
310
+ # else:
311
+ # text_batch = [" ".join([prompt, txt]) for txt in text]
312
+
313
+ # input_text = text_batch
314
+
315
+ # tokenized = self.tokenizer(
316
+ # input_text, return_tensors='pt', padding=True)
317
+
318
+ # input_ids = tokenized.input_ids
319
+ # attention_mask = tokenized.attention_mask
320
+
321
+ # clip_model = self.clip_model
322
+ # device = next(self.clip_model.parameters()).device
323
+ # input_ids = input_ids.to(device)
324
+ # attention_mask = attention_mask.to(device)
325
+
326
+ # text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output
327
+ # text_feat = clip_model.text_projection(text_feat)
328
+ # text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
329
+
330
+ # return text_feat
331
+
332
+ # @torch.no_grad()
333
+ # def calc_clip_s(self, img_feat, text_feat):
334
+ # return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1))
335
+
336
+ # @torch.no_grad()
337
+ # def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None):
338
+
339
+ # if clip_s is None:
340
+ # clip_s = self.calc_clip_s(img_feat, text_feat)
341
+
342
+ # B, dim = img_feat.size()
343
+
344
+ # ref_text_feat = ref_text_feat.view(B, -1, dim)
345
+
346
+ # K = ref_text_feat.size(1)
347
+
348
+ # text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1)
349
+ # assert ref_text_feat.size() == text_feat.size(), (ref_text_feat.size(), text_feat.size())
350
+
351
+ # ref_score = self.calc_clip_s(text_feat, ref_text_feat)
352
+ # if ref_text_mask is not None:
353
+ # if not isinstance(ref_text_mask, torch.Tensor):
354
+ # ref_text_mask = torch.tensor(ref_text_mask, dtype=ref_score.dtype, device=ref_score.device)
355
+ # ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K)
356
+
357
+ # ref_score = ref_score.view(B, K).max(dim=1).values
358
+
359
+ # assert clip_s.size() == (B,)
360
+ # assert clip_s.size() == ref_score.size()
361
+
362
+ # # harmonic mean
363
+ # refclip_s = 2 / (1 / clip_s + 1 / ref_score)
364
+ # return refclip_s
365
+
366
+
367
+ # @torch.no_grad()
368
+ # def forward(self,
369
+ # images=None, text=None,
370
+ # img_feat=None, text_feat=None,
371
+ # ref_text=None, ref_text_feat=None, ref_text_mask=None,
372
+ # prompt="A photo depicts",
373
+ # mode=None):
374
+ # if img_feat is None:
375
+ # img_feat = self.image_extract(images)
376
+ # img_feat = img_feat.view(-1, 512)
377
+
378
+ # if text_feat is None:
379
+ # text_feat = self.text_extract(text, prompt=prompt)
380
+ # text_feat = text_feat.view(-1, 512)
381
+
382
+ # if mode is None:
383
+ # mode = self.mode
384
+ # assert mode in ['clip_s', 'refclip_s']
385
+
386
+ # if mode == 'clip_s':
387
+ # clip_s = self.calc_clip_s(img_feat, text_feat)
388
+ # return clip_s
389
+ # elif mode == 'refclip_s':
390
+ # if ref_text_feat is None:
391
+ # ref_text_feat = self.text_extract(ref_text, prompt=prompt)
392
+ # ref_text_feat = ref_text_feat.view(-1, 512)
393
+
394
+ # refclip_s = self.calc_refclip_s(img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask)
395
+ # return refclip_s
396
+
captioning/utils/config.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ # Copy from fvcore
3
+
4
+ import logging
5
+ import os
6
+ from typing import Any
7
+ import yaml
8
+ from yacs.config import CfgNode as _CfgNode
9
+
10
+ import io as PathManager
11
+
12
+ BASE_KEY = "_BASE_"
13
+
14
+
15
+ class CfgNode(_CfgNode):
16
+ """
17
+ Our own extended version of :class:`yacs.config.CfgNode`.
18
+ It contains the following extra features:
19
+
20
+ 1. The :meth:`merge_from_file` method supports the "_BASE_" key,
21
+ which allows the new CfgNode to inherit all the attributes from the
22
+ base configuration file.
23
+ 2. Keys that start with "COMPUTED_" are treated as insertion-only
24
+ "computed" attributes. They can be inserted regardless of whether
25
+ the CfgNode is frozen or not.
26
+ 3. With "allow_unsafe=True", it supports pyyaml tags that evaluate
27
+ expressions in config. See examples in
28
+ https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types
29
+ Note that this may lead to arbitrary code execution: you must not
30
+ load a config file from untrusted sources before manually inspecting
31
+ the content of the file.
32
+ """
33
+
34
+ @staticmethod
35
+ def load_yaml_with_base(filename, allow_unsafe = False):
36
+ """
37
+ Just like `yaml.load(open(filename))`, but inherit attributes from its
38
+ `_BASE_`.
39
+
40
+ Args:
41
+ filename (str): the file name of the current config. Will be used to
42
+ find the base config file.
43
+ allow_unsafe (bool): whether to allow loading the config file with
44
+ `yaml.unsafe_load`.
45
+
46
+ Returns:
47
+ (dict): the loaded yaml
48
+ """
49
+ with PathManager.open(filename, "r") as f:
50
+ try:
51
+ cfg = yaml.safe_load(f)
52
+ except yaml.constructor.ConstructorError:
53
+ if not allow_unsafe:
54
+ raise
55
+ logger = logging.getLogger(__name__)
56
+ logger.warning(
57
+ "Loading config {} with yaml.unsafe_load. Your machine may "
58
+ "be at risk if the file contains malicious content.".format(
59
+ filename
60
+ )
61
+ )
62
+ f.close()
63
+ with open(filename, "r") as f:
64
+ cfg = yaml.unsafe_load(f)
65
+
66
+ def merge_a_into_b(a, b):
67
+ # merge dict a into dict b. values in a will overwrite b.
68
+ for k, v in a.items():
69
+ if isinstance(v, dict) and k in b:
70
+ assert isinstance(
71
+ b[k], dict
72
+ ), "Cannot inherit key '{}' from base!".format(k)
73
+ merge_a_into_b(v, b[k])
74
+ else:
75
+ b[k] = v
76
+
77
+ if BASE_KEY in cfg:
78
+ base_cfg_file = cfg[BASE_KEY]
79
+ if base_cfg_file.startswith("~"):
80
+ base_cfg_file = os.path.expanduser(base_cfg_file)
81
+ if not any(
82
+ map(base_cfg_file.startswith, ["/", "https://", "http://"])
83
+ ):
84
+ # the path to base cfg is relative to the config file itself.
85
+ base_cfg_file = os.path.join(
86
+ os.path.dirname(filename), base_cfg_file
87
+ )
88
+ base_cfg = CfgNode.load_yaml_with_base(
89
+ base_cfg_file, allow_unsafe=allow_unsafe
90
+ )
91
+ del cfg[BASE_KEY]
92
+
93
+ merge_a_into_b(cfg, base_cfg)
94
+ return base_cfg
95
+ return cfg
96
+
97
+ def merge_from_file(self, cfg_filename, allow_unsafe = False):
98
+ """
99
+ Merge configs from a given yaml file.
100
+
101
+ Args:
102
+ cfg_filename: the file name of the yaml config.
103
+ allow_unsafe: whether to allow loading the config file with
104
+ `yaml.unsafe_load`.
105
+ """
106
+ loaded_cfg = CfgNode.load_yaml_with_base(
107
+ cfg_filename, allow_unsafe=allow_unsafe
108
+ )
109
+ loaded_cfg = type(self)(loaded_cfg)
110
+ self.merge_from_other_cfg(loaded_cfg)
111
+
112
+ # Forward the following calls to base, but with a check on the BASE_KEY.
113
+ def merge_from_other_cfg(self, cfg_other):
114
+ """
115
+ Args:
116
+ cfg_other (CfgNode): configs to merge from.
117
+ """
118
+ assert (
119
+ BASE_KEY not in cfg_other
120
+ ), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
121
+ return super().merge_from_other_cfg(cfg_other)
122
+
123
+ def merge_from_list(self, cfg_list):
124
+ """
125
+ Args:
126
+ cfg_list (list): list of configs to merge from.
127
+ """
128
+ keys = set(cfg_list[0::2])
129
+ assert (
130
+ BASE_KEY not in keys
131
+ ), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
132
+ return super().merge_from_list(cfg_list)
133
+
134
+ def __setattr__(self, name, val):
135
+ if name.startswith("COMPUTED_"):
136
+ if name in self:
137
+ old_val = self[name]
138
+ if old_val == val:
139
+ return
140
+ raise KeyError(
141
+ "Computed attributed '{}' already exists "
142
+ "with a different value! old={}, new={}.".format(
143
+ name, old_val, val
144
+ )
145
+ )
146
+ self[name] = val
147
+ else:
148
+ super().__setattr__(name, val)
149
+
150
+
151
+ if __name__ == '__main__':
152
+ cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml')
153
+ print(cfg)
captioning/utils/dist_utils.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ This file contains primitives for multi-gpu communication.
4
+ This is useful when doing distributed training.
5
+ """
6
+
7
+ import functools
8
+ import logging
9
+ import numpy as np
10
+ import pickle
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ import torch
15
+
16
+ _LOCAL_PROCESS_GROUP = None
17
+ """
18
+ A torch process group which only includes processes that on the same machine as the current process.
19
+ This variable is set when processes are spawned by `launch()` in "engine/launch.py".
20
+ """
21
+
22
+
23
+ def get_world_size() -> int:
24
+ if not dist.is_available():
25
+ return 1
26
+ if not dist.is_initialized():
27
+ return 1
28
+ return dist.get_world_size()
29
+
30
+
31
+ def get_rank() -> int:
32
+ if not dist.is_available():
33
+ return 0
34
+ if not dist.is_initialized():
35
+ return 0
36
+ return dist.get_rank()
37
+
38
+
39
+ def get_local_rank() -> int:
40
+ """
41
+ Returns:
42
+ The rank of the current process within the local (per-machine) process group.
43
+ """
44
+ if not dist.is_available():
45
+ return 0
46
+ if not dist.is_initialized():
47
+ return 0
48
+ assert _LOCAL_PROCESS_GROUP is not None
49
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
50
+
51
+
52
+ def get_local_size() -> int:
53
+ """
54
+ Returns:
55
+ The size of the per-machine process group,
56
+ i.e. the number of processes per machine.
57
+ """
58
+ if not dist.is_available():
59
+ return 1
60
+ if not dist.is_initialized():
61
+ return 1
62
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
63
+
64
+
65
+ def is_main_process() -> bool:
66
+ return get_rank() == 0
67
+
68
+
69
+ def synchronize():
70
+ """
71
+ Helper function to synchronize (barrier) among all processes when
72
+ using distributed training
73
+ """
74
+ if not dist.is_available():
75
+ return
76
+ if not dist.is_initialized():
77
+ return
78
+ world_size = dist.get_world_size()
79
+ if world_size == 1:
80
+ return
81
+ dist.barrier()
82
+
83
+
84
+ @functools.lru_cache()
85
+ def _get_global_gloo_group():
86
+ """
87
+ Return a process group based on gloo backend, containing all the ranks
88
+ The result is cached.
89
+ """
90
+ if dist.get_backend() == "nccl":
91
+ return dist.new_group(backend="gloo")
92
+ else:
93
+ return dist.group.WORLD
94
+
95
+
96
+ def _serialize_to_tensor(data, group):
97
+ backend = dist.get_backend(group)
98
+ assert backend in ["gloo", "nccl"]
99
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
100
+
101
+ buffer = pickle.dumps(data)
102
+ if len(buffer) > 1024 ** 3:
103
+ logger = logging.getLogger(__name__)
104
+ logger.warning(
105
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
106
+ get_rank(), len(buffer) / (1024 ** 3), device
107
+ )
108
+ )
109
+ storage = torch.ByteStorage.from_buffer(buffer)
110
+ tensor = torch.ByteTensor(storage).to(device=device)
111
+ return tensor
112
+
113
+
114
+ def _pad_to_largest_tensor(tensor, group):
115
+ """
116
+ Returns:
117
+ list[int]: size of the tensor, on each rank
118
+ Tensor: padded tensor that has the max size
119
+ """
120
+ world_size = dist.get_world_size(group=group)
121
+ assert (
122
+ world_size >= 1
123
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
124
+ local_size = torch.tensor(
125
+ [tensor.numel()], dtype=torch.int64, device=tensor.device)
126
+ size_list = [
127
+ torch.zeros([1], dtype=torch.int64, device=tensor.device)
128
+ for _ in range(world_size)
129
+ ]
130
+ dist.all_gather(size_list, local_size, group=group)
131
+ size_list = [int(size.item()) for size in size_list]
132
+
133
+ max_size = max(size_list)
134
+
135
+ # we pad the tensor because torch all_gather does not support
136
+ # gathering tensors of different shapes
137
+ if local_size != max_size:
138
+ padding = torch.zeros(
139
+ (max_size - local_size,), dtype=torch.uint8, device=tensor.device
140
+ )
141
+ tensor = torch.cat((tensor, padding), dim=0)
142
+ return size_list, tensor
143
+
144
+
145
+ def all_gather(data, group=None):
146
+ """
147
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
148
+ Args:
149
+ data: any picklable object
150
+ group: a torch process group. By default, will use a group which
151
+ contains all ranks on gloo backend.
152
+ Returns:
153
+ list[data]: list of data gathered from each rank
154
+ """
155
+ if get_world_size() == 1:
156
+ return [data]
157
+ if group is None:
158
+ group = _get_global_gloo_group()
159
+ if dist.get_world_size(group) == 1:
160
+ return [data]
161
+
162
+ tensor = _serialize_to_tensor(data, group)
163
+
164
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
165
+ max_size = max(size_list)
166
+
167
+ # receiving Tensor from all ranks
168
+ tensor_list = [
169
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
170
+ for _ in size_list
171
+ ]
172
+ dist.all_gather(tensor_list, tensor, group=group)
173
+
174
+ data_list = []
175
+ for size, tensor in zip(size_list, tensor_list):
176
+ buffer = tensor.cpu().numpy().tobytes()[:size]
177
+ data_list.append(pickle.loads(buffer))
178
+
179
+ return data_list
180
+
181
+
182
+ def gather(data, dst=0, group=None):
183
+ """
184
+ Run gather on arbitrary picklable data (not necessarily tensors).
185
+ Args:
186
+ data: any picklable object
187
+ dst (int): destination rank
188
+ group: a torch process group. By default, will use a group which
189
+ contains all ranks on gloo backend.
190
+ Returns:
191
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
192
+ an empty list.
193
+ """
194
+ if get_world_size() == 1:
195
+ return [data]
196
+ if group is None:
197
+ group = _get_global_gloo_group()
198
+ if dist.get_world_size(group=group) == 1:
199
+ return [data]
200
+ rank = dist.get_rank(group=group)
201
+
202
+ tensor = _serialize_to_tensor(data, group)
203
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
204
+
205
+ # receiving Tensor from all ranks
206
+ if rank == dst:
207
+ max_size = max(size_list)
208
+ tensor_list = [
209
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
210
+ for _ in size_list
211
+ ]
212
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
213
+
214
+ data_list = []
215
+ for size, tensor in zip(size_list, tensor_list):
216
+ buffer = tensor.cpu().numpy().tobytes()[:size]
217
+ data_list.append(pickle.loads(buffer))
218
+ return data_list
219
+ else:
220
+ dist.gather(tensor, [], dst=dst, group=group)
221
+ return []
222
+
223
+
224
+ def shared_random_seed():
225
+ """
226
+ Returns:
227
+ int: a random number that is the same across all workers.
228
+ If workers need a shared RNG, they can use this shared seed to
229
+ create one.
230
+ All workers must call this function, otherwise it will deadlock.
231
+ """
232
+ ints = np.random.randint(2 ** 31)
233
+ all_ints = all_gather(ints)
234
+ return all_ints[0]
235
+
236
+
237
+ # def reduce_dict(input_dict, average=True):
238
+ # """
239
+ # Reduce the values in the dictionary from all processes so that process with rank
240
+ # 0 has the reduced results.
241
+ # Args:
242
+ # input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
243
+ # average (bool): whether to do average or sum
244
+ # Returns:
245
+ # a dict with the same keys as input_dict, after reduction.
246
+ # """
247
+ # world_size = get_world_size()
248
+ # if world_size < 2:
249
+ # return input_dict
250
+ # with torch.no_grad():
251
+ # names = []
252
+ # values = []
253
+ # # sort the keys so that they are consistent across processes
254
+ # for k in sorted(input_dict.keys()):
255
+ # names.append(k)
256
+ # values.append(input_dict[k])
257
+ # values = torch.stack(values, dim=0)
258
+ # dist.reduce(values, dst=0)
259
+ # if dist.get_rank() == 0 and average:
260
+ # # only main process gets accumulated, so only divide by
261
+ # # world_size in this case
262
+ # values /= world_size
263
+ # reduced_dict = {k: v for k, v in zip(names, values)}
264
+ # return reduced_dict
265
+
266
+
267
+ def reduce_dict(input_dict, average=True):
268
+ """
269
+ Reduce the values in the dictionary from all processes so that process with rank
270
+ 0 has the reduced results.
271
+ Args:
272
+ input_dict (dict): inputs to be reduced. (values not necessarily tensors).
273
+ average (bool): whether to do average or sum
274
+ Returns:
275
+ a dict with the same keys as input_dict, after reduction.
276
+ """
277
+
278
+ world_size = get_world_size()
279
+ if world_size < 2:
280
+ return input_dict
281
+
282
+ with torch.no_grad():
283
+
284
+ # Convert to CUDA Tensor for dist.reduce()
285
+ input_dict_cuda_vals = {}
286
+ for k, v in input_dict.items():
287
+ if type(v) == torch.Tensor:
288
+ input_dict_cuda_vals[k] = v.to('cuda')
289
+ else:
290
+ input_dict_cuda_vals[k] = torch.tensor(v, device='cuda')
291
+
292
+ names = []
293
+ values = []
294
+ for k, v in sorted(input_dict_cuda_vals.items()):
295
+ names.append(k)
296
+ values.append(v)
297
+ values = torch.stack(values, dim=0)
298
+ dist.reduce(values, dst=0) # reduce to gpu 0
299
+
300
+ if dist.get_rank() == 0 and average:
301
+ # only main process gets accumulated, so only divide by
302
+ # world_size in this case
303
+ values /= world_size
304
+ reduced_dict = {k: v for k, v in zip(names, values)}
305
+ return reduced_dict
captioning/utils/div_utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import uniform
2
+ import numpy as np
3
+ from collections import OrderedDict, defaultdict
4
+ from itertools import tee
5
+ import time
6
+
7
+ # -----------------------------------------------
8
+ def find_ngrams(input_list, n):
9
+ return zip(*[input_list[i:] for i in range(n)])
10
+
11
+ def compute_div_n(caps,n=1):
12
+ aggr_div = []
13
+ for k in caps:
14
+ all_ngrams = set()
15
+ lenT = 0.
16
+ for c in caps[k]:
17
+ tkns = c.split()
18
+ lenT += len(tkns)
19
+ ng = find_ngrams(tkns, n)
20
+ all_ngrams.update(ng)
21
+ aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT)))
22
+ return np.array(aggr_div).mean(), np.array(aggr_div)
23
+
24
+ def compute_global_div_n(caps,n=1):
25
+ aggr_div = []
26
+ all_ngrams = set()
27
+ lenT = 0.
28
+ for k in caps:
29
+ for c in caps[k]:
30
+ tkns = c.split()
31
+ lenT += len(tkns)
32
+ ng = find_ngrams(tkns, n)
33
+ all_ngrams.update(ng)
34
+ if n == 1:
35
+ aggr_div.append(float(len(all_ngrams)))
36
+ else:
37
+ aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT)))
38
+ return aggr_div[0], np.repeat(np.array(aggr_div),len(caps))
captioning/utils/eval_multi.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ import numpy as np
9
+ import json
10
+ from json import encoder
11
+ import random
12
+ import string
13
+ import time
14
+ import os
15
+ import sys
16
+ from . import misc as utils
17
+ from eval_utils import getCOCO
18
+
19
+ from .div_utils import compute_div_n, compute_global_div_n
20
+
21
+ import sys
22
+ try:
23
+ sys.path.append("coco-caption")
24
+ annFile = 'coco-caption/annotations/captions_val2014.json'
25
+ from pycocotools.coco import COCO
26
+ from pycocoevalcap.eval import COCOEvalCap
27
+ from pycocoevalcap.eval_spice import COCOEvalCapSpice
28
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
29
+ from pycocoevalcap.bleu.bleu import Bleu
30
+ sys.path.append("cider")
31
+ from pyciderevalcap.cider.cider import Cider
32
+ except:
33
+ print('Warning: requirements for eval_multi not satisfied')
34
+
35
+
36
+ def eval_allspice(dataset, preds_n, model_id, split):
37
+ coco = getCOCO(dataset)
38
+ valids = coco.getImgIds()
39
+
40
+ capsById = {}
41
+ for d in preds_n:
42
+ capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
43
+
44
+ # filter results to only those in MSCOCO validation set (will be about a third)
45
+ preds_filt_n = [p for p in preds_n if p['image_id'] in valids]
46
+ print('using %d/%d predictions_n' % (len(preds_filt_n), len(preds_n)))
47
+ cache_path_n = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
48
+ json.dump(preds_filt_n, open(cache_path_n, 'w')) # serialize to temporary json file. Sigh, COCO API...
49
+
50
+ # Eval AllSPICE
51
+ cocoRes_n = coco.loadRes(cache_path_n)
52
+ cocoEvalAllSPICE = COCOEvalCapSpice(coco, cocoRes_n)
53
+ cocoEvalAllSPICE.params['image_id'] = cocoRes_n.getImgIds()
54
+ cocoEvalAllSPICE.evaluate()
55
+
56
+ out = {}
57
+ for metric, score in cocoEvalAllSPICE.eval.items():
58
+ out['All'+metric] = score
59
+
60
+ imgToEvalAllSPICE = cocoEvalAllSPICE.imgToEval
61
+ # collect SPICE_sub_score
62
+ for k in list(imgToEvalAllSPICE.values())[0]['SPICE'].keys():
63
+ if k != 'All':
64
+ out['AllSPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEvalAllSPICE.values()])
65
+ out['AllSPICE_'+k] = (out['AllSPICE_'+k][out['AllSPICE_'+k]==out['AllSPICE_'+k]]).mean()
66
+ for p in preds_filt_n:
67
+ image_id, caption = p['image_id'], p['caption']
68
+ imgToEvalAllSPICE[image_id]['caption'] = capsById[image_id]
69
+ return {'overall': out, 'imgToEvalAllSPICE': imgToEvalAllSPICE}
70
+
71
+ def eval_oracle(dataset, preds_n, model_id, split):
72
+ cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
73
+
74
+ coco = getCOCO(dataset)
75
+ valids = coco.getImgIds()
76
+
77
+ capsById = {}
78
+ for d in preds_n:
79
+ capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
80
+
81
+ sample_n = capsById[list(capsById.keys())[0]]
82
+ for i in range(len(capsById[list(capsById.keys())[0]])):
83
+ preds = [_[i] for _ in capsById.values()]
84
+
85
+ json.dump(preds, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
86
+
87
+ cocoRes = coco.loadRes(cache_path)
88
+ cocoEval = COCOEvalCap(coco, cocoRes)
89
+ cocoEval.params['image_id'] = cocoRes.getImgIds()
90
+ cocoEval.evaluate()
91
+
92
+ imgToEval = cocoEval.imgToEval
93
+ for img_id in capsById.keys():
94
+ tmp = imgToEval[img_id]
95
+ for k in tmp['SPICE'].keys():
96
+ if k != 'All':
97
+ tmp['SPICE_'+k] = tmp['SPICE'][k]['f']
98
+ if tmp['SPICE_'+k] != tmp['SPICE_'+k]: # nan
99
+ tmp['SPICE_'+k] = -100
100
+ tmp['SPICE'] = tmp['SPICE']['All']['f']
101
+ if tmp['SPICE'] != tmp['SPICE']: tmp['SPICE'] = -100
102
+ capsById[img_id][i]['scores'] = imgToEval[img_id]
103
+
104
+ out = {'overall': {}, 'ImgToEval': {}}
105
+ for img_id in capsById.keys():
106
+ out['ImgToEval'][img_id] = {}
107
+ for metric in capsById[img_id][0]['scores'].keys():
108
+ if metric == 'image_id': continue
109
+ out['ImgToEval'][img_id]['oracle_'+metric] = max([_['scores'][metric] for _ in capsById[img_id]])
110
+ out['ImgToEval'][img_id]['avg_'+metric] = sum([_['scores'][metric] for _ in capsById[img_id]]) / len(capsById[img_id])
111
+ out['ImgToEval'][img_id]['captions'] = capsById[img_id]
112
+ for metric in list(out['ImgToEval'].values())[0].keys():
113
+ if metric == 'captions':
114
+ continue
115
+ tmp = np.array([_[metric] for _ in out['ImgToEval'].values()])
116
+ tmp = tmp[tmp!=-100]
117
+ out['overall'][metric] = tmp.mean()
118
+
119
+ return out
120
+
121
+ def eval_div_stats(dataset, preds_n, model_id, split):
122
+ tokenizer = PTBTokenizer()
123
+
124
+ capsById = {}
125
+ for i, d in enumerate(preds_n):
126
+ d['id'] = i
127
+ capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
128
+
129
+ n_caps_perimg = len(capsById[list(capsById.keys())[0]])
130
+ print(n_caps_perimg)
131
+ _capsById = capsById # save the untokenized version
132
+ capsById = tokenizer.tokenize(capsById)
133
+
134
+ div_1, adiv_1 = compute_div_n(capsById,1)
135
+ div_2, adiv_2 = compute_div_n(capsById,2)
136
+
137
+ globdiv_1, _= compute_global_div_n(capsById,1)
138
+
139
+ print('Diversity Statistics are as follows: \n Div1: %.2f, Div2: %.2f, gDiv1: %d\n'%(div_1,div_2, globdiv_1))
140
+
141
+ # compute mbleu
142
+ scorer = Bleu(4)
143
+ all_scrs = []
144
+ scrperimg = np.zeros((n_caps_perimg, len(capsById)))
145
+
146
+ for i in range(n_caps_perimg):
147
+ tempRefsById = {}
148
+ candsById = {}
149
+ for k in capsById:
150
+ tempRefsById[k] = capsById[k][:i] + capsById[k][i+1:]
151
+ candsById[k] = [capsById[k][i]]
152
+
153
+ score, scores = scorer.compute_score(tempRefsById, candsById)
154
+ all_scrs.append(score)
155
+ scrperimg[i,:] = scores[1]
156
+
157
+ all_scrs = np.array(all_scrs)
158
+
159
+ out = {}
160
+ out['overall'] = {'Div1': div_1, 'Div2': div_2, 'gDiv1': globdiv_1}
161
+ for k, score in zip(range(4), all_scrs.mean(axis=0).tolist()):
162
+ out['overall'].update({'mBLeu_%d'%(k+1): score})
163
+ imgToEval = {}
164
+ for i,imgid in enumerate(capsById.keys()):
165
+ imgToEval[imgid] = {'mBleu_2' : scrperimg[:,i].mean()}
166
+ imgToEval[imgid]['individuals'] = []
167
+ for j, d in enumerate(_capsById[imgid]):
168
+ imgToEval[imgid]['individuals'].append(preds_n[d['id']])
169
+ imgToEval[imgid]['individuals'][-1]['mBleu_2'] = scrperimg[j,i]
170
+ out['ImgToEval'] = imgToEval
171
+
172
+ print('Mean mutual Bleu scores on this set is:\nmBLeu_1, mBLeu_2, mBLeu_3, mBLeu_4')
173
+ print(all_scrs.mean(axis=0))
174
+
175
+ return out
176
+
177
+ def eval_self_cider(dataset, preds_n, model_id, split):
178
+ cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
179
+
180
+ coco = getCOCO(dataset)
181
+ valids = coco.getImgIds()
182
+
183
+ # Get Cider_scorer
184
+ Cider_scorer = Cider(df='corpus')
185
+
186
+ tokenizer = PTBTokenizer()
187
+ gts = {}
188
+ for imgId in valids:
189
+ gts[imgId] = coco.imgToAnns[imgId]
190
+ gts = tokenizer.tokenize(gts)
191
+
192
+ for imgId in valids:
193
+ Cider_scorer.cider_scorer += (None, gts[imgId])
194
+ Cider_scorer.cider_scorer.compute_doc_freq()
195
+ Cider_scorer.cider_scorer.ref_len = np.log(float(len(Cider_scorer.cider_scorer.crefs)))
196
+
197
+ # Prepare captions
198
+ capsById = {}
199
+ for d in preds_n:
200
+ capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
201
+
202
+ capsById = tokenizer.tokenize(capsById)
203
+ imgIds = list(capsById.keys())
204
+ scores = Cider_scorer.my_self_cider([capsById[_] for _ in imgIds])
205
+
206
+ def get_div(eigvals):
207
+ eigvals = np.clip(eigvals, 0, None)
208
+ return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals))
209
+ sc_scores = [get_div(np.linalg.eigvalsh(_/10)) for _ in scores]
210
+ score = np.mean(np.array(sc_scores))
211
+
212
+ imgToEval = {}
213
+ for i, image_id in enumerate(imgIds):
214
+ imgToEval[image_id] = {'self_cider': sc_scores[i], 'self_cider_mat': scores[i].tolist()}
215
+ return {'overall': {'self_cider': score}, 'imgToEval': imgToEval}
216
+
217
+
218
+ return score
captioning/utils/eval_utils.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ import numpy as np
10
+ import json
11
+ from json import encoder
12
+ import random
13
+ import string
14
+ import time
15
+ import os
16
+ import sys
17
+ from . import misc as utils
18
+
19
+ # load coco-caption if available
20
+ try:
21
+ sys.path.append("coco-caption")
22
+ from pycocotools.coco import COCO
23
+ from pycocoevalcap.eval import COCOEvalCap
24
+ except:
25
+ print('Warning: coco-caption not available')
26
+
27
+ bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
28
+ bad_endings += ['the']
29
+
30
+
31
+ def count_bad(sen):
32
+ sen = sen.split(' ')
33
+ if sen[-1] in bad_endings:
34
+ return 1
35
+ else:
36
+ return 0
37
+
38
+
39
+ def getCOCO(dataset):
40
+ if 'coco' in dataset:
41
+ annFile = 'coco-caption/annotations/captions_val2014.json'
42
+ elif 'flickr30k' in dataset or 'f30k' in dataset:
43
+ annFile = 'data/f30k_captions4eval.json'
44
+ return COCO(annFile)
45
+
46
+
47
+ def language_eval(dataset, preds, preds_n, eval_kwargs, split):
48
+ model_id = eval_kwargs['id']
49
+ eval_oracle = eval_kwargs.get('eval_oracle', 0)
50
+
51
+ # create output dictionary
52
+ out = {}
53
+
54
+ if len(preds_n) > 0:
55
+ # vocab size and novel sentences
56
+ if 'coco' in dataset:
57
+ dataset_file = 'data/dataset_coco.json'
58
+ elif 'flickr30k' in dataset or 'f30k' in dataset:
59
+ dataset_file = 'data/dataset_flickr30k.json'
60
+ training_sentences = set([' '.join(__['tokens']) for _ in json.load(open(dataset_file))['images'] if not _['split'] in ['val', 'test'] for __ in _['sentences']])
61
+ generated_sentences = set([_['caption'] for _ in preds_n])
62
+ novels = generated_sentences - training_sentences
63
+ out['novel_sentences'] = float(len(novels)) / len(preds_n)
64
+ tmp = [_.split() for _ in generated_sentences]
65
+ words = []
66
+ for _ in tmp:
67
+ words += _
68
+ out['vocab_size'] = len(set(words))
69
+
70
+ # encoder.FLOAT_REPR = lambda o: format(o, '.3f')
71
+
72
+ cache_path = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '.json')
73
+
74
+ coco = getCOCO(dataset)
75
+ valids = coco.getImgIds()
76
+
77
+ # filter results to only those in MSCOCO validation set
78
+ preds_filt = [p for p in preds if p['image_id'] in valids]
79
+ mean_perplexity = sum([_['perplexity'] for _ in preds_filt]) / len(preds_filt)
80
+ mean_entropy = sum([_['entropy'] for _ in preds_filt]) / len(preds_filt)
81
+ print('using %d/%d predictions' % (len(preds_filt), len(preds)))
82
+ json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
83
+
84
+ cocoRes = coco.loadRes(cache_path)
85
+ cocoEval = COCOEvalCap(coco, cocoRes)
86
+ cocoEval.params['image_id'] = cocoRes.getImgIds()
87
+ cocoEval.evaluate()
88
+
89
+ for metric, score in cocoEval.eval.items():
90
+ out[metric] = score
91
+ # Add mean perplexity
92
+ out['perplexity'] = mean_perplexity
93
+ out['entropy'] = mean_entropy
94
+
95
+ imgToEval = cocoEval.imgToEval
96
+ for k in list(imgToEval.values())[0]['SPICE'].keys():
97
+ if k != 'All':
98
+ out['SPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEval.values()])
99
+ out['SPICE_'+k] = (out['SPICE_'+k][out['SPICE_'+k]==out['SPICE_'+k]]).mean()
100
+ for p in preds_filt:
101
+ image_id, caption = p['image_id'], p['caption']
102
+ imgToEval[image_id]['caption'] = caption
103
+
104
+ if len(preds_n) > 0:
105
+ from . import eval_multi
106
+ cache_path_n = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '_n.json')
107
+ allspice = eval_multi.eval_allspice(dataset, preds_n, model_id, split)
108
+ out.update(allspice['overall'])
109
+ div_stats = eval_multi.eval_div_stats(dataset, preds_n, model_id, split)
110
+ out.update(div_stats['overall'])
111
+ if eval_oracle:
112
+ oracle = eval_multi.eval_oracle(dataset, preds_n, model_id, split)
113
+ out.update(oracle['overall'])
114
+ else:
115
+ oracle = None
116
+ self_cider = eval_multi.eval_self_cider(dataset, preds_n, model_id, split)
117
+ out.update(self_cider['overall'])
118
+ with open(cache_path_n, 'w') as outfile:
119
+ json.dump({'allspice': allspice, 'div_stats': div_stats, 'oracle': oracle, 'self_cider': self_cider}, outfile)
120
+
121
+ out['bad_count_rate'] = sum([count_bad(_['caption']) for _ in preds_filt]) / float(len(preds_filt))
122
+ outfile_path = os.path.join('eval_results/', model_id + '_' + split + '.json')
123
+ with open(outfile_path, 'w') as outfile:
124
+ json.dump({'overall': out, 'imgToEval': imgToEval}, outfile)
125
+
126
+ return out
127
+
128
+ def eval_split(model, crit, loader, eval_kwargs={}):
129
+ verbose = eval_kwargs.get('verbose', True)
130
+ verbose_beam = eval_kwargs.get('verbose_beam', 0)
131
+ verbose_loss = eval_kwargs.get('verbose_loss', 1)
132
+ num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
133
+ split = eval_kwargs.get('split', 'val')
134
+ lang_eval = eval_kwargs.get('language_eval', 0)
135
+ dataset = eval_kwargs.get('dataset', 'coco')
136
+ beam_size = eval_kwargs.get('beam_size', 1)
137
+ sample_n = eval_kwargs.get('sample_n', 1)
138
+ remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
139
+ os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration
140
+ device = eval_kwargs.get('device', 'cuda')
141
+
142
+ # Make sure in the evaluation mode
143
+ model.eval()
144
+
145
+ loader.reset_iterator(split)
146
+
147
+ n = 0
148
+ loss = 0
149
+ loss_sum = 0
150
+ loss_evals = 1e-8
151
+ predictions = []
152
+ n_predictions = [] # when sample_n > 1
153
+ while True:
154
+ data = loader.get_batch(split)
155
+ n = n + len(data['infos'])
156
+
157
+ tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
158
+ tmp = [_.to(device) if _ is not None else _ for _ in tmp]
159
+ fc_feats, att_feats, labels, masks, att_masks = tmp
160
+ if labels is not None and verbose_loss:
161
+ # forward the model to get loss
162
+ with torch.no_grad():
163
+ loss = crit(model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]).item()
164
+ loss_sum = loss_sum + loss
165
+ loss_evals = loss_evals + 1
166
+
167
+ # forward the model to also get generated samples for each image
168
+ with torch.no_grad():
169
+ tmp_eval_kwargs = eval_kwargs.copy()
170
+ tmp_eval_kwargs.update({'sample_n': 1})
171
+ seq, seq_logprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
172
+ seq = seq.data
173
+ entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
174
+ perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
175
+
176
+ # Print beam search
177
+ if beam_size > 1 and verbose_beam:
178
+ for i in range(fc_feats.shape[0]):
179
+ print('\n'.join([utils.decode_sequence(model.vocab, _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
180
+ print('--' * 10)
181
+ sents = utils.decode_sequence(model.vocab, seq)
182
+
183
+ for k, sent in enumerate(sents):
184
+ entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()}
185
+ if eval_kwargs.get('dump_path', 0) == 1:
186
+ entry['file_name'] = data['infos'][k]['file_path']
187
+ predictions.append(entry)
188
+ if eval_kwargs.get('dump_images', 0) == 1:
189
+ # dump the raw image to vis/ folder
190
+ cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross
191
+ print(cmd)
192
+ os.system(cmd)
193
+
194
+ if verbose:
195
+ print('image %s: %s' %(entry['image_id'], entry['caption']))
196
+
197
+ if sample_n > 1:
198
+ eval_split_n(model, n_predictions, [fc_feats, att_feats, att_masks, data], eval_kwargs)
199
+
200
+ # ix0 = data['bounds']['it_pos_now']
201
+ ix1 = data['bounds']['it_max']
202
+ if num_images != -1:
203
+ ix1 = min(ix1, num_images)
204
+ else:
205
+ num_images = ix1
206
+ for i in range(n - ix1):
207
+ predictions.pop()
208
+
209
+ if verbose:
210
+ print('evaluating validation preformance... %d/%d (%f)' %(n, ix1, loss))
211
+
212
+ if num_images >= 0 and n >= num_images:
213
+ break
214
+
215
+ lang_stats = None
216
+ if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]:
217
+ n_predictions = sorted(n_predictions, key=lambda x: x['perplexity'])
218
+ if not os.path.isdir('eval_results'):
219
+ os.mkdir('eval_results')
220
+ torch.save((predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth'))
221
+ if lang_eval == 1:
222
+ lang_stats = language_eval(dataset, predictions, n_predictions, eval_kwargs, split)
223
+
224
+ # Switch back to training mode
225
+ model.train()
226
+ return loss_sum/loss_evals, predictions, lang_stats
227
+
228
+
229
+ # Only run when sample_n > 0
230
+ def eval_split_n(model, n_predictions, input_data, eval_kwargs={}):
231
+ verbose = eval_kwargs.get('verbose', True)
232
+ beam_size = eval_kwargs.get('beam_size', 1)
233
+ sample_n = eval_kwargs.get('sample_n', 1)
234
+ sample_n_method = eval_kwargs.get('sample_n_method', 'sample')
235
+
236
+ fc_feats, att_feats, att_masks, data = input_data
237
+
238
+ tmp_eval_kwargs = eval_kwargs.copy()
239
+ if sample_n_method == 'bs':
240
+ # case 1 sample_n == beam size
241
+ tmp_eval_kwargs.update({'sample_n': 1, 'beam_size': sample_n, 'group_size': 1}) # randomness from softmax
242
+ with torch.no_grad():
243
+ model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
244
+ for k in range(fc_feats.shape[0]):
245
+ _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(sample_n)]))
246
+ for sent in _sents:
247
+ entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
248
+ n_predictions.append(entry)
249
+ # case 2 sample / gumbel / topk sampling/ nucleus sampling
250
+ elif sample_n_method == 'sample' or \
251
+ sample_n_method == 'gumbel' or \
252
+ sample_n_method.startswith('top'):
253
+ tmp_eval_kwargs.update({'sample_n': sample_n, 'sample_method': sample_n_method, 'beam_size': 1}) # randomness from sample
254
+ with torch.no_grad():
255
+ _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
256
+ _sents = utils.decode_sequence(model.vocab, _seq)
257
+ _perplexity = - _sampleLogprobs.gather(2, _seq.unsqueeze(2)).squeeze(2).sum(1) / ((_seq>0).to(_sampleLogprobs).sum(1)+1)
258
+ for k, sent in enumerate(_sents):
259
+ entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent, 'perplexity': _perplexity[k].item()}
260
+ n_predictions.append(entry)
261
+ elif sample_n_method == 'dbs':
262
+ # Use diverse beam search
263
+ tmp_eval_kwargs.update({'beam_size': sample_n * beam_size, 'group_size': sample_n}) # randomness from softmax
264
+ with torch.no_grad():
265
+ model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
266
+ for k in range(loader.batch_size):
267
+ _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(0, sample_n*beam_size, beam_size)]))
268
+ for sent in _sents:
269
+ entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
270
+ n_predictions.append(entry)
271
+ else:
272
+ tmp_eval_kwargs.update({'sample_method': sample_n_method[1:], 'group_size': sample_n, 'beam_size':1}) # randomness from softmax
273
+ with torch.no_grad():
274
+ _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
275
+ _sents = utils.decode_sequence(model.vocab, _seq)
276
+ for k, sent in enumerate(_sents):
277
+ entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent}
278
+ n_predictions.append(entry)
279
+ if verbose:
280
+ for entry in sorted(n_predictions[-fc_feats.shape[0] * sample_n:], key=lambda x: x['image_id']):
281
+ print('image %s: %s' %(entry['image_id'], entry['caption']))
captioning/utils/misc.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import collections
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ import torch.optim as optim
10
+ import os
11
+
12
+ import torch.nn.functional as F
13
+
14
+ import six
15
+ from six.moves import cPickle
16
+
17
+ bad_endings = ['with','in','on','of','a','at','to','for','an','this','his','her','that']
18
+ bad_endings += ['the']
19
+
20
+
21
+ def pickle_load(f):
22
+ """ Load a pickle.
23
+ Parameters
24
+ ----------
25
+ f: file-like object
26
+ """
27
+ if six.PY3:
28
+ return cPickle.load(f, encoding='latin-1')
29
+ else:
30
+ return cPickle.load(f)
31
+
32
+
33
+ def pickle_dump(obj, f):
34
+ """ Dump a pickle.
35
+ Parameters
36
+ ----------
37
+ obj: pickled object
38
+ f: file-like object
39
+ """
40
+ if six.PY3:
41
+ return cPickle.dump(obj, f, protocol=2)
42
+ else:
43
+ return cPickle.dump(obj, f)
44
+
45
+
46
+ # modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/comm.py
47
+ def serialize_to_tensor(data):
48
+ device = torch.device("cpu")
49
+
50
+ buffer = cPickle.dumps(data)
51
+ storage = torch.ByteStorage.from_buffer(buffer)
52
+ tensor = torch.ByteTensor(storage).to(device=device)
53
+ return tensor
54
+
55
+
56
+ def deserialize(tensor):
57
+ buffer = tensor.cpu().numpy().tobytes()
58
+ return cPickle.loads(buffer)
59
+
60
+
61
+ # Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token.
62
+ def decode_sequence(ix_to_word, seq):
63
+ # N, D = seq.size()
64
+ N, D = seq.shape
65
+ out = []
66
+ for i in range(N):
67
+ txt = ''
68
+ for j in range(D):
69
+ ix = seq[i,j]
70
+ if ix > 0 :
71
+ if j >= 1:
72
+ txt = txt + ' '
73
+ txt = txt + ix_to_word[str(ix.item())]
74
+ else:
75
+ break
76
+ if int(os.getenv('REMOVE_BAD_ENDINGS', '0')):
77
+ flag = 0
78
+ words = txt.split(' ')
79
+ for j in range(len(words)):
80
+ if words[-j-1] not in bad_endings:
81
+ flag = -j
82
+ break
83
+ txt = ' '.join(words[0:len(words)+flag])
84
+ out.append(txt.replace('@@ ', ''))
85
+ return out
86
+
87
+
88
+ def save_checkpoint(opt, model, infos, optimizer, histories=None, append=''):
89
+ if len(append) > 0:
90
+ append = '-' + append
91
+ # if checkpoint_path doesn't exist
92
+ if not os.path.isdir(opt.checkpoint_path):
93
+ os.makedirs(opt.checkpoint_path)
94
+ checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
95
+ torch.save(model.state_dict(), checkpoint_path)
96
+ print("model saved to {}".format(checkpoint_path))
97
+ optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
98
+ torch.save(optimizer.state_dict(), optimizer_path)
99
+ with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
100
+ pickle_dump(infos, f)
101
+ if histories:
102
+ with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
103
+ pickle_dump(histories, f)
104
+
105
+
106
+ def set_lr(optimizer, lr):
107
+ for group in optimizer.param_groups:
108
+ group['lr'] = lr
109
+
110
+ def get_lr(optimizer):
111
+ for group in optimizer.param_groups:
112
+ return group['lr']
113
+
114
+
115
+ def build_optimizer(params, opt):
116
+ if opt.optim == 'rmsprop':
117
+ return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay)
118
+ elif opt.optim == 'adagrad':
119
+ return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay)
120
+ elif opt.optim == 'sgd':
121
+ return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay)
122
+ elif opt.optim == 'sgdm':
123
+ return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay)
124
+ elif opt.optim == 'sgdmom':
125
+ return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True)
126
+ elif opt.optim == 'adam':
127
+ return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
128
+ elif opt.optim == 'adamw':
129
+ return optim.AdamW(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
130
+ else:
131
+ raise Exception("bad option opt.optim: {}".format(opt.optim))
132
+
133
+
134
+ def penalty_builder(penalty_config):
135
+ if penalty_config == '':
136
+ return lambda x,y: y
137
+ pen_type, alpha = penalty_config.split('_')
138
+ alpha = float(alpha)
139
+ if pen_type == 'wu':
140
+ return lambda x,y: length_wu(x,y,alpha)
141
+ if pen_type == 'avg':
142
+ return lambda x,y: length_average(x,y,alpha)
143
+
144
+ def length_wu(length, logprobs, alpha=0.):
145
+ """
146
+ NMT length re-ranking score from
147
+ "Google's Neural Machine Translation System" :cite:`wu2016google`.
148
+ """
149
+
150
+ modifier = (((5 + length) ** alpha) /
151
+ ((5 + 1) ** alpha))
152
+ return (logprobs / modifier)
153
+
154
+ def length_average(length, logprobs, alpha=0.):
155
+ """
156
+ Returns the average probability of tokens in a sequence.
157
+ """
158
+ return logprobs / length
159
+
160
+
161
+ class NoamOpt(object):
162
+ "Optim wrapper that implements rate."
163
+ def __init__(self, model_size, factor, warmup, optimizer):
164
+ self.optimizer = optimizer
165
+ self._step = 0
166
+ self.warmup = warmup
167
+ self.factor = factor
168
+ self.model_size = model_size
169
+ self._rate = 0
170
+
171
+ def step(self):
172
+ "Update parameters and rate"
173
+ self._step += 1
174
+ rate = self.rate()
175
+ for p in self.optimizer.param_groups:
176
+ p['lr'] = rate
177
+ self._rate = rate
178
+ self.optimizer.step()
179
+
180
+ def rate(self, step = None):
181
+ "Implement `lrate` above"
182
+ if step is None:
183
+ step = self._step
184
+ return self.factor * \
185
+ (self.model_size ** (-0.5) *
186
+ min(step ** (-0.5), step * self.warmup ** (-1.5)))
187
+
188
+ def __getattr__(self, name):
189
+ return getattr(self.optimizer, name)
190
+
191
+ def state_dict(self):
192
+ state_dict = self.optimizer.state_dict()
193
+ state_dict['_step'] = self._step
194
+ return state_dict
195
+
196
+ def load_state_dict(self, state_dict):
197
+ if '_step' in state_dict:
198
+ self._step = state_dict['_step']
199
+ del state_dict['_step']
200
+ self.optimizer.load_state_dict(state_dict)
201
+
202
+ class ReduceLROnPlateau(object):
203
+ "Optim wrapper that implements rate."
204
+ def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
205
+ self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps)
206
+ self.optimizer = optimizer
207
+ self.current_lr = get_lr(optimizer)
208
+
209
+ def step(self):
210
+ "Update parameters and rate"
211
+ self.optimizer.step()
212
+
213
+ def scheduler_step(self, val):
214
+ self.scheduler.step(val)
215
+ self.current_lr = get_lr(self.optimizer)
216
+
217
+ def state_dict(self):
218
+ return {'current_lr':self.current_lr,
219
+ 'scheduler_state_dict': self.scheduler.state_dict(),
220
+ 'optimizer_state_dict': self.optimizer.state_dict()}
221
+
222
+ def load_state_dict(self, state_dict):
223
+ if 'current_lr' not in state_dict:
224
+ # it's normal optimizer
225
+ self.optimizer.load_state_dict(state_dict)
226
+ set_lr(self.optimizer, self.current_lr) # use the lr fromt the option
227
+ else:
228
+ # it's a schduler
229
+ self.current_lr = state_dict['current_lr']
230
+ self.scheduler.load_state_dict(state_dict['scheduler_state_dict'])
231
+ self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
232
+ # current_lr is actually useless in this case
233
+
234
+ def rate(self, step = None):
235
+ "Implement `lrate` above"
236
+ if step is None:
237
+ step = self._step
238
+ return self.factor * \
239
+ (self.model_size ** (-0.5) *
240
+ min(step ** (-0.5), step * self.warmup ** (-1.5)))
241
+
242
+ def __getattr__(self, name):
243
+ return getattr(self.optimizer, name)
244
+
245
+ def get_std_opt(model, optim_func='adam', factor=1, warmup=2000):
246
+ # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000,
247
+ # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
248
+ optim_func = dict(adam=torch.optim.Adam,
249
+ adamw=torch.optim.AdamW)[optim_func]
250
+ return NoamOpt(model.d_model, factor, warmup,
251
+ optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
captioning/utils/opts.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import argparse
3
+
4
+
5
+ def if_use_feat(caption_model):
6
+ # Decide if load attention feature according to caption model
7
+ if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']:
8
+ use_att, use_fc = False, True
9
+ elif caption_model == 'language_model':
10
+ use_att, use_fc = False, False
11
+ elif caption_model in ['updown', 'topdown']:
12
+ use_fc, use_att = True, True
13
+ else:
14
+ use_att, use_fc = True, False
15
+ return use_fc, use_att
16
+
17
+ import pprint
18
+ class Config(object):
19
+ def __init__(self, **kwargs):
20
+ """Configuration Class: set kwargs as class attributes with setattr"""
21
+ for k, v in kwargs.items():
22
+ setattr(self, k, v)
23
+
24
+ @property
25
+ def config_str(self):
26
+ return pprint.pformat(self.__dict__)
27
+
28
+ def __repr__(self):
29
+ """Pretty-print configurations in alphabetical order"""
30
+ config_str = 'Configurations\n'
31
+ config_str += self.config_str
32
+ return config_str
33
+
34
+
35
+ def parse_opt(parse=True, **optional_kwargs):
36
+ parser = argparse.ArgumentParser()
37
+ # Data input settings
38
+ parser.add_argument('--input_json', type=str, default='data/coco.json',
39
+ help='path to the json file containing additional info and vocab')
40
+ parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc',
41
+ help='path to the directory containing the preprocessed fc feats')
42
+ parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att',
43
+ help='path to the directory containing the preprocessed att feats')
44
+ parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box',
45
+ help='path to the directory containing the boxes of att feats')
46
+ parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5',
47
+ help='path to the h5file containing the preprocessed dataset')
48
+ parser.add_argument('--data_in_memory', action='store_true',
49
+ help='True if we want to save the features in memory')
50
+ parser.add_argument('--start_from', type=str, default=None,
51
+ help="""continue training from saved model at this path. Path must contain files saved by previous training process:
52
+ 'infos.pkl' : configuration;
53
+ 'model.pth' : weights
54
+ """)
55
+ parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs',
56
+ help='Cached token file for calculating cider score during self critical training.')
57
+
58
+ # Model settings
59
+ parser.add_argument('--caption_model', type=str, default="show_tell",
60
+ help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, updown, stackatt, denseatt, transformer')
61
+ parser.add_argument('--rnn_size', type=int, default=512,
62
+ help='size of the rnn in number of hidden nodes in each layer')
63
+ parser.add_argument('--num_layers', type=int, default=1,
64
+ help='number of layers in the RNN')
65
+ parser.add_argument('--rnn_type', type=str, default='lstm',
66
+ help='rnn, gru, or lstm')
67
+ parser.add_argument('--input_encoding_size', type=int, default=512,
68
+ help='the encoding size of each token in the vocabulary, and the image.')
69
+ parser.add_argument('--att_hid_size', type=int, default=512,
70
+ help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer')
71
+ parser.add_argument('--fc_feat_size', type=int, default=2048,
72
+ help='2048 for resnet, 4096 for vgg')
73
+ parser.add_argument('--att_feat_size', type=int, default=2048,
74
+ help='2048 for resnet, 512 for vgg')
75
+ parser.add_argument('--logit_layers', type=int, default=1,
76
+ help='number of layers in the RNN')
77
+
78
+
79
+ parser.add_argument('--use_bn', type=int, default=0,
80
+ help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed')
81
+
82
+ # feature manipulation
83
+ parser.add_argument('--norm_att_feat', type=int, default=0,
84
+ help='If normalize attention features')
85
+ parser.add_argument('--use_box', type=int, default=0,
86
+ help='If use box features')
87
+ parser.add_argument('--norm_box_feat', type=int, default=0,
88
+ help='If use box, do we normalize box feature')
89
+
90
+ # Optimization: General
91
+ parser.add_argument('--max_epochs', type=int, default=-1,
92
+ help='number of epochs')
93
+ parser.add_argument('--batch_size', type=int, default=16,
94
+ help='minibatch size')
95
+ parser.add_argument('--grad_clip_mode', type=str, default='value',
96
+ help='value or norm')
97
+ parser.add_argument('--grad_clip_value', type=float, default=0.1,
98
+ help='clip gradients at this value/max_norm, 0 means no clipping')
99
+ parser.add_argument('--drop_prob_lm', type=float, default=0.5,
100
+ help='strength of dropout in the Language Model RNN')
101
+ parser.add_argument('--self_critical_after', type=int, default=-1,
102
+ help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)')
103
+ parser.add_argument('--seq_per_img', type=int, default=5,
104
+ help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image')
105
+
106
+ parser.add_argument('--verbose', type=int, default=0)
107
+
108
+ # Sample related
109
+ add_eval_sample_opts(parser)
110
+
111
+ #Optimization: for the Language Model
112
+ parser.add_argument('--optim', type=str, default='adam',
113
+ help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam|adamw')
114
+ parser.add_argument('--learning_rate', type=float, default=4e-4,
115
+ help='learning rate')
116
+ parser.add_argument('--learning_rate_decay_start', type=int, default=-1,
117
+ help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)')
118
+ parser.add_argument('--learning_rate_decay_every', type=int, default=3,
119
+ help='every how many iterations thereafter to drop LR?(in epoch)')
120
+ parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8,
121
+ help='every how many iterations thereafter to drop LR?(in epoch)')
122
+ parser.add_argument('--optim_alpha', type=float, default=0.9,
123
+ help='alpha for adam')
124
+ parser.add_argument('--optim_beta', type=float, default=0.999,
125
+ help='beta used for adam')
126
+ parser.add_argument('--optim_epsilon', type=float, default=1e-8,
127
+ help='epsilon that goes into denominator for smoothing')
128
+ parser.add_argument('--weight_decay', type=float, default=0,
129
+ help='weight_decay')
130
+ # Transformer
131
+ parser.add_argument('--label_smoothing', type=float, default=0,
132
+ help='')
133
+ parser.add_argument('--noamopt', action='store_true',
134
+ help='')
135
+ parser.add_argument('--noamopt_warmup', type=int, default=2000,
136
+ help='')
137
+ parser.add_argument('--noamopt_factor', type=float, default=1,
138
+ help='')
139
+ parser.add_argument('--reduce_on_plateau', action='store_true',
140
+ help='')
141
+ parser.add_argument('--reduce_on_plateau_factor', type=float, default=0.5,
142
+ help='')
143
+ parser.add_argument('--reduce_on_plateau_patience', type=int, default=3,
144
+ help='')
145
+ parser.add_argument('--cached_transformer', action='store_true',
146
+ help='')
147
+
148
+
149
+ parser.add_argument('--use_warmup', action='store_true',
150
+ help='warm up the learing rate?')
151
+
152
+ parser.add_argument('--scheduled_sampling_start', type=int, default=-1,
153
+ help='at what iteration to start decay gt probability')
154
+ parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5,
155
+ help='every how many iterations thereafter to gt probability')
156
+ parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05,
157
+ help='How much to update the prob')
158
+ parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25,
159
+ help='Maximum scheduled sampling prob.')
160
+
161
+
162
+ # Evaluation/Checkpointing
163
+ parser.add_argument('--val_images_use', type=int, default=3200,
164
+ help='how many images to use when periodically evaluating the validation loss? (-1 = all)')
165
+ parser.add_argument('--save_checkpoint_every', type=int, default=2500,
166
+ help='how often to save a model checkpoint (in iterations)?')
167
+ parser.add_argument('--save_every_epoch', action='store_true',
168
+ help='Save checkpoint every epoch, will overwrite save_checkpoint_every')
169
+ parser.add_argument('--save_history_ckpt', type=int, default=0,
170
+ help='If save checkpoints at every save point')
171
+ parser.add_argument('--checkpoint_path', type=str, default=None,
172
+ help='directory to store checkpointed models')
173
+ parser.add_argument('--language_eval', type=int, default=0,
174
+ help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
175
+ parser.add_argument('--losses_log_every', type=int, default=25,
176
+ help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)')
177
+ parser.add_argument('--load_best_score', type=int, default=1,
178
+ help='Do we load previous best score when resuming training.')
179
+
180
+ # misc
181
+ parser.add_argument('--id', type=str, default='',
182
+ help='an id identifying this run/job. used in cross-val and appended when writing progress files')
183
+ parser.add_argument('--train_only', type=int, default=0,
184
+ help='if true then use 80k, else use 110k')
185
+
186
+
187
+ # Reward
188
+ parser.add_argument('--cider_reward_weight', type=float, default=1,
189
+ help='The reward weight from cider')
190
+ parser.add_argument('--bleu_reward_weight', type=float, default=0,
191
+ help='The reward weight from bleu4')
192
+
193
+ # Reward
194
+ parser.add_argument('--clipscore_reward_weight', type=float, default=1,
195
+ help='The reward weight from clipscore')
196
+ parser.add_argument('--use_clipscore', type=float, default=0,
197
+ help='Use CLIPScore')
198
+ parser.add_argument('--clipscore_mode', type=str, default='clip_s',
199
+ help='Which CLIPScore to use: clip_s|refclip_s')
200
+
201
+
202
+ # Structure_loss
203
+ parser.add_argument('--structure_loss_weight', type=float, default=1,
204
+ help='')
205
+ parser.add_argument('--structure_after', type=int, default=-1,
206
+ help='T')
207
+ parser.add_argument('--structure_loss_type', type=str, default='seqnll',
208
+ help='')
209
+ parser.add_argument('--struc_use_logsoftmax', action='store_true', help='')
210
+ parser.add_argument('--entropy_reward_weight', type=float, default=0,
211
+ help='Entropy reward, seems very interesting')
212
+ parser.add_argument('--self_cider_reward_weight', type=float, default=0,
213
+ help='self cider reward')
214
+
215
+ # Used for self critical or structure. Used when sampling is need during training
216
+ parser.add_argument('--train_sample_n', type=int, default=16,
217
+ help='The reward weight from cider')
218
+ parser.add_argument('--train_sample_method', type=str, default='sample',
219
+ help='')
220
+ parser.add_argument('--train_beam_size', type=int, default=1,
221
+ help='')
222
+
223
+ # Used for self critical
224
+ parser.add_argument('--sc_sample_method', type=str, default='greedy',
225
+ help='')
226
+ parser.add_argument('--sc_beam_size', type=int, default=1,
227
+ help='')
228
+
229
+
230
+ # For diversity evaluation during training
231
+ add_diversity_opts(parser)
232
+
233
+
234
+ # config
235
+ parser.add_argument('--cfg', type=str, default=None,
236
+ help='configuration; similar to what is used in detectron')
237
+ parser.add_argument(
238
+ '--set_cfgs', dest='set_cfgs',
239
+ help='Set config keys. Key value sequence seperate by whitespace.'
240
+ 'e.g. [key] [value] [key] [value]\n This has higher priority'
241
+ 'than cfg file but lower than other args. (You can only overwrite'
242
+ 'arguments that have alerady been defined in config file.)',
243
+ default=[], nargs='+')
244
+ # How will config be used
245
+ # 1) read cfg argument, and load the cfg file if it's not None
246
+ # 2) Overwrite cfg argument with set_cfgs
247
+ # 3) parse config argument to args.
248
+ # 4) in the end, parse command line argument and overwrite args
249
+
250
+ # step 1: read cfg_fn
251
+ # args = parser.parse_args()
252
+ # Parse the arguments.
253
+ if parse:
254
+ args = parser.parse_args()
255
+ # For interative engironmnet (ex. jupyter)
256
+ else:
257
+ args = parser.parse_known_args()[0]
258
+ # print(args)
259
+
260
+ # Namespace => Dictionary
261
+ kwargs = vars(args)
262
+ # for k, v in optional_kwargs.items():
263
+ # setattr(args, k, v)
264
+ kwargs.update(optional_kwargs)
265
+
266
+ args = Config(**kwargs)
267
+
268
+
269
+ if args.cfg is not None or args.set_cfgs is not None:
270
+ from .config import CfgNode
271
+ if args.cfg is not None:
272
+ # print('Read Cfg')
273
+ cn = CfgNode(CfgNode.load_yaml_with_base(args.cfg))
274
+ # print(cn)
275
+ else:
276
+ cn = CfgNode()
277
+ if args.set_cfgs is not None:
278
+ cn.merge_from_list(args.set_cfgs)
279
+ for k,v in cn.items():
280
+ if not hasattr(args, k):
281
+ import os
282
+ if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
283
+ pass
284
+ else:
285
+ print('Warning: key %s not in args' % k)
286
+
287
+ setattr(args, k, v)
288
+
289
+ if parse:
290
+ args = parser.parse_args(namespace=args)
291
+ else:
292
+ args = parser.parse_known_args(namespace=args)[0]
293
+
294
+ # Check if args are valid
295
+ assert args.rnn_size > 0, "rnn_size should be greater than 0"
296
+ assert args.num_layers > 0, "num_layers should be greater than 0"
297
+ assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0"
298
+ assert args.batch_size > 0, "batch_size should be greater than 0"
299
+ assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1"
300
+ assert args.seq_per_img > 0, "seq_per_img should be greater than 0"
301
+ assert args.beam_size > 0, "beam_size should be greater than 0"
302
+ assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0"
303
+ assert args.losses_log_every > 0, "losses_log_every should be greater than 0"
304
+ assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1"
305
+ assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1"
306
+ assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1"
307
+
308
+ # default value for start_from and checkpoint_path
309
+ args.checkpoint_path = args.checkpoint_path or './log_%s' %args.id
310
+ args.start_from = args.start_from or args.checkpoint_path
311
+
312
+ # Deal with feature things before anything
313
+ args.use_fc, args.use_att = if_use_feat(args.caption_model)
314
+ if args.use_box: args.att_feat_size = args.att_feat_size + 5
315
+
316
+ return args
317
+
318
+
319
+ def add_eval_options(parser):
320
+ # Basic options
321
+ parser.add_argument('--batch_size', type=int, default=0,
322
+ help='if > 0 then overrule, otherwise load from checkpoint.')
323
+ parser.add_argument('--num_images', type=int, default=-1,
324
+ help='how many images to use when periodically evaluating the loss? (-1 = all)')
325
+ parser.add_argument('--language_eval', type=int, default=0,
326
+ help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
327
+ parser.add_argument('--dump_images', type=int, default=1,
328
+ help='Dump images into vis/imgs folder for vis? (1=yes,0=no)')
329
+ parser.add_argument('--dump_json', type=int, default=1,
330
+ help='Dump json with predictions into vis folder? (1=yes,0=no)')
331
+ parser.add_argument('--dump_path', type=int, default=0,
332
+ help='Write image paths along with predictions into vis json? (1=yes,0=no)')
333
+
334
+ # Sampling options
335
+ add_eval_sample_opts(parser)
336
+
337
+ # For evaluation on a folder of images:
338
+ parser.add_argument('--image_folder', type=str, default='',
339
+ help='If this is nonempty then will predict on the images in this folder path')
340
+ parser.add_argument('--image_root', type=str, default='',
341
+ help='In case the image paths have to be preprended with a root path to an image folder')
342
+ # For evaluation on MSCOCO images from some split:
343
+ parser.add_argument('--input_fc_dir', type=str, default='',
344
+ help='path to the h5file containing the preprocessed dataset')
345
+ parser.add_argument('--input_att_dir', type=str, default='',
346
+ help='path to the h5file containing the preprocessed dataset')
347
+ parser.add_argument('--input_box_dir', type=str, default='',
348
+ help='path to the h5file containing the preprocessed dataset')
349
+ parser.add_argument('--input_label_h5', type=str, default='',
350
+ help='path to the h5file containing the preprocessed dataset')
351
+ parser.add_argument('--input_json', type=str, default='',
352
+ help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.')
353
+ parser.add_argument('--split', type=str, default='test',
354
+ help='if running on MSCOCO images, which split to use: val|test|train')
355
+ parser.add_argument('--coco_json', type=str, default='',
356
+ help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.')
357
+ # misc
358
+ parser.add_argument('--id', type=str, default='',
359
+ help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
360
+ parser.add_argument('--verbose_beam', type=int, default=1,
361
+ help='if we need to print out all beam search beams.')
362
+ parser.add_argument('--verbose_loss', type=int, default=0,
363
+ help='If calculate loss using ground truth during evaluation')
364
+
365
+ def add_diversity_opts(parser):
366
+ parser.add_argument('--sample_n', type=int, default=1,
367
+ help='Diverse sampling')
368
+ parser.add_argument('--sample_n_method', type=str, default='sample',
369
+ help='sample, bs, dbs, gumbel, topk, dgreedy, dsample, dtopk, dtopp')
370
+ parser.add_argument('--eval_oracle', type=int, default=1,
371
+ help='if we need to calculate loss.')
372
+
373
+
374
+ # Sampling related options
375
+ def add_eval_sample_opts(parser):
376
+ parser.add_argument('--sample_method', type=str, default='greedy',
377
+ help='greedy; sample; gumbel; top<int>, top<0-1>')
378
+ parser.add_argument('--beam_size', type=int, default=1,
379
+ help='used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
380
+ parser.add_argument('--max_length', type=int, default=20,
381
+ help='Maximum length during sampling')
382
+ parser.add_argument('--length_penalty', type=str, default='',
383
+ help='wu_X or avg_X, X is the alpha')
384
+ parser.add_argument('--group_size', type=int, default=1,
385
+ help='used for diverse beam search. if group_size is 1, then it\'s normal beam search')
386
+ parser.add_argument('--diversity_lambda', type=float, default=0.5,
387
+ help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list')
388
+ parser.add_argument('--temperature', type=float, default=1.0,
389
+ help='temperature when sampling from distributions (i.e. when sample_method = sample). Lower = "safer" predictions.')
390
+ parser.add_argument('--decoding_constraint', type=int, default=0,
391
+ help='If 1, not allowing same word in a row')
392
+ parser.add_argument('--block_trigrams', type=int, default=0,
393
+ help='block repeated trigram.')
394
+ parser.add_argument('--remove_bad_endings', type=int, default=0,
395
+ help='Remove bad endings')
396
+ parser.add_argument('--suppress_UNK', type=int, default=1,
397
+ help='Not predicting UNK')
398
+
399
+
400
+ if __name__ == '__main__':
401
+ import sys
402
+ sys.argv = [sys.argv[0]]
403
+ args = parse_opt()
404
+ print(args)
405
+ print()
406
+ sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml']
407
+ args1 = parse_opt()
408
+ print(dict(set(vars(args1).items()) - set(vars(args).items())))
409
+ print()
410
+ sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml', '--caption_model', 'att2in2']
411
+ args2 = parse_opt()
412
+ print(dict(set(vars(args2).items()) - set(vars(args1).items())))
captioning/utils/resnet.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models.resnet
4
+ from torchvision.models.resnet import BasicBlock, Bottleneck
5
+
6
+ class ResNet(torchvision.models.resnet.ResNet):
7
+ def __init__(self, block, layers, num_classes=1000):
8
+ super(ResNet, self).__init__(block, layers, num_classes)
9
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
10
+ for i in range(2, 5):
11
+ getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2)
12
+ getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1)
13
+
14
+ def resnet18(pretrained=False):
15
+ """Constructs a ResNet-18 model.
16
+
17
+ Args:
18
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
19
+ """
20
+ model = ResNet(BasicBlock, [2, 2, 2, 2])
21
+ if pretrained:
22
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
23
+ return model
24
+
25
+
26
+ def resnet34(pretrained=False):
27
+ """Constructs a ResNet-34 model.
28
+
29
+ Args:
30
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
31
+ """
32
+ model = ResNet(BasicBlock, [3, 4, 6, 3])
33
+ if pretrained:
34
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
35
+ return model
36
+
37
+
38
+ def resnet50(pretrained=False):
39
+ """Constructs a ResNet-50 model.
40
+
41
+ Args:
42
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
43
+ """
44
+ model = ResNet(Bottleneck, [3, 4, 6, 3])
45
+ if pretrained:
46
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
47
+ return model
48
+
49
+
50
+ def resnet101(pretrained=False):
51
+ """Constructs a ResNet-101 model.
52
+
53
+ Args:
54
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
55
+ """
56
+ model = ResNet(Bottleneck, [3, 4, 23, 3])
57
+ if pretrained:
58
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
59
+ return model
60
+
61
+
62
+ def resnet152(pretrained=False):
63
+ """Constructs a ResNet-152 model.
64
+
65
+ Args:
66
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
67
+ """
68
+ model = ResNet(Bottleneck, [3, 8, 36, 3])
69
+ if pretrained:
70
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
71
+ return model
captioning/utils/resnet_utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class myResnet(nn.Module):
6
+ def __init__(self, resnet):
7
+ super(myResnet, self).__init__()
8
+ self.resnet = resnet
9
+
10
+ def forward(self, img, att_size=14):
11
+ x = img.unsqueeze(0)
12
+
13
+ x = self.resnet.conv1(x)
14
+ x = self.resnet.bn1(x)
15
+ x = self.resnet.relu(x)
16
+ x = self.resnet.maxpool(x)
17
+
18
+ x = self.resnet.layer1(x)
19
+ x = self.resnet.layer2(x)
20
+ x = self.resnet.layer3(x)
21
+ x = self.resnet.layer4(x)
22
+
23
+ fc = x.mean(3).mean(2).squeeze()
24
+ att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0)
25
+
26
+ return fc, att
27
+
captioning/utils/rewards.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import numpy as np
6
+ import time
7
+ from collections import OrderedDict
8
+ import torch
9
+
10
+ import sys
11
+ try:
12
+ sys.path.append("cider")
13
+ from pyciderevalcap.ciderD.ciderD import CiderD
14
+ from pyciderevalcap.cider.cider import Cider
15
+ sys.path.append("coco-caption")
16
+ from pycocoevalcap.bleu.bleu import Bleu
17
+ except:
18
+ print('cider or coco-caption missing')
19
+
20
+ CiderD_scorer = None
21
+ Cider_scorer = None
22
+ Bleu_scorer = None
23
+ #CiderD_scorer = CiderD(df='corpus')
24
+
25
+
26
+ from .misc import decode_sequence
27
+
28
+ def init_scorer(cached_tokens):
29
+ global CiderD_scorer
30
+ CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens)
31
+ global Cider_scorer
32
+ Cider_scorer = Cider_scorer or Cider(df=cached_tokens)
33
+ global Bleu_scorer
34
+ Bleu_scorer = Bleu_scorer or Bleu(4)
35
+
36
+ def array_to_str(arr):
37
+ out = ''
38
+ for i in range(len(arr)):
39
+ out += str(arr[i]) + ' '
40
+ if arr[i] == 0:
41
+ break
42
+ return out.strip()
43
+
44
+ def get_self_critical_reward(greedy_res, data_gts, gen_result, opt):
45
+ batch_size = len(data_gts)
46
+ gen_result_size = gen_result.shape[0]
47
+ seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img
48
+ assert greedy_res.shape[0] == batch_size
49
+
50
+ res = OrderedDict()
51
+ gen_result = gen_result.data.cpu().numpy()
52
+ greedy_res = greedy_res.data.cpu().numpy()
53
+ for i in range(gen_result_size):
54
+ res[i] = [array_to_str(gen_result[i])]
55
+ for i in range(batch_size):
56
+ res[gen_result_size + i] = [array_to_str(greedy_res[i])]
57
+
58
+ gts = OrderedDict()
59
+ for i in range(len(data_gts)):
60
+ gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
61
+
62
+ res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
63
+ res__ = {i: res[i] for i in range(len(res_))}
64
+ gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)}
65
+ gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)})
66
+ if opt.cider_reward_weight > 0:
67
+ _, cider_scores = CiderD_scorer.compute_score(gts_, res_)
68
+ if hasattr(opt, 'verbose') and not opt.verbose:
69
+ pass
70
+ else:
71
+ print('Cider scores:', _)
72
+ else:
73
+ cider_scores = 0
74
+ if opt.bleu_reward_weight > 0:
75
+ _, bleu_scores = Bleu_scorer.compute_score(gts_, res__)
76
+ bleu_scores = np.array(bleu_scores[3])
77
+ if hasattr(opt, 'verbose') and not opt.verbose:
78
+ pass
79
+ else:
80
+ print('Bleu scores:', _[3])
81
+ else:
82
+ bleu_scores = 0
83
+ scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
84
+
85
+ unnormalized_reward_mean = scores[:gen_result_size].flatten().mean()
86
+
87
+ scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis]
88
+
89
+ scores = scores.reshape(gen_result_size)
90
+
91
+ rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)
92
+
93
+ return rewards, unnormalized_reward_mean
94
+
95
+
96
+ def get_self_critical_clipscore_reward(greedy_res, data_gts, gen_result, opt, clipscore_model, clip_vis_feats, vocab):
97
+ batch_size = len(data_gts)
98
+ gen_result_size = gen_result.shape[0]
99
+ seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img
100
+ assert greedy_res.shape[0] == batch_size
101
+
102
+ B = batch_size
103
+ K = seq_per_img
104
+ L = gen_result.shape[1]
105
+ assert gen_result.shape == (B*K , L)
106
+
107
+ # res = OrderedDict()
108
+ # gen_result = gen_result.data.cpu().numpy()
109
+ # greedy_res = greedy_res.data.cpu().numpy()
110
+ # for i in range(gen_result_size):
111
+ # res[i] = [array_to_str(gen_result[i])]
112
+ # for i in range(batch_size):
113
+ # res[gen_result_size + i] = [array_to_str(greedy_res[i])]
114
+
115
+ # gts = OrderedDict()
116
+ # for i in range(len(data_gts)):
117
+ # gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
118
+
119
+ # res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
120
+ # res__ = {i: res[i] for i in range(len(res_))}
121
+ # gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)}
122
+ # gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)})
123
+
124
+ # res = []
125
+ # gen_result = gen_result.data.cpu().numpy()
126
+ # greedy_res = greedy_res.data.cpu().numpy()
127
+ # # for i in range(gen_result_size):
128
+ # # res.append(array_to_str(gen_result[i]))
129
+ # res.extend(decode_sequence(vocab, gen_result))
130
+
131
+
132
+ # # for i in range(batch_size):
133
+ # # res.append(array_to_str(greedy_res[i]))
134
+ # res.extend(decode_sequence(vocab, greedy_res))
135
+
136
+ if clipscore_model.mode == 'refclip_s':
137
+ gts = []
138
+ gts_valid_mask = []
139
+ max_n_refs = max([len(_gts) for _gts in data_gts])
140
+ for i in range(len(data_gts)):
141
+ _gts = decode_sequence(vocab, data_gts[i])
142
+ # pad references
143
+ n_ref = len(_gts)
144
+ _gts.extend([''] * (max_n_refs - n_ref))
145
+ gts.extend(_gts)
146
+ gts_valid_mask.extend([1] * n_ref + [0] * (max_n_refs - n_ref))
147
+ assert len(gts) == B * max_n_refs
148
+ assert len(gts_valid_mask) == B * max_n_refs
149
+
150
+ # print(gts)
151
+ # print(gts_valid_mask)
152
+ # exit()
153
+
154
+
155
+ # assert len(res) == B * K + B, len(res)
156
+
157
+ # print(res)
158
+ # exit()
159
+
160
+ if opt.clipscore_reward_weight > 0:
161
+ with torch.no_grad():
162
+ clipscore_model.eval()
163
+
164
+ # 1) calculate reward
165
+ gen_result = gen_result.data.cpu().numpy()
166
+ res = decode_sequence(vocab, gen_result)
167
+ assert len(res) == B * K, len(res)
168
+
169
+ # [B * K, dim)
170
+ if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False):
171
+ text_pre_feat = clipscore_model.text_extract(res, proj_norm=False)
172
+
173
+ grammar_logit = clipscore_model.grammar_score_head(text_pre_feat.view(-1, 512))
174
+ grammar_prob = torch.softmax(grammar_logit, dim=-1)[:, 1]
175
+ grammar_prob = grammar_prob.view(B*K).detach()
176
+
177
+ text_feat = clipscore_model.clip_model.text_projection(text_pre_feat)
178
+ text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
179
+
180
+ else:
181
+ text_feat = clipscore_model.text_extract(res)
182
+
183
+
184
+ assert text_feat.size() == (B * K, 512), text_feat.size()
185
+ assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size()
186
+
187
+ # [B * K, dim]
188
+ vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K, -1).contiguous().view(B * K, -1)
189
+
190
+ clip_s = clipscore_model(text_feat=text_feat, img_feat=vis_feat, mode='clip_s')
191
+ clip_s = clip_s.view(B * K).detach()
192
+
193
+ if clipscore_model.mode == 'refclip_s':
194
+ # [B * n_ref, dim]
195
+ ref_text_feat = clipscore_model.text_extract(gts)
196
+ ref_text_mask = torch.tensor(gts_valid_mask, dtype=ref_text_feat.dtype, device=ref_text_feat.device)
197
+
198
+ assert ref_text_feat.size() == (B * max_n_refs, 512), ref_text_feat.size()
199
+ assert ref_text_mask.size() == (B * max_n_refs,), ref_text_mask.size()
200
+
201
+ # [B * K]
202
+ refclip_s = clipscore_model.calc_refclip_s(
203
+ text_feat=text_feat, img_feat=vis_feat,
204
+ ref_text_feat=ref_text_feat.view(B, 1, max_n_refs, -1).expand(-1, K, -1, -1).contiguous().view(B * K * max_n_refs, -1),
205
+ ref_text_mask=ref_text_mask.view(B, 1, max_n_refs).expand(-1, K, -1).contiguous().view(B * K * max_n_refs),
206
+ clip_s=clip_s)
207
+ refclip_s = refclip_s.view(B * K).detach()
208
+
209
+ # 2) calcualte reward for baseline (greedy)
210
+ greedy_res = greedy_res.data.cpu().numpy()
211
+ res = decode_sequence(vocab, greedy_res)
212
+ assert len(res) == B, len(res)
213
+
214
+ # [B, dim)
215
+
216
+ if getattr(opt, 'use_grammar', False) and getattr(opt, 'use_grammar_baseline', False) and not getattr(opt, 'joint_out', False):
217
+ text_pre_feat = clipscore_model.text_extract(res, proj_norm=False)
218
+
219
+ grammar_logit = clipscore_model.grammar_score_head(text_pre_feat.view(-1, 512))
220
+ grammar_prob_baseline = torch.softmax(grammar_logit, dim=-1)[:, 1]
221
+ grammar_prob_baseline = grammar_prob_baseline.view(B).detach()
222
+
223
+ text_feat = clipscore_model.clip_model.text_projection(text_pre_feat)
224
+ text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
225
+ else:
226
+ text_feat = clipscore_model.text_extract(res)
227
+
228
+ assert text_feat.size() == (B, 512), text_feat.size()
229
+ assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size()
230
+
231
+ vis_feat = clip_vis_feats.view(B, 512)
232
+
233
+ # [B]
234
+ clip_s_baseline = clipscore_model(text_feat=text_feat, img_feat=vis_feat, mode='clip_s')
235
+ clip_s_baseline = clip_s_baseline.view(B).detach()
236
+
237
+ if clipscore_model.mode == 'refclip_s':
238
+ # # [B * n_ref]
239
+ # ref_text_feat = clipscore_model.text_extract(gts)
240
+ # ref_text_mask = torch.tensor(gts_valid_mask, dtype=ref_text_feat.dtype, device=ref_text_feat.device)
241
+ # assert ref_text_feat.size() == (B * max_n_refs, 512), ref_text_feat.size()
242
+ # assert ref_text_mask.size() == (B * max_n_refs), ref_text_mask.size()
243
+
244
+ # [B]
245
+ refclip_s_baseline = clipscore_model.calc_refclip_s(
246
+ text_feat=text_feat, img_feat=vis_feat,
247
+ ref_text_feat=ref_text_feat,
248
+ ref_text_mask=ref_text_mask,
249
+ clip_s=clip_s_baseline)
250
+ refclip_s_baseline = refclip_s_baseline.view(B).detach()
251
+
252
+ if clipscore_model.mode == 'clip_s':
253
+ rewards = clip_s - clip_s_baseline.view(B, 1).expand(-1, K).contiguous().flatten()
254
+ unnormalized_mean_reward = clip_s.mean()
255
+ elif clipscore_model.mode == 'refclip_s':
256
+ rewards = refclip_s - refclip_s_baseline.view(B, 1).expand(-1, K).contiguous().flatten()
257
+ unnormalized_mean_reward = refclip_s.mean()
258
+
259
+ # # [B * K + B, dim)
260
+ # text_feat = clipscore_model.text_extract(res)
261
+ # assert text_feat.size() == (B * K + B, 512), text_feat.size()
262
+
263
+ # assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size()
264
+
265
+ # # [B, dim] -> [B * K + B, dim]
266
+ # # vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K + 1, -1).contiguous().view(B * (K + 1), -1)
267
+ # # vis_feat = clip_vis_feats.view(1, B, -1).expand(K + 1, -1, -1).contiguous().view((K + 1) * B, -1)
268
+
269
+ # # [B * K, dim]
270
+ # gen_vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K, -1).contiguous().view(B * K, -1)
271
+ # # [B, dim]
272
+ # greedy_vis_feat = clip_vis_feats
273
+ # # [B * K + B, dim]
274
+ # vis_feat = torch.cat([gen_vis_feat, greedy_vis_feat], dim=0)
275
+
276
+ # # if clipscore_model.mode == 'clip_s':
277
+ # # [B * K + B, dim]
278
+ # clip_s = clipscore_model(text_feat=text_feat, img_feat=vis_feat)
279
+ # clip_s = clip_s.view(B * K + B).detach()
280
+
281
+
282
+ # if clipscore_model.mode == 'refclip_s':
283
+ # # [B * K, dim]
284
+ # ref_text_feat = clipscore_model.text_extract(gts)
285
+
286
+ # clipscore_scores = clipscore_model.calc_refclip_s(text_feat=text_feat, img_feat=vis_feat, ref_text_feat=ref_text_feat, clip_s=clip_s)
287
+ # clipscore_scores = clipscore_scores.view(B * K + B).detach()
288
+
289
+ if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False):
290
+
291
+ if getattr(opt, 'use_grammar_baseline', False):
292
+ grammar_rewards = grammar_prob - grammar_prob_baseline.view(B, 1).expand(-1, K).contiguous().flatten()
293
+ else:
294
+ grammar_rewards = grammar_prob
295
+ else:
296
+ grammar_rewards = None
297
+
298
+
299
+ if hasattr(opt, 'verbose') and not opt.verbose:
300
+ pass
301
+ else:
302
+ if clipscore_model.mode == 'clip_s':
303
+ print('CLIP-S:', rewards)
304
+ elif clipscore_model.mode == 'refclip_s':
305
+ print('RefCLIP-S:', rewards)
306
+ else:
307
+ rewards = torch.zeros(B, L)
308
+ unnormalized_mean_reward = None
309
+ grammar_rewards = None
310
+
311
+
312
+ rewards = opt.clipscore_reward_weight * rewards
313
+
314
+
315
+ # scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis]
316
+ # scores = scores.reshape(gen_result_size)
317
+ # rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)
318
+
319
+ # [B, K]
320
+ # scores = scores[:gen_result_size].reshape(B, K) - scores[-B:].unsqueeze(1)
321
+
322
+ # [B*K, L]
323
+ # rewards = scores.view(-1, 1).expand(-1, L).contiguous()
324
+ rewards = rewards.view(-1, 1).expand(-1, L).contiguous()
325
+
326
+ if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False):
327
+ grammar_rewards = grammar_rewards.view(-1, 1).expand(-1, L).contiguous()
328
+
329
+ return rewards, unnormalized_mean_reward, grammar_rewards
330
+
331
+ def get_scores(data_gts, gen_result, opt):
332
+ batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
333
+ seq_per_img = batch_size // len(data_gts)
334
+
335
+ res = OrderedDict()
336
+
337
+ gen_result = gen_result.data.cpu().numpy()
338
+ for i in range(batch_size):
339
+ res[i] = [array_to_str(gen_result[i])]
340
+
341
+ gts = OrderedDict()
342
+ for i in range(len(data_gts)):
343
+ gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
344
+
345
+ res_ = [{'image_id':i, 'caption': res[i]} for i in range(batch_size)]
346
+ res__ = {i: res[i] for i in range(batch_size)}
347
+ gts = {i: gts[i // seq_per_img] for i in range(batch_size)}
348
+ if opt.cider_reward_weight > 0:
349
+ _, cider_scores = CiderD_scorer.compute_score(gts, res_)
350
+ # print('Cider scores:', _)
351
+ if hasattr(opt, 'verbose') and not opt.verbose:
352
+ pass
353
+ else:
354
+ print('Cider scores:', _)
355
+ else:
356
+ cider_scores = 0
357
+ if opt.bleu_reward_weight > 0:
358
+ _, bleu_scores = Bleu_scorer.compute_score(gts, res__)
359
+ bleu_scores = np.array(bleu_scores[3])
360
+ # print('Bleu scores:', _[3])
361
+ if hasattr(opt, 'verbose') and not opt.verbose:
362
+ pass
363
+ else:
364
+ print('Bleu scores:', _[3])
365
+ else:
366
+ bleu_scores = 0
367
+
368
+ scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
369
+
370
+ return scores
371
+
372
+ def get_self_cider_scores(data_gts, gen_result, opt):
373
+ batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
374
+ seq_per_img = batch_size // len(data_gts)
375
+
376
+ res = []
377
+
378
+ gen_result = gen_result.data.cpu().numpy()
379
+ for i in range(batch_size):
380
+ res.append(array_to_str(gen_result[i]))
381
+
382
+ scores = []
383
+ for i in range(len(data_gts)):
384
+ tmp = Cider_scorer.my_self_cider([res[i*seq_per_img:(i+1)*seq_per_img]])
385
+ def get_div(eigvals):
386
+ eigvals = np.clip(eigvals, 0, None)
387
+ return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals))
388
+ scores.append(get_div(np.linalg.eigvalsh(tmp[0]/10)))
389
+
390
+ scores = np.array(scores)
391
+
392
+ return scores
captioning/utils/utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ import torch
4
+ import torch.distributed as dist
5
+ import collections
6
+ import logging
7
+
8
+ def get_area(pos):
9
+ """
10
+ Args
11
+ pos: [B, N, 4]
12
+ (x1, x2, y1, y2)
13
+
14
+ Return
15
+ area : [B, N]
16
+ """
17
+ # [B, N]
18
+ height = pos[:, :, 3] - pos[:, :, 2]
19
+ width = pos[:, :, 1] - pos[:, :, 0]
20
+ area = height * width
21
+ return area
22
+
23
+ def get_relative_distance(pos):
24
+ """
25
+ Args
26
+ pos: [B, N, 4]
27
+ (x1, x2, y1, y2)
28
+
29
+ Return
30
+ out : [B, N, N, 4]
31
+ """
32
+ # B, N = pos.size()[:-1]
33
+
34
+ # [B, N, N, 4]
35
+ relative_distance = pos.unsqueeze(1) - pos.unsqueeze(2)
36
+
37
+ return relative_distance
38
+
39
+
40
+ class LossMeter(object):
41
+ def __init__(self, maxlen=100):
42
+ """Computes and stores the running average"""
43
+ self.vals = collections.deque([], maxlen=maxlen)
44
+
45
+ def __len__(self):
46
+ return len(self.vals)
47
+
48
+ def update(self, new_val):
49
+ self.vals.append(new_val)
50
+
51
+ @property
52
+ def val(self):
53
+ return sum(self.vals) / len(self.vals)
54
+
55
+ def __repr__(self):
56
+ return str(self.val)
57
+
58
+
59
+ def count_parameters(model):
60
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
61
+
62
+
63
+ def load_state_dict(state_dict_path, loc='cpu'):
64
+ state_dict = torch.load(state_dict_path, map_location=loc)
65
+ # Change Multi GPU to single GPU
66
+ original_keys = list(state_dict.keys())
67
+ for key in original_keys:
68
+ if key.startswith("module."):
69
+ new_key = key[len("module."):]
70
+ state_dict[new_key] = state_dict.pop(key)
71
+ return state_dict
72
+
73
+
74
+ def set_global_logging_level(level=logging.ERROR, prefices=[""]):
75
+ """
76
+ Override logging levels of different modules based on their name as a prefix.
77
+ It needs to be invoked after the modules have been loaded so that their loggers have been initialized.
78
+
79
+ Args:
80
+ - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
81
+ - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
82
+ Default is `[""]` to match all active loggers.
83
+ The match is a case-sensitive `module_name.startswith(prefix)`
84
+ """
85
+ prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })')
86
+ for name in logging.root.manager.loggerDict:
87
+ if re.match(prefix_re, name):
88
+ logging.getLogger(name).setLevel(level)
89
+
90
+
91
+ def get_iou(anchors, gt_boxes):
92
+ """
93
+ anchors: (N, 4) torch floattensor
94
+ gt_boxes: (K, 4) torch floattensor
95
+ overlaps: (N, K) ndarray of overlap between boxes and query_boxes
96
+ """
97
+ N = anchors.size(0)
98
+
99
+ if gt_boxes.size() == (4,):
100
+ gt_boxes = gt_boxes.view(1, 4)
101
+ K = gt_boxes.size(0)
102
+
103
+ gt_boxes_area = (
104
+ (gt_boxes[:, 2] - gt_boxes[:, 0] + 1) *
105
+ (gt_boxes[:, 3] - gt_boxes[:, 1] + 1)
106
+ ).view(1, K)
107
+
108
+ anchors_area = (
109
+ (anchors[:, 2] - anchors[:, 0] + 1) *
110
+ (anchors[:, 3] - anchors[:, 1] + 1)
111
+ ).view(N, 1)
112
+
113
+ boxes = anchors.view(N, 1, 4).expand(N, K, 4)
114
+ query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4)
115
+
116
+ iw = (
117
+ torch.min(boxes[:, :, 2], query_boxes[:, :, 2])
118
+ - torch.max(boxes[:, :, 0], query_boxes[:, :, 0])
119
+ + 1
120
+ )
121
+ iw[iw < 0] = 0
122
+
123
+ ih = (
124
+ torch.min(boxes[:, :, 3], query_boxes[:, :, 3])
125
+ - torch.max(boxes[:, :, 1], query_boxes[:, :, 1])
126
+ + 1
127
+ )
128
+ ih[ih < 0] = 0
129
+
130
+ ua = anchors_area + gt_boxes_area - (iw * ih)
131
+ overlaps = iw * ih / ua
132
+
133
+ return overlaps
134
+
135
+
136
+ def xywh_to_xyxy(boxes):
137
+ """Convert [x y w h] box format to [x1 y1 x2 y2] format."""
138
+ return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1))
clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
clip/clip.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Union, List
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10
+ from tqdm import tqdm
11
+
12
+ from .model import build_model
13
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14
+
15
+ __all__ = ["available_models", "load", "tokenize"]
16
+ _tokenizer = _Tokenizer()
17
+
18
+ _MODELS = {
19
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
20
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
21
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
22
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
23
+ }
24
+
25
+
26
+ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
27
+ os.makedirs(root, exist_ok=True)
28
+ filename = os.path.basename(url)
29
+
30
+ expected_sha256 = url.split("/")[-2]
31
+ download_target = os.path.join(root, filename)
32
+
33
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
34
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
35
+
36
+ if os.path.isfile(download_target):
37
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
38
+ return download_target
39
+ else:
40
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
41
+
42
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
43
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
44
+ while True:
45
+ buffer = source.read(8192)
46
+ if not buffer:
47
+ break
48
+
49
+ output.write(buffer)
50
+ loop.update(len(buffer))
51
+
52
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
53
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
54
+
55
+ return download_target
56
+
57
+
58
+ def _transform(n_px):
59
+ return Compose([
60
+ Resize(n_px, interpolation=Image.BICUBIC),
61
+ CenterCrop(n_px),
62
+ lambda image: image.convert("RGB"),
63
+ ToTensor(),
64
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
65
+ ])
66
+
67
+
68
+ def available_models() -> List[str]:
69
+ """Returns the names of available CLIP models"""
70
+ return list(_MODELS.keys())
71
+
72
+
73
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
74
+ """Load a CLIP model
75
+
76
+ Parameters
77
+ ----------
78
+ name : str
79
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
80
+
81
+ device : Union[str, torch.device]
82
+ The device to put the loaded model
83
+
84
+ jit : bool
85
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
86
+
87
+ Returns
88
+ -------
89
+ model : torch.nn.Module
90
+ The CLIP model
91
+
92
+ preprocess : Callable[[PIL.Image], torch.Tensor]
93
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
94
+ """
95
+ if name in _MODELS:
96
+ model_path = _download(_MODELS[name])
97
+ elif os.path.isfile(name):
98
+ model_path = name
99
+ else:
100
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
101
+
102
+ try:
103
+ # loading JIT archive
104
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
105
+ state_dict = None
106
+ except RuntimeError:
107
+ # loading saved state dict
108
+ if jit:
109
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
110
+ jit = False
111
+ state_dict = torch.load(model_path, map_location="cpu")
112
+
113
+ if not jit:
114
+ model = build_model(state_dict or model.state_dict()).to(device)
115
+ if str(device) == "cpu":
116
+ model.float()
117
+ return model, _transform(model.visual.input_resolution)
118
+
119
+ # patch the device names
120
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
121
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
122
+
123
+ def patch_device(module):
124
+ graphs = [module.graph] if hasattr(module, "graph") else []
125
+ if hasattr(module, "forward1"):
126
+ graphs.append(module.forward1.graph)
127
+
128
+ for graph in graphs:
129
+ for node in graph.findAllNodes("prim::Constant"):
130
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
131
+ node.copyAttributes(device_node)
132
+
133
+ model.apply(patch_device)
134
+ patch_device(model.encode_image)
135
+ patch_device(model.encode_text)
136
+
137
+ # patch dtype to float32 on CPU
138
+ if str(device) == "cpu":
139
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
140
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
141
+ float_node = float_input.node()
142
+
143
+ def patch_float(module):
144
+ graphs = [module.graph] if hasattr(module, "graph") else []
145
+ if hasattr(module, "forward1"):
146
+ graphs.append(module.forward1.graph)
147
+
148
+ for graph in graphs:
149
+ for node in graph.findAllNodes("aten::to"):
150
+ inputs = list(node.inputs())
151
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
152
+ if inputs[i].node()["value"] == 5:
153
+ inputs[i].node().copyAttributes(float_node)
154
+
155
+ model.apply(patch_float)
156
+ patch_float(model.encode_image)
157
+ patch_float(model.encode_text)
158
+
159
+ model.float()
160
+
161
+ return model, _transform(model.input_resolution.item())
162
+
163
+
164
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
165
+ """
166
+ Returns the tokenized representation of given input string(s)
167
+
168
+ Parameters
169
+ ----------
170
+ texts : Union[str, List[str]]
171
+ An input string or a list of input strings to tokenize
172
+
173
+ context_length : int
174
+ The context length to use; all CLIP models use 77 as the context length
175
+
176
+ Returns
177
+ -------
178
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
179
+ """
180
+ if isinstance(texts, str):
181
+ texts = [texts]
182
+
183
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
184
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
185
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
186
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
187
+
188
+ for i, tokens in enumerate(all_tokens):
189
+ if len(tokens) > context_length:
190
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
191
+ result[i, :len(tokens)] = torch.tensor(tokens)
192
+
193
+ return result
clip/model.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+
9
+ class Bottleneck(nn.Module):
10
+ expansion = 4
11
+
12
+ def __init__(self, inplanes, planes, stride=1):
13
+ super().__init__()
14
+
15
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
16
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
17
+ self.bn1 = nn.BatchNorm2d(planes)
18
+
19
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
20
+ self.bn2 = nn.BatchNorm2d(planes)
21
+
22
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
23
+
24
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
25
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
26
+
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ self.stride = stride
30
+
31
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
32
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
33
+ self.downsample = nn.Sequential(OrderedDict([
34
+ ("-1", nn.AvgPool2d(stride)),
35
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
36
+ ("1", nn.BatchNorm2d(planes * self.expansion))
37
+ ]))
38
+
39
+ def forward(self, x: torch.Tensor):
40
+ identity = x
41
+
42
+ out = self.relu(self.bn1(self.conv1(x)))
43
+ out = self.relu(self.bn2(self.conv2(out)))
44
+ out = self.avgpool(out)
45
+ out = self.bn3(self.conv3(out))
46
+
47
+ if self.downsample is not None:
48
+ identity = self.downsample(x)
49
+
50
+ out += identity
51
+ out = self.relu(out)
52
+ return out
53
+
54
+
55
+ class AttentionPool2d(nn.Module):
56
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
57
+ super().__init__()
58
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
59
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
60
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
61
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
62
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
63
+ self.num_heads = num_heads
64
+
65
+ def forward(self, x):
66
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
67
+ # print(x.shape, self.positional_embedding.shape)
68
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69
+ x = x + self.positional_embedding[0, :, None, :].to(x.dtype) # (HW+1)NC
70
+ x, _ = F.multi_head_attention_forward(
71
+ query=x, key=x, value=x,
72
+ embed_dim_to_check=x.shape[-1],
73
+ num_heads=self.num_heads,
74
+ q_proj_weight=self.q_proj.weight,
75
+ k_proj_weight=self.k_proj.weight,
76
+ v_proj_weight=self.v_proj.weight,
77
+ in_proj_weight=None,
78
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79
+ bias_k=None,
80
+ bias_v=None,
81
+ add_zero_attn=False,
82
+ dropout_p=0,
83
+ out_proj_weight=torch.ones_like(self.q_proj.weight),
84
+ out_proj_bias=torch.zeros_like(self.q_proj.bias),
85
+ # out_proj_weight=self.c_proj.weight,
86
+ # out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+
92
+ return x[0]
93
+
94
+
95
+ class ModifiedResNet(nn.Module):
96
+ """
97
+ A ResNet class that is similar to torchvision's but contains the following changes:
98
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100
+ - The final pooling layer is a QKV attention instead of an average pool
101
+ """
102
+
103
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
104
+ super().__init__()
105
+ self.output_dim = output_dim
106
+ self.input_resolution = input_resolution
107
+
108
+ # the 3-layer stem
109
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110
+ self.bn1 = nn.BatchNorm2d(width // 2)
111
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112
+ self.bn2 = nn.BatchNorm2d(width // 2)
113
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
114
+ self.bn3 = nn.BatchNorm2d(width)
115
+ self.avgpool = nn.AvgPool2d(2)
116
+ self.relu = nn.ReLU(inplace=True)
117
+
118
+ # residual layers
119
+ self._inplanes = width # this is a *mutable* variable used during construction
120
+ self.layer1 = self._make_layer(width, layers[0])
121
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
122
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
123
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
124
+
125
+ embed_dim = width * 32 # the ResNet feature dimension
126
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
127
+
128
+ def _make_layer(self, planes, blocks, stride=1):
129
+ layers = [Bottleneck(self._inplanes, planes, stride)]
130
+
131
+ self._inplanes = planes * Bottleneck.expansion
132
+ for _ in range(1, blocks):
133
+ layers.append(Bottleneck(self._inplanes, planes))
134
+
135
+ return nn.Sequential(*layers)
136
+
137
+ def forward(self, x):
138
+ def stem(x):
139
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
140
+ x = self.relu(bn(conv(x)))
141
+ x = self.avgpool(x)
142
+ return x
143
+
144
+ x = x.type(self.conv1.weight.dtype)
145
+ x = stem(x)
146
+ x = self.layer1(x)
147
+ x = self.layer2(x)
148
+ x = self.layer3(x)
149
+ x = self.layer4(x)
150
+ # print(x.shape)
151
+ # x = self.attnpool(x)
152
+ attnpool = self.attnpool(x)
153
+
154
+ return (x, attnpool)
155
+
156
+
157
+ class LayerNorm(nn.LayerNorm):
158
+ """Subclass torch's LayerNorm to handle fp16."""
159
+
160
+ def forward(self, x: torch.Tensor):
161
+ orig_type = x.dtype
162
+ ret = super().forward(x.type(torch.float32))
163
+ return ret.type(orig_type)
164
+
165
+
166
+ class QuickGELU(nn.Module):
167
+ def forward(self, x: torch.Tensor):
168
+ return x * torch.sigmoid(1.702 * x)
169
+
170
+
171
+ class ResidualAttentionBlock(nn.Module):
172
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173
+ super().__init__()
174
+
175
+ self.attn = nn.MultiheadAttention(d_model, n_head)
176
+ self.ln_1 = LayerNorm(d_model)
177
+ self.mlp = nn.Sequential(OrderedDict([
178
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
179
+ ("gelu", QuickGELU()),
180
+ ("c_proj", nn.Linear(d_model * 4, d_model))
181
+ ]))
182
+ self.ln_2 = LayerNorm(d_model)
183
+ self.attn_mask = attn_mask
184
+
185
+ def attention(self, x: torch.Tensor):
186
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188
+
189
+ def forward(self, x: torch.Tensor):
190
+ x = x + self.attention(self.ln_1(x))
191
+ x = x + self.mlp(self.ln_2(x))
192
+ return x
193
+
194
+
195
+ class Transformer(nn.Module):
196
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197
+ super().__init__()
198
+ self.width = width
199
+ self.layers = layers
200
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201
+
202
+ def forward(self, x: torch.Tensor):
203
+ return self.resblocks(x)
204
+
205
+
206
+ class VisualTransformer(nn.Module):
207
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
208
+ super().__init__()
209
+ self.input_resolution = input_resolution
210
+ self.output_dim = output_dim
211
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
212
+
213
+ scale = width ** -0.5
214
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
215
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
216
+ self.ln_pre = LayerNorm(width)
217
+
218
+ self.transformer = Transformer(width, layers, heads)
219
+
220
+ self.ln_post = LayerNorm(width)
221
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
222
+
223
+ def forward(self, x: torch.Tensor):
224
+ x = self.conv1(x) # shape = [*, width, grid, grid]
225
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
226
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
227
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
228
+ x = x + self.positional_embedding.to(x.dtype)
229
+ x = self.ln_pre(x)
230
+
231
+ x = x.permute(1, 0, 2) # NLD -> LND
232
+ x = self.transformer(x)
233
+ x = x.permute(1, 0, 2) # LND -> NLD
234
+
235
+ # x = self.ln_post(x[:, 0, :])
236
+
237
+ x = self.ln_post(x)
238
+ # if self.proj is not None:
239
+ # x = x @ self.proj
240
+
241
+ return x
242
+
243
+
244
+ class CLIP(nn.Module):
245
+ def __init__(self,
246
+ embed_dim: int,
247
+ # vision
248
+ image_resolution: int,
249
+ vision_layers: Union[Tuple[int, int, int, int], int],
250
+ vision_width: int,
251
+ vision_patch_size: int,
252
+ # text
253
+ context_length: int,
254
+ vocab_size: int,
255
+ transformer_width: int,
256
+ transformer_heads: int,
257
+ transformer_layers: int
258
+ ):
259
+ super().__init__()
260
+
261
+ self.context_length = context_length
262
+
263
+ if isinstance(vision_layers, (tuple, list)):
264
+ vision_heads = vision_width * 32 // 64
265
+ self.visual = ModifiedResNet(
266
+ layers=vision_layers,
267
+ output_dim=embed_dim,
268
+ heads=vision_heads,
269
+ input_resolution=image_resolution,
270
+ width=vision_width
271
+ )
272
+ else:
273
+ vision_heads = vision_width // 64
274
+ self.visual = VisualTransformer(
275
+ input_resolution=image_resolution,
276
+ patch_size=vision_patch_size,
277
+ width=vision_width,
278
+ layers=vision_layers,
279
+ heads=vision_heads,
280
+ output_dim=embed_dim
281
+ )
282
+
283
+ self.transformer = Transformer(
284
+ width=transformer_width,
285
+ layers=transformer_layers,
286
+ heads=transformer_heads,
287
+ attn_mask=self.build_attention_mask()
288
+ )
289
+
290
+ self.vocab_size = vocab_size
291
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
292
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
293
+ self.ln_final = LayerNorm(transformer_width)
294
+
295
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
296
+ self.logit_scale = nn.Parameter(torch.ones([]))
297
+
298
+ self.initialize_parameters()
299
+
300
+ def initialize_parameters(self):
301
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
302
+ nn.init.normal_(self.positional_embedding, std=0.01)
303
+
304
+ if isinstance(self.visual, ModifiedResNet):
305
+ if self.visual.attnpool is not None:
306
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
307
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
308
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
309
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
310
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
311
+
312
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
313
+ for name, param in resnet_block.named_parameters():
314
+ if name.endswith("bn3.weight"):
315
+ nn.init.zeros_(param)
316
+
317
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
318
+ attn_std = self.transformer.width ** -0.5
319
+ fc_std = (2 * self.transformer.width) ** -0.5
320
+ for block in self.transformer.resblocks:
321
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
322
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
323
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
324
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
325
+
326
+ if self.text_projection is not None:
327
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
328
+
329
+ def build_attention_mask(self):
330
+ # lazily create causal attention mask, with full attention between the vision tokens
331
+ # pytorch uses additive attention mask; fill with -inf
332
+ mask = torch.empty(self.context_length, self.context_length)
333
+ mask.fill_(float("-inf"))
334
+ mask.triu_(1) # zero out the lower diagonal
335
+ return mask
336
+
337
+ @property
338
+ def dtype(self):
339
+ return self.visual.conv1.weight.dtype
340
+
341
+ def encode_image(self, image):
342
+ return self.visual(image.type(self.dtype))
343
+
344
+ def encode_text(self, text):
345
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
346
+
347
+ x = x + self.positional_embedding.type(self.dtype)
348
+ x = x.permute(1, 0, 2) # NLD -> LND
349
+ x = self.transformer(x)
350
+ x = x.permute(1, 0, 2) # LND -> NLD
351
+ x = self.ln_final(x).type(self.dtype)
352
+
353
+ # x.shape = [batch_size, n_ctx, transformer.width]
354
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
355
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
356
+
357
+ return x
358
+
359
+ def forward(self, image, text):
360
+ image_features = self.encode_image(image)
361
+ text_features = self.encode_text(text)
362
+
363
+ # normalized features
364
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
365
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
366
+
367
+ # cosine similarity as logits
368
+ logit_scale = self.logit_scale.exp()
369
+ logits_per_image = logit_scale * image_features @ text_features.t()
370
+ logits_per_text = logit_scale * text_features @ image_features.t()
371
+
372
+ # shape = [global_batch_size, global_batch_size]
373
+ return logits_per_image, logits_per_text
374
+
375
+
376
+ def convert_weights(model: nn.Module):
377
+ """Convert applicable model parameters to fp16"""
378
+
379
+ def _convert_weights_to_fp16(l):
380
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
381
+ l.weight.data = l.weight.data.half()
382
+ if l.bias is not None:
383
+ l.bias.data = l.bias.data.half()
384
+
385
+ if isinstance(l, nn.MultiheadAttention):
386
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
387
+ tensor = getattr(l, attr)
388
+ if tensor is not None:
389
+ tensor.data = tensor.data.half()
390
+
391
+ for name in ["text_projection", "proj"]:
392
+ if hasattr(l, name):
393
+ attr = getattr(l, name)
394
+ if attr is not None:
395
+ attr.data = attr.data.half()
396
+
397
+ model.apply(_convert_weights_to_fp16)
398
+
399
+
400
+ def build_model(state_dict: dict):
401
+ vit = "visual.proj" in state_dict
402
+
403
+ if vit:
404
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
405
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
406
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
407
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
408
+ image_resolution = vision_patch_size * grid_size
409
+ else:
410
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
411
+ vision_layers = tuple(counts)
412
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
413
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
414
+ vision_patch_size = None
415
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
416
+ image_resolution = output_width * 32
417
+
418
+ embed_dim = state_dict["text_projection"].shape[1]
419
+ context_length = state_dict["positional_embedding"].shape[0]
420
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
421
+ transformer_width = state_dict["ln_final.weight"].shape[0]
422
+ transformer_heads = transformer_width // 64
423
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
424
+
425
+ model = CLIP(
426
+ embed_dim,
427
+ image_resolution, vision_layers, vision_width, vision_patch_size,
428
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
429
+ )
430
+
431
+ for key in ["input_resolution", "context_length", "vocab_size"]:
432
+ if key in state_dict:
433
+ del state_dict[key]
434
+
435
+ convert_weights(model)
436
+ model.load_state_dict(state_dict)
437
+ return model.eval()
clip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
configs/phase1/FineCapEval_clipRN50_mle.yml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/FineCapEval.json
6
+ input_label_h5: none
7
+ input_fc_dir: data/FineCapEval_clip_RN50_fc
8
+ input_att_dir: data/FineCapEval_clip_RN50_att
9
+ input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
+
11
+ seq_per_img: 5
12
+ batch_size: 200
13
+ learning_rate: 0.0005
14
+
15
+ checkpoint_path: ./save/clipRN50_mle/clipRN50_mle
16
+
17
+ # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
18
+
19
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
20
+ # N=num_layers
21
+ # d_model=input_encoding_size
22
+ # d_ff=rnn_size
23
+
24
+ # will be ignored
25
+ num_layers: 6
26
+ input_encoding_size: 512
27
+ rnn_size: 2048
28
+
29
+ # Transformer config
30
+ N_enc: 6
31
+ N_dec: 6
32
+ d_model: 512
33
+ d_ff: 2048
34
+ num_att_heads: 8
35
+ dropout: 0.1
36
+
37
+
38
+ learning_rate_decay_start: 0
39
+ scheduled_sampling_start: -1
40
+ save_checkpoint_every: 3000
41
+ language_eval: 1
42
+ val_images_use: 5000
43
+ max_epochs: 15
44
+ train_sample_n: 5
45
+
46
+ REFORWARD: false
47
+
48
+ # _BASE_: transformer.yml
49
+ reduce_on_plateau: false
50
+ noamopt: false
51
+ learning_rate: 0.000005
52
+ learning_rate_decay_start: -1
53
+
54
+ self_critical_after: 15
55
+ max_epochs: 50
56
+
57
+ verbose: false
58
+ precision: 32
59
+
60
+ use_clipscore: false
configs/phase1/clipRN50_mle.yml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ # noamopt: false
4
+ noamopt_warmup: 20000
5
+ label_smoothing: 0.0
6
+ input_json: data/cocotalk.json
7
+ input_label_h5: data/cocotalk_label.h5
8
+ input_fc_dir: data/cocotalk_clip_RN50_fc
9
+ input_att_dir: data/cocotalk_clip_RN50_att
10
+ input_clipscore_vis_dir: data/cocotalk_clipscore_vis
11
+ seq_per_img: 5
12
+ # batch_size: 600
13
+ batch_size: 200
14
+
15
+ learning_rate: 0.0005
16
+
17
+ # checkpoint_path: ./save/trans_clip_rn50_sc_pl
18
+ checkpoint_path: save/clipRN50_mle/clipRN50_mle
19
+
20
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
21
+ # N=num_layers
22
+ # d_model=input_encoding_size
23
+ # d_ff=rnn_size
24
+
25
+ # will be ignored
26
+ num_layers: 6
27
+ input_encoding_size: 512
28
+ rnn_size: 2048
29
+
30
+ # Transformer config
31
+ N_enc: 6
32
+ N_dec: 6
33
+ d_model: 512
34
+ d_ff: 2048
35
+ num_att_heads: 8
36
+ dropout: 0.1
37
+
38
+
39
+ learning_rate_decay_start: 0
40
+ scheduled_sampling_start: -1
41
+ save_checkpoint_every: 3000
42
+ language_eval: 1
43
+ val_images_use: 5000
44
+ # max_epochs: 15
45
+ max_epochs: 25
46
+ train_sample_n: 5
47
+
48
+ REFORWARD: false
49
+
50
+
51
+ verbose: false
52
+ precision: 16
configs/phase1/transformer.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/cocotalk.json
6
+ input_label_h5: data/cocotalk_label.h5
7
+ input_att_dir: data/cocotalk_att
8
+ seq_per_img: 5
9
+ batch_size: 10
10
+ learning_rate: 0.0005
11
+
12
+ checkpoint_path: ./save/trans_rn50_sc
13
+
14
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
15
+ # N=num_layers
16
+ # d_model=input_encoding_size
17
+ # d_ff=rnn_size
18
+
19
+ # will be ignored
20
+ num_layers: 6
21
+ input_encoding_size: 512
22
+ rnn_size: 2048
23
+
24
+ # Transformer config
25
+ N_enc: 6
26
+ N_dec: 6
27
+ d_model: 512
28
+ d_ff: 2048
29
+ num_att_heads: 8
30
+ dropout: 0.1
31
+
32
+
33
+ learning_rate_decay_start: 0
34
+ scheduled_sampling_start: -1
35
+ save_checkpoint_every: 3000
36
+ language_eval: 1
37
+ val_images_use: 5000
38
+ max_epochs: 15
39
+ train_sample_n: 5
40
+
41
+ REFORWARD: false
configs/phase2/FineCapEval_clipRN50_cider.yml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/FineCapEval.json
6
+ input_label_h5: none
7
+ input_fc_dir: data/FineCapEval_clip_RN50_fc
8
+ input_att_dir: data/FineCapEval_clip_RN50_att
9
+ input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
+
11
+ seq_per_img: 5
12
+ batch_size: 200
13
+ learning_rate: 0.0005
14
+
15
+ checkpoint_path: ./save/clipRN50_cider/clipRN50_cider
16
+
17
+ # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
18
+
19
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
20
+ # N=num_layers
21
+ # d_model=input_encoding_size
22
+ # d_ff=rnn_size
23
+
24
+ # will be ignored
25
+ num_layers: 6
26
+ input_encoding_size: 512
27
+ rnn_size: 2048
28
+
29
+ # Transformer config
30
+ N_enc: 6
31
+ N_dec: 6
32
+ d_model: 512
33
+ d_ff: 2048
34
+ num_att_heads: 8
35
+ dropout: 0.1
36
+
37
+
38
+ learning_rate_decay_start: 0
39
+ scheduled_sampling_start: -1
40
+ save_checkpoint_every: 3000
41
+ language_eval: 1
42
+ val_images_use: 5000
43
+ max_epochs: 15
44
+ train_sample_n: 5
45
+
46
+ REFORWARD: false
47
+
48
+ # _BASE_: transformer.yml
49
+ reduce_on_plateau: false
50
+ noamopt: false
51
+ learning_rate: 0.000005
52
+ learning_rate_decay_start: -1
53
+
54
+ self_critical_after: 15
55
+ max_epochs: 50
56
+
57
+ verbose: false
58
+ precision: 32
59
+
60
+ # use_clipscore: true
61
+ use_clipscore: false
configs/phase2/FineCapEval_clipRN50_cider_clips.yml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/FineCapEval.json
6
+ input_label_h5: none
7
+ input_fc_dir: data/FineCapEval_clip_RN50_fc
8
+ input_att_dir: data/FineCapEval_clip_RN50_att
9
+ input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
+
11
+ seq_per_img: 5
12
+ batch_size: 200
13
+ learning_rate: 0.0005
14
+
15
+ checkpoint_path: ./save/clipRN50_cider_clips/clipRN50_cider_clips
16
+
17
+ # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
18
+
19
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
20
+ # N=num_layers
21
+ # d_model=input_encoding_size
22
+ # d_ff=rnn_size
23
+
24
+ # will be ignored
25
+ num_layers: 6
26
+ input_encoding_size: 512
27
+ rnn_size: 2048
28
+
29
+ # Transformer config
30
+ N_enc: 6
31
+ N_dec: 6
32
+ d_model: 512
33
+ d_ff: 2048
34
+ num_att_heads: 8
35
+ dropout: 0.1
36
+
37
+
38
+ learning_rate_decay_start: 0
39
+ scheduled_sampling_start: -1
40
+ save_checkpoint_every: 3000
41
+ language_eval: 1
42
+ val_images_use: 5000
43
+ max_epochs: 15
44
+ train_sample_n: 5
45
+
46
+ REFORWARD: false
47
+
48
+ # _BASE_: transformer.yml
49
+ reduce_on_plateau: false
50
+ noamopt: false
51
+ learning_rate: 0.000005
52
+ learning_rate_decay_start: -1
53
+
54
+ self_critical_after: 15
55
+ max_epochs: 50
56
+
57
+ verbose: false
58
+ precision: 32
59
+
60
+ # use_clipscore: true
61
+ use_clipscore: false
62
+ clipscore_reward_weight: 2.0
63
+ clipscore_mode: clip_s
64
+
65
+ use_multi_rewards: true
configs/phase2/FineCapEval_clipRN50_clips.yml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/FineCapEval.json
6
+ input_label_h5: none
7
+ input_fc_dir: data/FineCapEval_clip_RN50_fc
8
+ input_att_dir: data/FineCapEval_clip_RN50_att
9
+ input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
+ seq_per_img: 5
11
+ batch_size: 160
12
+ learning_rate: 0.0005
13
+
14
+ checkpoint_path: ./save/clipRN50_clips/clipRN50_clips
15
+
16
+ use_multi_rewards: false
17
+ use_grammar: false
18
+ use_grammar_baseline: false
19
+ # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
20
+
21
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
22
+ # N=num_layers
23
+ # d_model=input_encoding_size
24
+ # d_ff=rnn_size
25
+
26
+ # will be ignored
27
+ num_layers: 6
28
+ input_encoding_size: 512
29
+ rnn_size: 2048
30
+
31
+ # Transformer config
32
+ N_enc: 6
33
+ N_dec: 6
34
+ d_model: 512
35
+ d_ff: 2048
36
+ num_att_heads: 8
37
+ dropout: 0.1
38
+
39
+
40
+ learning_rate_decay_start: 0
41
+ scheduled_sampling_start: -1
42
+ save_checkpoint_every: 3000
43
+ language_eval: 0
44
+ val_images_use: 5000
45
+ max_epochs: 15
46
+ train_sample_n: 5
47
+
48
+ REFORWARD: false
49
+
50
+ # _BASE_: transformer.yml
51
+ reduce_on_plateau: false
52
+ noamopt: false
53
+ learning_rate: 0.000005
54
+ learning_rate_decay_start: -1
55
+
56
+ self_critical_after: 15
57
+ max_epochs: 50
58
+
59
+ verbose: false
60
+ precision: 32
61
+
62
+ # use_clipscore: true
63
+ use_clipscore: false
64
+ clipscore_reward_weight: 2.0
configs/phase2/FineCapEval_clipRN50_clips_grammar.yml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/FineCapEval.json
6
+ input_label_h5: none
7
+ input_fc_dir: data/FineCapEval_clip_RN50_fc
8
+ input_att_dir: data/FineCapEval_clip_RN50_att
9
+ input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
10
+ seq_per_img: 5
11
+ batch_size: 160
12
+ learning_rate: 0.0005
13
+
14
+ checkpoint_path: ./save/clipRN50_clips_grammar/clipRN50_clips_grammar
15
+
16
+ use_multi_rewards: true
17
+ use_grammar: true
18
+ use_grammar_baseline: true
19
+ # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
20
+
21
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
22
+ # N=num_layers
23
+ # d_model=input_encoding_size
24
+ # d_ff=rnn_size
25
+
26
+ # will be ignored
27
+ num_layers: 6
28
+ input_encoding_size: 512
29
+ rnn_size: 2048
30
+
31
+ # Transformer config
32
+ N_enc: 6
33
+ N_dec: 6
34
+ d_model: 512
35
+ d_ff: 2048
36
+ num_att_heads: 8
37
+ dropout: 0.1
38
+
39
+
40
+ learning_rate_decay_start: 0
41
+ scheduled_sampling_start: -1
42
+ save_checkpoint_every: 3000
43
+ language_eval: 0
44
+ val_images_use: 5000
45
+ max_epochs: 15
46
+ train_sample_n: 5
47
+
48
+ REFORWARD: false
49
+
50
+ # _BASE_: transformer.yml
51
+ reduce_on_plateau: false
52
+ noamopt: false
53
+ learning_rate: 0.000005
54
+ learning_rate_decay_start: -1
55
+
56
+ self_critical_after: 15
57
+ max_epochs: 50
58
+
59
+ verbose: false
60
+ precision: 32
61
+
62
+ # use_clipscore: true
63
+ use_clipscore: false
64
+ clipscore_reward_weight: 2.0
configs/phase2/clipRN50_cider.yml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/cocotalk.json
6
+ input_label_h5: data/cocotalk_label.h5
7
+ input_fc_dir: data/cocotalk_clip_RN50_fc
8
+ input_att_dir: data/cocotalk_clip_RN50_att
9
+ # used only for evaluation
10
+ input_clipscore_vis_dir: data/cocotalk_clipscore_vis
11
+
12
+ seq_per_img: 5
13
+ batch_size: 200
14
+ learning_rate: 0.0005
15
+
16
+ # checkpoint_path: ./save/trans_clip_rn50_sc_pl_scst_cider
17
+ checkpoint_path: save/clipRN50_cider/clipRN50_cider
18
+
19
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
20
+ # N=num_layers
21
+ # d_model=input_encoding_size
22
+ # d_ff=rnn_size
23
+
24
+ # will be ignored
25
+ num_layers: 6
26
+ input_encoding_size: 512
27
+ rnn_size: 2048
28
+
29
+ # Transformer config
30
+ N_enc: 6
31
+ N_dec: 6
32
+ d_model: 512
33
+ d_ff: 2048
34
+ num_att_heads: 8
35
+ dropout: 0.1
36
+
37
+
38
+ learning_rate_decay_start: 0
39
+ scheduled_sampling_start: -1
40
+ save_checkpoint_every: 3000
41
+ language_eval: 1
42
+ val_images_use: 5000
43
+ max_epochs: 15
44
+ train_sample_n: 5
45
+
46
+ REFORWARD: false
47
+
48
+ # _BASE_: transformer.yml
49
+ reduce_on_plateau: false
50
+ noamopt: false
51
+ learning_rate: 0.000005
52
+ learning_rate_decay_start: -1
53
+
54
+ self_critical_after: 15
55
+ max_epochs: 40
56
+
57
+ verbose: false
58
+ precision: 32
configs/phase2/clipRN50_cider_clips.yml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/cocotalk.json
6
+ input_label_h5: data/cocotalk_label.h5
7
+ input_fc_dir: data/cocotalk_clip_RN50_fc
8
+ input_att_dir: data/cocotalk_clip_RN50_att
9
+ input_clipscore_vis_dir: data/cocotalk_clipscore_vis
10
+ seq_per_img: 5
11
+ batch_size: 160
12
+ learning_rate: 0.0005
13
+
14
+ checkpoint_path: save/clipRN50_cider_clips/clipRN50_cider_clips
15
+
16
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
17
+ # N=num_layers
18
+ # d_model=input_encoding_size
19
+ # d_ff=rnn_size
20
+
21
+ # will be ignored
22
+ num_layers: 6
23
+ input_encoding_size: 512
24
+ rnn_size: 2048
25
+
26
+ # Transformer config
27
+ N_enc: 6
28
+ N_dec: 6
29
+ d_model: 512
30
+ d_ff: 2048
31
+ num_att_heads: 8
32
+ dropout: 0.1
33
+
34
+
35
+ learning_rate_decay_start: 0
36
+ scheduled_sampling_start: -1
37
+ save_checkpoint_every: 3000
38
+ language_eval: 1
39
+ val_images_use: 5000
40
+ max_epochs: 15
41
+ train_sample_n: 5
42
+
43
+ REFORWARD: false
44
+
45
+ # _BASE_: transformer.yml
46
+ reduce_on_plateau: false
47
+ noamopt: false
48
+ learning_rate: 0.000005
49
+ learning_rate_decay_start: -1
50
+
51
+ self_critical_after: 15
52
+ max_epochs: 40
53
+
54
+ verbose: false
55
+ precision: 32
56
+
57
+ use_clipscore: true
58
+ clipscore_reward_weight: 2.0
59
+ clipscore_mode: clip_s
60
+
61
+ use_multi_rewards: true
configs/phase2/clipRN50_clips.yml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/cocotalk.json
6
+ input_label_h5: data/cocotalk_label.h5
7
+ input_fc_dir: data/cocotalk_clip_RN50_fc
8
+ input_att_dir: data/cocotalk_clip_RN50_att
9
+ input_clipscore_vis_dir: data/cocotalk_clipscore_vis
10
+ seq_per_img: 5
11
+ batch_size: 160
12
+ learning_rate: 0.0005
13
+
14
+ checkpoint_path: save/clipRN50_clips/clipRN50_clips
15
+
16
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
17
+ # N=num_layers
18
+ # d_model=input_encoding_size
19
+ # d_ff=rnn_size
20
+
21
+ # will be ignored
22
+ num_layers: 6
23
+ input_encoding_size: 512
24
+ rnn_size: 2048
25
+
26
+ # Transformer config
27
+ N_enc: 6
28
+ N_dec: 6
29
+ d_model: 512
30
+ d_ff: 2048
31
+ num_att_heads: 8
32
+ dropout: 0.1
33
+
34
+
35
+ learning_rate_decay_start: 0
36
+ scheduled_sampling_start: -1
37
+ save_checkpoint_every: 3000
38
+ language_eval: 1
39
+ val_images_use: 5000
40
+ max_epochs: 15
41
+ train_sample_n: 5
42
+
43
+ REFORWARD: false
44
+
45
+ # _BASE_: transformer.yml
46
+ reduce_on_plateau: false
47
+ noamopt: false
48
+ learning_rate: 0.000005
49
+ learning_rate_decay_start: -1
50
+
51
+ self_critical_after: 15
52
+ max_epochs: 40
53
+
54
+ verbose: false
55
+ precision: 32
56
+
57
+ use_clipscore: true
58
+ clipscore_reward_weight: 2.0
configs/phase2/clipRN50_clips_grammar.yml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/cocotalk.json
6
+ input_label_h5: data/cocotalk_label.h5
7
+ input_fc_dir: data/cocotalk_clip_RN50_fc
8
+ input_att_dir: data/cocotalk_clip_RN50_att
9
+ input_clipscore_vis_dir: data/cocotalk_clipscore_vis
10
+ seq_per_img: 5
11
+ batch_size: 160
12
+ learning_rate: 0.0005
13
+
14
+ checkpoint_path: save/clipRN50_clips_grammar/clipRN50_clips_grammar
15
+
16
+ use_multi_rewards: true
17
+ use_grammar: true
18
+ use_grammar_baseline: true
19
+ # clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
20
+ clip_load_path: 'retrieval/save/clip_negative_text/clip_negative_text-epoch=12.ckpt'
21
+
22
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
23
+ # N=num_layers
24
+ # d_model=input_encoding_size
25
+ # d_ff=rnn_size
26
+
27
+ # will be ignored
28
+ num_layers: 6
29
+ input_encoding_size: 512
30
+ rnn_size: 2048
31
+
32
+ # Transformer config
33
+ N_enc: 6
34
+ N_dec: 6
35
+ d_model: 512
36
+ d_ff: 2048
37
+ num_att_heads: 8
38
+ dropout: 0.1
39
+
40
+
41
+ learning_rate_decay_start: 0
42
+ scheduled_sampling_start: -1
43
+ save_checkpoint_every: 3000
44
+ language_eval: 1
45
+ val_images_use: 5000
46
+ max_epochs: 15
47
+ train_sample_n: 5
48
+
49
+ REFORWARD: false
50
+
51
+ # _BASE_: transformer.yml
52
+ reduce_on_plateau: false
53
+ noamopt: false
54
+ learning_rate: 0.000005
55
+ learning_rate_decay_start: -1
56
+
57
+ self_critical_after: 15
58
+ max_epochs: 40
59
+
60
+ verbose: false
61
+ precision: 32
62
+
63
+ use_clipscore: true
64
+ clipscore_reward_weight: 2.0
configs/phase2/transformer.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ caption_model: transformer
2
+ noamopt: true
3
+ noamopt_warmup: 20000
4
+ label_smoothing: 0.0
5
+ input_json: data/cocotalk.json
6
+ input_label_h5: data/cocotalk_label.h5
7
+ input_att_dir: data/cocotalk_att
8
+ seq_per_img: 5
9
+ batch_size: 10
10
+ learning_rate: 0.0005
11
+
12
+ checkpoint_path: ./save/trans_rn50_sc
13
+
14
+ # Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
15
+ # N=num_layers
16
+ # d_model=input_encoding_size
17
+ # d_ff=rnn_size
18
+
19
+ # will be ignored
20
+ num_layers: 6
21
+ input_encoding_size: 512
22
+ rnn_size: 2048
23
+
24
+ # Transformer config
25
+ N_enc: 6
26
+ N_dec: 6
27
+ d_model: 512
28
+ d_ff: 2048
29
+ num_att_heads: 8
30
+ dropout: 0.1
31
+
32
+
33
+ learning_rate_decay_start: 0
34
+ scheduled_sampling_start: -1
35
+ save_checkpoint_every: 3000
36
+ language_eval: 1
37
+ val_images_use: 5000
38
+ max_epochs: 15
39
+ train_sample_n: 5
40
+
41
+ REFORWARD: false