# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License. import math import torch import torch.nn as nn import torch.nn.functional as F from nncore.nn import LOSSES, Parameter, build_loss @LOSSES.register() class SampledNCELoss(nn.Module): def __init__(self, temperature=0.07, max_scale=100, learnable=False, direction=('row', 'col'), loss_weight=1.0): super(SampledNCELoss, self).__init__() scale = torch.Tensor([math.log(1 / temperature)]) if learnable: self.scale = Parameter(scale) else: self.register_buffer('scale', scale) self.temperature = temperature self.max_scale = max_scale self.learnable = learnable self.direction = (direction, ) if isinstance(direction, str) else direction self.loss_weight = loss_weight def extra_repr(self): return ('temperature={}, max_scale={}, learnable={}, direction={}, loss_weight={}' .format(self.temperature, self.max_scale, self.learnable, self.direction, self.loss_weight)) def forward(self, video_emb, query_emb, video_msk, saliency, pos_clip): batch_inds = torch.arange(video_emb.size(0), device=video_emb.device) pos_scores = saliency[batch_inds, pos_clip].unsqueeze(-1) loss_msk = (saliency <= pos_scores) * video_msk scale = self.scale.exp().clamp(max=self.max_scale) i_sim = F.cosine_similarity(video_emb, query_emb, dim=-1) * scale i_sim = i_sim + torch.where(loss_msk > 0, .0, float('-inf')) loss = 0 if 'row' in self.direction: i_met = F.log_softmax(i_sim, dim=1)[batch_inds, pos_clip] loss = loss - i_met.sum() / i_met.size(0) if 'col' in self.direction: j_sim = i_sim.t() j_met = F.log_softmax(j_sim, dim=1)[pos_clip, batch_inds] loss = loss - j_met.sum() / j_met.size(0) loss = loss * self.loss_weight return loss @LOSSES.register() class BundleLoss(nn.Module): def __init__(self, sample_radius=1.5, loss_cls=None, loss_reg=None, loss_sal=None, loss_video_cal=None, loss_layer_cal=None): super(BundleLoss, self).__init__() self._loss_cls = build_loss(loss_cls) self._loss_reg = build_loss(loss_reg) self._loss_sal = build_loss(loss_sal) self._loss_video_cal = build_loss(loss_video_cal) self._loss_layer_cal = build_loss(loss_layer_cal) self.sample_radius = sample_radius def get_target_single(self, point, gt_bnd, gt_cls): num_pts, num_gts = point.size(0), gt_bnd.size(0) lens = gt_bnd[:, 1] - gt_bnd[:, 0] lens = lens[None, :].repeat(num_pts, 1) gt_seg = gt_bnd[None].expand(num_pts, num_gts, 2) s = point[:, 0, None] - gt_seg[:, :, 0] e = gt_seg[:, :, 1] - point[:, 0, None] r_tgt = torch.stack((s, e), dim=-1) if self.sample_radius > 0: center = (gt_seg[:, :, 0] + gt_seg[:, :, 1]) / 2 t_mins = center - point[:, 3, None] * self.sample_radius t_maxs = center + point[:, 3, None] * self.sample_radius dist_s = point[:, 0, None] - torch.maximum(t_mins, gt_seg[:, :, 0]) dist_e = torch.minimum(t_maxs, gt_seg[:, :, 1]) - point[:, 0, None] center = torch.stack((dist_s, dist_e), dim=-1) cls_msk = center.min(-1)[0] >= 0 else: cls_msk = r_tgt.min(-1)[0] >= 0 reg_dist = r_tgt.max(-1)[0] reg_msk = torch.logical_and((reg_dist >= point[:, 1, None]), (reg_dist <= point[:, 2, None])) lens.masked_fill_(cls_msk == 0, float('inf')) lens.masked_fill_(reg_msk == 0, float('inf')) min_len, min_len_inds = lens.min(dim=1) min_len_mask = torch.logical_and((lens <= (min_len[:, None] + 1e-3)), (lens < float('inf'))).to(r_tgt.dtype) label = F.one_hot(gt_cls[:, 0], 2).to(r_tgt.dtype) c_tgt = torch.matmul(min_len_mask, label).clamp(min=0.0, max=1.0)[:, 1] r_tgt = r_tgt[range(num_pts), min_len_inds] / point[:, 3, None] return c_tgt, r_tgt def get_target(self, data): cls_tgt, reg_tgt = [], [] for i in range(data['boundary'].size(0)): gt_bnd = data['boundary'][i] * data['fps'][i] gt_cls = gt_bnd.new_ones(gt_bnd.size(0), 1).long() c_tgt, r_tgt = self.get_target_single(data['point'], gt_bnd, gt_cls) cls_tgt.append(c_tgt) reg_tgt.append(r_tgt) cls_tgt = torch.stack(cls_tgt) reg_tgt = torch.stack(reg_tgt) return cls_tgt, reg_tgt def loss_cls(self, data, output, cls_tgt): src = data['out_class'].squeeze(-1) msk = torch.cat(data['pymid_msk'], dim=1) loss_cls = self._loss_cls(src, cls_tgt, weight=msk, avg_factor=msk.sum()) output['loss_cls'] = loss_cls return output def loss_reg(self, data, output, cls_tgt, reg_tgt): src = data['out_coord'] msk = cls_tgt.unsqueeze(2).repeat(1, 1, 2).bool() loss_reg = self._loss_reg(src, reg_tgt, weight=msk, avg_factor=msk.sum()) output['loss_reg'] = loss_reg return output def loss_sal(self, data, output): video_emb = data['video_emb'] query_emb = data['query_emb'] video_msk = data['video_msk'] saliency = data['saliency'] pos_clip = data['pos_clip'][:, 0] output['loss_sal'] = self._loss_sal(video_emb, query_emb, video_msk, saliency, pos_clip) return output def loss_cal(self, data, output): pos_clip = data['pos_clip'][:, 0] batch_inds = torch.arange(pos_clip.size(0), device=pos_clip.device) coll_v_emb, coll_q_emb = [], [] for v_emb, q_emb in zip(data['coll_v'], data['coll_q']): v_emb_pos = v_emb[batch_inds, pos_clip] q_emb_pos = q_emb[:, 0] coll_v_emb.append(v_emb_pos) coll_q_emb.append(q_emb_pos) v_emb = torch.stack(coll_v_emb) q_emb = torch.stack(coll_q_emb) output['loss_video_cal'] = self._loss_video_cal(v_emb, q_emb) v_emb = torch.stack(coll_v_emb, dim=1) q_emb = torch.stack(coll_q_emb, dim=1) output['loss_layer_cal'] = self._loss_layer_cal(v_emb, q_emb) return output def forward(self, data, output): if self._loss_reg is not None: cls_tgt, reg_tgt = self.get_target(data) output = self.loss_reg(data, output, cls_tgt, reg_tgt) else: cls_tgt = data['saliency'] if self._loss_cls is not None: output = self.loss_cls(data, output, cls_tgt) if self._loss_sal is not None: output = self.loss_sal(data, output) if self._loss_video_cal is not None or self._loss_layer_cal is not None: output = self.loss_cal(data, output) return output