sync with main
Browse files- modeling_chatglm.py +24 -16
modeling_chatglm.py
CHANGED
@@ -40,12 +40,6 @@ logger = logging.get_logger(__name__)
|
|
40 |
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
|
41 |
_CONFIG_FOR_DOC = "ChatGLMConfig"
|
42 |
|
43 |
-
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
44 |
-
"THUDM/chatglm3-6b",
|
45 |
-
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
|
46 |
-
]
|
47 |
-
|
48 |
-
|
49 |
def default_init(cls, *args, **kwargs):
|
50 |
return cls(*args, **kwargs)
|
51 |
|
@@ -253,15 +247,12 @@ class CoreAttention(torch.nn.Module):
|
|
253 |
# This is actually dropping out entire tokens to attend to, which might
|
254 |
# seem a bit unusual, but is taken from the original Transformer paper.
|
255 |
attention_probs = self.attention_dropout(attention_probs)
|
256 |
-
# =========================
|
257 |
-
# Context layer. [sq, b, hp]
|
258 |
-
# =========================
|
259 |
-
|
260 |
-
# value_layer -> context layer.
|
261 |
-
# [sk, b, np, hn] --> [b, np, sq, hn]
|
262 |
|
|
|
|
|
|
|
263 |
# context layer shape: [b, np, sq, hn]
|
264 |
-
output_size = (value_layer.size(
|
265 |
# change view [b * np, sk, hn]
|
266 |
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
|
267 |
# change view [b * np, sq, sk]
|
@@ -386,7 +377,10 @@ class SelfAttention(torch.nn.Module):
|
|
386 |
key_layer = torch.cat((cache_k, key_layer), dim=2)
|
387 |
value_layer = torch.cat((cache_v, value_layer), dim=2)
|
388 |
if use_cache:
|
389 |
-
kv_cache
|
|
|
|
|
|
|
390 |
else:
|
391 |
kv_cache = None
|
392 |
|
@@ -605,7 +599,7 @@ class GLMTransformer(torch.nn.Module):
|
|
605 |
hidden_states,
|
606 |
attention_mask=attention_mask,
|
607 |
rotary_pos_emb=rotary_pos_emb,
|
608 |
-
|
609 |
use_cache=use_cache,
|
610 |
use_reentrant=False
|
611 |
)
|
@@ -619,7 +613,15 @@ class GLMTransformer(torch.nn.Module):
|
|
619 |
)
|
620 |
hidden_states, kv_cache = layer_ret
|
621 |
if use_cache:
|
622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
623 |
|
624 |
if output_hidden_states:
|
625 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
@@ -773,6 +775,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
773 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
774 |
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
775 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
776 |
|
777 |
if not return_dict:
|
778 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
40 |
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
|
41 |
_CONFIG_FOR_DOC = "ChatGLMConfig"
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
def default_init(cls, *args, **kwargs):
|
44 |
return cls(*args, **kwargs)
|
45 |
|
|
|
247 |
# This is actually dropping out entire tokens to attend to, which might
|
248 |
# seem a bit unusual, but is taken from the original Transformer paper.
|
249 |
attention_probs = self.attention_dropout(attention_probs)
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
+
# query layer shape: [b * np, sq, hn]
|
252 |
+
# value layer shape: [b, np, sk, hn]
|
253 |
+
# attention shape: [b, np, sq, sk]
|
254 |
# context layer shape: [b, np, sq, hn]
|
255 |
+
output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
|
256 |
# change view [b * np, sk, hn]
|
257 |
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
|
258 |
# change view [b * np, sq, sk]
|
|
|
377 |
key_layer = torch.cat((cache_k, key_layer), dim=2)
|
378 |
value_layer = torch.cat((cache_v, value_layer), dim=2)
|
379 |
if use_cache:
|
380 |
+
if kv_cache is None:
|
381 |
+
kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
|
382 |
+
else:
|
383 |
+
kv_cache = (key_layer, value_layer)
|
384 |
else:
|
385 |
kv_cache = None
|
386 |
|
|
|
599 |
hidden_states,
|
600 |
attention_mask=attention_mask,
|
601 |
rotary_pos_emb=rotary_pos_emb,
|
602 |
+
kv_caches=kv_caches[index],
|
603 |
use_cache=use_cache,
|
604 |
use_reentrant=False
|
605 |
)
|
|
|
613 |
)
|
614 |
hidden_states, kv_cache = layer_ret
|
615 |
if use_cache:
|
616 |
+
# token by token decoding, use tuple format
|
617 |
+
if kv_caches[0] is not None:
|
618 |
+
presents = presents + (kv_cache,)
|
619 |
+
# prefilling in decoding, use tensor format to save cuda memory
|
620 |
+
else:
|
621 |
+
if len(presents) == 0:
|
622 |
+
presents = kv_cache
|
623 |
+
else:
|
624 |
+
presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0)
|
625 |
|
626 |
if output_hidden_states:
|
627 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
775 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
776 |
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
777 |
)
|
778 |
+
if presents is not None and type(presents) is torch.Tensor:
|
779 |
+
presents = presents.split(1, dim=0)
|
780 |
+
presents = list(presents)
|
781 |
+
presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
|
782 |
+
presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
|
783 |
+
presents = tuple(presents)
|
784 |
|
785 |
if not return_dict:
|
786 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|