Spaces:
Runtime error
Runtime error
jackbond2024
commited on
Commit
•
03f3fb4
1
Parent(s):
0f162b7
Upload folder using huggingface_hub
Browse files- README.md +139 -7
- README_en.md +141 -0
- openai_api_request.py +127 -0
- openai_api_server.py +635 -0
- requirements.txt +29 -0
- trans_batch_demo.py +90 -0
- trans_cli_demo.py +119 -0
- trans_cli_vision_demo.py +123 -0
- trans_stress_test.py +135 -0
- trans_web_demo.py +179 -0
- trans_web_vision_demo.py +121 -0
- vllm_cli_demo.py +111 -0
README.md
CHANGED
@@ -1,12 +1,144 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: green
|
5 |
-
colorTo: pink
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.37.2
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: glm4
|
3 |
+
app_file: trans_web_demo.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 4.37.2
|
|
|
|
|
6 |
---
|
7 |
+
# Basic Demo
|
8 |
+
|
9 |
+
Read this in [English](README_en.md).
|
10 |
+
|
11 |
+
本 demo 中,你将体验到如何使用 GLM-4-9B 开源模型进行基本的任务。
|
12 |
+
|
13 |
+
请严格按照文档的步骤进行操作,以避免不必要的错误。
|
14 |
+
|
15 |
+
## 设备和依赖检查
|
16 |
+
|
17 |
+
### 相关推理测试数据
|
18 |
+
|
19 |
+
**本文档的数据均在以下硬件环境测试,实际运行环境需求和运行占用的显存略有不同,请以实际运行环境为准。**
|
20 |
+
|
21 |
+
测试硬件信息:
|
22 |
+
|
23 |
+
+ OS: Ubuntu 22.04
|
24 |
+
+ Memory: 512GB
|
25 |
+
+ Python: 3.10.12 (推荐) / 3.12.3 均已测试
|
26 |
+
+ CUDA Version: 12.3
|
27 |
+
+ GPU Driver: 535.104.05
|
28 |
+
+ GPU: NVIDIA A100-SXM4-80GB * 8
|
29 |
+
|
30 |
+
相关推理的压力测试数据如下:
|
31 |
+
|
32 |
+
**所有测试均在单张GPU上进行测试,所有显存消耗都按照峰值左右进行测算**
|
33 |
+
|
34 |
+
#### GLM-4-9B-Chat
|
35 |
+
|
36 |
+
| 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
|
37 |
+
|------|-------|------------|---------------|--------------|
|
38 |
+
| BF16 | 19 GB | 0.2s | 27.8 tokens/s | 输入长度为 1000 |
|
39 |
+
| BF16 | 21 GB | 0.8s | 31.8 tokens/s | 输入长度为 8000 |
|
40 |
+
| BF16 | 28 GB | 4.3s | 14.4 tokens/s | 输入长度为 32000 |
|
41 |
+
| BF16 | 58 GB | 38.1s | 3.4 tokens/s | 输入长度为 128000 |
|
42 |
+
|
43 |
+
| 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
|
44 |
+
|------|-------|------------|---------------|-------------|
|
45 |
+
| INT4 | 8 GB | 0.2s | 23.3 tokens/s | 输入长度为 1000 |
|
46 |
+
| INT4 | 10 GB | 0.8s | 23.4 tokens/s | 输入长度为 8000 |
|
47 |
+
| INT4 | 17 GB | 4.3s | 14.6 tokens/s | 输入长度为 32000 |
|
48 |
+
|
49 |
+
#### GLM-4-9B-Chat-1M
|
50 |
+
|
51 |
+
| 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
|
52 |
+
|------|-------|------------|--------------|--------------|
|
53 |
+
| BF16 | 75 GB | 98.4s | 2.3 tokens/s | 输入长度为 200000 |
|
54 |
+
|
55 |
+
如果您的输入超过200K,我们建议您使用vLLM后端进行多卡推理,以获得更好的性能。
|
56 |
+
|
57 |
+
#### GLM-4V-9B
|
58 |
+
|
59 |
+
| 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
|
60 |
+
|------|-------|------------|---------------|------------|
|
61 |
+
| BF16 | 28 GB | 0.1s | 33.4 tokens/s | 输入长度为 1000 |
|
62 |
+
| BF16 | 33 GB | 0.7s | 39.2 tokens/s | 输入长度为 8000 |
|
63 |
+
|
64 |
+
| 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
|
65 |
+
|------|-------|------------|---------------|------------|
|
66 |
+
| INT4 | 10 GB | 0.1s | 28.7 tokens/s | 输入长度为 1000 |
|
67 |
+
| INT4 | 15 GB | 0.8s | 24.2 tokens/s | 输入长度为 8000 |
|
68 |
+
|
69 |
+
### 最低硬件要求
|
70 |
+
|
71 |
+
如果您希望运行官方提供的最基础代码 (transformers 后端) 您需要:
|
72 |
+
|
73 |
+
+ Python >= 3.10
|
74 |
+
+ 内存不少于 32 GB
|
75 |
+
|
76 |
+
如果您希望运行官方提供的本文件夹的所有代码,您还需要:
|
77 |
+
|
78 |
+
+ Linux 操作系统 (Debian 系列最佳)
|
79 |
+
+ 大于 8GB 显存的,支持 CUDA 或者 ROCM 并且支持 `BF16` 推理的 GPU 设备。(`FP16` 精度无法训练,推理有小概率出现问题)
|
80 |
+
|
81 |
+
安装依赖
|
82 |
+
|
83 |
+
```shell
|
84 |
+
pip install -r requirements.txt
|
85 |
+
```
|
86 |
+
|
87 |
+
## 基础功能调用
|
88 |
+
|
89 |
+
**除非特殊说明,本文件夹所有 demo 并不支持 Function Call 和 All Tools 等进阶用法**
|
90 |
+
|
91 |
+
### 使用 transformers 后端代码
|
92 |
+
|
93 |
+
+ 使用命令行与 GLM-4-9B 模型进行对话。
|
94 |
+
|
95 |
+
```shell
|
96 |
+
python trans_cli_demo.py # GLM-4-9B-Chat
|
97 |
+
python trans_cli_vision_demo.py # GLM-4V-9B
|
98 |
+
```
|
99 |
+
|
100 |
+
+ 使用 Gradio 网页端与 GLM-4-9B 模型进行对话。
|
101 |
+
|
102 |
+
```shell
|
103 |
+
python trans_web_demo.py # GLM-4-9B-Chat
|
104 |
+
python trans_web_vision_demo.py # GLM-4V-9B
|
105 |
+
```
|
106 |
+
|
107 |
+
+ 使用 Batch 推理。
|
108 |
+
|
109 |
+
```shell
|
110 |
+
python trans_batch_demo.py
|
111 |
+
```
|
112 |
+
|
113 |
+
### 使用 vLLM 后端代码
|
114 |
+
|
115 |
+
+ 使用命令行与 GLM-4-9B-Chat 模型进行对话。
|
116 |
+
|
117 |
+
```shell
|
118 |
+
python vllm_cli_demo.py
|
119 |
+
```
|
120 |
+
|
121 |
+
+ 自行构建服务端,并使用 `OpenAI API` 的请求格式与 GLM-4-9B-Chat 模型进行对话。本 demo 支持 Function Call 和 All Tools功能。
|
122 |
+
|
123 |
+
启动服务端:
|
124 |
+
|
125 |
+
```shell
|
126 |
+
python openai_api_server.py
|
127 |
+
```
|
128 |
+
|
129 |
+
客户端请求:
|
130 |
+
|
131 |
+
```shell
|
132 |
+
python openai_api_request.py
|
133 |
+
```
|
134 |
+
|
135 |
+
## 压力测试
|
136 |
+
|
137 |
+
用户可以在自己的设备上使用本代码测试模型在 transformers后端的生成速度:
|
138 |
+
|
139 |
+
```shell
|
140 |
+
python trans_stress_test.py
|
141 |
+
```
|
142 |
+
|
143 |
+
|
144 |
|
|
README_en.md
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Basic Demo
|
2 |
+
|
3 |
+
In this demo, you will experience how to use the GLM-4-9B open source model to perform basic tasks.
|
4 |
+
|
5 |
+
Please follow the steps in the document strictly to avoid unnecessary errors.
|
6 |
+
|
7 |
+
## Device and dependency check
|
8 |
+
|
9 |
+
### Related inference test data
|
10 |
+
|
11 |
+
**The data in this document are tested in the following hardware environment. The actual operating environment
|
12 |
+
requirements and the GPU memory occupied by the operation are slightly different. Please refer to the actual operating
|
13 |
+
environment.**
|
14 |
+
|
15 |
+
Test hardware information:
|
16 |
+
|
17 |
+
+ OS: Ubuntu 22.04
|
18 |
+
+ Memory: 512GB
|
19 |
+
+ Python: 3.10.12 (recommend) / 3.12.3 have been tested
|
20 |
+
+ CUDA Version: 12.3
|
21 |
+
+ GPU Driver: 535.104.05
|
22 |
+
+ GPU: NVIDIA A100-SXM4-80GB * 8
|
23 |
+
|
24 |
+
The stress test data of relevant inference are as follows:
|
25 |
+
|
26 |
+
**All tests are performed on a single GPU, and all GPU memory consumption is calculated based on the peak value**
|
27 |
+
|
28 |
+
#
|
29 |
+
|
30 |
+
### GLM-4-9B-Chat
|
31 |
+
|
32 |
+
| Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
|
33 |
+
|-------|------------|------------|---------------|------------------------|
|
34 |
+
| BF16 | 19 GB | 0.2s | 27.8 tokens/s | Input length is 1000 |
|
35 |
+
| BF16 | 21 GB | 0.8s | 31.8 tokens/s | Input length is 8000 |
|
36 |
+
| BF16 | 28 GB | 4.3s | 14.4 tokens/s | Input length is 32000 |
|
37 |
+
| BF16 | 58 GB | 38.1s | 3.4 tokens/s | Input length is 128000 |
|
38 |
+
|
39 |
+
| Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
|
40 |
+
|-------|------------|------------|---------------|-----------------------|
|
41 |
+
| INT4 | 8 GB | 0.2s | 23.3 tokens/s | Input length is 1000 |
|
42 |
+
| INT4 | 10 GB | 0.8s | 23.4 tokens/s | Input length is 8000 |
|
43 |
+
| INT4 | 17 GB | 4.3s | 14.6 tokens/s | Input length is 32000 |
|
44 |
+
|
45 |
+
### GLM-4-9B-Chat-1M
|
46 |
+
|
47 |
+
| Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
|
48 |
+
|-------|------------|------------|------------------|------------------------|
|
49 |
+
| BF16 | 74497MiB | 98.4s | 2.3653 tokens/s | Input length is 200000 |
|
50 |
+
|
51 |
+
If your input exceeds 200K, we recommend that you use the vLLM backend with multi gpus for inference to get better
|
52 |
+
performance.
|
53 |
+
|
54 |
+
#### GLM-4V-9B
|
55 |
+
|
56 |
+
| Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
|
57 |
+
|-------|------------|------------|---------------|----------------------|
|
58 |
+
| BF16 | 28 GB | 0.1s | 33.4 tokens/s | Input length is 1000 |
|
59 |
+
| BF16 | 33 GB | 0.7s | 39.2 tokens/s | Input length is 8000 |
|
60 |
+
|
61 |
+
| Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
|
62 |
+
|-------|------------|------------|---------------|----------------------|
|
63 |
+
| INT4 | 10 GB | 0.1s | 28.7 tokens/s | Input length is 1000 |
|
64 |
+
| INT4 | 15 GB | 0.8s | 24.2 tokens/s | Input length is 8000 |
|
65 |
+
|
66 |
+
### Minimum hardware requirements
|
67 |
+
|
68 |
+
If you want to run the most basic code provided by the official (transformers backend) you need:
|
69 |
+
|
70 |
+
+ Python >= 3.10
|
71 |
+
+ Memory of at least 32 GB
|
72 |
+
|
73 |
+
If you want to run all the codes in this folder provided by the official, you also need:
|
74 |
+
|
75 |
+
+ Linux operating system (Debian series is best)
|
76 |
+
+ GPU device with more than 8GB GPU memory, supporting CUDA or ROCM and supporting `BF16` reasoning (`FP16` precision
|
77 |
+
cannot be finetuned, and there is a small probability of problems in infering)
|
78 |
+
|
79 |
+
Install dependencies
|
80 |
+
|
81 |
+
```shell
|
82 |
+
pip install -r requirements.txt
|
83 |
+
```
|
84 |
+
|
85 |
+
## Basic function calls
|
86 |
+
|
87 |
+
**Unless otherwise specified, all demos in this folder do not support advanced usage such as Function Call and All Tools
|
88 |
+
**
|
89 |
+
|
90 |
+
### Use transformers backend code
|
91 |
+
|
92 |
+
+ Use the command line to communicate with the GLM-4-9B model.
|
93 |
+
|
94 |
+
```shell
|
95 |
+
python trans_cli_demo.py # GLM-4-9B-Chat
|
96 |
+
python trans_cli_vision_demo.py # GLM-4V-9B
|
97 |
+
```
|
98 |
+
|
99 |
+
+ Use the Gradio web client to communicate with the GLM-4-9B model.
|
100 |
+
|
101 |
+
```shell
|
102 |
+
python trans_web_demo.py # GLM-4-9B-Chat
|
103 |
+
python trans_web_vision_demo.py # GLM-4V-9B
|
104 |
+
```
|
105 |
+
|
106 |
+
+ Use Batch inference.
|
107 |
+
|
108 |
+
```shell
|
109 |
+
python trans_batch_demo.py
|
110 |
+
```
|
111 |
+
|
112 |
+
### Use vLLM backend code
|
113 |
+
|
114 |
+
+ Use the command line to communicate with the GLM-4-9B-Chat model.
|
115 |
+
|
116 |
+
```shell
|
117 |
+
python vllm_cli_demo.py
|
118 |
+
```
|
119 |
+
|
120 |
+
+ Build the server by yourself and use the request format of `OpenAI API` to communicate with the glm-4-9b model. This
|
121 |
+
demo supports Function Call and All Tools functions.
|
122 |
+
|
123 |
+
Start the server:
|
124 |
+
|
125 |
+
```shell
|
126 |
+
python openai_api_server.py
|
127 |
+
```
|
128 |
+
|
129 |
+
Client request:
|
130 |
+
|
131 |
+
```shell
|
132 |
+
python openai_api_request.py
|
133 |
+
```
|
134 |
+
|
135 |
+
## Stress test
|
136 |
+
|
137 |
+
Users can use this code to test the generation speed of the model on the transformers backend on their own devices:
|
138 |
+
|
139 |
+
```shell
|
140 |
+
python trans_stress_test.py
|
141 |
+
```
|
openai_api_request.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script creates a OpenAI Request demo for the glm-4-9b model, just Use OpenAI API to interact with the model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from openai import OpenAI
|
6 |
+
|
7 |
+
base_url = "http://127.0.0.1:8000/v1/"
|
8 |
+
client = OpenAI(api_key="EMPTY", base_url=base_url)
|
9 |
+
|
10 |
+
|
11 |
+
def function_chat(use_stream=False):
|
12 |
+
messages = [
|
13 |
+
{
|
14 |
+
"role": "user", "content": "What's the Celsius temperature in San Francisco?"
|
15 |
+
},
|
16 |
+
|
17 |
+
# Give Observations
|
18 |
+
# {
|
19 |
+
# "role": "assistant",
|
20 |
+
# "content": None,
|
21 |
+
# "function_call": None,
|
22 |
+
# "tool_calls": [
|
23 |
+
# {
|
24 |
+
# "id": "call_1717912616815",
|
25 |
+
# "function": {
|
26 |
+
# "name": "get_current_weather",
|
27 |
+
# "arguments": "{\"location\": \"San Francisco, CA\", \"format\": \"celsius\"}"
|
28 |
+
# },
|
29 |
+
# "type": "function"
|
30 |
+
# }
|
31 |
+
# ]
|
32 |
+
# },
|
33 |
+
# {
|
34 |
+
# "tool_call_id": "call_1717912616815",
|
35 |
+
# "role": "tool",
|
36 |
+
# "name": "get_current_weather",
|
37 |
+
# "content": "23°C",
|
38 |
+
# }
|
39 |
+
]
|
40 |
+
tools = [
|
41 |
+
{
|
42 |
+
"type": "function",
|
43 |
+
"function": {
|
44 |
+
"name": "get_current_weather",
|
45 |
+
"description": "Get the current weather",
|
46 |
+
"parameters": {
|
47 |
+
"type": "object",
|
48 |
+
"properties": {
|
49 |
+
"location": {
|
50 |
+
"type": "string",
|
51 |
+
"description": "The city and state, e.g. San Francisco, CA",
|
52 |
+
},
|
53 |
+
"format": {
|
54 |
+
"type": "string",
|
55 |
+
"enum": ["celsius", "fahrenheit"],
|
56 |
+
"description": "The temperature unit to use. Infer this from the users location.",
|
57 |
+
},
|
58 |
+
},
|
59 |
+
"required": ["location", "format"],
|
60 |
+
},
|
61 |
+
}
|
62 |
+
},
|
63 |
+
]
|
64 |
+
|
65 |
+
# All Tools: CogView
|
66 |
+
# messages = [{"role": "user", "content": "帮我画一张天空的画画吧"}]
|
67 |
+
# tools = [{"type": "cogview"}]
|
68 |
+
|
69 |
+
# All Tools: Searching
|
70 |
+
# messages = [{"role": "user", "content": "今天黄金的价格"}]
|
71 |
+
# tools = [{"type": "simple_browser"}]
|
72 |
+
|
73 |
+
response = client.chat.completions.create(
|
74 |
+
model="glm-4",
|
75 |
+
messages=messages,
|
76 |
+
tools=tools,
|
77 |
+
stream=use_stream,
|
78 |
+
max_tokens=256,
|
79 |
+
temperature=0.9,
|
80 |
+
presence_penalty=1.2,
|
81 |
+
top_p=0.1,
|
82 |
+
tool_choice="auto"
|
83 |
+
)
|
84 |
+
if response:
|
85 |
+
if use_stream:
|
86 |
+
for chunk in response:
|
87 |
+
print(chunk)
|
88 |
+
else:
|
89 |
+
print(response)
|
90 |
+
else:
|
91 |
+
print("Error:", response.status_code)
|
92 |
+
|
93 |
+
|
94 |
+
def simple_chat(use_stream=False):
|
95 |
+
messages = [
|
96 |
+
{
|
97 |
+
"role": "system",
|
98 |
+
"content": "请在你输出的时候都带上“喵喵喵”三个字,放在开头。",
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"role": "user",
|
102 |
+
"content": "你是谁"
|
103 |
+
}
|
104 |
+
]
|
105 |
+
response = client.chat.completions.create(
|
106 |
+
model="glm-4",
|
107 |
+
messages=messages,
|
108 |
+
stream=use_stream,
|
109 |
+
max_tokens=256,
|
110 |
+
temperature=0.4,
|
111 |
+
presence_penalty=1.2,
|
112 |
+
top_p=0.8,
|
113 |
+
)
|
114 |
+
if response:
|
115 |
+
if use_stream:
|
116 |
+
for chunk in response:
|
117 |
+
print(chunk)
|
118 |
+
else:
|
119 |
+
print(response)
|
120 |
+
else:
|
121 |
+
print("Error:", response.status_code)
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
# simple_chat(use_stream=False)
|
126 |
+
function_chat(use_stream=False)
|
127 |
+
|
openai_api_server.py
ADDED
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from asyncio.log import logger
|
3 |
+
import re
|
4 |
+
import uvicorn
|
5 |
+
import gc
|
6 |
+
import json
|
7 |
+
import torch
|
8 |
+
import random
|
9 |
+
import string
|
10 |
+
|
11 |
+
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
|
12 |
+
from fastapi import FastAPI, HTTPException, Response
|
13 |
+
from fastapi.middleware.cors import CORSMiddleware
|
14 |
+
from contextlib import asynccontextmanager
|
15 |
+
from typing import List, Literal, Optional, Union
|
16 |
+
from pydantic import BaseModel, Field
|
17 |
+
from transformers import AutoTokenizer, LogitsProcessor
|
18 |
+
from sse_starlette.sse import EventSourceResponse
|
19 |
+
|
20 |
+
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
21 |
+
import os
|
22 |
+
|
23 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
|
24 |
+
MAX_MODEL_LENGTH = 8192
|
25 |
+
|
26 |
+
|
27 |
+
@asynccontextmanager
|
28 |
+
async def lifespan(app: FastAPI):
|
29 |
+
yield
|
30 |
+
if torch.cuda.is_available():
|
31 |
+
torch.cuda.empty_cache()
|
32 |
+
torch.cuda.ipc_collect()
|
33 |
+
|
34 |
+
|
35 |
+
app = FastAPI(lifespan=lifespan)
|
36 |
+
|
37 |
+
app.add_middleware(
|
38 |
+
CORSMiddleware,
|
39 |
+
allow_origins=["*"],
|
40 |
+
allow_credentials=True,
|
41 |
+
allow_methods=["*"],
|
42 |
+
allow_headers=["*"],
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def generate_id(prefix: str, k=29) -> str:
|
47 |
+
suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=k))
|
48 |
+
return f"{prefix}{suffix}"
|
49 |
+
|
50 |
+
|
51 |
+
class ModelCard(BaseModel):
|
52 |
+
id: str = ""
|
53 |
+
object: str = "model"
|
54 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
55 |
+
owned_by: str = "owner"
|
56 |
+
root: Optional[str] = None
|
57 |
+
parent: Optional[str] = None
|
58 |
+
permission: Optional[list] = None
|
59 |
+
|
60 |
+
|
61 |
+
class ModelList(BaseModel):
|
62 |
+
object: str = "list"
|
63 |
+
data: List[ModelCard] = ["glm-4"]
|
64 |
+
|
65 |
+
|
66 |
+
class FunctionCall(BaseModel):
|
67 |
+
name: Optional[str] = None
|
68 |
+
arguments: Optional[str] = None
|
69 |
+
|
70 |
+
|
71 |
+
class ChoiceDeltaToolCallFunction(BaseModel):
|
72 |
+
name: Optional[str] = None
|
73 |
+
arguments: Optional[str] = None
|
74 |
+
|
75 |
+
|
76 |
+
class UsageInfo(BaseModel):
|
77 |
+
prompt_tokens: int = 0
|
78 |
+
total_tokens: int = 0
|
79 |
+
completion_tokens: Optional[int] = 0
|
80 |
+
|
81 |
+
|
82 |
+
class ChatCompletionMessageToolCall(BaseModel):
|
83 |
+
index: Optional[int] = 0
|
84 |
+
id: Optional[str] = None
|
85 |
+
function: FunctionCall
|
86 |
+
type: Optional[Literal["function"]] = 'function'
|
87 |
+
|
88 |
+
|
89 |
+
class ChatMessage(BaseModel):
|
90 |
+
# “function” 字段解释:
|
91 |
+
# 使用较老的OpenAI API版本需要注意在这里添加 function 字段并在 process_messages函数中添加相应角色转换逻辑为 observation
|
92 |
+
|
93 |
+
role: Literal["user", "assistant", "system", "tool"]
|
94 |
+
content: Optional[str] = None
|
95 |
+
function_call: Optional[ChoiceDeltaToolCallFunction] = None
|
96 |
+
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
97 |
+
|
98 |
+
|
99 |
+
class DeltaMessage(BaseModel):
|
100 |
+
role: Optional[Literal["user", "assistant", "system"]] = None
|
101 |
+
content: Optional[str] = None
|
102 |
+
function_call: Optional[ChoiceDeltaToolCallFunction] = None
|
103 |
+
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
104 |
+
|
105 |
+
|
106 |
+
class ChatCompletionResponseChoice(BaseModel):
|
107 |
+
index: int
|
108 |
+
message: ChatMessage
|
109 |
+
finish_reason: Literal["stop", "length", "tool_calls"]
|
110 |
+
|
111 |
+
|
112 |
+
class ChatCompletionResponseStreamChoice(BaseModel):
|
113 |
+
delta: DeltaMessage
|
114 |
+
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
|
115 |
+
index: int
|
116 |
+
|
117 |
+
|
118 |
+
class ChatCompletionResponse(BaseModel):
|
119 |
+
model: str
|
120 |
+
id: Optional[str] = Field(default_factory=lambda: generate_id('chatcmpl-', 29))
|
121 |
+
object: Literal["chat.completion", "chat.completion.chunk"]
|
122 |
+
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
123 |
+
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
124 |
+
system_fingerprint: Optional[str] = Field(default_factory=lambda: generate_id('fp_', 9))
|
125 |
+
usage: Optional[UsageInfo] = None
|
126 |
+
|
127 |
+
|
128 |
+
class ChatCompletionRequest(BaseModel):
|
129 |
+
model: str
|
130 |
+
messages: List[ChatMessage]
|
131 |
+
temperature: Optional[float] = 0.8
|
132 |
+
top_p: Optional[float] = 0.8
|
133 |
+
max_tokens: Optional[int] = None
|
134 |
+
stream: Optional[bool] = False
|
135 |
+
tools: Optional[Union[dict, List[dict]]] = None
|
136 |
+
tool_choice: Optional[Union[str, dict]] = None
|
137 |
+
repetition_penalty: Optional[float] = 1.1
|
138 |
+
|
139 |
+
|
140 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
141 |
+
def __call__(
|
142 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
143 |
+
) -> torch.FloatTensor:
|
144 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
145 |
+
scores.zero_()
|
146 |
+
scores[..., 5] = 5e4
|
147 |
+
return scores
|
148 |
+
|
149 |
+
|
150 |
+
def process_response(output: str, tools: dict | List[dict] = None, use_tool: bool = False) -> Union[str, dict]:
|
151 |
+
lines = output.strip().split("\n")
|
152 |
+
arguments_json = None
|
153 |
+
special_tools = ["cogview", "simple_browser"]
|
154 |
+
tools = {tool['function']['name'] for tool in tools} if tools else {}
|
155 |
+
|
156 |
+
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
|
157 |
+
##TODO 如果你希望做更多判断,可以在这里进行逻辑完善。
|
158 |
+
|
159 |
+
if len(lines) >= 2 and lines[1].startswith("{"):
|
160 |
+
function_name = lines[0].strip()
|
161 |
+
arguments = "\n".join(lines[1:]).strip()
|
162 |
+
if function_name in tools or function_name in special_tools:
|
163 |
+
try:
|
164 |
+
arguments_json = json.loads(arguments)
|
165 |
+
is_tool_call = True
|
166 |
+
except json.JSONDecodeError:
|
167 |
+
is_tool_call = function_name in special_tools
|
168 |
+
|
169 |
+
if is_tool_call and use_tool:
|
170 |
+
content = {
|
171 |
+
"name": function_name,
|
172 |
+
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
|
173 |
+
ensure_ascii=False)
|
174 |
+
}
|
175 |
+
if function_name == "simple_browser":
|
176 |
+
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
|
177 |
+
match = search_pattern.match(arguments)
|
178 |
+
if match:
|
179 |
+
content["arguments"] = json.dumps({
|
180 |
+
"query": match.group(1),
|
181 |
+
"recency_days": int(match.group(2))
|
182 |
+
}, ensure_ascii=False)
|
183 |
+
elif function_name == "cogview":
|
184 |
+
content["arguments"] = json.dumps({
|
185 |
+
"prompt": arguments
|
186 |
+
}, ensure_ascii=False)
|
187 |
+
|
188 |
+
return content
|
189 |
+
return output.strip()
|
190 |
+
|
191 |
+
|
192 |
+
@torch.inference_mode()
|
193 |
+
async def generate_stream_glm4(params):
|
194 |
+
messages = params["messages"]
|
195 |
+
tools = params["tools"]
|
196 |
+
tool_choice = params["tool_choice"]
|
197 |
+
temperature = float(params.get("temperature", 1.0))
|
198 |
+
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
199 |
+
top_p = float(params.get("top_p", 1.0))
|
200 |
+
max_new_tokens = int(params.get("max_tokens", 8192))
|
201 |
+
|
202 |
+
messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
|
203 |
+
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
204 |
+
params_dict = {
|
205 |
+
"n": 1,
|
206 |
+
"best_of": 1,
|
207 |
+
"presence_penalty": 1.0,
|
208 |
+
"frequency_penalty": 0.0,
|
209 |
+
"temperature": temperature,
|
210 |
+
"top_p": top_p,
|
211 |
+
"top_k": -1,
|
212 |
+
"repetition_penalty": repetition_penalty,
|
213 |
+
"use_beam_search": False,
|
214 |
+
"length_penalty": 1,
|
215 |
+
"early_stopping": False,
|
216 |
+
"stop_token_ids": [151329, 151336, 151338],
|
217 |
+
"ignore_eos": False,
|
218 |
+
"max_tokens": max_new_tokens,
|
219 |
+
"logprobs": None,
|
220 |
+
"prompt_logprobs": None,
|
221 |
+
"skip_special_tokens": True,
|
222 |
+
}
|
223 |
+
sampling_params = SamplingParams(**params_dict)
|
224 |
+
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
|
225 |
+
output_len = len(output.outputs[0].token_ids)
|
226 |
+
input_len = len(output.prompt_token_ids)
|
227 |
+
ret = {
|
228 |
+
"text": output.outputs[0].text,
|
229 |
+
"usage": {
|
230 |
+
"prompt_tokens": input_len,
|
231 |
+
"completion_tokens": output_len,
|
232 |
+
"total_tokens": output_len + input_len
|
233 |
+
},
|
234 |
+
"finish_reason": output.outputs[0].finish_reason,
|
235 |
+
}
|
236 |
+
yield ret
|
237 |
+
gc.collect()
|
238 |
+
torch.cuda.empty_cache()
|
239 |
+
|
240 |
+
|
241 |
+
def process_messages(messages, tools=None, tool_choice="none"):
|
242 |
+
_messages = messages
|
243 |
+
processed_messages = []
|
244 |
+
msg_has_sys = False
|
245 |
+
|
246 |
+
def filter_tools(tool_choice, tools):
|
247 |
+
function_name = tool_choice.get('function', {}).get('name', None)
|
248 |
+
if not function_name:
|
249 |
+
return []
|
250 |
+
filtered_tools = [
|
251 |
+
tool for tool in tools
|
252 |
+
if tool.get('function', {}).get('name') == function_name
|
253 |
+
]
|
254 |
+
return filtered_tools
|
255 |
+
|
256 |
+
if tool_choice != "none":
|
257 |
+
if isinstance(tool_choice, dict):
|
258 |
+
tools = filter_tools(tool_choice, tools)
|
259 |
+
if tools:
|
260 |
+
processed_messages.append(
|
261 |
+
{
|
262 |
+
"role": "system",
|
263 |
+
"content": None,
|
264 |
+
"tools": tools
|
265 |
+
}
|
266 |
+
)
|
267 |
+
msg_has_sys = True
|
268 |
+
|
269 |
+
if isinstance(tool_choice, dict) and tools:
|
270 |
+
processed_messages.append(
|
271 |
+
{
|
272 |
+
"role": "assistant",
|
273 |
+
"metadata": tool_choice["function"]["name"],
|
274 |
+
"content": ""
|
275 |
+
}
|
276 |
+
)
|
277 |
+
|
278 |
+
for m in _messages:
|
279 |
+
role, content, func_call = m.role, m.content, m.function_call
|
280 |
+
tool_calls = getattr(m, 'tool_calls', None)
|
281 |
+
|
282 |
+
if role == "function":
|
283 |
+
processed_messages.append(
|
284 |
+
{
|
285 |
+
"role": "observation",
|
286 |
+
"content": content
|
287 |
+
}
|
288 |
+
)
|
289 |
+
elif role == "tool":
|
290 |
+
processed_messages.append(
|
291 |
+
{
|
292 |
+
"role": "observation",
|
293 |
+
"content": content,
|
294 |
+
"function_call": True
|
295 |
+
}
|
296 |
+
)
|
297 |
+
elif role == "assistant":
|
298 |
+
if tool_calls:
|
299 |
+
for tool_call in tool_calls:
|
300 |
+
processed_messages.append(
|
301 |
+
{
|
302 |
+
"role": "assistant",
|
303 |
+
"metadata": tool_call.function.name,
|
304 |
+
"content": tool_call.function.arguments
|
305 |
+
}
|
306 |
+
)
|
307 |
+
else:
|
308 |
+
for response in content.split("\n"):
|
309 |
+
if "\n" in response:
|
310 |
+
metadata, sub_content = response.split("\n", maxsplit=1)
|
311 |
+
else:
|
312 |
+
metadata, sub_content = "", response
|
313 |
+
processed_messages.append(
|
314 |
+
{
|
315 |
+
"role": role,
|
316 |
+
"metadata": metadata,
|
317 |
+
"content": sub_content.strip()
|
318 |
+
}
|
319 |
+
)
|
320 |
+
else:
|
321 |
+
if role == "system" and msg_has_sys:
|
322 |
+
msg_has_sys = False
|
323 |
+
continue
|
324 |
+
processed_messages.append({"role": role, "content": content})
|
325 |
+
|
326 |
+
if not tools or tool_choice == "none":
|
327 |
+
for m in _messages:
|
328 |
+
if m.role == 'system':
|
329 |
+
processed_messages.insert(0, {"role": m.role, "content": m.content})
|
330 |
+
break
|
331 |
+
return processed_messages
|
332 |
+
|
333 |
+
|
334 |
+
@app.get("/health")
|
335 |
+
async def health() -> Response:
|
336 |
+
"""Health check."""
|
337 |
+
return Response(status_code=200)
|
338 |
+
|
339 |
+
|
340 |
+
@app.get("/v1/models", response_model=ModelList)
|
341 |
+
async def list_models():
|
342 |
+
model_card = ModelCard(id="glm-4")
|
343 |
+
return ModelList(data=[model_card])
|
344 |
+
|
345 |
+
|
346 |
+
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
347 |
+
async def create_chat_completion(request: ChatCompletionRequest):
|
348 |
+
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
349 |
+
raise HTTPException(status_code=400, detail="Invalid request")
|
350 |
+
|
351 |
+
gen_params = dict(
|
352 |
+
messages=request.messages,
|
353 |
+
temperature=request.temperature,
|
354 |
+
top_p=request.top_p,
|
355 |
+
max_tokens=request.max_tokens or 1024,
|
356 |
+
echo=False,
|
357 |
+
stream=request.stream,
|
358 |
+
repetition_penalty=request.repetition_penalty,
|
359 |
+
tools=request.tools,
|
360 |
+
tool_choice=request.tool_choice,
|
361 |
+
)
|
362 |
+
logger.debug(f"==== request ====\n{gen_params}")
|
363 |
+
|
364 |
+
if request.stream:
|
365 |
+
predict_stream_generator = predict_stream(request.model, gen_params)
|
366 |
+
output = await anext(predict_stream_generator)
|
367 |
+
if output:
|
368 |
+
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
|
369 |
+
logger.debug(f"First result output:\n{output}")
|
370 |
+
|
371 |
+
function_call = None
|
372 |
+
if output and request.tools:
|
373 |
+
try:
|
374 |
+
function_call = process_response(output, request.tools, use_tool=True)
|
375 |
+
except:
|
376 |
+
logger.warning("Failed to parse tool call")
|
377 |
+
|
378 |
+
if isinstance(function_call, dict):
|
379 |
+
function_call = ChoiceDeltaToolCallFunction(**function_call)
|
380 |
+
generate = parse_output_text(request.model, output, function_call=function_call)
|
381 |
+
return EventSourceResponse(generate, media_type="text/event-stream")
|
382 |
+
else:
|
383 |
+
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
|
384 |
+
response = ""
|
385 |
+
async for response in generate_stream_glm4(gen_params):
|
386 |
+
pass
|
387 |
+
|
388 |
+
if response["text"].startswith("\n"):
|
389 |
+
response["text"] = response["text"][1:]
|
390 |
+
response["text"] = response["text"].strip()
|
391 |
+
|
392 |
+
usage = UsageInfo()
|
393 |
+
|
394 |
+
function_call, finish_reason = None, "stop"
|
395 |
+
tool_calls = None
|
396 |
+
if request.tools:
|
397 |
+
try:
|
398 |
+
function_call = process_response(response["text"], request.tools, use_tool=True)
|
399 |
+
except Exception as e:
|
400 |
+
logger.warning(f"Failed to parse tool call: {e}")
|
401 |
+
if isinstance(function_call, dict):
|
402 |
+
finish_reason = "tool_calls"
|
403 |
+
function_call_response = ChoiceDeltaToolCallFunction(**function_call)
|
404 |
+
function_call_instance = FunctionCall(
|
405 |
+
name=function_call_response.name,
|
406 |
+
arguments=function_call_response.arguments
|
407 |
+
)
|
408 |
+
tool_calls = [
|
409 |
+
ChatCompletionMessageToolCall(
|
410 |
+
id=generate_id('call_', 24),
|
411 |
+
function=function_call_instance,
|
412 |
+
type="function")]
|
413 |
+
|
414 |
+
message = ChatMessage(
|
415 |
+
role="assistant",
|
416 |
+
content=None if tool_calls else response["text"],
|
417 |
+
function_call=None,
|
418 |
+
tool_calls=tool_calls,
|
419 |
+
)
|
420 |
+
|
421 |
+
logger.debug(f"==== message ====\n{message}")
|
422 |
+
|
423 |
+
choice_data = ChatCompletionResponseChoice(
|
424 |
+
index=0,
|
425 |
+
message=message,
|
426 |
+
finish_reason=finish_reason,
|
427 |
+
)
|
428 |
+
task_usage = UsageInfo.model_validate(response["usage"])
|
429 |
+
for usage_key, usage_value in task_usage.model_dump().items():
|
430 |
+
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
431 |
+
|
432 |
+
return ChatCompletionResponse(
|
433 |
+
model=request.model,
|
434 |
+
choices=[choice_data],
|
435 |
+
object="chat.completion",
|
436 |
+
usage=usage
|
437 |
+
)
|
438 |
+
|
439 |
+
|
440 |
+
async def predict_stream(model_id, gen_params):
|
441 |
+
output = ""
|
442 |
+
is_function_call = False
|
443 |
+
has_send_first_chunk = False
|
444 |
+
created_time = int(time.time())
|
445 |
+
function_name = None
|
446 |
+
response_id = generate_id('chatcmpl-', 29)
|
447 |
+
system_fingerprint = generate_id('fp_', 9)
|
448 |
+
tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else {}
|
449 |
+
async for new_response in generate_stream_glm4(gen_params):
|
450 |
+
decoded_unicode = new_response["text"]
|
451 |
+
delta_text = decoded_unicode[len(output):]
|
452 |
+
output = decoded_unicode
|
453 |
+
lines = output.strip().split("\n")
|
454 |
+
|
455 |
+
# 检查是否为工具
|
456 |
+
# 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
|
457 |
+
##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。
|
458 |
+
|
459 |
+
if not is_function_call and len(lines) >= 2:
|
460 |
+
first_line = lines[0].strip()
|
461 |
+
if first_line in tools:
|
462 |
+
is_function_call = True
|
463 |
+
function_name = first_line
|
464 |
+
|
465 |
+
# 工具调用返回
|
466 |
+
if is_function_call:
|
467 |
+
if not has_send_first_chunk:
|
468 |
+
function_call = {"name": function_name, "arguments": ""}
|
469 |
+
tool_call = ChatCompletionMessageToolCall(
|
470 |
+
index=0,
|
471 |
+
id=generate_id('call_', 24),
|
472 |
+
function=FunctionCall(**function_call),
|
473 |
+
type="function"
|
474 |
+
)
|
475 |
+
message = DeltaMessage(
|
476 |
+
content=None,
|
477 |
+
role="assistant",
|
478 |
+
function_call=None,
|
479 |
+
tool_calls=[tool_call]
|
480 |
+
)
|
481 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
482 |
+
index=0,
|
483 |
+
delta=message,
|
484 |
+
finish_reason=None
|
485 |
+
)
|
486 |
+
chunk = ChatCompletionResponse(
|
487 |
+
model=model_id,
|
488 |
+
id=response_id,
|
489 |
+
choices=[choice_data],
|
490 |
+
created=created_time,
|
491 |
+
system_fingerprint=system_fingerprint,
|
492 |
+
object="chat.completion.chunk"
|
493 |
+
)
|
494 |
+
yield ""
|
495 |
+
yield chunk.model_dump_json(exclude_unset=True)
|
496 |
+
has_send_first_chunk = True
|
497 |
+
|
498 |
+
function_call = {"name": None, "arguments": delta_text}
|
499 |
+
tool_call = ChatCompletionMessageToolCall(
|
500 |
+
index=0,
|
501 |
+
id=None,
|
502 |
+
function=FunctionCall(**function_call),
|
503 |
+
type="function"
|
504 |
+
)
|
505 |
+
message = DeltaMessage(
|
506 |
+
content=None,
|
507 |
+
role=None,
|
508 |
+
function_call=None,
|
509 |
+
tool_calls=[tool_call]
|
510 |
+
)
|
511 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
512 |
+
index=0,
|
513 |
+
delta=message,
|
514 |
+
finish_reason=None
|
515 |
+
)
|
516 |
+
chunk = ChatCompletionResponse(
|
517 |
+
model=model_id,
|
518 |
+
id=response_id,
|
519 |
+
choices=[choice_data],
|
520 |
+
created=created_time,
|
521 |
+
system_fingerprint=system_fingerprint,
|
522 |
+
object="chat.completion.chunk"
|
523 |
+
)
|
524 |
+
yield chunk.model_dump_json(exclude_unset=True)
|
525 |
+
|
526 |
+
# 用户请求了 Function Call 但是框架还没确定是否为Function Call
|
527 |
+
elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call:
|
528 |
+
continue
|
529 |
+
|
530 |
+
# 常规返回
|
531 |
+
else:
|
532 |
+
finish_reason = new_response.get("finish_reason", None)
|
533 |
+
if not has_send_first_chunk:
|
534 |
+
message = DeltaMessage(
|
535 |
+
content="",
|
536 |
+
role="assistant",
|
537 |
+
function_call=None,
|
538 |
+
)
|
539 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
540 |
+
index=0,
|
541 |
+
delta=message,
|
542 |
+
finish_reason=finish_reason
|
543 |
+
)
|
544 |
+
chunk = ChatCompletionResponse(
|
545 |
+
model=model_id,
|
546 |
+
id=response_id,
|
547 |
+
choices=[choice_data],
|
548 |
+
created=created_time,
|
549 |
+
system_fingerprint=system_fingerprint,
|
550 |
+
object="chat.completion.chunk"
|
551 |
+
)
|
552 |
+
yield chunk.model_dump_json(exclude_unset=True)
|
553 |
+
has_send_first_chunk = True
|
554 |
+
|
555 |
+
message = DeltaMessage(
|
556 |
+
content=delta_text,
|
557 |
+
role="assistant",
|
558 |
+
function_call=None,
|
559 |
+
)
|
560 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
561 |
+
index=0,
|
562 |
+
delta=message,
|
563 |
+
finish_reason=finish_reason
|
564 |
+
)
|
565 |
+
chunk = ChatCompletionResponse(
|
566 |
+
model=model_id,
|
567 |
+
id=response_id,
|
568 |
+
choices=[choice_data],
|
569 |
+
created=created_time,
|
570 |
+
system_fingerprint=system_fingerprint,
|
571 |
+
object="chat.completion.chunk"
|
572 |
+
)
|
573 |
+
yield chunk.model_dump_json(exclude_unset=True)
|
574 |
+
|
575 |
+
# 工具调用需要额外返回一个字段以对齐 OpenAI 接口
|
576 |
+
if is_function_call:
|
577 |
+
yield ChatCompletionResponse(
|
578 |
+
model=model_id,
|
579 |
+
id=response_id,
|
580 |
+
system_fingerprint=system_fingerprint,
|
581 |
+
choices=[
|
582 |
+
ChatCompletionResponseStreamChoice(
|
583 |
+
index=0,
|
584 |
+
delta=DeltaMessage(
|
585 |
+
content=None,
|
586 |
+
role=None,
|
587 |
+
function_call=None,
|
588 |
+
),
|
589 |
+
finish_reason="tool_calls"
|
590 |
+
)],
|
591 |
+
created=created_time,
|
592 |
+
object="chat.completion.chunk",
|
593 |
+
usage=None
|
594 |
+
).model_dump_json(exclude_unset=True)
|
595 |
+
yield '[DONE]'
|
596 |
+
|
597 |
+
|
598 |
+
async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
|
599 |
+
delta = DeltaMessage(role="assistant", content=value)
|
600 |
+
if function_call is not None:
|
601 |
+
delta.function_call = function_call
|
602 |
+
|
603 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
604 |
+
index=0,
|
605 |
+
delta=delta,
|
606 |
+
finish_reason=None
|
607 |
+
)
|
608 |
+
chunk = ChatCompletionResponse(
|
609 |
+
model=model_id,
|
610 |
+
choices=[choice_data],
|
611 |
+
object="chat.completion.chunk"
|
612 |
+
)
|
613 |
+
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
614 |
+
yield '[DONE]'
|
615 |
+
|
616 |
+
|
617 |
+
if __name__ == "__main__":
|
618 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
619 |
+
engine_args = AsyncEngineArgs(
|
620 |
+
model=MODEL_PATH,
|
621 |
+
tokenizer=MODEL_PATH,
|
622 |
+
# 如果你有多张显卡,可以在这里设置成你的显卡数量
|
623 |
+
tensor_parallel_size=1,
|
624 |
+
dtype="bfloat16",
|
625 |
+
trust_remote_code=True,
|
626 |
+
# 占用显存的比例,请根据你的显卡显存大小设置合适的值,例如,如果你的显卡有80G,您只想使用24G,请按照24/80=0.3设置
|
627 |
+
gpu_memory_utilization=0.9,
|
628 |
+
enforce_eager=True,
|
629 |
+
worker_use_ray=False,
|
630 |
+
engine_use_ray=False,
|
631 |
+
disable_log_requests=True,
|
632 |
+
max_model_len=MAX_MODEL_LENGTH,
|
633 |
+
)
|
634 |
+
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
635 |
+
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
requirements.txt
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# use vllm
|
2 |
+
# vllm>=0.5.0
|
3 |
+
|
4 |
+
torch>=2.3.0
|
5 |
+
torchvision>=0.18.0
|
6 |
+
transformers==4.40.0
|
7 |
+
huggingface-hub>=0.23.1
|
8 |
+
sentencepiece>=0.2.0
|
9 |
+
pydantic>=2.7.1
|
10 |
+
timm>=0.9.16
|
11 |
+
tiktoken>=0.7.0
|
12 |
+
accelerate>=0.30.1
|
13 |
+
sentence_transformers>=2.7.0
|
14 |
+
|
15 |
+
# web demo
|
16 |
+
gradio>=4.33.0
|
17 |
+
|
18 |
+
# openai demo
|
19 |
+
openai>=1.34.0
|
20 |
+
einops>=0.7.0
|
21 |
+
sse-starlette>=2.1.0
|
22 |
+
|
23 |
+
# INT4
|
24 |
+
# bitsandbytes>=0.43.1
|
25 |
+
bitsandbytes>=0.42.0
|
26 |
+
|
27 |
+
|
28 |
+
# PEFT model, not need if you don't use PEFT finetune model.
|
29 |
+
peft>=0.11.0
|
trans_batch_demo.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
|
3 |
+
Here is an example of using batch request glm-4-9b,
|
4 |
+
here you need to build the conversation format yourself and then call the batch function to make batch requests.
|
5 |
+
Please note that in this demo, the memory consumption is significantly higher.
|
6 |
+
|
7 |
+
"""
|
8 |
+
|
9 |
+
from typing import Optional, Union
|
10 |
+
from transformers import AutoModel, AutoTokenizer, LogitsProcessorList
|
11 |
+
|
12 |
+
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
13 |
+
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
15 |
+
MODEL_PATH,
|
16 |
+
trust_remote_code=True,
|
17 |
+
encode_special_tokens=True)
|
18 |
+
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
|
19 |
+
|
20 |
+
|
21 |
+
def process_model_outputs(inputs, outputs, tokenizer):
|
22 |
+
responses = []
|
23 |
+
for input_ids, output_ids in zip(inputs.input_ids, outputs):
|
24 |
+
response = tokenizer.decode(output_ids[len(input_ids):], skip_special_tokens=True).strip()
|
25 |
+
responses.append(response)
|
26 |
+
return responses
|
27 |
+
|
28 |
+
|
29 |
+
def batch(
|
30 |
+
model,
|
31 |
+
tokenizer,
|
32 |
+
messages: Union[str, list[str]],
|
33 |
+
max_input_tokens: int = 8192,
|
34 |
+
max_new_tokens: int = 8192,
|
35 |
+
num_beams: int = 1,
|
36 |
+
do_sample: bool = True,
|
37 |
+
top_p: float = 0.8,
|
38 |
+
temperature: float = 0.8,
|
39 |
+
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
|
40 |
+
):
|
41 |
+
messages = [messages] if isinstance(messages, str) else messages
|
42 |
+
batched_inputs = tokenizer(messages, return_tensors="pt", padding="max_length", truncation=True,
|
43 |
+
max_length=max_input_tokens).to(model.device)
|
44 |
+
|
45 |
+
gen_kwargs = {
|
46 |
+
"max_new_tokens": max_new_tokens,
|
47 |
+
"num_beams": num_beams,
|
48 |
+
"do_sample": do_sample,
|
49 |
+
"top_p": top_p,
|
50 |
+
"temperature": temperature,
|
51 |
+
"logits_processor": logits_processor,
|
52 |
+
"eos_token_id": model.config.eos_token_id
|
53 |
+
}
|
54 |
+
batched_outputs = model.generate(**batched_inputs, **gen_kwargs)
|
55 |
+
batched_response = process_model_outputs(batched_inputs, batched_outputs, tokenizer)
|
56 |
+
return batched_response
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
|
61 |
+
batch_message = [
|
62 |
+
[
|
63 |
+
{"role": "user", "content": "我的爸爸和妈妈结婚为什么不能带我去"},
|
64 |
+
{"role": "assistant", "content": "因为他们结婚时你还没有出生"},
|
65 |
+
{"role": "user", "content": "我刚才的提问是"}
|
66 |
+
],
|
67 |
+
[
|
68 |
+
{"role": "user", "content": "你好,你是谁"}
|
69 |
+
]
|
70 |
+
]
|
71 |
+
|
72 |
+
batch_inputs = []
|
73 |
+
max_input_tokens = 1024
|
74 |
+
for i, messages in enumerate(batch_message):
|
75 |
+
new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
76 |
+
max_input_tokens = max(max_input_tokens, len(new_batch_input))
|
77 |
+
batch_inputs.append(new_batch_input)
|
78 |
+
gen_kwargs = {
|
79 |
+
"max_input_tokens": max_input_tokens,
|
80 |
+
"max_new_tokens": 8192,
|
81 |
+
"do_sample": True,
|
82 |
+
"top_p": 0.8,
|
83 |
+
"temperature": 0.8,
|
84 |
+
"num_beams": 1,
|
85 |
+
}
|
86 |
+
|
87 |
+
batch_responses = batch(model, tokenizer, batch_inputs, **gen_kwargs)
|
88 |
+
for response in batch_responses:
|
89 |
+
print("=" * 10)
|
90 |
+
print(response)
|
trans_cli_demo.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script creates a CLI demo with transformers backend for the glm-4-9b model,
|
3 |
+
allowing users to interact with the model through a command-line interface.
|
4 |
+
|
5 |
+
Usage:
|
6 |
+
- Run the script to start the CLI demo.
|
7 |
+
- Interact with the model by typing questions and receiving responses.
|
8 |
+
|
9 |
+
Note: The script includes a modification to handle markdown to plain text conversion,
|
10 |
+
ensuring that the CLI interface displays formatted text correctly.
|
11 |
+
|
12 |
+
If you use flash attention, you should install the flash-attn and add attn_implementation="flash_attention_2" in model loading.
|
13 |
+
"""
|
14 |
+
|
15 |
+
import os
|
16 |
+
import torch
|
17 |
+
from threading import Thread
|
18 |
+
from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, AutoModel
|
19 |
+
|
20 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
|
21 |
+
|
22 |
+
print("MODEL_PATH: " + MODEL_PATH)
|
23 |
+
|
24 |
+
## If use peft model.
|
25 |
+
# def load_model_and_tokenizer(model_dir, trust_remote_code: bool = True):
|
26 |
+
# if (model_dir / 'adapter_config.json').exists():
|
27 |
+
# model = AutoModel.from_pretrained(
|
28 |
+
# model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
29 |
+
# )
|
30 |
+
# tokenizer_dir = model.peft_config['default'].base_model_name_or_path
|
31 |
+
# else:
|
32 |
+
# model = AutoModel.from_pretrained(
|
33 |
+
# model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
34 |
+
# )
|
35 |
+
# tokenizer_dir = model_dir
|
36 |
+
# tokenizer = AutoTokenizer.from_pretrained(
|
37 |
+
# tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False
|
38 |
+
# )
|
39 |
+
# return model, tokenizer
|
40 |
+
|
41 |
+
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
43 |
+
MODEL_PATH,
|
44 |
+
trust_remote_code=True,
|
45 |
+
encode_special_tokens=True
|
46 |
+
)
|
47 |
+
|
48 |
+
model = AutoModel.from_pretrained(
|
49 |
+
MODEL_PATH,
|
50 |
+
trust_remote_code=True,
|
51 |
+
# attn_implementation="flash_attention_2", # Use Flash Attention
|
52 |
+
# torch_dtype=torch.bfloat16, #using flash-attn must use bfloat16 or float16
|
53 |
+
device_map="auto").eval()
|
54 |
+
|
55 |
+
|
56 |
+
class StopOnTokens(StoppingCriteria):
|
57 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
58 |
+
stop_ids = model.config.eos_token_id
|
59 |
+
for stop_id in stop_ids:
|
60 |
+
if input_ids[0][-1] == stop_id:
|
61 |
+
return True
|
62 |
+
return False
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
history = []
|
67 |
+
max_length = 8192
|
68 |
+
top_p = 0.8
|
69 |
+
temperature = 0.6
|
70 |
+
stop = StopOnTokens()
|
71 |
+
|
72 |
+
print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
|
73 |
+
while True:
|
74 |
+
user_input = input("\nYou: ")
|
75 |
+
if user_input.lower() in ["exit", "quit"]:
|
76 |
+
break
|
77 |
+
history.append([user_input, ""])
|
78 |
+
|
79 |
+
messages = []
|
80 |
+
for idx, (user_msg, model_msg) in enumerate(history):
|
81 |
+
if idx == len(history) - 1 and not model_msg:
|
82 |
+
messages.append({"role": "user", "content": user_msg})
|
83 |
+
break
|
84 |
+
if user_msg:
|
85 |
+
messages.append({"role": "user", "content": user_msg})
|
86 |
+
if model_msg:
|
87 |
+
messages.append({"role": "assistant", "content": model_msg})
|
88 |
+
model_inputs = tokenizer.apply_chat_template(
|
89 |
+
messages,
|
90 |
+
add_generation_prompt=True,
|
91 |
+
tokenize=True,
|
92 |
+
return_tensors="pt"
|
93 |
+
).to(model.device)
|
94 |
+
streamer = TextIteratorStreamer(
|
95 |
+
tokenizer=tokenizer,
|
96 |
+
timeout=60,
|
97 |
+
skip_prompt=True,
|
98 |
+
skip_special_tokens=True
|
99 |
+
)
|
100 |
+
generate_kwargs = {
|
101 |
+
"input_ids": model_inputs,
|
102 |
+
"streamer": streamer,
|
103 |
+
"max_new_tokens": max_length,
|
104 |
+
"do_sample": True,
|
105 |
+
"top_p": top_p,
|
106 |
+
"temperature": temperature,
|
107 |
+
"stopping_criteria": StoppingCriteriaList([stop]),
|
108 |
+
"repetition_penalty": 1.2,
|
109 |
+
"eos_token_id": model.config.eos_token_id,
|
110 |
+
}
|
111 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
112 |
+
t.start()
|
113 |
+
print("GLM-4:", end="", flush=True)
|
114 |
+
for new_token in streamer:
|
115 |
+
if new_token:
|
116 |
+
print(new_token, end="", flush=True)
|
117 |
+
history[-1][1] += new_token
|
118 |
+
|
119 |
+
history[-1][1] = history[-1][1].strip()
|
trans_cli_vision_demo.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script creates a CLI demo with transformers backend for the glm-4v-9b model,
|
3 |
+
allowing users to interact with the model through a command-line interface.
|
4 |
+
|
5 |
+
Usage:
|
6 |
+
- Run the script to start the CLI demo.
|
7 |
+
- Interact with the model by typing questions and receiving responses.
|
8 |
+
|
9 |
+
Note: The script includes a modification to handle markdown to plain text conversion,
|
10 |
+
ensuring that the CLI interface displays formatted text correctly.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import os
|
14 |
+
import torch
|
15 |
+
from threading import Thread
|
16 |
+
from transformers import (
|
17 |
+
AutoTokenizer,
|
18 |
+
StoppingCriteria,
|
19 |
+
StoppingCriteriaList,
|
20 |
+
TextIteratorStreamer, AutoModel, BitsAndBytesConfig
|
21 |
+
)
|
22 |
+
|
23 |
+
from PIL import Image
|
24 |
+
|
25 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4v-9b')
|
26 |
+
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
28 |
+
MODEL_PATH,
|
29 |
+
trust_remote_code=True,
|
30 |
+
encode_special_tokens=True
|
31 |
+
)
|
32 |
+
model = AutoModel.from_pretrained(
|
33 |
+
MODEL_PATH,
|
34 |
+
trust_remote_code=True,
|
35 |
+
# attn_implementation="flash_attention_2", # Use Flash Attention
|
36 |
+
# torch_dtype=torch.bfloat16, # using flash-attn must use bfloat16 or float16,
|
37 |
+
device_map="auto",
|
38 |
+
).eval()
|
39 |
+
|
40 |
+
|
41 |
+
## For INT4 inference
|
42 |
+
# model = AutoModel.from_pretrained(
|
43 |
+
# MODEL_PATH,
|
44 |
+
# trust_remote_code=True,
|
45 |
+
# quantization_config=BitsAndBytesConfig(load_in_4bit=True),
|
46 |
+
# torch_dtype=torch.bfloat16,
|
47 |
+
# low_cpu_mem_usage=True
|
48 |
+
# ).eval()
|
49 |
+
|
50 |
+
class StopOnTokens(StoppingCriteria):
|
51 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
52 |
+
stop_ids = model.config.eos_token_id
|
53 |
+
for stop_id in stop_ids:
|
54 |
+
if input_ids[0][-1] == stop_id:
|
55 |
+
return True
|
56 |
+
return False
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
history = []
|
61 |
+
max_length = 1024
|
62 |
+
top_p = 0.8
|
63 |
+
temperature = 0.6
|
64 |
+
stop = StopOnTokens()
|
65 |
+
uploaded = False
|
66 |
+
image = None
|
67 |
+
print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
|
68 |
+
image_path = input("Image Path:")
|
69 |
+
try:
|
70 |
+
image = Image.open(image_path).convert("RGB")
|
71 |
+
except:
|
72 |
+
print("Invalid image path. Continuing with text conversation.")
|
73 |
+
while True:
|
74 |
+
user_input = input("\nYou: ")
|
75 |
+
if user_input.lower() in ["exit", "quit"]:
|
76 |
+
break
|
77 |
+
history.append([user_input, ""])
|
78 |
+
|
79 |
+
messages = []
|
80 |
+
for idx, (user_msg, model_msg) in enumerate(history):
|
81 |
+
if idx == len(history) - 1 and not model_msg:
|
82 |
+
messages.append({"role": "user", "content": user_msg})
|
83 |
+
if image and not uploaded:
|
84 |
+
messages[-1].update({"image": image})
|
85 |
+
uploaded = True
|
86 |
+
break
|
87 |
+
if user_msg:
|
88 |
+
messages.append({"role": "user", "content": user_msg})
|
89 |
+
if model_msg:
|
90 |
+
messages.append({"role": "assistant", "content": model_msg})
|
91 |
+
model_inputs = tokenizer.apply_chat_template(
|
92 |
+
messages,
|
93 |
+
add_generation_prompt=True,
|
94 |
+
tokenize=True,
|
95 |
+
return_tensors="pt",
|
96 |
+
return_dict=True
|
97 |
+
).to(next(model.parameters()).device)
|
98 |
+
streamer = TextIteratorStreamer(
|
99 |
+
tokenizer=tokenizer,
|
100 |
+
timeout=60,
|
101 |
+
skip_prompt=True,
|
102 |
+
skip_special_tokens=True
|
103 |
+
)
|
104 |
+
generate_kwargs = {
|
105 |
+
**model_inputs,
|
106 |
+
"streamer": streamer,
|
107 |
+
"max_new_tokens": max_length,
|
108 |
+
"do_sample": True,
|
109 |
+
"top_p": top_p,
|
110 |
+
"temperature": temperature,
|
111 |
+
"stopping_criteria": StoppingCriteriaList([stop]),
|
112 |
+
"repetition_penalty": 1.2,
|
113 |
+
"eos_token_id": [151329, 151336, 151338],
|
114 |
+
}
|
115 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
116 |
+
t.start()
|
117 |
+
print("GLM-4V:", end="", flush=True)
|
118 |
+
for new_token in streamer:
|
119 |
+
if new_token:
|
120 |
+
print(new_token, end="", flush=True)
|
121 |
+
history[-1][1] += new_token
|
122 |
+
|
123 |
+
history[-1][1] = history[-1][1].strip()
|
trans_stress_test.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import time
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
|
4 |
+
import torch
|
5 |
+
from threading import Thread
|
6 |
+
|
7 |
+
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
8 |
+
|
9 |
+
|
10 |
+
def stress_test(token_len, n, num_gpu):
|
11 |
+
device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu")
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
13 |
+
MODEL_PATH,
|
14 |
+
trust_remote_code=True,
|
15 |
+
padding_side="left"
|
16 |
+
)
|
17 |
+
model = AutoModelForCausalLM.from_pretrained(
|
18 |
+
MODEL_PATH,
|
19 |
+
trust_remote_code=True,
|
20 |
+
torch_dtype=torch.bfloat16
|
21 |
+
).to(device).eval()
|
22 |
+
|
23 |
+
# Use INT4 weight infer
|
24 |
+
# model = AutoModelForCausalLM.from_pretrained(
|
25 |
+
# MODEL_PATH,
|
26 |
+
# trust_remote_code=True,
|
27 |
+
# quantization_config=BitsAndBytesConfig(load_in_4bit=True),
|
28 |
+
# low_cpu_mem_usage=True,
|
29 |
+
# ).eval()
|
30 |
+
|
31 |
+
times = []
|
32 |
+
decode_times = []
|
33 |
+
|
34 |
+
print("Warming up...")
|
35 |
+
vocab_size = tokenizer.vocab_size
|
36 |
+
warmup_token_len = 20
|
37 |
+
random_token_ids = torch.randint(3, vocab_size - 200, (warmup_token_len - 5,), dtype=torch.long)
|
38 |
+
start_tokens = [151331, 151333, 151336, 198]
|
39 |
+
end_tokens = [151337]
|
40 |
+
input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(0).to(
|
41 |
+
device)
|
42 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
|
43 |
+
position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
|
44 |
+
warmup_inputs = {
|
45 |
+
'input_ids': input_ids,
|
46 |
+
'attention_mask': attention_mask,
|
47 |
+
'position_ids': position_ids
|
48 |
+
}
|
49 |
+
with torch.no_grad():
|
50 |
+
_ = model.generate(
|
51 |
+
input_ids=warmup_inputs['input_ids'],
|
52 |
+
attention_mask=warmup_inputs['attention_mask'],
|
53 |
+
max_new_tokens=2048,
|
54 |
+
do_sample=False,
|
55 |
+
repetition_penalty=1.0,
|
56 |
+
eos_token_id=[151329, 151336, 151338]
|
57 |
+
)
|
58 |
+
print("Warming up complete. Starting stress test...")
|
59 |
+
|
60 |
+
for i in range(n):
|
61 |
+
random_token_ids = torch.randint(3, vocab_size - 200, (token_len - 5,), dtype=torch.long)
|
62 |
+
input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(
|
63 |
+
0).to(device)
|
64 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
|
65 |
+
position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
|
66 |
+
test_inputs = {
|
67 |
+
'input_ids': input_ids,
|
68 |
+
'attention_mask': attention_mask,
|
69 |
+
'position_ids': position_ids
|
70 |
+
}
|
71 |
+
|
72 |
+
streamer = TextIteratorStreamer(
|
73 |
+
tokenizer=tokenizer,
|
74 |
+
timeout=36000,
|
75 |
+
skip_prompt=True,
|
76 |
+
skip_special_tokens=True
|
77 |
+
)
|
78 |
+
|
79 |
+
generate_kwargs = {
|
80 |
+
"input_ids": test_inputs['input_ids'],
|
81 |
+
"attention_mask": test_inputs['attention_mask'],
|
82 |
+
"max_new_tokens": 512,
|
83 |
+
"do_sample": False,
|
84 |
+
"repetition_penalty": 1.0,
|
85 |
+
"eos_token_id": [151329, 151336, 151338],
|
86 |
+
"streamer": streamer
|
87 |
+
}
|
88 |
+
|
89 |
+
start_time = time.time()
|
90 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
91 |
+
t.start()
|
92 |
+
|
93 |
+
first_token_time = None
|
94 |
+
all_token_times = []
|
95 |
+
|
96 |
+
for token in streamer:
|
97 |
+
current_time = time.time()
|
98 |
+
if first_token_time is None:
|
99 |
+
first_token_time = current_time
|
100 |
+
times.append(first_token_time - start_time)
|
101 |
+
all_token_times.append(current_time)
|
102 |
+
|
103 |
+
t.join()
|
104 |
+
end_time = time.time()
|
105 |
+
|
106 |
+
avg_decode_time_per_token = len(all_token_times) / (end_time - first_token_time) if all_token_times else 0
|
107 |
+
decode_times.append(avg_decode_time_per_token)
|
108 |
+
print(
|
109 |
+
f"Iteration {i + 1}/{n} - Prefilling Time: {times[-1]:.4f} seconds - Average Decode Time: {avg_decode_time_per_token:.4f} tokens/second")
|
110 |
+
|
111 |
+
torch.cuda.empty_cache()
|
112 |
+
|
113 |
+
avg_first_token_time = sum(times) / n
|
114 |
+
avg_decode_time = sum(decode_times) / n
|
115 |
+
print(f"\nAverage First Token Time over {n} iterations: {avg_first_token_time:.4f} seconds")
|
116 |
+
print(f"Average Decode Time per Token over {n} iterations: {avg_decode_time:.4f} tokens/second")
|
117 |
+
return times, avg_first_token_time, decode_times, avg_decode_time
|
118 |
+
|
119 |
+
|
120 |
+
def main():
|
121 |
+
parser = argparse.ArgumentParser(description="Stress test for model inference")
|
122 |
+
parser.add_argument('--token_len', type=int, default=1000, help='Number of tokens for each test')
|
123 |
+
parser.add_argument('--n', type=int, default=3, help='Number of iterations for the stress test')
|
124 |
+
parser.add_argument('--num_gpu', type=int, default=1, help='Number of GPUs to use for inference')
|
125 |
+
args = parser.parse_args()
|
126 |
+
|
127 |
+
token_len = args.token_len
|
128 |
+
n = args.n
|
129 |
+
num_gpu = args.num_gpu
|
130 |
+
|
131 |
+
stress_test(token_len, n, num_gpu)
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
main()
|
trans_web_demo.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script creates an interactive web demo for the GLM-4-9B model using Gradio,
|
3 |
+
a Python library for building quick and easy UI components for machine learning models.
|
4 |
+
It's designed to showcase the capabilities of the GLM-4-9B model in a user-friendly interface,
|
5 |
+
allowing users to interact with the model through a chat-like interface.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
from pathlib import Path
|
10 |
+
from threading import Thread
|
11 |
+
from typing import Union
|
12 |
+
|
13 |
+
import gradio as gr
|
14 |
+
import torch
|
15 |
+
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
|
16 |
+
from transformers import (
|
17 |
+
AutoModelForCausalLM,
|
18 |
+
AutoTokenizer,
|
19 |
+
PreTrainedModel,
|
20 |
+
PreTrainedTokenizer,
|
21 |
+
PreTrainedTokenizerFast,
|
22 |
+
StoppingCriteria,
|
23 |
+
StoppingCriteriaList,
|
24 |
+
TextIteratorStreamer
|
25 |
+
)
|
26 |
+
|
27 |
+
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
|
28 |
+
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
29 |
+
|
30 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
|
31 |
+
MODEL_PATH = "/Users/zmac/.cache/huggingface/hub/models--THUDM--glm-4-9b-chat/snapshots/04419001bc63e05e70991ade6da1f91c4aeec278"
|
32 |
+
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
|
33 |
+
|
34 |
+
|
35 |
+
def _resolve_path(path: Union[str, Path]) -> Path:
|
36 |
+
return Path(path).expanduser().resolve()
|
37 |
+
|
38 |
+
|
39 |
+
def load_model_and_tokenizer(
|
40 |
+
model_dir: Union[str, Path], trust_remote_code: bool = True
|
41 |
+
) -> tuple[ModelType, TokenizerType]:
|
42 |
+
model_dir = _resolve_path(model_dir)
|
43 |
+
if (model_dir / 'adapter_config.json').exists():
|
44 |
+
model = AutoPeftModelForCausalLM.from_pretrained(
|
45 |
+
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
46 |
+
)
|
47 |
+
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
|
48 |
+
else:
|
49 |
+
model = AutoModelForCausalLM.from_pretrained(
|
50 |
+
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
51 |
+
)
|
52 |
+
tokenizer_dir = model_dir
|
53 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
54 |
+
tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False
|
55 |
+
)
|
56 |
+
return model, tokenizer
|
57 |
+
|
58 |
+
|
59 |
+
model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)
|
60 |
+
|
61 |
+
|
62 |
+
class StopOnTokens(StoppingCriteria):
|
63 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
64 |
+
stop_ids = model.config.eos_token_id
|
65 |
+
for stop_id in stop_ids:
|
66 |
+
if input_ids[0][-1] == stop_id:
|
67 |
+
return True
|
68 |
+
return False
|
69 |
+
|
70 |
+
|
71 |
+
def parse_text(text):
|
72 |
+
lines = text.split("\n")
|
73 |
+
lines = [line for line in lines if line != ""]
|
74 |
+
count = 0
|
75 |
+
for i, line in enumerate(lines):
|
76 |
+
if "```" in line:
|
77 |
+
count += 1
|
78 |
+
items = line.split('`')
|
79 |
+
if count % 2 == 1:
|
80 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
81 |
+
else:
|
82 |
+
lines[i] = f'<br></code></pre>'
|
83 |
+
else:
|
84 |
+
if i > 0:
|
85 |
+
if count % 2 == 1:
|
86 |
+
line = line.replace("`", "\`")
|
87 |
+
line = line.replace("<", "<")
|
88 |
+
line = line.replace(">", ">")
|
89 |
+
line = line.replace(" ", " ")
|
90 |
+
line = line.replace("*", "*")
|
91 |
+
line = line.replace("_", "_")
|
92 |
+
line = line.replace("-", "-")
|
93 |
+
line = line.replace(".", ".")
|
94 |
+
line = line.replace("!", "!")
|
95 |
+
line = line.replace("(", "(")
|
96 |
+
line = line.replace(")", ")")
|
97 |
+
line = line.replace("$", "$")
|
98 |
+
lines[i] = "<br>" + line
|
99 |
+
text = "".join(lines)
|
100 |
+
return text
|
101 |
+
|
102 |
+
|
103 |
+
def predict(history, prompt, max_length, top_p, temperature):
|
104 |
+
stop = StopOnTokens()
|
105 |
+
messages = []
|
106 |
+
if prompt:
|
107 |
+
messages.append({"role": "system", "content": prompt})
|
108 |
+
for idx, (user_msg, model_msg) in enumerate(history):
|
109 |
+
if prompt and idx == 0:
|
110 |
+
continue
|
111 |
+
if idx == len(history) - 1 and not model_msg:
|
112 |
+
messages.append({"role": "user", "content": user_msg})
|
113 |
+
break
|
114 |
+
if user_msg:
|
115 |
+
messages.append({"role": "user", "content": user_msg})
|
116 |
+
if model_msg:
|
117 |
+
messages.append({"role": "assistant", "content": model_msg})
|
118 |
+
|
119 |
+
model_inputs = tokenizer.apply_chat_template(messages,
|
120 |
+
add_generation_prompt=True,
|
121 |
+
tokenize=True,
|
122 |
+
return_tensors="pt").to(next(model.parameters()).device)
|
123 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
|
124 |
+
generate_kwargs = {
|
125 |
+
"input_ids": model_inputs,
|
126 |
+
"streamer": streamer,
|
127 |
+
"max_new_tokens": max_length,
|
128 |
+
"do_sample": True,
|
129 |
+
"top_p": top_p,
|
130 |
+
"temperature": temperature,
|
131 |
+
"stopping_criteria": StoppingCriteriaList([stop]),
|
132 |
+
"repetition_penalty": 1.2,
|
133 |
+
"eos_token_id": model.config.eos_token_id,
|
134 |
+
}
|
135 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
136 |
+
t.start()
|
137 |
+
for new_token in streamer:
|
138 |
+
if new_token:
|
139 |
+
history[-1][1] += new_token
|
140 |
+
yield history
|
141 |
+
|
142 |
+
|
143 |
+
with gr.Blocks() as demo:
|
144 |
+
gr.HTML("""<h1 align="center">GLM-4-9B Gradio Simple Chat Demo</h1>""")
|
145 |
+
chatbot = gr.Chatbot()
|
146 |
+
|
147 |
+
with gr.Row():
|
148 |
+
with gr.Column(scale=3):
|
149 |
+
with gr.Column(scale=12):
|
150 |
+
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
|
151 |
+
with gr.Column(min_width=32, scale=1):
|
152 |
+
submitBtn = gr.Button("Submit")
|
153 |
+
with gr.Column(scale=1):
|
154 |
+
prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False)
|
155 |
+
pBtn = gr.Button("Set Prompt")
|
156 |
+
with gr.Column(scale=1):
|
157 |
+
emptyBtn = gr.Button("Clear History")
|
158 |
+
max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
|
159 |
+
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
|
160 |
+
temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
|
161 |
+
|
162 |
+
|
163 |
+
def user(query, history):
|
164 |
+
return "", history + [[parse_text(query), ""]]
|
165 |
+
|
166 |
+
|
167 |
+
def set_prompt(prompt_text):
|
168 |
+
return [[parse_text(prompt_text), "成功设置prompt"]]
|
169 |
+
|
170 |
+
|
171 |
+
pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot)
|
172 |
+
|
173 |
+
submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
|
174 |
+
predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot
|
175 |
+
)
|
176 |
+
emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False)
|
177 |
+
|
178 |
+
demo.queue()
|
179 |
+
demo.launch(server_name="127.0.0.1", server_port=8000, inbrowser=True, share=True)
|
trans_web_vision_demo.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script creates a Gradio demo with a Transformers backend for the glm-4v-9b model, allowing users to interact with the model through a Gradio web UI.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
- Run the script to start the Gradio server.
|
6 |
+
- Interact with the model via the web UI.
|
7 |
+
|
8 |
+
Requirements:
|
9 |
+
- Gradio package
|
10 |
+
- Type `pip install gradio` to install Gradio.
|
11 |
+
"""
|
12 |
+
|
13 |
+
import os
|
14 |
+
import torch
|
15 |
+
import gradio as gr
|
16 |
+
from threading import Thread
|
17 |
+
from transformers import (
|
18 |
+
AutoTokenizer,
|
19 |
+
StoppingCriteria,
|
20 |
+
StoppingCriteriaList,
|
21 |
+
TextIteratorStreamer, AutoModel, BitsAndBytesConfig
|
22 |
+
)
|
23 |
+
from PIL import Image
|
24 |
+
import requests
|
25 |
+
from io import BytesIO
|
26 |
+
|
27 |
+
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4v-9b')
|
28 |
+
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
30 |
+
MODEL_PATH,
|
31 |
+
trust_remote_code=True,
|
32 |
+
encode_special_tokens=True
|
33 |
+
)
|
34 |
+
model = AutoModel.from_pretrained(
|
35 |
+
MODEL_PATH,
|
36 |
+
trust_remote_code=True,
|
37 |
+
device_map="auto",
|
38 |
+
torch_dtype=torch.bfloat16
|
39 |
+
).eval()
|
40 |
+
|
41 |
+
class StopOnTokens(StoppingCriteria):
|
42 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
43 |
+
stop_ids = model.config.eos_token_id
|
44 |
+
for stop_id in stop_ids:
|
45 |
+
if input_ids[0][-1] == stop_id:
|
46 |
+
return True
|
47 |
+
return False
|
48 |
+
|
49 |
+
def get_image(image_path=None, image_url=None):
|
50 |
+
if image_path:
|
51 |
+
return Image.open(image_path).convert("RGB")
|
52 |
+
elif image_url:
|
53 |
+
response = requests.get(image_url)
|
54 |
+
return Image.open(BytesIO(response.content)).convert("RGB")
|
55 |
+
return None
|
56 |
+
|
57 |
+
def chatbot(image_path=None, image_url=None, assistant_prompt=""):
|
58 |
+
image = get_image(image_path, image_url)
|
59 |
+
|
60 |
+
messages = [
|
61 |
+
{"role": "assistant", "content": assistant_prompt},
|
62 |
+
{"role": "user", "content": "", "image": image}
|
63 |
+
]
|
64 |
+
|
65 |
+
model_inputs = tokenizer.apply_chat_template(
|
66 |
+
messages,
|
67 |
+
add_generation_prompt=True,
|
68 |
+
tokenize=True,
|
69 |
+
return_tensors="pt",
|
70 |
+
return_dict=True
|
71 |
+
).to(next(model.parameters()).device)
|
72 |
+
|
73 |
+
streamer = TextIteratorStreamer(
|
74 |
+
tokenizer=tokenizer,
|
75 |
+
timeout=60,
|
76 |
+
skip_prompt=True,
|
77 |
+
skip_special_tokens=True
|
78 |
+
)
|
79 |
+
|
80 |
+
generate_kwargs = {
|
81 |
+
**model_inputs,
|
82 |
+
"streamer": streamer,
|
83 |
+
"max_new_tokens": 1024,
|
84 |
+
"do_sample": True,
|
85 |
+
"top_p": 0.8,
|
86 |
+
"temperature": 0.6,
|
87 |
+
"stopping_criteria": StoppingCriteriaList([StopOnTokens()]),
|
88 |
+
"repetition_penalty": 1.2,
|
89 |
+
"eos_token_id": [151329, 151336, 151338],
|
90 |
+
}
|
91 |
+
|
92 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
93 |
+
t.start()
|
94 |
+
|
95 |
+
response = ""
|
96 |
+
for new_token in streamer:
|
97 |
+
if new_token:
|
98 |
+
response += new_token
|
99 |
+
|
100 |
+
return image, response.strip()
|
101 |
+
|
102 |
+
with gr.Blocks() as demo:
|
103 |
+
demo.title = "GLM-4V-9B Image Recognition Demo"
|
104 |
+
demo.description = """
|
105 |
+
This demo uses the GLM-4V-9B model to got image infomation.
|
106 |
+
"""
|
107 |
+
with gr.Row():
|
108 |
+
with gr.Column():
|
109 |
+
image_path_input = gr.File(label="Upload Image (High-Priority)", type="filepath")
|
110 |
+
image_url_input = gr.Textbox(label="Image URL (Low-Priority)")
|
111 |
+
assistant_prompt_input = gr.Textbox(label="Assistant Prompt (You Can Change It)", value="这是什么?")
|
112 |
+
submit_button = gr.Button("Submit")
|
113 |
+
with gr.Column():
|
114 |
+
chatbot_output = gr.Textbox(label="GLM-4V-9B Model Response")
|
115 |
+
image_output = gr.Image(label="Image Preview")
|
116 |
+
|
117 |
+
submit_button.click(chatbot,
|
118 |
+
inputs=[image_path_input, image_url_input, assistant_prompt_input],
|
119 |
+
outputs=[image_output, chatbot_output])
|
120 |
+
|
121 |
+
demo.launch(server_name="127.0.0.1", server_port=8911, inbrowser=True, share=False)
|
vllm_cli_demo.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script creates a CLI demo with vllm backand for the glm-4-9b model,
|
3 |
+
allowing users to interact with the model through a command-line interface.
|
4 |
+
|
5 |
+
Usage:
|
6 |
+
- Run the script to start the CLI demo.
|
7 |
+
- Interact with the model by typing questions and receiving responses.
|
8 |
+
|
9 |
+
Note: The script includes a modification to handle markdown to plain text conversion,
|
10 |
+
ensuring that the CLI interface displays formatted text correctly.
|
11 |
+
"""
|
12 |
+
import time
|
13 |
+
import asyncio
|
14 |
+
from transformers import AutoTokenizer
|
15 |
+
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
|
16 |
+
from typing import List, Dict
|
17 |
+
|
18 |
+
MODEL_PATH = 'THUDM/glm-4-9b'
|
19 |
+
|
20 |
+
|
21 |
+
def load_model_and_tokenizer(model_dir: str):
|
22 |
+
engine_args = AsyncEngineArgs(
|
23 |
+
model=model_dir,
|
24 |
+
tokenizer=model_dir,
|
25 |
+
tensor_parallel_size=1,
|
26 |
+
dtype="bfloat16",
|
27 |
+
trust_remote_code=True,
|
28 |
+
gpu_memory_utilization=0.3,
|
29 |
+
enforce_eager=True,
|
30 |
+
worker_use_ray=True,
|
31 |
+
engine_use_ray=False,
|
32 |
+
disable_log_requests=True
|
33 |
+
# 如果遇见 OOM 现象,建议开启下述参数
|
34 |
+
# enable_chunked_prefill=True,
|
35 |
+
# max_num_batched_tokens=8192
|
36 |
+
)
|
37 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
38 |
+
model_dir,
|
39 |
+
trust_remote_code=True,
|
40 |
+
encode_special_tokens=True
|
41 |
+
)
|
42 |
+
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
43 |
+
return engine, tokenizer
|
44 |
+
|
45 |
+
|
46 |
+
engine, tokenizer = load_model_and_tokenizer(MODEL_PATH)
|
47 |
+
|
48 |
+
|
49 |
+
async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int):
|
50 |
+
inputs = tokenizer.apply_chat_template(
|
51 |
+
messages,
|
52 |
+
add_generation_prompt=True,
|
53 |
+
tokenize=False
|
54 |
+
)
|
55 |
+
params_dict = {
|
56 |
+
"n": 1,
|
57 |
+
"best_of": 1,
|
58 |
+
"presence_penalty": 1.0,
|
59 |
+
"frequency_penalty": 0.0,
|
60 |
+
"temperature": temperature,
|
61 |
+
"top_p": top_p,
|
62 |
+
"top_k": -1,
|
63 |
+
"use_beam_search": False,
|
64 |
+
"length_penalty": 1,
|
65 |
+
"early_stopping": False,
|
66 |
+
"stop_token_ids": [151329, 151336, 151338],
|
67 |
+
"ignore_eos": False,
|
68 |
+
"max_tokens": max_dec_len,
|
69 |
+
"logprobs": None,
|
70 |
+
"prompt_logprobs": None,
|
71 |
+
"skip_special_tokens": True,
|
72 |
+
}
|
73 |
+
sampling_params = SamplingParams(**params_dict)
|
74 |
+
async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
|
75 |
+
yield output.outputs[0].text
|
76 |
+
|
77 |
+
|
78 |
+
async def chat():
|
79 |
+
history = []
|
80 |
+
max_length = 8192
|
81 |
+
top_p = 0.8
|
82 |
+
temperature = 0.6
|
83 |
+
|
84 |
+
print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
|
85 |
+
while True:
|
86 |
+
user_input = input("\nYou: ")
|
87 |
+
if user_input.lower() in ["exit", "quit"]:
|
88 |
+
break
|
89 |
+
history.append([user_input, ""])
|
90 |
+
|
91 |
+
messages = []
|
92 |
+
for idx, (user_msg, model_msg) in enumerate(history):
|
93 |
+
if idx == len(history) - 1 and not model_msg:
|
94 |
+
messages.append({"role": "user", "content": user_msg})
|
95 |
+
break
|
96 |
+
if user_msg:
|
97 |
+
messages.append({"role": "user", "content": user_msg})
|
98 |
+
if model_msg:
|
99 |
+
messages.append({"role": "assistant", "content": model_msg})
|
100 |
+
|
101 |
+
print("\nGLM-4: ", end="")
|
102 |
+
current_length = 0
|
103 |
+
output = ""
|
104 |
+
async for output in vllm_gen(messages, top_p, temperature, max_length):
|
105 |
+
print(output[current_length:], end="", flush=True)
|
106 |
+
current_length = len(output)
|
107 |
+
history[-1][1] = output
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
asyncio.run(chat())
|