Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
""" | |
Miscellaneous utility functions | |
""" | |
import torch | |
def cat(tensors, dim=0): | |
""" | |
Efficient version of torch.cat that avoids a copy if there is only a single element in a list | |
""" | |
assert isinstance(tensors, (list, tuple)) | |
if len(tensors) == 1: | |
return tensors[0] | |
return torch.cat(tensors, dim) | |
def permute_and_flatten(layer, N, A, C, H, W): | |
layer = layer.view(N, -1, C, H, W) | |
layer = layer.permute(0, 3, 4, 1, 2) | |
layer = layer.reshape(N, -1, C) | |
return layer | |
def concat_box_prediction_layers(box_regression, box_cls=None, token_logits=None): | |
box_regression_flattened = [] | |
box_cls_flattened = [] | |
token_logit_flattened = [] | |
# for each feature level, permute the outputs to make them be in the | |
# same format as the labels. Note that the labels are computed for | |
# all feature levels concatenated, so we keep the same representation | |
# for the objectness and the box_regression | |
for box_cls_per_level, box_regression_per_level in zip( | |
box_cls, box_regression | |
): | |
N, AxC, H, W = box_cls_per_level.shape | |
Ax4 = box_regression_per_level.shape[1] | |
A = Ax4 // 4 | |
C = AxC // A | |
box_cls_per_level = permute_and_flatten( | |
box_cls_per_level, N, A, C, H, W | |
) | |
box_cls_flattened.append(box_cls_per_level) | |
box_regression_per_level = permute_and_flatten( | |
box_regression_per_level, N, A, 4, H, W | |
) | |
box_regression_flattened.append(box_regression_per_level) | |
if token_logits is not None: | |
for token_logit_per_level in token_logits: | |
N, AXT, H, W = token_logit_per_level.shape | |
T = AXT // A | |
token_logit_per_level = permute_and_flatten( | |
token_logit_per_level, N, A, T, H, W | |
) | |
token_logit_flattened.append(token_logit_per_level) | |
# concatenate on the first dimension (representing the feature levels), to | |
# take into account the way the labels were generated (with all feature maps | |
# being concatenated as well) | |
box_cls = cat(box_cls_flattened, dim=1).reshape(-1, C) | |
box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4) | |
token_logits_stacked = None | |
if token_logits is not None: | |
# stacked | |
token_logits_stacked = cat(token_logit_flattened, dim=1) | |
return box_regression, box_cls, token_logits_stacked | |
def round_channels(channels, divisor=8): | |
rounded_channels = max(int(channels + divisor / 2.0) // divisor * divisor, divisor) | |
if float(rounded_channels) < 0.9 * channels: | |
rounded_channels += divisor | |
return rounded_channels | |