from __future__ import absolute_import from __future__ import division from __future__ import print_function import json import h5py from lmdbdict import lmdbdict from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC import os import numpy as np import numpy.random as npr import random from functools import partial import torch import torch.utils.data as data import multiprocessing import six class HybridLoader: """ If db_path is a director, then use normal file loading If lmdb, then load from lmdb The loading method depend on extention. in_memory: if in_memory is True, we save all the features in memory For individual np(y|z)s, we don't need to do that because the system will do this for us. Should be useful for lmdb or h5. (Copied this idea from vilbert) """ def __init__(self, db_path, ext, in_memory=False): self.db_path = db_path self.ext = ext if self.ext == '.npy': self.loader = lambda x: np.load(six.BytesIO(x)) else: def load_npz(x): x = np.load(six.BytesIO(x)) return x['feat'] if 'feat' in x else x['z'] # normally it should be 'feat', but under cocotest_bu, the key is saved to be 'z' mistakenly. self.loader = load_npz if db_path.endswith('.lmdb'): self.db_type = 'lmdb' self.lmdb = lmdbdict(db_path, unsafe=True) self.lmdb._key_dumps = DUMPS_FUNC['ascii'] self.lmdb._value_loads = LOADS_FUNC['identity'] elif db_path.endswith('.pth'): # Assume a key,value dictionary self.db_type = 'pth' self.feat_file = torch.load(db_path) self.loader = lambda x: x print('HybridLoader: ext is ignored') elif db_path.endswith('h5'): self.db_type = 'h5' self.loader = lambda x: np.array(x).astype('float32') else: self.db_type = 'dir' self.in_memory = in_memory if self.in_memory: self.features = {} def get(self, key): if self.in_memory and key in self.features: # We save f_input because we want to save the # compressed bytes to save memory f_input = self.features[key] elif self.db_type == 'lmdb': f_input = self.lmdb[key] elif self.db_type == 'pth': f_input = self.feat_file[key] elif self.db_type == 'h5': f_input = h5py.File(self.db_path, 'r')[key] else: f_input = open(os.path.join(self.db_path, key + self.ext), 'rb').read() if self.in_memory and key not in self.features: self.features[key] = f_input # load image feat = self.loader(f_input) return feat class Dataset(data.Dataset): def get_vocab_size(self): return self.vocab_size def get_vocab(self): return self.ix_to_word def get_seq_length(self): return self.seq_length def __init__(self, opt): self.opt = opt self.seq_per_img = opt.seq_per_img # feature related options self.use_fc = getattr(opt, 'use_fc', True) self.use_att = getattr(opt, 'use_att', True) self.use_box = getattr(opt, 'use_box', 0) self.norm_att_feat = getattr(opt, 'norm_att_feat', 0) self.norm_box_feat = getattr(opt, 'norm_box_feat', 0) # load the json file which contains additional information about the dataset print('DataLoader loading json file: ', opt.input_json) self.info = json.load(open(self.opt.input_json)) if 'ix_to_word' in self.info: self.ix_to_word = self.info['ix_to_word'] self.vocab_size = len(self.ix_to_word) print('vocab size is ', self.vocab_size) # open the hdf5 file print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5) """ Setting input_label_h5 to none is used when only doing generation. For example, when you need to test on coco test set. """ if self.opt.input_label_h5 != 'none': self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') # load in the sequence data seq_size = self.h5_label_file['labels'].shape self.label = self.h5_label_file['labels'][:] self.seq_length = seq_size[1] print('max sequence length in data is', self.seq_length) # load the pointers in full to RAM (should be small enough) self.label_start_ix = self.h5_label_file['label_start_ix'][:] self.label_end_ix = self.h5_label_file['label_end_ix'][:] else: self.seq_length = 1 self.data_in_memory = getattr(opt, 'data_in_memory', False) self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy', in_memory=self.data_in_memory) self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz', in_memory=self.data_in_memory) self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy', in_memory=self.data_in_memory) self.num_images = len(self.info['images']) # self.label_start_ix.shape[0] print('read %d image features' %(self.num_images)) # separate out indexes for each of the provided splits self.split_ix = {'train': [], 'val': [], 'test': []} for ix in range(len(self.info['images'])): img = self.info['images'][ix] if not 'split' in img: self.split_ix['train'].append(ix) self.split_ix['val'].append(ix) self.split_ix['test'].append(ix) elif img['split'] == 'train': self.split_ix['train'].append(ix) elif img['split'] == 'val': self.split_ix['val'].append(ix) elif img['split'] == 'test': self.split_ix['test'].append(ix) elif opt.train_only == 0: # restval self.split_ix['train'].append(ix) print('assigned %d images to split train' %len(self.split_ix['train'])) print('assigned %d images to split val' %len(self.split_ix['val'])) print('assigned %d images to split test' %len(self.split_ix['test'])) def get_captions(self, ix, seq_per_img): # fetch the sequence labels ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 ix2 = self.label_end_ix[ix] - 1 ncap = ix2 - ix1 + 1 # number of captions available for this image assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' random.seed(42) torch.manual_seed(42) if torch.cuda.is_available(): torch.cuda.manual_seed(42) if ncap < seq_per_img: # we need to subsample (with replacement) seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') for q in range(seq_per_img): ixl = random.randint(ix1,ix2) seq[q, :] = self.label[ixl, :self.seq_length] else: ixl = random.randint(ix1, ix2 - seq_per_img + 1) seq = self.label[ixl: ixl + seq_per_img, :self.seq_length] return seq def collate_func(self, batch, split): seq_per_img = self.seq_per_img fc_batch = [] att_batch = [] label_batch = [] wrapped = False infos = [] gts = [] for sample in batch: # fetch image tmp_fc, tmp_att, tmp_seq, \ ix, it_pos_now, tmp_wrapped = sample if tmp_wrapped: wrapped = True fc_batch.append(tmp_fc) att_batch.append(tmp_att) tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int') if hasattr(self, 'h5_label_file'): # if there is ground truth tmp_label[:, 1 : self.seq_length + 1] = tmp_seq label_batch.append(tmp_label) # Used for reward evaluation if hasattr(self, 'h5_label_file'): # if there is ground truth gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]]) else: gts.append([]) # record associated info as well info_dict = {} info_dict['ix'] = ix info_dict['id'] = self.info['images'][ix]['id'] info_dict['file_path'] = self.info['images'][ix].get('file_path', '') infos.append(info_dict) # #sort by att_feat length # fc_batch, att_batch, label_batch, gts, infos = \ # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) fc_batch, att_batch, label_batch, gts, infos = \ zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) data = {} data['fc_feats'] = np.stack(fc_batch) # merge att_feats max_att_len = max([_.shape[0] for _ in att_batch]) data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32') for i in range(len(att_batch)): data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i] data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32') for i in range(len(att_batch)): data['att_masks'][i, :att_batch[i].shape[0]] = 1 # set att_masks to None if attention features have same length if data['att_masks'].sum() == data['att_masks'].size: data['att_masks'] = None data['labels'] = np.vstack(label_batch) # generate mask nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels']))) mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32') for ix, row in enumerate(mask_batch): row[:nonzeros[ix]] = 1 data['masks'] = mask_batch data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1) data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1) data['gts'] = gts # all ground truth captions of each images data['bounds'] = {'it_pos_now': it_pos_now, # the it_pos_now of the last sample 'it_max': len(self.split_ix[split]), 'wrapped': wrapped} data['infos'] = infos data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor return data def __getitem__(self, index): """This function returns a tuple that is further passed to collate_fn """ ix, it_pos_now, wrapped = index #self.split_ix[index] if self.use_att: att_feat = self.att_loader.get(str(self.info['images'][ix]['id'])) # shape: (14,14,4096) # Reshape to K x C att_feat = att_feat.reshape(-1, att_feat.shape[-1]) # shape:(196,4096) if self.norm_att_feat: att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True) if self.use_box: box_feat = self.box_loader.get(str(self.info['images'][ix]['id'])) # devided by image width and height x1,y1,x2,y2 = np.hsplit(box_feat, 4) h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width'] box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1?? if self.norm_box_feat: box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True) att_feat = np.hstack([att_feat, box_feat]) # sort the features by the size of boxes att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True)) else: att_feat = np.zeros((0,0), dtype='float32') if self.use_fc: try: fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id'])) except: # Use average of attention when there is no fc provided (For bottomup feature) fc_feat = att_feat.mean(0) else: fc_feat = np.zeros((0), dtype='float32') if hasattr(self, 'h5_label_file'): seq = self.get_captions(ix, self.seq_per_img) else: seq = None return (fc_feat, att_feat, seq, ix, it_pos_now, wrapped) def __len__(self): return len(self.info['images']) class DataLoader: def __init__(self, opt): self.opt = opt self.batch_size = self.opt.batch_size self.dataset = Dataset(opt) # Initialize loaders and iters self.loaders, self.iters = {}, {} for split in ['train', 'val', 'test']: if split == 'train': sampler = MySampler(self.dataset.split_ix[split], shuffle=True, wrap=True) else: sampler = MySampler(self.dataset.split_ix[split], shuffle=False, wrap=False) self.loaders[split] = data.DataLoader(dataset=self.dataset, batch_size=self.batch_size, sampler=sampler, pin_memory=True, num_workers=4, # 4 is usually enough collate_fn=partial(self.dataset.collate_func, split=split), drop_last=False) self.iters[split] = iter(self.loaders[split]) def get_batch(self, split): try: data = next(self.iters[split]) except StopIteration: self.iters[split] = iter(self.loaders[split]) data = next(self.iters[split]) return data def reset_iterator(self, split): self.loaders[split].sampler._reset_iter() self.iters[split] = iter(self.loaders[split]) def get_vocab_size(self): return self.dataset.get_vocab_size() @property def vocab_size(self): return self.get_vocab_size() def get_vocab(self): return self.dataset.get_vocab() def get_seq_length(self): return self.dataset.get_seq_length() @property def seq_length(self): return self.get_seq_length() def state_dict(self): def get_prefetch_num(split): if self.loaders[split].num_workers > 0: return (self.iters[split]._send_idx - self.iters[split]._rcvd_idx) * self.batch_size else: return 0 return {split: loader.sampler.state_dict(get_prefetch_num(split)) \ for split, loader in self.loaders.items()} def load_state_dict(self, state_dict=None): if state_dict is None: return for split in self.loaders.keys(): self.loaders[split].sampler.load_state_dict(state_dict[split]) class MySampler(data.sampler.Sampler): def __init__(self, index_list, shuffle, wrap): self.index_list = index_list self.shuffle = shuffle self.wrap = wrap # if wrap, there will be not stop iteration called # wrap True used during training, and wrap False used during test. self._reset_iter() def __iter__(self): return self def __next__(self): wrapped = False if self.iter_counter == len(self._index_list): self._reset_iter() if self.wrap: wrapped = True else: raise StopIteration() if len(self._index_list) == 0: # overflow when 0 samples return None elem = (self._index_list[self.iter_counter], self.iter_counter+1, wrapped) self.iter_counter += 1 return elem def next(self): return self.__next__() def _reset_iter(self): np.random.seed(42) if self.shuffle: rand_perm = npr.permutation(len(self.index_list)) self._index_list = [self.index_list[_] for _ in rand_perm] else: self._index_list = self.index_list self.iter_counter = 0 def __len__(self): return len(self.index_list) def load_state_dict(self, state_dict=None): if state_dict is None: return self._index_list = state_dict['index_list'] self.iter_counter = state_dict['iter_counter'] def state_dict(self, prefetched_num=None): prefetched_num = prefetched_num or 0 return { 'index_list': self._index_list, 'iter_counter': self.iter_counter - prefetched_num }