Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +22 -53
modeling_quiet.py
CHANGED
@@ -1071,58 +1071,29 @@ class QuietModel(QuietPreTrainedModel):
|
|
1071 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1072 |
)
|
1073 |
|
1074 |
-
|
1075 |
-
|
1076 |
-
|
1077 |
-
|
1078 |
-
|
1079 |
-
|
1080 |
-
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
-
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
1093 |
-
|
1094 |
-
|
1095 |
-
)
|
1096 |
-
attention_mask = torch.cat([attention_mask, padding], dim=-1)
|
1097 |
-
elif attention_mask.shape[-1] > seq_length:
|
1098 |
-
# Truncate the attention mask to match the sequence length
|
1099 |
-
attention_mask = attention_mask[:, :, :, :seq_length]
|
1100 |
-
else:
|
1101 |
-
if attention_mask is None or attention_mask.dim() == 2:
|
1102 |
-
# 4d mask is passed through the layers
|
1103 |
-
attention_mask = _prepare_4d_causal_attention_mask(
|
1104 |
-
attention_mask,
|
1105 |
-
(batch_size, seq_length),
|
1106 |
-
inputs_embeds,
|
1107 |
-
past_key_values_length,
|
1108 |
-
sliding_window=self.config.sliding_window,
|
1109 |
-
)
|
1110 |
-
else:
|
1111 |
-
# Resize the attention mask if necessary
|
1112 |
-
if attention_mask.shape[-1] < seq_length:
|
1113 |
-
# Pad the attention mask with ones to match the sequence length
|
1114 |
-
padding = torch.ones(
|
1115 |
-
(attention_mask.shape[0], attention_mask.shape[1], attention_mask.shape[2], seq_length - attention_mask.shape[-1]),
|
1116 |
-
dtype=attention_mask.dtype,
|
1117 |
-
device=attention_mask.device
|
1118 |
-
)
|
1119 |
-
attention_mask = torch.cat([attention_mask, padding], dim=-1)
|
1120 |
-
elif attention_mask.shape[-1] > seq_length:
|
1121 |
-
# Truncate the attention mask to match the sequence length
|
1122 |
-
attention_mask = attention_mask[:, :, :, :seq_length]
|
1123 |
|
1124 |
-
|
1125 |
-
hidden_states = inputs_embeds
|
1126 |
|
1127 |
# decoder layers
|
1128 |
all_hidden_states = () if output_hidden_states else None
|
@@ -1912,7 +1883,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1912 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1913 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1914 |
|
1915 |
-
# Update the attention mask when new tokens are added
|
1916 |
if len(attention_mask.shape) == 2:
|
1917 |
breakpoint()
|
1918 |
else:
|
@@ -1935,7 +1905,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1935 |
new_attention = new_attention * original_attention
|
1936 |
new_attention[new_attention == 0] = attention_mask.min()
|
1937 |
new_attention[new_attention == 1] = attention_mask.max()
|
1938 |
-
attention_mask = torch.cat([original_attention, new_attention], dim=-1)
|
1939 |
attention_mask = torch.cat([attention_mask, new_attention], dim=-1)
|
1940 |
past_key_values = outputs.past_key_values
|
1941 |
position_ids = position_ids + 1
|
|
|
1071 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1072 |
)
|
1073 |
|
1074 |
+
if self._attn_implementation == "flash_attention_2":
|
1075 |
+
# 2d mask is passed through the layers
|
1076 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
1077 |
+
elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
|
1078 |
+
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
1079 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
1080 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
1081 |
+
attention_mask,
|
1082 |
+
(batch_size, seq_length),
|
1083 |
+
inputs_embeds,
|
1084 |
+
past_key_values_length,
|
1085 |
+
)
|
1086 |
+
elif attention_mask is None or attention_mask.dim() == 2:
|
1087 |
+
# 4d mask is passed through the layers
|
1088 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
1089 |
+
attention_mask,
|
1090 |
+
(batch_size, seq_length),
|
1091 |
+
inputs_embeds,
|
1092 |
+
past_key_values_length,
|
1093 |
+
sliding_window=self.config.sliding_window,
|
1094 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1095 |
|
1096 |
+
hidden_states = inputs_embeds
|
|
|
1097 |
|
1098 |
# decoder layers
|
1099 |
all_hidden_states = () if output_hidden_states else None
|
|
|
1883 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1884 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1885 |
|
|
|
1886 |
if len(attention_mask.shape) == 2:
|
1887 |
breakpoint()
|
1888 |
else:
|
|
|
1905 |
new_attention = new_attention * original_attention
|
1906 |
new_attention[new_attention == 0] = attention_mask.min()
|
1907 |
new_attention[new_attention == 1] = attention_mask.max()
|
|
|
1908 |
attention_mask = torch.cat([attention_mask, new_attention], dim=-1)
|
1909 |
past_key_values = outputs.past_key_values
|
1910 |
position_ids = position_ids + 1
|