|
This repository corresponds to the base Llama 3.1 8B model. The model has the same model weight format, but does RoPE using per frequency scaling, hence requiring code changes for inference. |
|
|
|
Here is a short term patch to make it generate properly |
|
|
|
``` |
|
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py |
|
index 5c0c57f3e..f94a4cb37 100644 |
|
--- a/src/transformers/models/llama/modeling_llama.py |
|
+++ b/src/transformers/models/llama/modeling_llama.py |
|
@@ -73,6 +73,29 @@ class LlamaRMSNorm(nn.Module): |
|
|
|
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) |
|
|
|
+def apply_scaling(freqs: torch.Tensor): |
|
+ # Values obtained from grid search |
|
+ scale_factor = 8 |
|
+ low_freq_factor = 1 |
|
+ high_freq_factor = 4 |
|
+ old_context_len = 8192 # original llama3 length |
|
+ |
|
+ low_freq_wavelen = old_context_len / low_freq_factor |
|
+ high_freq_wavelen = old_context_len / high_freq_factor |
|
+ new_freqs = [] |
|
+ for freq in freqs: |
|
+ wavelen = 2 * math.pi / freq |
|
+ if wavelen < high_freq_wavelen: |
|
+ new_freqs.append(freq) |
|
+ elif wavelen > low_freq_wavelen: |
|
+ new_freqs.append(freq / scale_factor) |
|
+ else: |
|
+ assert low_freq_wavelen != high_freq_wavelen |
|
+ smooth = (old_context_len / wavelen - low_freq_factor) / ( |
|
+ high_freq_factor - low_freq_factor |
|
+ ) |
|
+ new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) |
|
+ return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) |
|
|
|
class LlamaRotaryEmbedding(nn.Module): |
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): |
|
@@ -82,6 +105,7 @@ class LlamaRotaryEmbedding(nn.Module): |
|
self.max_position_embeddings = max_position_embeddings |
|
self.base = base |
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) |
|
+ inv_freq = apply_scaling(inv_freq) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
# For BC we register cos and sin cached |
|
self.max_seq_len_cached = max_position_embeddings |
|
``` |