Merge remote-tracking branch 'thu/main'
Browse files- README.md +4 -0
- modeling_chatglm.py +18 -4
README.md
CHANGED
@@ -14,6 +14,8 @@ ChatGLM-6B-Slim是在ChatGLM-6B的基础上通过裁剪词表构建的。因为C
|
|
14 |
|
15 |
ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。
|
16 |
|
|
|
|
|
17 |
## 软件依赖
|
18 |
|
19 |
```shell
|
@@ -47,6 +49,8 @@ pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels
|
|
47 |
|
48 |
关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO,以及使用模型量化以节省显存,请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。
|
49 |
|
|
|
|
|
50 |
## 协议
|
51 |
|
52 |
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。
|
|
|
14 |
|
15 |
ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。
|
16 |
|
17 |
+
ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference.
|
18 |
+
|
19 |
## 软件依赖
|
20 |
|
21 |
```shell
|
|
|
49 |
|
50 |
关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO,以及使用模型量化以节省显存,请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。
|
51 |
|
52 |
+
For more instructions, including how to run CLI and web demos, and model quantization, please refer to our [Github Repo](https://github.com/THUDM/ChatGLM-6B).
|
53 |
+
|
54 |
## 协议
|
55 |
|
56 |
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。
|
modeling_chatglm.py
CHANGED
@@ -4,6 +4,7 @@ import math
|
|
4 |
import copy
|
5 |
import os
|
6 |
import warnings
|
|
|
7 |
|
8 |
import torch
|
9 |
import torch.utils.checkpoint
|
@@ -1086,6 +1087,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1086 |
for layer_past in past
|
1087 |
)
|
1088 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1089 |
@torch.no_grad()
|
1090 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
1091 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
@@ -1108,8 +1124,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1108 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
1109 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1110 |
response = tokenizer.decode(outputs)
|
1111 |
-
response =
|
1112 |
-
response = response.replace("[[训练时间]]", "2023年")
|
1113 |
history = history + [(query, response)]
|
1114 |
return response, history
|
1115 |
|
@@ -1135,8 +1150,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1135 |
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
1136 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1137 |
response = tokenizer.decode(outputs)
|
1138 |
-
response =
|
1139 |
-
response = response.replace("[[训练时间]]", "2023年")
|
1140 |
new_history = history + [(query, response)]
|
1141 |
yield response, new_history
|
1142 |
|
|
|
4 |
import copy
|
5 |
import os
|
6 |
import warnings
|
7 |
+
import re
|
8 |
|
9 |
import torch
|
10 |
import torch.utils.checkpoint
|
|
|
1087 |
for layer_past in past
|
1088 |
)
|
1089 |
|
1090 |
+
def process_response(self, response):
|
1091 |
+
response = response.strip()
|
1092 |
+
response = response.replace("[[训练时间]]", "2023年")
|
1093 |
+
punkts = [
|
1094 |
+
[",", ","],
|
1095 |
+
["!", "!"],
|
1096 |
+
[":", ":"],
|
1097 |
+
[";", ";"],
|
1098 |
+
["\?", "?"],
|
1099 |
+
]
|
1100 |
+
for item in punkts:
|
1101 |
+
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
|
1102 |
+
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
|
1103 |
+
return response
|
1104 |
+
|
1105 |
@torch.no_grad()
|
1106 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
1107 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
|
|
1124 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
1125 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1126 |
response = tokenizer.decode(outputs)
|
1127 |
+
response = self.process_response(response)
|
|
|
1128 |
history = history + [(query, response)]
|
1129 |
return response, history
|
1130 |
|
|
|
1150 |
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
1151 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
1152 |
response = tokenizer.decode(outputs)
|
1153 |
+
response = self.process_response(response)
|
|
|
1154 |
new_history = history + [(query, response)]
|
1155 |
yield response, new_history
|
1156 |
|