Crystalcareai commited on
Commit
fd5387b
·
verified ·
1 Parent(s): 581b060

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +12 -64
modeling_quiet.py CHANGED
@@ -270,22 +270,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
270
 
271
 
272
  class QuietAttention(nn.Module):
273
- """
274
- Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
275
- and "Generating Long Sequences with Sparse Transformers".
276
- """
277
-
278
  def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
279
  super().__init__()
280
  self.config = config
281
  self.layer_idx = layer_idx
282
- if layer_idx is None:
283
- logger.warning_once(
284
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
285
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
286
- "when creating this class."
287
- )
288
-
289
  self.hidden_size = config.hidden_size
290
  self.num_heads = config.num_attention_heads
291
  self.head_dim = self.hidden_size // self.num_heads
@@ -296,11 +284,6 @@ class QuietAttention(nn.Module):
296
  self.is_causal = True
297
  self.attention_dropout = config.attention_dropout
298
 
299
- if (self.head_dim * self.num_heads) != self.hidden_size:
300
- raise ValueError(
301
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
302
- f" and `num_heads`: {self.num_heads})."
303
- )
304
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
305
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
306
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
@@ -312,9 +295,6 @@ class QuietAttention(nn.Module):
312
  base=self.rope_theta,
313
  )
314
 
315
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
316
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
317
-
318
  def forward(
319
  self,
320
  hidden_states: torch.Tensor,
@@ -324,11 +304,7 @@ class QuietAttention(nn.Module):
324
  output_attentions: bool = False,
325
  use_cache: bool = False,
326
  **kwargs,
327
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
328
- if "padding_mask" in kwargs:
329
- warnings.warn(
330
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
331
- )
332
  bsz, q_len, _ = hidden_states.size()
333
 
334
  query_states = self.q_proj(hidden_states)
@@ -342,50 +318,31 @@ class QuietAttention(nn.Module):
342
  kv_seq_len = key_states.shape[-2]
343
  if past_key_value is not None:
344
  if self.layer_idx is None:
345
- raise ValueError(
346
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
347
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
348
- "with a layer index."
349
- )
350
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
351
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
352
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
353
 
354
  if past_key_value is not None:
355
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
356
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
357
 
358
- # repeat k/v heads if n_kv_heads < n_heads
359
  key_states = repeat_kv(key_states, self.num_key_value_groups)
360
  value_states = repeat_kv(value_states, self.num_key_value_groups)
361
 
362
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
363
-
364
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
365
- raise ValueError(
366
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
367
- f" {attn_weights.size()}"
368
- )
369
-
370
- if attention_mask is not None:
371
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
372
- raise ValueError(
373
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
374
- )
375
-
376
- attn_weights = attn_weights + attention_mask
377
 
378
- # upcast attention to fp32
 
379
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
380
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
381
  attn_output = torch.matmul(attn_weights, value_states)
382
 
383
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
384
- raise ValueError(
385
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
386
- f" {attn_output.size()}"
387
- )
388
-
389
  attn_output = attn_output.transpose(1, 2).contiguous()
390
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
391
 
@@ -1083,16 +1040,7 @@ class QuietModel(QuietPreTrainedModel):
1083
  inputs_embeds,
1084
  past_key_values_length,
1085
  )
1086
- elif attention_mask is None or attention_mask.dim() == 2:
1087
- # 4d mask is passed through the layers
1088
- attention_mask = _prepare_4d_causal_attention_mask(
1089
- attention_mask,
1090
- (batch_size, seq_length),
1091
- inputs_embeds,
1092
- past_key_values_length,
1093
- sliding_window=self.config.sliding_window,
1094
- )
1095
-
1096
  hidden_states = inputs_embeds
1097
 
1098
  # decoder layers
 
270
 
271
 
272
  class QuietAttention(nn.Module):
 
 
 
 
 
273
  def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
274
  super().__init__()
275
  self.config = config
276
  self.layer_idx = layer_idx
 
 
 
 
 
 
 
277
  self.hidden_size = config.hidden_size
278
  self.num_heads = config.num_attention_heads
279
  self.head_dim = self.hidden_size // self.num_heads
 
284
  self.is_causal = True
285
  self.attention_dropout = config.attention_dropout
286
 
 
 
 
 
 
287
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
288
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
289
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
 
295
  base=self.rope_theta,
296
  )
297
 
 
 
 
298
  def forward(
299
  self,
300
  hidden_states: torch.Tensor,
 
304
  output_attentions: bool = False,
305
  use_cache: bool = False,
306
  **kwargs,
307
+ ):
 
 
 
 
308
  bsz, q_len, _ = hidden_states.size()
309
 
310
  query_states = self.q_proj(hidden_states)
 
318
  kv_seq_len = key_states.shape[-2]
319
  if past_key_value is not None:
320
  if self.layer_idx is None:
321
+ raise ValueError("Layer index must be provided when using past key values.")
 
 
 
 
322
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
323
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
324
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
325
 
326
  if past_key_value is not None:
327
+ cache_kwargs = {"sin": sin, "cos": cos}
328
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
329
 
 
330
  key_states = repeat_kv(key_states, self.num_key_value_groups)
331
  value_states = repeat_kv(value_states, self.num_key_value_groups)
332
 
333
+ if attention_mask is None:
334
+ attention_mask = torch.ones(bsz, 1, q_len, kv_seq_len, device=hidden_states.device)
335
+ else:
336
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
337
+ attention_mask = attention_mask.repeat(1, 1, q_len, 1)
338
+ attention_mask = torch.triu(attention_mask, diagonal=1)
 
 
 
 
 
 
 
 
 
339
 
340
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
341
+ attn_weights = attn_weights + attention_mask
342
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
343
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
344
  attn_output = torch.matmul(attn_weights, value_states)
345
 
 
 
 
 
 
 
346
  attn_output = attn_output.transpose(1, 2).contiguous()
347
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
348
 
 
1040
  inputs_embeds,
1041
  past_key_values_length,
1042
  )
1043
+
 
 
 
 
 
 
 
 
 
1044
  hidden_states = inputs_embeds
1045
 
1046
  # decoder layers