diff --git a/app.py b/app.py index 3aca2d791cf8f511669598537f8b2fa0ace6255c..6b1dd2c32cd02024cf3d7af36b17d7de334e06dd 100644 --- a/app.py +++ b/app.py @@ -8,6 +8,11 @@ import captioning.utils.misc as utils import pytorch_lightning as pl import gradio as gr +from diffusers import LDMTextToImagePipeline +# import PIL.Image +import random +import os + # Checkpoint class class ModelCheckpoint(pl.callbacks.ModelCheckpoint): @@ -47,7 +52,6 @@ opt.seq_length = seq_length opt.batch_size = 1 opt.vocab = ix_to_word -# opt.use_grammar = False model = models.setup(opt) del opt.vocab @@ -111,55 +115,74 @@ clip_model.visual.attnpool.positional_embedding = pos_embed # End below - -def generate_image(img, steps=100, seed=42, guidance_scale=6.0): - - with torch.no_grad(): - image = preprocess(img) - image = torch.tensor(np.stack([image])).to(device) - image -= image_mean - image /= image_std - - tmp_att, tmp_fc = clip_model.encode_image(image) - tmp_att = tmp_att[0].permute(1, 2, 0) - tmp_fc = tmp_fc[0] - - att_feat = tmp_att - fc_feat = tmp_fc - +print('Loading the model: CompVis/ldm-text2im-large-256') +ldm_pipeline = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") + +def generate_image_from_text(prompt, steps=100, seed=42, guidance_scale=6.0): + print('RUN: generate_image_from_text') + torch.cuda.empty_cache() + generator = torch.manual_seed(seed) + images = ldm_pipeline([prompt], generator=generator, num_inference_steps=steps, eta=0.3, guidance_scale=guidance_scale)["sample"] + return images[0] + +def generate_text_from_image(img): + print('RUN: generate_text_from_image') + with torch.no_grad(): + image = preprocess(img) + image = torch.tensor(np.stack([image])).to(device) + image -= image_mean + image /= image_std + + tmp_att, tmp_fc = clip_model.encode_image(image) + tmp_att = tmp_att[0].permute(1, 2, 0) + tmp_fc = tmp_fc[0] + + att_feat = tmp_att + fc_feat = tmp_fc - # Inference configurations - eval_kwargs = {} - eval_kwargs.update(vars(opt)) + # Inference configurations + eval_kwargs = {} + eval_kwargs.update(vars(opt)) - verbose = eval_kwargs.get('verbose', True) - verbose_beam = eval_kwargs.get('verbose_beam', 0) - verbose_loss = eval_kwargs.get('verbose_loss', 1) + verbose = eval_kwargs.get('verbose', True) + verbose_beam = eval_kwargs.get('verbose_beam', 0) + verbose_loss = eval_kwargs.get('verbose_loss', 1) - # dataset = eval_kwargs.get('dataset', 'coco') - beam_size = eval_kwargs.get('beam_size', 1) - sample_n = eval_kwargs.get('sample_n', 1) - remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) + # dataset = eval_kwargs.get('dataset', 'coco') + beam_size = eval_kwargs.get('beam_size', 1) + sample_n = eval_kwargs.get('sample_n', 1) + remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) - with torch.no_grad(): - fc_feats = torch.zeros((1,0)).to(device) - att_feats = att_feat.view(1, 196, 2048).float().to(device) - att_masks = None - - # forward the model to also get generated samples for each image - # Only leave one feature for each image, in case duplicate sample - tmp_eval_kwargs = eval_kwargs.copy() - tmp_eval_kwargs.update({'sample_n': 1}) - seq, seq_logprobs = model( - fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') - seq = seq.data - - sents = utils.decode_sequence(model.vocab, seq) + with torch.no_grad(): + fc_feats = torch.zeros((1,0)).to(device) + att_feats = att_feat.view(1, 196, 2048).float().to(device) + att_masks = None + + # forward the model to also get generated samples for each image + # Only leave one feature for each image, in case duplicate sample + tmp_eval_kwargs = eval_kwargs.copy() + tmp_eval_kwargs.update({'sample_n': 1}) + seq, seq_logprobs = model( + fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') + seq = seq.data + + sents = utils.decode_sequence(model.vocab, seq) + + return sents[0] + + +def generate_drawing_from_image(img, steps=100, seed=42, guidance_scale=6.0): + print('RUN: generate_drawing_from_image') + caption = generate_text_from_image(img) + gen_image = generate_image_from_text(caption, steps=steps, seed=seed, guidance_scale=guidance_scale) + return gen_image + + +random_seed = random.randint(0, 2147483647) - return sents[0] gr.Interface( - generate_image, + generate_drawing_from_image, inputs=[ gr.Image(type="pil"), gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1), @@ -168,4 +191,4 @@ gr.Interface( ], outputs=gr.Image(shape=[256,256], type="pil", elem_id="output_image"), css="#output_image{width: 256px}", -).launch() \ No newline at end of file +).launch() diff --git a/captioning/__init__.py b/captioning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captioning/data/__init__.py b/captioning/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captioning/data/dataloader.py b/captioning/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..7f2ed0304bd94db21bbc9fbdc6857beccb8bb621 --- /dev/null +++ b/captioning/data/dataloader.py @@ -0,0 +1,425 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import h5py +from lmdbdict import lmdbdict +from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC +import os +import numpy as np +import numpy.random as npr +import random +from functools import partial + +import torch +import torch.utils.data as data + +import multiprocessing +import six + +class HybridLoader: + """ + If db_path is a director, then use normal file loading + If lmdb, then load from lmdb + The loading method depend on extention. + + in_memory: if in_memory is True, we save all the features in memory + For individual np(y|z)s, we don't need to do that because the system will do this for us. + Should be useful for lmdb or h5. + (Copied this idea from vilbert) + """ + def __init__(self, db_path, ext, in_memory=False): + self.db_path = db_path + self.ext = ext + if self.ext == '.npy': + self.loader = lambda x: np.load(six.BytesIO(x)) + else: + def load_npz(x): + x = np.load(six.BytesIO(x)) + return x['feat'] if 'feat' in x else x['z'] # normally it should be 'feat', but under cocotest_bu, the key is saved to be 'z' mistakenly. + self.loader = load_npz + if db_path.endswith('.lmdb'): + self.db_type = 'lmdb' + self.lmdb = lmdbdict(db_path, unsafe=True) + self.lmdb._key_dumps = DUMPS_FUNC['ascii'] + self.lmdb._value_loads = LOADS_FUNC['identity'] + elif db_path.endswith('.pth'): # Assume a key,value dictionary + self.db_type = 'pth' + self.feat_file = torch.load(db_path) + self.loader = lambda x: x + print('HybridLoader: ext is ignored') + elif db_path.endswith('h5'): + self.db_type = 'h5' + self.loader = lambda x: np.array(x).astype('float32') + else: + self.db_type = 'dir' + + self.in_memory = in_memory + if self.in_memory: + self.features = {} + + def get(self, key): + + if self.in_memory and key in self.features: + # We save f_input because we want to save the + # compressed bytes to save memory + f_input = self.features[key] + elif self.db_type == 'lmdb': + f_input = self.lmdb[key] + elif self.db_type == 'pth': + f_input = self.feat_file[key] + elif self.db_type == 'h5': + f_input = h5py.File(self.db_path, 'r')[key] + else: + f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read() + + if self.in_memory and key not in self.features: + self.features[key] = f_input + + # load image + feat = self.loader(f_input) + + return feat + +class Dataset(data.Dataset): + + def get_vocab_size(self): + return self.vocab_size + + def get_vocab(self): + return self.ix_to_word + + def get_seq_length(self): + return self.seq_length + + def __init__(self, opt): + self.opt = opt + self.seq_per_img = opt.seq_per_img + + # feature related options + self.use_fc = getattr(opt, 'use_fc', True) + self.use_att = getattr(opt, 'use_att', True) + self.use_box = getattr(opt, 'use_box', 0) + self.norm_att_feat = getattr(opt, 'norm_att_feat', 0) + self.norm_box_feat = getattr(opt, 'norm_box_feat', 0) + + # load the json file which contains additional information about the dataset + print('DataLoader loading json file: ', opt.input_json) + self.info = json.load(open(self.opt.input_json)) + if 'ix_to_word' in self.info: + self.ix_to_word = self.info['ix_to_word'] + self.vocab_size = len(self.ix_to_word) + print('vocab size is ', self.vocab_size) + + # open the hdf5 file + print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5) + """ + Setting input_label_h5 to none is used when only doing generation. + For example, when you need to test on coco test set. + """ + if self.opt.input_label_h5 != 'none': + self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') + # load in the sequence data + seq_size = self.h5_label_file['labels'].shape + self.label = self.h5_label_file['labels'][:] + self.seq_length = seq_size[1] + print('max sequence length in data is', self.seq_length) + # load the pointers in full to RAM (should be small enough) + self.label_start_ix = self.h5_label_file['label_start_ix'][:] + self.label_end_ix = self.h5_label_file['label_end_ix'][:] + else: + self.seq_length = 1 + + self.data_in_memory = getattr(opt, 'data_in_memory', False) + self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory) + self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory) + self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory) + + self.num_images = len(self.info['images']) # self.label_start_ix.shape[0] + print('read %d image features' %(self.num_images)) + + # separate out indexes for each of the provided splits + self.split_ix = {'train': [], 'val': [], 'test': []} + for ix in range(len(self.info['images'])): + img = self.info['images'][ix] + if not 'split' in img: + self.split_ix['train'].append(ix) + self.split_ix['val'].append(ix) + self.split_ix['test'].append(ix) + elif img['split'] == 'train': + self.split_ix['train'].append(ix) + elif img['split'] == 'val': + self.split_ix['val'].append(ix) + elif img['split'] == 'test': + self.split_ix['test'].append(ix) + elif opt.train_only == 0: # restval + self.split_ix['train'].append(ix) + + print('assigned %d images to split train' %len(self.split_ix['train'])) + print('assigned %d images to split val' %len(self.split_ix['val'])) + print('assigned %d images to split test' %len(self.split_ix['test'])) + + def get_captions(self, ix, seq_per_img): + # fetch the sequence labels + ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 + ix2 = self.label_end_ix[ix] - 1 + ncap = ix2 - ix1 + 1 # number of captions available for this image + assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' + + if ncap < seq_per_img: + # we need to subsample (with replacement) + seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') + for q in range(seq_per_img): + ixl = random.randint(ix1,ix2) + seq[q, :] = self.label[ixl, :self.seq_length] + else: + ixl = random.randint(ix1, ix2 - seq_per_img + 1) + seq = self.label[ixl: ixl + seq_per_img, :self.seq_length] + + return seq + + def collate_func(self, batch, split): + seq_per_img = self.seq_per_img + + fc_batch = [] + att_batch = [] + label_batch = [] + + wrapped = False + + infos = [] + gts = [] + + for sample in batch: + # fetch image + tmp_fc, tmp_att, tmp_seq, \ + ix, it_pos_now, tmp_wrapped = sample + if tmp_wrapped: + wrapped = True + + fc_batch.append(tmp_fc) + att_batch.append(tmp_att) + + tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int') + if hasattr(self, 'h5_label_file'): + # if there is ground truth + tmp_label[:, 1 : self.seq_length + 1] = tmp_seq + label_batch.append(tmp_label) + + # Used for reward evaluation + if hasattr(self, 'h5_label_file'): + # if there is ground truth + gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) + else: + gts.append([]) + + # record associated info as well + info_dict = {} + info_dict['ix'] = ix + info_dict['id'] = self.info['images'][ix]['id'] + info_dict['file_path'] = self.info['images'][ix].get('file_path', '') + infos.append(info_dict) + + # #sort by att_feat length + # fc_batch, att_batch, label_batch, gts, infos = \ + # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) + fc_batch, att_batch, label_batch, gts, infos = \ + zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) + data = {} + data['fc_feats'] = np.stack(fc_batch) + # merge att_feats + max_att_len = max([_.shape[0] for _ in att_batch]) + data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32') + for i in range(len(att_batch)): + data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i] + data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32') + for i in range(len(att_batch)): + data['att_masks'][i, :att_batch[i].shape[0]] = 1 + # set att_masks to None if attention features have same length + if data['att_masks'].sum() == data['att_masks'].size: + data['att_masks'] = None + + data['labels'] = np.vstack(label_batch) + # generate mask + nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels']))) + mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32') + for ix, row in enumerate(mask_batch): + row[:nonzeros[ix]] = 1 + data['masks'] = mask_batch + data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1) + data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1) + + data['gts'] = gts # all ground truth captions of each images + data['bounds'] = {'it_pos_now': it_pos_now, # the it_pos_now of the last sample + 'it_max': len(self.split_ix[split]), 'wrapped': wrapped} + data['infos'] = infos + + data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor + + return data + + def __getitem__(self, index): + """This function returns a tuple that is further passed to collate_fn + """ + ix, it_pos_now, wrapped = index #self.split_ix[index] + if self.use_att: + att_feat = self.att_loader.get(str(self.info['images'][ix]['id'])) + # Reshape to K x C + att_feat = att_feat.reshape(-1, att_feat.shape[-1]) + if self.norm_att_feat: + att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True) + if self.use_box: + box_feat = self.box_loader.get(str(self.info['images'][ix]['id'])) + # devided by image width and height + x1,y1,x2,y2 = np.hsplit(box_feat, 4) + h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width'] + box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1?? + if self.norm_box_feat: + box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True) + att_feat = np.hstack([att_feat, box_feat]) + # sort the features by the size of boxes + att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True)) + else: + att_feat = np.zeros((0,0), dtype='float32') + if self.use_fc: + try: + fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id'])) + except: + # Use average of attention when there is no fc provided (For bottomup feature) + fc_feat = att_feat.mean(0) + else: + fc_feat = np.zeros((0), dtype='float32') + if hasattr(self, 'h5_label_file'): + seq = self.get_captions(ix, self.seq_per_img) + else: + seq = None + return (fc_feat, + att_feat, seq, + ix, it_pos_now, wrapped) + + def __len__(self): + return len(self.info['images']) + +class DataLoader: + def __init__(self, opt): + self.opt = opt + self.batch_size = self.opt.batch_size + self.dataset = Dataset(opt) + + # Initialize loaders and iters + self.loaders, self.iters = {}, {} + for split in ['train', 'val', 'test']: + if split == 'train': + sampler = MySampler(self.dataset.split_ix[split], shuffle=True, wrap=True) + else: + sampler = MySampler(self.dataset.split_ix[split], shuffle=False, wrap=False) + self.loaders[split] = data.DataLoader(dataset=self.dataset, + batch_size=self.batch_size, + sampler=sampler, + pin_memory=True, + num_workers=4, # 4 is usually enough + collate_fn=partial(self.dataset.collate_func, split=split), + drop_last=False) + self.iters[split] = iter(self.loaders[split]) + + def get_batch(self, split): + try: + data = next(self.iters[split]) + except StopIteration: + self.iters[split] = iter(self.loaders[split]) + data = next(self.iters[split]) + return data + + def reset_iterator(self, split): + self.loaders[split].sampler._reset_iter() + self.iters[split] = iter(self.loaders[split]) + + def get_vocab_size(self): + return self.dataset.get_vocab_size() + + @property + def vocab_size(self): + return self.get_vocab_size() + + def get_vocab(self): + return self.dataset.get_vocab() + + def get_seq_length(self): + return self.dataset.get_seq_length() + + @property + def seq_length(self): + return self.get_seq_length() + + def state_dict(self): + def get_prefetch_num(split): + if self.loaders[split].num_workers > 0: + return (self.iters[split]._send_idx - self.iters[split]._rcvd_idx) * self.batch_size + else: + return 0 + return {split: loader.sampler.state_dict(get_prefetch_num(split)) \ + for split, loader in self.loaders.items()} + + def load_state_dict(self, state_dict=None): + if state_dict is None: + return + for split in self.loaders.keys(): + self.loaders[split].sampler.load_state_dict(state_dict[split]) + + +class MySampler(data.sampler.Sampler): + def __init__(self, index_list, shuffle, wrap): + self.index_list = index_list + self.shuffle = shuffle + self.wrap = wrap + # if wrap, there will be not stop iteration called + # wrap True used during training, and wrap False used during test. + self._reset_iter() + + def __iter__(self): + return self + + def __next__(self): + wrapped = False + if self.iter_counter == len(self._index_list): + self._reset_iter() + if self.wrap: + wrapped = True + else: + raise StopIteration() + if len(self._index_list) == 0: # overflow when 0 samples + return None + elem = (self._index_list[self.iter_counter], self.iter_counter+1, wrapped) + self.iter_counter += 1 + return elem + + def next(self): + return self.__next__() + + def _reset_iter(self): + if self.shuffle: + rand_perm = npr.permutation(len(self.index_list)) + self._index_list = [self.index_list[_] for _ in rand_perm] + else: + self._index_list = self.index_list + + self.iter_counter = 0 + + def __len__(self): + return len(self.index_list) + + def load_state_dict(self, state_dict=None): + if state_dict is None: + return + self._index_list = state_dict['index_list'] + self.iter_counter = state_dict['iter_counter'] + + def state_dict(self, prefetched_num=None): + prefetched_num = prefetched_num or 0 + return { + 'index_list': self._index_list, + 'iter_counter': self.iter_counter - prefetched_num + } + + \ No newline at end of file diff --git a/captioning/data/pth_loader.py b/captioning/data/pth_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..28023699735470daa7e2ab4752a31ea8282c04c5 --- /dev/null +++ b/captioning/data/pth_loader.py @@ -0,0 +1,334 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import h5py +from lmdbdict import lmdbdict +from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC +import os +import numpy as np +import numpy.random as npr +import random + +import torch +import torch.utils.data as data + +import multiprocessing +import six + +verbose = True +# import torch +# if torch.cuda.current_device() in [0, -1]: +if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': + verbose = False + +class HybridLoader: + """ + If db_path is a director, then use normal file loading + If lmdb, then load from lmdb + The loading method depend on extention. + + in_memory: if in_memory is True, we save all the features in memory + For individual np(y|z)s, we don't need to do that because the system will do this for us. + Should be useful for lmdb or h5. + (Copied this idea from vilbert) + """ + def __init__(self, db_path, ext, in_memory=False): + self.db_path = db_path + self.ext = ext + if self.ext == '.npy': + self.loader = lambda x: np.load(six.BytesIO(x)) + else: + self.loader = lambda x: np.load(six.BytesIO(x))['feat'] + if db_path.endswith('.lmdb'): + self.db_type = 'lmdb' + self.lmdb = lmdbdict(db_path, unsafe=True) + self.lmdb._key_dumps = DUMPS_FUNC['ascii'] + self.lmdb._value_loads = LOADS_FUNC['identity'] + elif db_path.endswith('.pth'): # Assume a key,value dictionary + self.db_type = 'pth' + self.feat_file = torch.load(db_path) + self.loader = lambda x: x + print('HybridLoader: ext is ignored') + elif db_path.endswith('h5'): + self.db_type = 'h5' + self.loader = lambda x: np.array(x).astype('float32') + else: + self.db_type = 'dir' + + self.in_memory = in_memory + if self.in_memory: + self.features = {} + + def get(self, key): + + if self.in_memory and key in self.features: + # We save f_input because we want to save the + # compressed bytes to save memory + f_input = self.features[key] + elif self.db_type == 'lmdb': + f_input = self.lmdb[key] + elif self.db_type == 'pth': + f_input = self.feat_file[key] + elif self.db_type == 'h5': + f_input = h5py.File(self.db_path, 'r')[key] + else: + f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read() + + if self.in_memory and key not in self.features: + self.features[key] = f_input + + # load image + feat = self.loader(f_input) + + return feat + +class CaptionDataset(data.Dataset): + + def get_vocab_size(self): + return self.vocab_size + + def get_vocab(self): + return self.ix_to_word + + def get_seq_length(self): + return self.seq_length + + def __init__(self, opt): + self.opt = opt + self.seq_per_img = opt.seq_per_img + + # feature related options + self.use_fc = getattr(opt, 'use_fc', True) + self.use_att = getattr(opt, 'use_att', True) + self.use_box = getattr(opt, 'use_box', 0) + self.norm_att_feat = getattr(opt, 'norm_att_feat', 0) + self.norm_box_feat = getattr(opt, 'norm_box_feat', 0) + + # load the json file which contains additional information about the dataset + if verbose: + print('DataLoader loading json file: ', opt.input_json) + self.info = json.load(open(self.opt.input_json)) + if 'ix_to_word' in self.info: + self.ix_to_word = self.info['ix_to_word'] + self.vocab_size = len(self.ix_to_word) + if verbose: + print('vocab size is ', self.vocab_size) + + # open the hdf5 file + if verbose: + print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5) + """ + Setting input_label_h5 to none is used when only doing generation. + For example, when you need to test on coco test set. + """ + if self.opt.input_label_h5 != 'none': + self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') + # load in the sequence data + seq_size = self.h5_label_file['labels'].shape + self.label = self.h5_label_file['labels'][:] + self.seq_length = seq_size[1] + if verbose: + print('max sequence length in data is', self.seq_length) + # load the pointers in full to RAM (should be small enough) + self.label_start_ix = self.h5_label_file['label_start_ix'][:] + self.label_end_ix = self.h5_label_file['label_end_ix'][:] + else: + self.seq_length = 1 + + self.data_in_memory = getattr(opt, 'data_in_memory', False) + self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory) + self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory) + self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory) + + self.use_clipscore = getattr(opt, 'use_clipscore', False) + # if self.use_clipscore: + self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory) + + + self.num_images = len(self.info['images']) # self.label_start_ix.shape[0] + if verbose: + print('read %d image features' %(self.num_images)) + + # separate out indexes for each of the provided splits + self.split_ix = {'train': [], 'val': [], 'test': []} + for ix in range(len(self.info['images'])): + img = self.info['images'][ix] + if not 'split' in img: + self.split_ix['train'].append(ix) + self.split_ix['val'].append(ix) + self.split_ix['test'].append(ix) + elif img['split'] == 'train': + self.split_ix['train'].append(ix) + elif img['split'] == 'val': + self.split_ix['val'].append(ix) + elif img['split'] == 'test': + self.split_ix['test'].append(ix) + elif opt.train_only == 0: # restval + self.split_ix['train'].append(ix) + + if verbose: + print('assigned %d images to split train' %len(self.split_ix['train'])) + print('assigned %d images to split val' %len(self.split_ix['val'])) + print('assigned %d images to split test' %len(self.split_ix['test'])) + + def get_captions(self, ix, seq_per_img): + # fetch the sequence labels + ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 + ix2 = self.label_end_ix[ix] - 1 + ncap = ix2 - ix1 + 1 # number of captions available for this image + assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' + + if ncap < seq_per_img: + # we need to subsample (with replacement) + seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') + for q in range(seq_per_img): + ixl = random.randint(ix1,ix2) + seq[q, :] = self.label[ixl, :self.seq_length] + else: + ixl = random.randint(ix1, ix2 - seq_per_img + 1) + seq = self.label[ixl: ixl + seq_per_img, :self.seq_length] + + return seq + + def collate_func(self, batch): + seq_per_img = self.seq_per_img + + fc_batch = [] + att_batch = [] + label_batch = [] + + clip_vis_feat_batch = [] + + wrapped = False + + infos = [] + gts = [] + + for sample in batch: + # fetch image + # if self.use_clipscore: + tmp_fc, tmp_att, tmp_seq, \ + ix, tmp_clip_vis_feat = sample + + clip_vis_feat_batch.append(tmp_clip_vis_feat) + # else: + # tmp_fc, tmp_att, tmp_seq, \ + # ix = sample + + fc_batch.append(tmp_fc) + att_batch.append(tmp_att) + + tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int') + if hasattr(self, 'h5_label_file'): + # if there is ground truth + tmp_label[:, 1 : self.seq_length + 1] = tmp_seq + label_batch.append(tmp_label) + + # Used for reward evaluation + if hasattr(self, 'h5_label_file'): + # if there is ground truth + gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) + else: + gts.append([]) + + # record associated info as well + info_dict = {} + info_dict['ix'] = ix + info_dict['id'] = self.info['images'][ix]['id'] + info_dict['file_path'] = self.info['images'][ix].get('file_path', '') + infos.append(info_dict) + + # #sort by att_feat length + # fc_batch, att_batch, label_batch, gts, infos = \ + # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) + if self.use_clipscore: + fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \ + zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True)) + else: + fc_batch, att_batch, label_batch, gts, infos = \ + zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) + data = {} + data['fc_feats'] = np.stack(fc_batch) + # merge att_feats + max_att_len = max([_.shape[0] for _ in att_batch]) + data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32') + for i in range(len(att_batch)): + data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i] + data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32') + for i in range(len(att_batch)): + data['att_masks'][i, :att_batch[i].shape[0]] = 1 + # set att_masks to None if attention features have same length + if data['att_masks'].sum() == data['att_masks'].size: + data['att_masks'] = None + + # if self.use_clipscore: + data['clip_vis_feats'] = np.stack(clip_vis_feat_batch) + + data['labels'] = np.vstack(label_batch) + # generate mask + nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels']))) + mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32') + for ix, row in enumerate(mask_batch): + row[:nonzeros[ix]] = 1 + data['masks'] = mask_batch + data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1) + data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1) + + data['gts'] = gts # all ground truth captions of each images + data['infos'] = infos + + data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor + + return data + + def __getitem__(self, ix): + """This function returns a tuple that is further passed to collate_fn + """ + if self.use_att: + att_feat = self.att_loader.get(str(self.info['images'][ix]['id'])) + # Reshape to K x C + att_feat = att_feat.reshape(-1, att_feat.shape[-1]) + if self.norm_att_feat: + att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True) + if self.use_box: + box_feat = self.box_loader.get(str(self.info['images'][ix]['id'])) + # devided by image width and height + x1,y1,x2,y2 = np.hsplit(box_feat, 4) + h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width'] + box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1?? + if self.norm_box_feat: + box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True) + att_feat = np.hstack([att_feat, box_feat]) + # sort the features by the size of boxes + att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True)) + else: + att_feat = np.zeros((0,0), dtype='float32') + if self.use_fc: + try: + fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id'])) + except: + # Use average of attention when there is no fc provided (For bottomup feature) + fc_feat = att_feat.mean(0) + else: + fc_feat = np.zeros((0), dtype='float32') + if hasattr(self, 'h5_label_file'): + seq = self.get_captions(ix, self.seq_per_img) + else: + seq = None + + # if self.use_clipscore: + clip_vis_feat = self.clipscore_loader.get( + str(self.info['images'][ix]['id'])) + + return (fc_feat, + att_feat, seq, + ix, clip_vis_feat) + + # return (fc_feat, + # att_feat, seq, + # ix) + + def __len__(self): + return len(self.info['images']) diff --git a/captioning/data/pth_loader_FineCapEval.py b/captioning/data/pth_loader_FineCapEval.py new file mode 100644 index 0000000000000000000000000000000000000000..388301edd763d54d95675ca2ed6eb502f77e1eb1 --- /dev/null +++ b/captioning/data/pth_loader_FineCapEval.py @@ -0,0 +1,334 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import h5py +from lmdbdict import lmdbdict +from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC +import os +import numpy as np +import numpy.random as npr +import random + +import torch +import torch.utils.data as data + +import multiprocessing +import six + +verbose = True +# import torch +# if torch.cuda.current_device() in [0, -1]: +if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': + verbose = False + +class HybridLoader: + """ + If db_path is a director, then use normal file loading + If lmdb, then load from lmdb + The loading method depend on extention. + + in_memory: if in_memory is True, we save all the features in memory + For individual np(y|z)s, we don't need to do that because the system will do this for us. + Should be useful for lmdb or h5. + (Copied this idea from vilbert) + """ + def __init__(self, db_path, ext, in_memory=False): + self.db_path = db_path + self.ext = ext + if self.ext == '.npy': + self.loader = lambda x: np.load(six.BytesIO(x)) + else: + self.loader = lambda x: np.load(six.BytesIO(x))['feat'] + if db_path.endswith('.lmdb'): + self.db_type = 'lmdb' + self.lmdb = lmdbdict(db_path, unsafe=True) + self.lmdb._key_dumps = DUMPS_FUNC['ascii'] + self.lmdb._value_loads = LOADS_FUNC['identity'] + elif db_path.endswith('.pth'): # Assume a key,value dictionary + self.db_type = 'pth' + self.feat_file = torch.load(db_path) + self.loader = lambda x: x + print('HybridLoader: ext is ignored') + elif db_path.endswith('h5'): + self.db_type = 'h5' + self.loader = lambda x: np.array(x).astype('float32') + else: + self.db_type = 'dir' + + self.in_memory = in_memory + if self.in_memory: + self.features = {} + + def get(self, key): + + if self.in_memory and key in self.features: + # We save f_input because we want to save the + # compressed bytes to save memory + f_input = self.features[key] + elif self.db_type == 'lmdb': + f_input = self.lmdb[key] + elif self.db_type == 'pth': + f_input = self.feat_file[key] + elif self.db_type == 'h5': + f_input = h5py.File(self.db_path, 'r')[key] + else: + f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read() + + if self.in_memory and key not in self.features: + self.features[key] = f_input + + # load image + feat = self.loader(f_input) + + return feat + +class CaptionDataset(data.Dataset): + + def get_vocab_size(self): + return self.vocab_size + + def get_vocab(self): + return self.ix_to_word + + def get_seq_length(self): + return self.seq_length + + def __init__(self, opt): + self.opt = opt + self.seq_per_img = opt.seq_per_img + + # feature related options + self.use_fc = getattr(opt, 'use_fc', True) + self.use_att = getattr(opt, 'use_att', True) + self.use_box = getattr(opt, 'use_box', 0) + self.norm_att_feat = getattr(opt, 'norm_att_feat', 0) + self.norm_box_feat = getattr(opt, 'norm_box_feat', 0) + + # load the json file which contains additional information about the dataset + if verbose: + print('DataLoader loading json file: ', opt.input_json) + self.info = json.load(open(self.opt.input_json)) + if 'ix_to_word' in self.info: + self.ix_to_word = self.info['ix_to_word'] + self.vocab_size = len(self.ix_to_word) + if verbose: + print('vocab size is ', self.vocab_size) + + # open the hdf5 file + if verbose: + print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5) + """ + Setting input_label_h5 to none is used when only doing generation. + For example, when you need to test on coco test set. + """ + if self.opt.input_label_h5 != 'none': + self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') + # load in the sequence data + seq_size = self.h5_label_file['labels'].shape + self.label = self.h5_label_file['labels'][:] + self.seq_length = seq_size[1] + if verbose: + print('max sequence length in data is', self.seq_length) + # load the pointers in full to RAM (should be small enough) + self.label_start_ix = self.h5_label_file['label_start_ix'][:] + self.label_end_ix = self.h5_label_file['label_end_ix'][:] + else: + self.seq_length = 1 + + self.data_in_memory = getattr(opt, 'data_in_memory', False) + self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory) + self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory) + self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory) + + self.use_clipscore = getattr(opt, 'use_clipscore', False) + if self.use_clipscore: + self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory) + + + self.num_images = len(self.info['images']) # self.label_start_ix.shape[0] + if verbose: + print('read %d image features' %(self.num_images)) + + # separate out indexes for each of the provided splits + self.split_ix = {'train': [], 'val': [], 'test': []} + for ix in range(len(self.info['images'])): + img = self.info['images'][ix] + if not 'split' in img: + self.split_ix['train'].append(ix) + self.split_ix['val'].append(ix) + self.split_ix['test'].append(ix) + elif img['split'] == 'train': + self.split_ix['train'].append(ix) + elif img['split'] == 'val': + self.split_ix['val'].append(ix) + elif img['split'] == 'test': + self.split_ix['test'].append(ix) + elif opt.train_only == 0: # restval + self.split_ix['train'].append(ix) + + if verbose: + print('assigned %d images to split train' %len(self.split_ix['train'])) + print('assigned %d images to split val' %len(self.split_ix['val'])) + print('assigned %d images to split test' %len(self.split_ix['test'])) + + def get_captions(self, ix, seq_per_img): + # fetch the sequence labels + ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 + ix2 = self.label_end_ix[ix] - 1 + ncap = ix2 - ix1 + 1 # number of captions available for this image + assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' + + if ncap < seq_per_img: + # we need to subsample (with replacement) + seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') + for q in range(seq_per_img): + ixl = random.randint(ix1,ix2) + seq[q, :] = self.label[ixl, :self.seq_length] + else: + ixl = random.randint(ix1, ix2 - seq_per_img + 1) + seq = self.label[ixl: ixl + seq_per_img, :self.seq_length] + + return seq + + def collate_func(self, batch): + seq_per_img = self.seq_per_img + + fc_batch = [] + att_batch = [] + label_batch = [] + + clip_vis_feat_batch = [] + + wrapped = False + + infos = [] + gts = [] + + for sample in batch: + # fetch image + if self.use_clipscore: + tmp_fc, tmp_att, tmp_seq, \ + ix, tmp_clip_vis_feat = sample + + clip_vis_feat_batch.append(tmp_clip_vis_feat) + else: + tmp_fc, tmp_att, tmp_seq, \ + ix = sample + + fc_batch.append(tmp_fc) + att_batch.append(tmp_att) + + tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int') + if hasattr(self, 'h5_label_file'): + # if there is ground truth + tmp_label[:, 1 : self.seq_length + 1] = tmp_seq + label_batch.append(tmp_label) + + # Used for reward evaluation + if hasattr(self, 'h5_label_file'): + # if there is ground truth + gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) + else: + gts.append([]) + + # record associated info as well + info_dict = {} + info_dict['ix'] = ix + info_dict['id'] = self.info['images'][ix]['id'] + info_dict['file_path'] = self.info['images'][ix].get('file_path', '') + infos.append(info_dict) + + # #sort by att_feat length + # fc_batch, att_batch, label_batch, gts, infos = \ + # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) + if self.use_clipscore: + fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \ + zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True)) + else: + fc_batch, att_batch, label_batch, gts, infos = \ + zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) + data = {} + data['fc_feats'] = np.stack(fc_batch) + # merge att_feats + max_att_len = max([_.shape[0] for _ in att_batch]) + data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32') + for i in range(len(att_batch)): + data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i] + data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32') + for i in range(len(att_batch)): + data['att_masks'][i, :att_batch[i].shape[0]] = 1 + # set att_masks to None if attention features have same length + if data['att_masks'].sum() == data['att_masks'].size: + data['att_masks'] = None + + if self.use_clipscore: + data['clip_vis_feats'] = np.stack(clip_vis_feat_batch) + + data['labels'] = np.vstack(label_batch) + # generate mask + nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels']))) + mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32') + for ix, row in enumerate(mask_batch): + row[:nonzeros[ix]] = 1 + data['masks'] = mask_batch + data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1) + data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1) + + data['gts'] = gts # all ground truth captions of each images + data['infos'] = infos + + data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor + + return data + + def __getitem__(self, ix): + """This function returns a tuple that is further passed to collate_fn + """ + if self.use_att: + att_feat = self.att_loader.get(str(self.info['images'][ix]['id'])) + # Reshape to K x C + att_feat = att_feat.reshape(-1, att_feat.shape[-1]) + if self.norm_att_feat: + att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True) + if self.use_box: + box_feat = self.box_loader.get(str(self.info['images'][ix]['id'])) + # devided by image width and height + x1,y1,x2,y2 = np.hsplit(box_feat, 4) + h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width'] + box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1?? + if self.norm_box_feat: + box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True) + att_feat = np.hstack([att_feat, box_feat]) + # sort the features by the size of boxes + att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True)) + else: + att_feat = np.zeros((0,0), dtype='float32') + if self.use_fc: + try: + fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id'])) + except: + # Use average of attention when there is no fc provided (For bottomup feature) + fc_feat = att_feat.mean(0) + else: + fc_feat = np.zeros((0), dtype='float32') + if hasattr(self, 'h5_label_file'): + seq = self.get_captions(ix, self.seq_per_img) + else: + seq = None + + if self.use_clipscore: + clip_vis_feat = self.clipscore_loader.get( + str(self.info['images'][ix]['id'])) + + return (fc_feat, + att_feat, seq, + ix, clip_vis_feat) + + return (fc_feat, + att_feat, seq, + ix) + + def __len__(self): + return len(self.info['images']) diff --git a/captioning/models/AoAModel.py b/captioning/models/AoAModel.py new file mode 100644 index 0000000000000000000000000000000000000000..7925549fc7d134a98f8e12b6b4b741b03ab63c78 --- /dev/null +++ b/captioning/models/AoAModel.py @@ -0,0 +1,228 @@ +# Implementation for paper 'Attention on Attention for Image Captioning' +# https://arxiv.org/abs/1908.06954 + +# RT: Code from original author's repo: https://github.com/husthuaan/AoANet/ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .AttModel import pack_wrapper, AttModel, Attention +from .TransformerModel import LayerNorm, attention, clones, SublayerConnection, PositionwiseFeedForward + +class MultiHeadedDotAttention(nn.Module): + def __init__(self, h, d_model, dropout=0.1, scale=1, project_k_v=1, use_output_layer=1, do_aoa=0, norm_q=0, dropout_aoa=0.3): + super(MultiHeadedDotAttention, self).__init__() + assert d_model * scale % h == 0 + # We assume d_v always equals d_k + self.d_k = d_model * scale // h + self.h = h + + # Do we need to do linear projections on K and V? + self.project_k_v = project_k_v + + # normalize the query? + if norm_q: + self.norm = LayerNorm(d_model) + else: + self.norm = lambda x:x + self.linears = clones(nn.Linear(d_model, d_model * scale), 1 + 2 * project_k_v) + + # output linear layer after the multi-head attention? + self.output_layer = nn.Linear(d_model * scale, d_model) + + # apply aoa after attention? + self.use_aoa = do_aoa + if self.use_aoa: + self.aoa_layer = nn.Sequential(nn.Linear((1 + scale) * d_model, 2 * d_model), nn.GLU()) + # dropout to the input of AoA layer + if dropout_aoa > 0: + self.dropout_aoa = nn.Dropout(p=dropout_aoa) + else: + self.dropout_aoa = lambda x:x + + if self.use_aoa or not use_output_layer: + # AoA doesn't need the output linear layer + del self.output_layer + self.output_layer = lambda x:x + + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward(self, query, value, key, mask=None): + if mask is not None: + if len(mask.size()) == 2: + mask = mask.unsqueeze(-2) + # Same mask applied to all h heads. + mask = mask.unsqueeze(1) + + single_query = 0 + if len(query.size()) == 2: + single_query = 1 + query = query.unsqueeze(1) + + nbatches = query.size(0) + + query = self.norm(query) + + # Do all the linear projections in batch from d_model => h x d_k + if self.project_k_v == 0: + query_ = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + key_ = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + value_ = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + else: + query_, key_, value_ = \ + [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + for l, x in zip(self.linears, (query, key, value))] + + # Apply attention on all the projected vectors in batch. + x, self.attn = attention(query_, key_, value_, mask=mask, + dropout=self.dropout) + + # "Concat" using a view + x = x.transpose(1, 2).contiguous() \ + .view(nbatches, -1, self.h * self.d_k) + + if self.use_aoa: + # Apply AoA + x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1))) + x = self.output_layer(x) + + if single_query: + query = query.squeeze(1) + x = x.squeeze(1) + return x + +class AoA_Refiner_Layer(nn.Module): + def __init__(self, size, self_attn, feed_forward, dropout): + super(AoA_Refiner_Layer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.use_ff = 0 + if self.feed_forward is not None: + self.use_ff = 1 + self.sublayer = clones(SublayerConnection(size, dropout), 1+self.use_ff) + self.size = size + + def forward(self, x, mask): + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) + return self.sublayer[-1](x, self.feed_forward) if self.use_ff else x + +class AoA_Refiner_Core(nn.Module): + def __init__(self, opt): + super(AoA_Refiner_Core, self).__init__() + attn = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=1, scale=opt.multi_head_scale, do_aoa=opt.refine_aoa, norm_q=0, dropout_aoa=getattr(opt, 'dropout_aoa', 0.3)) + layer = AoA_Refiner_Layer(opt.rnn_size, attn, PositionwiseFeedForward(opt.rnn_size, 2048, 0.1) if opt.use_ff else None, 0.1) + self.layers = clones(layer, 6) + self.norm = LayerNorm(layer.size) + + def forward(self, x, mask): + for layer in self.layers: + x = layer(x, mask) + return self.norm(x) + +class AoA_Decoder_Core(nn.Module): + def __init__(self, opt): + super(AoA_Decoder_Core, self).__init__() + self.drop_prob_lm = opt.drop_prob_lm + self.d_model = opt.rnn_size + self.use_multi_head = opt.use_multi_head + self.multi_head_scale = opt.multi_head_scale + self.use_ctx_drop = getattr(opt, 'ctx_drop', 0) + self.out_res = getattr(opt, 'out_res', 0) + self.decoder_type = getattr(opt, 'decoder_type', 'AoA') + self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1 + self.out_drop = nn.Dropout(self.drop_prob_lm) + + if self.decoder_type == 'AoA': + # AoA layer + self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, 2 * opt.rnn_size), nn.GLU()) + elif self.decoder_type == 'LSTM': + # LSTM layer + self.att2ctx = nn.LSTMCell(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size) + else: + # Base linear layer + self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size), nn.ReLU()) + + # if opt.use_multi_head == 1: # TODO, not implemented for now + # self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale) + if opt.use_multi_head == 2: + self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1) + else: + self.attention = Attention(opt) + + if self.use_ctx_drop: + self.ctx_drop = nn.Dropout(self.drop_prob_lm) + else: + self.ctx_drop = lambda x :x + + def forward(self, xt, mean_feats, att_feats, p_att_feats, state, att_masks=None): + # state[0][1] is the context vector at the last step + h_att, c_att = self.att_lstm(torch.cat([xt, mean_feats + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0])) + + if self.use_multi_head == 2: + att = self.attention(h_att, p_att_feats.narrow(2, 0, self.multi_head_scale * self.d_model), p_att_feats.narrow(2, self.multi_head_scale * self.d_model, self.multi_head_scale * self.d_model), att_masks) + else: + att = self.attention(h_att, att_feats, p_att_feats, att_masks) + + ctx_input = torch.cat([att, h_att], 1) + if self.decoder_type == 'LSTM': + output, c_logic = self.att2ctx(ctx_input, (state[0][1], state[1][1])) + state = (torch.stack((h_att, output)), torch.stack((c_att, c_logic))) + else: + output = self.att2ctx(ctx_input) + # save the context vector to state[0][1] + state = (torch.stack((h_att, output)), torch.stack((c_att, state[1][1]))) + + if self.out_res: + # add residual connection + output = output + h_att + + output = self.out_drop(output) + return output, state + +class AoAModel(AttModel): + def __init__(self, opt): + super(AoAModel, self).__init__(opt) + self.num_layers = 2 + # mean pooling + self.use_mean_feats = getattr(opt, 'mean_feats', 1) + if opt.use_multi_head == 2: + del self.ctx2att + self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.multi_head_scale * opt.rnn_size) + + if self.use_mean_feats: + del self.fc_embed + if opt.refine: + self.refiner = AoA_Refiner_Core(opt) + else: + self.refiner = lambda x,y : x + self.core = AoA_Decoder_Core(opt) + + self.d_model = getattr(opt, 'd_model', opt.input_encoding_size) + + + def _prepare_feature(self, fc_feats, att_feats, att_masks): + att_feats, att_masks = self.clip_att(att_feats, att_masks) + + # embed att feats + att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) + att_feats = self.refiner(att_feats, att_masks) + + if self.use_mean_feats: + # meaning pooling + if att_masks is None: + mean_feats = torch.mean(att_feats, dim=1) + else: + mean_feats = (torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1)) + else: + mean_feats = self.fc_embed(fc_feats) + + # Project the attention feats first to reduce memory and computation. + p_att_feats = self.ctx2att(att_feats) + + return mean_feats, att_feats, p_att_feats, att_masks \ No newline at end of file diff --git a/captioning/models/AttEnsemble.py b/captioning/models/AttEnsemble.py new file mode 100644 index 0000000000000000000000000000000000000000..19e88e2ace19e4a73fe6fcb1024bd584d77a38fa --- /dev/null +++ b/captioning/models/AttEnsemble.py @@ -0,0 +1,90 @@ +# This file is the implementation for ensemble evaluation. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import * + +from .CaptionModel import CaptionModel +from .AttModel import pack_wrapper, AttModel + +class AttEnsemble(AttModel): + def __init__(self, models, weights=None): + CaptionModel.__init__(self) + # super(AttEnsemble, self).__init__() + + self.models = nn.ModuleList(models) + self.vocab_size = models[0].vocab_size + self.seq_length = models[0].seq_length + self.bad_endings_ix = models[0].bad_endings_ix + self.ss_prob = 0 + weights = weights or [1.0] * len(self.models) + self.register_buffer('weights', torch.tensor(weights)) + + def init_hidden(self, batch_size): + state = [m.init_hidden(batch_size) for m in self.models] + return self.pack_state(state) + + def pack_state(self, state): + self.state_lengths = [len(_) for _ in state] + return sum([list(_) for _ in state], []) + + def unpack_state(self, state): + out = [] + for l in self.state_lengths: + out.append(state[:l]) + state = state[l:] + return out + + def embed(self, it): + return [m.embed(it) for m in self.models] + + def core(self, *args): + return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))]) + + def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state, output_logsoftmax=1): + # 'it' contains a word index + xt = self.embed(it) + + state = self.unpack_state(state) + output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks) + logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mul(self.weights).div(self.weights.sum()).sum(-1).log() + + return logprobs, self.pack_state(state) + + def _prepare_feature(self, *args): + return tuple(zip(*[m._prepare_feature(*args) for m in self.models])) + + def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): + beam_size = opt.get('beam_size', 10) + batch_size = fc_feats.size(0) + + fc_feats, att_feats, p_att_feats, att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) + + assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' + seq = torch.LongTensor(self.seq_length, batch_size).zero_() + seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1) + # lets process every image independently for now, for simplicity + + self.done_beams = [[] for _ in range(batch_size)] + for k in range(batch_size): + state = self.init_hidden(beam_size) + tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)] + tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] + tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] + tmp_att_masks = [att_masks[i][k:k+1].expand(*((beam_size,)+att_masks[i].size()[1:])).contiguous() if att_masks[i] is not None else att_masks[i] for i,m in enumerate(self.models)] + + it = fc_feats[0].data.new(beam_size).long().zero_() + logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) + + self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) + seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score + seqLogprobs[:, k] = self.done_beams[k][0]['logps'] + # return the samples and their log likelihoods + return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) + # return the samples and their log likelihoods diff --git a/captioning/models/AttModel.py b/captioning/models/AttModel.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc4e5b7a78c4affbfba4044ca8c96c30b26e36a --- /dev/null +++ b/captioning/models/AttModel.py @@ -0,0 +1,969 @@ +# This file contains Att2in2, AdaAtt, AdaAttMO, UpDown model + +# AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning +# https://arxiv.org/abs/1612.01887 +# AdaAttMO is a modified version with maxout lstm + +# Att2in is from Self-critical Sequence Training for Image Captioning +# https://arxiv.org/abs/1612.00563 +# In this file we only have Att2in2, which is a slightly different version of att2in, +# in which the img feature embedding and word embedding is the same as what in adaatt. + +# UpDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA +# https://arxiv.org/abs/1707.07998 +# However, it may not be identical to the author's architecture. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import utils +from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence + +from .CaptionModel import CaptionModel + +bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] +bad_endings += ['the'] + +def sort_pack_padded_sequence(input, lengths): + sorted_lengths, indices = torch.sort(lengths, descending=True) + # tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True) + tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True) + inv_ix = indices.clone() + inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix) + return tmp, inv_ix + +def pad_unsort_packed_sequence(input, inv_ix): + tmp, _ = pad_packed_sequence(input, batch_first=True) + tmp = tmp[inv_ix] + return tmp + +def pack_wrapper(module, att_feats, att_masks): + if att_masks is not None: + packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1)) + return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) + else: + return module(att_feats) + +class AttModel(CaptionModel): + def __init__(self, opt): + super(AttModel, self).__init__() + self.vocab_size = opt.vocab_size + self.input_encoding_size = opt.input_encoding_size + #self.rnn_type = opt.rnn_type + self.rnn_size = opt.rnn_size + self.num_layers = opt.num_layers + self.drop_prob_lm = opt.drop_prob_lm + self.seq_length = getattr(opt, 'max_length', 20) or opt.seq_length # maximum sample length + self.fc_feat_size = opt.fc_feat_size + self.att_feat_size = opt.att_feat_size + self.att_hid_size = opt.att_hid_size + + self.bos_idx = getattr(opt, 'bos_idx', 0) + self.eos_idx = getattr(opt, 'eos_idx', 0) + self.pad_idx = getattr(opt, 'pad_idx', 0) + + self.use_bn = getattr(opt, 'use_bn', 0) + + self.ss_prob = 0.0 # Schedule sampling probability + + self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size), + nn.ReLU(), + nn.Dropout(self.drop_prob_lm)) + self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), + nn.ReLU(), + nn.Dropout(self.drop_prob_lm)) + self.att_embed = nn.Sequential(*( + ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ + (nn.Linear(self.att_feat_size, self.rnn_size), + nn.ReLU(), + nn.Dropout(self.drop_prob_lm))+ + ((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ()))) + + self.logit_layers = getattr(opt, 'logit_layers', 1) + if self.logit_layers == 1: + self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) + else: + self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)] + self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)])) + self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) + + # For remove bad endding + self.vocab = opt.vocab + self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings] + + def init_hidden(self, bsz): + weight = self.logit.weight \ + if hasattr(self.logit, "weight") \ + else self.logit[0].weight + return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), + weight.new_zeros(self.num_layers, bsz, self.rnn_size)) + + def clip_att(self, att_feats, att_masks): + # Clip the length of att_masks and att_feats to the maximum length + if att_masks is not None: + max_len = att_masks.data.long().sum(1).max() + att_feats = att_feats[:, :max_len].contiguous() + att_masks = att_masks[:, :max_len].contiguous() + return att_feats, att_masks + + def _prepare_feature(self, fc_feats, att_feats, att_masks): + att_feats, att_masks = self.clip_att(att_feats, att_masks) + + # embed fc and att feats + fc_feats = self.fc_embed(fc_feats) + att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) + + # Project the attention feats first to reduce memory and computation comsumptions. + p_att_feats = self.ctx2att(att_feats) + + return fc_feats, att_feats, p_att_feats, att_masks + + def _forward(self, fc_feats, att_feats, seq, att_masks=None): + batch_size = fc_feats.size(0) + if seq.ndim == 3: # B * seq_per_img * seq_len + seq = seq.reshape(-1, seq.shape[2]) + seq_per_img = seq.shape[0] // batch_size + state = self.init_hidden(batch_size*seq_per_img) + + outputs = fc_feats.new_zeros(batch_size*seq_per_img, seq.size(1), self.vocab_size+1) + + # Prepare the features + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) + # pp_att_feats is used for attention, we cache it in advance to reduce computation cost + + if seq_per_img > 1: + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(seq_per_img, + [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks] + ) + + for i in range(seq.size(1)): + if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample + sample_prob = fc_feats.new(batch_size*seq_per_img).uniform_(0, 1) + sample_mask = sample_prob < self.ss_prob + if sample_mask.sum() == 0: + it = seq[:, i].clone() + else: + sample_ind = sample_mask.nonzero().view(-1) + it = seq[:, i].data.clone() + prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1) + it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) + else: + it = seq[:, i].clone() + # break if all the sequences end + if i >= 1 and seq[:, i].sum() == 0: + break + + output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) + outputs[:, i] = output + + return outputs + + def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1): + # 'it' contains a word index + xt = self.embed(it) + + output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks) + if output_logsoftmax: + logprobs = F.log_softmax(self.logit(output), dim=1) + else: + logprobs = self.logit(output) + + return logprobs, state + + def _old_sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): + beam_size = opt.get('beam_size', 10) + group_size = opt.get('group_size', 1) + sample_n = opt.get('sample_n', 10) + # when sample_n == beam_size then each beam is a sample. + assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' + batch_size = fc_feats.size(0) + + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) + + assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' + seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) + seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) + # lets process every image independently for now, for simplicity + + self.done_beams = [[] for _ in range(batch_size)] + for k in range(batch_size): + state = self.init_hidden(beam_size) + tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks = utils.repeat_tensors(beam_size, + [p_fc_feats[k:k+1], p_att_feats[k:k+1], pp_att_feats[k:k+1], p_att_masks[k:k+1] if att_masks is not None else None] + ) + + for t in range(1): + if t == 0: # input + it = fc_feats.new_full([beam_size], self.bos_idx, dtype=torch.long) + + logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) + + self.done_beams[k] = self.old_beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) + if sample_n == beam_size: + for _n in range(sample_n): + seq[k*sample_n+_n, :] = self.done_beams[k][_n]['seq'] + seqLogprobs[k*sample_n+_n, :] = self.done_beams[k][_n]['logps'] + else: + seq[k, :] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score + seqLogprobs[k, :] = self.done_beams[k][0]['logps'] + # return the samples and their log likelihoods + return seq, seqLogprobs + + + def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): + beam_size = opt.get('beam_size', 10) + group_size = opt.get('group_size', 1) + sample_n = opt.get('sample_n', 10) + # when sample_n == beam_size then each beam is a sample. + assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' + batch_size = fc_feats.size(0) + + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) + + assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' + seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) + seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) + # lets process every image independently for now, for simplicity + + self.done_beams = [[] for _ in range(batch_size)] + + state = self.init_hidden(batch_size) + + # first step, feed bos + it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) + logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) + + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size, + [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks] + ) + self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt) + for k in range(batch_size): + if sample_n == beam_size: + for _n in range(sample_n): + seq_len = self.done_beams[k][_n]['seq'].shape[0] + seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq'] + seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps'] + else: + seq_len = self.done_beams[k][0]['seq'].shape[0] + seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score + seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps'] + # return the samples and their log likelihoods + return seq, seqLogprobs + + def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): + + sample_method = opt.get('sample_method', 'greedy') + beam_size = opt.get('beam_size', 1) + temperature = opt.get('temperature', 1.0) + sample_n = int(opt.get('sample_n', 1)) + group_size = opt.get('group_size', 1) + output_logsoftmax = opt.get('output_logsoftmax', 1) + decoding_constraint = opt.get('decoding_constraint', 0) + block_trigrams = opt.get('block_trigrams', 0) + remove_bad_endings = opt.get('remove_bad_endings', 0) + if beam_size > 1 and sample_method in ['greedy', 'beam_search']: + return self._sample_beam(fc_feats, att_feats, att_masks, opt) + if group_size > 1: + return self._diverse_sample(fc_feats, att_feats, att_masks, opt) + + batch_size = fc_feats.size(0) + state = self.init_hidden(batch_size*sample_n) + + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) + + if sample_n > 1: + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n, + [p_fc_feats, p_att_feats, pp_att_feats, p_att_masks] + ) + + trigrams = [] # will be a list of batch_size dictionaries + + seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) + seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) + for t in range(self.seq_length + 1): + if t == 0: # input + it = fc_feats.new_full([batch_size*sample_n], self.bos_idx, dtype=torch.long) + + logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, output_logsoftmax=output_logsoftmax) + + if decoding_constraint and t > 0: + tmp = logprobs.new_zeros(logprobs.size()) + tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) + logprobs = logprobs + tmp + + if remove_bad_endings and t > 0: + tmp = logprobs.new_zeros(logprobs.size()) + prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix) + # Make it impossible to generate bad_endings + tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf') + logprobs = logprobs + tmp + + # Mess with trigrams + # Copy from https://github.com/lukemelas/image-paragraph-captioning + if block_trigrams and t >= 3: + # Store trigram generated at last step + prev_two_batch = seq[:,t-3:t-1] + for i in range(batch_size): # = seq.size(0) + prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) + current = seq[i][t-1] + if t == 3: # initialize + trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} + elif t > 3: + if prev_two in trigrams[i]: # add to list + trigrams[i][prev_two].append(current) + else: # create list + trigrams[i][prev_two] = [current] + # Block used trigrams at next step + prev_two_batch = seq[:,t-2:t] + mask = torch.zeros(logprobs.size(), requires_grad=False).to(logprobs.device) # batch_size x vocab_size + for i in range(batch_size): + prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) + if prev_two in trigrams[i]: + for j in trigrams[i][prev_two]: + mask[i,j] += 1 + # Apply mask to log probs + #logprobs = logprobs - (mask * 1e9) + alpha = 2.0 # = 4 + logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) + + # sample the next word + if t == self.seq_length: # skip if we achieve maximum length + break + it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature) + + # stop when all finished + if t == 0: + unfinished = it != self.eos_idx + else: + it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0 + logprobs = logprobs * unfinished.unsqueeze(1).to(logprobs) + unfinished = unfinished & (it != self.eos_idx) + seq[:,t] = it + seqLogprobs[:,t] = logprobs + # quit loop if all sequences have finished + if unfinished.sum() == 0: + break + + return seq, seqLogprobs + + def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}): + + sample_method = opt.get('sample_method', 'greedy') + beam_size = opt.get('beam_size', 1) + temperature = opt.get('temperature', 1.0) + group_size = opt.get('group_size', 1) + diversity_lambda = opt.get('diversity_lambda', 0.5) + decoding_constraint = opt.get('decoding_constraint', 0) + block_trigrams = opt.get('block_trigrams', 0) + remove_bad_endings = opt.get('remove_bad_endings', 0) + + batch_size = fc_feats.size(0) + state = self.init_hidden(batch_size) + + p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) + + trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries + + seq_table = [fc_feats.new_full((batch_size, self.seq_length), self.pad_idx, dtype=torch.long) for _ in range(group_size)] + seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.seq_length) for _ in range(group_size)] + state_table = [self.init_hidden(batch_size) for _ in range(group_size)] + + for tt in range(self.seq_length + group_size): + for divm in range(group_size): + t = tt - divm + seq = seq_table[divm] + seqLogprobs = seqLogprobs_table[divm] + trigrams = trigrams_table[divm] + if t >= 0 and t <= self.seq_length-1: + if t == 0: # input + it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) + else: + it = seq[:, t-1] # changed + + logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state_table[divm]) # changed + logprobs = F.log_softmax(logprobs / temperature, dim=-1) + + # Add diversity + if divm > 0: + unaug_logprobs = logprobs.clone() + for prev_choice in range(divm): + prev_decisions = seq_table[prev_choice][:, t] + logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda + + if decoding_constraint and t > 0: + tmp = logprobs.new_zeros(logprobs.size()) + tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) + logprobs = logprobs + tmp + + if remove_bad_endings and t > 0: + tmp = logprobs.new_zeros(logprobs.size()) + prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix) + # Impossible to generate remove_bad_endings + tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf') + logprobs = logprobs + tmp + + # Mess with trigrams + if block_trigrams and t >= 3: + # Store trigram generated at last step + prev_two_batch = seq[:,t-3:t-1] + for i in range(batch_size): # = seq.size(0) + prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) + current = seq[i][t-1] + if t == 3: # initialize + trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} + elif t > 3: + if prev_two in trigrams[i]: # add to list + trigrams[i][prev_two].append(current) + else: # create list + trigrams[i][prev_two] = [current] + # Block used trigrams at next step + prev_two_batch = seq[:,t-2:t] + mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size + for i in range(batch_size): + prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) + if prev_two in trigrams[i]: + for j in trigrams[i][prev_two]: + mask[i,j] += 1 + # Apply mask to log probs + #logprobs = logprobs - (mask * 1e9) + alpha = 2.0 # = 4 + logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) + + it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1) + + # stop when all finished + if t == 0: + unfinished = it != self.eos_idx + else: + unfinished = (seq[:,t-1] != self.pad_idx) & (seq[:,t-1] != self.eos_idx) + it[~unfinished] = self.pad_idx + unfinished = unfinished & (it != self.eos_idx) # changed + seq[:,t] = it + seqLogprobs[:,t] = sampleLogprobs.view(-1) + + return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, 1).reshape(batch_size * group_size, -1) + +class AdaAtt_lstm(nn.Module): + def __init__(self, opt, use_maxout=True): + super(AdaAtt_lstm, self).__init__() + self.input_encoding_size = opt.input_encoding_size + #self.rnn_type = opt.rnn_type + self.rnn_size = opt.rnn_size + self.num_layers = opt.num_layers + self.drop_prob_lm = opt.drop_prob_lm + self.fc_feat_size = opt.fc_feat_size + self.att_feat_size = opt.att_feat_size + self.att_hid_size = opt.att_hid_size + + self.use_maxout = use_maxout + + # Build a LSTM + self.w2h = nn.Linear(self.input_encoding_size, (4+(use_maxout==True)) * self.rnn_size) + self.v2h = nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) + + self.i2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers - 1)]) + self.h2h = nn.ModuleList([nn.Linear(self.rnn_size, (4+(use_maxout==True)) * self.rnn_size) for _ in range(self.num_layers)]) + + # Layers for getting the fake region + if self.num_layers == 1: + self.r_w2h = nn.Linear(self.input_encoding_size, self.rnn_size) + self.r_v2h = nn.Linear(self.rnn_size, self.rnn_size) + else: + self.r_i2h = nn.Linear(self.rnn_size, self.rnn_size) + self.r_h2h = nn.Linear(self.rnn_size, self.rnn_size) + + + def forward(self, xt, img_fc, state): + + hs = [] + cs = [] + for L in range(self.num_layers): + # c,h from previous timesteps + prev_h = state[0][L] + prev_c = state[1][L] + # the input to this layer + if L == 0: + x = xt + i2h = self.w2h(x) + self.v2h(img_fc) + else: + x = hs[-1] + x = F.dropout(x, self.drop_prob_lm, self.training) + i2h = self.i2h[L-1](x) + + all_input_sums = i2h+self.h2h[L](prev_h) + + sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) + sigmoid_chunk = torch.sigmoid(sigmoid_chunk) + # decode the gates + in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) + forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) + out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) + # decode the write inputs + if not self.use_maxout: + in_transform = torch.tanh(all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size)) + else: + in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + in_transform = torch.max(\ + in_transform.narrow(1, 0, self.rnn_size), + in_transform.narrow(1, self.rnn_size, self.rnn_size)) + # perform the LSTM update + next_c = forget_gate * prev_c + in_gate * in_transform + # gated cells form the output + tanh_nex_c = torch.tanh(next_c) + next_h = out_gate * tanh_nex_c + if L == self.num_layers-1: + if L == 0: + i2h = self.r_w2h(x) + self.r_v2h(img_fc) + else: + i2h = self.r_i2h(x) + n5 = i2h+self.r_h2h(prev_h) + fake_region = torch.sigmoid(n5) * tanh_nex_c + + cs.append(next_c) + hs.append(next_h) + + # set up the decoder + top_h = hs[-1] + top_h = F.dropout(top_h, self.drop_prob_lm, self.training) + fake_region = F.dropout(fake_region, self.drop_prob_lm, self.training) + + state = (torch.cat([_.unsqueeze(0) for _ in hs], 0), + torch.cat([_.unsqueeze(0) for _ in cs], 0)) + return top_h, fake_region, state + +class AdaAtt_attention(nn.Module): + def __init__(self, opt): + super(AdaAtt_attention, self).__init__() + self.input_encoding_size = opt.input_encoding_size + #self.rnn_type = opt.rnn_type + self.rnn_size = opt.rnn_size + self.drop_prob_lm = opt.drop_prob_lm + self.att_hid_size = opt.att_hid_size + + # fake region embed + self.fr_linear = nn.Sequential( + nn.Linear(self.rnn_size, self.input_encoding_size), + nn.ReLU(), + nn.Dropout(self.drop_prob_lm)) + self.fr_embed = nn.Linear(self.input_encoding_size, self.att_hid_size) + + # h out embed + self.ho_linear = nn.Sequential( + nn.Linear(self.rnn_size, self.input_encoding_size), + nn.Tanh(), + nn.Dropout(self.drop_prob_lm)) + self.ho_embed = nn.Linear(self.input_encoding_size, self.att_hid_size) + + self.alpha_net = nn.Linear(self.att_hid_size, 1) + self.att2h = nn.Linear(self.rnn_size, self.rnn_size) + + def forward(self, h_out, fake_region, conv_feat, conv_feat_embed, att_masks=None): + + # View into three dimensions + att_size = conv_feat.numel() // conv_feat.size(0) // self.rnn_size + conv_feat = conv_feat.view(-1, att_size, self.rnn_size) + conv_feat_embed = conv_feat_embed.view(-1, att_size, self.att_hid_size) + + # view neighbor from bach_size * neighbor_num x rnn_size to bach_size x rnn_size * neighbor_num + fake_region = self.fr_linear(fake_region) + fake_region_embed = self.fr_embed(fake_region) + + h_out_linear = self.ho_linear(h_out) + h_out_embed = self.ho_embed(h_out_linear) + + txt_replicate = h_out_embed.unsqueeze(1).expand(h_out_embed.size(0), att_size + 1, h_out_embed.size(1)) + + img_all = torch.cat([fake_region.view(-1,1,self.input_encoding_size), conv_feat], 1) + img_all_embed = torch.cat([fake_region_embed.view(-1,1,self.input_encoding_size), conv_feat_embed], 1) + + hA = torch.tanh(img_all_embed + txt_replicate) + hA = F.dropout(hA,self.drop_prob_lm, self.training) + + hAflat = self.alpha_net(hA.view(-1, self.att_hid_size)) + PI = F.softmax(hAflat.view(-1, att_size + 1), dim=1) + + if att_masks is not None: + att_masks = att_masks.view(-1, att_size) + PI = PI * torch.cat([att_masks[:,:1], att_masks], 1) # assume one one at the first time step. + PI = PI / PI.sum(1, keepdim=True) + + visAtt = torch.bmm(PI.unsqueeze(1), img_all) + visAttdim = visAtt.squeeze(1) + + atten_out = visAttdim + h_out_linear + + h = torch.tanh(self.att2h(atten_out)) + h = F.dropout(h, self.drop_prob_lm, self.training) + return h + +class AdaAttCore(nn.Module): + def __init__(self, opt, use_maxout=False): + super(AdaAttCore, self).__init__() + self.lstm = AdaAtt_lstm(opt, use_maxout) + self.attention = AdaAtt_attention(opt) + + def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): + h_out, p_out, state = self.lstm(xt, fc_feats, state) + atten_out = self.attention(h_out, p_out, att_feats, p_att_feats, att_masks) + return atten_out, state + +class UpDownCore(nn.Module): + def __init__(self, opt, use_maxout=False): + super(UpDownCore, self).__init__() + self.drop_prob_lm = opt.drop_prob_lm + + self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) # we, fc, h^2_t-1 + self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v + self.attention = Attention(opt) + + def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): + prev_h = state[0][-1] + att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1) + + h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0])) + + att = self.attention(h_att, att_feats, p_att_feats, att_masks) + + lang_lstm_input = torch.cat([att, h_att], 1) + # lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ????? + + h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1])) + + output = F.dropout(h_lang, self.drop_prob_lm, self.training) + state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang])) + + return output, state + + +############################################################################ +# Notice: +# StackAtt and DenseAtt are models that I randomly designed. +# They are not related to any paper. +############################################################################ + +from .FCModel import LSTMCore +class StackAttCore(nn.Module): + def __init__(self, opt, use_maxout=False): + super(StackAttCore, self).__init__() + self.drop_prob_lm = opt.drop_prob_lm + + # self.att0 = Attention(opt) + self.att1 = Attention(opt) + self.att2 = Attention(opt) + + opt_input_encoding_size = opt.input_encoding_size + opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size + self.lstm0 = LSTMCore(opt) # att_feat + word_embedding + opt.input_encoding_size = opt.rnn_size * 2 + self.lstm1 = LSTMCore(opt) + self.lstm2 = LSTMCore(opt) + opt.input_encoding_size = opt_input_encoding_size + + # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size) + self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size) + + def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): + # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks) + h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]]) + att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks) + h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]]) + att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks) + h_2, state_2 = self.lstm2(torch.cat([h_1,att_res_2],1), [state[0][2:3], state[1][2:3]]) + + return h_2, [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)] + +class DenseAttCore(nn.Module): + def __init__(self, opt, use_maxout=False): + super(DenseAttCore, self).__init__() + self.drop_prob_lm = opt.drop_prob_lm + + # self.att0 = Attention(opt) + self.att1 = Attention(opt) + self.att2 = Attention(opt) + + opt_input_encoding_size = opt.input_encoding_size + opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size + self.lstm0 = LSTMCore(opt) # att_feat + word_embedding + opt.input_encoding_size = opt.rnn_size * 2 + self.lstm1 = LSTMCore(opt) + self.lstm2 = LSTMCore(opt) + opt.input_encoding_size = opt_input_encoding_size + + # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size) + self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size) + + # fuse h_0 and h_1 + self.fusion1 = nn.Sequential(nn.Linear(opt.rnn_size*2, opt.rnn_size), + nn.ReLU(), + nn.Dropout(opt.drop_prob_lm)) + # fuse h_0, h_1 and h_2 + self.fusion2 = nn.Sequential(nn.Linear(opt.rnn_size*3, opt.rnn_size), + nn.ReLU(), + nn.Dropout(opt.drop_prob_lm)) + + def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): + # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks) + h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]]) + att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks) + h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]]) + att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks) + h_2, state_2 = self.lstm2(torch.cat([self.fusion1(torch.cat([h_0, h_1], 1)),att_res_2],1), [state[0][2:3], state[1][2:3]]) + + return self.fusion2(torch.cat([h_0, h_1, h_2], 1)), [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)] + +class Attention(nn.Module): + def __init__(self, opt): + super(Attention, self).__init__() + self.rnn_size = opt.rnn_size + self.att_hid_size = opt.att_hid_size + + self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) + self.alpha_net = nn.Linear(self.att_hid_size, 1) + + def forward(self, h, att_feats, p_att_feats, att_masks=None): + # The p_att_feats here is already projected + att_size = att_feats.numel() // att_feats.size(0) // att_feats.size(-1) + att = p_att_feats.view(-1, att_size, self.att_hid_size) + + att_h = self.h2att(h) # batch * att_hid_size + att_h = att_h.unsqueeze(1).expand_as(att) # batch * att_size * att_hid_size + dot = att + att_h # batch * att_size * att_hid_size + dot = torch.tanh(dot) # batch * att_size * att_hid_size + dot = dot.view(-1, self.att_hid_size) # (batch * att_size) * att_hid_size + dot = self.alpha_net(dot) # (batch * att_size) * 1 + dot = dot.view(-1, att_size) # batch * att_size + + weight = F.softmax(dot, dim=1) # batch * att_size + if att_masks is not None: + weight = weight * att_masks.view(-1, att_size).to(weight) + weight = weight / weight.sum(1, keepdim=True) # normalize to 1 + att_feats_ = att_feats.view(-1, att_size, att_feats.size(-1)) # batch * att_size * att_feat_size + att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size + + return att_res + +class Att2in2Core(nn.Module): + def __init__(self, opt): + super(Att2in2Core, self).__init__() + self.input_encoding_size = opt.input_encoding_size + #self.rnn_type = opt.rnn_type + self.rnn_size = opt.rnn_size + #self.num_layers = opt.num_layers + self.drop_prob_lm = opt.drop_prob_lm + self.fc_feat_size = opt.fc_feat_size + self.att_feat_size = opt.att_feat_size + self.att_hid_size = opt.att_hid_size + + # Build a LSTM + self.a2c = nn.Linear(self.rnn_size, 2 * self.rnn_size) + self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) + self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) + self.dropout = nn.Dropout(self.drop_prob_lm) + + self.attention = Attention(opt) + + def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): + att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks) + + all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) + sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) + sigmoid_chunk = torch.sigmoid(sigmoid_chunk) + in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) + forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) + out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) + + in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + \ + self.a2c(att_res) + in_transform = torch.max(\ + in_transform.narrow(1, 0, self.rnn_size), + in_transform.narrow(1, self.rnn_size, self.rnn_size)) + next_c = forget_gate * state[1][-1] + in_gate * in_transform + next_h = out_gate * torch.tanh(next_c) + + output = self.dropout(next_h) + state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) + return output, state + +class Att2inCore(Att2in2Core): + def __init__(self, opt): + super(Att2inCore, self).__init__(opt) + del self.a2c + self.a2c = nn.Linear(self.att_feat_size, 2 * self.rnn_size) + +""" +Note this is my attempt to replicate att2all model in self-critical paper. +However, this is not a correct replication actually. Will fix it. +""" +class Att2all2Core(nn.Module): + def __init__(self, opt): + super(Att2all2Core, self).__init__() + self.input_encoding_size = opt.input_encoding_size + #self.rnn_type = opt.rnn_type + self.rnn_size = opt.rnn_size + #self.num_layers = opt.num_layers + self.drop_prob_lm = opt.drop_prob_lm + self.fc_feat_size = opt.fc_feat_size + self.att_feat_size = opt.att_feat_size + self.att_hid_size = opt.att_hid_size + + # Build a LSTM + self.a2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) + self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) + self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) + self.dropout = nn.Dropout(self.drop_prob_lm) + + self.attention = Attention(opt) + + def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): + att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks) + + all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) + self.a2h(att_res) + sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) + sigmoid_chunk = torch.sigmoid(sigmoid_chunk) + in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) + forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) + out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) + + in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + in_transform = torch.max(\ + in_transform.narrow(1, 0, self.rnn_size), + in_transform.narrow(1, self.rnn_size, self.rnn_size)) + next_c = forget_gate * state[1][-1] + in_gate * in_transform + next_h = out_gate * torch.tanh(next_c) + + output = self.dropout(next_h) + state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) + return output, state + +class AdaAttModel(AttModel): + def __init__(self, opt): + super(AdaAttModel, self).__init__(opt) + self.core = AdaAttCore(opt) + +# AdaAtt with maxout lstm +class AdaAttMOModel(AttModel): + def __init__(self, opt): + super(AdaAttMOModel, self).__init__(opt) + self.core = AdaAttCore(opt, True) + +class Att2in2Model(AttModel): + def __init__(self, opt): + super(Att2in2Model, self).__init__(opt) + self.core = Att2in2Core(opt) + delattr(self, 'fc_embed') + self.fc_embed = lambda x : x + +class Att2all2Model(AttModel): + def __init__(self, opt): + super(Att2all2Model, self).__init__(opt) + self.core = Att2all2Core(opt) + delattr(self, 'fc_embed') + self.fc_embed = lambda x : x + +class UpDownModel(AttModel): + def __init__(self, opt): + super(UpDownModel, self).__init__(opt) + self.num_layers = 2 + self.core = UpDownCore(opt) + +class StackAttModel(AttModel): + def __init__(self, opt): + super(StackAttModel, self).__init__(opt) + self.num_layers = 3 + self.core = StackAttCore(opt) + +class DenseAttModel(AttModel): + def __init__(self, opt): + super(DenseAttModel, self).__init__(opt) + self.num_layers = 3 + self.core = DenseAttCore(opt) + +class Att2inModel(AttModel): + def __init__(self, opt): + super(Att2inModel, self).__init__(opt) + del self.embed, self.fc_embed, self.att_embed + self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) + self.fc_embed = self.att_embed = lambda x: x + del self.ctx2att + self.ctx2att = nn.Linear(self.att_feat_size, self.att_hid_size) + self.core = Att2inCore(opt) + self.init_weights() + + def init_weights(self): + initrange = 0.1 + self.embed.weight.data.uniform_(-initrange, initrange) + self.logit.bias.data.fill_(0) + self.logit.weight.data.uniform_(-initrange, initrange) + + +class NewFCModel(AttModel): + def __init__(self, opt): + super(NewFCModel, self).__init__(opt) + self.fc_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) + self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) + self._core = LSTMCore(opt) + delattr(self, 'att_embed') + self.att_embed = lambda x : x + delattr(self, 'ctx2att') + self.ctx2att = lambda x: x + + def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks): + # Step 0, feed the input image + # if (self.training and state[0].is_leaf) or \ + # (not self.training and state[0].sum() == 0): + # _, state = self._core(fc_feats, state) + # three cases + # normal mle training + # Sample + # beam search (diverse beam search) + # fixed captioning module. + is_first_step = (state[0]==0).all(2).all(0) # size: B + if is_first_step.all(): + _, state = self._core(fc_feats, state) + elif is_first_step.any(): + # This is mostly for diverse beam search I think + new_state = [torch.zeros_like(_) for _ in state] + new_state[0][:, ~is_first_step] = state[0][:, ~is_first_step] + new_state[1][:, ~is_first_step] = state[1][:, ~is_first_step] + _, state = self._core(fc_feats, state) + new_state[0][:, is_first_step] = state[0][:, is_first_step] + new_state[1][:, is_first_step] = state[1][:, is_first_step] + state = new_state + # if (state[0]==0).all(): + # # Let's forget about diverse beam search first + # _, state = self._core(fc_feats, state) + return self._core(xt, state) + + def _prepare_feature(self, fc_feats, att_feats, att_masks): + fc_feats = self.fc_embed(fc_feats) + + return fc_feats, att_feats, att_feats, att_masks + + +class LMModel(AttModel): + def __init__(self, opt): + super(LMModel, self).__init__(opt) + delattr(self, 'fc_embed') + self.fc_embed = lambda x: x.new_zeros(x.shape[0], self.input_encoding_size) + self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) + self._core = LSTMCore(opt) + delattr(self, 'att_embed') + self.att_embed = lambda x : x + delattr(self, 'ctx2att') + self.ctx2att = lambda x: x + + def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks): + if (state[0]==0).all(): + # Let's forget about diverse beam search first + _, state = self._core(fc_feats, state) + return self._core(xt, state) + + def _prepare_feature(self, fc_feats, att_feats, att_masks): + fc_feats = self.fc_embed(fc_feats) + + return fc_feats, None, None, None \ No newline at end of file diff --git a/captioning/models/BertCapModel.py b/captioning/models/BertCapModel.py new file mode 100644 index 0000000000000000000000000000000000000000..3a7ccec2c40b2a171393059ec1a3af511163c246 --- /dev/null +++ b/captioning/models/BertCapModel.py @@ -0,0 +1,104 @@ +""" +BertCapModel is using huggingface transformer bert model as seq2seq model. + +The result is not as goog as original transformer. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import copy +import math +import numpy as np + +from .CaptionModel import CaptionModel +from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel +try: + from transformers import BertModel, BertConfig +except: + print('Hugginface transformers not installed; please visit https://github.com/huggingface/transformers') +from .TransformerModel import subsequent_mask, TransformerModel, Generator + +class EncoderDecoder(nn.Module): + """ + A standard Encoder-Decoder architecture. Base for this and many + other models. + """ + def __init__(self, encoder, decoder, generator): + super(EncoderDecoder, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.generator = generator + + def forward(self, src, tgt, src_mask, tgt_mask): + "Take in and process masked src and target sequences." + return self.decode(self.encode(src, src_mask), src_mask, + tgt, tgt_mask) + + def encode(self, src, src_mask): + return self.encoder(inputs_embeds=src, + attention_mask=src_mask)[0] + + def decode(self, memory, src_mask, tgt, tgt_mask): + return self.decoder(input_ids=tgt, + attention_mask=tgt_mask, + encoder_hidden_states=memory, + encoder_attention_mask=src_mask)[0] + + +class BertCapModel(TransformerModel): + + def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, + d_model=512, d_ff=2048, h=8, dropout=0.1): + "Helper: Construct a model from hyperparameters." + enc_config = BertConfig(vocab_size=1, + hidden_size=d_model, + num_hidden_layers=N_enc, + num_attention_heads=h, + intermediate_size=d_ff, + hidden_dropout_prob=dropout, + attention_probs_dropout_prob=dropout, + max_position_embeddings=1, + type_vocab_size=1) + dec_config = BertConfig(vocab_size=tgt_vocab, + hidden_size=d_model, + num_hidden_layers=N_dec, + num_attention_heads=h, + intermediate_size=d_ff, + hidden_dropout_prob=dropout, + attention_probs_dropout_prob=dropout, + max_position_embeddings=17, + type_vocab_size=1, + is_decoder=True) + encoder = BertModel(enc_config) + def return_embeds(*args, **kwargs): + return kwargs['inputs_embeds'] + del encoder.embeddings; encoder.embeddings = return_embeds + decoder = BertModel(dec_config) + model = EncoderDecoder( + encoder, + decoder, + Generator(d_model, tgt_vocab)) + return model + + def __init__(self, opt): + super(BertCapModel, self).__init__(opt) + + def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): + """ + state = [ys.unsqueeze(0)] + """ + if len(state) == 0: + ys = it.unsqueeze(1) + else: + ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) + out = self.model.decode(memory, mask, + ys, + subsequent_mask(ys.size(1)) + .to(memory.device)) + return out[:, -1], [ys.unsqueeze(0)] diff --git a/captioning/models/CaptionModel.py b/captioning/models/CaptionModel.py new file mode 100644 index 0000000000000000000000000000000000000000..221ecd1e173d2e20e0103d4cde328d82bfd6b66c --- /dev/null +++ b/captioning/models/CaptionModel.py @@ -0,0 +1,407 @@ +# This file contains ShowAttendTell and AllImg model + +# ShowAttendTell is from Show, Attend and Tell: Neural Image Caption Generation with Visual Attention +# https://arxiv.org/abs/1502.03044 + +# AllImg is a model where +# img feature is concatenated with word embedding at every time step as the input of lstm +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import * +from ..utils import misc as utils +from . import utils as model_utils + + +class CaptionModel(nn.Module): + def __init__(self): + super(CaptionModel, self).__init__() + + # implements beam search + # calls beam_step and returns the final set of beams + # augments log-probabilities with diversity terms when number of groups > 1 + + def forward(self, *args, **kwargs): + mode = kwargs.get('mode', 'forward') + if 'mode' in kwargs: + del kwargs['mode'] + return getattr(self, '_'+mode)(*args, **kwargs) + + def beam_search(self, init_state, init_logprobs, *args, **kwargs): + + # function computes the similarity score to be augmented + def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash): + local_time = t - divm + unaug_logprobs = logprobs.clone() + batch_size = beam_seq_table[0].shape[0] + + if divm > 0: + change = logprobs.new_zeros(batch_size, logprobs.shape[-1]) + for prev_choice in range(divm): + prev_decisions = beam_seq_table[prev_choice][:, :, local_time] # Nxb + for prev_labels in range(bdash): + change.scatter_add_(1, prev_decisions[:, prev_labels].unsqueeze(-1), change.new_ones(batch_size, 1)) + + if local_time == 0: + logprobs = logprobs - change * diversity_lambda + else: + logprobs = logprobs - self.repeat_tensor(bdash, change) * diversity_lambda + + return logprobs, unaug_logprobs + + + # does one step of classical beam search + + def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): + #INPUTS: + #logprobs: probabilities augmented after diversity N*bxV + #beam_size: obvious + #t : time instant + #beam_seq : tensor contanining the beams + #beam_seq_logprobs: tensor contanining the beam logprobs + #beam_logprobs_sum: tensor contanining joint logprobs + #OUPUTS: + #beam_seq : tensor containing the word indices of the decoded captions Nxbxl + #beam_seq_logprobs : log-probability of each decision made, NxbxlxV + #beam_logprobs_sum : joint log-probability of each beam Nxb + + batch_size = beam_logprobs_sum.shape[0] + vocab_size = logprobs.shape[-1] + logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV + if t == 0: + assert logprobs.shape[1] == 1 + beam_logprobs_sum = beam_logprobs_sum[:, :1] + candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV + ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True) + ys, ix = ys[:,:beam_size], ix[:,:beam_size] + beam_ix = ix // vocab_size # Nxb which beam + selected_ix = ix % vocab_size # Nxb # which world + state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams + + + if t > 0: + # gather according to beam_ix + assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all() + beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) + + beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(beam_seq_logprobs)) + + beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl + beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \ + logprobs.reshape(batch_size, -1).gather(1, ix) + assert (beam_logprobs_sum == ys).all() + _tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size) + beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV + assert (_tmp_beam_logprobs == beam_logprobs).all() + beam_seq_logprobs = torch.cat([ + beam_seq_logprobs, + beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2) + + new_state = [None for _ in state] + for _ix in range(len(new_state)): + # copy over state in previous beam q to new beam at vix + new_state[_ix] = state[_ix][:, state_ix] + state = new_state + return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state + + # Start diverse_beam_search + opt = kwargs['opt'] + temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs + beam_size = opt.get('beam_size', 10) + group_size = opt.get('group_size', 1) + diversity_lambda = opt.get('diversity_lambda', 0.5) + decoding_constraint = opt.get('decoding_constraint', 0) + remove_bad_endings = opt.get('remove_bad_endings', 0) + suppress_UNK = opt.get('suppress_UNK', 0) + length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) + bdash = beam_size // group_size # beam per group + + batch_size = init_logprobs.shape[0] + device = init_logprobs.device + # INITIALIZATIONS + beam_seq_table = [torch.LongTensor(batch_size, bdash, 0).to(device) for _ in range(group_size)] + beam_seq_logprobs_table = [torch.FloatTensor(batch_size, bdash, 0, self.vocab_size + 1).to(device) for _ in range(group_size)] + beam_logprobs_sum_table = [torch.zeros(batch_size, bdash).to(device) for _ in range(group_size)] + + # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) + done_beams_table = [[[] for __ in range(group_size)] for _ in range(batch_size)] + # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] + # state_table = list(zip(*[_.reshape(-1, batch_size * bdash, group_size, *_.shape[2:]).chunk(group_size, 2) for _ in init_state])) + state_table = [[_.clone() for _ in init_state] for _ in range(group_size)] + # logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0)) + logprobs_table = [init_logprobs.clone() for _ in range(group_size)] + # END INIT + + # Chunk elements in the args + args = list(args) + args = model_utils.split_tensors(group_size, args) # For each arg, turn (Bbg)x... to (Bb)x(g)x... + if self.__class__.__name__ == 'AttEnsemble': + args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name + else: + args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] + + for t in range(self.seq_length + group_size - 1): + for divm in range(group_size): + if t >= divm and t <= self.seq_length + divm - 1: + # add diversity + logprobs = logprobs_table[divm] + # suppress previous word + if decoding_constraint and t-divm > 0: + logprobs.scatter_(1, beam_seq_table[divm][:, :, t-divm-1].reshape(-1, 1).to(device), float('-inf')) + if remove_bad_endings and t-divm > 0: + logprobs[torch.from_numpy(np.isin(beam_seq_table[divm][:, :, t-divm-1].cpu().numpy(), self.bad_endings_ix)).reshape(-1), 0] = float('-inf') + # suppress UNK tokens in the decoding + if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK': + logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000 + # diversity is added here + # the function directly modifies the logprobs values and hence, we need to return + # the unaugmented ones for sorting the candidates in the end. # for historical + # reasons :-) + logprobs, unaug_logprobs = add_diversity(beam_seq_table,logprobs,t,divm,diversity_lambda,bdash) + + # infer new beams + beam_seq_table[divm],\ + beam_seq_logprobs_table[divm],\ + beam_logprobs_sum_table[divm],\ + state_table[divm] = beam_step(logprobs, + unaug_logprobs, + bdash, + t-divm, + beam_seq_table[divm], + beam_seq_logprobs_table[divm], + beam_logprobs_sum_table[divm], + state_table[divm]) + + # if time's up... or if end token is reached then copy beams + for b in range(batch_size): + is_end = beam_seq_table[divm][b, :, t-divm] == self.eos_idx + assert beam_seq_table[divm].shape[-1] == t-divm+1 + if t == self.seq_length + divm - 1: + is_end.fill_(1) + for vix in range(bdash): + if is_end[vix]: + final_beam = { + 'seq': beam_seq_table[divm][b, vix].clone(), + 'logps': beam_seq_logprobs_table[divm][b, vix].clone(), + 'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(), + 'p': beam_logprobs_sum_table[divm][b, vix].item() + } + final_beam['p'] = length_penalty(t-divm+1, final_beam['p']) + done_beams_table[b][divm].append(final_beam) + beam_logprobs_sum_table[divm][b, is_end] -= 1000 + + # move the current group one step forward in time + + it = beam_seq_table[divm][:, :, t-divm].reshape(-1).to(logprobs.device) + logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]])) + logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) + + # all beams are sorted by their log-probabilities + done_beams_table = [[sorted(done_beams_table[b][i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] for b in range(batch_size)] + done_beams = [sum(_, []) for _ in done_beams_table] + return done_beams + + def old_beam_search(self, init_state, init_logprobs, *args, **kwargs): + + # function computes the similarity score to be augmented + def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): + local_time = t - divm + unaug_logprobsf = logprobsf.clone() + for prev_choice in range(divm): + prev_decisions = beam_seq_table[prev_choice][local_time] + for sub_beam in range(bdash): + for prev_labels in range(bdash): + logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda + return unaug_logprobsf + + # does one step of classical beam search + + def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): + #INPUTS: + #logprobsf: probabilities augmented after diversity + #beam_size: obvious + #t : time instant + #beam_seq : tensor contanining the beams + #beam_seq_logprobs: tensor contanining the beam logprobs + #beam_logprobs_sum: tensor contanining joint logprobs + #OUPUTS: + #beam_seq : tensor containing the word indices of the decoded captions + #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq + #beam_logprobs_sum : joint log-probability of each beam + + ys,ix = torch.sort(logprobsf,1,True) + candidates = [] + cols = min(beam_size, ys.size(1)) + rows = beam_size + if t == 0: + rows = 1 + for c in range(cols): # for each column (word, essentially) + for q in range(rows): # for each beam expansion + #compute logprob of expanding beam q with word in (sorted) position c + local_logprob = ys[q,c].item() + candidate_logprob = beam_logprobs_sum[q] + local_logprob + # local_unaug_logprob = unaug_logprobsf[q,ix[q,c]] + candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':unaug_logprobsf[q]}) + candidates = sorted(candidates, key=lambda x: -x['p']) + + new_state = [_.clone() for _ in state] + #beam_seq_prev, beam_seq_logprobs_prev + if t >= 1: + #we''ll need these as reference when we fork beams around + beam_seq_prev = beam_seq[:t].clone() + beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() + for vix in range(beam_size): + v = candidates[vix] + #fork beam index q into index vix + if t >= 1: + beam_seq[:t, vix] = beam_seq_prev[:, v['q']] + beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] + #rearrange recurrent states + for state_ix in range(len(new_state)): + # copy over state in previous beam q to new beam at vix + new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step + #append new end terminal at the end of this beam + beam_seq[t, vix] = v['c'] # c'th word is the continuation + beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here + beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam + state = new_state + return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates + + # Start diverse_beam_search + opt = kwargs['opt'] + temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs + beam_size = opt.get('beam_size', 10) + group_size = opt.get('group_size', 1) + diversity_lambda = opt.get('diversity_lambda', 0.5) + decoding_constraint = opt.get('decoding_constraint', 0) + remove_bad_endings = opt.get('remove_bad_endings', 0) + suppress_UNK = opt.get('suppress_UNK', 0) + length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) + bdash = beam_size // group_size # beam per group + + # INITIALIZATIONS + beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] + beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash, self.vocab_size + 1).zero_() for _ in range(group_size)] + beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)] + + # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) + done_beams_table = [[] for _ in range(group_size)] + # state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] + state_table = list(zip(*[_.chunk(group_size, 1) for _ in init_state])) + logprobs_table = list(init_logprobs.chunk(group_size, 0)) + # END INIT + + # Chunk elements in the args + args = list(args) + if self.__class__.__name__ == 'AttEnsemble': + args = [[_.chunk(group_size) if _ is not None else [None]*group_size for _ in args_] for args_ in args] # arg_name, model_name, group_name + args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name + else: + args = [_.chunk(group_size) if _ is not None else [None]*group_size for _ in args] + args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] + + for t in range(self.seq_length + group_size - 1): + for divm in range(group_size): + if t >= divm and t <= self.seq_length + divm - 1: + # add diversity + logprobsf = logprobs_table[divm] + # suppress previous word + if decoding_constraint and t-divm > 0: + logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).to(logprobsf.device), float('-inf')) + if remove_bad_endings and t-divm > 0: + logprobsf[torch.from_numpy(np.isin(beam_seq_table[divm][t-divm-1].cpu().numpy(), self.bad_endings_ix)), 0] = float('-inf') + # suppress UNK tokens in the decoding + if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobsf.size(1)-1)] == 'UNK': + logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000 + # diversity is added here + # the function directly modifies the logprobsf values and hence, we need to return + # the unaugmented ones for sorting the candidates in the end. # for historical + # reasons :-) + unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash) + + # infer new beams + beam_seq_table[divm],\ + beam_seq_logprobs_table[divm],\ + beam_logprobs_sum_table[divm],\ + state_table[divm],\ + candidates_divm = beam_step(logprobsf, + unaug_logprobsf, + bdash, + t-divm, + beam_seq_table[divm], + beam_seq_logprobs_table[divm], + beam_logprobs_sum_table[divm], + state_table[divm]) + + # if time's up... or if end token is reached then copy beams + for vix in range(bdash): + if beam_seq_table[divm][t-divm,vix] == self.eos_idx or t == self.seq_length + divm - 1: + final_beam = { + 'seq': beam_seq_table[divm][:, vix].clone(), + 'logps': beam_seq_logprobs_table[divm][:, vix].clone(), + 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(), + 'p': beam_logprobs_sum_table[divm][vix].item() + } + final_beam['p'] = length_penalty(t-divm+1, final_beam['p']) + done_beams_table[divm].append(final_beam) + # don't continue beams from finished sequences + beam_logprobs_sum_table[divm][vix] = -1000 + + # move the current group one step forward in time + + it = beam_seq_table[divm][t-divm].to(logprobsf.device) + logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]])) + logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) + + # all beams are sorted by their log-probabilities + done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] + done_beams = sum(done_beams_table, []) + return done_beams + + def sample_next_word(self, logprobs, sample_method, temperature): + if sample_method == 'greedy': + sampleLogprobs, it = torch.max(logprobs.data, 1) + it = it.view(-1).long() + elif sample_method == 'gumbel': # gumbel softmax + # ref: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f + def sample_gumbel(shape, eps=1e-20): + U = torch.rand(shape).to(logprobs.device) + return -torch.log(-torch.log(U + eps) + eps) + def gumbel_softmax_sample(logits, temperature): + y = logits + sample_gumbel(logits.size()) + return F.log_softmax(y / temperature, dim=-1) + _logprobs = gumbel_softmax_sample(logprobs, temperature) + _, it = torch.max(_logprobs.data, 1) + sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions + else: + logprobs = logprobs / temperature + if sample_method.startswith('top'): # topk sampling + top_num = float(sample_method[3:]) + if 0 < top_num < 1: + # nucleus sampling from # The Curious Case of Neural Text Degeneration + probs = F.softmax(logprobs, dim=1) + sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1) + _cumsum = sorted_probs.cumsum(1) + mask = _cumsum < top_num + mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1) + sorted_probs = sorted_probs * mask.to(sorted_probs) + sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True) + logprobs.scatter_(1, sorted_indices, sorted_probs.log()) + else: + the_k = int(top_num) + tmp = torch.empty_like(logprobs).fill_(float('-inf')) + topk, indices = torch.topk(logprobs, the_k, dim=1) + tmp = tmp.scatter(1, indices, topk) + logprobs = tmp + it = torch.distributions.Categorical(logits=logprobs.detach()).sample() + sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions + return it, sampleLogprobs + + + def decode_sequence(self, seq): + return utils.decode_sequence(self.vocab, seq) diff --git a/captioning/models/FCModel.py b/captioning/models/FCModel.py new file mode 100644 index 0000000000000000000000000000000000000000..d3b8340c228e3f6039677e55540d87feb2765d62 --- /dev/null +++ b/captioning/models/FCModel.py @@ -0,0 +1,204 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import * +from . import utils + +from .CaptionModel import CaptionModel + +class LSTMCore(nn.Module): + def __init__(self, opt): + super(LSTMCore, self).__init__() + self.input_encoding_size = opt.input_encoding_size + self.rnn_size = opt.rnn_size + self.drop_prob_lm = opt.drop_prob_lm + + # Build a LSTM + self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) + self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) + self.dropout = nn.Dropout(self.drop_prob_lm) + + def forward(self, xt, state): + + all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) + sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) + sigmoid_chunk = torch.sigmoid(sigmoid_chunk) + in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) + forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) + out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) + + in_transform = torch.max(\ + all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size), + all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size)) + next_c = forget_gate * state[1][-1] + in_gate * in_transform + next_h = out_gate * torch.tanh(next_c) + + output = self.dropout(next_h) + state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) + return output, state + +class FCModel(CaptionModel): + def __init__(self, opt): + super(FCModel, self).__init__() + self.vocab_size = opt.vocab_size + self.input_encoding_size = opt.input_encoding_size + self.rnn_type = opt.rnn_type + self.rnn_size = opt.rnn_size + self.num_layers = opt.num_layers + self.drop_prob_lm = opt.drop_prob_lm + self.seq_length = opt.seq_length + self.fc_feat_size = opt.fc_feat_size + + self.ss_prob = 0.0 # Schedule sampling probability + + self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) + self.core = LSTMCore(opt) + self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) + self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) + + self.init_weights() + + def init_weights(self): + initrange = 0.1 + self.embed.weight.data.uniform_(-initrange, initrange) + self.logit.bias.data.fill_(0) + self.logit.weight.data.uniform_(-initrange, initrange) + + def init_hidden(self, bsz): + weight = self.logit.weight + if self.rnn_type == 'lstm': + return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), + weight.new_zeros(self.num_layers, bsz, self.rnn_size)) + else: + return weight.new_zeros(self.num_layers, bsz, self.rnn_size) + + def _forward(self, fc_feats, att_feats, seq, att_masks=None): + batch_size = fc_feats.size(0) + seq_per_img = seq.shape[0] // batch_size + state = self.init_hidden(batch_size*seq_per_img) + outputs = [] + + if seq_per_img > 1: + fc_feats = utils.repeat_tensors(seq_per_img, fc_feats) + + for i in range(seq.size(1) + 1): + if i == 0: + xt = self.img_embed(fc_feats) + else: + if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample + sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1) + sample_mask = sample_prob < self.ss_prob + if sample_mask.sum() == 0: + it = seq[:, i-1].clone() + else: + sample_ind = sample_mask.nonzero().view(-1) + it = seq[:, i-1].data.clone() + #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) + #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) + prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) + it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) + else: + it = seq[:, i-1].clone() + # break if all the sequences end + if i >= 2 and seq[:, i-1].sum() == 0: + break + xt = self.embed(it) + + output, state = self.core(xt, state) + output = F.log_softmax(self.logit(output), dim=1) + outputs.append(output) + + return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() + + def get_logprobs_state(self, it, state): + # 'it' is contains a word index + xt = self.embed(it) + + output, state = self.core(xt, state) + logprobs = F.log_softmax(self.logit(output), dim=1) + + return logprobs, state + + def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): + beam_size = opt.get('beam_size', 10) + batch_size = fc_feats.size(0) + + assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' + seq = torch.LongTensor(self.seq_length, batch_size).zero_() + seqLogprobs = torch.FloatTensor(self.seq_length, batch_size, self.vocab_size + 1) + # lets process every image independently for now, for simplicity + + self.done_beams = [[] for _ in range(batch_size)] + for k in range(batch_size): + state = self.init_hidden(beam_size) + for t in range(2): + if t == 0: + xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) + elif t == 1: # input + it = fc_feats.data.new(beam_size).long().zero_() + xt = self.embed(it) + + output, state = self.core(xt, state) + logprobs = F.log_softmax(self.logit(output), dim=1) + + self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) + seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score + seqLogprobs[:, k] = self.done_beams[k][0]['logps'] + # return the samples and their log likelihoods + return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) + + def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): + sample_method = opt.get('sample_method', 'greedy') + beam_size = opt.get('beam_size', 1) + temperature = opt.get('temperature', 1.0) + if beam_size > 1 and sample_method in ['greedy', 'beam_search']: + return self._sample_beam(fc_feats, att_feats, opt) + + batch_size = fc_feats.size(0) + state = self.init_hidden(batch_size) + seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) + seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length, self.vocab_size + 1) + for t in range(self.seq_length + 2): + if t == 0: + xt = self.img_embed(fc_feats) + else: + if t == 1: # input + it = fc_feats.data.new(batch_size).long().zero_() + xt = self.embed(it) + + output, state = self.core(xt, state) + logprobs = F.log_softmax(self.logit(output), dim=1) + + # sample the next_word + if t == self.seq_length + 1: # skip if we achieve maximum length + break + if sample_method == 'greedy': + sampleLogprobs, it = torch.max(logprobs.data, 1) + it = it.view(-1).long() + else: + if temperature == 1.0: + prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) + else: + # scale logprobs by temperature + prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() + it = torch.multinomial(prob_prev, 1).to(logprobs.device) + sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions + it = it.view(-1).long() # and flatten indices for downstream processing + + if t >= 1: + # stop when all finished + if t == 1: + unfinished = it > 0 + else: + unfinished = unfinished & (it > 0) + it = it * unfinished.type_as(it) + seq[:,t-1] = it #seq[t] the input of t+2 time step + seqLogprobs[:,t-1] = sampleLogprobs.view(-1) + if unfinished.sum() == 0: + break + + return seq, seqLogprobs diff --git a/captioning/models/M2Transformer.py b/captioning/models/M2Transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0428e5d429645bf340a9d72a4b2d0ae6a14bb2bc --- /dev/null +++ b/captioning/models/M2Transformer.py @@ -0,0 +1,98 @@ +""" +Instruction to use meshed_memory_transformer (https://arxiv.org/abs/1912.08226) + +pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git + +Note: +Currently m2transformer is not performing as well as original transformer. Not sure why? Still investigating. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import copy +import math +import numpy as np + +from .CaptionModel import CaptionModel +from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel + +try: + from m2transformer.models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory +except: + print('meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`') +from .TransformerModel import subsequent_mask, TransformerModel + + +class M2TransformerModel(TransformerModel): + + def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, + d_model=512, d_ff=2048, h=8, dropout=0.1): + "Helper: Construct a model from hyperparameters." + encoder = MemoryAugmentedEncoder(N_enc, 0, attention_module=ScaledDotProductAttentionMemory, + attention_module_kwargs={'m': 40}) + # Another implementation is to use MultiLevelEncoder + att_embed + decoder = MeshedDecoder(tgt_vocab, 54, N_dec, -1) # -1 is padding; + model = Transformer(0, encoder, decoder) # 0 is bos + return model + + def __init__(self, opt): + super(M2TransformerModel, self).__init__(opt) + delattr(self, 'att_embed') + self.att_embed = lambda x: x # The visual embed is in the MAEncoder + # Notes: The dropout in MAEncoder is different from my att_embed, mine is 0.5? + # Also the attention mask seems wrong in MAEncoder too...intersting + + def logit(self, x): # unsafe way + return x # M2transformer always output logsoftmax + + def _prepare_feature(self, fc_feats, att_feats, att_masks): + + att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) + memory, att_masks = self.model.encoder(att_feats) + + return fc_feats[...,:0], att_feats[...,:0], memory, att_masks + + def _forward(self, fc_feats, att_feats, seq, att_masks=None): + if seq.ndim == 3: # B * seq_per_img * seq_len + seq = seq.reshape(-1, seq.shape[2]) + att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) + + seq = seq.clone() + seq[~seq_mask.any(-2)] = -1 # Make padding to be -1 (my dataloader uses 0 as padding) + outputs = self.model(att_feats, seq) + + return outputs + + def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): + """ + state = [ys.unsqueeze(0)] + """ + if len(state) == 0: + ys = it.unsqueeze(1) + else: + ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) + out = self.model.decoder(ys, memory, mask) + return out[:, -1], [ys.unsqueeze(0)] + + def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): + beam_size = opt.get('beam_size', 10) + group_size = opt.get('group_size', 1) + sample_n = opt.get('sample_n', 10) + assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' + + att_feats, _, __, ___ = self._prepare_feature_forward(att_feats, att_masks) + seq, logprobs, seqLogprobs = self.model.beam_search(att_feats, self.seq_length, 0, + beam_size, return_probs=True, out_size=beam_size) + seq = seq.reshape(-1, *seq.shape[2:]) + seqLogprobs = seqLogprobs.reshape(-1, *seqLogprobs.shape[2:]) + + # if not (seqLogprobs.gather(-1, seq.unsqueeze(-1)).squeeze(-1) == logprobs.reshape(-1, logprobs.shape[-1])).all(): + # import pudb;pu.db + # seqLogprobs = logprobs.reshape(-1, logprobs.shape[-1]).unsqueeze(-1).expand(-1,-1,seqLogprobs.shape[-1]) + return seq, seqLogprobs \ No newline at end of file diff --git a/captioning/models/ShowTellModel.py b/captioning/models/ShowTellModel.py new file mode 100644 index 0000000000000000000000000000000000000000..2f3463b64f988aa61d90838ddcf8ac89053c3377 --- /dev/null +++ b/captioning/models/ShowTellModel.py @@ -0,0 +1,174 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import * +from . import utils + +from .CaptionModel import CaptionModel + +class ShowTellModel(CaptionModel): + def __init__(self, opt): + super(ShowTellModel, self).__init__() + self.vocab_size = opt.vocab_size + self.input_encoding_size = opt.input_encoding_size + self.rnn_type = opt.rnn_type + self.rnn_size = opt.rnn_size + self.num_layers = opt.num_layers + self.drop_prob_lm = opt.drop_prob_lm + self.seq_length = opt.seq_length + self.fc_feat_size = opt.fc_feat_size + + self.ss_prob = 0.0 # Schedule sampling probability + + self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) + self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) + self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) + self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) + self.dropout = nn.Dropout(self.drop_prob_lm) + + self.init_weights() + + def init_weights(self): + initrange = 0.1 + self.embed.weight.data.uniform_(-initrange, initrange) + self.logit.bias.data.fill_(0) + self.logit.weight.data.uniform_(-initrange, initrange) + + def init_hidden(self, bsz): + weight = self.logit.weight + if self.rnn_type == 'lstm': + return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), + weight.new_zeros(self.num_layers, bsz, self.rnn_size)) + else: + return weight.new_zeros(self.num_layers, bsz, self.rnn_size) + + def _forward(self, fc_feats, att_feats, seq, att_masks=None): + batch_size = fc_feats.size(0) + seq_per_img = seq.shape[0] // batch_size + state = self.init_hidden(batch_size*seq_per_img) + outputs = [] + + if seq_per_img > 1: + fc_feats = utils.repeat_tensors(seq_per_img, fc_feats) + + for i in range(seq.size(1) + 1): + if i == 0: + xt = self.img_embed(fc_feats) + else: + if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample + sample_prob = fc_feats.data.new(batch_size*seq_per_img).uniform_(0, 1) + sample_mask = sample_prob < self.ss_prob + if sample_mask.sum() == 0: + it = seq[:, i-1].clone() + else: + sample_ind = sample_mask.nonzero().view(-1) + it = seq[:, i-1].data.clone() + #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) + #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) + prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) + it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) + else: + it = seq[:, i-1].clone() + # break if all the sequences end + if i >= 2 and seq[:, i-1].data.sum() == 0: + break + xt = self.embed(it) + + output, state = self.core(xt.unsqueeze(0), state) + output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) + outputs.append(output) + + return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() + + def get_logprobs_state(self, it, state): + # 'it' contains a word index + xt = self.embed(it) + + output, state = self.core(xt.unsqueeze(0), state) + logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) + + return logprobs, state + + def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): + beam_size = opt.get('beam_size', 10) + batch_size = fc_feats.size(0) + + assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' + seq = torch.LongTensor(self.seq_length, batch_size).zero_() + seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) + # lets process every image independently for now, for simplicity + + self.done_beams = [[] for _ in range(batch_size)] + for k in range(batch_size): + state = self.init_hidden(beam_size) + for t in range(2): + if t == 0: + xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) + elif t == 1: # input + it = fc_feats.data.new(beam_size).long().zero_() + xt = self.embed(it) + + output, state = self.core(xt.unsqueeze(0), state) + logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) + + self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) + seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score + seqLogprobs[:, k] = self.done_beams[k][0]['logps'] + # return the samples and their log likelihoods + return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) + + def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): + sample_method = opt.get('sample_method', 'greedy') + beam_size = opt.get('beam_size', 1) + temperature = opt.get('temperature', 1.0) + if beam_size > 1 and sample_method in ['greedy', 'beam_search']: + return self.sample_beam(fc_feats, att_feats, opt) + + batch_size = fc_feats.size(0) + state = self.init_hidden(batch_size) + seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) + seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) + for t in range(self.seq_length + 2): + if t == 0: + xt = self.img_embed(fc_feats) + else: + if t == 1: # input + it = fc_feats.data.new(batch_size).long().zero_() + xt = self.embed(it) + + output, state = self.core(xt.unsqueeze(0), state) + logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) + + # sample the next word + if t == self.seq_length + 1: # skip if we achieve maximum length + break + if sample_method == 'greedy': + sampleLogprobs, it = torch.max(logprobs.data, 1) + it = it.view(-1).long() + else: + if temperature == 1.0: + prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) + else: + # scale logprobs by temperature + prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() + it = torch.multinomial(prob_prev, 1).to(logprobs.device) + sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions + it = it.view(-1).long() # and flatten indices for downstream processing + + if t >= 1: + # stop when all finished + if t == 1: + unfinished = it > 0 + else: + unfinished = unfinished & (it > 0) + it = it * unfinished.type_as(it) + seq[:,t-1] = it #seq[t] the input of t+2 time step + seqLogprobs[:,t-1] = sampleLogprobs.view(-1) + if unfinished.sum() == 0: + break + + return seq, seqLogprobs \ No newline at end of file diff --git a/captioning/models/TransformerModel.py b/captioning/models/TransformerModel.py new file mode 100644 index 0000000000000000000000000000000000000000..70a27a25e968cf906bdde461e054fed77c08f70b --- /dev/null +++ b/captioning/models/TransformerModel.py @@ -0,0 +1,363 @@ +# This file contains Transformer network +# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html + +# The cfg name correspondance: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size +# h is always 8 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import utils + +import copy +import math +import numpy as np + +from .CaptionModel import CaptionModel +from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel + +class EncoderDecoder(nn.Module): + """ + A standard Encoder-Decoder architecture. Base for this and many + other models. + """ + def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): + super(EncoderDecoder, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.src_embed = src_embed + self.tgt_embed = tgt_embed + self.generator = generator + + def forward(self, src, tgt, src_mask, tgt_mask): + "Take in and process masked src and target sequences." + return self.decode(self.encode(src, src_mask), src_mask, + tgt, tgt_mask) + + def encode(self, src, src_mask): + return self.encoder(self.src_embed(src), src_mask) + + def decode(self, memory, src_mask, tgt, tgt_mask): + return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) + +class Generator(nn.Module): + "Define standard linear + softmax generation step." + def __init__(self, d_model, vocab): + super(Generator, self).__init__() + self.proj = nn.Linear(d_model, vocab) + + def forward(self, x): + return F.log_softmax(self.proj(x), dim=-1) + +def clones(module, N): + "Produce N identical layers." + return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + +class Encoder(nn.Module): + "Core encoder is a stack of N layers" + def __init__(self, layer, N): + super(Encoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, mask): + "Pass the input (and mask) through each layer in turn." + for layer in self.layers: + x = layer(x, mask) + return self.norm(x) + +class LayerNorm(nn.Module): + "Construct a layernorm module (See citation for details)." + def __init__(self, features, eps=1e-6): + super(LayerNorm, self).__init__() + self.a_2 = nn.Parameter(torch.ones(features)) + self.b_2 = nn.Parameter(torch.zeros(features)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 + +class SublayerConnection(nn.Module): + """ + A residual connection followed by a layer norm. + Note for code simplicity the norm is first as opposed to last. + """ + def __init__(self, size, dropout): + super(SublayerConnection, self).__init__() + self.norm = LayerNorm(size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer): + "Apply residual connection to any sublayer with the same size." + return x + self.dropout(sublayer(self.norm(x))) + +class EncoderLayer(nn.Module): + "Encoder is made up of self-attn and feed forward (defined below)" + def __init__(self, size, self_attn, feed_forward, dropout): + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 2) + self.size = size + + def forward(self, x, mask): + "Follow Figure 1 (left) for connections." + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) + return self.sublayer[1](x, self.feed_forward) + +class Decoder(nn.Module): + "Generic N layer decoder with masking." + def __init__(self, layer, N): + super(Decoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, memory, src_mask, tgt_mask): + for layer in self.layers: + x = layer(x, memory, src_mask, tgt_mask) + return self.norm(x) + +class DecoderLayer(nn.Module): + "Decoder is made of self-attn, src-attn, and feed forward (defined below)" + def __init__(self, size, self_attn, src_attn, feed_forward, dropout): + super(DecoderLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 3) + + def forward(self, x, memory, src_mask, tgt_mask): + "Follow Figure 1 (right) for connections." + m = memory + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) + x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) + return self.sublayer[2](x, self.feed_forward) + +def subsequent_mask(size): + "Mask out subsequent positions." + attn_shape = (1, size, size) + subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') + return torch.from_numpy(subsequent_mask) == 0 + +def attention(query, key, value, mask=None, dropout=None): + "Compute 'Scaled Dot Product Attention'" + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) \ + / math.sqrt(d_k) + if mask is not None: + scores = scores.masked_fill(mask == 0, float('-inf')) + p_attn = F.softmax(scores, dim = -1) + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn, value), p_attn + +class MultiHeadedAttention(nn.Module): + def __init__(self, h, d_model, dropout=0.1): + "Take in model size and number of heads." + super(MultiHeadedAttention, self).__init__() + assert d_model % h == 0 + # We assume d_v always equals d_k + self.d_k = d_model // h + self.h = h + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward(self, query, key, value, mask=None): + "Implements Figure 2" + if mask is not None: + # Same mask applied to all h heads. + mask = mask.unsqueeze(1) + nbatches = query.size(0) + + # 1) Do all the linear projections in batch from d_model => h x d_k + query, key, value = \ + [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + for l, x in zip(self.linears, (query, key, value))] + + # 2) Apply attention on all the projected vectors in batch. + x, self.attn = attention(query, key, value, mask=mask, + dropout=self.dropout) + + # 3) "Concat" using a view and apply a final linear. + x = x.transpose(1, 2).contiguous() \ + .view(nbatches, -1, self.h * self.d_k) + return self.linears[-1](x) + +class PositionwiseFeedForward(nn.Module): + "Implements FFN equation." + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.w_2(self.dropout(F.relu(self.w_1(x)))) + +class Embeddings(nn.Module): + def __init__(self, d_model, vocab): + super(Embeddings, self).__init__() + self.lut = nn.Embedding(vocab, d_model) + self.d_model = d_model + + def forward(self, x): + return self.lut(x) * math.sqrt(self.d_model) + +class PositionalEncoding(nn.Module): + "Implement the PE function." + def __init__(self, d_model, dropout, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1).float() + div_term = torch.exp(torch.arange(0, d_model, 2).float() * + -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return self.dropout(x) + +class TransformerModel(AttModel): + + def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, + d_model=512, d_ff=2048, h=8, dropout=0.1): + "Helper: Construct a model from hyperparameters." + c = copy.deepcopy + attn = MultiHeadedAttention(h, d_model, dropout) + ff = PositionwiseFeedForward(d_model, d_ff, dropout) + position = PositionalEncoding(d_model, dropout) + model = EncoderDecoder( + Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc), + Decoder(DecoderLayer(d_model, c(attn), c(attn), + c(ff), dropout), N_dec), + lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)), + nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), + Generator(d_model, tgt_vocab)) + + # This was important from their code. + # Initialize parameters with Glorot / fan_avg. + for p in model.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + return model + + def __init__(self, opt): + super(TransformerModel, self).__init__(opt) + self.opt = opt + # self.config = yaml.load(open(opt.config_file)) + + self.N_enc = getattr(opt, 'N_enc', opt.num_layers) + self.N_dec = getattr(opt, 'N_dec', opt.num_layers) + self.d_model = getattr(opt, 'd_model', opt.input_encoding_size) + self.d_ff = getattr(opt, 'd_ff', opt.rnn_size) + self.h = getattr(opt, 'num_att_heads', 8) + self.dropout = getattr(opt, 'dropout', 0.1) + + delattr(self, 'att_embed') + self.att_embed = nn.Sequential(*( + ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ + (nn.Linear(self.att_feat_size, self.d_model), + nn.ReLU(), + nn.Dropout(self.drop_prob_lm))+ + ((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ()))) + + delattr(self, 'embed') + self.embed = lambda x : x + delattr(self, 'fc_embed') + self.fc_embed = lambda x : x + delattr(self, 'logit') + del self.ctx2att + + tgt_vocab = self.vocab_size + 1 + + + self.model = self.make_model(0, tgt_vocab, + N_enc=self.N_enc, + N_dec=self.N_dec, + d_model=self.d_model, + d_ff=self.d_ff, + h=self.h, + dropout=self.dropout) + + def logit(self, x): # unsafe way + return self.model.generator.proj(x) + + def init_hidden(self, bsz): + return [] + + def _prepare_feature(self, fc_feats, att_feats, att_masks): + + att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) + memory = self.model.encode(att_feats, att_masks) + + return fc_feats[...,:0], att_feats[...,:0], memory, att_masks + + def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): + att_feats, att_masks = self.clip_att(att_feats, att_masks) + + att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) + + if att_masks is None: + att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) + att_masks = att_masks.unsqueeze(-2) + + if seq is not None: + # crop the last one + # seq = seq[:,:-1] + seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx) + seq_mask[:,0] = 1 # bos + + seq_mask = seq_mask.unsqueeze(-2) + seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) + + seq_per_img = seq.shape[0] // att_feats.shape[0] + if seq_per_img > 1: + att_feats, att_masks = utils.repeat_tensors(seq_per_img, + [att_feats, att_masks] + ) + else: + seq_mask = None + + return att_feats, seq, att_masks, seq_mask + + def _forward(self, fc_feats, att_feats, seq, att_masks=None): + if seq.ndim == 3: # B * seq_per_img * seq_len + seq = seq.reshape(-1, seq.shape[2]) + att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) + + out = self.model(att_feats, seq, att_masks, seq_mask) + + outputs = self.model.generator(out) + return outputs + # return torch.cat([_.unsqueeze(1) for _ in outputs], 1) + + def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): + """ + state = [ys.unsqueeze(0)] + """ + if len(state) == 0: + ys = it.unsqueeze(1) + else: + ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) + out = self.model.decode(memory, mask, + ys, + subsequent_mask(ys.size(1)) + .to(memory.device)) + return out[:, -1], [ys.unsqueeze(0)] \ No newline at end of file diff --git a/captioning/models/__init__.py b/captioning/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..29f7a9cb48b9397ed0b658c15580b43c5ae1300d --- /dev/null +++ b/captioning/models/__init__.py @@ -0,0 +1,73 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import copy + +import numpy as np +import torch + +from .ShowTellModel import ShowTellModel +from .FCModel import FCModel +from .AttModel import * +from .TransformerModel import TransformerModel +from .cachedTransformer import TransformerModel as cachedTransformer +from .BertCapModel import BertCapModel +from .M2Transformer import M2TransformerModel +from .AoAModel import AoAModel + +def setup(opt): + if opt.caption_model in ['fc', 'show_tell']: + print('Warning: %s model is mostly deprecated; many new features are not supported.' %opt.caption_model) + if opt.caption_model == 'fc': + print('Use newfc instead of fc') + if opt.caption_model == 'fc': + model = FCModel(opt) + elif opt.caption_model == 'language_model': + model = LMModel(opt) + elif opt.caption_model == 'newfc': + model = NewFCModel(opt) + elif opt.caption_model == 'show_tell': + model = ShowTellModel(opt) + # Att2in model in self-critical + elif opt.caption_model == 'att2in': + model = Att2inModel(opt) + # Att2in model with two-layer MLP img embedding and word embedding + elif opt.caption_model == 'att2in2': + model = Att2in2Model(opt) + elif opt.caption_model == 'att2all2': + print('Warning: this is not a correct implementation of the att2all model in the original paper.') + model = Att2all2Model(opt) + # Adaptive Attention model from Knowing when to look + elif opt.caption_model == 'adaatt': + model = AdaAttModel(opt) + # Adaptive Attention with maxout lstm + elif opt.caption_model == 'adaattmo': + model = AdaAttMOModel(opt) + # Top-down attention model + elif opt.caption_model in ['topdown', 'updown']: + model = UpDownModel(opt) + # StackAtt + elif opt.caption_model == 'stackatt': + model = StackAttModel(opt) + # DenseAtt + elif opt.caption_model == 'denseatt': + model = DenseAttModel(opt) + # Transformer + elif opt.caption_model == 'transformer': + if getattr(opt, 'cached_transformer', False): + model = cachedTransformer(opt) + else: + model = TransformerModel(opt) + # AoANet + elif opt.caption_model == 'aoa': + model = AoAModel(opt) + elif opt.caption_model == 'bert': + model = BertCapModel(opt) + elif opt.caption_model == 'm2transformer': + model = M2TransformerModel(opt) + else: + raise Exception("Caption model not supported: {}".format(opt.caption_model)) + + return model diff --git a/captioning/models/cachedTransformer.py b/captioning/models/cachedTransformer.py new file mode 100644 index 0000000000000000000000000000000000000000..719701cb348a11255a36d554ad350dcfc87e5121 --- /dev/null +++ b/captioning/models/cachedTransformer.py @@ -0,0 +1,420 @@ +# This file contains Transformer network +# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html + +# The cfg name correspondance: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size +# h is always 8 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from . import utils + +import copy +import math +import numpy as np + +from .CaptionModel import CaptionModel +from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel + +class EncoderDecoder(nn.Module): + """ + A standard Encoder-Decoder architecture. Base for this and many + other models. + """ + def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): + super(EncoderDecoder, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.src_embed = src_embed + self.tgt_embed = tgt_embed + self.generator = generator + + def forward(self, src, tgt, src_mask, tgt_mask): + "Take in and process masked src and target sequences." + return self.decode(self.encode(src, src_mask), src_mask, + tgt, tgt_mask) + + def encode(self, src, src_mask): + return self.encoder(self.src_embed(src), src_mask) + + def decode(self, memory, src_mask, tgt, tgt_mask, past=None): + return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask, past=past) + +class Generator(nn.Module): + "Define standard linear + softmax generation step." + def __init__(self, d_model, vocab): + super(Generator, self).__init__() + self.proj = nn.Linear(d_model, vocab) + + def forward(self, x): + return F.log_softmax(self.proj(x), dim=-1) + +def clones(module, N): + "Produce N identical layers." + return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + +class Encoder(nn.Module): + "Core encoder is a stack of N layers" + def __init__(self, layer, N): + super(Encoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, mask): + "Pass the input (and mask) through each layer in turn." + for layer in self.layers: + x = layer(x, mask) + return self.norm(x) + +class LayerNorm(nn.Module): + "Construct a layernorm module (See citation for details)." + def __init__(self, features, eps=1e-6): + super(LayerNorm, self).__init__() + self.a_2 = nn.Parameter(torch.ones(features)) + self.b_2 = nn.Parameter(torch.zeros(features)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 + +class SublayerConnection(nn.Module): + """ + A residual connection followed by a layer norm. + Note for code simplicity the norm is first as opposed to last. + """ + def __init__(self, size, dropout): + super(SublayerConnection, self).__init__() + self.norm = LayerNorm(size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer): + "Apply residual connection to any sublayer with the same size." + _x = sublayer(self.norm(x)) + if type(_x) is tuple: # for multi-head attention that returns past + return x + self.dropout(_x[0]), _x[1] + return x + self.dropout(_x) + +class EncoderLayer(nn.Module): + "Encoder is made up of self-attn and feed forward (defined below)" + def __init__(self, size, self_attn, feed_forward, dropout): + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 2) + self.size = size + + def forward(self, x, mask): + "Follow Figure 1 (left) for connections." + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) + return self.sublayer[1](x, self.feed_forward) + +class Decoder(nn.Module): + "Generic N layer decoder with masking." + def __init__(self, layer, N): + super(Decoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, memory, src_mask, tgt_mask, past=None): + if past is not None: + present = [[], []] + x = x[:, -1:] + tgt_mask = tgt_mask[:, -1:] if tgt_mask is not None else None + past = list(zip(past[0].split(2, dim=0), past[1].split(2, dim=0))) + else: + past = [None] * len(self.layers) + for i, (layer, layer_past) in enumerate(zip(self.layers, past)): + x = layer(x, memory, src_mask, tgt_mask, + layer_past) + if layer_past is not None: + present[0].append(x[1][0]) + present[1].append(x[1][1]) + x = x[0] + if past[0] is None: + return self.norm(x) + else: + return self.norm(x), [torch.cat(present[0], 0), torch.cat(present[1], 0)] + + +class DecoderLayer(nn.Module): + "Decoder is made of self-attn, src-attn, and feed forward (defined below)" + def __init__(self, size, self_attn, src_attn, feed_forward, dropout): + super(DecoderLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 3) + + def forward(self, x, memory, src_mask, tgt_mask, layer_past=None): + "Follow Figure 1 (right) for connections." + m = memory + if layer_past is None: + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) + x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) + return self.sublayer[2](x, self.feed_forward) + else: + present = [None, None] + x, present[0] = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask, layer_past[0])) + x, present[1] = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask, layer_past[1])) + return self.sublayer[2](x, self.feed_forward), present + +def subsequent_mask(size): + "Mask out subsequent positions." + attn_shape = (1, size, size) + subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') + return torch.from_numpy(subsequent_mask) == 0 + +def attention(query, key, value, mask=None, dropout=None): + "Compute 'Scaled Dot Product Attention'" + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) \ + / math.sqrt(d_k) + if mask is not None: + scores = scores.masked_fill(mask == 0, float('-inf')) + p_attn = F.softmax(scores, dim = -1) + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn, value), p_attn + +class MultiHeadedAttention(nn.Module): + def __init__(self, h, d_model, dropout=0.1): + "Take in model size and number of heads." + super(MultiHeadedAttention, self).__init__() + assert d_model % h == 0 + # We assume d_v always equals d_k + self.d_k = d_model // h + self.h = h + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward(self, query, key, value, mask=None, layer_past=None): + "Implements Figure 2" + if mask is not None: + # Same mask applied to all h heads. + mask = mask.unsqueeze(1) + nbatches = query.size(0) + + # The past works differently here. For self attn, the query and key be updated incrementailly + # For src_attn the past is fixed. + + # For src_attn, when the layer past is ready + if layer_past is not None and layer_past.shape[2] == key.shape[1] > 1: # suppose memory size always greater than 1 + query = self.linears[0](query) + key, value = layer_past[0], layer_past[1] + present = torch.stack([key, value]) + else: + # 1) Do all the linear projections in batch from d_model => h x d_k + query, key, value = \ + [l(x) for l, x in zip(self.linears, (query, key, value))] + + # self attn + past OR the first time step of src attn + if layer_past is not None and not (layer_past.shape[2] == key.shape[1] > 1): + past_key, past_value = layer_past[0], layer_past[1] + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + present = torch.stack([key, value]) + + query, key, value = \ + [x.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + for x in [query, key, value]] + + # 2) Apply attention on all the projected vectors in batch. + x, self.attn = attention(query, key, value, mask=mask, + dropout=self.dropout) + + # 3) "Concat" using a view and apply a final linear. + x = x.transpose(1, 2).contiguous() \ + .view(nbatches, -1, self.h * self.d_k) + if layer_past is not None: + return self.linears[-1](x), present + else: + return self.linears[-1](x) + +class PositionwiseFeedForward(nn.Module): + "Implements FFN equation." + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.w_2(self.dropout(F.relu(self.w_1(x)))) + +class Embeddings(nn.Module): + def __init__(self, d_model, vocab): + super(Embeddings, self).__init__() + self.lut = nn.Embedding(vocab, d_model) + self.d_model = d_model + + def forward(self, x): + return self.lut(x) * math.sqrt(self.d_model) + +class PositionalEncoding(nn.Module): + "Implement the PE function." + def __init__(self, d_model, dropout, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1).float() + div_term = torch.exp(torch.arange(0, d_model, 2).float() * + -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return self.dropout(x) + +class TransformerModel(AttModel): + + def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, + d_model=512, d_ff=2048, h=8, dropout=0.1): + "Helper: Construct a model from hyperparameters." + c = copy.deepcopy + attn = MultiHeadedAttention(h, d_model, dropout) + ff = PositionwiseFeedForward(d_model, d_ff, dropout) + position = PositionalEncoding(d_model, dropout) + model = EncoderDecoder( + Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N_enc), + Decoder(DecoderLayer(d_model, c(attn), c(attn), + c(ff), dropout), N_dec), + lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)), + nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), + Generator(d_model, tgt_vocab)) + + # This was important from their code. + # Initialize parameters with Glorot / fan_avg. + for p in model.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + return model + + def __init__(self, opt): + super(TransformerModel, self).__init__(opt) + self.opt = opt + # self.config = yaml.load(open(opt.config_file)) + + self.N_enc = getattr(opt, 'N_enc', opt.num_layers) + self.N_dec = getattr(opt, 'N_dec', opt.num_layers) + self.d_model = getattr(opt, 'd_model', opt.input_encoding_size) + self.d_ff = getattr(opt, 'd_ff', opt.rnn_size) + self.h = getattr(opt, 'num_att_heads', 8) + self.dropout = getattr(opt, 'dropout', 0.1) + + delattr(self, 'att_embed') + self.att_embed = nn.Sequential(*( + ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ + (nn.Linear(self.att_feat_size, self.d_model), + nn.ReLU(), + nn.Dropout(self.drop_prob_lm))+ + ((nn.BatchNorm1d(self.d_model),) if self.use_bn==2 else ()))) + + delattr(self, 'embed') + self.embed = lambda x : x + delattr(self, 'fc_embed') + self.fc_embed = lambda x : x + delattr(self, 'logit') + del self.ctx2att + + tgt_vocab = self.vocab_size + 1 + + + self.model = self.make_model(0, tgt_vocab, + N_enc=self.N_enc, + N_dec=self.N_dec, + d_model=self.d_model, + d_ff=self.d_ff, + h=self.h, + dropout=self.dropout) + + def logit(self, x): # unsafe way + return self.model.generator.proj(x) + + def init_hidden(self, bsz): + return [] + + def _prepare_feature(self, fc_feats, att_feats, att_masks): + + att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) + memory = self.model.encode(att_feats, att_masks) + + return fc_feats[...,:0], att_feats[...,:0], memory, att_masks + + def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): + att_feats, att_masks = self.clip_att(att_feats, att_masks) + + att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) + + if att_masks is None: + att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) + att_masks = att_masks.unsqueeze(-2) + + if seq is not None: + # crop the last one + # seq = seq[:,:-1] + seq_mask = (seq.data != self.eos_idx) & (seq.data != self.pad_idx) + seq_mask[:,0] = 1 # bos + + seq_mask = seq_mask.unsqueeze(-2) + seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) + + seq_per_img = seq.shape[0] // att_feats.shape[0] + if seq_per_img > 1: + att_feats, att_masks = utils.repeat_tensors(seq_per_img, + [att_feats, att_masks] + ) + else: + seq_mask = None + + return att_feats, seq, att_masks, seq_mask + + def _forward(self, fc_feats, att_feats, seq, att_masks=None): + if seq.ndim == 3: # B * seq_per_img * seq_len + seq = seq.reshape(-1, seq.shape[2]) + att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) + + out = self.model(att_feats, seq, att_masks, seq_mask) + + outputs = self.model.generator(out) + return outputs + # return torch.cat([_.unsqueeze(1) for _ in outputs], 1) + + def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): + """ + state is the precomputed key/value. N_dec x seq_len x d_model + Note: due to the layer norm, it's not equivalant to stateless, + but it seems behaving similar + """ + # state is tokens + past + if len(state) == 0: + ys = it.unsqueeze(1) + # basically empty state, just to let it know to return past + # The second dim has to be batch_size, for beam search purpose + past = [fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model), # self + fc_feats_ph.new_zeros(self.N_dec * 2, fc_feats_ph.shape[0], 0, self.d_model)] # src + # 2 for self attn, 2 for src attn + else: + ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) + past = state[1:] + out, past = self.model.decode(memory, mask, + ys, # We still feed the full past words, because we need it for position embedding to know the position id + subsequent_mask(ys.size(1)) + .to(memory.device), + past=past) + return out[:, -1], [ys.unsqueeze(0)] + past diff --git a/captioning/models/utils.py b/captioning/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..feb130bceb26aae56b9a849a7131f8fde784a43d --- /dev/null +++ b/captioning/models/utils.py @@ -0,0 +1,25 @@ +import torch + +def repeat_tensors(n, x): + """ + For a tensor of size Bx..., we repeat it n times, and make it Bnx... + For collections, do nested repeat + """ + if torch.is_tensor(x): + x = x.unsqueeze(1) # Bx1x... + x = x.expand(-1, n, *([-1]*len(x.shape[2:]))) # Bxnx... + x = x.reshape(x.shape[0]*n, *x.shape[2:]) # Bnx... + elif type(x) is list or type(x) is tuple: + x = [repeat_tensors(n, _) for _ in x] + return x + + +def split_tensors(n, x): + if torch.is_tensor(x): + assert x.shape[0] % n == 0 + x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1) + elif type(x) is list or type(x) is tuple: + x = [split_tensors(n, _) for _ in x] + elif x is None: + x = [None] * n + return x \ No newline at end of file diff --git a/captioning/modules/loss_wrapper.py b/captioning/modules/loss_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..d86f1e6f7df4a6bc112563294b8bf6bb4d999b98 --- /dev/null +++ b/captioning/modules/loss_wrapper.py @@ -0,0 +1,127 @@ +import torch +from . import losses +from ..utils.rewards import init_scorer, get_self_critical_reward, get_self_critical_clipscore_reward +from ..utils.clipscore import CLIPScore +import numpy as np + +class LossWrapper(torch.nn.Module): + def __init__(self, model, opt): + super(LossWrapper, self).__init__() + self.opt = opt + self.model = model + if opt.label_smoothing > 0: + self.crit = losses.LabelSmoothing(smoothing=opt.label_smoothing) + else: + self.crit = losses.LanguageModelCriterion() + self.rl_crit = losses.RewardCriterion() + self.struc_crit = losses.StructureLosses(opt) + + self.clipscore_model = None + if self.opt.use_clipscore: + use_grammar = getattr(self.opt, 'use_grammar', False) + joint_out = getattr(self.opt, 'joint_out', False) + self.clipscore_model = CLIPScore( + mode=opt.clipscore_mode, + use_grammar=use_grammar, + joint_out=joint_out, + ) + for p in self.clipscore_model.parameters(): + p.requires_grad = False + + if use_grammar: + state_dict = torch.load(self.opt.clip_load_path, map_location='cpu') + self.clipscore_model.load_state_dict(state_dict['state_dict']) + + def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, + sc_flag, struc_flag, clip_vis_feats=None): + opt = self.opt + + out = {} + if struc_flag: + if opt.structure_loss_weight < 1: + lm_loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) + else: + lm_loss = torch.tensor(0).type_as(fc_feats) + if opt.structure_loss_weight > 0: + gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, + opt={'sample_method':opt.train_sample_method, + 'beam_size':opt.train_beam_size, + 'output_logsoftmax': opt.struc_use_logsoftmax or opt.structure_loss_type == 'softmax_margin'\ + or not 'margin' in opt.structure_loss_type, + 'sample_n': opt.train_sample_n}, + mode='sample') + gts = [gts[_] for _ in gt_indices.tolist()] + struc_loss = self.struc_crit(sample_logprobs, gen_result, gts) + else: + struc_loss = {'loss': torch.tensor(0).type_as(fc_feats), + 'reward': torch.tensor(0).type_as(fc_feats)} + loss = (1-opt.structure_loss_weight) * lm_loss + opt.structure_loss_weight * struc_loss['loss'] + out['lm_loss'] = lm_loss + out['struc_loss'] = struc_loss['loss'] + out['reward'] = struc_loss['reward'] + elif not sc_flag: + loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) + else: + self.model.eval() + with torch.no_grad(): + greedy_res, _ = self.model(fc_feats, att_feats, att_masks, + mode='sample', + opt={'sample_method': opt.sc_sample_method, + 'beam_size': opt.sc_beam_size}) + self.model.train() + gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks, + opt={'sample_method':opt.train_sample_method, + 'beam_size':opt.train_beam_size, + 'sample_n': opt.train_sample_n}, + mode='sample') + gts = [gts[_] for _ in gt_indices.tolist()] + + if getattr(self.opt, 'use_multi_rewards', False): + assert self.opt.use_clipscore + clipscore_reward_normalized, clipscore_unnormalized_mean, grammar_rewards = get_self_critical_clipscore_reward( + greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab) + + if self.opt.clipscore_mode == 'clip_s': + out['CLIP-S'] = clipscore_unnormalized_mean + elif self.opt.clipscore_mode == 'refclip_s': + out['RefCLIP-S'] = clipscore_unnormalized_mean + + if getattr(self.opt, 'use_grammar', False): + out['grammar_reward'] = grammar_rewards.mean() + + reward = clipscore_reward_normalized + grammar_rewards + + + else: + assert grammar_rewards is None + + cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward( + greedy_res, gts, gen_result, self.opt) + out['CIDEr'] = cider_unnormalized_mean + if isinstance(cider_reward_normalized, np.ndarray): + cider_reward_normalized = torch.from_numpy(cider_reward_normalized).to(clipscore_reward_normalized.device) + + reward = clipscore_reward_normalized + cider_reward_normalized + else: + if self.opt.use_clipscore: + clipscore_reward_normalized, clipscore_unnormalized_mean, _ = get_self_critical_clipscore_reward( + greedy_res, gts, gen_result, self.opt, self.clipscore_model, clip_vis_feats, self.model.vocab) + if self.opt.clipscore_mode == 'clip_s': + out['CLIP-S'] = clipscore_unnormalized_mean + elif self.opt.clipscore_mode == 'refclip_s': + out['RefCLIP-S'] = clipscore_unnormalized_mean + reward = clipscore_reward_normalized + else: + cider_reward_normalized, cider_unnormalized_mean = get_self_critical_reward( + greedy_res, gts, gen_result, self.opt) + out['CIDEr'] = cider_unnormalized_mean + reward = cider_reward_normalized + + if isinstance(reward, np.ndarray): + reward = torch.from_numpy(reward) + reward = reward.to(sample_logprobs) + loss = self.rl_crit(sample_logprobs, gen_result.data, reward) + out['reward'] = reward[:,0].mean() + out['loss'] = loss + return out + diff --git a/captioning/modules/losses.py b/captioning/modules/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..28d6db59dd70a9418a8a074d54402d6b5823520c --- /dev/null +++ b/captioning/modules/losses.py @@ -0,0 +1,218 @@ +import torch +import torch.nn as nn +from ..utils.rewards import get_scores, get_self_cider_scores + +class RewardCriterion(nn.Module): + def __init__(self): + super(RewardCriterion, self).__init__() + + def forward(self, input, seq, reward): + input = input.gather(2, seq.unsqueeze(2)).squeeze(2) + + input = input.reshape(-1) + reward = reward.reshape(-1) + mask = (seq>0).to(input) + mask = torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1).reshape(-1) + output = - input * reward * mask + output = torch.sum(output) / torch.sum(mask) + + return output + +class StructureLosses(nn.Module): + """ + This loss is inspired by Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018). + """ + def __init__(self, opt): + super(StructureLosses, self).__init__() + self.opt = opt + self.loss_type = opt.structure_loss_type + + def forward(self, input, seq, data_gts): + """ + Input is either logits or log softmax + """ + out = {} + + batch_size = input.size(0)# batch_size = sample_size * seq_per_img + seq_per_img = batch_size // len(data_gts) + + assert seq_per_img == self.opt.train_sample_n, seq_per_img + + mask = (seq>0).to(input) + mask = torch.cat([mask.new_full((mask.size(0), 1), 1), mask[:, :-1]], 1) + + scores = get_scores(data_gts, seq, self.opt) + scores = torch.from_numpy(scores).type_as(input).view(-1, seq_per_img) + out['reward'] = scores #.mean() + if self.opt.entropy_reward_weight > 0: + entropy = - (F.softmax(input, dim=2) * F.log_softmax(input, dim=2)).sum(2).data + entropy = (entropy * mask).sum(1) / mask.sum(1) + print('entropy', entropy.mean().item()) + scores = scores + self.opt.entropy_reward_weight * entropy.view(-1, seq_per_img) + # rescale cost to [0,1] + costs = - scores + if self.loss_type == 'risk' or self.loss_type == 'softmax_margin': + costs = costs - costs.min(1, keepdim=True)[0] + costs = costs / costs.max(1, keepdim=True)[0] + # in principle + # Only risk need such rescale + # margin should be alright; Let's try. + + # Gather input: BxTxD -> BxT + input = input.gather(2, seq.unsqueeze(2)).squeeze(2) + + if self.loss_type == 'seqnll': + # input is logsoftmax + input = input * mask + input = input.sum(1) / mask.sum(1) + input = input.view(-1, seq_per_img) + + target = costs.min(1)[1] + output = F.cross_entropy(input, target) + elif self.loss_type == 'risk': + # input is logsoftmax + input = input * mask + input = input.sum(1) + input = input.view(-1, seq_per_img) + + output = (F.softmax(input.exp()) * costs).sum(1).mean() + + # test + # avg_scores = input + # probs = F.softmax(avg_scores.exp_()) + # loss = (probs * costs.type_as(probs)).sum() / input.size(0) + # print(output.item(), loss.item()) + + elif self.loss_type == 'max_margin': + # input is logits + input = input * mask + input = input.sum(1) / mask.sum(1) + input = input.view(-1, seq_per_img) + _, __ = costs.min(1, keepdim=True) + costs_star = _ + input_star = input.gather(1, __) + output = F.relu(costs - costs_star - input_star + input).max(1)[0] / 2 + output = output.mean() + + # sanity test + # avg_scores = input + costs + # scores_with_high_target = avg_scores.clone() + # scores_with_high_target.scatter_(1, costs.min(1)[1].view(-1, 1), 1e10) + + # target_and_offender_index = scores_with_high_target.sort(1, True)[1][:, 0:2] + # avg_scores = avg_scores.gather(1, target_and_offender_index) + # target_index = avg_scores.new_zeros(avg_scores.size(0), dtype=torch.long) + # loss = F.multi_margin_loss(avg_scores, target_index, size_average=True, margin=0) + # print(loss.item() * 2, output.item()) + + elif self.loss_type == 'multi_margin': + # input is logits + input = input * mask + input = input.sum(1) / mask.sum(1) + input = input.view(-1, seq_per_img) + _, __ = costs.min(1, keepdim=True) + costs_star = _ + input_star = input.gather(1, __) + output = F.relu(costs - costs_star - input_star + input) + output = output.mean() + + # sanity test + # avg_scores = input + costs + # loss = F.multi_margin_loss(avg_scores, costs.min(1)[1], margin=0) + # print(output, loss) + + elif self.loss_type == 'softmax_margin': + # input is logsoftmax + input = input * mask + input = input.sum(1) / mask.sum(1) + input = input.view(-1, seq_per_img) + + input = input + costs + target = costs.min(1)[1] + output = F.cross_entropy(input, target) + + elif self.loss_type == 'real_softmax_margin': + # input is logits + # This is what originally defined in Kevin's paper + # The result should be equivalent to softmax_margin + input = input * mask + input = input.sum(1) / mask.sum(1) + input = input.view(-1, seq_per_img) + + input = input + costs + target = costs.min(1)[1] + output = F.cross_entropy(input, target) + + elif self.loss_type == 'new_self_critical': + """ + A different self critical + Self critical uses greedy decoding score as baseline; + This setting uses the average score of the rest samples as baseline + (suppose c1...cn n samples, reward1 = score1 - 1/(n-1)(score2+..+scoren) ) + """ + baseline = (scores.sum(1, keepdim=True) - scores) / (scores.shape[1] - 1) + scores = scores - baseline + # self cider used as reward to promote diversity (not working that much in this way) + if getattr(self.opt, 'self_cider_reward_weight', 0) > 0: + _scores = get_self_cider_scores(data_gts, seq, self.opt) + _scores = torch.from_numpy(_scores).type_as(scores).view(-1, 1) + _scores = _scores.expand_as(scores - 1) + scores += self.opt.self_cider_reward_weight * _scores + output = - input * mask * scores.view(-1, 1) + output = torch.sum(output) / torch.sum(mask) + + out['loss'] = output + return out + +class LanguageModelCriterion(nn.Module): + def __init__(self): + super(LanguageModelCriterion, self).__init__() + + def forward(self, input, target, mask): + if target.ndim == 3: + target = target.reshape(-1, target.shape[2]) + mask = mask.reshape(-1, mask.shape[2]) + # truncate to the same size + target = target[:, :input.size(1)] + mask = mask[:, :input.size(1)].to(input) + + output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask + # Average over each token + output = torch.sum(output) / torch.sum(mask) + + return output + +class LabelSmoothing(nn.Module): + "Implement label smoothing." + def __init__(self, size=0, padding_idx=0, smoothing=0.0): + super(LabelSmoothing, self).__init__() + self.criterion = nn.KLDivLoss(size_average=False, reduce=False) + # self.padding_idx = padding_idx + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + # self.size = size + self.true_dist = None + + def forward(self, input, target, mask): + if target.ndim == 3: + target = target.reshape(-1, target.shape[2]) + mask = mask.reshape(-1, mask.shape[2]) + # truncate to the same size + target = target[:, :input.size(1)] + mask = mask[:, :input.size(1)] + + input = input.reshape(-1, input.size(-1)) + target = target.reshape(-1) + mask = mask.reshape(-1).to(input) + + # assert x.size(1) == self.size + self.size = input.size(1) + # true_dist = x.data.clone() + true_dist = input.data.clone() + # true_dist.fill_(self.smoothing / (self.size - 2)) + true_dist.fill_(self.smoothing / (self.size - 1)) + true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) + # true_dist[:, self.padding_idx] = 0 + # mask = torch.nonzero(target.data == self.padding_idx) + # self.true_dist = true_dist + return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum() \ No newline at end of file diff --git a/captioning/utils/__init__.py b/captioning/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/captioning/utils/clipscore.py b/captioning/utils/clipscore.py new file mode 100644 index 0000000000000000000000000000000000000000..0345140d9f7b47e37b3a895915a135e1441c907b --- /dev/null +++ b/captioning/utils/clipscore.py @@ -0,0 +1,396 @@ +from transformers import CLIPModel, CLIPTokenizer +import os +import json +import argparse +from random import shuffle, seed +import string +# non-standard dependencies: +import h5py +from six.moves import cPickle +import numpy as np +import torch +import torchvision.models as models +import skimage.io + +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from PIL import Image +from torch import nn + + +class CLIPScore(nn.Module): + def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s', use_grammar=False, joint_out=False): + super(CLIPScore, self).__init__() + # from transformers import CLIPModel, CLIPTokenizer + self.clip_model = CLIPModel.from_pretrained( + 'openai/clip-vit-base-patch32') + self.tokenizer = CLIPTokenizer.from_pretrained( + 'openai/clip-vit-base-patch32') + + self.clip_model.eval() + + self.clipscore_w = clipscore_w + + self.image_transform = self._transform(image_size) + + self.mode = mode + assert mode in ['clip_s', 'refclip_s'] + + self.use_grammar = use_grammar + self.joint_out = joint_out + + if self.use_grammar and joint_out is False: + self.grammar_score_head = nn.Sequential( + nn.Linear(self.clip_model.text_embed_dim, self.clip_model.projection_dim, bias=False), + nn.ReLU(), + nn.Linear(self.clip_model.projection_dim, 2, bias=False) + ) + + def _transform(self, n_px): + return Compose([ + Resize(n_px, interpolation=Image.BICUBIC), + CenterCrop(n_px), + lambda image: image.convert("RGB"), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + + def load_image(self, image_path): + image = Image.open(image_path) + return image + + # @torch.no_grad() + def image_extract(self, image): + if isinstance(image, str): + image = self.load_image(image) + if not isinstance(image, torch.Tensor): + image = self.image_transform(image) + + img_tensor = image.view(-1, 3, 224, 224) + device = next(self.clip_model.parameters()).device + img_tensor = img_tensor.to(device) + + clip_model = self.clip_model + + img_feat = clip_model.vision_model(img_tensor).pooler_output + img_feat = clip_model.visual_projection(img_feat) + img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) + + return img_feat + + # @torch.no_grad() + def text_extract(self, text, prompt="A photo depicts", proj_norm=True): + if isinstance(text, str): + text_batch = [" ".join([prompt, text])] + elif isinstance(text, list): + text_batch = [" ".join([prompt, txt]) for txt in text] + + if isinstance(text, tuple) and isinstance(text[0], torch.Tensor): + input_ids, attention_mask = text + else: + input_text = text_batch + + tokenized = self.tokenizer( + input_text, return_tensors='pt', padding=True, truncation=True) + + input_ids = tokenized.input_ids + attention_mask = tokenized.attention_mask + + clip_model = self.clip_model + device = next(self.clip_model.parameters()).device + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + + text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output + + if proj_norm: + text_feat = clip_model.text_projection(text_feat) + text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) + + return text_feat + + # @torch.no_grad() + def calc_clip_s(self, img_feat, text_feat): + return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1)) + + # @torch.no_grad() + def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None): + + if clip_s is None: + clip_s = self.calc_clip_s(img_feat, text_feat) + + B, dim = img_feat.size() + + ref_text_feat = ref_text_feat.view(B, -1, dim) + + K = ref_text_feat.size(1) + + text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1) + assert ref_text_feat.size() == text_feat.size( + ), (ref_text_feat.size(), text_feat.size()) + + ref_score = self.calc_clip_s(text_feat, ref_text_feat) + if ref_text_mask is not None: + if not isinstance(ref_text_mask, torch.Tensor): + ref_text_mask = torch.tensor( + ref_text_mask, dtype=ref_score.dtype, device=ref_score.device) + ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K) + + ref_score = ref_score.view(B, K).max(dim=1).values + + assert clip_s.size() == (B,) + assert clip_s.size() == ref_score.size() + + # harmonic mean + refclip_s = 2 / (1 / clip_s + 1 / ref_score) + return refclip_s + + @torch.no_grad() + def forward(self, + images=None, text=None, + img_feat=None, text_feat=None, + ref_text=None, ref_text_feat=None, ref_text_mask=None, + prompt="A photo depicts", + mode=None): + if img_feat is None: + img_feat = self.image_extract(images) + img_feat = img_feat.view(-1, 512) + + B = img_feat.size(0) + + if text_feat is None: + text_feat = self.text_extract(text, prompt=prompt) + text_feat = text_feat.view(-1, 512) + + if mode is None: + mode = self.mode + assert mode in ['clip_s', 'refclip_s'] + + if mode == 'clip_s': + clip_s = self.calc_clip_s(img_feat, text_feat) + return clip_s + elif mode == 'refclip_s': + if ref_text_feat is None: + ref_text_feat = self.text_extract(ref_text, prompt=prompt) + ref_text_feat = ref_text_feat.view(-1, 512) + + refclip_s = self.calc_refclip_s( + img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask) + return refclip_s + + + def train_step(self, + images=None, text=None, + img_feat=None, text_feat=None, + neg_text=None, neg_text_feat=None, + # ref_text=None, ref_text_feat=None, ref_text_mask=None, + prompt="A photo depicts", + # return_loss=True, + **kwargs): + + if img_feat is None: + img_feat = self.image_extract(images) + img_feat = img_feat.view(-1, 512) + + B = img_feat.size(0) + + if text_feat is None: + text_feat = self.text_extract(text, prompt=prompt, proj_norm=False) + + text_cont_feat = self.clip_model.text_projection(text_feat) + text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True) + text_cont_feat = text_cont_feat.view(B, 512) + + # cosine similarity as logits + logit_scale = self.clip_model.logit_scale.exp() + logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale + # logits_per_image = logits_per_text.T + + clip_loss = clip_loss_fn(logits_per_text) + + + # negative sampling + pos_text_feat = text_feat.view(B, 512) + neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512) + + grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0) + + # 2B, 1 + grammar_text_logit = self.grammar_score_head(grammar_text_feat) + grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B) + + grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels) + + grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False) + grammar_pos_pred = grammar_pred[:B] + grammar_neg_pred = grammar_pred[B:] + # grammar_acc = (grammar_pred == grammar_labels).float().mean() + + out = { + 'clip_loss': clip_loss, + 'grammar_loss': grammar_loss, + 'img_feat': img_feat, + 'text_feat': text_cont_feat, + 'neg_text_feat': neg_text_feat, + 'grammar_pos_pred': grammar_pos_pred, + 'grammar_neg_pred': grammar_neg_pred, + } + + return out + +# contrastive loss function, adapted from +# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html +def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor: + neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim)) + return -neg_ce.mean() + + +def clip_loss_fn(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity, dim=0) + image_loss = contrastive_loss(similarity, dim=1) + return (caption_loss + image_loss) / 2.0 + + + +# class CLIPScore(nn.Module): +# def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s'): +# super(CLIPScore, self).__init__() +# # from transformers import CLIPModel, CLIPTokenizer +# self.clip_model = CLIPModel.from_pretrained( +# 'openai/clip-vit-base-patch32') +# self.tokenizer = CLIPTokenizer.from_pretrained( +# 'openai/clip-vit-base-patch32') + +# self.clip_model.eval() + +# self.clipscore_w = clipscore_w + +# self.image_transform = self._transform(image_size) + +# self.mode = mode +# assert mode in ['clip_s', 'refclip_s'] + +# def _transform(self, n_px): +# return Compose([ +# Resize(n_px, interpolation=Image.BICUBIC), +# CenterCrop(n_px), +# lambda image: image.convert("RGB"), +# ToTensor(), +# Normalize((0.48145466, 0.4578275, 0.40821073), +# (0.26862954, 0.26130258, 0.27577711)), +# ]) + +# def load_image(self, image_path): +# image = Image.open(image_path) +# return image + +# @torch.no_grad() +# def image_extract(self, image): +# if isinstance(image, str): +# image = self.load_image(image) +# if not isinstance(image, torch.Tensor): +# image = self.image_transform(image) + +# img_tensor = image.view(-1, 3, 224, 224) +# device = next(self.clip_model.parameters()).device +# img_tensor = img_tensor.to(device) + +# clip_model = self.clip_model + +# img_feat = clip_model.vision_model(img_tensor).pooler_output +# img_feat = clip_model.visual_projection(img_feat) +# img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) + +# return img_feat + +# @torch.no_grad() +# def text_extract(self, text, prompt="A photo depicts"): +# if isinstance(text, str): +# text_batch = [" ".join([prompt, text])] +# else: +# text_batch = [" ".join([prompt, txt]) for txt in text] + +# input_text = text_batch + +# tokenized = self.tokenizer( +# input_text, return_tensors='pt', padding=True) + +# input_ids = tokenized.input_ids +# attention_mask = tokenized.attention_mask + +# clip_model = self.clip_model +# device = next(self.clip_model.parameters()).device +# input_ids = input_ids.to(device) +# attention_mask = attention_mask.to(device) + +# text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output +# text_feat = clip_model.text_projection(text_feat) +# text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) + +# return text_feat + +# @torch.no_grad() +# def calc_clip_s(self, img_feat, text_feat): +# return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1)) + +# @torch.no_grad() +# def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None): + +# if clip_s is None: +# clip_s = self.calc_clip_s(img_feat, text_feat) + +# B, dim = img_feat.size() + +# ref_text_feat = ref_text_feat.view(B, -1, dim) + +# K = ref_text_feat.size(1) + +# text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1) +# assert ref_text_feat.size() == text_feat.size(), (ref_text_feat.size(), text_feat.size()) + +# ref_score = self.calc_clip_s(text_feat, ref_text_feat) +# if ref_text_mask is not None: +# if not isinstance(ref_text_mask, torch.Tensor): +# ref_text_mask = torch.tensor(ref_text_mask, dtype=ref_score.dtype, device=ref_score.device) +# ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K) + +# ref_score = ref_score.view(B, K).max(dim=1).values + +# assert clip_s.size() == (B,) +# assert clip_s.size() == ref_score.size() + +# # harmonic mean +# refclip_s = 2 / (1 / clip_s + 1 / ref_score) +# return refclip_s + + +# @torch.no_grad() +# def forward(self, +# images=None, text=None, +# img_feat=None, text_feat=None, +# ref_text=None, ref_text_feat=None, ref_text_mask=None, +# prompt="A photo depicts", +# mode=None): +# if img_feat is None: +# img_feat = self.image_extract(images) +# img_feat = img_feat.view(-1, 512) + +# if text_feat is None: +# text_feat = self.text_extract(text, prompt=prompt) +# text_feat = text_feat.view(-1, 512) + +# if mode is None: +# mode = self.mode +# assert mode in ['clip_s', 'refclip_s'] + +# if mode == 'clip_s': +# clip_s = self.calc_clip_s(img_feat, text_feat) +# return clip_s +# elif mode == 'refclip_s': +# if ref_text_feat is None: +# ref_text_feat = self.text_extract(ref_text, prompt=prompt) +# ref_text_feat = ref_text_feat.view(-1, 512) + +# refclip_s = self.calc_refclip_s(img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask) +# return refclip_s + diff --git a/captioning/utils/config.py b/captioning/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e42704dcba2fb2f751fec413551a5069e63f25c9 --- /dev/null +++ b/captioning/utils/config.py @@ -0,0 +1,153 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Copy from fvcore + +import logging +import os +from typing import Any +import yaml +from yacs.config import CfgNode as _CfgNode + +import io as PathManager + +BASE_KEY = "_BASE_" + + +class CfgNode(_CfgNode): + """ + Our own extended version of :class:`yacs.config.CfgNode`. + It contains the following extra features: + + 1. The :meth:`merge_from_file` method supports the "_BASE_" key, + which allows the new CfgNode to inherit all the attributes from the + base configuration file. + 2. Keys that start with "COMPUTED_" are treated as insertion-only + "computed" attributes. They can be inserted regardless of whether + the CfgNode is frozen or not. + 3. With "allow_unsafe=True", it supports pyyaml tags that evaluate + expressions in config. See examples in + https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types + Note that this may lead to arbitrary code execution: you must not + load a config file from untrusted sources before manually inspecting + the content of the file. + """ + + @staticmethod + def load_yaml_with_base(filename, allow_unsafe = False): + """ + Just like `yaml.load(open(filename))`, but inherit attributes from its + `_BASE_`. + + Args: + filename (str): the file name of the current config. Will be used to + find the base config file. + allow_unsafe (bool): whether to allow loading the config file with + `yaml.unsafe_load`. + + Returns: + (dict): the loaded yaml + """ + with PathManager.open(filename, "r") as f: + try: + cfg = yaml.safe_load(f) + except yaml.constructor.ConstructorError: + if not allow_unsafe: + raise + logger = logging.getLogger(__name__) + logger.warning( + "Loading config {} with yaml.unsafe_load. Your machine may " + "be at risk if the file contains malicious content.".format( + filename + ) + ) + f.close() + with open(filename, "r") as f: + cfg = yaml.unsafe_load(f) + + def merge_a_into_b(a, b): + # merge dict a into dict b. values in a will overwrite b. + for k, v in a.items(): + if isinstance(v, dict) and k in b: + assert isinstance( + b[k], dict + ), "Cannot inherit key '{}' from base!".format(k) + merge_a_into_b(v, b[k]) + else: + b[k] = v + + if BASE_KEY in cfg: + base_cfg_file = cfg[BASE_KEY] + if base_cfg_file.startswith("~"): + base_cfg_file = os.path.expanduser(base_cfg_file) + if not any( + map(base_cfg_file.startswith, ["/", "https://", "http://"]) + ): + # the path to base cfg is relative to the config file itself. + base_cfg_file = os.path.join( + os.path.dirname(filename), base_cfg_file + ) + base_cfg = CfgNode.load_yaml_with_base( + base_cfg_file, allow_unsafe=allow_unsafe + ) + del cfg[BASE_KEY] + + merge_a_into_b(cfg, base_cfg) + return base_cfg + return cfg + + def merge_from_file(self, cfg_filename, allow_unsafe = False): + """ + Merge configs from a given yaml file. + + Args: + cfg_filename: the file name of the yaml config. + allow_unsafe: whether to allow loading the config file with + `yaml.unsafe_load`. + """ + loaded_cfg = CfgNode.load_yaml_with_base( + cfg_filename, allow_unsafe=allow_unsafe + ) + loaded_cfg = type(self)(loaded_cfg) + self.merge_from_other_cfg(loaded_cfg) + + # Forward the following calls to base, but with a check on the BASE_KEY. + def merge_from_other_cfg(self, cfg_other): + """ + Args: + cfg_other (CfgNode): configs to merge from. + """ + assert ( + BASE_KEY not in cfg_other + ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) + return super().merge_from_other_cfg(cfg_other) + + def merge_from_list(self, cfg_list): + """ + Args: + cfg_list (list): list of configs to merge from. + """ + keys = set(cfg_list[0::2]) + assert ( + BASE_KEY not in keys + ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) + return super().merge_from_list(cfg_list) + + def __setattr__(self, name, val): + if name.startswith("COMPUTED_"): + if name in self: + old_val = self[name] + if old_val == val: + return + raise KeyError( + "Computed attributed '{}' already exists " + "with a different value! old={}, new={}.".format( + name, old_val, val + ) + ) + self[name] = val + else: + super().__setattr__(name, val) + + +if __name__ == '__main__': + cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml') + print(cfg) \ No newline at end of file diff --git a/captioning/utils/dist_utils.py b/captioning/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..53a7c462570edb8f381c65fabf60c729f1607f41 --- /dev/null +++ b/captioning/utils/dist_utils.py @@ -0,0 +1,305 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import functools +import logging +import numpy as np +import pickle +import torch +import torch.distributed as dist + +import torch + +_LOCAL_PROCESS_GROUP = None +""" +A torch process group which only includes processes that on the same machine as the current process. +This variable is set when processes are spawned by `launch()` in "engine/launch.py". +""" + + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def _serialize_to_tensor(data, group): + backend = dist.get_backend(group) + assert backend in ["gloo", "nccl"] + device = torch.device("cpu" if backend == "gloo" else "cuda") + + buffer = pickle.dumps(data) + if len(buffer) > 1024 ** 3: + logger = logging.getLogger(__name__) + logger.warning( + "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( + get_rank(), len(buffer) / (1024 ** 3), device + ) + ) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def _pad_to_largest_tensor(tensor, group): + """ + Returns: + list[int]: size of the tensor, on each rank + Tensor: padded tensor that has the max size + """ + world_size = dist.get_world_size(group=group) + assert ( + world_size >= 1 + ), "comm.gather/all_gather must be called from ranks within the given group!" + local_size = torch.tensor( + [tensor.numel()], dtype=torch.int64, device=tensor.device) + size_list = [ + torch.zeros([1], dtype=torch.int64, device=tensor.device) + for _ in range(world_size) + ] + dist.all_gather(size_list, local_size, group=group) + size_list = [int(size.item()) for size in size_list] + + max_size = max(size_list) + + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + if local_size != max_size: + padding = torch.zeros( + (max_size - local_size,), dtype=torch.uint8, device=tensor.device + ) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group) == 1: + return [data] + + tensor = _serialize_to_tensor(data, group) + + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving Tensor from all ranks + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) + for _ in size_list + ] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + if dist.get_world_size(group=group) == 1: + return [data] + rank = dist.get_rank(group=group) + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + + # receiving Tensor from all ranks + if rank == dst: + max_size = max(size_list) + tensor_list = [ + torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) + for _ in size_list + ] + dist.gather(tensor, tensor_list, dst=dst, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + else: + dist.gather(tensor, [], dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2 ** 31) + all_ints = all_gather(ints) + return all_ints[0] + + +# def reduce_dict(input_dict, average=True): +# """ +# Reduce the values in the dictionary from all processes so that process with rank +# 0 has the reduced results. +# Args: +# input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. +# average (bool): whether to do average or sum +# Returns: +# a dict with the same keys as input_dict, after reduction. +# """ +# world_size = get_world_size() +# if world_size < 2: +# return input_dict +# with torch.no_grad(): +# names = [] +# values = [] +# # sort the keys so that they are consistent across processes +# for k in sorted(input_dict.keys()): +# names.append(k) +# values.append(input_dict[k]) +# values = torch.stack(values, dim=0) +# dist.reduce(values, dst=0) +# if dist.get_rank() == 0 and average: +# # only main process gets accumulated, so only divide by +# # world_size in this case +# values /= world_size +# reduced_dict = {k: v for k, v in zip(names, values)} +# return reduced_dict + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + Args: + input_dict (dict): inputs to be reduced. (values not necessarily tensors). + average (bool): whether to do average or sum + Returns: + a dict with the same keys as input_dict, after reduction. + """ + + world_size = get_world_size() + if world_size < 2: + return input_dict + + with torch.no_grad(): + + # Convert to CUDA Tensor for dist.reduce() + input_dict_cuda_vals = {} + for k, v in input_dict.items(): + if type(v) == torch.Tensor: + input_dict_cuda_vals[k] = v.to('cuda') + else: + input_dict_cuda_vals[k] = torch.tensor(v, device='cuda') + + names = [] + values = [] + for k, v in sorted(input_dict_cuda_vals.items()): + names.append(k) + values.append(v) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) # reduce to gpu 0 + + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/captioning/utils/div_utils.py b/captioning/utils/div_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a757eb7b2184767f8ea2351b30cce6601a45be78 --- /dev/null +++ b/captioning/utils/div_utils.py @@ -0,0 +1,38 @@ +from random import uniform +import numpy as np +from collections import OrderedDict, defaultdict +from itertools import tee +import time + +# ----------------------------------------------- +def find_ngrams(input_list, n): + return zip(*[input_list[i:] for i in range(n)]) + +def compute_div_n(caps,n=1): + aggr_div = [] + for k in caps: + all_ngrams = set() + lenT = 0. + for c in caps[k]: + tkns = c.split() + lenT += len(tkns) + ng = find_ngrams(tkns, n) + all_ngrams.update(ng) + aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT))) + return np.array(aggr_div).mean(), np.array(aggr_div) + +def compute_global_div_n(caps,n=1): + aggr_div = [] + all_ngrams = set() + lenT = 0. + for k in caps: + for c in caps[k]: + tkns = c.split() + lenT += len(tkns) + ng = find_ngrams(tkns, n) + all_ngrams.update(ng) + if n == 1: + aggr_div.append(float(len(all_ngrams))) + else: + aggr_div.append(float(len(all_ngrams))/ (1e-6 + float(lenT))) + return aggr_div[0], np.repeat(np.array(aggr_div),len(caps)) \ No newline at end of file diff --git a/captioning/utils/eval_multi.py b/captioning/utils/eval_multi.py new file mode 100644 index 0000000000000000000000000000000000000000..83907410b806a50002aa32db289ca86cff72f45d --- /dev/null +++ b/captioning/utils/eval_multi.py @@ -0,0 +1,218 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn + +import numpy as np +import json +from json import encoder +import random +import string +import time +import os +import sys +from . import misc as utils +from eval_utils import getCOCO + +from .div_utils import compute_div_n, compute_global_div_n + +import sys +try: + sys.path.append("coco-caption") + annFile = 'coco-caption/annotations/captions_val2014.json' + from pycocotools.coco import COCO + from pycocoevalcap.eval import COCOEvalCap + from pycocoevalcap.eval_spice import COCOEvalCapSpice + from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer + from pycocoevalcap.bleu.bleu import Bleu + sys.path.append("cider") + from pyciderevalcap.cider.cider import Cider +except: + print('Warning: requirements for eval_multi not satisfied') + + +def eval_allspice(dataset, preds_n, model_id, split): + coco = getCOCO(dataset) + valids = coco.getImgIds() + + capsById = {} + for d in preds_n: + capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] + + # filter results to only those in MSCOCO validation set (will be about a third) + preds_filt_n = [p for p in preds_n if p['image_id'] in valids] + print('using %d/%d predictions_n' % (len(preds_filt_n), len(preds_n))) + cache_path_n = os.path.join('eval_results/', model_id + '_' + split + '_n.json') + json.dump(preds_filt_n, open(cache_path_n, 'w')) # serialize to temporary json file. Sigh, COCO API... + + # Eval AllSPICE + cocoRes_n = coco.loadRes(cache_path_n) + cocoEvalAllSPICE = COCOEvalCapSpice(coco, cocoRes_n) + cocoEvalAllSPICE.params['image_id'] = cocoRes_n.getImgIds() + cocoEvalAllSPICE.evaluate() + + out = {} + for metric, score in cocoEvalAllSPICE.eval.items(): + out['All'+metric] = score + + imgToEvalAllSPICE = cocoEvalAllSPICE.imgToEval + # collect SPICE_sub_score + for k in list(imgToEvalAllSPICE.values())[0]['SPICE'].keys(): + if k != 'All': + out['AllSPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEvalAllSPICE.values()]) + out['AllSPICE_'+k] = (out['AllSPICE_'+k][out['AllSPICE_'+k]==out['AllSPICE_'+k]]).mean() + for p in preds_filt_n: + image_id, caption = p['image_id'], p['caption'] + imgToEvalAllSPICE[image_id]['caption'] = capsById[image_id] + return {'overall': out, 'imgToEvalAllSPICE': imgToEvalAllSPICE} + +def eval_oracle(dataset, preds_n, model_id, split): + cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json') + + coco = getCOCO(dataset) + valids = coco.getImgIds() + + capsById = {} + for d in preds_n: + capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] + + sample_n = capsById[list(capsById.keys())[0]] + for i in range(len(capsById[list(capsById.keys())[0]])): + preds = [_[i] for _ in capsById.values()] + + json.dump(preds, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API... + + cocoRes = coco.loadRes(cache_path) + cocoEval = COCOEvalCap(coco, cocoRes) + cocoEval.params['image_id'] = cocoRes.getImgIds() + cocoEval.evaluate() + + imgToEval = cocoEval.imgToEval + for img_id in capsById.keys(): + tmp = imgToEval[img_id] + for k in tmp['SPICE'].keys(): + if k != 'All': + tmp['SPICE_'+k] = tmp['SPICE'][k]['f'] + if tmp['SPICE_'+k] != tmp['SPICE_'+k]: # nan + tmp['SPICE_'+k] = -100 + tmp['SPICE'] = tmp['SPICE']['All']['f'] + if tmp['SPICE'] != tmp['SPICE']: tmp['SPICE'] = -100 + capsById[img_id][i]['scores'] = imgToEval[img_id] + + out = {'overall': {}, 'ImgToEval': {}} + for img_id in capsById.keys(): + out['ImgToEval'][img_id] = {} + for metric in capsById[img_id][0]['scores'].keys(): + if metric == 'image_id': continue + out['ImgToEval'][img_id]['oracle_'+metric] = max([_['scores'][metric] for _ in capsById[img_id]]) + out['ImgToEval'][img_id]['avg_'+metric] = sum([_['scores'][metric] for _ in capsById[img_id]]) / len(capsById[img_id]) + out['ImgToEval'][img_id]['captions'] = capsById[img_id] + for metric in list(out['ImgToEval'].values())[0].keys(): + if metric == 'captions': + continue + tmp = np.array([_[metric] for _ in out['ImgToEval'].values()]) + tmp = tmp[tmp!=-100] + out['overall'][metric] = tmp.mean() + + return out + +def eval_div_stats(dataset, preds_n, model_id, split): + tokenizer = PTBTokenizer() + + capsById = {} + for i, d in enumerate(preds_n): + d['id'] = i + capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] + + n_caps_perimg = len(capsById[list(capsById.keys())[0]]) + print(n_caps_perimg) + _capsById = capsById # save the untokenized version + capsById = tokenizer.tokenize(capsById) + + div_1, adiv_1 = compute_div_n(capsById,1) + div_2, adiv_2 = compute_div_n(capsById,2) + + globdiv_1, _= compute_global_div_n(capsById,1) + + print('Diversity Statistics are as follows: \n Div1: %.2f, Div2: %.2f, gDiv1: %d\n'%(div_1,div_2, globdiv_1)) + + # compute mbleu + scorer = Bleu(4) + all_scrs = [] + scrperimg = np.zeros((n_caps_perimg, len(capsById))) + + for i in range(n_caps_perimg): + tempRefsById = {} + candsById = {} + for k in capsById: + tempRefsById[k] = capsById[k][:i] + capsById[k][i+1:] + candsById[k] = [capsById[k][i]] + + score, scores = scorer.compute_score(tempRefsById, candsById) + all_scrs.append(score) + scrperimg[i,:] = scores[1] + + all_scrs = np.array(all_scrs) + + out = {} + out['overall'] = {'Div1': div_1, 'Div2': div_2, 'gDiv1': globdiv_1} + for k, score in zip(range(4), all_scrs.mean(axis=0).tolist()): + out['overall'].update({'mBLeu_%d'%(k+1): score}) + imgToEval = {} + for i,imgid in enumerate(capsById.keys()): + imgToEval[imgid] = {'mBleu_2' : scrperimg[:,i].mean()} + imgToEval[imgid]['individuals'] = [] + for j, d in enumerate(_capsById[imgid]): + imgToEval[imgid]['individuals'].append(preds_n[d['id']]) + imgToEval[imgid]['individuals'][-1]['mBleu_2'] = scrperimg[j,i] + out['ImgToEval'] = imgToEval + + print('Mean mutual Bleu scores on this set is:\nmBLeu_1, mBLeu_2, mBLeu_3, mBLeu_4') + print(all_scrs.mean(axis=0)) + + return out + +def eval_self_cider(dataset, preds_n, model_id, split): + cache_path = os.path.join('eval_results/', model_id + '_' + split + '_n.json') + + coco = getCOCO(dataset) + valids = coco.getImgIds() + + # Get Cider_scorer + Cider_scorer = Cider(df='corpus') + + tokenizer = PTBTokenizer() + gts = {} + for imgId in valids: + gts[imgId] = coco.imgToAnns[imgId] + gts = tokenizer.tokenize(gts) + + for imgId in valids: + Cider_scorer.cider_scorer += (None, gts[imgId]) + Cider_scorer.cider_scorer.compute_doc_freq() + Cider_scorer.cider_scorer.ref_len = np.log(float(len(Cider_scorer.cider_scorer.crefs))) + + # Prepare captions + capsById = {} + for d in preds_n: + capsById[d['image_id']] = capsById.get(d['image_id'], []) + [d] + + capsById = tokenizer.tokenize(capsById) + imgIds = list(capsById.keys()) + scores = Cider_scorer.my_self_cider([capsById[_] for _ in imgIds]) + + def get_div(eigvals): + eigvals = np.clip(eigvals, 0, None) + return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals)) + sc_scores = [get_div(np.linalg.eigvalsh(_/10)) for _ in scores] + score = np.mean(np.array(sc_scores)) + + imgToEval = {} + for i, image_id in enumerate(imgIds): + imgToEval[image_id] = {'self_cider': sc_scores[i], 'self_cider_mat': scores[i].tolist()} + return {'overall': {'self_cider': score}, 'imgToEval': imgToEval} + + + return score diff --git a/captioning/utils/eval_utils.py b/captioning/utils/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bc7f4471e6d3e1fcc2f80af6f47bfec5d920a1 --- /dev/null +++ b/captioning/utils/eval_utils.py @@ -0,0 +1,281 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +import json +from json import encoder +import random +import string +import time +import os +import sys +from . import misc as utils + +# load coco-caption if available +try: + sys.path.append("coco-caption") + from pycocotools.coco import COCO + from pycocoevalcap.eval import COCOEvalCap +except: + print('Warning: coco-caption not available') + +bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] +bad_endings += ['the'] + + +def count_bad(sen): + sen = sen.split(' ') + if sen[-1] in bad_endings: + return 1 + else: + return 0 + + +def getCOCO(dataset): + if 'coco' in dataset: + annFile = 'coco-caption/annotations/captions_val2014.json' + elif 'flickr30k' in dataset or 'f30k' in dataset: + annFile = 'data/f30k_captions4eval.json' + return COCO(annFile) + + +def language_eval(dataset, preds, preds_n, eval_kwargs, split): + model_id = eval_kwargs['id'] + eval_oracle = eval_kwargs.get('eval_oracle', 0) + + # create output dictionary + out = {} + + if len(preds_n) > 0: + # vocab size and novel sentences + if 'coco' in dataset: + dataset_file = 'data/dataset_coco.json' + elif 'flickr30k' in dataset or 'f30k' in dataset: + dataset_file = 'data/dataset_flickr30k.json' + training_sentences = set([' '.join(__['tokens']) for _ in json.load(open(dataset_file))['images'] if not _['split'] in ['val', 'test'] for __ in _['sentences']]) + generated_sentences = set([_['caption'] for _ in preds_n]) + novels = generated_sentences - training_sentences + out['novel_sentences'] = float(len(novels)) / len(preds_n) + tmp = [_.split() for _ in generated_sentences] + words = [] + for _ in tmp: + words += _ + out['vocab_size'] = len(set(words)) + + # encoder.FLOAT_REPR = lambda o: format(o, '.3f') + + cache_path = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '.json') + + coco = getCOCO(dataset) + valids = coco.getImgIds() + + # filter results to only those in MSCOCO validation set + preds_filt = [p for p in preds if p['image_id'] in valids] + mean_perplexity = sum([_['perplexity'] for _ in preds_filt]) / len(preds_filt) + mean_entropy = sum([_['entropy'] for _ in preds_filt]) / len(preds_filt) + print('using %d/%d predictions' % (len(preds_filt), len(preds))) + json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API... + + cocoRes = coco.loadRes(cache_path) + cocoEval = COCOEvalCap(coco, cocoRes) + cocoEval.params['image_id'] = cocoRes.getImgIds() + cocoEval.evaluate() + + for metric, score in cocoEval.eval.items(): + out[metric] = score + # Add mean perplexity + out['perplexity'] = mean_perplexity + out['entropy'] = mean_entropy + + imgToEval = cocoEval.imgToEval + for k in list(imgToEval.values())[0]['SPICE'].keys(): + if k != 'All': + out['SPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEval.values()]) + out['SPICE_'+k] = (out['SPICE_'+k][out['SPICE_'+k]==out['SPICE_'+k]]).mean() + for p in preds_filt: + image_id, caption = p['image_id'], p['caption'] + imgToEval[image_id]['caption'] = caption + + if len(preds_n) > 0: + from . import eval_multi + cache_path_n = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '_n.json') + allspice = eval_multi.eval_allspice(dataset, preds_n, model_id, split) + out.update(allspice['overall']) + div_stats = eval_multi.eval_div_stats(dataset, preds_n, model_id, split) + out.update(div_stats['overall']) + if eval_oracle: + oracle = eval_multi.eval_oracle(dataset, preds_n, model_id, split) + out.update(oracle['overall']) + else: + oracle = None + self_cider = eval_multi.eval_self_cider(dataset, preds_n, model_id, split) + out.update(self_cider['overall']) + with open(cache_path_n, 'w') as outfile: + json.dump({'allspice': allspice, 'div_stats': div_stats, 'oracle': oracle, 'self_cider': self_cider}, outfile) + + out['bad_count_rate'] = sum([count_bad(_['caption']) for _ in preds_filt]) / float(len(preds_filt)) + outfile_path = os.path.join('eval_results/', model_id + '_' + split + '.json') + with open(outfile_path, 'w') as outfile: + json.dump({'overall': out, 'imgToEval': imgToEval}, outfile) + + return out + +def eval_split(model, crit, loader, eval_kwargs={}): + verbose = eval_kwargs.get('verbose', True) + verbose_beam = eval_kwargs.get('verbose_beam', 0) + verbose_loss = eval_kwargs.get('verbose_loss', 1) + num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) + split = eval_kwargs.get('split', 'val') + lang_eval = eval_kwargs.get('language_eval', 0) + dataset = eval_kwargs.get('dataset', 'coco') + beam_size = eval_kwargs.get('beam_size', 1) + sample_n = eval_kwargs.get('sample_n', 1) + remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) + os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration + device = eval_kwargs.get('device', 'cuda') + + # Make sure in the evaluation mode + model.eval() + + loader.reset_iterator(split) + + n = 0 + loss = 0 + loss_sum = 0 + loss_evals = 1e-8 + predictions = [] + n_predictions = [] # when sample_n > 1 + while True: + data = loader.get_batch(split) + n = n + len(data['infos']) + + tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] + tmp = [_.to(device) if _ is not None else _ for _ in tmp] + fc_feats, att_feats, labels, masks, att_masks = tmp + if labels is not None and verbose_loss: + # forward the model to get loss + with torch.no_grad(): + loss = crit(model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]).item() + loss_sum = loss_sum + loss + loss_evals = loss_evals + 1 + + # forward the model to also get generated samples for each image + with torch.no_grad(): + tmp_eval_kwargs = eval_kwargs.copy() + tmp_eval_kwargs.update({'sample_n': 1}) + seq, seq_logprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') + seq = seq.data + entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1) + perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1) + + # Print beam search + if beam_size > 1 and verbose_beam: + for i in range(fc_feats.shape[0]): + print('\n'.join([utils.decode_sequence(model.vocab, _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) + print('--' * 10) + sents = utils.decode_sequence(model.vocab, seq) + + for k, sent in enumerate(sents): + entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()} + if eval_kwargs.get('dump_path', 0) == 1: + entry['file_name'] = data['infos'][k]['file_path'] + predictions.append(entry) + if eval_kwargs.get('dump_images', 0) == 1: + # dump the raw image to vis/ folder + cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross + print(cmd) + os.system(cmd) + + if verbose: + print('image %s: %s' %(entry['image_id'], entry['caption'])) + + if sample_n > 1: + eval_split_n(model, n_predictions, [fc_feats, att_feats, att_masks, data], eval_kwargs) + + # ix0 = data['bounds']['it_pos_now'] + ix1 = data['bounds']['it_max'] + if num_images != -1: + ix1 = min(ix1, num_images) + else: + num_images = ix1 + for i in range(n - ix1): + predictions.pop() + + if verbose: + print('evaluating validation preformance... %d/%d (%f)' %(n, ix1, loss)) + + if num_images >= 0 and n >= num_images: + break + + lang_stats = None + if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]: + n_predictions = sorted(n_predictions, key=lambda x: x['perplexity']) + if not os.path.isdir('eval_results'): + os.mkdir('eval_results') + torch.save((predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth')) + if lang_eval == 1: + lang_stats = language_eval(dataset, predictions, n_predictions, eval_kwargs, split) + + # Switch back to training mode + model.train() + return loss_sum/loss_evals, predictions, lang_stats + + +# Only run when sample_n > 0 +def eval_split_n(model, n_predictions, input_data, eval_kwargs={}): + verbose = eval_kwargs.get('verbose', True) + beam_size = eval_kwargs.get('beam_size', 1) + sample_n = eval_kwargs.get('sample_n', 1) + sample_n_method = eval_kwargs.get('sample_n_method', 'sample') + + fc_feats, att_feats, att_masks, data = input_data + + tmp_eval_kwargs = eval_kwargs.copy() + if sample_n_method == 'bs': + # case 1 sample_n == beam size + tmp_eval_kwargs.update({'sample_n': 1, 'beam_size': sample_n, 'group_size': 1}) # randomness from softmax + with torch.no_grad(): + model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') + for k in range(fc_feats.shape[0]): + _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(sample_n)])) + for sent in _sents: + entry = {'image_id': data['infos'][k]['id'], 'caption': sent} + n_predictions.append(entry) + # case 2 sample / gumbel / topk sampling/ nucleus sampling + elif sample_n_method == 'sample' or \ + sample_n_method == 'gumbel' or \ + sample_n_method.startswith('top'): + tmp_eval_kwargs.update({'sample_n': sample_n, 'sample_method': sample_n_method, 'beam_size': 1}) # randomness from sample + with torch.no_grad(): + _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') + _sents = utils.decode_sequence(model.vocab, _seq) + _perplexity = - _sampleLogprobs.gather(2, _seq.unsqueeze(2)).squeeze(2).sum(1) / ((_seq>0).to(_sampleLogprobs).sum(1)+1) + for k, sent in enumerate(_sents): + entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent, 'perplexity': _perplexity[k].item()} + n_predictions.append(entry) + elif sample_n_method == 'dbs': + # Use diverse beam search + tmp_eval_kwargs.update({'beam_size': sample_n * beam_size, 'group_size': sample_n}) # randomness from softmax + with torch.no_grad(): + model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') + for k in range(loader.batch_size): + _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(0, sample_n*beam_size, beam_size)])) + for sent in _sents: + entry = {'image_id': data['infos'][k]['id'], 'caption': sent} + n_predictions.append(entry) + else: + tmp_eval_kwargs.update({'sample_method': sample_n_method[1:], 'group_size': sample_n, 'beam_size':1}) # randomness from softmax + with torch.no_grad(): + _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') + _sents = utils.decode_sequence(model.vocab, _seq) + for k, sent in enumerate(_sents): + entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent} + n_predictions.append(entry) + if verbose: + for entry in sorted(n_predictions[-fc_feats.shape[0] * sample_n:], key=lambda x: x['image_id']): + print('image %s: %s' %(entry['image_id'], entry['caption'])) \ No newline at end of file diff --git a/captioning/utils/misc.py b/captioning/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..3edcc1b51c99e66c568fa5d3d93f131911096489 --- /dev/null +++ b/captioning/utils/misc.py @@ -0,0 +1,251 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import torch +import torch.nn as nn +import numpy as np +import torch.optim as optim +import os + +import torch.nn.functional as F + +import six +from six.moves import cPickle + +bad_endings = ['with','in','on','of','a','at','to','for','an','this','his','her','that'] +bad_endings += ['the'] + + +def pickle_load(f): + """ Load a pickle. + Parameters + ---------- + f: file-like object + """ + if six.PY3: + return cPickle.load(f, encoding='latin-1') + else: + return cPickle.load(f) + + +def pickle_dump(obj, f): + """ Dump a pickle. + Parameters + ---------- + obj: pickled object + f: file-like object + """ + if six.PY3: + return cPickle.dump(obj, f, protocol=2) + else: + return cPickle.dump(obj, f) + + +# modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/comm.py +def serialize_to_tensor(data): + device = torch.device("cpu") + + buffer = cPickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + + +def deserialize(tensor): + buffer = tensor.cpu().numpy().tobytes() + return cPickle.loads(buffer) + + +# Input: seq, N*D numpy array, with element 0 .. vocab_size. 0 is END token. +def decode_sequence(ix_to_word, seq): + # N, D = seq.size() + N, D = seq.shape + out = [] + for i in range(N): + txt = '' + for j in range(D): + ix = seq[i,j] + if ix > 0 : + if j >= 1: + txt = txt + ' ' + txt = txt + ix_to_word[str(ix.item())] + else: + break + if int(os.getenv('REMOVE_BAD_ENDINGS', '0')): + flag = 0 + words = txt.split(' ') + for j in range(len(words)): + if words[-j-1] not in bad_endings: + flag = -j + break + txt = ' '.join(words[0:len(words)+flag]) + out.append(txt.replace('@@ ', '')) + return out + + +def save_checkpoint(opt, model, infos, optimizer, histories=None, append=''): + if len(append) > 0: + append = '-' + append + # if checkpoint_path doesn't exist + if not os.path.isdir(opt.checkpoint_path): + os.makedirs(opt.checkpoint_path) + checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append)) + torch.save(model.state_dict(), checkpoint_path) + print("model saved to {}".format(checkpoint_path)) + optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append)) + torch.save(optimizer.state_dict(), optimizer_path) + with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f: + pickle_dump(infos, f) + if histories: + with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f: + pickle_dump(histories, f) + + +def set_lr(optimizer, lr): + for group in optimizer.param_groups: + group['lr'] = lr + +def get_lr(optimizer): + for group in optimizer.param_groups: + return group['lr'] + + +def build_optimizer(params, opt): + if opt.optim == 'rmsprop': + return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay) + elif opt.optim == 'adagrad': + return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay) + elif opt.optim == 'sgd': + return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay) + elif opt.optim == 'sgdm': + return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay) + elif opt.optim == 'sgdmom': + return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True) + elif opt.optim == 'adam': + return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay) + elif opt.optim == 'adamw': + return optim.AdamW(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay) + else: + raise Exception("bad option opt.optim: {}".format(opt.optim)) + + +def penalty_builder(penalty_config): + if penalty_config == '': + return lambda x,y: y + pen_type, alpha = penalty_config.split('_') + alpha = float(alpha) + if pen_type == 'wu': + return lambda x,y: length_wu(x,y,alpha) + if pen_type == 'avg': + return lambda x,y: length_average(x,y,alpha) + +def length_wu(length, logprobs, alpha=0.): + """ + NMT length re-ranking score from + "Google's Neural Machine Translation System" :cite:`wu2016google`. + """ + + modifier = (((5 + length) ** alpha) / + ((5 + 1) ** alpha)) + return (logprobs / modifier) + +def length_average(length, logprobs, alpha=0.): + """ + Returns the average probability of tokens in a sequence. + """ + return logprobs / length + + +class NoamOpt(object): + "Optim wrapper that implements rate." + def __init__(self, model_size, factor, warmup, optimizer): + self.optimizer = optimizer + self._step = 0 + self.warmup = warmup + self.factor = factor + self.model_size = model_size + self._rate = 0 + + def step(self): + "Update parameters and rate" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p['lr'] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step = None): + "Implement `lrate` above" + if step is None: + step = self._step + return self.factor * \ + (self.model_size ** (-0.5) * + min(step ** (-0.5), step * self.warmup ** (-1.5))) + + def __getattr__(self, name): + return getattr(self.optimizer, name) + + def state_dict(self): + state_dict = self.optimizer.state_dict() + state_dict['_step'] = self._step + return state_dict + + def load_state_dict(self, state_dict): + if '_step' in state_dict: + self._step = state_dict['_step'] + del state_dict['_step'] + self.optimizer.load_state_dict(state_dict) + +class ReduceLROnPlateau(object): + "Optim wrapper that implements rate." + def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08): + self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps) + self.optimizer = optimizer + self.current_lr = get_lr(optimizer) + + def step(self): + "Update parameters and rate" + self.optimizer.step() + + def scheduler_step(self, val): + self.scheduler.step(val) + self.current_lr = get_lr(self.optimizer) + + def state_dict(self): + return {'current_lr':self.current_lr, + 'scheduler_state_dict': self.scheduler.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict()} + + def load_state_dict(self, state_dict): + if 'current_lr' not in state_dict: + # it's normal optimizer + self.optimizer.load_state_dict(state_dict) + set_lr(self.optimizer, self.current_lr) # use the lr fromt the option + else: + # it's a schduler + self.current_lr = state_dict['current_lr'] + self.scheduler.load_state_dict(state_dict['scheduler_state_dict']) + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + # current_lr is actually useless in this case + + def rate(self, step = None): + "Implement `lrate` above" + if step is None: + step = self._step + return self.factor * \ + (self.model_size ** (-0.5) * + min(step ** (-0.5), step * self.warmup ** (-1.5))) + + def __getattr__(self, name): + return getattr(self.optimizer, name) + +def get_std_opt(model, optim_func='adam', factor=1, warmup=2000): + # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000, + # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) + optim_func = dict(adam=torch.optim.Adam, + adamw=torch.optim.AdamW)[optim_func] + return NoamOpt(model.d_model, factor, warmup, + optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) diff --git a/captioning/utils/opts.py b/captioning/utils/opts.py new file mode 100644 index 0000000000000000000000000000000000000000..778e512361727de0939bbd7b014e6eeb716a0c67 --- /dev/null +++ b/captioning/utils/opts.py @@ -0,0 +1,412 @@ +from __future__ import print_function +import argparse + + +def if_use_feat(caption_model): + # Decide if load attention feature according to caption model + if caption_model in ['show_tell', 'all_img', 'fc', 'newfc']: + use_att, use_fc = False, True + elif caption_model == 'language_model': + use_att, use_fc = False, False + elif caption_model in ['updown', 'topdown']: + use_fc, use_att = True, True + else: + use_att, use_fc = True, False + return use_fc, use_att + +import pprint +class Config(object): + def __init__(self, **kwargs): + """Configuration Class: set kwargs as class attributes with setattr""" + for k, v in kwargs.items(): + setattr(self, k, v) + + @property + def config_str(self): + return pprint.pformat(self.__dict__) + + def __repr__(self): + """Pretty-print configurations in alphabetical order""" + config_str = 'Configurations\n' + config_str += self.config_str + return config_str + + +def parse_opt(parse=True, **optional_kwargs): + parser = argparse.ArgumentParser() + # Data input settings + parser.add_argument('--input_json', type=str, default='data/coco.json', + help='path to the json file containing additional info and vocab') + parser.add_argument('--input_fc_dir', type=str, default='data/cocotalk_fc', + help='path to the directory containing the preprocessed fc feats') + parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att', + help='path to the directory containing the preprocessed att feats') + parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box', + help='path to the directory containing the boxes of att feats') + parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5', + help='path to the h5file containing the preprocessed dataset') + parser.add_argument('--data_in_memory', action='store_true', + help='True if we want to save the features in memory') + parser.add_argument('--start_from', type=str, default=None, + help="""continue training from saved model at this path. Path must contain files saved by previous training process: + 'infos.pkl' : configuration; + 'model.pth' : weights + """) + parser.add_argument('--cached_tokens', type=str, default='coco-train-idxs', + help='Cached token file for calculating cider score during self critical training.') + + # Model settings + parser.add_argument('--caption_model', type=str, default="show_tell", + help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, updown, stackatt, denseatt, transformer') + parser.add_argument('--rnn_size', type=int, default=512, + help='size of the rnn in number of hidden nodes in each layer') + parser.add_argument('--num_layers', type=int, default=1, + help='number of layers in the RNN') + parser.add_argument('--rnn_type', type=str, default='lstm', + help='rnn, gru, or lstm') + parser.add_argument('--input_encoding_size', type=int, default=512, + help='the encoding size of each token in the vocabulary, and the image.') + parser.add_argument('--att_hid_size', type=int, default=512, + help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer') + parser.add_argument('--fc_feat_size', type=int, default=2048, + help='2048 for resnet, 4096 for vgg') + parser.add_argument('--att_feat_size', type=int, default=2048, + help='2048 for resnet, 512 for vgg') + parser.add_argument('--logit_layers', type=int, default=1, + help='number of layers in the RNN') + + + parser.add_argument('--use_bn', type=int, default=0, + help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed') + + # feature manipulation + parser.add_argument('--norm_att_feat', type=int, default=0, + help='If normalize attention features') + parser.add_argument('--use_box', type=int, default=0, + help='If use box features') + parser.add_argument('--norm_box_feat', type=int, default=0, + help='If use box, do we normalize box feature') + + # Optimization: General + parser.add_argument('--max_epochs', type=int, default=-1, + help='number of epochs') + parser.add_argument('--batch_size', type=int, default=16, + help='minibatch size') + parser.add_argument('--grad_clip_mode', type=str, default='value', + help='value or norm') + parser.add_argument('--grad_clip_value', type=float, default=0.1, + help='clip gradients at this value/max_norm, 0 means no clipping') + parser.add_argument('--drop_prob_lm', type=float, default=0.5, + help='strength of dropout in the Language Model RNN') + parser.add_argument('--self_critical_after', type=int, default=-1, + help='After what epoch do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)') + parser.add_argument('--seq_per_img', type=int, default=5, + help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image') + + parser.add_argument('--verbose', type=int, default=0) + + # Sample related + add_eval_sample_opts(parser) + + #Optimization: for the Language Model + parser.add_argument('--optim', type=str, default='adam', + help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam|adamw') + parser.add_argument('--learning_rate', type=float, default=4e-4, + help='learning rate') + parser.add_argument('--learning_rate_decay_start', type=int, default=-1, + help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)') + parser.add_argument('--learning_rate_decay_every', type=int, default=3, + help='every how many iterations thereafter to drop LR?(in epoch)') + parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8, + help='every how many iterations thereafter to drop LR?(in epoch)') + parser.add_argument('--optim_alpha', type=float, default=0.9, + help='alpha for adam') + parser.add_argument('--optim_beta', type=float, default=0.999, + help='beta used for adam') + parser.add_argument('--optim_epsilon', type=float, default=1e-8, + help='epsilon that goes into denominator for smoothing') + parser.add_argument('--weight_decay', type=float, default=0, + help='weight_decay') + # Transformer + parser.add_argument('--label_smoothing', type=float, default=0, + help='') + parser.add_argument('--noamopt', action='store_true', + help='') + parser.add_argument('--noamopt_warmup', type=int, default=2000, + help='') + parser.add_argument('--noamopt_factor', type=float, default=1, + help='') + parser.add_argument('--reduce_on_plateau', action='store_true', + help='') + parser.add_argument('--reduce_on_plateau_factor', type=float, default=0.5, + help='') + parser.add_argument('--reduce_on_plateau_patience', type=int, default=3, + help='') + parser.add_argument('--cached_transformer', action='store_true', + help='') + + + parser.add_argument('--use_warmup', action='store_true', + help='warm up the learing rate?') + + parser.add_argument('--scheduled_sampling_start', type=int, default=-1, + help='at what iteration to start decay gt probability') + parser.add_argument('--scheduled_sampling_increase_every', type=int, default=5, + help='every how many iterations thereafter to gt probability') + parser.add_argument('--scheduled_sampling_increase_prob', type=float, default=0.05, + help='How much to update the prob') + parser.add_argument('--scheduled_sampling_max_prob', type=float, default=0.25, + help='Maximum scheduled sampling prob.') + + + # Evaluation/Checkpointing + parser.add_argument('--val_images_use', type=int, default=3200, + help='how many images to use when periodically evaluating the validation loss? (-1 = all)') + parser.add_argument('--save_checkpoint_every', type=int, default=2500, + help='how often to save a model checkpoint (in iterations)?') + parser.add_argument('--save_every_epoch', action='store_true', + help='Save checkpoint every epoch, will overwrite save_checkpoint_every') + parser.add_argument('--save_history_ckpt', type=int, default=0, + help='If save checkpoints at every save point') + parser.add_argument('--checkpoint_path', type=str, default=None, + help='directory to store checkpointed models') + parser.add_argument('--language_eval', type=int, default=0, + help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') + parser.add_argument('--losses_log_every', type=int, default=25, + help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)') + parser.add_argument('--load_best_score', type=int, default=1, + help='Do we load previous best score when resuming training.') + + # misc + parser.add_argument('--id', type=str, default='', + help='an id identifying this run/job. used in cross-val and appended when writing progress files') + parser.add_argument('--train_only', type=int, default=0, + help='if true then use 80k, else use 110k') + + + # Reward + parser.add_argument('--cider_reward_weight', type=float, default=1, + help='The reward weight from cider') + parser.add_argument('--bleu_reward_weight', type=float, default=0, + help='The reward weight from bleu4') + + # Reward + parser.add_argument('--clipscore_reward_weight', type=float, default=1, + help='The reward weight from clipscore') + parser.add_argument('--use_clipscore', type=float, default=0, + help='Use CLIPScore') + parser.add_argument('--clipscore_mode', type=str, default='clip_s', + help='Which CLIPScore to use: clip_s|refclip_s') + + + # Structure_loss + parser.add_argument('--structure_loss_weight', type=float, default=1, + help='') + parser.add_argument('--structure_after', type=int, default=-1, + help='T') + parser.add_argument('--structure_loss_type', type=str, default='seqnll', + help='') + parser.add_argument('--struc_use_logsoftmax', action='store_true', help='') + parser.add_argument('--entropy_reward_weight', type=float, default=0, + help='Entropy reward, seems very interesting') + parser.add_argument('--self_cider_reward_weight', type=float, default=0, + help='self cider reward') + + # Used for self critical or structure. Used when sampling is need during training + parser.add_argument('--train_sample_n', type=int, default=16, + help='The reward weight from cider') + parser.add_argument('--train_sample_method', type=str, default='sample', + help='') + parser.add_argument('--train_beam_size', type=int, default=1, + help='') + + # Used for self critical + parser.add_argument('--sc_sample_method', type=str, default='greedy', + help='') + parser.add_argument('--sc_beam_size', type=int, default=1, + help='') + + + # For diversity evaluation during training + add_diversity_opts(parser) + + + # config + parser.add_argument('--cfg', type=str, default=None, + help='configuration; similar to what is used in detectron') + parser.add_argument( + '--set_cfgs', dest='set_cfgs', + help='Set config keys. Key value sequence seperate by whitespace.' + 'e.g. [key] [value] [key] [value]\n This has higher priority' + 'than cfg file but lower than other args. (You can only overwrite' + 'arguments that have alerady been defined in config file.)', + default=[], nargs='+') + # How will config be used + # 1) read cfg argument, and load the cfg file if it's not None + # 2) Overwrite cfg argument with set_cfgs + # 3) parse config argument to args. + # 4) in the end, parse command line argument and overwrite args + + # step 1: read cfg_fn + # args = parser.parse_args() + # Parse the arguments. + if parse: + args = parser.parse_args() + # For interative engironmnet (ex. jupyter) + else: + args = parser.parse_known_args()[0] + # print(args) + + # Namespace => Dictionary + kwargs = vars(args) + # for k, v in optional_kwargs.items(): + # setattr(args, k, v) + kwargs.update(optional_kwargs) + + args = Config(**kwargs) + + + if args.cfg is not None or args.set_cfgs is not None: + from .config import CfgNode + if args.cfg is not None: + # print('Read Cfg') + cn = CfgNode(CfgNode.load_yaml_with_base(args.cfg)) + # print(cn) + else: + cn = CfgNode() + if args.set_cfgs is not None: + cn.merge_from_list(args.set_cfgs) + for k,v in cn.items(): + if not hasattr(args, k): + import os + if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': + pass + else: + print('Warning: key %s not in args' % k) + + setattr(args, k, v) + + if parse: + args = parser.parse_args(namespace=args) + else: + args = parser.parse_known_args(namespace=args)[0] + + # Check if args are valid + assert args.rnn_size > 0, "rnn_size should be greater than 0" + assert args.num_layers > 0, "num_layers should be greater than 0" + assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0" + assert args.batch_size > 0, "batch_size should be greater than 0" + assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1" + assert args.seq_per_img > 0, "seq_per_img should be greater than 0" + assert args.beam_size > 0, "beam_size should be greater than 0" + assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0" + assert args.losses_log_every > 0, "losses_log_every should be greater than 0" + assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1" + assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1" + assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1" + + # default value for start_from and checkpoint_path + args.checkpoint_path = args.checkpoint_path or './log_%s' %args.id + args.start_from = args.start_from or args.checkpoint_path + + # Deal with feature things before anything + args.use_fc, args.use_att = if_use_feat(args.caption_model) + if args.use_box: args.att_feat_size = args.att_feat_size + 5 + + return args + + +def add_eval_options(parser): + # Basic options + parser.add_argument('--batch_size', type=int, default=0, + help='if > 0 then overrule, otherwise load from checkpoint.') + parser.add_argument('--num_images', type=int, default=-1, + help='how many images to use when periodically evaluating the loss? (-1 = all)') + parser.add_argument('--language_eval', type=int, default=0, + help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') + parser.add_argument('--dump_images', type=int, default=1, + help='Dump images into vis/imgs folder for vis? (1=yes,0=no)') + parser.add_argument('--dump_json', type=int, default=1, + help='Dump json with predictions into vis folder? (1=yes,0=no)') + parser.add_argument('--dump_path', type=int, default=0, + help='Write image paths along with predictions into vis json? (1=yes,0=no)') + + # Sampling options + add_eval_sample_opts(parser) + + # For evaluation on a folder of images: + parser.add_argument('--image_folder', type=str, default='', + help='If this is nonempty then will predict on the images in this folder path') + parser.add_argument('--image_root', type=str, default='', + help='In case the image paths have to be preprended with a root path to an image folder') + # For evaluation on MSCOCO images from some split: + parser.add_argument('--input_fc_dir', type=str, default='', + help='path to the h5file containing the preprocessed dataset') + parser.add_argument('--input_att_dir', type=str, default='', + help='path to the h5file containing the preprocessed dataset') + parser.add_argument('--input_box_dir', type=str, default='', + help='path to the h5file containing the preprocessed dataset') + parser.add_argument('--input_label_h5', type=str, default='', + help='path to the h5file containing the preprocessed dataset') + parser.add_argument('--input_json', type=str, default='', + help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.') + parser.add_argument('--split', type=str, default='test', + help='if running on MSCOCO images, which split to use: val|test|train') + parser.add_argument('--coco_json', type=str, default='', + help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.') + # misc + parser.add_argument('--id', type=str, default='', + help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') + parser.add_argument('--verbose_beam', type=int, default=1, + help='if we need to print out all beam search beams.') + parser.add_argument('--verbose_loss', type=int, default=0, + help='If calculate loss using ground truth during evaluation') + +def add_diversity_opts(parser): + parser.add_argument('--sample_n', type=int, default=1, + help='Diverse sampling') + parser.add_argument('--sample_n_method', type=str, default='sample', + help='sample, bs, dbs, gumbel, topk, dgreedy, dsample, dtopk, dtopp') + parser.add_argument('--eval_oracle', type=int, default=1, + help='if we need to calculate loss.') + + +# Sampling related options +def add_eval_sample_opts(parser): + parser.add_argument('--sample_method', type=str, default='greedy', + help='greedy; sample; gumbel; top, top<0-1>') + parser.add_argument('--beam_size', type=int, default=1, + help='used when sample_method = greedy, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') + parser.add_argument('--max_length', type=int, default=20, + help='Maximum length during sampling') + parser.add_argument('--length_penalty', type=str, default='', + help='wu_X or avg_X, X is the alpha') + parser.add_argument('--group_size', type=int, default=1, + help='used for diverse beam search. if group_size is 1, then it\'s normal beam search') + parser.add_argument('--diversity_lambda', type=float, default=0.5, + help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list') + parser.add_argument('--temperature', type=float, default=1.0, + help='temperature when sampling from distributions (i.e. when sample_method = sample). Lower = "safer" predictions.') + parser.add_argument('--decoding_constraint', type=int, default=0, + help='If 1, not allowing same word in a row') + parser.add_argument('--block_trigrams', type=int, default=0, + help='block repeated trigram.') + parser.add_argument('--remove_bad_endings', type=int, default=0, + help='Remove bad endings') + parser.add_argument('--suppress_UNK', type=int, default=1, + help='Not predicting UNK') + + +if __name__ == '__main__': + import sys + sys.argv = [sys.argv[0]] + args = parse_opt() + print(args) + print() + sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml'] + args1 = parse_opt() + print(dict(set(vars(args1).items()) - set(vars(args).items()))) + print() + sys.argv = [sys.argv[0], '--cfg', 'configs/updown_long.yml', '--caption_model', 'att2in2'] + args2 = parse_opt() + print(dict(set(vars(args2).items()) - set(vars(args1).items()))) diff --git a/captioning/utils/resnet.py b/captioning/utils/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e8aaff426d5d6c837f6dc49eefa16a31fc1834de --- /dev/null +++ b/captioning/utils/resnet.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +import torchvision.models.resnet +from torchvision.models.resnet import BasicBlock, Bottleneck + +class ResNet(torchvision.models.resnet.ResNet): + def __init__(self, block, layers, num_classes=1000): + super(ResNet, self).__init__(block, layers, num_classes) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change + for i in range(2, 5): + getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2) + getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1) + +def resnet18(pretrained=False): + """Constructs a ResNet-18 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2]) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + return model + + +def resnet34(pretrained=False): + """Constructs a ResNet-34 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3]) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + return model + + +def resnet50(pretrained=False): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3]) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + return model + + +def resnet101(pretrained=False): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3]) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + return model + + +def resnet152(pretrained=False): + """Constructs a ResNet-152 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3]) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + return model \ No newline at end of file diff --git a/captioning/utils/resnet_utils.py b/captioning/utils/resnet_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e1df171ab75700352333f6af5d59f751819b57f6 --- /dev/null +++ b/captioning/utils/resnet_utils.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class myResnet(nn.Module): + def __init__(self, resnet): + super(myResnet, self).__init__() + self.resnet = resnet + + def forward(self, img, att_size=14): + x = img.unsqueeze(0) + + x = self.resnet.conv1(x) + x = self.resnet.bn1(x) + x = self.resnet.relu(x) + x = self.resnet.maxpool(x) + + x = self.resnet.layer1(x) + x = self.resnet.layer2(x) + x = self.resnet.layer3(x) + x = self.resnet.layer4(x) + + fc = x.mean(3).mean(2).squeeze() + att = F.adaptive_avg_pool2d(x,[att_size,att_size]).squeeze().permute(1, 2, 0) + + return fc, att + diff --git a/captioning/utils/rewards.py b/captioning/utils/rewards.py new file mode 100644 index 0000000000000000000000000000000000000000..668b830cbdef05d6c3eab8d99a07918a325e9157 --- /dev/null +++ b/captioning/utils/rewards.py @@ -0,0 +1,392 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import time +from collections import OrderedDict +import torch + +import sys +try: + sys.path.append("cider") + from pyciderevalcap.ciderD.ciderD import CiderD + from pyciderevalcap.cider.cider import Cider + sys.path.append("coco-caption") + from pycocoevalcap.bleu.bleu import Bleu +except: + print('cider or coco-caption missing') + +CiderD_scorer = None +Cider_scorer = None +Bleu_scorer = None +#CiderD_scorer = CiderD(df='corpus') + + +from .misc import decode_sequence + +def init_scorer(cached_tokens): + global CiderD_scorer + CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens) + global Cider_scorer + Cider_scorer = Cider_scorer or Cider(df=cached_tokens) + global Bleu_scorer + Bleu_scorer = Bleu_scorer or Bleu(4) + +def array_to_str(arr): + out = '' + for i in range(len(arr)): + out += str(arr[i]) + ' ' + if arr[i] == 0: + break + return out.strip() + +def get_self_critical_reward(greedy_res, data_gts, gen_result, opt): + batch_size = len(data_gts) + gen_result_size = gen_result.shape[0] + seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img + assert greedy_res.shape[0] == batch_size + + res = OrderedDict() + gen_result = gen_result.data.cpu().numpy() + greedy_res = greedy_res.data.cpu().numpy() + for i in range(gen_result_size): + res[i] = [array_to_str(gen_result[i])] + for i in range(batch_size): + res[gen_result_size + i] = [array_to_str(greedy_res[i])] + + gts = OrderedDict() + for i in range(len(data_gts)): + gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))] + + res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))] + res__ = {i: res[i] for i in range(len(res_))} + gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)} + gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)}) + if opt.cider_reward_weight > 0: + _, cider_scores = CiderD_scorer.compute_score(gts_, res_) + if hasattr(opt, 'verbose') and not opt.verbose: + pass + else: + print('Cider scores:', _) + else: + cider_scores = 0 + if opt.bleu_reward_weight > 0: + _, bleu_scores = Bleu_scorer.compute_score(gts_, res__) + bleu_scores = np.array(bleu_scores[3]) + if hasattr(opt, 'verbose') and not opt.verbose: + pass + else: + print('Bleu scores:', _[3]) + else: + bleu_scores = 0 + scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores + + unnormalized_reward_mean = scores[:gen_result_size].flatten().mean() + + scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis] + + scores = scores.reshape(gen_result_size) + + rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1) + + return rewards, unnormalized_reward_mean + + +def get_self_critical_clipscore_reward(greedy_res, data_gts, gen_result, opt, clipscore_model, clip_vis_feats, vocab): + batch_size = len(data_gts) + gen_result_size = gen_result.shape[0] + seq_per_img = gen_result_size // len(data_gts) # gen_result_size = batch_size * seq_per_img + assert greedy_res.shape[0] == batch_size + + B = batch_size + K = seq_per_img + L = gen_result.shape[1] + assert gen_result.shape == (B*K , L) + + # res = OrderedDict() + # gen_result = gen_result.data.cpu().numpy() + # greedy_res = greedy_res.data.cpu().numpy() + # for i in range(gen_result_size): + # res[i] = [array_to_str(gen_result[i])] + # for i in range(batch_size): + # res[gen_result_size + i] = [array_to_str(greedy_res[i])] + + # gts = OrderedDict() + # for i in range(len(data_gts)): + # gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))] + + # res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))] + # res__ = {i: res[i] for i in range(len(res_))} + # gts_ = {i: gts[i // seq_per_img] for i in range(gen_result_size)} + # gts_.update({i+gen_result_size: gts[i] for i in range(batch_size)}) + + # res = [] + # gen_result = gen_result.data.cpu().numpy() + # greedy_res = greedy_res.data.cpu().numpy() + # # for i in range(gen_result_size): + # # res.append(array_to_str(gen_result[i])) + # res.extend(decode_sequence(vocab, gen_result)) + + + # # for i in range(batch_size): + # # res.append(array_to_str(greedy_res[i])) + # res.extend(decode_sequence(vocab, greedy_res)) + + if clipscore_model.mode == 'refclip_s': + gts = [] + gts_valid_mask = [] + max_n_refs = max([len(_gts) for _gts in data_gts]) + for i in range(len(data_gts)): + _gts = decode_sequence(vocab, data_gts[i]) + # pad references + n_ref = len(_gts) + _gts.extend([''] * (max_n_refs - n_ref)) + gts.extend(_gts) + gts_valid_mask.extend([1] * n_ref + [0] * (max_n_refs - n_ref)) + assert len(gts) == B * max_n_refs + assert len(gts_valid_mask) == B * max_n_refs + + # print(gts) + # print(gts_valid_mask) + # exit() + + + # assert len(res) == B * K + B, len(res) + + # print(res) + # exit() + + if opt.clipscore_reward_weight > 0: + with torch.no_grad(): + clipscore_model.eval() + + # 1) calculate reward + gen_result = gen_result.data.cpu().numpy() + res = decode_sequence(vocab, gen_result) + assert len(res) == B * K, len(res) + + # [B * K, dim) + if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False): + text_pre_feat = clipscore_model.text_extract(res, proj_norm=False) + + grammar_logit = clipscore_model.grammar_score_head(text_pre_feat.view(-1, 512)) + grammar_prob = torch.softmax(grammar_logit, dim=-1)[:, 1] + grammar_prob = grammar_prob.view(B*K).detach() + + text_feat = clipscore_model.clip_model.text_projection(text_pre_feat) + text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) + + else: + text_feat = clipscore_model.text_extract(res) + + + assert text_feat.size() == (B * K, 512), text_feat.size() + assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size() + + # [B * K, dim] + vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K, -1).contiguous().view(B * K, -1) + + clip_s = clipscore_model(text_feat=text_feat, img_feat=vis_feat, mode='clip_s') + clip_s = clip_s.view(B * K).detach() + + if clipscore_model.mode == 'refclip_s': + # [B * n_ref, dim] + ref_text_feat = clipscore_model.text_extract(gts) + ref_text_mask = torch.tensor(gts_valid_mask, dtype=ref_text_feat.dtype, device=ref_text_feat.device) + + assert ref_text_feat.size() == (B * max_n_refs, 512), ref_text_feat.size() + assert ref_text_mask.size() == (B * max_n_refs,), ref_text_mask.size() + + # [B * K] + refclip_s = clipscore_model.calc_refclip_s( + text_feat=text_feat, img_feat=vis_feat, + ref_text_feat=ref_text_feat.view(B, 1, max_n_refs, -1).expand(-1, K, -1, -1).contiguous().view(B * K * max_n_refs, -1), + ref_text_mask=ref_text_mask.view(B, 1, max_n_refs).expand(-1, K, -1).contiguous().view(B * K * max_n_refs), + clip_s=clip_s) + refclip_s = refclip_s.view(B * K).detach() + + # 2) calcualte reward for baseline (greedy) + greedy_res = greedy_res.data.cpu().numpy() + res = decode_sequence(vocab, greedy_res) + assert len(res) == B, len(res) + + # [B, dim) + + if getattr(opt, 'use_grammar', False) and getattr(opt, 'use_grammar_baseline', False) and not getattr(opt, 'joint_out', False): + text_pre_feat = clipscore_model.text_extract(res, proj_norm=False) + + grammar_logit = clipscore_model.grammar_score_head(text_pre_feat.view(-1, 512)) + grammar_prob_baseline = torch.softmax(grammar_logit, dim=-1)[:, 1] + grammar_prob_baseline = grammar_prob_baseline.view(B).detach() + + text_feat = clipscore_model.clip_model.text_projection(text_pre_feat) + text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) + else: + text_feat = clipscore_model.text_extract(res) + + assert text_feat.size() == (B, 512), text_feat.size() + assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size() + + vis_feat = clip_vis_feats.view(B, 512) + + # [B] + clip_s_baseline = clipscore_model(text_feat=text_feat, img_feat=vis_feat, mode='clip_s') + clip_s_baseline = clip_s_baseline.view(B).detach() + + if clipscore_model.mode == 'refclip_s': + # # [B * n_ref] + # ref_text_feat = clipscore_model.text_extract(gts) + # ref_text_mask = torch.tensor(gts_valid_mask, dtype=ref_text_feat.dtype, device=ref_text_feat.device) + # assert ref_text_feat.size() == (B * max_n_refs, 512), ref_text_feat.size() + # assert ref_text_mask.size() == (B * max_n_refs), ref_text_mask.size() + + # [B] + refclip_s_baseline = clipscore_model.calc_refclip_s( + text_feat=text_feat, img_feat=vis_feat, + ref_text_feat=ref_text_feat, + ref_text_mask=ref_text_mask, + clip_s=clip_s_baseline) + refclip_s_baseline = refclip_s_baseline.view(B).detach() + + if clipscore_model.mode == 'clip_s': + rewards = clip_s - clip_s_baseline.view(B, 1).expand(-1, K).contiguous().flatten() + unnormalized_mean_reward = clip_s.mean() + elif clipscore_model.mode == 'refclip_s': + rewards = refclip_s - refclip_s_baseline.view(B, 1).expand(-1, K).contiguous().flatten() + unnormalized_mean_reward = refclip_s.mean() + + # # [B * K + B, dim) + # text_feat = clipscore_model.text_extract(res) + # assert text_feat.size() == (B * K + B, 512), text_feat.size() + + # assert clip_vis_feats.size() == (B, 512), clip_vis_feats.size() + + # # [B, dim] -> [B * K + B, dim] + # # vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K + 1, -1).contiguous().view(B * (K + 1), -1) + # # vis_feat = clip_vis_feats.view(1, B, -1).expand(K + 1, -1, -1).contiguous().view((K + 1) * B, -1) + + # # [B * K, dim] + # gen_vis_feat = clip_vis_feats.view(B, 1, -1).expand(-1, K, -1).contiguous().view(B * K, -1) + # # [B, dim] + # greedy_vis_feat = clip_vis_feats + # # [B * K + B, dim] + # vis_feat = torch.cat([gen_vis_feat, greedy_vis_feat], dim=0) + + # # if clipscore_model.mode == 'clip_s': + # # [B * K + B, dim] + # clip_s = clipscore_model(text_feat=text_feat, img_feat=vis_feat) + # clip_s = clip_s.view(B * K + B).detach() + + + # if clipscore_model.mode == 'refclip_s': + # # [B * K, dim] + # ref_text_feat = clipscore_model.text_extract(gts) + + # clipscore_scores = clipscore_model.calc_refclip_s(text_feat=text_feat, img_feat=vis_feat, ref_text_feat=ref_text_feat, clip_s=clip_s) + # clipscore_scores = clipscore_scores.view(B * K + B).detach() + + if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False): + + if getattr(opt, 'use_grammar_baseline', False): + grammar_rewards = grammar_prob - grammar_prob_baseline.view(B, 1).expand(-1, K).contiguous().flatten() + else: + grammar_rewards = grammar_prob + else: + grammar_rewards = None + + + if hasattr(opt, 'verbose') and not opt.verbose: + pass + else: + if clipscore_model.mode == 'clip_s': + print('CLIP-S:', rewards) + elif clipscore_model.mode == 'refclip_s': + print('RefCLIP-S:', rewards) + else: + rewards = torch.zeros(B, L) + unnormalized_mean_reward = None + grammar_rewards = None + + + rewards = opt.clipscore_reward_weight * rewards + + + # scores = scores[:gen_result_size].reshape(batch_size, seq_per_img) - scores[-batch_size:][:, np.newaxis] + # scores = scores.reshape(gen_result_size) + # rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1) + + # [B, K] + # scores = scores[:gen_result_size].reshape(B, K) - scores[-B:].unsqueeze(1) + + # [B*K, L] + # rewards = scores.view(-1, 1).expand(-1, L).contiguous() + rewards = rewards.view(-1, 1).expand(-1, L).contiguous() + + if getattr(opt, 'use_grammar', False) and not getattr(opt, 'joint_out', False): + grammar_rewards = grammar_rewards.view(-1, 1).expand(-1, L).contiguous() + + return rewards, unnormalized_mean_reward, grammar_rewards + +def get_scores(data_gts, gen_result, opt): + batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img + seq_per_img = batch_size // len(data_gts) + + res = OrderedDict() + + gen_result = gen_result.data.cpu().numpy() + for i in range(batch_size): + res[i] = [array_to_str(gen_result[i])] + + gts = OrderedDict() + for i in range(len(data_gts)): + gts[i] = [array_to_str(data_gts[i][j]) for j in range(len(data_gts[i]))] + + res_ = [{'image_id':i, 'caption': res[i]} for i in range(batch_size)] + res__ = {i: res[i] for i in range(batch_size)} + gts = {i: gts[i // seq_per_img] for i in range(batch_size)} + if opt.cider_reward_weight > 0: + _, cider_scores = CiderD_scorer.compute_score(gts, res_) + # print('Cider scores:', _) + if hasattr(opt, 'verbose') and not opt.verbose: + pass + else: + print('Cider scores:', _) + else: + cider_scores = 0 + if opt.bleu_reward_weight > 0: + _, bleu_scores = Bleu_scorer.compute_score(gts, res__) + bleu_scores = np.array(bleu_scores[3]) + # print('Bleu scores:', _[3]) + if hasattr(opt, 'verbose') and not opt.verbose: + pass + else: + print('Bleu scores:', _[3]) + else: + bleu_scores = 0 + + scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores + + return scores + +def get_self_cider_scores(data_gts, gen_result, opt): + batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img + seq_per_img = batch_size // len(data_gts) + + res = [] + + gen_result = gen_result.data.cpu().numpy() + for i in range(batch_size): + res.append(array_to_str(gen_result[i])) + + scores = [] + for i in range(len(data_gts)): + tmp = Cider_scorer.my_self_cider([res[i*seq_per_img:(i+1)*seq_per_img]]) + def get_div(eigvals): + eigvals = np.clip(eigvals, 0, None) + return -np.log(np.sqrt(eigvals[-1]) / (np.sqrt(eigvals).sum())) / np.log(len(eigvals)) + scores.append(get_div(np.linalg.eigvalsh(tmp[0]/10))) + + scores = np.array(scores) + + return scores diff --git a/captioning/utils/utils.py b/captioning/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..85e12a8a1fcb5be1fa6b8833381b0a7918add5c4 --- /dev/null +++ b/captioning/utils/utils.py @@ -0,0 +1,138 @@ +import re +import numpy as np +import torch +import torch.distributed as dist +import collections +import logging + +def get_area(pos): + """ + Args + pos: [B, N, 4] + (x1, x2, y1, y2) + + Return + area : [B, N] + """ + # [B, N] + height = pos[:, :, 3] - pos[:, :, 2] + width = pos[:, :, 1] - pos[:, :, 0] + area = height * width + return area + +def get_relative_distance(pos): + """ + Args + pos: [B, N, 4] + (x1, x2, y1, y2) + + Return + out : [B, N, N, 4] + """ + # B, N = pos.size()[:-1] + + # [B, N, N, 4] + relative_distance = pos.unsqueeze(1) - pos.unsqueeze(2) + + return relative_distance + + +class LossMeter(object): + def __init__(self, maxlen=100): + """Computes and stores the running average""" + self.vals = collections.deque([], maxlen=maxlen) + + def __len__(self): + return len(self.vals) + + def update(self, new_val): + self.vals.append(new_val) + + @property + def val(self): + return sum(self.vals) / len(self.vals) + + def __repr__(self): + return str(self.val) + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def load_state_dict(state_dict_path, loc='cpu'): + state_dict = torch.load(state_dict_path, map_location=loc) + # Change Multi GPU to single GPU + original_keys = list(state_dict.keys()) + for key in original_keys: + if key.startswith("module."): + new_key = key[len("module."):] + state_dict[new_key] = state_dict.pop(key) + return state_dict + + +def set_global_logging_level(level=logging.ERROR, prefices=[""]): + """ + Override logging levels of different modules based on their name as a prefix. + It needs to be invoked after the modules have been loaded so that their loggers have been initialized. + + Args: + - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR + - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional. + Default is `[""]` to match all active loggers. + The match is a case-sensitive `module_name.startswith(prefix)` + """ + prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })') + for name in logging.root.manager.loggerDict: + if re.match(prefix_re, name): + logging.getLogger(name).setLevel(level) + + +def get_iou(anchors, gt_boxes): + """ + anchors: (N, 4) torch floattensor + gt_boxes: (K, 4) torch floattensor + overlaps: (N, K) ndarray of overlap between boxes and query_boxes + """ + N = anchors.size(0) + + if gt_boxes.size() == (4,): + gt_boxes = gt_boxes.view(1, 4) + K = gt_boxes.size(0) + + gt_boxes_area = ( + (gt_boxes[:, 2] - gt_boxes[:, 0] + 1) * + (gt_boxes[:, 3] - gt_boxes[:, 1] + 1) + ).view(1, K) + + anchors_area = ( + (anchors[:, 2] - anchors[:, 0] + 1) * + (anchors[:, 3] - anchors[:, 1] + 1) + ).view(N, 1) + + boxes = anchors.view(N, 1, 4).expand(N, K, 4) + query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4) + + iw = ( + torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) + - torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) + + 1 + ) + iw[iw < 0] = 0 + + ih = ( + torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) + - torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) + + 1 + ) + ih[ih < 0] = 0 + + ua = anchors_area + gt_boxes_area - (iw * ih) + overlaps = iw * ih / ua + + return overlaps + + +def xywh_to_xyxy(boxes): + """Convert [x y w h] box format to [x1 y1 x2 y2] format.""" + return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1)) diff --git a/clip/__init__.py b/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9 --- /dev/null +++ b/clip/__init__.py @@ -0,0 +1 @@ +from .clip import * diff --git a/clip/bpe_simple_vocab_16e6.txt.gz b/clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/clip/clip.py b/clip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..76f241b053e3a6da06b1165e73e0d54c5b5356b2 --- /dev/null +++ b/clip/clip.py @@ -0,0 +1,193 @@ +import hashlib +import os +import urllib +import warnings +from typing import Union, List + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", +} + + +def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=Image.BICUBIC), + CenterCrop(n_px), + lambda image: image.convert("RGB"), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name]) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + graphs = [module.graph] if hasattr(module, "graph") else [] + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + graphs = [module.graph] if hasattr(module, "graph") else [] + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/clip/model.py b/clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..049391e9816d7faf00bdab95a08b99a99c3c405a --- /dev/null +++ b/clip/model.py @@ -0,0 +1,437 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + # print(x.shape, self.positional_embedding.shape) + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[0, :, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=torch.ones_like(self.q_proj.weight), + out_proj_bias=torch.zeros_like(self.q_proj.bias), + # out_proj_weight=self.c_proj.weight, + # out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + # print(x.shape) + # x = self.attnpool(x) + attnpool = self.attnpool(x) + + return (x, attnpool) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisualTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + # x = self.ln_post(x[:, 0, :]) + + x = self.ln_post(x) + # if self.proj is not None: + # x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisualTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([])) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/clip/simple_tokenizer.py b/clip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/clip/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/configs/phase1/FineCapEval_clipRN50_mle.yml b/configs/phase1/FineCapEval_clipRN50_mle.yml new file mode 100644 index 0000000000000000000000000000000000000000..0f71ae39417dbd8f1afc25ccd78689c04b746ad3 --- /dev/null +++ b/configs/phase1/FineCapEval_clipRN50_mle.yml @@ -0,0 +1,60 @@ +caption_model: transformer +noamopt: true +noamopt_warmup: 20000 +label_smoothing: 0.0 +input_json: data/FineCapEval.json +input_label_h5: none +input_fc_dir: data/FineCapEval_clip_RN50_fc +input_att_dir: data/FineCapEval_clip_RN50_att +input_clipscore_vis_dir: data/FineCapEval_clipscore_vis + +seq_per_img: 5 +batch_size: 200 +learning_rate: 0.0005 + +checkpoint_path: ./save/clipRN50_mle/clipRN50_mle + +# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' + +# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size + +# will be ignored +num_layers: 6 +input_encoding_size: 512 +rnn_size: 2048 + +# Transformer config +N_enc: 6 +N_dec: 6 +d_model: 512 +d_ff: 2048 +num_att_heads: 8 +dropout: 0.1 + + +learning_rate_decay_start: 0 +scheduled_sampling_start: -1 +save_checkpoint_every: 3000 +language_eval: 1 +val_images_use: 5000 +max_epochs: 15 +train_sample_n: 5 + +REFORWARD: false + +# _BASE_: transformer.yml +reduce_on_plateau: false +noamopt: false +learning_rate: 0.000005 +learning_rate_decay_start: -1 + +self_critical_after: 15 +max_epochs: 50 + +verbose: false +precision: 32 + +use_clipscore: false \ No newline at end of file diff --git a/configs/phase1/clipRN50_mle.yml b/configs/phase1/clipRN50_mle.yml new file mode 100644 index 0000000000000000000000000000000000000000..4756d12c6156724db6f9e7025b28276b86125c5e --- /dev/null +++ b/configs/phase1/clipRN50_mle.yml @@ -0,0 +1,52 @@ +caption_model: transformer +noamopt: true +# noamopt: false +noamopt_warmup: 20000 +label_smoothing: 0.0 +input_json: data/cocotalk.json +input_label_h5: data/cocotalk_label.h5 +input_fc_dir: data/cocotalk_clip_RN50_fc +input_att_dir: data/cocotalk_clip_RN50_att +input_clipscore_vis_dir: data/cocotalk_clipscore_vis +seq_per_img: 5 +# batch_size: 600 +batch_size: 200 + +learning_rate: 0.0005 + +# checkpoint_path: ./save/trans_clip_rn50_sc_pl +checkpoint_path: save/clipRN50_mle/clipRN50_mle + +# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size + +# will be ignored +num_layers: 6 +input_encoding_size: 512 +rnn_size: 2048 + +# Transformer config +N_enc: 6 +N_dec: 6 +d_model: 512 +d_ff: 2048 +num_att_heads: 8 +dropout: 0.1 + + +learning_rate_decay_start: 0 +scheduled_sampling_start: -1 +save_checkpoint_every: 3000 +language_eval: 1 +val_images_use: 5000 +# max_epochs: 15 +max_epochs: 25 +train_sample_n: 5 + +REFORWARD: false + + +verbose: false +precision: 16 \ No newline at end of file diff --git a/configs/phase1/transformer.yml b/configs/phase1/transformer.yml new file mode 100644 index 0000000000000000000000000000000000000000..3dfa9f78b14a8fbec12a4d1177fa489942f861c7 --- /dev/null +++ b/configs/phase1/transformer.yml @@ -0,0 +1,41 @@ +caption_model: transformer +noamopt: true +noamopt_warmup: 20000 +label_smoothing: 0.0 +input_json: data/cocotalk.json +input_label_h5: data/cocotalk_label.h5 +input_att_dir: data/cocotalk_att +seq_per_img: 5 +batch_size: 10 +learning_rate: 0.0005 + +checkpoint_path: ./save/trans_rn50_sc + +# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size + +# will be ignored +num_layers: 6 +input_encoding_size: 512 +rnn_size: 2048 + +# Transformer config +N_enc: 6 +N_dec: 6 +d_model: 512 +d_ff: 2048 +num_att_heads: 8 +dropout: 0.1 + + +learning_rate_decay_start: 0 +scheduled_sampling_start: -1 +save_checkpoint_every: 3000 +language_eval: 1 +val_images_use: 5000 +max_epochs: 15 +train_sample_n: 5 + +REFORWARD: false \ No newline at end of file diff --git a/configs/phase2/FineCapEval_clipRN50_cider.yml b/configs/phase2/FineCapEval_clipRN50_cider.yml new file mode 100644 index 0000000000000000000000000000000000000000..52cac145b854455e92d6ade17be017317907a76a --- /dev/null +++ b/configs/phase2/FineCapEval_clipRN50_cider.yml @@ -0,0 +1,61 @@ +caption_model: transformer +noamopt: true +noamopt_warmup: 20000 +label_smoothing: 0.0 +input_json: data/FineCapEval.json +input_label_h5: none +input_fc_dir: data/FineCapEval_clip_RN50_fc +input_att_dir: data/FineCapEval_clip_RN50_att +input_clipscore_vis_dir: data/FineCapEval_clipscore_vis + +seq_per_img: 5 +batch_size: 200 +learning_rate: 0.0005 + +checkpoint_path: ./save/clipRN50_cider/clipRN50_cider + +# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' + +# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size + +# will be ignored +num_layers: 6 +input_encoding_size: 512 +rnn_size: 2048 + +# Transformer config +N_enc: 6 +N_dec: 6 +d_model: 512 +d_ff: 2048 +num_att_heads: 8 +dropout: 0.1 + + +learning_rate_decay_start: 0 +scheduled_sampling_start: -1 +save_checkpoint_every: 3000 +language_eval: 1 +val_images_use: 5000 +max_epochs: 15 +train_sample_n: 5 + +REFORWARD: false + +# _BASE_: transformer.yml +reduce_on_plateau: false +noamopt: false +learning_rate: 0.000005 +learning_rate_decay_start: -1 + +self_critical_after: 15 +max_epochs: 50 + +verbose: false +precision: 32 + +# use_clipscore: true +use_clipscore: false \ No newline at end of file diff --git a/configs/phase2/FineCapEval_clipRN50_cider_clips.yml b/configs/phase2/FineCapEval_clipRN50_cider_clips.yml new file mode 100644 index 0000000000000000000000000000000000000000..a74ee8b6d71e3bd260713f77be5ab9d4c8f4ad5d --- /dev/null +++ b/configs/phase2/FineCapEval_clipRN50_cider_clips.yml @@ -0,0 +1,65 @@ +caption_model: transformer +noamopt: true +noamopt_warmup: 20000 +label_smoothing: 0.0 +input_json: data/FineCapEval.json +input_label_h5: none +input_fc_dir: data/FineCapEval_clip_RN50_fc +input_att_dir: data/FineCapEval_clip_RN50_att +input_clipscore_vis_dir: data/FineCapEval_clipscore_vis + +seq_per_img: 5 +batch_size: 200 +learning_rate: 0.0005 + +checkpoint_path: ./save/clipRN50_cider_clips/clipRN50_cider_clips + +# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' + +# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size + +# will be ignored +num_layers: 6 +input_encoding_size: 512 +rnn_size: 2048 + +# Transformer config +N_enc: 6 +N_dec: 6 +d_model: 512 +d_ff: 2048 +num_att_heads: 8 +dropout: 0.1 + + +learning_rate_decay_start: 0 +scheduled_sampling_start: -1 +save_checkpoint_every: 3000 +language_eval: 1 +val_images_use: 5000 +max_epochs: 15 +train_sample_n: 5 + +REFORWARD: false + +# _BASE_: transformer.yml +reduce_on_plateau: false +noamopt: false +learning_rate: 0.000005 +learning_rate_decay_start: -1 + +self_critical_after: 15 +max_epochs: 50 + +verbose: false +precision: 32 + +# use_clipscore: true +use_clipscore: false +clipscore_reward_weight: 2.0 +clipscore_mode: clip_s + +use_multi_rewards: true \ No newline at end of file diff --git a/configs/phase2/FineCapEval_clipRN50_clips.yml b/configs/phase2/FineCapEval_clipRN50_clips.yml new file mode 100644 index 0000000000000000000000000000000000000000..5440a45f3196995e2ccfb6e61f88a149fee72b2f --- /dev/null +++ b/configs/phase2/FineCapEval_clipRN50_clips.yml @@ -0,0 +1,64 @@ +caption_model: transformer +noamopt: true +noamopt_warmup: 20000 +label_smoothing: 0.0 +input_json: data/FineCapEval.json +input_label_h5: none +input_fc_dir: data/FineCapEval_clip_RN50_fc +input_att_dir: data/FineCapEval_clip_RN50_att +input_clipscore_vis_dir: data/FineCapEval_clipscore_vis +seq_per_img: 5 +batch_size: 160 +learning_rate: 0.0005 + +checkpoint_path: ./save/clipRN50_clips/clipRN50_clips + +use_multi_rewards: false +use_grammar: false +use_grammar_baseline: false +# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' + +# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size + +# will be ignored +num_layers: 6 +input_encoding_size: 512 +rnn_size: 2048 + +# Transformer config +N_enc: 6 +N_dec: 6 +d_model: 512 +d_ff: 2048 +num_att_heads: 8 +dropout: 0.1 + + +learning_rate_decay_start: 0 +scheduled_sampling_start: -1 +save_checkpoint_every: 3000 +language_eval: 0 +val_images_use: 5000 +max_epochs: 15 +train_sample_n: 5 + +REFORWARD: false + +# _BASE_: transformer.yml +reduce_on_plateau: false +noamopt: false +learning_rate: 0.000005 +learning_rate_decay_start: -1 + +self_critical_after: 15 +max_epochs: 50 + +verbose: false +precision: 32 + +# use_clipscore: true +use_clipscore: false +clipscore_reward_weight: 2.0 \ No newline at end of file diff --git a/configs/phase2/FineCapEval_clipRN50_clips_grammar.yml b/configs/phase2/FineCapEval_clipRN50_clips_grammar.yml new file mode 100644 index 0000000000000000000000000000000000000000..854394e9125a81c7351c555dc598eb541eaf20d3 --- /dev/null +++ b/configs/phase2/FineCapEval_clipRN50_clips_grammar.yml @@ -0,0 +1,64 @@ +caption_model: transformer +noamopt: true +noamopt_warmup: 20000 +label_smoothing: 0.0 +input_json: data/FineCapEval.json +input_label_h5: none +input_fc_dir: data/FineCapEval_clip_RN50_fc +input_att_dir: data/FineCapEval_clip_RN50_att +input_clipscore_vis_dir: data/FineCapEval_clipscore_vis +seq_per_img: 5 +batch_size: 160 +learning_rate: 0.0005 + +checkpoint_path: ./save/clipRN50_clips_grammar/clipRN50_clips_grammar + +use_multi_rewards: true +use_grammar: true +use_grammar_baseline: true +# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' + +# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size + +# will be ignored +num_layers: 6 +input_encoding_size: 512 +rnn_size: 2048 + +# Transformer config +N_enc: 6 +N_dec: 6 +d_model: 512 +d_ff: 2048 +num_att_heads: 8 +dropout: 0.1 + + +learning_rate_decay_start: 0 +scheduled_sampling_start: -1 +save_checkpoint_every: 3000 +language_eval: 0 +val_images_use: 5000 +max_epochs: 15 +train_sample_n: 5 + +REFORWARD: false + +# _BASE_: transformer.yml +reduce_on_plateau: false +noamopt: false +learning_rate: 0.000005 +learning_rate_decay_start: -1 + +self_critical_after: 15 +max_epochs: 50 + +verbose: false +precision: 32 + +# use_clipscore: true +use_clipscore: false +clipscore_reward_weight: 2.0 \ No newline at end of file diff --git a/configs/phase2/clipRN50_cider.yml b/configs/phase2/clipRN50_cider.yml new file mode 100644 index 0000000000000000000000000000000000000000..924b2dacecf012f158502136169b0340d37e9a47 --- /dev/null +++ b/configs/phase2/clipRN50_cider.yml @@ -0,0 +1,58 @@ +caption_model: transformer +noamopt: true +noamopt_warmup: 20000 +label_smoothing: 0.0 +input_json: data/cocotalk.json +input_label_h5: data/cocotalk_label.h5 +input_fc_dir: data/cocotalk_clip_RN50_fc +input_att_dir: data/cocotalk_clip_RN50_att +# used only for evaluation +input_clipscore_vis_dir: data/cocotalk_clipscore_vis + +seq_per_img: 5 +batch_size: 200 +learning_rate: 0.0005 + +# checkpoint_path: ./save/trans_clip_rn50_sc_pl_scst_cider +checkpoint_path: save/clipRN50_cider/clipRN50_cider + +# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size + +# will be ignored +num_layers: 6 +input_encoding_size: 512 +rnn_size: 2048 + +# Transformer config +N_enc: 6 +N_dec: 6 +d_model: 512 +d_ff: 2048 +num_att_heads: 8 +dropout: 0.1 + + +learning_rate_decay_start: 0 +scheduled_sampling_start: -1 +save_checkpoint_every: 3000 +language_eval: 1 +val_images_use: 5000 +max_epochs: 15 +train_sample_n: 5 + +REFORWARD: false + +# _BASE_: transformer.yml +reduce_on_plateau: false +noamopt: false +learning_rate: 0.000005 +learning_rate_decay_start: -1 + +self_critical_after: 15 +max_epochs: 40 + +verbose: false +precision: 32 \ No newline at end of file diff --git a/configs/phase2/clipRN50_cider_clips.yml b/configs/phase2/clipRN50_cider_clips.yml new file mode 100644 index 0000000000000000000000000000000000000000..d1b0f3ff7ce92d80fcb1f77b769cfadec471bc45 --- /dev/null +++ b/configs/phase2/clipRN50_cider_clips.yml @@ -0,0 +1,61 @@ +caption_model: transformer +noamopt: true +noamopt_warmup: 20000 +label_smoothing: 0.0 +input_json: data/cocotalk.json +input_label_h5: data/cocotalk_label.h5 +input_fc_dir: data/cocotalk_clip_RN50_fc +input_att_dir: data/cocotalk_clip_RN50_att +input_clipscore_vis_dir: data/cocotalk_clipscore_vis +seq_per_img: 5 +batch_size: 160 +learning_rate: 0.0005 + +checkpoint_path: save/clipRN50_cider_clips/clipRN50_cider_clips + +# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size + +# will be ignored +num_layers: 6 +input_encoding_size: 512 +rnn_size: 2048 + +# Transformer config +N_enc: 6 +N_dec: 6 +d_model: 512 +d_ff: 2048 +num_att_heads: 8 +dropout: 0.1 + + +learning_rate_decay_start: 0 +scheduled_sampling_start: -1 +save_checkpoint_every: 3000 +language_eval: 1 +val_images_use: 5000 +max_epochs: 15 +train_sample_n: 5 + +REFORWARD: false + +# _BASE_: transformer.yml +reduce_on_plateau: false +noamopt: false +learning_rate: 0.000005 +learning_rate_decay_start: -1 + +self_critical_after: 15 +max_epochs: 40 + +verbose: false +precision: 32 + +use_clipscore: true +clipscore_reward_weight: 2.0 +clipscore_mode: clip_s + +use_multi_rewards: true \ No newline at end of file diff --git a/configs/phase2/clipRN50_clips.yml b/configs/phase2/clipRN50_clips.yml new file mode 100644 index 0000000000000000000000000000000000000000..2b62f5c5d5cbc8ab5c8ece8faa87adcf7a0e70fa --- /dev/null +++ b/configs/phase2/clipRN50_clips.yml @@ -0,0 +1,58 @@ +caption_model: transformer +noamopt: true +noamopt_warmup: 20000 +label_smoothing: 0.0 +input_json: data/cocotalk.json +input_label_h5: data/cocotalk_label.h5 +input_fc_dir: data/cocotalk_clip_RN50_fc +input_att_dir: data/cocotalk_clip_RN50_att +input_clipscore_vis_dir: data/cocotalk_clipscore_vis +seq_per_img: 5 +batch_size: 160 +learning_rate: 0.0005 + +checkpoint_path: save/clipRN50_clips/clipRN50_clips + +# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size + +# will be ignored +num_layers: 6 +input_encoding_size: 512 +rnn_size: 2048 + +# Transformer config +N_enc: 6 +N_dec: 6 +d_model: 512 +d_ff: 2048 +num_att_heads: 8 +dropout: 0.1 + + +learning_rate_decay_start: 0 +scheduled_sampling_start: -1 +save_checkpoint_every: 3000 +language_eval: 1 +val_images_use: 5000 +max_epochs: 15 +train_sample_n: 5 + +REFORWARD: false + +# _BASE_: transformer.yml +reduce_on_plateau: false +noamopt: false +learning_rate: 0.000005 +learning_rate_decay_start: -1 + +self_critical_after: 15 +max_epochs: 40 + +verbose: false +precision: 32 + +use_clipscore: true +clipscore_reward_weight: 2.0 \ No newline at end of file diff --git a/configs/phase2/clipRN50_clips_grammar.yml b/configs/phase2/clipRN50_clips_grammar.yml new file mode 100644 index 0000000000000000000000000000000000000000..c9db26ff17158568d0f3d2a63837f3925dc007b8 --- /dev/null +++ b/configs/phase2/clipRN50_clips_grammar.yml @@ -0,0 +1,64 @@ +caption_model: transformer +noamopt: true +noamopt_warmup: 20000 +label_smoothing: 0.0 +input_json: data/cocotalk.json +input_label_h5: data/cocotalk_label.h5 +input_fc_dir: data/cocotalk_clip_RN50_fc +input_att_dir: data/cocotalk_clip_RN50_att +input_clipscore_vis_dir: data/cocotalk_clipscore_vis +seq_per_img: 5 +batch_size: 160 +learning_rate: 0.0005 + +checkpoint_path: save/clipRN50_clips_grammar/clipRN50_clips_grammar + +use_multi_rewards: true +use_grammar: true +use_grammar_baseline: true +# clip_load_path: '/scratch-space/retrieval/save/clip_negative_text/clip_negative_text-epoch=10.ckpt' +clip_load_path: 'retrieval/save/clip_negative_text/clip_negative_text-epoch=12.ckpt' + +# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size + +# will be ignored +num_layers: 6 +input_encoding_size: 512 +rnn_size: 2048 + +# Transformer config +N_enc: 6 +N_dec: 6 +d_model: 512 +d_ff: 2048 +num_att_heads: 8 +dropout: 0.1 + + +learning_rate_decay_start: 0 +scheduled_sampling_start: -1 +save_checkpoint_every: 3000 +language_eval: 1 +val_images_use: 5000 +max_epochs: 15 +train_sample_n: 5 + +REFORWARD: false + +# _BASE_: transformer.yml +reduce_on_plateau: false +noamopt: false +learning_rate: 0.000005 +learning_rate_decay_start: -1 + +self_critical_after: 15 +max_epochs: 40 + +verbose: false +precision: 32 + +use_clipscore: true +clipscore_reward_weight: 2.0 \ No newline at end of file diff --git a/configs/phase2/transformer.yml b/configs/phase2/transformer.yml new file mode 100644 index 0000000000000000000000000000000000000000..3dfa9f78b14a8fbec12a4d1177fa489942f861c7 --- /dev/null +++ b/configs/phase2/transformer.yml @@ -0,0 +1,41 @@ +caption_model: transformer +noamopt: true +noamopt_warmup: 20000 +label_smoothing: 0.0 +input_json: data/cocotalk.json +input_label_h5: data/cocotalk_label.h5 +input_att_dir: data/cocotalk_att +seq_per_img: 5 +batch_size: 10 +learning_rate: 0.0005 + +checkpoint_path: ./save/trans_rn50_sc + +# Notice: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size + +# will be ignored +num_layers: 6 +input_encoding_size: 512 +rnn_size: 2048 + +# Transformer config +N_enc: 6 +N_dec: 6 +d_model: 512 +d_ff: 2048 +num_att_heads: 8 +dropout: 0.1 + + +learning_rate_decay_start: 0 +scheduled_sampling_start: -1 +save_checkpoint_every: 3000 +language_eval: 1 +val_images_use: 5000 +max_epochs: 15 +train_sample_n: 5 + +REFORWARD: false \ No newline at end of file diff --git a/data/README.md b/data/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c786a9e85300c02f477a4d977cee587f35162b0d --- /dev/null +++ b/data/README.md @@ -0,0 +1 @@ +directory to store preprocessed files \ No newline at end of file diff --git a/predict.py b/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..23293cc6f92669f5c70ef7ba262784ae29232cf0 --- /dev/null +++ b/predict.py @@ -0,0 +1,182 @@ +import os +import numpy as np +import json +import torch +import torch.nn as nn +import clip +import pytorch_lightning as pl +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from PIL import Image +from timm.models.vision_transformer import resize_pos_embed +from cog import BasePredictor, Path, Input + +import captioning.utils.opts as opts +import captioning.models as models +import captioning.utils.misc as utils + + +class Predictor(BasePredictor): + def setup(self): + import __main__ + __main__.ModelCheckpoint = pl.callbacks.ModelCheckpoint + + self.device = torch.device("cuda:0") + self.dict_json = json.load(open("./data/cocotalk.json")) + self.ix_to_word = self.dict_json["ix_to_word"] + self.vocab_size = len(self.ix_to_word) + self.clip_model, self.clip_transform = clip.load( + "RN50", jit=False, device=self.device + ) + + self.preprocess = Compose( + [ + Resize((448, 448), interpolation=Image.BICUBIC), + CenterCrop((448, 448)), + ToTensor(), + ] + ) + + def predict( + self, + image: Path = Input( + description="Input image.", + ), + reward: str = Input( + choices=["mle", "cider", "clips", "cider_clips", "clips_grammar"], + default="clips_grammar", + description="Choose a reward criterion.", + ), + ) -> str: + + self.device = torch.device("cuda:0") + self.dict_json = json.load(open("./data/cocotalk.json")) + self.ix_to_word = self.dict_json["ix_to_word"] + self.vocab_size = len(self.ix_to_word) + self.clip_model, self.clip_transform = clip.load( + "RN50", jit=False, device=self.device + ) + + self.preprocess = Compose( + [ + Resize((448, 448), interpolation=Image.BICUBIC), + CenterCrop((448, 448)), + ToTensor(), + ] + ) + + cfg = ( + f"configs/phase1/clipRN50_{reward}.yml" + if reward == "mle" + else f"configs/phase2/clipRN50_{reward}.yml" + ) + print("Loading cfg from", cfg) + + opt = opts.parse_opt(parse=False, cfg=cfg) + print("vocab size:", self.vocab_size) + + seq_length = 1 + opt.vocab_size = self.vocab_size + opt.seq_length = seq_length + + opt.batch_size = 1 + opt.vocab = self.ix_to_word + print(opt.caption_model) + + model = models.setup(opt) + del opt.vocab + + ckpt_path = opt.checkpoint_path + "-last.ckpt" + print("Loading checkpoint from", ckpt_path) + raw_state_dict = torch.load(ckpt_path, map_location=self.device) + + strict = True + state_dict = raw_state_dict["state_dict"] + + if "_vocab" in state_dict: + model.vocab = utils.deserialize(state_dict["_vocab"]) + del state_dict["_vocab"] + elif strict: + raise KeyError + if "_opt" in state_dict: + saved_model_opt = utils.deserialize(state_dict["_opt"]) + del state_dict["_opt"] + # Make sure the saved opt is compatible with the curren topt + need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"] + for checkme in need_be_same: + if ( + getattr(saved_model_opt, checkme) + in [ + "updown", + "topdown", + ] + and getattr(opt, checkme) in ["updown", "topdown"] + ): + continue + assert getattr(saved_model_opt, checkme) == getattr(opt, checkme), ( + "Command line argument and saved model disagree on '%s' " % checkme + ) + elif strict: + raise KeyError + res = model.load_state_dict(state_dict, strict) + print(res) + + model = model.to(self.device) + model.eval() + + image_mean = ( + torch.Tensor([0.48145466, 0.4578275, 0.40821073]) + .to(self.device) + .reshape(3, 1, 1) + ) + image_std = ( + torch.Tensor([0.26862954, 0.26130258, 0.27577711]) + .to(self.device) + .reshape(3, 1, 1) + ) + + num_patches = 196 # 600 * 1000 // 32 // 32 + pos_embed = nn.Parameter( + torch.zeros( + 1, + num_patches + 1, + self.clip_model.visual.attnpool.positional_embedding.shape[-1], + device=self.device, + ), + ) + pos_embed.weight = resize_pos_embed( + self.clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed + ) + self.clip_model.visual.attnpool.positional_embedding = pos_embed + + with torch.no_grad(): + image = self.preprocess(Image.open(str(image)).convert("RGB")) + image = torch.tensor(np.stack([image])).to(self.device) + image -= image_mean + image /= image_std + + tmp_att, tmp_fc = self.clip_model.encode_image(image) + tmp_att = tmp_att[0].permute(1, 2, 0) + + att_feat = tmp_att + + # Inference configurations + eval_kwargs = {} + eval_kwargs.update(vars(opt)) + + with torch.no_grad(): + fc_feats = torch.zeros((1, 0)).to(self.device) + att_feats = att_feat.view(1, 196, 2048).float().to(self.device) + att_masks = None + + # forward the model to also get generated samples for each image + # Only leave one feature for each image, in case duplicate sample + tmp_eval_kwargs = eval_kwargs.copy() + tmp_eval_kwargs.update({"sample_n": 1}) + seq, seq_logprobs = model( + fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode="sample" + ) + seq = seq.data + + sents = utils.decode_sequence(model.vocab, seq) + + return sents[0] diff --git a/retrieval/README.md b/retrieval/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2f5cce9ad9b93234fa5ac7a0e99f05868d883fd0 --- /dev/null +++ b/retrieval/README.md @@ -0,0 +1,5 @@ +# Finetuning CLIP reward model + +```bash +python train_pl.py --cfg clip_negative_text --id clip_negative_text +``` \ No newline at end of file diff --git a/retrieval/caption_data.py b/retrieval/caption_data.py new file mode 100644 index 0000000000000000000000000000000000000000..595a81ae5346937e5d9174401cd8a62e78946864 --- /dev/null +++ b/retrieval/caption_data.py @@ -0,0 +1,500 @@ +from torch.utils.data import DataLoader, Dataset, Sampler +from pathlib import Path +import json +from multiprocessing import Pool +from tqdm import tqdm +from PIL import Image +import random +import numpy as np +import torch +import torchvision +import torchvision.transforms as T + +from torch.utils.data.distributed import DistributedSampler + +from transformers import T5Tokenizer, BertTokenizer, BertTokenizerFast, CLIPTokenizer + +import text_utils + +project_dir = Path(__file__).parent.resolve() +workspace_dir = project_dir.parent.parent +dataset_dir = workspace_dir.joinpath('datasets/').resolve() +# coco_dir = dataset_dir.joinpath('COCO') +# vg_dir = dataset_dir.joinpath('VG') +coco_img_dir = dataset_dir.joinpath('COCO/images/') +coco_data_dir = project_dir.parent.joinpath('CLIP-ViL/CLIP-ViL-Direct/caption/data/') +# coco_feature_dir = coco_dir.joinpath('features') + + +class COCORetrievalDataset(Dataset): + def __init__(self, split='karpathy_train', rank=-1, topk=-1, verbose=True, args=None, mode='train'): + super().__init__() + + self.topk = topk + self.verbose = verbose + self.args = args + self.rank = rank + self.mode = mode + + # Loading datasets to data + self.source = split + if self.verbose: + print('Data source: ', self.source) + + # if self.args.tokenizer is None: + # self.args.tokenizer = self.args.decoder_backbone + + # if 'bert' in self.args.tokenizer: + # self.tokenizer = BertTokenizerFast.from_pretrained( + # self.args.tokenizer, + # # max_length=self.args.max_text_length, + # # do_lower_case=self.args.do_lower_case + # ) + # elif 'clip' in self.args.tokenizer: + # self.tokenizer = CLIPTokenizer.from_pretrained( + # self.args.tokenizer, + # # max_length=self.args.max_text_length, + # # do_lower_case=self.args.do_lower_case + # ) + + self.tokenizer = CLIPTokenizer.from_pretrained( + self.args.tokenizer, + # max_length=self.args.max_text_length, + # do_lower_case=self.args.do_lower_case + ) + + with open(coco_data_dir.joinpath('cocotalk.json')) as f: + self.vocab = list(json.load(f)['ix_to_word'].values()) + popped = self.vocab.pop(-1) + assert popped == 'UNK' + if self.verbose: + print('vocab size: ', len(self.vocab)) + + + data_info_path = coco_data_dir.joinpath('dataset_coco.json') + with open(data_info_path) as f: + karpathy_data = json.load(f) + + split_rename = { + 'train': 'train', + 'restval': 'train', + 'val': 'val', + 'test': 'test' + } + + n_images = 0 + + data = [] + # self.vocab = set() + for datum in karpathy_data['images']: + re_split = split_rename[datum['split']] + + # if re_split == 'train': + # for d in datum['sentences']: + # self.vocab = self.vocab.union(set(d['tokens'])) + + if re_split != self.source.split('_')[-1]: + continue + + if re_split == 'train': + # for d in datum['sentences']: + # img_id = datum['filename'].split('.')[0] + # new_datum = { + # 'filename': datum['filename'], + # 'img_id': img_id, + # 'sent': d['raw'].strip(), + # 'targets': [d['raw'].strip() for d in datum['sentences']], + # 'is_train': True, + # 'cocoid': datum['cocoid'] + # } + # data.append(new_datum) + img_id = datum['filename'].split('.')[0] + new_datum = { + 'filename': datum['filename'], + 'img_id': img_id, + # 'sent': d['raw'], + # 'targets': [d['raw'].strip() for d in datum['sentences']], + 'targets': [" ".join(d['tokens']) for d in datum['sentences']], + 'is_train': True, + 'cocoid': datum['cocoid'] + } + data.append(new_datum) + + else: + img_id = datum['filename'].split('.')[0] + new_datum = { + 'filename': datum['filename'], + 'img_id': img_id, + # 'sent': d['raw'], + # 'targets': [d['raw'].strip() for d in datum['sentences']], + 'targets': [" ".join(d['tokens']) for d in datum['sentences']], + 'is_train': False, + 'cocoid': datum['cocoid'] + } + data.append(new_datum) + + n_images += 1 + + if self.verbose: + print(f"{self.source} has {n_images} images") + # print(f"Loaded {len(data)} data from", split) + + self.n_gpus = torch.cuda.device_count() + + if self.topk > 0: + data = data[:self.topk] + if self.verbose: + print(f"Use only {self.topk} data") + + self.data = data + + # if self.verbose: + # print("# all sentences:", len(self.data)) + + if self.args.load_feat: + # feat_dir = coco_dir.joinpath('' + # self.feat_loader = HybridLoader('/scratch-space/CLIP-ViL/CLIP-ViL-Direct/caption/data/cocotalk_clipscore_vis', ext='.npy', in_memory=False) + self.feat_loader = HybridLoader( + coco_data_dir.joinpath('cocotalk_clipscore_vis'), + ext='.npy', in_memory=False) + else: + if 'openai/clip' in self.args.encoder_backbone: + # from transformers import CLIPProcessor + # self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", + # size=args.image_size, + # do_resize=True, + # do_center_crop=False, + # ) + # self.img_transform = lambda image: self.processor.feature_extractor( + # image, + # return_tensors='pt')['pixel_values'][0] + + self.image_mean = [0.48145466, 0.4578275, 0.40821073] + self.image_std = [0.26862954, 0.26130258, 0.27577711] + + # captioning + # self.img_transform = T.Compose([ + # T.Resize((self.args.image_size, self.args.image_size)) + # ]) + + # retrieval + self.img_transform = T.Compose([ + T.Resize(self.args.image_size, interpolation=T.functional.InterpolationMode.BICUBIC), + T.CenterCrop(self.args.image_size) + ]) + + self.img_tensor_transform = T.Compose([ + # T.RandomCrop(224), + # T.RandomHorizontalFlip(p=0.3), + T.ConvertImageDtype(torch.float), + T.Normalize(self.image_mean, self.image_std) + ] + ) + # elif 'google/vit' in self.args.encoder_backbone: + # self.image_mean = [0.5, 0.5, 0.5] + # self.image_std = [0.5, 0.5, 0.5] + + # self.img_transform = T.Compose([ + # # T.PILToTensor(), + # T.Resize((self.args.image_size, self.args.image_size)) + # ]) + + # self.img_tensor_transform = T.Compose([ + # # T.RandomCrop(224), + # # T.RandomHorizontalFlip(p=0.3), + # T.ConvertImageDtype(torch.float), + # T.Normalize(self.image_mean, self.image_std) + # ] + # ) + + def get_negative_text(self, text): + neg_type = random.choice(['repeat', 'remove', 'insert', 'swap', 'shuffle']) + + if neg_type == 'repeat': + text = text_utils.repeat(text) + elif neg_type == 'remove': + text = text_utils.remove(text) + elif neg_type == 'insert': + text = text_utils.insert(text, self.vocab) + elif neg_type == 'swap': + text = text_utils.swap(text, self.vocab) + elif neg_type == 'shuffle': + text = text_utils.shuffle(text) + + return text, neg_type + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + datum = self.data[idx] + return self.process_datum(datum) + + def process_datum(self, datum): + out_dict = {} + + ###### Image ###### + + if self.args.load_feat: + cocoid = datum['cocoid'] + out_dict['cocoid'] = str(cocoid) + img_feat = self.feat_loader.get(str(cocoid)) + out_dict['img_feat'] = torch.from_numpy(img_feat) + + else: + img_id = datum['img_id'] + out_dict['img_id'] = img_id + + if 'train' in datum['filename']: + img_split = 'train2014' + elif 'val' in datum['filename']: + img_split = 'val2014' + img_path = coco_img_dir.joinpath(img_split).joinpath(datum['filename']).with_suffix('.jpg') + assert img_path.exists() + img_path = str(img_path) + out_dict['img_path'] = img_path + + img_tensor = torchvision.io.read_image(img_path) + # out_dict['img_tensor'] = img + + # img = Image.open(img_path).convert('RGB') + # img_tensor = torch.as_tensor(np.asarray(img)) + out_dict['img_tensor'] = self.img_transform(img_tensor) + # self.img_transform(img_tensor) + # out_dict['img_tensor'] = self.img_transform(img) + + ###### Text ##### + # if datum['is_train']: + # sent = datum['sent'].strip() + + sent = random.choice(datum['targets']) + + # target_ids = self.tokenizer.encode( + # sent, max_length=self.args.gen_max_length, truncation=True) + + # assert len(target_ids) <= self.args.gen_max_length, len(target_ids) + out_dict['sent'] = sent + # out_dict['target_ids'] = torch.LongTensor(target_ids) + # out_dict['target_length'] = len(target_ids) + + + # negative sample + neg_sent, neg_type = self.get_negative_text(sent) + + # neg_target_ids = self.tokenizer.encode( + # neg_sent, max_length=self.args.gen_max_length, truncation=True) + + # assert len(neg_target_ids) <= self.args.gen_max_length, len(neg_target_ids) + out_dict['neg_sent'] = neg_sent + out_dict['neg_type'] = neg_type + # out_dict['neg_target_ids'] = torch.LongTensor(neg_target_ids) + # out_dict['neg_target_length'] = len(neg_target_ids) + + + if 'targets' in datum: + out_dict['targets'] = datum['targets'] + + return out_dict + + def collate_fn(self, batch): + batch_entry = {} + + B = len(batch) + + # if 'target_ids' in batch[0]: + # T_W_L = max(entry['target_length'] for entry in batch) + # target_ids = torch.ones( + # B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id + + # if 'target_ids' in batch[0]: + # T_W_L = max(entry['target_length'] for entry in batch) + # target_ids = torch.ones( + # B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id + + + + targets = [] + img_ids = [] + img_paths = [] + + coco_ids = [] + + if self.args.load_feat: + img_feats = torch.zeros(B, 512, dtype=torch.float) + else: + # imgs = [] + img_tensor = torch.zeros(B, 3, self.args.image_size, self.args.image_size, dtype=torch.uint8) + + for i, entry in enumerate(batch): + + if self.args.load_feat: + coco_ids.append(entry['cocoid']) + img_feats[i] = entry['img_feat'] + + else: + + img_ids.append(entry['img_id']) + img_paths.append(entry['img_path']) + img_tensor[i] = entry['img_tensor'] + + # if 'target_ids' in entry: + # target_ids[i, :entry['target_length']] = entry['target_ids'] + + if 'targets' in entry: + targets.append(entry['targets']) + + if 'sent' in batch[0]: + # word_mask = target_ids != self.tokenizer.pad_token_id + # target_ids[~word_mask] = -100 + # batch_entry['target_ids'] = target_ids + + tokenized = self.tokenizer([entry['sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt') + neg_tokenized = self.tokenizer([entry['neg_sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt') + # sent, max_length=self.args.gen_max_length, truncation=True) + + batch_entry['text'] = (tokenized.input_ids, tokenized.attention_mask) + batch_entry['neg_text'] = (neg_tokenized.input_ids, neg_tokenized.attention_mask) + + + if self.args.load_feat: + batch_entry['coco_ids'] = coco_ids + batch_entry['img_feats'] = img_feats + + else: + + img_tensor = self.img_tensor_transform(img_tensor) + + batch_entry['img_id'] = img_ids + batch_entry['img_paths'] = img_paths + batch_entry['img_tensor'] = img_tensor + + batch_entry['targets'] = targets + + # print('batch created') + + # batch_entry['task'] = 'caption' + + return batch_entry + + +# def get_loader(args, split='karpathy_train', mode='train', +# batch_size=32, workers=4, distributed=False, gpu=0, +# topk=-1): + +# verbose = (gpu == 0) + +# dataset = COCORetrievalDataset( +# split, +# rank=gpu, +# topk=topk, +# verbose=verbose, +# args=args, +# mode=mode) + +# # if distributed: +# # sampler = DistributedSampler(dataset) +# # else: +# # sampler = None + +# if mode == 'train': +# loader = DataLoader( +# dataset, batch_size=batch_size, shuffle=(sampler is None), +# num_workers=workers, pin_memory=True, sampler=sampler, +# collate_fn=dataset.collate_fn) +# else: +# loader = DataLoader( +# dataset, +# batch_size=batch_size, shuffle=False, +# num_workers=workers, pin_memory=True, +# sampler=sampler, +# collate_fn=dataset.collate_fn, +# drop_last=False) + +# # if verbose: +# # loader.evaluator = COCOCaptionEvaluator() + +# # loader.task = 'caption' + +# return loader + + +# class COCOCaptionEvaluator: +# def __init__(self): +# import language_evaluation +# self.evaluator = language_evaluation.CocoEvaluator(verbose=False) + +# def evaluate(self, predicts, answers): + +# results = self.evaluator.run_evaluation(predicts, answers) + +# return results + +import six +import os +import h5py + +class HybridLoader: + """ + If db_path is a director, then use normal file loading + If lmdb, then load from lmdb + The loading method depend on extention. + + in_memory: if in_memory is True, we save all the features in memory + For individual np(y|z)s, we don't need to do that because the system will do this for us. + Should be useful for lmdb or h5. + (Copied this idea from vilbert) + """ + + def __init__(self, db_path, ext='.npy', in_memory=False): + self.db_path = db_path + self.ext = ext + if self.ext == '.npy': + self.loader = lambda x: np.load(six.BytesIO(x)) + else: + self.loader = lambda x: np.load(six.BytesIO(x))['feat'] + # if db_path.endswith('.lmdb'): + # self.db_type = 'lmdb' + # self.lmdb = lmdbdict(db_path, unsafe=True) + # self.lmdb._key_dumps = DUMPS_FUNC['ascii'] + # self.lmdb._value_loads = LOADS_FUNC['identity'] + # elif db_path.endswith('.pth'): # Assume a key,value dictionary + # self.db_type = 'pth' + # self.feat_file = torch.load(db_path) + # self.loader = lambda x: x + # print('HybridLoader: ext is ignored') + # elif db_path.endswith('h5'): + # self.db_type = 'h5' + # self.loader = lambda x: np.array(x).astype('float32') + # else: + # self.db_type = 'dir' + + self.in_memory = in_memory + if self.in_memory: + self.features = {} + + def get(self, key): + + # if self.in_memory and key in self.features: + # # We save f_input because we want to save the + # # compressed bytes to save memory + # f_input = self.features[key] + # elif self.db_type == 'lmdb': + # f_input = self.lmdb[key] + # elif self.db_type == 'pth': + # f_input = self.feat_file[key] + # elif self.db_type == 'h5': + # f_input = h5py.File(self.db_path, 'r')[key] + # else: + # f_input = open(os.path.join( + # self.db_path, key + self.ext), 'rb').read() + + f_input = open(os.path.join( + self.db_path, key + self.ext), 'rb').read() + + if self.in_memory and key not in self.features: + self.features[key] = f_input + + # load image + feat = self.loader(f_input) + + return feat diff --git a/retrieval/clip_model.py b/retrieval/clip_model.py new file mode 100644 index 0000000000000000000000000000000000000000..83d35620683bd11d3c9e6ac38bf76acbcd364e21 --- /dev/null +++ b/retrieval/clip_model.py @@ -0,0 +1,350 @@ +from transformers import CLIPModel, CLIPTokenizer +import os +import json +import argparse +from random import shuffle, seed +import string +# non-standard dependencies: +import h5py +from six.moves import cPickle +import numpy as np +import torch +import torchvision.models as models +import skimage.io + +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from PIL import Image +from torch import nn + + +class CLIPScore(nn.Module): + def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s', use_grammar=False, joint_out=False): + super(CLIPScore, self).__init__() + # from transformers import CLIPModel, CLIPTokenizer + self.clip_model = CLIPModel.from_pretrained( + 'openai/clip-vit-base-patch32') + self.tokenizer = CLIPTokenizer.from_pretrained( + 'openai/clip-vit-base-patch32') + + self.clip_model.eval() + + self.clipscore_w = clipscore_w + + self.image_transform = self._transform(image_size) + + self.mode = mode + assert mode in ['clip_s', 'refclip_s'] + + self.use_grammar = use_grammar + self.joint_out = joint_out + + if self.use_grammar and self.joint_out is False: + self.grammar_score_head = nn.Sequential( + nn.Linear(self.clip_model.text_embed_dim, self.clip_model.projection_dim, bias=False), + nn.ReLU(), + nn.Linear(self.clip_model.projection_dim, 2, bias=False) + ) + + def _transform(self, n_px): + return Compose([ + Resize(n_px, interpolation=Image.BICUBIC), + CenterCrop(n_px), + lambda image: image.convert("RGB"), + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + + def load_image(self, image_path): + image = Image.open(image_path) + return image + + # @torch.no_grad() + def image_extract(self, image): + if isinstance(image, str): + image = self.load_image(image) + if not isinstance(image, torch.Tensor): + image = self.image_transform(image) + + img_tensor = image.view(-1, 3, 224, 224) + device = next(self.clip_model.parameters()).device + img_tensor = img_tensor.to(device) + + clip_model = self.clip_model + + img_feat = clip_model.vision_model(img_tensor).pooler_output + img_feat = clip_model.visual_projection(img_feat) + img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) + + return img_feat + + # @torch.no_grad() + def text_extract(self, text, prompt="A photo depicts", proj_norm=True): + if isinstance(text, str): + text_batch = [" ".join([prompt, text])] + elif isinstance(text, list): + text_batch = [" ".join([prompt, txt]) for txt in text] + + if isinstance(text, tuple) and isinstance(text[0], torch.Tensor): + input_ids, attention_mask = text + else: + input_text = text_batch + + tokenized = self.tokenizer( + input_text, return_tensors='pt', padding=True) + + input_ids = tokenized.input_ids + attention_mask = tokenized.attention_mask + + clip_model = self.clip_model + device = next(self.clip_model.parameters()).device + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + + text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output + + if proj_norm: + text_feat = clip_model.text_projection(text_feat) + text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) + + return text_feat + + # @torch.no_grad() + def calc_clip_s(self, img_feat, text_feat): + return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1)) + + # @torch.no_grad() + def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None): + + if clip_s is None: + clip_s = self.calc_clip_s(img_feat, text_feat) + + B, dim = img_feat.size() + + ref_text_feat = ref_text_feat.view(B, -1, dim) + + K = ref_text_feat.size(1) + + text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1) + assert ref_text_feat.size() == text_feat.size( + ), (ref_text_feat.size(), text_feat.size()) + + ref_score = self.calc_clip_s(text_feat, ref_text_feat) + if ref_text_mask is not None: + if not isinstance(ref_text_mask, torch.Tensor): + ref_text_mask = torch.tensor( + ref_text_mask, dtype=ref_score.dtype, device=ref_score.device) + ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K) + + ref_score = ref_score.view(B, K).max(dim=1).values + + assert clip_s.size() == (B,) + assert clip_s.size() == ref_score.size() + + # harmonic mean + refclip_s = 2 / (1 / clip_s + 1 / ref_score) + return refclip_s + + # # @torch.no_grad() + # def forward(self, + # images=None, text=None, + # img_feat=None, text_feat=None, + # ref_text=None, ref_text_feat=None, ref_text_mask=None, + # prompt="A photo depicts", + # mode=None): + # if img_feat is None: + # img_feat = self.image_extract(images) + # img_feat = img_feat.view(-1, 512) + + # if text_feat is None: + # text_feat = self.text_extract(text, prompt=prompt) + # text_feat = text_feat.view(-1, 512) + + # if mode is None: + # mode = self.mode + # assert mode in ['clip_s', 'refclip_s'] + + # if mode == 'clip_s': + # clip_s = self.calc_clip_s(img_feat, text_feat) + # return clip_s + # elif mode == 'refclip_s': + # if ref_text_feat is None: + # ref_text_feat = self.text_extract(ref_text, prompt=prompt) + # ref_text_feat = ref_text_feat.view(-1, 512) + + # refclip_s = self.calc_refclip_s( + # img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask) + # return refclip_s + + + def train_step(self, + images=None, text=None, + img_feat=None, text_feat=None, + neg_text=None, neg_text_feat=None, + # ref_text=None, ref_text_feat=None, ref_text_mask=None, + prompt="A photo depicts", + # return_loss=True, + **kwargs): + + if img_feat is None: + img_feat = self.image_extract(images) + img_feat = img_feat.view(-1, 512) + + B = img_feat.size(0) + + if self.joint_out: + pos_text_feat = self.text_extract(text, prompt=prompt, proj_norm=False).view(B, 512) + neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(-1, 512) + neg_B = neg_text_feat.size(0) + + # [B+neg_B, 512] + text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0) + + text_cont_feat = self.clip_model.text_projection(text_feat) + text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True) + + text_cont_feat = text_cont_feat.view(B+neg_B, 512) + + logit_scale = self.clip_model.logit_scale.exp() + + # [B+neg_B * B] + logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale + + # image-to-text label: positive text + caption_loss = -torch.diag(nn.functional.log_softmax(logits_per_text, dim=0)[:B]).mean() + + # calculate text-to-image only on positive text + image_loss = -torch.diag(nn.functional.log_softmax(logits_per_text[:B], dim=1)).mean() + + clip_loss = (caption_loss + image_loss) / 2.0 + + out = { + 'clip_loss': clip_loss, + 'img_feat': img_feat, + 'text_feat': text_cont_feat[:B].detach(), + # 'neg_text_feat': neg_text_feat, + } + + return out + + + else: + if text_feat is None: + text_feat = self.text_extract(text, prompt=prompt, proj_norm=False) + + text_cont_feat = self.clip_model.text_projection(text_feat) + text_cont_feat = text_cont_feat / \ + text_cont_feat.norm(dim=-1, keepdim=True) + + text_cont_feat = text_cont_feat.view(B, 512) + + + # cosine similarity as logits + logit_scale = self.clip_model.logit_scale.exp() + logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale + # logits_per_image = logits_per_text.T + + clip_loss = clip_loss_fn(logits_per_text) + + + # negative sampling + pos_text_feat = text_feat.view(B, 512) + neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512) + + grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0) + + # 2B, 1 + grammar_text_logit = self.grammar_score_head(grammar_text_feat) + grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B) + + grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels) + + grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False) + grammar_pos_pred = grammar_pred[:B] + grammar_neg_pred = grammar_pred[B:] + # grammar_acc = (grammar_pred == grammar_labels).float().mean() + + out = { + 'clip_loss': clip_loss, + 'grammar_loss': grammar_loss, + 'img_feat': img_feat, + 'text_feat': text_cont_feat, + 'neg_text_feat': neg_text_feat, + 'grammar_pos_pred': grammar_pos_pred, + 'grammar_neg_pred': grammar_neg_pred, + } + + return out + + def train_step_old(self, + images=None, text=None, + img_feat=None, text_feat=None, + neg_text=None, neg_text_feat=None, + # ref_text=None, ref_text_feat=None, ref_text_mask=None, + prompt="A photo depicts", + # return_loss=True, + **kwargs): + + if img_feat is None: + img_feat = self.image_extract(images) + img_feat = img_feat.view(-1, 512) + + B = img_feat.size(0) + + + + if text_feat is None: + text_feat = self.text_extract(text, prompt=prompt, proj_norm=False) + + text_cont_feat = self.clip_model.text_projection(text_feat) + text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True) + text_cont_feat = text_cont_feat.view(B, 512) + + # cosine similarity as logits + logit_scale = self.clip_model.logit_scale.exp() + logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale + # logits_per_image = logits_per_text.T + + clip_loss = clip_loss_fn(logits_per_text) + + + # negative sampling + pos_text_feat = text_feat.view(B, 512) + neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512) + + grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0) + + # 2B, 1 + grammar_text_logit = self.grammar_score_head(grammar_text_feat) + grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B) + + grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels) + + grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False) + grammar_pos_pred = grammar_pred[:B] + grammar_neg_pred = grammar_pred[B:] + # grammar_acc = (grammar_pred == grammar_labels).float().mean() + + out = { + 'clip_loss': clip_loss, + 'grammar_loss': grammar_loss, + 'img_feat': img_feat, + 'text_feat': text_cont_feat, + 'neg_text_feat': neg_text_feat, + 'grammar_pos_pred': grammar_pos_pred, + 'grammar_neg_pred': grammar_neg_pred, + } + + return out + +# contrastive loss function, adapted from +# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html +def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor: + neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim)) + return -neg_ce.mean() + + +def clip_loss_fn(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity, dim=0) + image_loss = contrastive_loss(similarity, dim=1) + return (caption_loss + image_loss) / 2.0 diff --git a/retrieval/configs/clip_negative_text.yaml b/retrieval/configs/clip_negative_text.yaml new file mode 100644 index 0000000000000000000000000000000000000000..21dbeea2221a183bd4599a2f85ad039afc5ff44f --- /dev/null +++ b/retrieval/configs/clip_negative_text.yaml @@ -0,0 +1,14 @@ +checkpoint_dir: ./save/clip_negative_text/ + +losses_log_every: 25 +precision: 32 +load_feat: true +data_in_memory: false + +batch_size: 1600 +valid_batch_size: 200 +clip_grad_norm: 0 + +epochs: 30 +use_grammar: true +joint_out: false \ No newline at end of file diff --git a/retrieval/param.py b/retrieval/param.py new file mode 100644 index 0000000000000000000000000000000000000000..45feaa691759a7f0a04080cd397764e6e5362a36 --- /dev/null +++ b/retrieval/param.py @@ -0,0 +1,209 @@ +import argparse +import random + +import numpy as np +import torch + +import pprint +import yaml + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def is_interactive(): + import __main__ as main + return not hasattr(main, '__file__') + + +def get_optimizer(optim, verbose=False): + # Bind the optimizer + if optim == 'rms': + if verbose: + print("Optimizer: Using RMSProp") + optimizer = torch.optim.RMSprop + elif optim == 'adam': + if verbose: + print("Optimizer: Using Adam") + optimizer = torch.optim.Adam + elif optim == 'adamw': + if verbose: + print("Optimizer: Using AdamW") + # optimizer = torch.optim.AdamW + optimizer = 'adamw' + elif optim == 'adamax': + if verbose: + print("Optimizer: Using Adamax") + optimizer = torch.optim.Adamax + elif optim == 'sgd': + if verbose: + print("Optimizer: SGD") + optimizer = torch.optim.SGD + else: + assert False, "Please add your optimizer %s in the list." % optim + + return optimizer + + +def parse_args(parse=True, **optional_kwargs): + parser = argparse.ArgumentParser() + + parser.add_argument('--seed', type=int, default=9595, help='random seed') + + # Data Splits + parser.add_argument("--train", default='karpathy_train') + parser.add_argument("--valid", default='karpathy_val') + parser.add_argument("--test", default='karpathy_test') + # parser.add_argument('--test_only', action='store_true') + + # Quick experiments + parser.add_argument('--train_topk', type=int, default=-1) + parser.add_argument('--valid_topk', type=int, default=-1) + + # Checkpoint + parser.add_argument('--output', type=str, default='snap/test') + parser.add_argument('--load', type=str, default=None, help='Load the model (usually the fine-tuned model).') + parser.add_argument('--from_scratch', action='store_true') + + # CPU/GPU + parser.add_argument("--multiGPU", action='store_const', default=False, const=True) + parser.add_argument('--fp16', action='store_true') + parser.add_argument("--distributed", action='store_true') + parser.add_argument("--num_workers", default=0, type=int) + parser.add_argument('--local_rank', type=int, default=-1) + # parser.add_argument('--rank', type=int, default=-1) + + # Model Config + # parser.add_argument('--encoder_backbone', type=str, default='openai/clip-vit-base-patch32') + # parser.add_argument('--decoder_backbone', type=str, default='bert-base-uncased') + parser.add_argument('--tokenizer', type=str, default='openai/clip-vit-base-patch32') + + # parser.add_argument('--position_embedding_type', type=str, default='absolute') + + # parser.add_argument('--encoder_transform', action='store_true') + + parser.add_argument('--max_text_length', type=int, default=40) + + # parser.add_argument('--image_size', type=int, default=224) + # parser.add_argument('--patch_size', type=int, default=32) + + # parser.add_argument('--decoder_num_layers', type=int, default=12) + + # Training + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--valid_batch_size', type=int, default=None) + + parser.add_argument('--optim', default='adamw') + + parser.add_argument('--warmup_ratio', type=float, default=0.05) + parser.add_argument('--weight_decay', type=float, default=0.01) + parser.add_argument('--clip_grad_norm', type=float, default=-1.0) + parser.add_argument('--gradient_accumulation_steps', type=int, default=1) + parser.add_argument('--lr', type=float, default=1e-4) + parser.add_argument('--adam_eps', type=float, default=1e-6) + parser.add_argument('--adam_beta1', type=float, default=0.9) + parser.add_argument('--adam_beta2', type=float, default=0.999) + + parser.add_argument('--epochs', type=int, default=20) + # parser.add_argument('--dropout', type=float, default=0.1) + + + # Inference + # parser.add_argument('--num_beams', type=int, default=1) + # parser.add_argument('--gen_max_length', type=int, default=20) + + parser.add_argument('--start_from', type=str, default=None) + + # Data + # parser.add_argument('--do_lower_case', type=str2bool, default=None) + + # parser.add_argument('--prefix', type=str, default=None) + + + # COCO Caption + # parser.add_argument('--no_prefix', action='store_true') + + parser.add_argument('--no_cls', action='store_true') + + parser.add_argument('--cfg', type=str, default=None) + parser.add_argument('--id', type=str, default=None) + + # Etc. + parser.add_argument('--comment', type=str, default='') + parser.add_argument("--dry", action='store_true') + + # Parse the arguments. + if parse: + args = parser.parse_args() + # For interative engironmnet (ex. jupyter) + else: + args = parser.parse_known_args()[0] + + loaded_kwargs = {} + if args.cfg is not None: + cfg_path = f'configs/{args.cfg}.yaml' + with open(cfg_path, 'r') as f: + loaded_kwargs = yaml.safe_load(f) + + # Namespace => Dictionary + parsed_kwargs = vars(args) + parsed_kwargs.update(optional_kwargs) + + kwargs = {} + kwargs.update(parsed_kwargs) + kwargs.update(loaded_kwargs) + + args = Config(**kwargs) + + # Bind optimizer class. + verbose = False + args.optimizer = get_optimizer(args.optim, verbose=verbose) + + # Set seeds + torch.manual_seed(args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + + return args + + +class Config(object): + def __init__(self, **kwargs): + """Configuration Class: set kwargs as class attributes with setattr""" + for k, v in kwargs.items(): + setattr(self, k, v) + + @property + def config_str(self): + return pprint.pformat(self.__dict__) + + def __repr__(self): + """Pretty-print configurations in alphabetical order""" + config_str = 'Configurations\n' + config_str += self.config_str + return config_str + + # def update(self, **kwargs): + # for k, v in kwargs.items(): + # setattr(self, k, v) + + # def save(self, path): + # with open(path, 'w') as f: + # yaml.dump(self.__dict__, f, default_flow_style=False) + + # @classmethod + # def load(cls, path): + # with open(path, 'r') as f: + # kwargs = yaml.load(f) + + # return Config(**kwargs) + + +if __name__ == '__main__': + args = parse_args(True) diff --git a/retrieval/pth_loader.py b/retrieval/pth_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..388301edd763d54d95675ca2ed6eb502f77e1eb1 --- /dev/null +++ b/retrieval/pth_loader.py @@ -0,0 +1,334 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import h5py +from lmdbdict import lmdbdict +from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC +import os +import numpy as np +import numpy.random as npr +import random + +import torch +import torch.utils.data as data + +import multiprocessing +import six + +verbose = True +# import torch +# if torch.cuda.current_device() in [0, -1]: +if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': + verbose = False + +class HybridLoader: + """ + If db_path is a director, then use normal file loading + If lmdb, then load from lmdb + The loading method depend on extention. + + in_memory: if in_memory is True, we save all the features in memory + For individual np(y|z)s, we don't need to do that because the system will do this for us. + Should be useful for lmdb or h5. + (Copied this idea from vilbert) + """ + def __init__(self, db_path, ext, in_memory=False): + self.db_path = db_path + self.ext = ext + if self.ext == '.npy': + self.loader = lambda x: np.load(six.BytesIO(x)) + else: + self.loader = lambda x: np.load(six.BytesIO(x))['feat'] + if db_path.endswith('.lmdb'): + self.db_type = 'lmdb' + self.lmdb = lmdbdict(db_path, unsafe=True) + self.lmdb._key_dumps = DUMPS_FUNC['ascii'] + self.lmdb._value_loads = LOADS_FUNC['identity'] + elif db_path.endswith('.pth'): # Assume a key,value dictionary + self.db_type = 'pth' + self.feat_file = torch.load(db_path) + self.loader = lambda x: x + print('HybridLoader: ext is ignored') + elif db_path.endswith('h5'): + self.db_type = 'h5' + self.loader = lambda x: np.array(x).astype('float32') + else: + self.db_type = 'dir' + + self.in_memory = in_memory + if self.in_memory: + self.features = {} + + def get(self, key): + + if self.in_memory and key in self.features: + # We save f_input because we want to save the + # compressed bytes to save memory + f_input = self.features[key] + elif self.db_type == 'lmdb': + f_input = self.lmdb[key] + elif self.db_type == 'pth': + f_input = self.feat_file[key] + elif self.db_type == 'h5': + f_input = h5py.File(self.db_path, 'r')[key] + else: + f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read() + + if self.in_memory and key not in self.features: + self.features[key] = f_input + + # load image + feat = self.loader(f_input) + + return feat + +class CaptionDataset(data.Dataset): + + def get_vocab_size(self): + return self.vocab_size + + def get_vocab(self): + return self.ix_to_word + + def get_seq_length(self): + return self.seq_length + + def __init__(self, opt): + self.opt = opt + self.seq_per_img = opt.seq_per_img + + # feature related options + self.use_fc = getattr(opt, 'use_fc', True) + self.use_att = getattr(opt, 'use_att', True) + self.use_box = getattr(opt, 'use_box', 0) + self.norm_att_feat = getattr(opt, 'norm_att_feat', 0) + self.norm_box_feat = getattr(opt, 'norm_box_feat', 0) + + # load the json file which contains additional information about the dataset + if verbose: + print('DataLoader loading json file: ', opt.input_json) + self.info = json.load(open(self.opt.input_json)) + if 'ix_to_word' in self.info: + self.ix_to_word = self.info['ix_to_word'] + self.vocab_size = len(self.ix_to_word) + if verbose: + print('vocab size is ', self.vocab_size) + + # open the hdf5 file + if verbose: + print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5) + """ + Setting input_label_h5 to none is used when only doing generation. + For example, when you need to test on coco test set. + """ + if self.opt.input_label_h5 != 'none': + self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') + # load in the sequence data + seq_size = self.h5_label_file['labels'].shape + self.label = self.h5_label_file['labels'][:] + self.seq_length = seq_size[1] + if verbose: + print('max sequence length in data is', self.seq_length) + # load the pointers in full to RAM (should be small enough) + self.label_start_ix = self.h5_label_file['label_start_ix'][:] + self.label_end_ix = self.h5_label_file['label_end_ix'][:] + else: + self.seq_length = 1 + + self.data_in_memory = getattr(opt, 'data_in_memory', False) + self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory) + self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory) + self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory) + + self.use_clipscore = getattr(opt, 'use_clipscore', False) + if self.use_clipscore: + self.clipscore_loader = HybridLoader(self.opt.input_clipscore_vis_dir, '.npy', in_memory=self.data_in_memory) + + + self.num_images = len(self.info['images']) # self.label_start_ix.shape[0] + if verbose: + print('read %d image features' %(self.num_images)) + + # separate out indexes for each of the provided splits + self.split_ix = {'train': [], 'val': [], 'test': []} + for ix in range(len(self.info['images'])): + img = self.info['images'][ix] + if not 'split' in img: + self.split_ix['train'].append(ix) + self.split_ix['val'].append(ix) + self.split_ix['test'].append(ix) + elif img['split'] == 'train': + self.split_ix['train'].append(ix) + elif img['split'] == 'val': + self.split_ix['val'].append(ix) + elif img['split'] == 'test': + self.split_ix['test'].append(ix) + elif opt.train_only == 0: # restval + self.split_ix['train'].append(ix) + + if verbose: + print('assigned %d images to split train' %len(self.split_ix['train'])) + print('assigned %d images to split val' %len(self.split_ix['val'])) + print('assigned %d images to split test' %len(self.split_ix['test'])) + + def get_captions(self, ix, seq_per_img): + # fetch the sequence labels + ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 + ix2 = self.label_end_ix[ix] - 1 + ncap = ix2 - ix1 + 1 # number of captions available for this image + assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' + + if ncap < seq_per_img: + # we need to subsample (with replacement) + seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') + for q in range(seq_per_img): + ixl = random.randint(ix1,ix2) + seq[q, :] = self.label[ixl, :self.seq_length] + else: + ixl = random.randint(ix1, ix2 - seq_per_img + 1) + seq = self.label[ixl: ixl + seq_per_img, :self.seq_length] + + return seq + + def collate_func(self, batch): + seq_per_img = self.seq_per_img + + fc_batch = [] + att_batch = [] + label_batch = [] + + clip_vis_feat_batch = [] + + wrapped = False + + infos = [] + gts = [] + + for sample in batch: + # fetch image + if self.use_clipscore: + tmp_fc, tmp_att, tmp_seq, \ + ix, tmp_clip_vis_feat = sample + + clip_vis_feat_batch.append(tmp_clip_vis_feat) + else: + tmp_fc, tmp_att, tmp_seq, \ + ix = sample + + fc_batch.append(tmp_fc) + att_batch.append(tmp_att) + + tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int') + if hasattr(self, 'h5_label_file'): + # if there is ground truth + tmp_label[:, 1 : self.seq_length + 1] = tmp_seq + label_batch.append(tmp_label) + + # Used for reward evaluation + if hasattr(self, 'h5_label_file'): + # if there is ground truth + gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) + else: + gts.append([]) + + # record associated info as well + info_dict = {} + info_dict['ix'] = ix + info_dict['id'] = self.info['images'][ix]['id'] + info_dict['file_path'] = self.info['images'][ix].get('file_path', '') + infos.append(info_dict) + + # #sort by att_feat length + # fc_batch, att_batch, label_batch, gts, infos = \ + # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) + if self.use_clipscore: + fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos = \ + zip(*sorted(zip(fc_batch, att_batch, label_batch, clip_vis_feat_batch, gts, infos), key=lambda x: 0, reverse=True)) + else: + fc_batch, att_batch, label_batch, gts, infos = \ + zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) + data = {} + data['fc_feats'] = np.stack(fc_batch) + # merge att_feats + max_att_len = max([_.shape[0] for _ in att_batch]) + data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32') + for i in range(len(att_batch)): + data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i] + data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32') + for i in range(len(att_batch)): + data['att_masks'][i, :att_batch[i].shape[0]] = 1 + # set att_masks to None if attention features have same length + if data['att_masks'].sum() == data['att_masks'].size: + data['att_masks'] = None + + if self.use_clipscore: + data['clip_vis_feats'] = np.stack(clip_vis_feat_batch) + + data['labels'] = np.vstack(label_batch) + # generate mask + nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels']))) + mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32') + for ix, row in enumerate(mask_batch): + row[:nonzeros[ix]] = 1 + data['masks'] = mask_batch + data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1) + data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1) + + data['gts'] = gts # all ground truth captions of each images + data['infos'] = infos + + data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor + + return data + + def __getitem__(self, ix): + """This function returns a tuple that is further passed to collate_fn + """ + if self.use_att: + att_feat = self.att_loader.get(str(self.info['images'][ix]['id'])) + # Reshape to K x C + att_feat = att_feat.reshape(-1, att_feat.shape[-1]) + if self.norm_att_feat: + att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True) + if self.use_box: + box_feat = self.box_loader.get(str(self.info['images'][ix]['id'])) + # devided by image width and height + x1,y1,x2,y2 = np.hsplit(box_feat, 4) + h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width'] + box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1?? + if self.norm_box_feat: + box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True) + att_feat = np.hstack([att_feat, box_feat]) + # sort the features by the size of boxes + att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True)) + else: + att_feat = np.zeros((0,0), dtype='float32') + if self.use_fc: + try: + fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id'])) + except: + # Use average of attention when there is no fc provided (For bottomup feature) + fc_feat = att_feat.mean(0) + else: + fc_feat = np.zeros((0), dtype='float32') + if hasattr(self, 'h5_label_file'): + seq = self.get_captions(ix, self.seq_per_img) + else: + seq = None + + if self.use_clipscore: + clip_vis_feat = self.clipscore_loader.get( + str(self.info['images'][ix]['id'])) + + return (fc_feat, + att_feat, seq, + ix, clip_vis_feat) + + return (fc_feat, + att_feat, seq, + ix) + + def __len__(self): + return len(self.info['images']) diff --git a/retrieval/text_utils.py b/retrieval/text_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51f981054b41a945656f9e619c722e09de198bf7 --- /dev/null +++ b/retrieval/text_utils.py @@ -0,0 +1,74 @@ +import random + +def repeat(text, n_max_gram=3, n_max_repeat=3): + """repeat n-grams""" + tokens = text.split() + + n_gram = random.randint(1, n_max_gram) + + repeat_token_idx = random.randint(0, len(tokens) - n_gram) + + repeated_tokens = tokens[repeat_token_idx:repeat_token_idx+n_gram] + + n_repeat = random.randint(1, n_max_repeat) + for _ in range(n_repeat): + insert_idx = random.randint(0, len(tokens)) + tokens = tokens[:insert_idx] + \ + repeated_tokens + tokens[insert_idx:] + + new_text = " ".join(tokens) + return new_text + +def remove(text, n_max_gram=3): + """remove n-grams""" + tokens = text.split() + + n_gram = random.randint(1, n_max_gram) + + remove_token_idx = random.randint(0, len(tokens) - n_gram) + + tokens = tokens[:remove_token_idx] + tokens[remove_token_idx + n_gram:] + + new_text = " ".join(tokens) + return new_text + +def insert(text, vocab, n_max_tokens=3): + """Insert tokens""" + tokens = text.split() + + n_insert_token = random.randint(1, n_max_tokens) + + for _ in range(n_insert_token): + insert_token_idx = random.randint(0, len(tokens) - 1) + insert_token = random.choice(vocab) + tokens = tokens[:insert_token_idx] + [insert_token] + tokens[insert_token_idx:] + + new_text = " ".join(tokens) + return new_text + +def swap(text, vocab, n_max_tokens=3): + """Swap tokens""" + tokens = text.split() + + n_swap_tokens = random.randint(1, n_max_tokens) + + for _ in range(n_swap_tokens): + swap_token_idx = random.randint(0, len(tokens) - 1) + + swap_token = random.choice(vocab) + while swap_token == tokens[swap_token_idx]: + swap_token = random.choice(vocab) + + tokens[swap_token_idx] = swap_token + + new_text = " ".join(tokens) + return new_text + +def shuffle(text): + """shuffle tokens""" + tokens = text.split() + + random.shuffle(tokens) + + new_text = " ".join(tokens) + return new_text diff --git a/retrieval/train_pl.py b/retrieval/train_pl.py new file mode 100644 index 0000000000000000000000000000000000000000..28f1330c945dd4b083a0adff287e4020b2433a4d --- /dev/null +++ b/retrieval/train_pl.py @@ -0,0 +1,661 @@ +from ast import parse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +import numpy as np + +import time +import os +from collections import defaultdict + +# import captioning.utils.opts as opts +# import captioning.models as models +# from captioning.data.pth_loader import CaptionDataset +# import captioning.utils.eval_utils as eval_utils +# import captioning.utils.misc as utils +# from captioning.utils.rewards import init_scorer, get_self_critical_reward +# from captioning.modules.loss_wrapper import LossWrapper + +from clip_model import CLIPScore +from caption_data import COCORetrievalDataset + +import pytorch_lightning as pl + +import detectron2.utils.comm as d2comm +from detectron2.utils.env import seed_all_rng +seed_all_rng(1234) + + +class LitModel(pl.LightningModule): + def __init__(self, opt): + super().__init__() + self.opt = opt + self.args = args + # Intilaize dataset + # self.dataset = CaptionDataset(opt) + + # self.dataset = + + # opt.vocab_size = self.dataset.vocab_size + # opt.seq_length = self.dataset.seq_length + # self.batch_size = opt.batch_size + + # Build model + # opt.vocab = self.dataset.get_vocab() + # model = models.setup(opt) + # print(model) + # del opt.vocab + + # wrapper with loss in it. + # lw_model = LossWrapper(model, opt) + + self.model = CLIPScore(use_grammar=opt.use_grammar, joint_out=opt.joint_out) + # self.lw_model = lw_model + + for p in self.model.clip_model.vision_model.parameters(): + p.requires_grad = False + for p in self.model.clip_model.visual_projection.parameters(): + p.requires_grad = False + + # self.struc_flag = None + # self.sc_flag = None + + + def forward(self, *args, **kwargs): + """ + I hate this design. Never pretend it as a nn.Module + """ + raise NotImplementedError + + def train_dataloader(self): + # train_dataset = torch.utils.data.Subset( + # self.dataset, + # self.dataset.split_ix['train'] + # ) + + # train_loader = torch.utils.data.DataLoader( + # dataset=train_dataset, + # batch_size=self.batch_size, + # shuffle=True, + # num_workers=4, + # collate_fn=self.dataset.collate_func + # ) + + train_dataset = COCORetrievalDataset( + split='karpathy_train', mode='train', + args=opt, + verbose=verbose + ) + + train_loader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=opt.batch_size, + shuffle=True, + num_workers=4, + collate_fn=train_dataset.collate_fn + ) + + return train_loader + + def val_dataloader(self, split='karpathy_val'): + # val_dataset = torch.utils.data.Subset( + # self.dataset, + # self.dataset.split_ix[split] + # ) + # val_loader = torch.utils.data.DataLoader( + # val_dataset, + # batch_size=self.batch_size, + # shuffle=False, + # num_workers=4, + # drop_last=False, + # collate_fn=self.dataset.collate_func + # ) + + val_dataset = COCORetrievalDataset( + split=split, mode='val', + args=opt, + verbose=verbose + ) + + val_loader = torch.utils.data.DataLoader( + dataset=val_dataset, + batch_size=opt.valid_batch_size, + shuffle=False, + num_workers=4, + drop_last=False, + collate_fn=val_dataset.collate_fn + ) + + return val_loader + + def test_dataloader(self): + + return self.val_dataloader('karpathy_test') + + def training_step(self, data, batch_idx): + + + batch = data + self.model.train() + + model_out = self.model.train_step( + img_feat=batch['img_feats'], + text=batch['text'], + neg_text=batch['neg_text'], + ) + + clip_loss = model_out['clip_loss'] + + if self.opt.joint_out: + loss = clip_loss + else: + grammar_loss = model_out['grammar_loss'] + loss = clip_loss + grammar_loss + + + data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1] + data_time = torch.tensor(data_time) + + # print('batch_idx', batch_idx) + # print('loss:', loss) + + # logger_logs = model_out.copy() + logger_logs = {} + + logger_logs['loss'] = loss.detach() + + logger_logs['clip_loss'] = clip_loss.detach() + + if not self.opt.joint_out: + logger_logs['grammar_loss'] = grammar_loss.detach() + + logger_logs['data_time'] = data_time.detach() + + # UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0 + # Please use self.log(...) inside the lightningModule instead. + + # # log on a step or aggregate epoch metric to the logger and/or progress bar + # # (inside LightningModule) + # self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) + # warnings.warn(*args, **kwargs) + # UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0 + # Please use self.log(...) inside the lightningModule instead. + + # output = { + # 'loss': loss, + # 'log': logger_logs, + # 'progress_bar': {'data_time': data_time} + # } + + for k, v in logger_logs.items(): + if k in ['data_time', 'clip_loss', 'grammar_loss']: + self.log('train/'+k, v, prog_bar=True) + else: + self.log('train/'+k, v) + + # print('training step logged') + + return loss + + def validation_step(self, data, batch_idx): + + batch = data + self.model.eval() + + with torch.no_grad(): + model_out = self.model.train_step( + img_feat=batch['img_feats'], + text=batch['text'], + neg_text=batch['neg_text'], + ) + + if self.opt.joint_out: + clip_loss = model_out['clip_loss'] + loss = clip_loss + + output = { + # 'val_loss': loss, + 'loss': loss.detach(), + 'clip_loss': clip_loss.detach(), + # 'grammar_loss': grammar_loss.detach(), + + 'img_feat': model_out['img_feat'].detach(), + 'text_feat': model_out['text_feat'].detach(), + # 'neg_text_feat': model_out['neg_text_feat'].detach(), + # 'grammar_pos_pred': model_out['grammar_pos_pred'].detach(), + # 'grammar_neg_pred': model_out['grammar_neg_pred'].detach(), + # 'predictions': predictions, + # 'n_predictions': n_predictions, + } + else: + clip_loss = model_out['clip_loss'] + grammar_loss = model_out['grammar_loss'] + loss = clip_loss + grammar_loss + + output = { + # 'val_loss': loss, + 'loss': loss.detach(), + 'clip_loss': clip_loss.detach(), + 'grammar_loss': grammar_loss.detach(), + + 'img_feat': model_out['img_feat'].detach(), + 'text_feat': model_out['text_feat'].detach(), + # 'neg_text_feat': model_out['neg_text_feat'].detach(), + 'grammar_pos_pred': model_out['grammar_pos_pred'].detach(), + 'grammar_neg_pred': model_out['grammar_neg_pred'].detach(), + # 'predictions': predictions, + # 'n_predictions': n_predictions, + } + return output + + def test_step(self, *args, **kwargs): + return self.validation_step(*args, **kwargs) + + def validation_epoch_end(self, outputs, split='val'): + outputs = d2comm.gather(outputs) + # master node + if d2comm.is_main_process(): + assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0 + outputs = sum(outputs, []) + + out = {} + + val_loss_mean = sum([_['loss'].cpu() for _ in outputs]) / len(outputs) + val_clip_loss_mean = sum([_['clip_loss'].cpu() for _ in outputs]) / len(outputs) + if not self.opt.joint_out: + val_grammar_loss_mean = sum([_['grammar_loss'].cpu() for _ in outputs]) / len(outputs) + + print('loss', val_loss_mean.item()) + print('clip_loss', val_clip_loss_mean.item()) + if not self.opt.joint_out: + print('grammar_loss', val_grammar_loss_mean.item()) + + logit_scale = self.model.clip_model.logit_scale.exp().cpu() + + text_feats = torch.cat([_['text_feat'].cpu() for _ in outputs], dim=0) + img_feats = torch.cat([_['img_feat'].cpu() for _ in outputs], dim=0) + + assert text_feats.size() == (5000, 512), text_feats.size() + assert img_feats.size() == (5000, 512), img_feats.size() + + logits_per_text = torch.matmul(text_feats, img_feats.t()) * logit_scale + logits_per_image = logits_per_text.T + + # text-to-image retrieval + print('Text-to-Image retrieval') + for k in [1, 5, 10]: + text_to_image_topk = logits_per_text.topk(k, dim=1).indices + + n_text = len(text_to_image_topk) + + labels = torch.arange(0, n_text).view(-1, 1) + + n_retrieved = ((text_to_image_topk == labels).sum(dim=1) > 0).sum() + + recall_k = n_retrieved / n_text * 100 + + out[f'text_to_image_recall_{k}'] = recall_k.item() + + print(f'R@{k}: {recall_k.item():.2f}%') + + # image-to-text retrieval + print('Image-to-Text retrieval') + for k in [1, 5, 10]: + image_to_text_topk = logits_per_image.topk(k, dim=1).indices + + n_image = len(image_to_text_topk) + + labels = torch.arange(0, n_image).view(-1, 1) + + n_retrieved = ((image_to_text_topk == labels).sum(dim=1) > 0).sum() + + recall_k = n_retrieved / n_image * 100 + + out[f'image_to_text_recall_{k}'] = recall_k.item() + + print(f'R@{k}: {recall_k.item():.2f}%') + + out.update({ + 'loss': val_loss_mean.item(), + 'clip_loss': val_clip_loss_mean.item() + }) + + if not self.opt.joint_out: + # grammar scoring + grammar_pos_pred = torch.cat([_['grammar_pos_pred'].cpu() for _ in outputs], dim=0) + grammar_neg_pred = torch.cat([_['grammar_neg_pred'].cpu() for _ in outputs], dim=0) + + TP = (grammar_pos_pred == 1).sum().item() + FP = (grammar_pos_pred == 0).sum().item() + FN = (grammar_neg_pred == 1).sum().item() + TN = (grammar_neg_pred == 0).sum().item() + print('Grammar check') + print(f'TP: {TP} FP: {FP} FN: {FN} TN: {TN}') + + precision = TP / (TP + FP) * 100 + recall = TP / (TP + FN) * 100 + accuracy = (TP + TN) / (TP + FP + FN + TN) * 100 + f1 = 2 * precision * recall / (precision + recall) + print(f'Precision: {precision:.2f}%') + print(f'Recall: {recall:.2f}%') + print(f'Accuracy: {accuracy:.2f}%') + print(f'F1: {f1:.2f}%') + print('Total: {}'.format(len(grammar_pos_pred))) + + out.update({ + 'grammar_loss': val_grammar_loss_mean, + + 'grammar_precision': precision, + 'grammar_recall': recall, + 'grammar_accuracy': accuracy, + 'grammar_f1': f1, + + }) + + else: + out = {} + + out = d2comm.all_gather(out)[0] # Only the one from master node + assert len(out) > 0 # make sure the head has index 0 + + # must all be tensors + out = {k: torch.tensor(v) if not torch.is_tensor( + v) else v for k, v in out.items()} + + for k, v in out.items(): + self.log(f'{split}/{k}', v) + + def test_epoch_end(self, outputs): + + self.validation_epoch_end(outputs, 'test') + + def configure_optimizers(self): + # opt = self.opt + # model = self.model + + # parameters = [p for p in model.parameters() if p.requires_grad] + + # if opt.noamopt: + # # assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer' + # optimizer = utils.get_std_opt( + # model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) + # elif opt.reduce_on_plateau: + # # optimizer = utils.build_optimizer(model.parameters(), opt) + # optimizer = utils.build_optimizer(parameters, opt) + # optimizer = utils.ReduceLROnPlateau(optimizer, + # factor=opt.reduce_on_plateau_factor, + # patience=opt.reduce_on_plateau_patience) + # else: + # # optimizer = utils.build_optimizer(model.parameters(), opt) + # optimizer = utils.build_optimizer(parameters, opt) + + + # from transformers.optimization import AdamW, get_linear_schedule_with_warmup + # batch_per_epoch = len(self.train_loader) + # t_total = batch_per_epoch // self.args.gradient_accumulation_steps * self.args.epochs + # warmup_ratio = self.args.warmup_ratio + # warmup_iters = int(t_total * warmup_ratio) + # if self.verbose: + # print("Batch per epoch: %d" % batch_per_epoch) + # print("Total Iters: %d" % t_total) + # print('Warmup ratio:', warmup_ratio) + # print("Warm up Iters: %d" % warmup_iters) + + if self.args.optim == 'adamw': + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + for group in optimizer_grouped_parameters: + group['params'] = [p for p in group['params'] if p.requires_grad] + + from transformers.optimization import AdamW + optim = AdamW(optimizer_grouped_parameters, + lr=self.args.lr, eps=self.args.adam_eps) + # lr_scheduler = get_linear_schedule_with_warmup( + # optim, warmup_iters, t_total) + + # optimizers = [] + optimizers = [optim] + lr_schedulers = [] + + return optimizers, lr_schedulers + + def optimizer_step(self, epoch, batch_idx, optimizer, + optimizer_idx, *args, **kwargs): + # # warm up lr + # opt = self.opt + # iteration = self.trainer.global_step + # if opt.use_warmup and (iteration < opt.noamopt_warmup): + # opt.current_lr = opt.learning_rate * \ + # (iteration+1) / opt.noamopt_warmup + # utils.set_lr(optimizer, opt.current_lr) + + super().optimizer_step(epoch, batch_idx, optimizer, + optimizer_idx, *args, **kwargs) + + # print('optimizer step') + + def state_dict(self): + """ + Save the model state dict as well as opt and vocab + """ + state_dict = self.model.state_dict() + device = next(iter(state_dict.values())).device + assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case' + # state_dict.update({ + # '_vocab': utils.serialize_to_tensor(self.model.vocab).to(device), + # '_opt': utils.serialize_to_tensor(self.opt).to(device) + # }) + return state_dict + + def load_state_dict(self, state_dict=None, strict=True): + # if '_vocab' in state_dict: + # self.model.vocab = utils.deserialize(state_dict['_vocab']) + # del state_dict['_vocab'] + # elif strict: + # raise KeyError + # if '_opt' in state_dict: + # saved_model_opt = utils.deserialize(state_dict['_opt']) + # del state_dict['_opt'] + # opt = self.opt + # # Make sure the saved opt is compatible with the curren topt + # need_be_same = ["caption_model", + # "rnn_type", "rnn_size", "num_layers"] + # for checkme in need_be_same: + # if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \ + # getattr(opt, checkme) in ['updown', 'topdown']: + # continue + # assert getattr(saved_model_opt, checkme) == getattr( + # opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme + # elif strict: + # raise KeyError + self.model.load_state_dict(state_dict, strict) + + +class OnEpochStartCallback(pl.Callback): + + def on_epoch_start(self, trainer, pl_module): + # Update lr/training stage/scheduled sampling prob etc. + opt = pl_module.opt + model = pl_module.model + epoch = trainer.current_epoch + optimizer = trainer.optimizers[0] + + # if not opt.noamopt and not opt.reduce_on_plateau: + # # Assign the learning rate + # if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: + # frac = ( + # epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every + # decay_factor = opt.learning_rate_decay_rate ** frac + # opt.current_lr = opt.learning_rate * decay_factor + # else: + # opt.current_lr = opt.learning_rate + # utils.set_lr(optimizer, opt.current_lr) # set the decayed rate + # # Assign the scheduled sampling prob + # if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: + # frac = ( + # epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every + # opt.ss_prob = min(opt.scheduled_sampling_increase_prob * + # frac, opt.scheduled_sampling_max_prob) + # model.ss_prob = opt.ss_prob + + # # If start self critical training + # if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: + # sc_flag = True + # init_scorer(opt.cached_tokens) + # else: + # sc_flag = False + + # # If start structure loss training + # if opt.structure_after != -1 and epoch >= opt.structure_after: + # struc_flag = True + # init_scorer(opt.cached_tokens) + # else: + # struc_flag = False + + # pl_module.struc_flag = struc_flag + # pl_module.sc_flag = sc_flag + + +class ModelCheckpoint(pl.callbacks.ModelCheckpoint): + + def on_keyboard_interrupt(self, trainer, pl_module): + # Save model when keyboard interrupt + filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt') + self._save_model(filepath) + +from param import parse_args +# opt = opts.parse_opt() +args = parse_args() +opt = args + +checkpoint_callback = ModelCheckpoint( + filepath=opt.checkpoint_dir + '{epoch:02d}', + # dirpath=opt.checkpoint_path, + save_last=True, + save_top_k=1, + verbose=True, + # monitor='to_monitor', + # monitor='val/to_monitor', + # monitor='val/CIDEr', + monitor='val/loss', + mode='min', + # prefix=opt.id+'_', + prefix=opt.id, + # filename=f'{opt.id}_', +) + +verbose = True +# import torch +# if torch.cuda.current_device() in [0, -1]: +if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': + verbose = False + +# if verbose: +# print(opt) +# print(""" +# val_image_use, +# save_checkpoint_very +# save_every_epoch, +# save_history-ckpt will be ignored. +# """) + +# Lightning defines batch size as batch size per gpu +assert opt.batch_size % torch.cuda.device_count() == 0 +opt.batch_size = opt.batch_size // torch.cuda.device_count() +opt.valid_batch_size = opt.valid_batch_size // torch.cuda.device_count() + +# If resume from last checkpoint +# if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}_last.ckpt')): +# resume_from = os.path.join(opt.start_from, f'{opt.id}_last.ckpt') +if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}-last.ckpt')): + resume_from = os.path.join(opt.start_from, f'{opt.id}-last.ckpt') + if verbose: + print('resume from', resume_from) +else: + resume_from = None + +from pytorch_lightning.loggers import WandbLogger +wandb_logger = WandbLogger( + # project='CLIP-ViL-COCOCaption', + project='CLIP-Finetune-COCO', + name=opt.id, +) + +if verbose: + wandb_logger.experiment.config.update(opt) + from pathlib import Path + import glob + import wandb + # src_dir = Path(__file__).resolve().parent.parent + glob_str = "*.py" + base_path = './' + wandb.save(glob_str=glob_str, base_path=base_path) + + glob_str = "**/*.yaml" + base_path = './' + wandb.save(glob_str=glob_str, base_path=base_path) + + # code = wandb.Artifact('project-source', type='code') + # for path in glob.glob('**/*.py', recursive=True): + # code.add_file(path, name='source/'+path) + # print(path) + # wandb.run.use_artifact(code) + + + + +lit = LitModel(opt) +# warning grad_clip_mode is ignored. +trainer = pl.Trainer( + callbacks=[ + OnEpochStartCallback(), + # pl.callbacks.lr_logger.LearningRateLogger() + pl.callbacks.LearningRateMonitor() + ], + default_root_dir=opt.checkpoint_dir, + resume_from_checkpoint=resume_from, + + distributed_backend='ddp', + gpus=torch.cuda.device_count(), + + # gpus=1, + + check_val_every_n_epoch=1, + # max_epochs=opt.max_epochs, + max_epochs=opt.epochs, + # gradient_clip_val=opt.grad_clip_value, + gradient_clip_val=opt.clip_grad_norm, + + checkpoint_callback=checkpoint_callback, + log_gpu_memory='min_max', + # log_save_interval=opt.losses_log_every, + log_every_n_steps=opt.losses_log_every, + profiler=True, + # profiler='simple', + # row_log_interval=10, # what is it? + flush_logs_every_n_steps=10, + num_sanity_val_steps=0, + # val_check_interval=0.01, + # limit_train_batches=500, + # progress_bar_refresh_rate=0, + # fast_dev_run=True, + precision=opt.precision, + logger=wandb_logger +) + +if os.getenv('EVALUATE', '0') == '1': + trainer.test(lit) +else: + trainer.fit(lit) diff --git a/save/README.md b/save/README.md new file mode 100644 index 0000000000000000000000000000000000000000..91547b46ffedc91d209fec4c7ac0b8cfb9e447de --- /dev/null +++ b/save/README.md @@ -0,0 +1 @@ +Directory for checkpoints \ No newline at end of file diff --git a/scripts/build_bpe_subword_nmt.py b/scripts/build_bpe_subword_nmt.py new file mode 100644 index 0000000000000000000000000000000000000000..bdf5dfa17f18f06c285edb17500b67301c1143dd --- /dev/null +++ b/scripts/build_bpe_subword_nmt.py @@ -0,0 +1,214 @@ +""" +Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua + +Input: json file that has the form +[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] +example element in this list would look like +{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} + +This script reads this json, does some basic preprocessing on the captions +(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays + +Output: a json file and an hdf5 file +The hdf5 file contains several fields: +/labels is (M,max_length) uint32 array of encoded labels, zero padded +/label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the + first and last indices (in range 1..M) of labels for each image +/label_length stores the length of the sequence for each of the M sequences + +The json file has a dict that contains: +- an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed +- an 'images' field that is a list holding auxiliary information for each image, + such as in particular the 'split' it was assigned to. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import json +import argparse +from random import shuffle, seed +import string +# non-standard dependencies: +import h5py +import numpy as np +import torch +import torchvision.models as models +import skimage.io +from PIL import Image + +import codecs +import tempfile +from subword_nmt import learn_bpe, apply_bpe + +# python scripts/build_bpe_subword_nmt.py --input_json data/dataset_coco.json --output_json data/cocotalkbpe.json --output_h5 data/cocotalkbpe + +def build_vocab(imgs, params): + # count up the number of words + captions = [] + for img in imgs: + for sent in img['sentences']: + captions.append(' '.join(sent['tokens'])) + captions='\n'.join(captions) + all_captions = tempfile.NamedTemporaryFile(delete=False) + all_captions.close() + with open(all_captions.name, 'w') as txt_file: + txt_file.write(captions) + + # + codecs_output = tempfile.NamedTemporaryFile(delete=False) + codecs_output.close() + with codecs.open(codecs_output.name, 'w', encoding='UTF-8') as output: + learn_bpe.learn_bpe(codecs.open(all_captions.name, encoding='UTF-8'), output, params['symbol_count']) + + with codecs.open(codecs_output.name, encoding='UTF-8') as codes: + bpe = apply_bpe.BPE(codes) + + tmp = tempfile.NamedTemporaryFile(delete=False) + tmp.close() + + tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8') + + for _, img in enumerate(imgs): + img['final_captions'] = [] + for sent in img['sentences']: + txt = ' '.join(sent['tokens']) + txt = bpe.segment(txt).strip() + img['final_captions'].append(txt.split(' ')) + tmpout.write(txt) + tmpout.write('\n') + if _ < 20: + print(txt) + + tmpout.close() + tmpin = codecs.open(tmp.name, encoding='UTF-8') + + vocab = learn_bpe.get_vocabulary(tmpin) + vocab = sorted(vocab.keys(), key=lambda x: vocab[x], reverse=True) + + # Always insert UNK + print('inserting the special UNK token') + vocab.append('UNK') + + print('Vocab size:', len(vocab)) + + os.remove(all_captions.name) + with open(codecs_output.name, 'r') as codes: + bpe = codes.read() + os.remove(codecs_output.name) + os.remove(tmp.name) + + return vocab, bpe + +def encode_captions(imgs, params, wtoi): + """ + encode all captions into one large array, which will be 1-indexed. + also produces label_start_ix and label_end_ix which store 1-indexed + and inclusive (Lua-style) pointers to the first and last caption for + each image in the dataset. + """ + + max_length = params['max_length'] + N = len(imgs) + M = sum(len(img['final_captions']) for img in imgs) # total number of captions + + label_arrays = [] + label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed + label_end_ix = np.zeros(N, dtype='uint32') + label_length = np.zeros(M, dtype='uint32') + caption_counter = 0 + counter = 1 + for i,img in enumerate(imgs): + n = len(img['final_captions']) + assert n > 0, 'error: some image has no captions' + + Li = np.zeros((n, max_length), dtype='uint32') + for j,s in enumerate(img['final_captions']): + label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence + caption_counter += 1 + for k,w in enumerate(s): + if k < max_length: + Li[j,k] = wtoi[w] + + # note: word indices are 1-indexed, and captions are padded with zeros + label_arrays.append(Li) + label_start_ix[i] = counter + label_end_ix[i] = counter + n - 1 + + counter += n + + L = np.concatenate(label_arrays, axis=0) # put all the labels together + assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' + assert np.all(label_length > 0), 'error: some caption had no words?' + + print('encoded captions to array of size ', L.shape) + return L, label_start_ix, label_end_ix, label_length + +def main(params): + + imgs = json.load(open(params['input_json'], 'r')) + imgs = imgs['images'] + + seed(123) # make reproducible + + # create the vocab + vocab, bpe = build_vocab(imgs, params) + itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table + wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table + + # encode captions in large arrays, ready to ship to hdf5 file + L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) + + # create output h5 file + N = len(imgs) + f_lb = h5py.File(params['output_h5']+'_label.h5', "w") + f_lb.create_dataset("labels", dtype='uint32', data=L) + f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) + f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) + f_lb.create_dataset("label_length", dtype='uint32', data=label_length) + f_lb.close() + + # create output json file + out = {} + out['ix_to_word'] = itow # encode the (1-indexed) vocab + out['images'] = [] + out['bpe'] = bpe + for i,img in enumerate(imgs): + + jimg = {} + jimg['split'] = img['split'] + if 'filename' in img: jimg['file_path'] = os.path.join(img['filepath'], img['filename']) # copy it over, might need + if 'cocoid' in img: jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) + + if params['images_root'] != '': + with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: + jimg['width'], jimg['height'] = _img.size + + out['images'].append(jimg) + + json.dump(out, open(params['output_json'], 'w')) + print('wrote ', params['output_json']) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + # input json + parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') + parser.add_argument('--output_json', default='data.json', help='output json file') + parser.add_argument('--output_h5', default='data', help='output h5 file') + parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') + + # options + parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') + parser.add_argument('--symbol_count', default=10000, type=int, help='only words that occur more than this number of times will be put in vocab') + + args = parser.parse_args() + params = vars(args) # convert to ordinary dict + print('parsed input parameters:') + print(json.dumps(params, indent = 2)) + main(params) + + diff --git a/scripts/clip_prepro_feats.py b/scripts/clip_prepro_feats.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a45c829fa5c19e36509170135835c6d6bc8d67 --- /dev/null +++ b/scripts/clip_prepro_feats.py @@ -0,0 +1,170 @@ +""" +Preprocess a raw json dataset into features files for use in data_loader.py + +Input: json file that has the form +[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] +example element in this list would look like +{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} + +This script reads this json, does some basic preprocessing on the captions +(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays + +Output: two folders of features +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import json +import argparse +from random import shuffle, seed +import string +# non-standard dependencies: +import h5py +from six.moves import cPickle +import numpy as np +import torch +import torchvision.models as models +import skimage.io + +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from PIL import Image +from torch import nn + +preprocess = Compose([ + Resize((448, 448), interpolation=Image.BICUBIC), + CenterCrop((448, 448)), + ToTensor() +]) + + +from clip.clip import load +from timm.models.vision_transformer import resize_pos_embed +import timm + +from captioning.utils.resnet_utils import myResnet +import captioning.utils.resnet as resnet + +from tqdm import tqdm + + +def main(params): + if params["model_type"] != 'vit_base_patch32_224_in21k': + model, transform = load(params["model_type"], jit=False) + else: + model = timm.create_model(params["model_type"], pretrained=True) + model = model.cuda() + + if params["model_type"] != 'vit_base_patch32_224_in21k': + save_model_type = params["model_type"].split("-")[0] + mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to("cuda").reshape(3, 1, 1) + std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to("cuda").reshape(3, 1, 1) + + if "RN" in params["model_type"]: + num_patches = 196 #600 * 1000 // 32 // 32 + pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, model.visual.attnpool.positional_embedding.shape[-1], device='cuda'),) + pos_embed.weight = resize_pos_embed(model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed) + model.visual.attnpool.positional_embedding = pos_embed + + else: + save_model_type = 'vit_base' + mean = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1) + std = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1) + + num_patches = 196 #600 * 1000 // 32 // 32 + pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, 768, device='cuda'),) + pos_embed.weight = resize_pos_embed(model.pos_embed, pos_embed) + model.pos_embed = pos_embed + + if params["model_type"] == "ViT-B/32": + num_patches = 196 #600 * 1000 // 32 // 32 + pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768, device='cuda'),) + pos_embed.weight = resize_pos_embed(model.visual.positional_embedding.unsqueeze(0), pos_embed.unsqueeze(0)) + model.visual.positional_embedding = pos_embed + imgs = json.load(open(params['input_json'], 'r')) + + imgs = imgs['images'] + + if args.n_jobs > 1: + print('Total imgs:', len(imgs)) + print('Using {} jobs'.format(args.n_jobs)) + print('job id:', args.job_id) + imgs = imgs[args.job_id::args.n_jobs] + + N = len(imgs) + + seed(123) # make reproducible + + dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc' + dir_att = params['output_dir']+'_clip_'+save_model_type+'_att' + if not os.path.isdir(dir_fc): + os.mkdir(dir_fc) + if not os.path.isdir(dir_att): + os.mkdir(dir_att) + + for i,img in enumerate(tqdm(imgs)): + # load the image + with torch.no_grad(): + + image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB")) + image = torch.tensor(np.stack([image])).cuda() + image -= mean + image /= std + if "RN" in params["model_type"]: + tmp_att, tmp_fc = model.encode_image(image) + tmp_att = tmp_att[0].permute(1, 2, 0) + tmp_fc = tmp_fc[0] + elif params["model_type"] == 'vit_base_patch32_224_in21k': + x = model(image) + tmp_fc = x[0, 0, :] + tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) + else: + x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :] + x = model.visual.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + + for layer_idx, layer in enumerate(model.visual.transformer.resblocks): + x = layer(x) + + x = x.permute(1, 0, 2) + tmp_fc = x[0, 0, :] + tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) + + np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) + np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) + + + # if i % 1000 == 0: + # print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) + print('wrote ', dir_fc, dir_att) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + # input json + parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') + parser.add_argument('--output_dir', default='data', help='output h5 file') + + # options + parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') + parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') + parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k') + + parser.add_argument('--n_jobs', default=-1, type=int, help='number of jobs to run in parallel') + parser.add_argument('--job_id', default=0, type=int, help='job id') + parser.add_argument('--batch_size', default=1, type=int, help='batch size') + + + args = parser.parse_args() + params = vars(args) # convert to ordinary dict + print('parsed input parameters:') + print(json.dumps(params, indent = 2)) + main(params) diff --git a/scripts/clipscore_prepro_feats.py b/scripts/clipscore_prepro_feats.py new file mode 100644 index 0000000000000000000000000000000000000000..72df6a02e55c3828dbc043fa272808acf1ee9f7e --- /dev/null +++ b/scripts/clipscore_prepro_feats.py @@ -0,0 +1,162 @@ +""" +Preprocess a raw json dataset into features files for use in data_loader.py + +Input: json file that has the form +[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] +example element in this list would look like +{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} + +This script reads this json, does some basic preprocessing on the captions +(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays + +Output: two folders of features +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import json +import argparse +from random import shuffle, seed +import string +# non-standard dependencies: +import h5py +from six.moves import cPickle +import numpy as np +import torch +import torchvision.models as models +import skimage.io + +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from PIL import Image +from torch import nn + +# preprocess = Compose([ +# Resize((448, 448), interpolation=Image.BICUBIC), +# CenterCrop((448, 448)), +# ToTensor() +# ]) + + +# from clip.clip import load +# from timm.models.vision_transformer import resize_pos_embed +# import timm + +# from captioning.utils.resnet_utils import myResnet +# import captioning.utils.resnet as resnet + +from captioning.utils.clipscore import CLIPScore + +from tqdm import tqdm + + + +def main(params): + + clipscore_model = CLIPScore() + clipscore_model.to('cuda') + + imgs = json.load(open(params['input_json'], 'r')) + imgs = imgs['images'] + + if args.n_jobs > 1: + print('Total imgs:', len(imgs)) + print('Using {} jobs'.format(args.n_jobs)) + print('job id:', args.job_id) + imgs = imgs[args.job_id::args.n_jobs] + + N = len(imgs) + + seed(123) # make reproducible + + # dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc' + # dir_att = params['output_dir']+'_clip_'+save_model_type+'_att' + + vis_dir_fc = params['output_dir']+'_clipscore_vis' + if not os.path.isdir(vis_dir_fc): + os.mkdir(vis_dir_fc) + + # text_dir_fc = params['output_dir']+'_clipscore_text' + # if not os.path.isdir(text_dir_fc): + # os.mkdir(text_dir_fc) + + # if not os.path.isdir(dir_att): + # os.mkdir(dir_att) + + for i, img in enumerate(tqdm(imgs)): + # load the image + + img_path = os.path.join(params['images_root'], img['filepath'], img['filename']) + img_feat = clipscore_model.image_extract(img_path) + img_feat = img_feat.view(512) + + # for d in img['sentences']: + # text = d['raw'].strip() + # text_feat = clipscore_model.text_extract(text) + + + # with torch.no_grad(): + + # image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB")) + # image = torch.tensor(np.stack([image])).cuda() + # image -= mean + # image /= std + # if "RN" in params["model_type"]: + # tmp_att, tmp_fc = model.encode_image(image) + # tmp_att = tmp_att[0].permute(1, 2, 0) + # tmp_fc = tmp_fc[0] + # elif params["model_type"] == 'vit_base_patch32_224_in21k': + # x = model(image) + # tmp_fc = x[0, 0, :] + # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) + # else: + # x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid] + # x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + # x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + # x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + # x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :] + # x = model.visual.ln_pre(x) + + # x = x.permute(1, 0, 2) # NLD -> LND + + # for layer_idx, layer in enumerate(model.visual.transformer.resblocks): + # x = layer(x) + + # x = x.permute(1, 0, 2) + # tmp_fc = x[0, 0, :] + # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) + + np.save(os.path.join(vis_dir_fc, str(img['cocoid'])), img_feat.data.cpu().float().numpy()) + # np.save(os.path.join(text_dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) + + + # np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) + + if i % 1000 == 0: + print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) + print('wrote ', vis_dir_fc) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + # input json + # dataset_coco.json + parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') + parser.add_argument('--output_dir', default='data', help='output h5 file') + + # options + parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') + # parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') + # parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k') + + parser.add_argument('--n_jobs', default=-1, type=int, help='number of jobs to run in parallel') + parser.add_argument('--job_id', default=0, type=int, help='job id') + + args = parser.parse_args() + params = vars(args) # convert to ordinary dict + print('parsed input parameters:') + print(json.dumps(params, indent = 2)) + main(params) diff --git a/scripts/copy_model.sh b/scripts/copy_model.sh new file mode 100644 index 0000000000000000000000000000000000000000..3e0f8945ffcc1aff3016812a0f5ab91465677514 --- /dev/null +++ b/scripts/copy_model.sh @@ -0,0 +1,9 @@ +#!/bin/sh + +if [ ! -d log_$2 ]; then +cp -r log_$1 log_$2 +cd log_$2 +mv infos_$1-best.pkl infos_$2-best.pkl +mv infos_$1.pkl infos_$2.pkl +cd ../ +fi diff --git a/scripts/dump_to_h5df.py b/scripts/dump_to_h5df.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d3f2c3dfea2c450ad0d1f4e71c1382b66eb095 --- /dev/null +++ b/scripts/dump_to_h5df.py @@ -0,0 +1,56 @@ +import argparse +import h5py +import os +import numpy as np +import json +from tqdm import tqdm + + +def main(params): + + imgs = json.load(open(params['input_json'], 'r')) + imgs = imgs['images'] + N = len(imgs) + + if params['fc_input_dir'] is not None: + print('processing fc') + with h5py.File(params['fc_output']) as file_fc: + for i, img in enumerate(tqdm(imgs)): + npy_fc_path = os.path.join( + params['fc_input_dir'], + str(img['cocoid']) + '.npy') + + d_set_fc = file_fc.create_dataset( + str(img['cocoid']), data=np.load(npy_fc_path)) + file_fc.close() + + if params['att_input_dir'] is not None: + print('processing att') + with h5py.File(params['att_output']) as file_att: + for i, img in enumerate(tqdm(imgs)): + npy_att_path = os.path.join( + params['att_input_dir'], + str(img['cocoid']) + '.npz') + + d_set_att = file_att.create_dataset( + str(img['cocoid']), + data=np.load(npy_att_path)['feat']) + file_att.close() + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') + parser.add_argument('--fc_output', default='data', help='output h5 filename for fc') + parser.add_argument('--att_output', default='data', help='output h5 file for att') + parser.add_argument('--fc_input_dir', default=None, help='input directory for numpy fc files') + parser.add_argument('--att_input_dir', default=None, help='input directory for numpy att files') + + args = parser.parse_args() + params = vars(args) # convert to ordinary dict + print('parsed input parameters:') + print(json.dumps(params, indent=2)) + + main(params) \ No newline at end of file diff --git a/scripts/dump_to_lmdb.py b/scripts/dump_to_lmdb.py new file mode 100644 index 0000000000000000000000000000000000000000..483dae7d7f2ec513968f12937a82666727ef2700 --- /dev/null +++ b/scripts/dump_to_lmdb.py @@ -0,0 +1,241 @@ +# copy from https://github.com/Lyken17/Efficient-PyTorch/tools + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import os.path as osp +import os, sys +import os.path as osp +from PIL import Image +import six +import string + +from lmdbdict import lmdbdict +from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC +import pickle +import tqdm +import numpy as np +import argparse +import json + +import torch +import torch.utils.data as data +from torch.utils.data import DataLoader + +import csv +csv.field_size_limit(sys.maxsize) +FIELDNAMES = ['image_id', 'status'] + +class FolderLMDB(data.Dataset): + def __init__(self, db_path, fn_list=None): + self.db_path = db_path + self.lmdb = lmdbdict(db_path, unsafe=True) + self.lmdb._key_dumps = DUMPS_FUNC['ascii'] + self.lmdb._value_loads = LOADS_FUNC['identity'] + if fn_list is not None: + self.length = len(fn_list) + self.keys = fn_list + else: + raise Error + + def __getitem__(self, index): + byteflow = self.lmdb[self.keys[index]] + + # load image + imgbuf = byteflow + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + try: + if args.extension == '.npz': + feat = np.load(buf)['feat'] + else: + feat = np.load(buf) + except Exception as e: + print(self.keys[index], e) + return None + + return feat + + def __len__(self): + return self.length + + def __repr__(self): + return self.__class__.__name__ + ' (' + self.db_path + ')' + + +def make_dataset(dir, extension): + images = [] + dir = os.path.expanduser(dir) + for root, _, fnames in sorted(os.walk(dir)): + for fname in sorted(fnames): + if has_file_allowed_extension(fname, [extension]): + path = os.path.join(root, fname) + images.append(path) + + return images + + +def raw_reader(path): + with open(path, 'rb') as f: + bin_data = f.read() + return bin_data + + +def raw_npz_reader(path): + with open(path, 'rb') as f: + bin_data = f.read() + try: + npz_data = np.load(six.BytesIO(bin_data))['feat'] + except Exception as e: + print(path) + npz_data = None + print(e) + return bin_data, npz_data + + +def raw_npy_reader(path): + with open(path, 'rb') as f: + bin_data = f.read() + try: + npy_data = np.load(six.BytesIO(bin_data)) + except Exception as e: + print(path) + npy_data = None + print(e) + return bin_data, npy_data + + +class Folder(data.Dataset): + + def __init__(self, root, loader, extension, fn_list=None): + super(Folder, self).__init__() + self.root = root + if fn_list: + samples = [os.path.join(root, str(_)+extension) for _ in fn_list] + else: + samples = make_dataset(self.root, extension) + + self.loader = loader + self.extension = extension + self.samples = samples + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + path = self.samples[index] + sample = self.loader(path) + + return (path.split('/')[-1].split('.')[0],) + sample + + def __len__(self): + return len(self.samples) + + +def folder2lmdb(dpath, fn_list, write_frequency=5000): + directory = osp.expanduser(osp.join(dpath)) + print("Loading dataset from %s" % directory) + if args.extension == '.npz': + dataset = Folder(directory, loader=raw_npz_reader, extension='.npz', + fn_list=fn_list) + else: + dataset = Folder(directory, loader=raw_npy_reader, extension='.npy', + fn_list=fn_list) + data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x) + + # lmdb_path = osp.join(dpath, "%s.lmdb" % (directory.split('/')[-1])) + lmdb_path = osp.join("%s.lmdb" % (directory)) + isdir = os.path.isdir(lmdb_path) + + print("Generate LMDB to %s" % lmdb_path) + db = lmdbdict(lmdb_path, mode='w', key_method='ascii', value_method='identity') + + tsvfile = open(args.output_file, 'a') + writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) + names = [] + all_keys = [] + for idx, data in enumerate(tqdm.tqdm(data_loader)): + # print(type(data), data) + name, byte, npz = data[0] + if npz is not None: + db[name] = byte + all_keys.append(name) + names.append({'image_id': name, 'status': str(npz is not None)}) + if idx % write_frequency == 0: + print("[%d/%d]" % (idx, len(data_loader))) + print('writing') + db.flush() + # write in tsv + for name in names: + writer.writerow(name) + names = [] + tsvfile.flush() + print('writing finished') + # write all keys + # txn.put("keys".encode(), pickle.dumps(all_keys)) + # # finish iterating through dataset + # txn.commit() + for name in names: + writer.writerow(name) + tsvfile.flush() + tsvfile.close() + + print("Flushing database ...") + db.flush() + del db + +def parse_args(): + """ + Parse input arguments + """ + parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network') + # parser.add_argument('--json) + parser.add_argument('--input_json', default='./data/dataset_coco.json', type=str) + parser.add_argument('--output_file', default='.dump_cache.tsv', type=str) + parser.add_argument('--folder', default='./data/cocobu_att', type=str) + parser.add_argument('--extension', default='.npz', type=str) + + args = parser.parse_args() + return args + +if __name__ == "__main__": + global args + args = parse_args() + + args.output_file += args.folder.split('/')[-1] + if args.folder.find('/') > 0: + args.output_file = args.folder[:args.folder.rfind('/')+1]+args.output_file + print(args.output_file) + + img_list = json.load(open(args.input_json, 'r'))['images'] + fn_list = [str(_['cocoid']) for _ in img_list] + found_ids = set() + try: + with open(args.output_file, 'r') as tsvfile: + reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) + for item in reader: + if item['status'] == 'True': + found_ids.add(item['image_id']) + except: + pass + fn_list = [_ for _ in fn_list if _ not in found_ids] + folder2lmdb(args.folder, fn_list) + + # Test existing. + found_ids = set() + with open(args.output_file, 'r') as tsvfile: + reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) + for item in reader: + if item['status'] == 'True': + found_ids.add(item['image_id']) + + folder_dataset = FolderLMDB(args.folder+'.lmdb', list(found_ids)) + data_loader = DataLoader(folder_dataset, num_workers=16, collate_fn=lambda x: x) + for data in tqdm.tqdm(data_loader): + assert data[0] is not None \ No newline at end of file diff --git a/scripts/make_bu_data.py b/scripts/make_bu_data.py new file mode 100644 index 0000000000000000000000000000000000000000..211f3e93dd3df9836e542322b0a19eeb581b2e1a --- /dev/null +++ b/scripts/make_bu_data.py @@ -0,0 +1,52 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import base64 +import numpy as np +import csv +import sys +import zlib +import time +import mmap +import argparse + +parser = argparse.ArgumentParser() + +# output_dir +parser.add_argument('--downloaded_feats', default='data/bu_data', help='downloaded feature directory') +parser.add_argument('--output_dir', default='data/cocobu', help='output feature files') + +args = parser.parse_args() + +csv.field_size_limit(sys.maxsize) + + +FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features'] +infiles = ['trainval/karpathy_test_resnet101_faster_rcnn_genome.tsv', + 'trainval/karpathy_val_resnet101_faster_rcnn_genome.tsv',\ + 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.0', \ + 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.1'] + +os.makedirs(args.output_dir+'_att') +os.makedirs(args.output_dir+'_fc') +os.makedirs(args.output_dir+'_box') + +for infile in infiles: + print('Reading ' + infile) + with open(os.path.join(args.downloaded_feats, infile), "r") as tsv_in_file: + reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) + for item in reader: + item['image_id'] = int(item['image_id']) + item['num_boxes'] = int(item['num_boxes']) + for field in ['boxes', 'features']: + item[field] = np.frombuffer(base64.decodestring(item[field].encode('ascii')), + dtype=np.float32).reshape((item['num_boxes'],-1)) + np.savez_compressed(os.path.join(args.output_dir+'_att', str(item['image_id'])), feat=item['features']) + np.save(os.path.join(args.output_dir+'_fc', str(item['image_id'])), item['features'].mean(0)) + np.save(os.path.join(args.output_dir+'_box', str(item['image_id'])), item['boxes']) + + + + diff --git a/scripts/prepro_feats.py b/scripts/prepro_feats.py new file mode 100644 index 0000000000000000000000000000000000000000..2c98d880d6b0b76ddb21f1bd516c4ce90515b8f3 --- /dev/null +++ b/scripts/prepro_feats.py @@ -0,0 +1,103 @@ +""" +Preprocess a raw json dataset into features files for use in data_loader.py + +Input: json file that has the form +[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] +example element in this list would look like +{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} + +This script reads this json, does some basic preprocessing on the captions +(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays + +Output: two folders of features +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import json +import argparse +from random import shuffle, seed +import string +# non-standard dependencies: +import h5py +from six.moves import cPickle +import numpy as np +import torch +import torchvision.models as models +import skimage.io + +from torchvision import transforms as trn +preprocess = trn.Compose([ + #trn.ToTensor(), + trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) +]) + +from captioning.utils.resnet_utils import myResnet +import captioning.utils.resnet as resnet + + +def main(params): + net = getattr(resnet, params['model'])() + net.load_state_dict(torch.load(os.path.join(params['model_root'],params['model']+'.pth'))) + my_resnet = myResnet(net) + my_resnet.cuda() + my_resnet.eval() + + imgs = json.load(open(params['input_json'], 'r')) + imgs = imgs['images'] + N = len(imgs) + + seed(123) # make reproducible + + dir_fc = params['output_dir']+'_fc' + dir_att = params['output_dir']+'_att' + if not os.path.isdir(dir_fc): + os.mkdir(dir_fc) + if not os.path.isdir(dir_att): + os.mkdir(dir_att) + + for i,img in enumerate(imgs): + # load the image + I = skimage.io.imread(os.path.join(params['images_root'], img['filepath'], img['filename'])) + # handle grayscale input images + if len(I.shape) == 2: + I = I[:,:,np.newaxis] + I = np.concatenate((I,I,I), axis=2) + + I = I.astype('float32')/255.0 + I = torch.from_numpy(I.transpose([2,0,1])).cuda() + I = preprocess(I) + with torch.no_grad(): + tmp_fc, tmp_att = my_resnet(I, params['att_size']) + # write to pkl + # print(dir_fc, str(img['cocoid']), tmp_fc.shape, tmp_att.shape, dir_att) + # exit() + np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) + np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) + + if i % 1000 == 0: + print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) + print('wrote ', params['output_dir']) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + # input json + parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') + parser.add_argument('--output_dir', default='data', help='output h5 file') + + # options + parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') + parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') + parser.add_argument('--model', default='resnet101', type=str, help='resnet101, resnet152') + parser.add_argument('--model_root', default='./data/imagenet_weights', type=str, help='model root') + + args = parser.parse_args() + params = vars(args) # convert to ordinary dict + print('parsed input parameters:') + print(json.dumps(params, indent = 2)) + main(params) diff --git a/scripts/prepro_labels.py b/scripts/prepro_labels.py new file mode 100644 index 0000000000000000000000000000000000000000..57fd82fb5144e51fd7dfe3e159080dbf29a63567 --- /dev/null +++ b/scripts/prepro_labels.py @@ -0,0 +1,206 @@ +""" +Preprocess a raw json dataset into hdf5/json files for use in data_loader.py + +Input: json file that has the form +[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] +example element in this list would look like +{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} + +This script reads this json, does some basic preprocessing on the captions +(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays + +Output: a json file and an hdf5 file +The hdf5 file contains several fields: +/labels is (M,max_length) uint32 array of encoded labels, zero padded +/label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the + first and last indices (in range 1..M) of labels for each image +/label_length stores the length of the sequence for each of the M sequences + +The json file has a dict that contains: +- an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed +- an 'images' field that is a list holding auxiliary information for each image, + such as in particular the 'split' it was assigned to. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import json +import argparse +from random import shuffle, seed +import string +# non-standard dependencies: +import h5py +import numpy as np +import torch +import torchvision.models as models +import skimage.io +from PIL import Image + + +def build_vocab(imgs, params): + count_thr = params['word_count_threshold'] + + # count up the number of words + counts = {} + for img in imgs: + for sent in img['sentences']: + for w in sent['tokens']: + counts[w] = counts.get(w, 0) + 1 + cw = sorted([(count,w) for w,count in counts.items()], reverse=True) + print('top words and their counts:') + print('\n'.join(map(str,cw[:20]))) + + # print some stats + total_words = sum(counts.values()) + print('total words:', total_words) + bad_words = [w for w,n in counts.items() if n <= count_thr] + vocab = [w for w,n in counts.items() if n > count_thr] + bad_count = sum(counts[w] for w in bad_words) + print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) + print('number of words in vocab would be %d' % (len(vocab), )) + print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) + + # lets look at the distribution of lengths as well + sent_lengths = {} + for img in imgs: + for sent in img['sentences']: + txt = sent['tokens'] + nw = len(txt) + sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 + max_len = max(sent_lengths.keys()) + print('max length sentence in raw data: ', max_len) + print('sentence length distribution (count, number of words):') + sum_len = sum(sent_lengths.values()) + for i in range(max_len+1): + print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len)) + + # lets now produce the final annotations + if bad_count > 0: + # additional special UNK token we will use below to map infrequent words to + print('inserting the special UNK token') + vocab.append('UNK') + + for img in imgs: + img['final_captions'] = [] + for sent in img['sentences']: + txt = sent['tokens'] + caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] + img['final_captions'].append(caption) + + return vocab + + +def encode_captions(imgs, params, wtoi): + """ + encode all captions into one large array, which will be 1-indexed. + also produces label_start_ix and label_end_ix which store 1-indexed + and inclusive (Lua-style) pointers to the first and last caption for + each image in the dataset. + """ + + max_length = params['max_length'] + N = len(imgs) + M = sum(len(img['final_captions']) for img in imgs) # total number of captions + + label_arrays = [] + label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed + label_end_ix = np.zeros(N, dtype='uint32') + label_length = np.zeros(M, dtype='uint32') + caption_counter = 0 + counter = 1 + for i,img in enumerate(imgs): + n = len(img['final_captions']) + assert n > 0, 'error: some image has no captions' + + Li = np.zeros((n, max_length), dtype='uint32') + for j,s in enumerate(img['final_captions']): + label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence + caption_counter += 1 + for k,w in enumerate(s): + if k < max_length: + Li[j,k] = wtoi[w] + + # note: word indices are 1-indexed, and captions are padded with zeros + label_arrays.append(Li) + label_start_ix[i] = counter + label_end_ix[i] = counter + n - 1 + + counter += n + + L = np.concatenate(label_arrays, axis=0) # put all the labels together + assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' + assert np.all(label_length > 0), 'error: some caption had no words?' + + print('encoded captions to array of size ', L.shape) + return L, label_start_ix, label_end_ix, label_length + + +def main(params): + + imgs = json.load(open(params['input_json'], 'r')) + imgs = imgs['images'] + + seed(123) # make reproducible + + # create the vocab + vocab = build_vocab(imgs, params) + itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table + wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table + + # encode captions in large arrays, ready to ship to hdf5 file + L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) + + # create output h5 file + N = len(imgs) + f_lb = h5py.File(params['output_h5']+'_label.h5', "w") + f_lb.create_dataset("labels", dtype='uint32', data=L) + f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) + f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) + f_lb.create_dataset("label_length", dtype='uint32', data=label_length) + f_lb.close() + + # create output json file + out = {} + out['ix_to_word'] = itow # encode the (1-indexed) vocab + out['images'] = [] + for i,img in enumerate(imgs): + + jimg = {} + jimg['split'] = img['split'] + if 'filename' in img: jimg['file_path'] = os.path.join(img.get('filepath', ''), img['filename']) # copy it over, might need + if 'cocoid' in img: + jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) + elif 'imgid' in img: + jimg['id'] = img['imgid'] + + if params['images_root'] != '': + with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: + jimg['width'], jimg['height'] = _img.size + + out['images'].append(jimg) + + json.dump(out, open(params['output_json'], 'w')) + print('wrote ', params['output_json']) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + # input json + parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') + parser.add_argument('--output_json', default='data.json', help='output json file') + parser.add_argument('--output_h5', default='data', help='output h5 file') + parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') + + # options + parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') + parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab') + + args = parser.parse_args() + params = vars(args) # convert to ordinary dict + print('parsed input parameters:') + print(json.dumps(params, indent = 2)) + main(params) diff --git a/scripts/prepro_ngrams.py b/scripts/prepro_ngrams.py new file mode 100644 index 0000000000000000000000000000000000000000..f7cdce47deddaae19b97b24c4191c99fcf8f9cb8 --- /dev/null +++ b/scripts/prepro_ngrams.py @@ -0,0 +1,94 @@ +""" +Precompute ngram counts of captions, to accelerate cider computation during training time. +""" + +import os +import json +import argparse +from six.moves import cPickle +import captioning.utils.misc as utils +from collections import defaultdict + +import sys +sys.path.append("cider") +from pyciderevalcap.ciderD.ciderD_scorer import CiderScorer + + +def get_doc_freq(refs, params): + tmp = CiderScorer(df_mode="corpus") + for ref in refs: + tmp.cook_append(None, ref) + tmp.compute_doc_freq() + return tmp.document_frequency, len(tmp.crefs) + + +def build_dict(imgs, wtoi, params): + wtoi[''] = 0 + + count_imgs = 0 + + refs_words = [] + refs_idxs = [] + for img in imgs: + if (params['split'] == img['split']) or \ + (params['split'] == 'train' and img['split'] == 'restval') or \ + (params['split'] == 'all'): + #(params['split'] == 'val' and img['split'] == 'restval') or \ + ref_words = [] + ref_idxs = [] + for sent in img['sentences']: + if hasattr(params, 'bpe'): + sent['tokens'] = params.bpe.segment(' '.join(sent['tokens'])).strip().split(' ') + tmp_tokens = sent['tokens'] + [''] + tmp_tokens = [_ if _ in wtoi else 'UNK' for _ in tmp_tokens] + ref_words.append(' '.join(tmp_tokens)) + ref_idxs.append(' '.join([str(wtoi[_]) for _ in tmp_tokens])) + refs_words.append(ref_words) + refs_idxs.append(ref_idxs) + count_imgs += 1 + print('total imgs:', count_imgs) + + ngram_words, count_refs = get_doc_freq(refs_words, params) + ngram_idxs, count_refs = get_doc_freq(refs_idxs, params) + print('count_refs:', count_refs) + return ngram_words, ngram_idxs, count_refs + +def main(params): + + imgs = json.load(open(params['input_json'], 'r')) + dict_json = json.load(open(params['dict_json'], 'r')) + itow = dict_json['ix_to_word'] + wtoi = {w:i for i,w in itow.items()} + + # Load bpe + if 'bpe' in dict_json: + import tempfile + import codecs + codes_f = tempfile.NamedTemporaryFile(delete=False) + codes_f.close() + with open(codes_f.name, 'w') as f: + f.write(dict_json['bpe']) + with codecs.open(codes_f.name, encoding='UTF-8') as codes: + bpe = apply_bpe.BPE(codes) + params.bpe = bpe + + imgs = imgs['images'] + + ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params) + + utils.pickle_dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','wb')) + utils.pickle_dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','wb')) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + # input json + parser.add_argument('--input_json', default='data/dataset_coco.json', help='input json file to process into hdf5') + parser.add_argument('--dict_json', default='data/cocotalk.json', help='output json file') + parser.add_argument('--output_pkl', default='data/coco-all', help='output pickle file') + parser.add_argument('--split', default='all', help='test, val, train, all') + args = parser.parse_args() + params = vars(args) # convert to ordinary dict + + main(params) diff --git a/scripts/prepro_reference_json.py b/scripts/prepro_reference_json.py new file mode 100644 index 0000000000000000000000000000000000000000..683b12b03e0ef5768af2b11d359dc1f814a1e39b --- /dev/null +++ b/scripts/prepro_reference_json.py @@ -0,0 +1,69 @@ +# coding: utf-8 +""" +Create a reference json file used for evaluation with `coco-caption` repo. +Used when reference json is not provided, (e.g., flickr30k, or you have your own split of train/val/test) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import json +import argparse +import sys +import hashlib +from random import shuffle, seed + + +def main(params): + + imgs = json.load(open(params['input_json'][0], 'r'))['images'] + # tmp = [] + # for k in imgs.keys(): + # for img in imgs[k]: + # img['filename'] = img['image_id'] # k+'/'+img['image_id'] + # img['image_id'] = int( + # int(hashlib.sha256(img['image_id']).hexdigest(), 16) % sys.maxint) + # tmp.append(img) + # imgs = tmp + + # create output json file + out = {'info': {'description': 'This is stable 1.0 version of the 2014 MS COCO dataset.', 'url': 'http://mscoco.org', 'version': '1.0', 'year': 2014, 'contributor': 'Microsoft COCO group', 'date_created': '2015-01-27 09:11:52.357475'}, 'licenses': [{'url': 'http://creativecommons.org/licenses/by-nc-sa/2.0/', 'id': 1, 'name': 'Attribution-NonCommercial-ShareAlike License'}, {'url': 'http://creativecommons.org/licenses/by-nc/2.0/', 'id': 2, 'name': 'Attribution-NonCommercial License'}, {'url': 'http://creativecommons.org/licenses/by-nc-nd/2.0/', 'id': 3, 'name': 'Attribution-NonCommercial-NoDerivs License'}, {'url': 'http://creativecommons.org/licenses/by/2.0/', 'id': 4, 'name': 'Attribution License'}, {'url': 'http://creativecommons.org/licenses/by-sa/2.0/', 'id': 5, 'name': 'Attribution-ShareAlike License'}, {'url': 'http://creativecommons.org/licenses/by-nd/2.0/', 'id': 6, 'name': 'Attribution-NoDerivs License'}, {'url': 'http://flickr.com/commons/usage/', 'id': 7, 'name': 'No known copyright restrictions'}, {'url': 'http://www.usa.gov/copyright.shtml', 'id': 8, 'name': 'United States Government Work'}], 'type': 'captions'} + out.update({'images': [], 'annotations': []}) + + cnt = 0 + empty_cnt = 0 + for i, img in enumerate(imgs): + if img['split'] == 'train': + continue + out['images'].append( + {'id': img.get('cocoid', img['imgid'])}) + for j, s in enumerate(img['sentences']): + if len(s) == 0: + continue + s = ' '.join(s['tokens']) + out['annotations'].append( + {'image_id': out['images'][-1]['id'], 'caption': s, 'id': cnt}) + cnt += 1 + + json.dump(out, open(params['output_json'], 'w')) + print('wrote ', params['output_json']) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + # input json + parser.add_argument('--input_json', nargs='+', required=True, + help='input json file to process into hdf5') + parser.add_argument('--output_json', default='data.json', + help='output json file') + + args = parser.parse_args() + params = vars(args) # convert to ordinary dict + print('parsed input parameters:') + print(json.dumps(params, indent=2)) + main(params) + diff --git a/scripts_FineCapEval/clip_prepro_feats.py b/scripts_FineCapEval/clip_prepro_feats.py new file mode 100644 index 0000000000000000000000000000000000000000..3b2986a54766752ac353e09e064a6e8abb43e0d5 --- /dev/null +++ b/scripts_FineCapEval/clip_prepro_feats.py @@ -0,0 +1,163 @@ +""" +Preprocess a raw json dataset into features files for use in data_loader.py + +Input: json file that has the form +[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] +example element in this list would look like +{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} + +This script reads this json, does some basic preprocessing on the captions +(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays + +Output: two folders of features +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import json +import argparse +from random import shuffle, seed +import string +# non-standard dependencies: +import h5py +from six.moves import cPickle +import numpy as np +import torch +import torchvision.models as models +import skimage.io + +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from PIL import Image +from torch import nn + +preprocess = Compose([ + Resize((448, 448), interpolation=Image.BICUBIC), + CenterCrop((448, 448)), + ToTensor() +]) + + +from clip.clip import load +from timm.models.vision_transformer import resize_pos_embed +import timm + +from captioning.utils.resnet_utils import myResnet +import captioning.utils.resnet as resnet + +from tqdm import tqdm + + +def main(params): + if params["model_type"] != 'vit_base_patch32_224_in21k': + model, transform = load(params["model_type"], jit=False) + else: + model = timm.create_model(params["model_type"], pretrained=True) + model = model.cuda() + + if params["model_type"] != 'vit_base_patch32_224_in21k': + save_model_type = params["model_type"].split("-")[0] + mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to("cuda").reshape(3, 1, 1) + std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to("cuda").reshape(3, 1, 1) + + if "RN" in params["model_type"]: + num_patches = 196 #600 * 1000 // 32 // 32 + pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, model.visual.attnpool.positional_embedding.shape[-1], device='cuda'),) + pos_embed.weight = resize_pos_embed(model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed) + model.visual.attnpool.positional_embedding = pos_embed + + else: + save_model_type = 'vit_base' + mean = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1) + std = torch.Tensor([0.5, 0.5, 0.5]).to("cuda").reshape(3, 1, 1) + + num_patches = 196 #600 * 1000 // 32 // 32 + pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, 768, device='cuda'),) + pos_embed.weight = resize_pos_embed(model.pos_embed, pos_embed) + model.pos_embed = pos_embed + + if params["model_type"] == "ViT-B/32": + num_patches = 196 #600 * 1000 // 32 // 32 + pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768, device='cuda'),) + pos_embed.weight = resize_pos_embed(model.visual.positional_embedding.unsqueeze(0), pos_embed.unsqueeze(0)) + model.visual.positional_embedding = pos_embed + imgs = json.load(open(params['input_json'], 'r')) + imgs = imgs['images'] + N = len(imgs) + + seed(123) # make reproducible + + dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc' + dir_att = params['output_dir']+'_clip_'+save_model_type+'_att' + if not os.path.isdir(dir_fc): + os.mkdir(dir_fc) + if not os.path.isdir(dir_att): + os.mkdir(dir_att) + + for i, img in enumerate(tqdm(imgs)): + with torch.no_grad(): + + # img_path = os.path.join(params['images_root'], img['filepath'], img['filename']) + # img_path = os.path.join(params['images_root'], img['file_name']) + + img_path = os.path.join(params['images_root'], img['file_path']) + + image = preprocess(Image.open( img_path ).convert("RGB")) + image = torch.tensor(np.stack([image])).cuda() + image -= mean + image /= std + if "RN" in params["model_type"]: + tmp_att, tmp_fc = model.encode_image(image) + tmp_att = tmp_att[0].permute(1, 2, 0) + tmp_fc = tmp_fc[0] + elif params["model_type"] == 'vit_base_patch32_224_in21k': + x = model(image) + tmp_fc = x[0, 0, :] + tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) + else: + x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :] + x = model.visual.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + + for layer_idx, layer in enumerate(model.visual.transformer.resblocks): + x = layer(x) + + x = x.permute(1, 0, 2) + tmp_fc = x[0, 0, :] + tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) + + # np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) + # np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) + np.save(os.path.join(dir_fc, str(img['id'])), tmp_fc.data.cpu().float().numpy()) + np.savez_compressed(os.path.join(dir_att, str(img['id'])), feat=tmp_att.data.cpu().float().numpy()) + + + # if i % 1000 == 0: + # print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) + print('wrote ', dir_fc, dir_att) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + # input json + parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') + parser.add_argument('--output_dir', default='data', help='output h5 file') + + # options + parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') + parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') + parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k') + + args = parser.parse_args() + params = vars(args) # convert to ordinary dict + print('parsed input parameters:') + print(json.dumps(params, indent = 2)) + main(params) diff --git a/scripts_FineCapEval/clipscore_prepro_feats.py b/scripts_FineCapEval/clipscore_prepro_feats.py new file mode 100644 index 0000000000000000000000000000000000000000..5e085078ecd67e4e390bc50b543c14d4934cb260 --- /dev/null +++ b/scripts_FineCapEval/clipscore_prepro_feats.py @@ -0,0 +1,154 @@ +""" +Preprocess a raw json dataset into features files for use in data_loader.py + +Input: json file that has the form +[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] +example element in this list would look like +{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} + +This script reads this json, does some basic preprocessing on the captions +(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays + +Output: two folders of features +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import json +import argparse +from random import shuffle, seed +import string +# non-standard dependencies: +import h5py +from six.moves import cPickle +import numpy as np +import torch +import torchvision.models as models +import skimage.io + +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from PIL import Image +from torch import nn + +# preprocess = Compose([ +# Resize((448, 448), interpolation=Image.BICUBIC), +# CenterCrop((448, 448)), +# ToTensor() +# ]) + + +# from clip.clip import load +# from timm.models.vision_transformer import resize_pos_embed +# import timm + +# from captioning.utils.resnet_utils import myResnet +# import captioning.utils.resnet as resnet + +from captioning.utils.clipscore import CLIPScore + +from tqdm import tqdm + + +def main(params): + + clipscore_model = CLIPScore() + clipscore_model.to('cuda') + + imgs = json.load(open(params['input_json'], 'r')) + imgs = imgs['images'] + N = len(imgs) + + seed(123) # make reproducible + + # dir_fc = params['output_dir']+'_clip_'+save_model_type+'_fc' + # dir_att = params['output_dir']+'_clip_'+save_model_type+'_att' + + vis_dir_fc = params['output_dir']+'_clipscore_vis' + if not os.path.isdir(vis_dir_fc): + os.mkdir(vis_dir_fc) + + # text_dir_fc = params['output_dir']+'_clipscore_text' + # if not os.path.isdir(text_dir_fc): + # os.mkdir(text_dir_fc) + + # if not os.path.isdir(dir_att): + # os.mkdir(dir_att) + + for i,img in enumerate(tqdm(imgs)): + # load the image + + # img_path = os.path.join(params['images_root'], img['filepath'], img['filename']) + # img_path = os.path.join(params['images_root'], img['file_name']) + img_path = os.path.join(params['images_root'], img['file_path']) + + img_feat = clipscore_model.image_extract(img_path) + img_feat = img_feat.view(512) + + # for d in img['sentences']: + # text = d['raw'].strip() + # text_feat = clipscore_model.text_extract(text) + + + # with torch.no_grad(): + + # image = preprocess(Image.open(os.path.join(params['images_root'], img['filepath'], img['filename']) ).convert("RGB")) + # image = torch.tensor(np.stack([image])).cuda() + # image -= mean + # image /= std + # if "RN" in params["model_type"]: + # tmp_att, tmp_fc = model.encode_image(image) + # tmp_att = tmp_att[0].permute(1, 2, 0) + # tmp_fc = tmp_fc[0] + # elif params["model_type"] == 'vit_base_patch32_224_in21k': + # x = model(image) + # tmp_fc = x[0, 0, :] + # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) + # else: + # x = model.visual.conv1(image.half()) # shape = [*, width, grid, grid] + # x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + # x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + # x = torch.cat([model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + # x = x + model.visual.positional_embedding.to(x.dtype)[:x.shape[1], :] + # x = model.visual.ln_pre(x) + + # x = x.permute(1, 0, 2) # NLD -> LND + + # for layer_idx, layer in enumerate(model.visual.transformer.resblocks): + # x = layer(x) + + # x = x.permute(1, 0, 2) + # tmp_fc = x[0, 0, :] + # tmp_att = x[0, 1:, :].reshape( 14, 14, 768 ) + + np.save(os.path.join(vis_dir_fc, str(img['id'])), img_feat.data.cpu().float().numpy()) + # np.save(os.path.join(text_dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) + + + # np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) + + # if i % 1000 == 0: + # print('processing %d/%d (%.2f%% done)' % (i, N, i*100.0/N)) + print('wrote ', vis_dir_fc) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + # input json + # dataset_coco.json + parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') + parser.add_argument('--output_dir', default='data', help='output h5 file') + + # options + parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') + # parser.add_argument('--att_size', default=14, type=int, help='14x14 or 7x7') + # parser.add_argument('--model_type', default='RN50', type=str, help='RN50, RN101, RN50x4, ViT-B/32, vit_base_patch32_224_in21k') + + args = parser.parse_args() + params = vars(args) # convert to ordinary dict + print('parsed input parameters:') + print(json.dumps(params, indent = 2)) + main(params) diff --git a/scripts_FineCapEval/prepro_labels.py b/scripts_FineCapEval/prepro_labels.py new file mode 100644 index 0000000000000000000000000000000000000000..48e7d079808760941a78d87435f8f0e2bbcfb280 --- /dev/null +++ b/scripts_FineCapEval/prepro_labels.py @@ -0,0 +1,209 @@ +""" +Preprocess a raw json dataset into hdf5/json files for use in data_loader.py + +Input: json file that has the form +[{ file_path: 'path/img.jpg', captions: ['a caption', ...] }, ...] +example element in this list would look like +{'captions': [u'A man with a red helmet on a small moped on a dirt road. ', u'Man riding a motor bike on a dirt road on the countryside.', u'A man riding on the back of a motorcycle.', u'A dirt path with a young person on a motor bike rests to the foreground of a verdant area with a bridge and a background of cloud-wreathed mountains. ', u'A man in a red shirt and a red hat is on a motorcycle on a hill side.'], 'file_path': u'val2014/COCO_val2014_000000391895.jpg', 'id': 391895} + +This script reads this json, does some basic preprocessing on the captions +(e.g. lowercase, etc.), creates a special UNK token, and encodes everything to arrays + +Output: a json file and an hdf5 file +The hdf5 file contains several fields: +/labels is (M,max_length) uint32 array of encoded labels, zero padded +/label_start_ix and /label_end_ix are (N,) uint32 arrays of pointers to the + first and last indices (in range 1..M) of labels for each image +/label_length stores the length of the sequence for each of the M sequences + +The json file has a dict that contains: +- an 'ix_to_word' field storing the vocab in form {ix:'word'}, where ix is 1-indexed +- an 'images' field that is a list holding auxiliary information for each image, + such as in particular the 'split' it was assigned to. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import json +import argparse +from random import shuffle, seed +import string +# non-standard dependencies: +import h5py +import numpy as np +import torch +import torchvision.models as models +import skimage.io +from PIL import Image + + +def build_vocab(imgs, params): + count_thr = params['word_count_threshold'] + + # count up the number of words + counts = {} + for img in imgs: + for sent in img['sentences']: + for w in sent['tokens']: + counts[w] = counts.get(w, 0) + 1 + cw = sorted([(count,w) for w,count in counts.items()], reverse=True) + print('top words and their counts:') + print('\n'.join(map(str,cw[:20]))) + + # print some stats + total_words = sum(counts.values()) + print('total words:', total_words) + bad_words = [w for w,n in counts.items() if n <= count_thr] + vocab = [w for w,n in counts.items() if n > count_thr] + bad_count = sum(counts[w] for w in bad_words) + print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts))) + print('number of words in vocab would be %d' % (len(vocab), )) + print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words)) + + # lets look at the distribution of lengths as well + sent_lengths = {} + for img in imgs: + for sent in img['sentences']: + txt = sent['tokens'] + nw = len(txt) + sent_lengths[nw] = sent_lengths.get(nw, 0) + 1 + max_len = max(sent_lengths.keys()) + print('max length sentence in raw data: ', max_len) + print('sentence length distribution (count, number of words):') + sum_len = sum(sent_lengths.values()) + for i in range(max_len+1): + print('%2d: %10d %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len)) + + # lets now produce the final annotations + if bad_count > 0: + # additional special UNK token we will use below to map infrequent words to + print('inserting the special UNK token') + vocab.append('UNK') + + for img in imgs: + img['final_captions'] = [] + for sent in img['sentences']: + txt = sent['tokens'] + caption = [w if counts.get(w,0) > count_thr else 'UNK' for w in txt] + img['final_captions'].append(caption) + + return vocab + + +def encode_captions(imgs, params, wtoi): + """ + encode all captions into one large array, which will be 1-indexed. + also produces label_start_ix and label_end_ix which store 1-indexed + and inclusive (Lua-style) pointers to the first and last caption for + each image in the dataset. + """ + + max_length = params['max_length'] + N = len(imgs) + M = sum(len(img['final_captions']) for img in imgs) # total number of captions + + label_arrays = [] + label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed + label_end_ix = np.zeros(N, dtype='uint32') + label_length = np.zeros(M, dtype='uint32') + caption_counter = 0 + counter = 1 + for i,img in enumerate(imgs): + n = len(img['final_captions']) + assert n > 0, 'error: some image has no captions' + + Li = np.zeros((n, max_length), dtype='uint32') + for j,s in enumerate(img['final_captions']): + label_length[caption_counter] = min(max_length, len(s)) # record the length of this sequence + caption_counter += 1 + for k,w in enumerate(s): + if k < max_length: + Li[j,k] = wtoi[w] + + # note: word indices are 1-indexed, and captions are padded with zeros + label_arrays.append(Li) + label_start_ix[i] = counter + label_end_ix[i] = counter + n - 1 + + counter += n + + L = np.concatenate(label_arrays, axis=0) # put all the labels together + assert L.shape[0] == M, 'lengths don\'t match? that\'s weird' + assert np.all(label_length > 0), 'error: some caption had no words?' + + print('encoded captions to array of size ', L.shape) + return L, label_start_ix, label_end_ix, label_length + + +def main(params): + + imgs = json.load(open(params['input_json'], 'r')) + imgs = imgs['images'] + + seed(123) # make reproducible + + # # create the vocab + # vocab = build_vocab(imgs, params) + # itow = {i+1:w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table + # wtoi = {w:i+1 for i,w in enumerate(vocab)} # inverse table + + itow = imgs['ix_to_word'] + wtoi = {w:i for i, w in itow.items()} + + # encode captions in large arrays, ready to ship to hdf5 file + L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi) + + # create output h5 file + N = len(imgs) + f_lb = h5py.File(params['output_h5']+'_label.h5', "w") + f_lb.create_dataset("labels", dtype='uint32', data=L) + f_lb.create_dataset("label_start_ix", dtype='uint32', data=label_start_ix) + f_lb.create_dataset("label_end_ix", dtype='uint32', data=label_end_ix) + f_lb.create_dataset("label_length", dtype='uint32', data=label_length) + f_lb.close() + + # create output json file + out = {} + out['ix_to_word'] = itow # encode the (1-indexed) vocab + out['images'] = [] + for i,img in enumerate(imgs): + + jimg = {} + jimg['split'] = img['split'] + if 'filename' in img: jimg['file_path'] = os.path.join(img.get('filepath', ''), img['filename']) # copy it over, might need + if 'cocoid' in img: + jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) + elif 'imgid' in img: + jimg['id'] = img['imgid'] + + if params['images_root'] != '': + with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: + jimg['width'], jimg['height'] = _img.size + + out['images'].append(jimg) + + json.dump(out, open(params['output_json'], 'w')) + print('wrote ', params['output_json']) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + # input json + parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') + parser.add_argument('--output_json', default='data.json', help='output json file') + parser.add_argument('--output_h5', default='data', help='output h5 file') + parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') + + # options + parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') + parser.add_argument('--word_count_threshold', default=5, type=int, help='only words that occur more than this number of times will be put in vocab') + + args = parser.parse_args() + params = vars(args) # convert to ordinary dict + print('parsed input parameters:') + print(json.dumps(params, indent = 2)) + main(params) diff --git a/tools/eval.py b/tools/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..881580737fa554344b1b66ab79c4f1de114759ca --- /dev/null +++ b/tools/eval.py @@ -0,0 +1,125 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import numpy as np + +import time +import os +from six.moves import cPickle + +import captioning.utils.opts as opts +import captioning.models as models +from captioning.data.dataloader import * +# from captioning.data.dataloaderraw import * +import captioning.utils.eval_utils as eval_utils +import argparse +import captioning.utils.misc as utils +import captioning.modules.losses as losses +import torch + +# Input arguments and options +parser = argparse.ArgumentParser() +# Input paths +parser.add_argument('--model', type=str, default='', + help='path to model to evaluate') +parser.add_argument('--cnn_model', type=str, default='resnet101', + help='resnet101, resnet152') +parser.add_argument('--infos_path', type=str, default='', + help='path to infos to evaluate') +parser.add_argument('--only_lang_eval', type=int, default=0, + help='lang eval on saved results') +parser.add_argument('--force', type=int, default=0, + help='force to evaluate no matter if there are results available') +parser.add_argument('--device', type=str, default='cuda', + help='cpu or cuda') +opts.add_eval_options(parser) +opts.add_diversity_opts(parser) +opt = parser.parse_args() + +# Load infos +with open(opt.infos_path, 'rb') as f: + infos = utils.pickle_load(f) + +# override and collect parameters +replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id'] +ignore = ['start_from'] + +for k in vars(infos['opt']).keys(): + if k in replace: + setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, '')) + elif k not in ignore: + if not k in vars(opt): + vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model + +vocab = infos['vocab'] # ix -> word mapping + +pred_fn = os.path.join('eval_results/', '.saved_pred_'+ opt.id + '_' + opt.split + '.pth') +result_fn = os.path.join('eval_results/', opt.id + '_' + opt.split + '.json') + +if opt.only_lang_eval == 1 or (not opt.force and os.path.isfile(pred_fn)): + # if results existed, then skip, unless force is on + if not opt.force: + try: + if os.path.isfile(result_fn): + print(result_fn) + json.load(open(result_fn, 'r')) + print('already evaluated') + os._exit(0) + except: + pass + + predictions, n_predictions = torch.load(pred_fn) + lang_stats = eval_utils.language_eval(opt.input_json, predictions, n_predictions, vars(opt), opt.split) + print(lang_stats) + os._exit(0) + +# At this point only_lang_eval if 0 +if not opt.force: + # Check out if + try: + # if no pred exists, then continue + tmp = torch.load(pred_fn) + # if language_eval == 1, and no pred exists, then continue + if opt.language_eval == 1: + json.load(open(result_fn, 'r')) + print('Result is already there') + os._exit(0) + except: + pass + +# Setup the model +opt.vocab = vocab +model = models.setup(opt) +del opt.vocab +model.load_state_dict(torch.load(opt.model, map_location='cpu')) +model.to(opt.device) +model.eval() +crit = losses.LanguageModelCriterion() + +# Create the Data Loader instance +if len(opt.image_folder) == 0: + loader = DataLoader(opt) +else: + loader = DataLoaderRaw({'folder_path': opt.image_folder, + 'coco_json': opt.coco_json, + 'batch_size': opt.batch_size, + 'cnn_model': opt.cnn_model}) +# When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json +# So make sure to use the vocab in infos file. +loader.dataset.ix_to_word = infos['vocab'] + + +# Set sample options +opt.dataset = opt.input_json +loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, + vars(opt)) + +print('loss: ', loss) +if lang_stats: + print(lang_stats) + +if opt.dump_json == 1: + # dump the json + json.dump(split_predictions, open('vis/vis.json', 'w')) diff --git a/tools/eval_clip_retrieval.py b/tools/eval_clip_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..639aace1d4558094d93f4bc3a8643e883b26785b --- /dev/null +++ b/tools/eval_clip_retrieval.py @@ -0,0 +1,231 @@ + +from PIL import Image +# import requests + +from transformers import CLIPProcessor, CLIPModel + +import torch +from torch.utils.data import DataLoader, Dataset + +from pathlib import Path +from tqdm import tqdm +import json +import argparse +import numpy as np + +class COCODataset(Dataset): + def __init__(self, + coco_root="/nas-ssd/jmincho/datasets/COCO/", + gen_caption_path=None, + is_gt=True): + super().__init__() + + self.coco_root = Path(coco_root) + + self.image_dir = self.coco_root.joinpath('images/val2014') + + if is_gt: + print("Loading karpathy splits") + data_info_path = self.coco_root.joinpath('dataset_coco.json') + with open(data_info_path) as f: + karpathy_data = json.load(f) + + data = [] + for datum in karpathy_data['images']: + # karpathy test split + if datum['split'] == 'test': + img_id = datum['filename'].split('.')[0] + new_datum = { + 'img_id': img_id, + 'captions': [d['raw'].strip() for d in datum['sentences']], + } + data.append(new_datum) + else: + print("Loading generated captions") + gen_caption_path = Path(gen_caption_path) + with open(gen_caption_path) as f: + # karpathy_data = json.load(f) + imgTogen_results = json.load(f)['imgToEval'] + data = [] + for img_id, img_data in imgTogen_results.items(): + new_datum = { + 'img_id': img_id, + 'captions': [img_data['caption']], + } + data.append(new_datum) + + self.data = data + print('# images:', len(self.data)) + + self.img_transform = processor.feature_extractor + self.tokenizer = processor.tokenizer + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + datum = self.data[idx] + img_id = datum['img_id'] + if 'COCO' not in img_id: + img_id = f'COCO_val2014_{str(img_id).zfill(12)}' + img_fname = f"{img_id}.jpg" + # COCO_val2014_000000522418.jpg + img_path = self.image_dir.joinpath(img_fname) + img = Image.open(img_path).convert("RGB") + + # take first caption + caption = datum['captions'][0] + + return { + "img": img, + "caption": caption, + } + + def collate_fn(self, datum_list): + B = len(datum_list) + imgs = [datum['img'] for datum in datum_list] + images = self.img_transform(imgs, return_tensors="pt") + + captions = [datum['caption'] for datum in datum_list] + + text_tokens = self.tokenizer(captions, return_tensors="pt", padding=True) + batch = { + 'images': images, + 'captions': text_tokens, + } + return batch + + +def compute_similarity(image_features, text_features, bs = 1000): + # compute similarity + max_pairs = image_features.shape[0] + similarity_scores = torch.zeros(max_pairs, max_pairs) + for v in range(0, max_pairs, bs): + for t in range(0, max_pairs, bs): + # print('Processing Visual '+str(v)+' Text '+str(t), end='\r') + batch_visual_emb = image_features[v:v+bs] + batch_caption_emb = text_features[t:t+bs] + + logits = batch_visual_emb @ batch_caption_emb.t() + similarity_scores[v:v+bs,t:t+bs] = logits + + print('Done similarity') + return similarity_scores + +def compute_retrieval(a2b_sims, return_ranks=True): + """ + Args: + a2b_sims: Result of computing similarity between two sets of embeddings (emb1 @ emb2.T) + with shape (num_datapoints, num_datapoints). + + Returns: + Retrieval metrics for that similarity. + """ + npts = a2b_sims.shape[0] + ranks = np.zeros(npts) + top1 = np.zeros(npts) + # loop source embedding indices + for index in range(npts): + # get order of similarities to target embeddings + inds = np.argsort(a2b_sims[index])[::-1] + # find where the correct embedding is ranked + where = np.where(inds == index) + rank = where[0][0] + ranks[index] = rank + # save the top1 result as well + top1[index] = inds[0] + + # Compute metrics + r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) + r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) + r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) + r50 = 100.0 * len(np.where(ranks < 50)[0]) / len(ranks) + medr = np.floor(np.median(ranks)) + 1 + meanr = ranks.mean() + 1 + + report_dict = {"r1": r1, "r5": r5, "r10": r10, "r50": r50, "medr": medr, "meanr": meanr, "sum": r1 + r5 + r10} + + if return_ranks: + return report_dict, (ranks, top1) + else: + return report_dict + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--coco_root', type=str, default="/nas-ssd/jmincho/datasets/COCO/") + parser.add_argument('--gt', action='store_true') + parser.add_argument('--gen_caption_path', type=str, default="./eval_results/clipRN50_cider_test.json") + args = parser.parse_args() + + model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") + + device = "cuda" + model = model.to(device) + model.eval() + print(f"Loaded CLIP at {device}") + + batch_size = 1000 + + dataset = COCODataset( + coco_root="/nas-ssd/jmincho/datasets/COCO/", + gen_caption_path=args.gen_caption_path, + is_gt=args.gt + ) + data_loader = DataLoader( + dataset, + batch_size=batch_size, + collate_fn=dataset.collate_fn, + shuffle=False, + num_workers=8) + + # fwd all samples + image_features = [] + text_features = [] + for batch_idx, batch in enumerate(tqdm(data_loader)): + # print('Evaluating batch {}/{}'.format(batch_idx, len(data_loader)), end="\r") + # images, texts = batch + + with torch.no_grad(): + images = batch["images"].to(device) + texts = batch["captions"].to(device) + + vision_outputs = model.vision_model(**batch['images']) + text_outputs = model.text_model(**batch['captions']) + + image_embeds = vision_outputs[1] + image_embeds = model.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = model.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) + + text_features.append(text_embeds.detach().cpu()) + image_features.append(image_embeds.detach().cpu()) + + image_features = torch.cat(image_features, 0) + text_features = torch.cat(text_features, 0) + print('Done forward') + + # normalized features + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # if not single_caption: + # for cap_idx in range(text_features.shape[1]): + # similarity_scores = compute_similarity(image_features, text_features[:,cap_idx,:]) + # i2t_dict = compute_retrieval(similarity_scores.numpy()) + # t2i_dict = compute_retrieval(similarity_scores.t().numpy()) + # print(cap_idx, 'i2t', i2t_dict) + # print(cap_idx, 't2i', t2i_dict) + # else: + similarity_scores = compute_similarity(image_features, text_features) + i2t_dict = compute_retrieval(similarity_scores.numpy()) + t2i_dict = compute_retrieval(similarity_scores.t().numpy()) + print('i2t', i2t_dict) + print('t2i', t2i_dict) diff --git a/tools/eval_finecapeval.py b/tools/eval_finecapeval.py new file mode 100644 index 0000000000000000000000000000000000000000..43916493adfb0736bc97589512f0b23c12154626 --- /dev/null +++ b/tools/eval_finecapeval.py @@ -0,0 +1,204 @@ + +from tqdm import tqdm +from pprint import pprint +import pandas as pd +import argparse +import re +import json +import nltk +from nltk.tokenize import word_tokenize +from nltk.stem.porter import PorterStemmer +p_stemmer = PorterStemmer() + +# nltk.download('punkt') +# nltk.download('wordnet') +# nltk.download('stopwords') + +import language_evaluation +evaluator = language_evaluation.CocoEvaluator() + + +def nltk_process(text): + # Tokenization + nltk_tokenList = word_tokenize(text) + + # Stemming + nltk_stemedList = [] + for word in nltk_tokenList: + nltk_stemedList.append(p_stemmer.stem(word)) + + filtered_sentence = nltk_stemedList + + # Removing Punctuation + + tokens = [re.sub(r'[^a-zA-Z0-9]', '', tok) for tok in filtered_sentence] + + text = " ".join(tokens) + + return text + + +def calculate_finegrained_scores(pred_id2sent, id2caption, use_coco_eval=False): + if use_coco_eval: + n_total = 0 + refs = [] + hyps = [] + for id, gt_captions in id2caption.items(): + pred_sent = pred_id2sent[id] + + refs.append(gt_captions) + hyps.append(pred_sent) + + n_total += 1 + + print('caption') + results = evaluator.run_evaluation(hyps, refs) + pprint(results) + + n_total = 0 + total_score = 0 + for id, gt_phrases in id2background.items(): + pred_sent = pred_id2sent[id] + + score = 0 + n_phrases = len(gt_phrases) + + for gt_phrase in gt_phrases: + word_score = 0 + for gt_word in gt_phrase.split(): + if gt_word in pred_sent: + word_score += 1 + if len(gt_phrase.split()) > 0: + score += word_score / len(gt_phrase.split()) + + if n_phrases > 0: + score /= n_phrases + + total_score += score + n_total += 1 + print('background') +# print('# retrieved words:', n_retrieved) + print(f'Acc: {total_score / n_total * 100:.2f}') + + n_total = 0 + total_score = 0 + for id, gt_phrases in id2object.items(): + pred_sent = pred_id2sent[id] + + score = 0 + n_phrases = len(gt_phrases) + + for gt_phrase in gt_phrases: + word_score = 0 + for gt_word in gt_phrase.split(): + if gt_word in pred_sent: + word_score += 1 + if len(gt_phrase.split()) > 0: + score += word_score / len(gt_phrase.split()) + + if n_phrases > 0: + score /= n_phrases + + total_score += score + n_total += 1 + print('object') +# print('# retrieved words:', n_retrieved) + print(f'Acc: {total_score / n_total * 100:.2f}') + + n_total = 0 + total_score = 0 + for id, gt_phrases in id2relation.items(): + pred_sent = pred_id2sent[id] + + score = 0 + n_phrases = len(gt_phrases) + + for gt_phrase in gt_phrases: + word_score = 0 + for gt_word in gt_phrase.split(): + if gt_word in pred_sent: + word_score += 1 + if len(gt_phrase.split()) > 0: + score += word_score / len(gt_phrase.split()) + + if n_phrases > 0: + score /= n_phrases + + total_score += score + n_total += 1 + print('relation') +# print('# retrieved words:', n_retrieved) + print(f'Acc: {total_score / n_total * 100:.2f}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--finecapeval_path', type=str, default="data/FineCapEval.csv") + parser.add_argument('--generated_id2caption', type=str, default="FineCapEval_results/mle.json") + args = parser.parse_args() + + df = pd.read_csv(args.finecapeval_path) + assert df.shape == (5000, 5) + + generated_id2caption = json.load(open(args.generated_id2caption, 'r')) + + print("Preprocessing GT FineCapEval data...") + id2caption = {} + id2background = {} + id2object = {} + id2relation = {} + + for row in tqdm(df.itertuples(), total=len(df)): + + id = row.image.split('.')[0] + caption = row.caption + background = row.background + object = row.object + relation = row.relation + + if not isinstance(caption, str): + continue + if not isinstance(background, str): + continue + if not isinstance(object, str): + continue + if not isinstance(relation, str): + continue + + if id not in id2caption: + id2caption[id] = [] + id2background[id] = [] + id2object[id] = [] + id2relation[id] = [] + + id2caption[id].append(caption) + + phrases = [] + for phrase in background.lower().split('\;'): + if len(phrase) > 1: + phrase = nltk_process(phrase) + phrases.append(phrase) + id2background[id].extend(phrases) + + phrases = [] + for phrase in object.lower().split('\;'): + if len(phrase) > 1: + phrase = nltk_process(phrase) + phrases.append(phrase) + id2object[id].extend(phrases) + + phrases = [] + for phrase in relation.lower().split('\;'): + if len(phrase) > 1: + phrase = nltk_process(phrase) + phrases.append(phrase) + id2relation[id].extend(phrases) + + print("Calculating scores...") + calculate_finegrained_scores( + generated_id2caption, + id2caption, + use_coco_eval=True) + + + diff --git a/tools/finecapeval_inference.py b/tools/finecapeval_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..260b083e00df7c9b2349be23fd2a09591dec3f2b --- /dev/null +++ b/tools/finecapeval_inference.py @@ -0,0 +1,186 @@ +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +import numpy as np + +import time +import os +from collections import defaultdict +import json + +import captioning.utils.opts as opts +import captioning.models as models +from captioning.data.pth_loader import CaptionDataset +import captioning.utils.eval_utils as eval_utils +# import captioning.utils.vizwiz_eval_utils as vizwiz_eval_utils +import captioning.utils.misc as utils +from captioning.utils.rewards import init_scorer, get_self_critical_reward +from captioning.modules.loss_wrapper import LossWrapper + +import pytorch_lightning as pl + + +class ModelCheckpoint(pl.callbacks.ModelCheckpoint): + + def on_keyboard_interrupt(self, trainer, pl_module): + # Save model when keyboard interrupt + filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt') + self._save_model(filepath) + + +if __name__ == '__main__': + + device = 'cuda' + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--reward', type=str, default='mle') + args = parser.parse_args() + + if args.reward == 'mle': + cfg = f'configs/phase1/fg_clipRN50_{args.reward}.yml' + else: + cfg = f'configs/phase2/fg_clipRN50_{args.reward}.yml' + + print("Loading cfg from", cfg) + + opt = opts.parse_opt(parse=False, cfg=cfg) + + dataset = CaptionDataset(opt) + + opt.vocab_size = dataset.vocab_size + opt.seq_length = dataset.seq_length + + opt.batch_size = 40 + + opt.vocab = dataset.get_vocab() + + model = models.setup(opt) + del opt.vocab + + ckpt_path = opt.checkpoint_path + '-last.ckpt' + + print("Loading checkpoint from", ckpt_path) + raw_state_dict = torch.load( + ckpt_path, + map_location=device) + + strict = True + + state_dict = raw_state_dict['state_dict'] + + if '_vocab' in state_dict: + model.vocab = utils.deserialize(state_dict['_vocab']) + del state_dict['_vocab'] + elif strict: + raise KeyError + if '_opt' in state_dict: + saved_model_opt = utils.deserialize(state_dict['_opt']) + del state_dict['_opt'] + # Make sure the saved opt is compatible with the curren topt + need_be_same = ["caption_model", + "rnn_type", "rnn_size", "num_layers"] + for checkme in need_be_same: + if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \ + getattr(opt, checkme) in ['updown', 'topdown']: + continue + assert getattr(saved_model_opt, checkme) == getattr( + opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme + elif strict: + raise KeyError + res = model.load_state_dict(state_dict, strict) + print(res) + + opt.use_grammar = False + + lw_model = LossWrapper(model, opt) + + split = 'test' + + print("Building dataloader...") + + test_dataset = torch.utils.data.Subset( + dataset, + dataset.split_ix[split] + ) + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=opt.batch_size, + shuffle=False, + num_workers=4, + drop_last=False, + collate_fn=dataset.collate_func + ) + + eval_kwargs = {'dataset': opt.input_json} + eval_kwargs.update(vars(opt)) + + verbose = eval_kwargs.get('verbose', True) + verbose_beam = eval_kwargs.get('verbose_beam', 0) + verbose_loss = eval_kwargs.get('verbose_loss', 1) + # num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) + # lang_eval = eval_kwargs.get('language_eval', 0) + dataset = eval_kwargs.get('dataset', 'coco') + beam_size = eval_kwargs.get('beam_size', 1) + sample_n = eval_kwargs.get('sample_n', 1) + remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) + + crit = lw_model.crit + + model = model.to(device) + + from tqdm import tqdm + + test_id2sent = {} + + model.eval() + + print("running inference...") + + for data in tqdm(test_loader): + with torch.no_grad(): + # forward the model to get loss + tmp = [data['fc_feats'], data['att_feats'], + data['labels'], data['masks'], data['att_masks']] + tmp = [d.to(device) if isinstance(d, torch.Tensor) else d for d in tmp] + + fc_feats, att_feats, labels, masks, att_masks = tmp + + loss = crit(model(fc_feats, att_feats, + labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) + + # forward the model to also get generated samples for each image + # Only leave one feature for each image, in case duplicate sample + tmp_eval_kwargs = eval_kwargs.copy() + tmp_eval_kwargs.update({'sample_n': 1}) + seq, seq_logprobs = model( + fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') + seq = seq.data + entropy = - (F.softmax(seq_logprobs, dim=2) * + seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) + perplexity = - \ + seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze( + 2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) + + # Print beam search + if beam_size > 1 and verbose_beam: + for i in range(fc_feats.shape[0]): + print('\n'.join([utils.decode_sequence(model.vocab, _[ + 'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) + print('--' * 10) + sents = utils.decode_sequence(model.vocab, seq) + + for d, sent in zip(data['infos'], sents): + test_id2sent[d['id']] = sent + + res_path = f'FineCapEval_results/clipRN50_{args.reward}.json' + + print("Results save at {}".format(res_path)) + + with open(res_path, 'w') as f: + json.dump(test_id2sent, f) + + diff --git a/tools/train_pl.py b/tools/train_pl.py new file mode 100644 index 0000000000000000000000000000000000000000..48ac2d0cf68466bd0e39f9c994056063a0529f27 --- /dev/null +++ b/tools/train_pl.py @@ -0,0 +1,709 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +import numpy as np + +import time +import os +from collections import defaultdict + +import captioning.utils.opts as opts +import captioning.models as models +from captioning.data.pth_loader import CaptionDataset +import captioning.utils.eval_utils as eval_utils +import captioning.utils.misc as utils +from captioning.utils.rewards import init_scorer, get_self_critical_reward +from captioning.modules.loss_wrapper import LossWrapper + +import pytorch_lightning as pl + +import detectron2.utils.comm as d2comm +from detectron2.utils.env import seed_all_rng +seed_all_rng(1234) + + +class LitModel(pl.LightningModule): + def __init__(self, opt): + super().__init__() + self.opt = opt + # Intilaize dataset + self.dataset = CaptionDataset(opt) + opt.vocab_size = self.dataset.vocab_size + opt.seq_length = self.dataset.seq_length + self.batch_size = opt.batch_size + + # Build model + opt.vocab = self.dataset.get_vocab() + model = models.setup(opt) + # print(model) + del opt.vocab + + # wrapper with loss in it. + lw_model = LossWrapper(model, opt) + + self.model = model + self.lw_model = lw_model + + self.struc_flag = None + self.sc_flag = None + + # if self.opt.use_clipscore: + # if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1': + # if CLIP-S+Grammar is used in reward -> Launch another CLIP-S where parameter is unchanged + if getattr(self.opt, 'use_grammar', False): + from captioning.utils.clipscore import CLIPScore + self.val_clipscore_model = CLIPScore( + mode=opt.clipscore_mode, use_grammar=False) + for p in self.val_clipscore_model.parameters(): + p.requires_grad = False + else: + if self.lw_model.clipscore_model is not None: + self.val_clipscore_model = self.lw_model.clipscore_model + else: + from captioning.utils.clipscore import CLIPScore + self.val_clipscore_model = CLIPScore( + mode=opt.clipscore_mode, use_grammar=False) + for p in self.val_clipscore_model.parameters(): + p.requires_grad = False + self.val_clipscore_model.eval() + + # BERTSCORE + from bert_score import BERTScorer + self.bert_scorer = BERTScorer( + lang="en", + # rescale_with_baseline=True, + rescale_with_baseline=False, + device='cpu' + ) + + def forward(self, *args, **kwargs): + """ + I hate this design. Never pretend it as a nn.Module + """ + raise NotImplementedError + + def train_dataloader(self): + train_dataset = torch.utils.data.Subset( + self.dataset, + self.dataset.split_ix['train'] + ) + + train_loader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=4, + collate_fn=self.dataset.collate_func + ) + return train_loader + + def val_dataloader(self, split='val'): + val_dataset = torch.utils.data.Subset( + self.dataset, + self.dataset.split_ix[split] + ) + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=4, + drop_last=False, + collate_fn=self.dataset.collate_func + ) + return val_loader + + def test_dataloader(self): + return self.val_dataloader('test') + + def training_step(self, data, batch_idx): + sc_flag, struc_flag = self.sc_flag, self.struc_flag + + tmp = [data['fc_feats'], data['att_feats'], + data['labels'], data['masks'], data['att_masks']] + fc_feats, att_feats, labels, masks, att_masks = tmp + if int(os.getenv('M2_cider', '0')) != 0: + data['gts'] = data['rawgts'] + + if self.opt.use_clipscore: + clip_vis_feats = data['clip_vis_feats'] + model_out = self.lw_model(fc_feats, att_feats, labels, masks, att_masks, + data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag, + clip_vis_feats=clip_vis_feats) + else: + model_out = self.lw_model(fc_feats, att_feats, labels, masks, att_masks, + data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag) + loss = model_out['loss'] + + data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1] + data_time = torch.tensor(data_time) + + logger_logs = model_out.copy() + # if struc_flag or sc_flag: + # logger_logs['reward'] = model_out['reward'].mean() + # logger_logs['reward_var'] = model_out['reward'].var(1).mean() + if struc_flag or sc_flag: + logger_logs['reward'] = model_out['reward'].mean() + for k in ['CLIP-S', 'RefCLIP-S', 'CIDEr', 'grammar_reward']: + if k in model_out: + logger_logs[k] = model_out[k] + if struc_flag: + logger_logs['reward_var'] = model_out['reward'].var(1).mean() + + logger_logs['scheduled_sampling_prob'] = torch.tensor( + self.model.ss_prob) + # logger_logs['training_loss'] = loss + logger_logs['loss'] = loss + logger_logs['data_time'] = data_time + + # UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0 + # Please use self.log(...) inside the lightningModule instead. + + # # log on a step or aggregate epoch metric to the logger and/or progress bar + # # (inside LightningModule) + # self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) + # warnings.warn(*args, **kwargs) + # UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0 + # Please use self.log(...) inside the lightningModule instead. + + # output = { + # 'loss': loss, + # 'log': logger_logs, + # 'progress_bar': {'data_time': data_time} + # } + + for k, v in logger_logs.items(): + if k in ['reward', 'reward_var', 'data_time', 'CLIP-S', 'RefCLIP-S', 'CIDEr', 'grammar_reward']: + self.log('train/'+k, v, prog_bar=True) + else: + self.log('train/'+k, v) + + return loss + + def validation_step(self, data, batch_idx): + model = self.model + crit = self.lw_model.crit + + opt = self.opt + eval_kwargs = {'dataset': opt.input_json} + eval_kwargs.update(vars(opt)) + + # CLIPScore + use_grammar = getattr(self.opt, 'use_grammar', False) + joint_out = getattr(self.opt, 'joint_out', False) + + verbose = eval_kwargs.get('verbose', True) + verbose_beam = eval_kwargs.get('verbose_beam', 0) + verbose_loss = eval_kwargs.get('verbose_loss', 1) + # num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) + # lang_eval = eval_kwargs.get('language_eval', 0) + dataset = eval_kwargs.get('dataset', 'coco') + beam_size = eval_kwargs.get('beam_size', 1) + sample_n = eval_kwargs.get('sample_n', 1) + remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0) + # Use this nasty way to make other code clean since it's a global configuration + os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) + + predictions = [] + n_predictions = [] + + loss = torch.tensor(0) + if data.get('labels', None) is not None and verbose_loss: + # forward the model to get loss + tmp = [data['fc_feats'], data['att_feats'], + data['labels'], data['masks'], data['att_masks']] + fc_feats, att_feats, labels, masks, att_masks = tmp + + loss = crit(model(fc_feats, att_feats, + labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) + + # forward the model to also get generated samples for each image + # Only leave one feature for each image, in case duplicate sample + tmp_eval_kwargs = eval_kwargs.copy() + tmp_eval_kwargs.update({'sample_n': 1}) + seq, seq_logprobs = model( + fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') + seq = seq.data + entropy = - (F.softmax(seq_logprobs, dim=2) * + seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) + perplexity = - \ + seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze( + 2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1) + + # Print beam search + if beam_size > 1 and verbose_beam: + for i in range(fc_feats.shape[0]): + print('\n'.join([utils.decode_sequence(model.vocab, _[ + 'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) + print('--' * 10) + sents = utils.decode_sequence(model.vocab, seq) + + # if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1': + # text_feat = self.lw_model.clipscore_model.text_extract(sents) + text_feat = self.val_clipscore_model.text_extract(sents, proj_norm=False) + + text_cont_feat = self.val_clipscore_model.clip_model.text_projection(text_feat) + text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True) + + vis_feat = data['clip_vis_feats'] + # if self.opt.clipscore_mode == 'clip_s': + # clip_s = self.val_clipscore_model(text_feat=text_cont_feat, img_feat=vis_feat, mode='clip_s') + + # elif self.opt.clipscore_mode == 'refclip_s': + clip_s = self.val_clipscore_model(text_feat=text_cont_feat, img_feat=vis_feat, mode='clip_s') + # ref_text = utils.decode_sequence(model.vocab, data['gts']) + + gt_indices = torch.arange(0, len(data['gts'])) + data_gts = [data['gts'][_] for _ in gt_indices.tolist()] + + B = len(data_gts) + + gts = [] + gts_valid_mask = [] + max_n_refs = max([len(_gts) for _gts in data_gts]) + for i in range(len(data_gts)): + _gts = utils.decode_sequence(model.vocab, data_gts[i]) + # pad references + n_ref = len(_gts) + _gts.extend([''] * (max_n_refs - n_ref)) + gts.extend(_gts) + gts_valid_mask.extend([1] * n_ref + [0] * (max_n_refs - n_ref)) + assert len(gts) == B * max_n_refs + assert len(gts_valid_mask) == B * max_n_refs + + ref_text = gts + ref_text_mask = gts_valid_mask + + refclip_s = self.val_clipscore_model( + text_feat=text_cont_feat, img_feat=vis_feat, + ref_text=ref_text, ref_text_mask=ref_text_mask, mode='refclip_s') + + # use_grammar = getattr(self.opt, 'use_grammar', False) + # joint_out = getattr(self.opt, 'joint_out', False) + if use_grammar and not joint_out: + with torch.no_grad(): + # grammar_logit = self.val_clipscore_model.grammar_score_head(text_feat.view(-1, 512)) + grammar_logit = self.lw_model.clipscore_model.grammar_score_head(text_feat.view(-1, 512)) + grammar_prob = torch.softmax(grammar_logit, dim=-1)[:, 1] + + + # BERTScore + if next(self.bert_scorer._model.parameters()).device != self.device: + self.bert_scorer._model.to(self.device) + self.bert_scorer.device = self.device + + + # [B*K] -> [B, K] + ref_text_per_example = [] + for i in range(B): + ref_text_list_example = [] + for k in range(max_n_refs): + ref = ref_text[i * max_n_refs + k] + if len(ref) > 0: + ref_text_list_example.append(ref) + # assert len(ref_text_list_example) == max_n_refs + ref_text_per_example.append(ref_text_list_example) + assert len(ref_text_per_example) == B + + P, R, F1 = self.bert_scorer.score( + sents, + ref_text_per_example, + ) + bertscore_f1 = F1 + # print('Example 5:') + # for i in range(5): + # print('Generated:', sents[i]) + # print('ref_text:', ref_text_per_example[i]) + # print('BERT-Score:', F1[i].item()) + + + for k, sent in enumerate(sents): + entry = {'image_id': data['infos'][k]['id'], 'caption': sent, + 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()} + if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1': + # if self.opt.clipscore_mode == 'clip_s': + # entry['clipscore'] = clipscore[k].item() + # entry['CLIP-S'] = clip_s[k].item() + # elif self.opt.clipscore_mode == 'refclip_s': + entry['CLIP-S'] = clip_s[k].item() + entry['RefCLIP-S'] = refclip_s[k].item() + + if use_grammar and not joint_out: + entry['grammar_prob'] = grammar_prob[k].item() + + # BERT-S + entry['BERT-S'] = bertscore_f1[k].item() + + if eval_kwargs.get('dump_path', 0) == 1: + entry['file_name'] = data['infos'][k]['file_path'] + predictions.append(entry) + if eval_kwargs.get('dump_images', 0) == 1: + # dump the raw image to vis/ folder + cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + \ + '" vis/imgs/img' + \ + str(len(predictions)) + '.jpg' # bit gross + print(cmd) + os.system(cmd) + + if verbose: + print('image %s: %s' % + (entry['image_id'], entry['caption'])) + + if sample_n > 1: + eval_utils.eval_split_n(model, n_predictions, [ + fc_feats, att_feats, att_masks, data], eval_kwargs) + + output = { + # 'val_loss': loss, + 'loss': loss, + 'predictions': predictions, + 'n_predictions': n_predictions, + } + return output + + def test_step(self, *args, **kwargs): + return self.validation_step(*args, **kwargs) + + def validation_epoch_end(self, outputs, split='val'): + outputs = d2comm.gather(outputs) + # master node + if d2comm.is_main_process(): + assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0 + outputs = sum(outputs, []) + + opt = self.opt + # val_loss_mean = sum([_['val_loss'] + # val_loss_mean = sum([_['val_loss'].cpu() + val_loss_mean = sum([_['loss'].cpu() + for _ in outputs]) / len(outputs) + + predictions = sum([_['predictions'] for _ in outputs], []) + if len(outputs[0]['n_predictions']) != 0: + n_predictions = sum([_['n_predictions'] for _ in outputs], []) + else: + n_predictions = [] + + lang_stats = None + if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]: + n_predictions = sorted( + n_predictions, key=lambda x: x['perplexity']) + + if not os.path.isdir('eval_results'): + os.mkdir('eval_results') + torch.save((predictions, n_predictions), os.path.join( + 'eval_results/', '.saved_pred_' + opt.id + '_' + split + '.pth')) + + if opt.language_eval: + lang_stats = eval_utils.language_eval( + opt.input_json, predictions, n_predictions, vars(opt), split) + + if opt.reduce_on_plateau: + optimizer = self.trainer.optimizers[0] + if 'CIDEr' in lang_stats: + optimizer.scheduler_step(-lang_stats['CIDEr']) + else: + optimizer.scheduler_step(val_loss_mean) + + # out = { + # 'val_loss': val_loss_mean + # } + out = { + 'loss': val_loss_mean + } + out.update(lang_stats) + # out['to_monitor'] = lang_stats['CIDEr'] if lang_stats is not None else -val_loss_mean + if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1': + # if self.opt.clipscore_mode == 'clip_s': + # out['clipscore'] = sum([p['clipscore'] for p in predictions]) / len(predictions) + # print('CLIPScore', out['clipscore']) + # out['CLIP-S'] = sum([p['CLIP-S'] for p in predictions]) / len(predictions) + # print('CLIP-S', out['CLIP-S']) + # elif self.opt.clipscore_mode == 'refclip_s': + out['CLIP-S'] = sum([p['CLIP-S'] for p in predictions]) / len(predictions) + print('CLIP-S', out['CLIP-S']) + + out['RefCLIP-S'] = sum([p['RefCLIP-S'] for p in predictions]) / len(predictions) + print('RefCLIP-S', out['RefCLIP-S']) + + if getattr(self.opt, 'use_grammar', False) and not getattr(self.opt, 'joint_out', False): + out['grammar_prob'] = sum([p['grammar_prob'] for p in predictions]) / len(predictions) + print('grammar_prob', out['grammar_prob']) + + out['BERT-S'] = sum([p['BERT-S'] for p in predictions]) / len(predictions) + print('BERT-S', out['BERT-S']) + else: + out = {} + + out = d2comm.all_gather(out)[0] # Only the one from master node + assert len(out) > 0 # make sure the head has index 0 + + # must all be tensors + out = {k: torch.tensor(v) if not torch.is_tensor( + v) else v for k, v in out.items()} + + # return { + # 'progress_bar': {'val_loss': out['val_loss']}, + # 'log': out, + # } + for k, v in out.items(): + # if k in ['loss', 'clipscore', 'RefCLIP-S', 'CIDEr']: + # if split != 'test': + # self.log(f'{split}/{k}', v, prog_bar=True) + # elif k == 'to_monitor': + # if split != 'test': + # self.log(f'{split}/{k}', v) + # else: + self.log(f'{split}/{k}', v) + + def test_epoch_end(self, outputs): + # out = self.validation_epoch_end(outputs, 'test') + # out['progress_bar'] = { + # # 'test_loss': out['progress_bar']['val_loss'] + # 'test_loss': out['progress_bar']['loss'] + # } + # out['log']['test_loss'] = out['log']['val_loss'] + # del out['log']['val_loss'] + # del out['log']['to_monitor'] + + # out['log'] = {'test_'+k if 'test' not in k else k:v \ + # for k,v in out['log'].items()} + + # return out + self.validation_epoch_end(outputs, 'test') + + def configure_optimizers(self): + opt = self.opt + model = self.model + + parameters = [p for p in model.parameters() if p.requires_grad] + + if opt.noamopt: + # assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer' + optimizer = utils.get_std_opt( + model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) + elif opt.reduce_on_plateau: + # optimizer = utils.build_optimizer(model.parameters(), opt) + optimizer = utils.build_optimizer(parameters, opt) + optimizer = utils.ReduceLROnPlateau(optimizer, + factor=opt.reduce_on_plateau_factor, + patience=opt.reduce_on_plateau_patience) + else: + # optimizer = utils.build_optimizer(model.parameters(), opt) + optimizer = utils.build_optimizer(parameters, opt) + return [optimizer], [] + + def optimizer_step(self, epoch, batch_idx, optimizer, + optimizer_idx, *args, **kwargs): + # warm up lr + opt = self.opt + iteration = self.trainer.global_step + if opt.use_warmup and (iteration < opt.noamopt_warmup): + opt.current_lr = opt.learning_rate * \ + (iteration+1) / opt.noamopt_warmup + utils.set_lr(optimizer, opt.current_lr) + + super().optimizer_step(epoch, batch_idx, optimizer, + optimizer_idx, *args, **kwargs) + + def state_dict(self): + """ + Save the model state dict as well as opt and vocab + """ + state_dict = self.model.state_dict() + device = next(iter(state_dict.values())).device + assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case' + state_dict.update({ + '_vocab': utils.serialize_to_tensor(self.model.vocab).to(device), + '_opt': utils.serialize_to_tensor(self.opt).to(device) + }) + return state_dict + + def load_state_dict(self, state_dict=None, strict=True): + if '_vocab' in state_dict: + self.model.vocab = utils.deserialize(state_dict['_vocab']) + del state_dict['_vocab'] + # elif strict: + # raise KeyError + if '_opt' in state_dict: + saved_model_opt = utils.deserialize(state_dict['_opt']) + del state_dict['_opt'] + opt = self.opt + # Make sure the saved opt is compatible with the curren topt + need_be_same = ["caption_model", + "rnn_type", "rnn_size", "num_layers"] + for checkme in need_be_same: + if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \ + getattr(opt, checkme) in ['updown', 'topdown']: + continue + assert getattr(saved_model_opt, checkme) == getattr( + opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme + # elif strict: + # raise KeyError + self.model.load_state_dict(state_dict, strict) + + +class OnEpochStartCallback(pl.Callback): + + def on_epoch_start(self, trainer, pl_module): + # Update lr/training stage/scheduled sampling prob etc. + opt = pl_module.opt + model = pl_module.model + epoch = trainer.current_epoch + optimizer = trainer.optimizers[0] + + if not opt.noamopt and not opt.reduce_on_plateau: + # Assign the learning rate + if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: + frac = ( + epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every + decay_factor = opt.learning_rate_decay_rate ** frac + opt.current_lr = opt.learning_rate * decay_factor + else: + opt.current_lr = opt.learning_rate + utils.set_lr(optimizer, opt.current_lr) # set the decayed rate + # Assign the scheduled sampling prob + if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: + frac = ( + epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every + opt.ss_prob = min(opt.scheduled_sampling_increase_prob * + frac, opt.scheduled_sampling_max_prob) + model.ss_prob = opt.ss_prob + + # If start self critical training + if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: + sc_flag = True + init_scorer(opt.cached_tokens) + else: + sc_flag = False + + # If start structure loss training + if opt.structure_after != -1 and epoch >= opt.structure_after: + struc_flag = True + init_scorer(opt.cached_tokens) + else: + struc_flag = False + + pl_module.struc_flag = struc_flag + pl_module.sc_flag = sc_flag + + +class ModelCheckpoint(pl.callbacks.ModelCheckpoint): + + def on_keyboard_interrupt(self, trainer, pl_module): + # Save model when keyboard interrupt + filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt') + self._save_model(filepath) + + +opt = opts.parse_opt() + +checkpoint_callback = ModelCheckpoint( + filepath=opt.checkpoint_path, + # dirpath=opt.checkpoint_path, + save_last=True, + save_top_k=1, + verbose=True, + # monitor='to_monitor', + # monitor='val/to_monitor', + monitor='val/CIDEr', + mode='max', + # prefix=opt.id+'_', + prefix=opt.id, + # filename=f'{opt.id}_', +) + +verbose = True +# import torch +# if torch.cuda.current_device() in [0, -1]: +if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0': + verbose = False + +if verbose: + print(opt) + print(""" + val_image_use, + save_checkpoint_very + save_every_epoch, + save_history-ckpt will be ignored. + """) + +# Lightning defines batch size as batch size per gpu +assert opt.batch_size % torch.cuda.device_count() == 0 +opt.batch_size = opt.batch_size // torch.cuda.device_count() + +# If resume from last checkpoint +# if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}_last.ckpt')): +# resume_from = os.path.join(opt.start_from, f'{opt.id}_last.ckpt') +if opt.start_from is not None: + resume_from = os.path.join(opt.start_from, f'{opt.id}-last.ckpt') + if os.path.isfile(resume_from): + if verbose: + print('Loading checkpoint from', resume_from) + else: + print("Checkpoint not found:", resume_from) + resume_from = None +else: + resume_from = None + +from pytorch_lightning.loggers import WandbLogger +wandb_logger = WandbLogger( + project='CLIP-ViL-COCOCaption', + name=opt.id, +) + +if verbose: + wandb_logger.experiment.config.update(opt) + from pathlib import Path + import glob + import wandb + # src_dir = Path(__file__).resolve().parent.parent + glob_str = "**/*.py" + base_path = './' + wandb.save(glob_str=glob_str, base_path=base_path) + + # code = wandb.Artifact('project-source', type='code') + # for path in glob.glob('**/*.py', recursive=True): + # code.add_file(path, name='source/'+path) + # print(path) + # wandb.run.use_artifact(code) + + + + +lit = LitModel(opt) +# warning grad_clip_mode is ignored. +trainer = pl.Trainer( + callbacks=[ + OnEpochStartCallback(), + # pl.callbacks.lr_logger.LearningRateLogger() + pl.callbacks.LearningRateMonitor() + ], + default_root_dir=opt.checkpoint_path, + resume_from_checkpoint=resume_from, + distributed_backend='ddp', + check_val_every_n_epoch=1, + max_epochs=opt.max_epochs, + gradient_clip_val=opt.grad_clip_value, + gpus=torch.cuda.device_count(), + checkpoint_callback=checkpoint_callback, + log_gpu_memory='min_max', + # log_save_interval=opt.losses_log_every, + log_every_n_steps=opt.losses_log_every, + profiler=True, + # profiler='simple', + # row_log_interval=10, # what is it? + flush_logs_every_n_steps=10, + num_sanity_val_steps=0, + # val_check_interval=0.01, + # limit_train_batches=500, + # progress_bar_refresh_rate=0, + # fast_dev_run=True, + precision=opt.precision, + logger=wandb_logger +) + +if os.getenv('EVALUATE', '0') == '1': + trainer.test(lit) +else: + trainer.fit(lit)