Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
from torch.utils.data import Dataset | |
from PIL import Image | |
import os | |
import json | |
from build_vocab import Vocabulary, JsonReader | |
import numpy as np | |
from torchvision import transforms | |
import pickle | |
class ChestXrayDataSet(Dataset): | |
def __init__(self, | |
image_dir, | |
caption_json, | |
file_list, | |
vocabulary, | |
s_max=10, | |
n_max=50, | |
transforms=None): | |
self.image_dir = image_dir | |
self.caption = JsonReader(caption_json) | |
self.file_names, self.labels = self.__load_label_list(file_list) | |
self.vocab = vocabulary | |
self.transform = transforms | |
self.s_max = s_max | |
self.n_max = n_max | |
def __load_label_list(self, file_list): | |
labels = [] | |
filename_list = [] | |
with open(file_list, 'r') as f: | |
for line in f: | |
items = line.split() | |
image_name = items[0] | |
label = items[1:] | |
label = [int(i) for i in label] | |
image_name = '{}.png'.format(image_name) | |
filename_list.append(image_name) | |
labels.append(label) | |
return filename_list, labels | |
def __getitem__(self, index): | |
image_name = self.file_names[index] | |
image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB') | |
label = self.labels[index] | |
if self.transform is not None: | |
image = self.transform(image) | |
try: | |
text = self.caption[image_name] | |
except Exception as err: | |
text = 'normal. ' | |
target = list() | |
max_word_num = 0 | |
for i, sentence in enumerate(text.split('. ')): | |
if i >= self.s_max: | |
break | |
sentence = sentence.split() | |
if len(sentence) == 0 or len(sentence) == 1 or len(sentence) > self.n_max: | |
continue | |
tokens = list() | |
tokens.append(self.vocab('<start>')) | |
tokens.extend([self.vocab(token) for token in sentence]) | |
tokens.append(self.vocab('<end>')) | |
if max_word_num < len(tokens): | |
max_word_num = len(tokens) | |
target.append(tokens) | |
sentence_num = len(target) | |
return image, image_name, list(label / np.sum(label)), target, sentence_num, max_word_num | |
def __len__(self): | |
return len(self.file_names) | |
def collate_fn(data): | |
images, image_id, label, captions, sentence_num, max_word_num = zip(*data) | |
images = torch.stack(images, 0) | |
max_sentence_num = max(sentence_num) | |
max_word_num = max(max_word_num) | |
targets = np.zeros((len(captions), max_sentence_num + 1, max_word_num)) | |
prob = np.zeros((len(captions), max_sentence_num + 1)) | |
for i, caption in enumerate(captions): | |
for j, sentence in enumerate(caption): | |
targets[i, j, :len(sentence)] = sentence[:] | |
prob[i][j] = len(sentence) > 0 | |
return images, image_id, torch.Tensor(label), targets, prob | |
def get_loader(image_dir, | |
caption_json, | |
file_list, | |
vocabulary, | |
transform, | |
batch_size, | |
s_max=10, | |
n_max=50, | |
shuffle=False): | |
dataset = ChestXrayDataSet(image_dir=image_dir, | |
caption_json=caption_json, | |
file_list=file_list, | |
vocabulary=vocabulary, | |
s_max=s_max, | |
n_max=n_max, | |
transforms=transform) | |
data_loader = torch.utils.data.DataLoader(dataset=dataset, | |
batch_size=batch_size, | |
shuffle=shuffle, | |
collate_fn=collate_fn) | |
return data_loader | |
if __name__ == '__main__': | |
vocab_path = '../data/vocab.pkl' | |
image_dir = '../data/images' | |
caption_json = '../data/debugging_captions.json' | |
file_list = '../data/debugging.txt' | |
batch_size = 6 | |
resize = 256 | |
crop_size = 224 | |
transform = transforms.Compose([ | |
transforms.Resize(resize), | |
transforms.RandomCrop(crop_size), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), | |
(0.229, 0.224, 0.225))]) | |
with open(vocab_path, 'rb') as f: | |
vocab = pickle.load(f) | |
data_loader = get_loader(image_dir=image_dir, | |
caption_json=caption_json, | |
file_list=file_list, | |
vocabulary=vocab, | |
transform=transform, | |
batch_size=batch_size, | |
shuffle=False) | |
for i, (image, image_id, label, target, prob) in enumerate(data_loader): | |
print(image.shape) | |
print(image_id) | |
print(label) | |
print(target) | |
print(prob) | |
break | |