Merge branch 'main' into dev_pt
Browse files- README.md +4 -0
- modeling_chatglm.py +154 -63
- tokenization_chatglm.py +1 -1
README.md
CHANGED
@@ -11,6 +11,8 @@ tags:
|
|
11 |
## 介绍
|
12 |
ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。
|
13 |
|
|
|
|
|
14 |
## 软件依赖
|
15 |
|
16 |
```shell
|
@@ -44,6 +46,8 @@ pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels
|
|
44 |
|
45 |
关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO,以及使用模型量化以节省显存,请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。
|
46 |
|
|
|
|
|
47 |
## 协议
|
48 |
|
49 |
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。
|
|
|
11 |
## 介绍
|
12 |
ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。
|
13 |
|
14 |
+
ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference.
|
15 |
+
|
16 |
## 软件依赖
|
17 |
|
18 |
```shell
|
|
|
46 |
|
47 |
关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO,以及使用模型量化以节省显存,请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。
|
48 |
|
49 |
+
For more instructions, including how to run CLI and web demos, and model quantization, please refer to our [Github Repo](https://github.com/THUDM/ChatGLM-6B).
|
50 |
+
|
51 |
## 协议
|
52 |
|
53 |
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。
|
modeling_chatglm.py
CHANGED
@@ -3,7 +3,9 @@
|
|
3 |
import math
|
4 |
import copy
|
5 |
import os
|
6 |
-
import
|
|
|
|
|
7 |
|
8 |
import torch
|
9 |
import torch.utils.checkpoint
|
@@ -11,7 +13,7 @@ import torch.nn.functional as F
|
|
11 |
from torch import nn
|
12 |
from torch.nn import CrossEntropyLoss, LayerNorm
|
13 |
from torch.nn.utils import skip_init
|
14 |
-
from typing import Optional, Tuple, Union, List
|
15 |
|
16 |
from transformers.utils import (
|
17 |
add_code_sample_docstrings,
|
@@ -26,15 +28,17 @@ from transformers.modeling_outputs import (
|
|
26 |
from transformers.modeling_utils import PreTrainedModel
|
27 |
from transformers.utils import logging
|
28 |
from transformers.generation.logits_process import LogitsProcessor
|
29 |
-
from transformers.generation.utils import LogitsProcessorList
|
30 |
|
31 |
from .configuration_chatglm import ChatGLMConfig
|
32 |
|
33 |
# flags required to enable jit fusion kernels
|
34 |
-
|
35 |
-
|
36 |
-
torch._C.
|
37 |
-
torch._C.
|
|
|
|
|
38 |
|
39 |
logger = logging.get_logger(__name__)
|
40 |
|
@@ -294,7 +298,7 @@ def attention_fn(
|
|
294 |
if not (attention_mask == 0).all():
|
295 |
# if auto-regressive, skip
|
296 |
attention_scores.masked_fill_(attention_mask, -10000.0)
|
297 |
-
dtype = attention_scores.
|
298 |
attention_scores = attention_scores.float()
|
299 |
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
300 |
|
@@ -814,8 +818,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
814 |
return past_key_values
|
815 |
|
816 |
@staticmethod
|
817 |
-
def get_masks(seq, device):
|
818 |
-
context_length = seq.index(
|
819 |
|
820 |
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
821 |
attention_mask.tril_()
|
@@ -826,9 +830,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
826 |
return attention_mask
|
827 |
|
828 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
829 |
-
context_length = seq
|
830 |
if self.position_encoding_2d:
|
831 |
-
seq_length = seq.index(
|
832 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
833 |
if not gmask:
|
834 |
position_ids[seq_length:] = mask_position
|
@@ -886,14 +890,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
886 |
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
|
887 |
else:
|
888 |
past_key_values = tuple([None] * len(self.layers))
|
889 |
-
|
890 |
-
MASK, gMASK = 150000, 150001
|
891 |
-
mask_token = MASK if MASK in input_ids else gMASK
|
892 |
-
use_gmask = False if MASK in input_ids else gMASK
|
893 |
seq = input_ids[0].tolist()
|
894 |
|
895 |
-
mask_position = seq.index(mask_token)
|
896 |
-
|
897 |
if attention_mask is None:
|
898 |
attention_mask = self.get_masks(
|
899 |
seq=seq,
|
@@ -906,6 +904,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
906 |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
907 |
|
908 |
if position_ids is None:
|
|
|
|
|
|
|
|
|
|
|
909 |
position_ids = self.get_position_ids(
|
910 |
seq=seq,
|
911 |
mask_position=mask_position,
|
@@ -1009,7 +1012,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1009 |
attention_mask = (attention_mask < 0.5).bool()
|
1010 |
|
1011 |
if self.position_encoding_2d:
|
1012 |
-
seq_length = seq.index(
|
1013 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
1014 |
if not gmask:
|
1015 |
position_ids[seq_length:] = mask_position
|
@@ -1047,7 +1050,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1047 |
|
1048 |
# only last token for input_ids if past is not None
|
1049 |
if past is not None or past_key_values is not None:
|
1050 |
-
context_length = seq.index(
|
1051 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
1052 |
if self.position_encoding_2d:
|
1053 |
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|
@@ -1155,6 +1158,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1155 |
for layer_past in past
|
1156 |
)
|
1157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1158 |
@torch.no_grad()
|
1159 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
1160 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
@@ -1175,66 +1193,139 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1175 |
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
|
1176 |
input_ids = input_ids.to(self.device)
|
1177 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
1178 |
-
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0])
|
1179 |
response = tokenizer.decode(outputs)
|
1180 |
-
response =
|
1181 |
-
response = response.replace("[[训练时间]]", "2023年")
|
1182 |
history = history + [(query, response)]
|
1183 |
return response, history
|
1184 |
|
1185 |
@torch.no_grad()
|
1186 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1187 |
self,
|
|
|
|
|
|
|
|
|
|
|
1188 |
**kwargs,
|
1189 |
):
|
1190 |
-
|
1191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1192 |
|
1193 |
-
if
|
1194 |
-
|
|
|
|
|
|
|
|
|
|
|
1195 |
|
1196 |
-
|
|
|
|
|
1197 |
|
1198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1199 |
|
1200 |
-
|
|
|
|
|
|
|
1201 |
|
|
|
|
|
1202 |
while True:
|
1203 |
-
|
1204 |
-
|
1205 |
-
|
1206 |
-
|
1207 |
-
|
1208 |
-
|
1209 |
-
|
1210 |
-
|
1211 |
-
mask_token = MASK if MASK in output_seq else gMASK
|
1212 |
-
mask_position = output_seq.index(mask_token)
|
1213 |
-
bos_position = output_seq.index(bos)
|
1214 |
-
if eos in output_seq:
|
1215 |
-
eos_position = output_seq.index(eos)
|
1216 |
-
else:
|
1217 |
-
eos_position = len(output_seq)
|
1218 |
-
|
1219 |
-
return_seq = output_seq[:mask_position] + output_seq[bos_position + 1:eos_position] + output_seq[
|
1220 |
-
mask_position + 1:bos_position]
|
1221 |
-
max_length = max(max_length, len(return_seq))
|
1222 |
-
return_seqs.append(return_seq)
|
1223 |
-
|
1224 |
-
for i in range(output_ids.shape[0]):
|
1225 |
-
return_seqs[i] = [0] * (max_length - len(return_seqs[i])) + return_seqs[i] # padding
|
1226 |
-
if mask_token not in return_seqs[i]:
|
1227 |
-
stop = True
|
1228 |
-
|
1229 |
-
if stop:
|
1230 |
-
break
|
1231 |
|
1232 |
-
|
1233 |
-
return_seq += [bos]
|
1234 |
|
1235 |
-
|
|
|
|
|
1236 |
|
1237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1238 |
|
1239 |
def quantize(self, bits: int):
|
1240 |
from .quantization import quantize
|
|
|
3 |
import math
|
4 |
import copy
|
5 |
import os
|
6 |
+
import warnings
|
7 |
+
import re
|
8 |
+
import sys
|
9 |
|
10 |
import torch
|
11 |
import torch.utils.checkpoint
|
|
|
13 |
from torch import nn
|
14 |
from torch.nn import CrossEntropyLoss, LayerNorm
|
15 |
from torch.nn.utils import skip_init
|
16 |
+
from typing import Optional, Tuple, Union, List, Callable
|
17 |
|
18 |
from transformers.utils import (
|
19 |
add_code_sample_docstrings,
|
|
|
28 |
from transformers.modeling_utils import PreTrainedModel
|
29 |
from transformers.utils import logging
|
30 |
from transformers.generation.logits_process import LogitsProcessor
|
31 |
+
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
32 |
|
33 |
from .configuration_chatglm import ChatGLMConfig
|
34 |
|
35 |
# flags required to enable jit fusion kernels
|
36 |
+
|
37 |
+
if sys.platform != 'darwin':
|
38 |
+
torch._C._jit_set_profiling_mode(False)
|
39 |
+
torch._C._jit_set_profiling_executor(False)
|
40 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
41 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
42 |
|
43 |
logger = logging.get_logger(__name__)
|
44 |
|
|
|
298 |
if not (attention_mask == 0).all():
|
299 |
# if auto-regressive, skip
|
300 |
attention_scores.masked_fill_(attention_mask, -10000.0)
|
301 |
+
dtype = attention_scores.dtype
|
302 |
attention_scores = attention_scores.float()
|
303 |
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
304 |
|
|
|
818 |
return past_key_values
|
819 |
|
820 |
@staticmethod
|
821 |
+
def get_masks(self, seq, device):
|
822 |
+
context_length = seq.index(self.config.bos_token_id) + 1
|
823 |
|
824 |
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
825 |
attention_mask.tril_()
|
|
|
830 |
return attention_mask
|
831 |
|
832 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
833 |
+
context_length = len(seq)
|
834 |
if self.position_encoding_2d:
|
835 |
+
seq_length = seq.index(self.config.bos_token_id)
|
836 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
837 |
if not gmask:
|
838 |
position_ids[seq_length:] = mask_position
|
|
|
890 |
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
|
891 |
else:
|
892 |
past_key_values = tuple([None] * len(self.layers))
|
|
|
|
|
|
|
|
|
893 |
seq = input_ids[0].tolist()
|
894 |
|
|
|
|
|
895 |
if attention_mask is None:
|
896 |
attention_mask = self.get_masks(
|
897 |
seq=seq,
|
|
|
904 |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
905 |
|
906 |
if position_ids is None:
|
907 |
+
MASK, gMASK = 150000, 150001
|
908 |
+
mask_token = MASK if MASK in input_ids else gMASK
|
909 |
+
use_gmask = False if MASK in input_ids else gMASK
|
910 |
+
|
911 |
+
mask_position = seq.index(mask_token)
|
912 |
position_ids = self.get_position_ids(
|
913 |
seq=seq,
|
914 |
mask_position=mask_position,
|
|
|
1012 |
attention_mask = (attention_mask < 0.5).bool()
|
1013 |
|
1014 |
if self.position_encoding_2d:
|
1015 |
+
seq_length = seq.index(self.config.bos_token_id)
|
1016 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
1017 |
if not gmask:
|
1018 |
position_ids[seq_length:] = mask_position
|
|
|
1050 |
|
1051 |
# only last token for input_ids if past is not None
|
1052 |
if past is not None or past_key_values is not None:
|
1053 |
+
context_length = seq.index(self.config.bos_token_id)
|
1054 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
1055 |
if self.position_encoding_2d:
|
1056 |
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|
|
|
1158 |
for layer_past in past
|
1159 |
)
|
1160 |
|
1161 |
+
def process_response(self, response):
|
1162 |
+
response = response.strip()
|
1163 |
+
response = response.replace("[[训练时间]]", "2023年")
|
1164 |
+
punkts = [
|
1165 |
+
[",", ","],
|
1166 |
+
["!", "!"],
|
1167 |
+
[":", ":"],
|
1168 |
+
[";", ";"],
|
1169 |
+
["\?", "?"],
|
1170 |
+
]
|
1171 |
+
for item in punkts:
|
1172 |
+
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
|
1173 |
+
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
|
1174 |
+
return response
|
1175 |
+
|
1176 |
@torch.no_grad()
|
1177 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
1178 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
|
|
1193 |
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
|
1194 |
input_ids = input_ids.to(self.device)
|
1195 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
1196 |
+
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1197 |
response = tokenizer.decode(outputs)
|
1198 |
+
response = self.process_response(response)
|
|
|
1199 |
history = history + [(query, response)]
|
1200 |
return response, history
|
1201 |
|
1202 |
@torch.no_grad()
|
1203 |
+
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
|
1204 |
+
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
1205 |
+
if history is None:
|
1206 |
+
history = []
|
1207 |
+
if logits_processor is None:
|
1208 |
+
logits_processor = LogitsProcessorList()
|
1209 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1210 |
+
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
1211 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1212 |
+
if not history:
|
1213 |
+
prompt = query
|
1214 |
+
else:
|
1215 |
+
prompt = ""
|
1216 |
+
for i, (old_query, response) in enumerate(history):
|
1217 |
+
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
|
1218 |
+
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
1219 |
+
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
|
1220 |
+
input_ids = input_ids.to(self.device)
|
1221 |
+
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
1222 |
+
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1223 |
+
response = tokenizer.decode(outputs)
|
1224 |
+
response = self.process_response(response)
|
1225 |
+
new_history = history + [(query, response)]
|
1226 |
+
yield response, new_history
|
1227 |
+
|
1228 |
+
@torch.no_grad()
|
1229 |
+
def stream_generate(
|
1230 |
self,
|
1231 |
+
input_ids,
|
1232 |
+
generation_config: Optional[GenerationConfig] = None,
|
1233 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
1234 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
1235 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
1236 |
**kwargs,
|
1237 |
):
|
1238 |
+
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
1239 |
+
|
1240 |
+
if generation_config is None:
|
1241 |
+
generation_config = self.generation_config
|
1242 |
+
generation_config = copy.deepcopy(generation_config)
|
1243 |
+
model_kwargs = generation_config.update(**kwargs)
|
1244 |
+
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
1245 |
+
|
1246 |
+
if isinstance(eos_token_id, int):
|
1247 |
+
eos_token_id = [eos_token_id]
|
1248 |
+
|
1249 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
1250 |
+
if has_default_max_length and generation_config.max_new_tokens is None:
|
1251 |
+
warnings.warn(
|
1252 |
+
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
1253 |
+
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
1254 |
+
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
1255 |
+
UserWarning,
|
1256 |
+
)
|
1257 |
+
elif generation_config.max_new_tokens is not None:
|
1258 |
+
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
1259 |
+
if not has_default_max_length:
|
1260 |
+
logger.warn(
|
1261 |
+
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
1262 |
+
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
1263 |
+
"Please refer to the documentation for more information. "
|
1264 |
+
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
1265 |
+
UserWarning,
|
1266 |
+
)
|
1267 |
|
1268 |
+
if input_ids_seq_length >= generation_config.max_length:
|
1269 |
+
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
1270 |
+
logger.warning(
|
1271 |
+
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
1272 |
+
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
1273 |
+
" increasing `max_new_tokens`."
|
1274 |
+
)
|
1275 |
|
1276 |
+
# 2. Set generation parameters if not already defined
|
1277 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
1278 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
1279 |
|
1280 |
+
logits_processor = self._get_logits_processor(
|
1281 |
+
generation_config=generation_config,
|
1282 |
+
input_ids_seq_length=input_ids_seq_length,
|
1283 |
+
encoder_input_ids=input_ids,
|
1284 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
1285 |
+
logits_processor=logits_processor,
|
1286 |
+
)
|
1287 |
|
1288 |
+
stopping_criteria = self._get_stopping_criteria(
|
1289 |
+
generation_config=generation_config, stopping_criteria=stopping_criteria
|
1290 |
+
)
|
1291 |
+
logits_warper = self._get_logits_warper(generation_config)
|
1292 |
|
1293 |
+
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
1294 |
+
scores = None
|
1295 |
while True:
|
1296 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
1297 |
+
# forward pass to get next token
|
1298 |
+
outputs = self(
|
1299 |
+
**model_inputs,
|
1300 |
+
return_dict=True,
|
1301 |
+
output_attentions=False,
|
1302 |
+
output_hidden_states=False,
|
1303 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1304 |
|
1305 |
+
next_token_logits = outputs.logits[:, -1, :]
|
|
|
1306 |
|
1307 |
+
# pre-process distribution
|
1308 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
1309 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
1310 |
|
1311 |
+
# sample
|
1312 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
1313 |
+
if generation_config.do_sample:
|
1314 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
1315 |
+
else:
|
1316 |
+
next_tokens = torch.argmax(probs, dim=-1)
|
1317 |
+
|
1318 |
+
# update generated ids, model inputs, and length for next step
|
1319 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
1320 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
1321 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
1322 |
+
)
|
1323 |
+
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
1324 |
+
|
1325 |
+
# stop when each sentence is finished, or if we exceed the maximum length
|
1326 |
+
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
1327 |
+
break
|
1328 |
+
yield input_ids
|
1329 |
|
1330 |
def quantize(self, bits: int):
|
1331 |
from .quantization import quantize
|
tokenization_chatglm.py
CHANGED
@@ -299,7 +299,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
299 |
"""
|
300 |
if os.path.isdir(save_directory):
|
301 |
vocab_file = os.path.join(
|
302 |
-
save_directory,
|
303 |
)
|
304 |
else:
|
305 |
vocab_file = save_directory
|
|
|
299 |
"""
|
300 |
if os.path.isdir(save_directory):
|
301 |
vocab_file = os.path.join(
|
302 |
+
save_directory, self.vocab_files_names["vocab_file"]
|
303 |
)
|
304 |
else:
|
305 |
vocab_file = save_directory
|