Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +55 -40
modeling_quiet.py
CHANGED
@@ -147,6 +147,8 @@ def _get_unpad_data(attention_mask):
|
|
147 |
cu_seqlens,
|
148 |
max_seqlen_in_batch,
|
149 |
)
|
|
|
|
|
150 |
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Quiet
|
151 |
class QuietRMSNorm(nn.Module):
|
152 |
def __init__(self, hidden_size, eps=1e-6):
|
@@ -167,18 +169,18 @@ class QuietRMSNorm(nn.Module):
|
|
167 |
|
168 |
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
|
169 |
class QuietRotaryEmbedding(nn.Module):
|
170 |
-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None
|
171 |
super().__init__()
|
172 |
|
173 |
self.dim = dim
|
174 |
-
self.max_position_embeddings = max_position_embeddings
|
175 |
self.base = base
|
176 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
177 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
178 |
|
179 |
# Build here to make `torch.jit.trace` work.
|
180 |
self._set_cos_sin_cache(
|
181 |
-
seq_len=max_position_embeddings
|
182 |
)
|
183 |
|
184 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
@@ -186,6 +188,7 @@ class QuietRotaryEmbedding(nn.Module):
|
|
186 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
187 |
|
188 |
freqs = torch.outer(t, self.inv_freq)
|
|
|
189 |
emb = torch.cat((freqs, freqs), dim=-1)
|
190 |
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
191 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
@@ -231,18 +234,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
231 |
Returns:
|
232 |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
233 |
"""
|
234 |
-
print(f"cos shape: {cos.shape}")
|
235 |
-
print(f"position_ids shape: {position_ids.shape}")
|
236 |
-
print(f"position_ids values: {position_ids}")
|
237 |
-
print(f"unsqueeze_dim: {unsqueeze_dim}")
|
238 |
-
assert torch.all(position_ids >= 0), "position_ids must be non-negative"
|
239 |
-
assert torch.all(position_ids < cos.shape[0]), f"position_ids must be less than the size of cos ({cos.shape[0]})"
|
240 |
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
241 |
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
242 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
243 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
244 |
return q_embed, k_embed
|
245 |
|
|
|
246 |
class QuietMLP(nn.Module):
|
247 |
def __init__(self, config):
|
248 |
super().__init__()
|
@@ -283,8 +281,8 @@ class QuietAttention(nn.Module):
|
|
283 |
self.layer_idx = layer_idx
|
284 |
if layer_idx is None:
|
285 |
logger.warning_once(
|
286 |
-
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
287 |
-
"to errors during the forward call
|
288 |
"when creating this class."
|
289 |
)
|
290 |
|
@@ -312,7 +310,6 @@ class QuietAttention(nn.Module):
|
|
312 |
self.head_dim,
|
313 |
max_position_embeddings=self.max_position_embeddings,
|
314 |
base=self.rope_theta,
|
315 |
-
max_thought_tokens=2,
|
316 |
)
|
317 |
|
318 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
@@ -370,36 +367,54 @@ class QuietAttention(nn.Module):
|
|
370 |
f" {attn_weights.size()}"
|
371 |
)
|
372 |
|
|
|
|
|
|
|
|
|
373 |
if attention_mask is not None:
|
|
|
374 |
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
375 |
raise ValueError(
|
376 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
377 |
)
|
378 |
-
|
379 |
attn_weights = attn_weights + attention_mask
|
380 |
-
|
|
|
|
|
|
|
381 |
# upcast attention to fp32
|
382 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
|
|
|
|
|
|
|
383 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
|
|
|
|
|
|
|
384 |
attn_output = torch.matmul(attn_weights, value_states)
|
385 |
-
|
|
|
|
|
|
|
386 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
387 |
raise ValueError(
|
388 |
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
389 |
f" {attn_output.size()}"
|
390 |
)
|
391 |
-
|
392 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
393 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
394 |
-
|
395 |
attn_output = self.o_proj(attn_output)
|
396 |
-
|
|
|
|
|
397 |
if not output_attentions:
|
398 |
attn_weights = None
|
399 |
-
|
400 |
return attn_output, attn_weights, past_key_value
|
401 |
|
402 |
-
|
403 |
class QuietFlashAttention2(QuietAttention):
|
404 |
"""
|
405 |
Quiet flash attention module. This module inherits from `QuietAttention` as the weights of the module stays
|
@@ -576,7 +591,7 @@ class QuietFlashAttention2(QuietAttention):
|
|
576 |
attention_mask (`torch.Tensor`):
|
577 |
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
578 |
position of padding tokens and 1 for the position of non-padding tokens.
|
579 |
-
dropout (`
|
580 |
Attention dropout
|
581 |
softmax_scale (`float`, *optional*):
|
582 |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
@@ -694,7 +709,8 @@ class QuietFlashAttention2(QuietAttention):
|
|
694 |
)
|
695 |
|
696 |
|
697 |
-
#
|
|
|
698 |
class QuietSdpaAttention(QuietAttention):
|
699 |
"""
|
700 |
Quiet attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
@@ -768,14 +784,14 @@ class QuietSdpaAttention(QuietAttention):
|
|
768 |
query_states,
|
769 |
key_states,
|
770 |
value_states,
|
771 |
-
attn_mask=attention_mask
|
772 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
773 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
774 |
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
775 |
)
|
776 |
|
777 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
778 |
-
attn_output = attn_output.
|
779 |
|
780 |
attn_output = self.o_proj(attn_output)
|
781 |
|
@@ -1095,7 +1111,7 @@ class QuietModel(QuietPreTrainedModel):
|
|
1095 |
past_key_values_length,
|
1096 |
sliding_window=self.config.sliding_window,
|
1097 |
)
|
1098 |
-
|
1099 |
hidden_states = inputs_embeds
|
1100 |
|
1101 |
# decoder layers
|
@@ -1318,11 +1334,16 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1318 |
original_input_ids = input_ids.clone()
|
1319 |
original_attention_mask = attention_mask.clone() if attention_mask is not None else None
|
1320 |
|
|
|
1321 |
# Append the start thought token to the input sequence
|
1322 |
start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
|
1323 |
input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
|
1324 |
seq_len += 1
|
1325 |
|
|
|
|
|
|
|
|
|
1326 |
# Update the attention mask
|
1327 |
if attention_mask is not None:
|
1328 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
@@ -1344,7 +1365,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1344 |
output_hidden_states=output_hidden_states,
|
1345 |
return_dict=return_dict,
|
1346 |
)
|
1347 |
-
print(f"Passing attention mask to the model. Shape: {attention_mask.shape}")
|
1348 |
new_key_values = outputs.past_key_values
|
1349 |
|
1350 |
hidden_states = outputs[0]
|
@@ -1365,10 +1385,15 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1365 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1366 |
|
1367 |
# Append the end thought token to the input sequence
|
|
|
1368 |
end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
|
1369 |
input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
|
1370 |
seq_len += 1
|
1371 |
|
|
|
|
|
|
|
|
|
1372 |
# Update the attention mask
|
1373 |
if attention_mask is not None:
|
1374 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
@@ -1603,6 +1628,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1603 |
base_embeddings = self.model.embed_tokens.weight
|
1604 |
if self.train_only_thinking_embedding:
|
1605 |
base_embeddings = base_embeddings.detach()
|
|
|
|
|
1606 |
# # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1607 |
fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
|
1608 |
for ahead_idx in range(fwd_iters):
|
@@ -1882,9 +1909,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1882 |
if len(attention_mask.shape) == 2:
|
1883 |
breakpoint()
|
1884 |
else:
|
1885 |
-
original_attention = attention_mask[..., :attention_mask.shape[-2]
|
1886 |
-
print(f"Original attention shape: {original_attention.shape}")
|
1887 |
-
|
1888 |
if self.use_upper_triangular:
|
1889 |
new_attention = original_attention
|
1890 |
else:
|
@@ -1900,20 +1925,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1900 |
).to(attention_mask.dtype)
|
1901 |
|
1902 |
new_attention = new_attention.view(1, 1, seq_len, seq_len).repeat(input_ids.shape[0], 1, 1, 1)
|
1903 |
-
print(f"New attention shape: {new_attention.shape}")
|
1904 |
-
|
1905 |
new_attention = new_attention * original_attention
|
1906 |
new_attention[new_attention == 0] = attention_mask.min()
|
1907 |
new_attention[new_attention == 1] = attention_mask.max()
|
1908 |
-
|
1909 |
-
print(f"Original attention shape before concatenation: {original_attention.shape}")
|
1910 |
-
print(f"New attention shape before concatenation: {new_attention.shape}")
|
1911 |
-
|
1912 |
-
if self.use_upper_triangular:
|
1913 |
-
attention_mask = original_attention
|
1914 |
-
else:
|
1915 |
-
attention_mask = new_attention
|
1916 |
-
print(f"Attention mask shape after concatenation: {attention_mask.shape}")
|
1917 |
past_key_values = outputs.past_key_values
|
1918 |
position_ids = position_ids + 1
|
1919 |
|
|
|
147 |
cu_seqlens,
|
148 |
max_seqlen_in_batch,
|
149 |
)
|
150 |
+
|
151 |
+
|
152 |
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Quiet
|
153 |
class QuietRMSNorm(nn.Module):
|
154 |
def __init__(self, hidden_size, eps=1e-6):
|
|
|
169 |
|
170 |
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
|
171 |
class QuietRotaryEmbedding(nn.Module):
|
172 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
173 |
super().__init__()
|
174 |
|
175 |
self.dim = dim
|
176 |
+
self.max_position_embeddings = max_position_embeddings
|
177 |
self.base = base
|
178 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
179 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
180 |
|
181 |
# Build here to make `torch.jit.trace` work.
|
182 |
self._set_cos_sin_cache(
|
183 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
184 |
)
|
185 |
|
186 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|
|
188 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
189 |
|
190 |
freqs = torch.outer(t, self.inv_freq)
|
191 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
192 |
emb = torch.cat((freqs, freqs), dim=-1)
|
193 |
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
194 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
|
|
234 |
Returns:
|
235 |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
236 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
238 |
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
239 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
240 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
241 |
return q_embed, k_embed
|
242 |
|
243 |
+
|
244 |
class QuietMLP(nn.Module):
|
245 |
def __init__(self, config):
|
246 |
super().__init__()
|
|
|
281 |
self.layer_idx = layer_idx
|
282 |
if layer_idx is None:
|
283 |
logger.warning_once(
|
284 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
285 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
286 |
"when creating this class."
|
287 |
)
|
288 |
|
|
|
310 |
self.head_dim,
|
311 |
max_position_embeddings=self.max_position_embeddings,
|
312 |
base=self.rope_theta,
|
|
|
313 |
)
|
314 |
|
315 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
|
367 |
f" {attn_weights.size()}"
|
368 |
)
|
369 |
|
370 |
+
print("Before applying attention mask:")
|
371 |
+
print("attention_mask shape:", attention_mask.shape if attention_mask is not None else None)
|
372 |
+
print("attn_weights shape:", attn_weights.shape)
|
373 |
+
|
374 |
if attention_mask is not None:
|
375 |
+
print("Applying attention mask")
|
376 |
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
377 |
raise ValueError(
|
378 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
379 |
)
|
|
|
380 |
attn_weights = attn_weights + attention_mask
|
381 |
+
|
382 |
+
print("After applying attention mask:")
|
383 |
+
print("attn_weights shape:", attn_weights.shape)
|
384 |
+
|
385 |
# upcast attention to fp32
|
386 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
387 |
+
|
388 |
+
print("After softmax:")
|
389 |
+
print("attn_weights shape:", attn_weights.shape)
|
390 |
+
|
391 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
392 |
+
|
393 |
+
print("After dropout:")
|
394 |
+
print("attn_weights shape:", attn_weights.shape)
|
395 |
+
|
396 |
attn_output = torch.matmul(attn_weights, value_states)
|
397 |
+
|
398 |
+
print("After matmul with value states:")
|
399 |
+
print("attn_output shape:", attn_output.shape)
|
400 |
+
|
401 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
402 |
raise ValueError(
|
403 |
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
404 |
f" {attn_output.size()}"
|
405 |
)
|
406 |
+
|
407 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
408 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
409 |
attn_output = self.o_proj(attn_output)
|
410 |
+
|
411 |
+
print("Final attn_output shape:", attn_output.shape)
|
412 |
+
|
413 |
if not output_attentions:
|
414 |
attn_weights = None
|
415 |
+
|
416 |
return attn_output, attn_weights, past_key_value
|
417 |
|
|
|
418 |
class QuietFlashAttention2(QuietAttention):
|
419 |
"""
|
420 |
Quiet flash attention module. This module inherits from `QuietAttention` as the weights of the module stays
|
|
|
591 |
attention_mask (`torch.Tensor`):
|
592 |
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
593 |
position of padding tokens and 1 for the position of non-padding tokens.
|
594 |
+
dropout (`float`):
|
595 |
Attention dropout
|
596 |
softmax_scale (`float`, *optional*):
|
597 |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
|
|
709 |
)
|
710 |
|
711 |
|
712 |
+
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
|
713 |
+
# TODO @Arthur no longer copied from LLama after static cache
|
714 |
class QuietSdpaAttention(QuietAttention):
|
715 |
"""
|
716 |
Quiet attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
|
|
784 |
query_states,
|
785 |
key_states,
|
786 |
value_states,
|
787 |
+
attn_mask=attention_mask,
|
788 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
789 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
790 |
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
791 |
)
|
792 |
|
793 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
794 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
795 |
|
796 |
attn_output = self.o_proj(attn_output)
|
797 |
|
|
|
1111 |
past_key_values_length,
|
1112 |
sliding_window=self.config.sliding_window,
|
1113 |
)
|
1114 |
+
|
1115 |
hidden_states = inputs_embeds
|
1116 |
|
1117 |
# decoder layers
|
|
|
1334 |
original_input_ids = input_ids.clone()
|
1335 |
original_attention_mask = attention_mask.clone() if attention_mask is not None else None
|
1336 |
|
1337 |
+
# Append the start thought token to the input sequence
|
1338 |
# Append the start thought token to the input sequence
|
1339 |
start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
|
1340 |
input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
|
1341 |
seq_len += 1
|
1342 |
|
1343 |
+
# Update the position_ids tensor
|
1344 |
+
position_ids = position_ids[:, :-1] # Remove the last position
|
1345 |
+
position_ids = torch.cat([position_ids, torch.full((batch_size, 1), seq_len - 1, dtype=torch.long, device=position_ids.device)], dim=-1)
|
1346 |
+
|
1347 |
# Update the attention mask
|
1348 |
if attention_mask is not None:
|
1349 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
|
|
1365 |
output_hidden_states=output_hidden_states,
|
1366 |
return_dict=return_dict,
|
1367 |
)
|
|
|
1368 |
new_key_values = outputs.past_key_values
|
1369 |
|
1370 |
hidden_states = outputs[0]
|
|
|
1385 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1386 |
|
1387 |
# Append the end thought token to the input sequence
|
1388 |
+
# Append the end thought token to the input sequence
|
1389 |
end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
|
1390 |
input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
|
1391 |
seq_len += 1
|
1392 |
|
1393 |
+
# Update the position_ids tensor
|
1394 |
+
position_ids = position_ids[:, :-1] # Remove the last position
|
1395 |
+
position_ids = torch.cat([position_ids, torch.full((batch_size, 1), seq_len - 1, dtype=torch.long, device=position_ids.device)], dim=-1)
|
1396 |
+
|
1397 |
# Update the attention mask
|
1398 |
if attention_mask is not None:
|
1399 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
|
|
1628 |
base_embeddings = self.model.embed_tokens.weight
|
1629 |
if self.train_only_thinking_embedding:
|
1630 |
base_embeddings = base_embeddings.detach()
|
1631 |
+
if position_ids is None:
|
1632 |
+
position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
|
1633 |
# # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1634 |
fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
|
1635 |
for ahead_idx in range(fwd_iters):
|
|
|
1909 |
if len(attention_mask.shape) == 2:
|
1910 |
breakpoint()
|
1911 |
else:
|
1912 |
+
original_attention = attention_mask[..., :attention_mask.shape[-2]]
|
|
|
|
|
1913 |
if self.use_upper_triangular:
|
1914 |
new_attention = original_attention
|
1915 |
else:
|
|
|
1925 |
).to(attention_mask.dtype)
|
1926 |
|
1927 |
new_attention = new_attention.view(1, 1, seq_len, seq_len).repeat(input_ids.shape[0], 1, 1, 1)
|
|
|
|
|
1928 |
new_attention = new_attention * original_attention
|
1929 |
new_attention[new_attention == 0] = attention_mask.min()
|
1930 |
new_attention[new_attention == 1] = attention_mask.max()
|
1931 |
+
attention_mask = torch.cat([attention_mask, new_attention], dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1932 |
past_key_values = outputs.past_key_values
|
1933 |
position_ids = position_ids + 1
|
1934 |
|