Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +12 -64
modeling_quiet.py
CHANGED
@@ -270,22 +270,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
270 |
|
271 |
|
272 |
class QuietAttention(nn.Module):
|
273 |
-
"""
|
274 |
-
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
275 |
-
and "Generating Long Sequences with Sparse Transformers".
|
276 |
-
"""
|
277 |
-
|
278 |
def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
|
279 |
super().__init__()
|
280 |
self.config = config
|
281 |
self.layer_idx = layer_idx
|
282 |
-
if layer_idx is None:
|
283 |
-
logger.warning_once(
|
284 |
-
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
285 |
-
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
286 |
-
"when creating this class."
|
287 |
-
)
|
288 |
-
|
289 |
self.hidden_size = config.hidden_size
|
290 |
self.num_heads = config.num_attention_heads
|
291 |
self.head_dim = self.hidden_size // self.num_heads
|
@@ -296,11 +284,6 @@ class QuietAttention(nn.Module):
|
|
296 |
self.is_causal = True
|
297 |
self.attention_dropout = config.attention_dropout
|
298 |
|
299 |
-
if (self.head_dim * self.num_heads) != self.hidden_size:
|
300 |
-
raise ValueError(
|
301 |
-
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
302 |
-
f" and `num_heads`: {self.num_heads})."
|
303 |
-
)
|
304 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
305 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
306 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
@@ -312,9 +295,6 @@ class QuietAttention(nn.Module):
|
|
312 |
base=self.rope_theta,
|
313 |
)
|
314 |
|
315 |
-
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
316 |
-
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
317 |
-
|
318 |
def forward(
|
319 |
self,
|
320 |
hidden_states: torch.Tensor,
|
@@ -324,11 +304,7 @@ class QuietAttention(nn.Module):
|
|
324 |
output_attentions: bool = False,
|
325 |
use_cache: bool = False,
|
326 |
**kwargs,
|
327 |
-
)
|
328 |
-
if "padding_mask" in kwargs:
|
329 |
-
warnings.warn(
|
330 |
-
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
331 |
-
)
|
332 |
bsz, q_len, _ = hidden_states.size()
|
333 |
|
334 |
query_states = self.q_proj(hidden_states)
|
@@ -342,50 +318,31 @@ class QuietAttention(nn.Module):
|
|
342 |
kv_seq_len = key_states.shape[-2]
|
343 |
if past_key_value is not None:
|
344 |
if self.layer_idx is None:
|
345 |
-
raise ValueError(
|
346 |
-
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
347 |
-
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
348 |
-
"with a layer index."
|
349 |
-
)
|
350 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
351 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
352 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
353 |
|
354 |
if past_key_value is not None:
|
355 |
-
cache_kwargs = {"sin": sin, "cos": cos}
|
356 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
357 |
|
358 |
-
# repeat k/v heads if n_kv_heads < n_heads
|
359 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
360 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
361 |
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
)
|
369 |
-
|
370 |
-
if attention_mask is not None:
|
371 |
-
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
372 |
-
raise ValueError(
|
373 |
-
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
374 |
-
)
|
375 |
-
|
376 |
-
attn_weights = attn_weights + attention_mask
|
377 |
|
378 |
-
|
|
|
379 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
380 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
381 |
attn_output = torch.matmul(attn_weights, value_states)
|
382 |
|
383 |
-
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
384 |
-
raise ValueError(
|
385 |
-
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
386 |
-
f" {attn_output.size()}"
|
387 |
-
)
|
388 |
-
|
389 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
390 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
391 |
|
@@ -1083,16 +1040,7 @@ class QuietModel(QuietPreTrainedModel):
|
|
1083 |
inputs_embeds,
|
1084 |
past_key_values_length,
|
1085 |
)
|
1086 |
-
|
1087 |
-
# 4d mask is passed through the layers
|
1088 |
-
attention_mask = _prepare_4d_causal_attention_mask(
|
1089 |
-
attention_mask,
|
1090 |
-
(batch_size, seq_length),
|
1091 |
-
inputs_embeds,
|
1092 |
-
past_key_values_length,
|
1093 |
-
sliding_window=self.config.sliding_window,
|
1094 |
-
)
|
1095 |
-
|
1096 |
hidden_states = inputs_embeds
|
1097 |
|
1098 |
# decoder layers
|
|
|
270 |
|
271 |
|
272 |
class QuietAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
273 |
def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
|
274 |
super().__init__()
|
275 |
self.config = config
|
276 |
self.layer_idx = layer_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
self.hidden_size = config.hidden_size
|
278 |
self.num_heads = config.num_attention_heads
|
279 |
self.head_dim = self.hidden_size // self.num_heads
|
|
|
284 |
self.is_causal = True
|
285 |
self.attention_dropout = config.attention_dropout
|
286 |
|
|
|
|
|
|
|
|
|
|
|
287 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
288 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
289 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
|
295 |
base=self.rope_theta,
|
296 |
)
|
297 |
|
|
|
|
|
|
|
298 |
def forward(
|
299 |
self,
|
300 |
hidden_states: torch.Tensor,
|
|
|
304 |
output_attentions: bool = False,
|
305 |
use_cache: bool = False,
|
306 |
**kwargs,
|
307 |
+
):
|
|
|
|
|
|
|
|
|
308 |
bsz, q_len, _ = hidden_states.size()
|
309 |
|
310 |
query_states = self.q_proj(hidden_states)
|
|
|
318 |
kv_seq_len = key_states.shape[-2]
|
319 |
if past_key_value is not None:
|
320 |
if self.layer_idx is None:
|
321 |
+
raise ValueError("Layer index must be provided when using past key values.")
|
|
|
|
|
|
|
|
|
322 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
323 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
324 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
325 |
|
326 |
if past_key_value is not None:
|
327 |
+
cache_kwargs = {"sin": sin, "cos": cos}
|
328 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
329 |
|
|
|
330 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
331 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
332 |
|
333 |
+
if attention_mask is None:
|
334 |
+
attention_mask = torch.ones(bsz, 1, q_len, kv_seq_len, device=hidden_states.device)
|
335 |
+
else:
|
336 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
337 |
+
attention_mask = attention_mask.repeat(1, 1, q_len, 1)
|
338 |
+
attention_mask = torch.triu(attention_mask, diagonal=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
|
340 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
341 |
+
attn_weights = attn_weights + attention_mask
|
342 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
343 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
344 |
attn_output = torch.matmul(attn_weights, value_states)
|
345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
347 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
348 |
|
|
|
1040 |
inputs_embeds,
|
1041 |
past_key_values_length,
|
1042 |
)
|
1043 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1044 |
hidden_states = inputs_embeds
|
1045 |
|
1046 |
# decoder layers
|