Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
""" | |
DETR model and criterion classes. | |
""" | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx | |
from model.transformer import build_transformer | |
from model.matcher import build_matcher | |
from model.position_encoding import build_position_encoding | |
def accuracy(output, target, topk=(1,)): | |
"""Computes the precision@k for the specified values of k | |
output: (#items, #classes) | |
target: int, | |
""" | |
maxk = max(topk) | |
num_items = output.size(0) | |
_, pred = output.topk(maxk, 1, True, True) | |
pred = pred.t() | |
correct = pred.eq(target) | |
res = [] | |
for k in topk: | |
correct_k = correct[:k].view(-1).float().sum(0) | |
res.append(correct_k.mul_(100.0 / num_items)) | |
return res | |
class Model(nn.Module): | |
""" This is the Moment-DETR module that performs moment localization. """ | |
def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim, | |
num_queries, input_dropout, aux_loss=False, | |
contrastive_align_loss=False, contrastive_hdim=64, | |
max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2): | |
""" Initializes the model. | |
Parameters: | |
transformer: torch module of the transformer architecture. See transformer.py | |
position_embed: torch module of the position_embedding, See position_encoding.py | |
txt_position_embed: position_embedding for text | |
txt_dim: int, text query input dimension | |
vid_dim: int, video feature input dimension | |
num_queries: number of object queries, ie detection slot. This is the maximal number of objects | |
Moment-DETR can detect in a single video. | |
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. | |
contrastive_align_loss: If true, perform span - tokens contrastive learning | |
contrastive_hdim: dimension used for projecting the embeddings before computing contrastive loss | |
max_v_l: int, maximum #clips in videos | |
span_loss_type: str, one of [l1, ce] | |
l1: (center-x, width) regression. | |
ce: (st_idx, ed_idx) classification. | |
# foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground | |
# background_thd: float, intersection over prediction <= background_thd: labeled background | |
""" | |
super().__init__() | |
self.num_queries = num_queries | |
self.transformer = transformer | |
self.position_embed = position_embed | |
self.txt_position_embed = txt_position_embed | |
hidden_dim = transformer.d_model | |
self.span_loss_type = span_loss_type | |
self.max_v_l = max_v_l | |
span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 | |
self.span_embed = MLP(hidden_dim, hidden_dim, span_pred_dim, 3) | |
self.class_embed = nn.Linear(hidden_dim, 2) # 0: background, 1: foreground | |
self.use_txt_pos = use_txt_pos | |
self.n_input_proj = n_input_proj | |
# self.foreground_thd = foreground_thd | |
# self.background_thd = background_thd | |
self.query_embed = nn.Embedding(num_queries, hidden_dim) | |
relu_args = [True] * 3 | |
relu_args[n_input_proj-1] = False | |
self.input_txt_proj = nn.Sequential(*[ | |
LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), | |
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), | |
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) | |
][:n_input_proj]) | |
self.input_vid_proj = nn.Sequential(*[ | |
LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), | |
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), | |
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) | |
][:n_input_proj]) | |
self.contrastive_align_loss = contrastive_align_loss | |
if contrastive_align_loss: | |
self.contrastive_align_projection_query = nn.Linear(hidden_dim, contrastive_hdim) | |
self.contrastive_align_projection_txt = nn.Linear(hidden_dim, contrastive_hdim) | |
self.contrastive_align_projection_vid = nn.Linear(hidden_dim, contrastive_hdim) | |
self.saliency_proj = nn.Linear(hidden_dim, 1) | |
self.aux_loss = aux_loss | |
def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask): | |
"""The forward expects two tensors: | |
- src_txt: [batch_size, L_txt, D_txt] | |
- src_txt_mask: [batch_size, L_txt], containing 0 on padded pixels, | |
will convert to 1 as padding later for transformer | |
- src_vid: [batch_size, L_vid, D_vid] | |
- src_vid_mask: [batch_size, L_vid], containing 0 on padded pixels, | |
will convert to 1 as padding later for transformer | |
It returns a dict with the following elements: | |
- "pred_spans": The normalized boxes coordinates for all queries, represented as | |
(center_x, width). These values are normalized in [0, 1], | |
relative to the size of each individual image (disregarding possible padding). | |
See PostProcess for information on how to retrieve the unnormalized bounding box. | |
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of | |
dictionnaries containing the two above keys for each decoder layer. | |
""" | |
src_vid = self.input_vid_proj(src_vid) | |
src_txt = self.input_txt_proj(src_txt) | |
src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d) | |
mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt) | |
# TODO should we remove or use different positional embeddings to the src_txt? | |
pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d) | |
pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d) | |
# pos_txt = torch.zeros_like(src_txt) | |
# pad zeros for txt positions | |
pos = torch.cat([pos_vid, pos_txt], dim=1) | |
# (#layers, bsz, #queries, d), (bsz, L_vid+L_txt, d) | |
hs, memory = self.transformer(src, ~mask, self.query_embed.weight, pos) | |
outputs_class = self.class_embed(hs) # (#layers, batch_size, #queries, #classes) | |
outputs_coord = self.span_embed(hs) # (#layers, bsz, #queries, 2 or max_v_l * 2) | |
if self.span_loss_type == "l1": | |
outputs_coord = outputs_coord.sigmoid() | |
out = {'pred_logits': outputs_class[-1], 'pred_spans': outputs_coord[-1]} | |
txt_mem = memory[:, src_vid.shape[1]:] # (bsz, L_txt, d) | |
vid_mem = memory[:, :src_vid.shape[1]] # (bsz, L_vid, d) | |
if self.contrastive_align_loss: | |
proj_queries = F.normalize(self.contrastive_align_projection_query(hs), p=2, dim=-1) | |
proj_txt_mem = F.normalize(self.contrastive_align_projection_txt(txt_mem), p=2, dim=-1) | |
proj_vid_mem = F.normalize(self.contrastive_align_projection_vid(vid_mem), p=2, dim=-1) | |
out.update(dict( | |
proj_queries=proj_queries[-1], | |
proj_txt_mem=proj_txt_mem, | |
proj_vid_mem=proj_vid_mem | |
)) | |
out["saliency_scores"] = self.saliency_proj(vid_mem).squeeze(-1) # (bsz, L_vid) | |
if self.aux_loss: | |
# assert proj_queries and proj_txt_mem | |
out['aux_outputs'] = [ | |
{'pred_logits': a, 'pred_spans': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] | |
if self.contrastive_align_loss: | |
assert proj_queries is not None | |
for idx, d in enumerate(proj_queries[:-1]): | |
out['aux_outputs'][idx].update(dict(proj_queries=d, proj_txt_mem=proj_txt_mem)) | |
return out | |
# @torch.jit.unused | |
# def _set_aux_loss(self, outputs_class, outputs_coord): | |
# # this is a workaround to make torchscript happy, as torchscript | |
# # doesn't support dictionary with non-homogeneous values, such | |
# # as a dict having both a Tensor and a list. | |
# return [{'pred_logits': a, 'pred_spans': b} | |
# for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] | |
class SetCriterion(nn.Module): | |
""" This class computes the loss for DETR. | |
The process happens in two steps: | |
1) we compute hungarian assignment between ground truth boxes and the outputs of the model | |
2) we supervise each pair of matched ground-truth / prediction (supervise class and box) | |
""" | |
def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, | |
saliency_margin=1): | |
""" Create the criterion. | |
Parameters: | |
matcher: module able to compute a matching between targets and proposals | |
weight_dict: dict containing as key the names of the losses and as values their relative weight. | |
eos_coef: relative classification weight applied to the no-object category | |
losses: list of all the losses to be applied. See get_loss for list of available losses. | |
temperature: float, temperature for NCE loss | |
span_loss_type: str, [l1, ce] | |
max_v_l: int, | |
saliency_margin: float | |
""" | |
super().__init__() | |
self.matcher = matcher | |
self.weight_dict = weight_dict | |
self.losses = losses | |
self.temperature = temperature | |
self.span_loss_type = span_loss_type | |
self.max_v_l = max_v_l | |
self.saliency_margin = saliency_margin | |
# foreground and background classification | |
self.foreground_label = 0 | |
self.background_label = 1 | |
self.eos_coef = eos_coef | |
empty_weight = torch.ones(2) | |
empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) | |
self.register_buffer('empty_weight', empty_weight) | |
def loss_spans(self, outputs, targets, indices): | |
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss | |
targets dicts must contain the key "spans" containing a tensor of dim [nb_tgt_spans, 2] | |
The target spans are expected in format (center_x, w), normalized by the image size. | |
""" | |
assert 'pred_spans' in outputs | |
targets = targets["span_labels"] | |
idx = self._get_src_permutation_idx(indices) | |
src_spans = outputs['pred_spans'][idx] # (#spans, max_v_l * 2) | |
tgt_spans = torch.cat([t['spans'][i] for t, (_, i) in zip(targets, indices)], dim=0) # (#spans, 2) | |
if self.span_loss_type == "l1": | |
loss_span = F.l1_loss(src_spans, tgt_spans, reduction='none') | |
loss_giou = 1 - torch.diag(generalized_temporal_iou(span_cxw_to_xx(src_spans), span_cxw_to_xx(tgt_spans))) | |
else: # ce | |
n_spans = src_spans.shape[0] | |
src_spans = src_spans.view(n_spans, 2, self.max_v_l).transpose(1, 2) | |
loss_span = F.cross_entropy(src_spans, tgt_spans, reduction='none') | |
# giou | |
# src_span_indices = src_spans.max(1)[1] # (#spans, 2) | |
# src_span_indices[:, 1] += 1 # ed non-inclusive [st, ed) | |
# | |
# tgt_span_indices = tgt_spans | |
# tgt_span_indices[:, 1] += 1 | |
# loss_giou = 1 - torch.diag(generalized_temporal_iou(src_span_indices, tgt_span_indices)) | |
loss_giou = loss_span.new_zeros([1]) | |
losses = {} | |
losses['loss_b'] = loss_span.mean() | |
losses['loss_g'] = loss_giou.mean() | |
return losses | |
def loss_labels(self, outputs, targets, indices, log=True): | |
"""Classification loss (NLL) | |
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] | |
""" | |
# TODO add foreground and background classifier. use all non-matched as background. | |
assert 'pred_logits' in outputs | |
src_logits = outputs['pred_logits'] # (batch_size, #queries, #classes=2) | |
# idx is a tuple of two 1D tensors (batch_idx, src_idx), of the same length == #objects in batch | |
idx = self._get_src_permutation_idx(indices) | |
target_classes = torch.full(src_logits.shape[:2], self.background_label, | |
dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) | |
target_classes[idx] = self.foreground_label | |
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, reduction="none") | |
losses = {'loss_f': loss_ce.mean()} | |
if log: | |
# TODO this should probably be a separate loss, not hacked in this one here | |
losses['class_error'] = 100 - accuracy(src_logits[idx], self.foreground_label)[0] | |
return losses | |
def loss_saliency(self, outputs, targets, indices, log=True): | |
"""higher scores for positive clips""" | |
if "saliency_pos_labels" not in targets: | |
return {"loss_s_intra": 0} | |
saliency_scores = outputs["saliency_scores"] # (N, L) | |
pos_indices = targets["saliency_pos_labels"] # (N, #pairs) | |
neg_indices = targets["saliency_neg_labels"] # (N, #pairs) | |
num_pairs = pos_indices.shape[1] # typically 2 or 4 | |
batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device) | |
pos_scores = torch.stack( | |
[saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1) | |
neg_scores = torch.stack( | |
[saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1) | |
loss_saliency = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \ | |
/ (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale | |
return {"loss_s_intra": loss_saliency} | |
def loss_contrastive_align(self, outputs, targets, indices, log=True): | |
"""encourage higher scores between matched query span and input text""" | |
normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens | |
normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d) | |
logits = torch.einsum( | |
"bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens) | |
logits = logits.sum(2) / self.temperature # (bsz, #queries) | |
idx = self._get_src_permutation_idx(indices) | |
positive_map = torch.zeros_like(logits, dtype=torch.bool) | |
positive_map[idx] = True | |
positive_logits = logits.masked_fill(~positive_map, 0) | |
pos_term = positive_logits.sum(1) # (bsz, ) | |
num_pos = positive_map.sum(1) # (bsz, ) | |
neg_term = logits.logsumexp(1) # (bsz, ) | |
loss_nce = - pos_term / num_pos + neg_term # (bsz, ) | |
losses = {"loss_contrastive_align": loss_nce.mean()} | |
return losses | |
def loss_contrastive_align_vid_txt(self, outputs, targets, indices, log=True): | |
"""encourage higher scores between matched query span and input text""" | |
# TODO (1) align vid_mem and txt_mem; | |
# TODO (2) change L1 loss as CE loss on 75 labels, similar to soft token prediction in MDETR | |
normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens | |
normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d) | |
logits = torch.einsum( | |
"bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens) | |
logits = logits.sum(2) / self.temperature # (bsz, #queries) | |
idx = self._get_src_permutation_idx(indices) | |
positive_map = torch.zeros_like(logits, dtype=torch.bool) | |
positive_map[idx] = True | |
positive_logits = logits.masked_fill(~positive_map, 0) | |
pos_term = positive_logits.sum(1) # (bsz, ) | |
num_pos = positive_map.sum(1) # (bsz, ) | |
neg_term = logits.logsumexp(1) # (bsz, ) | |
loss_nce = - pos_term / num_pos + neg_term # (bsz, ) | |
losses = {"loss_contrastive_align": loss_nce.mean()} | |
return losses | |
def _get_src_permutation_idx(self, indices): | |
# permute predictions following indices | |
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |
src_idx = torch.cat([src for (src, _) in indices]) | |
return batch_idx, src_idx # two 1D tensors of the same length | |
def _get_tgt_permutation_idx(self, indices): | |
# permute targets following indices | |
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) | |
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) | |
return batch_idx, tgt_idx | |
def get_loss(self, loss, outputs, targets, indices, **kwargs): | |
loss_map = { | |
"spans": self.loss_spans, | |
"labels": self.loss_labels, | |
"contrastive_align": self.loss_contrastive_align, | |
"saliency": self.loss_saliency, | |
} | |
assert loss in loss_map, f'do you really want to compute {loss} loss?' | |
return loss_map[loss](outputs, targets, indices, **kwargs) | |
def forward(self, outputs, targets): | |
""" This performs the loss computation. | |
Parameters: | |
outputs: dict of tensors, see the output specification of the model for the format | |
targets: list of dicts, such that len(targets) == batch_size. | |
The expected keys in each dict depends on the losses applied, see each loss' doc | |
""" | |
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} | |
# Retrieve the matching between the outputs of the last layer and the targets | |
# list(tuples), each tuple is (pred_span_indices, tgt_span_indices) | |
indices = self.matcher(outputs_without_aux, targets) | |
# Compute all the requested losses | |
losses = {} | |
for loss in self.losses: | |
losses.update(self.get_loss(loss, outputs, targets, indices)) | |
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer. | |
if 'aux_outputs' in outputs: | |
for i, aux_outputs in enumerate(outputs['aux_outputs']): | |
indices = self.matcher(aux_outputs, targets) | |
for loss in self.losses: | |
if "saliency" == loss: # skip as it is only in the top layer | |
continue | |
kwargs = {} | |
l_dict = self.get_loss(loss, aux_outputs, targets, indices, **kwargs) | |
l_dict = {k + f'_{i}': v for k, v in l_dict.items()} | |
losses.update(l_dict) | |
return losses | |
class MLP(nn.Module): | |
""" Very simple multi-layer perceptron (also called FFN)""" | |
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |
super().__init__() | |
self.num_layers = num_layers | |
h = [hidden_dim] * (num_layers - 1) | |
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) | |
def forward(self, x): | |
for i, layer in enumerate(self.layers): | |
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
return x | |
class LinearLayer(nn.Module): | |
"""linear layer configurable with layer normalization, dropout, ReLU.""" | |
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True): | |
super(LinearLayer, self).__init__() | |
self.relu = relu | |
self.layer_norm = layer_norm | |
if layer_norm: | |
self.LayerNorm = nn.LayerNorm(in_hsz) | |
layers = [ | |
nn.Dropout(dropout), | |
nn.Linear(in_hsz, out_hsz) | |
] | |
self.net = nn.Sequential(*layers) | |
def forward(self, x): | |
"""(N, L, D)""" | |
if self.layer_norm: | |
x = self.LayerNorm(x) | |
x = self.net(x) | |
if self.relu: | |
x = F.relu(x, inplace=True) | |
return x # (N, L, D) | |
def build_model(args): | |
# the `num_classes` naming here is somewhat misleading. | |
# it indeed corresponds to `max_obj_id + 1`, where max_obj_id | |
# is the maximum id for a class in your dataset. For example, | |
# COCO has a max_obj_id of 90, so we pass `num_classes` to be 91. | |
# As another example, for a dataset that has a single class with id 1, | |
# you should pass `num_classes` to be 2 (max_obj_id + 1). | |
# For more details on this, check the following discussion | |
# https://github.com/facebookresearch/moment_bert/issues/108#issuecomment-650269223 | |
device = torch.device(args.device) | |
transformer = build_transformer(args) | |
position_embedding, txt_position_embedding = build_position_encoding(args) | |
model = Model( | |
transformer, | |
position_embedding, | |
txt_position_embedding, | |
txt_dim=args.t_feat_dim, | |
vid_dim=args.v_feat_dim, | |
num_queries=args.num_queries, | |
input_dropout=args.input_dropout, | |
aux_loss=args.aux_loss, | |
# contrastive_align_loss=args.contrastive_align_loss, | |
# contrastive_hdim=args.contrastive_hdim, | |
span_loss_type=args.span_loss_type, | |
use_txt_pos=args.use_txt_pos, | |
n_input_proj=args.n_input_proj, | |
) | |
matcher = build_matcher(args) | |
weight_dict = {"loss_b": args.b_loss_coef, | |
"loss_g": args.g_loss_coef, | |
"loss_f": args.f_loss_coef, | |
"loss_s_intra": args.s_loss_intra_coef, | |
"loss_s_inter": args.s_loss_inter_coef} | |
# if args.contrastive_align_loss: | |
# weight_dict["loss_contrastive_align"] = args.contrastive_align_loss_coef | |
# TODO this is a hack | |
if args.aux_loss: | |
aux_weight_dict = {} | |
for i in range(args.dec_layers - 1): | |
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items() if k != "loss_saliency"}) | |
weight_dict.update(aux_weight_dict) | |
losses = ['spans', 'labels', 'saliency'] | |
# if args.contrastive_align_loss: | |
# losses += ["contrastive_align"] | |
criterion = SetCriterion( | |
matcher=matcher, weight_dict=weight_dict, losses=losses, | |
eos_coef=args.eos_coef, temperature=args.temperature, | |
span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, | |
saliency_margin=args.saliency_margin | |
) | |
criterion.to(device) | |
return model, criterion | |