|
import math
|
|
|
|
import torch
|
|
|
|
|
|
def encode_single(d_model, value, max_period=10000.0):
|
|
"""
|
|
:param d_model: dimension of the model
|
|
:param value: the value to encode
|
|
:param max_period: the maximum allowed value
|
|
:return: length*d_model position matrix
|
|
"""
|
|
if d_model % 2 != 0:
|
|
raise ValueError(
|
|
"Cannot use sin/cos positional encoding with "
|
|
"odd dim (got dim={:d})".format(d_model),
|
|
)
|
|
pe = torch.zeros(d_model)
|
|
div_term = torch.exp(
|
|
torch.arange(0, d_model, 2, dtype=torch.float)
|
|
* -(math.log(max_period) / d_model),
|
|
)
|
|
pe[0::2] = torch.sin(value * div_term)
|
|
pe[1::2] = torch.cos(value * div_term)
|
|
|
|
return pe
|
|
|
|
|
|
def timestep_embedding(t, dim, max_period=10000):
|
|
"""
|
|
Create sinusoidal timestep embeddings.
|
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
|
These may be fractional.
|
|
:param dim: the dimension of the output.
|
|
:param max_period: controls the minimum frequency of the embeddings.
|
|
:return: an (N, D) Tensor of positional embeddings.
|
|
"""
|
|
|
|
half = dim // 2
|
|
freqs = torch.exp(
|
|
-math.log(max_period)
|
|
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
|
/ half,
|
|
)
|
|
args = t[:, None].float() * freqs[None]
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
if dim % 2:
|
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
return embedding
|
|
|
|
|
|
def offset_sequence_embedding(t, dim, max_period=10000):
|
|
"""
|
|
Create sinusoidal timestep embeddings.
|
|
:param t: an (N, T) Tensor of sequences of time offsets
|
|
:param dim: the dimension of the output.
|
|
:param max_period: controls the minimum frequency of the embeddings.
|
|
:return: an (N, T, dim) Tensor of positional embeddings.
|
|
"""
|
|
N, T = t.shape
|
|
flattened = torch.flatten(t)
|
|
embedding = timestep_embedding(flattened, dim, max_period)
|
|
return torch.reshape(embedding, (N, T, dim))
|
|
|
|
|
|
def position_sequence_embedding(t, dim, max_period=10000):
|
|
"""
|
|
Create sinusoidal timestep embeddings.
|
|
:param t: an (N, T, D) Tensor of sequences of D dimensional positions.
|
|
:param dim: the dimension of the output.
|
|
:param max_period: controls the minimum frequency of the embeddings.
|
|
:return: an (N, T, D * dim) Tensor of positional embeddings.
|
|
"""
|
|
N, T, D = t.shape
|
|
flattened = torch.flatten(t)
|
|
embedding = timestep_embedding(flattened, dim, max_period)
|
|
return torch.reshape(embedding, (N, T, D * dim))
|
|
|
|
|
|
def positionalencoding(d_model, values, max_period=10000.0):
|
|
"""
|
|
:param d_model: dimension of the model
|
|
:param values: the values to encode
|
|
:param max_period: the maximum allowed value
|
|
:return: length*d_model position matrix
|
|
"""
|
|
if d_model % 2 != 0:
|
|
raise ValueError(
|
|
"Cannot use sin/cos positional encoding with "
|
|
"odd dim (got dim={:d})".format(d_model),
|
|
)
|
|
pe = torch.zeros(len(values), d_model)
|
|
position = values.unsqueeze(1)
|
|
div_term = torch.exp(
|
|
torch.arange(0, d_model, 2, dtype=torch.float)
|
|
* -(math.log(max_period) / d_model),
|
|
)
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
|
|
return pe
|
|
|
|
|
|
def positionalencoding1d(d_model, length):
|
|
"""
|
|
:param d_model: dimension of the model
|
|
:param length: length of positions
|
|
:return: length*d_model position matrix
|
|
"""
|
|
if d_model % 2 != 0:
|
|
raise ValueError(
|
|
"Cannot use sin/cos positional encoding with "
|
|
"odd dim (got dim={:d})".format(d_model),
|
|
)
|
|
pe = torch.zeros(2, d_model)
|
|
position = torch.arange(-50, 50, 100).unsqueeze(1)
|
|
div_term = torch.exp(
|
|
torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model),
|
|
)
|
|
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
|
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
|
|
|
return pe
|
|
|
|
|
|
def positionalencoding2d(d_model, height, width):
|
|
"""
|
|
:param d_model: dimension of the model
|
|
:param height: height of the positions
|
|
:param width: width of the positions
|
|
:return: d_model*height*width position matrix
|
|
"""
|
|
if d_model % 4 != 0:
|
|
raise ValueError(
|
|
"Cannot use sin/cos positional encoding with "
|
|
"odd dimension (got dim={:d})".format(d_model),
|
|
)
|
|
pe = torch.zeros(d_model, height, width)
|
|
|
|
d_model = int(d_model / 2)
|
|
div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model))
|
|
pos_w = torch.arange(0.0, width).unsqueeze(1)
|
|
pos_h = torch.arange(0.0, height).unsqueeze(1)
|
|
pe[0:d_model:2, :, :] = (
|
|
torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
|
)
|
|
pe[1:d_model:2, :, :] = (
|
|
torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
|
)
|
|
pe[d_model::2, :, :] = (
|
|
torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
|
)
|
|
pe[d_model + 1 :: 2, :, :] = (
|
|
torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
|
)
|
|
|
|
return pe
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import matplotlib.pyplot as plt
|
|
|
|
pe = positionalencoding(128, torch.tensor([-50, 50]))
|
|
plt.imshow(pe)
|
|
plt.show()
|
|
|