Crystalcareai commited on
Commit
f5e1b24
·
verified ·
1 Parent(s): 5454cb2

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +22 -53
modeling_quiet.py CHANGED
@@ -1071,58 +1071,29 @@ class QuietModel(QuietPreTrainedModel):
1071
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1072
  )
1073
 
1074
- if self._attn_implementation == "flash_attention_2":
1075
- # 2d mask is passed through the layers
1076
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1077
- elif self._attn_implementation == "sdpa" and not output_attentions:
1078
- if attention_mask.dim() == 2:
1079
- # output_attentions=True can not be supported when using SDPA, and we fall back on
1080
- # the manual implementation that requires a 4D causal mask in all cases.
1081
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1082
- attention_mask,
1083
- (batch_size, seq_length),
1084
- inputs_embeds,
1085
- past_key_values_length,
1086
- )
1087
- else:
1088
- # Resize the attention mask if necessary
1089
- if attention_mask.shape[-1] < seq_length:
1090
- # Pad the attention mask with ones to match the sequence length
1091
- padding = torch.ones(
1092
- (attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], seq_length - attention_mask.shape[-1]),
1093
- dtype=attention_mask.dtype,
1094
- device=attention_mask.device
1095
- )
1096
- attention_mask = torch.cat([attention_mask, padding], dim=-1)
1097
- elif attention_mask.shape[-1] > seq_length:
1098
- # Truncate the attention mask to match the sequence length
1099
- attention_mask = attention_mask[:, :, :, :seq_length]
1100
- else:
1101
- if attention_mask is None or attention_mask.dim() == 2:
1102
- # 4d mask is passed through the layers
1103
- attention_mask = _prepare_4d_causal_attention_mask(
1104
- attention_mask,
1105
- (batch_size, seq_length),
1106
- inputs_embeds,
1107
- past_key_values_length,
1108
- sliding_window=self.config.sliding_window,
1109
- )
1110
- else:
1111
- # Resize the attention mask if necessary
1112
- if attention_mask.shape[-1] < seq_length:
1113
- # Pad the attention mask with ones to match the sequence length
1114
- padding = torch.ones(
1115
- (attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], seq_length - attention_mask.shape[-1]),
1116
- dtype=attention_mask.dtype,
1117
- device=attention_mask.device
1118
- )
1119
- attention_mask = torch.cat([attention_mask, padding], dim=-1)
1120
- elif attention_mask.shape[-1] > seq_length:
1121
- # Truncate the attention mask to match the sequence length
1122
- attention_mask = attention_mask[:, :, :, :seq_length]
1123
 
1124
- # Assign the value to hidden_states after the attention mask preparation
1125
- hidden_states = inputs_embeds
1126
 
1127
  # decoder layers
1128
  all_hidden_states = () if output_hidden_states else None
@@ -1912,7 +1883,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1912
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1913
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1914
 
1915
- # Update the attention mask when new tokens are added
1916
  if len(attention_mask.shape) == 2:
1917
  breakpoint()
1918
  else:
@@ -1935,7 +1905,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1935
  new_attention = new_attention * original_attention
1936
  new_attention[new_attention == 0] = attention_mask.min()
1937
  new_attention[new_attention == 1] = attention_mask.max()
1938
- attention_mask = torch.cat([original_attention, new_attention], dim=-1)
1939
  attention_mask = torch.cat([attention_mask, new_attention], dim=-1)
1940
  past_key_values = outputs.past_key_values
1941
  position_ids = position_ids + 1
 
1071
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1072
  )
1073
 
1074
+ if self._attn_implementation == "flash_attention_2":
1075
+ # 2d mask is passed through the layers
1076
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1077
+ elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
1078
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1079
+ # the manual implementation that requires a 4D causal mask in all cases.
1080
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1081
+ attention_mask,
1082
+ (batch_size, seq_length),
1083
+ inputs_embeds,
1084
+ past_key_values_length,
1085
+ )
1086
+ elif attention_mask is None or attention_mask.dim() == 2:
1087
+ # 4d mask is passed through the layers
1088
+ attention_mask = _prepare_4d_causal_attention_mask(
1089
+ attention_mask,
1090
+ (batch_size, seq_length),
1091
+ inputs_embeds,
1092
+ past_key_values_length,
1093
+ sliding_window=self.config.sliding_window,
1094
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1095
 
1096
+ hidden_states = inputs_embeds
 
1097
 
1098
  # decoder layers
1099
  all_hidden_states = () if output_hidden_states else None
 
1883
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1884
  inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1885
 
 
1886
  if len(attention_mask.shape) == 2:
1887
  breakpoint()
1888
  else:
 
1905
  new_attention = new_attention * original_attention
1906
  new_attention[new_attention == 0] = attention_mask.min()
1907
  new_attention[new_attention == 1] = attention_mask.max()
 
1908
  attention_mask = torch.cat([attention_mask, new_attention], dim=-1)
1909
  past_key_values = outputs.past_key_values
1910
  position_ids = position_ids + 1