Crystalcareai
commited on
Commit
•
6f8d262
1
Parent(s):
40fb520
Update modeling_quiet.py
Browse files- modeling_quiet.py +3 -2
modeling_quiet.py
CHANGED
@@ -109,12 +109,13 @@ class QuietRotaryEmbedding(nn.Module):
|
|
109 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
110 |
self.max_seq_len_cached = seq_len
|
111 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
|
|
112 |
freqs = torch.outer(t, self.inv_freq)
|
113 |
-
|
|
|
114 |
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
115 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
116 |
|
117 |
-
|
118 |
def forward(self, x, seq_len=None):
|
119 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
120 |
if seq_len > self.max_seq_len_cached:
|
|
|
109 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
110 |
self.max_seq_len_cached = seq_len
|
111 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
112 |
+
|
113 |
freqs = torch.outer(t, self.inv_freq)
|
114 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
115 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
116 |
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
117 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
118 |
|
|
|
119 |
def forward(self, x, seq_len=None):
|
120 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
121 |
if seq_len > self.max_seq_len_cached:
|