Spaces:
Sleeping
Sleeping
#include "common.comp" | |
// TODO: use a local size of 32 or more (Metal uses 1024) | |
layout(local_size_x = 1) in; | |
layout (push_constant) uniform parameter { | |
uint inAOff; | |
uint inBOff; | |
uint outOff; | |
int n_dims; | |
int mode; | |
int n_ctx_orig; | |
float freq_base; | |
float freq_scale; | |
float ext_factor; | |
float attn_factor; | |
float beta_fast; | |
float beta_slow; | |
uint nb00; | |
uint nb01; | |
uint nb02; | |
uint nb03; | |
int ne0; | |
uint nb0; | |
uint nb1; | |
uint nb2; | |
uint nb3; | |
} pcs; | |
float rope_yarn_ramp(const float low, const float high, const float i0) { | |
const float y = (i0 / 2 - low) / max(0.001f, high - low); | |
return 1.0f - min(1.0f, max(0.0f, y)); | |
} | |
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn | |
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. | |
void rope_yarn( | |
float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale, | |
out float cos_theta, out float sin_theta | |
) { | |
// Get n-d rotational scaling corrected for extrapolation | |
float theta_interp = freq_scale * theta_extrap; | |
float theta = theta_interp; | |
if (ext_factor != 0.0f) { | |
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; | |
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; | |
// Get n-d magnitude scaling corrected for interpolation | |
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); | |
} | |
cos_theta = cos(theta) * mscale; | |
sin_theta = sin(theta) * mscale; | |
} | |
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get | |
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` | |
float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { | |
return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base)); | |
} | |
void rope_yarn_corr_dims( | |
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2] | |
) { | |
// start and end correction dims | |
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); | |
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); | |
} | |