Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
""" | |
Wrappers around on some nn functions, mainly to support empty tensors. | |
Ideally, add support directly in PyTorch to empty tensors in those functions. | |
These can be removed once https://github.com/pytorch/pytorch/issues/12013 | |
is implemented | |
""" | |
from typing import List | |
import torch | |
from torch.nn import functional as F | |
def cat(tensors: List[torch.Tensor], dim: int = 0): | |
""" | |
Efficient version of torch.cat that avoids a copy if there is only a single element in a list | |
""" | |
assert isinstance(tensors, (list, tuple)) | |
if len(tensors) == 1: | |
return tensors[0] | |
return torch.cat(tensors, dim) | |
def cross_entropy(input, target, *, reduction="mean", **kwargs): | |
""" | |
Same as `torch.nn.functional.cross_entropy`, but returns 0 (instead of nan) | |
for empty inputs. | |
""" | |
if target.numel() == 0 and reduction == "mean": | |
return input.sum() * 0.0 # connect the gradient | |
return F.cross_entropy(input, target, **kwargs) | |
class _NewEmptyTensorOp(torch.autograd.Function): | |
def forward(ctx, x, new_shape): | |
ctx.shape = x.shape | |
return x.new_empty(new_shape) | |
def backward(ctx, grad): | |
shape = ctx.shape | |
return _NewEmptyTensorOp.apply(grad, shape), None | |
class Conv2d(torch.nn.Conv2d): | |
""" | |
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. | |
""" | |
def __init__(self, *args, **kwargs): | |
""" | |
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: | |
Args: | |
norm (nn.Module, optional): a normalization layer | |
activation (callable(Tensor) -> Tensor): a callable activation function | |
It assumes that norm layer is used before activation. | |
""" | |
norm = kwargs.pop("norm", None) | |
activation = kwargs.pop("activation", None) | |
super().__init__(*args, **kwargs) | |
self.norm = norm | |
self.activation = activation | |
def forward(self, x): | |
# torchscript does not support SyncBatchNorm yet | |
# https://github.com/pytorch/pytorch/issues/40507 | |
# and we skip these codes in torchscript since: | |
# 1. currently we only support torchscript in evaluation mode | |
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or | |
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs. | |
if not torch.jit.is_scripting(): | |
if x.numel() == 0 and self.training: | |
# https://github.com/pytorch/pytorch/issues/12013 | |
assert not isinstance( | |
self.norm, torch.nn.SyncBatchNorm | |
), "SyncBatchNorm does not support empty inputs!" | |
x = F.conv2d( | |
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups | |
) | |
if self.norm is not None: | |
x = self.norm(x) | |
if self.activation is not None: | |
x = self.activation(x) | |
return x | |
ConvTranspose2d = torch.nn.ConvTranspose2d | |
BatchNorm2d = torch.nn.BatchNorm2d | |
interpolate = F.interpolate | |
Linear = torch.nn.Linear | |
def nonzero_tuple(x): | |
""" | |
A 'as_tuple=True' version of torch.nonzero to support torchscript. | |
because of https://github.com/pytorch/pytorch/issues/38718 | |
""" | |
if torch.jit.is_scripting(): | |
if x.dim() == 0: | |
return x.unsqueeze(0).nonzero().unbind(1) | |
return x.nonzero().unbind(1) | |
else: | |
return x.nonzero(as_tuple=True) | |