import torch import torch.nn as nn from fastai.vision import * from .model_vision import BaseVision from .model_language import BCNLanguage from .model_alignment import BaseAlignment class ABINetModel(nn.Module): def __init__(self, config): super().__init__() self.use_alignment = ifnone(config.model_use_alignment, True) self.max_length = config.dataset_max_length + 1 # additional stop token self.vision = BaseVision(config) self.language = BCNLanguage(config) if self.use_alignment: self.alignment = BaseAlignment(config) def forward(self, images, *args): v_res = self.vision(images) v_tokens = torch.softmax(v_res['logits'], dim=-1) v_lengths = v_res['pt_lengths'].clamp_(2, self.max_length) # TODO:move to langauge model l_res = self.language(v_tokens, v_lengths) if not self.use_alignment: return l_res, v_res l_feature, v_feature = l_res['feature'], v_res['feature'] a_res = self.alignment(l_feature, v_feature) return a_res, l_res, v_res