Spaces:
Runtime error
Runtime error
File size: 2,600 Bytes
c80917c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import numpy as np
import torch
from .ShowTellModel import ShowTellModel
from .FCModel import FCModel
from .AttModel import *
from .TransformerModel import TransformerModel
from .cachedTransformer import TransformerModel as cachedTransformer
from .BertCapModel import BertCapModel
from .M2Transformer import M2TransformerModel
from .AoAModel import AoAModel
def setup(opt):
if opt.caption_model in ['fc', 'show_tell']:
print('Warning: %s model is mostly deprecated; many new features are not supported.' %opt.caption_model)
if opt.caption_model == 'fc':
print('Use newfc instead of fc')
if opt.caption_model == 'fc':
model = FCModel(opt)
elif opt.caption_model == 'language_model':
model = LMModel(opt)
elif opt.caption_model == 'newfc':
model = NewFCModel(opt)
elif opt.caption_model == 'show_tell':
model = ShowTellModel(opt)
# Att2in model in self-critical
elif opt.caption_model == 'att2in':
model = Att2inModel(opt)
# Att2in model with two-layer MLP img embedding and word embedding
elif opt.caption_model == 'att2in2':
model = Att2in2Model(opt)
elif opt.caption_model == 'att2all2':
print('Warning: this is not a correct implementation of the att2all model in the original paper.')
model = Att2all2Model(opt)
# Adaptive Attention model from Knowing when to look
elif opt.caption_model == 'adaatt':
model = AdaAttModel(opt)
# Adaptive Attention with maxout lstm
elif opt.caption_model == 'adaattmo':
model = AdaAttMOModel(opt)
# Top-down attention model
elif opt.caption_model in ['topdown', 'updown']:
model = UpDownModel(opt)
# StackAtt
elif opt.caption_model == 'stackatt':
model = StackAttModel(opt)
# DenseAtt
elif opt.caption_model == 'denseatt':
model = DenseAttModel(opt)
# Transformer
elif opt.caption_model == 'transformer':
if getattr(opt, 'cached_transformer', False):
model = cachedTransformer(opt)
else:
model = TransformerModel(opt)
# AoANet
elif opt.caption_model == 'aoa':
model = AoAModel(opt)
elif opt.caption_model == 'bert':
model = BertCapModel(opt)
elif opt.caption_model == 'm2transformer':
model = M2TransformerModel(opt)
else:
raise Exception("Caption model not supported: {}".format(opt.caption_model))
return model
|