image-captioning-chest-xrays / models_debugger.py
Jyothirmai's picture
Upload 10 files
26e26de verified
raw
history blame
29.8 kB
import torch
import torch.nn as nn
import torchvision
import numpy as np
from torch.autograd import Variable
from torchvision.models.vgg import model_urls as vgg_model_urls
import torchvision.models as models
from utils.tcn import *
class DenseNet121(nn.Module):
def __init__(self, classes=14, pretrained=True):
super(DenseNet121, self).__init__()
self.model = torchvision.models.densenet121(pretrained=pretrained)
num_in_features = self.model.classifier.in_features
self.model.classifier = nn.Sequential(
nn.Linear(in_features=num_in_features, out_features=classes, bias=True),
# nn.Sigmoid()
)
def forward(self, x) -> object:
"""
:rtype: object
"""
x = self.densenet121(x)
return x
class DenseNet161(nn.Module):
def __init__(self, classes=156, pretrained=True):
super(DenseNet161, self).__init__()
self.model = torchvision.models.densenet161(pretrained=pretrained)
num_in_features = self.model.classifier.in_features
self.model.classifier = nn.Sequential(
self.__init_linear(in_features=num_in_features, out_features=classes),
# nn.Sigmoid()
)
def __init_linear(self, in_features, out_features):
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
func.weight.data.normal_(0, 0.1)
return func
def forward(self, x) -> object:
"""
:rtype: object
"""
x = self.model(x)
return x
class DenseNet169(nn.Module):
def __init__(self, classes=156, pretrained=True):
super(DenseNet169, self).__init__()
self.model = torchvision.models.densenet169(pretrained=pretrained)
num_in_features = self.model.classifier.in_features
self.model.classifier = nn.Sequential(
self.__init_linear(in_features=num_in_features, out_features=classes),
# nn.Sigmoid()
)
def __init_linear(self, in_features, out_features):
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
func.weight.data.normal_(0, 0.1)
return func
def forward(self, x) -> object:
"""
:rtype: object
"""
x = self.model(x)
return x
class DenseNet201(nn.Module):
def __init__(self, classes=156, pretrained=True):
super(DenseNet201, self).__init__()
self.model = torchvision.models.densenet201(pretrained=pretrained)
num_in_features = self.model.classifier.in_features
self.model.classifier = nn.Sequential(
self.__init_linear(in_features=num_in_features, out_features=classes),
nn.Sigmoid()
)
def __init_linear(self, in_features, out_features):
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
func.weight.data.normal_(0, 0.1)
return func
def forward(self, x) -> object:
"""
:rtype: object
"""
x = self.model(x)
return x
class ResNet18(nn.Module):
def __init__(self, classes=156, pretrained=True):
super(ResNet18, self).__init__()
self.model = torchvision.models.resnet18(pretrained=pretrained)
num_in_features = self.model.fc.in_features
self.model.fc = nn.Sequential(
self.__init_linear(in_features=num_in_features, out_features=classes),
# nn.Sigmoid()
)
def __init_linear(self, in_features, out_features):
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
func.weight.data.normal_(0, 0.1)
return func
def forward(self, x) -> object:
"""
:rtype: object
"""
x = self.model(x)
return x
class ResNet34(nn.Module):
def __init__(self, classes=156, pretrained=True):
super(ResNet34, self).__init__()
self.model = torchvision.models.resnet34(pretrained=pretrained)
num_in_features = self.model.fc.in_features
self.model.fc = nn.Sequential(
self.__init_linear(in_features=num_in_features, out_features=classes),
# nn.Sigmoid()
)
def __init_linear(self, in_features, out_features):
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
func.weight.data.normal_(0, 0.1)
return func
def forward(self, x) -> object:
"""
:rtype: object
"""
x = self.model(x)
return x
class ResNet50(nn.Module):
def __init__(self, classes=156, pretrained=True):
super(ResNet50, self).__init__()
self.model = torchvision.models.resnet50(pretrained=pretrained)
num_in_features = self.model.fc.in_features
self.model.fc = nn.Sequential(
self.__init_linear(in_features=num_in_features, out_features=classes),
# nn.Sigmoid()
)
def __init_linear(self, in_features, out_features):
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
func.weight.data.normal_(0, 0.1)
return func
def forward(self, x) -> object:
"""
:rtype: object
"""
x = self.model(x)
return x
class ResNet101(nn.Module):
def __init__(self, classes=156, pretrained=True):
super(ResNet101, self).__init__()
self.model = torchvision.models.resnet101(pretrained=pretrained)
num_in_features = self.model.fc.in_features
self.model.fc = nn.Sequential(
self.__init_linear(in_features=num_in_features, out_features=classes),
# nn.Sigmoid()
)
def __init_linear(self, in_features, out_features):
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
func.weight.data.normal_(0, 0.1)
return func
def forward(self, x) -> object:
"""
:rtype: object
"""
x = self.model(x)
return x
class ResNet152(nn.Module):
def __init__(self, classes=156, pretrained=True):
super(ResNet152, self).__init__()
self.model = torchvision.models.resnet152(pretrained=pretrained)
num_in_features = self.model.fc.in_features
self.model.fc = nn.Sequential(
self.__init_linear(in_features=num_in_features, out_features=classes),
# nn.Sigmoid()
)
def __init_linear(self, in_features, out_features):
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
func.weight.data.normal_(0, 0.1)
return func
def forward(self, x) -> object:
"""
:rtype: object
"""
x = self.model(x)
return x
class VGG19(nn.Module):
def __init__(self, classes=14, pretrained=True):
super(VGG19, self).__init__()
self.model = torchvision.models.vgg19_bn(pretrained=pretrained)
self.model.classifier = nn.Sequential(
self.__init_linear(in_features=25088, out_features=4096),
nn.ReLU(),
nn.Dropout(0.5),
self.__init_linear(in_features=4096, out_features=4096),
nn.ReLU(),
nn.Dropout(0.5),
self.__init_linear(in_features=4096, out_features=classes),
# nn.Sigmoid()
)
def __init_linear(self, in_features, out_features):
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
func.weight.data.normal_(0, 0.1)
return func
def forward(self, x) -> object:
"""
:rtype: object
"""
x = self.model(x)
return x
class VGG(nn.Module):
def __init__(self, tags_num):
super(VGG, self).__init__()
vgg_model_urls['vgg19'] = vgg_model_urls['vgg19'].replace('https://', 'http://')
self.vgg19 = models.vgg19(pretrained=True)
vgg19_classifier = list(self.vgg19.classifier.children())[:-1]
self.classifier = nn.Sequential(*vgg19_classifier)
self.fc = nn.Linear(4096, tags_num)
self.fc.apply(self.init_weights)
self.bn = nn.BatchNorm1d(tags_num, momentum=0.1)
# self.init_weights()
def init_weights(self, m):
if type(m) == nn.Linear:
self.fc.weight.data.normal_(0, 0.1)
self.fc.bias.data.fill_(0)
def forward(self, images) -> object:
"""
:rtype: object
"""
visual_feats = self.vgg19.features(images)
tags_classifier = visual_feats.view(visual_feats.size(0), -1)
tags_classifier = self.bn(self.fc(self.classifier(tags_classifier)))
return tags_classifier
class InceptionV3(nn.Module):
def __init__(self, classes=156, pretrained=True):
super(InceptionV3, self).__init__()
self.model = torchvision.models.inception_v3(pretrained=pretrained)
num_in_features = self.model.classifier.in_features
self.model.classifier = nn.Sequential(
self.__init_linear(in_features=num_in_features, out_features=classes),
# nn.Sigmoid()
)
def __init_linear(self, in_features, out_features):
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
func.weight.data.normal_(0, 0.1)
return func
def forward(self, x) -> object:
"""
:rtype: object
"""
x = self.model(x)
return x
class CheXNetDenseNet121(nn.Module):
def __init__(self, classes=14, pretrained=True):
super(CheXNetDenseNet121, self).__init__()
self.densenet121 = torchvision.models.densenet121(pretrained=pretrained)
num_in_features = self.densenet121.classifier.in_features
self.densenet121.classifier = nn.Sequential(
nn.Linear(in_features=num_in_features, out_features=classes, bias=True),
nn.Sigmoid()
)
def forward(self, x) -> object:
"""
:rtype: object
"""
x = self.densenet121(x)
return x
class CheXNet(nn.Module):
def __init__(self, classes=156):
super(CheXNet, self).__init__()
self.densenet121 = CheXNetDenseNet121(classes=14)
self.densenet121 = torch.nn.DataParallel(self.densenet121).cuda()
self.densenet121.load_state_dict(torch.load('./models/CheXNet.pth.tar')['state_dict'])
self.densenet121.module.densenet121.classifier = nn.Sequential(
self.__init_linear(1024, classes),
nn.Sigmoid()
)
def __init_linear(self, in_features, out_features):
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True)
func.weight.data.normal_(0, 0.1)
return func
def forward(self, x) -> object:
"""
:rtype: object
"""
x = self.densenet121(x)
return x
class ModelFactory(object):
def __init__(self, model_name, pretrained, classes):
self.model_name = model_name
self.pretrained = pretrained
self.classes = classes
def create_model(self):
if self.model_name == 'VGG19':
_model = VGG19(pretrained=self.pretrained, classes=self.classes)
elif self.model_name == 'DenseNet121':
_model = DenseNet121(pretrained=self.pretrained, classes=self.classes)
elif self.model_name == 'DenseNet161':
_model = DenseNet161(pretrained=self.pretrained, classes=self.classes)
elif self.model_name == 'DenseNet169':
_model = DenseNet169(pretrained=self.pretrained, classes=self.classes)
elif self.model_name == 'DenseNet201':
_model = DenseNet201(pretrained=self.pretrained, classes=self.classes)
elif self.model_name == 'CheXNet':
_model = CheXNet(classes=self.classes)
elif self.model_name == 'ResNet18':
_model = ResNet18(pretrained=self.pretrained, classes=self.classes)
elif self.model_name == 'ResNet34':
_model = ResNet34(pretrained=self.pretrained, classes=self.classes)
elif self.model_name == 'ResNet50':
_model = ResNet50(pretrained=self.pretrained, classes=self.classes)
elif self.model_name == 'ResNet101':
_model = ResNet101(pretrained=self.pretrained, classes=self.classes)
elif self.model_name == 'ResNet152':
_model = ResNet152(pretrained=self.pretrained, classes=self.classes)
elif self.model_name == 'VGG':
_model = VGG(tags_num=self.classes)
else:
_model = CheXNet(classes=self.classes)
return _model
class EncoderCNN(nn.Module):
def __init__(self, embed_size, pretrained=True):
super(EncoderCNN, self).__init__()
# TODO Extract Image features from CNN based on other models
resnet = models.resnet152(pretrained=pretrained)
modules = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*modules)
self.linear = nn.Linear(resnet.fc.in_features, embed_size)
self.bn = nn.BatchNorm1d(embed_size, momentum=0.1)
self.__init_weights()
def __init_weights(self):
self.linear.weight.data.normal_(0.0, 0.1)
self.linear.bias.data.fill_(0)
def forward(self, images) -> object:
"""
:rtype: object
"""
features = self.resnet(images)
features = Variable(features.data)
features = features.view(features.size(0), -1)
features = self.bn(self.linear(features))
return features
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, n_max=50):
super(DecoderRNN, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
self.__init_weights()
self.n_max = n_max
def __init_weights(self):
self.embed.weight.data.uniform_(-0.1, 0.1)
self.linear.weight.data.uniform_(-0.1, 0.1)
self.linear.bias.data.fill_(0)
def forward(self, features, captions) -> object:
"""
:rtype: object
"""
embeddings = self.embed(captions)
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
hidden, _ = self.lstm(embeddings)
outputs = self.linear(hidden[:, -1, :])
return outputs
def sample(self, features, start_tokens):
sampled_ids = np.zeros((np.shape(features)[0], self.n_max))
predicted = start_tokens
embeddings = features
embeddings = embeddings.unsqueeze(1)
for i in range(self.n_max):
predicted = self.embed(predicted)
embeddings = torch.cat([embeddings, predicted], dim=1)
hidden_states, _ = self.lstm(embeddings)
hidden_states = hidden_states[:, -1, :]
outputs = self.linear(hidden_states)
predicted = torch.max(outputs, 1)[1]
sampled_ids[:, i] = predicted
predicted = predicted.unsqueeze(1)
return sampled_ids
class VisualFeatureExtractor(nn.Module):
def __init__(self, pretrained=False):
super(VisualFeatureExtractor, self).__init__()
resnet = models.resnet152(pretrained=pretrained)
modules = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*modules)
self.out_features = resnet.fc.in_features
def forward(self, images) -> object:
"""
:rtype: object
"""
features = self.resnet(images)
features = features.view(features.size(0), -1)
return features
class MLC(nn.Module):
def __init__(self, classes=156, sementic_features_dim=512, fc_in_features=2048, k=10):
super(MLC, self).__init__()
self.classifier = nn.Linear(in_features=fc_in_features, out_features=classes)
self.embed = nn.Embedding(classes, sementic_features_dim)
self.k = k
self.softmax = nn.Softmax()
def forward(self, visual_features) -> object:
"""
:rtype: object
"""
tags = self.softmax(self.classifier(visual_features))
semantic_features = self.embed(torch.topk(tags, self.k)[1])
return tags, semantic_features
class CoAttention(nn.Module):
def __init__(self, embed_size=512, hidden_size=512, visual_size=2048):
super(CoAttention, self).__init__()
self.W_v = nn.Linear(in_features=visual_size, out_features=visual_size)
self.bn_v = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
self.W_v_h = nn.Linear(in_features=hidden_size, out_features=visual_size)
self.bn_v_h = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
self.W_v_att = nn.Linear(in_features=visual_size, out_features=visual_size)
self.bn_v_att = nn.BatchNorm1d(num_features=visual_size, momentum=0.1)
self.W_a = nn.Linear(in_features=hidden_size, out_features=hidden_size)
self.bn_a = nn.BatchNorm1d(num_features=10, momentum=0.1)
self.W_a_h = nn.Linear(in_features=hidden_size, out_features=hidden_size)
self.bn_a_h = nn.BatchNorm1d(num_features=1, momentum=0.1)
self.W_a_att = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
self.bn_a_att = nn.BatchNorm1d(num_features=10, momentum=0.1)
self.W_fc = nn.Linear(in_features=visual_size + hidden_size, out_features=embed_size)
self.bn_fc = nn.BatchNorm1d(num_features=embed_size, momentum=0.1)
self.tanh = nn.Tanh()
self.softmax = nn.Softmax()
def forward(self, visual_features, semantic_features, h_sent) -> object:
"""
only training
:rtype: object
"""
W_v = self.bn_v(self.W_v(visual_features))
W_v_h = self.bn_v_h(self.W_v_h(h_sent.squeeze(1)))
alpha_v = self.softmax(self.bn_v_att(self.W_v_att(self.tanh(W_v + W_v_h))))
v_att = torch.mul(alpha_v, visual_features)
# v_att = torch.mul(alpha_v, visual_features).sum(1).unsqueeze(1)
W_a_h = self.bn_a_h(self.W_a_h(h_sent))
W_a = self.bn_a(self.W_a(semantic_features))
alpha_a = self.softmax(self.bn_a_att(self.W_a_att(self.tanh(torch.add(W_a_h, W_a)))))
a_att = torch.mul(alpha_a, semantic_features).sum(1)
# a_att = (alpha_a * semantic_features).sum(1)
ctx = self.bn_fc(self.W_fc(torch.cat([v_att, a_att], dim=1)))
# return self.W_fc(self.bn_fc(torch.cat([v_att, a_att], dim=1)))
return ctx, v_att
class SentenceLSTM(nn.Module):
def __init__(self, embed_size=512, hidden_size=512, num_layers=1):
super(SentenceLSTM, self).__init__()
self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers)
self.W_t_h = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
self.bn_t_h = nn.BatchNorm1d(num_features=1, momentum=0.1)
self.W_t_ctx = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
self.bn_t_ctx = nn.BatchNorm1d(num_features=1, momentum=0.1)
self.W_stop_s_1 = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
self.bn_stop_s_1 = nn.BatchNorm1d(num_features=1, momentum=0.1)
self.W_stop_s = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True)
self.bn_stop_s = nn.BatchNorm1d(num_features=1, momentum=0.1)
self.W_stop = nn.Linear(in_features=embed_size, out_features=2, bias=True)
self.bn_stop = nn.BatchNorm1d(num_features=1, momentum=0.1)
self.W_topic = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
self.bn_topic = nn.BatchNorm1d(num_features=1, momentum=0.1)
self.W_topic_2 = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True)
self.bn_topic_2 = nn.BatchNorm1d(num_features=1, momentum=0.1)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
# def forward(self, ctx, prev_hidden_state, states=None) -> object:
# """
# Only training
# :rtype: object
# """
# ctx = ctx.unsqueeze(1)
# hidden_state, states = self.lstm(ctx, states)
# topic = self.bn_topic(self.W_topic(self.sigmoid(self.bn_t_h(self.W_t_h(hidden_state))
# + self.bn_t_ctx(self.W_t_ctx(ctx)))))
# p_stop = self.bn_stop(self.W_stop(self.sigmoid(self.bn_stop_s_1(self.W_stop_s_1(prev_hidden_state))
# + self.bn_stop_s(self.W_stop_s(hidden_state)))))
# return topic, p_stop, hidden_state, states
def forward(self, ctx, prev_hidden_state, states=None) -> object:
"""
v2
:rtype: object
"""
ctx = ctx.unsqueeze(1)
hidden_state, states = self.lstm(ctx, states)
topic = self.bn_topic(self.W_topic(self.tanh(self.bn_t_h(self.W_t_h(hidden_state)
+ self.W_t_ctx(ctx)))))
p_stop = self.bn_stop(self.W_stop(self.tanh(self.bn_stop_s(self.W_stop_s_1(prev_hidden_state)
+ self.W_stop_s(hidden_state)))))
return topic, p_stop, hidden_state, states
class SentenceTCN(nn.Module):
def __init__(self,
input_channel=10,
embed_size=512,
output_size=512,
nhid=512,
levels=8,
kernel_size=2,
dropout=0):
super(SentenceTCN, self).__init__()
channel_sizes = [nhid] * levels
self.tcn = TCN(input_size=input_channel,
output_size=output_size,
num_channels=channel_sizes,
kernel_size=kernel_size,
dropout=dropout)
self.W_t_h = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
self.W_t_ctx = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
self.W_stop_s_1 = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
self.W_stop_s = nn.Linear(in_features=output_size, out_features=embed_size, bias=True)
self.W_stop = nn.Linear(in_features=embed_size, out_features=2, bias=True)
self.t_w = nn.Linear(in_features=5120, out_features=2, bias=True)
self.tanh = nn.Tanh()
def forward(self, ctx, prev_output) -> object:
"""
:rtype: object
"""
output = self.tcn.forward(ctx)
topic = self.tanh(self.W_t_h(output) + self.W_t_ctx(ctx[:, -1, :]).squeeze(1))
p_stop = self.W_stop(self.tanh(self.W_stop_s_1(prev_output) + self.W_stop_s(output)))
return topic, p_stop, output
class WordLSTM(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, n_max=50):
super(WordLSTM, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
self.__init_weights()
self.n_max = n_max
self.vocab_size = vocab_size
def __init_weights(self):
self.embed.weight.data.uniform_(-0.1, 0.1)
self.linear.weight.data.uniform_(-0.1, 0.1)
self.linear.bias.data.fill_(0)
def forward(self, topic_vec, captions) -> object:
"""
:rtype: object
"""
embeddings = self.embed(captions)
embeddings = torch.cat((topic_vec, embeddings), 1)
hidden, _ = self.lstm(embeddings)
outputs = self.linear(hidden[:, -1, :])
return outputs
def val(self, features, start_tokens):
samples = torch.zeros((np.shape(features)[0], self.n_max, self.vocab_size))
samples[:, 0, start_tokens[0]] = 1
predicted = start_tokens
embeddings = features
embeddings = embeddings
for i in range(1, self.n_max):
predicted = self.embed(predicted)
embeddings = torch.cat([embeddings, predicted], dim=1)
hidden_states, _ = self.lstm(embeddings)
hidden_states = hidden_states[:, -1, :]
outputs = self.linear(hidden_states)
samples[:, i, :] = outputs
predicted = torch.max(outputs, 1)[1]
predicted = predicted.unsqueeze(1)
return samples
def sample(self, features, start_tokens):
sampled_ids = np.zeros((np.shape(features)[0], self.n_max))
sampled_ids[:, 0] = start_tokens.view(-1,)
predicted = start_tokens
embeddings = features
embeddings = embeddings
for i in range(1, self.n_max):
predicted = self.embed(predicted)
embeddings = torch.cat([embeddings, predicted], dim=1)
hidden_states, _ = self.lstm(embeddings)
hidden_states = hidden_states[:, -1, :]
outputs = self.linear(hidden_states)
predicted = torch.max(outputs, 1)[1]
sampled_ids[:, i] = predicted
predicted = predicted.unsqueeze(1)
return sampled_ids
class WordTCN(nn.Module):
def __init__(self,
input_channel=11,
vocab_size=1000,
embed_size=512,
output_size=512,
nhid=512,
levels=8,
kernel_size=2,
dropout=0,
n_max=50):
super(WordTCN, self).__init__()
self.vocab_size = vocab_size
self.embed_size = embed_size
self.output_size = output_size
channel_sizes = [nhid] * levels
self.kernel_size = kernel_size
self.dropout = dropout
self.n_max = n_max
self.embed = nn.Embedding(vocab_size, embed_size)
self.W_out = nn.Linear(in_features=output_size, out_features=vocab_size, bias=True)
self.tcn = TCN(input_size=input_channel,
output_size=output_size,
num_channels=channel_sizes,
kernel_size=kernel_size,
dropout=dropout)
def forward(self, topic_vec, captions) -> object:
"""
:rtype: object
"""
captions = self.embed(captions)
embeddings = torch.cat([topic_vec, captions], dim=1)
output = self.tcn.forward(embeddings)
words = self.W_out(output)
return words
if __name__ == '__main__':
import warnings
warnings.filterwarnings("ignore")
images = torch.randn((4, 3, 224, 224))
captions = torch.ones((4, 10)).long()
hidden_state = torch.randn((4, 1, 512))
print("images:{}".format(images.shape))
print("captions:{}".format(captions.shape))
print("hidden_states:{}".format(hidden_state.shape))
extractor = VisualFeatureExtractor()
visual_features = extractor.forward(images)
print("visual_features:{}".format(visual_features.shape))
mlc = MLC()
tags, semantic_features = mlc.forward(visual_features)
print("tags:{}".format(tags.shape))
print("semantic_features:{}".format(semantic_features.shape))
co_att = CoAttention()
ctx, v_att = co_att.forward(visual_features, semantic_features, hidden_state)
print("ctx:{}".format(ctx.shape))
print("v_att:{}".format(v_att.shape))
sent_lstm = SentenceLSTM()
topic, p_stop, hidden_state, states = sent_lstm.forward(ctx, hidden_state)
print("Topic:{}".format(topic.shape))
print("P_STOP:{}".format(p_stop.shape))
word_lstm = WordLSTM(embed_size=512, hidden_size=512, vocab_size=100, num_layers=1)
words = word_lstm.forward(topic, captions)
print("words:{}".format(words.shape))
# Expected Output
# images: torch.Size([4, 3, 224, 224])
# captions: torch.Size([4, 1, 10])
# hidden_states: torch.Size([4, 1, 512])
# visual_features: torch.Size([4, 2048, 7, 7])
# tags: torch.Size([4, 156])
# semantic_features: torch.Size([4, 10, 512])
# ctx: torch.Size([4, 512])
# Topic: torch.Size([4, 1, 512])
# P_STOP: torch.Size([4, 1, 2])
# words: torch.Size([4, 1000])
# images = torch.randn((4, 3, 224, 224))
# captions = torch.ones((4, 3, 10)).long()
# prev_outputs = torch.randn((4, 512))
# now_words = torch.ones((4, 1))
#
# ctx_records = torch.zeros((4, 10, 512))
# captions = torch.zeros((4, 10)).long()
#
# print("images:{}".format(images.shape))
# print("captions:{}".format(captions.shape))
# print("hidden_states:{}".format(prev_outputs.shape))
#
# extractor = VisualFeatureExtractor()
# visual_features = extractor.forward(images)
# print("visual_features:{}".format(visual_features.shape))
#
# mlc = MLC()
# tags, semantic_features = mlc.forward(visual_features)
# print("tags:{}".format(tags.shape))
# print("semantic_features:{}".format(semantic_features.shape))
#
# co_att = CoAttention()
# ctx = co_att.forward(visual_features, semantic_features, prev_outputs)
# print("ctx:{}".format(ctx.shape))
#
# ctx_records[:, 0, :] = ctx
#
# sent_tcn = SentenceTCN()
# topic, p_stop, prev_outputs = sent_tcn.forward(ctx_records, prev_outputs)
# print("Topic:{}".format(topic.shape))
# print("P_STOP:{}".format(p_stop.shape))
# print("Prev_Outputs:{}".format(prev_outputs.shape))
#
# captions[:, 0] = now_words.view(-1,)
#
# word_tcn = WordTCN()
# words = word_tcn.forward(topic, captions)
# print("words:{}".format(words.shape))