kaiokendev commited on
Commit
0cad621
1 Parent(s): 67bf26a

Fix bug, t needs to be scaled if input > 8192

Files changed (1) hide show
  1. llama_rope_scaled_monkey_patch.py +1 -0
llama_rope_scaled_monkey_patch.py CHANGED
@@ -42,6 +42,7 @@ class ScaledRotaryEmbedding(torch.nn.Module):
42
  t = torch.arange(
43
  self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
44
  )
 
45
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
46
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
47
  emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
 
42
  t = torch.arange(
43
  self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
44
  )
45
+ t *= self.scale
46
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
47
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
48
  emb = torch.cat((freqs, freqs), dim=-1).to(x.device)