rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame
1.22 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from mmengine.registry import MODELS
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
class HSwish(nn.Module):
"""Hard Swish Module.
This module applies the hard swish function:
.. math::
Hswish(x) = x * ReLU6(x + 3) / 6
Args:
inplace (bool): can optionally do the operation in-place.
Default: False.
Returns:
Tensor: The output tensor.
"""
def __init__(self, inplace: bool = False):
super().__init__()
self.act = nn.ReLU6(inplace)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.act(x + 3) / 6
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.7')):
# Hardswish is not supported when PyTorch version < 1.6.
# And Hardswish in PyTorch 1.6 does not support inplace.
MODELS.register_module(module=HSwish)
else:
MODELS.register_module(module=nn.Hardswish, name='HSwish')