import torch | |
class PositionalEmbedding(torch.nn.Module): | |
def __init__(self, num_channels, max_positions=10000, endpoint=False): | |
super().__init__() | |
self.num_channels = num_channels | |
self.max_positions = max_positions | |
self.endpoint = endpoint | |
def forward(self, x): | |
freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device) | |
freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) | |
freqs = (1 / self.max_positions) ** freqs | |
x = x.ger(freqs.to(x.dtype)) | |
x = torch.cat([x.cos(), x.sin()], dim=1) | |
return x | |