Jyothirmai's picture
Update ViTCoAtt.py
a77dcc0 verified
raw
history blame
9.89 kB
import time
import pickle
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
from PIL import Image
import cv2
from models import *
from dataset import *
from loss import *
from build_tag import *
from build_vocab import *
class CaptionSampler(object):
def __init__(self):
# Default configuration values
self.args = {
"model_dir": "model/",
"image_dir": "",
"caption_json": "",
"vocab_path": "vocab.pkl",
"file_lists": "",
"load_model_path": "train_best_loss.pth.tar",
"resize": 224,
"cam_size": 224,
"generate_dir": "cam",
"result_path": "results",
"result_name": "debug",
"momentum": 0.1,
"visual_model_name": "densenet201",
"pretrained": False,
"classes": 210,
"sementic_features_dim": 512,
"k": 10,
"attention_version": "v4",
"embed_size": 512,
"hidden_size": 512,
"sent_version": "v1",
"sentence_num_layers": 2,
"dropout": 0.1,
"word_num_layers": 1,
"s_max": 10,
"n_max": 30,
"batch_size": 8,
"lambda_tag": 10000,
"lambda_stop": 10,
"lambda_word": 1,
"cuda": False # Keep CUDA disabled by default
}
self.vocab = self.__init_vocab()
self.tagger = self.__init_tagger()
self.transform = self.__init_transform()
self.model_state_dict = self.__load_mode_state_dict()
self.extractor = self.__init_visual_extractor()
self.mlc = self.__init_mlc()
self.co_attention = self.__init_co_attention()
self.sentence_model = self.__init_sentence_model()
self.word_model = self.__init_word_word()
self.ce_criterion = self._init_ce_criterion()
self.mse_criterion = self._init_mse_criterion()
@staticmethod
def _init_ce_criterion():
return nn.CrossEntropyLoss(size_average=False, reduce=False)
@staticmethod
def _init_mse_criterion():
return nn.MSELoss()
def sample(self, image_file):
self.extractor.eval()
self.mlc.eval()
self.co_attention.eval()
self.sentence_model.eval()
self.word_model.eval()
imageData = self.transform(image_file)
imageData = imageData.unsqueeze_(0)
image = self.__to_var(imageData, requires_grad=False)
visual_features, avg_features = self.extractor.forward(image)
tags, semantic_features = self.mlc(avg_features)
sentence_states = None
prev_hidden_states = self.__to_var(torch.zeros(image.shape[0], 1, self.args["hidden_size"]))
pred_sentences = []
for i in range(self.args["s_max"]):
ctx, alpha_v, alpha_a = self.co_attention.forward(avg_features, semantic_features, prev_hidden_states)
topic, p_stop, hidden_state, sentence_states = self.sentence_model.forward(ctx,
prev_hidden_states,
sentence_states)
p_stop = p_stop.squeeze(1)
p_stop = torch.max(p_stop, 1)[1].unsqueeze(1)
start_tokens = np.zeros((topic.shape[0], 1))
start_tokens[:, 0] = self.vocab('<start>')
start_tokens = self.__to_var(torch.Tensor(start_tokens).long(), requires_grad=False)
sampled_ids = self.word_model.sample(topic, start_tokens)
prev_hidden_states = hidden_state
sampled_ids = sampled_ids * p_stop.numpy()
pred_sentences.append(self.__vec2sent(sampled_ids[0]))
return pred_sentences
def __init_cam_path(self, image_file):
generate_dir = os.path.join(self.args["model_dir"], self.args["generate_dir"])
if not os.path.exists(generate_dir):
os.makedirs(generate_dir)
image_dir = os.path.join(generate_dir, image_file)
if not os.path.exists(image_dir):
os.makedirs(image_dir)
return image_dir
def __save_json(self, result):
result_path = os.path.join(self.args["model_dir"], self.args["result_path"])
if not os.path.exists(result_path):
os.makedirs(result_path)
with open(os.path.join(result_path, '{}.json'.format(self.args["result_name"])), 'w') as f:
json.dump(result, f)
def __load_mode_state_dict(self):
try:
model_state_dict = torch.load(os.path.join(self.args["model_dir"], self.args["load_model_path"]), map_location=torch.device('cpu'))
print("[Load Model-{} Succeed!]".format(self.args["load_model_path"]))
print("Load From Epoch {}".format(model_state_dict['epoch']))
return model_state_dict
except Exception as err:
print("[Load Model Failed] {}".format(err))
raise err
def __init_tagger(self):
return Tag()
def __vec2sent(self, array):
sampled_caption = []
for word_id in array:
word = self.vocab.get_word_by_id(word_id)
if word == '<start>':
continue
if word == '<end>' or word == '<pad>':
break
sampled_caption.append(word)
return ' '.join(sampled_caption)
def __init_vocab(self):
with open('vocab.pkl', 'rb') as f:
vocab = pickle.load(f)
print(vocab)
return vocab
def __init_data_loader(self, file_list):
data_loader = get_loader(image_dir=self.args.image_dir,
caption_json=self.args.caption_json,
file_list=file_list,
vocabulary=self.vocab,
transform=self.transform,
batch_size=self.args.batch_size,
s_max=self.args.s_max,
n_max=self.args.n_max,
shuffle=False)
return data_loader
def __init_transform(self):
transform = transforms.Compose([
transforms.Resize((self.args["resize"], self.args["resize"])),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
return transform
def __to_var(self, x, requires_grad=True):
if self.args["cuda"]:
x = x.cuda()
return Variable(x, requires_grad=requires_grad)
def __init_visual_extractor(self):
model = VisualFeatureExtractor(model_name=self.args["visual_model_name"],
pretrained=self.args["pretrained"])
if self.model_state_dict is not None:
print("Visual Extractor Loaded!")
model.load_state_dict(self.model_state_dict['extractor'])
if self.args["cuda"]:
model = model.cuda()
return model
def __init_mlc(self):
model = MLC(classes=self.args["classes"],
sementic_features_dim=self.args["sementic_features_dim"],
fc_in_features=self.extractor.out_features,
k=self.args["k"])
if self.model_state_dict is not None:
print("MLC Loaded!")
model.load_state_dict(self.model_state_dict['mlc'])
if self.args["cuda"]:
model = model.cuda()
return model
def __init_co_attention(self):
model = CoAttention(version=self.args["attention_version"],
embed_size=self.args["embed_size"],
hidden_size=self.args["hidden_size"],
visual_size=self.extractor.out_features,
k=self.args["k"],
momentum=self.args["momentum"])
if self.model_state_dict is not None:
print("Co-Attention Loaded!")
model.load_state_dict(self.model_state_dict['co_attention'])
if self.args["cuda"]:
model = model.cuda()
return model
def __init_sentence_model(self):
model = SentenceLSTM(version=self.args["sent_version"],
embed_size=self.args["embed_size"],
hidden_size=self.args["hidden_size"],
num_layers=self.args["sentence_num_layers"],
dropout=self.args["dropout"],
momentum=self.args["momentum"])
if self.model_state_dict is not None:
print("Sentence Model Loaded!")
model.load_state_dict(self.model_state_dict['sentence_model'])
if self.args["cuda"]:
model = model.cuda()
return model
def __init_word_word(self):
model = WordLSTM(vocab_size=len(self.vocab),
embed_size=self.args["embed_size"],
hidden_size=self.args["hidden_size"],
num_layers=self.args["word_num_layers"],
n_max=self.args["n_max"])
if self.model_state_dict is not None:
print("Word Model Loaded!")
model.load_state_dict(self.model_state_dict['word_model'])
if self.args["cuda"]:
model = model.cuda()
return model
def main(image):
sampler = CaptionSampler()
# image = 'sample_images/CXR195_IM-0618-1001.png'
caption = sampler.sample(image)
print(caption[0])
return caption[0]