Spaces:
Runtime error
Runtime error
Alberto Carmona
commited on
Commit
·
7ec5667
1
Parent(s):
5fb99b3
Track error cloning the repo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +148 -144
- captioning/__init__.py +0 -0
- captioning/data/__init__.py +0 -0
- captioning/data/dataloader.py +0 -425
- captioning/data/pth_loader.py +0 -334
- captioning/data/pth_loader_FineCapEval.py +0 -334
- captioning/models/AoAModel.py +0 -228
- captioning/models/AttEnsemble.py +0 -90
- captioning/models/AttModel.py +0 -969
- captioning/models/BertCapModel.py +0 -104
- captioning/models/CaptionModel.py +0 -407
- captioning/models/FCModel.py +0 -204
- captioning/models/M2Transformer.py +0 -98
- captioning/models/ShowTellModel.py +0 -174
- captioning/models/TransformerModel.py +0 -363
- captioning/models/__init__.py +0 -73
- captioning/models/cachedTransformer.py +0 -420
- captioning/models/utils.py +0 -25
- captioning/modules/loss_wrapper.py +0 -127
- captioning/modules/losses.py +0 -218
- captioning/utils/__init__.py +0 -0
- captioning/utils/clipscore.py +0 -396
- captioning/utils/config.py +0 -153
- captioning/utils/dist_utils.py +0 -305
- captioning/utils/div_utils.py +0 -38
- captioning/utils/eval_multi.py +0 -218
- captioning/utils/eval_utils.py +0 -281
- captioning/utils/misc.py +0 -251
- captioning/utils/opts.py +0 -412
- captioning/utils/resnet.py +0 -71
- captioning/utils/resnet_utils.py +0 -27
- captioning/utils/rewards.py +0 -392
- captioning/utils/utils.py +0 -138
- clip/__init__.py +0 -1
- clip/bpe_simple_vocab_16e6.txt.gz +0 -3
- clip/clip.py +0 -193
- clip/model.py +0 -437
- clip/simple_tokenizer.py +0 -132
- configs/phase1/FineCapEval_clipRN50_mle.yml +0 -60
- configs/phase1/clipRN50_mle.yml +0 -52
- configs/phase1/transformer.yml +0 -41
- configs/phase2/FineCapEval_clipRN50_cider.yml +0 -61
- configs/phase2/FineCapEval_clipRN50_cider_clips.yml +0 -65
- configs/phase2/FineCapEval_clipRN50_clips.yml +0 -64
- configs/phase2/FineCapEval_clipRN50_clips_grammar.yml +0 -64
- configs/phase2/clipRN50_cider.yml +0 -58
- configs/phase2/clipRN50_cider_clips.yml +0 -61
- configs/phase2/clipRN50_clips.yml +0 -58
- configs/phase2/clipRN50_clips_grammar.yml +0 -64
- configs/phase2/transformer.yml +0 -41
app.py
CHANGED
@@ -1,194 +1,198 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import numpy as np
|
4 |
-
import json
|
5 |
-
import captioning.utils.opts as opts
|
6 |
-
import captioning.models as models
|
7 |
-
import captioning.utils.misc as utils
|
8 |
-
import pytorch_lightning as pl
|
9 |
import gradio as gr
|
10 |
|
11 |
-
from diffusers import LDMTextToImagePipeline
|
12 |
-
# import PIL.Image
|
13 |
import random
|
14 |
-
import os
|
15 |
|
16 |
|
17 |
-
# Checkpoint class
|
18 |
-
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
|
24 |
-
device = 'cpu' #@param ["cuda", "cpu"] {allow-input: true}
|
25 |
-
reward = 'clips_grammar'
|
26 |
|
27 |
-
cfg = f'./configs/phase2/clipRN50_{reward}.yml'
|
28 |
|
29 |
-
print("Loading cfg from", cfg)
|
30 |
|
31 |
-
opt = opts.parse_opt(parse=False, cfg=cfg)
|
32 |
|
33 |
-
import gdown
|
34 |
|
35 |
-
url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W"
|
36 |
-
gdown.download_folder(url, quiet=True, use_cookies=False, output="save/")
|
37 |
|
38 |
-
url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp"
|
39 |
-
gdown.download(url, quiet=True, use_cookies=False, output="data/")
|
40 |
|
41 |
-
dict_json = json.load(open('./data/cocotalk.json'))
|
42 |
-
print(dict_json.keys())
|
43 |
|
44 |
-
ix_to_word = dict_json['ix_to_word']
|
45 |
-
vocab_size = len(ix_to_word)
|
46 |
-
print('vocab size:', vocab_size)
|
47 |
|
48 |
-
seq_length = 1
|
49 |
|
50 |
-
opt.vocab_size = vocab_size
|
51 |
-
opt.seq_length = seq_length
|
52 |
|
53 |
-
opt.batch_size = 1
|
54 |
-
opt.vocab = ix_to_word
|
55 |
|
56 |
-
model = models.setup(opt)
|
57 |
-
del opt.vocab
|
58 |
|
59 |
-
ckpt_path = opt.checkpoint_path + '-last.ckpt'
|
60 |
|
61 |
-
print("Loading checkpoint from", ckpt_path)
|
62 |
-
raw_state_dict = torch.load(
|
63 |
-
|
64 |
-
|
65 |
|
66 |
-
strict = True
|
67 |
|
68 |
-
state_dict = raw_state_dict['state_dict']
|
69 |
|
70 |
-
if '_vocab' in state_dict:
|
71 |
-
|
72 |
-
|
73 |
-
elif strict:
|
74 |
-
|
75 |
-
if '_opt' in state_dict:
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
elif strict:
|
88 |
-
|
89 |
-
res = model.load_state_dict(state_dict, strict)
|
90 |
-
print(res)
|
91 |
|
92 |
-
model = model.to(device)
|
93 |
-
model.eval();
|
94 |
|
95 |
-
import clip
|
96 |
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
97 |
-
from PIL import Image
|
98 |
-
from timm.models.vision_transformer import resize_pos_embed
|
99 |
|
100 |
-
clip_model, clip_transform = clip.load("RN50", jit=False, device=device)
|
101 |
|
102 |
-
preprocess = Compose([
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
])
|
107 |
|
108 |
-
image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1)
|
109 |
-
image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1)
|
110 |
|
111 |
-
num_patches = 196 #600 * 1000 // 32 // 32
|
112 |
-
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1], device=device),)
|
113 |
-
pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
|
114 |
-
clip_model.visual.attnpool.positional_embedding = pos_embed
|
115 |
|
116 |
|
117 |
-
# End below
|
118 |
-
print('Loading the model: CompVis/ldm-text2im-large-256')
|
119 |
-
ldm_pipeline = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
120 |
|
121 |
-
def generate_image_from_text(prompt, steps=100, seed=42, guidance_scale=6.0):
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
def generate_text_from_image(img):
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
|
140 |
-
|
141 |
-
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
|
169 |
-
|
170 |
|
171 |
-
|
172 |
|
173 |
|
174 |
-
def generate_drawing_from_image(img, steps=100, seed=42, guidance_scale=6.0):
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
|
180 |
|
181 |
random_seed = random.randint(0, 2147483647)
|
182 |
|
|
|
|
|
|
|
183 |
|
184 |
gr.Interface(
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
194 |
).launch()
|
|
|
1 |
+
# import torch
|
2 |
+
# import torch.nn as nn
|
3 |
+
# import numpy as np
|
4 |
+
# import json
|
5 |
+
# import captioning.utils.opts as opts
|
6 |
+
# import captioning.models as models
|
7 |
+
# import captioning.utils.misc as utils
|
8 |
+
# import pytorch_lightning as pl
|
9 |
import gradio as gr
|
10 |
|
11 |
+
# from diffusers import LDMTextToImagePipeline
|
12 |
+
# # import PIL.Image
|
13 |
import random
|
14 |
+
# import os
|
15 |
|
16 |
|
17 |
+
# # Checkpoint class
|
18 |
+
# class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
|
19 |
+
# def on_keyboard_interrupt(self, trainer, pl_module):
|
20 |
+
# # Save model when keyboard interrupt
|
21 |
+
# filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
|
22 |
+
# self._save_model(filepath)
|
23 |
|
24 |
+
# device = 'cpu' #@param ["cuda", "cpu"] {allow-input: true}
|
25 |
+
# reward = 'clips_grammar'
|
26 |
|
27 |
+
# cfg = f'./configs/phase2/clipRN50_{reward}.yml'
|
28 |
|
29 |
+
# print("Loading cfg from", cfg)
|
30 |
|
31 |
+
# opt = opts.parse_opt(parse=False, cfg=cfg)
|
32 |
|
33 |
+
# import gdown
|
34 |
|
35 |
+
# url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W"
|
36 |
+
# gdown.download_folder(url, quiet=True, use_cookies=False, output="save/")
|
37 |
|
38 |
+
# url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp"
|
39 |
+
# gdown.download(url, quiet=True, use_cookies=False, output="data/")
|
40 |
|
41 |
+
# dict_json = json.load(open('./data/cocotalk.json'))
|
42 |
+
# print(dict_json.keys())
|
43 |
|
44 |
+
# ix_to_word = dict_json['ix_to_word']
|
45 |
+
# vocab_size = len(ix_to_word)
|
46 |
+
# print('vocab size:', vocab_size)
|
47 |
|
48 |
+
# seq_length = 1
|
49 |
|
50 |
+
# opt.vocab_size = vocab_size
|
51 |
+
# opt.seq_length = seq_length
|
52 |
|
53 |
+
# opt.batch_size = 1
|
54 |
+
# opt.vocab = ix_to_word
|
55 |
|
56 |
+
# model = models.setup(opt)
|
57 |
+
# del opt.vocab
|
58 |
|
59 |
+
# ckpt_path = opt.checkpoint_path + '-last.ckpt'
|
60 |
|
61 |
+
# print("Loading checkpoint from", ckpt_path)
|
62 |
+
# raw_state_dict = torch.load(
|
63 |
+
# ckpt_path,
|
64 |
+
# map_location=device)
|
65 |
|
66 |
+
# strict = True
|
67 |
|
68 |
+
# state_dict = raw_state_dict['state_dict']
|
69 |
|
70 |
+
# if '_vocab' in state_dict:
|
71 |
+
# model.vocab = utils.deserialize(state_dict['_vocab'])
|
72 |
+
# del state_dict['_vocab']
|
73 |
+
# elif strict:
|
74 |
+
# raise KeyError
|
75 |
+
# if '_opt' in state_dict:
|
76 |
+
# saved_model_opt = utils.deserialize(state_dict['_opt'])
|
77 |
+
# del state_dict['_opt']
|
78 |
+
# # Make sure the saved opt is compatible with the curren topt
|
79 |
+
# need_be_same = ["caption_model",
|
80 |
+
# "rnn_type", "rnn_size", "num_layers"]
|
81 |
+
# for checkme in need_be_same:
|
82 |
+
# if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
|
83 |
+
# getattr(opt, checkme) in ['updown', 'topdown']:
|
84 |
+
# continue
|
85 |
+
# assert getattr(saved_model_opt, checkme) == getattr(
|
86 |
+
# opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
|
87 |
+
# elif strict:
|
88 |
+
# raise KeyError
|
89 |
+
# res = model.load_state_dict(state_dict, strict)
|
90 |
+
# print(res)
|
91 |
|
92 |
+
# model = model.to(device)
|
93 |
+
# model.eval();
|
94 |
|
95 |
+
# import clip
|
96 |
+
# from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
97 |
+
# from PIL import Image
|
98 |
+
# from timm.models.vision_transformer import resize_pos_embed
|
99 |
|
100 |
+
# clip_model, clip_transform = clip.load("RN50", jit=False, device=device)
|
101 |
|
102 |
+
# preprocess = Compose([
|
103 |
+
# Resize((448, 448), interpolation=Image.BICUBIC),
|
104 |
+
# CenterCrop((448, 448)),
|
105 |
+
# ToTensor()
|
106 |
+
# ])
|
107 |
|
108 |
+
# image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1)
|
109 |
+
# image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1)
|
110 |
|
111 |
+
# num_patches = 196 #600 * 1000 // 32 // 32
|
112 |
+
# pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1], device=device),)
|
113 |
+
# pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
|
114 |
+
# clip_model.visual.attnpool.positional_embedding = pos_embed
|
115 |
|
116 |
|
117 |
+
# # End below
|
118 |
+
# print('Loading the model: CompVis/ldm-text2im-large-256')
|
119 |
+
# ldm_pipeline = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
120 |
|
121 |
+
# def generate_image_from_text(prompt, steps=100, seed=42, guidance_scale=6.0):
|
122 |
+
# print('RUN: generate_image_from_text')
|
123 |
+
# torch.cuda.empty_cache()
|
124 |
+
# generator = torch.manual_seed(seed)
|
125 |
+
# images = ldm_pipeline([prompt], generator=generator, num_inference_steps=steps, eta=0.3, guidance_scale=guidance_scale)["sample"]
|
126 |
+
# return images[0]
|
127 |
+
|
128 |
+
# def generate_text_from_image(img):
|
129 |
+
# print('RUN: generate_text_from_image')
|
130 |
+
# with torch.no_grad():
|
131 |
+
# image = preprocess(img)
|
132 |
+
# image = torch.tensor(np.stack([image])).to(device)
|
133 |
+
# image -= image_mean
|
134 |
+
# image /= image_std
|
135 |
|
136 |
+
# tmp_att, tmp_fc = clip_model.encode_image(image)
|
137 |
+
# tmp_att = tmp_att[0].permute(1, 2, 0)
|
138 |
+
# tmp_fc = tmp_fc[0]
|
139 |
|
140 |
+
# att_feat = tmp_att
|
141 |
+
# fc_feat = tmp_fc
|
142 |
|
143 |
+
# # Inference configurations
|
144 |
+
# eval_kwargs = {}
|
145 |
+
# eval_kwargs.update(vars(opt))
|
146 |
|
147 |
+
# verbose = eval_kwargs.get('verbose', True)
|
148 |
+
# verbose_beam = eval_kwargs.get('verbose_beam', 0)
|
149 |
+
# verbose_loss = eval_kwargs.get('verbose_loss', 1)
|
150 |
|
151 |
+
# # dataset = eval_kwargs.get('dataset', 'coco')
|
152 |
+
# beam_size = eval_kwargs.get('beam_size', 1)
|
153 |
+
# sample_n = eval_kwargs.get('sample_n', 1)
|
154 |
+
# remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
|
155 |
|
156 |
+
# with torch.no_grad():
|
157 |
+
# fc_feats = torch.zeros((1,0)).to(device)
|
158 |
+
# att_feats = att_feat.view(1, 196, 2048).float().to(device)
|
159 |
+
# att_masks = None
|
160 |
|
161 |
+
# # forward the model to also get generated samples for each image
|
162 |
+
# # Only leave one feature for each image, in case duplicate sample
|
163 |
+
# tmp_eval_kwargs = eval_kwargs.copy()
|
164 |
+
# tmp_eval_kwargs.update({'sample_n': 1})
|
165 |
+
# seq, seq_logprobs = model(
|
166 |
+
# fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
167 |
+
# seq = seq.data
|
168 |
|
169 |
+
# sents = utils.decode_sequence(model.vocab, seq)
|
170 |
|
171 |
+
# return sents[0]
|
172 |
|
173 |
|
174 |
+
# def generate_drawing_from_image(img, steps=100, seed=42, guidance_scale=6.0):
|
175 |
+
# print('RUN: generate_drawing_from_image')
|
176 |
+
# caption = generate_text_from_image(img)
|
177 |
+
# gen_image = generate_image_from_text(caption, steps=steps, seed=seed, guidance_scale=guidance_scale)
|
178 |
+
# return gen_image
|
179 |
|
180 |
|
181 |
random_seed = random.randint(0, 2147483647)
|
182 |
|
183 |
+
def test_fn(**kwargs):
|
184 |
+
return None
|
185 |
+
|
186 |
|
187 |
gr.Interface(
|
188 |
+
# generate_drawing_from_image,
|
189 |
+
test_fn,
|
190 |
+
inputs=[
|
191 |
+
gr.Image(type="pil"),
|
192 |
+
gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1),
|
193 |
+
gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed, step=1),
|
194 |
+
gr.inputs.Slider(1.0, 20.0, label='Guidance Scale - how much the prompt will influence the results', default=6.0, step=0.1),
|
195 |
+
],
|
196 |
+
outputs=gr.Image(shape=[256,256], type="pil", elem_id="output_image"),
|
197 |
+
css="#output_image{width: 256px}",
|
198 |
).launch()
|
captioning/__init__.py
DELETED
File without changes
|
captioning/data/__init__.py
DELETED
File without changes
|
captioning/data/dataloader.py
DELETED
@@ -1,425 +0,0 @@
|
|
1 |
-
from __future__ import absolute_import
|
2 |
-
from __future__ import division
|
3 |
-
from __future__ import print_function
|
4 |
-
|
5 |
-
import json
|
6 |
-
import h5py
|
7 |
-
from lmdbdict import lmdbdict
|
8 |
-
from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
|
9 |
-
import os
|
10 |
-
import numpy as np
|
11 |
-
import numpy.random as npr
|
12 |
-
import random
|
13 |
-
from functools import partial
|
14 |
-
|
15 |
-
import torch
|
16 |
-
import torch.utils.data as data
|
17 |
-
|
18 |
-
import multiprocessing
|
19 |
-
import six
|
20 |
-
|
21 |
-
class HybridLoader:
|
22 |
-
"""
|
23 |
-
If db_path is a director, then use normal file loading
|
24 |
-
If lmdb, then load from lmdb
|
25 |
-
The loading method depend on extention.
|
26 |
-
|
27 |
-
in_memory: if in_memory is True, we save all the features in memory
|
28 |
-
For individual np(y|z)s, we don't need to do that because the system will do this for us.
|
29 |
-
Should be useful for lmdb or h5.
|
30 |
-
(Copied this idea from vilbert)
|
31 |
-
"""
|
32 |
-
def __init__(self, db_path, ext, in_memory=False):
|
33 |
-
self.db_path = db_path
|
34 |
-
self.ext = ext
|
35 |
-
if self.ext == '.npy':
|
36 |
-
self.loader = lambda x: np.load(six.BytesIO(x))
|
37 |
-
else:
|
38 |
-
def load_npz(x):
|
39 |
-
x = np.load(six.BytesIO(x))
|
40 |
-
return x['feat'] if 'feat' in x else x['z'] # normally it should be 'feat', but under cocotest_bu, the key is saved to be 'z' mistakenly.
|
41 |
-
self.loader = load_npz
|
42 |
-
if db_path.endswith('.lmdb'):
|
43 |
-
self.db_type = 'lmdb'
|
44 |
-
self.lmdb = lmdbdict(db_path, unsafe=True)
|
45 |
-
self.lmdb._key_dumps = DUMPS_FUNC['ascii']
|
46 |
-
self.lmdb._value_loads = LOADS_FUNC['identity']
|
47 |
-
elif db_path.endswith('.pth'): # Assume a key,value dictionary
|
48 |
-
self.db_type = 'pth'
|
49 |
-
self.feat_file = torch.load(db_path)
|
50 |
-
self.loader = lambda x: x
|
51 |
-
print('HybridLoader: ext is ignored')
|
52 |
-
elif db_path.endswith('h5'):
|
53 |
-
self.db_type = 'h5'
|
54 |
-
self.loader = lambda x: np.array(x).astype('float32')
|
55 |
-
else:
|
56 |
-
self.db_type = 'dir'
|
57 |
-
|
58 |
-
self.in_memory = in_memory
|
59 |
-
if self.in_memory:
|
60 |
-
self.features = {}
|
61 |
-
|
62 |
-
def get(self, key):
|
63 |
-
|
64 |
-
if self.in_memory and key in self.features:
|
65 |
-
# We save f_input because we want to save the
|
66 |
-
# compressed bytes to save memory
|
67 |
-
f_input = self.features[key]
|
68 |
-
elif self.db_type == 'lmdb':
|
69 |
-
f_input = self.lmdb[key]
|
70 |
-
elif self.db_type == 'pth':
|
71 |
-
f_input = self.feat_file[key]
|
72 |
-
elif self.db_type == 'h5':
|
73 |
-
f_input = h5py.File(self.db_path, 'r')[key]
|
74 |
-
else:
|
75 |
-
f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
|
76 |
-
|
77 |
-
if self.in_memory and key not in self.features:
|
78 |
-
self.features[key] = f_input
|
79 |
-
|
80 |
-
# load image
|
81 |
-
feat = self.loader(f_input)
|
82 |
-
|
83 |
-
return feat
|
84 |
-
|
85 |
-
class Dataset(data.Dataset):
|
86 |
-
|
87 |
-
def get_vocab_size(self):
|
88 |
-
return self.vocab_size
|
89 |
-
|
90 |
-
def get_vocab(self):
|
91 |
-
return self.ix_to_word
|
92 |
-
|
93 |
-
def get_seq_length(self):
|
94 |
-
return self.seq_length
|
95 |
-
|
96 |
-
def __init__(self, opt):
|
97 |
-
self.opt = opt
|
98 |
-
self.seq_per_img = opt.seq_per_img
|
99 |
-
|
100 |
-
# feature related options
|
101 |
-
self.use_fc = getattr(opt, 'use_fc', True)
|
102 |
-
self.use_att = getattr(opt, 'use_att', True)
|
103 |
-
self.use_box = getattr(opt, 'use_box', 0)
|
104 |
-
self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
|
105 |
-
self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
|
106 |
-
|
107 |
-
# load the json file which contains additional information about the dataset
|
108 |
-
print('DataLoader loading json file: ', opt.input_json)
|
109 |
-
self.info = json.load(open(self.opt.input_json))
|
110 |
-
if 'ix_to_word' in self.info:
|
111 |
-
self.ix_to_word = self.info['ix_to_word']
|
112 |
-
self.vocab_size = len(self.ix_to_word)
|
113 |
-
print('vocab size is ', self.vocab_size)
|
114 |
-
|
115 |
-
# open the hdf5 file
|
116 |
-
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
|
117 |
-
"""
|
118 |
-
Setting input_label_h5 to none is used when only doing generation.
|
119 |
-
For example, when you need to test on coco test set.
|
120 |
-
"""
|
121 |
-
if self.opt.input_label_h5 != 'none':
|
122 |
-
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
|
123 |
-
# load in the sequence data
|
124 |
-
seq_size = self.h5_label_file['labels'].shape
|
125 |
-
self.label = self.h5_label_file['labels'][:]
|
126 |
-
self.seq_length = seq_size[1]
|
127 |
-
print('max sequence length in data is', self.seq_length)
|
128 |
-
# load the pointers in full to RAM (should be small enough)
|
129 |
-
self.label_start_ix = self.h5_label_file['label_start_ix'][:]
|
130 |
-
self.label_end_ix = self.h5_label_file['label_end_ix'][:]
|
131 |
-
else:
|
132 |
-
self.seq_length = 1
|
133 |
-
|
134 |
-
self.data_in_memory = getattr(opt, 'data_in_memory', False)
|
135 |
-
self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
|
136 |
-
self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
|
137 |
-
self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
|
138 |
-
|
139 |
-
self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
|
140 |
-
print('read %d image features' %(self.num_images))
|
141 |
-
|
142 |
-
# separate out indexes for each of the provided splits
|
143 |
-
self.split_ix = {'train': [], 'val': [], 'test': []}
|
144 |
-
for ix in range(len(self.info['images'])):
|
145 |
-
img = self.info['images'][ix]
|
146 |
-
if not 'split' in img:
|
147 |
-
self.split_ix['train'].append(ix)
|
148 |
-
self.split_ix['val'].append(ix)
|
149 |
-
self.split_ix['test'].append(ix)
|
150 |
-
elif img['split'] == 'train':
|
151 |
-
self.split_ix['train'].append(ix)
|
152 |
-
elif img['split'] == 'val':
|
153 |
-
self.split_ix['val'].append(ix)
|
154 |
-
elif img['split'] == 'test':
|
155 |
-
self.split_ix['test'].append(ix)
|
156 |
-
elif opt.train_only == 0: # restval
|
157 |
-
self.split_ix['train'].append(ix)
|
158 |
-
|
159 |
-
print('assigned %d images to split train' %len(self.split_ix['train']))
|
160 |
-
print('assigned %d images to split val' %len(self.split_ix['val']))
|
161 |
-
print('assigned %d images to split test' %len(self.split_ix['test']))
|
162 |
-
|
163 |
-
def get_captions(self, ix, seq_per_img):
|
164 |
-
# fetch the sequence labels
|
165 |
-
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
|
166 |
-
ix2 = self.label_end_ix[ix] - 1
|
167 |
-
ncap = ix2 - ix1 + 1 # number of captions available for this image
|
168 |
-
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
|
169 |
-
|
170 |
-
if ncap < seq_per_img:
|
171 |
-
# we need to subsample (with replacement)
|
172 |
-
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
|
173 |
-
for q in range(seq_per_img):
|
174 |
-
ixl = random.randint(ix1,ix2)
|
175 |
-
seq[q, :] = self.label[ixl, :self.seq_length]
|
176 |
-
else:
|
177 |
-
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
|
178 |
-
seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
|
179 |
-
|
180 |
-
return seq
|
181 |
-
|
182 |
-
def collate_func(self, batch, split):
|
183 |
-
seq_per_img = self.seq_per_img
|
184 |
-
|
185 |
-
fc_batch = []
|
186 |
-
att_batch = []
|
187 |
-
label_batch = []
|
188 |
-
|
189 |
-
wrapped = False
|
190 |
-
|
191 |
-
infos = []
|
192 |
-
gts = []
|
193 |
-
|
194 |
-
for sample in batch:
|
195 |
-
# fetch image
|
196 |
-
tmp_fc, tmp_att, tmp_seq, \
|
197 |
-
ix, it_pos_now, tmp_wrapped = sample
|
198 |
-
if tmp_wrapped:
|
199 |
-
wrapped = True
|
200 |
-
|
201 |
-
fc_batch.append(tmp_fc)
|
202 |
-
att_batch.append(tmp_att)
|
203 |
-
|
204 |
-
tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
|
205 |
-
if hasattr(self, 'h5_label_file'):
|
206 |
-
# if there is ground truth
|
207 |
-
tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
|
208 |
-
label_batch.append(tmp_label)
|
209 |
-
|
210 |
-
# Used for reward evaluation
|
211 |
-
if hasattr(self, 'h5_label_file'):
|
212 |
-
# if there is ground truth
|
213 |
-
gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
|
214 |
-
else:
|
215 |
-
gts.append([])
|
216 |
-
|
217 |
-
# record associated info as well
|
218 |
-
info_dict = {}
|
219 |
-
info_dict['ix'] = ix
|
220 |
-
info_dict['id'] = self.info['images'][ix]['id']
|
221 |
-
info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
|
222 |
-
infos.append(info_dict)
|
223 |
-
|
224 |
-
# #sort by att_feat length
|
225 |
-
# fc_batch, att_batch, label_batch, gts, infos = \
|
226 |
-
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
|
227 |
-
fc_batch, att_batch, label_batch, gts, infos = \
|
228 |
-
zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
|
229 |
-
data = {}
|
230 |
-
data['fc_feats'] = np.stack(fc_batch)
|
231 |
-
# merge att_feats
|
232 |
-
max_att_len = max([_.shape[0] for _ in att_batch])
|
233 |
-
data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
|
234 |
-
for i in range(len(att_batch)):
|
235 |
-
data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
|
236 |
-
data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
|
237 |
-
for i in range(len(att_batch)):
|
238 |
-
data['att_masks'][i, :att_batch[i].shape[0]] = 1
|
239 |
-
# set att_masks to None if attention features have same length
|
240 |
-
if data['att_masks'].sum() == data['att_masks'].size:
|
241 |
-
data['att_masks'] = None
|
242 |
-
|
243 |
-
data['labels'] = np.vstack(label_batch)
|
244 |
-
# generate mask
|
245 |
-
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
|
246 |
-
mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
|
247 |
-
for ix, row in enumerate(mask_batch):
|
248 |
-
row[:nonzeros[ix]] = 1
|
249 |
-
data['masks'] = mask_batch
|
250 |
-
data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
|
251 |
-
data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
|
252 |
-
|
253 |
-
data['gts'] = gts # all ground truth captions of each images
|
254 |
-
data['bounds'] = {'it_pos_now': it_pos_now, # the it_pos_now of the last sample
|
255 |
-
'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
|
256 |
-
data['infos'] = infos
|
257 |
-
|
258 |
-
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
|
259 |
-
|
260 |
-
return data
|
261 |
-
|
262 |
-
def __getitem__(self, index):
|
263 |
-
"""This function returns a tuple that is further passed to collate_fn
|
264 |
-
"""
|
265 |
-
ix, it_pos_now, wrapped = index #self.split_ix[index]
|
266 |
-
if self.use_att:
|
267 |
-
att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
|
268 |
-
# Reshape to K x C
|
269 |
-
att_feat = att_feat.reshape(-1, att_feat.shape[-1])
|
270 |
-
if self.norm_att_feat:
|
271 |
-
att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
|
272 |
-
if self.use_box:
|
273 |
-
box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
|
274 |
-
# devided by image width and height
|
275 |
-
x1,y1,x2,y2 = np.hsplit(box_feat, 4)
|
276 |
-
h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
|
277 |
-
box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
|
278 |
-
if self.norm_box_feat:
|
279 |
-
box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
|
280 |
-
att_feat = np.hstack([att_feat, box_feat])
|
281 |
-
# sort the features by the size of boxes
|
282 |
-
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
|
283 |
-
else:
|
284 |
-
att_feat = np.zeros((0,0), dtype='float32')
|
285 |
-
if self.use_fc:
|
286 |
-
try:
|
287 |
-
fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
|
288 |
-
except:
|
289 |
-
# Use average of attention when there is no fc provided (For bottomup feature)
|
290 |
-
fc_feat = att_feat.mean(0)
|
291 |
-
else:
|
292 |
-
fc_feat = np.zeros((0), dtype='float32')
|
293 |
-
if hasattr(self, 'h5_label_file'):
|
294 |
-
seq = self.get_captions(ix, self.seq_per_img)
|
295 |
-
else:
|
296 |
-
seq = None
|
297 |
-
return (fc_feat,
|
298 |
-
att_feat, seq,
|
299 |
-
ix, it_pos_now, wrapped)
|
300 |
-
|
301 |
-
def __len__(self):
|
302 |
-
return len(self.info['images'])
|
303 |
-
|
304 |
-
class DataLoader:
|
305 |
-
def __init__(self, opt):
|
306 |
-
self.opt = opt
|
307 |
-
self.batch_size = self.opt.batch_size
|
308 |
-
self.dataset = Dataset(opt)
|
309 |
-
|
310 |
-
# Initialize loaders and iters
|
311 |
-
self.loaders, self.iters = {}, {}
|
312 |
-
for split in ['train', 'val', 'test']:
|
313 |
-
if split == 'train':
|
314 |
-
sampler = MySampler(self.dataset.split_ix[split], shuffle=True, wrap=True)
|
315 |
-
else:
|
316 |
-
sampler = MySampler(self.dataset.split_ix[split], shuffle=False, wrap=False)
|
317 |
-
self.loaders[split] = data.DataLoader(dataset=self.dataset,
|
318 |
-
batch_size=self.batch_size,
|
319 |
-
sampler=sampler,
|
320 |
-
pin_memory=True,
|
321 |
-
num_workers=4, # 4 is usually enough
|
322 |
-
collate_fn=partial(self.dataset.collate_func, split=split),
|
323 |
-
drop_last=False)
|
324 |
-
self.iters[split] = iter(self.loaders[split])
|
325 |
-
|
326 |
-
def get_batch(self, split):
|
327 |
-
try:
|
328 |
-
data = next(self.iters[split])
|
329 |
-
except StopIteration:
|
330 |
-
self.iters[split] = iter(self.loaders[split])
|
331 |
-
data = next(self.iters[split])
|
332 |
-
return data
|
333 |
-
|
334 |
-
def reset_iterator(self, split):
|
335 |
-
self.loaders[split].sampler._reset_iter()
|
336 |
-
self.iters[split] = iter(self.loaders[split])
|
337 |
-
|
338 |
-
def get_vocab_size(self):
|
339 |
-
return self.dataset.get_vocab_size()
|
340 |
-
|
341 |
-
@property
|
342 |
-
def vocab_size(self):
|
343 |
-
return self.get_vocab_size()
|
344 |
-
|
345 |
-
def get_vocab(self):
|
346 |
-
return self.dataset.get_vocab()
|
347 |
-
|
348 |
-
def get_seq_length(self):
|
349 |
-
return self.dataset.get_seq_length()
|
350 |
-
|
351 |
-
@property
|
352 |
-
def seq_length(self):
|
353 |
-
return self.get_seq_length()
|
354 |
-
|
355 |
-
def state_dict(self):
|
356 |
-
def get_prefetch_num(split):
|
357 |
-
if self.loaders[split].num_workers > 0:
|
358 |
-
return (self.iters[split]._send_idx - self.iters[split]._rcvd_idx) * self.batch_size
|
359 |
-
else:
|
360 |
-
return 0
|
361 |
-
return {split: loader.sampler.state_dict(get_prefetch_num(split)) \
|
362 |
-
for split, loader in self.loaders.items()}
|
363 |
-
|
364 |
-
def load_state_dict(self, state_dict=None):
|
365 |
-
if state_dict is None:
|
366 |
-
return
|
367 |
-
for split in self.loaders.keys():
|
368 |
-
self.loaders[split].sampler.load_state_dict(state_dict[split])
|
369 |
-
|
370 |
-
|
371 |
-
class MySampler(data.sampler.Sampler):
|
372 |
-
def __init__(self, index_list, shuffle, wrap):
|
373 |
-
self.index_list = index_list
|
374 |
-
self.shuffle = shuffle
|
375 |
-
self.wrap = wrap
|
376 |
-
# if wrap, there will be not stop iteration called
|
377 |
-
# wrap True used during training, and wrap False used during test.
|
378 |
-
self._reset_iter()
|
379 |
-
|
380 |
-
def __iter__(self):
|
381 |
-
return self
|
382 |
-
|
383 |
-
def __next__(self):
|
384 |
-
wrapped = False
|
385 |
-
if self.iter_counter == len(self._index_list):
|
386 |
-
self._reset_iter()
|
387 |
-
if self.wrap:
|
388 |
-
wrapped = True
|
389 |
-
else:
|
390 |
-
raise StopIteration()
|
391 |
-
if len(self._index_list) == 0: # overflow when 0 samples
|
392 |
-
return None
|
393 |
-
elem = (self._index_list[self.iter_counter], self.iter_counter+1, wrapped)
|
394 |
-
self.iter_counter += 1
|
395 |
-
return elem
|
396 |
-
|
397 |
-
def next(self):
|
398 |
-
return self.__next__()
|
399 |
-
|
400 |
-
def _reset_iter(self):
|
401 |
-
if self.shuffle:
|
402 |
-
rand_perm = npr.permutation(len(self.index_list))
|
403 |
-
self._index_list = [self.index_list[_] for _ in rand_perm]
|
404 |
-
else:
|
405 |
-
self._index_list = self.index_list
|
406 |
-
|
407 |
-
self.iter_counter = 0
|
408 |
-
|
409 |
-
def __len__(self):
|
410 |
-
return len(self.index_list)
|
411 |
-
|
412 |
-
def load_state_dict(self, state_dict=None):
|
413 |
-
if state_dict is None:
|
414 |
-
return
|
415 |
-
self._index_list = state_dict['index_list']
|
416 |
-
self.iter_counter = state_dict['iter_counter']
|
417 |
-
|
418 |
-
def state_dict(self, prefetched_num=None):
|
419 |
-
prefetched_num = prefetched_num or 0
|
420 |
-
return {
|
421 |
-
'index_list': self._index_list,
|
422 |
-
'iter_counter': self.iter_counter - prefetched_num
|
423 |
-
}
|
424 |
-
|
425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/data/pth_loader.py
DELETED
@@ -1,334 +0,0 @@
|
|
1 |
-
from __future__ import absolute_import
|
2 |
-
from __future__ import division
|
3 |
-
from __future__ import print_function
|
4 |
-
|
5 |
-
import json
|
6 |
-
import h5py
|
7 |
-
from lmdbdict import lmdbdict
|
8 |
-
from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
|
9 |
-
import os
|
10 |
-
import numpy as np
|
11 |
-
import numpy.random as npr
|
12 |
-
import random
|
13 |
-
|
14 |
-
import torch
|
15 |
-
import torch.utils.data as data
|
16 |
-
|
17 |
-
import multiprocessing
|
18 |
-
import six
|
19 |
-
|
20 |
-
verbose = True
|
21 |
-
# import torch
|
22 |
-
# if torch.cuda.current_device() in [0, -1]:
|
23 |
-
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
|
24 |
-
verbose = False
|
25 |
-
|
26 |
-
class HybridLoader:
|
27 |
-
"""
|
28 |
-
If db_path is a director, then use normal file loading
|
29 |
-
If lmdb, then load from lmdb
|
30 |
-
The loading method depend on extention.
|
31 |
-
|
32 |
-
in_memory: if in_memory is True, we save all the features in memory
|
33 |
-
For individual np(y|z)s, we don't need to do that because the system will do this for us.
|
34 |
-
Should be useful for lmdb or h5.
|
35 |
-
(Copied this idea from vilbert)
|
36 |
-
"""
|
37 |
-
def __init__(self, db_path, ext, in_memory=False):
|
38 |
-
self.db_path = db_path
|
39 |
-
self.ext = ext
|
40 |
-
if self.ext == '.npy':
|
41 |
-
self.loader = lambda x: np.load(six.BytesIO(x))
|
42 |
-
else:
|
43 |
-
self.loader = lambda x: np.load(six.BytesIO(x))['feat']
|
44 |
-
if db_path.endswith('.lmdb'):
|
45 |
-
self.db_type = 'lmdb'
|
46 |
-
self.lmdb = lmdbdict(db_path, unsafe=True)
|
47 |
-
self.lmdb._key_dumps = DUMPS_FUNC['ascii']
|
48 |
-
self.lmdb._value_loads = LOADS_FUNC['identity']
|
49 |
-
elif db_path.endswith('.pth'): # Assume a key,value dictionary
|
50 |
-
self.db_type = 'pth'
|
51 |
-
self.feat_file = torch.load(db_path)
|
52 |
-
self.loader = lambda x: x
|
53 |
-
print('HybridLoader: ext is ignored')
|
54 |
-
elif db_path.endswith('h5'):
|
55 |
-
self.db_type = 'h5'
|
56 |
-
self.loader = lambda x: np.array(x).astype('float32')
|
57 |
-
else:
|
58 |
-
self.db_type = 'dir'
|
59 |
-
|
60 |
-
self.in_memory = in_memory
|
61 |
-
if self.in_memory:
|
62 |
-
self.features = {}
|
63 |
-
|
64 |
-
def get(self, key):
|
65 |
-
|
66 |
-
if self.in_memory and key in self.features:
|
67 |
-
# We save f_input because we want to save the
|
68 |
-
# compressed bytes to save memory
|
69 |
-
f_input = self.features[key]
|
70 |
-
elif self.db_type == 'lmdb':
|
71 |
-
f_input = self.lmdb[key]
|
72 |
-
elif self.db_type == 'pth':
|
73 |
-
f_input = self.feat_file[key]
|
74 |
-
elif self.db_type == 'h5':
|
75 |
-
f_input = h5py.File(self.db_path, 'r')[key]
|
76 |
-
else:
|
77 |
-
f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
|
78 |
-
|
79 |
-
if self.in_memory and key not in self.features:
|
80 |
-
self.features[key] = f_input
|
81 |
-
|
82 |
-
# load image
|
83 |
-
feat = self.loader(f_input)
|
84 |
-
|
85 |
-
return feat
|
86 |
-
|
87 |
-
class CaptionDataset(data.Dataset):
|
88 |
-
|
89 |
-
def get_vocab_size(self):
|
90 |
-
return self.vocab_size
|
91 |
-
|
92 |
-
def get_vocab(self):
|
93 |
-
return self.ix_to_word
|
94 |
-
|
95 |
-
def get_seq_length(self):
|
96 |
-
return self.seq_length
|
97 |
-
|
98 |
-
def __init__(self, opt):
|
99 |
-
self.opt = opt
|
100 |
-
self.seq_per_img = opt.seq_per_img
|
101 |
-
|
102 |
-
# feature related options
|
103 |
-
self.use_fc = getattr(opt, 'use_fc', True)
|
104 |
-
self.use_att = getattr(opt, 'use_att', True)
|
105 |
-
self.use_box = getattr(opt, 'use_box', 0)
|
106 |
-
self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
|
107 |
-
self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
|
108 |
-
|
109 |
-
# load the json file which contains additional information about the dataset
|
110 |
-
if verbose:
|
111 |
-
print('DataLoader loading json file: ', opt.input_json)
|
112 |
-
self.info = json.load(open(self.opt.input_json))
|
113 |
-
if 'ix_to_word' in self.info:
|
114 |
-
self.ix_to_word = self.info['ix_to_word']
|
115 |
-
self.vocab_size = len(self.ix_to_word)
|
116 |
-
if verbose:
|
117 |
-
print('vocab size is ', self.vocab_size)
|
118 |
-
|
119 |
-
# open the hdf5 file
|
120 |
-
if verbose:
|
121 |
-
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
|
122 |
-
"""
|
123 |
-
Setting input_label_h5 to none is used when only doing generation.
|
124 |
-
For example, when you need to test on coco test set.
|
125 |
-
"""
|
126 |
-
if self.opt.input_label_h5 != 'none':
|
127 |
-
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
|
128 |
-
# load in the sequence data
|
129 |
-
seq_size = self.h5_label_file['labels'].shape
|
130 |
-
self.label = self.h5_label_file['labels'][:]
|
131 |
-
self.seq_length = seq_size[1]
|
132 |
-
if verbose:
|
133 |
-
print('max sequence length in data is', self.seq_length)
|
134 |
-
# load the pointers in full to RAM (should be small enough)
|
135 |
-
self.label_start_ix = self.h5_label_file['label_start_ix'][:]
|
136 |
-
self.label_end_ix = self.h5_label_file['label_end_ix'][:]
|
137 |
-
else:
|
138 |
-
self.seq_length = 1
|
139 |
-
|
140 |
-
self.data_in_memory = getattr(opt, 'data_in_memory', False)
|
141 |
-
self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
|
142 |
-
self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
|
143 |
-
self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
|
144 |
-
|
145 |
-
self.use_clipscore = getattr(opt, 'use_clipscore', False)
|
146 |
-
# if self.use_clipscore:
|
147 |
-
self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory)
|
148 |
-
|
149 |
-
|
150 |
-
self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
|
151 |
-
if verbose:
|
152 |
-
print('read %d image features' %(self.num_images))
|
153 |
-
|
154 |
-
# separate out indexes for each of the provided splits
|
155 |
-
self.split_ix = {'train': [], 'val': [], 'test': []}
|
156 |
-
for ix in range(len(self.info['images'])):
|
157 |
-
img = self.info['images'][ix]
|
158 |
-
if not 'split' in img:
|
159 |
-
self.split_ix['train'].append(ix)
|
160 |
-
self.split_ix['val'].append(ix)
|
161 |
-
self.split_ix['test'].append(ix)
|
162 |
-
elif img['split'] == 'train':
|
163 |
-
self.split_ix['train'].append(ix)
|
164 |
-
elif img['split'] == 'val':
|
165 |
-
self.split_ix['val'].append(ix)
|
166 |
-
elif img['split'] == 'test':
|
167 |
-
self.split_ix['test'].append(ix)
|
168 |
-
elif opt.train_only == 0: # restval
|
169 |
-
self.split_ix['train'].append(ix)
|
170 |
-
|
171 |
-
if verbose:
|
172 |
-
print('assigned %d images to split train' %len(self.split_ix['train']))
|
173 |
-
print('assigned %d images to split val' %len(self.split_ix['val']))
|
174 |
-
print('assigned %d images to split test' %len(self.split_ix['test']))
|
175 |
-
|
176 |
-
def get_captions(self, ix, seq_per_img):
|
177 |
-
# fetch the sequence labels
|
178 |
-
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
|
179 |
-
ix2 = self.label_end_ix[ix] - 1
|
180 |
-
ncap = ix2 - ix1 + 1 # number of captions available for this image
|
181 |
-
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
|
182 |
-
|
183 |
-
if ncap < seq_per_img:
|
184 |
-
# we need to subsample (with replacement)
|
185 |
-
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
|
186 |
-
for q in range(seq_per_img):
|
187 |
-
ixl = random.randint(ix1,ix2)
|
188 |
-
seq[q, :] = self.label[ixl, :self.seq_length]
|
189 |
-
else:
|
190 |
-
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
|
191 |
-
seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
|
192 |
-
|
193 |
-
return seq
|
194 |
-
|
195 |
-
def collate_func(self, batch):
|
196 |
-
seq_per_img = self.seq_per_img
|
197 |
-
|
198 |
-
fc_batch = []
|
199 |
-
att_batch = []
|
200 |
-
label_batch = []
|
201 |
-
|
202 |
-
clip_vis_feat_batch = []
|
203 |
-
|
204 |
-
wrapped = False
|
205 |
-
|
206 |
-
infos = []
|
207 |
-
gts = []
|
208 |
-
|
209 |
-
for sample in batch:
|
210 |
-
# fetch image
|
211 |
-
# if self.use_clipscore:
|
212 |
-
tmp_fc, tmp_att, tmp_seq, \
|
213 |
-
ix, tmp_clip_vis_feat = sample
|
214 |
-
|
215 |
-
clip_vis_feat_batch.append(tmp_clip_vis_feat)
|
216 |
-
# else:
|
217 |
-
# tmp_fc, tmp_att, tmp_seq, \
|
218 |
-
# ix = sample
|
219 |
-
|
220 |
-
fc_batch.append(tmp_fc)
|
221 |
-
att_batch.append(tmp_att)
|
222 |
-
|
223 |
-
tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
|
224 |
-
if hasattr(self, 'h5_label_file'):
|
225 |
-
# if there is ground truth
|
226 |
-
tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
|
227 |
-
label_batch.append(tmp_label)
|
228 |
-
|
229 |
-
# Used for reward evaluation
|
230 |
-
if hasattr(self, 'h5_label_file'):
|
231 |
-
# if there is ground truth
|
232 |
-
gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
|
233 |
-
else:
|
234 |
-
gts.append([])
|
235 |
-
|
236 |
-
# record associated info as well
|
237 |
-
info_dict = {}
|
238 |
-
info_dict['ix'] = ix
|
239 |
-
info_dict['id'] = self.info['images'][ix]['id']
|
240 |
-
info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
|
241 |
-
infos.append(info_dict)
|
242 |
-
|
243 |
-
# #sort by att_feat length
|
244 |
-
# fc_batch, att_batch, label_batch, gts, infos = \
|
245 |
-
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
|
246 |
-
if self.use_clipscore:
|
247 |
-
fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \
|
248 |
-
zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True))
|
249 |
-
else:
|
250 |
-
fc_batch, att_batch, label_batch, gts, infos = \
|
251 |
-
zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
|
252 |
-
data = {}
|
253 |
-
data['fc_feats'] = np.stack(fc_batch)
|
254 |
-
# merge att_feats
|
255 |
-
max_att_len = max([_.shape[0] for _ in att_batch])
|
256 |
-
data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
|
257 |
-
for i in range(len(att_batch)):
|
258 |
-
data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
|
259 |
-
data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
|
260 |
-
for i in range(len(att_batch)):
|
261 |
-
data['att_masks'][i, :att_batch[i].shape[0]] = 1
|
262 |
-
# set att_masks to None if attention features have same length
|
263 |
-
if data['att_masks'].sum() == data['att_masks'].size:
|
264 |
-
data['att_masks'] = None
|
265 |
-
|
266 |
-
# if self.use_clipscore:
|
267 |
-
data['clip_vis_feats'] = np.stack(clip_vis_feat_batch)
|
268 |
-
|
269 |
-
data['labels'] = np.vstack(label_batch)
|
270 |
-
# generate mask
|
271 |
-
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
|
272 |
-
mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
|
273 |
-
for ix, row in enumerate(mask_batch):
|
274 |
-
row[:nonzeros[ix]] = 1
|
275 |
-
data['masks'] = mask_batch
|
276 |
-
data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
|
277 |
-
data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
|
278 |
-
|
279 |
-
data['gts'] = gts # all ground truth captions of each images
|
280 |
-
data['infos'] = infos
|
281 |
-
|
282 |
-
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
|
283 |
-
|
284 |
-
return data
|
285 |
-
|
286 |
-
def __getitem__(self, ix):
|
287 |
-
"""This function returns a tuple that is further passed to collate_fn
|
288 |
-
"""
|
289 |
-
if self.use_att:
|
290 |
-
att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
|
291 |
-
# Reshape to K x C
|
292 |
-
att_feat = att_feat.reshape(-1, att_feat.shape[-1])
|
293 |
-
if self.norm_att_feat:
|
294 |
-
att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
|
295 |
-
if self.use_box:
|
296 |
-
box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
|
297 |
-
# devided by image width and height
|
298 |
-
x1,y1,x2,y2 = np.hsplit(box_feat, 4)
|
299 |
-
h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
|
300 |
-
box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
|
301 |
-
if self.norm_box_feat:
|
302 |
-
box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
|
303 |
-
att_feat = np.hstack([att_feat, box_feat])
|
304 |
-
# sort the features by the size of boxes
|
305 |
-
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
|
306 |
-
else:
|
307 |
-
att_feat = np.zeros((0,0), dtype='float32')
|
308 |
-
if self.use_fc:
|
309 |
-
try:
|
310 |
-
fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
|
311 |
-
except:
|
312 |
-
# Use average of attention when there is no fc provided (For bottomup feature)
|
313 |
-
fc_feat = att_feat.mean(0)
|
314 |
-
else:
|
315 |
-
fc_feat = np.zeros((0), dtype='float32')
|
316 |
-
if hasattr(self, 'h5_label_file'):
|
317 |
-
seq = self.get_captions(ix, self.seq_per_img)
|
318 |
-
else:
|
319 |
-
seq = None
|
320 |
-
|
321 |
-
# if self.use_clipscore:
|
322 |
-
clip_vis_feat = self.clipscore_loader.get(
|
323 |
-
str(self.info['images'][ix]['id']))
|
324 |
-
|
325 |
-
return (fc_feat,
|
326 |
-
att_feat, seq,
|
327 |
-
ix, clip_vis_feat)
|
328 |
-
|
329 |
-
# return (fc_feat,
|
330 |
-
# att_feat, seq,
|
331 |
-
# ix)
|
332 |
-
|
333 |
-
def __len__(self):
|
334 |
-
return len(self.info['images'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/data/pth_loader_FineCapEval.py
DELETED
@@ -1,334 +0,0 @@
|
|
1 |
-
from __future__ import absolute_import
|
2 |
-
from __future__ import division
|
3 |
-
from __future__ import print_function
|
4 |
-
|
5 |
-
import json
|
6 |
-
import h5py
|
7 |
-
from lmdbdict import lmdbdict
|
8 |
-
from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC
|
9 |
-
import os
|
10 |
-
import numpy as np
|
11 |
-
import numpy.random as npr
|
12 |
-
import random
|
13 |
-
|
14 |
-
import torch
|
15 |
-
import torch.utils.data as data
|
16 |
-
|
17 |
-
import multiprocessing
|
18 |
-
import six
|
19 |
-
|
20 |
-
verbose = True
|
21 |
-
# import torch
|
22 |
-
# if torch.cuda.current_device() in [0, -1]:
|
23 |
-
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
|
24 |
-
verbose = False
|
25 |
-
|
26 |
-
class HybridLoader:
|
27 |
-
"""
|
28 |
-
If db_path is a director, then use normal file loading
|
29 |
-
If lmdb, then load from lmdb
|
30 |
-
The loading method depend on extention.
|
31 |
-
|
32 |
-
in_memory: if in_memory is True, we save all the features in memory
|
33 |
-
For individual np(y|z)s, we don't need to do that because the system will do this for us.
|
34 |
-
Should be useful for lmdb or h5.
|
35 |
-
(Copied this idea from vilbert)
|
36 |
-
"""
|
37 |
-
def __init__(self, db_path, ext, in_memory=False):
|
38 |
-
self.db_path = db_path
|
39 |
-
self.ext = ext
|
40 |
-
if self.ext == '.npy':
|
41 |
-
self.loader = lambda x: np.load(six.BytesIO(x))
|
42 |
-
else:
|
43 |
-
self.loader = lambda x: np.load(six.BytesIO(x))['feat']
|
44 |
-
if db_path.endswith('.lmdb'):
|
45 |
-
self.db_type = 'lmdb'
|
46 |
-
self.lmdb = lmdbdict(db_path, unsafe=True)
|
47 |
-
self.lmdb._key_dumps = DUMPS_FUNC['ascii']
|
48 |
-
self.lmdb._value_loads = LOADS_FUNC['identity']
|
49 |
-
elif db_path.endswith('.pth'): # Assume a key,value dictionary
|
50 |
-
self.db_type = 'pth'
|
51 |
-
self.feat_file = torch.load(db_path)
|
52 |
-
self.loader = lambda x: x
|
53 |
-
print('HybridLoader: ext is ignored')
|
54 |
-
elif db_path.endswith('h5'):
|
55 |
-
self.db_type = 'h5'
|
56 |
-
self.loader = lambda x: np.array(x).astype('float32')
|
57 |
-
else:
|
58 |
-
self.db_type = 'dir'
|
59 |
-
|
60 |
-
self.in_memory = in_memory
|
61 |
-
if self.in_memory:
|
62 |
-
self.features = {}
|
63 |
-
|
64 |
-
def get(self, key):
|
65 |
-
|
66 |
-
if self.in_memory and key in self.features:
|
67 |
-
# We save f_input because we want to save the
|
68 |
-
# compressed bytes to save memory
|
69 |
-
f_input = self.features[key]
|
70 |
-
elif self.db_type == 'lmdb':
|
71 |
-
f_input = self.lmdb[key]
|
72 |
-
elif self.db_type == 'pth':
|
73 |
-
f_input = self.feat_file[key]
|
74 |
-
elif self.db_type == 'h5':
|
75 |
-
f_input = h5py.File(self.db_path, 'r')[key]
|
76 |
-
else:
|
77 |
-
f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read()
|
78 |
-
|
79 |
-
if self.in_memory and key not in self.features:
|
80 |
-
self.features[key] = f_input
|
81 |
-
|
82 |
-
# load image
|
83 |
-
feat = self.loader(f_input)
|
84 |
-
|
85 |
-
return feat
|
86 |
-
|
87 |
-
class CaptionDataset(data.Dataset):
|
88 |
-
|
89 |
-
def get_vocab_size(self):
|
90 |
-
return self.vocab_size
|
91 |
-
|
92 |
-
def get_vocab(self):
|
93 |
-
return self.ix_to_word
|
94 |
-
|
95 |
-
def get_seq_length(self):
|
96 |
-
return self.seq_length
|
97 |
-
|
98 |
-
def __init__(self, opt):
|
99 |
-
self.opt = opt
|
100 |
-
self.seq_per_img = opt.seq_per_img
|
101 |
-
|
102 |
-
# feature related options
|
103 |
-
self.use_fc = getattr(opt, 'use_fc', True)
|
104 |
-
self.use_att = getattr(opt, 'use_att', True)
|
105 |
-
self.use_box = getattr(opt, 'use_box', 0)
|
106 |
-
self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
|
107 |
-
self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
|
108 |
-
|
109 |
-
# load the json file which contains additional information about the dataset
|
110 |
-
if verbose:
|
111 |
-
print('DataLoader loading json file: ', opt.input_json)
|
112 |
-
self.info = json.load(open(self.opt.input_json))
|
113 |
-
if 'ix_to_word' in self.info:
|
114 |
-
self.ix_to_word = self.info['ix_to_word']
|
115 |
-
self.vocab_size = len(self.ix_to_word)
|
116 |
-
if verbose:
|
117 |
-
print('vocab size is ', self.vocab_size)
|
118 |
-
|
119 |
-
# open the hdf5 file
|
120 |
-
if verbose:
|
121 |
-
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
|
122 |
-
"""
|
123 |
-
Setting input_label_h5 to none is used when only doing generation.
|
124 |
-
For example, when you need to test on coco test set.
|
125 |
-
"""
|
126 |
-
if self.opt.input_label_h5 != 'none':
|
127 |
-
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
|
128 |
-
# load in the sequence data
|
129 |
-
seq_size = self.h5_label_file['labels'].shape
|
130 |
-
self.label = self.h5_label_file['labels'][:]
|
131 |
-
self.seq_length = seq_size[1]
|
132 |
-
if verbose:
|
133 |
-
print('max sequence length in data is', self.seq_length)
|
134 |
-
# load the pointers in full to RAM (should be small enough)
|
135 |
-
self.label_start_ix = self.h5_label_file['label_start_ix'][:]
|
136 |
-
self.label_end_ix = self.h5_label_file['label_end_ix'][:]
|
137 |
-
else:
|
138 |
-
self.seq_length = 1
|
139 |
-
|
140 |
-
self.data_in_memory = getattr(opt, 'data_in_memory', False)
|
141 |
-
self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory)
|
142 |
-
self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory)
|
143 |
-
self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory)
|
144 |
-
|
145 |
-
self.use_clipscore = getattr(opt, 'use_clipscore', False)
|
146 |
-
if self.use_clipscore:
|
147 |
-
self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory)
|
148 |
-
|
149 |
-
|
150 |
-
self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
|
151 |
-
if verbose:
|
152 |
-
print('read %d image features' %(self.num_images))
|
153 |
-
|
154 |
-
# separate out indexes for each of the provided splits
|
155 |
-
self.split_ix = {'train': [], 'val': [], 'test': []}
|
156 |
-
for ix in range(len(self.info['images'])):
|
157 |
-
img = self.info['images'][ix]
|
158 |
-
if not 'split' in img:
|
159 |
-
self.split_ix['train'].append(ix)
|
160 |
-
self.split_ix['val'].append(ix)
|
161 |
-
self.split_ix['test'].append(ix)
|
162 |
-
elif img['split'] == 'train':
|
163 |
-
self.split_ix['train'].append(ix)
|
164 |
-
elif img['split'] == 'val':
|
165 |
-
self.split_ix['val'].append(ix)
|
166 |
-
elif img['split'] == 'test':
|
167 |
-
self.split_ix['test'].append(ix)
|
168 |
-
elif opt.train_only == 0: # restval
|
169 |
-
self.split_ix['train'].append(ix)
|
170 |
-
|
171 |
-
if verbose:
|
172 |
-
print('assigned %d images to split train' %len(self.split_ix['train']))
|
173 |
-
print('assigned %d images to split val' %len(self.split_ix['val']))
|
174 |
-
print('assigned %d images to split test' %len(self.split_ix['test']))
|
175 |
-
|
176 |
-
def get_captions(self, ix, seq_per_img):
|
177 |
-
# fetch the sequence labels
|
178 |
-
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
|
179 |
-
ix2 = self.label_end_ix[ix] - 1
|
180 |
-
ncap = ix2 - ix1 + 1 # number of captions available for this image
|
181 |
-
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
|
182 |
-
|
183 |
-
if ncap < seq_per_img:
|
184 |
-
# we need to subsample (with replacement)
|
185 |
-
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
|
186 |
-
for q in range(seq_per_img):
|
187 |
-
ixl = random.randint(ix1,ix2)
|
188 |
-
seq[q, :] = self.label[ixl, :self.seq_length]
|
189 |
-
else:
|
190 |
-
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
|
191 |
-
seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
|
192 |
-
|
193 |
-
return seq
|
194 |
-
|
195 |
-
def collate_func(self, batch):
|
196 |
-
seq_per_img = self.seq_per_img
|
197 |
-
|
198 |
-
fc_batch = []
|
199 |
-
att_batch = []
|
200 |
-
label_batch = []
|
201 |
-
|
202 |
-
clip_vis_feat_batch = []
|
203 |
-
|
204 |
-
wrapped = False
|
205 |
-
|
206 |
-
infos = []
|
207 |
-
gts = []
|
208 |
-
|
209 |
-
for sample in batch:
|
210 |
-
# fetch image
|
211 |
-
if self.use_clipscore:
|
212 |
-
tmp_fc, tmp_att, tmp_seq, \
|
213 |
-
ix, tmp_clip_vis_feat = sample
|
214 |
-
|
215 |
-
clip_vis_feat_batch.append(tmp_clip_vis_feat)
|
216 |
-
else:
|
217 |
-
tmp_fc, tmp_att, tmp_seq, \
|
218 |
-
ix = sample
|
219 |
-
|
220 |
-
fc_batch.append(tmp_fc)
|
221 |
-
att_batch.append(tmp_att)
|
222 |
-
|
223 |
-
tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
|
224 |
-
if hasattr(self, 'h5_label_file'):
|
225 |
-
# if there is ground truth
|
226 |
-
tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
|
227 |
-
label_batch.append(tmp_label)
|
228 |
-
|
229 |
-
# Used for reward evaluation
|
230 |
-
if hasattr(self, 'h5_label_file'):
|
231 |
-
# if there is ground truth
|
232 |
-
gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
|
233 |
-
else:
|
234 |
-
gts.append([])
|
235 |
-
|
236 |
-
# record associated info as well
|
237 |
-
info_dict = {}
|
238 |
-
info_dict['ix'] = ix
|
239 |
-
info_dict['id'] = self.info['images'][ix]['id']
|
240 |
-
info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
|
241 |
-
infos.append(info_dict)
|
242 |
-
|
243 |
-
# #sort by att_feat length
|
244 |
-
# fc_batch, att_batch, label_batch, gts, infos = \
|
245 |
-
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
|
246 |
-
if self.use_clipscore:
|
247 |
-
fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \
|
248 |
-
zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True))
|
249 |
-
else:
|
250 |
-
fc_batch, att_batch, label_batch, gts, infos = \
|
251 |
-
zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
|
252 |
-
data = {}
|
253 |
-
data['fc_feats'] = np.stack(fc_batch)
|
254 |
-
# merge att_feats
|
255 |
-
max_att_len = max([_.shape[0] for _ in att_batch])
|
256 |
-
data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
|
257 |
-
for i in range(len(att_batch)):
|
258 |
-
data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
|
259 |
-
data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
|
260 |
-
for i in range(len(att_batch)):
|
261 |
-
data['att_masks'][i, :att_batch[i].shape[0]] = 1
|
262 |
-
# set att_masks to None if attention features have same length
|
263 |
-
if data['att_masks'].sum() == data['att_masks'].size:
|
264 |
-
data['att_masks'] = None
|
265 |
-
|
266 |
-
if self.use_clipscore:
|
267 |
-
data['clip_vis_feats'] = np.stack(clip_vis_feat_batch)
|
268 |
-
|
269 |
-
data['labels'] = np.vstack(label_batch)
|
270 |
-
# generate mask
|
271 |
-
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
|
272 |
-
mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
|
273 |
-
for ix, row in enumerate(mask_batch):
|
274 |
-
row[:nonzeros[ix]] = 1
|
275 |
-
data['masks'] = mask_batch
|
276 |
-
data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
|
277 |
-
data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
|
278 |
-
|
279 |
-
data['gts'] = gts # all ground truth captions of each images
|
280 |
-
data['infos'] = infos
|
281 |
-
|
282 |
-
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
|
283 |
-
|
284 |
-
return data
|
285 |
-
|
286 |
-
def __getitem__(self, ix):
|
287 |
-
"""This function returns a tuple that is further passed to collate_fn
|
288 |
-
"""
|
289 |
-
if self.use_att:
|
290 |
-
att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
|
291 |
-
# Reshape to K x C
|
292 |
-
att_feat = att_feat.reshape(-1, att_feat.shape[-1])
|
293 |
-
if self.norm_att_feat:
|
294 |
-
att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
|
295 |
-
if self.use_box:
|
296 |
-
box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
|
297 |
-
# devided by image width and height
|
298 |
-
x1,y1,x2,y2 = np.hsplit(box_feat, 4)
|
299 |
-
h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
|
300 |
-
box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
|
301 |
-
if self.norm_box_feat:
|
302 |
-
box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
|
303 |
-
att_feat = np.hstack([att_feat, box_feat])
|
304 |
-
# sort the features by the size of boxes
|
305 |
-
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
|
306 |
-
else:
|
307 |
-
att_feat = np.zeros((0,0), dtype='float32')
|
308 |
-
if self.use_fc:
|
309 |
-
try:
|
310 |
-
fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
|
311 |
-
except:
|
312 |
-
# Use average of attention when there is no fc provided (For bottomup feature)
|
313 |
-
fc_feat = att_feat.mean(0)
|
314 |
-
else:
|
315 |
-
fc_feat = np.zeros((0), dtype='float32')
|
316 |
-
if hasattr(self, 'h5_label_file'):
|
317 |
-
seq = self.get_captions(ix, self.seq_per_img)
|
318 |
-
else:
|
319 |
-
seq = None
|
320 |
-
|
321 |
-
if self.use_clipscore:
|
322 |
-
clip_vis_feat = self.clipscore_loader.get(
|
323 |
-
str(self.info['images'][ix]['id']))
|
324 |
-
|
325 |
-
return (fc_feat,
|
326 |
-
att_feat, seq,
|
327 |
-
ix, clip_vis_feat)
|
328 |
-
|
329 |
-
return (fc_feat,
|
330 |
-
att_feat, seq,
|
331 |
-
ix)
|
332 |
-
|
333 |
-
def __len__(self):
|
334 |
-
return len(self.info['images'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/models/AoAModel.py
DELETED
@@ -1,228 +0,0 @@
|
|
1 |
-
# Implementation for paper 'Attention on Attention for Image Captioning'
|
2 |
-
# https://arxiv.org/abs/1908.06954
|
3 |
-
|
4 |
-
# RT: Code from original author's repo: https://github.com/husthuaan/AoANet/
|
5 |
-
|
6 |
-
from __future__ import absolute_import
|
7 |
-
from __future__ import division
|
8 |
-
from __future__ import print_function
|
9 |
-
|
10 |
-
import torch
|
11 |
-
import torch.nn as nn
|
12 |
-
import torch.nn.functional as F
|
13 |
-
|
14 |
-
from .AttModel import pack_wrapper, AttModel, Attention
|
15 |
-
from .TransformerModel import LayerNorm, attention, clones, SublayerConnection, PositionwiseFeedForward
|
16 |
-
|
17 |
-
class MultiHeadedDotAttention(nn.Module):
|
18 |
-
def __init__(self, h, d_model, dropout=0.1, scale=1, project_k_v=1, use_output_layer=1, do_aoa=0, norm_q=0, dropout_aoa=0.3):
|
19 |
-
super(MultiHeadedDotAttention, self).__init__()
|
20 |
-
assert d_model * scale % h == 0
|
21 |
-
# We assume d_v always equals d_k
|
22 |
-
self.d_k = d_model * scale // h
|
23 |
-
self.h = h
|
24 |
-
|
25 |
-
# Do we need to do linear projections on K and V?
|
26 |
-
self.project_k_v = project_k_v
|
27 |
-
|
28 |
-
# normalize the query?
|
29 |
-
if norm_q:
|
30 |
-
self.norm = LayerNorm(d_model)
|
31 |
-
else:
|
32 |
-
self.norm = lambda x:x
|
33 |
-
self.linears = clones(nn.Linear(d_model, d_model * scale), 1 + 2 * project_k_v)
|
34 |
-
|
35 |
-
# output linear layer after the multi-head attention?
|
36 |
-
self.output_layer = nn.Linear(d_model * scale, d_model)
|
37 |
-
|
38 |
-
# apply aoa after attention?
|
39 |
-
self.use_aoa = do_aoa
|
40 |
-
if self.use_aoa:
|
41 |
-
self.aoa_layer = nn.Sequential(nn.Linear((1 + scale) * d_model, 2 * d_model), nn.GLU())
|
42 |
-
# dropout to the input of AoA layer
|
43 |
-
if dropout_aoa > 0:
|
44 |
-
self.dropout_aoa = nn.Dropout(p=dropout_aoa)
|
45 |
-
else:
|
46 |
-
self.dropout_aoa = lambda x:x
|
47 |
-
|
48 |
-
if self.use_aoa or not use_output_layer:
|
49 |
-
# AoA doesn't need the output linear layer
|
50 |
-
del self.output_layer
|
51 |
-
self.output_layer = lambda x:x
|
52 |
-
|
53 |
-
self.attn = None
|
54 |
-
self.dropout = nn.Dropout(p=dropout)
|
55 |
-
|
56 |
-
def forward(self, query, value, key, mask=None):
|
57 |
-
if mask is not None:
|
58 |
-
if len(mask.size()) == 2:
|
59 |
-
mask = mask.unsqueeze(-2)
|
60 |
-
# Same mask applied to all h heads.
|
61 |
-
mask = mask.unsqueeze(1)
|
62 |
-
|
63 |
-
single_query = 0
|
64 |
-
if len(query.size()) == 2:
|
65 |
-
single_query = 1
|
66 |
-
query = query.unsqueeze(1)
|
67 |
-
|
68 |
-
nbatches = query.size(0)
|
69 |
-
|
70 |
-
query = self.norm(query)
|
71 |
-
|
72 |
-
# Do all the linear projections in batch from d_model => h x d_k
|
73 |
-
if self.project_k_v == 0:
|
74 |
-
query_ = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
75 |
-
key_ = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
76 |
-
value_ = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
77 |
-
else:
|
78 |
-
query_, key_, value_ = \
|
79 |
-
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
80 |
-
for l, x in zip(self.linears, (query, key, value))]
|
81 |
-
|
82 |
-
# Apply attention on all the projected vectors in batch.
|
83 |
-
x, self.attn = attention(query_, key_, value_, mask=mask,
|
84 |
-
dropout=self.dropout)
|
85 |
-
|
86 |
-
# "Concat" using a view
|
87 |
-
x = x.transpose(1, 2).contiguous() \
|
88 |
-
.view(nbatches, -1, self.h * self.d_k)
|
89 |
-
|
90 |
-
if self.use_aoa:
|
91 |
-
# Apply AoA
|
92 |
-
x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1)))
|
93 |
-
x = self.output_layer(x)
|
94 |
-
|
95 |
-
if single_query:
|
96 |
-
query = query.squeeze(1)
|
97 |
-
x = x.squeeze(1)
|
98 |
-
return x
|
99 |
-
|
100 |
-
class AoA_Refiner_Layer(nn.Module):
|
101 |
-
def __init__(self, size, self_attn, feed_forward, dropout):
|
102 |
-
super(AoA_Refiner_Layer, self).__init__()
|
103 |
-
self.self_attn = self_attn
|
104 |
-
self.feed_forward = feed_forward
|
105 |
-
self.use_ff = 0
|
106 |
-
if self.feed_forward is not None:
|
107 |
-
self.use_ff = 1
|
108 |
-
self.sublayer = clones(SublayerConnection(size, dropout), 1+self.use_ff)
|
109 |
-
self.size = size
|
110 |
-
|
111 |
-
def forward(self, x, mask):
|
112 |
-
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
|
113 |
-
return self.sublayer[-1](x, self.feed_forward) if self.use_ff else x
|
114 |
-
|
115 |
-
class AoA_Refiner_Core(nn.Module):
|
116 |
-
def __init__(self, opt):
|
117 |
-
super(AoA_Refiner_Core, self).__init__()
|
118 |
-
attn = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=1, scale=opt.multi_head_scale, do_aoa=opt.refine_aoa, norm_q=0, dropout_aoa=getattr(opt, 'dropout_aoa', 0.3))
|
119 |
-
layer = AoA_Refiner_Layer(opt.rnn_size, attn, PositionwiseFeedForward(opt.rnn_size, 2048, 0.1) if opt.use_ff else None, 0.1)
|
120 |
-
self.layers = clones(layer, 6)
|
121 |
-
self.norm = LayerNorm(layer.size)
|
122 |
-
|
123 |
-
def forward(self, x, mask):
|
124 |
-
for layer in self.layers:
|
125 |
-
x = layer(x, mask)
|
126 |
-
return self.norm(x)
|
127 |
-
|
128 |
-
class AoA_Decoder_Core(nn.Module):
|
129 |
-
def __init__(self, opt):
|
130 |
-
super(AoA_Decoder_Core, self).__init__()
|
131 |
-
self.drop_prob_lm = opt.drop_prob_lm
|
132 |
-
self.d_model = opt.rnn_size
|
133 |
-
self.use_multi_head = opt.use_multi_head
|
134 |
-
self.multi_head_scale = opt.multi_head_scale
|
135 |
-
self.use_ctx_drop = getattr(opt, 'ctx_drop', 0)
|
136 |
-
self.out_res = getattr(opt, 'out_res', 0)
|
137 |
-
self.decoder_type = getattr(opt, 'decoder_type', 'AoA')
|
138 |
-
self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1
|
139 |
-
self.out_drop = nn.Dropout(self.drop_prob_lm)
|
140 |
-
|
141 |
-
if self.decoder_type == 'AoA':
|
142 |
-
# AoA layer
|
143 |
-
self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, 2 * opt.rnn_size), nn.GLU())
|
144 |
-
elif self.decoder_type == 'LSTM':
|
145 |
-
# LSTM layer
|
146 |
-
self.att2ctx = nn.LSTMCell(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size)
|
147 |
-
else:
|
148 |
-
# Base linear layer
|
149 |
-
self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size), nn.ReLU())
|
150 |
-
|
151 |
-
# if opt.use_multi_head == 1: # TODO, not implemented for now
|
152 |
-
# self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale)
|
153 |
-
if opt.use_multi_head == 2:
|
154 |
-
self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1)
|
155 |
-
else:
|
156 |
-
self.attention = Attention(opt)
|
157 |
-
|
158 |
-
if self.use_ctx_drop:
|
159 |
-
self.ctx_drop = nn.Dropout(self.drop_prob_lm)
|
160 |
-
else:
|
161 |
-
self.ctx_drop = lambda x :x
|
162 |
-
|
163 |
-
def forward(self, xt, mean_feats, att_feats, p_att_feats, state, att_masks=None):
|
164 |
-
# state[0][1] is the context vector at the last step
|
165 |
-
h_att, c_att = self.att_lstm(torch.cat([xt, mean_feats + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0]))
|
166 |
-
|
167 |
-
if self.use_multi_head == 2:
|
168 |
-
att = self.attention(h_att, p_att_feats.narrow(2, 0, self.multi_head_scale * self.d_model), p_att_feats.narrow(2, self.multi_head_scale * self.d_model, self.multi_head_scale * self.d_model), att_masks)
|
169 |
-
else:
|
170 |
-
att = self.attention(h_att, att_feats, p_att_feats, att_masks)
|
171 |
-
|
172 |
-
ctx_input = torch.cat([att, h_att], 1)
|
173 |
-
if self.decoder_type == 'LSTM':
|
174 |
-
output, c_logic = self.att2ctx(ctx_input, (state[0][1], state[1][1]))
|
175 |
-
state = (torch.stack((h_att, output)), torch.stack((c_att, c_logic)))
|
176 |
-
else:
|
177 |
-
output = self.att2ctx(ctx_input)
|
178 |
-
# save the context vector to state[0][1]
|
179 |
-
state = (torch.stack((h_att, output)), torch.stack((c_att, state[1][1])))
|
180 |
-
|
181 |
-
if self.out_res:
|
182 |
-
# add residual connection
|
183 |
-
output = output + h_att
|
184 |
-
|
185 |
-
output = self.out_drop(output)
|
186 |
-
return output, state
|
187 |
-
|
188 |
-
class AoAModel(AttModel):
|
189 |
-
def __init__(self, opt):
|
190 |
-
super(AoAModel, self).__init__(opt)
|
191 |
-
self.num_layers = 2
|
192 |
-
# mean pooling
|
193 |
-
self.use_mean_feats = getattr(opt, 'mean_feats', 1)
|
194 |
-
if opt.use_multi_head == 2:
|
195 |
-
del self.ctx2att
|
196 |
-
self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.multi_head_scale * opt.rnn_size)
|
197 |
-
|
198 |
-
if self.use_mean_feats:
|
199 |
-
del self.fc_embed
|
200 |
-
if opt.refine:
|
201 |
-
self.refiner = AoA_Refiner_Core(opt)
|
202 |
-
else:
|
203 |
-
self.refiner = lambda x,y : x
|
204 |
-
self.core = AoA_Decoder_Core(opt)
|
205 |
-
|
206 |
-
self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
|
207 |
-
|
208 |
-
|
209 |
-
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
210 |
-
att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
211 |
-
|
212 |
-
# embed att feats
|
213 |
-
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
214 |
-
att_feats = self.refiner(att_feats, att_masks)
|
215 |
-
|
216 |
-
if self.use_mean_feats:
|
217 |
-
# meaning pooling
|
218 |
-
if att_masks is None:
|
219 |
-
mean_feats = torch.mean(att_feats, dim=1)
|
220 |
-
else:
|
221 |
-
mean_feats = (torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1))
|
222 |
-
else:
|
223 |
-
mean_feats = self.fc_embed(fc_feats)
|
224 |
-
|
225 |
-
# Project the attention feats first to reduce memory and computation.
|
226 |
-
p_att_feats = self.ctx2att(att_feats)
|
227 |
-
|
228 |
-
return mean_feats, att_feats, p_att_feats, att_masks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/models/AttEnsemble.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
# This file is the implementation for ensemble evaluation.
|
2 |
-
|
3 |
-
from __future__ import absolute_import
|
4 |
-
from __future__ import division
|
5 |
-
from __future__ import print_function
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
import torch.nn as nn
|
10 |
-
import torch.nn.functional as F
|
11 |
-
from torch.autograd import *
|
12 |
-
|
13 |
-
from .CaptionModel import CaptionModel
|
14 |
-
from .AttModel import pack_wrapper, AttModel
|
15 |
-
|
16 |
-
class AttEnsemble(AttModel):
|
17 |
-
def __init__(self, models, weights=None):
|
18 |
-
CaptionModel.__init__(self)
|
19 |
-
# super(AttEnsemble, self).__init__()
|
20 |
-
|
21 |
-
self.models = nn.ModuleList(models)
|
22 |
-
self.vocab_size = models[0].vocab_size
|
23 |
-
self.seq_length = models[0].seq_length
|
24 |
-
self.bad_endings_ix = models[0].bad_endings_ix
|
25 |
-
self.ss_prob = 0
|
26 |
-
weights = weights or [1.0] * len(self.models)
|
27 |
-
self.register_buffer('weights', torch.tensor(weights))
|
28 |
-
|
29 |
-
def init_hidden(self, batch_size):
|
30 |
-
state = [m.init_hidden(batch_size) for m in self.models]
|
31 |
-
return self.pack_state(state)
|
32 |
-
|
33 |
-
def pack_state(self, state):
|
34 |
-
self.state_lengths = [len(_) for _ in state]
|
35 |
-
return sum([list(_) for _ in state], [])
|
36 |
-
|
37 |
-
def unpack_state(self, state):
|
38 |
-
out = []
|
39 |
-
for l in self.state_lengths:
|
40 |
-
out.append(state[:l])
|
41 |
-
state = state[l:]
|
42 |
-
return out
|
43 |
-
|
44 |
-
def embed(self, it):
|
45 |
-
return [m.embed(it) for m in self.models]
|
46 |
-
|
47 |
-
def core(self, *args):
|
48 |
-
return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))])
|
49 |
-
|
50 |
-
def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state, output_logsoftmax=1):
|
51 |
-
# 'it' contains a word index
|
52 |
-
xt = self.embed(it)
|
53 |
-
|
54 |
-
state = self.unpack_state(state)
|
55 |
-
output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks)
|
56 |
-
logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mul(self.weights).div(self.weights.sum()).sum(-1).log()
|
57 |
-
|
58 |
-
return logprobs, self.pack_state(state)
|
59 |
-
|
60 |
-
def _prepare_feature(self, *args):
|
61 |
-
return tuple(zip(*[m._prepare_feature(*args) for m in self.models]))
|
62 |
-
|
63 |
-
def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
64 |
-
beam_size = opt.get('beam_size', 10)
|
65 |
-
batch_size = fc_feats.size(0)
|
66 |
-
|
67 |
-
fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
68 |
-
|
69 |
-
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
70 |
-
seq = torch.LongTensor(self.seq_length, batch_size).zero_()
|
71 |
-
seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1)
|
72 |
-
# lets process every image independently for now, for simplicity
|
73 |
-
|
74 |
-
self.done_beams = [[] for _ in range(batch_size)]
|
75 |
-
for k in range(batch_size):
|
76 |
-
state = self.init_hidden(beam_size)
|
77 |
-
tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)]
|
78 |
-
tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
|
79 |
-
tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)]
|
80 |
-
tmp_att_masks = [att_masks[i][k:k+1].expand(*((beam_size,)+att_masks[i].size()[1:])).contiguous() if att_masks[i] is not None else att_masks[i] for i,m in enumerate(self.models)]
|
81 |
-
|
82 |
-
it = fc_feats[0].data.new(beam_size).long().zero_()
|
83 |
-
logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
|
84 |
-
|
85 |
-
self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
|
86 |
-
seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
87 |
-
seqLogprobs[:, k] = self.done_beams[k][0]['logps']
|
88 |
-
# return the samples and their log likelihoods
|
89 |
-
return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
|
90 |
-
# return the samples and their log likelihoods
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/models/AttModel.py
DELETED
@@ -1,969 +0,0 @@
|
|
1 |
-
# This file contains Att2in2, AdaAtt, AdaAttMO, UpDown model
|
2 |
-
|
3 |
-
# AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning
|
4 |
-
# https://arxiv.org/abs/1612.01887
|
5 |
-
# AdaAttMO is a modified version with maxout lstm
|
6 |
-
|
7 |
-
# Att2in is from Self-critical Sequence Training for Image Captioning
|
8 |
-
# https://arxiv.org/abs/1612.00563
|
9 |
-
# In this file we only have Att2in2, which is a slightly different version of att2in,
|
10 |
-
# in which the img feature embedding and word embedding is the same as what in adaatt.
|
11 |
-
|
12 |
-
# UpDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA
|
13 |
-
# https://arxiv.org/abs/1707.07998
|
14 |
-
# However, it may not be identical to the author's architecture.
|
15 |
-
|
16 |
-
from __future__ import absolute_import
|
17 |
-
from __future__ import division
|
18 |
-
from __future__ import print_function
|
19 |
-
|
20 |
-
import numpy as np
|
21 |
-
import torch
|
22 |
-
import torch.nn as nn
|
23 |
-
import torch.nn.functional as F
|
24 |
-
from . import utils
|
25 |
-
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
|
26 |
-
|
27 |
-
from .CaptionModel import CaptionModel
|
28 |
-
|
29 |
-
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
|
30 |
-
bad_endings += ['the']
|
31 |
-
|
32 |
-
def sort_pack_padded_sequence(input, lengths):
|
33 |
-
sorted_lengths, indices = torch.sort(lengths, descending=True)
|
34 |
-
# tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True)
|
35 |
-
tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
|
36 |
-
inv_ix = indices.clone()
|
37 |
-
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
|
38 |
-
return tmp, inv_ix
|
39 |
-
|
40 |
-
def pad_unsort_packed_sequence(input, inv_ix):
|
41 |
-
tmp, _ = pad_packed_sequence(input, batch_first=True)
|
42 |
-
tmp = tmp[inv_ix]
|
43 |
-
return tmp
|
44 |
-
|
45 |
-
def pack_wrapper(module, att_feats, att_masks):
|
46 |
-
if att_masks is not None:
|
47 |
-
packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
|
48 |
-
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
|
49 |
-
else:
|
50 |
-
return module(att_feats)
|
51 |
-
|
52 |
-
class AttModel(CaptionModel):
|
53 |
-
def __init__(self, opt):
|
54 |
-
super(AttModel, self).__init__()
|
55 |
-
self.vocab_size = opt.vocab_size
|
56 |
-
self.input_encoding_size = opt.input_encoding_size
|
57 |
-
#self.rnn_type = opt.rnn_type
|
58 |
-
self.rnn_size = opt.rnn_size
|
59 |
-
self.num_layers = opt.num_layers
|
60 |
-
self.drop_prob_lm = opt.drop_prob_lm
|
61 |
-
self.seq_length = getattr(opt, 'max_length', 20) or opt.seq_length # maximum sample length
|
62 |
-
self.fc_feat_size = opt.fc_feat_size
|
63 |
-
self.att_feat_size = opt.att_feat_size
|
64 |
-
self.att_hid_size = opt.att_hid_size
|
65 |
-
|
66 |
-
self.bos_idx = getattr(opt, 'bos_idx', 0)
|
67 |
-
self.eos_idx = getattr(opt, 'eos_idx', 0)
|
68 |
-
self.pad_idx = getattr(opt, 'pad_idx', 0)
|
69 |
-
|
70 |
-
self.use_bn = getattr(opt, 'use_bn', 0)
|
71 |
-
|
72 |
-
self.ss_prob = 0.0 # Schedule sampling probability
|
73 |
-
|
74 |
-
self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size),
|
75 |
-
nn.ReLU(),
|
76 |
-
nn.Dropout(self.drop_prob_lm))
|
77 |
-
self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size),
|
78 |
-
nn.ReLU(),
|
79 |
-
nn.Dropout(self.drop_prob_lm))
|
80 |
-
self.att_embed = nn.Sequential(*(
|
81 |
-
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
|
82 |
-
(nn.Linear(self.att_feat_size, self.rnn_size),
|
83 |
-
nn.ReLU(),
|
84 |
-
nn.Dropout(self.drop_prob_lm))+
|
85 |
-
((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ())))
|
86 |
-
|
87 |
-
self.logit_layers = getattr(opt, 'logit_layers', 1)
|
88 |
-
if self.logit_layers == 1:
|
89 |
-
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
|
90 |
-
else:
|
91 |
-
self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)]
|
92 |
-
self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)]))
|
93 |
-
self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size)
|
94 |
-
|
95 |
-
# For remove bad endding
|
96 |
-
self.vocab = opt.vocab
|
97 |
-
self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings]
|
98 |
-
|
99 |
-
def init_hidden(self, bsz):
|
100 |
-
weight = self.logit.weight \
|
101 |
-
if hasattr(self.logit, "weight") \
|
102 |
-
else self.logit[0].weight
|
103 |
-
return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
|
104 |
-
weight.new_zeros(self.num_layers, bsz, self.rnn_size))
|
105 |
-
|
106 |
-
def clip_att(self, att_feats, att_masks):
|
107 |
-
# Clip the length of att_masks and att_feats to the maximum length
|
108 |
-
if att_masks is not None:
|
109 |
-
max_len = att_masks.data.long().sum(1).max()
|
110 |
-
att_feats = att_feats[:, :max_len].contiguous()
|
111 |
-
att_masks = att_masks[:, :max_len].contiguous()
|
112 |
-
return att_feats, att_masks
|
113 |
-
|
114 |
-
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
115 |
-
att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
116 |
-
|
117 |
-
# embed fc and att feats
|
118 |
-
fc_feats = self.fc_embed(fc_feats)
|
119 |
-
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
120 |
-
|
121 |
-
# Project the attention feats first to reduce memory and computation comsumptions.
|
122 |
-
p_att_feats = self.ctx2att(att_feats)
|
123 |
-
|
124 |
-
return fc_feats, att_feats, p_att_feats, att_masks
|
125 |
-
|
126 |
-
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
127 |
-
batch_size = fc_feats.size(0)
|
128 |
-
if seq.ndim == 3: # B * seq_per_img * seq_len
|
129 |
-
seq = seq.reshape(-1, seq.shape[2])
|
130 |
-
seq_per_img = seq.shape[0] // batch_size
|
131 |
-
state = self.init_hidden(batch_size*seq_per_img)
|
132 |
-
|
133 |
-
outputs = fc_feats.new_zeros(batch_size*seq_per_img, seq.size(1), self.vocab_size+1)
|
134 |
-
|
135 |
-
# Prepare the features
|
136 |
-
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
137 |
-
# pp_att_feats is used for attention, we cache it in advance to reduce computation cost
|
138 |
-
|
139 |
-
if seq_per_img > 1:
|
140 |
-
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(seq_per_img,
|
141 |
-
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
|
142 |
-
)
|
143 |
-
|
144 |
-
for i in range(seq.size(1)):
|
145 |
-
if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
|
146 |
-
sample_prob = fc_feats.new(batch_size*seq_per_img).uniform_(0, 1)
|
147 |
-
sample_mask = sample_prob < self.ss_prob
|
148 |
-
if sample_mask.sum() == 0:
|
149 |
-
it = seq[:, i].clone()
|
150 |
-
else:
|
151 |
-
sample_ind = sample_mask.nonzero().view(-1)
|
152 |
-
it = seq[:, i].data.clone()
|
153 |
-
prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1)
|
154 |
-
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
|
155 |
-
else:
|
156 |
-
it = seq[:, i].clone()
|
157 |
-
# break if all the sequences end
|
158 |
-
if i >= 1 and seq[:, i].sum() == 0:
|
159 |
-
break
|
160 |
-
|
161 |
-
output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
|
162 |
-
outputs[:, i] = output
|
163 |
-
|
164 |
-
return outputs
|
165 |
-
|
166 |
-
def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1):
|
167 |
-
# 'it' contains a word index
|
168 |
-
xt = self.embed(it)
|
169 |
-
|
170 |
-
output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
|
171 |
-
if output_logsoftmax:
|
172 |
-
logprobs = F.log_softmax(self.logit(output), dim=1)
|
173 |
-
else:
|
174 |
-
logprobs = self.logit(output)
|
175 |
-
|
176 |
-
return logprobs, state
|
177 |
-
|
178 |
-
def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
179 |
-
beam_size = opt.get('beam_size', 10)
|
180 |
-
group_size = opt.get('group_size', 1)
|
181 |
-
sample_n = opt.get('sample_n', 10)
|
182 |
-
# when sample_n == beam_size then each beam is a sample.
|
183 |
-
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
|
184 |
-
batch_size = fc_feats.size(0)
|
185 |
-
|
186 |
-
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
187 |
-
|
188 |
-
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
189 |
-
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
|
190 |
-
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
|
191 |
-
# lets process every image independently for now, for simplicity
|
192 |
-
|
193 |
-
self.done_beams = [[] for _ in range(batch_size)]
|
194 |
-
for k in range(batch_size):
|
195 |
-
state = self.init_hidden(beam_size)
|
196 |
-
tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks = utils.repeat_tensors(beam_size,
|
197 |
-
[p_fc_feats[k:k+1], p_att_feats[k:k+1], pp_att_feats[k:k+1], p_att_masks[k:k+1] if att_masks is not None else None]
|
198 |
-
)
|
199 |
-
|
200 |
-
for t in range(1):
|
201 |
-
if t == 0: # input <bos>
|
202 |
-
it = fc_feats.new_full([beam_size], self.bos_idx, dtype=torch.long)
|
203 |
-
|
204 |
-
logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
|
205 |
-
|
206 |
-
self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt)
|
207 |
-
if sample_n == beam_size:
|
208 |
-
for _n in range(sample_n):
|
209 |
-
seq[k*sample_n+_n, :] = self.done_beams[k][_n]['seq']
|
210 |
-
seqLogprobs[k*sample_n+_n, :] = self.done_beams[k][_n]['logps']
|
211 |
-
else:
|
212 |
-
seq[k, :] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
213 |
-
seqLogprobs[k, :] = self.done_beams[k][0]['logps']
|
214 |
-
# return the samples and their log likelihoods
|
215 |
-
return seq, seqLogprobs
|
216 |
-
|
217 |
-
|
218 |
-
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
219 |
-
beam_size = opt.get('beam_size', 10)
|
220 |
-
group_size = opt.get('group_size', 1)
|
221 |
-
sample_n = opt.get('sample_n', 10)
|
222 |
-
# when sample_n == beam_size then each beam is a sample.
|
223 |
-
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
|
224 |
-
batch_size = fc_feats.size(0)
|
225 |
-
|
226 |
-
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
227 |
-
|
228 |
-
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
229 |
-
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
|
230 |
-
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
|
231 |
-
# lets process every image independently for now, for simplicity
|
232 |
-
|
233 |
-
self.done_beams = [[] for _ in range(batch_size)]
|
234 |
-
|
235 |
-
state = self.init_hidden(batch_size)
|
236 |
-
|
237 |
-
# first step, feed bos
|
238 |
-
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
|
239 |
-
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
|
240 |
-
|
241 |
-
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size,
|
242 |
-
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
|
243 |
-
)
|
244 |
-
self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt)
|
245 |
-
for k in range(batch_size):
|
246 |
-
if sample_n == beam_size:
|
247 |
-
for _n in range(sample_n):
|
248 |
-
seq_len = self.done_beams[k][_n]['seq'].shape[0]
|
249 |
-
seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq']
|
250 |
-
seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps']
|
251 |
-
else:
|
252 |
-
seq_len = self.done_beams[k][0]['seq'].shape[0]
|
253 |
-
seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
254 |
-
seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps']
|
255 |
-
# return the samples and their log likelihoods
|
256 |
-
return seq, seqLogprobs
|
257 |
-
|
258 |
-
def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
|
259 |
-
|
260 |
-
sample_method = opt.get('sample_method', 'greedy')
|
261 |
-
beam_size = opt.get('beam_size', 1)
|
262 |
-
temperature = opt.get('temperature', 1.0)
|
263 |
-
sample_n = int(opt.get('sample_n', 1))
|
264 |
-
group_size = opt.get('group_size', 1)
|
265 |
-
output_logsoftmax = opt.get('output_logsoftmax', 1)
|
266 |
-
decoding_constraint = opt.get('decoding_constraint', 0)
|
267 |
-
block_trigrams = opt.get('block_trigrams', 0)
|
268 |
-
remove_bad_endings = opt.get('remove_bad_endings', 0)
|
269 |
-
if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
|
270 |
-
return self._sample_beam(fc_feats, att_feats, att_masks, opt)
|
271 |
-
if group_size > 1:
|
272 |
-
return self._diverse_sample(fc_feats, att_feats, att_masks, opt)
|
273 |
-
|
274 |
-
batch_size = fc_feats.size(0)
|
275 |
-
state = self.init_hidden(batch_size*sample_n)
|
276 |
-
|
277 |
-
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
278 |
-
|
279 |
-
if sample_n > 1:
|
280 |
-
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n,
|
281 |
-
[p_fc_feats, p_att_feats, pp_att_feats, p_att_masks]
|
282 |
-
)
|
283 |
-
|
284 |
-
trigrams = [] # will be a list of batch_size dictionaries
|
285 |
-
|
286 |
-
seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long)
|
287 |
-
seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1)
|
288 |
-
for t in range(self.seq_length + 1):
|
289 |
-
if t == 0: # input <bos>
|
290 |
-
it = fc_feats.new_full([batch_size*sample_n], self.bos_idx, dtype=torch.long)
|
291 |
-
|
292 |
-
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, output_logsoftmax=output_logsoftmax)
|
293 |
-
|
294 |
-
if decoding_constraint and t > 0:
|
295 |
-
tmp = logprobs.new_zeros(logprobs.size())
|
296 |
-
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
|
297 |
-
logprobs = logprobs + tmp
|
298 |
-
|
299 |
-
if remove_bad_endings and t > 0:
|
300 |
-
tmp = logprobs.new_zeros(logprobs.size())
|
301 |
-
prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
|
302 |
-
# Make it impossible to generate bad_endings
|
303 |
-
tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
|
304 |
-
logprobs = logprobs + tmp
|
305 |
-
|
306 |
-
# Mess with trigrams
|
307 |
-
# Copy from https://github.com/lukemelas/image-paragraph-captioning
|
308 |
-
if block_trigrams and t >= 3:
|
309 |
-
# Store trigram generated at last step
|
310 |
-
prev_two_batch = seq[:,t-3:t-1]
|
311 |
-
for i in range(batch_size): # = seq.size(0)
|
312 |
-
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
313 |
-
current = seq[i][t-1]
|
314 |
-
if t == 3: # initialize
|
315 |
-
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
|
316 |
-
elif t > 3:
|
317 |
-
if prev_two in trigrams[i]: # add to list
|
318 |
-
trigrams[i][prev_two].append(current)
|
319 |
-
else: # create list
|
320 |
-
trigrams[i][prev_two] = [current]
|
321 |
-
# Block used trigrams at next step
|
322 |
-
prev_two_batch = seq[:,t-2:t]
|
323 |
-
mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size
|
324 |
-
for i in range(batch_size):
|
325 |
-
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
326 |
-
if prev_two in trigrams[i]:
|
327 |
-
for j in trigrams[i][prev_two]:
|
328 |
-
mask[i,j] += 1
|
329 |
-
# Apply mask to log probs
|
330 |
-
#logprobs = logprobs - (mask * 1e9)
|
331 |
-
alpha = 2.0 # = 4
|
332 |
-
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
|
333 |
-
|
334 |
-
# sample the next word
|
335 |
-
if t == self.seq_length: # skip if we achieve maximum length
|
336 |
-
break
|
337 |
-
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)
|
338 |
-
|
339 |
-
# stop when all finished
|
340 |
-
if t == 0:
|
341 |
-
unfinished = it != self.eos_idx
|
342 |
-
else:
|
343 |
-
it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0
|
344 |
-
logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs)
|
345 |
-
unfinished = unfinished & (it != self.eos_idx)
|
346 |
-
seq[:,t] = it
|
347 |
-
seqLogprobs[:,t] = logprobs
|
348 |
-
# quit loop if all sequences have finished
|
349 |
-
if unfinished.sum() == 0:
|
350 |
-
break
|
351 |
-
|
352 |
-
return seq, seqLogprobs
|
353 |
-
|
354 |
-
def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}):
|
355 |
-
|
356 |
-
sample_method = opt.get('sample_method', 'greedy')
|
357 |
-
beam_size = opt.get('beam_size', 1)
|
358 |
-
temperature = opt.get('temperature', 1.0)
|
359 |
-
group_size = opt.get('group_size', 1)
|
360 |
-
diversity_lambda = opt.get('diversity_lambda', 0.5)
|
361 |
-
decoding_constraint = opt.get('decoding_constraint', 0)
|
362 |
-
block_trigrams = opt.get('block_trigrams', 0)
|
363 |
-
remove_bad_endings = opt.get('remove_bad_endings', 0)
|
364 |
-
|
365 |
-
batch_size = fc_feats.size(0)
|
366 |
-
state = self.init_hidden(batch_size)
|
367 |
-
|
368 |
-
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
|
369 |
-
|
370 |
-
trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries
|
371 |
-
|
372 |
-
seq_table = [fc_feats.new_full((batch_size, self.seq_length), self.pad_idx, dtype=torch.long) for _ in range(group_size)]
|
373 |
-
seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.seq_length) for _ in range(group_size)]
|
374 |
-
state_table = [self.init_hidden(batch_size) for _ in range(group_size)]
|
375 |
-
|
376 |
-
for tt in range(self.seq_length + group_size):
|
377 |
-
for divm in range(group_size):
|
378 |
-
t = tt - divm
|
379 |
-
seq = seq_table[divm]
|
380 |
-
seqLogprobs = seqLogprobs_table[divm]
|
381 |
-
trigrams = trigrams_table[divm]
|
382 |
-
if t >= 0 and t <= self.seq_length-1:
|
383 |
-
if t == 0: # input <bos>
|
384 |
-
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
|
385 |
-
else:
|
386 |
-
it = seq[:, t-1] # changed
|
387 |
-
|
388 |
-
logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state_table[divm]) # changed
|
389 |
-
logprobs = F.log_softmax(logprobs / temperature, dim=-1)
|
390 |
-
|
391 |
-
# Add diversity
|
392 |
-
if divm > 0:
|
393 |
-
unaug_logprobs = logprobs.clone()
|
394 |
-
for prev_choice in range(divm):
|
395 |
-
prev_decisions = seq_table[prev_choice][:, t]
|
396 |
-
logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda
|
397 |
-
|
398 |
-
if decoding_constraint and t > 0:
|
399 |
-
tmp = logprobs.new_zeros(logprobs.size())
|
400 |
-
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
|
401 |
-
logprobs = logprobs + tmp
|
402 |
-
|
403 |
-
if remove_bad_endings and t > 0:
|
404 |
-
tmp = logprobs.new_zeros(logprobs.size())
|
405 |
-
prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
|
406 |
-
# Impossible to generate remove_bad_endings
|
407 |
-
tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
|
408 |
-
logprobs = logprobs + tmp
|
409 |
-
|
410 |
-
# Mess with trigrams
|
411 |
-
if block_trigrams and t >= 3:
|
412 |
-
# Store trigram generated at last step
|
413 |
-
prev_two_batch = seq[:,t-3:t-1]
|
414 |
-
for i in range(batch_size): # = seq.size(0)
|
415 |
-
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
416 |
-
current = seq[i][t-1]
|
417 |
-
if t == 3: # initialize
|
418 |
-
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
|
419 |
-
elif t > 3:
|
420 |
-
if prev_two in trigrams[i]: # add to list
|
421 |
-
trigrams[i][prev_two].append(current)
|
422 |
-
else: # create list
|
423 |
-
trigrams[i][prev_two] = [current]
|
424 |
-
# Block used trigrams at next step
|
425 |
-
prev_two_batch = seq[:,t-2:t]
|
426 |
-
mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
|
427 |
-
for i in range(batch_size):
|
428 |
-
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
|
429 |
-
if prev_two in trigrams[i]:
|
430 |
-
for j in trigrams[i][prev_two]:
|
431 |
-
mask[i,j] += 1
|
432 |
-
# Apply mask to log probs
|
433 |
-
#logprobs = logprobs - (mask * 1e9)
|
434 |
-
alpha = 2.0 # = 4
|
435 |
-
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
|
436 |
-
|
437 |
-
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1)
|
438 |
-
|
439 |
-
# stop when all finished
|
440 |
-
if t == 0:
|
441 |
-
unfinished = it != self.eos_idx
|
442 |
-
else:
|
443 |
-
unfinished = (seq[:,t-1] != self.pad_idx) & (seq[:,t-1] != self.eos_idx)
|
444 |
-
it[~unfinished] = self.pad_idx
|
445 |
-
unfinished = unfinished & (it != self.eos_idx) # changed
|
446 |
-
seq[:,t] = it
|
447 |
-
seqLogprobs[:,t] = sampleLogprobs.view(-1)
|
448 |
-
|
449 |
-
return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, 1).reshape(batch_size * group_size, -1)
|
450 |
-
|
451 |
-
class AdaAtt_lstm(nn.Module):
|
452 |
-
def __init__(self, opt, use_maxout=True):
|
453 |
-
super(AdaAtt_lstm, self).__init__()
|
454 |
-
self.input_encoding_size = opt.input_encoding_size
|
455 |
-
#self.rnn_type = opt.rnn_type
|
456 |
-
self.rnn_size = opt.rnn_size
|
457 |
-
self.num_layers = opt.num_layers
|
458 |
-
self.drop_prob_lm = opt.drop_prob_lm
|
459 |
-
self.fc_feat_size = opt.fc_feat_size
|
460 |
-
self.att_feat_size = opt.att_feat_size
|
461 |
-
self.att_hid_size = opt.att_hid_size
|
462 |
-
|
463 |
-
self.use_maxout = use_maxout
|
464 |
-
|
465 |
-
# Build a LSTM
|
466 |
-
self.w2h = nn.Linear(self.input_encoding_size, (4+(use_maxout==True)) * self.rnn_size)
|
467 |
-
self.v2h = nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size)
|
468 |
-
|
469 |
-
self.i2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers - 1)])
|
470 |
-
self.h2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers)])
|
471 |
-
|
472 |
-
# Layers for getting the fake region
|
473 |
-
if self.num_layers == 1:
|
474 |
-
self.r_w2h = nn.Linear(self.input_encoding_size, self.rnn_size)
|
475 |
-
self.r_v2h = nn.Linear(self.rnn_size, self.rnn_size)
|
476 |
-
else:
|
477 |
-
self.r_i2h = nn.Linear(self.rnn_size, self.rnn_size)
|
478 |
-
self.r_h2h = nn.Linear(self.rnn_size, self.rnn_size)
|
479 |
-
|
480 |
-
|
481 |
-
def forward(self, xt, img_fc, state):
|
482 |
-
|
483 |
-
hs = []
|
484 |
-
cs = []
|
485 |
-
for L in range(self.num_layers):
|
486 |
-
# c,h from previous timesteps
|
487 |
-
prev_h = state[0][L]
|
488 |
-
prev_c = state[1][L]
|
489 |
-
# the input to this layer
|
490 |
-
if L == 0:
|
491 |
-
x = xt
|
492 |
-
i2h = self.w2h(x) + self.v2h(img_fc)
|
493 |
-
else:
|
494 |
-
x = hs[-1]
|
495 |
-
x = F.dropout(x, self.drop_prob_lm, self.training)
|
496 |
-
i2h = self.i2h[L-1](x)
|
497 |
-
|
498 |
-
all_input_sums = i2h+self.h2h[L](prev_h)
|
499 |
-
|
500 |
-
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
|
501 |
-
sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
|
502 |
-
# decode the gates
|
503 |
-
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
|
504 |
-
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
|
505 |
-
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
|
506 |
-
# decode the write inputs
|
507 |
-
if not self.use_maxout:
|
508 |
-
in_transform = torch.tanh(all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size))
|
509 |
-
else:
|
510 |
-
in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size)
|
511 |
-
in_transform = torch.max(\
|
512 |
-
in_transform.narrow(1, 0, self.rnn_size),
|
513 |
-
in_transform.narrow(1, self.rnn_size, self.rnn_size))
|
514 |
-
# perform the LSTM update
|
515 |
-
next_c = forget_gate * prev_c + in_gate * in_transform
|
516 |
-
# gated cells form the output
|
517 |
-
tanh_nex_c = torch.tanh(next_c)
|
518 |
-
next_h = out_gate * tanh_nex_c
|
519 |
-
if L == self.num_layers-1:
|
520 |
-
if L == 0:
|
521 |
-
i2h = self.r_w2h(x) + self.r_v2h(img_fc)
|
522 |
-
else:
|
523 |
-
i2h = self.r_i2h(x)
|
524 |
-
n5 = i2h+self.r_h2h(prev_h)
|
525 |
-
fake_region = torch.sigmoid(n5) * tanh_nex_c
|
526 |
-
|
527 |
-
cs.append(next_c)
|
528 |
-
hs.append(next_h)
|
529 |
-
|
530 |
-
# set up the decoder
|
531 |
-
top_h = hs[-1]
|
532 |
-
top_h = F.dropout(top_h, self.drop_prob_lm, self.training)
|
533 |
-
fake_region = F.dropout(fake_region, self.drop_prob_lm, self.training)
|
534 |
-
|
535 |
-
state = (torch.cat([_.unsqueeze(0) for _ in hs], 0),
|
536 |
-
torch.cat([_.unsqueeze(0) for _ in cs], 0))
|
537 |
-
return top_h, fake_region, state
|
538 |
-
|
539 |
-
class AdaAtt_attention(nn.Module):
|
540 |
-
def __init__(self, opt):
|
541 |
-
super(AdaAtt_attention, self).__init__()
|
542 |
-
self.input_encoding_size = opt.input_encoding_size
|
543 |
-
#self.rnn_type = opt.rnn_type
|
544 |
-
self.rnn_size = opt.rnn_size
|
545 |
-
self.drop_prob_lm = opt.drop_prob_lm
|
546 |
-
self.att_hid_size = opt.att_hid_size
|
547 |
-
|
548 |
-
# fake region embed
|
549 |
-
self.fr_linear = nn.Sequential(
|
550 |
-
nn.Linear(self.rnn_size, self.input_encoding_size),
|
551 |
-
nn.ReLU(),
|
552 |
-
nn.Dropout(self.drop_prob_lm))
|
553 |
-
self.fr_embed = nn.Linear(self.input_encoding_size, self.att_hid_size)
|
554 |
-
|
555 |
-
# h out embed
|
556 |
-
self.ho_linear = nn.Sequential(
|
557 |
-
nn.Linear(self.rnn_size, self.input_encoding_size),
|
558 |
-
nn.Tanh(),
|
559 |
-
nn.Dropout(self.drop_prob_lm))
|
560 |
-
self.ho_embed = nn.Linear(self.input_encoding_size, self.att_hid_size)
|
561 |
-
|
562 |
-
self.alpha_net = nn.Linear(self.att_hid_size, 1)
|
563 |
-
self.att2h = nn.Linear(self.rnn_size, self.rnn_size)
|
564 |
-
|
565 |
-
def forward(self, h_out, fake_region, conv_feat, conv_feat_embed, att_masks=None):
|
566 |
-
|
567 |
-
# View into three dimensions
|
568 |
-
att_size = conv_feat.numel() // conv_feat.size(0) // self.rnn_size
|
569 |
-
conv_feat = conv_feat.view(-1, att_size, self.rnn_size)
|
570 |
-
conv_feat_embed = conv_feat_embed.view(-1, att_size, self.att_hid_size)
|
571 |
-
|
572 |
-
# view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num
|
573 |
-
fake_region = self.fr_linear(fake_region)
|
574 |
-
fake_region_embed = self.fr_embed(fake_region)
|
575 |
-
|
576 |
-
h_out_linear = self.ho_linear(h_out)
|
577 |
-
h_out_embed = self.ho_embed(h_out_linear)
|
578 |
-
|
579 |
-
txt_replicate = h_out_embed.unsqueeze(1).expand(h_out_embed.size(0), att_size + 1, h_out_embed.size(1))
|
580 |
-
|
581 |
-
img_all = torch.cat([fake_region.view(-1,1,self.input_encoding_size), conv_feat], 1)
|
582 |
-
img_all_embed = torch.cat([fake_region_embed.view(-1,1,self.input_encoding_size), conv_feat_embed], 1)
|
583 |
-
|
584 |
-
hA = torch.tanh(img_all_embed + txt_replicate)
|
585 |
-
hA = F.dropout(hA,self.drop_prob_lm, self.training)
|
586 |
-
|
587 |
-
hAflat = self.alpha_net(hA.view(-1, self.att_hid_size))
|
588 |
-
PI = F.softmax(hAflat.view(-1, att_size + 1), dim=1)
|
589 |
-
|
590 |
-
if att_masks is not None:
|
591 |
-
att_masks = att_masks.view(-1, att_size)
|
592 |
-
PI = PI * torch.cat([att_masks[:,:1], att_masks], 1) # assume one one at the first time step.
|
593 |
-
PI = PI / PI.sum(1, keepdim=True)
|
594 |
-
|
595 |
-
visAtt = torch.bmm(PI.unsqueeze(1), img_all)
|
596 |
-
visAttdim = visAtt.squeeze(1)
|
597 |
-
|
598 |
-
atten_out = visAttdim + h_out_linear
|
599 |
-
|
600 |
-
h = torch.tanh(self.att2h(atten_out))
|
601 |
-
h = F.dropout(h, self.drop_prob_lm, self.training)
|
602 |
-
return h
|
603 |
-
|
604 |
-
class AdaAttCore(nn.Module):
|
605 |
-
def __init__(self, opt, use_maxout=False):
|
606 |
-
super(AdaAttCore, self).__init__()
|
607 |
-
self.lstm = AdaAtt_lstm(opt, use_maxout)
|
608 |
-
self.attention = AdaAtt_attention(opt)
|
609 |
-
|
610 |
-
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
611 |
-
h_out, p_out, state = self.lstm(xt, fc_feats, state)
|
612 |
-
atten_out = self.attention(h_out, p_out, att_feats, p_att_feats, att_masks)
|
613 |
-
return atten_out, state
|
614 |
-
|
615 |
-
class UpDownCore(nn.Module):
|
616 |
-
def __init__(self, opt, use_maxout=False):
|
617 |
-
super(UpDownCore, self).__init__()
|
618 |
-
self.drop_prob_lm = opt.drop_prob_lm
|
619 |
-
|
620 |
-
self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) # we, fc, h^2_t-1
|
621 |
-
self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v
|
622 |
-
self.attention = Attention(opt)
|
623 |
-
|
624 |
-
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
625 |
-
prev_h = state[0][-1]
|
626 |
-
att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1)
|
627 |
-
|
628 |
-
h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0]))
|
629 |
-
|
630 |
-
att = self.attention(h_att, att_feats, p_att_feats, att_masks)
|
631 |
-
|
632 |
-
lang_lstm_input = torch.cat([att, h_att], 1)
|
633 |
-
# lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ?????
|
634 |
-
|
635 |
-
h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1]))
|
636 |
-
|
637 |
-
output = F.dropout(h_lang, self.drop_prob_lm, self.training)
|
638 |
-
state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang]))
|
639 |
-
|
640 |
-
return output, state
|
641 |
-
|
642 |
-
|
643 |
-
############################################################################
|
644 |
-
# Notice:
|
645 |
-
# StackAtt and DenseAtt are models that I randomly designed.
|
646 |
-
# They are not related to any paper.
|
647 |
-
############################################################################
|
648 |
-
|
649 |
-
from .FCModel import LSTMCore
|
650 |
-
class StackAttCore(nn.Module):
|
651 |
-
def __init__(self, opt, use_maxout=False):
|
652 |
-
super(StackAttCore, self).__init__()
|
653 |
-
self.drop_prob_lm = opt.drop_prob_lm
|
654 |
-
|
655 |
-
# self.att0 = Attention(opt)
|
656 |
-
self.att1 = Attention(opt)
|
657 |
-
self.att2 = Attention(opt)
|
658 |
-
|
659 |
-
opt_input_encoding_size = opt.input_encoding_size
|
660 |
-
opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size
|
661 |
-
self.lstm0 = LSTMCore(opt) # att_feat + word_embedding
|
662 |
-
opt.input_encoding_size = opt.rnn_size * 2
|
663 |
-
self.lstm1 = LSTMCore(opt)
|
664 |
-
self.lstm2 = LSTMCore(opt)
|
665 |
-
opt.input_encoding_size = opt_input_encoding_size
|
666 |
-
|
667 |
-
# self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size)
|
668 |
-
self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size)
|
669 |
-
|
670 |
-
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
671 |
-
# att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks)
|
672 |
-
h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]])
|
673 |
-
att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks)
|
674 |
-
h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]])
|
675 |
-
att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks)
|
676 |
-
h_2, state_2 = self.lstm2(torch.cat([h_1,att_res_2],1), [state[0][2:3], state[1][2:3]])
|
677 |
-
|
678 |
-
return h_2, [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)]
|
679 |
-
|
680 |
-
class DenseAttCore(nn.Module):
|
681 |
-
def __init__(self, opt, use_maxout=False):
|
682 |
-
super(DenseAttCore, self).__init__()
|
683 |
-
self.drop_prob_lm = opt.drop_prob_lm
|
684 |
-
|
685 |
-
# self.att0 = Attention(opt)
|
686 |
-
self.att1 = Attention(opt)
|
687 |
-
self.att2 = Attention(opt)
|
688 |
-
|
689 |
-
opt_input_encoding_size = opt.input_encoding_size
|
690 |
-
opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size
|
691 |
-
self.lstm0 = LSTMCore(opt) # att_feat + word_embedding
|
692 |
-
opt.input_encoding_size = opt.rnn_size * 2
|
693 |
-
self.lstm1 = LSTMCore(opt)
|
694 |
-
self.lstm2 = LSTMCore(opt)
|
695 |
-
opt.input_encoding_size = opt_input_encoding_size
|
696 |
-
|
697 |
-
# self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size)
|
698 |
-
self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size)
|
699 |
-
|
700 |
-
# fuse h_0 and h_1
|
701 |
-
self.fusion1 = nn.Sequential(nn.Linear(opt.rnn_size*2, opt.rnn_size),
|
702 |
-
nn.ReLU(),
|
703 |
-
nn.Dropout(opt.drop_prob_lm))
|
704 |
-
# fuse h_0, h_1 and h_2
|
705 |
-
self.fusion2 = nn.Sequential(nn.Linear(opt.rnn_size*3, opt.rnn_size),
|
706 |
-
nn.ReLU(),
|
707 |
-
nn.Dropout(opt.drop_prob_lm))
|
708 |
-
|
709 |
-
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
710 |
-
# att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks)
|
711 |
-
h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]])
|
712 |
-
att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks)
|
713 |
-
h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]])
|
714 |
-
att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks)
|
715 |
-
h_2, state_2 = self.lstm2(torch.cat([self.fusion1(torch.cat([h_0, h_1], 1)),att_res_2],1), [state[0][2:3], state[1][2:3]])
|
716 |
-
|
717 |
-
return self.fusion2(torch.cat([h_0, h_1, h_2], 1)), [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)]
|
718 |
-
|
719 |
-
class Attention(nn.Module):
|
720 |
-
def __init__(self, opt):
|
721 |
-
super(Attention, self).__init__()
|
722 |
-
self.rnn_size = opt.rnn_size
|
723 |
-
self.att_hid_size = opt.att_hid_size
|
724 |
-
|
725 |
-
self.h2att = nn.Linear(self.rnn_size, self.att_hid_size)
|
726 |
-
self.alpha_net = nn.Linear(self.att_hid_size, 1)
|
727 |
-
|
728 |
-
def forward(self, h, att_feats, p_att_feats, att_masks=None):
|
729 |
-
# The p_att_feats here is already projected
|
730 |
-
att_size = att_feats.numel() // att_feats.size(0) // att_feats.size(-1)
|
731 |
-
att = p_att_feats.view(-1, att_size, self.att_hid_size)
|
732 |
-
|
733 |
-
att_h = self.h2att(h) # batch * att_hid_size
|
734 |
-
att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size
|
735 |
-
dot = att + att_h # batch * att_size * att_hid_size
|
736 |
-
dot = torch.tanh(dot) # batch * att_size * att_hid_size
|
737 |
-
dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size
|
738 |
-
dot = self.alpha_net(dot) # (batch * att_size) * 1
|
739 |
-
dot = dot.view(-1, att_size) # batch * att_size
|
740 |
-
|
741 |
-
weight = F.softmax(dot, dim=1) # batch * att_size
|
742 |
-
if att_masks is not None:
|
743 |
-
weight = weight * att_masks.view(-1, att_size).to(weight)
|
744 |
-
weight = weight / weight.sum(1, keepdim=True) # normalize to 1
|
745 |
-
att_feats_ = att_feats.view(-1, att_size, att_feats.size(-1)) # batch * att_size * att_feat_size
|
746 |
-
att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size
|
747 |
-
|
748 |
-
return att_res
|
749 |
-
|
750 |
-
class Att2in2Core(nn.Module):
|
751 |
-
def __init__(self, opt):
|
752 |
-
super(Att2in2Core, self).__init__()
|
753 |
-
self.input_encoding_size = opt.input_encoding_size
|
754 |
-
#self.rnn_type = opt.rnn_type
|
755 |
-
self.rnn_size = opt.rnn_size
|
756 |
-
#self.num_layers = opt.num_layers
|
757 |
-
self.drop_prob_lm = opt.drop_prob_lm
|
758 |
-
self.fc_feat_size = opt.fc_feat_size
|
759 |
-
self.att_feat_size = opt.att_feat_size
|
760 |
-
self.att_hid_size = opt.att_hid_size
|
761 |
-
|
762 |
-
# Build a LSTM
|
763 |
-
self.a2c = nn.Linear(self.rnn_size, 2 * self.rnn_size)
|
764 |
-
self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
|
765 |
-
self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
|
766 |
-
self.dropout = nn.Dropout(self.drop_prob_lm)
|
767 |
-
|
768 |
-
self.attention = Attention(opt)
|
769 |
-
|
770 |
-
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
771 |
-
att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks)
|
772 |
-
|
773 |
-
all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
|
774 |
-
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
|
775 |
-
sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
|
776 |
-
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
|
777 |
-
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
|
778 |
-
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
|
779 |
-
|
780 |
-
in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \
|
781 |
-
self.a2c(att_res)
|
782 |
-
in_transform = torch.max(\
|
783 |
-
in_transform.narrow(1, 0, self.rnn_size),
|
784 |
-
in_transform.narrow(1, self.rnn_size, self.rnn_size))
|
785 |
-
next_c = forget_gate * state[1][-1] + in_gate * in_transform
|
786 |
-
next_h = out_gate * torch.tanh(next_c)
|
787 |
-
|
788 |
-
output = self.dropout(next_h)
|
789 |
-
state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
|
790 |
-
return output, state
|
791 |
-
|
792 |
-
class Att2inCore(Att2in2Core):
|
793 |
-
def __init__(self, opt):
|
794 |
-
super(Att2inCore, self).__init__(opt)
|
795 |
-
del self.a2c
|
796 |
-
self.a2c = nn.Linear(self.att_feat_size, 2 * self.rnn_size)
|
797 |
-
|
798 |
-
"""
|
799 |
-
Note this is my attempt to replicate att2all model in self-critical paper.
|
800 |
-
However, this is not a correct replication actually. Will fix it.
|
801 |
-
"""
|
802 |
-
class Att2all2Core(nn.Module):
|
803 |
-
def __init__(self, opt):
|
804 |
-
super(Att2all2Core, self).__init__()
|
805 |
-
self.input_encoding_size = opt.input_encoding_size
|
806 |
-
#self.rnn_type = opt.rnn_type
|
807 |
-
self.rnn_size = opt.rnn_size
|
808 |
-
#self.num_layers = opt.num_layers
|
809 |
-
self.drop_prob_lm = opt.drop_prob_lm
|
810 |
-
self.fc_feat_size = opt.fc_feat_size
|
811 |
-
self.att_feat_size = opt.att_feat_size
|
812 |
-
self.att_hid_size = opt.att_hid_size
|
813 |
-
|
814 |
-
# Build a LSTM
|
815 |
-
self.a2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
|
816 |
-
self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
|
817 |
-
self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
|
818 |
-
self.dropout = nn.Dropout(self.drop_prob_lm)
|
819 |
-
|
820 |
-
self.attention = Attention(opt)
|
821 |
-
|
822 |
-
def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None):
|
823 |
-
att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks)
|
824 |
-
|
825 |
-
all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) + self.a2h(att_res)
|
826 |
-
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
|
827 |
-
sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
|
828 |
-
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
|
829 |
-
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
|
830 |
-
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
|
831 |
-
|
832 |
-
in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size)
|
833 |
-
in_transform = torch.max(\
|
834 |
-
in_transform.narrow(1, 0, self.rnn_size),
|
835 |
-
in_transform.narrow(1, self.rnn_size, self.rnn_size))
|
836 |
-
next_c = forget_gate * state[1][-1] + in_gate * in_transform
|
837 |
-
next_h = out_gate * torch.tanh(next_c)
|
838 |
-
|
839 |
-
output = self.dropout(next_h)
|
840 |
-
state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
|
841 |
-
return output, state
|
842 |
-
|
843 |
-
class AdaAttModel(AttModel):
|
844 |
-
def __init__(self, opt):
|
845 |
-
super(AdaAttModel, self).__init__(opt)
|
846 |
-
self.core = AdaAttCore(opt)
|
847 |
-
|
848 |
-
# AdaAtt with maxout lstm
|
849 |
-
class AdaAttMOModel(AttModel):
|
850 |
-
def __init__(self, opt):
|
851 |
-
super(AdaAttMOModel, self).__init__(opt)
|
852 |
-
self.core = AdaAttCore(opt, True)
|
853 |
-
|
854 |
-
class Att2in2Model(AttModel):
|
855 |
-
def __init__(self, opt):
|
856 |
-
super(Att2in2Model, self).__init__(opt)
|
857 |
-
self.core = Att2in2Core(opt)
|
858 |
-
delattr(self, 'fc_embed')
|
859 |
-
self.fc_embed = lambda x : x
|
860 |
-
|
861 |
-
class Att2all2Model(AttModel):
|
862 |
-
def __init__(self, opt):
|
863 |
-
super(Att2all2Model, self).__init__(opt)
|
864 |
-
self.core = Att2all2Core(opt)
|
865 |
-
delattr(self, 'fc_embed')
|
866 |
-
self.fc_embed = lambda x : x
|
867 |
-
|
868 |
-
class UpDownModel(AttModel):
|
869 |
-
def __init__(self, opt):
|
870 |
-
super(UpDownModel, self).__init__(opt)
|
871 |
-
self.num_layers = 2
|
872 |
-
self.core = UpDownCore(opt)
|
873 |
-
|
874 |
-
class StackAttModel(AttModel):
|
875 |
-
def __init__(self, opt):
|
876 |
-
super(StackAttModel, self).__init__(opt)
|
877 |
-
self.num_layers = 3
|
878 |
-
self.core = StackAttCore(opt)
|
879 |
-
|
880 |
-
class DenseAttModel(AttModel):
|
881 |
-
def __init__(self, opt):
|
882 |
-
super(DenseAttModel, self).__init__(opt)
|
883 |
-
self.num_layers = 3
|
884 |
-
self.core = DenseAttCore(opt)
|
885 |
-
|
886 |
-
class Att2inModel(AttModel):
|
887 |
-
def __init__(self, opt):
|
888 |
-
super(Att2inModel, self).__init__(opt)
|
889 |
-
del self.embed, self.fc_embed, self.att_embed
|
890 |
-
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
891 |
-
self.fc_embed = self.att_embed = lambda x: x
|
892 |
-
del self.ctx2att
|
893 |
-
self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size)
|
894 |
-
self.core = Att2inCore(opt)
|
895 |
-
self.init_weights()
|
896 |
-
|
897 |
-
def init_weights(self):
|
898 |
-
initrange = 0.1
|
899 |
-
self.embed.weight.data.uniform_(-initrange, initrange)
|
900 |
-
self.logit.bias.data.fill_(0)
|
901 |
-
self.logit.weight.data.uniform_(-initrange, initrange)
|
902 |
-
|
903 |
-
|
904 |
-
class NewFCModel(AttModel):
|
905 |
-
def __init__(self, opt):
|
906 |
-
super(NewFCModel, self).__init__(opt)
|
907 |
-
self.fc_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
|
908 |
-
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
909 |
-
self._core = LSTMCore(opt)
|
910 |
-
delattr(self, 'att_embed')
|
911 |
-
self.att_embed = lambda x : x
|
912 |
-
delattr(self, 'ctx2att')
|
913 |
-
self.ctx2att = lambda x: x
|
914 |
-
|
915 |
-
def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks):
|
916 |
-
# Step 0, feed the input image
|
917 |
-
# if (self.training and state[0].is_leaf) or \
|
918 |
-
# (not self.training and state[0].sum() == 0):
|
919 |
-
# _, state = self._core(fc_feats, state)
|
920 |
-
# three cases
|
921 |
-
# normal mle training
|
922 |
-
# Sample
|
923 |
-
# beam search (diverse beam search)
|
924 |
-
# fixed captioning module.
|
925 |
-
is_first_step = (state[0]==0).all(2).all(0) # size: B
|
926 |
-
if is_first_step.all():
|
927 |
-
_, state = self._core(fc_feats, state)
|
928 |
-
elif is_first_step.any():
|
929 |
-
# This is mostly for diverse beam search I think
|
930 |
-
new_state = [torch.zeros_like(_) for _ in state]
|
931 |
-
new_state[0][:, ~is_first_step] = state[0][:, ~is_first_step]
|
932 |
-
new_state[1][:, ~is_first_step] = state[1][:, ~is_first_step]
|
933 |
-
_, state = self._core(fc_feats, state)
|
934 |
-
new_state[0][:, is_first_step] = state[0][:, is_first_step]
|
935 |
-
new_state[1][:, is_first_step] = state[1][:, is_first_step]
|
936 |
-
state = new_state
|
937 |
-
# if (state[0]==0).all():
|
938 |
-
# # Let's forget about diverse beam search first
|
939 |
-
# _, state = self._core(fc_feats, state)
|
940 |
-
return self._core(xt, state)
|
941 |
-
|
942 |
-
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
943 |
-
fc_feats = self.fc_embed(fc_feats)
|
944 |
-
|
945 |
-
return fc_feats, att_feats, att_feats, att_masks
|
946 |
-
|
947 |
-
|
948 |
-
class LMModel(AttModel):
|
949 |
-
def __init__(self, opt):
|
950 |
-
super(LMModel, self).__init__(opt)
|
951 |
-
delattr(self, 'fc_embed')
|
952 |
-
self.fc_embed = lambda x: x.new_zeros(x.shape[0], self.input_encoding_size)
|
953 |
-
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
954 |
-
self._core = LSTMCore(opt)
|
955 |
-
delattr(self, 'att_embed')
|
956 |
-
self.att_embed = lambda x : x
|
957 |
-
delattr(self, 'ctx2att')
|
958 |
-
self.ctx2att = lambda x: x
|
959 |
-
|
960 |
-
def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks):
|
961 |
-
if (state[0]==0).all():
|
962 |
-
# Let's forget about diverse beam search first
|
963 |
-
_, state = self._core(fc_feats, state)
|
964 |
-
return self._core(xt, state)
|
965 |
-
|
966 |
-
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
967 |
-
fc_feats = self.fc_embed(fc_feats)
|
968 |
-
|
969 |
-
return fc_feats, None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/models/BertCapModel.py
DELETED
@@ -1,104 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
BertCapModel is using huggingface transformer bert model as seq2seq model.
|
3 |
-
|
4 |
-
The result is not as goog as original transformer.
|
5 |
-
"""
|
6 |
-
|
7 |
-
from __future__ import absolute_import
|
8 |
-
from __future__ import division
|
9 |
-
from __future__ import print_function
|
10 |
-
|
11 |
-
import torch
|
12 |
-
import torch.nn as nn
|
13 |
-
import torch.nn.functional as F
|
14 |
-
|
15 |
-
import copy
|
16 |
-
import math
|
17 |
-
import numpy as np
|
18 |
-
|
19 |
-
from .CaptionModel import CaptionModel
|
20 |
-
from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
|
21 |
-
try:
|
22 |
-
from transformers import BertModel, BertConfig
|
23 |
-
except:
|
24 |
-
print('Hugginface transformers not installed; please visit https://github.com/huggingface/transformers')
|
25 |
-
from .TransformerModel import subsequent_mask, TransformerModel, Generator
|
26 |
-
|
27 |
-
class EncoderDecoder(nn.Module):
|
28 |
-
"""
|
29 |
-
A standard Encoder-Decoder architecture. Base for this and many
|
30 |
-
other models.
|
31 |
-
"""
|
32 |
-
def __init__(self, encoder, decoder, generator):
|
33 |
-
super(EncoderDecoder, self).__init__()
|
34 |
-
self.encoder = encoder
|
35 |
-
self.decoder = decoder
|
36 |
-
self.generator = generator
|
37 |
-
|
38 |
-
def forward(self, src, tgt, src_mask, tgt_mask):
|
39 |
-
"Take in and process masked src and target sequences."
|
40 |
-
return self.decode(self.encode(src, src_mask), src_mask,
|
41 |
-
tgt, tgt_mask)
|
42 |
-
|
43 |
-
def encode(self, src, src_mask):
|
44 |
-
return self.encoder(inputs_embeds=src,
|
45 |
-
attention_mask=src_mask)[0]
|
46 |
-
|
47 |
-
def decode(self, memory, src_mask, tgt, tgt_mask):
|
48 |
-
return self.decoder(input_ids=tgt,
|
49 |
-
attention_mask=tgt_mask,
|
50 |
-
encoder_hidden_states=memory,
|
51 |
-
encoder_attention_mask=src_mask)[0]
|
52 |
-
|
53 |
-
|
54 |
-
class BertCapModel(TransformerModel):
|
55 |
-
|
56 |
-
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
|
57 |
-
d_model=512, d_ff=2048, h=8, dropout=0.1):
|
58 |
-
"Helper: Construct a model from hyperparameters."
|
59 |
-
enc_config = BertConfig(vocab_size=1,
|
60 |
-
hidden_size=d_model,
|
61 |
-
num_hidden_layers=N_enc,
|
62 |
-
num_attention_heads=h,
|
63 |
-
intermediate_size=d_ff,
|
64 |
-
hidden_dropout_prob=dropout,
|
65 |
-
attention_probs_dropout_prob=dropout,
|
66 |
-
max_position_embeddings=1,
|
67 |
-
type_vocab_size=1)
|
68 |
-
dec_config = BertConfig(vocab_size=tgt_vocab,
|
69 |
-
hidden_size=d_model,
|
70 |
-
num_hidden_layers=N_dec,
|
71 |
-
num_attention_heads=h,
|
72 |
-
intermediate_size=d_ff,
|
73 |
-
hidden_dropout_prob=dropout,
|
74 |
-
attention_probs_dropout_prob=dropout,
|
75 |
-
max_position_embeddings=17,
|
76 |
-
type_vocab_size=1,
|
77 |
-
is_decoder=True)
|
78 |
-
encoder = BertModel(enc_config)
|
79 |
-
def return_embeds(*args, **kwargs):
|
80 |
-
return kwargs['inputs_embeds']
|
81 |
-
del encoder.embeddings; encoder.embeddings = return_embeds
|
82 |
-
decoder = BertModel(dec_config)
|
83 |
-
model = EncoderDecoder(
|
84 |
-
encoder,
|
85 |
-
decoder,
|
86 |
-
Generator(d_model, tgt_vocab))
|
87 |
-
return model
|
88 |
-
|
89 |
-
def __init__(self, opt):
|
90 |
-
super(BertCapModel, self).__init__(opt)
|
91 |
-
|
92 |
-
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
|
93 |
-
"""
|
94 |
-
state = [ys.unsqueeze(0)]
|
95 |
-
"""
|
96 |
-
if len(state) == 0:
|
97 |
-
ys = it.unsqueeze(1)
|
98 |
-
else:
|
99 |
-
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
|
100 |
-
out = self.model.decode(memory, mask,
|
101 |
-
ys,
|
102 |
-
subsequent_mask(ys.size(1))
|
103 |
-
.to(memory.device))
|
104 |
-
return out[:, -1], [ys.unsqueeze(0)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/models/CaptionModel.py
DELETED
@@ -1,407 +0,0 @@
|
|
1 |
-
# This file contains ShowAttendTell and AllImg model
|
2 |
-
|
3 |
-
# ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
|
4 |
-
# https://arxiv.org/abs/1502.03044
|
5 |
-
|
6 |
-
# AllImg is a model where
|
7 |
-
# img feature is concatenated with word embedding at every time step as the input of lstm
|
8 |
-
from __future__ import absolute_import
|
9 |
-
from __future__ import division
|
10 |
-
from __future__ import print_function
|
11 |
-
|
12 |
-
import numpy as np
|
13 |
-
import torch
|
14 |
-
import torch.nn as nn
|
15 |
-
import torch.nn.functional as F
|
16 |
-
from torch.autograd import *
|
17 |
-
from ..utils import misc as utils
|
18 |
-
from . import utils as model_utils
|
19 |
-
|
20 |
-
|
21 |
-
class CaptionModel(nn.Module):
|
22 |
-
def __init__(self):
|
23 |
-
super(CaptionModel, self).__init__()
|
24 |
-
|
25 |
-
# implements beam search
|
26 |
-
# calls beam_step and returns the final set of beams
|
27 |
-
# augments log-probabilities with diversity terms when number of groups > 1
|
28 |
-
|
29 |
-
def forward(self, *args, **kwargs):
|
30 |
-
mode = kwargs.get('mode', 'forward')
|
31 |
-
if 'mode' in kwargs:
|
32 |
-
del kwargs['mode']
|
33 |
-
return getattr(self, '_'+mode)(*args, **kwargs)
|
34 |
-
|
35 |
-
def beam_search(self, init_state, init_logprobs, *args, **kwargs):
|
36 |
-
|
37 |
-
# function computes the similarity score to be augmented
|
38 |
-
def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash):
|
39 |
-
local_time = t - divm
|
40 |
-
unaug_logprobs = logprobs.clone()
|
41 |
-
batch_size = beam_seq_table[0].shape[0]
|
42 |
-
|
43 |
-
if divm > 0:
|
44 |
-
change = logprobs.new_zeros(batch_size, logprobs.shape[-1])
|
45 |
-
for prev_choice in range(divm):
|
46 |
-
prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb
|
47 |
-
for prev_labels in range(bdash):
|
48 |
-
change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1))
|
49 |
-
|
50 |
-
if local_time == 0:
|
51 |
-
logprobs = logprobs - change * diversity_lambda
|
52 |
-
else:
|
53 |
-
logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda
|
54 |
-
|
55 |
-
return logprobs, unaug_logprobs
|
56 |
-
|
57 |
-
|
58 |
-
# does one step of classical beam search
|
59 |
-
|
60 |
-
def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
|
61 |
-
#INPUTS:
|
62 |
-
#logprobs: probabilities augmented after diversity N*bxV
|
63 |
-
#beam_size: obvious
|
64 |
-
#t : time instant
|
65 |
-
#beam_seq : tensor contanining the beams
|
66 |
-
#beam_seq_logprobs: tensor contanining the beam logprobs
|
67 |
-
#beam_logprobs_sum: tensor contanining joint logprobs
|
68 |
-
#OUPUTS:
|
69 |
-
#beam_seq : tensor containing the word indices of the decoded captions Nxbxl
|
70 |
-
#beam_seq_logprobs : log-probability of each decision made, NxbxlxV
|
71 |
-
#beam_logprobs_sum : joint log-probability of each beam Nxb
|
72 |
-
|
73 |
-
batch_size = beam_logprobs_sum.shape[0]
|
74 |
-
vocab_size = logprobs.shape[-1]
|
75 |
-
logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV
|
76 |
-
if t == 0:
|
77 |
-
assert logprobs.shape[1] == 1
|
78 |
-
beam_logprobs_sum = beam_logprobs_sum[:, :1]
|
79 |
-
candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV
|
80 |
-
ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True)
|
81 |
-
ys, ix = ys[:,:beam_size], ix[:,:beam_size]
|
82 |
-
beam_ix = ix // vocab_size # Nxb which beam
|
83 |
-
selected_ix = ix % vocab_size # Nxb # which world
|
84 |
-
state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams
|
85 |
-
|
86 |
-
|
87 |
-
if t > 0:
|
88 |
-
# gather according to beam_ix
|
89 |
-
assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all()
|
90 |
-
beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq))
|
91 |
-
|
92 |
-
beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(beam_seq_logprobs))
|
93 |
-
|
94 |
-
beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl
|
95 |
-
beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \
|
96 |
-
logprobs.reshape(batch_size, -1).gather(1, ix)
|
97 |
-
assert (beam_logprobs_sum == ys).all()
|
98 |
-
_tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size)
|
99 |
-
beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV
|
100 |
-
assert (_tmp_beam_logprobs == beam_logprobs).all()
|
101 |
-
beam_seq_logprobs = torch.cat([
|
102 |
-
beam_seq_logprobs,
|
103 |
-
beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2)
|
104 |
-
|
105 |
-
new_state = [None for _ in state]
|
106 |
-
for _ix in range(len(new_state)):
|
107 |
-
# copy over state in previous beam q to new beam at vix
|
108 |
-
new_state[_ix] = state[_ix][:, state_ix]
|
109 |
-
state = new_state
|
110 |
-
return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state
|
111 |
-
|
112 |
-
# Start diverse_beam_search
|
113 |
-
opt = kwargs['opt']
|
114 |
-
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
|
115 |
-
beam_size = opt.get('beam_size', 10)
|
116 |
-
group_size = opt.get('group_size', 1)
|
117 |
-
diversity_lambda = opt.get('diversity_lambda', 0.5)
|
118 |
-
decoding_constraint = opt.get('decoding_constraint', 0)
|
119 |
-
remove_bad_endings = opt.get('remove_bad_endings', 0)
|
120 |
-
suppress_UNK = opt.get('suppress_UNK', 0)
|
121 |
-
length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
|
122 |
-
bdash = beam_size // group_size # beam per group
|
123 |
-
|
124 |
-
batch_size = init_logprobs.shape[0]
|
125 |
-
device = init_logprobs.device
|
126 |
-
# INITIALIZATIONS
|
127 |
-
beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)]
|
128 |
-
beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in range(group_size)]
|
129 |
-
beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)]
|
130 |
-
|
131 |
-
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
|
132 |
-
done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)]
|
133 |
-
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
|
134 |
-
# state_table = list(zip(*[_.reshape(-1, batch_size * bdash, group_size, *_.shape[2:]).chunk(group_size, 2) for _ in init_state]))
|
135 |
-
state_table = [[_.clone() for _ in init_state] for _ in range(group_size)]
|
136 |
-
# logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0))
|
137 |
-
logprobs_table = [init_logprobs.clone() for _ in range(group_size)]
|
138 |
-
# END INIT
|
139 |
-
|
140 |
-
# Chunk elements in the args
|
141 |
-
args = list(args)
|
142 |
-
args = model_utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x...
|
143 |
-
if self.__class__.__name__ == 'AttEnsemble':
|
144 |
-
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
|
145 |
-
else:
|
146 |
-
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
|
147 |
-
|
148 |
-
for t in range(self.seq_length + group_size - 1):
|
149 |
-
for divm in range(group_size):
|
150 |
-
if t >= divm and t <= self.seq_length + divm - 1:
|
151 |
-
# add diversity
|
152 |
-
logprobs = logprobs_table[divm]
|
153 |
-
# suppress previous word
|
154 |
-
if decoding_constraint and t-divm > 0:
|
155 |
-
logprobs.scatter_(1, beam_seq_table[divm][:, :, t-divm-1].reshape(-1, 1).to(device), float('-inf'))
|
156 |
-
if remove_bad_endings and t-divm > 0:
|
157 |
-
logprobs[torch.from_numpy(np.isin(beam_seq_table[divm][:, :, t-divm-1].cpu().numpy(), self.bad_endings_ix)).reshape(-1), 0] = float('-inf')
|
158 |
-
# suppress UNK tokens in the decoding
|
159 |
-
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK':
|
160 |
-
logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000
|
161 |
-
# diversity is added here
|
162 |
-
# the function directly modifies the logprobs values and hence, we need to return
|
163 |
-
# the unaugmented ones for sorting the candidates in the end. # for historical
|
164 |
-
# reasons :-)
|
165 |
-
logprobs, unaug_logprobs = add_diversity(beam_seq_table,logprobs,t,divm,diversity_lambda,bdash)
|
166 |
-
|
167 |
-
# infer new beams
|
168 |
-
beam_seq_table[divm],\
|
169 |
-
beam_seq_logprobs_table[divm],\
|
170 |
-
beam_logprobs_sum_table[divm],\
|
171 |
-
state_table[divm] = beam_step(logprobs,
|
172 |
-
unaug_logprobs,
|
173 |
-
bdash,
|
174 |
-
t-divm,
|
175 |
-
beam_seq_table[divm],
|
176 |
-
beam_seq_logprobs_table[divm],
|
177 |
-
beam_logprobs_sum_table[divm],
|
178 |
-
state_table[divm])
|
179 |
-
|
180 |
-
# if time's up... or if end token is reached then copy beams
|
181 |
-
for b in range(batch_size):
|
182 |
-
is_end = beam_seq_table[divm][b, :, t-divm] == self.eos_idx
|
183 |
-
assert beam_seq_table[divm].shape[-1] == t-divm+1
|
184 |
-
if t == self.seq_length + divm - 1:
|
185 |
-
is_end.fill_(1)
|
186 |
-
for vix in range(bdash):
|
187 |
-
if is_end[vix]:
|
188 |
-
final_beam = {
|
189 |
-
'seq': beam_seq_table[divm][b, vix].clone(),
|
190 |
-
'logps': beam_seq_logprobs_table[divm][b, vix].clone(),
|
191 |
-
'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(),
|
192 |
-
'p': beam_logprobs_sum_table[divm][b, vix].item()
|
193 |
-
}
|
194 |
-
final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
|
195 |
-
done_beams_table[b][divm].append(final_beam)
|
196 |
-
beam_logprobs_sum_table[divm][b, is_end] -= 1000
|
197 |
-
|
198 |
-
# move the current group one step forward in time
|
199 |
-
|
200 |
-
it = beam_seq_table[divm][:, :, t-divm].reshape(-1).to(logprobs.device)
|
201 |
-
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
|
202 |
-
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
|
203 |
-
|
204 |
-
# all beams are sorted by their log-probabilities
|
205 |
-
done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] for b in range(batch_size)]
|
206 |
-
done_beams = [sum(_, []) for _ in done_beams_table]
|
207 |
-
return done_beams
|
208 |
-
|
209 |
-
def old_beam_search(self, init_state, init_logprobs, *args, **kwargs):
|
210 |
-
|
211 |
-
# function computes the similarity score to be augmented
|
212 |
-
def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
|
213 |
-
local_time = t - divm
|
214 |
-
unaug_logprobsf = logprobsf.clone()
|
215 |
-
for prev_choice in range(divm):
|
216 |
-
prev_decisions = beam_seq_table[prev_choice][local_time]
|
217 |
-
for sub_beam in range(bdash):
|
218 |
-
for prev_labels in range(bdash):
|
219 |
-
logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda
|
220 |
-
return unaug_logprobsf
|
221 |
-
|
222 |
-
# does one step of classical beam search
|
223 |
-
|
224 |
-
def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
|
225 |
-
#INPUTS:
|
226 |
-
#logprobsf: probabilities augmented after diversity
|
227 |
-
#beam_size: obvious
|
228 |
-
#t : time instant
|
229 |
-
#beam_seq : tensor contanining the beams
|
230 |
-
#beam_seq_logprobs: tensor contanining the beam logprobs
|
231 |
-
#beam_logprobs_sum: tensor contanining joint logprobs
|
232 |
-
#OUPUTS:
|
233 |
-
#beam_seq : tensor containing the word indices of the decoded captions
|
234 |
-
#beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
|
235 |
-
#beam_logprobs_sum : joint log-probability of each beam
|
236 |
-
|
237 |
-
ys,ix = torch.sort(logprobsf,1,True)
|
238 |
-
candidates = []
|
239 |
-
cols = min(beam_size, ys.size(1))
|
240 |
-
rows = beam_size
|
241 |
-
if t == 0:
|
242 |
-
rows = 1
|
243 |
-
for c in range(cols): # for each column (word, essentially)
|
244 |
-
for q in range(rows): # for each beam expansion
|
245 |
-
#compute logprob of expanding beam q with word in (sorted) position c
|
246 |
-
local_logprob = ys[q,c].item()
|
247 |
-
candidate_logprob = beam_logprobs_sum[q] + local_logprob
|
248 |
-
# local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
|
249 |
-
candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':unaug_logprobsf[q]})
|
250 |
-
candidates = sorted(candidates, key=lambda x: -x['p'])
|
251 |
-
|
252 |
-
new_state = [_.clone() for _ in state]
|
253 |
-
#beam_seq_prev, beam_seq_logprobs_prev
|
254 |
-
if t >= 1:
|
255 |
-
#we''ll need these as reference when we fork beams around
|
256 |
-
beam_seq_prev = beam_seq[:t].clone()
|
257 |
-
beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
|
258 |
-
for vix in range(beam_size):
|
259 |
-
v = candidates[vix]
|
260 |
-
#fork beam index q into index vix
|
261 |
-
if t >= 1:
|
262 |
-
beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
|
263 |
-
beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
|
264 |
-
#rearrange recurrent states
|
265 |
-
for state_ix in range(len(new_state)):
|
266 |
-
# copy over state in previous beam q to new beam at vix
|
267 |
-
new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
|
268 |
-
#append new end terminal at the end of this beam
|
269 |
-
beam_seq[t, vix] = v['c'] # c'th word is the continuation
|
270 |
-
beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
|
271 |
-
beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
|
272 |
-
state = new_state
|
273 |
-
return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates
|
274 |
-
|
275 |
-
# Start diverse_beam_search
|
276 |
-
opt = kwargs['opt']
|
277 |
-
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
|
278 |
-
beam_size = opt.get('beam_size', 10)
|
279 |
-
group_size = opt.get('group_size', 1)
|
280 |
-
diversity_lambda = opt.get('diversity_lambda', 0.5)
|
281 |
-
decoding_constraint = opt.get('decoding_constraint', 0)
|
282 |
-
remove_bad_endings = opt.get('remove_bad_endings', 0)
|
283 |
-
suppress_UNK = opt.get('suppress_UNK', 0)
|
284 |
-
length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
|
285 |
-
bdash = beam_size // group_size # beam per group
|
286 |
-
|
287 |
-
# INITIALIZATIONS
|
288 |
-
beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)]
|
289 |
-
beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash, self.vocab_size + 1).zero_() for _ in range(group_size)]
|
290 |
-
beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]
|
291 |
-
|
292 |
-
# logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
|
293 |
-
done_beams_table = [[] for _ in range(group_size)]
|
294 |
-
# state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
|
295 |
-
state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state]))
|
296 |
-
logprobs_table = list(init_logprobs.chunk(group_size, 0))
|
297 |
-
# END INIT
|
298 |
-
|
299 |
-
# Chunk elements in the args
|
300 |
-
args = list(args)
|
301 |
-
if self.__class__.__name__ == 'AttEnsemble':
|
302 |
-
args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name
|
303 |
-
args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name
|
304 |
-
else:
|
305 |
-
args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args]
|
306 |
-
args = [[args[i][j] for i in range(len(args))] for j in range(group_size)]
|
307 |
-
|
308 |
-
for t in range(self.seq_length + group_size - 1):
|
309 |
-
for divm in range(group_size):
|
310 |
-
if t >= divm and t <= self.seq_length + divm - 1:
|
311 |
-
# add diversity
|
312 |
-
logprobsf = logprobs_table[divm]
|
313 |
-
# suppress previous word
|
314 |
-
if decoding_constraint and t-divm > 0:
|
315 |
-
logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).to(logprobsf.device), float('-inf'))
|
316 |
-
if remove_bad_endings and t-divm > 0:
|
317 |
-
logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf')
|
318 |
-
# suppress UNK tokens in the decoding
|
319 |
-
if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1)-1)] == 'UNK':
|
320 |
-
logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000
|
321 |
-
# diversity is added here
|
322 |
-
# the function directly modifies the logprobsf values and hence, we need to return
|
323 |
-
# the unaugmented ones for sorting the candidates in the end. # for historical
|
324 |
-
# reasons :-)
|
325 |
-
unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash)
|
326 |
-
|
327 |
-
# infer new beams
|
328 |
-
beam_seq_table[divm],\
|
329 |
-
beam_seq_logprobs_table[divm],\
|
330 |
-
beam_logprobs_sum_table[divm],\
|
331 |
-
state_table[divm],\
|
332 |
-
candidates_divm = beam_step(logprobsf,
|
333 |
-
unaug_logprobsf,
|
334 |
-
bdash,
|
335 |
-
t-divm,
|
336 |
-
beam_seq_table[divm],
|
337 |
-
beam_seq_logprobs_table[divm],
|
338 |
-
beam_logprobs_sum_table[divm],
|
339 |
-
state_table[divm])
|
340 |
-
|
341 |
-
# if time's up... or if end token is reached then copy beams
|
342 |
-
for vix in range(bdash):
|
343 |
-
if beam_seq_table[divm][t-divm,vix] == self.eos_idx or t == self.seq_length + divm - 1:
|
344 |
-
final_beam = {
|
345 |
-
'seq': beam_seq_table[divm][:, vix].clone(),
|
346 |
-
'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
|
347 |
-
'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
|
348 |
-
'p': beam_logprobs_sum_table[divm][vix].item()
|
349 |
-
}
|
350 |
-
final_beam['p'] = length_penalty(t-divm+1, final_beam['p'])
|
351 |
-
done_beams_table[divm].append(final_beam)
|
352 |
-
# don't continue beams from finished sequences
|
353 |
-
beam_logprobs_sum_table[divm][vix] = -1000
|
354 |
-
|
355 |
-
# move the current group one step forward in time
|
356 |
-
|
357 |
-
it = beam_seq_table[divm][t-divm].to(logprobsf.device)
|
358 |
-
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]]))
|
359 |
-
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)
|
360 |
-
|
361 |
-
# all beams are sorted by their log-probabilities
|
362 |
-
done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
|
363 |
-
done_beams = sum(done_beams_table, [])
|
364 |
-
return done_beams
|
365 |
-
|
366 |
-
def sample_next_word(self, logprobs, sample_method, temperature):
|
367 |
-
if sample_method == 'greedy':
|
368 |
-
sampleLogprobs, it = torch.max(logprobs.data, 1)
|
369 |
-
it = it.view(-1).long()
|
370 |
-
elif sample_method == 'gumbel': # gumbel softmax
|
371 |
-
# ref: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
|
372 |
-
def sample_gumbel(shape, eps=1e-20):
|
373 |
-
U = torch.rand(shape).to(logprobs.device)
|
374 |
-
return -torch.log(-torch.log(U + eps) + eps)
|
375 |
-
def gumbel_softmax_sample(logits, temperature):
|
376 |
-
y = logits + sample_gumbel(logits.size())
|
377 |
-
return F.log_softmax(y / temperature, dim=-1)
|
378 |
-
_logprobs = gumbel_softmax_sample(logprobs, temperature)
|
379 |
-
_, it = torch.max(_logprobs.data, 1)
|
380 |
-
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
|
381 |
-
else:
|
382 |
-
logprobs = logprobs / temperature
|
383 |
-
if sample_method.startswith('top'): # topk sampling
|
384 |
-
top_num = float(sample_method[3:])
|
385 |
-
if 0 < top_num < 1:
|
386 |
-
# nucleus sampling from # The Curious Case of Neural Text Degeneration
|
387 |
-
probs = F.softmax(logprobs, dim=1)
|
388 |
-
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
|
389 |
-
_cumsum = sorted_probs.cumsum(1)
|
390 |
-
mask = _cumsum < top_num
|
391 |
-
mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
|
392 |
-
sorted_probs = sorted_probs * mask.to(sorted_probs)
|
393 |
-
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
|
394 |
-
logprobs.scatter_(1, sorted_indices, sorted_probs.log())
|
395 |
-
else:
|
396 |
-
the_k = int(top_num)
|
397 |
-
tmp = torch.empty_like(logprobs).fill_(float('-inf'))
|
398 |
-
topk, indices = torch.topk(logprobs, the_k, dim=1)
|
399 |
-
tmp = tmp.scatter(1, indices, topk)
|
400 |
-
logprobs = tmp
|
401 |
-
it = torch.distributions.Categorical(logits=logprobs.detach()).sample()
|
402 |
-
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions
|
403 |
-
return it, sampleLogprobs
|
404 |
-
|
405 |
-
|
406 |
-
def decode_sequence(self, seq):
|
407 |
-
return utils.decode_sequence(self.vocab, seq)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/models/FCModel.py
DELETED
@@ -1,204 +0,0 @@
|
|
1 |
-
from __future__ import absolute_import
|
2 |
-
from __future__ import division
|
3 |
-
from __future__ import print_function
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from torch.autograd import *
|
9 |
-
from . import utils
|
10 |
-
|
11 |
-
from .CaptionModel import CaptionModel
|
12 |
-
|
13 |
-
class LSTMCore(nn.Module):
|
14 |
-
def __init__(self, opt):
|
15 |
-
super(LSTMCore, self).__init__()
|
16 |
-
self.input_encoding_size = opt.input_encoding_size
|
17 |
-
self.rnn_size = opt.rnn_size
|
18 |
-
self.drop_prob_lm = opt.drop_prob_lm
|
19 |
-
|
20 |
-
# Build a LSTM
|
21 |
-
self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
|
22 |
-
self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
|
23 |
-
self.dropout = nn.Dropout(self.drop_prob_lm)
|
24 |
-
|
25 |
-
def forward(self, xt, state):
|
26 |
-
|
27 |
-
all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
|
28 |
-
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
|
29 |
-
sigmoid_chunk = torch.sigmoid(sigmoid_chunk)
|
30 |
-
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
|
31 |
-
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
|
32 |
-
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)
|
33 |
-
|
34 |
-
in_transform = torch.max(\
|
35 |
-
all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size),
|
36 |
-
all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size))
|
37 |
-
next_c = forget_gate * state[1][-1] + in_gate * in_transform
|
38 |
-
next_h = out_gate * torch.tanh(next_c)
|
39 |
-
|
40 |
-
output = self.dropout(next_h)
|
41 |
-
state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
|
42 |
-
return output, state
|
43 |
-
|
44 |
-
class FCModel(CaptionModel):
|
45 |
-
def __init__(self, opt):
|
46 |
-
super(FCModel, self).__init__()
|
47 |
-
self.vocab_size = opt.vocab_size
|
48 |
-
self.input_encoding_size = opt.input_encoding_size
|
49 |
-
self.rnn_type = opt.rnn_type
|
50 |
-
self.rnn_size = opt.rnn_size
|
51 |
-
self.num_layers = opt.num_layers
|
52 |
-
self.drop_prob_lm = opt.drop_prob_lm
|
53 |
-
self.seq_length = opt.seq_length
|
54 |
-
self.fc_feat_size = opt.fc_feat_size
|
55 |
-
|
56 |
-
self.ss_prob = 0.0 # Schedule sampling probability
|
57 |
-
|
58 |
-
self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
|
59 |
-
self.core = LSTMCore(opt)
|
60 |
-
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
61 |
-
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
|
62 |
-
|
63 |
-
self.init_weights()
|
64 |
-
|
65 |
-
def init_weights(self):
|
66 |
-
initrange = 0.1
|
67 |
-
self.embed.weight.data.uniform_(-initrange, initrange)
|
68 |
-
self.logit.bias.data.fill_(0)
|
69 |
-
self.logit.weight.data.uniform_(-initrange, initrange)
|
70 |
-
|
71 |
-
def init_hidden(self, bsz):
|
72 |
-
weight = self.logit.weight
|
73 |
-
if self.rnn_type == 'lstm':
|
74 |
-
return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
|
75 |
-
weight.new_zeros(self.num_layers, bsz, self.rnn_size))
|
76 |
-
else:
|
77 |
-
return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
|
78 |
-
|
79 |
-
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
80 |
-
batch_size = fc_feats.size(0)
|
81 |
-
seq_per_img = seq.shape[0] // batch_size
|
82 |
-
state = self.init_hidden(batch_size*seq_per_img)
|
83 |
-
outputs = []
|
84 |
-
|
85 |
-
if seq_per_img > 1:
|
86 |
-
fc_feats = utils.repeat_tensors(seq_per_img, fc_feats)
|
87 |
-
|
88 |
-
for i in range(seq.size(1) + 1):
|
89 |
-
if i == 0:
|
90 |
-
xt = self.img_embed(fc_feats)
|
91 |
-
else:
|
92 |
-
if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
|
93 |
-
sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1)
|
94 |
-
sample_mask = sample_prob < self.ss_prob
|
95 |
-
if sample_mask.sum() == 0:
|
96 |
-
it = seq[:, i-1].clone()
|
97 |
-
else:
|
98 |
-
sample_ind = sample_mask.nonzero().view(-1)
|
99 |
-
it = seq[:, i-1].data.clone()
|
100 |
-
#prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
|
101 |
-
#it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
|
102 |
-
prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
|
103 |
-
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
|
104 |
-
else:
|
105 |
-
it = seq[:, i-1].clone()
|
106 |
-
# break if all the sequences end
|
107 |
-
if i >= 2 and seq[:, i-1].sum() == 0:
|
108 |
-
break
|
109 |
-
xt = self.embed(it)
|
110 |
-
|
111 |
-
output, state = self.core(xt, state)
|
112 |
-
output = F.log_softmax(self.logit(output), dim=1)
|
113 |
-
outputs.append(output)
|
114 |
-
|
115 |
-
return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
|
116 |
-
|
117 |
-
def get_logprobs_state(self, it, state):
|
118 |
-
# 'it' is contains a word index
|
119 |
-
xt = self.embed(it)
|
120 |
-
|
121 |
-
output, state = self.core(xt, state)
|
122 |
-
logprobs = F.log_softmax(self.logit(output), dim=1)
|
123 |
-
|
124 |
-
return logprobs, state
|
125 |
-
|
126 |
-
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
127 |
-
beam_size = opt.get('beam_size', 10)
|
128 |
-
batch_size = fc_feats.size(0)
|
129 |
-
|
130 |
-
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
131 |
-
seq = torch.LongTensor(self.seq_length, batch_size).zero_()
|
132 |
-
seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1)
|
133 |
-
# lets process every image independently for now, for simplicity
|
134 |
-
|
135 |
-
self.done_beams = [[] for _ in range(batch_size)]
|
136 |
-
for k in range(batch_size):
|
137 |
-
state = self.init_hidden(beam_size)
|
138 |
-
for t in range(2):
|
139 |
-
if t == 0:
|
140 |
-
xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
|
141 |
-
elif t == 1: # input <bos>
|
142 |
-
it = fc_feats.data.new(beam_size).long().zero_()
|
143 |
-
xt = self.embed(it)
|
144 |
-
|
145 |
-
output, state = self.core(xt, state)
|
146 |
-
logprobs = F.log_softmax(self.logit(output), dim=1)
|
147 |
-
|
148 |
-
self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
|
149 |
-
seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
150 |
-
seqLogprobs[:, k] = self.done_beams[k][0]['logps']
|
151 |
-
# return the samples and their log likelihoods
|
152 |
-
return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
|
153 |
-
|
154 |
-
def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
|
155 |
-
sample_method = opt.get('sample_method', 'greedy')
|
156 |
-
beam_size = opt.get('beam_size', 1)
|
157 |
-
temperature = opt.get('temperature', 1.0)
|
158 |
-
if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
|
159 |
-
return self._sample_beam(fc_feats, att_feats, opt)
|
160 |
-
|
161 |
-
batch_size = fc_feats.size(0)
|
162 |
-
state = self.init_hidden(batch_size)
|
163 |
-
seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
|
164 |
-
seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length, self.vocab_size + 1)
|
165 |
-
for t in range(self.seq_length + 2):
|
166 |
-
if t == 0:
|
167 |
-
xt = self.img_embed(fc_feats)
|
168 |
-
else:
|
169 |
-
if t == 1: # input <bos>
|
170 |
-
it = fc_feats.data.new(batch_size).long().zero_()
|
171 |
-
xt = self.embed(it)
|
172 |
-
|
173 |
-
output, state = self.core(xt, state)
|
174 |
-
logprobs = F.log_softmax(self.logit(output), dim=1)
|
175 |
-
|
176 |
-
# sample the next_word
|
177 |
-
if t == self.seq_length + 1: # skip if we achieve maximum length
|
178 |
-
break
|
179 |
-
if sample_method == 'greedy':
|
180 |
-
sampleLogprobs, it = torch.max(logprobs.data, 1)
|
181 |
-
it = it.view(-1).long()
|
182 |
-
else:
|
183 |
-
if temperature == 1.0:
|
184 |
-
prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
|
185 |
-
else:
|
186 |
-
# scale logprobs by temperature
|
187 |
-
prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
|
188 |
-
it = torch.multinomial(prob_prev, 1).to(logprobs.device)
|
189 |
-
sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
|
190 |
-
it = it.view(-1).long() # and flatten indices for downstream processing
|
191 |
-
|
192 |
-
if t >= 1:
|
193 |
-
# stop when all finished
|
194 |
-
if t == 1:
|
195 |
-
unfinished = it > 0
|
196 |
-
else:
|
197 |
-
unfinished = unfinished & (it > 0)
|
198 |
-
it = it * unfinished.type_as(it)
|
199 |
-
seq[:,t-1] = it #seq[t] the input of t+2 time step
|
200 |
-
seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
|
201 |
-
if unfinished.sum() == 0:
|
202 |
-
break
|
203 |
-
|
204 |
-
return seq, seqLogprobs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/models/M2Transformer.py
DELETED
@@ -1,98 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Instruction to use meshed_memory_transformer (https://arxiv.org/abs/1912.08226)
|
3 |
-
|
4 |
-
pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git
|
5 |
-
|
6 |
-
Note:
|
7 |
-
Currently m2transformer is not performing as well as original transformer. Not sure why? Still investigating.
|
8 |
-
"""
|
9 |
-
|
10 |
-
from __future__ import absolute_import
|
11 |
-
from __future__ import division
|
12 |
-
from __future__ import print_function
|
13 |
-
|
14 |
-
import torch
|
15 |
-
import torch.nn as nn
|
16 |
-
import torch.nn.functional as F
|
17 |
-
|
18 |
-
import copy
|
19 |
-
import math
|
20 |
-
import numpy as np
|
21 |
-
|
22 |
-
from .CaptionModel import CaptionModel
|
23 |
-
from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
|
24 |
-
|
25 |
-
try:
|
26 |
-
from m2transformer.models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory
|
27 |
-
except:
|
28 |
-
print('meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`')
|
29 |
-
from .TransformerModel import subsequent_mask, TransformerModel
|
30 |
-
|
31 |
-
|
32 |
-
class M2TransformerModel(TransformerModel):
|
33 |
-
|
34 |
-
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
|
35 |
-
d_model=512, d_ff=2048, h=8, dropout=0.1):
|
36 |
-
"Helper: Construct a model from hyperparameters."
|
37 |
-
encoder = MemoryAugmentedEncoder(N_enc, 0, attention_module=ScaledDotProductAttentionMemory,
|
38 |
-
attention_module_kwargs={'m': 40})
|
39 |
-
# Another implementation is to use MultiLevelEncoder + att_embed
|
40 |
-
decoder = MeshedDecoder(tgt_vocab, 54, N_dec, -1) # -1 is padding;
|
41 |
-
model = Transformer(0, encoder, decoder) # 0 is bos
|
42 |
-
return model
|
43 |
-
|
44 |
-
def __init__(self, opt):
|
45 |
-
super(M2TransformerModel, self).__init__(opt)
|
46 |
-
delattr(self, 'att_embed')
|
47 |
-
self.att_embed = lambda x: x # The visual embed is in the MAEncoder
|
48 |
-
# Notes: The dropout in MAEncoder is different from my att_embed, mine is 0.5?
|
49 |
-
# Also the attention mask seems wrong in MAEncoder too...intersting
|
50 |
-
|
51 |
-
def logit(self, x): # unsafe way
|
52 |
-
return x # M2transformer always output logsoftmax
|
53 |
-
|
54 |
-
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
55 |
-
|
56 |
-
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
|
57 |
-
memory, att_masks = self.model.encoder(att_feats)
|
58 |
-
|
59 |
-
return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
|
60 |
-
|
61 |
-
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
62 |
-
if seq.ndim == 3: # B * seq_per_img * seq_len
|
63 |
-
seq = seq.reshape(-1, seq.shape[2])
|
64 |
-
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
|
65 |
-
|
66 |
-
seq = seq.clone()
|
67 |
-
seq[~seq_mask.any(-2)] = -1 # Make padding to be -1 (my dataloader uses 0 as padding)
|
68 |
-
outputs = self.model(att_feats, seq)
|
69 |
-
|
70 |
-
return outputs
|
71 |
-
|
72 |
-
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
|
73 |
-
"""
|
74 |
-
state = [ys.unsqueeze(0)]
|
75 |
-
"""
|
76 |
-
if len(state) == 0:
|
77 |
-
ys = it.unsqueeze(1)
|
78 |
-
else:
|
79 |
-
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
|
80 |
-
out = self.model.decoder(ys, memory, mask)
|
81 |
-
return out[:, -1], [ys.unsqueeze(0)]
|
82 |
-
|
83 |
-
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
84 |
-
beam_size = opt.get('beam_size', 10)
|
85 |
-
group_size = opt.get('group_size', 1)
|
86 |
-
sample_n = opt.get('sample_n', 10)
|
87 |
-
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search'
|
88 |
-
|
89 |
-
att_feats, _, __, ___ = self._prepare_feature_forward(att_feats, att_masks)
|
90 |
-
seq, logprobs, seqLogprobs = self.model.beam_search(att_feats, self.seq_length, 0,
|
91 |
-
beam_size, return_probs=True, out_size=beam_size)
|
92 |
-
seq = seq.reshape(-1, *seq.shape[2:])
|
93 |
-
seqLogprobs = seqLogprobs.reshape(-1, *seqLogprobs.shape[2:])
|
94 |
-
|
95 |
-
# if not (seqLogprobs.gather(-1, seq.unsqueeze(-1)).squeeze(-1) == logprobs.reshape(-1, logprobs.shape[-1])).all():
|
96 |
-
# import pudb;pu.db
|
97 |
-
# seqLogprobs = logprobs.reshape(-1, logprobs.shape[-1]).unsqueeze(-1).expand(-1,-1,seqLogprobs.shape[-1])
|
98 |
-
return seq, seqLogprobs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/models/ShowTellModel.py
DELETED
@@ -1,174 +0,0 @@
|
|
1 |
-
from __future__ import absolute_import
|
2 |
-
from __future__ import division
|
3 |
-
from __future__ import print_function
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from torch.autograd import *
|
9 |
-
from . import utils
|
10 |
-
|
11 |
-
from .CaptionModel import CaptionModel
|
12 |
-
|
13 |
-
class ShowTellModel(CaptionModel):
|
14 |
-
def __init__(self, opt):
|
15 |
-
super(ShowTellModel, self).__init__()
|
16 |
-
self.vocab_size = opt.vocab_size
|
17 |
-
self.input_encoding_size = opt.input_encoding_size
|
18 |
-
self.rnn_type = opt.rnn_type
|
19 |
-
self.rnn_size = opt.rnn_size
|
20 |
-
self.num_layers = opt.num_layers
|
21 |
-
self.drop_prob_lm = opt.drop_prob_lm
|
22 |
-
self.seq_length = opt.seq_length
|
23 |
-
self.fc_feat_size = opt.fc_feat_size
|
24 |
-
|
25 |
-
self.ss_prob = 0.0 # Schedule sampling probability
|
26 |
-
|
27 |
-
self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
|
28 |
-
self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm)
|
29 |
-
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
|
30 |
-
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
|
31 |
-
self.dropout = nn.Dropout(self.drop_prob_lm)
|
32 |
-
|
33 |
-
self.init_weights()
|
34 |
-
|
35 |
-
def init_weights(self):
|
36 |
-
initrange = 0.1
|
37 |
-
self.embed.weight.data.uniform_(-initrange, initrange)
|
38 |
-
self.logit.bias.data.fill_(0)
|
39 |
-
self.logit.weight.data.uniform_(-initrange, initrange)
|
40 |
-
|
41 |
-
def init_hidden(self, bsz):
|
42 |
-
weight = self.logit.weight
|
43 |
-
if self.rnn_type == 'lstm':
|
44 |
-
return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
|
45 |
-
weight.new_zeros(self.num_layers, bsz, self.rnn_size))
|
46 |
-
else:
|
47 |
-
return weight.new_zeros(self.num_layers, bsz, self.rnn_size)
|
48 |
-
|
49 |
-
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
50 |
-
batch_size = fc_feats.size(0)
|
51 |
-
seq_per_img = seq.shape[0] // batch_size
|
52 |
-
state = self.init_hidden(batch_size*seq_per_img)
|
53 |
-
outputs = []
|
54 |
-
|
55 |
-
if seq_per_img > 1:
|
56 |
-
fc_feats = utils.repeat_tensors(seq_per_img, fc_feats)
|
57 |
-
|
58 |
-
for i in range(seq.size(1) + 1):
|
59 |
-
if i == 0:
|
60 |
-
xt = self.img_embed(fc_feats)
|
61 |
-
else:
|
62 |
-
if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
|
63 |
-
sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1)
|
64 |
-
sample_mask = sample_prob < self.ss_prob
|
65 |
-
if sample_mask.sum() == 0:
|
66 |
-
it = seq[:, i-1].clone()
|
67 |
-
else:
|
68 |
-
sample_ind = sample_mask.nonzero().view(-1)
|
69 |
-
it = seq[:, i-1].data.clone()
|
70 |
-
#prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
|
71 |
-
#it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
|
72 |
-
prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
|
73 |
-
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
|
74 |
-
else:
|
75 |
-
it = seq[:, i-1].clone()
|
76 |
-
# break if all the sequences end
|
77 |
-
if i >= 2 and seq[:, i-1].data.sum() == 0:
|
78 |
-
break
|
79 |
-
xt = self.embed(it)
|
80 |
-
|
81 |
-
output, state = self.core(xt.unsqueeze(0), state)
|
82 |
-
output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
|
83 |
-
outputs.append(output)
|
84 |
-
|
85 |
-
return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
|
86 |
-
|
87 |
-
def get_logprobs_state(self, it, state):
|
88 |
-
# 'it' contains a word index
|
89 |
-
xt = self.embed(it)
|
90 |
-
|
91 |
-
output, state = self.core(xt.unsqueeze(0), state)
|
92 |
-
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
|
93 |
-
|
94 |
-
return logprobs, state
|
95 |
-
|
96 |
-
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}):
|
97 |
-
beam_size = opt.get('beam_size', 10)
|
98 |
-
batch_size = fc_feats.size(0)
|
99 |
-
|
100 |
-
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
|
101 |
-
seq = torch.LongTensor(self.seq_length, batch_size).zero_()
|
102 |
-
seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
|
103 |
-
# lets process every image independently for now, for simplicity
|
104 |
-
|
105 |
-
self.done_beams = [[] for _ in range(batch_size)]
|
106 |
-
for k in range(batch_size):
|
107 |
-
state = self.init_hidden(beam_size)
|
108 |
-
for t in range(2):
|
109 |
-
if t == 0:
|
110 |
-
xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
|
111 |
-
elif t == 1: # input <bos>
|
112 |
-
it = fc_feats.data.new(beam_size).long().zero_()
|
113 |
-
xt = self.embed(it)
|
114 |
-
|
115 |
-
output, state = self.core(xt.unsqueeze(0), state)
|
116 |
-
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
|
117 |
-
|
118 |
-
self.done_beams[k] = self.beam_search(state, logprobs, opt=opt)
|
119 |
-
seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
|
120 |
-
seqLogprobs[:, k] = self.done_beams[k][0]['logps']
|
121 |
-
# return the samples and their log likelihoods
|
122 |
-
return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)
|
123 |
-
|
124 |
-
def _sample(self, fc_feats, att_feats, att_masks=None, opt={}):
|
125 |
-
sample_method = opt.get('sample_method', 'greedy')
|
126 |
-
beam_size = opt.get('beam_size', 1)
|
127 |
-
temperature = opt.get('temperature', 1.0)
|
128 |
-
if beam_size > 1 and sample_method in ['greedy', 'beam_search']:
|
129 |
-
return self.sample_beam(fc_feats, att_feats, opt)
|
130 |
-
|
131 |
-
batch_size = fc_feats.size(0)
|
132 |
-
state = self.init_hidden(batch_size)
|
133 |
-
seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
|
134 |
-
seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
|
135 |
-
for t in range(self.seq_length + 2):
|
136 |
-
if t == 0:
|
137 |
-
xt = self.img_embed(fc_feats)
|
138 |
-
else:
|
139 |
-
if t == 1: # input <bos>
|
140 |
-
it = fc_feats.data.new(batch_size).long().zero_()
|
141 |
-
xt = self.embed(it)
|
142 |
-
|
143 |
-
output, state = self.core(xt.unsqueeze(0), state)
|
144 |
-
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
|
145 |
-
|
146 |
-
# sample the next word
|
147 |
-
if t == self.seq_length + 1: # skip if we achieve maximum length
|
148 |
-
break
|
149 |
-
if sample_method == 'greedy':
|
150 |
-
sampleLogprobs, it = torch.max(logprobs.data, 1)
|
151 |
-
it = it.view(-1).long()
|
152 |
-
else:
|
153 |
-
if temperature == 1.0:
|
154 |
-
prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
|
155 |
-
else:
|
156 |
-
# scale logprobs by temperature
|
157 |
-
prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
|
158 |
-
it = torch.multinomial(prob_prev, 1).to(logprobs.device)
|
159 |
-
sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions
|
160 |
-
it = it.view(-1).long() # and flatten indices for downstream processing
|
161 |
-
|
162 |
-
if t >= 1:
|
163 |
-
# stop when all finished
|
164 |
-
if t == 1:
|
165 |
-
unfinished = it > 0
|
166 |
-
else:
|
167 |
-
unfinished = unfinished & (it > 0)
|
168 |
-
it = it * unfinished.type_as(it)
|
169 |
-
seq[:,t-1] = it #seq[t] the input of t+2 time step
|
170 |
-
seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
|
171 |
-
if unfinished.sum() == 0:
|
172 |
-
break
|
173 |
-
|
174 |
-
return seq, seqLogprobs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/models/TransformerModel.py
DELETED
@@ -1,363 +0,0 @@
|
|
1 |
-
# This file contains Transformer network
|
2 |
-
# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
|
3 |
-
|
4 |
-
# The cfg name correspondance:
|
5 |
-
# N=num_layers
|
6 |
-
# d_model=input_encoding_size
|
7 |
-
# d_ff=rnn_size
|
8 |
-
# h is always 8
|
9 |
-
|
10 |
-
from __future__ import absolute_import
|
11 |
-
from __future__ import division
|
12 |
-
from __future__ import print_function
|
13 |
-
|
14 |
-
import torch
|
15 |
-
import torch.nn as nn
|
16 |
-
import torch.nn.functional as F
|
17 |
-
from . import utils
|
18 |
-
|
19 |
-
import copy
|
20 |
-
import math
|
21 |
-
import numpy as np
|
22 |
-
|
23 |
-
from .CaptionModel import CaptionModel
|
24 |
-
from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
|
25 |
-
|
26 |
-
class EncoderDecoder(nn.Module):
|
27 |
-
"""
|
28 |
-
A standard Encoder-Decoder architecture. Base for this and many
|
29 |
-
other models.
|
30 |
-
"""
|
31 |
-
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
|
32 |
-
super(EncoderDecoder, self).__init__()
|
33 |
-
self.encoder = encoder
|
34 |
-
self.decoder = decoder
|
35 |
-
self.src_embed = src_embed
|
36 |
-
self.tgt_embed = tgt_embed
|
37 |
-
self.generator = generator
|
38 |
-
|
39 |
-
def forward(self, src, tgt, src_mask, tgt_mask):
|
40 |
-
"Take in and process masked src and target sequences."
|
41 |
-
return self.decode(self.encode(src, src_mask), src_mask,
|
42 |
-
tgt, tgt_mask)
|
43 |
-
|
44 |
-
def encode(self, src, src_mask):
|
45 |
-
return self.encoder(self.src_embed(src), src_mask)
|
46 |
-
|
47 |
-
def decode(self, memory, src_mask, tgt, tgt_mask):
|
48 |
-
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
|
49 |
-
|
50 |
-
class Generator(nn.Module):
|
51 |
-
"Define standard linear + softmax generation step."
|
52 |
-
def __init__(self, d_model, vocab):
|
53 |
-
super(Generator, self).__init__()
|
54 |
-
self.proj = nn.Linear(d_model, vocab)
|
55 |
-
|
56 |
-
def forward(self, x):
|
57 |
-
return F.log_softmax(self.proj(x), dim=-1)
|
58 |
-
|
59 |
-
def clones(module, N):
|
60 |
-
"Produce N identical layers."
|
61 |
-
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
62 |
-
|
63 |
-
class Encoder(nn.Module):
|
64 |
-
"Core encoder is a stack of N layers"
|
65 |
-
def __init__(self, layer, N):
|
66 |
-
super(Encoder, self).__init__()
|
67 |
-
self.layers = clones(layer, N)
|
68 |
-
self.norm = LayerNorm(layer.size)
|
69 |
-
|
70 |
-
def forward(self, x, mask):
|
71 |
-
"Pass the input (and mask) through each layer in turn."
|
72 |
-
for layer in self.layers:
|
73 |
-
x = layer(x, mask)
|
74 |
-
return self.norm(x)
|
75 |
-
|
76 |
-
class LayerNorm(nn.Module):
|
77 |
-
"Construct a layernorm module (See citation for details)."
|
78 |
-
def __init__(self, features, eps=1e-6):
|
79 |
-
super(LayerNorm, self).__init__()
|
80 |
-
self.a_2 = nn.Parameter(torch.ones(features))
|
81 |
-
self.b_2 = nn.Parameter(torch.zeros(features))
|
82 |
-
self.eps = eps
|
83 |
-
|
84 |
-
def forward(self, x):
|
85 |
-
mean = x.mean(-1, keepdim=True)
|
86 |
-
std = x.std(-1, keepdim=True)
|
87 |
-
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
88 |
-
|
89 |
-
class SublayerConnection(nn.Module):
|
90 |
-
"""
|
91 |
-
A residual connection followed by a layer norm.
|
92 |
-
Note for code simplicity the norm is first as opposed to last.
|
93 |
-
"""
|
94 |
-
def __init__(self, size, dropout):
|
95 |
-
super(SublayerConnection, self).__init__()
|
96 |
-
self.norm = LayerNorm(size)
|
97 |
-
self.dropout = nn.Dropout(dropout)
|
98 |
-
|
99 |
-
def forward(self, x, sublayer):
|
100 |
-
"Apply residual connection to any sublayer with the same size."
|
101 |
-
return x + self.dropout(sublayer(self.norm(x)))
|
102 |
-
|
103 |
-
class EncoderLayer(nn.Module):
|
104 |
-
"Encoder is made up of self-attn and feed forward (defined below)"
|
105 |
-
def __init__(self, size, self_attn, feed_forward, dropout):
|
106 |
-
super(EncoderLayer, self).__init__()
|
107 |
-
self.self_attn = self_attn
|
108 |
-
self.feed_forward = feed_forward
|
109 |
-
self.sublayer = clones(SublayerConnection(size, dropout), 2)
|
110 |
-
self.size = size
|
111 |
-
|
112 |
-
def forward(self, x, mask):
|
113 |
-
"Follow Figure 1 (left) for connections."
|
114 |
-
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
|
115 |
-
return self.sublayer[1](x, self.feed_forward)
|
116 |
-
|
117 |
-
class Decoder(nn.Module):
|
118 |
-
"Generic N layer decoder with masking."
|
119 |
-
def __init__(self, layer, N):
|
120 |
-
super(Decoder, self).__init__()
|
121 |
-
self.layers = clones(layer, N)
|
122 |
-
self.norm = LayerNorm(layer.size)
|
123 |
-
|
124 |
-
def forward(self, x, memory, src_mask, tgt_mask):
|
125 |
-
for layer in self.layers:
|
126 |
-
x = layer(x, memory, src_mask, tgt_mask)
|
127 |
-
return self.norm(x)
|
128 |
-
|
129 |
-
class DecoderLayer(nn.Module):
|
130 |
-
"Decoder is made of self-attn, src-attn, and feed forward (defined below)"
|
131 |
-
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
|
132 |
-
super(DecoderLayer, self).__init__()
|
133 |
-
self.size = size
|
134 |
-
self.self_attn = self_attn
|
135 |
-
self.src_attn = src_attn
|
136 |
-
self.feed_forward = feed_forward
|
137 |
-
self.sublayer = clones(SublayerConnection(size, dropout), 3)
|
138 |
-
|
139 |
-
def forward(self, x, memory, src_mask, tgt_mask):
|
140 |
-
"Follow Figure 1 (right) for connections."
|
141 |
-
m = memory
|
142 |
-
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
|
143 |
-
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
|
144 |
-
return self.sublayer[2](x, self.feed_forward)
|
145 |
-
|
146 |
-
def subsequent_mask(size):
|
147 |
-
"Mask out subsequent positions."
|
148 |
-
attn_shape = (1, size, size)
|
149 |
-
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
|
150 |
-
return torch.from_numpy(subsequent_mask) == 0
|
151 |
-
|
152 |
-
def attention(query, key, value, mask=None, dropout=None):
|
153 |
-
"Compute 'Scaled Dot Product Attention'"
|
154 |
-
d_k = query.size(-1)
|
155 |
-
scores = torch.matmul(query, key.transpose(-2, -1)) \
|
156 |
-
/ math.sqrt(d_k)
|
157 |
-
if mask is not None:
|
158 |
-
scores = scores.masked_fill(mask == 0, float('-inf'))
|
159 |
-
p_attn = F.softmax(scores, dim = -1)
|
160 |
-
if dropout is not None:
|
161 |
-
p_attn = dropout(p_attn)
|
162 |
-
return torch.matmul(p_attn, value), p_attn
|
163 |
-
|
164 |
-
class MultiHeadedAttention(nn.Module):
|
165 |
-
def __init__(self, h, d_model, dropout=0.1):
|
166 |
-
"Take in model size and number of heads."
|
167 |
-
super(MultiHeadedAttention, self).__init__()
|
168 |
-
assert d_model % h == 0
|
169 |
-
# We assume d_v always equals d_k
|
170 |
-
self.d_k = d_model // h
|
171 |
-
self.h = h
|
172 |
-
self.linears = clones(nn.Linear(d_model, d_model), 4)
|
173 |
-
self.attn = None
|
174 |
-
self.dropout = nn.Dropout(p=dropout)
|
175 |
-
|
176 |
-
def forward(self, query, key, value, mask=None):
|
177 |
-
"Implements Figure 2"
|
178 |
-
if mask is not None:
|
179 |
-
# Same mask applied to all h heads.
|
180 |
-
mask = mask.unsqueeze(1)
|
181 |
-
nbatches = query.size(0)
|
182 |
-
|
183 |
-
# 1) Do all the linear projections in batch from d_model => h x d_k
|
184 |
-
query, key, value = \
|
185 |
-
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
186 |
-
for l, x in zip(self.linears, (query, key, value))]
|
187 |
-
|
188 |
-
# 2) Apply attention on all the projected vectors in batch.
|
189 |
-
x, self.attn = attention(query, key, value, mask=mask,
|
190 |
-
dropout=self.dropout)
|
191 |
-
|
192 |
-
# 3) "Concat" using a view and apply a final linear.
|
193 |
-
x = x.transpose(1, 2).contiguous() \
|
194 |
-
.view(nbatches, -1, self.h * self.d_k)
|
195 |
-
return self.linears[-1](x)
|
196 |
-
|
197 |
-
class PositionwiseFeedForward(nn.Module):
|
198 |
-
"Implements FFN equation."
|
199 |
-
def __init__(self, d_model, d_ff, dropout=0.1):
|
200 |
-
super(PositionwiseFeedForward, self).__init__()
|
201 |
-
self.w_1 = nn.Linear(d_model, d_ff)
|
202 |
-
self.w_2 = nn.Linear(d_ff, d_model)
|
203 |
-
self.dropout = nn.Dropout(dropout)
|
204 |
-
|
205 |
-
def forward(self, x):
|
206 |
-
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
207 |
-
|
208 |
-
class Embeddings(nn.Module):
|
209 |
-
def __init__(self, d_model, vocab):
|
210 |
-
super(Embeddings, self).__init__()
|
211 |
-
self.lut = nn.Embedding(vocab, d_model)
|
212 |
-
self.d_model = d_model
|
213 |
-
|
214 |
-
def forward(self, x):
|
215 |
-
return self.lut(x) * math.sqrt(self.d_model)
|
216 |
-
|
217 |
-
class PositionalEncoding(nn.Module):
|
218 |
-
"Implement the PE function."
|
219 |
-
def __init__(self, d_model, dropout, max_len=5000):
|
220 |
-
super(PositionalEncoding, self).__init__()
|
221 |
-
self.dropout = nn.Dropout(p=dropout)
|
222 |
-
|
223 |
-
# Compute the positional encodings once in log space.
|
224 |
-
pe = torch.zeros(max_len, d_model)
|
225 |
-
position = torch.arange(0, max_len).unsqueeze(1).float()
|
226 |
-
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
227 |
-
-(math.log(10000.0) / d_model))
|
228 |
-
pe[:, 0::2] = torch.sin(position * div_term)
|
229 |
-
pe[:, 1::2] = torch.cos(position * div_term)
|
230 |
-
pe = pe.unsqueeze(0)
|
231 |
-
self.register_buffer('pe', pe)
|
232 |
-
|
233 |
-
def forward(self, x):
|
234 |
-
x = x + self.pe[:, :x.size(1)]
|
235 |
-
return self.dropout(x)
|
236 |
-
|
237 |
-
class TransformerModel(AttModel):
|
238 |
-
|
239 |
-
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
|
240 |
-
d_model=512, d_ff=2048, h=8, dropout=0.1):
|
241 |
-
"Helper: Construct a model from hyperparameters."
|
242 |
-
c = copy.deepcopy
|
243 |
-
attn = MultiHeadedAttention(h, d_model, dropout)
|
244 |
-
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
|
245 |
-
position = PositionalEncoding(d_model, dropout)
|
246 |
-
model = EncoderDecoder(
|
247 |
-
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc),
|
248 |
-
Decoder(DecoderLayer(d_model, c(attn), c(attn),
|
249 |
-
c(ff), dropout), N_dec),
|
250 |
-
lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
|
251 |
-
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
|
252 |
-
Generator(d_model, tgt_vocab))
|
253 |
-
|
254 |
-
# This was important from their code.
|
255 |
-
# Initialize parameters with Glorot / fan_avg.
|
256 |
-
for p in model.parameters():
|
257 |
-
if p.dim() > 1:
|
258 |
-
nn.init.xavier_uniform_(p)
|
259 |
-
return model
|
260 |
-
|
261 |
-
def __init__(self, opt):
|
262 |
-
super(TransformerModel, self).__init__(opt)
|
263 |
-
self.opt = opt
|
264 |
-
# self.config = yaml.load(open(opt.config_file))
|
265 |
-
|
266 |
-
self.N_enc = getattr(opt, 'N_enc', opt.num_layers)
|
267 |
-
self.N_dec = getattr(opt, 'N_dec', opt.num_layers)
|
268 |
-
self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
|
269 |
-
self.d_ff = getattr(opt, 'd_ff', opt.rnn_size)
|
270 |
-
self.h = getattr(opt, 'num_att_heads', 8)
|
271 |
-
self.dropout = getattr(opt, 'dropout', 0.1)
|
272 |
-
|
273 |
-
delattr(self, 'att_embed')
|
274 |
-
self.att_embed = nn.Sequential(*(
|
275 |
-
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
|
276 |
-
(nn.Linear(self.att_feat_size, self.d_model),
|
277 |
-
nn.ReLU(),
|
278 |
-
nn.Dropout(self.drop_prob_lm))+
|
279 |
-
((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ())))
|
280 |
-
|
281 |
-
delattr(self, 'embed')
|
282 |
-
self.embed = lambda x : x
|
283 |
-
delattr(self, 'fc_embed')
|
284 |
-
self.fc_embed = lambda x : x
|
285 |
-
delattr(self, 'logit')
|
286 |
-
del self.ctx2att
|
287 |
-
|
288 |
-
tgt_vocab = self.vocab_size + 1
|
289 |
-
|
290 |
-
|
291 |
-
self.model = self.make_model(0, tgt_vocab,
|
292 |
-
N_enc=self.N_enc,
|
293 |
-
N_dec=self.N_dec,
|
294 |
-
d_model=self.d_model,
|
295 |
-
d_ff=self.d_ff,
|
296 |
-
h=self.h,
|
297 |
-
dropout=self.dropout)
|
298 |
-
|
299 |
-
def logit(self, x): # unsafe way
|
300 |
-
return self.model.generator.proj(x)
|
301 |
-
|
302 |
-
def init_hidden(self, bsz):
|
303 |
-
return []
|
304 |
-
|
305 |
-
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
306 |
-
|
307 |
-
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
|
308 |
-
memory = self.model.encode(att_feats, att_masks)
|
309 |
-
|
310 |
-
return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
|
311 |
-
|
312 |
-
def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
|
313 |
-
att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
314 |
-
|
315 |
-
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
316 |
-
|
317 |
-
if att_masks is None:
|
318 |
-
att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
|
319 |
-
att_masks = att_masks.unsqueeze(-2)
|
320 |
-
|
321 |
-
if seq is not None:
|
322 |
-
# crop the last one
|
323 |
-
# seq = seq[:,:-1]
|
324 |
-
seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx)
|
325 |
-
seq_mask[:,0] = 1 # bos
|
326 |
-
|
327 |
-
seq_mask = seq_mask.unsqueeze(-2)
|
328 |
-
seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
|
329 |
-
|
330 |
-
seq_per_img = seq.shape[0] // att_feats.shape[0]
|
331 |
-
if seq_per_img > 1:
|
332 |
-
att_feats, att_masks = utils.repeat_tensors(seq_per_img,
|
333 |
-
[att_feats, att_masks]
|
334 |
-
)
|
335 |
-
else:
|
336 |
-
seq_mask = None
|
337 |
-
|
338 |
-
return att_feats, seq, att_masks, seq_mask
|
339 |
-
|
340 |
-
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
341 |
-
if seq.ndim == 3: # B * seq_per_img * seq_len
|
342 |
-
seq = seq.reshape(-1, seq.shape[2])
|
343 |
-
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
|
344 |
-
|
345 |
-
out = self.model(att_feats, seq, att_masks, seq_mask)
|
346 |
-
|
347 |
-
outputs = self.model.generator(out)
|
348 |
-
return outputs
|
349 |
-
# return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
|
350 |
-
|
351 |
-
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
|
352 |
-
"""
|
353 |
-
state = [ys.unsqueeze(0)]
|
354 |
-
"""
|
355 |
-
if len(state) == 0:
|
356 |
-
ys = it.unsqueeze(1)
|
357 |
-
else:
|
358 |
-
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
|
359 |
-
out = self.model.decode(memory, mask,
|
360 |
-
ys,
|
361 |
-
subsequent_mask(ys.size(1))
|
362 |
-
.to(memory.device))
|
363 |
-
return out[:, -1], [ys.unsqueeze(0)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/models/__init__.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
from __future__ import absolute_import
|
2 |
-
from __future__ import division
|
3 |
-
from __future__ import print_function
|
4 |
-
|
5 |
-
import os
|
6 |
-
import copy
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import torch
|
10 |
-
|
11 |
-
from .ShowTellModel import ShowTellModel
|
12 |
-
from .FCModel import FCModel
|
13 |
-
from .AttModel import *
|
14 |
-
from .TransformerModel import TransformerModel
|
15 |
-
from .cachedTransformer import TransformerModel as cachedTransformer
|
16 |
-
from .BertCapModel import BertCapModel
|
17 |
-
from .M2Transformer import M2TransformerModel
|
18 |
-
from .AoAModel import AoAModel
|
19 |
-
|
20 |
-
def setup(opt):
|
21 |
-
if opt.caption_model in ['fc', 'show_tell']:
|
22 |
-
print('Warning: %s model is mostly deprecated; many new features are not supported.' %opt.caption_model)
|
23 |
-
if opt.caption_model == 'fc':
|
24 |
-
print('Use newfc instead of fc')
|
25 |
-
if opt.caption_model == 'fc':
|
26 |
-
model = FCModel(opt)
|
27 |
-
elif opt.caption_model == 'language_model':
|
28 |
-
model = LMModel(opt)
|
29 |
-
elif opt.caption_model == 'newfc':
|
30 |
-
model = NewFCModel(opt)
|
31 |
-
elif opt.caption_model == 'show_tell':
|
32 |
-
model = ShowTellModel(opt)
|
33 |
-
# Att2in model in self-critical
|
34 |
-
elif opt.caption_model == 'att2in':
|
35 |
-
model = Att2inModel(opt)
|
36 |
-
# Att2in model with two-layer MLP img embedding and word embedding
|
37 |
-
elif opt.caption_model == 'att2in2':
|
38 |
-
model = Att2in2Model(opt)
|
39 |
-
elif opt.caption_model == 'att2all2':
|
40 |
-
print('Warning: this is not a correct implementation of the att2all model in the original paper.')
|
41 |
-
model = Att2all2Model(opt)
|
42 |
-
# Adaptive Attention model from Knowing when to look
|
43 |
-
elif opt.caption_model == 'adaatt':
|
44 |
-
model = AdaAttModel(opt)
|
45 |
-
# Adaptive Attention with maxout lstm
|
46 |
-
elif opt.caption_model == 'adaattmo':
|
47 |
-
model = AdaAttMOModel(opt)
|
48 |
-
# Top-down attention model
|
49 |
-
elif opt.caption_model in ['topdown', 'updown']:
|
50 |
-
model = UpDownModel(opt)
|
51 |
-
# StackAtt
|
52 |
-
elif opt.caption_model == 'stackatt':
|
53 |
-
model = StackAttModel(opt)
|
54 |
-
# DenseAtt
|
55 |
-
elif opt.caption_model == 'denseatt':
|
56 |
-
model = DenseAttModel(opt)
|
57 |
-
# Transformer
|
58 |
-
elif opt.caption_model == 'transformer':
|
59 |
-
if getattr(opt, 'cached_transformer', False):
|
60 |
-
model = cachedTransformer(opt)
|
61 |
-
else:
|
62 |
-
model = TransformerModel(opt)
|
63 |
-
# AoANet
|
64 |
-
elif opt.caption_model == 'aoa':
|
65 |
-
model = AoAModel(opt)
|
66 |
-
elif opt.caption_model == 'bert':
|
67 |
-
model = BertCapModel(opt)
|
68 |
-
elif opt.caption_model == 'm2transformer':
|
69 |
-
model = M2TransformerModel(opt)
|
70 |
-
else:
|
71 |
-
raise Exception("Caption model not supported: {}".format(opt.caption_model))
|
72 |
-
|
73 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/models/cachedTransformer.py
DELETED
@@ -1,420 +0,0 @@
|
|
1 |
-
# This file contains Transformer network
|
2 |
-
# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html
|
3 |
-
|
4 |
-
# The cfg name correspondance:
|
5 |
-
# N=num_layers
|
6 |
-
# d_model=input_encoding_size
|
7 |
-
# d_ff=rnn_size
|
8 |
-
# h is always 8
|
9 |
-
|
10 |
-
from __future__ import absolute_import
|
11 |
-
from __future__ import division
|
12 |
-
from __future__ import print_function
|
13 |
-
|
14 |
-
import torch
|
15 |
-
import torch.nn as nn
|
16 |
-
import torch.nn.functional as F
|
17 |
-
from . import utils
|
18 |
-
|
19 |
-
import copy
|
20 |
-
import math
|
21 |
-
import numpy as np
|
22 |
-
|
23 |
-
from .CaptionModel import CaptionModel
|
24 |
-
from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
|
25 |
-
|
26 |
-
class EncoderDecoder(nn.Module):
|
27 |
-
"""
|
28 |
-
A standard Encoder-Decoder architecture. Base for this and many
|
29 |
-
other models.
|
30 |
-
"""
|
31 |
-
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
|
32 |
-
super(EncoderDecoder, self).__init__()
|
33 |
-
self.encoder = encoder
|
34 |
-
self.decoder = decoder
|
35 |
-
self.src_embed = src_embed
|
36 |
-
self.tgt_embed = tgt_embed
|
37 |
-
self.generator = generator
|
38 |
-
|
39 |
-
def forward(self, src, tgt, src_mask, tgt_mask):
|
40 |
-
"Take in and process masked src and target sequences."
|
41 |
-
return self.decode(self.encode(src, src_mask), src_mask,
|
42 |
-
tgt, tgt_mask)
|
43 |
-
|
44 |
-
def encode(self, src, src_mask):
|
45 |
-
return self.encoder(self.src_embed(src), src_mask)
|
46 |
-
|
47 |
-
def decode(self, memory, src_mask, tgt, tgt_mask, past=None):
|
48 |
-
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask, past=past)
|
49 |
-
|
50 |
-
class Generator(nn.Module):
|
51 |
-
"Define standard linear + softmax generation step."
|
52 |
-
def __init__(self, d_model, vocab):
|
53 |
-
super(Generator, self).__init__()
|
54 |
-
self.proj = nn.Linear(d_model, vocab)
|
55 |
-
|
56 |
-
def forward(self, x):
|
57 |
-
return F.log_softmax(self.proj(x), dim=-1)
|
58 |
-
|
59 |
-
def clones(module, N):
|
60 |
-
"Produce N identical layers."
|
61 |
-
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
62 |
-
|
63 |
-
class Encoder(nn.Module):
|
64 |
-
"Core encoder is a stack of N layers"
|
65 |
-
def __init__(self, layer, N):
|
66 |
-
super(Encoder, self).__init__()
|
67 |
-
self.layers = clones(layer, N)
|
68 |
-
self.norm = LayerNorm(layer.size)
|
69 |
-
|
70 |
-
def forward(self, x, mask):
|
71 |
-
"Pass the input (and mask) through each layer in turn."
|
72 |
-
for layer in self.layers:
|
73 |
-
x = layer(x, mask)
|
74 |
-
return self.norm(x)
|
75 |
-
|
76 |
-
class LayerNorm(nn.Module):
|
77 |
-
"Construct a layernorm module (See citation for details)."
|
78 |
-
def __init__(self, features, eps=1e-6):
|
79 |
-
super(LayerNorm, self).__init__()
|
80 |
-
self.a_2 = nn.Parameter(torch.ones(features))
|
81 |
-
self.b_2 = nn.Parameter(torch.zeros(features))
|
82 |
-
self.eps = eps
|
83 |
-
|
84 |
-
def forward(self, x):
|
85 |
-
mean = x.mean(-1, keepdim=True)
|
86 |
-
std = x.std(-1, keepdim=True)
|
87 |
-
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
88 |
-
|
89 |
-
class SublayerConnection(nn.Module):
|
90 |
-
"""
|
91 |
-
A residual connection followed by a layer norm.
|
92 |
-
Note for code simplicity the norm is first as opposed to last.
|
93 |
-
"""
|
94 |
-
def __init__(self, size, dropout):
|
95 |
-
super(SublayerConnection, self).__init__()
|
96 |
-
self.norm = LayerNorm(size)
|
97 |
-
self.dropout = nn.Dropout(dropout)
|
98 |
-
|
99 |
-
def forward(self, x, sublayer):
|
100 |
-
"Apply residual connection to any sublayer with the same size."
|
101 |
-
_x = sublayer(self.norm(x))
|
102 |
-
if type(_x) is tuple: # for multi-head attention that returns past
|
103 |
-
return x + self.dropout(_x[0]), _x[1]
|
104 |
-
return x + self.dropout(_x)
|
105 |
-
|
106 |
-
class EncoderLayer(nn.Module):
|
107 |
-
"Encoder is made up of self-attn and feed forward (defined below)"
|
108 |
-
def __init__(self, size, self_attn, feed_forward, dropout):
|
109 |
-
super(EncoderLayer, self).__init__()
|
110 |
-
self.self_attn = self_attn
|
111 |
-
self.feed_forward = feed_forward
|
112 |
-
self.sublayer = clones(SublayerConnection(size, dropout), 2)
|
113 |
-
self.size = size
|
114 |
-
|
115 |
-
def forward(self, x, mask):
|
116 |
-
"Follow Figure 1 (left) for connections."
|
117 |
-
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
|
118 |
-
return self.sublayer[1](x, self.feed_forward)
|
119 |
-
|
120 |
-
class Decoder(nn.Module):
|
121 |
-
"Generic N layer decoder with masking."
|
122 |
-
def __init__(self, layer, N):
|
123 |
-
super(Decoder, self).__init__()
|
124 |
-
self.layers = clones(layer, N)
|
125 |
-
self.norm = LayerNorm(layer.size)
|
126 |
-
|
127 |
-
def forward(self, x, memory, src_mask, tgt_mask, past=None):
|
128 |
-
if past is not None:
|
129 |
-
present = [[], []]
|
130 |
-
x = x[:, -1:]
|
131 |
-
tgt_mask = tgt_mask[:, -1:] if tgt_mask is not None else None
|
132 |
-
past = list(zip(past[0].split(2, dim=0), past[1].split(2, dim=0)))
|
133 |
-
else:
|
134 |
-
past = [None] * len(self.layers)
|
135 |
-
for i, (layer, layer_past) in enumerate(zip(self.layers, past)):
|
136 |
-
x = layer(x, memory, src_mask, tgt_mask,
|
137 |
-
layer_past)
|
138 |
-
if layer_past is not None:
|
139 |
-
present[0].append(x[1][0])
|
140 |
-
present[1].append(x[1][1])
|
141 |
-
x = x[0]
|
142 |
-
if past[0] is None:
|
143 |
-
return self.norm(x)
|
144 |
-
else:
|
145 |
-
return self.norm(x), [torch.cat(present[0], 0), torch.cat(present[1], 0)]
|
146 |
-
|
147 |
-
|
148 |
-
class DecoderLayer(nn.Module):
|
149 |
-
"Decoder is made of self-attn, src-attn, and feed forward (defined below)"
|
150 |
-
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
|
151 |
-
super(DecoderLayer, self).__init__()
|
152 |
-
self.size = size
|
153 |
-
self.self_attn = self_attn
|
154 |
-
self.src_attn = src_attn
|
155 |
-
self.feed_forward = feed_forward
|
156 |
-
self.sublayer = clones(SublayerConnection(size, dropout), 3)
|
157 |
-
|
158 |
-
def forward(self, x, memory, src_mask, tgt_mask, layer_past=None):
|
159 |
-
"Follow Figure 1 (right) for connections."
|
160 |
-
m = memory
|
161 |
-
if layer_past is None:
|
162 |
-
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
|
163 |
-
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
|
164 |
-
return self.sublayer[2](x, self.feed_forward)
|
165 |
-
else:
|
166 |
-
present = [None, None]
|
167 |
-
x, present[0] = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, layer_past[0]))
|
168 |
-
x, present[1] = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask, layer_past[1]))
|
169 |
-
return self.sublayer[2](x, self.feed_forward), present
|
170 |
-
|
171 |
-
def subsequent_mask(size):
|
172 |
-
"Mask out subsequent positions."
|
173 |
-
attn_shape = (1, size, size)
|
174 |
-
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
|
175 |
-
return torch.from_numpy(subsequent_mask) == 0
|
176 |
-
|
177 |
-
def attention(query, key, value, mask=None, dropout=None):
|
178 |
-
"Compute 'Scaled Dot Product Attention'"
|
179 |
-
d_k = query.size(-1)
|
180 |
-
scores = torch.matmul(query, key.transpose(-2, -1)) \
|
181 |
-
/ math.sqrt(d_k)
|
182 |
-
if mask is not None:
|
183 |
-
scores = scores.masked_fill(mask == 0, float('-inf'))
|
184 |
-
p_attn = F.softmax(scores, dim = -1)
|
185 |
-
if dropout is not None:
|
186 |
-
p_attn = dropout(p_attn)
|
187 |
-
return torch.matmul(p_attn, value), p_attn
|
188 |
-
|
189 |
-
class MultiHeadedAttention(nn.Module):
|
190 |
-
def __init__(self, h, d_model, dropout=0.1):
|
191 |
-
"Take in model size and number of heads."
|
192 |
-
super(MultiHeadedAttention, self).__init__()
|
193 |
-
assert d_model % h == 0
|
194 |
-
# We assume d_v always equals d_k
|
195 |
-
self.d_k = d_model // h
|
196 |
-
self.h = h
|
197 |
-
self.linears = clones(nn.Linear(d_model, d_model), 4)
|
198 |
-
self.attn = None
|
199 |
-
self.dropout = nn.Dropout(p=dropout)
|
200 |
-
|
201 |
-
def forward(self, query, key, value, mask=None, layer_past=None):
|
202 |
-
"Implements Figure 2"
|
203 |
-
if mask is not None:
|
204 |
-
# Same mask applied to all h heads.
|
205 |
-
mask = mask.unsqueeze(1)
|
206 |
-
nbatches = query.size(0)
|
207 |
-
|
208 |
-
# The past works differently here. For self attn, the query and key be updated incrementailly
|
209 |
-
# For src_attn the past is fixed.
|
210 |
-
|
211 |
-
# For src_attn, when the layer past is ready
|
212 |
-
if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1: # suppose memory size always greater than 1
|
213 |
-
query = self.linears[0](query)
|
214 |
-
key, value = layer_past[0], layer_past[1]
|
215 |
-
present = torch.stack([key, value])
|
216 |
-
else:
|
217 |
-
# 1) Do all the linear projections in batch from d_model => h x d_k
|
218 |
-
query, key, value = \
|
219 |
-
[l(x) for l, x in zip(self.linears, (query, key, value))]
|
220 |
-
|
221 |
-
# self attn + past OR the first time step of src attn
|
222 |
-
if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1):
|
223 |
-
past_key, past_value = layer_past[0], layer_past[1]
|
224 |
-
key = torch.cat((past_key, key), dim=1)
|
225 |
-
value = torch.cat((past_value, value), dim=1)
|
226 |
-
present = torch.stack([key, value])
|
227 |
-
|
228 |
-
query, key, value = \
|
229 |
-
[x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
230 |
-
for x in [query, key, value]]
|
231 |
-
|
232 |
-
# 2) Apply attention on all the projected vectors in batch.
|
233 |
-
x, self.attn = attention(query, key, value, mask=mask,
|
234 |
-
dropout=self.dropout)
|
235 |
-
|
236 |
-
# 3) "Concat" using a view and apply a final linear.
|
237 |
-
x = x.transpose(1, 2).contiguous() \
|
238 |
-
.view(nbatches, -1, self.h * self.d_k)
|
239 |
-
if layer_past is not None:
|
240 |
-
return self.linears[-1](x), present
|
241 |
-
else:
|
242 |
-
return self.linears[-1](x)
|
243 |
-
|
244 |
-
class PositionwiseFeedForward(nn.Module):
|
245 |
-
"Implements FFN equation."
|
246 |
-
def __init__(self, d_model, d_ff, dropout=0.1):
|
247 |
-
super(PositionwiseFeedForward, self).__init__()
|
248 |
-
self.w_1 = nn.Linear(d_model, d_ff)
|
249 |
-
self.w_2 = nn.Linear(d_ff, d_model)
|
250 |
-
self.dropout = nn.Dropout(dropout)
|
251 |
-
|
252 |
-
def forward(self, x):
|
253 |
-
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
254 |
-
|
255 |
-
class Embeddings(nn.Module):
|
256 |
-
def __init__(self, d_model, vocab):
|
257 |
-
super(Embeddings, self).__init__()
|
258 |
-
self.lut = nn.Embedding(vocab, d_model)
|
259 |
-
self.d_model = d_model
|
260 |
-
|
261 |
-
def forward(self, x):
|
262 |
-
return self.lut(x) * math.sqrt(self.d_model)
|
263 |
-
|
264 |
-
class PositionalEncoding(nn.Module):
|
265 |
-
"Implement the PE function."
|
266 |
-
def __init__(self, d_model, dropout, max_len=5000):
|
267 |
-
super(PositionalEncoding, self).__init__()
|
268 |
-
self.dropout = nn.Dropout(p=dropout)
|
269 |
-
|
270 |
-
# Compute the positional encodings once in log space.
|
271 |
-
pe = torch.zeros(max_len, d_model)
|
272 |
-
position = torch.arange(0, max_len).unsqueeze(1).float()
|
273 |
-
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
274 |
-
-(math.log(10000.0) / d_model))
|
275 |
-
pe[:, 0::2] = torch.sin(position * div_term)
|
276 |
-
pe[:, 1::2] = torch.cos(position * div_term)
|
277 |
-
pe = pe.unsqueeze(0)
|
278 |
-
self.register_buffer('pe', pe)
|
279 |
-
|
280 |
-
def forward(self, x):
|
281 |
-
x = x + self.pe[:, :x.size(1)]
|
282 |
-
return self.dropout(x)
|
283 |
-
|
284 |
-
class TransformerModel(AttModel):
|
285 |
-
|
286 |
-
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
|
287 |
-
d_model=512, d_ff=2048, h=8, dropout=0.1):
|
288 |
-
"Helper: Construct a model from hyperparameters."
|
289 |
-
c = copy.deepcopy
|
290 |
-
attn = MultiHeadedAttention(h, d_model, dropout)
|
291 |
-
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
|
292 |
-
position = PositionalEncoding(d_model, dropout)
|
293 |
-
model = EncoderDecoder(
|
294 |
-
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc),
|
295 |
-
Decoder(DecoderLayer(d_model, c(attn), c(attn),
|
296 |
-
c(ff), dropout), N_dec),
|
297 |
-
lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
|
298 |
-
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
|
299 |
-
Generator(d_model, tgt_vocab))
|
300 |
-
|
301 |
-
# This was important from their code.
|
302 |
-
# Initialize parameters with Glorot / fan_avg.
|
303 |
-
for p in model.parameters():
|
304 |
-
if p.dim() > 1:
|
305 |
-
nn.init.xavier_uniform_(p)
|
306 |
-
return model
|
307 |
-
|
308 |
-
def __init__(self, opt):
|
309 |
-
super(TransformerModel, self).__init__(opt)
|
310 |
-
self.opt = opt
|
311 |
-
# self.config = yaml.load(open(opt.config_file))
|
312 |
-
|
313 |
-
self.N_enc = getattr(opt, 'N_enc', opt.num_layers)
|
314 |
-
self.N_dec = getattr(opt, 'N_dec', opt.num_layers)
|
315 |
-
self.d_model = getattr(opt, 'd_model', opt.input_encoding_size)
|
316 |
-
self.d_ff = getattr(opt, 'd_ff', opt.rnn_size)
|
317 |
-
self.h = getattr(opt, 'num_att_heads', 8)
|
318 |
-
self.dropout = getattr(opt, 'dropout', 0.1)
|
319 |
-
|
320 |
-
delattr(self, 'att_embed')
|
321 |
-
self.att_embed = nn.Sequential(*(
|
322 |
-
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
|
323 |
-
(nn.Linear(self.att_feat_size, self.d_model),
|
324 |
-
nn.ReLU(),
|
325 |
-
nn.Dropout(self.drop_prob_lm))+
|
326 |
-
((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ())))
|
327 |
-
|
328 |
-
delattr(self, 'embed')
|
329 |
-
self.embed = lambda x : x
|
330 |
-
delattr(self, 'fc_embed')
|
331 |
-
self.fc_embed = lambda x : x
|
332 |
-
delattr(self, 'logit')
|
333 |
-
del self.ctx2att
|
334 |
-
|
335 |
-
tgt_vocab = self.vocab_size + 1
|
336 |
-
|
337 |
-
|
338 |
-
self.model = self.make_model(0, tgt_vocab,
|
339 |
-
N_enc=self.N_enc,
|
340 |
-
N_dec=self.N_dec,
|
341 |
-
d_model=self.d_model,
|
342 |
-
d_ff=self.d_ff,
|
343 |
-
h=self.h,
|
344 |
-
dropout=self.dropout)
|
345 |
-
|
346 |
-
def logit(self, x): # unsafe way
|
347 |
-
return self.model.generator.proj(x)
|
348 |
-
|
349 |
-
def init_hidden(self, bsz):
|
350 |
-
return []
|
351 |
-
|
352 |
-
def _prepare_feature(self, fc_feats, att_feats, att_masks):
|
353 |
-
|
354 |
-
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks)
|
355 |
-
memory = self.model.encode(att_feats, att_masks)
|
356 |
-
|
357 |
-
return fc_feats[...,:0], att_feats[...,:0], memory, att_masks
|
358 |
-
|
359 |
-
def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None):
|
360 |
-
att_feats, att_masks = self.clip_att(att_feats, att_masks)
|
361 |
-
|
362 |
-
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
|
363 |
-
|
364 |
-
if att_masks is None:
|
365 |
-
att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long)
|
366 |
-
att_masks = att_masks.unsqueeze(-2)
|
367 |
-
|
368 |
-
if seq is not None:
|
369 |
-
# crop the last one
|
370 |
-
# seq = seq[:,:-1]
|
371 |
-
seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx)
|
372 |
-
seq_mask[:,0] = 1 # bos
|
373 |
-
|
374 |
-
seq_mask = seq_mask.unsqueeze(-2)
|
375 |
-
seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
|
376 |
-
|
377 |
-
seq_per_img = seq.shape[0] // att_feats.shape[0]
|
378 |
-
if seq_per_img > 1:
|
379 |
-
att_feats, att_masks = utils.repeat_tensors(seq_per_img,
|
380 |
-
[att_feats, att_masks]
|
381 |
-
)
|
382 |
-
else:
|
383 |
-
seq_mask = None
|
384 |
-
|
385 |
-
return att_feats, seq, att_masks, seq_mask
|
386 |
-
|
387 |
-
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
|
388 |
-
if seq.ndim == 3: # B * seq_per_img * seq_len
|
389 |
-
seq = seq.reshape(-1, seq.shape[2])
|
390 |
-
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq)
|
391 |
-
|
392 |
-
out = self.model(att_feats, seq, att_masks, seq_mask)
|
393 |
-
|
394 |
-
outputs = self.model.generator(out)
|
395 |
-
return outputs
|
396 |
-
# return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
|
397 |
-
|
398 |
-
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
|
399 |
-
"""
|
400 |
-
state is the precomputed key/value. N_dec x seq_len x d_model
|
401 |
-
Note: due to the layer norm, it's not equivalant to stateless,
|
402 |
-
but it seems behaving similar
|
403 |
-
"""
|
404 |
-
# state is tokens + past
|
405 |
-
if len(state) == 0:
|
406 |
-
ys = it.unsqueeze(1)
|
407 |
-
# basically empty state, just to let it know to return past
|
408 |
-
# The second dim has to be batch_size, for beam search purpose
|
409 |
-
past = [fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model), # self
|
410 |
-
fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model)] # src
|
411 |
-
# 2 for self attn, 2 for src attn
|
412 |
-
else:
|
413 |
-
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
|
414 |
-
past = state[1:]
|
415 |
-
out, past = self.model.decode(memory, mask,
|
416 |
-
ys, # We still feed the full past words, because we need it for position embedding to know the position id
|
417 |
-
subsequent_mask(ys.size(1))
|
418 |
-
.to(memory.device),
|
419 |
-
past=past)
|
420 |
-
return out[:, -1], [ys.unsqueeze(0)] + past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/models/utils.py
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
def repeat_tensors(n, x):
|
4 |
-
"""
|
5 |
-
For a tensor of size Bx..., we repeat it n times, and make it Bnx...
|
6 |
-
For collections, do nested repeat
|
7 |
-
"""
|
8 |
-
if torch.is_tensor(x):
|
9 |
-
x = x.unsqueeze(1) # Bx1x...
|
10 |
-
x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx...
|
11 |
-
x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx...
|
12 |
-
elif type(x) is list or type(x) is tuple:
|
13 |
-
x = [repeat_tensors(n, _) for _ in x]
|
14 |
-
return x
|
15 |
-
|
16 |
-
|
17 |
-
def split_tensors(n, x):
|
18 |
-
if torch.is_tensor(x):
|
19 |
-
assert x.shape[0] % n == 0
|
20 |
-
x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
|
21 |
-
elif type(x) is list or type(x) is tuple:
|
22 |
-
x = [split_tensors(n, _) for _ in x]
|
23 |
-
elif x is None:
|
24 |
-
x = [None] * n
|
25 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/modules/loss_wrapper.py
DELETED
@@ -1,127 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from . import losses
|
3 |
-
from ..utils.rewards import init_scorer, get_self_critical_reward, get_self_critical_clipscore_reward
|
4 |
-
from ..utils.clipscore import CLIPScore
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
class LossWrapper(torch.nn.Module):
|
8 |
-
def __init__(self, model, opt):
|
9 |
-
super(LossWrapper, self).__init__()
|
10 |
-
self.opt = opt
|
11 |
-
self.model = model
|
12 |
-
if opt.label_smoothing > 0:
|
13 |
-
self.crit = losses.LabelSmoothing(smoothing=opt.label_smoothing)
|
14 |
-
else:
|
15 |
-
self.crit = losses.LanguageModelCriterion()
|
16 |
-
self.rl_crit = losses.RewardCriterion()
|
17 |
-
self.struc_crit = losses.StructureLosses(opt)
|
18 |
-
|
19 |
-
self.clipscore_model = None
|
20 |
-
if self.opt.use_clipscore:
|
21 |
-
use_grammar = getattr(self.opt, 'use_grammar', False)
|
22 |
-
joint_out = getattr(self.opt, 'joint_out', False)
|
23 |
-
self.clipscore_model = CLIPScore(
|
24 |
-
mode=opt.clipscore_mode,
|
25 |
-
use_grammar=use_grammar,
|
26 |
-
joint_out=joint_out,
|
27 |
-
)
|
28 |
-
for p in self.clipscore_model.parameters():
|
29 |
-
p.requires_grad = False
|
30 |
-
|
31 |
-
if use_grammar:
|
32 |
-
state_dict = torch.load(self.opt.clip_load_path, map_location='cpu')
|
33 |
-
self.clipscore_model.load_state_dict(state_dict['state_dict'])
|
34 |
-
|
35 |
-
def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices,
|
36 |
-
sc_flag, struc_flag, clip_vis_feats=None):
|
37 |
-
opt = self.opt
|
38 |
-
|
39 |
-
out = {}
|
40 |
-
if struc_flag:
|
41 |
-
if opt.structure_loss_weight < 1:
|
42 |
-
lm_loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
|
43 |
-
else:
|
44 |
-
lm_loss = torch.tensor(0).type_as(fc_feats)
|
45 |
-
if opt.structure_loss_weight > 0:
|
46 |
-
gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
|
47 |
-
opt={'sample_method':opt.train_sample_method,
|
48 |
-
'beam_size':opt.train_beam_size,
|
49 |
-
'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\
|
50 |
-
or not 'margin' in opt.structure_loss_type,
|
51 |
-
'sample_n': opt.train_sample_n},
|
52 |
-
mode='sample')
|
53 |
-
gts = [gts[_] for _ in gt_indices.tolist()]
|
54 |
-
struc_loss = self.struc_crit(sample_logprobs, gen_result, gts)
|
55 |
-
else:
|
56 |
-
struc_loss = {'loss': torch.tensor(0).type_as(fc_feats),
|
57 |
-
'reward': torch.tensor(0).type_as(fc_feats)}
|
58 |
-
loss = (1-opt.structure_loss_weight) * lm_loss + opt.structure_loss_weight * struc_loss['loss']
|
59 |
-
out['lm_loss'] = lm_loss
|
60 |
-
out['struc_loss'] = struc_loss['loss']
|
61 |
-
out['reward'] = struc_loss['reward']
|
62 |
-
elif not sc_flag:
|
63 |
-
loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
|
64 |
-
else:
|
65 |
-
self.model.eval()
|
66 |
-
with torch.no_grad():
|
67 |
-
greedy_res, _ = self.model(fc_feats, att_feats, att_masks,
|
68 |
-
mode='sample',
|
69 |
-
opt={'sample_method': opt.sc_sample_method,
|
70 |
-
'beam_size': opt.sc_beam_size})
|
71 |
-
self.model.train()
|
72 |
-
gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
|
73 |
-
opt={'sample_method':opt.train_sample_method,
|
74 |
-
'beam_size':opt.train_beam_size,
|
75 |
-
'sample_n': opt.train_sample_n},
|
76 |
-
mode='sample')
|
77 |
-
gts = [gts[_] for _ in gt_indices.tolist()]
|
78 |
-
|
79 |
-
if getattr(self.opt, 'use_multi_rewards', False):
|
80 |
-
assert self.opt.use_clipscore
|
81 |
-
clipscore_reward_normalized, clipscore_unnormalized_mean, grammar_rewards = get_self_critical_clipscore_reward(
|
82 |
-
greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab)
|
83 |
-
|
84 |
-
if self.opt.clipscore_mode == 'clip_s':
|
85 |
-
out['CLIP-S'] = clipscore_unnormalized_mean
|
86 |
-
elif self.opt.clipscore_mode == 'refclip_s':
|
87 |
-
out['RefCLIP-S'] = clipscore_unnormalized_mean
|
88 |
-
|
89 |
-
if getattr(self.opt, 'use_grammar', False):
|
90 |
-
out['grammar_reward'] = grammar_rewards.mean()
|
91 |
-
|
92 |
-
reward = clipscore_reward_normalized + grammar_rewards
|
93 |
-
|
94 |
-
|
95 |
-
else:
|
96 |
-
assert grammar_rewards is None
|
97 |
-
|
98 |
-
cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward(
|
99 |
-
greedy_res, gts, gen_result, self.opt)
|
100 |
-
out['CIDEr'] = cider_unnormalized_mean
|
101 |
-
if isinstance(cider_reward_normalized, np.ndarray):
|
102 |
-
cider_reward_normalized = torch.from_numpy(cider_reward_normalized).to(clipscore_reward_normalized.device)
|
103 |
-
|
104 |
-
reward = clipscore_reward_normalized + cider_reward_normalized
|
105 |
-
else:
|
106 |
-
if self.opt.use_clipscore:
|
107 |
-
clipscore_reward_normalized, clipscore_unnormalized_mean, _ = get_self_critical_clipscore_reward(
|
108 |
-
greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab)
|
109 |
-
if self.opt.clipscore_mode == 'clip_s':
|
110 |
-
out['CLIP-S'] = clipscore_unnormalized_mean
|
111 |
-
elif self.opt.clipscore_mode == 'refclip_s':
|
112 |
-
out['RefCLIP-S'] = clipscore_unnormalized_mean
|
113 |
-
reward = clipscore_reward_normalized
|
114 |
-
else:
|
115 |
-
cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward(
|
116 |
-
greedy_res, gts, gen_result, self.opt)
|
117 |
-
out['CIDEr'] = cider_unnormalized_mean
|
118 |
-
reward = cider_reward_normalized
|
119 |
-
|
120 |
-
if isinstance(reward, np.ndarray):
|
121 |
-
reward = torch.from_numpy(reward)
|
122 |
-
reward = reward.to(sample_logprobs)
|
123 |
-
loss = self.rl_crit(sample_logprobs, gen_result.data, reward)
|
124 |
-
out['reward'] = reward[:,0].mean()
|
125 |
-
out['loss'] = loss
|
126 |
-
return out
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/modules/losses.py
DELETED
@@ -1,218 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
from ..utils.rewards import get_scores, get_self_cider_scores
|
4 |
-
|
5 |
-
class RewardCriterion(nn.Module):
|
6 |
-
def __init__(self):
|
7 |
-
super(RewardCriterion, self).__init__()
|
8 |
-
|
9 |
-
def forward(self, input, seq, reward):
|
10 |
-
input = input.gather(2, seq.unsqueeze(2)).squeeze(2)
|
11 |
-
|
12 |
-
input = input.reshape(-1)
|
13 |
-
reward = reward.reshape(-1)
|
14 |
-
mask = (seq>0).to(input)
|
15 |
-
mask = torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1).reshape(-1)
|
16 |
-
output = - input * reward * mask
|
17 |
-
output = torch.sum(output) / torch.sum(mask)
|
18 |
-
|
19 |
-
return output
|
20 |
-
|
21 |
-
class StructureLosses(nn.Module):
|
22 |
-
"""
|
23 |
-
This loss is inspired by Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018).
|
24 |
-
"""
|
25 |
-
def __init__(self, opt):
|
26 |
-
super(StructureLosses, self).__init__()
|
27 |
-
self.opt = opt
|
28 |
-
self.loss_type = opt.structure_loss_type
|
29 |
-
|
30 |
-
def forward(self, input, seq, data_gts):
|
31 |
-
"""
|
32 |
-
Input is either logits or log softmax
|
33 |
-
"""
|
34 |
-
out = {}
|
35 |
-
|
36 |
-
batch_size = input.size(0)# batch_size = sample_size * seq_per_img
|
37 |
-
seq_per_img = batch_size // len(data_gts)
|
38 |
-
|
39 |
-
assert seq_per_img == self.opt.train_sample_n, seq_per_img
|
40 |
-
|
41 |
-
mask = (seq>0).to(input)
|
42 |
-
mask = torch.cat([mask.new_full((mask.size(0), 1), 1), mask[:, :-1]], 1)
|
43 |
-
|
44 |
-
scores = get_scores(data_gts, seq, self.opt)
|
45 |
-
scores = torch.from_numpy(scores).type_as(input).view(-1, seq_per_img)
|
46 |
-
out['reward'] = scores #.mean()
|
47 |
-
if self.opt.entropy_reward_weight > 0:
|
48 |
-
entropy = - (F.softmax(input, dim=2) * F.log_softmax(input, dim=2)).sum(2).data
|
49 |
-
entropy = (entropy * mask).sum(1) / mask.sum(1)
|
50 |
-
print('entropy', entropy.mean().item())
|
51 |
-
scores = scores + self.opt.entropy_reward_weight * entropy.view(-1, seq_per_img)
|
52 |
-
# rescale cost to [0,1]
|
53 |
-
costs = - scores
|
54 |
-
if self.loss_type == 'risk' or self.loss_type == 'softmax_margin':
|
55 |
-
costs = costs - costs.min(1, keepdim=True)[0]
|
56 |
-
costs = costs / costs.max(1, keepdim=True)[0]
|
57 |
-
# in principle
|
58 |
-
# Only risk need such rescale
|
59 |
-
# margin should be alright; Let's try.
|
60 |
-
|
61 |
-
# Gather input: BxTxD -> BxT
|
62 |
-
input = input.gather(2, seq.unsqueeze(2)).squeeze(2)
|
63 |
-
|
64 |
-
if self.loss_type == 'seqnll':
|
65 |
-
# input is logsoftmax
|
66 |
-
input = input * mask
|
67 |
-
input = input.sum(1) / mask.sum(1)
|
68 |
-
input = input.view(-1, seq_per_img)
|
69 |
-
|
70 |
-
target = costs.min(1)[1]
|
71 |
-
output = F.cross_entropy(input, target)
|
72 |
-
elif self.loss_type == 'risk':
|
73 |
-
# input is logsoftmax
|
74 |
-
input = input * mask
|
75 |
-
input = input.sum(1)
|
76 |
-
input = input.view(-1, seq_per_img)
|
77 |
-
|
78 |
-
output = (F.softmax(input.exp()) * costs).sum(1).mean()
|
79 |
-
|
80 |
-
# test
|
81 |
-
# avg_scores = input
|
82 |
-
# probs = F.softmax(avg_scores.exp_())
|
83 |
-
# loss = (probs * costs.type_as(probs)).sum() / input.size(0)
|
84 |
-
# print(output.item(), loss.item())
|
85 |
-
|
86 |
-
elif self.loss_type == 'max_margin':
|
87 |
-
# input is logits
|
88 |
-
input = input * mask
|
89 |
-
input = input.sum(1) / mask.sum(1)
|
90 |
-
input = input.view(-1, seq_per_img)
|
91 |
-
_, __ = costs.min(1, keepdim=True)
|
92 |
-
costs_star = _
|
93 |
-
input_star = input.gather(1, __)
|
94 |
-
output = F.relu(costs - costs_star - input_star + input).max(1)[0] / 2
|
95 |
-
output = output.mean()
|
96 |
-
|
97 |
-
# sanity test
|
98 |
-
# avg_scores = input + costs
|
99 |
-
# scores_with_high_target = avg_scores.clone()
|
100 |
-
# scores_with_high_target.scatter_(1, costs.min(1)[1].view(-1, 1), 1e10)
|
101 |
-
|
102 |
-
# target_and_offender_index = scores_with_high_target.sort(1, True)[1][:, 0:2]
|
103 |
-
# avg_scores = avg_scores.gather(1, target_and_offender_index)
|
104 |
-
# target_index = avg_scores.new_zeros(avg_scores.size(0), dtype=torch.long)
|
105 |
-
# loss = F.multi_margin_loss(avg_scores, target_index, size_average=True, margin=0)
|
106 |
-
# print(loss.item() * 2, output.item())
|
107 |
-
|
108 |
-
elif self.loss_type == 'multi_margin':
|
109 |
-
# input is logits
|
110 |
-
input = input * mask
|
111 |
-
input = input.sum(1) / mask.sum(1)
|
112 |
-
input = input.view(-1, seq_per_img)
|
113 |
-
_, __ = costs.min(1, keepdim=True)
|
114 |
-
costs_star = _
|
115 |
-
input_star = input.gather(1, __)
|
116 |
-
output = F.relu(costs - costs_star - input_star + input)
|
117 |
-
output = output.mean()
|
118 |
-
|
119 |
-
# sanity test
|
120 |
-
# avg_scores = input + costs
|
121 |
-
# loss = F.multi_margin_loss(avg_scores, costs.min(1)[1], margin=0)
|
122 |
-
# print(output, loss)
|
123 |
-
|
124 |
-
elif self.loss_type == 'softmax_margin':
|
125 |
-
# input is logsoftmax
|
126 |
-
input = input * mask
|
127 |
-
input = input.sum(1) / mask.sum(1)
|
128 |
-
input = input.view(-1, seq_per_img)
|
129 |
-
|
130 |
-
input = input + costs
|
131 |
-
target = costs.min(1)[1]
|
132 |
-
output = F.cross_entropy(input, target)
|
133 |
-
|
134 |
-
elif self.loss_type == 'real_softmax_margin':
|
135 |
-
# input is logits
|
136 |
-
# This is what originally defined in Kevin's paper
|
137 |
-
# The result should be equivalent to softmax_margin
|
138 |
-
input = input * mask
|
139 |
-
input = input.sum(1) / mask.sum(1)
|
140 |
-
input = input.view(-1, seq_per_img)
|
141 |
-
|
142 |
-
input = input + costs
|
143 |
-
target = costs.min(1)[1]
|
144 |
-
output = F.cross_entropy(input, target)
|
145 |
-
|
146 |
-
elif self.loss_type == 'new_self_critical':
|
147 |
-
"""
|
148 |
-
A different self critical
|
149 |
-
Self critical uses greedy decoding score as baseline;
|
150 |
-
This setting uses the average score of the rest samples as baseline
|
151 |
-
(suppose c1...cn n samples, reward1 = score1 - 1/(n-1)(score2+..+scoren) )
|
152 |
-
"""
|
153 |
-
baseline = (scores.sum(1, keepdim=True) - scores) / (scores.shape[1] - 1)
|
154 |
-
scores = scores - baseline
|
155 |
-
# self cider used as reward to promote diversity (not working that much in this way)
|
156 |
-
if getattr(self.opt, 'self_cider_reward_weight', 0) > 0:
|
157 |
-
_scores = get_self_cider_scores(data_gts, seq, self.opt)
|
158 |
-
_scores = torch.from_numpy(_scores).type_as(scores).view(-1, 1)
|
159 |
-
_scores = _scores.expand_as(scores - 1)
|
160 |
-
scores += self.opt.self_cider_reward_weight * _scores
|
161 |
-
output = - input * mask * scores.view(-1, 1)
|
162 |
-
output = torch.sum(output) / torch.sum(mask)
|
163 |
-
|
164 |
-
out['loss'] = output
|
165 |
-
return out
|
166 |
-
|
167 |
-
class LanguageModelCriterion(nn.Module):
|
168 |
-
def __init__(self):
|
169 |
-
super(LanguageModelCriterion, self).__init__()
|
170 |
-
|
171 |
-
def forward(self, input, target, mask):
|
172 |
-
if target.ndim == 3:
|
173 |
-
target = target.reshape(-1, target.shape[2])
|
174 |
-
mask = mask.reshape(-1, mask.shape[2])
|
175 |
-
# truncate to the same size
|
176 |
-
target = target[:, :input.size(1)]
|
177 |
-
mask = mask[:, :input.size(1)].to(input)
|
178 |
-
|
179 |
-
output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask
|
180 |
-
# Average over each token
|
181 |
-
output = torch.sum(output) / torch.sum(mask)
|
182 |
-
|
183 |
-
return output
|
184 |
-
|
185 |
-
class LabelSmoothing(nn.Module):
|
186 |
-
"Implement label smoothing."
|
187 |
-
def __init__(self, size=0, padding_idx=0, smoothing=0.0):
|
188 |
-
super(LabelSmoothing, self).__init__()
|
189 |
-
self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
|
190 |
-
# self.padding_idx = padding_idx
|
191 |
-
self.confidence = 1.0 - smoothing
|
192 |
-
self.smoothing = smoothing
|
193 |
-
# self.size = size
|
194 |
-
self.true_dist = None
|
195 |
-
|
196 |
-
def forward(self, input, target, mask):
|
197 |
-
if target.ndim == 3:
|
198 |
-
target = target.reshape(-1, target.shape[2])
|
199 |
-
mask = mask.reshape(-1, mask.shape[2])
|
200 |
-
# truncate to the same size
|
201 |
-
target = target[:, :input.size(1)]
|
202 |
-
mask = mask[:, :input.size(1)]
|
203 |
-
|
204 |
-
input = input.reshape(-1, input.size(-1))
|
205 |
-
target = target.reshape(-1)
|
206 |
-
mask = mask.reshape(-1).to(input)
|
207 |
-
|
208 |
-
# assert x.size(1) == self.size
|
209 |
-
self.size = input.size(1)
|
210 |
-
# true_dist = x.data.clone()
|
211 |
-
true_dist = input.data.clone()
|
212 |
-
# true_dist.fill_(self.smoothing / (self.size - 2))
|
213 |
-
true_dist.fill_(self.smoothing / (self.size - 1))
|
214 |
-
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
|
215 |
-
# true_dist[:, self.padding_idx] = 0
|
216 |
-
# mask = torch.nonzero(target.data == self.padding_idx)
|
217 |
-
# self.true_dist = true_dist
|
218 |
-
return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/utils/__init__.py
DELETED
File without changes
|
captioning/utils/clipscore.py
DELETED
@@ -1,396 +0,0 @@
|
|
1 |
-
from transformers import CLIPModel, CLIPTokenizer
|
2 |
-
import os
|
3 |
-
import json
|
4 |
-
import argparse
|
5 |
-
from random import shuffle, seed
|
6 |
-
import string
|
7 |
-
# non-standard dependencies:
|
8 |
-
import h5py
|
9 |
-
from six.moves import cPickle
|
10 |
-
import numpy as np
|
11 |
-
import torch
|
12 |
-
import torchvision.models as models
|
13 |
-
import skimage.io
|
14 |
-
|
15 |
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
16 |
-
from PIL import Image
|
17 |
-
from torch import nn
|
18 |
-
|
19 |
-
|
20 |
-
class CLIPScore(nn.Module):
|
21 |
-
def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s', use_grammar=False, joint_out=False):
|
22 |
-
super(CLIPScore, self).__init__()
|
23 |
-
# from transformers import CLIPModel, CLIPTokenizer
|
24 |
-
self.clip_model = CLIPModel.from_pretrained(
|
25 |
-
'openai/clip-vit-base-patch32')
|
26 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(
|
27 |
-
'openai/clip-vit-base-patch32')
|
28 |
-
|
29 |
-
self.clip_model.eval()
|
30 |
-
|
31 |
-
self.clipscore_w = clipscore_w
|
32 |
-
|
33 |
-
self.image_transform = self._transform(image_size)
|
34 |
-
|
35 |
-
self.mode = mode
|
36 |
-
assert mode in ['clip_s', 'refclip_s']
|
37 |
-
|
38 |
-
self.use_grammar = use_grammar
|
39 |
-
self.joint_out = joint_out
|
40 |
-
|
41 |
-
if self.use_grammar and joint_out is False:
|
42 |
-
self.grammar_score_head = nn.Sequential(
|
43 |
-
nn.Linear(self.clip_model.text_embed_dim, self.clip_model.projection_dim, bias=False),
|
44 |
-
nn.ReLU(),
|
45 |
-
nn.Linear(self.clip_model.projection_dim, 2, bias=False)
|
46 |
-
)
|
47 |
-
|
48 |
-
def _transform(self, n_px):
|
49 |
-
return Compose([
|
50 |
-
Resize(n_px, interpolation=Image.BICUBIC),
|
51 |
-
CenterCrop(n_px),
|
52 |
-
lambda image: image.convert("RGB"),
|
53 |
-
ToTensor(),
|
54 |
-
Normalize((0.48145466, 0.4578275, 0.40821073),
|
55 |
-
(0.26862954, 0.26130258, 0.27577711)),
|
56 |
-
])
|
57 |
-
|
58 |
-
def load_image(self, image_path):
|
59 |
-
image = Image.open(image_path)
|
60 |
-
return image
|
61 |
-
|
62 |
-
# @torch.no_grad()
|
63 |
-
def image_extract(self, image):
|
64 |
-
if isinstance(image, str):
|
65 |
-
image = self.load_image(image)
|
66 |
-
if not isinstance(image, torch.Tensor):
|
67 |
-
image = self.image_transform(image)
|
68 |
-
|
69 |
-
img_tensor = image.view(-1, 3, 224, 224)
|
70 |
-
device = next(self.clip_model.parameters()).device
|
71 |
-
img_tensor = img_tensor.to(device)
|
72 |
-
|
73 |
-
clip_model = self.clip_model
|
74 |
-
|
75 |
-
img_feat = clip_model.vision_model(img_tensor).pooler_output
|
76 |
-
img_feat = clip_model.visual_projection(img_feat)
|
77 |
-
img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
|
78 |
-
|
79 |
-
return img_feat
|
80 |
-
|
81 |
-
# @torch.no_grad()
|
82 |
-
def text_extract(self, text, prompt="A photo depicts", proj_norm=True):
|
83 |
-
if isinstance(text, str):
|
84 |
-
text_batch = [" ".join([prompt, text])]
|
85 |
-
elif isinstance(text, list):
|
86 |
-
text_batch = [" ".join([prompt, txt]) for txt in text]
|
87 |
-
|
88 |
-
if isinstance(text, tuple) and isinstance(text[0], torch.Tensor):
|
89 |
-
input_ids, attention_mask = text
|
90 |
-
else:
|
91 |
-
input_text = text_batch
|
92 |
-
|
93 |
-
tokenized = self.tokenizer(
|
94 |
-
input_text, return_tensors='pt', padding=True, truncation=True)
|
95 |
-
|
96 |
-
input_ids = tokenized.input_ids
|
97 |
-
attention_mask = tokenized.attention_mask
|
98 |
-
|
99 |
-
clip_model = self.clip_model
|
100 |
-
device = next(self.clip_model.parameters()).device
|
101 |
-
input_ids = input_ids.to(device)
|
102 |
-
attention_mask = attention_mask.to(device)
|
103 |
-
|
104 |
-
text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output
|
105 |
-
|
106 |
-
if proj_norm:
|
107 |
-
text_feat = clip_model.text_projection(text_feat)
|
108 |
-
text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
|
109 |
-
|
110 |
-
return text_feat
|
111 |
-
|
112 |
-
# @torch.no_grad()
|
113 |
-
def calc_clip_s(self, img_feat, text_feat):
|
114 |
-
return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1))
|
115 |
-
|
116 |
-
# @torch.no_grad()
|
117 |
-
def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None):
|
118 |
-
|
119 |
-
if clip_s is None:
|
120 |
-
clip_s = self.calc_clip_s(img_feat, text_feat)
|
121 |
-
|
122 |
-
B, dim = img_feat.size()
|
123 |
-
|
124 |
-
ref_text_feat = ref_text_feat.view(B, -1, dim)
|
125 |
-
|
126 |
-
K = ref_text_feat.size(1)
|
127 |
-
|
128 |
-
text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1)
|
129 |
-
assert ref_text_feat.size() == text_feat.size(
|
130 |
-
), (ref_text_feat.size(), text_feat.size())
|
131 |
-
|
132 |
-
ref_score = self.calc_clip_s(text_feat, ref_text_feat)
|
133 |
-
if ref_text_mask is not None:
|
134 |
-
if not isinstance(ref_text_mask, torch.Tensor):
|
135 |
-
ref_text_mask = torch.tensor(
|
136 |
-
ref_text_mask, dtype=ref_score.dtype, device=ref_score.device)
|
137 |
-
ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K)
|
138 |
-
|
139 |
-
ref_score = ref_score.view(B, K).max(dim=1).values
|
140 |
-
|
141 |
-
assert clip_s.size() == (B,)
|
142 |
-
assert clip_s.size() == ref_score.size()
|
143 |
-
|
144 |
-
# harmonic mean
|
145 |
-
refclip_s = 2 / (1 / clip_s + 1 / ref_score)
|
146 |
-
return refclip_s
|
147 |
-
|
148 |
-
@torch.no_grad()
|
149 |
-
def forward(self,
|
150 |
-
images=None, text=None,
|
151 |
-
img_feat=None, text_feat=None,
|
152 |
-
ref_text=None, ref_text_feat=None, ref_text_mask=None,
|
153 |
-
prompt="A photo depicts",
|
154 |
-
mode=None):
|
155 |
-
if img_feat is None:
|
156 |
-
img_feat = self.image_extract(images)
|
157 |
-
img_feat = img_feat.view(-1, 512)
|
158 |
-
|
159 |
-
B = img_feat.size(0)
|
160 |
-
|
161 |
-
if text_feat is None:
|
162 |
-
text_feat = self.text_extract(text, prompt=prompt)
|
163 |
-
text_feat = text_feat.view(-1, 512)
|
164 |
-
|
165 |
-
if mode is None:
|
166 |
-
mode = self.mode
|
167 |
-
assert mode in ['clip_s', 'refclip_s']
|
168 |
-
|
169 |
-
if mode == 'clip_s':
|
170 |
-
clip_s = self.calc_clip_s(img_feat, text_feat)
|
171 |
-
return clip_s
|
172 |
-
elif mode == 'refclip_s':
|
173 |
-
if ref_text_feat is None:
|
174 |
-
ref_text_feat = self.text_extract(ref_text, prompt=prompt)
|
175 |
-
ref_text_feat = ref_text_feat.view(-1, 512)
|
176 |
-
|
177 |
-
refclip_s = self.calc_refclip_s(
|
178 |
-
img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask)
|
179 |
-
return refclip_s
|
180 |
-
|
181 |
-
|
182 |
-
def train_step(self,
|
183 |
-
images=None, text=None,
|
184 |
-
img_feat=None, text_feat=None,
|
185 |
-
neg_text=None, neg_text_feat=None,
|
186 |
-
# ref_text=None, ref_text_feat=None, ref_text_mask=None,
|
187 |
-
prompt="A photo depicts",
|
188 |
-
# return_loss=True,
|
189 |
-
**kwargs):
|
190 |
-
|
191 |
-
if img_feat is None:
|
192 |
-
img_feat = self.image_extract(images)
|
193 |
-
img_feat = img_feat.view(-1, 512)
|
194 |
-
|
195 |
-
B = img_feat.size(0)
|
196 |
-
|
197 |
-
if text_feat is None:
|
198 |
-
text_feat = self.text_extract(text, prompt=prompt, proj_norm=False)
|
199 |
-
|
200 |
-
text_cont_feat = self.clip_model.text_projection(text_feat)
|
201 |
-
text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
|
202 |
-
text_cont_feat = text_cont_feat.view(B, 512)
|
203 |
-
|
204 |
-
# cosine similarity as logits
|
205 |
-
logit_scale = self.clip_model.logit_scale.exp()
|
206 |
-
logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale
|
207 |
-
# logits_per_image = logits_per_text.T
|
208 |
-
|
209 |
-
clip_loss = clip_loss_fn(logits_per_text)
|
210 |
-
|
211 |
-
|
212 |
-
# negative sampling
|
213 |
-
pos_text_feat = text_feat.view(B, 512)
|
214 |
-
neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512)
|
215 |
-
|
216 |
-
grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0)
|
217 |
-
|
218 |
-
# 2B, 1
|
219 |
-
grammar_text_logit = self.grammar_score_head(grammar_text_feat)
|
220 |
-
grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B)
|
221 |
-
|
222 |
-
grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels)
|
223 |
-
|
224 |
-
grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False)
|
225 |
-
grammar_pos_pred = grammar_pred[:B]
|
226 |
-
grammar_neg_pred = grammar_pred[B:]
|
227 |
-
# grammar_acc = (grammar_pred == grammar_labels).float().mean()
|
228 |
-
|
229 |
-
out = {
|
230 |
-
'clip_loss': clip_loss,
|
231 |
-
'grammar_loss': grammar_loss,
|
232 |
-
'img_feat': img_feat,
|
233 |
-
'text_feat': text_cont_feat,
|
234 |
-
'neg_text_feat': neg_text_feat,
|
235 |
-
'grammar_pos_pred': grammar_pos_pred,
|
236 |
-
'grammar_neg_pred': grammar_neg_pred,
|
237 |
-
}
|
238 |
-
|
239 |
-
return out
|
240 |
-
|
241 |
-
# contrastive loss function, adapted from
|
242 |
-
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
|
243 |
-
def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor:
|
244 |
-
neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim))
|
245 |
-
return -neg_ce.mean()
|
246 |
-
|
247 |
-
|
248 |
-
def clip_loss_fn(similarity: torch.Tensor) -> torch.Tensor:
|
249 |
-
caption_loss = contrastive_loss(similarity, dim=0)
|
250 |
-
image_loss = contrastive_loss(similarity, dim=1)
|
251 |
-
return (caption_loss + image_loss) / 2.0
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
# class CLIPScore(nn.Module):
|
256 |
-
# def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s'):
|
257 |
-
# super(CLIPScore, self).__init__()
|
258 |
-
# # from transformers import CLIPModel, CLIPTokenizer
|
259 |
-
# self.clip_model = CLIPModel.from_pretrained(
|
260 |
-
# 'openai/clip-vit-base-patch32')
|
261 |
-
# self.tokenizer = CLIPTokenizer.from_pretrained(
|
262 |
-
# 'openai/clip-vit-base-patch32')
|
263 |
-
|
264 |
-
# self.clip_model.eval()
|
265 |
-
|
266 |
-
# self.clipscore_w = clipscore_w
|
267 |
-
|
268 |
-
# self.image_transform = self._transform(image_size)
|
269 |
-
|
270 |
-
# self.mode = mode
|
271 |
-
# assert mode in ['clip_s', 'refclip_s']
|
272 |
-
|
273 |
-
# def _transform(self, n_px):
|
274 |
-
# return Compose([
|
275 |
-
# Resize(n_px, interpolation=Image.BICUBIC),
|
276 |
-
# CenterCrop(n_px),
|
277 |
-
# lambda image: image.convert("RGB"),
|
278 |
-
# ToTensor(),
|
279 |
-
# Normalize((0.48145466, 0.4578275, 0.40821073),
|
280 |
-
# (0.26862954, 0.26130258, 0.27577711)),
|
281 |
-
# ])
|
282 |
-
|
283 |
-
# def load_image(self, image_path):
|
284 |
-
# image = Image.open(image_path)
|
285 |
-
# return image
|
286 |
-
|
287 |
-
# @torch.no_grad()
|
288 |
-
# def image_extract(self, image):
|
289 |
-
# if isinstance(image, str):
|
290 |
-
# image = self.load_image(image)
|
291 |
-
# if not isinstance(image, torch.Tensor):
|
292 |
-
# image = self.image_transform(image)
|
293 |
-
|
294 |
-
# img_tensor = image.view(-1, 3, 224, 224)
|
295 |
-
# device = next(self.clip_model.parameters()).device
|
296 |
-
# img_tensor = img_tensor.to(device)
|
297 |
-
|
298 |
-
# clip_model = self.clip_model
|
299 |
-
|
300 |
-
# img_feat = clip_model.vision_model(img_tensor).pooler_output
|
301 |
-
# img_feat = clip_model.visual_projection(img_feat)
|
302 |
-
# img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
|
303 |
-
|
304 |
-
# return img_feat
|
305 |
-
|
306 |
-
# @torch.no_grad()
|
307 |
-
# def text_extract(self, text, prompt="A photo depicts"):
|
308 |
-
# if isinstance(text, str):
|
309 |
-
# text_batch = [" ".join([prompt, text])]
|
310 |
-
# else:
|
311 |
-
# text_batch = [" ".join([prompt, txt]) for txt in text]
|
312 |
-
|
313 |
-
# input_text = text_batch
|
314 |
-
|
315 |
-
# tokenized = self.tokenizer(
|
316 |
-
# input_text, return_tensors='pt', padding=True)
|
317 |
-
|
318 |
-
# input_ids = tokenized.input_ids
|
319 |
-
# attention_mask = tokenized.attention_mask
|
320 |
-
|
321 |
-
# clip_model = self.clip_model
|
322 |
-
# device = next(self.clip_model.parameters()).device
|
323 |
-
# input_ids = input_ids.to(device)
|
324 |
-
# attention_mask = attention_mask.to(device)
|
325 |
-
|
326 |
-
# text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output
|
327 |
-
# text_feat = clip_model.text_projection(text_feat)
|
328 |
-
# text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
|
329 |
-
|
330 |
-
# return text_feat
|
331 |
-
|
332 |
-
# @torch.no_grad()
|
333 |
-
# def calc_clip_s(self, img_feat, text_feat):
|
334 |
-
# return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1))
|
335 |
-
|
336 |
-
# @torch.no_grad()
|
337 |
-
# def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None):
|
338 |
-
|
339 |
-
# if clip_s is None:
|
340 |
-
# clip_s = self.calc_clip_s(img_feat, text_feat)
|
341 |
-
|
342 |
-
# B, dim = img_feat.size()
|
343 |
-
|
344 |
-
# ref_text_feat = ref_text_feat.view(B, -1, dim)
|
345 |
-
|
346 |
-
# K = ref_text_feat.size(1)
|
347 |
-
|
348 |
-
# text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1)
|
349 |
-
# assert ref_text_feat.size() == text_feat.size(), (ref_text_feat.size(), text_feat.size())
|
350 |
-
|
351 |
-
# ref_score = self.calc_clip_s(text_feat, ref_text_feat)
|
352 |
-
# if ref_text_mask is not None:
|
353 |
-
# if not isinstance(ref_text_mask, torch.Tensor):
|
354 |
-
# ref_text_mask = torch.tensor(ref_text_mask, dtype=ref_score.dtype, device=ref_score.device)
|
355 |
-
# ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K)
|
356 |
-
|
357 |
-
# ref_score = ref_score.view(B, K).max(dim=1).values
|
358 |
-
|
359 |
-
# assert clip_s.size() == (B,)
|
360 |
-
# assert clip_s.size() == ref_score.size()
|
361 |
-
|
362 |
-
# # harmonic mean
|
363 |
-
# refclip_s = 2 / (1 / clip_s + 1 / ref_score)
|
364 |
-
# return refclip_s
|
365 |
-
|
366 |
-
|
367 |
-
# @torch.no_grad()
|
368 |
-
# def forward(self,
|
369 |
-
# images=None, text=None,
|
370 |
-
# img_feat=None, text_feat=None,
|
371 |
-
# ref_text=None, ref_text_feat=None, ref_text_mask=None,
|
372 |
-
# prompt="A photo depicts",
|
373 |
-
# mode=None):
|
374 |
-
# if img_feat is None:
|
375 |
-
# img_feat = self.image_extract(images)
|
376 |
-
# img_feat = img_feat.view(-1, 512)
|
377 |
-
|
378 |
-
# if text_feat is None:
|
379 |
-
# text_feat = self.text_extract(text, prompt=prompt)
|
380 |
-
# text_feat = text_feat.view(-1, 512)
|
381 |
-
|
382 |
-
# if mode is None:
|
383 |
-
# mode = self.mode
|
384 |
-
# assert mode in ['clip_s', 'refclip_s']
|
385 |
-
|
386 |
-
# if mode == 'clip_s':
|
387 |
-
# clip_s = self.calc_clip_s(img_feat, text_feat)
|
388 |
-
# return clip_s
|
389 |
-
# elif mode == 'refclip_s':
|
390 |
-
# if ref_text_feat is None:
|
391 |
-
# ref_text_feat = self.text_extract(ref_text, prompt=prompt)
|
392 |
-
# ref_text_feat = ref_text_feat.view(-1, 512)
|
393 |
-
|
394 |
-
# refclip_s = self.calc_refclip_s(img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask)
|
395 |
-
# return refclip_s
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/utils/config.py
DELETED
@@ -1,153 +0,0 @@
|
|
1 |
-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
2 |
-
# Copy from fvcore
|
3 |
-
|
4 |
-
import logging
|
5 |
-
import os
|
6 |
-
from typing import Any
|
7 |
-
import yaml
|
8 |
-
from yacs.config import CfgNode as _CfgNode
|
9 |
-
|
10 |
-
import io as PathManager
|
11 |
-
|
12 |
-
BASE_KEY = "_BASE_"
|
13 |
-
|
14 |
-
|
15 |
-
class CfgNode(_CfgNode):
|
16 |
-
"""
|
17 |
-
Our own extended version of :class:`yacs.config.CfgNode`.
|
18 |
-
It contains the following extra features:
|
19 |
-
|
20 |
-
1. The :meth:`merge_from_file` method supports the "_BASE_" key,
|
21 |
-
which allows the new CfgNode to inherit all the attributes from the
|
22 |
-
base configuration file.
|
23 |
-
2. Keys that start with "COMPUTED_" are treated as insertion-only
|
24 |
-
"computed" attributes. They can be inserted regardless of whether
|
25 |
-
the CfgNode is frozen or not.
|
26 |
-
3. With "allow_unsafe=True", it supports pyyaml tags that evaluate
|
27 |
-
expressions in config. See examples in
|
28 |
-
https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types
|
29 |
-
Note that this may lead to arbitrary code execution: you must not
|
30 |
-
load a config file from untrusted sources before manually inspecting
|
31 |
-
the content of the file.
|
32 |
-
"""
|
33 |
-
|
34 |
-
@staticmethod
|
35 |
-
def load_yaml_with_base(filename, allow_unsafe = False):
|
36 |
-
"""
|
37 |
-
Just like `yaml.load(open(filename))`, but inherit attributes from its
|
38 |
-
`_BASE_`.
|
39 |
-
|
40 |
-
Args:
|
41 |
-
filename (str): the file name of the current config. Will be used to
|
42 |
-
find the base config file.
|
43 |
-
allow_unsafe (bool): whether to allow loading the config file with
|
44 |
-
`yaml.unsafe_load`.
|
45 |
-
|
46 |
-
Returns:
|
47 |
-
(dict): the loaded yaml
|
48 |
-
"""
|
49 |
-
with PathManager.open(filename, "r") as f:
|
50 |
-
try:
|
51 |
-
cfg = yaml.safe_load(f)
|
52 |
-
except yaml.constructor.ConstructorError:
|
53 |
-
if not allow_unsafe:
|
54 |
-
raise
|
55 |
-
logger = logging.getLogger(__name__)
|
56 |
-
logger.warning(
|
57 |
-
"Loading config {} with yaml.unsafe_load. Your machine may "
|
58 |
-
"be at risk if the file contains malicious content.".format(
|
59 |
-
filename
|
60 |
-
)
|
61 |
-
)
|
62 |
-
f.close()
|
63 |
-
with open(filename, "r") as f:
|
64 |
-
cfg = yaml.unsafe_load(f)
|
65 |
-
|
66 |
-
def merge_a_into_b(a, b):
|
67 |
-
# merge dict a into dict b. values in a will overwrite b.
|
68 |
-
for k, v in a.items():
|
69 |
-
if isinstance(v, dict) and k in b:
|
70 |
-
assert isinstance(
|
71 |
-
b[k], dict
|
72 |
-
), "Cannot inherit key '{}' from base!".format(k)
|
73 |
-
merge_a_into_b(v, b[k])
|
74 |
-
else:
|
75 |
-
b[k] = v
|
76 |
-
|
77 |
-
if BASE_KEY in cfg:
|
78 |
-
base_cfg_file = cfg[BASE_KEY]
|
79 |
-
if base_cfg_file.startswith("~"):
|
80 |
-
base_cfg_file = os.path.expanduser(base_cfg_file)
|
81 |
-
if not any(
|
82 |
-
map(base_cfg_file.startswith, ["/", "https://", "http://"])
|
83 |
-
):
|
84 |
-
# the path to base cfg is relative to the config file itself.
|
85 |
-
base_cfg_file = os.path.join(
|
86 |
-
os.path.dirname(filename), base_cfg_file
|
87 |
-
)
|
88 |
-
base_cfg = CfgNode.load_yaml_with_base(
|
89 |
-
base_cfg_file, allow_unsafe=allow_unsafe
|
90 |
-
)
|
91 |
-
del cfg[BASE_KEY]
|
92 |
-
|
93 |
-
merge_a_into_b(cfg, base_cfg)
|
94 |
-
return base_cfg
|
95 |
-
return cfg
|
96 |
-
|
97 |
-
def merge_from_file(self, cfg_filename, allow_unsafe = False):
|
98 |
-
"""
|
99 |
-
Merge configs from a given yaml file.
|
100 |
-
|
101 |
-
Args:
|
102 |
-
cfg_filename: the file name of the yaml config.
|
103 |
-
allow_unsafe: whether to allow loading the config file with
|
104 |
-
`yaml.unsafe_load`.
|
105 |
-
"""
|
106 |
-
loaded_cfg = CfgNode.load_yaml_with_base(
|
107 |
-
cfg_filename, allow_unsafe=allow_unsafe
|
108 |
-
)
|
109 |
-
loaded_cfg = type(self)(loaded_cfg)
|
110 |
-
self.merge_from_other_cfg(loaded_cfg)
|
111 |
-
|
112 |
-
# Forward the following calls to base, but with a check on the BASE_KEY.
|
113 |
-
def merge_from_other_cfg(self, cfg_other):
|
114 |
-
"""
|
115 |
-
Args:
|
116 |
-
cfg_other (CfgNode): configs to merge from.
|
117 |
-
"""
|
118 |
-
assert (
|
119 |
-
BASE_KEY not in cfg_other
|
120 |
-
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
|
121 |
-
return super().merge_from_other_cfg(cfg_other)
|
122 |
-
|
123 |
-
def merge_from_list(self, cfg_list):
|
124 |
-
"""
|
125 |
-
Args:
|
126 |
-
cfg_list (list): list of configs to merge from.
|
127 |
-
"""
|
128 |
-
keys = set(cfg_list[0::2])
|
129 |
-
assert (
|
130 |
-
BASE_KEY not in keys
|
131 |
-
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
|
132 |
-
return super().merge_from_list(cfg_list)
|
133 |
-
|
134 |
-
def __setattr__(self, name, val):
|
135 |
-
if name.startswith("COMPUTED_"):
|
136 |
-
if name in self:
|
137 |
-
old_val = self[name]
|
138 |
-
if old_val == val:
|
139 |
-
return
|
140 |
-
raise KeyError(
|
141 |
-
"Computed attributed '{}' already exists "
|
142 |
-
"with a different value! old={}, new={}.".format(
|
143 |
-
name, old_val, val
|
144 |
-
)
|
145 |
-
)
|
146 |
-
self[name] = val
|
147 |
-
else:
|
148 |
-
super().__setattr__(name, val)
|
149 |
-
|
150 |
-
|
151 |
-
if __name__ == '__main__':
|
152 |
-
cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml')
|
153 |
-
print(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/utils/dist_utils.py
DELETED
@@ -1,305 +0,0 @@
|
|
1 |
-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
-
"""
|
3 |
-
This file contains primitives for multi-gpu communication.
|
4 |
-
This is useful when doing distributed training.
|
5 |
-
"""
|
6 |
-
|
7 |
-
import functools
|
8 |
-
import logging
|
9 |
-
import numpy as np
|
10 |
-
import pickle
|
11 |
-
import torch
|
12 |
-
import torch.distributed as dist
|
13 |
-
|
14 |
-
import torch
|
15 |
-
|
16 |
-
_LOCAL_PROCESS_GROUP = None
|
17 |
-
"""
|
18 |
-
A torch process group which only includes processes that on the same machine as the current process.
|
19 |
-
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
|
20 |
-
"""
|
21 |
-
|
22 |
-
|
23 |
-
def get_world_size() -> int:
|
24 |
-
if not dist.is_available():
|
25 |
-
return 1
|
26 |
-
if not dist.is_initialized():
|
27 |
-
return 1
|
28 |
-
return dist.get_world_size()
|
29 |
-
|
30 |
-
|
31 |
-
def get_rank() -> int:
|
32 |
-
if not dist.is_available():
|
33 |
-
return 0
|
34 |
-
if not dist.is_initialized():
|
35 |
-
return 0
|
36 |
-
return dist.get_rank()
|
37 |
-
|
38 |
-
|
39 |
-
def get_local_rank() -> int:
|
40 |
-
"""
|
41 |
-
Returns:
|
42 |
-
The rank of the current process within the local (per-machine) process group.
|
43 |
-
"""
|
44 |
-
if not dist.is_available():
|
45 |
-
return 0
|
46 |
-
if not dist.is_initialized():
|
47 |
-
return 0
|
48 |
-
assert _LOCAL_PROCESS_GROUP is not None
|
49 |
-
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|
50 |
-
|
51 |
-
|
52 |
-
def get_local_size() -> int:
|
53 |
-
"""
|
54 |
-
Returns:
|
55 |
-
The size of the per-machine process group,
|
56 |
-
i.e. the number of processes per machine.
|
57 |
-
"""
|
58 |
-
if not dist.is_available():
|
59 |
-
return 1
|
60 |
-
if not dist.is_initialized():
|
61 |
-
return 1
|
62 |
-
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
|
63 |
-
|
64 |
-
|
65 |
-
def is_main_process() -> bool:
|
66 |
-
return get_rank() == 0
|
67 |
-
|
68 |
-
|
69 |
-
def synchronize():
|
70 |
-
"""
|
71 |
-
Helper function to synchronize (barrier) among all processes when
|
72 |
-
using distributed training
|
73 |
-
"""
|
74 |
-
if not dist.is_available():
|
75 |
-
return
|
76 |
-
if not dist.is_initialized():
|
77 |
-
return
|
78 |
-
world_size = dist.get_world_size()
|
79 |
-
if world_size == 1:
|
80 |
-
return
|
81 |
-
dist.barrier()
|
82 |
-
|
83 |
-
|
84 |
-
@functools.lru_cache()
|
85 |
-
def _get_global_gloo_group():
|
86 |
-
"""
|
87 |
-
Return a process group based on gloo backend, containing all the ranks
|
88 |
-
The result is cached.
|
89 |
-
"""
|
90 |
-
if dist.get_backend() == "nccl":
|
91 |
-
return dist.new_group(backend="gloo")
|
92 |
-
else:
|
93 |
-
return dist.group.WORLD
|
94 |
-
|
95 |
-
|
96 |
-
def _serialize_to_tensor(data, group):
|
97 |
-
backend = dist.get_backend(group)
|
98 |
-
assert backend in ["gloo", "nccl"]
|
99 |
-
device = torch.device("cpu" if backend == "gloo" else "cuda")
|
100 |
-
|
101 |
-
buffer = pickle.dumps(data)
|
102 |
-
if len(buffer) > 1024 ** 3:
|
103 |
-
logger = logging.getLogger(__name__)
|
104 |
-
logger.warning(
|
105 |
-
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
|
106 |
-
get_rank(), len(buffer) / (1024 ** 3), device
|
107 |
-
)
|
108 |
-
)
|
109 |
-
storage = torch.ByteStorage.from_buffer(buffer)
|
110 |
-
tensor = torch.ByteTensor(storage).to(device=device)
|
111 |
-
return tensor
|
112 |
-
|
113 |
-
|
114 |
-
def _pad_to_largest_tensor(tensor, group):
|
115 |
-
"""
|
116 |
-
Returns:
|
117 |
-
list[int]: size of the tensor, on each rank
|
118 |
-
Tensor: padded tensor that has the max size
|
119 |
-
"""
|
120 |
-
world_size = dist.get_world_size(group=group)
|
121 |
-
assert (
|
122 |
-
world_size >= 1
|
123 |
-
), "comm.gather/all_gather must be called from ranks within the given group!"
|
124 |
-
local_size = torch.tensor(
|
125 |
-
[tensor.numel()], dtype=torch.int64, device=tensor.device)
|
126 |
-
size_list = [
|
127 |
-
torch.zeros([1], dtype=torch.int64, device=tensor.device)
|
128 |
-
for _ in range(world_size)
|
129 |
-
]
|
130 |
-
dist.all_gather(size_list, local_size, group=group)
|
131 |
-
size_list = [int(size.item()) for size in size_list]
|
132 |
-
|
133 |
-
max_size = max(size_list)
|
134 |
-
|
135 |
-
# we pad the tensor because torch all_gather does not support
|
136 |
-
# gathering tensors of different shapes
|
137 |
-
if local_size != max_size:
|
138 |
-
padding = torch.zeros(
|
139 |
-
(max_size - local_size,), dtype=torch.uint8, device=tensor.device
|
140 |
-
)
|
141 |
-
tensor = torch.cat((tensor, padding), dim=0)
|
142 |
-
return size_list, tensor
|
143 |
-
|
144 |
-
|
145 |
-
def all_gather(data, group=None):
|
146 |
-
"""
|
147 |
-
Run all_gather on arbitrary picklable data (not necessarily tensors).
|
148 |
-
Args:
|
149 |
-
data: any picklable object
|
150 |
-
group: a torch process group. By default, will use a group which
|
151 |
-
contains all ranks on gloo backend.
|
152 |
-
Returns:
|
153 |
-
list[data]: list of data gathered from each rank
|
154 |
-
"""
|
155 |
-
if get_world_size() == 1:
|
156 |
-
return [data]
|
157 |
-
if group is None:
|
158 |
-
group = _get_global_gloo_group()
|
159 |
-
if dist.get_world_size(group) == 1:
|
160 |
-
return [data]
|
161 |
-
|
162 |
-
tensor = _serialize_to_tensor(data, group)
|
163 |
-
|
164 |
-
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
165 |
-
max_size = max(size_list)
|
166 |
-
|
167 |
-
# receiving Tensor from all ranks
|
168 |
-
tensor_list = [
|
169 |
-
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
|
170 |
-
for _ in size_list
|
171 |
-
]
|
172 |
-
dist.all_gather(tensor_list, tensor, group=group)
|
173 |
-
|
174 |
-
data_list = []
|
175 |
-
for size, tensor in zip(size_list, tensor_list):
|
176 |
-
buffer = tensor.cpu().numpy().tobytes()[:size]
|
177 |
-
data_list.append(pickle.loads(buffer))
|
178 |
-
|
179 |
-
return data_list
|
180 |
-
|
181 |
-
|
182 |
-
def gather(data, dst=0, group=None):
|
183 |
-
"""
|
184 |
-
Run gather on arbitrary picklable data (not necessarily tensors).
|
185 |
-
Args:
|
186 |
-
data: any picklable object
|
187 |
-
dst (int): destination rank
|
188 |
-
group: a torch process group. By default, will use a group which
|
189 |
-
contains all ranks on gloo backend.
|
190 |
-
Returns:
|
191 |
-
list[data]: on dst, a list of data gathered from each rank. Otherwise,
|
192 |
-
an empty list.
|
193 |
-
"""
|
194 |
-
if get_world_size() == 1:
|
195 |
-
return [data]
|
196 |
-
if group is None:
|
197 |
-
group = _get_global_gloo_group()
|
198 |
-
if dist.get_world_size(group=group) == 1:
|
199 |
-
return [data]
|
200 |
-
rank = dist.get_rank(group=group)
|
201 |
-
|
202 |
-
tensor = _serialize_to_tensor(data, group)
|
203 |
-
size_list, tensor = _pad_to_largest_tensor(tensor, group)
|
204 |
-
|
205 |
-
# receiving Tensor from all ranks
|
206 |
-
if rank == dst:
|
207 |
-
max_size = max(size_list)
|
208 |
-
tensor_list = [
|
209 |
-
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
|
210 |
-
for _ in size_list
|
211 |
-
]
|
212 |
-
dist.gather(tensor, tensor_list, dst=dst, group=group)
|
213 |
-
|
214 |
-
data_list = []
|
215 |
-
for size, tensor in zip(size_list, tensor_list):
|
216 |
-
buffer = tensor.cpu().numpy().tobytes()[:size]
|
217 |
-
data_list.append(pickle.loads(buffer))
|
218 |
-
return data_list
|
219 |
-
else:
|
220 |
-
dist.gather(tensor, [], dst=dst, group=group)
|
221 |
-
return []
|
222 |
-
|
223 |
-
|
224 |
-
def shared_random_seed():
|
225 |
-
"""
|
226 |
-
Returns:
|
227 |
-
int: a random number that is the same across all workers.
|
228 |
-
If workers need a shared RNG, they can use this shared seed to
|
229 |
-
create one.
|
230 |
-
All workers must call this function, otherwise it will deadlock.
|
231 |
-
"""
|
232 |
-
ints = np.random.randint(2 ** 31)
|
233 |
-
all_ints = all_gather(ints)
|
234 |
-
return all_ints[0]
|
235 |
-
|
236 |
-
|
237 |
-
# def reduce_dict(input_dict, average=True):
|
238 |
-
# """
|
239 |
-
# Reduce the values in the dictionary from all processes so that process with rank
|
240 |
-
# 0 has the reduced results.
|
241 |
-
# Args:
|
242 |
-
# input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
|
243 |
-
# average (bool): whether to do average or sum
|
244 |
-
# Returns:
|
245 |
-
# a dict with the same keys as input_dict, after reduction.
|
246 |
-
# """
|
247 |
-
# world_size = get_world_size()
|
248 |
-
# if world_size < 2:
|
249 |
-
# return input_dict
|
250 |
-
# with torch.no_grad():
|
251 |
-
# names = []
|
252 |
-
# values = []
|
253 |
-
# # sort the keys so that they are consistent across processes
|
254 |
-
# for k in sorted(input_dict.keys()):
|
255 |
-
# names.append(k)
|
256 |
-
# values.append(input_dict[k])
|
257 |
-
# values = torch.stack(values, dim=0)
|
258 |
-
# dist.reduce(values, dst=0)
|
259 |
-
# if dist.get_rank() == 0 and average:
|
260 |
-
# # only main process gets accumulated, so only divide by
|
261 |
-
# # world_size in this case
|
262 |
-
# values /= world_size
|
263 |
-
# reduced_dict = {k: v for k, v in zip(names, values)}
|
264 |
-
# return reduced_dict
|
265 |
-
|
266 |
-
|
267 |
-
def reduce_dict(input_dict, average=True):
|
268 |
-
"""
|
269 |
-
Reduce the values in the dictionary from all processes so that process with rank
|
270 |
-
0 has the reduced results.
|
271 |
-
Args:
|
272 |
-
input_dict (dict): inputs to be reduced. (values not necessarily tensors).
|
273 |
-
average (bool): whether to do average or sum
|
274 |
-
Returns:
|
275 |
-
a dict with the same keys as input_dict, after reduction.
|
276 |
-
"""
|
277 |
-
|
278 |
-
world_size = get_world_size()
|
279 |
-
if world_size < 2:
|
280 |
-
return input_dict
|
281 |
-
|
282 |
-
with torch.no_grad():
|
283 |
-
|
284 |
-
# Convert to CUDA Tensor for dist.reduce()
|
285 |
-
input_dict_cuda_vals = {}
|
286 |
-
for k, v in input_dict.items():
|
287 |
-
if type(v) == torch.Tensor:
|
288 |
-
input_dict_cuda_vals[k] = v.to('cuda')
|
289 |
-
else:
|
290 |
-
input_dict_cuda_vals[k] = torch.tensor(v, device='cuda')
|
291 |
-
|
292 |
-
names = []
|
293 |
-
values = []
|
294 |
-
for k, v in sorted(input_dict_cuda_vals.items()):
|
295 |
-
names.append(k)
|
296 |
-
values.append(v)
|
297 |
-
values = torch.stack(values, dim=0)
|
298 |
-
dist.reduce(values, dst=0) # reduce to gpu 0
|
299 |
-
|
300 |
-
if dist.get_rank() == 0 and average:
|
301 |
-
# only main process gets accumulated, so only divide by
|
302 |
-
# world_size in this case
|
303 |
-
values /= world_size
|
304 |
-
reduced_dict = {k: v for k, v in zip(names, values)}
|
305 |
-
return reduced_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/utils/div_utils.py
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
from random import uniform
|
2 |
-
import numpy as np
|
3 |
-
from collections import OrderedDict, defaultdict
|
4 |
-
from itertools import tee
|
5 |
-
import time
|
6 |
-
|
7 |
-
# -----------------------------------------------
|
8 |
-
def find_ngrams(input_list, n):
|
9 |
-
return zip(*[input_list[i:] for i in range(n)])
|
10 |
-
|
11 |
-
def compute_div_n(caps,n=1):
|
12 |
-
aggr_div = []
|
13 |
-
for k in caps:
|
14 |
-
all_ngrams = set()
|
15 |
-
lenT = 0.
|
16 |
-
for c in caps[k]:
|
17 |
-
tkns = c.split()
|
18 |
-
lenT += len(tkns)
|
19 |
-
ng = find_ngrams(tkns, n)
|
20 |
-
all_ngrams.update(ng)
|
21 |
-
aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT)))
|
22 |
-
return np.array(aggr_div).mean(), np.array(aggr_div)
|
23 |
-
|
24 |
-
def compute_global_div_n(caps,n=1):
|
25 |
-
aggr_div = []
|
26 |
-
all_ngrams = set()
|
27 |
-
lenT = 0.
|
28 |
-
for k in caps:
|
29 |
-
for c in caps[k]:
|
30 |
-
tkns = c.split()
|
31 |
-
lenT += len(tkns)
|
32 |
-
ng = find_ngrams(tkns, n)
|
33 |
-
all_ngrams.update(ng)
|
34 |
-
if n == 1:
|
35 |
-
aggr_div.append(float(len(all_ngrams)))
|
36 |
-
else:
|
37 |
-
aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT)))
|
38 |
-
return aggr_div[0], np.repeat(np.array(aggr_div),len(caps))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/utils/eval_multi.py
DELETED
@@ -1,218 +0,0 @@
|
|
1 |
-
from __future__ import absolute_import
|
2 |
-
from __future__ import division
|
3 |
-
from __future__ import print_function
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import json
|
10 |
-
from json import encoder
|
11 |
-
import random
|
12 |
-
import string
|
13 |
-
import time
|
14 |
-
import os
|
15 |
-
import sys
|
16 |
-
from . import misc as utils
|
17 |
-
from eval_utils import getCOCO
|
18 |
-
|
19 |
-
from .div_utils import compute_div_n, compute_global_div_n
|
20 |
-
|
21 |
-
import sys
|
22 |
-
try:
|
23 |
-
sys.path.append("coco-caption")
|
24 |
-
annFile = 'coco-caption/annotations/captions_val2014.json'
|
25 |
-
from pycocotools.coco import COCO
|
26 |
-
from pycocoevalcap.eval import COCOEvalCap
|
27 |
-
from pycocoevalcap.eval_spice import COCOEvalCapSpice
|
28 |
-
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
29 |
-
from pycocoevalcap.bleu.bleu import Bleu
|
30 |
-
sys.path.append("cider")
|
31 |
-
from pyciderevalcap.cider.cider import Cider
|
32 |
-
except:
|
33 |
-
print('Warning: requirements for eval_multi not satisfied')
|
34 |
-
|
35 |
-
|
36 |
-
def eval_allspice(dataset, preds_n, model_id, split):
|
37 |
-
coco = getCOCO(dataset)
|
38 |
-
valids = coco.getImgIds()
|
39 |
-
|
40 |
-
capsById = {}
|
41 |
-
for d in preds_n:
|
42 |
-
capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
|
43 |
-
|
44 |
-
# filter results to only those in MSCOCO validation set (will be about a third)
|
45 |
-
preds_filt_n = [p for p in preds_n if p['image_id'] in valids]
|
46 |
-
print('using %d/%d predictions_n' % (len(preds_filt_n), len(preds_n)))
|
47 |
-
cache_path_n = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
|
48 |
-
json.dump(preds_filt_n, open(cache_path_n, 'w')) # serialize to temporary json file. Sigh, COCO API...
|
49 |
-
|
50 |
-
# Eval AllSPICE
|
51 |
-
cocoRes_n = coco.loadRes(cache_path_n)
|
52 |
-
cocoEvalAllSPICE = COCOEvalCapSpice(coco, cocoRes_n)
|
53 |
-
cocoEvalAllSPICE.params['image_id'] = cocoRes_n.getImgIds()
|
54 |
-
cocoEvalAllSPICE.evaluate()
|
55 |
-
|
56 |
-
out = {}
|
57 |
-
for metric, score in cocoEvalAllSPICE.eval.items():
|
58 |
-
out['All'+metric] = score
|
59 |
-
|
60 |
-
imgToEvalAllSPICE = cocoEvalAllSPICE.imgToEval
|
61 |
-
# collect SPICE_sub_score
|
62 |
-
for k in list(imgToEvalAllSPICE.values())[0]['SPICE'].keys():
|
63 |
-
if k != 'All':
|
64 |
-
out['AllSPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEvalAllSPICE.values()])
|
65 |
-
out['AllSPICE_'+k] = (out['AllSPICE_'+k][out['AllSPICE_'+k]==out['AllSPICE_'+k]]).mean()
|
66 |
-
for p in preds_filt_n:
|
67 |
-
image_id, caption = p['image_id'], p['caption']
|
68 |
-
imgToEvalAllSPICE[image_id]['caption'] = capsById[image_id]
|
69 |
-
return {'overall': out, 'imgToEvalAllSPICE': imgToEvalAllSPICE}
|
70 |
-
|
71 |
-
def eval_oracle(dataset, preds_n, model_id, split):
|
72 |
-
cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
|
73 |
-
|
74 |
-
coco = getCOCO(dataset)
|
75 |
-
valids = coco.getImgIds()
|
76 |
-
|
77 |
-
capsById = {}
|
78 |
-
for d in preds_n:
|
79 |
-
capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
|
80 |
-
|
81 |
-
sample_n = capsById[list(capsById.keys())[0]]
|
82 |
-
for i in range(len(capsById[list(capsById.keys())[0]])):
|
83 |
-
preds = [_[i] for _ in capsById.values()]
|
84 |
-
|
85 |
-
json.dump(preds, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
|
86 |
-
|
87 |
-
cocoRes = coco.loadRes(cache_path)
|
88 |
-
cocoEval = COCOEvalCap(coco, cocoRes)
|
89 |
-
cocoEval.params['image_id'] = cocoRes.getImgIds()
|
90 |
-
cocoEval.evaluate()
|
91 |
-
|
92 |
-
imgToEval = cocoEval.imgToEval
|
93 |
-
for img_id in capsById.keys():
|
94 |
-
tmp = imgToEval[img_id]
|
95 |
-
for k in tmp['SPICE'].keys():
|
96 |
-
if k != 'All':
|
97 |
-
tmp['SPICE_'+k] = tmp['SPICE'][k]['f']
|
98 |
-
if tmp['SPICE_'+k] != tmp['SPICE_'+k]: # nan
|
99 |
-
tmp['SPICE_'+k] = -100
|
100 |
-
tmp['SPICE'] = tmp['SPICE']['All']['f']
|
101 |
-
if tmp['SPICE'] != tmp['SPICE']: tmp['SPICE'] = -100
|
102 |
-
capsById[img_id][i]['scores'] = imgToEval[img_id]
|
103 |
-
|
104 |
-
out = {'overall': {}, 'ImgToEval': {}}
|
105 |
-
for img_id in capsById.keys():
|
106 |
-
out['ImgToEval'][img_id] = {}
|
107 |
-
for metric in capsById[img_id][0]['scores'].keys():
|
108 |
-
if metric == 'image_id': continue
|
109 |
-
out['ImgToEval'][img_id]['oracle_'+metric] = max([_['scores'][metric] for _ in capsById[img_id]])
|
110 |
-
out['ImgToEval'][img_id]['avg_'+metric] = sum([_['scores'][metric] for _ in capsById[img_id]]) / len(capsById[img_id])
|
111 |
-
out['ImgToEval'][img_id]['captions'] = capsById[img_id]
|
112 |
-
for metric in list(out['ImgToEval'].values())[0].keys():
|
113 |
-
if metric == 'captions':
|
114 |
-
continue
|
115 |
-
tmp = np.array([_[metric] for _ in out['ImgToEval'].values()])
|
116 |
-
tmp = tmp[tmp!=-100]
|
117 |
-
out['overall'][metric] = tmp.mean()
|
118 |
-
|
119 |
-
return out
|
120 |
-
|
121 |
-
def eval_div_stats(dataset, preds_n, model_id, split):
|
122 |
-
tokenizer = PTBTokenizer()
|
123 |
-
|
124 |
-
capsById = {}
|
125 |
-
for i, d in enumerate(preds_n):
|
126 |
-
d['id'] = i
|
127 |
-
capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
|
128 |
-
|
129 |
-
n_caps_perimg = len(capsById[list(capsById.keys())[0]])
|
130 |
-
print(n_caps_perimg)
|
131 |
-
_capsById = capsById # save the untokenized version
|
132 |
-
capsById = tokenizer.tokenize(capsById)
|
133 |
-
|
134 |
-
div_1, adiv_1 = compute_div_n(capsById,1)
|
135 |
-
div_2, adiv_2 = compute_div_n(capsById,2)
|
136 |
-
|
137 |
-
globdiv_1, _= compute_global_div_n(capsById,1)
|
138 |
-
|
139 |
-
print('Diversity Statistics are as follows: \n Div1: %.2f, Div2: %.2f, gDiv1: %d\n'%(div_1,div_2, globdiv_1))
|
140 |
-
|
141 |
-
# compute mbleu
|
142 |
-
scorer = Bleu(4)
|
143 |
-
all_scrs = []
|
144 |
-
scrperimg = np.zeros((n_caps_perimg, len(capsById)))
|
145 |
-
|
146 |
-
for i in range(n_caps_perimg):
|
147 |
-
tempRefsById = {}
|
148 |
-
candsById = {}
|
149 |
-
for k in capsById:
|
150 |
-
tempRefsById[k] = capsById[k][:i] + capsById[k][i+1:]
|
151 |
-
candsById[k] = [capsById[k][i]]
|
152 |
-
|
153 |
-
score, scores = scorer.compute_score(tempRefsById, candsById)
|
154 |
-
all_scrs.append(score)
|
155 |
-
scrperimg[i,:] = scores[1]
|
156 |
-
|
157 |
-
all_scrs = np.array(all_scrs)
|
158 |
-
|
159 |
-
out = {}
|
160 |
-
out['overall'] = {'Div1': div_1, 'Div2': div_2, 'gDiv1': globdiv_1}
|
161 |
-
for k, score in zip(range(4), all_scrs.mean(axis=0).tolist()):
|
162 |
-
out['overall'].update({'mBLeu_%d'%(k+1): score})
|
163 |
-
imgToEval = {}
|
164 |
-
for i,imgid in enumerate(capsById.keys()):
|
165 |
-
imgToEval[imgid] = {'mBleu_2' : scrperimg[:,i].mean()}
|
166 |
-
imgToEval[imgid]['individuals'] = []
|
167 |
-
for j, d in enumerate(_capsById[imgid]):
|
168 |
-
imgToEval[imgid]['individuals'].append(preds_n[d['id']])
|
169 |
-
imgToEval[imgid]['individuals'][-1]['mBleu_2'] = scrperimg[j,i]
|
170 |
-
out['ImgToEval'] = imgToEval
|
171 |
-
|
172 |
-
print('Mean mutual Bleu scores on this set is:\nmBLeu_1, mBLeu_2, mBLeu_3, mBLeu_4')
|
173 |
-
print(all_scrs.mean(axis=0))
|
174 |
-
|
175 |
-
return out
|
176 |
-
|
177 |
-
def eval_self_cider(dataset, preds_n, model_id, split):
|
178 |
-
cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json')
|
179 |
-
|
180 |
-
coco = getCOCO(dataset)
|
181 |
-
valids = coco.getImgIds()
|
182 |
-
|
183 |
-
# Get Cider_scorer
|
184 |
-
Cider_scorer = Cider(df='corpus')
|
185 |
-
|
186 |
-
tokenizer = PTBTokenizer()
|
187 |
-
gts = {}
|
188 |
-
for imgId in valids:
|
189 |
-
gts[imgId] = coco.imgToAnns[imgId]
|
190 |
-
gts = tokenizer.tokenize(gts)
|
191 |
-
|
192 |
-
for imgId in valids:
|
193 |
-
Cider_scorer.cider_scorer += (None, gts[imgId])
|
194 |
-
Cider_scorer.cider_scorer.compute_doc_freq()
|
195 |
-
Cider_scorer.cider_scorer.ref_len = np.log(float(len(Cider_scorer.cider_scorer.crefs)))
|
196 |
-
|
197 |
-
# Prepare captions
|
198 |
-
capsById = {}
|
199 |
-
for d in preds_n:
|
200 |
-
capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d]
|
201 |
-
|
202 |
-
capsById = tokenizer.tokenize(capsById)
|
203 |
-
imgIds = list(capsById.keys())
|
204 |
-
scores = Cider_scorer.my_self_cider([capsById[_] for _ in imgIds])
|
205 |
-
|
206 |
-
def get_div(eigvals):
|
207 |
-
eigvals = np.clip(eigvals, 0, None)
|
208 |
-
return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals))
|
209 |
-
sc_scores = [get_div(np.linalg.eigvalsh(_/10)) for _ in scores]
|
210 |
-
score = np.mean(np.array(sc_scores))
|
211 |
-
|
212 |
-
imgToEval = {}
|
213 |
-
for i, image_id in enumerate(imgIds):
|
214 |
-
imgToEval[image_id] = {'self_cider': sc_scores[i], 'self_cider_mat': scores[i].tolist()}
|
215 |
-
return {'overall': {'self_cider': score}, 'imgToEval': imgToEval}
|
216 |
-
|
217 |
-
|
218 |
-
return score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/utils/eval_utils.py
DELETED
@@ -1,281 +0,0 @@
|
|
1 |
-
from __future__ import absolute_import
|
2 |
-
from __future__ import division
|
3 |
-
from __future__ import print_function
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
import torch.nn.functional as F
|
8 |
-
|
9 |
-
import numpy as np
|
10 |
-
import json
|
11 |
-
from json import encoder
|
12 |
-
import random
|
13 |
-
import string
|
14 |
-
import time
|
15 |
-
import os
|
16 |
-
import sys
|
17 |
-
from . import misc as utils
|
18 |
-
|
19 |
-
# load coco-caption if available
|
20 |
-
try:
|
21 |
-
sys.path.append("coco-caption")
|
22 |
-
from pycocotools.coco import COCO
|
23 |
-
from pycocoevalcap.eval import COCOEvalCap
|
24 |
-
except:
|
25 |
-
print('Warning: coco-caption not available')
|
26 |
-
|
27 |
-
bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am']
|
28 |
-
bad_endings += ['the']
|
29 |
-
|
30 |
-
|
31 |
-
def count_bad(sen):
|
32 |
-
sen = sen.split(' ')
|
33 |
-
if sen[-1] in bad_endings:
|
34 |
-
return 1
|
35 |
-
else:
|
36 |
-
return 0
|
37 |
-
|
38 |
-
|
39 |
-
def getCOCO(dataset):
|
40 |
-
if 'coco' in dataset:
|
41 |
-
annFile = 'coco-caption/annotations/captions_val2014.json'
|
42 |
-
elif 'flickr30k' in dataset or 'f30k' in dataset:
|
43 |
-
annFile = 'data/f30k_captions4eval.json'
|
44 |
-
return COCO(annFile)
|
45 |
-
|
46 |
-
|
47 |
-
def language_eval(dataset, preds, preds_n, eval_kwargs, split):
|
48 |
-
model_id = eval_kwargs['id']
|
49 |
-
eval_oracle = eval_kwargs.get('eval_oracle', 0)
|
50 |
-
|
51 |
-
# create output dictionary
|
52 |
-
out = {}
|
53 |
-
|
54 |
-
if len(preds_n) > 0:
|
55 |
-
# vocab size and novel sentences
|
56 |
-
if 'coco' in dataset:
|
57 |
-
dataset_file = 'data/dataset_coco.json'
|
58 |
-
elif 'flickr30k' in dataset or 'f30k' in dataset:
|
59 |
-
dataset_file = 'data/dataset_flickr30k.json'
|
60 |
-
training_sentences = set([' '.join(__['tokens']) for _ in json.load(open(dataset_file))['images'] if not _['split'] in ['val', 'test'] for __ in _['sentences']])
|
61 |
-
generated_sentences = set([_['caption'] for _ in preds_n])
|
62 |
-
novels = generated_sentences - training_sentences
|
63 |
-
out['novel_sentences'] = float(len(novels)) / len(preds_n)
|
64 |
-
tmp = [_.split() for _ in generated_sentences]
|
65 |
-
words = []
|
66 |
-
for _ in tmp:
|
67 |
-
words += _
|
68 |
-
out['vocab_size'] = len(set(words))
|
69 |
-
|
70 |
-
# encoder.FLOAT_REPR = lambda o: format(o, '.3f')
|
71 |
-
|
72 |
-
cache_path = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '.json')
|
73 |
-
|
74 |
-
coco = getCOCO(dataset)
|
75 |
-
valids = coco.getImgIds()
|
76 |
-
|
77 |
-
# filter results to only those in MSCOCO validation set
|
78 |
-
preds_filt = [p for p in preds if p['image_id'] in valids]
|
79 |
-
mean_perplexity = sum([_['perplexity'] for _ in preds_filt]) / len(preds_filt)
|
80 |
-
mean_entropy = sum([_['entropy'] for _ in preds_filt]) / len(preds_filt)
|
81 |
-
print('using %d/%d predictions' % (len(preds_filt), len(preds)))
|
82 |
-
json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API...
|
83 |
-
|
84 |
-
cocoRes = coco.loadRes(cache_path)
|
85 |
-
cocoEval = COCOEvalCap(coco, cocoRes)
|
86 |
-
cocoEval.params['image_id'] = cocoRes.getImgIds()
|
87 |
-
cocoEval.evaluate()
|
88 |
-
|
89 |
-
for metric, score in cocoEval.eval.items():
|
90 |
-
out[metric] = score
|
91 |
-
# Add mean perplexity
|
92 |
-
out['perplexity'] = mean_perplexity
|
93 |
-
out['entropy'] = mean_entropy
|
94 |
-
|
95 |
-
imgToEval = cocoEval.imgToEval
|
96 |
-
for k in list(imgToEval.values())[0]['SPICE'].keys():
|
97 |
-
if k != 'All':
|
98 |
-
out['SPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEval.values()])
|
99 |
-
out['SPICE_'+k] = (out['SPICE_'+k][out['SPICE_'+k]==out['SPICE_'+k]]).mean()
|
100 |
-
for p in preds_filt:
|
101 |
-
image_id, caption = p['image_id'], p['caption']
|
102 |
-
imgToEval[image_id]['caption'] = caption
|
103 |
-
|
104 |
-
if len(preds_n) > 0:
|
105 |
-
from . import eval_multi
|
106 |
-
cache_path_n = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '_n.json')
|
107 |
-
allspice = eval_multi.eval_allspice(dataset, preds_n, model_id, split)
|
108 |
-
out.update(allspice['overall'])
|
109 |
-
div_stats = eval_multi.eval_div_stats(dataset, preds_n, model_id, split)
|
110 |
-
out.update(div_stats['overall'])
|
111 |
-
if eval_oracle:
|
112 |
-
oracle = eval_multi.eval_oracle(dataset, preds_n, model_id, split)
|
113 |
-
out.update(oracle['overall'])
|
114 |
-
else:
|
115 |
-
oracle = None
|
116 |
-
self_cider = eval_multi.eval_self_cider(dataset, preds_n, model_id, split)
|
117 |
-
out.update(self_cider['overall'])
|
118 |
-
with open(cache_path_n, 'w') as outfile:
|
119 |
-
json.dump({'allspice': allspice, 'div_stats': div_stats, 'oracle': oracle, 'self_cider': self_cider}, outfile)
|
120 |
-
|
121 |
-
out['bad_count_rate'] = sum([count_bad(_['caption']) for _ in preds_filt]) / float(len(preds_filt))
|
122 |
-
outfile_path = os.path.join('eval_results/', model_id + '_' + split + '.json')
|
123 |
-
with open(outfile_path, 'w') as outfile:
|
124 |
-
json.dump({'overall': out, 'imgToEval': imgToEval}, outfile)
|
125 |
-
|
126 |
-
return out
|
127 |
-
|
128 |
-
def eval_split(model, crit, loader, eval_kwargs={}):
|
129 |
-
verbose = eval_kwargs.get('verbose', True)
|
130 |
-
verbose_beam = eval_kwargs.get('verbose_beam', 0)
|
131 |
-
verbose_loss = eval_kwargs.get('verbose_loss', 1)
|
132 |
-
num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
|
133 |
-
split = eval_kwargs.get('split', 'val')
|
134 |
-
lang_eval = eval_kwargs.get('language_eval', 0)
|
135 |
-
dataset = eval_kwargs.get('dataset', 'coco')
|
136 |
-
beam_size = eval_kwargs.get('beam_size', 1)
|
137 |
-
sample_n = eval_kwargs.get('sample_n', 1)
|
138 |
-
remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
|
139 |
-
os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration
|
140 |
-
device = eval_kwargs.get('device', 'cuda')
|
141 |
-
|
142 |
-
# Make sure in the evaluation mode
|
143 |
-
model.eval()
|
144 |
-
|
145 |
-
loader.reset_iterator(split)
|
146 |
-
|
147 |
-
n = 0
|
148 |
-
loss = 0
|
149 |
-
loss_sum = 0
|
150 |
-
loss_evals = 1e-8
|
151 |
-
predictions = []
|
152 |
-
n_predictions = [] # when sample_n > 1
|
153 |
-
while True:
|
154 |
-
data = loader.get_batch(split)
|
155 |
-
n = n + len(data['infos'])
|
156 |
-
|
157 |
-
tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
|
158 |
-
tmp = [_.to(device) if _ is not None else _ for _ in tmp]
|
159 |
-
fc_feats, att_feats, labels, masks, att_masks = tmp
|
160 |
-
if labels is not None and verbose_loss:
|
161 |
-
# forward the model to get loss
|
162 |
-
with torch.no_grad():
|
163 |
-
loss = crit(model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]).item()
|
164 |
-
loss_sum = loss_sum + loss
|
165 |
-
loss_evals = loss_evals + 1
|
166 |
-
|
167 |
-
# forward the model to also get generated samples for each image
|
168 |
-
with torch.no_grad():
|
169 |
-
tmp_eval_kwargs = eval_kwargs.copy()
|
170 |
-
tmp_eval_kwargs.update({'sample_n': 1})
|
171 |
-
seq, seq_logprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
172 |
-
seq = seq.data
|
173 |
-
entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
|
174 |
-
perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1)
|
175 |
-
|
176 |
-
# Print beam search
|
177 |
-
if beam_size > 1 and verbose_beam:
|
178 |
-
for i in range(fc_feats.shape[0]):
|
179 |
-
print('\n'.join([utils.decode_sequence(model.vocab, _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
|
180 |
-
print('--' * 10)
|
181 |
-
sents = utils.decode_sequence(model.vocab, seq)
|
182 |
-
|
183 |
-
for k, sent in enumerate(sents):
|
184 |
-
entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()}
|
185 |
-
if eval_kwargs.get('dump_path', 0) == 1:
|
186 |
-
entry['file_name'] = data['infos'][k]['file_path']
|
187 |
-
predictions.append(entry)
|
188 |
-
if eval_kwargs.get('dump_images', 0) == 1:
|
189 |
-
# dump the raw image to vis/ folder
|
190 |
-
cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross
|
191 |
-
print(cmd)
|
192 |
-
os.system(cmd)
|
193 |
-
|
194 |
-
if verbose:
|
195 |
-
print('image %s: %s' %(entry['image_id'], entry['caption']))
|
196 |
-
|
197 |
-
if sample_n > 1:
|
198 |
-
eval_split_n(model, n_predictions, [fc_feats, att_feats, att_masks, data], eval_kwargs)
|
199 |
-
|
200 |
-
# ix0 = data['bounds']['it_pos_now']
|
201 |
-
ix1 = data['bounds']['it_max']
|
202 |
-
if num_images != -1:
|
203 |
-
ix1 = min(ix1, num_images)
|
204 |
-
else:
|
205 |
-
num_images = ix1
|
206 |
-
for i in range(n - ix1):
|
207 |
-
predictions.pop()
|
208 |
-
|
209 |
-
if verbose:
|
210 |
-
print('evaluating validation preformance... %d/%d (%f)' %(n, ix1, loss))
|
211 |
-
|
212 |
-
if num_images >= 0 and n >= num_images:
|
213 |
-
break
|
214 |
-
|
215 |
-
lang_stats = None
|
216 |
-
if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]:
|
217 |
-
n_predictions = sorted(n_predictions, key=lambda x: x['perplexity'])
|
218 |
-
if not os.path.isdir('eval_results'):
|
219 |
-
os.mkdir('eval_results')
|
220 |
-
torch.save((predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth'))
|
221 |
-
if lang_eval == 1:
|
222 |
-
lang_stats = language_eval(dataset, predictions, n_predictions, eval_kwargs, split)
|
223 |
-
|
224 |
-
# Switch back to training mode
|
225 |
-
model.train()
|
226 |
-
return loss_sum/loss_evals, predictions, lang_stats
|
227 |
-
|
228 |
-
|
229 |
-
# Only run when sample_n > 0
|
230 |
-
def eval_split_n(model, n_predictions, input_data, eval_kwargs={}):
|
231 |
-
verbose = eval_kwargs.get('verbose', True)
|
232 |
-
beam_size = eval_kwargs.get('beam_size', 1)
|
233 |
-
sample_n = eval_kwargs.get('sample_n', 1)
|
234 |
-
sample_n_method = eval_kwargs.get('sample_n_method', 'sample')
|
235 |
-
|
236 |
-
fc_feats, att_feats, att_masks, data = input_data
|
237 |
-
|
238 |
-
tmp_eval_kwargs = eval_kwargs.copy()
|
239 |
-
if sample_n_method == 'bs':
|
240 |
-
# case 1 sample_n == beam size
|
241 |
-
tmp_eval_kwargs.update({'sample_n': 1, 'beam_size': sample_n, 'group_size': 1}) # randomness from softmax
|
242 |
-
with torch.no_grad():
|
243 |
-
model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
244 |
-
for k in range(fc_feats.shape[0]):
|
245 |
-
_sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(sample_n)]))
|
246 |
-
for sent in _sents:
|
247 |
-
entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
|
248 |
-
n_predictions.append(entry)
|
249 |
-
# case 2 sample / gumbel / topk sampling/ nucleus sampling
|
250 |
-
elif sample_n_method == 'sample' or \
|
251 |
-
sample_n_method == 'gumbel' or \
|
252 |
-
sample_n_method.startswith('top'):
|
253 |
-
tmp_eval_kwargs.update({'sample_n': sample_n, 'sample_method': sample_n_method, 'beam_size': 1}) # randomness from sample
|
254 |
-
with torch.no_grad():
|
255 |
-
_seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
256 |
-
_sents = utils.decode_sequence(model.vocab, _seq)
|
257 |
-
_perplexity = - _sampleLogprobs.gather(2, _seq.unsqueeze(2)).squeeze(2).sum(1) / ((_seq>0).to(_sampleLogprobs).sum(1)+1)
|
258 |
-
for k, sent in enumerate(_sents):
|
259 |
-
entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent, 'perplexity': _perplexity[k].item()}
|
260 |
-
n_predictions.append(entry)
|
261 |
-
elif sample_n_method == 'dbs':
|
262 |
-
# Use diverse beam search
|
263 |
-
tmp_eval_kwargs.update({'beam_size': sample_n * beam_size, 'group_size': sample_n}) # randomness from softmax
|
264 |
-
with torch.no_grad():
|
265 |
-
model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
266 |
-
for k in range(loader.batch_size):
|
267 |
-
_sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(0, sample_n*beam_size, beam_size)]))
|
268 |
-
for sent in _sents:
|
269 |
-
entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
|
270 |
-
n_predictions.append(entry)
|
271 |
-
else:
|
272 |
-
tmp_eval_kwargs.update({'sample_method': sample_n_method[1:], 'group_size': sample_n, 'beam_size':1}) # randomness from softmax
|
273 |
-
with torch.no_grad():
|
274 |
-
_seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
|
275 |
-
_sents = utils.decode_sequence(model.vocab, _seq)
|
276 |
-
for k, sent in enumerate(_sents):
|
277 |
-
entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent}
|
278 |
-
n_predictions.append(entry)
|
279 |
-
if verbose:
|
280 |
-
for entry in sorted(n_predictions[-fc_feats.shape[0] * sample_n:], key=lambda x: x['image_id']):
|
281 |
-
print('image %s: %s' %(entry['image_id'], entry['caption']))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/utils/misc.py
DELETED
@@ -1,251 +0,0 @@
|
|
1 |
-
from __future__ import absolute_import
|
2 |
-
from __future__ import division
|
3 |
-
from __future__ import print_function
|
4 |
-
|
5 |
-
import collections
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
import numpy as np
|
9 |
-
import torch.optim as optim
|
10 |
-
import os
|
11 |
-
|
12 |
-
import torch.nn.functional as F
|
13 |
-
|
14 |
-
import six
|
15 |
-
from six.moves import cPickle
|
16 |
-
|
17 |
-
bad_endings = ['with','in','on','of','a','at','to','for','an','this','his','her','that']
|
18 |
-
bad_endings += ['the']
|
19 |
-
|
20 |
-
|
21 |
-
def pickle_load(f):
|
22 |
-
""" Load a pickle.
|
23 |
-
Parameters
|
24 |
-
----------
|
25 |
-
f: file-like object
|
26 |
-
"""
|
27 |
-
if six.PY3:
|
28 |
-
return cPickle.load(f, encoding='latin-1')
|
29 |
-
else:
|
30 |
-
return cPickle.load(f)
|
31 |
-
|
32 |
-
|
33 |
-
def pickle_dump(obj, f):
|
34 |
-
""" Dump a pickle.
|
35 |
-
Parameters
|
36 |
-
----------
|
37 |
-
obj: pickled object
|
38 |
-
f: file-like object
|
39 |
-
"""
|
40 |
-
if six.PY3:
|
41 |
-
return cPickle.dump(obj, f, protocol=2)
|
42 |
-
else:
|
43 |
-
return cPickle.dump(obj, f)
|
44 |
-
|
45 |
-
|
46 |
-
# modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/comm.py
|
47 |
-
def serialize_to_tensor(data):
|
48 |
-
device = torch.device("cpu")
|
49 |
-
|
50 |
-
buffer = cPickle.dumps(data)
|
51 |
-
storage = torch.ByteStorage.from_buffer(buffer)
|
52 |
-
tensor = torch.ByteTensor(storage).to(device=device)
|
53 |
-
return tensor
|
54 |
-
|
55 |
-
|
56 |
-
def deserialize(tensor):
|
57 |
-
buffer = tensor.cpu().numpy().tobytes()
|
58 |
-
return cPickle.loads(buffer)
|
59 |
-
|
60 |
-
|
61 |
-
# Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token.
|
62 |
-
def decode_sequence(ix_to_word, seq):
|
63 |
-
# N, D = seq.size()
|
64 |
-
N, D = seq.shape
|
65 |
-
out = []
|
66 |
-
for i in range(N):
|
67 |
-
txt = ''
|
68 |
-
for j in range(D):
|
69 |
-
ix = seq[i,j]
|
70 |
-
if ix > 0 :
|
71 |
-
if j >= 1:
|
72 |
-
txt = txt + ' '
|
73 |
-
txt = txt + ix_to_word[str(ix.item())]
|
74 |
-
else:
|
75 |
-
break
|
76 |
-
if int(os.getenv('REMOVE_BAD_ENDINGS', '0')):
|
77 |
-
flag = 0
|
78 |
-
words = txt.split(' ')
|
79 |
-
for j in range(len(words)):
|
80 |
-
if words[-j-1] not in bad_endings:
|
81 |
-
flag = -j
|
82 |
-
break
|
83 |
-
txt = ' '.join(words[0:len(words)+flag])
|
84 |
-
out.append(txt.replace('@@ ', ''))
|
85 |
-
return out
|
86 |
-
|
87 |
-
|
88 |
-
def save_checkpoint(opt, model, infos, optimizer, histories=None, append=''):
|
89 |
-
if len(append) > 0:
|
90 |
-
append = '-' + append
|
91 |
-
# if checkpoint_path doesn't exist
|
92 |
-
if not os.path.isdir(opt.checkpoint_path):
|
93 |
-
os.makedirs(opt.checkpoint_path)
|
94 |
-
checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
|
95 |
-
torch.save(model.state_dict(), checkpoint_path)
|
96 |
-
print("model saved to {}".format(checkpoint_path))
|
97 |
-
optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
|
98 |
-
torch.save(optimizer.state_dict(), optimizer_path)
|
99 |
-
with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
|
100 |
-
pickle_dump(infos, f)
|
101 |
-
if histories:
|
102 |
-
with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
|
103 |
-
pickle_dump(histories, f)
|
104 |
-
|
105 |
-
|
106 |
-
def set_lr(optimizer, lr):
|
107 |
-
for group in optimizer.param_groups:
|
108 |
-
group['lr'] = lr
|
109 |
-
|
110 |
-
def get_lr(optimizer):
|
111 |
-
for group in optimizer.param_groups:
|
112 |
-
return group['lr']
|
113 |
-
|
114 |
-
|
115 |
-
def build_optimizer(params, opt):
|
116 |
-
if opt.optim == 'rmsprop':
|
117 |
-
return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay)
|
118 |
-
elif opt.optim == 'adagrad':
|
119 |
-
return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay)
|
120 |
-
elif opt.optim == 'sgd':
|
121 |
-
return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay)
|
122 |
-
elif opt.optim == 'sgdm':
|
123 |
-
return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay)
|
124 |
-
elif opt.optim == 'sgdmom':
|
125 |
-
return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True)
|
126 |
-
elif opt.optim == 'adam':
|
127 |
-
return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
|
128 |
-
elif opt.optim == 'adamw':
|
129 |
-
return optim.AdamW(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay)
|
130 |
-
else:
|
131 |
-
raise Exception("bad option opt.optim: {}".format(opt.optim))
|
132 |
-
|
133 |
-
|
134 |
-
def penalty_builder(penalty_config):
|
135 |
-
if penalty_config == '':
|
136 |
-
return lambda x,y: y
|
137 |
-
pen_type, alpha = penalty_config.split('_')
|
138 |
-
alpha = float(alpha)
|
139 |
-
if pen_type == 'wu':
|
140 |
-
return lambda x,y: length_wu(x,y,alpha)
|
141 |
-
if pen_type == 'avg':
|
142 |
-
return lambda x,y: length_average(x,y,alpha)
|
143 |
-
|
144 |
-
def length_wu(length, logprobs, alpha=0.):
|
145 |
-
"""
|
146 |
-
NMT length re-ranking score from
|
147 |
-
"Google's Neural Machine Translation System" :cite:`wu2016google`.
|
148 |
-
"""
|
149 |
-
|
150 |
-
modifier = (((5 + length) ** alpha) /
|
151 |
-
((5 + 1) ** alpha))
|
152 |
-
return (logprobs / modifier)
|
153 |
-
|
154 |
-
def length_average(length, logprobs, alpha=0.):
|
155 |
-
"""
|
156 |
-
Returns the average probability of tokens in a sequence.
|
157 |
-
"""
|
158 |
-
return logprobs / length
|
159 |
-
|
160 |
-
|
161 |
-
class NoamOpt(object):
|
162 |
-
"Optim wrapper that implements rate."
|
163 |
-
def __init__(self, model_size, factor, warmup, optimizer):
|
164 |
-
self.optimizer = optimizer
|
165 |
-
self._step = 0
|
166 |
-
self.warmup = warmup
|
167 |
-
self.factor = factor
|
168 |
-
self.model_size = model_size
|
169 |
-
self._rate = 0
|
170 |
-
|
171 |
-
def step(self):
|
172 |
-
"Update parameters and rate"
|
173 |
-
self._step += 1
|
174 |
-
rate = self.rate()
|
175 |
-
for p in self.optimizer.param_groups:
|
176 |
-
p['lr'] = rate
|
177 |
-
self._rate = rate
|
178 |
-
self.optimizer.step()
|
179 |
-
|
180 |
-
def rate(self, step = None):
|
181 |
-
"Implement `lrate` above"
|
182 |
-
if step is None:
|
183 |
-
step = self._step
|
184 |
-
return self.factor * \
|
185 |
-
(self.model_size ** (-0.5) *
|
186 |
-
min(step ** (-0.5), step * self.warmup ** (-1.5)))
|
187 |
-
|
188 |
-
def __getattr__(self, name):
|
189 |
-
return getattr(self.optimizer, name)
|
190 |
-
|
191 |
-
def state_dict(self):
|
192 |
-
state_dict = self.optimizer.state_dict()
|
193 |
-
state_dict['_step'] = self._step
|
194 |
-
return state_dict
|
195 |
-
|
196 |
-
def load_state_dict(self, state_dict):
|
197 |
-
if '_step' in state_dict:
|
198 |
-
self._step = state_dict['_step']
|
199 |
-
del state_dict['_step']
|
200 |
-
self.optimizer.load_state_dict(state_dict)
|
201 |
-
|
202 |
-
class ReduceLROnPlateau(object):
|
203 |
-
"Optim wrapper that implements rate."
|
204 |
-
def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
|
205 |
-
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps)
|
206 |
-
self.optimizer = optimizer
|
207 |
-
self.current_lr = get_lr(optimizer)
|
208 |
-
|
209 |
-
def step(self):
|
210 |
-
"Update parameters and rate"
|
211 |
-
self.optimizer.step()
|
212 |
-
|
213 |
-
def scheduler_step(self, val):
|
214 |
-
self.scheduler.step(val)
|
215 |
-
self.current_lr = get_lr(self.optimizer)
|
216 |
-
|
217 |
-
def state_dict(self):
|
218 |
-
return {'current_lr':self.current_lr,
|
219 |
-
'scheduler_state_dict': self.scheduler.state_dict(),
|
220 |
-
'optimizer_state_dict': self.optimizer.state_dict()}
|
221 |
-
|
222 |
-
def load_state_dict(self, state_dict):
|
223 |
-
if 'current_lr' not in state_dict:
|
224 |
-
# it's normal optimizer
|
225 |
-
self.optimizer.load_state_dict(state_dict)
|
226 |
-
set_lr(self.optimizer, self.current_lr) # use the lr fromt the option
|
227 |
-
else:
|
228 |
-
# it's a schduler
|
229 |
-
self.current_lr = state_dict['current_lr']
|
230 |
-
self.scheduler.load_state_dict(state_dict['scheduler_state_dict'])
|
231 |
-
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
|
232 |
-
# current_lr is actually useless in this case
|
233 |
-
|
234 |
-
def rate(self, step = None):
|
235 |
-
"Implement `lrate` above"
|
236 |
-
if step is None:
|
237 |
-
step = self._step
|
238 |
-
return self.factor * \
|
239 |
-
(self.model_size ** (-0.5) *
|
240 |
-
min(step ** (-0.5), step * self.warmup ** (-1.5)))
|
241 |
-
|
242 |
-
def __getattr__(self, name):
|
243 |
-
return getattr(self.optimizer, name)
|
244 |
-
|
245 |
-
def get_std_opt(model, optim_func='adam', factor=1, warmup=2000):
|
246 |
-
# return NoamOpt(model.tgt_embed[0].d_model, 2, 4000,
|
247 |
-
# torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
|
248 |
-
optim_func = dict(adam=torch.optim.Adam,
|
249 |
-
adamw=torch.optim.AdamW)[optim_func]
|
250 |
-
return NoamOpt(model.d_model, factor, warmup,
|
251 |
-
optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/utils/opts.py
DELETED
@@ -1,412 +0,0 @@
|
|
1 |
-
from __future__ import print_function
|
2 |
-
import argparse
|
3 |
-
|
4 |
-
|
5 |
-
def if_use_feat(caption_model):
|
6 |
-
# Decide if load attention feature according to caption model
|
7 |
-
if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']:
|
8 |
-
use_att, use_fc = False, True
|
9 |
-
elif caption_model == 'language_model':
|
10 |
-
use_att, use_fc = False, False
|
11 |
-
elif caption_model in ['updown', 'topdown']:
|
12 |
-
use_fc, use_att = True, True
|
13 |
-
else:
|
14 |
-
use_att, use_fc = True, False
|
15 |
-
return use_fc, use_att
|
16 |
-
|
17 |
-
import pprint
|
18 |
-
class Config(object):
|
19 |
-
def __init__(self, **kwargs):
|
20 |
-
"""Configuration Class: set kwargs as class attributes with setattr"""
|
21 |
-
for k, v in kwargs.items():
|
22 |
-
setattr(self, k, v)
|
23 |
-
|
24 |
-
@property
|
25 |
-
def config_str(self):
|
26 |
-
return pprint.pformat(self.__dict__)
|
27 |
-
|
28 |
-
def __repr__(self):
|
29 |
-
"""Pretty-print configurations in alphabetical order"""
|
30 |
-
config_str = 'Configurations\n'
|
31 |
-
config_str += self.config_str
|
32 |
-
return config_str
|
33 |
-
|
34 |
-
|
35 |
-
def parse_opt(parse=True, **optional_kwargs):
|
36 |
-
parser = argparse.ArgumentParser()
|
37 |
-
# Data input settings
|
38 |
-
parser.add_argument('--input_json', type=str, default='data/coco.json',
|
39 |
-
help='path to the json file containing additional info and vocab')
|
40 |
-
parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc',
|
41 |
-
help='path to the directory containing the preprocessed fc feats')
|
42 |
-
parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att',
|
43 |
-
help='path to the directory containing the preprocessed att feats')
|
44 |
-
parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box',
|
45 |
-
help='path to the directory containing the boxes of att feats')
|
46 |
-
parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5',
|
47 |
-
help='path to the h5file containing the preprocessed dataset')
|
48 |
-
parser.add_argument('--data_in_memory', action='store_true',
|
49 |
-
help='True if we want to save the features in memory')
|
50 |
-
parser.add_argument('--start_from', type=str, default=None,
|
51 |
-
help="""continue training from saved model at this path. Path must contain files saved by previous training process:
|
52 |
-
'infos.pkl' : configuration;
|
53 |
-
'model.pth' : weights
|
54 |
-
""")
|
55 |
-
parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs',
|
56 |
-
help='Cached token file for calculating cider score during self critical training.')
|
57 |
-
|
58 |
-
# Model settings
|
59 |
-
parser.add_argument('--caption_model', type=str, default="show_tell",
|
60 |
-
help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, updown, stackatt, denseatt, transformer')
|
61 |
-
parser.add_argument('--rnn_size', type=int, default=512,
|
62 |
-
help='size of the rnn in number of hidden nodes in each layer')
|
63 |
-
parser.add_argument('--num_layers', type=int, default=1,
|
64 |
-
help='number of layers in the RNN')
|
65 |
-
parser.add_argument('--rnn_type', type=str, default='lstm',
|
66 |
-
help='rnn, gru, or lstm')
|
67 |
-
parser.add_argument('--input_encoding_size', type=int, default=512,
|
68 |
-
help='the encoding size of each token in the vocabulary, and the image.')
|
69 |
-
parser.add_argument('--att_hid_size', type=int, default=512,
|
70 |
-
help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer')
|
71 |
-
parser.add_argument('--fc_feat_size', type=int, default=2048,
|
72 |
-
help='2048 for resnet, 4096 for vgg')
|
73 |
-
parser.add_argument('--att_feat_size', type=int, default=2048,
|
74 |
-
help='2048 for resnet, 512 for vgg')
|
75 |
-
parser.add_argument('--logit_layers', type=int, default=1,
|
76 |
-
help='number of layers in the RNN')
|
77 |
-
|
78 |
-
|
79 |
-
parser.add_argument('--use_bn', type=int, default=0,
|
80 |
-
help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed')
|
81 |
-
|
82 |
-
# feature manipulation
|
83 |
-
parser.add_argument('--norm_att_feat', type=int, default=0,
|
84 |
-
help='If normalize attention features')
|
85 |
-
parser.add_argument('--use_box', type=int, default=0,
|
86 |
-
help='If use box features')
|
87 |
-
parser.add_argument('--norm_box_feat', type=int, default=0,
|
88 |
-
help='If use box, do we normalize box feature')
|
89 |
-
|
90 |
-
# Optimization: General
|
91 |
-
parser.add_argument('--max_epochs', type=int, default=-1,
|
92 |
-
help='number of epochs')
|
93 |
-
parser.add_argument('--batch_size', type=int, default=16,
|
94 |
-
help='minibatch size')
|
95 |
-
parser.add_argument('--grad_clip_mode', type=str, default='value',
|
96 |
-
help='value or norm')
|
97 |
-
parser.add_argument('--grad_clip_value', type=float, default=0.1,
|
98 |
-
help='clip gradients at this value/max_norm, 0 means no clipping')
|
99 |
-
parser.add_argument('--drop_prob_lm', type=float, default=0.5,
|
100 |
-
help='strength of dropout in the Language Model RNN')
|
101 |
-
parser.add_argument('--self_critical_after', type=int, default=-1,
|
102 |
-
help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)')
|
103 |
-
parser.add_argument('--seq_per_img', type=int, default=5,
|
104 |
-
help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image')
|
105 |
-
|
106 |
-
parser.add_argument('--verbose', type=int, default=0)
|
107 |
-
|
108 |
-
# Sample related
|
109 |
-
add_eval_sample_opts(parser)
|
110 |
-
|
111 |
-
#Optimization: for the Language Model
|
112 |
-
parser.add_argument('--optim', type=str, default='adam',
|
113 |
-
help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam|adamw')
|
114 |
-
parser.add_argument('--learning_rate', type=float, default=4e-4,
|
115 |
-
help='learning rate')
|
116 |
-
parser.add_argument('--learning_rate_decay_start', type=int, default=-1,
|
117 |
-
help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)')
|
118 |
-
parser.add_argument('--learning_rate_decay_every', type=int, default=3,
|
119 |
-
help='every how many iterations thereafter to drop LR?(in epoch)')
|
120 |
-
parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8,
|
121 |
-
help='every how many iterations thereafter to drop LR?(in epoch)')
|
122 |
-
parser.add_argument('--optim_alpha', type=float, default=0.9,
|
123 |
-
help='alpha for adam')
|
124 |
-
parser.add_argument('--optim_beta', type=float, default=0.999,
|
125 |
-
help='beta used for adam')
|
126 |
-
parser.add_argument('--optim_epsilon', type=float, default=1e-8,
|
127 |
-
help='epsilon that goes into denominator for smoothing')
|
128 |
-
parser.add_argument('--weight_decay', type=float, default=0,
|
129 |
-
help='weight_decay')
|
130 |
-
# Transformer
|
131 |
-
parser.add_argument('--label_smoothing', type=float, default=0,
|
132 |
-
help='')
|
133 |
-
parser.add_argument('--noamopt', action='store_true',
|
134 |
-
help='')
|
135 |
-
parser.add_argument('--noamopt_warmup', type=int, default=2000,
|
136 |
-
help='')
|
137 |
-
parser.add_argument('--noamopt_factor', type=float, default=1,
|
138 |
-
help='')
|
139 |
-
parser.add_argument('--reduce_on_plateau', action='store_true',
|
140 |
-
help='')
|
141 |
-
parser.add_argument('--reduce_on_plateau_factor', type=float, default=0.5,
|
142 |
-
help='')
|
143 |
-
parser.add_argument('--reduce_on_plateau_patience', type=int, default=3,
|
144 |
-
help='')
|
145 |
-
parser.add_argument('--cached_transformer', action='store_true',
|
146 |
-
help='')
|
147 |
-
|
148 |
-
|
149 |
-
parser.add_argument('--use_warmup', action='store_true',
|
150 |
-
help='warm up the learing rate?')
|
151 |
-
|
152 |
-
parser.add_argument('--scheduled_sampling_start', type=int, default=-1,
|
153 |
-
help='at what iteration to start decay gt probability')
|
154 |
-
parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5,
|
155 |
-
help='every how many iterations thereafter to gt probability')
|
156 |
-
parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05,
|
157 |
-
help='How much to update the prob')
|
158 |
-
parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25,
|
159 |
-
help='Maximum scheduled sampling prob.')
|
160 |
-
|
161 |
-
|
162 |
-
# Evaluation/Checkpointing
|
163 |
-
parser.add_argument('--val_images_use', type=int, default=3200,
|
164 |
-
help='how many images to use when periodically evaluating the validation loss? (-1 = all)')
|
165 |
-
parser.add_argument('--save_checkpoint_every', type=int, default=2500,
|
166 |
-
help='how often to save a model checkpoint (in iterations)?')
|
167 |
-
parser.add_argument('--save_every_epoch', action='store_true',
|
168 |
-
help='Save checkpoint every epoch, will overwrite save_checkpoint_every')
|
169 |
-
parser.add_argument('--save_history_ckpt', type=int, default=0,
|
170 |
-
help='If save checkpoints at every save point')
|
171 |
-
parser.add_argument('--checkpoint_path', type=str, default=None,
|
172 |
-
help='directory to store checkpointed models')
|
173 |
-
parser.add_argument('--language_eval', type=int, default=0,
|
174 |
-
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
|
175 |
-
parser.add_argument('--losses_log_every', type=int, default=25,
|
176 |
-
help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)')
|
177 |
-
parser.add_argument('--load_best_score', type=int, default=1,
|
178 |
-
help='Do we load previous best score when resuming training.')
|
179 |
-
|
180 |
-
# misc
|
181 |
-
parser.add_argument('--id', type=str, default='',
|
182 |
-
help='an id identifying this run/job. used in cross-val and appended when writing progress files')
|
183 |
-
parser.add_argument('--train_only', type=int, default=0,
|
184 |
-
help='if true then use 80k, else use 110k')
|
185 |
-
|
186 |
-
|
187 |
-
# Reward
|
188 |
-
parser.add_argument('--cider_reward_weight', type=float, default=1,
|
189 |
-
help='The reward weight from cider')
|
190 |
-
parser.add_argument('--bleu_reward_weight', type=float, default=0,
|
191 |
-
help='The reward weight from bleu4')
|
192 |
-
|
193 |
-
# Reward
|
194 |
-
parser.add_argument('--clipscore_reward_weight', type=float, default=1,
|
195 |
-
help='The reward weight from clipscore')
|
196 |
-
parser.add_argument('--use_clipscore', type=float, default=0,
|
197 |
-
help='Use CLIPScore')
|
198 |
-
parser.add_argument('--clipscore_mode', type=str, default='clip_s',
|
199 |
-
help='Which CLIPScore to use: clip_s|refclip_s')
|
200 |
-
|
201 |
-
|
202 |
-
# Structure_loss
|
203 |
-
parser.add_argument('--structure_loss_weight', type=float, default=1,
|
204 |
-
help='')
|
205 |
-
parser.add_argument('--structure_after', type=int, default=-1,
|
206 |
-
help='T')
|
207 |
-
parser.add_argument('--structure_loss_type', type=str, default='seqnll',
|
208 |
-
help='')
|
209 |
-
parser.add_argument('--struc_use_logsoftmax', action='store_true', help='')
|
210 |
-
parser.add_argument('--entropy_reward_weight', type=float, default=0,
|
211 |
-
help='Entropy reward, seems very interesting')
|
212 |
-
parser.add_argument('--self_cider_reward_weight', type=float, default=0,
|
213 |
-
help='self cider reward')
|
214 |
-
|
215 |
-
# Used for self critical or structure. Used when sampling is need during training
|
216 |
-
parser.add_argument('--train_sample_n', type=int, default=16,
|
217 |
-
help='The reward weight from cider')
|
218 |
-
parser.add_argument('--train_sample_method', type=str, default='sample',
|
219 |
-
help='')
|
220 |
-
parser.add_argument('--train_beam_size', type=int, default=1,
|
221 |
-
help='')
|
222 |
-
|
223 |
-
# Used for self critical
|
224 |
-
parser.add_argument('--sc_sample_method', type=str, default='greedy',
|
225 |
-
help='')
|
226 |
-
parser.add_argument('--sc_beam_size', type=int, default=1,
|
227 |
-
help='')
|
228 |
-
|
229 |
-
|
230 |
-
# For diversity evaluation during training
|
231 |
-
add_diversity_opts(parser)
|
232 |
-
|
233 |
-
|
234 |
-
# config
|
235 |
-
parser.add_argument('--cfg', type=str, default=None,
|
236 |
-
help='configuration; similar to what is used in detectron')
|
237 |
-
parser.add_argument(
|
238 |
-
'--set_cfgs', dest='set_cfgs',
|
239 |
-
help='Set config keys. Key value sequence seperate by whitespace.'
|
240 |
-
'e.g. [key] [value] [key] [value]\n This has higher priority'
|
241 |
-
'than cfg file but lower than other args. (You can only overwrite'
|
242 |
-
'arguments that have alerady been defined in config file.)',
|
243 |
-
default=[], nargs='+')
|
244 |
-
# How will config be used
|
245 |
-
# 1) read cfg argument, and load the cfg file if it's not None
|
246 |
-
# 2) Overwrite cfg argument with set_cfgs
|
247 |
-
# 3) parse config argument to args.
|
248 |
-
# 4) in the end, parse command line argument and overwrite args
|
249 |
-
|
250 |
-
# step 1: read cfg_fn
|
251 |
-
# args = parser.parse_args()
|
252 |
-
# Parse the arguments.
|
253 |
-
if parse:
|
254 |
-
args = parser.parse_args()
|
255 |
-
# For interative engironmnet (ex. jupyter)
|
256 |
-
else:
|
257 |
-
args = parser.parse_known_args()[0]
|
258 |
-
# print(args)
|
259 |
-
|
260 |
-
# Namespace => Dictionary
|
261 |
-
kwargs = vars(args)
|
262 |
-
# for k, v in optional_kwargs.items():
|
263 |
-
# setattr(args, k, v)
|
264 |
-
kwargs.update(optional_kwargs)
|
265 |
-
|
266 |
-
args = Config(**kwargs)
|
267 |
-
|
268 |
-
|
269 |
-
if args.cfg is not None or args.set_cfgs is not None:
|
270 |
-
from .config import CfgNode
|
271 |
-
if args.cfg is not None:
|
272 |
-
# print('Read Cfg')
|
273 |
-
cn = CfgNode(CfgNode.load_yaml_with_base(args.cfg))
|
274 |
-
# print(cn)
|
275 |
-
else:
|
276 |
-
cn = CfgNode()
|
277 |
-
if args.set_cfgs is not None:
|
278 |
-
cn.merge_from_list(args.set_cfgs)
|
279 |
-
for k,v in cn.items():
|
280 |
-
if not hasattr(args, k):
|
281 |
-
import os
|
282 |
-
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
|
283 |
-
pass
|
284 |
-
else:
|
285 |
-
print('Warning: key %s not in args' % k)
|
286 |
-
|
287 |
-
setattr(args, k, v)
|
288 |
-
|
289 |
-
if parse:
|
290 |
-
args = parser.parse_args(namespace=args)
|
291 |
-
else:
|
292 |
-
args = parser.parse_known_args(namespace=args)[0]
|
293 |
-
|
294 |
-
# Check if args are valid
|
295 |
-
assert args.rnn_size > 0, "rnn_size should be greater than 0"
|
296 |
-
assert args.num_layers > 0, "num_layers should be greater than 0"
|
297 |
-
assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0"
|
298 |
-
assert args.batch_size > 0, "batch_size should be greater than 0"
|
299 |
-
assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1"
|
300 |
-
assert args.seq_per_img > 0, "seq_per_img should be greater than 0"
|
301 |
-
assert args.beam_size > 0, "beam_size should be greater than 0"
|
302 |
-
assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0"
|
303 |
-
assert args.losses_log_every > 0, "losses_log_every should be greater than 0"
|
304 |
-
assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1"
|
305 |
-
assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1"
|
306 |
-
assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1"
|
307 |
-
|
308 |
-
# default value for start_from and checkpoint_path
|
309 |
-
args.checkpoint_path = args.checkpoint_path or './log_%s' %args.id
|
310 |
-
args.start_from = args.start_from or args.checkpoint_path
|
311 |
-
|
312 |
-
# Deal with feature things before anything
|
313 |
-
args.use_fc, args.use_att = if_use_feat(args.caption_model)
|
314 |
-
if args.use_box: args.att_feat_size = args.att_feat_size + 5
|
315 |
-
|
316 |
-
return args
|
317 |
-
|
318 |
-
|
319 |
-
def add_eval_options(parser):
|
320 |
-
# Basic options
|
321 |
-
parser.add_argument('--batch_size', type=int, default=0,
|
322 |
-
help='if > 0 then overrule, otherwise load from checkpoint.')
|
323 |
-
parser.add_argument('--num_images', type=int, default=-1,
|
324 |
-
help='how many images to use when periodically evaluating the loss? (-1 = all)')
|
325 |
-
parser.add_argument('--language_eval', type=int, default=0,
|
326 |
-
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
|
327 |
-
parser.add_argument('--dump_images', type=int, default=1,
|
328 |
-
help='Dump images into vis/imgs folder for vis? (1=yes,0=no)')
|
329 |
-
parser.add_argument('--dump_json', type=int, default=1,
|
330 |
-
help='Dump json with predictions into vis folder? (1=yes,0=no)')
|
331 |
-
parser.add_argument('--dump_path', type=int, default=0,
|
332 |
-
help='Write image paths along with predictions into vis json? (1=yes,0=no)')
|
333 |
-
|
334 |
-
# Sampling options
|
335 |
-
add_eval_sample_opts(parser)
|
336 |
-
|
337 |
-
# For evaluation on a folder of images:
|
338 |
-
parser.add_argument('--image_folder', type=str, default='',
|
339 |
-
help='If this is nonempty then will predict on the images in this folder path')
|
340 |
-
parser.add_argument('--image_root', type=str, default='',
|
341 |
-
help='In case the image paths have to be preprended with a root path to an image folder')
|
342 |
-
# For evaluation on MSCOCO images from some split:
|
343 |
-
parser.add_argument('--input_fc_dir', type=str, default='',
|
344 |
-
help='path to the h5file containing the preprocessed dataset')
|
345 |
-
parser.add_argument('--input_att_dir', type=str, default='',
|
346 |
-
help='path to the h5file containing the preprocessed dataset')
|
347 |
-
parser.add_argument('--input_box_dir', type=str, default='',
|
348 |
-
help='path to the h5file containing the preprocessed dataset')
|
349 |
-
parser.add_argument('--input_label_h5', type=str, default='',
|
350 |
-
help='path to the h5file containing the preprocessed dataset')
|
351 |
-
parser.add_argument('--input_json', type=str, default='',
|
352 |
-
help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.')
|
353 |
-
parser.add_argument('--split', type=str, default='test',
|
354 |
-
help='if running on MSCOCO images, which split to use: val|test|train')
|
355 |
-
parser.add_argument('--coco_json', type=str, default='',
|
356 |
-
help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.')
|
357 |
-
# misc
|
358 |
-
parser.add_argument('--id', type=str, default='',
|
359 |
-
help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
|
360 |
-
parser.add_argument('--verbose_beam', type=int, default=1,
|
361 |
-
help='if we need to print out all beam search beams.')
|
362 |
-
parser.add_argument('--verbose_loss', type=int, default=0,
|
363 |
-
help='If calculate loss using ground truth during evaluation')
|
364 |
-
|
365 |
-
def add_diversity_opts(parser):
|
366 |
-
parser.add_argument('--sample_n', type=int, default=1,
|
367 |
-
help='Diverse sampling')
|
368 |
-
parser.add_argument('--sample_n_method', type=str, default='sample',
|
369 |
-
help='sample, bs, dbs, gumbel, topk, dgreedy, dsample, dtopk, dtopp')
|
370 |
-
parser.add_argument('--eval_oracle', type=int, default=1,
|
371 |
-
help='if we need to calculate loss.')
|
372 |
-
|
373 |
-
|
374 |
-
# Sampling related options
|
375 |
-
def add_eval_sample_opts(parser):
|
376 |
-
parser.add_argument('--sample_method', type=str, default='greedy',
|
377 |
-
help='greedy; sample; gumbel; top<int>, top<0-1>')
|
378 |
-
parser.add_argument('--beam_size', type=int, default=1,
|
379 |
-
help='used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
|
380 |
-
parser.add_argument('--max_length', type=int, default=20,
|
381 |
-
help='Maximum length during sampling')
|
382 |
-
parser.add_argument('--length_penalty', type=str, default='',
|
383 |
-
help='wu_X or avg_X, X is the alpha')
|
384 |
-
parser.add_argument('--group_size', type=int, default=1,
|
385 |
-
help='used for diverse beam search. if group_size is 1, then it\'s normal beam search')
|
386 |
-
parser.add_argument('--diversity_lambda', type=float, default=0.5,
|
387 |
-
help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list')
|
388 |
-
parser.add_argument('--temperature', type=float, default=1.0,
|
389 |
-
help='temperature when sampling from distributions (i.e. when sample_method = sample). Lower = "safer" predictions.')
|
390 |
-
parser.add_argument('--decoding_constraint', type=int, default=0,
|
391 |
-
help='If 1, not allowing same word in a row')
|
392 |
-
parser.add_argument('--block_trigrams', type=int, default=0,
|
393 |
-
help='block repeated trigram.')
|
394 |
-
parser.add_argument('--remove_bad_endings', type=int, default=0,
|
395 |
-
help='Remove bad endings')
|
396 |
-
parser.add_argument('--suppress_UNK', type=int, default=1,
|
397 |
-
help='Not predicting UNK')
|
398 |
-
|
399 |
-
|
400 |
-
if __name__ == '__main__':
|
401 |
-
import sys
|
402 |
-
sys.argv = [sys.argv[0]]
|
403 |
-
args = parse_opt()
|
404 |
-
print(args)
|
405 |
-
print()
|
406 |
-
sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml']
|
407 |
-
args1 = parse_opt()
|
408 |
-
print(dict(set(vars(args1).items()) - set(vars(args).items())))
|
409 |
-
print()
|
410 |
-
sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml', '--caption_model', 'att2in2']
|
411 |
-
args2 = parse_opt()
|
412 |
-
print(dict(set(vars(args2).items()) - set(vars(args1).items())))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/utils/resnet.py
DELETED
@@ -1,71 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torchvision.models.resnet
|
4 |
-
from torchvision.models.resnet import BasicBlock, Bottleneck
|
5 |
-
|
6 |
-
class ResNet(torchvision.models.resnet.ResNet):
|
7 |
-
def __init__(self, block, layers, num_classes=1000):
|
8 |
-
super(ResNet, self).__init__(block, layers, num_classes)
|
9 |
-
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
|
10 |
-
for i in range(2, 5):
|
11 |
-
getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2)
|
12 |
-
getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1)
|
13 |
-
|
14 |
-
def resnet18(pretrained=False):
|
15 |
-
"""Constructs a ResNet-18 model.
|
16 |
-
|
17 |
-
Args:
|
18 |
-
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
19 |
-
"""
|
20 |
-
model = ResNet(BasicBlock, [2, 2, 2, 2])
|
21 |
-
if pretrained:
|
22 |
-
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
|
23 |
-
return model
|
24 |
-
|
25 |
-
|
26 |
-
def resnet34(pretrained=False):
|
27 |
-
"""Constructs a ResNet-34 model.
|
28 |
-
|
29 |
-
Args:
|
30 |
-
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
31 |
-
"""
|
32 |
-
model = ResNet(BasicBlock, [3, 4, 6, 3])
|
33 |
-
if pretrained:
|
34 |
-
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
|
35 |
-
return model
|
36 |
-
|
37 |
-
|
38 |
-
def resnet50(pretrained=False):
|
39 |
-
"""Constructs a ResNet-50 model.
|
40 |
-
|
41 |
-
Args:
|
42 |
-
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
43 |
-
"""
|
44 |
-
model = ResNet(Bottleneck, [3, 4, 6, 3])
|
45 |
-
if pretrained:
|
46 |
-
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
|
47 |
-
return model
|
48 |
-
|
49 |
-
|
50 |
-
def resnet101(pretrained=False):
|
51 |
-
"""Constructs a ResNet-101 model.
|
52 |
-
|
53 |
-
Args:
|
54 |
-
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
55 |
-
"""
|
56 |
-
model = ResNet(Bottleneck, [3, 4, 23, 3])
|
57 |
-
if pretrained:
|
58 |
-
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
|
59 |
-
return model
|
60 |
-
|
61 |
-
|
62 |
-
def resnet152(pretrained=False):
|
63 |
-
"""Constructs a ResNet-152 model.
|
64 |
-
|
65 |
-
Args:
|
66 |
-
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
67 |
-
"""
|
68 |
-
model = ResNet(Bottleneck, [3, 8, 36, 3])
|
69 |
-
if pretrained:
|
70 |
-
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
|
71 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/utils/resnet_utils.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
|
5 |
-
class myResnet(nn.Module):
|
6 |
-
def __init__(self, resnet):
|
7 |
-
super(myResnet, self).__init__()
|
8 |
-
self.resnet = resnet
|
9 |
-
|
10 |
-
def forward(self, img, att_size=14):
|
11 |
-
x = img.unsqueeze(0)
|
12 |
-
|
13 |
-
x = self.resnet.conv1(x)
|
14 |
-
x = self.resnet.bn1(x)
|
15 |
-
x = self.resnet.relu(x)
|
16 |
-
x = self.resnet.maxpool(x)
|
17 |
-
|
18 |
-
x = self.resnet.layer1(x)
|
19 |
-
x = self.resnet.layer2(x)
|
20 |
-
x = self.resnet.layer3(x)
|
21 |
-
x = self.resnet.layer4(x)
|
22 |
-
|
23 |
-
fc = x.mean(3).mean(2).squeeze()
|
24 |
-
att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0)
|
25 |
-
|
26 |
-
return fc, att
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/utils/rewards.py
DELETED
@@ -1,392 +0,0 @@
|
|
1 |
-
from __future__ import absolute_import
|
2 |
-
from __future__ import division
|
3 |
-
from __future__ import print_function
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import time
|
7 |
-
from collections import OrderedDict
|
8 |
-
import torch
|
9 |
-
|
10 |
-
import sys
|
11 |
-
try:
|
12 |
-
sys.path.append("cider")
|
13 |
-
from pyciderevalcap.ciderD.ciderD import CiderD
|
14 |
-
from pyciderevalcap.cider.cider import Cider
|
15 |
-
sys.path.append("coco-caption")
|
16 |
-
from pycocoevalcap.bleu.bleu import Bleu
|
17 |
-
except:
|
18 |
-
print('cider or coco-caption missing')
|
19 |
-
|
20 |
-
CiderD_scorer = None
|
21 |
-
Cider_scorer = None
|
22 |
-
Bleu_scorer = None
|
23 |
-
#CiderD_scorer = CiderD(df='corpus')
|
24 |
-
|
25 |
-
|
26 |
-
from .misc import decode_sequence
|
27 |
-
|
28 |
-
def init_scorer(cached_tokens):
|
29 |
-
global CiderD_scorer
|
30 |
-
CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens)
|
31 |
-
global Cider_scorer
|
32 |
-
Cider_scorer = Cider_scorer or Cider(df=cached_tokens)
|
33 |
-
global Bleu_scorer
|
34 |
-
Bleu_scorer = Bleu_scorer or Bleu(4)
|
35 |
-
|
36 |
-
def array_to_str(arr):
|
37 |
-
out = ''
|
38 |
-
for i in range(len(arr)):
|
39 |
-
out += str(arr[i]) + ' '
|
40 |
-
if arr[i] == 0:
|
41 |
-
break
|
42 |
-
return out.strip()
|
43 |
-
|
44 |
-
def get_self_critical_reward(greedy_res, data_gts, gen_result, opt):
|
45 |
-
batch_size = len(data_gts)
|
46 |
-
gen_result_size = gen_result.shape[0]
|
47 |
-
seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img
|
48 |
-
assert greedy_res.shape[0] == batch_size
|
49 |
-
|
50 |
-
res = OrderedDict()
|
51 |
-
gen_result = gen_result.data.cpu().numpy()
|
52 |
-
greedy_res = greedy_res.data.cpu().numpy()
|
53 |
-
for i in range(gen_result_size):
|
54 |
-
res[i] = [array_to_str(gen_result[i])]
|
55 |
-
for i in range(batch_size):
|
56 |
-
res[gen_result_size + i] = [array_to_str(greedy_res[i])]
|
57 |
-
|
58 |
-
gts = OrderedDict()
|
59 |
-
for i in range(len(data_gts)):
|
60 |
-
gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
|
61 |
-
|
62 |
-
res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
|
63 |
-
res__ = {i: res[i] for i in range(len(res_))}
|
64 |
-
gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)}
|
65 |
-
gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)})
|
66 |
-
if opt.cider_reward_weight > 0:
|
67 |
-
_, cider_scores = CiderD_scorer.compute_score(gts_, res_)
|
68 |
-
if hasattr(opt, 'verbose') and not opt.verbose:
|
69 |
-
pass
|
70 |
-
else:
|
71 |
-
print('Cider scores:', _)
|
72 |
-
else:
|
73 |
-
cider_scores = 0
|
74 |
-
if opt.bleu_reward_weight > 0:
|
75 |
-
_, bleu_scores = Bleu_scorer.compute_score(gts_, res__)
|
76 |
-
bleu_scores = np.array(bleu_scores[3])
|
77 |
-
if hasattr(opt, 'verbose') and not opt.verbose:
|
78 |
-
pass
|
79 |
-
else:
|
80 |
-
print('Bleu scores:', _[3])
|
81 |
-
else:
|
82 |
-
bleu_scores = 0
|
83 |
-
scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
|
84 |
-
|
85 |
-
unnormalized_reward_mean = scores[:gen_result_size].flatten().mean()
|
86 |
-
|
87 |
-
scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis]
|
88 |
-
|
89 |
-
scores = scores.reshape(gen_result_size)
|
90 |
-
|
91 |
-
rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)
|
92 |
-
|
93 |
-
return rewards, unnormalized_reward_mean
|
94 |
-
|
95 |
-
|
96 |
-
def get_self_critical_clipscore_reward(greedy_res, data_gts, gen_result, opt, clipscore_model, clip_vis_feats, vocab):
|
97 |
-
batch_size = len(data_gts)
|
98 |
-
gen_result_size = gen_result.shape[0]
|
99 |
-
seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img
|
100 |
-
assert greedy_res.shape[0] == batch_size
|
101 |
-
|
102 |
-
B = batch_size
|
103 |
-
K = seq_per_img
|
104 |
-
L = gen_result.shape[1]
|
105 |
-
assert gen_result.shape == (B*K , L)
|
106 |
-
|
107 |
-
# res = OrderedDict()
|
108 |
-
# gen_result = gen_result.data.cpu().numpy()
|
109 |
-
# greedy_res = greedy_res.data.cpu().numpy()
|
110 |
-
# for i in range(gen_result_size):
|
111 |
-
# res[i] = [array_to_str(gen_result[i])]
|
112 |
-
# for i in range(batch_size):
|
113 |
-
# res[gen_result_size + i] = [array_to_str(greedy_res[i])]
|
114 |
-
|
115 |
-
# gts = OrderedDict()
|
116 |
-
# for i in range(len(data_gts)):
|
117 |
-
# gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
|
118 |
-
|
119 |
-
# res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
|
120 |
-
# res__ = {i: res[i] for i in range(len(res_))}
|
121 |
-
# gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)}
|
122 |
-
# gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)})
|
123 |
-
|
124 |
-
# res = []
|
125 |
-
# gen_result = gen_result.data.cpu().numpy()
|
126 |
-
# greedy_res = greedy_res.data.cpu().numpy()
|
127 |
-
# # for i in range(gen_result_size):
|
128 |
-
# # res.append(array_to_str(gen_result[i]))
|
129 |
-
# res.extend(decode_sequence(vocab, gen_result))
|
130 |
-
|
131 |
-
|
132 |
-
# # for i in range(batch_size):
|
133 |
-
# # res.append(array_to_str(greedy_res[i]))
|
134 |
-
# res.extend(decode_sequence(vocab, greedy_res))
|
135 |
-
|
136 |
-
if clipscore_model.mode == 'refclip_s':
|
137 |
-
gts = []
|
138 |
-
gts_valid_mask = []
|
139 |
-
max_n_refs = max([len(_gts) for _gts in data_gts])
|
140 |
-
for i in range(len(data_gts)):
|
141 |
-
_gts = decode_sequence(vocab, data_gts[i])
|
142 |
-
# pad references
|
143 |
-
n_ref = len(_gts)
|
144 |
-
_gts.extend([''] * (max_n_refs - n_ref))
|
145 |
-
gts.extend(_gts)
|
146 |
-
gts_valid_mask.extend([1] * n_ref + [0] * (max_n_refs - n_ref))
|
147 |
-
assert len(gts) == B * max_n_refs
|
148 |
-
assert len(gts_valid_mask) == B * max_n_refs
|
149 |
-
|
150 |
-
# print(gts)
|
151 |
-
# print(gts_valid_mask)
|
152 |
-
# exit()
|
153 |
-
|
154 |
-
|
155 |
-
# assert len(res) == B * K + B, len(res)
|
156 |
-
|
157 |
-
# print(res)
|
158 |
-
# exit()
|
159 |
-
|
160 |
-
if opt.clipscore_reward_weight > 0:
|
161 |
-
with torch.no_grad():
|
162 |
-
clipscore_model.eval()
|
163 |
-
|
164 |
-
# 1) calculate reward
|
165 |
-
gen_result = gen_result.data.cpu().numpy()
|
166 |
-
res = decode_sequence(vocab, gen_result)
|
167 |
-
assert len(res) == B * K, len(res)
|
168 |
-
|
169 |
-
# [B * K, dim)
|
170 |
-
if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False):
|
171 |
-
text_pre_feat = clipscore_model.text_extract(res, proj_norm=False)
|
172 |
-
|
173 |
-
grammar_logit = clipscore_model.grammar_score_head(text_pre_feat.view(-1, 512))
|
174 |
-
grammar_prob = torch.softmax(grammar_logit, dim=-1)[:, 1]
|
175 |
-
grammar_prob = grammar_prob.view(B*K).detach()
|
176 |
-
|
177 |
-
text_feat = clipscore_model.clip_model.text_projection(text_pre_feat)
|
178 |
-
text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
|
179 |
-
|
180 |
-
else:
|
181 |
-
text_feat = clipscore_model.text_extract(res)
|
182 |
-
|
183 |
-
|
184 |
-
assert text_feat.size() == (B * K, 512), text_feat.size()
|
185 |
-
assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size()
|
186 |
-
|
187 |
-
# [B * K, dim]
|
188 |
-
vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K, -1).contiguous().view(B * K, -1)
|
189 |
-
|
190 |
-
clip_s = clipscore_model(text_feat=text_feat, img_feat=vis_feat, mode='clip_s')
|
191 |
-
clip_s = clip_s.view(B * K).detach()
|
192 |
-
|
193 |
-
if clipscore_model.mode == 'refclip_s':
|
194 |
-
# [B * n_ref, dim]
|
195 |
-
ref_text_feat = clipscore_model.text_extract(gts)
|
196 |
-
ref_text_mask = torch.tensor(gts_valid_mask, dtype=ref_text_feat.dtype, device=ref_text_feat.device)
|
197 |
-
|
198 |
-
assert ref_text_feat.size() == (B * max_n_refs, 512), ref_text_feat.size()
|
199 |
-
assert ref_text_mask.size() == (B * max_n_refs,), ref_text_mask.size()
|
200 |
-
|
201 |
-
# [B * K]
|
202 |
-
refclip_s = clipscore_model.calc_refclip_s(
|
203 |
-
text_feat=text_feat, img_feat=vis_feat,
|
204 |
-
ref_text_feat=ref_text_feat.view(B, 1, max_n_refs, -1).expand(-1, K, -1, -1).contiguous().view(B * K * max_n_refs, -1),
|
205 |
-
ref_text_mask=ref_text_mask.view(B, 1, max_n_refs).expand(-1, K, -1).contiguous().view(B * K * max_n_refs),
|
206 |
-
clip_s=clip_s)
|
207 |
-
refclip_s = refclip_s.view(B * K).detach()
|
208 |
-
|
209 |
-
# 2) calcualte reward for baseline (greedy)
|
210 |
-
greedy_res = greedy_res.data.cpu().numpy()
|
211 |
-
res = decode_sequence(vocab, greedy_res)
|
212 |
-
assert len(res) == B, len(res)
|
213 |
-
|
214 |
-
# [B, dim)
|
215 |
-
|
216 |
-
if getattr(opt, 'use_grammar', False) and getattr(opt, 'use_grammar_baseline', False) and not getattr(opt, 'joint_out', False):
|
217 |
-
text_pre_feat = clipscore_model.text_extract(res, proj_norm=False)
|
218 |
-
|
219 |
-
grammar_logit = clipscore_model.grammar_score_head(text_pre_feat.view(-1, 512))
|
220 |
-
grammar_prob_baseline = torch.softmax(grammar_logit, dim=-1)[:, 1]
|
221 |
-
grammar_prob_baseline = grammar_prob_baseline.view(B).detach()
|
222 |
-
|
223 |
-
text_feat = clipscore_model.clip_model.text_projection(text_pre_feat)
|
224 |
-
text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
|
225 |
-
else:
|
226 |
-
text_feat = clipscore_model.text_extract(res)
|
227 |
-
|
228 |
-
assert text_feat.size() == (B, 512), text_feat.size()
|
229 |
-
assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size()
|
230 |
-
|
231 |
-
vis_feat = clip_vis_feats.view(B, 512)
|
232 |
-
|
233 |
-
# [B]
|
234 |
-
clip_s_baseline = clipscore_model(text_feat=text_feat, img_feat=vis_feat, mode='clip_s')
|
235 |
-
clip_s_baseline = clip_s_baseline.view(B).detach()
|
236 |
-
|
237 |
-
if clipscore_model.mode == 'refclip_s':
|
238 |
-
# # [B * n_ref]
|
239 |
-
# ref_text_feat = clipscore_model.text_extract(gts)
|
240 |
-
# ref_text_mask = torch.tensor(gts_valid_mask, dtype=ref_text_feat.dtype, device=ref_text_feat.device)
|
241 |
-
# assert ref_text_feat.size() == (B * max_n_refs, 512), ref_text_feat.size()
|
242 |
-
# assert ref_text_mask.size() == (B * max_n_refs), ref_text_mask.size()
|
243 |
-
|
244 |
-
# [B]
|
245 |
-
refclip_s_baseline = clipscore_model.calc_refclip_s(
|
246 |
-
text_feat=text_feat, img_feat=vis_feat,
|
247 |
-
ref_text_feat=ref_text_feat,
|
248 |
-
ref_text_mask=ref_text_mask,
|
249 |
-
clip_s=clip_s_baseline)
|
250 |
-
refclip_s_baseline = refclip_s_baseline.view(B).detach()
|
251 |
-
|
252 |
-
if clipscore_model.mode == 'clip_s':
|
253 |
-
rewards = clip_s - clip_s_baseline.view(B, 1).expand(-1, K).contiguous().flatten()
|
254 |
-
unnormalized_mean_reward = clip_s.mean()
|
255 |
-
elif clipscore_model.mode == 'refclip_s':
|
256 |
-
rewards = refclip_s - refclip_s_baseline.view(B, 1).expand(-1, K).contiguous().flatten()
|
257 |
-
unnormalized_mean_reward = refclip_s.mean()
|
258 |
-
|
259 |
-
# # [B * K + B, dim)
|
260 |
-
# text_feat = clipscore_model.text_extract(res)
|
261 |
-
# assert text_feat.size() == (B * K + B, 512), text_feat.size()
|
262 |
-
|
263 |
-
# assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size()
|
264 |
-
|
265 |
-
# # [B, dim] -> [B * K + B, dim]
|
266 |
-
# # vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K + 1, -1).contiguous().view(B * (K + 1), -1)
|
267 |
-
# # vis_feat = clip_vis_feats.view(1, B, -1).expand(K + 1, -1, -1).contiguous().view((K + 1) * B, -1)
|
268 |
-
|
269 |
-
# # [B * K, dim]
|
270 |
-
# gen_vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K, -1).contiguous().view(B * K, -1)
|
271 |
-
# # [B, dim]
|
272 |
-
# greedy_vis_feat = clip_vis_feats
|
273 |
-
# # [B * K + B, dim]
|
274 |
-
# vis_feat = torch.cat([gen_vis_feat, greedy_vis_feat], dim=0)
|
275 |
-
|
276 |
-
# # if clipscore_model.mode == 'clip_s':
|
277 |
-
# # [B * K + B, dim]
|
278 |
-
# clip_s = clipscore_model(text_feat=text_feat, img_feat=vis_feat)
|
279 |
-
# clip_s = clip_s.view(B * K + B).detach()
|
280 |
-
|
281 |
-
|
282 |
-
# if clipscore_model.mode == 'refclip_s':
|
283 |
-
# # [B * K, dim]
|
284 |
-
# ref_text_feat = clipscore_model.text_extract(gts)
|
285 |
-
|
286 |
-
# clipscore_scores = clipscore_model.calc_refclip_s(text_feat=text_feat, img_feat=vis_feat, ref_text_feat=ref_text_feat, clip_s=clip_s)
|
287 |
-
# clipscore_scores = clipscore_scores.view(B * K + B).detach()
|
288 |
-
|
289 |
-
if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False):
|
290 |
-
|
291 |
-
if getattr(opt, 'use_grammar_baseline', False):
|
292 |
-
grammar_rewards = grammar_prob - grammar_prob_baseline.view(B, 1).expand(-1, K).contiguous().flatten()
|
293 |
-
else:
|
294 |
-
grammar_rewards = grammar_prob
|
295 |
-
else:
|
296 |
-
grammar_rewards = None
|
297 |
-
|
298 |
-
|
299 |
-
if hasattr(opt, 'verbose') and not opt.verbose:
|
300 |
-
pass
|
301 |
-
else:
|
302 |
-
if clipscore_model.mode == 'clip_s':
|
303 |
-
print('CLIP-S:', rewards)
|
304 |
-
elif clipscore_model.mode == 'refclip_s':
|
305 |
-
print('RefCLIP-S:', rewards)
|
306 |
-
else:
|
307 |
-
rewards = torch.zeros(B, L)
|
308 |
-
unnormalized_mean_reward = None
|
309 |
-
grammar_rewards = None
|
310 |
-
|
311 |
-
|
312 |
-
rewards = opt.clipscore_reward_weight * rewards
|
313 |
-
|
314 |
-
|
315 |
-
# scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis]
|
316 |
-
# scores = scores.reshape(gen_result_size)
|
317 |
-
# rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)
|
318 |
-
|
319 |
-
# [B, K]
|
320 |
-
# scores = scores[:gen_result_size].reshape(B, K) - scores[-B:].unsqueeze(1)
|
321 |
-
|
322 |
-
# [B*K, L]
|
323 |
-
# rewards = scores.view(-1, 1).expand(-1, L).contiguous()
|
324 |
-
rewards = rewards.view(-1, 1).expand(-1, L).contiguous()
|
325 |
-
|
326 |
-
if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False):
|
327 |
-
grammar_rewards = grammar_rewards.view(-1, 1).expand(-1, L).contiguous()
|
328 |
-
|
329 |
-
return rewards, unnormalized_mean_reward, grammar_rewards
|
330 |
-
|
331 |
-
def get_scores(data_gts, gen_result, opt):
|
332 |
-
batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
|
333 |
-
seq_per_img = batch_size // len(data_gts)
|
334 |
-
|
335 |
-
res = OrderedDict()
|
336 |
-
|
337 |
-
gen_result = gen_result.data.cpu().numpy()
|
338 |
-
for i in range(batch_size):
|
339 |
-
res[i] = [array_to_str(gen_result[i])]
|
340 |
-
|
341 |
-
gts = OrderedDict()
|
342 |
-
for i in range(len(data_gts)):
|
343 |
-
gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))]
|
344 |
-
|
345 |
-
res_ = [{'image_id':i, 'caption': res[i]} for i in range(batch_size)]
|
346 |
-
res__ = {i: res[i] for i in range(batch_size)}
|
347 |
-
gts = {i: gts[i // seq_per_img] for i in range(batch_size)}
|
348 |
-
if opt.cider_reward_weight > 0:
|
349 |
-
_, cider_scores = CiderD_scorer.compute_score(gts, res_)
|
350 |
-
# print('Cider scores:', _)
|
351 |
-
if hasattr(opt, 'verbose') and not opt.verbose:
|
352 |
-
pass
|
353 |
-
else:
|
354 |
-
print('Cider scores:', _)
|
355 |
-
else:
|
356 |
-
cider_scores = 0
|
357 |
-
if opt.bleu_reward_weight > 0:
|
358 |
-
_, bleu_scores = Bleu_scorer.compute_score(gts, res__)
|
359 |
-
bleu_scores = np.array(bleu_scores[3])
|
360 |
-
# print('Bleu scores:', _[3])
|
361 |
-
if hasattr(opt, 'verbose') and not opt.verbose:
|
362 |
-
pass
|
363 |
-
else:
|
364 |
-
print('Bleu scores:', _[3])
|
365 |
-
else:
|
366 |
-
bleu_scores = 0
|
367 |
-
|
368 |
-
scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores
|
369 |
-
|
370 |
-
return scores
|
371 |
-
|
372 |
-
def get_self_cider_scores(data_gts, gen_result, opt):
|
373 |
-
batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
|
374 |
-
seq_per_img = batch_size // len(data_gts)
|
375 |
-
|
376 |
-
res = []
|
377 |
-
|
378 |
-
gen_result = gen_result.data.cpu().numpy()
|
379 |
-
for i in range(batch_size):
|
380 |
-
res.append(array_to_str(gen_result[i]))
|
381 |
-
|
382 |
-
scores = []
|
383 |
-
for i in range(len(data_gts)):
|
384 |
-
tmp = Cider_scorer.my_self_cider([res[i*seq_per_img:(i+1)*seq_per_img]])
|
385 |
-
def get_div(eigvals):
|
386 |
-
eigvals = np.clip(eigvals, 0, None)
|
387 |
-
return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals))
|
388 |
-
scores.append(get_div(np.linalg.eigvalsh(tmp[0]/10)))
|
389 |
-
|
390 |
-
scores = np.array(scores)
|
391 |
-
|
392 |
-
return scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
captioning/utils/utils.py
DELETED
@@ -1,138 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
import numpy as np
|
3 |
-
import torch
|
4 |
-
import torch.distributed as dist
|
5 |
-
import collections
|
6 |
-
import logging
|
7 |
-
|
8 |
-
def get_area(pos):
|
9 |
-
"""
|
10 |
-
Args
|
11 |
-
pos: [B, N, 4]
|
12 |
-
(x1, x2, y1, y2)
|
13 |
-
|
14 |
-
Return
|
15 |
-
area : [B, N]
|
16 |
-
"""
|
17 |
-
# [B, N]
|
18 |
-
height = pos[:, :, 3] - pos[:, :, 2]
|
19 |
-
width = pos[:, :, 1] - pos[:, :, 0]
|
20 |
-
area = height * width
|
21 |
-
return area
|
22 |
-
|
23 |
-
def get_relative_distance(pos):
|
24 |
-
"""
|
25 |
-
Args
|
26 |
-
pos: [B, N, 4]
|
27 |
-
(x1, x2, y1, y2)
|
28 |
-
|
29 |
-
Return
|
30 |
-
out : [B, N, N, 4]
|
31 |
-
"""
|
32 |
-
# B, N = pos.size()[:-1]
|
33 |
-
|
34 |
-
# [B, N, N, 4]
|
35 |
-
relative_distance = pos.unsqueeze(1) - pos.unsqueeze(2)
|
36 |
-
|
37 |
-
return relative_distance
|
38 |
-
|
39 |
-
|
40 |
-
class LossMeter(object):
|
41 |
-
def __init__(self, maxlen=100):
|
42 |
-
"""Computes and stores the running average"""
|
43 |
-
self.vals = collections.deque([], maxlen=maxlen)
|
44 |
-
|
45 |
-
def __len__(self):
|
46 |
-
return len(self.vals)
|
47 |
-
|
48 |
-
def update(self, new_val):
|
49 |
-
self.vals.append(new_val)
|
50 |
-
|
51 |
-
@property
|
52 |
-
def val(self):
|
53 |
-
return sum(self.vals) / len(self.vals)
|
54 |
-
|
55 |
-
def __repr__(self):
|
56 |
-
return str(self.val)
|
57 |
-
|
58 |
-
|
59 |
-
def count_parameters(model):
|
60 |
-
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
61 |
-
|
62 |
-
|
63 |
-
def load_state_dict(state_dict_path, loc='cpu'):
|
64 |
-
state_dict = torch.load(state_dict_path, map_location=loc)
|
65 |
-
# Change Multi GPU to single GPU
|
66 |
-
original_keys = list(state_dict.keys())
|
67 |
-
for key in original_keys:
|
68 |
-
if key.startswith("module."):
|
69 |
-
new_key = key[len("module."):]
|
70 |
-
state_dict[new_key] = state_dict.pop(key)
|
71 |
-
return state_dict
|
72 |
-
|
73 |
-
|
74 |
-
def set_global_logging_level(level=logging.ERROR, prefices=[""]):
|
75 |
-
"""
|
76 |
-
Override logging levels of different modules based on their name as a prefix.
|
77 |
-
It needs to be invoked after the modules have been loaded so that their loggers have been initialized.
|
78 |
-
|
79 |
-
Args:
|
80 |
-
- level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
|
81 |
-
- prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
|
82 |
-
Default is `[""]` to match all active loggers.
|
83 |
-
The match is a case-sensitive `module_name.startswith(prefix)`
|
84 |
-
"""
|
85 |
-
prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })')
|
86 |
-
for name in logging.root.manager.loggerDict:
|
87 |
-
if re.match(prefix_re, name):
|
88 |
-
logging.getLogger(name).setLevel(level)
|
89 |
-
|
90 |
-
|
91 |
-
def get_iou(anchors, gt_boxes):
|
92 |
-
"""
|
93 |
-
anchors: (N, 4) torch floattensor
|
94 |
-
gt_boxes: (K, 4) torch floattensor
|
95 |
-
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
|
96 |
-
"""
|
97 |
-
N = anchors.size(0)
|
98 |
-
|
99 |
-
if gt_boxes.size() == (4,):
|
100 |
-
gt_boxes = gt_boxes.view(1, 4)
|
101 |
-
K = gt_boxes.size(0)
|
102 |
-
|
103 |
-
gt_boxes_area = (
|
104 |
-
(gt_boxes[:, 2] - gt_boxes[:, 0] + 1) *
|
105 |
-
(gt_boxes[:, 3] - gt_boxes[:, 1] + 1)
|
106 |
-
).view(1, K)
|
107 |
-
|
108 |
-
anchors_area = (
|
109 |
-
(anchors[:, 2] - anchors[:, 0] + 1) *
|
110 |
-
(anchors[:, 3] - anchors[:, 1] + 1)
|
111 |
-
).view(N, 1)
|
112 |
-
|
113 |
-
boxes = anchors.view(N, 1, 4).expand(N, K, 4)
|
114 |
-
query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4)
|
115 |
-
|
116 |
-
iw = (
|
117 |
-
torch.min(boxes[:, :, 2], query_boxes[:, :, 2])
|
118 |
-
- torch.max(boxes[:, :, 0], query_boxes[:, :, 0])
|
119 |
-
+ 1
|
120 |
-
)
|
121 |
-
iw[iw < 0] = 0
|
122 |
-
|
123 |
-
ih = (
|
124 |
-
torch.min(boxes[:, :, 3], query_boxes[:, :, 3])
|
125 |
-
- torch.max(boxes[:, :, 1], query_boxes[:, :, 1])
|
126 |
-
+ 1
|
127 |
-
)
|
128 |
-
ih[ih < 0] = 0
|
129 |
-
|
130 |
-
ua = anchors_area + gt_boxes_area - (iw * ih)
|
131 |
-
overlaps = iw * ih / ua
|
132 |
-
|
133 |
-
return overlaps
|
134 |
-
|
135 |
-
|
136 |
-
def xywh_to_xyxy(boxes):
|
137 |
-
"""Convert [x y w h] box format to [x1 y1 x2 y2] format."""
|
138 |
-
return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .clip import *
|
|
|
|
clip/bpe_simple_vocab_16e6.txt.gz
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
-
size 1356917
|
|
|
|
|
|
|
|
clip/clip.py
DELETED
@@ -1,193 +0,0 @@
|
|
1 |
-
import hashlib
|
2 |
-
import os
|
3 |
-
import urllib
|
4 |
-
import warnings
|
5 |
-
from typing import Union, List
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from PIL import Image
|
9 |
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
10 |
-
from tqdm import tqdm
|
11 |
-
|
12 |
-
from .model import build_model
|
13 |
-
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
14 |
-
|
15 |
-
__all__ = ["available_models", "load", "tokenize"]
|
16 |
-
_tokenizer = _Tokenizer()
|
17 |
-
|
18 |
-
_MODELS = {
|
19 |
-
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
20 |
-
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
21 |
-
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
22 |
-
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
23 |
-
}
|
24 |
-
|
25 |
-
|
26 |
-
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
27 |
-
os.makedirs(root, exist_ok=True)
|
28 |
-
filename = os.path.basename(url)
|
29 |
-
|
30 |
-
expected_sha256 = url.split("/")[-2]
|
31 |
-
download_target = os.path.join(root, filename)
|
32 |
-
|
33 |
-
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
34 |
-
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
35 |
-
|
36 |
-
if os.path.isfile(download_target):
|
37 |
-
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
38 |
-
return download_target
|
39 |
-
else:
|
40 |
-
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
41 |
-
|
42 |
-
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
43 |
-
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
44 |
-
while True:
|
45 |
-
buffer = source.read(8192)
|
46 |
-
if not buffer:
|
47 |
-
break
|
48 |
-
|
49 |
-
output.write(buffer)
|
50 |
-
loop.update(len(buffer))
|
51 |
-
|
52 |
-
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
53 |
-
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
54 |
-
|
55 |
-
return download_target
|
56 |
-
|
57 |
-
|
58 |
-
def _transform(n_px):
|
59 |
-
return Compose([
|
60 |
-
Resize(n_px, interpolation=Image.BICUBIC),
|
61 |
-
CenterCrop(n_px),
|
62 |
-
lambda image: image.convert("RGB"),
|
63 |
-
ToTensor(),
|
64 |
-
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
65 |
-
])
|
66 |
-
|
67 |
-
|
68 |
-
def available_models() -> List[str]:
|
69 |
-
"""Returns the names of available CLIP models"""
|
70 |
-
return list(_MODELS.keys())
|
71 |
-
|
72 |
-
|
73 |
-
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
|
74 |
-
"""Load a CLIP model
|
75 |
-
|
76 |
-
Parameters
|
77 |
-
----------
|
78 |
-
name : str
|
79 |
-
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
80 |
-
|
81 |
-
device : Union[str, torch.device]
|
82 |
-
The device to put the loaded model
|
83 |
-
|
84 |
-
jit : bool
|
85 |
-
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
86 |
-
|
87 |
-
Returns
|
88 |
-
-------
|
89 |
-
model : torch.nn.Module
|
90 |
-
The CLIP model
|
91 |
-
|
92 |
-
preprocess : Callable[[PIL.Image], torch.Tensor]
|
93 |
-
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
94 |
-
"""
|
95 |
-
if name in _MODELS:
|
96 |
-
model_path = _download(_MODELS[name])
|
97 |
-
elif os.path.isfile(name):
|
98 |
-
model_path = name
|
99 |
-
else:
|
100 |
-
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
101 |
-
|
102 |
-
try:
|
103 |
-
# loading JIT archive
|
104 |
-
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
105 |
-
state_dict = None
|
106 |
-
except RuntimeError:
|
107 |
-
# loading saved state dict
|
108 |
-
if jit:
|
109 |
-
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
110 |
-
jit = False
|
111 |
-
state_dict = torch.load(model_path, map_location="cpu")
|
112 |
-
|
113 |
-
if not jit:
|
114 |
-
model = build_model(state_dict or model.state_dict()).to(device)
|
115 |
-
if str(device) == "cpu":
|
116 |
-
model.float()
|
117 |
-
return model, _transform(model.visual.input_resolution)
|
118 |
-
|
119 |
-
# patch the device names
|
120 |
-
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
121 |
-
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
122 |
-
|
123 |
-
def patch_device(module):
|
124 |
-
graphs = [module.graph] if hasattr(module, "graph") else []
|
125 |
-
if hasattr(module, "forward1"):
|
126 |
-
graphs.append(module.forward1.graph)
|
127 |
-
|
128 |
-
for graph in graphs:
|
129 |
-
for node in graph.findAllNodes("prim::Constant"):
|
130 |
-
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
131 |
-
node.copyAttributes(device_node)
|
132 |
-
|
133 |
-
model.apply(patch_device)
|
134 |
-
patch_device(model.encode_image)
|
135 |
-
patch_device(model.encode_text)
|
136 |
-
|
137 |
-
# patch dtype to float32 on CPU
|
138 |
-
if str(device) == "cpu":
|
139 |
-
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
140 |
-
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
141 |
-
float_node = float_input.node()
|
142 |
-
|
143 |
-
def patch_float(module):
|
144 |
-
graphs = [module.graph] if hasattr(module, "graph") else []
|
145 |
-
if hasattr(module, "forward1"):
|
146 |
-
graphs.append(module.forward1.graph)
|
147 |
-
|
148 |
-
for graph in graphs:
|
149 |
-
for node in graph.findAllNodes("aten::to"):
|
150 |
-
inputs = list(node.inputs())
|
151 |
-
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
152 |
-
if inputs[i].node()["value"] == 5:
|
153 |
-
inputs[i].node().copyAttributes(float_node)
|
154 |
-
|
155 |
-
model.apply(patch_float)
|
156 |
-
patch_float(model.encode_image)
|
157 |
-
patch_float(model.encode_text)
|
158 |
-
|
159 |
-
model.float()
|
160 |
-
|
161 |
-
return model, _transform(model.input_resolution.item())
|
162 |
-
|
163 |
-
|
164 |
-
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
165 |
-
"""
|
166 |
-
Returns the tokenized representation of given input string(s)
|
167 |
-
|
168 |
-
Parameters
|
169 |
-
----------
|
170 |
-
texts : Union[str, List[str]]
|
171 |
-
An input string or a list of input strings to tokenize
|
172 |
-
|
173 |
-
context_length : int
|
174 |
-
The context length to use; all CLIP models use 77 as the context length
|
175 |
-
|
176 |
-
Returns
|
177 |
-
-------
|
178 |
-
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
179 |
-
"""
|
180 |
-
if isinstance(texts, str):
|
181 |
-
texts = [texts]
|
182 |
-
|
183 |
-
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
184 |
-
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
185 |
-
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
186 |
-
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
187 |
-
|
188 |
-
for i, tokens in enumerate(all_tokens):
|
189 |
-
if len(tokens) > context_length:
|
190 |
-
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
191 |
-
result[i, :len(tokens)] = torch.tensor(tokens)
|
192 |
-
|
193 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip/model.py
DELETED
@@ -1,437 +0,0 @@
|
|
1 |
-
from collections import OrderedDict
|
2 |
-
from typing import Tuple, Union
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn.functional as F
|
6 |
-
from torch import nn
|
7 |
-
|
8 |
-
|
9 |
-
class Bottleneck(nn.Module):
|
10 |
-
expansion = 4
|
11 |
-
|
12 |
-
def __init__(self, inplanes, planes, stride=1):
|
13 |
-
super().__init__()
|
14 |
-
|
15 |
-
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
16 |
-
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
17 |
-
self.bn1 = nn.BatchNorm2d(planes)
|
18 |
-
|
19 |
-
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
20 |
-
self.bn2 = nn.BatchNorm2d(planes)
|
21 |
-
|
22 |
-
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
23 |
-
|
24 |
-
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
25 |
-
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
26 |
-
|
27 |
-
self.relu = nn.ReLU(inplace=True)
|
28 |
-
self.downsample = None
|
29 |
-
self.stride = stride
|
30 |
-
|
31 |
-
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
32 |
-
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
33 |
-
self.downsample = nn.Sequential(OrderedDict([
|
34 |
-
("-1", nn.AvgPool2d(stride)),
|
35 |
-
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
36 |
-
("1", nn.BatchNorm2d(planes * self.expansion))
|
37 |
-
]))
|
38 |
-
|
39 |
-
def forward(self, x: torch.Tensor):
|
40 |
-
identity = x
|
41 |
-
|
42 |
-
out = self.relu(self.bn1(self.conv1(x)))
|
43 |
-
out = self.relu(self.bn2(self.conv2(out)))
|
44 |
-
out = self.avgpool(out)
|
45 |
-
out = self.bn3(self.conv3(out))
|
46 |
-
|
47 |
-
if self.downsample is not None:
|
48 |
-
identity = self.downsample(x)
|
49 |
-
|
50 |
-
out += identity
|
51 |
-
out = self.relu(out)
|
52 |
-
return out
|
53 |
-
|
54 |
-
|
55 |
-
class AttentionPool2d(nn.Module):
|
56 |
-
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
57 |
-
super().__init__()
|
58 |
-
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
59 |
-
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
60 |
-
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
61 |
-
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
62 |
-
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
63 |
-
self.num_heads = num_heads
|
64 |
-
|
65 |
-
def forward(self, x):
|
66 |
-
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
67 |
-
# print(x.shape, self.positional_embedding.shape)
|
68 |
-
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
69 |
-
x = x + self.positional_embedding[0, :, None, :].to(x.dtype) # (HW+1)NC
|
70 |
-
x, _ = F.multi_head_attention_forward(
|
71 |
-
query=x, key=x, value=x,
|
72 |
-
embed_dim_to_check=x.shape[-1],
|
73 |
-
num_heads=self.num_heads,
|
74 |
-
q_proj_weight=self.q_proj.weight,
|
75 |
-
k_proj_weight=self.k_proj.weight,
|
76 |
-
v_proj_weight=self.v_proj.weight,
|
77 |
-
in_proj_weight=None,
|
78 |
-
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
79 |
-
bias_k=None,
|
80 |
-
bias_v=None,
|
81 |
-
add_zero_attn=False,
|
82 |
-
dropout_p=0,
|
83 |
-
out_proj_weight=torch.ones_like(self.q_proj.weight),
|
84 |
-
out_proj_bias=torch.zeros_like(self.q_proj.bias),
|
85 |
-
# out_proj_weight=self.c_proj.weight,
|
86 |
-
# out_proj_bias=self.c_proj.bias,
|
87 |
-
use_separate_proj_weight=True,
|
88 |
-
training=self.training,
|
89 |
-
need_weights=False
|
90 |
-
)
|
91 |
-
|
92 |
-
return x[0]
|
93 |
-
|
94 |
-
|
95 |
-
class ModifiedResNet(nn.Module):
|
96 |
-
"""
|
97 |
-
A ResNet class that is similar to torchvision's but contains the following changes:
|
98 |
-
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
99 |
-
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
100 |
-
- The final pooling layer is a QKV attention instead of an average pool
|
101 |
-
"""
|
102 |
-
|
103 |
-
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
104 |
-
super().__init__()
|
105 |
-
self.output_dim = output_dim
|
106 |
-
self.input_resolution = input_resolution
|
107 |
-
|
108 |
-
# the 3-layer stem
|
109 |
-
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
110 |
-
self.bn1 = nn.BatchNorm2d(width // 2)
|
111 |
-
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
112 |
-
self.bn2 = nn.BatchNorm2d(width // 2)
|
113 |
-
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
114 |
-
self.bn3 = nn.BatchNorm2d(width)
|
115 |
-
self.avgpool = nn.AvgPool2d(2)
|
116 |
-
self.relu = nn.ReLU(inplace=True)
|
117 |
-
|
118 |
-
# residual layers
|
119 |
-
self._inplanes = width # this is a *mutable* variable used during construction
|
120 |
-
self.layer1 = self._make_layer(width, layers[0])
|
121 |
-
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
122 |
-
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
123 |
-
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
124 |
-
|
125 |
-
embed_dim = width * 32 # the ResNet feature dimension
|
126 |
-
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
127 |
-
|
128 |
-
def _make_layer(self, planes, blocks, stride=1):
|
129 |
-
layers = [Bottleneck(self._inplanes, planes, stride)]
|
130 |
-
|
131 |
-
self._inplanes = planes * Bottleneck.expansion
|
132 |
-
for _ in range(1, blocks):
|
133 |
-
layers.append(Bottleneck(self._inplanes, planes))
|
134 |
-
|
135 |
-
return nn.Sequential(*layers)
|
136 |
-
|
137 |
-
def forward(self, x):
|
138 |
-
def stem(x):
|
139 |
-
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
140 |
-
x = self.relu(bn(conv(x)))
|
141 |
-
x = self.avgpool(x)
|
142 |
-
return x
|
143 |
-
|
144 |
-
x = x.type(self.conv1.weight.dtype)
|
145 |
-
x = stem(x)
|
146 |
-
x = self.layer1(x)
|
147 |
-
x = self.layer2(x)
|
148 |
-
x = self.layer3(x)
|
149 |
-
x = self.layer4(x)
|
150 |
-
# print(x.shape)
|
151 |
-
# x = self.attnpool(x)
|
152 |
-
attnpool = self.attnpool(x)
|
153 |
-
|
154 |
-
return (x, attnpool)
|
155 |
-
|
156 |
-
|
157 |
-
class LayerNorm(nn.LayerNorm):
|
158 |
-
"""Subclass torch's LayerNorm to handle fp16."""
|
159 |
-
|
160 |
-
def forward(self, x: torch.Tensor):
|
161 |
-
orig_type = x.dtype
|
162 |
-
ret = super().forward(x.type(torch.float32))
|
163 |
-
return ret.type(orig_type)
|
164 |
-
|
165 |
-
|
166 |
-
class QuickGELU(nn.Module):
|
167 |
-
def forward(self, x: torch.Tensor):
|
168 |
-
return x * torch.sigmoid(1.702 * x)
|
169 |
-
|
170 |
-
|
171 |
-
class ResidualAttentionBlock(nn.Module):
|
172 |
-
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
173 |
-
super().__init__()
|
174 |
-
|
175 |
-
self.attn = nn.MultiheadAttention(d_model, n_head)
|
176 |
-
self.ln_1 = LayerNorm(d_model)
|
177 |
-
self.mlp = nn.Sequential(OrderedDict([
|
178 |
-
("c_fc", nn.Linear(d_model, d_model * 4)),
|
179 |
-
("gelu", QuickGELU()),
|
180 |
-
("c_proj", nn.Linear(d_model * 4, d_model))
|
181 |
-
]))
|
182 |
-
self.ln_2 = LayerNorm(d_model)
|
183 |
-
self.attn_mask = attn_mask
|
184 |
-
|
185 |
-
def attention(self, x: torch.Tensor):
|
186 |
-
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
187 |
-
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
188 |
-
|
189 |
-
def forward(self, x: torch.Tensor):
|
190 |
-
x = x + self.attention(self.ln_1(x))
|
191 |
-
x = x + self.mlp(self.ln_2(x))
|
192 |
-
return x
|
193 |
-
|
194 |
-
|
195 |
-
class Transformer(nn.Module):
|
196 |
-
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
197 |
-
super().__init__()
|
198 |
-
self.width = width
|
199 |
-
self.layers = layers
|
200 |
-
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
201 |
-
|
202 |
-
def forward(self, x: torch.Tensor):
|
203 |
-
return self.resblocks(x)
|
204 |
-
|
205 |
-
|
206 |
-
class VisualTransformer(nn.Module):
|
207 |
-
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
208 |
-
super().__init__()
|
209 |
-
self.input_resolution = input_resolution
|
210 |
-
self.output_dim = output_dim
|
211 |
-
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
212 |
-
|
213 |
-
scale = width ** -0.5
|
214 |
-
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
215 |
-
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
216 |
-
self.ln_pre = LayerNorm(width)
|
217 |
-
|
218 |
-
self.transformer = Transformer(width, layers, heads)
|
219 |
-
|
220 |
-
self.ln_post = LayerNorm(width)
|
221 |
-
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
222 |
-
|
223 |
-
def forward(self, x: torch.Tensor):
|
224 |
-
x = self.conv1(x) # shape = [*, width, grid, grid]
|
225 |
-
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
226 |
-
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
227 |
-
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
228 |
-
x = x + self.positional_embedding.to(x.dtype)
|
229 |
-
x = self.ln_pre(x)
|
230 |
-
|
231 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
232 |
-
x = self.transformer(x)
|
233 |
-
x = x.permute(1, 0, 2) # LND -> NLD
|
234 |
-
|
235 |
-
# x = self.ln_post(x[:, 0, :])
|
236 |
-
|
237 |
-
x = self.ln_post(x)
|
238 |
-
# if self.proj is not None:
|
239 |
-
# x = x @ self.proj
|
240 |
-
|
241 |
-
return x
|
242 |
-
|
243 |
-
|
244 |
-
class CLIP(nn.Module):
|
245 |
-
def __init__(self,
|
246 |
-
embed_dim: int,
|
247 |
-
# vision
|
248 |
-
image_resolution: int,
|
249 |
-
vision_layers: Union[Tuple[int, int, int, int], int],
|
250 |
-
vision_width: int,
|
251 |
-
vision_patch_size: int,
|
252 |
-
# text
|
253 |
-
context_length: int,
|
254 |
-
vocab_size: int,
|
255 |
-
transformer_width: int,
|
256 |
-
transformer_heads: int,
|
257 |
-
transformer_layers: int
|
258 |
-
):
|
259 |
-
super().__init__()
|
260 |
-
|
261 |
-
self.context_length = context_length
|
262 |
-
|
263 |
-
if isinstance(vision_layers, (tuple, list)):
|
264 |
-
vision_heads = vision_width * 32 // 64
|
265 |
-
self.visual = ModifiedResNet(
|
266 |
-
layers=vision_layers,
|
267 |
-
output_dim=embed_dim,
|
268 |
-
heads=vision_heads,
|
269 |
-
input_resolution=image_resolution,
|
270 |
-
width=vision_width
|
271 |
-
)
|
272 |
-
else:
|
273 |
-
vision_heads = vision_width // 64
|
274 |
-
self.visual = VisualTransformer(
|
275 |
-
input_resolution=image_resolution,
|
276 |
-
patch_size=vision_patch_size,
|
277 |
-
width=vision_width,
|
278 |
-
layers=vision_layers,
|
279 |
-
heads=vision_heads,
|
280 |
-
output_dim=embed_dim
|
281 |
-
)
|
282 |
-
|
283 |
-
self.transformer = Transformer(
|
284 |
-
width=transformer_width,
|
285 |
-
layers=transformer_layers,
|
286 |
-
heads=transformer_heads,
|
287 |
-
attn_mask=self.build_attention_mask()
|
288 |
-
)
|
289 |
-
|
290 |
-
self.vocab_size = vocab_size
|
291 |
-
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
292 |
-
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
293 |
-
self.ln_final = LayerNorm(transformer_width)
|
294 |
-
|
295 |
-
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
296 |
-
self.logit_scale = nn.Parameter(torch.ones([]))
|
297 |
-
|
298 |
-
self.initialize_parameters()
|
299 |
-
|
300 |
-
def initialize_parameters(self):
|
301 |
-
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
302 |
-
nn.init.normal_(self.positional_embedding, std=0.01)
|
303 |
-
|
304 |
-
if isinstance(self.visual, ModifiedResNet):
|
305 |
-
if self.visual.attnpool is not None:
|
306 |
-
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
307 |
-
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
308 |
-
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
309 |
-
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
310 |
-
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
311 |
-
|
312 |
-
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
313 |
-
for name, param in resnet_block.named_parameters():
|
314 |
-
if name.endswith("bn3.weight"):
|
315 |
-
nn.init.zeros_(param)
|
316 |
-
|
317 |
-
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
318 |
-
attn_std = self.transformer.width ** -0.5
|
319 |
-
fc_std = (2 * self.transformer.width) ** -0.5
|
320 |
-
for block in self.transformer.resblocks:
|
321 |
-
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
322 |
-
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
323 |
-
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
324 |
-
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
325 |
-
|
326 |
-
if self.text_projection is not None:
|
327 |
-
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
328 |
-
|
329 |
-
def build_attention_mask(self):
|
330 |
-
# lazily create causal attention mask, with full attention between the vision tokens
|
331 |
-
# pytorch uses additive attention mask; fill with -inf
|
332 |
-
mask = torch.empty(self.context_length, self.context_length)
|
333 |
-
mask.fill_(float("-inf"))
|
334 |
-
mask.triu_(1) # zero out the lower diagonal
|
335 |
-
return mask
|
336 |
-
|
337 |
-
@property
|
338 |
-
def dtype(self):
|
339 |
-
return self.visual.conv1.weight.dtype
|
340 |
-
|
341 |
-
def encode_image(self, image):
|
342 |
-
return self.visual(image.type(self.dtype))
|
343 |
-
|
344 |
-
def encode_text(self, text):
|
345 |
-
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
346 |
-
|
347 |
-
x = x + self.positional_embedding.type(self.dtype)
|
348 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
349 |
-
x = self.transformer(x)
|
350 |
-
x = x.permute(1, 0, 2) # LND -> NLD
|
351 |
-
x = self.ln_final(x).type(self.dtype)
|
352 |
-
|
353 |
-
# x.shape = [batch_size, n_ctx, transformer.width]
|
354 |
-
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
355 |
-
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
356 |
-
|
357 |
-
return x
|
358 |
-
|
359 |
-
def forward(self, image, text):
|
360 |
-
image_features = self.encode_image(image)
|
361 |
-
text_features = self.encode_text(text)
|
362 |
-
|
363 |
-
# normalized features
|
364 |
-
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
365 |
-
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
366 |
-
|
367 |
-
# cosine similarity as logits
|
368 |
-
logit_scale = self.logit_scale.exp()
|
369 |
-
logits_per_image = logit_scale * image_features @ text_features.t()
|
370 |
-
logits_per_text = logit_scale * text_features @ image_features.t()
|
371 |
-
|
372 |
-
# shape = [global_batch_size, global_batch_size]
|
373 |
-
return logits_per_image, logits_per_text
|
374 |
-
|
375 |
-
|
376 |
-
def convert_weights(model: nn.Module):
|
377 |
-
"""Convert applicable model parameters to fp16"""
|
378 |
-
|
379 |
-
def _convert_weights_to_fp16(l):
|
380 |
-
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
381 |
-
l.weight.data = l.weight.data.half()
|
382 |
-
if l.bias is not None:
|
383 |
-
l.bias.data = l.bias.data.half()
|
384 |
-
|
385 |
-
if isinstance(l, nn.MultiheadAttention):
|
386 |
-
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
387 |
-
tensor = getattr(l, attr)
|
388 |
-
if tensor is not None:
|
389 |
-
tensor.data = tensor.data.half()
|
390 |
-
|
391 |
-
for name in ["text_projection", "proj"]:
|
392 |
-
if hasattr(l, name):
|
393 |
-
attr = getattr(l, name)
|
394 |
-
if attr is not None:
|
395 |
-
attr.data = attr.data.half()
|
396 |
-
|
397 |
-
model.apply(_convert_weights_to_fp16)
|
398 |
-
|
399 |
-
|
400 |
-
def build_model(state_dict: dict):
|
401 |
-
vit = "visual.proj" in state_dict
|
402 |
-
|
403 |
-
if vit:
|
404 |
-
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
405 |
-
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
406 |
-
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
407 |
-
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
408 |
-
image_resolution = vision_patch_size * grid_size
|
409 |
-
else:
|
410 |
-
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
411 |
-
vision_layers = tuple(counts)
|
412 |
-
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
413 |
-
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
414 |
-
vision_patch_size = None
|
415 |
-
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
416 |
-
image_resolution = output_width * 32
|
417 |
-
|
418 |
-
embed_dim = state_dict["text_projection"].shape[1]
|
419 |
-
context_length = state_dict["positional_embedding"].shape[0]
|
420 |
-
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
421 |
-
transformer_width = state_dict["ln_final.weight"].shape[0]
|
422 |
-
transformer_heads = transformer_width // 64
|
423 |
-
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
424 |
-
|
425 |
-
model = CLIP(
|
426 |
-
embed_dim,
|
427 |
-
image_resolution, vision_layers, vision_width, vision_patch_size,
|
428 |
-
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
429 |
-
)
|
430 |
-
|
431 |
-
for key in ["input_resolution", "context_length", "vocab_size"]:
|
432 |
-
if key in state_dict:
|
433 |
-
del state_dict[key]
|
434 |
-
|
435 |
-
convert_weights(model)
|
436 |
-
model.load_state_dict(state_dict)
|
437 |
-
return model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip/simple_tokenizer.py
DELETED
@@ -1,132 +0,0 @@
|
|
1 |
-
import gzip
|
2 |
-
import html
|
3 |
-
import os
|
4 |
-
from functools import lru_cache
|
5 |
-
|
6 |
-
import ftfy
|
7 |
-
import regex as re
|
8 |
-
|
9 |
-
|
10 |
-
@lru_cache()
|
11 |
-
def default_bpe():
|
12 |
-
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
-
|
14 |
-
|
15 |
-
@lru_cache()
|
16 |
-
def bytes_to_unicode():
|
17 |
-
"""
|
18 |
-
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
-
The reversible bpe codes work on unicode strings.
|
20 |
-
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
-
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
-
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
-
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
-
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
-
"""
|
26 |
-
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
27 |
-
cs = bs[:]
|
28 |
-
n = 0
|
29 |
-
for b in range(2**8):
|
30 |
-
if b not in bs:
|
31 |
-
bs.append(b)
|
32 |
-
cs.append(2**8+n)
|
33 |
-
n += 1
|
34 |
-
cs = [chr(n) for n in cs]
|
35 |
-
return dict(zip(bs, cs))
|
36 |
-
|
37 |
-
|
38 |
-
def get_pairs(word):
|
39 |
-
"""Return set of symbol pairs in a word.
|
40 |
-
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
-
"""
|
42 |
-
pairs = set()
|
43 |
-
prev_char = word[0]
|
44 |
-
for char in word[1:]:
|
45 |
-
pairs.add((prev_char, char))
|
46 |
-
prev_char = char
|
47 |
-
return pairs
|
48 |
-
|
49 |
-
|
50 |
-
def basic_clean(text):
|
51 |
-
text = ftfy.fix_text(text)
|
52 |
-
text = html.unescape(html.unescape(text))
|
53 |
-
return text.strip()
|
54 |
-
|
55 |
-
|
56 |
-
def whitespace_clean(text):
|
57 |
-
text = re.sub(r'\s+', ' ', text)
|
58 |
-
text = text.strip()
|
59 |
-
return text
|
60 |
-
|
61 |
-
|
62 |
-
class SimpleTokenizer(object):
|
63 |
-
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
-
self.byte_encoder = bytes_to_unicode()
|
65 |
-
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
-
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
-
merges = merges[1:49152-256-2+1]
|
68 |
-
merges = [tuple(merge.split()) for merge in merges]
|
69 |
-
vocab = list(bytes_to_unicode().values())
|
70 |
-
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
-
for merge in merges:
|
72 |
-
vocab.append(''.join(merge))
|
73 |
-
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
-
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
-
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
-
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
-
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
-
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
-
|
80 |
-
def bpe(self, token):
|
81 |
-
if token in self.cache:
|
82 |
-
return self.cache[token]
|
83 |
-
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
84 |
-
pairs = get_pairs(word)
|
85 |
-
|
86 |
-
if not pairs:
|
87 |
-
return token+'</w>'
|
88 |
-
|
89 |
-
while True:
|
90 |
-
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
91 |
-
if bigram not in self.bpe_ranks:
|
92 |
-
break
|
93 |
-
first, second = bigram
|
94 |
-
new_word = []
|
95 |
-
i = 0
|
96 |
-
while i < len(word):
|
97 |
-
try:
|
98 |
-
j = word.index(first, i)
|
99 |
-
new_word.extend(word[i:j])
|
100 |
-
i = j
|
101 |
-
except:
|
102 |
-
new_word.extend(word[i:])
|
103 |
-
break
|
104 |
-
|
105 |
-
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
106 |
-
new_word.append(first+second)
|
107 |
-
i += 2
|
108 |
-
else:
|
109 |
-
new_word.append(word[i])
|
110 |
-
i += 1
|
111 |
-
new_word = tuple(new_word)
|
112 |
-
word = new_word
|
113 |
-
if len(word) == 1:
|
114 |
-
break
|
115 |
-
else:
|
116 |
-
pairs = get_pairs(word)
|
117 |
-
word = ' '.join(word)
|
118 |
-
self.cache[token] = word
|
119 |
-
return word
|
120 |
-
|
121 |
-
def encode(self, text):
|
122 |
-
bpe_tokens = []
|
123 |
-
text = whitespace_clean(basic_clean(text)).lower()
|
124 |
-
for token in re.findall(self.pat, text):
|
125 |
-
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
126 |
-
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
127 |
-
return bpe_tokens
|
128 |
-
|
129 |
-
def decode(self, tokens):
|
130 |
-
text = ''.join([self.decoder[token] for token in tokens])
|
131 |
-
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
132 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase1/FineCapEval_clipRN50_mle.yml
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/FineCapEval.json
|
6 |
-
input_label_h5: none
|
7 |
-
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
-
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
-
|
11 |
-
seq_per_img: 5
|
12 |
-
batch_size: 200
|
13 |
-
learning_rate: 0.0005
|
14 |
-
|
15 |
-
checkpoint_path: ./save/clipRN50_mle/clipRN50_mle
|
16 |
-
|
17 |
-
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
18 |
-
|
19 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
20 |
-
# N=num_layers
|
21 |
-
# d_model=input_encoding_size
|
22 |
-
# d_ff=rnn_size
|
23 |
-
|
24 |
-
# will be ignored
|
25 |
-
num_layers: 6
|
26 |
-
input_encoding_size: 512
|
27 |
-
rnn_size: 2048
|
28 |
-
|
29 |
-
# Transformer config
|
30 |
-
N_enc: 6
|
31 |
-
N_dec: 6
|
32 |
-
d_model: 512
|
33 |
-
d_ff: 2048
|
34 |
-
num_att_heads: 8
|
35 |
-
dropout: 0.1
|
36 |
-
|
37 |
-
|
38 |
-
learning_rate_decay_start: 0
|
39 |
-
scheduled_sampling_start: -1
|
40 |
-
save_checkpoint_every: 3000
|
41 |
-
language_eval: 1
|
42 |
-
val_images_use: 5000
|
43 |
-
max_epochs: 15
|
44 |
-
train_sample_n: 5
|
45 |
-
|
46 |
-
REFORWARD: false
|
47 |
-
|
48 |
-
# _BASE_: transformer.yml
|
49 |
-
reduce_on_plateau: false
|
50 |
-
noamopt: false
|
51 |
-
learning_rate: 0.000005
|
52 |
-
learning_rate_decay_start: -1
|
53 |
-
|
54 |
-
self_critical_after: 15
|
55 |
-
max_epochs: 50
|
56 |
-
|
57 |
-
verbose: false
|
58 |
-
precision: 32
|
59 |
-
|
60 |
-
use_clipscore: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase1/clipRN50_mle.yml
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
# noamopt: false
|
4 |
-
noamopt_warmup: 20000
|
5 |
-
label_smoothing: 0.0
|
6 |
-
input_json: data/cocotalk.json
|
7 |
-
input_label_h5: data/cocotalk_label.h5
|
8 |
-
input_fc_dir: data/cocotalk_clip_RN50_fc
|
9 |
-
input_att_dir: data/cocotalk_clip_RN50_att
|
10 |
-
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
11 |
-
seq_per_img: 5
|
12 |
-
# batch_size: 600
|
13 |
-
batch_size: 200
|
14 |
-
|
15 |
-
learning_rate: 0.0005
|
16 |
-
|
17 |
-
# checkpoint_path: ./save/trans_clip_rn50_sc_pl
|
18 |
-
checkpoint_path: save/clipRN50_mle/clipRN50_mle
|
19 |
-
|
20 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
21 |
-
# N=num_layers
|
22 |
-
# d_model=input_encoding_size
|
23 |
-
# d_ff=rnn_size
|
24 |
-
|
25 |
-
# will be ignored
|
26 |
-
num_layers: 6
|
27 |
-
input_encoding_size: 512
|
28 |
-
rnn_size: 2048
|
29 |
-
|
30 |
-
# Transformer config
|
31 |
-
N_enc: 6
|
32 |
-
N_dec: 6
|
33 |
-
d_model: 512
|
34 |
-
d_ff: 2048
|
35 |
-
num_att_heads: 8
|
36 |
-
dropout: 0.1
|
37 |
-
|
38 |
-
|
39 |
-
learning_rate_decay_start: 0
|
40 |
-
scheduled_sampling_start: -1
|
41 |
-
save_checkpoint_every: 3000
|
42 |
-
language_eval: 1
|
43 |
-
val_images_use: 5000
|
44 |
-
# max_epochs: 15
|
45 |
-
max_epochs: 25
|
46 |
-
train_sample_n: 5
|
47 |
-
|
48 |
-
REFORWARD: false
|
49 |
-
|
50 |
-
|
51 |
-
verbose: false
|
52 |
-
precision: 16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase1/transformer.yml
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/cocotalk.json
|
6 |
-
input_label_h5: data/cocotalk_label.h5
|
7 |
-
input_att_dir: data/cocotalk_att
|
8 |
-
seq_per_img: 5
|
9 |
-
batch_size: 10
|
10 |
-
learning_rate: 0.0005
|
11 |
-
|
12 |
-
checkpoint_path: ./save/trans_rn50_sc
|
13 |
-
|
14 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
15 |
-
# N=num_layers
|
16 |
-
# d_model=input_encoding_size
|
17 |
-
# d_ff=rnn_size
|
18 |
-
|
19 |
-
# will be ignored
|
20 |
-
num_layers: 6
|
21 |
-
input_encoding_size: 512
|
22 |
-
rnn_size: 2048
|
23 |
-
|
24 |
-
# Transformer config
|
25 |
-
N_enc: 6
|
26 |
-
N_dec: 6
|
27 |
-
d_model: 512
|
28 |
-
d_ff: 2048
|
29 |
-
num_att_heads: 8
|
30 |
-
dropout: 0.1
|
31 |
-
|
32 |
-
|
33 |
-
learning_rate_decay_start: 0
|
34 |
-
scheduled_sampling_start: -1
|
35 |
-
save_checkpoint_every: 3000
|
36 |
-
language_eval: 1
|
37 |
-
val_images_use: 5000
|
38 |
-
max_epochs: 15
|
39 |
-
train_sample_n: 5
|
40 |
-
|
41 |
-
REFORWARD: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/FineCapEval_clipRN50_cider.yml
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/FineCapEval.json
|
6 |
-
input_label_h5: none
|
7 |
-
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
-
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
-
|
11 |
-
seq_per_img: 5
|
12 |
-
batch_size: 200
|
13 |
-
learning_rate: 0.0005
|
14 |
-
|
15 |
-
checkpoint_path: ./save/clipRN50_cider/clipRN50_cider
|
16 |
-
|
17 |
-
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
18 |
-
|
19 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
20 |
-
# N=num_layers
|
21 |
-
# d_model=input_encoding_size
|
22 |
-
# d_ff=rnn_size
|
23 |
-
|
24 |
-
# will be ignored
|
25 |
-
num_layers: 6
|
26 |
-
input_encoding_size: 512
|
27 |
-
rnn_size: 2048
|
28 |
-
|
29 |
-
# Transformer config
|
30 |
-
N_enc: 6
|
31 |
-
N_dec: 6
|
32 |
-
d_model: 512
|
33 |
-
d_ff: 2048
|
34 |
-
num_att_heads: 8
|
35 |
-
dropout: 0.1
|
36 |
-
|
37 |
-
|
38 |
-
learning_rate_decay_start: 0
|
39 |
-
scheduled_sampling_start: -1
|
40 |
-
save_checkpoint_every: 3000
|
41 |
-
language_eval: 1
|
42 |
-
val_images_use: 5000
|
43 |
-
max_epochs: 15
|
44 |
-
train_sample_n: 5
|
45 |
-
|
46 |
-
REFORWARD: false
|
47 |
-
|
48 |
-
# _BASE_: transformer.yml
|
49 |
-
reduce_on_plateau: false
|
50 |
-
noamopt: false
|
51 |
-
learning_rate: 0.000005
|
52 |
-
learning_rate_decay_start: -1
|
53 |
-
|
54 |
-
self_critical_after: 15
|
55 |
-
max_epochs: 50
|
56 |
-
|
57 |
-
verbose: false
|
58 |
-
precision: 32
|
59 |
-
|
60 |
-
# use_clipscore: true
|
61 |
-
use_clipscore: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/FineCapEval_clipRN50_cider_clips.yml
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/FineCapEval.json
|
6 |
-
input_label_h5: none
|
7 |
-
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
-
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
-
|
11 |
-
seq_per_img: 5
|
12 |
-
batch_size: 200
|
13 |
-
learning_rate: 0.0005
|
14 |
-
|
15 |
-
checkpoint_path: ./save/clipRN50_cider_clips/clipRN50_cider_clips
|
16 |
-
|
17 |
-
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
18 |
-
|
19 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
20 |
-
# N=num_layers
|
21 |
-
# d_model=input_encoding_size
|
22 |
-
# d_ff=rnn_size
|
23 |
-
|
24 |
-
# will be ignored
|
25 |
-
num_layers: 6
|
26 |
-
input_encoding_size: 512
|
27 |
-
rnn_size: 2048
|
28 |
-
|
29 |
-
# Transformer config
|
30 |
-
N_enc: 6
|
31 |
-
N_dec: 6
|
32 |
-
d_model: 512
|
33 |
-
d_ff: 2048
|
34 |
-
num_att_heads: 8
|
35 |
-
dropout: 0.1
|
36 |
-
|
37 |
-
|
38 |
-
learning_rate_decay_start: 0
|
39 |
-
scheduled_sampling_start: -1
|
40 |
-
save_checkpoint_every: 3000
|
41 |
-
language_eval: 1
|
42 |
-
val_images_use: 5000
|
43 |
-
max_epochs: 15
|
44 |
-
train_sample_n: 5
|
45 |
-
|
46 |
-
REFORWARD: false
|
47 |
-
|
48 |
-
# _BASE_: transformer.yml
|
49 |
-
reduce_on_plateau: false
|
50 |
-
noamopt: false
|
51 |
-
learning_rate: 0.000005
|
52 |
-
learning_rate_decay_start: -1
|
53 |
-
|
54 |
-
self_critical_after: 15
|
55 |
-
max_epochs: 50
|
56 |
-
|
57 |
-
verbose: false
|
58 |
-
precision: 32
|
59 |
-
|
60 |
-
# use_clipscore: true
|
61 |
-
use_clipscore: false
|
62 |
-
clipscore_reward_weight: 2.0
|
63 |
-
clipscore_mode: clip_s
|
64 |
-
|
65 |
-
use_multi_rewards: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/FineCapEval_clipRN50_clips.yml
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/FineCapEval.json
|
6 |
-
input_label_h5: none
|
7 |
-
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
-
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
-
seq_per_img: 5
|
11 |
-
batch_size: 160
|
12 |
-
learning_rate: 0.0005
|
13 |
-
|
14 |
-
checkpoint_path: ./save/clipRN50_clips/clipRN50_clips
|
15 |
-
|
16 |
-
use_multi_rewards: false
|
17 |
-
use_grammar: false
|
18 |
-
use_grammar_baseline: false
|
19 |
-
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
20 |
-
|
21 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
22 |
-
# N=num_layers
|
23 |
-
# d_model=input_encoding_size
|
24 |
-
# d_ff=rnn_size
|
25 |
-
|
26 |
-
# will be ignored
|
27 |
-
num_layers: 6
|
28 |
-
input_encoding_size: 512
|
29 |
-
rnn_size: 2048
|
30 |
-
|
31 |
-
# Transformer config
|
32 |
-
N_enc: 6
|
33 |
-
N_dec: 6
|
34 |
-
d_model: 512
|
35 |
-
d_ff: 2048
|
36 |
-
num_att_heads: 8
|
37 |
-
dropout: 0.1
|
38 |
-
|
39 |
-
|
40 |
-
learning_rate_decay_start: 0
|
41 |
-
scheduled_sampling_start: -1
|
42 |
-
save_checkpoint_every: 3000
|
43 |
-
language_eval: 0
|
44 |
-
val_images_use: 5000
|
45 |
-
max_epochs: 15
|
46 |
-
train_sample_n: 5
|
47 |
-
|
48 |
-
REFORWARD: false
|
49 |
-
|
50 |
-
# _BASE_: transformer.yml
|
51 |
-
reduce_on_plateau: false
|
52 |
-
noamopt: false
|
53 |
-
learning_rate: 0.000005
|
54 |
-
learning_rate_decay_start: -1
|
55 |
-
|
56 |
-
self_critical_after: 15
|
57 |
-
max_epochs: 50
|
58 |
-
|
59 |
-
verbose: false
|
60 |
-
precision: 32
|
61 |
-
|
62 |
-
# use_clipscore: true
|
63 |
-
use_clipscore: false
|
64 |
-
clipscore_reward_weight: 2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/FineCapEval_clipRN50_clips_grammar.yml
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/FineCapEval.json
|
6 |
-
input_label_h5: none
|
7 |
-
input_fc_dir: data/FineCapEval_clip_RN50_fc
|
8 |
-
input_att_dir: data/FineCapEval_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/FineCapEval_clipscore_vis
|
10 |
-
seq_per_img: 5
|
11 |
-
batch_size: 160
|
12 |
-
learning_rate: 0.0005
|
13 |
-
|
14 |
-
checkpoint_path: ./save/clipRN50_clips_grammar/clipRN50_clips_grammar
|
15 |
-
|
16 |
-
use_multi_rewards: true
|
17 |
-
use_grammar: true
|
18 |
-
use_grammar_baseline: true
|
19 |
-
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
20 |
-
|
21 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
22 |
-
# N=num_layers
|
23 |
-
# d_model=input_encoding_size
|
24 |
-
# d_ff=rnn_size
|
25 |
-
|
26 |
-
# will be ignored
|
27 |
-
num_layers: 6
|
28 |
-
input_encoding_size: 512
|
29 |
-
rnn_size: 2048
|
30 |
-
|
31 |
-
# Transformer config
|
32 |
-
N_enc: 6
|
33 |
-
N_dec: 6
|
34 |
-
d_model: 512
|
35 |
-
d_ff: 2048
|
36 |
-
num_att_heads: 8
|
37 |
-
dropout: 0.1
|
38 |
-
|
39 |
-
|
40 |
-
learning_rate_decay_start: 0
|
41 |
-
scheduled_sampling_start: -1
|
42 |
-
save_checkpoint_every: 3000
|
43 |
-
language_eval: 0
|
44 |
-
val_images_use: 5000
|
45 |
-
max_epochs: 15
|
46 |
-
train_sample_n: 5
|
47 |
-
|
48 |
-
REFORWARD: false
|
49 |
-
|
50 |
-
# _BASE_: transformer.yml
|
51 |
-
reduce_on_plateau: false
|
52 |
-
noamopt: false
|
53 |
-
learning_rate: 0.000005
|
54 |
-
learning_rate_decay_start: -1
|
55 |
-
|
56 |
-
self_critical_after: 15
|
57 |
-
max_epochs: 50
|
58 |
-
|
59 |
-
verbose: false
|
60 |
-
precision: 32
|
61 |
-
|
62 |
-
# use_clipscore: true
|
63 |
-
use_clipscore: false
|
64 |
-
clipscore_reward_weight: 2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/clipRN50_cider.yml
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/cocotalk.json
|
6 |
-
input_label_h5: data/cocotalk_label.h5
|
7 |
-
input_fc_dir: data/cocotalk_clip_RN50_fc
|
8 |
-
input_att_dir: data/cocotalk_clip_RN50_att
|
9 |
-
# used only for evaluation
|
10 |
-
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
11 |
-
|
12 |
-
seq_per_img: 5
|
13 |
-
batch_size: 200
|
14 |
-
learning_rate: 0.0005
|
15 |
-
|
16 |
-
# checkpoint_path: ./save/trans_clip_rn50_sc_pl_scst_cider
|
17 |
-
checkpoint_path: save/clipRN50_cider/clipRN50_cider
|
18 |
-
|
19 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
20 |
-
# N=num_layers
|
21 |
-
# d_model=input_encoding_size
|
22 |
-
# d_ff=rnn_size
|
23 |
-
|
24 |
-
# will be ignored
|
25 |
-
num_layers: 6
|
26 |
-
input_encoding_size: 512
|
27 |
-
rnn_size: 2048
|
28 |
-
|
29 |
-
# Transformer config
|
30 |
-
N_enc: 6
|
31 |
-
N_dec: 6
|
32 |
-
d_model: 512
|
33 |
-
d_ff: 2048
|
34 |
-
num_att_heads: 8
|
35 |
-
dropout: 0.1
|
36 |
-
|
37 |
-
|
38 |
-
learning_rate_decay_start: 0
|
39 |
-
scheduled_sampling_start: -1
|
40 |
-
save_checkpoint_every: 3000
|
41 |
-
language_eval: 1
|
42 |
-
val_images_use: 5000
|
43 |
-
max_epochs: 15
|
44 |
-
train_sample_n: 5
|
45 |
-
|
46 |
-
REFORWARD: false
|
47 |
-
|
48 |
-
# _BASE_: transformer.yml
|
49 |
-
reduce_on_plateau: false
|
50 |
-
noamopt: false
|
51 |
-
learning_rate: 0.000005
|
52 |
-
learning_rate_decay_start: -1
|
53 |
-
|
54 |
-
self_critical_after: 15
|
55 |
-
max_epochs: 40
|
56 |
-
|
57 |
-
verbose: false
|
58 |
-
precision: 32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/clipRN50_cider_clips.yml
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/cocotalk.json
|
6 |
-
input_label_h5: data/cocotalk_label.h5
|
7 |
-
input_fc_dir: data/cocotalk_clip_RN50_fc
|
8 |
-
input_att_dir: data/cocotalk_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
10 |
-
seq_per_img: 5
|
11 |
-
batch_size: 160
|
12 |
-
learning_rate: 0.0005
|
13 |
-
|
14 |
-
checkpoint_path: save/clipRN50_cider_clips/clipRN50_cider_clips
|
15 |
-
|
16 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
17 |
-
# N=num_layers
|
18 |
-
# d_model=input_encoding_size
|
19 |
-
# d_ff=rnn_size
|
20 |
-
|
21 |
-
# will be ignored
|
22 |
-
num_layers: 6
|
23 |
-
input_encoding_size: 512
|
24 |
-
rnn_size: 2048
|
25 |
-
|
26 |
-
# Transformer config
|
27 |
-
N_enc: 6
|
28 |
-
N_dec: 6
|
29 |
-
d_model: 512
|
30 |
-
d_ff: 2048
|
31 |
-
num_att_heads: 8
|
32 |
-
dropout: 0.1
|
33 |
-
|
34 |
-
|
35 |
-
learning_rate_decay_start: 0
|
36 |
-
scheduled_sampling_start: -1
|
37 |
-
save_checkpoint_every: 3000
|
38 |
-
language_eval: 1
|
39 |
-
val_images_use: 5000
|
40 |
-
max_epochs: 15
|
41 |
-
train_sample_n: 5
|
42 |
-
|
43 |
-
REFORWARD: false
|
44 |
-
|
45 |
-
# _BASE_: transformer.yml
|
46 |
-
reduce_on_plateau: false
|
47 |
-
noamopt: false
|
48 |
-
learning_rate: 0.000005
|
49 |
-
learning_rate_decay_start: -1
|
50 |
-
|
51 |
-
self_critical_after: 15
|
52 |
-
max_epochs: 40
|
53 |
-
|
54 |
-
verbose: false
|
55 |
-
precision: 32
|
56 |
-
|
57 |
-
use_clipscore: true
|
58 |
-
clipscore_reward_weight: 2.0
|
59 |
-
clipscore_mode: clip_s
|
60 |
-
|
61 |
-
use_multi_rewards: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/clipRN50_clips.yml
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/cocotalk.json
|
6 |
-
input_label_h5: data/cocotalk_label.h5
|
7 |
-
input_fc_dir: data/cocotalk_clip_RN50_fc
|
8 |
-
input_att_dir: data/cocotalk_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
10 |
-
seq_per_img: 5
|
11 |
-
batch_size: 160
|
12 |
-
learning_rate: 0.0005
|
13 |
-
|
14 |
-
checkpoint_path: save/clipRN50_clips/clipRN50_clips
|
15 |
-
|
16 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
17 |
-
# N=num_layers
|
18 |
-
# d_model=input_encoding_size
|
19 |
-
# d_ff=rnn_size
|
20 |
-
|
21 |
-
# will be ignored
|
22 |
-
num_layers: 6
|
23 |
-
input_encoding_size: 512
|
24 |
-
rnn_size: 2048
|
25 |
-
|
26 |
-
# Transformer config
|
27 |
-
N_enc: 6
|
28 |
-
N_dec: 6
|
29 |
-
d_model: 512
|
30 |
-
d_ff: 2048
|
31 |
-
num_att_heads: 8
|
32 |
-
dropout: 0.1
|
33 |
-
|
34 |
-
|
35 |
-
learning_rate_decay_start: 0
|
36 |
-
scheduled_sampling_start: -1
|
37 |
-
save_checkpoint_every: 3000
|
38 |
-
language_eval: 1
|
39 |
-
val_images_use: 5000
|
40 |
-
max_epochs: 15
|
41 |
-
train_sample_n: 5
|
42 |
-
|
43 |
-
REFORWARD: false
|
44 |
-
|
45 |
-
# _BASE_: transformer.yml
|
46 |
-
reduce_on_plateau: false
|
47 |
-
noamopt: false
|
48 |
-
learning_rate: 0.000005
|
49 |
-
learning_rate_decay_start: -1
|
50 |
-
|
51 |
-
self_critical_after: 15
|
52 |
-
max_epochs: 40
|
53 |
-
|
54 |
-
verbose: false
|
55 |
-
precision: 32
|
56 |
-
|
57 |
-
use_clipscore: true
|
58 |
-
clipscore_reward_weight: 2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/clipRN50_clips_grammar.yml
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/cocotalk.json
|
6 |
-
input_label_h5: data/cocotalk_label.h5
|
7 |
-
input_fc_dir: data/cocotalk_clip_RN50_fc
|
8 |
-
input_att_dir: data/cocotalk_clip_RN50_att
|
9 |
-
input_clipscore_vis_dir: data/cocotalk_clipscore_vis
|
10 |
-
seq_per_img: 5
|
11 |
-
batch_size: 160
|
12 |
-
learning_rate: 0.0005
|
13 |
-
|
14 |
-
checkpoint_path: save/clipRN50_clips_grammar/clipRN50_clips_grammar
|
15 |
-
|
16 |
-
use_multi_rewards: true
|
17 |
-
use_grammar: true
|
18 |
-
use_grammar_baseline: true
|
19 |
-
# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt'
|
20 |
-
clip_load_path: 'retrieval/save/clip_negative_text/clip_negative_text-epoch=12.ckpt'
|
21 |
-
|
22 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
23 |
-
# N=num_layers
|
24 |
-
# d_model=input_encoding_size
|
25 |
-
# d_ff=rnn_size
|
26 |
-
|
27 |
-
# will be ignored
|
28 |
-
num_layers: 6
|
29 |
-
input_encoding_size: 512
|
30 |
-
rnn_size: 2048
|
31 |
-
|
32 |
-
# Transformer config
|
33 |
-
N_enc: 6
|
34 |
-
N_dec: 6
|
35 |
-
d_model: 512
|
36 |
-
d_ff: 2048
|
37 |
-
num_att_heads: 8
|
38 |
-
dropout: 0.1
|
39 |
-
|
40 |
-
|
41 |
-
learning_rate_decay_start: 0
|
42 |
-
scheduled_sampling_start: -1
|
43 |
-
save_checkpoint_every: 3000
|
44 |
-
language_eval: 1
|
45 |
-
val_images_use: 5000
|
46 |
-
max_epochs: 15
|
47 |
-
train_sample_n: 5
|
48 |
-
|
49 |
-
REFORWARD: false
|
50 |
-
|
51 |
-
# _BASE_: transformer.yml
|
52 |
-
reduce_on_plateau: false
|
53 |
-
noamopt: false
|
54 |
-
learning_rate: 0.000005
|
55 |
-
learning_rate_decay_start: -1
|
56 |
-
|
57 |
-
self_critical_after: 15
|
58 |
-
max_epochs: 40
|
59 |
-
|
60 |
-
verbose: false
|
61 |
-
precision: 32
|
62 |
-
|
63 |
-
use_clipscore: true
|
64 |
-
clipscore_reward_weight: 2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/phase2/transformer.yml
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
caption_model: transformer
|
2 |
-
noamopt: true
|
3 |
-
noamopt_warmup: 20000
|
4 |
-
label_smoothing: 0.0
|
5 |
-
input_json: data/cocotalk.json
|
6 |
-
input_label_h5: data/cocotalk_label.h5
|
7 |
-
input_att_dir: data/cocotalk_att
|
8 |
-
seq_per_img: 5
|
9 |
-
batch_size: 10
|
10 |
-
learning_rate: 0.0005
|
11 |
-
|
12 |
-
checkpoint_path: ./save/trans_rn50_sc
|
13 |
-
|
14 |
-
# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer:
|
15 |
-
# N=num_layers
|
16 |
-
# d_model=input_encoding_size
|
17 |
-
# d_ff=rnn_size
|
18 |
-
|
19 |
-
# will be ignored
|
20 |
-
num_layers: 6
|
21 |
-
input_encoding_size: 512
|
22 |
-
rnn_size: 2048
|
23 |
-
|
24 |
-
# Transformer config
|
25 |
-
N_enc: 6
|
26 |
-
N_dec: 6
|
27 |
-
d_model: 512
|
28 |
-
d_ff: 2048
|
29 |
-
num_att_heads: 8
|
30 |
-
dropout: 0.1
|
31 |
-
|
32 |
-
|
33 |
-
learning_rate_decay_start: 0
|
34 |
-
scheduled_sampling_start: -1
|
35 |
-
save_checkpoint_every: 3000
|
36 |
-
language_eval: 1
|
37 |
-
val_images_use: 5000
|
38 |
-
max_epochs: 15
|
39 |
-
train_sample_n: 5
|
40 |
-
|
41 |
-
REFORWARD: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|