|
from .bundled_loss import BundledLoss |
|
from .consisitency_loss import get_consistency_loss, get_volume_seg_map |
|
from .entropy_loss import get_entropy_loss |
|
from .loss import Loss |
|
from .map_label_loss import get_map_label_loss |
|
from .map_mask_loss import get_map_mask_loss |
|
from .multi_view_consistency_loss import ( |
|
get_multi_view_consistency_loss, |
|
get_spixel_tgt_map, |
|
) |
|
from .volume_label_loss import get_volume_label_loss |
|
from .volume_mask_loss import get_volume_mask_loss |
|
|
|
|
|
def get_bundled_loss(opt): |
|
"""Loss function for the overeall training, including the multi-view |
|
consistency loss.""" |
|
single_modality_loss = get_loss(opt) |
|
multi_view_consistency_loss = get_multi_view_consistency_loss(opt) |
|
volume_mask_loss = get_volume_mask_loss(opt) |
|
bundled_loss = BundledLoss( |
|
single_modality_loss, |
|
multi_view_consistency_loss, |
|
volume_mask_loss, |
|
opt.mvc_weight, |
|
opt.mvc_time_dependent, |
|
opt.mvc_steepness, |
|
opt.modality, |
|
opt.consistency_weight, |
|
opt.consistency_source, |
|
) |
|
|
|
return bundled_loss |
|
|
|
|
|
def get_loss(opt): |
|
"""Loss function for a single model, excluding the multi-view consistency |
|
loss.""" |
|
map_label_loss = get_map_label_loss(opt) |
|
volume_label_loss = get_volume_label_loss(opt) |
|
map_mask_loss = get_map_mask_loss(opt) |
|
volume_mask_loss = get_volume_mask_loss(opt) |
|
consisitency_loss = get_consistency_loss(opt) |
|
entropy_loss = get_entropy_loss(opt) |
|
loss = Loss( |
|
map_label_loss, |
|
volume_label_loss, |
|
map_mask_loss, |
|
volume_mask_loss, |
|
consisitency_loss, |
|
entropy_loss, |
|
opt.map_label_weight, |
|
opt.volume_label_weight, |
|
opt.map_mask_weight, |
|
opt.volume_mask_weight, |
|
opt.consistency_weight, |
|
opt.map_entropy_weight, |
|
opt.volume_entropy_weight, |
|
opt.consistency_source, |
|
) |
|
return loss |
|
|