Crystalcareai commited on
Commit
6f8d262
1 Parent(s): 40fb520

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +3 -2
modeling_quiet.py CHANGED
@@ -109,12 +109,13 @@ class QuietRotaryEmbedding(nn.Module):
109
  def _set_cos_sin_cache(self, seq_len, device, dtype):
110
  self.max_seq_len_cached = seq_len
111
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
 
112
  freqs = torch.outer(t, self.inv_freq)
113
- emb = torch.cat((freqs, freqs), dim=-1).view(1, 1, seq_len, self.dim * 2)
 
114
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
115
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
116
 
117
-
118
  def forward(self, x, seq_len=None):
119
  # x: [bs, num_attention_heads, seq_len, head_size]
120
  if seq_len > self.max_seq_len_cached:
 
109
  def _set_cos_sin_cache(self, seq_len, device, dtype):
110
  self.max_seq_len_cached = seq_len
111
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
112
+
113
  freqs = torch.outer(t, self.inv_freq)
114
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
115
+ emb = torch.cat((freqs, freqs), dim=-1)
116
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
117
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
118
 
 
119
  def forward(self, x, seq_len=None):
120
  # x: [bs, num_attention_heads, seq_len, head_size]
121
  if seq_len > self.max_seq_len_cached: