File size: 5,129 Bytes
d290c84 1162512 d290c84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
import json
from build_vocab import Vocabulary, JsonReader
import numpy as np
from torchvision import transforms
import pickle
class ChestXrayDataSet(Dataset):
def __init__(self,
image_dir,
caption_json,
file_list,
vocabulary,
s_max=10,
n_max=50,
transforms=None):
self.image_dir = image_dir
self.caption = JsonReader(caption_json)
self.file_names, self.labels = self.__load_label_list(file_list)
self.vocab = vocabulary
self.transform = transforms
self.s_max = s_max
self.n_max = n_max
def __load_label_list(self, file_list):
labels = []
filename_list = []
with open(file_list, 'r') as f:
for line in f:
items = line.split()
image_name = items[0]
label = items[1:]
label = [int(i) for i in label]
image_name = '{}.png'.format(image_name)
filename_list.append(image_name)
labels.append(label)
return filename_list, labels
def __getitem__(self, index):
image_name = self.file_names[index]
image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB')
label = self.labels[index]
if self.transform is not None:
image = self.transform(image)
try:
text = self.caption[image_name]
except Exception as err:
text = 'normal. '
target = list()
max_word_num = 0
for i, sentence in enumerate(text.split('. ')):
if i >= self.s_max:
break
sentence = sentence.split()
if len(sentence) == 0 or len(sentence) == 1 or len(sentence) > self.n_max:
continue
tokens = list()
tokens.append(self.vocab('<start>'))
tokens.extend([self.vocab(token) for token in sentence])
tokens.append(self.vocab('<end>'))
if max_word_num < len(tokens):
max_word_num = len(tokens)
target.append(tokens)
sentence_num = len(target)
return image, image_name, list(label / np.sum(label)), target, sentence_num, max_word_num
def __len__(self):
return len(self.file_names)
def collate_fn(data):
images, image_id, label, captions, sentence_num, max_word_num = zip(*data)
images = torch.stack(images, 0)
max_sentence_num = max(sentence_num)
max_word_num = max(max_word_num)
targets = np.zeros((len(captions), max_sentence_num + 1, max_word_num))
prob = np.zeros((len(captions), max_sentence_num + 1))
for i, caption in enumerate(captions):
for j, sentence in enumerate(caption):
targets[i, j, :len(sentence)] = sentence[:]
prob[i][j] = len(sentence) > 0
return images, image_id, torch.Tensor(label), targets, prob
def get_loader(image_dir,
caption_json,
file_list,
vocabulary,
transform,
batch_size,
s_max=10,
n_max=50,
shuffle=False):
dataset = ChestXrayDataSet(image_dir=image_dir,
caption_json=caption_json,
file_list=file_list,
vocabulary=vocabulary,
s_max=s_max,
n_max=n_max,
transforms=transform)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn)
return data_loader
if __name__ == '__main__':
vocab_path = '../data/vocab.pkl'
image_dir = '../data/images'
caption_json = '../data/debugging_captions.json'
file_list = '../data/debugging.txt'
batch_size = 6
resize = 256
crop_size = 224
transform = transforms.Compose([
transforms.Resize(resize),
transforms.RandomCrop(crop_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
with open(vocab_path, 'rb') as f:
vocab = pickle.load(f)
data_loader = get_loader(image_dir=image_dir,
caption_json=caption_json,
file_list=file_list,
vocabulary=vocab,
transform=transform,
batch_size=batch_size,
shuffle=False)
for i, (image, image_id, label, target, prob) in enumerate(data_loader):
print(image.shape)
print(image_id)
print(label)
print(target)
print(prob)
break
|