davidlvxin
commited on
Commit
·
1ed39ad
1
Parent(s):
2990a09
Optimize the storage of KV cache
Browse files- README.md +5 -0
- modeling_chatglm.py +21 -8
README.md
CHANGED
@@ -15,8 +15,13 @@ tags:
|
|
15 |
<p align="center">
|
16 |
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1y7pqoloy-9b1g6T6JjA8J0KxvUjbwJw" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
|
17 |
</p>
|
|
|
|
|
|
|
|
|
18 |
|
19 |
## 介绍
|
|
|
20 |
ChatGLM**2**-6B-32K在[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b)的基础上进一步强化了对于长文本的理解能力,能够更好的处理最多32K长度的上下文。具体地,我们基于[位置插值](https://arxiv.org/abs/2306.15595)(Positional Interpolation)的方法对位置编码进行了更新,并在对话阶段使用 32K 的上下文长度训练。在实际的使用中,如果您面临的上下文长度基本在 **8K 以内**,我们推荐使用[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b);如果您需要处理**超过 8K** 的上下文长度,我们推荐使用ChatGLM2-6B-32K。
|
21 |
|
22 |
ChatGLM**2**-6B-32K是开源中英双语对话模型 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) 的加长版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B-32k 引入了如下新特性:
|
|
|
15 |
<p align="center">
|
16 |
👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1y7pqoloy-9b1g6T6JjA8J0KxvUjbwJw" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
|
17 |
</p>
|
18 |
+
## 更新/Update
|
19 |
+
|
20 |
+
- 我们优化了KV Cache的存储方式,减少了显存碎片的产生。基于优化后的代码,模型可以在约**20G显存**的情况下处理32K长度的上下文(FP/BF16格式)。
|
21 |
+
- We have optimized the storage method of the KV Cache, reducing the generation of memory fragmentation. Based on the optimized code, the model can process a context length of 32K under approximately **20G** of memory (FP/BF16 format).
|
22 |
|
23 |
## 介绍
|
24 |
+
|
25 |
ChatGLM**2**-6B-32K在[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b)的基础上进一步强化了对于长文本的理解能力,能够更好的处理最多32K长度的上下文。具体地,我们基于[位置插值](https://arxiv.org/abs/2306.15595)(Positional Interpolation)的方法对位置编码进行了更新,并在对话阶段使用 32K 的上下文长度训练。在实际的使用中,如果您面临的上下文长度基本在 **8K 以内**,我们推荐使用[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b);如果您需要处理**超过 8K** 的上下文长度,我们推荐使用ChatGLM2-6B-32K。
|
26 |
|
27 |
ChatGLM**2**-6B-32K是开源中英双语对话模型 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) 的加长版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B-32k 引入了如下新特性:
|
modeling_chatglm.py
CHANGED
@@ -413,7 +413,10 @@ class SelfAttention(torch.nn.Module):
|
|
413 |
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
414 |
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
415 |
if use_cache:
|
416 |
-
kv_cache
|
|
|
|
|
|
|
417 |
else:
|
418 |
kv_cache = None
|
419 |
|
@@ -612,12 +615,8 @@ class GLMTransformer(torch.nn.Module):
|
|
612 |
if not kv_caches:
|
613 |
kv_caches = [None for _ in range(self.num_layers)]
|
614 |
presents = () if use_cache else None
|
615 |
-
if self.
|
616 |
-
|
617 |
-
logger.warning_once(
|
618 |
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
619 |
-
)
|
620 |
-
use_cache = False
|
621 |
|
622 |
all_self_attentions = None
|
623 |
all_hidden_states = () if output_hidden_states else None
|
@@ -645,7 +644,15 @@ class GLMTransformer(torch.nn.Module):
|
|
645 |
)
|
646 |
hidden_states, kv_cache = layer_ret
|
647 |
if use_cache:
|
648 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
|
650 |
if output_hidden_states:
|
651 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
@@ -830,6 +837,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
830 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
831 |
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
832 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
833 |
|
834 |
if not return_dict:
|
835 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
413 |
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
414 |
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
415 |
if use_cache:
|
416 |
+
if kv_cache is None:
|
417 |
+
kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
|
418 |
+
else:
|
419 |
+
kv_cache = (key_layer, value_layer)
|
420 |
else:
|
421 |
kv_cache = None
|
422 |
|
|
|
615 |
if not kv_caches:
|
616 |
kv_caches = [None for _ in range(self.num_layers)]
|
617 |
presents = () if use_cache else None
|
618 |
+
if self.training:
|
619 |
+
use_cache = False
|
|
|
|
|
|
|
|
|
620 |
|
621 |
all_self_attentions = None
|
622 |
all_hidden_states = () if output_hidden_states else None
|
|
|
644 |
)
|
645 |
hidden_states, kv_cache = layer_ret
|
646 |
if use_cache:
|
647 |
+
# token by token decoding, use tuple format
|
648 |
+
if kv_caches[0] is not None:
|
649 |
+
presents = presents + (kv_cache,)
|
650 |
+
# prefilling in decoding, use tensor format to save cuda memory
|
651 |
+
else:
|
652 |
+
if len(presents) == 0:
|
653 |
+
presents = kv_cache
|
654 |
+
else:
|
655 |
+
presents = torch.cat((presents, kv_cache), dim=0)
|
656 |
|
657 |
if output_hidden_states:
|
658 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
837 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
838 |
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
839 |
)
|
840 |
+
if presents is not None and type(presents) is torch.Tensor:
|
841 |
+
presents = presents.split(1, dim=0)
|
842 |
+
presents = list(presents)
|
843 |
+
presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
|
844 |
+
presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
|
845 |
+
presents = tuple(presents)
|
846 |
|
847 |
if not return_dict:
|
848 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|