Crystalcareai commited on
Commit
e7aeafc
·
verified ·
1 Parent(s): 03cf46d

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +4 -4
modeling_quiet.py CHANGED
@@ -167,18 +167,18 @@ class QuietRMSNorm(nn.Module):
167
 
168
  # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
169
  class QuietRotaryEmbedding(nn.Module):
170
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
171
  super().__init__()
172
 
173
  self.dim = dim
174
- self.max_position_embeddings = max_position_embeddings
175
  self.base = base
176
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
177
  self.register_buffer("inv_freq", inv_freq, persistent=False)
178
 
179
  # Build here to make `torch.jit.trace` work.
180
  self._set_cos_sin_cache(
181
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
182
  )
183
 
184
  def _set_cos_sin_cache(self, seq_len, device, dtype):
@@ -186,7 +186,6 @@ class QuietRotaryEmbedding(nn.Module):
186
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
187
 
188
  freqs = torch.outer(t, self.inv_freq)
189
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
190
  emb = torch.cat((freqs, freqs), dim=-1)
191
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
192
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
@@ -313,6 +312,7 @@ class QuietAttention(nn.Module):
313
  self.head_dim,
314
  max_position_embeddings=self.max_position_embeddings,
315
  base=self.rope_theta,
 
316
  )
317
 
318
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
 
167
 
168
  # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
169
  class QuietRotaryEmbedding(nn.Module):
170
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, max_thought_tokens=2):
171
  super().__init__()
172
 
173
  self.dim = dim
174
+ self.max_position_embeddings = max_position_embeddings + max_thought_tokens
175
  self.base = base
176
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
177
  self.register_buffer("inv_freq", inv_freq, persistent=False)
178
 
179
  # Build here to make `torch.jit.trace` work.
180
  self._set_cos_sin_cache(
181
+ seq_len=max_position_embeddings + max_thought_tokens, device=self.inv_freq.device, dtype=torch.get_default_dtype()
182
  )
183
 
184
  def _set_cos_sin_cache(self, seq_len, device, dtype):
 
186
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
187
 
188
  freqs = torch.outer(t, self.inv_freq)
 
189
  emb = torch.cat((freqs, freqs), dim=-1)
190
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
191
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
 
312
  self.head_dim,
313
  max_position_embeddings=self.max_position_embeddings,
314
  base=self.rope_theta,
315
+ max_thought_tokens=2,
316
  )
317
 
318
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):