import torch from torch.utils.data import Dataset from PIL import Image import os import json from build_vocab import Vocabulary, JsonReader import numpy as np from torchvision import transforms import pickle class ChestXrayDataSet(Dataset): def __init__(self, image_dir, caption_json, file_list, vocabulary, s_max=10, n_max=50, transforms=None): self.image_dir = image_dir self.caption = JsonReader(caption_json) self.file_names, self.labels = self.__load_label_list(file_list) self.vocab = vocabulary self.transform = transforms self.s_max = s_max self.n_max = n_max def __load_label_list(self, file_list): labels = [] filename_list = [] with open(file_list, 'r') as f: for line in f: items = line.split() image_name = items[0] label = items[1:] label = [int(i) for i in label] image_name = '{}.png'.format(image_name) filename_list.append(image_name) labels.append(label) return filename_list, labels def __getitem__(self, index): image_name = self.file_names[index] image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB') label = self.labels[index] if self.transform is not None: image = self.transform(image) try: text = self.caption[image_name] except Exception as err: text = 'normal. ' target = list() max_word_num = 0 for i, sentence in enumerate(text.split('. ')): if i >= self.s_max: break sentence = sentence.split() if len(sentence) == 0 or len(sentence) == 1 or len(sentence) > self.n_max: continue tokens = list() tokens.append(self.vocab('')) tokens.extend([self.vocab(token) for token in sentence]) tokens.append(self.vocab('')) if max_word_num < len(tokens): max_word_num = len(tokens) target.append(tokens) sentence_num = len(target) return image, image_name, list(label / np.sum(label)), target, sentence_num, max_word_num def __len__(self): return len(self.file_names) def collate_fn(data): images, image_id, label, captions, sentence_num, max_word_num = zip(*data) images = torch.stack(images, 0) max_sentence_num = max(sentence_num) max_word_num = max(max_word_num) targets = np.zeros((len(captions), max_sentence_num + 1, max_word_num)) prob = np.zeros((len(captions), max_sentence_num + 1)) for i, caption in enumerate(captions): for j, sentence in enumerate(caption): targets[i, j, :len(sentence)] = sentence[:] prob[i][j] = len(sentence) > 0 return images, image_id, torch.Tensor(label), targets, prob def get_loader(image_dir, caption_json, file_list, vocabulary, transform, batch_size, s_max=10, n_max=50, shuffle=False): dataset = ChestXrayDataSet(image_dir=image_dir, caption_json=caption_json, file_list=file_list, vocabulary=vocabulary, s_max=s_max, n_max=n_max, transforms=transform) data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn) return data_loader if __name__ == '__main__': vocab_path = '../data/vocab.pkl' image_dir = '../data/images' caption_json = '../data/debugging_captions.json' file_list = '../data/debugging.txt' batch_size = 6 resize = 256 crop_size = 224 transform = transforms.Compose([ transforms.Resize(resize), transforms.RandomCrop(crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) with open(vocab_path, 'rb') as f: vocab = pickle.load(f) data_loader = get_loader(image_dir=image_dir, caption_json=caption_json, file_list=file_list, vocabulary=vocab, transform=transform, batch_size=batch_size, shuffle=False) for i, (image, image_id, label, target, prob) in enumerate(data_loader): print(image.shape) print(image_id) print(label) print(target) print(prob) break