[WIP] Upload folder using huggingface_hub (multi-commit 8db452b023d5c93b129a158fc21edde110d13c36a45f9ea385e43662f0f9a2ac)

#1
by bys0318 - opened
README.md DELETED
@@ -1,79 +0,0 @@
1
- ---
2
- language:
3
- - en
4
- - zh
5
- library_name: transformers
6
- tags:
7
- - Long Context
8
- - chatglm
9
- datasets:
10
- - THUDM/LongAlign-10k
11
- pipeline_tag: text-generation
12
- license: apache-2.0
13
- ---
14
- # LongAlign-6B-64k
15
-
16
- <p align="center">
17
- 🤗 <a href="https://huggingface.co/datasets/THUDM/LongAlign-10k" target="_blank">[LongAlign Dataset] </a> • 💻 <a href="https://github.com/THUDM/LongAlign" target="_blank">[Github Repo]</a> • 📃 <a href="https://arxiv.org/abs/2401.18058" target="_blank">[LongAlign Paper]</a>
18
- </p>
19
-
20
- **LongAlign** is the first full recipe for LLM alignment on long context. We propose the **LongAlign-10k** dataset, containing 10,000 long instruction data of 8k-64k in length. We investigate on trianing strategies, namely **packing (with loss weighting) and sorted batching**, which are all implemented in our code. For real-world long context evaluation, we introduce **LongBench-Chat** that evaluate the instruction-following capability on queries of 10k-100k length.
21
-
22
- ## All Models
23
-
24
- We open-sourced the following list of models:
25
-
26
- |Model|Huggingface Repo|Description|
27
- |---|---|---|
28
- |**LongAlign-6B-64k-base**| [🤗 Huggingface Repo](https://huggingface.co/THUDM/LongAlign-6B-64k-base) | **ChatGLM3-6B** with an extended 64k context window |
29
- |**LongAlign-6B-64k**| [🤗 Huggingface Repo](https://huggingface.co/THUDM/LongAlign-6B-64k) | Chat model by LongAlign training on LongAlign-6B-64k-base|
30
- |**LongAlign-7B-64k-base**| [🤗 Huggingface Repo](https://huggingface.co/THUDM/LongAlign-7B-64k-base) | **Llama-2-7B** with an extended 64k context window |
31
- |**LongAlign-7B-64k**| [🤗 Huggingface Repo](https://huggingface.co/THUDM/LongAlign-7B-64k) | Chat model by LongAlign training on LongAlign-7B-64k-base|
32
- |**LongAlign-13B-64k-base**| [🤗 Huggingface Repo](https://huggingface.co/THUDM/LongAlign-13B-64k-base) | **Llama-2-13B** with an extended 64k context window |
33
- |**LongAlign-13B-64k**| [🤗 Huggingface Repo](https://huggingface.co/THUDM/LongAlign-13B-64k) | Chat model by LongAlign training on LongAlign-13B-64k-base|
34
- |**ChatGLM3-6B-128k**| [🤗 Huggingface Repo](https://huggingface.co/THUDM/chatglm3-6b-128k) | **ChatGLM3-6B** with a 128k context window|
35
-
36
- ![](assets/leaderboard.png)
37
-
38
- ## Model usage
39
- Chat prompt template for LongAlign-6B-64k:
40
- ```text
41
- [Round 1]
42
-
43
- 问:Hi!
44
-
45
- 答:Hello! What can I assist you today?
46
-
47
- [Round 2]
48
-
49
- 问:What should I do if I can't sleep at night?
50
-
51
- 答:
52
- ```
53
- Chat prompt template for LongAlign-7B-64k and LongAlign-13B-64k:
54
- ```text
55
- [INST]Hi![/INST]Hello! What can I assist you today?
56
-
57
- [INST]What should I do if I can't sleep at night?[/INST]
58
- ```
59
- ChatGLM3-6B-128k uses the same prompt template as [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b).
60
-
61
- A simple demo for deployment of the model:
62
- ```python
63
- from transformers import AutoTokenizer, AutoModelForCausalLM
64
- import torch
65
- tokenizer = AutoTokenizer.from_pretrained("THUDM/LongAlign-6B-64k", trust_remote_code=True)
66
- model = AutoModelForCausalLM.from_pretrained("THUDM/LongAlign-6B-64k", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
67
- model = model.eval()
68
- query = open("assets/paper.txt").read() + "\n\nPlease summarize the paper."
69
- response, history = model.chat(tokenizer, query, history=[], max_new_tokens=512, temperature=1)
70
- print(response)
71
- ```
72
-
73
- ## Citation
74
-
75
- If you find our work useful, please consider citing LongAlign:
76
-
77
- ```
78
-
79
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
assets/leaderboard.png DELETED
Binary file (414 kB)
 
config.json DELETED
@@ -1,48 +0,0 @@
1
- {
2
- "_name_or_path": "THUDM/LongAlign-6B-64k",
3
- "add_bias_linear": false,
4
- "add_qkv_bias": true,
5
- "apply_query_key_layer_scaling": true,
6
- "apply_residual_connection_post_layernorm": false,
7
- "architectures": [
8
- "ChatGLMForConditionalGeneration"
9
- ],
10
- "attention_dropout": 0.0,
11
- "attention_softmax_in_fp32": true,
12
- "auto_map": {
13
- "AutoConfig": "configuration_chatglm.ChatGLMConfig",
14
- "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
15
- "AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
16
- "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
17
- "AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
18
- },
19
- "bias_dropout_fusion": true,
20
- "classifier_dropout": null,
21
- "eos_token_id": 2,
22
- "ffn_hidden_size": 13696,
23
- "fp32_residual_connection": false,
24
- "hidden_dropout": 0.0,
25
- "hidden_size": 4096,
26
- "kv_channels": 128,
27
- "layernorm_epsilon": 1e-05,
28
- "model_type": "chatglm",
29
- "multi_query_attention": true,
30
- "multi_query_group_num": 2,
31
- "num_attention_heads": 32,
32
- "num_layers": 28,
33
- "original_rope": true,
34
- "pad_token_id": 0,
35
- "padded_vocab_size": 65024,
36
- "post_layer_norm": true,
37
- "pre_seq_len": null,
38
- "prefix_projection": false,
39
- "quantization_bit": 0,
40
- "rmsnorm": true,
41
- "rope_ratio": 200,
42
- "seq_length": 65536,
43
- "tie_word_embeddings": false,
44
- "torch_dtype": "bfloat16",
45
- "transformers_version": "4.33.0",
46
- "use_cache": true,
47
- "vocab_size": 65024
48
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configuration_chatglm.py DELETED
@@ -1,63 +0,0 @@
1
- from transformers import PretrainedConfig
2
-
3
-
4
- class ChatGLMConfig(PretrainedConfig):
5
- model_type = "chatglm"
6
- def __init__(
7
- self,
8
- num_layers=28,
9
- padded_vocab_size=65024,
10
- hidden_size=4096,
11
- ffn_hidden_size=13696,
12
- kv_channels=128,
13
- num_attention_heads=32,
14
- seq_length=2048,
15
- hidden_dropout=0.0,
16
- classifier_dropout=None,
17
- attention_dropout=0.0,
18
- layernorm_epsilon=1e-5,
19
- rope_ratio=1,
20
- rmsnorm=True,
21
- apply_residual_connection_post_layernorm=False,
22
- post_layer_norm=True,
23
- add_bias_linear=False,
24
- add_qkv_bias=False,
25
- bias_dropout_fusion=True,
26
- multi_query_attention=False,
27
- multi_query_group_num=1,
28
- apply_query_key_layer_scaling=True,
29
- attention_softmax_in_fp32=True,
30
- fp32_residual_connection=False,
31
- quantization_bit=0,
32
- pre_seq_len=None,
33
- prefix_projection=False,
34
- **kwargs
35
- ):
36
- self.num_layers = num_layers
37
- self.vocab_size = padded_vocab_size
38
- self.padded_vocab_size = padded_vocab_size
39
- self.hidden_size = hidden_size
40
- self.ffn_hidden_size = ffn_hidden_size
41
- self.kv_channels = kv_channels
42
- self.num_attention_heads = num_attention_heads
43
- self.seq_length = seq_length
44
- self.hidden_dropout = hidden_dropout
45
- self.classifier_dropout = classifier_dropout
46
- self.attention_dropout = attention_dropout
47
- self.layernorm_epsilon = layernorm_epsilon
48
- self.rope_ratio = rope_ratio
49
- self.rmsnorm = rmsnorm
50
- self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
51
- self.post_layer_norm = post_layer_norm
52
- self.add_bias_linear = add_bias_linear
53
- self.add_qkv_bias = add_qkv_bias
54
- self.bias_dropout_fusion = bias_dropout_fusion
55
- self.multi_query_attention = multi_query_attention
56
- self.multi_query_group_num = multi_query_group_num
57
- self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
58
- self.attention_softmax_in_fp32 = attention_softmax_in_fp32
59
- self.fp32_residual_connection = fp32_residual_connection
60
- self.quantization_bit = quantization_bit
61
- self.pre_seq_len = pre_seq_len
62
- self.prefix_projection = prefix_projection
63
- super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generation_config.json DELETED
@@ -1,6 +0,0 @@
1
- {
2
- "_from_model_config": true,
3
- "eos_token_id": 2,
4
- "pad_token_id": 0,
5
- "transformers_version": "4.33.0"
6
- }
 
 
 
 
 
 
 
modeling_chatglm.py DELETED
@@ -1,1138 +0,0 @@
1
- """ PyTorch ChatGLM model. """
2
-
3
- import math
4
- import copy
5
- import warnings
6
- import re
7
- import sys
8
-
9
- import torch
10
- import torch.utils.checkpoint
11
- import torch.nn.functional as F
12
- from torch import nn
13
- from torch.nn import CrossEntropyLoss, LayerNorm
14
- from torch.nn.utils import skip_init
15
- from typing import Optional, Tuple, Union, List, Callable, Dict, Any
16
-
17
- from transformers.modeling_outputs import (
18
- BaseModelOutputWithPast,
19
- CausalLMOutputWithPast,
20
- )
21
- from transformers.modeling_utils import PreTrainedModel
22
- from transformers.utils import logging
23
- from transformers.generation.logits_process import LogitsProcessor
24
- from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
25
-
26
- from .configuration_chatglm import ChatGLMConfig
27
- from einops import rearrange
28
- try:
29
- from flash_attn.flash_attn_interface import flash_attn_unpadded_func
30
- except ImportError:
31
- try:
32
- # FlashAttention-2
33
- from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
34
- except ImportError:
35
- flash_attn_unpadded_func = None
36
-
37
- # flags required to enable jit fusion kernels
38
-
39
- if sys.platform != 'darwin':
40
- torch._C._jit_set_profiling_mode(False)
41
- torch._C._jit_set_profiling_executor(False)
42
- torch._C._jit_override_can_fuse_on_cpu(True)
43
- torch._C._jit_override_can_fuse_on_gpu(True)
44
-
45
- logger = logging.get_logger(__name__)
46
-
47
- _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B"
48
- _CONFIG_FOR_DOC = "ChatGLM6BConfig"
49
-
50
- CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
51
- "THUDM/chatglm2-6b",
52
- # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
53
- ]
54
-
55
- def default_init(cls, *args, **kwargs):
56
- return cls(*args, **kwargs)
57
-
58
-
59
- class InvalidScoreLogitsProcessor(LogitsProcessor):
60
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
61
- if torch.isnan(scores).any() or torch.isinf(scores).any():
62
- scores.zero_()
63
- scores[..., 5] = 5e4
64
- return scores
65
-
66
-
67
- class PrefixEncoder(torch.nn.Module):
68
- """
69
- The torch.nn model to encode the prefix
70
- Input shape: (batch-size, prefix-length)
71
- Output shape: (batch-size, prefix-length, 2*layers*hidden)
72
- """
73
-
74
- def __init__(self, config):
75
- super().__init__()
76
- self.prefix_projection = config.prefix_projection
77
- if self.prefix_projection:
78
- # Use a two-layer MLP to encode the prefix
79
- self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
80
- self.trans = torch.nn.Sequential(
81
- torch.nn.Linear(config.hidden_size, config.hidden_size),
82
- torch.nn.Tanh(),
83
- torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
84
- )
85
- else:
86
- self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
87
-
88
- def forward(self, prefix: torch.Tensor):
89
- if self.prefix_projection:
90
- prefix_tokens = self.embedding(prefix)
91
- past_key_values = self.trans(prefix_tokens)
92
- else:
93
- past_key_values = self.embedding(prefix)
94
- return past_key_values
95
-
96
-
97
- def split_tensor_along_last_dim(
98
- tensor: torch.Tensor,
99
- num_partitions: int,
100
- contiguous_split_chunks: bool = False,
101
- ) -> List[torch.Tensor]:
102
- """Split a tensor along its last dimension.
103
-
104
- Arguments:
105
- tensor: input tensor.
106
- num_partitions: number of partitions to split the tensor
107
- contiguous_split_chunks: If True, make each chunk contiguous
108
- in memory.
109
-
110
- Returns:
111
- A list of Tensors
112
- """
113
- # Get the size and dimension.
114
- last_dim = tensor.dim() - 1
115
- last_dim_size = tensor.size()[last_dim] // num_partitions
116
- # Split.
117
- tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
118
- # Note: torch.split does not create contiguous tensors by default.
119
- if contiguous_split_chunks:
120
- return tuple(chunk.contiguous() for chunk in tensor_list)
121
-
122
- return tensor_list
123
-
124
-
125
- class RotaryEmbedding(nn.Module):
126
- def __init__(self, dim, original_impl=False, device=None, dtype=None):
127
- super().__init__()
128
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
129
- self.register_buffer("inv_freq", inv_freq)
130
- self.dim = dim
131
- self.original_impl = original_impl
132
- self.ratio = 200
133
-
134
- def forward_impl(
135
- self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
136
- ):
137
- """Enhanced Transformer with Rotary Position Embedding.
138
-
139
- Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
140
- transformers/rope/__init__.py. MIT License:
141
- https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
142
- """
143
- # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
144
-
145
- base = base * self.ratio
146
- theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
147
-
148
- # Create position indexes `[0, 1, ..., seq_len - 1]`
149
- seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
150
-
151
- # Calculate the product of position index and $\theta_i$
152
- idx_theta = torch.outer(seq_idx, theta).float()
153
-
154
- cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
155
-
156
- # this is to mimic the behaviour of complex32, else we will get different results
157
- if dtype in (torch.float16, torch.bfloat16, torch.int8):
158
- cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
159
- return cache
160
-
161
- def forward(self, max_seq_len, offset=0):
162
- return self.forward_impl(
163
- max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
164
- )
165
-
166
-
167
- @torch.jit.script
168
- def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
169
- # x: [sq, b, np, hn]
170
- sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
171
- rot_dim = rope_cache.shape[-2] * 2
172
- x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
173
- # truncate to support variable sizes
174
- rope_cache = rope_cache[:sq]
175
- xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
176
- rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
177
- x_out2 = torch.stack(
178
- [
179
- xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
180
- xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
181
- ],
182
- -1,
183
- )
184
- x_out2 = x_out2.flatten(3)
185
- return torch.cat((x_out2, x_pass), dim=-1)
186
-
187
-
188
- class RMSNorm(torch.nn.Module):
189
- def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
190
- super().__init__()
191
- self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
192
- self.eps = eps
193
-
194
- def forward(self, hidden_states: torch.Tensor):
195
- input_dtype = hidden_states.dtype
196
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
197
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
198
-
199
- return (self.weight * hidden_states).to(input_dtype)
200
-
201
-
202
- class CoreAttention(torch.nn.Module):
203
- def __init__(self, config: ChatGLMConfig, layer_number):
204
- super(CoreAttention, self).__init__()
205
-
206
- self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
207
- self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
208
- if self.apply_query_key_layer_scaling:
209
- self.attention_softmax_in_fp32 = True
210
- self.layer_number = max(1, layer_number)
211
-
212
- projection_size = config.kv_channels * config.num_attention_heads
213
-
214
- # Per attention head and per partition values.
215
- self.hidden_size_per_partition = projection_size
216
- self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
217
- self.num_attention_heads_per_partition = config.num_attention_heads
218
-
219
- self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
220
- self.attention_dropout = config.attention_dropout
221
-
222
- def forward(self, query_layer, key_layer, value_layer, attention_mask):
223
- seqlen_q, batch_size = query_layer.shape[0], query_layer.shape[1]
224
- seqlen_k = key_layer.shape[0]
225
- query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> (b s) ...') for x in [query_layer, key_layer, value_layer]]
226
- # DO flash_attn_varlen_func
227
- if attention_mask is None or attention_mask.ndim != 1:
228
- cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
229
- device=query_layer.device)
230
- else:
231
- assert seqlen_q == seqlen_k
232
- cu_seqlens_q = attention_mask
233
- if self.training:
234
- assert seqlen_k == seqlen_q
235
- is_causal = True
236
- cu_seqlens_k = cu_seqlens_q
237
- else:
238
- is_causal = seqlen_q == seqlen_k
239
- cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
240
- device=query_layer.device) if not is_causal else cu_seqlens_q
241
- self.attention_dropout = 0
242
- context_layer = flash_attn_unpadded_func(
243
- query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
244
- self.attention_dropout,
245
- softmax_scale=1.0 / self.norm_factor, causal=is_causal
246
- )
247
- context_layer = rearrange(context_layer, '(b s) ... -> s b ...', b=batch_size)
248
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
249
- context_layer = context_layer.reshape(*new_context_layer_shape)
250
- return context_layer
251
-
252
-
253
- class SelfAttention(torch.nn.Module):
254
- """Parallel self-attention layer abstract class.
255
-
256
- Self-attention layer takes input with size [s, b, h]
257
- and returns output of the same size.
258
- """
259
-
260
- def __init__(self, config: ChatGLMConfig, layer_number, device=None):
261
- super(SelfAttention, self).__init__()
262
- self.layer_number = max(1, layer_number)
263
-
264
- self.projection_size = config.kv_channels * config.num_attention_heads
265
-
266
- # Per attention head and per partition values.
267
- self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
268
- self.num_attention_heads_per_partition = config.num_attention_heads
269
-
270
- self.multi_query_attention = config.multi_query_attention
271
- self.qkv_hidden_size = 3 * self.projection_size
272
- if self.multi_query_attention:
273
- self.num_multi_query_groups_per_partition = config.multi_query_group_num
274
- self.qkv_hidden_size = (
275
- self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
276
- )
277
- self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
278
- bias=config.add_bias_linear or config.add_qkv_bias,
279
- device=device, **_config_to_kwargs(config)
280
- )
281
-
282
- self.core_attention = CoreAttention(config, self.layer_number)
283
-
284
- # Output.
285
- self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
286
- device=device, **_config_to_kwargs(config)
287
- )
288
-
289
- def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
290
- if self.multi_query_attention:
291
- num_attention_heads = self.num_multi_query_groups_per_partition
292
- else:
293
- num_attention_heads = self.num_attention_heads_per_partition
294
- return torch.empty(
295
- inference_max_sequence_len,
296
- batch_size,
297
- num_attention_heads,
298
- self.hidden_size_per_attention_head,
299
- dtype=dtype,
300
- device=device,
301
- )
302
-
303
- def forward(
304
- self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
305
- ):
306
- # hidden_states: [sq, b, h]
307
-
308
- # =================================================
309
- # Pre-allocate memory for key-values for inference.
310
- # =================================================
311
- # =====================
312
- # Query, Key, and Value
313
- # =====================
314
-
315
- # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
316
- mixed_x_layer = self.query_key_value(hidden_states)
317
-
318
- if self.multi_query_attention:
319
- (query_layer, key_layer, value_layer) = mixed_x_layer.split(
320
- [
321
- self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
322
- self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
323
- self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
324
- ],
325
- dim=-1,
326
- )
327
- query_layer = query_layer.view(
328
- query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
329
- )
330
- key_layer = key_layer.view(
331
- key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
332
- )
333
- value_layer = value_layer.view(
334
- value_layer.size()[:-1]
335
- + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
336
- )
337
- else:
338
- new_tensor_shape = mixed_x_layer.size()[:-1] + \
339
- (self.num_attention_heads_per_partition,
340
- 3 * self.hidden_size_per_attention_head)
341
- mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
342
-
343
- # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
344
- (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
345
-
346
- # apply relative positional encoding (rotary embedding)
347
- if rotary_pos_emb is not None:
348
- query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
349
- key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
350
-
351
- # adjust key and value for inference
352
- if use_cache:
353
- if kv_cache is not None:
354
- cache_k, cache_v = kv_cache
355
- key_layer = torch.cat((cache_k, key_layer), dim=0)
356
- value_layer = torch.cat((cache_v, value_layer), dim=0)
357
- kv_cache = (key_layer, value_layer)
358
- else:
359
- kv_cache = None
360
-
361
-
362
- if self.multi_query_attention:
363
- key_layer = key_layer.unsqueeze(-2)
364
- key_layer = key_layer.expand(
365
- -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
366
- )
367
- key_layer = key_layer.contiguous().view(
368
- key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
369
- )
370
- value_layer = value_layer.unsqueeze(-2)
371
- value_layer = value_layer.expand(
372
- -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
373
- )
374
- value_layer = value_layer.contiguous().view(
375
- value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
376
- )
377
-
378
- # ==================================
379
- # core attention computation
380
- # ==================================
381
-
382
- context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
383
-
384
- # =================
385
- # Output. [sq, b, h]
386
- # =================
387
-
388
- output = self.dense(context_layer)
389
-
390
- return output, kv_cache
391
-
392
-
393
- def _config_to_kwargs(args):
394
- common_kwargs = {
395
- "dtype": args.torch_dtype,
396
- }
397
- return common_kwargs
398
-
399
-
400
- class MLP(torch.nn.Module):
401
- """MLP.
402
-
403
- MLP will take the input with h hidden state, project it to 4*h
404
- hidden dimension, perform nonlinear transformation, and project the
405
- state back into h hidden dimension.
406
- """
407
-
408
- def __init__(self, config: ChatGLMConfig, device=None):
409
- super(MLP, self).__init__()
410
-
411
- self.add_bias = config.add_bias_linear
412
-
413
- # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
414
- self.dense_h_to_4h = nn.Linear(
415
- config.hidden_size,
416
- config.ffn_hidden_size * 2,
417
- bias=self.add_bias,
418
- device=device,
419
- **_config_to_kwargs(config)
420
- )
421
-
422
- def swiglu(x):
423
- x = torch.chunk(x, 2, dim=-1)
424
- return F.silu(x[0]) * x[1]
425
-
426
- self.activation_func = swiglu
427
-
428
- # Project back to h.
429
- self.dense_4h_to_h = nn.Linear(
430
- config.ffn_hidden_size,
431
- config.hidden_size,
432
- bias=self.add_bias,
433
- device=device,
434
- **_config_to_kwargs(config)
435
- )
436
-
437
- def forward(self, hidden_states):
438
- # [s, b, 4hp]
439
- intermediate_parallel = self.dense_h_to_4h(hidden_states)
440
- intermediate_parallel = self.activation_func(intermediate_parallel)
441
- # [s, b, h]
442
- output = self.dense_4h_to_h(intermediate_parallel)
443
- return output
444
-
445
-
446
- class GLMBlock(torch.nn.Module):
447
- """A single transformer layer.
448
-
449
- Transformer layer takes input with size [s, b, h] and returns an
450
- output of the same size.
451
- """
452
-
453
- def __init__(self, config: ChatGLMConfig, layer_number, device=None):
454
- super(GLMBlock, self).__init__()
455
- self.layer_number = layer_number
456
-
457
- self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
458
-
459
- self.fp32_residual_connection = config.fp32_residual_connection
460
-
461
- LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
462
- # Layernorm on the input data.
463
- self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
464
- dtype=config.torch_dtype)
465
-
466
- # Self attention.
467
- self.self_attention = SelfAttention(config, layer_number, device=device)
468
- self.hidden_dropout = config.hidden_dropout
469
-
470
- # Layernorm on the attention output
471
- self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
472
- dtype=config.torch_dtype)
473
-
474
- # MLP
475
- self.mlp = MLP(config, device=device)
476
-
477
- def forward(
478
- self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
479
- ):
480
- # hidden_states: [s, b, h]
481
-
482
- # Layer norm at the beginning of the transformer layer.
483
- layernorm_output = self.input_layernorm(hidden_states)
484
- # Self attention.
485
- attention_output, kv_cache = self.self_attention(
486
- layernorm_output,
487
- attention_mask,
488
- rotary_pos_emb,
489
- kv_cache=kv_cache,
490
- use_cache=use_cache
491
- )
492
-
493
- # Residual connection.
494
- if self.apply_residual_connection_post_layernorm:
495
- residual = layernorm_output
496
- else:
497
- residual = hidden_states
498
-
499
- layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
500
- layernorm_input = residual + layernorm_input
501
-
502
- # Layer norm post the self attention.
503
- layernorm_output = self.post_attention_layernorm(layernorm_input)
504
-
505
- # MLP.
506
- mlp_output = self.mlp(layernorm_output)
507
-
508
- # Second residual connection.
509
- if self.apply_residual_connection_post_layernorm:
510
- residual = layernorm_output
511
- else:
512
- residual = layernorm_input
513
-
514
- output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
515
- output = residual + output
516
-
517
- return output, kv_cache
518
-
519
-
520
- class GLMTransformer(torch.nn.Module):
521
- """Transformer class."""
522
-
523
- def __init__(self, config: ChatGLMConfig, device=None):
524
- super(GLMTransformer, self).__init__()
525
-
526
- self.fp32_residual_connection = config.fp32_residual_connection
527
- self.post_layer_norm = config.post_layer_norm
528
-
529
- # Number of layers.
530
- self.num_layers = config.num_layers
531
-
532
- # Transformer layers.
533
- def build_layer(layer_number):
534
- return GLMBlock(config, layer_number, device=device)
535
-
536
- self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
537
-
538
- if self.post_layer_norm:
539
- LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
540
- # Final layer norm before output.
541
- self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
542
- dtype=config.torch_dtype)
543
-
544
- self.gradient_checkpointing = False
545
-
546
- def _get_layer(self, layer_number):
547
- return self.layers[layer_number]
548
-
549
- def forward(
550
- self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
551
- use_cache: Optional[bool] = True,
552
- output_hidden_states: Optional[bool] = False,
553
- ):
554
- if not kv_caches:
555
- kv_caches = [None for _ in range(self.num_layers)]
556
- presents = () if use_cache else None
557
- if self.gradient_checkpointing and self.training:
558
- if use_cache:
559
- logger.warning_once(
560
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
561
- )
562
- use_cache = False
563
-
564
- all_self_attentions = None
565
- all_hidden_states = () if output_hidden_states else None
566
- for index in range(self.num_layers):
567
- if output_hidden_states:
568
- all_hidden_states = all_hidden_states + (hidden_states,)
569
-
570
- layer = self._get_layer(index)
571
- if self.gradient_checkpointing and self.training:
572
- layer_ret = torch.utils.checkpoint.checkpoint(
573
- layer,
574
- hidden_states,
575
- attention_mask,
576
- rotary_pos_emb,
577
- kv_caches[index],
578
- use_cache
579
- )
580
- else:
581
- layer_ret = layer(
582
- hidden_states,
583
- attention_mask,
584
- rotary_pos_emb,
585
- kv_cache=kv_caches[index],
586
- use_cache=use_cache
587
- )
588
- hidden_states, kv_cache = layer_ret
589
- if use_cache:
590
- presents = presents + (kv_cache,)
591
-
592
- if output_hidden_states:
593
- all_hidden_states = all_hidden_states + (hidden_states,)
594
-
595
- # Final layer norm.
596
- if self.post_layer_norm:
597
- hidden_states = self.final_layernorm(hidden_states)
598
-
599
- return hidden_states, presents, all_hidden_states, all_self_attentions
600
-
601
-
602
- class ChatGLMPreTrainedModel(PreTrainedModel):
603
- """
604
- An abstract class to handle weights initialization and
605
- a simple interface for downloading and loading pretrained models.
606
- """
607
-
608
- is_parallelizable = False
609
- supports_gradient_checkpointing = True
610
- config_class = ChatGLMConfig
611
- base_model_prefix = "transformer"
612
- _no_split_modules = ["GLMBlock"]
613
-
614
- def _init_weights(self, module: nn.Module):
615
- """Initialize the weights."""
616
- return
617
-
618
- def get_masks(self, input_ids, past_key_values, padding_mask=None):
619
- batch_size, seq_length = input_ids.shape
620
- full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
621
- full_attention_mask.tril_()
622
- past_length = 0
623
- if past_key_values:
624
- past_length = past_key_values[0][0].shape[0]
625
- if past_length:
626
- full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
627
- device=input_ids.device), full_attention_mask), dim=-1)
628
- if padding_mask is not None:
629
- full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
630
- if not past_length and padding_mask is not None:
631
- full_attention_mask -= padding_mask.unsqueeze(-1) - 1
632
- full_attention_mask = (full_attention_mask < 0.5).bool()
633
- full_attention_mask.unsqueeze_(1)
634
- return full_attention_mask
635
-
636
- def get_position_ids(self, input_ids, device):
637
- batch_size, seq_length = input_ids.shape
638
- position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
639
- return position_ids
640
-
641
- def _set_gradient_checkpointing(self, module, value=False):
642
- if isinstance(module, GLMTransformer):
643
- module.gradient_checkpointing = value
644
-
645
-
646
- class Embedding(torch.nn.Module):
647
- """Language model embeddings."""
648
-
649
- def __init__(self, config: ChatGLMConfig, device=None):
650
- super(Embedding, self).__init__()
651
-
652
- self.hidden_size = config.hidden_size
653
- # Word embeddings (parallel).
654
- self.word_embeddings = nn.Embedding(
655
- config.padded_vocab_size,
656
- self.hidden_size,
657
- dtype=config.torch_dtype,
658
- device=device
659
- )
660
- self.fp32_residual_connection = config.fp32_residual_connection
661
-
662
- def forward(self, input_ids):
663
- # Embeddings.
664
- words_embeddings = self.word_embeddings(input_ids)
665
- embeddings = words_embeddings
666
- # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
667
- embeddings = embeddings.transpose(0, 1).contiguous()
668
- # If the input flag for fp32 residual connection is set, convert for float.
669
- if self.fp32_residual_connection:
670
- embeddings = embeddings.float()
671
- return embeddings
672
-
673
-
674
- class ChatGLMModel(ChatGLMPreTrainedModel):
675
- def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
676
- super().__init__(config)
677
- if empty_init:
678
- init_method = skip_init
679
- else:
680
- init_method = default_init
681
- init_kwargs = {}
682
- if device is not None:
683
- init_kwargs["device"] = device
684
- self.embedding = init_method(Embedding, config, **init_kwargs)
685
-
686
- # Rotary positional embeddings
687
- self.seq_length = config.seq_length
688
- rotary_dim = (
689
- config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
690
- )
691
-
692
- self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
693
- dtype=config.torch_dtype)
694
- self.encoder = init_method(GLMTransformer, config, **init_kwargs)
695
- self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
696
- dtype=config.torch_dtype, **init_kwargs)
697
- self.pre_seq_len = config.pre_seq_len
698
- self.prefix_projection = config.prefix_projection
699
- if self.pre_seq_len is not None:
700
- for param in self.parameters():
701
- param.requires_grad = False
702
- self.prefix_tokens = torch.arange(self.pre_seq_len).long()
703
- self.prefix_encoder = PrefixEncoder(config)
704
- self.dropout = torch.nn.Dropout(0.1)
705
-
706
- def get_input_embeddings(self):
707
- return self.embedding.word_embeddings
708
-
709
- def get_prompt(self, batch_size, device, dtype=torch.half):
710
- prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
711
- past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
712
- past_key_values = past_key_values.view(
713
- batch_size,
714
- self.pre_seq_len,
715
- self.num_layers * 2,
716
- self.num_attention_heads,
717
- self.hidden_size // self.num_attention_heads
718
- )
719
- # seq_len, b, nh, hidden_size
720
- past_key_values = self.dropout(past_key_values)
721
- past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
722
- return past_key_values
723
-
724
- def forward(
725
- self,
726
- input_ids,
727
- position_ids: Optional[torch.Tensor] = None,
728
- attention_mask: Optional[torch.BoolTensor] = None,
729
- full_attention_mask: Optional[torch.BoolTensor] = None,
730
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
731
- inputs_embeds: Optional[torch.Tensor] = None,
732
- use_cache: Optional[bool] = None,
733
- output_hidden_states: Optional[bool] = None,
734
- return_dict: Optional[bool] = None,
735
- ):
736
- output_hidden_states = (
737
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
738
- )
739
- use_cache = use_cache if use_cache is not None else self.config.use_cache
740
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
741
-
742
- batch_size, seq_length = input_ids.shape
743
-
744
- if inputs_embeds is None:
745
- inputs_embeds = self.embedding(input_ids)
746
-
747
- # if full_attention_mask is None:
748
- # if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
749
- # full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
750
-
751
- # Rotary positional embeddings
752
- rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
753
- if position_ids is not None:
754
- rotary_pos_emb = rotary_pos_emb[position_ids]
755
- else:
756
- rotary_pos_emb = rotary_pos_emb[None, :seq_length]
757
- rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
758
-
759
- if past_key_values is None:
760
- if self.pre_seq_len is not None:
761
- past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
762
- dtype=inputs_embeds.dtype)
763
-
764
- # Run encoder.
765
- hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
766
- inputs_embeds, attention_mask, rotary_pos_emb=rotary_pos_emb,
767
- kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
768
- )
769
-
770
- if not return_dict:
771
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
772
-
773
- return BaseModelOutputWithPast(
774
- last_hidden_state=hidden_states,
775
- past_key_values=presents,
776
- hidden_states=all_hidden_states,
777
- attentions=all_self_attentions,
778
- )
779
-
780
- def quantize(self, weight_bit_width: int):
781
- from .quantization import quantize
782
- quantize(self.encoder, weight_bit_width)
783
- return self
784
-
785
-
786
- class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
787
- def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
788
- super().__init__(config)
789
-
790
- self.max_sequence_length = config.max_length
791
- self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
792
- self.config = config
793
- self.quantized = False
794
- self.pack_loss = False
795
-
796
- if self.config.quantization_bit:
797
- self.quantize(self.config.quantization_bit, empty_init=True)
798
-
799
- def _update_model_kwargs_for_generation(
800
- self,
801
- outputs: ModelOutput,
802
- model_kwargs: Dict[str, Any],
803
- is_encoder_decoder: bool = False,
804
- standardize_cache_format: bool = False,
805
- ) -> Dict[str, Any]:
806
- # update past_key_values
807
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
808
- outputs, standardize_cache_format=standardize_cache_format
809
- )
810
-
811
- # update attention mask
812
- if "attention_mask" in model_kwargs:
813
- attention_mask = model_kwargs["attention_mask"]
814
- model_kwargs["attention_mask"] = torch.cat(
815
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
816
- )
817
-
818
- # update position ids
819
- if "position_ids" in model_kwargs:
820
- position_ids = model_kwargs["position_ids"]
821
- new_position_id = position_ids[..., -1:].clone()
822
- new_position_id += 1
823
- model_kwargs["position_ids"] = torch.cat(
824
- [position_ids, new_position_id], dim=-1
825
- )
826
-
827
- model_kwargs["is_first_forward"] = False
828
- return model_kwargs
829
-
830
- def prepare_inputs_for_generation(
831
- self,
832
- input_ids: torch.LongTensor,
833
- past_key_values: Optional[torch.Tensor] = None,
834
- attention_mask: Optional[torch.Tensor] = None,
835
- position_ids: Optional[torch.Tensor] = None,
836
- is_first_forward: bool = True,
837
- **kwargs
838
- ) -> dict:
839
- # only last token for input_ids if past is not None
840
- if position_ids is None:
841
- position_ids = self.get_position_ids(input_ids, device=input_ids.device)
842
- if not is_first_forward:
843
- position_ids = position_ids[..., -1:]
844
- input_ids = input_ids[:, -1:]
845
- return {
846
- "input_ids": input_ids,
847
- "past_key_values": past_key_values,
848
- "position_ids": position_ids,
849
- "attention_mask": attention_mask,
850
- "return_last_logit": True
851
- }
852
-
853
- def forward(
854
- self,
855
- input_ids: Optional[torch.Tensor] = None,
856
- position_ids: Optional[torch.Tensor] = None,
857
- attention_mask: Optional[torch.Tensor] = None,
858
- past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
859
- inputs_embeds: Optional[torch.Tensor] = None,
860
- labels: Optional[Tuple[torch.Tensor]] = None,
861
- use_cache: Optional[bool] = None,
862
- output_attentions: Optional[bool] = None,
863
- output_hidden_states: Optional[bool] = None,
864
- return_dict: Optional[bool] = None,
865
- return_last_logit: Optional[bool] = False,
866
- ):
867
- use_cache = use_cache if use_cache is not None else self.config.use_cache
868
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
869
-
870
- transformer_outputs = self.transformer(
871
- input_ids=input_ids,
872
- position_ids=position_ids,
873
- attention_mask=attention_mask,
874
- past_key_values=past_key_values,
875
- inputs_embeds=inputs_embeds,
876
- use_cache=use_cache,
877
- output_hidden_states=output_hidden_states,
878
- return_dict=return_dict,
879
- )
880
-
881
- hidden_states = transformer_outputs[0]
882
- if return_last_logit:
883
- hidden_states = hidden_states[-1:]
884
- lm_logits = self.transformer.output_layer(hidden_states)
885
- lm_logits = lm_logits.transpose(0, 1).contiguous()
886
-
887
- loss = None
888
- if labels is not None:
889
- lm_logits = lm_logits.to(torch.float32)
890
- # Shift so that tokens < n predict n
891
- shift_logits = lm_logits[..., :-1, :].contiguous()
892
- if isinstance(labels, tuple) or isinstance(labels, list):
893
- labels, weights = labels
894
- shift_labels = labels[..., 1:].contiguous()
895
- if self.pack_loss:
896
- shift_weights = weights[..., 1:].contiguous()
897
- loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
898
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
899
- loss = (loss * shift_weights).sum()
900
- # loss *= weights
901
- else:
902
- loss_fct = CrossEntropyLoss(ignore_index=-100)
903
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
904
-
905
- lm_logits = lm_logits.to(hidden_states.dtype)
906
- loss = loss.to(hidden_states.dtype)
907
-
908
- if not return_dict:
909
- output = (lm_logits,) + transformer_outputs[1:]
910
- return ((loss,) + output) if loss is not None else output
911
-
912
- return CausalLMOutputWithPast(
913
- loss=loss,
914
- logits=lm_logits,
915
- past_key_values=transformer_outputs.past_key_values,
916
- hidden_states=transformer_outputs.hidden_states,
917
- attentions=transformer_outputs.attentions,
918
- )
919
-
920
- @staticmethod
921
- def _reorder_cache(
922
- past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
923
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
924
- """
925
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
926
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
927
- beam_idx at every generation step.
928
-
929
- Output shares the same memory storage as `past`.
930
- """
931
- return tuple(
932
- (
933
- layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
934
- layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
935
- )
936
- for layer_past in past
937
- )
938
-
939
- def process_response(self, response):
940
- response = response.strip()
941
- response = response.replace("[[训练时间]]", "2023年")
942
- return response
943
-
944
- def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
945
- prompt = tokenizer.build_prompt(query, history=history)
946
- inputs = tokenizer([prompt], return_tensors="pt")
947
- inputs = inputs.to(self.device)
948
- return inputs
949
-
950
- def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
951
- if history:
952
- prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
953
- input_ids = tokenizer.encode(prompt, add_special_tokens=False)
954
- input_ids = input_ids[1:]
955
- inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False)
956
- else:
957
- prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
958
- inputs = tokenizer([prompt], return_tensors="pt")
959
- inputs = inputs.to(self.device)
960
- return inputs
961
-
962
- @torch.no_grad()
963
- def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
964
- do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
965
- if history is None:
966
- history = []
967
- if logits_processor is None:
968
- logits_processor = LogitsProcessorList()
969
- logits_processor.append(InvalidScoreLogitsProcessor())
970
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
971
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
972
- inputs = self.build_inputs(tokenizer, query, history=history)
973
- outputs = self.generate(**inputs, **gen_kwargs)
974
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
975
- response = tokenizer.decode(outputs)
976
- response = self.process_response(response)
977
- history = history + [(query, response)]
978
- return response, history
979
-
980
- @torch.no_grad()
981
- def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
982
- max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
983
- return_past_key_values=False, **kwargs):
984
- if history is None:
985
- history = []
986
- if logits_processor is None:
987
- logits_processor = LogitsProcessorList()
988
- logits_processor.append(InvalidScoreLogitsProcessor())
989
- gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
990
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
991
- if past_key_values is None and not return_past_key_values:
992
- inputs = self.build_inputs(tokenizer, query, history=history)
993
- else:
994
- inputs = self.build_stream_inputs(tokenizer, query, history=history)
995
- if past_key_values is not None:
996
- past_length = past_key_values[0][0].shape[0]
997
- if self.transformer.pre_seq_len is not None:
998
- past_length -= self.transformer.pre_seq_len
999
- inputs.position_ids += past_length
1000
- attention_mask = inputs.attention_mask
1001
- attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1002
- inputs['attention_mask'] = attention_mask
1003
- for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1004
- return_past_key_values=return_past_key_values, **gen_kwargs):
1005
- if return_past_key_values:
1006
- outputs, past_key_values = outputs
1007
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
1008
- response = tokenizer.decode(outputs)
1009
- if response and response[-1] != "�":
1010
- response = self.process_response(response)
1011
- new_history = history + [(query, response)]
1012
- if return_past_key_values:
1013
- yield response, new_history, past_key_values
1014
- else:
1015
- yield response, new_history
1016
-
1017
- @torch.no_grad()
1018
- def stream_generate(
1019
- self,
1020
- input_ids,
1021
- generation_config: Optional[GenerationConfig] = None,
1022
- logits_processor: Optional[LogitsProcessorList] = None,
1023
- stopping_criteria: Optional[StoppingCriteriaList] = None,
1024
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1025
- return_past_key_values=False,
1026
- **kwargs,
1027
- ):
1028
- batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1029
-
1030
- if generation_config is None:
1031
- generation_config = self.generation_config
1032
- generation_config = copy.deepcopy(generation_config)
1033
- model_kwargs = generation_config.update(**kwargs)
1034
- bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1035
-
1036
- if isinstance(eos_token_id, int):
1037
- eos_token_id = [eos_token_id]
1038
-
1039
- has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1040
- if has_default_max_length and generation_config.max_new_tokens is None:
1041
- warnings.warn(
1042
- f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1043
- "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1044
- " recommend using `max_new_tokens` to control the maximum length of the generation.",
1045
- UserWarning,
1046
- )
1047
- elif generation_config.max_new_tokens is not None:
1048
- generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1049
- if not has_default_max_length:
1050
- logger.warn(
1051
- f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1052
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1053
- "Please refer to the documentation for more information. "
1054
- "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1055
- UserWarning,
1056
- )
1057
-
1058
- if input_ids_seq_length >= generation_config.max_length:
1059
- input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1060
- logger.warning(
1061
- f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1062
- f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1063
- " increasing `max_new_tokens`."
1064
- )
1065
-
1066
- # 2. Set generation parameters if not already defined
1067
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1068
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1069
-
1070
- logits_processor = self._get_logits_processor(
1071
- generation_config=generation_config,
1072
- input_ids_seq_length=input_ids_seq_length,
1073
- encoder_input_ids=input_ids,
1074
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1075
- logits_processor=logits_processor,
1076
- )
1077
-
1078
- stopping_criteria = self._get_stopping_criteria(
1079
- generation_config=generation_config, stopping_criteria=stopping_criteria
1080
- )
1081
- logits_warper = self._get_logits_warper(generation_config)
1082
-
1083
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1084
- scores = None
1085
- while True:
1086
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1087
- # forward pass to get next token
1088
- outputs = self(
1089
- **model_inputs,
1090
- return_dict=True,
1091
- output_attentions=False,
1092
- output_hidden_states=False,
1093
- )
1094
-
1095
- next_token_logits = outputs.logits[:, -1, :]
1096
-
1097
- # pre-process distribution
1098
- next_token_scores = logits_processor(input_ids, next_token_logits)
1099
- next_token_scores = logits_warper(input_ids, next_token_scores)
1100
-
1101
- # sample
1102
- probs = nn.functional.softmax(next_token_scores, dim=-1)
1103
- if generation_config.do_sample:
1104
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1105
- else:
1106
- next_tokens = torch.argmax(probs, dim=-1)
1107
-
1108
- # update generated ids, model inputs, and length for next step
1109
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1110
- model_kwargs = self._update_model_kwargs_for_generation(
1111
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1112
- )
1113
- unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
1114
- if return_past_key_values:
1115
- yield input_ids, outputs.past_key_values
1116
- else:
1117
- yield input_ids
1118
- # stop when each sentence is finished, or if we exceed the maximum length
1119
- if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1120
- break
1121
-
1122
- def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
1123
- if bits == 0:
1124
- return
1125
-
1126
- from .quantization import quantize
1127
-
1128
- if self.quantized:
1129
- logger.info("Already quantized.")
1130
- return self
1131
-
1132
- self.quantized = True
1133
-
1134
- self.config.quantization_bit = bits
1135
-
1136
- self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
1137
- **kwargs)
1138
- return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pytorch_model-00001-of-00002.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:25be0c70efa1bc74d2f5881f7de92750728bb2d7858b5289549575ba32a74570
3
- size 9986264643
 
 
 
 
pytorch_model-00002-of-00002.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b845b597dce850a33f466186b6aa48b5055c0df5ecb33df675a6ebc4ce4d073f
3
- size 2500975567
 
 
 
 
pytorch_model.bin.index.json DELETED
@@ -1,207 +0,0 @@
1
- {
2
- "metadata": {
3
- "total_size": 12487168064
4
- },
5
- "weight_map": {
6
- "transformer.embedding.word_embeddings.weight": "pytorch_model-00001-of-00002.bin",
7
- "transformer.encoder.final_layernorm.weight": "pytorch_model-00002-of-00002.bin",
8
- "transformer.encoder.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
9
- "transformer.encoder.layers.0.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
10
- "transformer.encoder.layers.0.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
11
- "transformer.encoder.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
12
- "transformer.encoder.layers.0.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
13
- "transformer.encoder.layers.0.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
14
- "transformer.encoder.layers.0.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
15
- "transformer.encoder.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
16
- "transformer.encoder.layers.1.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
17
- "transformer.encoder.layers.1.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
18
- "transformer.encoder.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
19
- "transformer.encoder.layers.1.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
20
- "transformer.encoder.layers.1.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
21
- "transformer.encoder.layers.1.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
22
- "transformer.encoder.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
23
- "transformer.encoder.layers.10.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
24
- "transformer.encoder.layers.10.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
25
- "transformer.encoder.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
26
- "transformer.encoder.layers.10.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
27
- "transformer.encoder.layers.10.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
28
- "transformer.encoder.layers.10.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
29
- "transformer.encoder.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
30
- "transformer.encoder.layers.11.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
31
- "transformer.encoder.layers.11.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
32
- "transformer.encoder.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
33
- "transformer.encoder.layers.11.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
34
- "transformer.encoder.layers.11.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
35
- "transformer.encoder.layers.11.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
36
- "transformer.encoder.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
37
- "transformer.encoder.layers.12.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
38
- "transformer.encoder.layers.12.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
39
- "transformer.encoder.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
40
- "transformer.encoder.layers.12.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
41
- "transformer.encoder.layers.12.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
42
- "transformer.encoder.layers.12.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
43
- "transformer.encoder.layers.13.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
44
- "transformer.encoder.layers.13.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
45
- "transformer.encoder.layers.13.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
46
- "transformer.encoder.layers.13.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
47
- "transformer.encoder.layers.13.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
48
- "transformer.encoder.layers.13.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
49
- "transformer.encoder.layers.13.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
50
- "transformer.encoder.layers.14.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
51
- "transformer.encoder.layers.14.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
52
- "transformer.encoder.layers.14.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
53
- "transformer.encoder.layers.14.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
54
- "transformer.encoder.layers.14.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
55
- "transformer.encoder.layers.14.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
56
- "transformer.encoder.layers.14.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
57
- "transformer.encoder.layers.15.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
58
- "transformer.encoder.layers.15.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
59
- "transformer.encoder.layers.15.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
60
- "transformer.encoder.layers.15.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
61
- "transformer.encoder.layers.15.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
62
- "transformer.encoder.layers.15.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
63
- "transformer.encoder.layers.15.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
64
- "transformer.encoder.layers.16.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
65
- "transformer.encoder.layers.16.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
66
- "transformer.encoder.layers.16.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
67
- "transformer.encoder.layers.16.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
68
- "transformer.encoder.layers.16.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
69
- "transformer.encoder.layers.16.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
70
- "transformer.encoder.layers.16.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
71
- "transformer.encoder.layers.17.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
72
- "transformer.encoder.layers.17.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
73
- "transformer.encoder.layers.17.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
74
- "transformer.encoder.layers.17.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
75
- "transformer.encoder.layers.17.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
76
- "transformer.encoder.layers.17.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
77
- "transformer.encoder.layers.17.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
78
- "transformer.encoder.layers.18.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
79
- "transformer.encoder.layers.18.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
80
- "transformer.encoder.layers.18.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
81
- "transformer.encoder.layers.18.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
82
- "transformer.encoder.layers.18.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
83
- "transformer.encoder.layers.18.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
84
- "transformer.encoder.layers.18.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
85
- "transformer.encoder.layers.19.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
86
- "transformer.encoder.layers.19.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
87
- "transformer.encoder.layers.19.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
88
- "transformer.encoder.layers.19.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
89
- "transformer.encoder.layers.19.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
90
- "transformer.encoder.layers.19.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
91
- "transformer.encoder.layers.19.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
92
- "transformer.encoder.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
93
- "transformer.encoder.layers.2.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
94
- "transformer.encoder.layers.2.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
95
- "transformer.encoder.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
96
- "transformer.encoder.layers.2.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
97
- "transformer.encoder.layers.2.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
98
- "transformer.encoder.layers.2.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
99
- "transformer.encoder.layers.20.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
100
- "transformer.encoder.layers.20.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
101
- "transformer.encoder.layers.20.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
102
- "transformer.encoder.layers.20.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
103
- "transformer.encoder.layers.20.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
104
- "transformer.encoder.layers.20.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
105
- "transformer.encoder.layers.20.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
106
- "transformer.encoder.layers.21.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
107
- "transformer.encoder.layers.21.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
108
- "transformer.encoder.layers.21.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
109
- "transformer.encoder.layers.21.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
110
- "transformer.encoder.layers.21.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
111
- "transformer.encoder.layers.21.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
112
- "transformer.encoder.layers.21.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
113
- "transformer.encoder.layers.22.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
114
- "transformer.encoder.layers.22.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
115
- "transformer.encoder.layers.22.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
116
- "transformer.encoder.layers.22.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
117
- "transformer.encoder.layers.22.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
118
- "transformer.encoder.layers.22.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
119
- "transformer.encoder.layers.22.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
120
- "transformer.encoder.layers.23.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
121
- "transformer.encoder.layers.23.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00002.bin",
122
- "transformer.encoder.layers.23.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00002.bin",
123
- "transformer.encoder.layers.23.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
124
- "transformer.encoder.layers.23.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
125
- "transformer.encoder.layers.23.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
126
- "transformer.encoder.layers.23.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
127
- "transformer.encoder.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
128
- "transformer.encoder.layers.24.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00002.bin",
129
- "transformer.encoder.layers.24.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00002.bin",
130
- "transformer.encoder.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
131
- "transformer.encoder.layers.24.self_attention.dense.weight": "pytorch_model-00002-of-00002.bin",
132
- "transformer.encoder.layers.24.self_attention.query_key_value.bias": "pytorch_model-00002-of-00002.bin",
133
- "transformer.encoder.layers.24.self_attention.query_key_value.weight": "pytorch_model-00002-of-00002.bin",
134
- "transformer.encoder.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
135
- "transformer.encoder.layers.25.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00002.bin",
136
- "transformer.encoder.layers.25.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00002.bin",
137
- "transformer.encoder.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
138
- "transformer.encoder.layers.25.self_attention.dense.weight": "pytorch_model-00002-of-00002.bin",
139
- "transformer.encoder.layers.25.self_attention.query_key_value.bias": "pytorch_model-00002-of-00002.bin",
140
- "transformer.encoder.layers.25.self_attention.query_key_value.weight": "pytorch_model-00002-of-00002.bin",
141
- "transformer.encoder.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
142
- "transformer.encoder.layers.26.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00002.bin",
143
- "transformer.encoder.layers.26.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00002.bin",
144
- "transformer.encoder.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
145
- "transformer.encoder.layers.26.self_attention.dense.weight": "pytorch_model-00002-of-00002.bin",
146
- "transformer.encoder.layers.26.self_attention.query_key_value.bias": "pytorch_model-00002-of-00002.bin",
147
- "transformer.encoder.layers.26.self_attention.query_key_value.weight": "pytorch_model-00002-of-00002.bin",
148
- "transformer.encoder.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
149
- "transformer.encoder.layers.27.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00002.bin",
150
- "transformer.encoder.layers.27.mlp.dense_h_to_4h.weight": "pytorch_model-00002-of-00002.bin",
151
- "transformer.encoder.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
152
- "transformer.encoder.layers.27.self_attention.dense.weight": "pytorch_model-00002-of-00002.bin",
153
- "transformer.encoder.layers.27.self_attention.query_key_value.bias": "pytorch_model-00002-of-00002.bin",
154
- "transformer.encoder.layers.27.self_attention.query_key_value.weight": "pytorch_model-00002-of-00002.bin",
155
- "transformer.encoder.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
156
- "transformer.encoder.layers.3.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
157
- "transformer.encoder.layers.3.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
158
- "transformer.encoder.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
159
- "transformer.encoder.layers.3.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
160
- "transformer.encoder.layers.3.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
161
- "transformer.encoder.layers.3.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
162
- "transformer.encoder.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
163
- "transformer.encoder.layers.4.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
164
- "transformer.encoder.layers.4.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
165
- "transformer.encoder.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
166
- "transformer.encoder.layers.4.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
167
- "transformer.encoder.layers.4.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
168
- "transformer.encoder.layers.4.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
169
- "transformer.encoder.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
170
- "transformer.encoder.layers.5.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
171
- "transformer.encoder.layers.5.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
172
- "transformer.encoder.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
173
- "transformer.encoder.layers.5.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
174
- "transformer.encoder.layers.5.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
175
- "transformer.encoder.layers.5.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
176
- "transformer.encoder.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
177
- "transformer.encoder.layers.6.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
178
- "transformer.encoder.layers.6.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
179
- "transformer.encoder.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
180
- "transformer.encoder.layers.6.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
181
- "transformer.encoder.layers.6.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
182
- "transformer.encoder.layers.6.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
183
- "transformer.encoder.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
184
- "transformer.encoder.layers.7.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
185
- "transformer.encoder.layers.7.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
186
- "transformer.encoder.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
187
- "transformer.encoder.layers.7.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
188
- "transformer.encoder.layers.7.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
189
- "transformer.encoder.layers.7.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
190
- "transformer.encoder.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
191
- "transformer.encoder.layers.8.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
192
- "transformer.encoder.layers.8.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
193
- "transformer.encoder.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
194
- "transformer.encoder.layers.8.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
195
- "transformer.encoder.layers.8.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
196
- "transformer.encoder.layers.8.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
197
- "transformer.encoder.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
198
- "transformer.encoder.layers.9.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00002.bin",
199
- "transformer.encoder.layers.9.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00002.bin",
200
- "transformer.encoder.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
201
- "transformer.encoder.layers.9.self_attention.dense.weight": "pytorch_model-00001-of-00002.bin",
202
- "transformer.encoder.layers.9.self_attention.query_key_value.bias": "pytorch_model-00001-of-00002.bin",
203
- "transformer.encoder.layers.9.self_attention.query_key_value.weight": "pytorch_model-00001-of-00002.bin",
204
- "transformer.output_layer.weight": "pytorch_model-00002-of-00002.bin",
205
- "transformer.rotary_pos_emb.inv_freq": "pytorch_model-00001-of-00002.bin"
206
- }
207
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
quantization.py DELETED
@@ -1,188 +0,0 @@
1
- from torch.nn import Linear
2
- from torch.nn.parameter import Parameter
3
-
4
- import bz2
5
- import torch
6
- import base64
7
- import ctypes
8
- from transformers.utils import logging
9
-
10
- from typing import List
11
- from functools import partial
12
-
13
- logger = logging.get_logger(__name__)
14
-
15
- try:
16
- from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
17
-
18
- class Kernel:
19
- def __init__(self, code: bytes, function_names: List[str]):
20
- self.code = code
21
- self._function_names = function_names
22
- self._cmodule = LazyKernelCModule(self.code)
23
-
24
- for name in self._function_names:
25
- setattr(self, name, KernelFunction(self._cmodule, name))
26
-
27
- quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ"
28
-
29
- kernels = Kernel(
30
- bz2.decompress(base64.b64decode(quantization_code)),
31
- [
32
- "int4WeightCompression",
33
- "int4WeightExtractionFloat",
34
- "int4WeightExtractionHalf",
35
- "int8WeightExtractionFloat",
36
- "int8WeightExtractionHalf",
37
- ],
38
- )
39
- except Exception as exception:
40
- kernels = None
41
- logger.warning("Failed to load cpm_kernels:" + str(exception))
42
-
43
-
44
- class W8A16Linear(torch.autograd.Function):
45
- @staticmethod
46
- def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
47
- ctx.inp_shape = inp.size()
48
- ctx.weight_bit_width = weight_bit_width
49
- out_features = quant_w.size(0)
50
- inp = inp.contiguous().view(-1, inp.size(-1))
51
- weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
52
- ctx.weight_shape = weight.size()
53
- output = inp.mm(weight.t())
54
- ctx.save_for_backward(inp, quant_w, scale_w)
55
- return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
56
-
57
- @staticmethod
58
- def backward(ctx, grad_output: torch.Tensor):
59
- inp, quant_w, scale_w = ctx.saved_tensors
60
- weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
61
- grad_output = grad_output.contiguous().view(-1, weight.size(0))
62
- grad_input = grad_output.mm(weight)
63
- grad_weight = grad_output.t().mm(inp)
64
- return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
65
-
66
-
67
- def compress_int4_weight(weight: torch.Tensor): # (n, m)
68
- with torch.cuda.device(weight.device):
69
- n, m = weight.size(0), weight.size(1)
70
- assert m % 2 == 0
71
- m = m // 2
72
- out = torch.empty(n, m, dtype=torch.int8, device="cuda")
73
- stream = torch.cuda.current_stream()
74
-
75
- gridDim = (n, 1, 1)
76
- blockDim = (min(round_up(m, 32), 1024), 1, 1)
77
-
78
- kernels.int4WeightCompression(
79
- gridDim,
80
- blockDim,
81
- 0,
82
- stream,
83
- [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
84
- )
85
- return out
86
-
87
-
88
- def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
89
- assert scale_list.dtype in [torch.half, torch.bfloat16]
90
- assert weight.dtype in [torch.int8]
91
- if source_bit_width == 8:
92
- return weight.to(scale_list.dtype) * scale_list[:, None]
93
- elif source_bit_width == 4:
94
- func = (
95
- kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16
96
- )
97
- else:
98
- assert False, "Unsupported bit-width"
99
-
100
- with torch.cuda.device(weight.device):
101
- n, m = weight.size(0), weight.size(1)
102
- out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda")
103
- stream = torch.cuda.current_stream()
104
-
105
- gridDim = (n, 1, 1)
106
- blockDim = (min(round_up(m, 32), 1024), 1, 1)
107
-
108
- func(
109
- gridDim,
110
- blockDim,
111
- 0,
112
- stream,
113
- [
114
- ctypes.c_void_p(weight.data_ptr()),
115
- ctypes.c_void_p(scale_list.data_ptr()),
116
- ctypes.c_void_p(out.data_ptr()),
117
- ctypes.c_int32(n),
118
- ctypes.c_int32(m),
119
- ],
120
- )
121
- return out
122
-
123
-
124
- class QuantizedLinear(torch.nn.Module):
125
- def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args,
126
- **kwargs):
127
- super().__init__()
128
- self.weight_bit_width = weight_bit_width
129
-
130
- shape = weight.shape
131
-
132
- if weight is None or empty_init:
133
- self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device)
134
- self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device)
135
- else:
136
- self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)
137
- self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
138
- if weight_bit_width == 4:
139
- self.weight = compress_int4_weight(self.weight)
140
-
141
- self.weight = Parameter(self.weight.to(device), requires_grad=False)
142
- self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False)
143
- self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None
144
-
145
- def forward(self, input):
146
- output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
147
- if self.bias is not None:
148
- output = output + self.bias
149
- return output
150
-
151
-
152
- def quantize(model, weight_bit_width, empty_init=False, device=None):
153
- """Replace fp16 linear with quantized linear"""
154
- for layer in model.layers:
155
- layer.self_attention.query_key_value = QuantizedLinear(
156
- weight_bit_width=weight_bit_width,
157
- weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()),
158
- bias=layer.self_attention.query_key_value.bias,
159
- dtype=layer.self_attention.query_key_value.weight.dtype,
160
- device=layer.self_attention.query_key_value.weight.device if device is None else device,
161
- empty_init=empty_init
162
- )
163
- layer.self_attention.dense = QuantizedLinear(
164
- weight_bit_width=weight_bit_width,
165
- weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()),
166
- bias=layer.self_attention.dense.bias,
167
- dtype=layer.self_attention.dense.weight.dtype,
168
- device=layer.self_attention.dense.weight.device if device is None else device,
169
- empty_init=empty_init
170
- )
171
- layer.mlp.dense_h_to_4h = QuantizedLinear(
172
- weight_bit_width=weight_bit_width,
173
- weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
174
- bias=layer.mlp.dense_h_to_4h.bias,
175
- dtype=layer.mlp.dense_h_to_4h.weight.dtype,
176
- device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
177
- empty_init=empty_init
178
- )
179
- layer.mlp.dense_4h_to_h = QuantizedLinear(
180
- weight_bit_width=weight_bit_width,
181
- weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
182
- bias=layer.mlp.dense_4h_to_h.bias,
183
- dtype=layer.mlp.dense_4h_to_h.weight.dtype,
184
- device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,
185
- empty_init=empty_init
186
- )
187
-
188
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
special_tokens_map.json DELETED
@@ -1 +0,0 @@
1
- {}
 
 
tokenization_chatglm.py DELETED
@@ -1,277 +0,0 @@
1
- import json
2
- import os
3
- import re
4
- from typing import List, Optional, Union, Dict
5
- from sentencepiece import SentencePieceProcessor
6
- from transformers import PreTrainedTokenizer
7
- from transformers.utils import logging, PaddingStrategy
8
- from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
9
-
10
-
11
- class SPTokenizer:
12
- def __init__(self, model_path: str):
13
- # reload tokenizer
14
- assert os.path.isfile(model_path), model_path
15
- self.sp_model = SentencePieceProcessor(model_file=model_path)
16
-
17
- # BOS / EOS token IDs
18
- self.n_words: int = self.sp_model.vocab_size()
19
- self.bos_id: int = self.sp_model.bos_id()
20
- self.eos_id: int = self.sp_model.eos_id()
21
- self.pad_id: int = self.sp_model.unk_id()
22
- assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
23
-
24
- special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
25
- self.special_tokens = {}
26
- self.index_special_tokens = {}
27
- for token in special_tokens:
28
- self.special_tokens[token] = self.n_words
29
- self.index_special_tokens[self.n_words] = token
30
- self.n_words += 1
31
-
32
- def tokenize(self, s: str):
33
- return self.sp_model.EncodeAsPieces(s)
34
-
35
-
36
- def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
37
- assert type(s) is str
38
- t = self.sp_model.encode(s)
39
- if bos:
40
- t = [self.bos_id] + t
41
- if eos:
42
- t = t + [self.eos_id]
43
- return t
44
-
45
- def decode(self, t: List[int]) -> str:
46
- text, buffer = "", []
47
- for token in t:
48
- if token in self.index_special_tokens:
49
- if buffer:
50
- text += self.sp_model.decode(buffer)
51
- buffer = []
52
- text += self.index_special_tokens[token]
53
- else:
54
- buffer.append(token)
55
- if buffer:
56
- text += self.sp_model.decode(buffer)
57
- return text
58
-
59
- def decode_tokens(self, tokens: List[str]) -> str:
60
- text = self.sp_model.DecodePieces(tokens)
61
- return text
62
-
63
- def convert_token_to_id(self, token):
64
- """ Converts a token (str) in an id using the vocab. """
65
- if token in self.special_tokens:
66
- return self.special_tokens[token]
67
- return self.sp_model.PieceToId(token)
68
-
69
- def convert_id_to_token(self, index):
70
- """Converts an index (integer) in a token (str) using the vocab."""
71
- if index in self.index_special_tokens:
72
- return self.index_special_tokens[index]
73
- if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0 or index > self.sp_model.vocab_size():
74
- return ""
75
- return self.sp_model.IdToPiece(index)
76
-
77
-
78
- class ChatGLMTokenizer(PreTrainedTokenizer):
79
- vocab_files_names = {"vocab_file": "tokenizer.model"}
80
-
81
- model_input_names = ["input_ids", "attention_mask", "position_ids"]
82
-
83
- def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
84
- **kwargs):
85
- self.name = "GLMTokenizer"
86
-
87
- self.vocab_file = vocab_file
88
- self.tokenizer = SPTokenizer(vocab_file)
89
- self.special_tokens = {
90
- "<bos>": self.tokenizer.bos_id,
91
- "<eos>": self.tokenizer.eos_id,
92
- "<pad>": self.tokenizer.pad_id
93
- }
94
- self.encode_special_tokens = encode_special_tokens
95
- super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
96
- encode_special_tokens=encode_special_tokens,
97
- **kwargs)
98
-
99
- def get_command(self, token):
100
- if token in self.special_tokens:
101
- return self.special_tokens[token]
102
- assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
103
- return self.tokenizer.special_tokens[token]
104
-
105
- @property
106
- def unk_token(self) -> str:
107
- return "<unk>"
108
-
109
- @property
110
- def pad_token(self) -> str:
111
- return "<unk>"
112
-
113
- @property
114
- def pad_token_id(self):
115
- return self.get_command("<pad>")
116
-
117
- @property
118
- def eos_token(self) -> str:
119
- return "</s>"
120
-
121
- @property
122
- def eos_token_id(self):
123
- return self.get_command("<eos>")
124
-
125
- @property
126
- def vocab_size(self):
127
- return self.tokenizer.n_words
128
-
129
- def get_vocab(self):
130
- """ Returns vocab as a dict """
131
- vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
132
- vocab.update(self.added_tokens_encoder)
133
- return vocab
134
-
135
- def _tokenize(self, text, **kwargs):
136
- return self.tokenizer.tokenize(text)
137
-
138
- def _convert_token_to_id(self, token):
139
- """ Converts a token (str) in an id using the vocab. """
140
- return self.tokenizer.convert_token_to_id(token)
141
-
142
- def _convert_id_to_token(self, index):
143
- """Converts an index (integer) in a token (str) using the vocab."""
144
- return self.tokenizer.convert_id_to_token(index)
145
-
146
- def convert_tokens_to_string(self, tokens: List[str]) -> str:
147
- return self.tokenizer.decode_tokens(tokens)
148
-
149
- def save_vocabulary(self, save_directory, filename_prefix=None):
150
- """
151
- Save the vocabulary and special tokens file to a directory.
152
-
153
- Args:
154
- save_directory (`str`):
155
- The directory in which to save the vocabulary.
156
- filename_prefix (`str`, *optional*):
157
- An optional prefix to add to the named of the saved files.
158
-
159
- Returns:
160
- `Tuple(str)`: Paths to the files saved.
161
- """
162
- if os.path.isdir(save_directory):
163
- vocab_file = os.path.join(
164
- save_directory, self.vocab_files_names["vocab_file"]
165
- )
166
- else:
167
- vocab_file = save_directory
168
-
169
- with open(self.vocab_file, 'rb') as fin:
170
- proto_str = fin.read()
171
-
172
- with open(vocab_file, "wb") as writer:
173
- writer.write(proto_str)
174
-
175
- return (vocab_file,)
176
-
177
- def get_prefix_tokens(self):
178
- prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
179
- return prefix_tokens
180
-
181
- def build_prompt(self, query, history=None):
182
- if history is None:
183
- history = []
184
- prompt = ""
185
- for i, (old_query, response) in enumerate(history):
186
- prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
187
- prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
188
- return prompt
189
-
190
-
191
- def build_inputs_with_special_tokens(
192
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
193
- ) -> List[int]:
194
- """
195
- Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
196
- adding special tokens. A BERT sequence has the following format:
197
-
198
- - single sequence: `[CLS] X [SEP]`
199
- - pair of sequences: `[CLS] A [SEP] B [SEP]`
200
-
201
- Args:
202
- token_ids_0 (`List[int]`):
203
- List of IDs to which the special tokens will be added.
204
- token_ids_1 (`List[int]`, *optional*):
205
- Optional second list of IDs for sequence pairs.
206
-
207
- Returns:
208
- `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
209
- """
210
- prefix_tokens = self.get_prefix_tokens()
211
- token_ids_0 = prefix_tokens + token_ids_0
212
- if token_ids_1 is not None:
213
- token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
214
- return token_ids_0
215
-
216
- def _pad(
217
- self,
218
- encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
219
- max_length: Optional[int] = None,
220
- padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
221
- pad_to_multiple_of: Optional[int] = None,
222
- return_attention_mask: Optional[bool] = None,
223
- ) -> dict:
224
- """
225
- Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
226
-
227
- Args:
228
- encoded_inputs:
229
- Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
230
- max_length: maximum length of the returned list and optionally padding length (see below).
231
- Will truncate by taking into account the special tokens.
232
- padding_strategy: PaddingStrategy to use for padding.
233
-
234
- - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
235
- - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
236
- - PaddingStrategy.DO_NOT_PAD: Do not pad
237
- The tokenizer padding sides are defined in self.padding_side:
238
-
239
- - 'left': pads on the left of the sequences
240
- - 'right': pads on the right of the sequences
241
- pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
242
- This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
243
- `>= 7.5` (Volta).
244
- return_attention_mask:
245
- (optional) Set to False to avoid returning attention mask (default: set to model specifics)
246
- """
247
- # Load from model defaults
248
- assert self.padding_side == "left"
249
-
250
- required_input = encoded_inputs[self.model_input_names[0]]
251
- seq_length = len(required_input)
252
-
253
- if padding_strategy == PaddingStrategy.LONGEST:
254
- max_length = len(required_input)
255
-
256
- if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
257
- max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
258
-
259
- needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
260
-
261
- # Initialize attention mask if not present.
262
- if "attention_mask" not in encoded_inputs:
263
- encoded_inputs["attention_mask"] = [1] * seq_length
264
-
265
- if "position_ids" not in encoded_inputs:
266
- encoded_inputs["position_ids"] = list(range(seq_length))
267
-
268
- if needs_to_be_padded:
269
- difference = max_length - len(required_input)
270
-
271
- if "attention_mask" in encoded_inputs:
272
- encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
273
- if "position_ids" in encoded_inputs:
274
- encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
275
- encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
276
-
277
- return encoded_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenizer.model DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e7dc4c393423b76e4373e5157ddc34803a0189ba96b21ddbb40269d31468a6f2
3
- size 1018370
 
 
 
 
tokenizer_config.json DELETED
@@ -1,14 +0,0 @@
1
- {
2
- "auto_map": {
3
- "AutoTokenizer": [
4
- "tokenization_chatglm.ChatGLMTokenizer",
5
- null
6
- ]
7
- },
8
- "clean_up_tokenization_spaces": true,
9
- "do_lower_case": false,
10
- "model_max_length": 1000000000000000019884624838656,
11
- "padding_side": "left",
12
- "remove_space": false,
13
- "tokenizer_class": "ChatGLMTokenizer"
14
- }