import time import pickle import torch import torchvision.transforms as transforms from torch.utils.data import DataLoader from torch.autograd import Variable from PIL import Image import cv2 from models import * from dataset import * from loss import * from build_tag import * from build_vocab import * class CaptionSampler(object): def __init__(self): # Default configuration values self.args = { "model_dir": "model/", "image_dir": "", "caption_json": "", "vocab_path": "vocab.pkl", "file_lists": "", "load_model_path": "train_best_loss.pth.tar", "resize": 224, "cam_size": 224, "generate_dir": "cam", "result_path": "results", "result_name": "debug", "momentum": 0.1, "visual_model_name": "densenet201", "pretrained": False, "classes": 210, "sementic_features_dim": 512, "k": 10, "attention_version": "v4", "embed_size": 512, "hidden_size": 512, "sent_version": "v1", "sentence_num_layers": 2, "dropout": 0.1, "word_num_layers": 1, "s_max": 10, "n_max": 30, "batch_size": 8, "lambda_tag": 10000, "lambda_stop": 10, "lambda_word": 1, "cuda": False # Keep CUDA disabled by default } self.vocab = self.__init_vocab() self.tagger = self.__init_tagger() self.transform = self.__init_transform() self.model_state_dict = self.__load_mode_state_dict() self.extractor = self.__init_visual_extractor() self.mlc = self.__init_mlc() self.co_attention = self.__init_co_attention() self.sentence_model = self.__init_sentence_model() self.word_model = self.__init_word_word() self.ce_criterion = self._init_ce_criterion() self.mse_criterion = self._init_mse_criterion() @staticmethod def _init_ce_criterion(): return nn.CrossEntropyLoss(size_average=False, reduce=False) @staticmethod def _init_mse_criterion(): return nn.MSELoss() def sample(self, image_file): self.extractor.eval() self.mlc.eval() self.co_attention.eval() self.sentence_model.eval() self.word_model.eval() imageData = self.transform(image_file) imageData = imageData.unsqueeze_(0) image = self.__to_var(imageData, requires_grad=False) visual_features, avg_features = self.extractor.forward(image) tags, semantic_features = self.mlc(avg_features) sentence_states = None prev_hidden_states = self.__to_var(torch.zeros(image.shape[0], 1, self.args["hidden_size"])) pred_sentences = [] for i in range(self.args["s_max"]): ctx, alpha_v, alpha_a = self.co_attention.forward(avg_features, semantic_features, prev_hidden_states) topic, p_stop, hidden_state, sentence_states = self.sentence_model.forward(ctx, prev_hidden_states, sentence_states) p_stop = p_stop.squeeze(1) p_stop = torch.max(p_stop, 1)[1].unsqueeze(1) start_tokens = np.zeros((topic.shape[0], 1)) start_tokens[:, 0] = self.vocab('') start_tokens = self.__to_var(torch.Tensor(start_tokens).long(), requires_grad=False) sampled_ids = self.word_model.sample(topic, start_tokens) prev_hidden_states = hidden_state sampled_ids = sampled_ids * p_stop.numpy() pred_sentences.append(self.__vec2sent(sampled_ids[0])) return pred_sentences def __init_cam_path(self, image_file): generate_dir = os.path.join(self.args["model_dir"], self.args["generate_dir"]) if not os.path.exists(generate_dir): os.makedirs(generate_dir) image_dir = os.path.join(generate_dir, image_file) if not os.path.exists(image_dir): os.makedirs(image_dir) return image_dir def __save_json(self, result): result_path = os.path.join(self.args["model_dir"], self.args["result_path"]) if not os.path.exists(result_path): os.makedirs(result_path) with open(os.path.join(result_path, '{}.json'.format(self.args["result_name"])), 'w') as f: json.dump(result, f) def __load_mode_state_dict(self): try: model_state_dict = torch.load(os.path.join(self.args["model_dir"], self.args["load_model_path"]), map_location=torch.device('cpu')) print("[Load Model-{} Succeed!]".format(self.args["load_model_path"])) print("Load From Epoch {}".format(model_state_dict['epoch'])) return model_state_dict except Exception as err: print("[Load Model Failed] {}".format(err)) raise err def __init_tagger(self): return Tag() def __vec2sent(self, array): sampled_caption = [] for word_id in array: word = self.vocab.get_word_by_id(word_id) if word == '': continue if word == '' or word == '': break sampled_caption.append(word) return ' '.join(sampled_caption) def __init_vocab(self): with open('vocab.pkl', 'rb') as f: vocab = pickle.load(f) print(vocab) return vocab def __init_data_loader(self, file_list): data_loader = get_loader(image_dir=self.args.image_dir, caption_json=self.args.caption_json, file_list=file_list, vocabulary=self.vocab, transform=self.transform, batch_size=self.args.batch_size, s_max=self.args.s_max, n_max=self.args.n_max, shuffle=False) return data_loader def __init_transform(self): transform = transforms.Compose([ transforms.Resize((self.args["resize"], self.args["resize"])), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) return transform def __to_var(self, x, requires_grad=True): if self.args["cuda"]: x = x.cuda() return Variable(x, requires_grad=requires_grad) def __init_visual_extractor(self): model = VisualFeatureExtractor(model_name=self.args["visual_model_name"], pretrained=self.args["pretrained"]) if self.model_state_dict is not None: print("Visual Extractor Loaded!") model.load_state_dict(self.model_state_dict['extractor']) if self.args["cuda"]: model = model.cuda() return model def __init_mlc(self): model = MLC(classes=self.args["classes"], sementic_features_dim=self.args["sementic_features_dim"], fc_in_features=self.extractor.out_features, k=self.args["k"]) if self.model_state_dict is not None: print("MLC Loaded!") model.load_state_dict(self.model_state_dict['mlc']) if self.args["cuda"]: model = model.cuda() return model def __init_co_attention(self): model = CoAttention(version=self.args["attention_version"], embed_size=self.args["embed_size"], hidden_size=self.args["hidden_size"], visual_size=self.extractor.out_features, k=self.args["k"], momentum=self.args["momentum"]) if self.model_state_dict is not None: print("Co-Attention Loaded!") model.load_state_dict(self.model_state_dict['co_attention']) if self.args["cuda"]: model = model.cuda() return model def __init_sentence_model(self): model = SentenceLSTM(version=self.args["sent_version"], embed_size=self.args["embed_size"], hidden_size=self.args["hidden_size"], num_layers=self.args["sentence_num_layers"], dropout=self.args["dropout"], momentum=self.args["momentum"]) if self.model_state_dict is not None: print("Sentence Model Loaded!") model.load_state_dict(self.model_state_dict['sentence_model']) if self.args["cuda"]: model = model.cuda() return model def __init_word_word(self): model = WordLSTM(vocab_size=len(self.vocab), embed_size=self.args["embed_size"], hidden_size=self.args["hidden_size"], num_layers=self.args["word_num_layers"], n_max=self.args["n_max"]) if self.model_state_dict is not None: print("Word Model Loaded!") model.load_state_dict(self.model_state_dict['word_model']) if self.args["cuda"]: model = model.cuda() return model def main(image): sampler = CaptionSampler() # image = 'sample_images/CXR195_IM-0618-1001.png' caption = sampler.sample(image) print(caption[0]) return caption[0]