|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
|
|
|
|
def centers(start: float, stop, num, dtype=None, device=None): |
|
"""linspace through bin centers. |
|
|
|
Args: |
|
start (float): Start of the range. |
|
stop (float): End of the range. |
|
num (int): Number of points. |
|
dtype (torch.dtype): Data type of the points. |
|
device (torch.device): Device of the points. |
|
|
|
Returns: |
|
centers (Tensor): Centers of the bins. Shape: (num,). |
|
""" |
|
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device) |
|
return (edges[:-1] + edges[1:]) / 2 |
|
|
|
|
|
|
|
def create_position_matrix( |
|
T: int, |
|
pH: int, |
|
pW: int, |
|
device: torch.device, |
|
dtype: torch.dtype, |
|
*, |
|
target_area: float = 36864, |
|
): |
|
""" |
|
Args: |
|
T: int - Temporal dimension |
|
pH: int - Height dimension after patchify |
|
pW: int - Width dimension after patchify |
|
|
|
Returns: |
|
pos: [T * pH * pW, 3] - position matrix |
|
""" |
|
|
|
t = torch.arange(T, dtype=dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
scale = math.sqrt(target_area / (pW * pH)) |
|
w = centers(-pW * scale / 2, pW * scale / 2, pW) |
|
h = centers(-pH * scale / 2, pH * scale / 2, pH) |
|
|
|
|
|
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") |
|
|
|
|
|
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) |
|
pos = pos.view(-1, 3) |
|
pos = pos.to(dtype=dtype, device=device) |
|
|
|
return pos |
|
|
|
|
|
def compute_mixed_rotation( |
|
freqs: torch.Tensor, |
|
pos: torch.Tensor, |
|
): |
|
""" |
|
Project each 3-dim position into per-head, per-head-dim 1D frequencies. |
|
|
|
Args: |
|
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position |
|
pos: [N, 3] - position of each token |
|
num_heads: int |
|
|
|
Returns: |
|
freqs_cos: [N, num_heads, num_freqs] - cosine components |
|
freqs_sin: [N, num_heads, num_freqs] - sine components |
|
""" |
|
assert freqs.ndim == 3 |
|
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs) |
|
freqs_cos = torch.cos(freqs_sum) |
|
freqs_sin = torch.sin(freqs_sum) |
|
return freqs_cos, freqs_sin |
|
|