Llama-3.1-8B / README.md
osanseviero's picture
Create README.md
e55f7f5 verified
|
raw
history blame
2.2 kB

This repository corresponds to the base Llama 3.1 8B model. The model has the same model weight format, but does RoPE using per frequency scaling, hence requiring code changes for inference.

Here is a short term patch to make it generate properly

diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 5c0c57f3e..f94a4cb37 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -73,6 +73,29 @@ class LlamaRMSNorm(nn.Module):
 
 ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
 
+def apply_scaling(freqs: torch.Tensor):
+    # Values obtained from grid search
+    scale_factor = 8
+    low_freq_factor = 1
+    high_freq_factor = 4
+    old_context_len = 8192  # original llama3 length
+
+    low_freq_wavelen = old_context_len / low_freq_factor
+    high_freq_wavelen = old_context_len / high_freq_factor
+    new_freqs = []
+    for freq in freqs:
+        wavelen = 2 * math.pi / freq
+        if wavelen < high_freq_wavelen:
+            new_freqs.append(freq)
+        elif wavelen > low_freq_wavelen:
+            new_freqs.append(freq / scale_factor)
+        else:
+            assert low_freq_wavelen != high_freq_wavelen
+            smooth = (old_context_len / wavelen - low_freq_factor) / (
+                high_freq_factor - low_freq_factor
+            )
+            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
+    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
 
 class LlamaRotaryEmbedding(nn.Module):
     def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
@@ -82,6 +105,7 @@ class LlamaRotaryEmbedding(nn.Module):
         self.max_position_embeddings = max_position_embeddings
         self.base = base
         inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
+        inv_freq = apply_scaling(inv_freq)
         self.register_buffer("inv_freq", inv_freq, persistent=False)
         # For BC we register cos and sin cached
         self.max_seq_len_cached = max_position_embeddings