Crystalcareai commited on
Commit
9172f24
·
verified ·
1 Parent(s): e7aeafc

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +55 -40
modeling_quiet.py CHANGED
@@ -147,6 +147,8 @@ def _get_unpad_data(attention_mask):
147
  cu_seqlens,
148
  max_seqlen_in_batch,
149
  )
 
 
150
  # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Quiet
151
  class QuietRMSNorm(nn.Module):
152
  def __init__(self, hidden_size, eps=1e-6):
@@ -167,18 +169,18 @@ class QuietRMSNorm(nn.Module):
167
 
168
  # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
169
  class QuietRotaryEmbedding(nn.Module):
170
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, max_thought_tokens=2):
171
  super().__init__()
172
 
173
  self.dim = dim
174
- self.max_position_embeddings = max_position_embeddings + max_thought_tokens
175
  self.base = base
176
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
177
  self.register_buffer("inv_freq", inv_freq, persistent=False)
178
 
179
  # Build here to make `torch.jit.trace` work.
180
  self._set_cos_sin_cache(
181
- seq_len=max_position_embeddings + max_thought_tokens, device=self.inv_freq.device, dtype=torch.get_default_dtype()
182
  )
183
 
184
  def _set_cos_sin_cache(self, seq_len, device, dtype):
@@ -186,6 +188,7 @@ class QuietRotaryEmbedding(nn.Module):
186
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
187
 
188
  freqs = torch.outer(t, self.inv_freq)
 
189
  emb = torch.cat((freqs, freqs), dim=-1)
190
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
191
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
@@ -231,18 +234,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
231
  Returns:
232
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
233
  """
234
- print(f"cos shape: {cos.shape}")
235
- print(f"position_ids shape: {position_ids.shape}")
236
- print(f"position_ids values: {position_ids}")
237
- print(f"unsqueeze_dim: {unsqueeze_dim}")
238
- assert torch.all(position_ids >= 0), "position_ids must be non-negative"
239
- assert torch.all(position_ids < cos.shape[0]), f"position_ids must be less than the size of cos ({cos.shape[0]})"
240
  cos = cos[position_ids].unsqueeze(unsqueeze_dim)
241
  sin = sin[position_ids].unsqueeze(unsqueeze_dim)
242
  q_embed = (q * cos) + (rotate_half(q) * sin)
243
  k_embed = (k * cos) + (rotate_half(k) * sin)
244
  return q_embed, k_embed
245
 
 
246
  class QuietMLP(nn.Module):
247
  def __init__(self, config):
248
  super().__init__()
@@ -283,8 +281,8 @@ class QuietAttention(nn.Module):
283
  self.layer_idx = layer_idx
284
  if layer_idx is None:
285
  logger.warning_once(
286
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
287
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
288
  "when creating this class."
289
  )
290
 
@@ -312,7 +310,6 @@ class QuietAttention(nn.Module):
312
  self.head_dim,
313
  max_position_embeddings=self.max_position_embeddings,
314
  base=self.rope_theta,
315
- max_thought_tokens=2,
316
  )
317
 
318
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@@ -370,36 +367,54 @@ class QuietAttention(nn.Module):
370
  f" {attn_weights.size()}"
371
  )
372
 
 
 
 
 
373
  if attention_mask is not None:
 
374
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
375
  raise ValueError(
376
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
377
  )
378
-
379
  attn_weights = attn_weights + attention_mask
380
-
 
 
 
381
  # upcast attention to fp32
382
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
 
 
 
383
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
 
 
 
384
  attn_output = torch.matmul(attn_weights, value_states)
385
-
 
 
 
386
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
387
  raise ValueError(
388
  f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
389
  f" {attn_output.size()}"
390
  )
391
-
392
  attn_output = attn_output.transpose(1, 2).contiguous()
393
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
394
-
395
  attn_output = self.o_proj(attn_output)
396
-
 
 
397
  if not output_attentions:
398
  attn_weights = None
399
-
400
  return attn_output, attn_weights, past_key_value
401
 
402
-
403
  class QuietFlashAttention2(QuietAttention):
404
  """
405
  Quiet flash attention module. This module inherits from `QuietAttention` as the weights of the module stays
@@ -576,7 +591,7 @@ class QuietFlashAttention2(QuietAttention):
576
  attention_mask (`torch.Tensor`):
577
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
578
  position of padding tokens and 1 for the position of non-padding tokens.
579
- dropout (`int`, *optional*):
580
  Attention dropout
581
  softmax_scale (`float`, *optional*):
582
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
@@ -694,7 +709,8 @@ class QuietFlashAttention2(QuietAttention):
694
  )
695
 
696
 
697
- # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
 
698
  class QuietSdpaAttention(QuietAttention):
699
  """
700
  Quiet attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -768,14 +784,14 @@ class QuietSdpaAttention(QuietAttention):
768
  query_states,
769
  key_states,
770
  value_states,
771
- attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
772
  dropout_p=self.attention_dropout if self.training else 0.0,
773
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
774
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
775
  )
776
 
777
  attn_output = attn_output.transpose(1, 2).contiguous()
778
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
779
 
780
  attn_output = self.o_proj(attn_output)
781
 
@@ -1095,7 +1111,7 @@ class QuietModel(QuietPreTrainedModel):
1095
  past_key_values_length,
1096
  sliding_window=self.config.sliding_window,
1097
  )
1098
- print(f"Prepared 4D causal attention mask. Shape: {attention_mask.shape}")
1099
  hidden_states = inputs_embeds
1100
 
1101
  # decoder layers
@@ -1318,11 +1334,16 @@ class QuietForCausalLM(QuietPreTrainedModel):
1318
  original_input_ids = input_ids.clone()
1319
  original_attention_mask = attention_mask.clone() if attention_mask is not None else None
1320
 
 
1321
  # Append the start thought token to the input sequence
1322
  start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
1323
  input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1324
  seq_len += 1
1325
 
 
 
 
 
1326
  # Update the attention mask
1327
  if attention_mask is not None:
1328
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
@@ -1344,7 +1365,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1344
  output_hidden_states=output_hidden_states,
1345
  return_dict=return_dict,
1346
  )
1347
- print(f"Passing attention mask to the model. Shape: {attention_mask.shape}")
1348
  new_key_values = outputs.past_key_values
1349
 
1350
  hidden_states = outputs[0]
@@ -1365,10 +1385,15 @@ class QuietForCausalLM(QuietPreTrainedModel):
1365
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1366
 
1367
  # Append the end thought token to the input sequence
 
1368
  end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1369
  input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1370
  seq_len += 1
1371
 
 
 
 
 
1372
  # Update the attention mask
1373
  if attention_mask is not None:
1374
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
@@ -1603,6 +1628,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
1603
  base_embeddings = self.model.embed_tokens.weight
1604
  if self.train_only_thinking_embedding:
1605
  base_embeddings = base_embeddings.detach()
 
 
1606
  # # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1607
  fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
1608
  for ahead_idx in range(fwd_iters):
@@ -1882,9 +1909,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1882
  if len(attention_mask.shape) == 2:
1883
  breakpoint()
1884
  else:
1885
- original_attention = attention_mask[..., :attention_mask.shape[-2], :attention_mask.shape[-1]]
1886
- print(f"Original attention shape: {original_attention.shape}")
1887
-
1888
  if self.use_upper_triangular:
1889
  new_attention = original_attention
1890
  else:
@@ -1900,20 +1925,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
1900
  ).to(attention_mask.dtype)
1901
 
1902
  new_attention = new_attention.view(1, 1, seq_len, seq_len).repeat(input_ids.shape[0], 1, 1, 1)
1903
- print(f"New attention shape: {new_attention.shape}")
1904
-
1905
  new_attention = new_attention * original_attention
1906
  new_attention[new_attention == 0] = attention_mask.min()
1907
  new_attention[new_attention == 1] = attention_mask.max()
1908
-
1909
- print(f"Original attention shape before concatenation: {original_attention.shape}")
1910
- print(f"New attention shape before concatenation: {new_attention.shape}")
1911
-
1912
- if self.use_upper_triangular:
1913
- attention_mask = original_attention
1914
- else:
1915
- attention_mask = new_attention
1916
- print(f"Attention mask shape after concatenation: {attention_mask.shape}")
1917
  past_key_values = outputs.past_key_values
1918
  position_ids = position_ids + 1
1919
 
 
147
  cu_seqlens,
148
  max_seqlen_in_batch,
149
  )
150
+
151
+
152
  # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Quiet
153
  class QuietRMSNorm(nn.Module):
154
  def __init__(self, hidden_size, eps=1e-6):
 
169
 
170
  # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
171
  class QuietRotaryEmbedding(nn.Module):
172
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
173
  super().__init__()
174
 
175
  self.dim = dim
176
+ self.max_position_embeddings = max_position_embeddings
177
  self.base = base
178
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
179
  self.register_buffer("inv_freq", inv_freq, persistent=False)
180
 
181
  # Build here to make `torch.jit.trace` work.
182
  self._set_cos_sin_cache(
183
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
184
  )
185
 
186
  def _set_cos_sin_cache(self, seq_len, device, dtype):
 
188
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
189
 
190
  freqs = torch.outer(t, self.inv_freq)
191
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
192
  emb = torch.cat((freqs, freqs), dim=-1)
193
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
194
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
 
234
  Returns:
235
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
236
  """
 
 
 
 
 
 
237
  cos = cos[position_ids].unsqueeze(unsqueeze_dim)
238
  sin = sin[position_ids].unsqueeze(unsqueeze_dim)
239
  q_embed = (q * cos) + (rotate_half(q) * sin)
240
  k_embed = (k * cos) + (rotate_half(k) * sin)
241
  return q_embed, k_embed
242
 
243
+
244
  class QuietMLP(nn.Module):
245
  def __init__(self, config):
246
  super().__init__()
 
281
  self.layer_idx = layer_idx
282
  if layer_idx is None:
283
  logger.warning_once(
284
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
285
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
286
  "when creating this class."
287
  )
288
 
 
310
  self.head_dim,
311
  max_position_embeddings=self.max_position_embeddings,
312
  base=self.rope_theta,
 
313
  )
314
 
315
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
 
367
  f" {attn_weights.size()}"
368
  )
369
 
370
+ print("Before applying attention mask:")
371
+ print("attention_mask shape:", attention_mask.shape if attention_mask is not None else None)
372
+ print("attn_weights shape:", attn_weights.shape)
373
+
374
  if attention_mask is not None:
375
+ print("Applying attention mask")
376
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
377
  raise ValueError(
378
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
379
  )
 
380
  attn_weights = attn_weights + attention_mask
381
+
382
+ print("After applying attention mask:")
383
+ print("attn_weights shape:", attn_weights.shape)
384
+
385
  # upcast attention to fp32
386
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
387
+
388
+ print("After softmax:")
389
+ print("attn_weights shape:", attn_weights.shape)
390
+
391
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
392
+
393
+ print("After dropout:")
394
+ print("attn_weights shape:", attn_weights.shape)
395
+
396
  attn_output = torch.matmul(attn_weights, value_states)
397
+
398
+ print("After matmul with value states:")
399
+ print("attn_output shape:", attn_output.shape)
400
+
401
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
402
  raise ValueError(
403
  f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
404
  f" {attn_output.size()}"
405
  )
406
+
407
  attn_output = attn_output.transpose(1, 2).contiguous()
408
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
409
  attn_output = self.o_proj(attn_output)
410
+
411
+ print("Final attn_output shape:", attn_output.shape)
412
+
413
  if not output_attentions:
414
  attn_weights = None
415
+
416
  return attn_output, attn_weights, past_key_value
417
 
 
418
  class QuietFlashAttention2(QuietAttention):
419
  """
420
  Quiet flash attention module. This module inherits from `QuietAttention` as the weights of the module stays
 
591
  attention_mask (`torch.Tensor`):
592
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
593
  position of padding tokens and 1 for the position of non-padding tokens.
594
+ dropout (`float`):
595
  Attention dropout
596
  softmax_scale (`float`, *optional*):
597
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
 
709
  )
710
 
711
 
712
+ # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
713
+ # TODO @Arthur no longer copied from LLama after static cache
714
  class QuietSdpaAttention(QuietAttention):
715
  """
716
  Quiet attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
 
784
  query_states,
785
  key_states,
786
  value_states,
787
+ attn_mask=attention_mask,
788
  dropout_p=self.attention_dropout if self.training else 0.0,
789
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
790
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
791
  )
792
 
793
  attn_output = attn_output.transpose(1, 2).contiguous()
794
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
795
 
796
  attn_output = self.o_proj(attn_output)
797
 
 
1111
  past_key_values_length,
1112
  sliding_window=self.config.sliding_window,
1113
  )
1114
+
1115
  hidden_states = inputs_embeds
1116
 
1117
  # decoder layers
 
1334
  original_input_ids = input_ids.clone()
1335
  original_attention_mask = attention_mask.clone() if attention_mask is not None else None
1336
 
1337
+ # Append the start thought token to the input sequence
1338
  # Append the start thought token to the input sequence
1339
  start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
1340
  input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1341
  seq_len += 1
1342
 
1343
+ # Update the position_ids tensor
1344
+ position_ids = position_ids[:, :-1] # Remove the last position
1345
+ position_ids = torch.cat([position_ids, torch.full((batch_size, 1), seq_len - 1, dtype=torch.long, device=position_ids.device)], dim=-1)
1346
+
1347
  # Update the attention mask
1348
  if attention_mask is not None:
1349
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
 
1365
  output_hidden_states=output_hidden_states,
1366
  return_dict=return_dict,
1367
  )
 
1368
  new_key_values = outputs.past_key_values
1369
 
1370
  hidden_states = outputs[0]
 
1385
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1386
 
1387
  # Append the end thought token to the input sequence
1388
+ # Append the end thought token to the input sequence
1389
  end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1390
  input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1391
  seq_len += 1
1392
 
1393
+ # Update the position_ids tensor
1394
+ position_ids = position_ids[:, :-1] # Remove the last position
1395
+ position_ids = torch.cat([position_ids, torch.full((batch_size, 1), seq_len - 1, dtype=torch.long, device=position_ids.device)], dim=-1)
1396
+
1397
  # Update the attention mask
1398
  if attention_mask is not None:
1399
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
 
1628
  base_embeddings = self.model.embed_tokens.weight
1629
  if self.train_only_thinking_embedding:
1630
  base_embeddings = base_embeddings.detach()
1631
+ if position_ids is None:
1632
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
1633
  # # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1634
  fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
1635
  for ahead_idx in range(fwd_iters):
 
1909
  if len(attention_mask.shape) == 2:
1910
  breakpoint()
1911
  else:
1912
+ original_attention = attention_mask[..., :attention_mask.shape[-2]]
 
 
1913
  if self.use_upper_triangular:
1914
  new_attention = original_attention
1915
  else:
 
1925
  ).to(attention_mask.dtype)
1926
 
1927
  new_attention = new_attention.view(1, 1, seq_len, seq_len).repeat(input_ids.shape[0], 1, 1, 1)
 
 
1928
  new_attention = new_attention * original_attention
1929
  new_attention[new_attention == 0] = attention_mask.min()
1930
  new_attention[new_attention == 1] = attention_mask.max()
1931
+ attention_mask = torch.cat([attention_mask, new_attention], dim=-1)
 
 
 
 
 
 
 
 
1932
  past_key_values = outputs.past_key_values
1933
  position_ids = position_ids + 1
1934