File size: 3,266 Bytes
9d0d223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from torch import Tensor
from typing import Union, Callable


class CustomGLU(nn.Module):
    """Custom Gated Linear Unit activation.
    Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
    of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
    function (i.e. sigmoid, swish, etc.).

    Args:
        activation (nn.Module): The custom activation to apply in the Gated Linear Unit
        dim (int): the dimension on which to split the input. Default: -1

    Shape:
        - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
          dimensions
        - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`

    Examples::
        >>> m = CustomGLU(nn.Sigmoid())
        >>> input = torch.randn(4, 2)
        >>> output = m(input)
    """
    def __init__(self, activation: nn.Module, dim: int = -1):
        super(CustomGLU, self).__init__()
        self.dim = dim
        self.activation = activation

    def forward(self, x: Tensor):
        assert x.shape[self.dim] % 2 == 0  # M = N / 2
        a, b = torch.chunk(x, 2, dim=self.dim)
        return a * self.activation(b)


class SwiGLU(CustomGLU):
    """SiLU Gated Linear Unit activation.
    Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
    the first half of the input matrices, :math:`b` is the second half.

    Args:
        dim (int): the dimension on which to split the input. Default: -1
    """
    def __init__(self, dim: int = -1):
        super(SwiGLU, self).__init__(nn.SiLU(), dim)


class GeGLU(CustomGLU):
    """GeLU Gated Linear Unit activation.
    Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
    the first half of the input matrices, :math:`b` is the second half.

    Args:
        dim (int): the dimension on which to split the input. Default: -1
    """
    def __init__(self, dim: int = -1):
        super(GeGLU, self).__init__(nn.GELU(), dim)


class ReGLU(CustomGLU):
    """ReLU Gated Linear Unit activation.
    Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
    the first half of the input matrices, :math:`b` is the second half.

    Args:
        dim (int): the dimension on which to split the input. Default: -1
    """
    def __init__(self, dim: int = -1):
        super(ReGLU, self).__init__(nn.ReLU(), dim)


def get_activation_fn(
    activation: Union[str, Callable[[Tensor], Tensor]]
) -> Union[str, Callable[[Tensor], Tensor]]:
    """Helper function to map an activation string to the activation class.
    If the supplied activation is not a string that is recognized, the activation is passed back.

    Args:
        activation (str, or Callable[[Tensor], Tensor]): Activation to check
    """
    if isinstance(activation, str):
        if activation == "reglu":
            return ReGLU()
        elif activation == "geglu":
            return GeGLU()
        elif activation == "swiglu":
            return SwiGLU()
    return activation