Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
class Scale(nn.Module): | |
"""A learnable scale parameter. | |
This layer scales the input by a learnable factor. It multiplies a | |
learnable scale parameter of shape (1,) with input of any shape. | |
Args: | |
scale (float): Initial value of scale factor. Default: 1.0 | |
""" | |
def __init__(self, scale: float = 1.0): | |
super().__init__() | |
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return x * self.scale | |
class LayerScale(nn.Module): | |
"""LayerScale layer. | |
Args: | |
dim (int): Dimension of input features. | |
inplace (bool): Whether performs operation in-place. | |
Default: `False`. | |
data_format (str): The input data format, could be 'channels_last' | |
or 'channels_first', representing (B, C, H, W) and | |
(B, N, C) format data respectively. Default: 'channels_last'. | |
scale (float): Initial value of scale factor. Default: 1.0 | |
""" | |
def __init__(self, | |
dim: int, | |
inplace: bool = False, | |
data_format: str = 'channels_last', | |
scale: float = 1e-5): | |
super().__init__() | |
assert data_format in ('channels_last', 'channels_first'), \ | |
"'data_format' could only be channels_last or channels_first." | |
self.inplace = inplace | |
self.data_format = data_format | |
self.weight = nn.Parameter(torch.ones(dim) * scale) | |
def forward(self, x) -> torch.Tensor: | |
if self.data_format == 'channels_first': | |
shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2)))) | |
else: | |
shape = tuple((*(1 for _ in range(x.dim() - 1)), -1)) | |
if self.inplace: | |
return x.mul_(self.weight.view(*shape)) | |
else: | |
return x * self.weight.view(*shape) | |