Jyothirmai commited on
Commit
d290c84
1 Parent(s): 590d88b

Upload 13 files

Browse files
Files changed (13) hide show
  1. build_tag.py +89 -0
  2. dataset.py +151 -0
  3. extractor.pth.tar +3 -0
  4. loss.py +78 -0
  5. mlc.pth.tar +3 -0
  6. models.py +552 -0
  7. pytorch_model.bin +3 -0
  8. sentence.pth.tar +3 -0
  9. tester.py +283 -0
  10. train_best_loss.pth.tar +3 -0
  11. val_best_loss.pth.tar +3 -0
  12. vocab.pkl +3 -0
  13. word.pth.tar +3 -0
build_tag.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Tag(object):
2
+ def __init__(self):
3
+ self.static_tags = self.__load_static_tags()
4
+ self.id2tags = self.__load_id2tags()
5
+ self.tags2id = self.__load_tags2id()
6
+
7
+ def array2tags(self, array):
8
+ tags = []
9
+ for id in array:
10
+ tags.append(self.id2tags[id])
11
+ return tags
12
+
13
+ def tags2array(self, tags):
14
+ array = []
15
+ for tag in self.static_tags:
16
+ if tag in tags:
17
+ array.append(1)
18
+ else:
19
+ array.append(0)
20
+ return array
21
+
22
+ def inv_tags2array(self, array):
23
+ tags = []
24
+ for i, value in enumerate(array):
25
+ if value != 0:
26
+ tags.append(self.id2tags[i])
27
+ return tags
28
+
29
+ def __load_id2tags(self):
30
+ id2tags = {}
31
+ for i, tag in enumerate(self.static_tags):
32
+ id2tags[i] = tag
33
+ return id2tags
34
+
35
+ def __load_tags2id(self):
36
+ tags2id = {}
37
+ for i, tag in enumerate(self.static_tags):
38
+ tags2id[tag] = i
39
+ return tags2id
40
+
41
+ def __load_static_tags(self):
42
+ static_tags_name = ['cardiac monitor', 'lymphatic diseases', 'pulmonary disease', 'osteophytes', 'foreign body',
43
+ 'dish', 'aorta, thoracic', 'atherosclerosis', 'histoplasmosis', 'hypoventilation',
44
+ 'catheterization, central venous', 'pleural effusions', 'pleural effusion', 'callus',
45
+ 'sternotomy', 'lymph nodes', 'tortuous aorta', 'stent', 'interstitial pulmonary edema',
46
+ 'cholecystectomies', 'neoplasm', 'central venous catheter', 'pneumothorax',
47
+ 'metastatic disease', 'vena cava, superior', 'cholecystectomy', 'scoliosis',
48
+ 'subcutaneous emphysema', 'thoracolumbar scoliosis', 'spinal osteophytosis',
49
+ 'pulmonary fibroses', 'rib fractures', 'sarcoidosis', 'eventration', 'fibrosis', 'spine',
50
+ 'obstructive lung disease', 'pneumonitis', 'osteopenia', 'air trapping', 'demineralization',
51
+ 'mass lesion', 'pulmonary hypertension', 'pleural diseases', 'pleural thickening',
52
+ 'calcifications of the aorta', 'calcinosis', 'cystic fibrosis', 'empyema', 'catheter',
53
+ 'lymph', 'pericardial effusion', 'lung cancer', 'rib fracture', 'granulomatous disease',
54
+ 'chronic obstructive pulmonary disease', 'rib', 'clip', 'aortic ectasia', 'shoulder',
55
+ 'scarring', 'scleroses', 'adenopathy', 'emphysemas', 'pneumonectomy', 'infection',
56
+ 'aspiration', 'bilateral pleural effusion', 'bulla', 'lumbar vertebrae', 'lung neoplasms',
57
+ 'lymphadenopathy', 'hyperexpansion', 'ectasia', 'bronchiectasis', 'nodule', 'pneumonia',
58
+ 'right-sided pleural effusion', 'osteoarthritis', 'thoracic spondylosis', 'picc',
59
+ 'cervical fusion', 'tracheostomies', 'fusion', 'thoracic vertebrae', 'catheters',
60
+ 'emphysema', 'trachea', 'surgery', 'cervical spine fusion', 'hypertension, pulmonary',
61
+ 'pneumoperitoneum', 'scar', 'atheroscleroses', 'aortic calcifications', 'volume overload',
62
+ 'right upper lobe pneumonia', 'apical granuloma', 'diaphragms', 'copd', 'kyphoses',
63
+ 'spinal fractures', 'fracture', 'clavicle', 'focal atelectasis', 'collapse',
64
+ 'thoracotomies', 'congestive heart failure', 'calcified lymph nodes', 'edema',
65
+ 'degenerative disc diseases', 'cervical vertebrae', 'diaphragm', 'humerus', 'heart failure',
66
+ 'normal', 'coronary artery bypass', 'pulmonary atelectasis', 'lung diseases, interstitial',
67
+ 'pulmonary disease, chronic obstructive', 'opacity', 'deformity', 'chronic disease',
68
+ 'pleura', 'aorta', 'tuberculoses', 'hiatal hernia', 'scolioses', 'pleural fluid',
69
+ 'malignancy', 'kyphosis', 'bronchiectases', 'congestion', 'discoid atelectasis', 'nipple',
70
+ 'bronchitis', 'pulmonary artery', 'cardiomegaly', 'thoracic aorta', 'arthritic changes',
71
+ 'pulmonary edema', 'vascular calcification', 'sclerotic', 'central venous catheters',
72
+ 'catheterization', 'hydropneumothorax', 'aortic valve', 'hyperinflation', 'prostheses',
73
+ 'pacemaker, artificial', 'bypass grafts', 'pulmonary fibrosis', 'multiple myeloma',
74
+ 'postoperative period', 'cabg', 'right lower lobe pneumonia', 'granuloma',
75
+ 'degenerative change', 'atelectasis', 'inflammation', 'effusion', 'cicatrix',
76
+ 'tracheostomy', 'aortic diseases', 'sarcoidoses', 'granulomas', 'interstitial lung disease',
77
+ 'infiltrates', 'displaced fractures', 'chronic lung disease', 'picc line',
78
+ 'intubation, gastrointestinal', 'lung diseases', 'multiple pulmonary nodules',
79
+ 'intervertebral disc degeneration', 'pulmonary emphysema', 'spine curvature', 'fibroses',
80
+ 'chronic granulomatous disease', 'degenerative disease', 'atelectases', 'ribs',
81
+ 'pulmonary arterial hypertension', 'edemas', 'pectus excavatum', 'lung granuloma',
82
+ 'plate-like atelectasis', 'enlarged heart', 'hilar calcification', 'heart valve prosthesis',
83
+ 'tuberculosis', 'old injury', 'patchy atelectasis', 'histoplasmoses', 'exostoses',
84
+ 'mastectomies', 'right atrium', 'large hiatal hernia', 'hernia, hiatal', 'aortic aneurysm',
85
+ 'lobectomy', 'spinal fusion', 'spondylosis', 'ascending aorta', 'granulomatous infection',
86
+ 'fractures, bone', 'calcified granuloma', 'degenerative joint disease',
87
+ 'intubation, intratracheal', 'others']
88
+
89
+ return static_tags_name
dataset.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from PIL import Image
4
+ import os
5
+ import json
6
+ from utils.build_vocab import Vocabulary, JsonReader
7
+ import numpy as np
8
+ from torchvision import transforms
9
+ import pickle
10
+
11
+
12
+ class ChestXrayDataSet(Dataset):
13
+ def __init__(self,
14
+ image_dir,
15
+ caption_json,
16
+ file_list,
17
+ vocabulary,
18
+ s_max=10,
19
+ n_max=50,
20
+ transforms=None):
21
+ self.image_dir = image_dir
22
+ self.caption = JsonReader(caption_json)
23
+ self.file_names, self.labels = self.__load_label_list(file_list)
24
+ self.vocab = vocabulary
25
+ self.transform = transforms
26
+ self.s_max = s_max
27
+ self.n_max = n_max
28
+
29
+ def __load_label_list(self, file_list):
30
+ labels = []
31
+ filename_list = []
32
+ with open(file_list, 'r') as f:
33
+ for line in f:
34
+ items = line.split()
35
+ image_name = items[0]
36
+ label = items[1:]
37
+ label = [int(i) for i in label]
38
+ image_name = '{}.png'.format(image_name)
39
+ filename_list.append(image_name)
40
+ labels.append(label)
41
+ return filename_list, labels
42
+
43
+ def __getitem__(self, index):
44
+ image_name = self.file_names[index]
45
+ image = Image.open(os.path.join(self.image_dir, image_name)).convert('RGB')
46
+ label = self.labels[index]
47
+ if self.transform is not None:
48
+ image = self.transform(image)
49
+ try:
50
+ text = self.caption[image_name]
51
+ except Exception as err:
52
+ text = 'normal. '
53
+
54
+ target = list()
55
+ max_word_num = 0
56
+ for i, sentence in enumerate(text.split('. ')):
57
+ if i >= self.s_max:
58
+ break
59
+ sentence = sentence.split()
60
+ if len(sentence) == 0 or len(sentence) == 1 or len(sentence) > self.n_max:
61
+ continue
62
+ tokens = list()
63
+ tokens.append(self.vocab('<start>'))
64
+ tokens.extend([self.vocab(token) for token in sentence])
65
+ tokens.append(self.vocab('<end>'))
66
+ if max_word_num < len(tokens):
67
+ max_word_num = len(tokens)
68
+ target.append(tokens)
69
+ sentence_num = len(target)
70
+ return image, image_name, list(label / np.sum(label)), target, sentence_num, max_word_num
71
+
72
+ def __len__(self):
73
+ return len(self.file_names)
74
+
75
+
76
+ def collate_fn(data):
77
+ images, image_id, label, captions, sentence_num, max_word_num = zip(*data)
78
+ images = torch.stack(images, 0)
79
+
80
+ max_sentence_num = max(sentence_num)
81
+ max_word_num = max(max_word_num)
82
+
83
+ targets = np.zeros((len(captions), max_sentence_num + 1, max_word_num))
84
+ prob = np.zeros((len(captions), max_sentence_num + 1))
85
+
86
+ for i, caption in enumerate(captions):
87
+ for j, sentence in enumerate(caption):
88
+ targets[i, j, :len(sentence)] = sentence[:]
89
+ prob[i][j] = len(sentence) > 0
90
+
91
+ return images, image_id, torch.Tensor(label), targets, prob
92
+
93
+
94
+ def get_loader(image_dir,
95
+ caption_json,
96
+ file_list,
97
+ vocabulary,
98
+ transform,
99
+ batch_size,
100
+ s_max=10,
101
+ n_max=50,
102
+ shuffle=False):
103
+ dataset = ChestXrayDataSet(image_dir=image_dir,
104
+ caption_json=caption_json,
105
+ file_list=file_list,
106
+ vocabulary=vocabulary,
107
+ s_max=s_max,
108
+ n_max=n_max,
109
+ transforms=transform)
110
+ data_loader = torch.utils.data.DataLoader(dataset=dataset,
111
+ batch_size=batch_size,
112
+ shuffle=shuffle,
113
+ collate_fn=collate_fn)
114
+ return data_loader
115
+
116
+
117
+ if __name__ == '__main__':
118
+ vocab_path = '../data/vocab.pkl'
119
+ image_dir = '../data/images'
120
+ caption_json = '../data/debugging_captions.json'
121
+ file_list = '../data/debugging.txt'
122
+ batch_size = 6
123
+ resize = 256
124
+ crop_size = 224
125
+
126
+ transform = transforms.Compose([
127
+ transforms.Resize(resize),
128
+ transforms.RandomCrop(crop_size),
129
+ transforms.RandomHorizontalFlip(),
130
+ transforms.ToTensor(),
131
+ transforms.Normalize((0.485, 0.456, 0.406),
132
+ (0.229, 0.224, 0.225))])
133
+
134
+ with open(vocab_path, 'rb') as f:
135
+ vocab = pickle.load(f)
136
+
137
+ data_loader = get_loader(image_dir=image_dir,
138
+ caption_json=caption_json,
139
+ file_list=file_list,
140
+ vocabulary=vocab,
141
+ transform=transform,
142
+ batch_size=batch_size,
143
+ shuffle=False)
144
+
145
+ for i, (image, image_id, label, target, prob) in enumerate(data_loader):
146
+ print(image.shape)
147
+ print(image_id)
148
+ print(label)
149
+ print(target)
150
+ print(prob)
151
+ break
extractor.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cec084672668d0a2d9c4e2451f1f9c71069fff1ad0bb09759953625b4ec731c
3
+ size 348017030
loss.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch.nn.modules import loss
4
+
5
+
6
+ class WARPLoss(loss.Module):
7
+ def __init__(self, num_labels=204):
8
+ super(WARPLoss, self).__init__()
9
+ self.rank_weights = [1.0 / 1]
10
+ for i in range(1, num_labels):
11
+ self.rank_weights.append(self.rank_weights[i - 1] + (1.0 / i + 1))
12
+
13
+ def forward(self, input, target) -> object:
14
+ """
15
+
16
+ :rtype:
17
+ :param input: Deep features tensor Variable of size batch x n_attrs.
18
+ :param target: Ground truth tensor Variable of size batch x n_attrs.
19
+ :return:
20
+ """
21
+ batch_size = target.size()[0]
22
+ n_labels = target.size()[1]
23
+ max_num_trials = n_labels - 1
24
+ loss = 0.0
25
+
26
+ for i in range(batch_size):
27
+
28
+ for j in range(n_labels):
29
+ if target[i, j] == 1:
30
+
31
+ neg_labels_idx = np.array([idx for idx, v in enumerate(target[i, :]) if v == 0])
32
+ neg_idx = np.random.choice(neg_labels_idx, replace=False)
33
+ sample_score_margin = 1 - input[i, j] + input[i, neg_idx]
34
+ num_trials = 0
35
+
36
+ while sample_score_margin < 0 and num_trials < max_num_trials:
37
+ neg_idx = np.random.choice(neg_labels_idx, replace=False)
38
+ num_trials += 1
39
+ sample_score_margin = 1 - input[i, j] + input[i, neg_idx]
40
+
41
+ r_j = np.floor(max_num_trials / num_trials)
42
+ weight = self.rank_weights[r_j]
43
+
44
+ for k in range(n_labels):
45
+ if target[i, k] == 0:
46
+ score_margin = 1 - input[i, j] + input[i, k]
47
+ loss += (weight * torch.clamp(score_margin, min=0.0))
48
+ return loss
49
+
50
+
51
+ class MultiLabelSoftmaxRegressionLoss(loss.Module):
52
+ def __init__(self):
53
+ super(MultiLabelSoftmaxRegressionLoss, self).__init__()
54
+
55
+ def forward(self, input, target) -> object:
56
+ return -1 * torch.sum(input * target)
57
+
58
+
59
+ class LossFactory(object):
60
+ def __init__(self, type, num_labels=156):
61
+ self.type = type
62
+ if type == 'BCE':
63
+ # self.activation_func = torch.nn.Sigmoid()
64
+ self.loss = torch.nn.BCELoss()
65
+ elif type == 'CE':
66
+ self.loss = torch.nn.CrossEntropyLoss()
67
+ elif type == 'WARP':
68
+ self.activation_func = torch.nn.Softmax()
69
+ self.loss = WARPLoss(num_labels=num_labels)
70
+ elif type == 'MSR':
71
+ self.activation_func = torch.nn.LogSoftmax()
72
+ self.loss = MultiLabelSoftmaxRegressionLoss()
73
+
74
+ def compute_loss(self, output, target):
75
+ # output = self.activation_func(output)
76
+ # if self.type == 'NLL' or self.type == 'WARP' or self.type == 'MSR':
77
+ # target /= torch.sum(target, 1).view(-1, 1)
78
+ return self.loss(output, target)
mlc.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b099a01f553ee36e1b4ecedae8b19f1decd6fb35d1aebd3ed42ae9fa6bf61080
3
+ size 346714524
models.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ import numpy as np
5
+ from torch.autograd import Variable
6
+ import torchvision.models as models
7
+ import transformers
8
+ import torchvision.transforms
9
+
10
+ import torchxrayvision as xrv
11
+ from transformers import ViTModel, ViTConfig
12
+
13
+
14
+
15
+ class VisualFeatureExtractor(nn.Module):
16
+ def __init__(self, model_name='densenet201', pretrained=False):
17
+ super(VisualFeatureExtractor, self).__init__()
18
+ self.model_name = 'chexnet'
19
+ self.pretrained = pretrained
20
+ self.model, self.out_features, self.avg_func, self.bn, self.linear = self.__get_model()
21
+ self.activation = nn.ReLU()
22
+
23
+ def __get_model(self):
24
+ model = None
25
+ out_features = None
26
+ func = None
27
+
28
+ if self.model_name == 'resnet152':
29
+ resnet = models.resnet152(pretrained=self.pretrained)
30
+ modules = list(resnet.children())[:-2]
31
+ model = nn.Sequential(*modules)
32
+ out_features = resnet.fc.in_features
33
+ func = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0)
34
+
35
+
36
+ elif self.model_name == 'densenet201':
37
+ densenet = models.densenet201(pretrained=self.pretrained)
38
+ modules = list(densenet.features)
39
+ model = nn.Sequential(*modules)
40
+ func = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0)
41
+ out_features = densenet.classifier.in_features
42
+
43
+
44
+ elif self.model_name == 'chexnet':
45
+ print("vit chest xray pretrained model loading")
46
+ # Load the Vision Transformer (ViT) model configuration
47
+ config = ViTConfig.from_pretrained('nickmuchi/vit-finetuned-chest-xray-pneumonia')
48
+
49
+ # Initialize the ViT model with the specific configuration
50
+ vit_model = ViTModel(config)
51
+
52
+ # Load the state dict specifically, excluding 'classifier.bias', 'classifier.weight'
53
+ state_dict = torch.load('vit-coatten/utils/pytorch_model.bin', map_location=torch.device('cpu'))
54
+ state_dict = {k: v for k, v in state_dict.items() if not k.startswith('classifier')}
55
+ vit_model.load_state_dict(state_dict, strict=False)
56
+
57
+ model = vit_model
58
+ out_features = config.hidden_size
59
+
60
+ linear = nn.Linear(in_features=out_features, out_features=out_features)
61
+ bn = nn.BatchNorm1d(num_features=out_features, momentum=0.1)
62
+
63
+ return model, out_features, func, bn, linear
64
+
65
+ def forward(self, images):
66
+ """
67
+ :param images: Input images
68
+ :return: visual_features, avg_features
69
+ """
70
+ model_output = self.model(images)
71
+
72
+ # Extract the pooler_output
73
+
74
+ pooler_output = model_output.pooler_output
75
+
76
+ # Apply the linear layer, batch normalization, and activation
77
+ avg_features = self.activation(self.bn(self.linear(pooler_output)))
78
+
79
+ return model_output.last_hidden_state, avg_features
80
+
81
+ # def forward(self, images):
82
+ # """
83
+ # :param images:
84
+ # :return:
85
+ # """
86
+ # visual_features = self.model(images)
87
+
88
+ # avg_features = self.avg_func(visual_features).squeeze()
89
+ # # avg_features = self.activation(self.bn(self.linear(visual_features)))
90
+
91
+ # return visual_features, avg_features
92
+
93
+
94
+ class MLC(nn.Module):
95
+ def __init__(self,
96
+ classes=210,
97
+ sementic_features_dim=512,
98
+ fc_in_features=2048,
99
+ k=10,
100
+ ):
101
+ super(MLC, self).__init__()
102
+ pretrained_model_name="nickmuchi/vit-finetuned-chest-xray-pneumonia"
103
+ vit_config = ViTConfig.from_pretrained(pretrained_model_name)
104
+ self.vit = ViTModel(vit_config)
105
+
106
+ # Adjust the classifier to your number of classes
107
+ self.classifier = nn.Linear(in_features=vit_config.hidden_size, out_features=classes)
108
+ self.embed = nn.Embedding(classes, sementic_features_dim)
109
+ self.k = k
110
+ self.sigmoid = nn.Sigmoid()
111
+ self.__init_weight()
112
+
113
+ def __init_weight(self):
114
+ nn.init.xavier_uniform_(self.classifier.weight)
115
+ if self.classifier.bias is not None:
116
+ self.classifier.bias.data.fill_(0)
117
+
118
+ def forward(self, avg_features):
119
+
120
+
121
+ tags = self.sigmoid(self.classifier(avg_features))
122
+ semantic_features = self.embed(torch.topk(tags, self.k)[1])
123
+ return tags, semantic_features
124
+
125
+ # class MLC(nn.Module):
126
+ # def __init__(self,
127
+ # classes=210,
128
+ # sementic_features_dim=512,
129
+ # fc_in_features=2048,
130
+ # k=10):
131
+ # super(MLC, self).__init__()
132
+ # self.classifier = nn.Linear(in_features=fc_in_features, out_features=classes)
133
+ # self.embed = nn.Embedding(classes, sementic_features_dim)
134
+ # self.k = k
135
+ # self.sigmoid = nn.Sigmoid()
136
+ # self.__init_weight()
137
+
138
+ # def __init_weight(self):
139
+ # # Example: Initialize weights with a different strategy
140
+ # nn.init.xavier_uniform_(self.classifier.weight)
141
+ # if self.classifier.bias is not None:
142
+ # self.classifier.bias.data.fill_(0)
143
+
144
+ # def forward(self, avg_features):
145
+ # tags = self.sigmoid(self.classifier(avg_features))
146
+ # semantic_features = self.embed(torch.topk(tags, self.k)[1])
147
+ # return tags, semantic_features
148
+
149
+
150
+ class CoAttention(nn.Module):
151
+ def __init__(self,
152
+ version='v1',
153
+ embed_size=512,
154
+ hidden_size=512,
155
+ visual_size=2048,
156
+ k=10,
157
+ momentum=0.1):
158
+ super(CoAttention, self).__init__()
159
+ self.version = version
160
+ self.W_v = nn.Linear(in_features=visual_size, out_features=visual_size)
161
+ self.bn_v = nn.BatchNorm1d(num_features=visual_size, momentum=momentum)
162
+
163
+ self.W_v_h = nn.Linear(in_features=hidden_size, out_features=visual_size)
164
+ self.bn_v_h = nn.BatchNorm1d(num_features=visual_size, momentum=momentum)
165
+
166
+ self.W_v_att = nn.Linear(in_features=visual_size, out_features=visual_size)
167
+ self.bn_v_att = nn.BatchNorm1d(num_features=visual_size, momentum=momentum)
168
+
169
+ self.W_a = nn.Linear(in_features=hidden_size, out_features=hidden_size)
170
+ self.bn_a = nn.BatchNorm1d(num_features=k, momentum=momentum)
171
+
172
+ self.W_a_h = nn.Linear(in_features=hidden_size, out_features=hidden_size)
173
+ self.bn_a_h = nn.BatchNorm1d(num_features=1, momentum=momentum)
174
+
175
+ self.W_a_att = nn.Linear(in_features=hidden_size, out_features=hidden_size)
176
+ self.bn_a_att = nn.BatchNorm1d(num_features=k, momentum=momentum)
177
+
178
+ # self.W_fc = nn.Linear(in_features=visual_size, out_features=embed_size) # for v3
179
+ self.W_fc = nn.Linear(in_features=visual_size + hidden_size, out_features=embed_size)
180
+ self.bn_fc = nn.BatchNorm1d(num_features=embed_size, momentum=momentum)
181
+
182
+ self.tanh = nn.Tanh()
183
+ self.softmax = nn.Softmax()
184
+
185
+ self.__init_weight()
186
+
187
+ def __init_weight(self):
188
+ self.W_v.weight.data.uniform_(-0.1, 0.1)
189
+ self.W_v.bias.data.fill_(0)
190
+
191
+ self.W_v_h.weight.data.uniform_(-0.1, 0.1)
192
+ self.W_v_h.bias.data.fill_(0)
193
+
194
+ self.W_v_att.weight.data.uniform_(-0.1, 0.1)
195
+ self.W_v_att.bias.data.fill_(0)
196
+
197
+ self.W_a.weight.data.uniform_(-0.1, 0.1)
198
+ self.W_a.bias.data.fill_(0)
199
+
200
+ self.W_a_h.weight.data.uniform_(-0.1, 0.1)
201
+ self.W_a_h.bias.data.fill_(0)
202
+
203
+ self.W_a_att.weight.data.uniform_(-0.1, 0.1)
204
+ self.W_a_att.bias.data.fill_(0)
205
+
206
+ self.W_fc.weight.data.uniform_(-0.1, 0.1)
207
+ self.W_fc.bias.data.fill_(0)
208
+
209
+ def forward(self, avg_features, semantic_features, h_sent):
210
+ if self.version == 'v1':
211
+ return self.v1(avg_features, semantic_features, h_sent)
212
+ elif self.version == 'v2':
213
+ return self.v2(avg_features, semantic_features, h_sent)
214
+ elif self.version == 'v3':
215
+ return self.v3(avg_features, semantic_features, h_sent)
216
+ elif self.version == 'v4':
217
+ return self.v4(avg_features, semantic_features, h_sent)
218
+ elif self.version == 'v5':
219
+ return self.v5(avg_features, semantic_features, h_sent)
220
+
221
+ def v1(self, avg_features, semantic_features, h_sent) -> object:
222
+ """
223
+ only training
224
+ :rtype: object
225
+ """
226
+ W_v = self.bn_v(self.W_v(avg_features))
227
+ W_v_h = self.bn_v_h(self.W_v_h(h_sent.squeeze(1)))
228
+
229
+ alpha_v = self.softmax(self.bn_v_att(self.W_v_att(self.tanh(W_v + W_v_h))))
230
+ v_att = torch.mul(alpha_v, avg_features)
231
+
232
+ W_a_h = self.bn_a_h(self.W_a_h(h_sent))
233
+ W_a = self.bn_a(self.W_a(semantic_features))
234
+ alpha_a = self.softmax(self.bn_a_att(self.W_a_att(self.tanh(torch.add(W_a_h, W_a)))))
235
+ a_att = torch.mul(alpha_a, semantic_features).sum(1)
236
+
237
+ ctx = self.W_fc(torch.cat([v_att, a_att], dim=1))
238
+
239
+ return ctx, alpha_v, alpha_a
240
+
241
+ def v2(self, avg_features, semantic_features, h_sent) -> object:
242
+ """
243
+ no bn
244
+ :rtype: object
245
+ """
246
+ W_v = self.W_v(avg_features)
247
+ W_v_h = self.W_v_h(h_sent.squeeze(1))
248
+
249
+ alpha_v = self.softmax(self.W_v_att(self.tanh(W_v + W_v_h)))
250
+ v_att = torch.mul(alpha_v, avg_features)
251
+
252
+ W_a_h = self.W_a_h(h_sent)
253
+ W_a = self.W_a(semantic_features)
254
+ alpha_a = self.softmax(self.W_a_att(self.tanh(torch.add(W_a_h, W_a))))
255
+ a_att = torch.mul(alpha_a, semantic_features).sum(1)
256
+
257
+ ctx = self.W_fc(torch.cat([v_att, a_att], dim=1))
258
+
259
+ return ctx, alpha_v, alpha_a
260
+
261
+ def v3(self, avg_features, semantic_features, h_sent) -> object:
262
+ """
263
+
264
+ :rtype: object
265
+ """
266
+ W_v = self.bn_v(self.W_v(avg_features))
267
+ W_v_h = self.bn_v_h(self.W_v_h(h_sent.squeeze(1)))
268
+
269
+ alpha_v = self.softmax(self.W_v_att(self.tanh(W_v + W_v_h)))
270
+ v_att = torch.mul(alpha_v, avg_features)
271
+
272
+ W_a_h = self.bn_a_h(self.W_a_h(h_sent))
273
+ W_a = self.bn_a(self.W_a(semantic_features))
274
+ alpha_a = self.softmax(self.W_a_att(self.tanh(torch.add(W_a_h, W_a))))
275
+ a_att = torch.mul(alpha_a, semantic_features).sum(1)
276
+
277
+ ctx = self.W_fc(torch.cat([v_att, a_att], dim=1))
278
+
279
+ return ctx, alpha_v, alpha_a
280
+
281
+ def v4(self, avg_features, semantic_features, h_sent):
282
+ W_v = self.W_v(avg_features)
283
+ W_v_h = self.W_v_h(h_sent.squeeze(1))
284
+
285
+ alpha_v = self.softmax(self.W_v_att(self.tanh(torch.add(W_v, W_v_h))))
286
+ v_att = torch.mul(alpha_v, avg_features)
287
+
288
+ W_a_h = self.W_a_h(h_sent)
289
+ W_a = self.W_a(semantic_features)
290
+ alpha_a = self.softmax(self.W_a_att(self.tanh(torch.add(W_a_h, W_a))))
291
+ a_att = torch.mul(alpha_a, semantic_features).sum(1)
292
+
293
+ ctx = self.W_fc(torch.cat([v_att, a_att], dim=1))
294
+
295
+ return ctx, alpha_v, alpha_a
296
+
297
+ def v5(self, avg_features, semantic_features, h_sent):
298
+ W_v = self.W_v(avg_features)
299
+ W_v_h = self.W_v_h(h_sent.squeeze(1))
300
+
301
+ alpha_v = self.softmax(self.W_v_att(self.tanh(self.bn_v(torch.add(W_v, W_v_h)))))
302
+ v_att = torch.mul(alpha_v, avg_features)
303
+
304
+ W_a_h = self.W_a_h(h_sent)
305
+ W_a = self.W_a(semantic_features)
306
+ alpha_a = self.softmax(self.W_a_att(self.tanh(self.bn_a(torch.add(W_a_h, W_a)))))
307
+ a_att = torch.mul(alpha_a, semantic_features).sum(1)
308
+
309
+ ctx = self.W_fc(torch.cat([v_att, a_att], dim=1))
310
+
311
+ return ctx, alpha_v, alpha_a
312
+
313
+
314
+ class SentenceLSTM(nn.Module):
315
+ def __init__(self,
316
+ version='v1',
317
+ embed_size=512,
318
+ hidden_size=512,
319
+ num_layers=1,
320
+ dropout=0.3,
321
+ momentum=0.1):
322
+ super(SentenceLSTM, self).__init__()
323
+ self.version = version
324
+
325
+ self.lstm = nn.LSTM(input_size=embed_size,
326
+ hidden_size=hidden_size,
327
+ num_layers=num_layers,
328
+ dropout=dropout)
329
+
330
+ self.W_t_h = nn.Linear(in_features=hidden_size,
331
+ out_features=embed_size,
332
+ bias=True)
333
+ self.bn_t_h = nn.BatchNorm1d(num_features=1, momentum=momentum)
334
+
335
+ self.W_t_ctx = nn.Linear(in_features=embed_size,
336
+ out_features=embed_size,
337
+ bias=True)
338
+ self.bn_t_ctx = nn.BatchNorm1d(num_features=1, momentum=momentum)
339
+
340
+ self.W_stop_s_1 = nn.Linear(in_features=hidden_size,
341
+ out_features=embed_size,
342
+ bias=True)
343
+ self.bn_stop_s_1 = nn.BatchNorm1d(num_features=1, momentum=momentum)
344
+
345
+ self.W_stop_s = nn.Linear(in_features=hidden_size,
346
+ out_features=embed_size,
347
+ bias=True)
348
+ self.bn_stop_s = nn.BatchNorm1d(num_features=1, momentum=momentum)
349
+
350
+ self.W_stop = nn.Linear(in_features=embed_size,
351
+ out_features=2,
352
+ bias=True)
353
+ self.bn_stop = nn.BatchNorm1d(num_features=1, momentum=momentum)
354
+
355
+ self.W_topic = nn.Linear(in_features=embed_size,
356
+ out_features=embed_size,
357
+ bias=True)
358
+ self.bn_topic = nn.BatchNorm1d(num_features=1, momentum=momentum)
359
+
360
+ self.sigmoid = nn.Sigmoid()
361
+ self.tanh = nn.Tanh()
362
+ self.__init_weight()
363
+
364
+ def __init_weight(self):
365
+ self.W_t_h.weight.data.uniform_(-0.1, 0.1)
366
+ self.W_t_h.bias.data.fill_(0)
367
+
368
+ self.W_t_ctx.weight.data.uniform_(-0.1, 0.1)
369
+ self.W_t_ctx.bias.data.fill_(0)
370
+
371
+ self.W_stop_s_1.weight.data.uniform_(-0.1, 0.1)
372
+ self.W_stop_s_1.bias.data.fill_(0)
373
+
374
+ self.W_stop_s.weight.data.uniform_(-0.1, 0.1)
375
+ self.W_stop_s.bias.data.fill_(0)
376
+
377
+ self.W_stop.weight.data.uniform_(-0.1, 0.1)
378
+ self.W_stop.bias.data.fill_(0)
379
+
380
+ self.W_topic.weight.data.uniform_(-0.1, 0.1)
381
+ self.W_topic.bias.data.fill_(0)
382
+
383
+ def forward(self, ctx, prev_hidden_state, states=None) -> object:
384
+ """
385
+ :rtype: object
386
+ """
387
+ if self.version == 'v1':
388
+ return self.v1(ctx, prev_hidden_state, states)
389
+ elif self.version == 'v2':
390
+ return self.v2(ctx, prev_hidden_state, states)
391
+ elif self.version == 'v3':
392
+ return self.v3(ctx, prev_hidden_state, states)
393
+
394
+ def v1(self, ctx, prev_hidden_state, states=None):
395
+ """
396
+ v1 (only training)
397
+ :param ctx:
398
+ :param prev_hidden_state:
399
+ :param states:
400
+ :return:
401
+ """
402
+ ctx = ctx.unsqueeze(1)
403
+ hidden_state, states = self.lstm(ctx, states)
404
+ topic = self.W_topic(self.sigmoid(self.bn_t_h(self.W_t_h(hidden_state))
405
+ + self.bn_t_ctx(self.W_t_ctx(ctx))))
406
+ p_stop = self.W_stop(self.sigmoid(self.bn_stop_s_1(self.W_stop_s_1(prev_hidden_state))
407
+ + self.bn_stop_s(self.W_stop_s(hidden_state))))
408
+ return topic, p_stop, hidden_state, states
409
+
410
+ def v2(self, ctx, prev_hidden_state, states=None):
411
+ """
412
+ v2
413
+ :rtype: object
414
+ """
415
+ ctx = ctx.unsqueeze(1)
416
+ hidden_state, states = self.lstm(ctx, states)
417
+ topic = self.bn_topic(self.W_topic(self.tanh(self.bn_t_h(self.W_t_h(hidden_state)
418
+ + self.W_t_ctx(ctx)))))
419
+ p_stop = self.bn_stop(self.W_stop(self.tanh(self.bn_stop_s(self.W_stop_s_1(prev_hidden_state)
420
+ + self.W_stop_s(hidden_state)))))
421
+ return topic, p_stop, hidden_state, states
422
+
423
+ def v3(self, ctx, prev_hidden_state, states=None):
424
+ """
425
+ v3
426
+ :rtype: object
427
+ """
428
+ ctx = ctx.unsqueeze(1)
429
+ hidden_state, states = self.lstm(ctx, states)
430
+ topic = self.W_topic(self.tanh(self.W_t_h(hidden_state) + self.W_t_ctx(ctx)))
431
+ p_stop = self.W_stop(self.tanh(self.W_stop_s_1(prev_hidden_state) + self.W_stop_s(hidden_state)))
432
+ return topic, p_stop, hidden_state, states
433
+
434
+
435
+ class WordLSTM(nn.Module):
436
+ def __init__(self,
437
+ embed_size,
438
+ hidden_size,
439
+ vocab_size,
440
+ num_layers,
441
+ n_max=50):
442
+ super(WordLSTM, self).__init__()
443
+ self.embed = nn.Embedding(vocab_size, embed_size)
444
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
445
+ self.linear = nn.Linear(hidden_size, vocab_size)
446
+ self.__init_weights()
447
+ self.n_max = n_max
448
+ self.vocab_size = vocab_size
449
+
450
+ def __init_weights(self):
451
+ self.embed.weight.data.uniform_(-0.1, 0.1)
452
+ self.linear.weight.data.uniform_(-0.1, 0.1)
453
+ self.linear.bias.data.fill_(0)
454
+
455
+ def forward(self, topic_vec, captions):
456
+ embeddings = self.embed(captions)
457
+ embeddings = torch.cat((topic_vec, embeddings), 1)
458
+ hidden, _ = self.lstm(embeddings)
459
+ outputs = self.linear(hidden[:, -1, :])
460
+ return outputs
461
+
462
+ def sample(self, features, start_tokens):
463
+ sampled_ids = np.zeros((np.shape(features)[0], self.n_max))
464
+ sampled_ids[:, 0] = start_tokens.view(-1, )
465
+ predicted = start_tokens
466
+ embeddings = features
467
+ embeddings = embeddings
468
+
469
+ for i in range(1, self.n_max):
470
+ predicted = self.embed(predicted)
471
+ embeddings = torch.cat([embeddings, predicted], dim=1)
472
+ hidden_states, _ = self.lstm(embeddings)
473
+ hidden_states = hidden_states[:, -1, :]
474
+ outputs = self.linear(hidden_states)
475
+ predicted = torch.max(outputs, 1)[1]
476
+ sampled_ids[:, i] = predicted
477
+ predicted = predicted.unsqueeze(1)
478
+ return sampled_ids
479
+
480
+
481
+ if __name__ == '__main__':
482
+ import torchvision.transforms as transforms
483
+
484
+ import warnings
485
+ warnings.filterwarnings("ignore")
486
+ #
487
+ extractor = VisualFeatureExtractor(model_name='resnet152')
488
+ mlc = MLC(fc_in_features=extractor.out_features)
489
+ co_att = CoAttention(visual_size=extractor.out_features)
490
+ sent_lstm = SentenceLSTM()
491
+ word_lstm = WordLSTM(embed_size=512, hidden_size=512, vocab_size=100, num_layers=1)
492
+
493
+ images = torch.randn((4, 3, 224, 224))
494
+ captions = torch.ones((4, 10)).long()
495
+ hidden_state = torch.randn((4, 1, 512))
496
+
497
+ # # image_file = '../data/images/CXR2814_IM-1239-1001.png'
498
+ # # # images = Image.open(image_file).convert('RGB')
499
+ # # # captions = torch.ones((1, 10)).long()
500
+ # # # hidden_state = torch.randn((10, 512))
501
+ # #
502
+ # norm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
503
+ #
504
+ # transform = transforms.Compose([
505
+ # transforms.Resize(256),
506
+ # transforms.TenCrop(224),
507
+ # transforms.Lambda(lambda crops: torch.stack([norm(transforms.ToTensor()(crop)) for crop in crops])),
508
+ # ])
509
+
510
+ # images = transform(images)
511
+ # images.unsqueeze_(0)
512
+ #
513
+ # # bs, ncrops, c, h, w = images.size()
514
+ # # images = images.view(-1, c, h, w)
515
+ #
516
+ print("images:{}".format(images.shape))
517
+ print("captions:{}".format(captions.shape))
518
+ print("hidden_states:{}".format(hidden_state.shape))
519
+
520
+ visual_features, avg_features = extractor.forward(images)
521
+
522
+ print("visual_features:{}".format(visual_features.shape))
523
+ print("avg features:{}".format(avg_features.shape))
524
+
525
+ tags, semantic_features = mlc.forward(avg_features)
526
+
527
+ print("tags:{}".format(tags.shape))
528
+ print("semantic_features:{}".format(semantic_features.shape))
529
+
530
+ ctx, alpht_v, alpht_a = co_att.forward(avg_features, semantic_features, hidden_state)
531
+
532
+ print("ctx:{}".format(ctx.shape))
533
+ print("alpht_v:{}".format(alpht_v.shape))
534
+ print("alpht_a:{}".format(alpht_a.shape))
535
+
536
+ topic, p_stop, hidden_state, states = sent_lstm.forward(ctx, hidden_state)
537
+ # p_stop_avg = p_stop.view(bs, ncrops, -1).mean(1)
538
+
539
+ print("Topic:{}".format(topic.shape))
540
+ print("P_STOP:{}".format(p_stop.shape))
541
+ # print("P_stop_avg:{}".format(p_stop_avg.shape))
542
+
543
+ words = word_lstm.forward(topic, captions)
544
+ print("words:{}".format(words.shape))
545
+
546
+ cam = torch.mul(visual_features, alpht_v.view(alpht_v.shape[0], alpht_v.shape[1], 1, 1)).sum(1)
547
+ cam.squeeze_()
548
+ cam = cam.cpu().data.numpy()
549
+ for i in range(cam.shape[0]):
550
+ heatmap = cam[i]
551
+ heatmap = heatmap / np.max(heatmap)
552
+ print(heatmap.shape)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4b052d087f705ba06b16aa03c01dfdf37f36f0f8ab7b136cda1524bba8ab09d
3
+ size 343280753
sentence.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33aa4fa7339a3958307299a75d72cd1d6d0d144cadb101cfbe5c40af205741bc
3
+ size 22081920
tester.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import time
4
+ import pickle
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ from torch.utils.data import DataLoader
8
+ from torch.autograd import Variable
9
+ from PIL import Image
10
+ import cv2
11
+
12
+ from utils.models import *
13
+ from utils.dataset import *
14
+ from utils.loss import *
15
+ from utils.build_tag import *
16
+
17
+
18
+ class CaptionSampler(object):
19
+ def __init__(self):
20
+ # Default configuration values
21
+ self.args = {
22
+ "model_dir": "/Users/jkottu/Desktop/image-captioning-chest-xrays/vit-coatten",
23
+ "image_dir": "./data/images",
24
+ "caption_json": "data/new_data/captions.json",
25
+ "vocab_path": "/Users/jkottu/Desktop/image-captioning-chest-xrays/vit-coatten/vocab.pkl",
26
+ "file_lists": "data/new_data/test_data.txt",
27
+ "load_model_path": "train_best_loss.pth.tar",
28
+ "resize": 224,
29
+ "cam_size": 224,
30
+ "generate_dir": "cam",
31
+ "result_path": "results",
32
+ "result_name": "debug",
33
+ "momentum": 0.1,
34
+ "visual_model_name": "densenet201",
35
+ "pretrained": False,
36
+ "classes": 210,
37
+ "sementic_features_dim": 512,
38
+ "k": 10,
39
+ "attention_version": "v4",
40
+ "embed_size": 512,
41
+ "hidden_size": 512,
42
+ "sent_version": "v1",
43
+ "sentence_num_layers": 2,
44
+ "dropout": 0.1,
45
+ "word_num_layers": 1,
46
+ "s_max": 10,
47
+ "n_max": 30,
48
+ "batch_size": 8,
49
+ "lambda_tag": 10000,
50
+ "lambda_stop": 10,
51
+ "lambda_word": 1,
52
+ "cuda": False # Keep CUDA disabled by default
53
+ }
54
+
55
+ self.vocab = self.__init_vocab()
56
+ self.tagger = self.__init_tagger()
57
+ self.transform = self.__init_transform()
58
+ self.model_state_dict = self.__load_mode_state_dict()
59
+
60
+ self.extractor = self.__init_visual_extractor()
61
+ self.mlc = self.__init_mlc()
62
+ self.co_attention = self.__init_co_attention()
63
+ self.sentence_model = self.__init_sentence_model()
64
+ self.word_model = self.__init_word_word()
65
+
66
+ self.ce_criterion = self._init_ce_criterion()
67
+ self.mse_criterion = self._init_mse_criterion()
68
+
69
+ @staticmethod
70
+ def _init_ce_criterion():
71
+ return nn.CrossEntropyLoss(size_average=False, reduce=False)
72
+
73
+ @staticmethod
74
+ def _init_mse_criterion():
75
+ return nn.MSELoss()
76
+
77
+
78
+ def sample(self, image_file):
79
+ self.extractor.eval()
80
+ self.mlc.eval()
81
+ self.co_attention.eval()
82
+ self.sentence_model.eval()
83
+ self.word_model.eval()
84
+
85
+ # imageData = Image.open(image_file).convert('RGB')
86
+ imageData = self.transform(imageData)
87
+ imageData = imageData.unsqueeze_(0)
88
+
89
+ print(imageData.shape)
90
+
91
+ image = self.__to_var(imageData, requires_grad=False)
92
+
93
+ visual_features, avg_features = self.extractor.forward(image)
94
+
95
+ tags, semantic_features = self.mlc(avg_features)
96
+ sentence_states = None
97
+ prev_hidden_states = self.__to_var(torch.zeros(image.shape[0], 1, self.args["hidden_size"]))
98
+
99
+ pred_sentences = []
100
+
101
+ for i in range(self.args["s_max"]):
102
+ ctx, alpha_v, alpha_a = self.co_attention.forward(avg_features, semantic_features, prev_hidden_states)
103
+ topic, p_stop, hidden_state, sentence_states = self.sentence_model.forward(ctx,
104
+ prev_hidden_states,
105
+ sentence_states)
106
+ p_stop = p_stop.squeeze(1)
107
+ p_stop = torch.max(p_stop, 1)[1].unsqueeze(1)
108
+
109
+ start_tokens = np.zeros((topic.shape[0], 1))
110
+ start_tokens[:, 0] = self.vocab('<start>')
111
+ start_tokens = self.__to_var(torch.Tensor(start_tokens).long(), requires_grad=False)
112
+
113
+ sampled_ids = self.word_model.sample(topic, start_tokens)
114
+ prev_hidden_states = hidden_state
115
+
116
+ sampled_ids = sampled_ids * p_stop.numpy()
117
+
118
+
119
+ pred_sentences.append(self.__vec2sent(sampled_ids[0]))
120
+
121
+ return pred_sentences
122
+
123
+
124
+ def __init_cam_path(self, image_file):
125
+ generate_dir = os.path.join(self.args["model_dir"], self.args["generate_dir"])
126
+ if not os.path.exists(generate_dir):
127
+ os.makedirs(generate_dir)
128
+
129
+ image_dir = os.path.join(generate_dir, image_file)
130
+
131
+ if not os.path.exists(image_dir):
132
+ os.makedirs(image_dir)
133
+ return image_dir
134
+
135
+ def __save_json(self, result):
136
+ result_path = os.path.join(self.args["model_dir"], self.args["result_path"])
137
+ if not os.path.exists(result_path):
138
+ os.makedirs(result_path)
139
+ with open(os.path.join(result_path, '{}.json'.format(self.args["result_name"])), 'w') as f:
140
+ json.dump(result, f)
141
+
142
+ def __load_mode_state_dict(self):
143
+ try:
144
+ model_state_dict = torch.load(os.path.join(self.args["model_dir"], self.args["load_model_path"]), map_location=torch.device('cpu'))
145
+ print("[Load Model-{} Succeed!]".format(self.args["load_model_path"]))
146
+ print("Load From Epoch {}".format(model_state_dict['epoch']))
147
+ return model_state_dict
148
+ except Exception as err:
149
+ print("[Load Model Failed] {}".format(err))
150
+ raise err
151
+
152
+ def __init_tagger(self):
153
+ return Tag()
154
+
155
+ def __vec2sent(self, array):
156
+ sampled_caption = []
157
+ for word_id in array:
158
+ word = self.vocab.get_word_by_id(word_id)
159
+ if word == '<start>':
160
+ continue
161
+ if word == '<end>' or word == '<pad>':
162
+ break
163
+ sampled_caption.append(word)
164
+ return ' '.join(sampled_caption)
165
+
166
+ def __init_vocab(self):
167
+ with open(self.args["vocab_path"], 'rb') as f:
168
+ vocab = pickle.load(f)
169
+ return vocab
170
+
171
+ def __init_data_loader(self, file_list):
172
+ data_loader = get_loader(image_dir=self.args.image_dir,
173
+ caption_json=self.args.caption_json,
174
+ file_list=file_list,
175
+ vocabulary=self.vocab,
176
+ transform=self.transform,
177
+ batch_size=self.args.batch_size,
178
+ s_max=self.args.s_max,
179
+ n_max=self.args.n_max,
180
+ shuffle=False)
181
+ return data_loader
182
+
183
+ def __init_transform(self):
184
+ transform = transforms.Compose([
185
+ transforms.Resize((self.args["resize"], self.args["resize"])),
186
+ transforms.ToTensor(),
187
+ transforms.Normalize((0.485, 0.456, 0.406),
188
+ (0.229, 0.224, 0.225))])
189
+ return transform
190
+
191
+ def __to_var(self, x, requires_grad=True):
192
+ if self.args["cuda"]:
193
+ x = x.cuda()
194
+ return Variable(x, requires_grad=requires_grad)
195
+
196
+ def __init_visual_extractor(self):
197
+ model = VisualFeatureExtractor(model_name=self.args["visual_model_name"],
198
+ pretrained=self.args["pretrained"])
199
+
200
+ if self.model_state_dict is not None:
201
+ print("Visual Extractor Loaded!")
202
+ model.load_state_dict(self.model_state_dict['extractor'])
203
+
204
+ if self.args["cuda"]:
205
+ model = model.cuda()
206
+
207
+ return model
208
+
209
+ def __init_mlc(self):
210
+ model = MLC(classes=self.args["classes"],
211
+ sementic_features_dim=self.args["sementic_features_dim"],
212
+ fc_in_features=self.extractor.out_features,
213
+ k=self.args["k"])
214
+
215
+ if self.model_state_dict is not None:
216
+ print("MLC Loaded!")
217
+ model.load_state_dict(self.model_state_dict['mlc'])
218
+
219
+ if self.args["cuda"]:
220
+ model = model.cuda()
221
+
222
+ return model
223
+
224
+ def __init_co_attention(self):
225
+ model = CoAttention(version=self.args["attention_version"],
226
+ embed_size=self.args["embed_size"],
227
+ hidden_size=self.args["hidden_size"],
228
+ visual_size=self.extractor.out_features,
229
+ k=self.args["k"],
230
+ momentum=self.args["momentum"])
231
+
232
+ if self.model_state_dict is not None:
233
+ print("Co-Attention Loaded!")
234
+ model.load_state_dict(self.model_state_dict['co_attention'])
235
+
236
+ if self.args["cuda"]:
237
+ model = model.cuda()
238
+
239
+ return model
240
+
241
+ def __init_sentence_model(self):
242
+ model = SentenceLSTM(version=self.args["sent_version"],
243
+ embed_size=self.args["embed_size"],
244
+ hidden_size=self.args["hidden_size"],
245
+ num_layers=self.args["sentence_num_layers"],
246
+ dropout=self.args["dropout"],
247
+ momentum=self.args["momentum"])
248
+
249
+ if self.model_state_dict is not None:
250
+ print("Sentence Model Loaded!")
251
+ model.load_state_dict(self.model_state_dict['sentence_model'])
252
+
253
+ if self.args["cuda"]:
254
+ model = model.cuda()
255
+
256
+ return model
257
+
258
+ def __init_word_word(self):
259
+ model = WordLSTM(vocab_size=len(self.vocab),
260
+ embed_size=self.args["embed_size"],
261
+ hidden_size=self.args["hidden_size"],
262
+ num_layers=self.args["word_num_layers"],
263
+ n_max=self.args["n_max"])
264
+
265
+ if self.model_state_dict is not None:
266
+ print("Word Model Loaded!")
267
+ model.load_state_dict(self.model_state_dict['word_model'])
268
+
269
+ if self.args["cuda"]:
270
+ model = model.cuda()
271
+
272
+ return model
273
+
274
+
275
+
276
+ def main(image):
277
+ sampler = CaptionSampler()
278
+ # image = 'sample_images/CXR195_IM-0618-1001.png'
279
+ caption = sampler.sample(image)
280
+ print(caption[0])
281
+
282
+ return caption[0]
283
+
train_best_loss.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16673cee2882ab65a5f2e4fb23cb1b2d25cd484713ece7cdc910d791b5b79a59
3
+ size 1535115128
val_best_loss.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74099c8e559e56a355ed8e9d8d7b1408ad98be7abc07f21122e5e57e93cf6dfc
3
+ size 1535112572
vocab.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b87d71ea39483d3f3af078210a08d3b685f92ee679aa3d6030904844587d4d8
3
+ size 31925
word.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41a13a1ba5b085ab7d3848cb42dc7cbbc866d7f00fad4e8a838c1369e511b949
3
+ size 13216728