FupBERT / positional_encoding.py
c-dunlap's picture
Upload FupBERT
aae4e29
"""
© Battelle Memorial Institute 2023
Made available under the GNU General Public License v 2.0
BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
REPAIR OR CORRECTION.
"""
import numpy as np
import torch
import torch.nn as nn
class PositionalEncoding(nn.Module):
"""
A class that extends torch.nn.Module that applies positional encoding
for use in the Transformer architecture.
"""
def __init__(self, d_model, dropout=0.1, max_len=5000):
"""
Initializes a PositionalEncoding object.
Parameters
----------
d_model : int
The size of the model's embedding dimension.
dropout : float, optional
The fractional dropout to apply to the embedding. The default is 0.1.
max_len : int, optional
The maximum potential input sequnce length. The default is 5000.
Returns
-------
None.
"""
super(PositionalEncoding, self).__init__()
# Create the dropout
self.dropout = nn.Dropout(p=dropout)
# Create the encoding
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
"""
Perform a forward pass of the module.
Parameters
----------
x : tensor
The input tensor to apply the positional encoding to.
Returns
-------
tensor
The resulting tensor after applying the positional encoding to the
input.
"""
x = x + self.pe[:, : x.size(1)]
return self.dropout(x)