Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
import torch.nn as nn | |
import torchvision | |
import numpy as np | |
from torch.autograd import Variable | |
import torchvision.models as models | |
import transformers | |
import torchvision.transforms | |
import torchxrayvision as xrv | |
from transformers import ViTModel, ViTConfig | |
class VisualFeatureExtractor(nn.Module): | |
def __init__(self, model_name='densenet201', pretrained=False): | |
super(VisualFeatureExtractor, self).__init__() | |
self.model_name = 'chexnet' | |
self.pretrained = pretrained | |
self.model, self.out_features, self.avg_func, self.bn, self.linear = self.__get_model() | |
self.activation = nn.ReLU() | |
def __get_model(self): | |
model = None | |
out_features = None | |
func = None | |
if self.model_name == 'resnet152': | |
resnet = models.resnet152(pretrained=self.pretrained) | |
modules = list(resnet.children())[:-2] | |
model = nn.Sequential(*modules) | |
out_features = resnet.fc.in_features | |
func = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0) | |
elif self.model_name == 'densenet201': | |
densenet = models.densenet201(pretrained=self.pretrained) | |
modules = list(densenet.features) | |
model = nn.Sequential(*modules) | |
func = torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0) | |
out_features = densenet.classifier.in_features | |
elif self.model_name == 'chexnet': | |
print("vit chest xray pretrained model loading") | |
# Load the Vision Transformer (ViT) model configuration | |
config = ViTConfig.from_pretrained('nickmuchi/vit-finetuned-chest-xray-pneumonia') | |
# Initialize the ViT model with the specific configuration | |
vit_model = ViTModel(config) | |
# Load the state dict specifically, excluding 'classifier.bias', 'classifier.weight' | |
state_dict = torch.load('model/pytorch_model.bin', map_location=torch.device('cpu')) | |
state_dict = {k: v for k, v in state_dict.items() if not k.startswith('classifier')} | |
vit_model.load_state_dict(state_dict, strict=False) | |
model = vit_model | |
out_features = config.hidden_size | |
linear = nn.Linear(in_features=out_features, out_features=out_features) | |
bn = nn.BatchNorm1d(num_features=out_features, momentum=0.1) | |
return model, out_features, func, bn, linear | |
def forward(self, images): | |
""" | |
:param images: Input images | |
:return: visual_features, avg_features | |
""" | |
model_output = self.model(images) | |
# Extract the pooler_output | |
pooler_output = model_output.pooler_output | |
# Apply the linear layer, batch normalization, and activation | |
avg_features = self.activation(self.bn(self.linear(pooler_output))) | |
return model_output.last_hidden_state, avg_features | |
# def forward(self, images): | |
# """ | |
# :param images: | |
# :return: | |
# """ | |
# visual_features = self.model(images) | |
# avg_features = self.avg_func(visual_features).squeeze() | |
# # avg_features = self.activation(self.bn(self.linear(visual_features))) | |
# return visual_features, avg_features | |
class MLC(nn.Module): | |
def __init__(self, | |
classes=210, | |
sementic_features_dim=512, | |
fc_in_features=2048, | |
k=10, | |
): | |
super(MLC, self).__init__() | |
pretrained_model_name="nickmuchi/vit-finetuned-chest-xray-pneumonia" | |
vit_config = ViTConfig.from_pretrained(pretrained_model_name) | |
self.vit = ViTModel(vit_config) | |
# Adjust the classifier to your number of classes | |
self.classifier = nn.Linear(in_features=vit_config.hidden_size, out_features=classes) | |
self.embed = nn.Embedding(classes, sementic_features_dim) | |
self.k = k | |
self.sigmoid = nn.Sigmoid() | |
self.__init_weight() | |
def __init_weight(self): | |
nn.init.xavier_uniform_(self.classifier.weight) | |
if self.classifier.bias is not None: | |
self.classifier.bias.data.fill_(0) | |
def forward(self, avg_features): | |
tags = self.sigmoid(self.classifier(avg_features)) | |
semantic_features = self.embed(torch.topk(tags, self.k)[1]) | |
return tags, semantic_features | |
# class MLC(nn.Module): | |
# def __init__(self, | |
# classes=210, | |
# sementic_features_dim=512, | |
# fc_in_features=2048, | |
# k=10): | |
# super(MLC, self).__init__() | |
# self.classifier = nn.Linear(in_features=fc_in_features, out_features=classes) | |
# self.embed = nn.Embedding(classes, sementic_features_dim) | |
# self.k = k | |
# self.sigmoid = nn.Sigmoid() | |
# self.__init_weight() | |
# def __init_weight(self): | |
# # Example: Initialize weights with a different strategy | |
# nn.init.xavier_uniform_(self.classifier.weight) | |
# if self.classifier.bias is not None: | |
# self.classifier.bias.data.fill_(0) | |
# def forward(self, avg_features): | |
# tags = self.sigmoid(self.classifier(avg_features)) | |
# semantic_features = self.embed(torch.topk(tags, self.k)[1]) | |
# return tags, semantic_features | |
class CoAttention(nn.Module): | |
def __init__(self, | |
version='v1', | |
embed_size=512, | |
hidden_size=512, | |
visual_size=2048, | |
k=10, | |
momentum=0.1): | |
super(CoAttention, self).__init__() | |
self.version = version | |
self.W_v = nn.Linear(in_features=visual_size, out_features=visual_size) | |
self.bn_v = nn.BatchNorm1d(num_features=visual_size, momentum=momentum) | |
self.W_v_h = nn.Linear(in_features=hidden_size, out_features=visual_size) | |
self.bn_v_h = nn.BatchNorm1d(num_features=visual_size, momentum=momentum) | |
self.W_v_att = nn.Linear(in_features=visual_size, out_features=visual_size) | |
self.bn_v_att = nn.BatchNorm1d(num_features=visual_size, momentum=momentum) | |
self.W_a = nn.Linear(in_features=hidden_size, out_features=hidden_size) | |
self.bn_a = nn.BatchNorm1d(num_features=k, momentum=momentum) | |
self.W_a_h = nn.Linear(in_features=hidden_size, out_features=hidden_size) | |
self.bn_a_h = nn.BatchNorm1d(num_features=1, momentum=momentum) | |
self.W_a_att = nn.Linear(in_features=hidden_size, out_features=hidden_size) | |
self.bn_a_att = nn.BatchNorm1d(num_features=k, momentum=momentum) | |
# self.W_fc = nn.Linear(in_features=visual_size, out_features=embed_size) # for v3 | |
self.W_fc = nn.Linear(in_features=visual_size + hidden_size, out_features=embed_size) | |
self.bn_fc = nn.BatchNorm1d(num_features=embed_size, momentum=momentum) | |
self.tanh = nn.Tanh() | |
self.softmax = nn.Softmax() | |
self.__init_weight() | |
def __init_weight(self): | |
self.W_v.weight.data.uniform_(-0.1, 0.1) | |
self.W_v.bias.data.fill_(0) | |
self.W_v_h.weight.data.uniform_(-0.1, 0.1) | |
self.W_v_h.bias.data.fill_(0) | |
self.W_v_att.weight.data.uniform_(-0.1, 0.1) | |
self.W_v_att.bias.data.fill_(0) | |
self.W_a.weight.data.uniform_(-0.1, 0.1) | |
self.W_a.bias.data.fill_(0) | |
self.W_a_h.weight.data.uniform_(-0.1, 0.1) | |
self.W_a_h.bias.data.fill_(0) | |
self.W_a_att.weight.data.uniform_(-0.1, 0.1) | |
self.W_a_att.bias.data.fill_(0) | |
self.W_fc.weight.data.uniform_(-0.1, 0.1) | |
self.W_fc.bias.data.fill_(0) | |
def forward(self, avg_features, semantic_features, h_sent): | |
if self.version == 'v1': | |
return self.v1(avg_features, semantic_features, h_sent) | |
elif self.version == 'v2': | |
return self.v2(avg_features, semantic_features, h_sent) | |
elif self.version == 'v3': | |
return self.v3(avg_features, semantic_features, h_sent) | |
elif self.version == 'v4': | |
return self.v4(avg_features, semantic_features, h_sent) | |
elif self.version == 'v5': | |
return self.v5(avg_features, semantic_features, h_sent) | |
def v1(self, avg_features, semantic_features, h_sent) -> object: | |
""" | |
only training | |
:rtype: object | |
""" | |
W_v = self.bn_v(self.W_v(avg_features)) | |
W_v_h = self.bn_v_h(self.W_v_h(h_sent.squeeze(1))) | |
alpha_v = self.softmax(self.bn_v_att(self.W_v_att(self.tanh(W_v + W_v_h)))) | |
v_att = torch.mul(alpha_v, avg_features) | |
W_a_h = self.bn_a_h(self.W_a_h(h_sent)) | |
W_a = self.bn_a(self.W_a(semantic_features)) | |
alpha_a = self.softmax(self.bn_a_att(self.W_a_att(self.tanh(torch.add(W_a_h, W_a))))) | |
a_att = torch.mul(alpha_a, semantic_features).sum(1) | |
ctx = self.W_fc(torch.cat([v_att, a_att], dim=1)) | |
return ctx, alpha_v, alpha_a | |
def v2(self, avg_features, semantic_features, h_sent) -> object: | |
""" | |
no bn | |
:rtype: object | |
""" | |
W_v = self.W_v(avg_features) | |
W_v_h = self.W_v_h(h_sent.squeeze(1)) | |
alpha_v = self.softmax(self.W_v_att(self.tanh(W_v + W_v_h))) | |
v_att = torch.mul(alpha_v, avg_features) | |
W_a_h = self.W_a_h(h_sent) | |
W_a = self.W_a(semantic_features) | |
alpha_a = self.softmax(self.W_a_att(self.tanh(torch.add(W_a_h, W_a)))) | |
a_att = torch.mul(alpha_a, semantic_features).sum(1) | |
ctx = self.W_fc(torch.cat([v_att, a_att], dim=1)) | |
return ctx, alpha_v, alpha_a | |
def v3(self, avg_features, semantic_features, h_sent) -> object: | |
""" | |
:rtype: object | |
""" | |
W_v = self.bn_v(self.W_v(avg_features)) | |
W_v_h = self.bn_v_h(self.W_v_h(h_sent.squeeze(1))) | |
alpha_v = self.softmax(self.W_v_att(self.tanh(W_v + W_v_h))) | |
v_att = torch.mul(alpha_v, avg_features) | |
W_a_h = self.bn_a_h(self.W_a_h(h_sent)) | |
W_a = self.bn_a(self.W_a(semantic_features)) | |
alpha_a = self.softmax(self.W_a_att(self.tanh(torch.add(W_a_h, W_a)))) | |
a_att = torch.mul(alpha_a, semantic_features).sum(1) | |
ctx = self.W_fc(torch.cat([v_att, a_att], dim=1)) | |
return ctx, alpha_v, alpha_a | |
def v4(self, avg_features, semantic_features, h_sent): | |
W_v = self.W_v(avg_features) | |
W_v_h = self.W_v_h(h_sent.squeeze(1)) | |
alpha_v = self.softmax(self.W_v_att(self.tanh(torch.add(W_v, W_v_h)))) | |
v_att = torch.mul(alpha_v, avg_features) | |
W_a_h = self.W_a_h(h_sent) | |
W_a = self.W_a(semantic_features) | |
alpha_a = self.softmax(self.W_a_att(self.tanh(torch.add(W_a_h, W_a)))) | |
a_att = torch.mul(alpha_a, semantic_features).sum(1) | |
ctx = self.W_fc(torch.cat([v_att, a_att], dim=1)) | |
return ctx, alpha_v, alpha_a | |
def v5(self, avg_features, semantic_features, h_sent): | |
W_v = self.W_v(avg_features) | |
W_v_h = self.W_v_h(h_sent.squeeze(1)) | |
alpha_v = self.softmax(self.W_v_att(self.tanh(self.bn_v(torch.add(W_v, W_v_h))))) | |
v_att = torch.mul(alpha_v, avg_features) | |
W_a_h = self.W_a_h(h_sent) | |
W_a = self.W_a(semantic_features) | |
alpha_a = self.softmax(self.W_a_att(self.tanh(self.bn_a(torch.add(W_a_h, W_a))))) | |
a_att = torch.mul(alpha_a, semantic_features).sum(1) | |
ctx = self.W_fc(torch.cat([v_att, a_att], dim=1)) | |
return ctx, alpha_v, alpha_a | |
class SentenceLSTM(nn.Module): | |
def __init__(self, | |
version='v1', | |
embed_size=512, | |
hidden_size=512, | |
num_layers=1, | |
dropout=0.3, | |
momentum=0.1): | |
super(SentenceLSTM, self).__init__() | |
self.version = version | |
self.lstm = nn.LSTM(input_size=embed_size, | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
dropout=dropout) | |
self.W_t_h = nn.Linear(in_features=hidden_size, | |
out_features=embed_size, | |
bias=True) | |
self.bn_t_h = nn.BatchNorm1d(num_features=1, momentum=momentum) | |
self.W_t_ctx = nn.Linear(in_features=embed_size, | |
out_features=embed_size, | |
bias=True) | |
self.bn_t_ctx = nn.BatchNorm1d(num_features=1, momentum=momentum) | |
self.W_stop_s_1 = nn.Linear(in_features=hidden_size, | |
out_features=embed_size, | |
bias=True) | |
self.bn_stop_s_1 = nn.BatchNorm1d(num_features=1, momentum=momentum) | |
self.W_stop_s = nn.Linear(in_features=hidden_size, | |
out_features=embed_size, | |
bias=True) | |
self.bn_stop_s = nn.BatchNorm1d(num_features=1, momentum=momentum) | |
self.W_stop = nn.Linear(in_features=embed_size, | |
out_features=2, | |
bias=True) | |
self.bn_stop = nn.BatchNorm1d(num_features=1, momentum=momentum) | |
self.W_topic = nn.Linear(in_features=embed_size, | |
out_features=embed_size, | |
bias=True) | |
self.bn_topic = nn.BatchNorm1d(num_features=1, momentum=momentum) | |
self.sigmoid = nn.Sigmoid() | |
self.tanh = nn.Tanh() | |
self.__init_weight() | |
def __init_weight(self): | |
self.W_t_h.weight.data.uniform_(-0.1, 0.1) | |
self.W_t_h.bias.data.fill_(0) | |
self.W_t_ctx.weight.data.uniform_(-0.1, 0.1) | |
self.W_t_ctx.bias.data.fill_(0) | |
self.W_stop_s_1.weight.data.uniform_(-0.1, 0.1) | |
self.W_stop_s_1.bias.data.fill_(0) | |
self.W_stop_s.weight.data.uniform_(-0.1, 0.1) | |
self.W_stop_s.bias.data.fill_(0) | |
self.W_stop.weight.data.uniform_(-0.1, 0.1) | |
self.W_stop.bias.data.fill_(0) | |
self.W_topic.weight.data.uniform_(-0.1, 0.1) | |
self.W_topic.bias.data.fill_(0) | |
def forward(self, ctx, prev_hidden_state, states=None) -> object: | |
""" | |
:rtype: object | |
""" | |
if self.version == 'v1': | |
return self.v1(ctx, prev_hidden_state, states) | |
elif self.version == 'v2': | |
return self.v2(ctx, prev_hidden_state, states) | |
elif self.version == 'v3': | |
return self.v3(ctx, prev_hidden_state, states) | |
def v1(self, ctx, prev_hidden_state, states=None): | |
""" | |
v1 (only training) | |
:param ctx: | |
:param prev_hidden_state: | |
:param states: | |
:return: | |
""" | |
ctx = ctx.unsqueeze(1) | |
hidden_state, states = self.lstm(ctx, states) | |
topic = self.W_topic(self.sigmoid(self.bn_t_h(self.W_t_h(hidden_state)) | |
+ self.bn_t_ctx(self.W_t_ctx(ctx)))) | |
p_stop = self.W_stop(self.sigmoid(self.bn_stop_s_1(self.W_stop_s_1(prev_hidden_state)) | |
+ self.bn_stop_s(self.W_stop_s(hidden_state)))) | |
return topic, p_stop, hidden_state, states | |
def v2(self, ctx, prev_hidden_state, states=None): | |
""" | |
v2 | |
:rtype: object | |
""" | |
ctx = ctx.unsqueeze(1) | |
hidden_state, states = self.lstm(ctx, states) | |
topic = self.bn_topic(self.W_topic(self.tanh(self.bn_t_h(self.W_t_h(hidden_state) | |
+ self.W_t_ctx(ctx))))) | |
p_stop = self.bn_stop(self.W_stop(self.tanh(self.bn_stop_s(self.W_stop_s_1(prev_hidden_state) | |
+ self.W_stop_s(hidden_state))))) | |
return topic, p_stop, hidden_state, states | |
def v3(self, ctx, prev_hidden_state, states=None): | |
""" | |
v3 | |
:rtype: object | |
""" | |
ctx = ctx.unsqueeze(1) | |
hidden_state, states = self.lstm(ctx, states) | |
topic = self.W_topic(self.tanh(self.W_t_h(hidden_state) + self.W_t_ctx(ctx))) | |
p_stop = self.W_stop(self.tanh(self.W_stop_s_1(prev_hidden_state) + self.W_stop_s(hidden_state))) | |
return topic, p_stop, hidden_state, states | |
class WordLSTM(nn.Module): | |
def __init__(self, | |
embed_size, | |
hidden_size, | |
vocab_size, | |
num_layers, | |
n_max=50): | |
super(WordLSTM, self).__init__() | |
self.embed = nn.Embedding(vocab_size, embed_size) | |
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) | |
self.linear = nn.Linear(hidden_size, vocab_size) | |
self.__init_weights() | |
self.n_max = n_max | |
self.vocab_size = vocab_size | |
def __init_weights(self): | |
self.embed.weight.data.uniform_(-0.1, 0.1) | |
self.linear.weight.data.uniform_(-0.1, 0.1) | |
self.linear.bias.data.fill_(0) | |
def forward(self, topic_vec, captions): | |
embeddings = self.embed(captions) | |
embeddings = torch.cat((topic_vec, embeddings), 1) | |
hidden, _ = self.lstm(embeddings) | |
outputs = self.linear(hidden[:, -1, :]) | |
return outputs | |
def sample(self, features, start_tokens): | |
sampled_ids = np.zeros((np.shape(features)[0], self.n_max)) | |
sampled_ids[:, 0] = start_tokens.view(-1, ) | |
predicted = start_tokens | |
embeddings = features | |
embeddings = embeddings | |
for i in range(1, self.n_max): | |
predicted = self.embed(predicted) | |
embeddings = torch.cat([embeddings, predicted], dim=1) | |
hidden_states, _ = self.lstm(embeddings) | |
hidden_states = hidden_states[:, -1, :] | |
outputs = self.linear(hidden_states) | |
predicted = torch.max(outputs, 1)[1] | |
sampled_ids[:, i] = predicted | |
predicted = predicted.unsqueeze(1) | |
return sampled_ids | |
if __name__ == '__main__': | |
import torchvision.transforms as transforms | |
import warnings | |
warnings.filterwarnings("ignore") | |
# | |
extractor = VisualFeatureExtractor(model_name='resnet152') | |
mlc = MLC(fc_in_features=extractor.out_features) | |
co_att = CoAttention(visual_size=extractor.out_features) | |
sent_lstm = SentenceLSTM() | |
word_lstm = WordLSTM(embed_size=512, hidden_size=512, vocab_size=100, num_layers=1) | |
images = torch.randn((4, 3, 224, 224)) | |
captions = torch.ones((4, 10)).long() | |
hidden_state = torch.randn((4, 1, 512)) | |
# # image_file = '../data/images/CXR2814_IM-1239-1001.png' | |
# # # images = Image.open(image_file).convert('RGB') | |
# # # captions = torch.ones((1, 10)).long() | |
# # # hidden_state = torch.randn((10, 512)) | |
# # | |
# norm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
# | |
# transform = transforms.Compose([ | |
# transforms.Resize(256), | |
# transforms.TenCrop(224), | |
# transforms.Lambda(lambda crops: torch.stack([norm(transforms.ToTensor()(crop)) for crop in crops])), | |
# ]) | |
# images = transform(images) | |
# images.unsqueeze_(0) | |
# | |
# # bs, ncrops, c, h, w = images.size() | |
# # images = images.view(-1, c, h, w) | |
# | |
print("images:{}".format(images.shape)) | |
print("captions:{}".format(captions.shape)) | |
print("hidden_states:{}".format(hidden_state.shape)) | |
visual_features, avg_features = extractor.forward(images) | |
print("visual_features:{}".format(visual_features.shape)) | |
print("avg features:{}".format(avg_features.shape)) | |
tags, semantic_features = mlc.forward(avg_features) | |
print("tags:{}".format(tags.shape)) | |
print("semantic_features:{}".format(semantic_features.shape)) | |
ctx, alpht_v, alpht_a = co_att.forward(avg_features, semantic_features, hidden_state) | |
print("ctx:{}".format(ctx.shape)) | |
print("alpht_v:{}".format(alpht_v.shape)) | |
print("alpht_a:{}".format(alpht_a.shape)) | |
topic, p_stop, hidden_state, states = sent_lstm.forward(ctx, hidden_state) | |
# p_stop_avg = p_stop.view(bs, ncrops, -1).mean(1) | |
print("Topic:{}".format(topic.shape)) | |
print("P_STOP:{}".format(p_stop.shape)) | |
# print("P_stop_avg:{}".format(p_stop_avg.shape)) | |
words = word_lstm.forward(topic, captions) | |
print("words:{}".format(words.shape)) | |
cam = torch.mul(visual_features, alpht_v.view(alpht_v.shape[0], alpht_v.shape[1], 1, 1)).sum(1) | |
cam.squeeze_() | |
cam = cam.cpu().data.numpy() | |
for i in range(cam.shape[0]): | |
heatmap = cam[i] | |
heatmap = heatmap / np.max(heatmap) | |
print(heatmap.shape) | |