|
import torch.nn as nn |
|
|
|
from .bayar_conv import BayarConv2d |
|
from .early_fusion_pre_filter import EarlyFusionPreFilter |
|
from .ensemble_model import EnsembleModel |
|
from .main_model import MainModel |
|
from .models import ModelBuilder, SegmentationModule |
|
from .srm_conv import SRMConv2d |
|
|
|
|
|
def get_ensemble_model(opt): |
|
models = {} |
|
for modality in opt.modality: |
|
models[modality] = get_single_modal_model(opt, modality) |
|
|
|
ensemble_model = EnsembleModel( |
|
models=models, mvc_single_weight=opt.mvc_single_weight |
|
) |
|
return ensemble_model |
|
|
|
|
|
def get_single_modal_model(opt, modality): |
|
encoder = ModelBuilder.build_encoder( |
|
arch=opt.encoder.lower(), fc_dim=opt.fc_dim, weights=opt.encoder_weight |
|
) |
|
decoder = ModelBuilder.build_decoder( |
|
arch=opt.decoder.lower(), |
|
fc_dim=opt.fc_dim, |
|
weights=opt.decoder_weight, |
|
num_class=opt.num_class, |
|
dropout=opt.dropout, |
|
fcn_up=opt.fcn_up, |
|
) |
|
|
|
if modality.lower() == "bayar": |
|
pre_filter = BayarConv2d( |
|
3, 3, 5, stride=1, padding=2, magnitude=opt.bayar_magnitude |
|
) |
|
elif modality.lower() == "srm": |
|
pre_filter = SRMConv2d( |
|
stride=1, padding=2, clip=opt.srm_clip |
|
) |
|
elif modality.lower() == "rgb": |
|
pre_filter = nn.Identity() |
|
else: |
|
pre_filter = EarlyFusionPreFilter( |
|
bayar_magnitude=opt.bayar_magnitude, srm_clip=opt.srm_clip |
|
) |
|
|
|
model = MainModel( |
|
encoder, |
|
decoder, |
|
opt.fc_dim, |
|
opt.volume_block_idx, |
|
opt.share_embed_head, |
|
pre_filter, |
|
opt.gem, |
|
opt.gem_coef, |
|
opt.gsm, |
|
opt.map_portion, |
|
opt.otsu_sel, |
|
opt.otsu_portion, |
|
) |
|
|
|
return model |
|
|