osu_mapper2 / osudiffusion /positional_embedding.py
Tiger14n's picture
Upload folder using huggingface_hub
7ef7abb verified
import math
import torch
def encode_single(d_model, value, max_period=10000.0):
"""
:param d_model: dimension of the model
:param value: the value to encode
:param max_period: the maximum allowed value
:return: length*d_model position matrix
"""
if d_model % 2 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(d_model),
)
pe = torch.zeros(d_model)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float)
* -(math.log(max_period) / d_model),
)
pe[0::2] = torch.sin(value * div_term)
pe[1::2] = torch.cos(value * div_term)
return pe
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
/ half,
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def offset_sequence_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: an (N, T) Tensor of sequences of time offsets
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, T, dim) Tensor of positional embeddings.
"""
N, T = t.shape
flattened = torch.flatten(t)
embedding = timestep_embedding(flattened, dim, max_period)
return torch.reshape(embedding, (N, T, dim))
def position_sequence_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: an (N, T, D) Tensor of sequences of D dimensional positions.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, T, D * dim) Tensor of positional embeddings.
"""
N, T, D = t.shape
flattened = torch.flatten(t)
embedding = timestep_embedding(flattened, dim, max_period)
return torch.reshape(embedding, (N, T, D * dim))
def positionalencoding(d_model, values, max_period=10000.0):
"""
:param d_model: dimension of the model
:param values: the values to encode
:param max_period: the maximum allowed value
:return: length*d_model position matrix
"""
if d_model % 2 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(d_model),
)
pe = torch.zeros(len(values), d_model)
position = values.unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float)
* -(math.log(max_period) / d_model),
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe
def positionalencoding1d(d_model, length):
"""
:param d_model: dimension of the model
:param length: length of positions
:return: length*d_model position matrix
"""
if d_model % 2 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(d_model),
)
pe = torch.zeros(2, d_model)
position = torch.arange(-50, 50, 100).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model),
)
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
return pe
def positionalencoding2d(d_model, height, width):
"""
:param d_model: dimension of the model
:param height: height of the positions
:param width: width of the positions
:return: d_model*height*width position matrix
"""
if d_model % 4 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with "
"odd dimension (got dim={:d})".format(d_model),
)
pe = torch.zeros(d_model, height, width)
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model))
pos_w = torch.arange(0.0, width).unsqueeze(1)
pos_h = torch.arange(0.0, height).unsqueeze(1)
pe[0:d_model:2, :, :] = (
torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
)
pe[1:d_model:2, :, :] = (
torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
)
pe[d_model::2, :, :] = (
torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
)
pe[d_model + 1 :: 2, :, :] = (
torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
)
return pe
if __name__ == "__main__":
import matplotlib.pyplot as plt
pe = positionalencoding(128, torch.tensor([-50, 50]))
plt.imshow(pe)
plt.show()