import torch import torch.nn.functional as F def clip_bce(output_dict, target_dict): """Binary crossentropy loss. """ return F.binary_cross_entropy( output_dict['clipwise_output'], target_dict['target']) def get_loss_func(loss_type): if loss_type == 'clip_bce': return clip_bce