Update modeling_quiet.py
Browse files- modeling_quiet.py +4 -4
modeling_quiet.py
CHANGED
@@ -167,18 +167,18 @@ class QuietRMSNorm(nn.Module):
|
|
167 |
|
168 |
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
|
169 |
class QuietRotaryEmbedding(nn.Module):
|
170 |
-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
171 |
super().__init__()
|
172 |
|
173 |
self.dim = dim
|
174 |
-
self.max_position_embeddings = max_position_embeddings
|
175 |
self.base = base
|
176 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
177 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
178 |
|
179 |
# Build here to make `torch.jit.trace` work.
|
180 |
self._set_cos_sin_cache(
|
181 |
-
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
182 |
)
|
183 |
|
184 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
@@ -186,7 +186,6 @@ class QuietRotaryEmbedding(nn.Module):
|
|
186 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
187 |
|
188 |
freqs = torch.outer(t, self.inv_freq)
|
189 |
-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
190 |
emb = torch.cat((freqs, freqs), dim=-1)
|
191 |
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
192 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
@@ -313,6 +312,7 @@ class QuietAttention(nn.Module):
|
|
313 |
self.head_dim,
|
314 |
max_position_embeddings=self.max_position_embeddings,
|
315 |
base=self.rope_theta,
|
|
|
316 |
)
|
317 |
|
318 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
|
167 |
|
168 |
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
|
169 |
class QuietRotaryEmbedding(nn.Module):
|
170 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, max_thought_tokens=2):
|
171 |
super().__init__()
|
172 |
|
173 |
self.dim = dim
|
174 |
+
self.max_position_embeddings = max_position_embeddings + max_thought_tokens
|
175 |
self.base = base
|
176 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
177 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
178 |
|
179 |
# Build here to make `torch.jit.trace` work.
|
180 |
self._set_cos_sin_cache(
|
181 |
+
seq_len=max_position_embeddings + max_thought_tokens, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
182 |
)
|
183 |
|
184 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|
|
186 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
187 |
|
188 |
freqs = torch.outer(t, self.inv_freq)
|
|
|
189 |
emb = torch.cat((freqs, freqs), dim=-1)
|
190 |
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
191 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
|
|
312 |
self.head_dim,
|
313 |
max_position_embeddings=self.max_position_embeddings,
|
314 |
base=self.rope_theta,
|
315 |
+
max_thought_tokens=2,
|
316 |
)
|
317 |
|
318 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|