Update modeling_qwen.py, fix logn bug
Browse files- modeling_qwen.py +6 -5
modeling_qwen.py
CHANGED
@@ -177,7 +177,8 @@ class QWenAttention(nn.Module):
|
|
177 |
config.hidden_size, self.projection_size, bias=not config.no_bias
|
178 |
)
|
179 |
|
180 |
-
|
|
|
181 |
self.core_attention_flash = FlashSelfAttention(
|
182 |
causal=True, attention_dropout=config.attn_pdrop
|
183 |
)
|
@@ -371,12 +372,12 @@ class QWenAttention(nn.Module):
|
|
371 |
if self.use_logn_attn and not self.training:
|
372 |
if self.logn_tensor.device != query.device:
|
373 |
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
|
374 |
-
seq_start = key.size(
|
375 |
-
seq_end = key.size(
|
376 |
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
|
377 |
query = query * logn_tensor.expand_as(query)
|
378 |
|
379 |
-
if self.use_flash_attn and flash_attn_unpadded_func is not None:
|
380 |
q, k, v = query, key, value
|
381 |
context_layer = self.core_attention_flash(q, k, v)
|
382 |
|
@@ -397,7 +398,7 @@ class QWenAttention(nn.Module):
|
|
397 |
attn_output = self.c_proj(context_layer)
|
398 |
outputs = (attn_output, present)
|
399 |
if output_attentions:
|
400 |
-
if self.use_flash_attn and flash_attn_unpadded_func is not None:
|
401 |
raise ValueError("Cannot output attentions while using flash-attn")
|
402 |
else:
|
403 |
outputs += (attn_weight,)
|
|
|
177 |
config.hidden_size, self.projection_size, bias=not config.no_bias
|
178 |
)
|
179 |
|
180 |
+
self.is_fp32 = not(config.bf16 or config.fp16)
|
181 |
+
if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
|
182 |
self.core_attention_flash = FlashSelfAttention(
|
183 |
causal=True, attention_dropout=config.attn_pdrop
|
184 |
)
|
|
|
372 |
if self.use_logn_attn and not self.training:
|
373 |
if self.logn_tensor.device != query.device:
|
374 |
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
|
375 |
+
seq_start = key.size(1) - query.size(1)
|
376 |
+
seq_end = key.size(1)
|
377 |
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
|
378 |
query = query * logn_tensor.expand_as(query)
|
379 |
|
380 |
+
if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
|
381 |
q, k, v = query, key, value
|
382 |
context_layer = self.core_attention_flash(q, k, v)
|
383 |
|
|
|
398 |
attn_output = self.c_proj(context_layer)
|
399 |
outputs = (attn_output, present)
|
400 |
if output_attentions:
|
401 |
+
if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
|
402 |
raise ValueError("Cannot output attentions while using flash-attn")
|
403 |
else:
|
404 |
outputs += (attn_weight,)
|