zxdu20 commited on
Commit
fbda120
·
2 Parent(s): 812f43f 096f3de

Merge branch 'main' into dev_pt

Browse files
Files changed (3) hide show
  1. README.md +4 -0
  2. modeling_chatglm.py +154 -63
  3. tokenization_chatglm.py +1 -1
README.md CHANGED
@@ -11,6 +11,8 @@ tags:
11
  ## 介绍
12
  ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。
13
 
 
 
14
  ## 软件依赖
15
 
16
  ```shell
@@ -44,6 +46,8 @@ pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels
44
 
45
  关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO,以及使用模型量化以节省显存,请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。
46
 
 
 
47
  ## 协议
48
 
49
  本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。
 
11
  ## 介绍
12
  ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。
13
 
14
+ ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference.
15
+
16
  ## 软件依赖
17
 
18
  ```shell
 
46
 
47
  关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO,以及使用模型量化以节省显存,请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。
48
 
49
+ For more instructions, including how to run CLI and web demos, and model quantization, please refer to our [Github Repo](https://github.com/THUDM/ChatGLM-6B).
50
+
51
  ## 协议
52
 
53
  本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。
modeling_chatglm.py CHANGED
@@ -3,7 +3,9 @@
3
  import math
4
  import copy
5
  import os
6
- import time
 
 
7
 
8
  import torch
9
  import torch.utils.checkpoint
@@ -11,7 +13,7 @@ import torch.nn.functional as F
11
  from torch import nn
12
  from torch.nn import CrossEntropyLoss, LayerNorm
13
  from torch.nn.utils import skip_init
14
- from typing import Optional, Tuple, Union, List
15
 
16
  from transformers.utils import (
17
  add_code_sample_docstrings,
@@ -26,15 +28,17 @@ from transformers.modeling_outputs import (
26
  from transformers.modeling_utils import PreTrainedModel
27
  from transformers.utils import logging
28
  from transformers.generation.logits_process import LogitsProcessor
29
- from transformers.generation.utils import LogitsProcessorList
30
 
31
  from .configuration_chatglm import ChatGLMConfig
32
 
33
  # flags required to enable jit fusion kernels
34
- torch._C._jit_set_profiling_mode(False)
35
- torch._C._jit_set_profiling_executor(False)
36
- torch._C._jit_override_can_fuse_on_cpu(True)
37
- torch._C._jit_override_can_fuse_on_gpu(True)
 
 
38
 
39
  logger = logging.get_logger(__name__)
40
 
@@ -294,7 +298,7 @@ def attention_fn(
294
  if not (attention_mask == 0).all():
295
  # if auto-regressive, skip
296
  attention_scores.masked_fill_(attention_mask, -10000.0)
297
- dtype = attention_scores.type()
298
  attention_scores = attention_scores.float()
299
  attention_scores = attention_scores * query_key_layer_scaling_coeff
300
 
@@ -814,8 +818,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
814
  return past_key_values
815
 
816
  @staticmethod
817
- def get_masks(seq, device):
818
- context_length = seq.index(150004) + 1
819
 
820
  attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
821
  attention_mask.tril_()
@@ -826,9 +830,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
826
  return attention_mask
827
 
828
  def get_position_ids(self, seq, mask_position, device, gmask=False):
829
- context_length = seq.index(150004) + 1
830
  if self.position_encoding_2d:
831
- seq_length = seq.index(150004)
832
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
833
  if not gmask:
834
  position_ids[seq_length:] = mask_position
@@ -886,14 +890,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
886
  past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
887
  else:
888
  past_key_values = tuple([None] * len(self.layers))
889
-
890
- MASK, gMASK = 150000, 150001
891
- mask_token = MASK if MASK in input_ids else gMASK
892
- use_gmask = False if MASK in input_ids else gMASK
893
  seq = input_ids[0].tolist()
894
 
895
- mask_position = seq.index(mask_token)
896
-
897
  if attention_mask is None:
898
  attention_mask = self.get_masks(
899
  seq=seq,
@@ -906,6 +904,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
906
  attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
907
 
908
  if position_ids is None:
 
 
 
 
 
909
  position_ids = self.get_position_ids(
910
  seq=seq,
911
  mask_position=mask_position,
@@ -1009,7 +1012,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1009
  attention_mask = (attention_mask < 0.5).bool()
1010
 
1011
  if self.position_encoding_2d:
1012
- seq_length = seq.index(150004)
1013
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
1014
  if not gmask:
1015
  position_ids[seq_length:] = mask_position
@@ -1047,7 +1050,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1047
 
1048
  # only last token for input_ids if past is not None
1049
  if past is not None or past_key_values is not None:
1050
- context_length = seq.index(150004)
1051
  last_token = input_ids[:, -1].unsqueeze(-1)
1052
  if self.position_encoding_2d:
1053
  position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
@@ -1155,6 +1158,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1155
  for layer_past in past
1156
  )
1157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1158
  @torch.no_grad()
1159
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1160
  do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
@@ -1175,66 +1193,139 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1175
  input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1176
  input_ids = input_ids.to(self.device)
1177
  outputs = self.generate(**input_ids, **gen_kwargs)
1178
- outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]) - 2:]
1179
  response = tokenizer.decode(outputs)
1180
- response = response.strip()
1181
- response = response.replace("[[训练时间]]", "2023年")
1182
  history = history + [(query, response)]
1183
  return response, history
1184
 
1185
  @torch.no_grad()
1186
- def generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1187
  self,
 
 
 
 
 
1188
  **kwargs,
1189
  ):
1190
- MASK, gMASK = 150000, 150001
1191
- bos, eos = 150004, 150005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1192
 
1193
- if "eos_token_id" not in kwargs:
1194
- kwargs["eos_token_id"] = eos
 
 
 
 
 
1195
 
1196
- truncate = kwargs.pop("truncate") if "truncate" in kwargs else False
 
 
1197
 
1198
- stop = False
 
 
 
 
 
 
1199
 
1200
- return_seqs = []
 
 
 
1201
 
 
 
1202
  while True:
1203
- output_ids = super().generate(**kwargs)
1204
- return_seqs = []
1205
- max_length = 0
1206
-
1207
- for i in range(output_ids.shape[0]):
1208
- output_seq = output_ids[i].tolist()
1209
- if truncate:
1210
- output_seq = output_seq[len(kwargs["input_ids"][i]) - 2:]
1211
- mask_token = MASK if MASK in output_seq else gMASK
1212
- mask_position = output_seq.index(mask_token)
1213
- bos_position = output_seq.index(bos)
1214
- if eos in output_seq:
1215
- eos_position = output_seq.index(eos)
1216
- else:
1217
- eos_position = len(output_seq)
1218
-
1219
- return_seq = output_seq[:mask_position] + output_seq[bos_position + 1:eos_position] + output_seq[
1220
- mask_position + 1:bos_position]
1221
- max_length = max(max_length, len(return_seq))
1222
- return_seqs.append(return_seq)
1223
-
1224
- for i in range(output_ids.shape[0]):
1225
- return_seqs[i] = [0] * (max_length - len(return_seqs[i])) + return_seqs[i] # padding
1226
- if mask_token not in return_seqs[i]:
1227
- stop = True
1228
-
1229
- if stop:
1230
- break
1231
 
1232
- for return_seq in return_seqs:
1233
- return_seq += [bos]
1234
 
1235
- kwargs['input_ids'] = torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
 
 
1236
 
1237
- return torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1238
 
1239
  def quantize(self, bits: int):
1240
  from .quantization import quantize
 
3
  import math
4
  import copy
5
  import os
6
+ import warnings
7
+ import re
8
+ import sys
9
 
10
  import torch
11
  import torch.utils.checkpoint
 
13
  from torch import nn
14
  from torch.nn import CrossEntropyLoss, LayerNorm
15
  from torch.nn.utils import skip_init
16
+ from typing import Optional, Tuple, Union, List, Callable
17
 
18
  from transformers.utils import (
19
  add_code_sample_docstrings,
 
28
  from transformers.modeling_utils import PreTrainedModel
29
  from transformers.utils import logging
30
  from transformers.generation.logits_process import LogitsProcessor
31
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
32
 
33
  from .configuration_chatglm import ChatGLMConfig
34
 
35
  # flags required to enable jit fusion kernels
36
+
37
+ if sys.platform != 'darwin':
38
+ torch._C._jit_set_profiling_mode(False)
39
+ torch._C._jit_set_profiling_executor(False)
40
+ torch._C._jit_override_can_fuse_on_cpu(True)
41
+ torch._C._jit_override_can_fuse_on_gpu(True)
42
 
43
  logger = logging.get_logger(__name__)
44
 
 
298
  if not (attention_mask == 0).all():
299
  # if auto-regressive, skip
300
  attention_scores.masked_fill_(attention_mask, -10000.0)
301
+ dtype = attention_scores.dtype
302
  attention_scores = attention_scores.float()
303
  attention_scores = attention_scores * query_key_layer_scaling_coeff
304
 
 
818
  return past_key_values
819
 
820
  @staticmethod
821
+ def get_masks(self, seq, device):
822
+ context_length = seq.index(self.config.bos_token_id) + 1
823
 
824
  attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
825
  attention_mask.tril_()
 
830
  return attention_mask
831
 
832
  def get_position_ids(self, seq, mask_position, device, gmask=False):
833
+ context_length = len(seq)
834
  if self.position_encoding_2d:
835
+ seq_length = seq.index(self.config.bos_token_id)
836
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
837
  if not gmask:
838
  position_ids[seq_length:] = mask_position
 
890
  past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
891
  else:
892
  past_key_values = tuple([None] * len(self.layers))
 
 
 
 
893
  seq = input_ids[0].tolist()
894
 
 
 
895
  if attention_mask is None:
896
  attention_mask = self.get_masks(
897
  seq=seq,
 
904
  attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
905
 
906
  if position_ids is None:
907
+ MASK, gMASK = 150000, 150001
908
+ mask_token = MASK if MASK in input_ids else gMASK
909
+ use_gmask = False if MASK in input_ids else gMASK
910
+
911
+ mask_position = seq.index(mask_token)
912
  position_ids = self.get_position_ids(
913
  seq=seq,
914
  mask_position=mask_position,
 
1012
  attention_mask = (attention_mask < 0.5).bool()
1013
 
1014
  if self.position_encoding_2d:
1015
+ seq_length = seq.index(self.config.bos_token_id)
1016
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
1017
  if not gmask:
1018
  position_ids[seq_length:] = mask_position
 
1050
 
1051
  # only last token for input_ids if past is not None
1052
  if past is not None or past_key_values is not None:
1053
+ context_length = seq.index(self.config.bos_token_id)
1054
  last_token = input_ids[:, -1].unsqueeze(-1)
1055
  if self.position_encoding_2d:
1056
  position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
 
1158
  for layer_past in past
1159
  )
1160
 
1161
+ def process_response(self, response):
1162
+ response = response.strip()
1163
+ response = response.replace("[[训练时间]]", "2023年")
1164
+ punkts = [
1165
+ [",", ","],
1166
+ ["!", "!"],
1167
+ [":", ":"],
1168
+ [";", ";"],
1169
+ ["\?", "?"],
1170
+ ]
1171
+ for item in punkts:
1172
+ response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
1173
+ response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
1174
+ return response
1175
+
1176
  @torch.no_grad()
1177
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1178
  do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
 
1193
  input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1194
  input_ids = input_ids.to(self.device)
1195
  outputs = self.generate(**input_ids, **gen_kwargs)
1196
+ outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1197
  response = tokenizer.decode(outputs)
1198
+ response = self.process_response(response)
 
1199
  history = history + [(query, response)]
1200
  return response, history
1201
 
1202
  @torch.no_grad()
1203
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
1204
+ do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
1205
+ if history is None:
1206
+ history = []
1207
+ if logits_processor is None:
1208
+ logits_processor = LogitsProcessorList()
1209
+ logits_processor.append(InvalidScoreLogitsProcessor())
1210
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1211
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1212
+ if not history:
1213
+ prompt = query
1214
+ else:
1215
+ prompt = ""
1216
+ for i, (old_query, response) in enumerate(history):
1217
+ prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1218
+ prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1219
+ input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1220
+ input_ids = input_ids.to(self.device)
1221
+ for outputs in self.stream_generate(**input_ids, **gen_kwargs):
1222
+ outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1223
+ response = tokenizer.decode(outputs)
1224
+ response = self.process_response(response)
1225
+ new_history = history + [(query, response)]
1226
+ yield response, new_history
1227
+
1228
+ @torch.no_grad()
1229
+ def stream_generate(
1230
  self,
1231
+ input_ids,
1232
+ generation_config: Optional[GenerationConfig] = None,
1233
+ logits_processor: Optional[LogitsProcessorList] = None,
1234
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1235
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1236
  **kwargs,
1237
  ):
1238
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1239
+
1240
+ if generation_config is None:
1241
+ generation_config = self.generation_config
1242
+ generation_config = copy.deepcopy(generation_config)
1243
+ model_kwargs = generation_config.update(**kwargs)
1244
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1245
+
1246
+ if isinstance(eos_token_id, int):
1247
+ eos_token_id = [eos_token_id]
1248
+
1249
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1250
+ if has_default_max_length and generation_config.max_new_tokens is None:
1251
+ warnings.warn(
1252
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1253
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1254
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
1255
+ UserWarning,
1256
+ )
1257
+ elif generation_config.max_new_tokens is not None:
1258
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1259
+ if not has_default_max_length:
1260
+ logger.warn(
1261
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1262
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1263
+ "Please refer to the documentation for more information. "
1264
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1265
+ UserWarning,
1266
+ )
1267
 
1268
+ if input_ids_seq_length >= generation_config.max_length:
1269
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1270
+ logger.warning(
1271
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1272
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1273
+ " increasing `max_new_tokens`."
1274
+ )
1275
 
1276
+ # 2. Set generation parameters if not already defined
1277
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1278
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1279
 
1280
+ logits_processor = self._get_logits_processor(
1281
+ generation_config=generation_config,
1282
+ input_ids_seq_length=input_ids_seq_length,
1283
+ encoder_input_ids=input_ids,
1284
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1285
+ logits_processor=logits_processor,
1286
+ )
1287
 
1288
+ stopping_criteria = self._get_stopping_criteria(
1289
+ generation_config=generation_config, stopping_criteria=stopping_criteria
1290
+ )
1291
+ logits_warper = self._get_logits_warper(generation_config)
1292
 
1293
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1294
+ scores = None
1295
  while True:
1296
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1297
+ # forward pass to get next token
1298
+ outputs = self(
1299
+ **model_inputs,
1300
+ return_dict=True,
1301
+ output_attentions=False,
1302
+ output_hidden_states=False,
1303
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1304
 
1305
+ next_token_logits = outputs.logits[:, -1, :]
 
1306
 
1307
+ # pre-process distribution
1308
+ next_token_scores = logits_processor(input_ids, next_token_logits)
1309
+ next_token_scores = logits_warper(input_ids, next_token_scores)
1310
 
1311
+ # sample
1312
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1313
+ if generation_config.do_sample:
1314
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1315
+ else:
1316
+ next_tokens = torch.argmax(probs, dim=-1)
1317
+
1318
+ # update generated ids, model inputs, and length for next step
1319
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1320
+ model_kwargs = self._update_model_kwargs_for_generation(
1321
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1322
+ )
1323
+ unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
1324
+
1325
+ # stop when each sentence is finished, or if we exceed the maximum length
1326
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1327
+ break
1328
+ yield input_ids
1329
 
1330
  def quantize(self, bits: int):
1331
  from .quantization import quantize
tokenization_chatglm.py CHANGED
@@ -299,7 +299,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
299
  """
300
  if os.path.isdir(save_directory):
301
  vocab_file = os.path.join(
302
- save_directory, VOCAB_FILES_NAMES["vocab_file"]
303
  )
304
  else:
305
  vocab_file = save_directory
 
299
  """
300
  if os.path.isdir(save_directory):
301
  vocab_file = os.path.join(
302
+ save_directory, self.vocab_files_names["vocab_file"]
303
  )
304
  else:
305
  vocab_file = save_directory