|
import torch |
|
import torch.nn as nn |
|
import torchvision |
|
import numpy as np |
|
from torch.autograd import Variable |
|
from torchvision.models.vgg import model_urls as vgg_model_urls |
|
import torchvision.models as models |
|
|
|
from utils.tcn import * |
|
|
|
|
|
class DenseNet121(nn.Module): |
|
def __init__(self, classes=14, pretrained=True): |
|
super(DenseNet121, self).__init__() |
|
self.model = torchvision.models.densenet121(pretrained=pretrained) |
|
num_in_features = self.model.classifier.in_features |
|
self.model.classifier = nn.Sequential( |
|
nn.Linear(in_features=num_in_features, out_features=classes, bias=True), |
|
|
|
) |
|
|
|
def forward(self, x) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
x = self.densenet121(x) |
|
return x |
|
|
|
|
|
class DenseNet161(nn.Module): |
|
def __init__(self, classes=156, pretrained=True): |
|
super(DenseNet161, self).__init__() |
|
self.model = torchvision.models.densenet161(pretrained=pretrained) |
|
num_in_features = self.model.classifier.in_features |
|
self.model.classifier = nn.Sequential( |
|
self.__init_linear(in_features=num_in_features, out_features=classes), |
|
|
|
) |
|
|
|
def __init_linear(self, in_features, out_features): |
|
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True) |
|
func.weight.data.normal_(0, 0.1) |
|
return func |
|
|
|
def forward(self, x) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
x = self.model(x) |
|
return x |
|
|
|
|
|
class DenseNet169(nn.Module): |
|
def __init__(self, classes=156, pretrained=True): |
|
super(DenseNet169, self).__init__() |
|
self.model = torchvision.models.densenet169(pretrained=pretrained) |
|
num_in_features = self.model.classifier.in_features |
|
self.model.classifier = nn.Sequential( |
|
self.__init_linear(in_features=num_in_features, out_features=classes), |
|
|
|
) |
|
|
|
def __init_linear(self, in_features, out_features): |
|
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True) |
|
func.weight.data.normal_(0, 0.1) |
|
return func |
|
|
|
def forward(self, x) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
x = self.model(x) |
|
return x |
|
|
|
|
|
class DenseNet201(nn.Module): |
|
def __init__(self, classes=156, pretrained=True): |
|
super(DenseNet201, self).__init__() |
|
self.model = torchvision.models.densenet201(pretrained=pretrained) |
|
num_in_features = self.model.classifier.in_features |
|
self.model.classifier = nn.Sequential( |
|
self.__init_linear(in_features=num_in_features, out_features=classes), |
|
nn.Sigmoid() |
|
) |
|
|
|
def __init_linear(self, in_features, out_features): |
|
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True) |
|
func.weight.data.normal_(0, 0.1) |
|
return func |
|
|
|
def forward(self, x) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
x = self.model(x) |
|
return x |
|
|
|
|
|
class ResNet18(nn.Module): |
|
def __init__(self, classes=156, pretrained=True): |
|
super(ResNet18, self).__init__() |
|
self.model = torchvision.models.resnet18(pretrained=pretrained) |
|
num_in_features = self.model.fc.in_features |
|
self.model.fc = nn.Sequential( |
|
self.__init_linear(in_features=num_in_features, out_features=classes), |
|
|
|
) |
|
|
|
def __init_linear(self, in_features, out_features): |
|
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True) |
|
func.weight.data.normal_(0, 0.1) |
|
return func |
|
|
|
def forward(self, x) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
x = self.model(x) |
|
return x |
|
|
|
|
|
class ResNet34(nn.Module): |
|
def __init__(self, classes=156, pretrained=True): |
|
super(ResNet34, self).__init__() |
|
self.model = torchvision.models.resnet34(pretrained=pretrained) |
|
num_in_features = self.model.fc.in_features |
|
self.model.fc = nn.Sequential( |
|
self.__init_linear(in_features=num_in_features, out_features=classes), |
|
|
|
) |
|
|
|
def __init_linear(self, in_features, out_features): |
|
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True) |
|
func.weight.data.normal_(0, 0.1) |
|
return func |
|
|
|
def forward(self, x) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
x = self.model(x) |
|
return x |
|
|
|
|
|
class ResNet50(nn.Module): |
|
def __init__(self, classes=156, pretrained=True): |
|
super(ResNet50, self).__init__() |
|
self.model = torchvision.models.resnet50(pretrained=pretrained) |
|
num_in_features = self.model.fc.in_features |
|
self.model.fc = nn.Sequential( |
|
self.__init_linear(in_features=num_in_features, out_features=classes), |
|
|
|
) |
|
|
|
def __init_linear(self, in_features, out_features): |
|
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True) |
|
func.weight.data.normal_(0, 0.1) |
|
return func |
|
|
|
def forward(self, x) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
x = self.model(x) |
|
return x |
|
|
|
|
|
class ResNet101(nn.Module): |
|
def __init__(self, classes=156, pretrained=True): |
|
super(ResNet101, self).__init__() |
|
self.model = torchvision.models.resnet101(pretrained=pretrained) |
|
num_in_features = self.model.fc.in_features |
|
self.model.fc = nn.Sequential( |
|
self.__init_linear(in_features=num_in_features, out_features=classes), |
|
|
|
) |
|
|
|
def __init_linear(self, in_features, out_features): |
|
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True) |
|
func.weight.data.normal_(0, 0.1) |
|
return func |
|
|
|
def forward(self, x) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
x = self.model(x) |
|
return x |
|
|
|
|
|
class ResNet152(nn.Module): |
|
def __init__(self, classes=156, pretrained=True): |
|
super(ResNet152, self).__init__() |
|
self.model = torchvision.models.resnet152(pretrained=pretrained) |
|
num_in_features = self.model.fc.in_features |
|
self.model.fc = nn.Sequential( |
|
self.__init_linear(in_features=num_in_features, out_features=classes), |
|
|
|
) |
|
|
|
def __init_linear(self, in_features, out_features): |
|
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True) |
|
func.weight.data.normal_(0, 0.1) |
|
return func |
|
|
|
def forward(self, x) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
x = self.model(x) |
|
return x |
|
|
|
|
|
class VGG19(nn.Module): |
|
def __init__(self, classes=14, pretrained=True): |
|
super(VGG19, self).__init__() |
|
self.model = torchvision.models.vgg19_bn(pretrained=pretrained) |
|
self.model.classifier = nn.Sequential( |
|
self.__init_linear(in_features=25088, out_features=4096), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
self.__init_linear(in_features=4096, out_features=4096), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
self.__init_linear(in_features=4096, out_features=classes), |
|
|
|
) |
|
|
|
def __init_linear(self, in_features, out_features): |
|
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True) |
|
func.weight.data.normal_(0, 0.1) |
|
return func |
|
|
|
def forward(self, x) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
x = self.model(x) |
|
return x |
|
|
|
|
|
class VGG(nn.Module): |
|
def __init__(self, tags_num): |
|
super(VGG, self).__init__() |
|
vgg_model_urls['vgg19'] = vgg_model_urls['vgg19'].replace('https://', 'http://') |
|
self.vgg19 = models.vgg19(pretrained=True) |
|
vgg19_classifier = list(self.vgg19.classifier.children())[:-1] |
|
self.classifier = nn.Sequential(*vgg19_classifier) |
|
self.fc = nn.Linear(4096, tags_num) |
|
self.fc.apply(self.init_weights) |
|
self.bn = nn.BatchNorm1d(tags_num, momentum=0.1) |
|
|
|
|
|
def init_weights(self, m): |
|
if type(m) == nn.Linear: |
|
self.fc.weight.data.normal_(0, 0.1) |
|
self.fc.bias.data.fill_(0) |
|
|
|
def forward(self, images) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
visual_feats = self.vgg19.features(images) |
|
tags_classifier = visual_feats.view(visual_feats.size(0), -1) |
|
tags_classifier = self.bn(self.fc(self.classifier(tags_classifier))) |
|
return tags_classifier |
|
|
|
|
|
class InceptionV3(nn.Module): |
|
def __init__(self, classes=156, pretrained=True): |
|
super(InceptionV3, self).__init__() |
|
self.model = torchvision.models.inception_v3(pretrained=pretrained) |
|
num_in_features = self.model.classifier.in_features |
|
self.model.classifier = nn.Sequential( |
|
self.__init_linear(in_features=num_in_features, out_features=classes), |
|
|
|
) |
|
|
|
def __init_linear(self, in_features, out_features): |
|
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True) |
|
func.weight.data.normal_(0, 0.1) |
|
return func |
|
|
|
def forward(self, x) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
x = self.model(x) |
|
return x |
|
|
|
|
|
class CheXNetDenseNet121(nn.Module): |
|
def __init__(self, classes=14, pretrained=True): |
|
super(CheXNetDenseNet121, self).__init__() |
|
self.densenet121 = torchvision.models.densenet121(pretrained=pretrained) |
|
num_in_features = self.densenet121.classifier.in_features |
|
self.densenet121.classifier = nn.Sequential( |
|
nn.Linear(in_features=num_in_features, out_features=classes, bias=True), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
x = self.densenet121(x) |
|
return x |
|
|
|
|
|
class CheXNet(nn.Module): |
|
def __init__(self, classes=156): |
|
super(CheXNet, self).__init__() |
|
self.densenet121 = CheXNetDenseNet121(classes=14) |
|
self.densenet121 = torch.nn.DataParallel(self.densenet121).cuda() |
|
self.densenet121.load_state_dict(torch.load('./models/CheXNet.pth.tar')['state_dict']) |
|
self.densenet121.module.densenet121.classifier = nn.Sequential( |
|
self.__init_linear(1024, classes), |
|
nn.Sigmoid() |
|
) |
|
|
|
def __init_linear(self, in_features, out_features): |
|
func = nn.Linear(in_features=in_features, out_features=out_features, bias=True) |
|
func.weight.data.normal_(0, 0.1) |
|
return func |
|
|
|
def forward(self, x) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
x = self.densenet121(x) |
|
return x |
|
|
|
|
|
class ModelFactory(object): |
|
def __init__(self, model_name, pretrained, classes): |
|
self.model_name = model_name |
|
self.pretrained = pretrained |
|
self.classes = classes |
|
|
|
def create_model(self): |
|
if self.model_name == 'VGG19': |
|
_model = VGG19(pretrained=self.pretrained, classes=self.classes) |
|
elif self.model_name == 'DenseNet121': |
|
_model = DenseNet121(pretrained=self.pretrained, classes=self.classes) |
|
elif self.model_name == 'DenseNet161': |
|
_model = DenseNet161(pretrained=self.pretrained, classes=self.classes) |
|
elif self.model_name == 'DenseNet169': |
|
_model = DenseNet169(pretrained=self.pretrained, classes=self.classes) |
|
elif self.model_name == 'DenseNet201': |
|
_model = DenseNet201(pretrained=self.pretrained, classes=self.classes) |
|
elif self.model_name == 'CheXNet': |
|
_model = CheXNet(classes=self.classes) |
|
elif self.model_name == 'ResNet18': |
|
_model = ResNet18(pretrained=self.pretrained, classes=self.classes) |
|
elif self.model_name == 'ResNet34': |
|
_model = ResNet34(pretrained=self.pretrained, classes=self.classes) |
|
elif self.model_name == 'ResNet50': |
|
_model = ResNet50(pretrained=self.pretrained, classes=self.classes) |
|
elif self.model_name == 'ResNet101': |
|
_model = ResNet101(pretrained=self.pretrained, classes=self.classes) |
|
elif self.model_name == 'ResNet152': |
|
_model = ResNet152(pretrained=self.pretrained, classes=self.classes) |
|
elif self.model_name == 'VGG': |
|
_model = VGG(tags_num=self.classes) |
|
else: |
|
_model = CheXNet(classes=self.classes) |
|
|
|
return _model |
|
|
|
|
|
class EncoderCNN(nn.Module): |
|
def __init__(self, embed_size, pretrained=True): |
|
super(EncoderCNN, self).__init__() |
|
|
|
resnet = models.resnet152(pretrained=pretrained) |
|
modules = list(resnet.children())[:-1] |
|
self.resnet = nn.Sequential(*modules) |
|
self.linear = nn.Linear(resnet.fc.in_features, embed_size) |
|
self.bn = nn.BatchNorm1d(embed_size, momentum=0.1) |
|
self.__init_weights() |
|
|
|
def __init_weights(self): |
|
self.linear.weight.data.normal_(0.0, 0.1) |
|
self.linear.bias.data.fill_(0) |
|
|
|
def forward(self, images) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
features = self.resnet(images) |
|
features = Variable(features.data) |
|
features = features.view(features.size(0), -1) |
|
features = self.bn(self.linear(features)) |
|
return features |
|
|
|
|
|
class DecoderRNN(nn.Module): |
|
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, n_max=50): |
|
super(DecoderRNN, self).__init__() |
|
self.embed = nn.Embedding(vocab_size, embed_size) |
|
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) |
|
self.linear = nn.Linear(hidden_size, vocab_size) |
|
self.__init_weights() |
|
self.n_max = n_max |
|
|
|
def __init_weights(self): |
|
self.embed.weight.data.uniform_(-0.1, 0.1) |
|
self.linear.weight.data.uniform_(-0.1, 0.1) |
|
self.linear.bias.data.fill_(0) |
|
|
|
def forward(self, features, captions) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
embeddings = self.embed(captions) |
|
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1) |
|
hidden, _ = self.lstm(embeddings) |
|
outputs = self.linear(hidden[:, -1, :]) |
|
return outputs |
|
|
|
def sample(self, features, start_tokens): |
|
sampled_ids = np.zeros((np.shape(features)[0], self.n_max)) |
|
predicted = start_tokens |
|
embeddings = features |
|
embeddings = embeddings.unsqueeze(1) |
|
|
|
for i in range(self.n_max): |
|
predicted = self.embed(predicted) |
|
embeddings = torch.cat([embeddings, predicted], dim=1) |
|
hidden_states, _ = self.lstm(embeddings) |
|
hidden_states = hidden_states[:, -1, :] |
|
outputs = self.linear(hidden_states) |
|
predicted = torch.max(outputs, 1)[1] |
|
sampled_ids[:, i] = predicted |
|
predicted = predicted.unsqueeze(1) |
|
return sampled_ids |
|
|
|
|
|
class VisualFeatureExtractor(nn.Module): |
|
def __init__(self, pretrained=False): |
|
super(VisualFeatureExtractor, self).__init__() |
|
resnet = models.resnet152(pretrained=pretrained) |
|
modules = list(resnet.children())[:-1] |
|
self.resnet = nn.Sequential(*modules) |
|
self.out_features = resnet.fc.in_features |
|
|
|
def forward(self, images) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
features = self.resnet(images) |
|
features = features.view(features.size(0), -1) |
|
return features |
|
|
|
|
|
class MLC(nn.Module): |
|
def __init__(self, classes=156, sementic_features_dim=512, fc_in_features=2048, k=10): |
|
super(MLC, self).__init__() |
|
self.classifier = nn.Linear(in_features=fc_in_features, out_features=classes) |
|
self.embed = nn.Embedding(classes, sementic_features_dim) |
|
self.k = k |
|
self.softmax = nn.Softmax() |
|
|
|
def forward(self, visual_features) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
tags = self.softmax(self.classifier(visual_features)) |
|
semantic_features = self.embed(torch.topk(tags, self.k)[1]) |
|
return tags, semantic_features |
|
|
|
|
|
class CoAttention(nn.Module): |
|
def __init__(self, embed_size=512, hidden_size=512, visual_size=2048): |
|
super(CoAttention, self).__init__() |
|
self.W_v = nn.Linear(in_features=visual_size, out_features=visual_size) |
|
self.bn_v = nn.BatchNorm1d(num_features=visual_size, momentum=0.1) |
|
|
|
self.W_v_h = nn.Linear(in_features=hidden_size, out_features=visual_size) |
|
self.bn_v_h = nn.BatchNorm1d(num_features=visual_size, momentum=0.1) |
|
|
|
self.W_v_att = nn.Linear(in_features=visual_size, out_features=visual_size) |
|
self.bn_v_att = nn.BatchNorm1d(num_features=visual_size, momentum=0.1) |
|
|
|
self.W_a = nn.Linear(in_features=hidden_size, out_features=hidden_size) |
|
self.bn_a = nn.BatchNorm1d(num_features=10, momentum=0.1) |
|
|
|
self.W_a_h = nn.Linear(in_features=hidden_size, out_features=hidden_size) |
|
self.bn_a_h = nn.BatchNorm1d(num_features=1, momentum=0.1) |
|
|
|
self.W_a_att = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) |
|
self.bn_a_att = nn.BatchNorm1d(num_features=10, momentum=0.1) |
|
|
|
self.W_fc = nn.Linear(in_features=visual_size + hidden_size, out_features=embed_size) |
|
self.bn_fc = nn.BatchNorm1d(num_features=embed_size, momentum=0.1) |
|
|
|
self.tanh = nn.Tanh() |
|
self.softmax = nn.Softmax() |
|
|
|
def forward(self, visual_features, semantic_features, h_sent) -> object: |
|
""" |
|
only training |
|
:rtype: object |
|
""" |
|
W_v = self.bn_v(self.W_v(visual_features)) |
|
W_v_h = self.bn_v_h(self.W_v_h(h_sent.squeeze(1))) |
|
|
|
alpha_v = self.softmax(self.bn_v_att(self.W_v_att(self.tanh(W_v + W_v_h)))) |
|
v_att = torch.mul(alpha_v, visual_features) |
|
|
|
|
|
W_a_h = self.bn_a_h(self.W_a_h(h_sent)) |
|
W_a = self.bn_a(self.W_a(semantic_features)) |
|
alpha_a = self.softmax(self.bn_a_att(self.W_a_att(self.tanh(torch.add(W_a_h, W_a))))) |
|
a_att = torch.mul(alpha_a, semantic_features).sum(1) |
|
|
|
ctx = self.bn_fc(self.W_fc(torch.cat([v_att, a_att], dim=1))) |
|
|
|
return ctx, v_att |
|
|
|
|
|
class SentenceLSTM(nn.Module): |
|
def __init__(self, embed_size=512, hidden_size=512, num_layers=1): |
|
super(SentenceLSTM, self).__init__() |
|
self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers) |
|
self.W_t_h = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True) |
|
self.bn_t_h = nn.BatchNorm1d(num_features=1, momentum=0.1) |
|
|
|
self.W_t_ctx = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True) |
|
self.bn_t_ctx = nn.BatchNorm1d(num_features=1, momentum=0.1) |
|
|
|
self.W_stop_s_1 = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True) |
|
self.bn_stop_s_1 = nn.BatchNorm1d(num_features=1, momentum=0.1) |
|
|
|
self.W_stop_s = nn.Linear(in_features=hidden_size, out_features=embed_size, bias=True) |
|
self.bn_stop_s = nn.BatchNorm1d(num_features=1, momentum=0.1) |
|
|
|
self.W_stop = nn.Linear(in_features=embed_size, out_features=2, bias=True) |
|
self.bn_stop = nn.BatchNorm1d(num_features=1, momentum=0.1) |
|
|
|
self.W_topic = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True) |
|
self.bn_topic = nn.BatchNorm1d(num_features=1, momentum=0.1) |
|
|
|
self.W_topic_2 = nn.Linear(in_features=embed_size, out_features=embed_size, bias=True) |
|
self.bn_topic_2 = nn.BatchNorm1d(num_features=1, momentum=0.1) |
|
|
|
self.sigmoid = nn.Sigmoid() |
|
self.tanh = nn.Tanh() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, ctx, prev_hidden_state, states=None) -> object: |
|
""" |
|
v2 |
|
:rtype: object |
|
""" |
|
ctx = ctx.unsqueeze(1) |
|
hidden_state, states = self.lstm(ctx, states) |
|
topic = self.bn_topic(self.W_topic(self.tanh(self.bn_t_h(self.W_t_h(hidden_state) |
|
+ self.W_t_ctx(ctx))))) |
|
p_stop = self.bn_stop(self.W_stop(self.tanh(self.bn_stop_s(self.W_stop_s_1(prev_hidden_state) |
|
+ self.W_stop_s(hidden_state))))) |
|
return topic, p_stop, hidden_state, states |
|
|
|
|
|
class SentenceTCN(nn.Module): |
|
def __init__(self, |
|
input_channel=10, |
|
embed_size=512, |
|
output_size=512, |
|
nhid=512, |
|
levels=8, |
|
kernel_size=2, |
|
dropout=0): |
|
super(SentenceTCN, self).__init__() |
|
channel_sizes = [nhid] * levels |
|
self.tcn = TCN(input_size=input_channel, |
|
output_size=output_size, |
|
num_channels=channel_sizes, |
|
kernel_size=kernel_size, |
|
dropout=dropout) |
|
self.W_t_h = nn.Linear(in_features=output_size, out_features=embed_size, bias=True) |
|
self.W_t_ctx = nn.Linear(in_features=output_size, out_features=embed_size, bias=True) |
|
self.W_stop_s_1 = nn.Linear(in_features=output_size, out_features=embed_size, bias=True) |
|
self.W_stop_s = nn.Linear(in_features=output_size, out_features=embed_size, bias=True) |
|
self.W_stop = nn.Linear(in_features=embed_size, out_features=2, bias=True) |
|
self.t_w = nn.Linear(in_features=5120, out_features=2, bias=True) |
|
self.tanh = nn.Tanh() |
|
|
|
def forward(self, ctx, prev_output) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
output = self.tcn.forward(ctx) |
|
topic = self.tanh(self.W_t_h(output) + self.W_t_ctx(ctx[:, -1, :]).squeeze(1)) |
|
p_stop = self.W_stop(self.tanh(self.W_stop_s_1(prev_output) + self.W_stop_s(output))) |
|
return topic, p_stop, output |
|
|
|
|
|
class WordLSTM(nn.Module): |
|
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, n_max=50): |
|
super(WordLSTM, self).__init__() |
|
self.embed = nn.Embedding(vocab_size, embed_size) |
|
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) |
|
self.linear = nn.Linear(hidden_size, vocab_size) |
|
self.__init_weights() |
|
self.n_max = n_max |
|
self.vocab_size = vocab_size |
|
|
|
def __init_weights(self): |
|
self.embed.weight.data.uniform_(-0.1, 0.1) |
|
self.linear.weight.data.uniform_(-0.1, 0.1) |
|
self.linear.bias.data.fill_(0) |
|
|
|
def forward(self, topic_vec, captions) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
embeddings = self.embed(captions) |
|
embeddings = torch.cat((topic_vec, embeddings), 1) |
|
hidden, _ = self.lstm(embeddings) |
|
outputs = self.linear(hidden[:, -1, :]) |
|
return outputs |
|
|
|
def val(self, features, start_tokens): |
|
samples = torch.zeros((np.shape(features)[0], self.n_max, self.vocab_size)) |
|
samples[:, 0, start_tokens[0]] = 1 |
|
predicted = start_tokens |
|
embeddings = features |
|
embeddings = embeddings |
|
|
|
for i in range(1, self.n_max): |
|
predicted = self.embed(predicted) |
|
embeddings = torch.cat([embeddings, predicted], dim=1) |
|
hidden_states, _ = self.lstm(embeddings) |
|
hidden_states = hidden_states[:, -1, :] |
|
outputs = self.linear(hidden_states) |
|
samples[:, i, :] = outputs |
|
predicted = torch.max(outputs, 1)[1] |
|
predicted = predicted.unsqueeze(1) |
|
return samples |
|
|
|
def sample(self, features, start_tokens): |
|
sampled_ids = np.zeros((np.shape(features)[0], self.n_max)) |
|
sampled_ids[:, 0] = start_tokens.view(-1,) |
|
predicted = start_tokens |
|
embeddings = features |
|
embeddings = embeddings |
|
|
|
for i in range(1, self.n_max): |
|
predicted = self.embed(predicted) |
|
embeddings = torch.cat([embeddings, predicted], dim=1) |
|
hidden_states, _ = self.lstm(embeddings) |
|
hidden_states = hidden_states[:, -1, :] |
|
outputs = self.linear(hidden_states) |
|
predicted = torch.max(outputs, 1)[1] |
|
sampled_ids[:, i] = predicted |
|
predicted = predicted.unsqueeze(1) |
|
return sampled_ids |
|
|
|
|
|
class WordTCN(nn.Module): |
|
def __init__(self, |
|
input_channel=11, |
|
vocab_size=1000, |
|
embed_size=512, |
|
output_size=512, |
|
nhid=512, |
|
levels=8, |
|
kernel_size=2, |
|
dropout=0, |
|
n_max=50): |
|
super(WordTCN, self).__init__() |
|
self.vocab_size = vocab_size |
|
self.embed_size = embed_size |
|
self.output_size = output_size |
|
channel_sizes = [nhid] * levels |
|
self.kernel_size = kernel_size |
|
self.dropout = dropout |
|
self.n_max = n_max |
|
self.embed = nn.Embedding(vocab_size, embed_size) |
|
self.W_out = nn.Linear(in_features=output_size, out_features=vocab_size, bias=True) |
|
self.tcn = TCN(input_size=input_channel, |
|
output_size=output_size, |
|
num_channels=channel_sizes, |
|
kernel_size=kernel_size, |
|
dropout=dropout) |
|
|
|
def forward(self, topic_vec, captions) -> object: |
|
""" |
|
|
|
:rtype: object |
|
""" |
|
captions = self.embed(captions) |
|
embeddings = torch.cat([topic_vec, captions], dim=1) |
|
output = self.tcn.forward(embeddings) |
|
words = self.W_out(output) |
|
return words |
|
|
|
|
|
if __name__ == '__main__': |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
images = torch.randn((4, 3, 224, 224)) |
|
captions = torch.ones((4, 10)).long() |
|
hidden_state = torch.randn((4, 1, 512)) |
|
|
|
print("images:{}".format(images.shape)) |
|
print("captions:{}".format(captions.shape)) |
|
print("hidden_states:{}".format(hidden_state.shape)) |
|
|
|
extractor = VisualFeatureExtractor() |
|
visual_features = extractor.forward(images) |
|
print("visual_features:{}".format(visual_features.shape)) |
|
|
|
mlc = MLC() |
|
tags, semantic_features = mlc.forward(visual_features) |
|
print("tags:{}".format(tags.shape)) |
|
print("semantic_features:{}".format(semantic_features.shape)) |
|
|
|
co_att = CoAttention() |
|
ctx, v_att = co_att.forward(visual_features, semantic_features, hidden_state) |
|
print("ctx:{}".format(ctx.shape)) |
|
print("v_att:{}".format(v_att.shape)) |
|
|
|
sent_lstm = SentenceLSTM() |
|
topic, p_stop, hidden_state, states = sent_lstm.forward(ctx, hidden_state) |
|
print("Topic:{}".format(topic.shape)) |
|
print("P_STOP:{}".format(p_stop.shape)) |
|
|
|
word_lstm = WordLSTM(embed_size=512, hidden_size=512, vocab_size=100, num_layers=1) |
|
words = word_lstm.forward(topic, captions) |
|
print("words:{}".format(words.shape)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|