|
|
|
|
|
|
|
import torch |
|
|
|
|
|
def apply_rotary_emb_qk_real( |
|
xqk: torch.Tensor, |
|
freqs_cos: torch.Tensor, |
|
freqs_sin: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers. |
|
|
|
Args: |
|
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D) |
|
Can be either just query or just key, or both stacked along some batch or * dim. |
|
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor. |
|
freqs_sin (torch.Tensor): Precomputed sine frequency tensor. |
|
|
|
Returns: |
|
torch.Tensor: The input tensor with rotary embeddings applied. |
|
""" |
|
|
|
xqk_even = xqk[..., 0::2] |
|
xqk_odd = xqk[..., 1::2] |
|
|
|
|
|
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk) |
|
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk) |
|
|
|
|
|
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2) |
|
return out |
|
|