Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import json | |
import h5py | |
import os | |
import numpy as np | |
import random | |
import torch | |
import skimage | |
import skimage.io | |
import scipy.misc | |
from torchvision import transforms as trn | |
preprocess = trn.Compose([ | |
#trn.ToTensor(), | |
trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# from ..utils.resnet_utils import myResnet | |
# from ..utils import resnet | |
from captioning.utils.resnet_utils import myResnet | |
from captioning.utils import resnet | |
class DataLoaderRaw(): | |
def __init__(self, opt): | |
self.opt = opt | |
self.coco_json = opt.get('coco_json', '') | |
self.folder_path = opt.get('folder_path', '') | |
self.batch_size = opt.get('batch_size', 1) | |
self.seq_per_img = 1 | |
# Load resnet | |
self.cnn_model = opt.get('cnn_model', 'resnet101') | |
self.my_resnet = getattr(resnet, self.cnn_model)() | |
self.my_resnet.load_state_dict(torch.load('./data/imagenet_weights/'+self.cnn_model+'.pth')) | |
self.my_resnet = myResnet(self.my_resnet) | |
self.my_resnet.cuda() | |
self.my_resnet.eval() | |
# load the json file which contains additional information about the dataset | |
print('DataLoaderRaw loading images from folder: ', self.folder_path) | |
self.files = [] | |
self.ids = [] | |
print(len(self.coco_json)) | |
if len(self.coco_json) > 0: | |
print('reading from ' + opt.coco_json) | |
# read in filenames from the coco-style json file | |
self.coco_annotation = json.load(open(self.coco_json)) | |
for k,v in enumerate(self.coco_annotation['images']): | |
fullpath = os.path.join(self.folder_path, v['file_name']) | |
self.files.append(fullpath) | |
self.ids.append(v['id']) | |
else: | |
# read in all the filenames from the folder | |
print('listing all images in directory ' + self.folder_path) | |
def isImage(f): | |
supportedExt = ['.jpg','.JPG','.jpeg','.JPEG','.png','.PNG','.ppm','.PPM'] | |
for ext in supportedExt: | |
start_idx = f.rfind(ext) | |
if start_idx >= 0 and start_idx + len(ext) == len(f): | |
return True | |
return False | |
n = 1 | |
for root, dirs, files in os.walk(self.folder_path, topdown=False): | |
for file in files: | |
fullpath = os.path.join(self.folder_path, file) | |
if isImage(fullpath): | |
self.files.append(fullpath) | |
self.ids.append(str(n)) # just order them sequentially | |
n = n + 1 | |
self.N = len(self.files) | |
print('DataLoaderRaw found ', self.N, ' images') | |
self.iterator = 0 | |
# Nasty | |
self.dataset = self # to fix the bug in eval | |
def get_batch(self, split, batch_size=None): | |
batch_size = batch_size or self.batch_size | |
# pick an index of the datapoint to load next | |
fc_batch = np.ndarray((batch_size, 2048), dtype = 'float32') | |
att_batch = np.ndarray((batch_size, 14, 14, 2048), dtype = 'float32') | |
max_index = self.N | |
wrapped = False | |
infos = [] | |
for i in range(batch_size): | |
ri = self.iterator | |
ri_next = ri + 1 | |
if ri_next >= max_index: | |
ri_next = 0 | |
wrapped = True | |
# wrap back around | |
self.iterator = ri_next | |
img = skimage.io.imread(self.files[ri]) | |
if len(img.shape) == 2: | |
img = img[:,:,np.newaxis] | |
img = np.concatenate((img, img, img), axis=2) | |
img = img[:,:,:3].astype('float32')/255.0 | |
img = torch.from_numpy(img.transpose([2,0,1])).cuda() | |
img = preprocess(img) | |
with torch.no_grad(): | |
tmp_fc, tmp_att = self.my_resnet(img) | |
fc_batch[i] = tmp_fc.data.cpu().float().numpy() | |
att_batch[i] = tmp_att.data.cpu().float().numpy() | |
info_struct = {} | |
info_struct['id'] = self.ids[ri] | |
info_struct['file_path'] = self.files[ri] | |
infos.append(info_struct) | |
data = {} | |
data['fc_feats'] = fc_batch | |
data['att_feats'] = att_batch.reshape(batch_size, -1, 2048) | |
data['labels'] = np.zeros([batch_size, 0]) | |
data['masks'] = None | |
data['att_masks'] = None | |
data['bounds'] = {'it_pos_now': self.iterator, 'it_max': self.N, 'wrapped': wrapped} | |
data['infos'] = infos | |
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor | |
return data | |
def reset_iterator(self, split): | |
self.iterator = 0 | |
def get_vocab_size(self): | |
return len(self.ix_to_word) | |
def get_vocab(self): | |
return self.ix_to_word | |