jackbond2024 commited on
Commit
03f3fb4
1 Parent(s): 0f162b7

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,12 +1,144 @@
1
  ---
2
- title: Glm4
3
- emoji: 🌖
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.37.2
8
- app_file: app.py
9
- pinned: false
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: glm4
3
+ app_file: trans_web_demo.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.37.2
 
 
6
  ---
7
+ # Basic Demo
8
+
9
+ Read this in [English](README_en.md).
10
+
11
+ 本 demo 中,你将体验到如何使用 GLM-4-9B 开源模型进行基本的任务。
12
+
13
+ 请严格按照文档的步骤进行操作,以避免不必要的错误。
14
+
15
+ ## 设备和依赖检查
16
+
17
+ ### 相关推理测试数据
18
+
19
+ **本文档的数据均在以下硬件环境测试,实际运行环境需求和运行占用的显存略有不同,请以实际运行环境为准。**
20
+
21
+ 测试硬件信息:
22
+
23
+ + OS: Ubuntu 22.04
24
+ + Memory: 512GB
25
+ + Python: 3.10.12 (推荐) / 3.12.3 均已测试
26
+ + CUDA Version: 12.3
27
+ + GPU Driver: 535.104.05
28
+ + GPU: NVIDIA A100-SXM4-80GB * 8
29
+
30
+ 相关推理的压力测试数据如下:
31
+
32
+ **所有测试均在单张GPU上进行测试,所有显存消耗都按照峰值左右进行测算**
33
+
34
+ #### GLM-4-9B-Chat
35
+
36
+ | 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
37
+ |------|-------|------------|---------------|--------------|
38
+ | BF16 | 19 GB | 0.2s | 27.8 tokens/s | 输入长度为 1000 |
39
+ | BF16 | 21 GB | 0.8s | 31.8 tokens/s | 输入长度为 8000 |
40
+ | BF16 | 28 GB | 4.3s | 14.4 tokens/s | 输入长度为 32000 |
41
+ | BF16 | 58 GB | 38.1s | 3.4 tokens/s | 输入长度为 128000 |
42
+
43
+ | 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
44
+ |------|-------|------------|---------------|-------------|
45
+ | INT4 | 8 GB | 0.2s | 23.3 tokens/s | 输入长度为 1000 |
46
+ | INT4 | 10 GB | 0.8s | 23.4 tokens/s | 输入长度为 8000 |
47
+ | INT4 | 17 GB | 4.3s | 14.6 tokens/s | 输入长度为 32000 |
48
+
49
+ #### GLM-4-9B-Chat-1M
50
+
51
+ | 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
52
+ |------|-------|------------|--------------|--------------|
53
+ | BF16 | 75 GB | 98.4s | 2.3 tokens/s | 输入长度为 200000 |
54
+
55
+ 如果您的输入超过200K,我们建议您使用vLLM后端进行多卡推理,以获得更好的性能。
56
+
57
+ #### GLM-4V-9B
58
+
59
+ | 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
60
+ |------|-------|------------|---------------|------------|
61
+ | BF16 | 28 GB | 0.1s | 33.4 tokens/s | 输入长度为 1000 |
62
+ | BF16 | 33 GB | 0.7s | 39.2 tokens/s | 输入长度为 8000 |
63
+
64
+ | 精度 | 显存占用 | Prefilling | Decode Speed | Remarks |
65
+ |------|-------|------------|---------------|------------|
66
+ | INT4 | 10 GB | 0.1s | 28.7 tokens/s | 输入长度为 1000 |
67
+ | INT4 | 15 GB | 0.8s | 24.2 tokens/s | 输入长度为 8000 |
68
+
69
+ ### 最低硬件要求
70
+
71
+ 如果您希望运行官方提供的最基础代码 (transformers 后端) 您需要:
72
+
73
+ + Python >= 3.10
74
+ + 内存不少于 32 GB
75
+
76
+ 如果您希望运行官方提供的本文件夹的所有代码,您还需要:
77
+
78
+ + Linux 操作系统 (Debian 系列最佳)
79
+ + 大于 8GB 显存的,支持 CUDA 或者 ROCM 并且支持 `BF16` 推理的 GPU 设备。(`FP16` 精度无法训练,推理有小概率出现问题)
80
+
81
+ 安装依赖
82
+
83
+ ```shell
84
+ pip install -r requirements.txt
85
+ ```
86
+
87
+ ## 基础功能调用
88
+
89
+ **除非特殊说明,本文件夹所有 demo 并不支持 Function Call 和 All Tools 等进阶用法**
90
+
91
+ ### 使用 transformers 后端代码
92
+
93
+ + 使用命令行与 GLM-4-9B 模型进行对话。
94
+
95
+ ```shell
96
+ python trans_cli_demo.py # GLM-4-9B-Chat
97
+ python trans_cli_vision_demo.py # GLM-4V-9B
98
+ ```
99
+
100
+ + 使用 Gradio 网页端与 GLM-4-9B 模型进行对话。
101
+
102
+ ```shell
103
+ python trans_web_demo.py # GLM-4-9B-Chat
104
+ python trans_web_vision_demo.py # GLM-4V-9B
105
+ ```
106
+
107
+ + 使用 Batch 推理。
108
+
109
+ ```shell
110
+ python trans_batch_demo.py
111
+ ```
112
+
113
+ ### 使用 vLLM 后端代码
114
+
115
+ + 使用命令行与 GLM-4-9B-Chat 模型进行对话。
116
+
117
+ ```shell
118
+ python vllm_cli_demo.py
119
+ ```
120
+
121
+ + 自行构建服务端,并使用 `OpenAI API` 的请求格式与 GLM-4-9B-Chat 模型进行对话。本 demo 支持 Function Call 和 All Tools功能。
122
+
123
+ 启动服务端:
124
+
125
+ ```shell
126
+ python openai_api_server.py
127
+ ```
128
+
129
+ 客户端请求:
130
+
131
+ ```shell
132
+ python openai_api_request.py
133
+ ```
134
+
135
+ ## 压力测试
136
+
137
+ 用户可以在自己的设备上使用本代码测试模型在 transformers后端的生成速度:
138
+
139
+ ```shell
140
+ python trans_stress_test.py
141
+ ```
142
+
143
+
144
 
 
README_en.md ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Basic Demo
2
+
3
+ In this demo, you will experience how to use the GLM-4-9B open source model to perform basic tasks.
4
+
5
+ Please follow the steps in the document strictly to avoid unnecessary errors.
6
+
7
+ ## Device and dependency check
8
+
9
+ ### Related inference test data
10
+
11
+ **The data in this document are tested in the following hardware environment. The actual operating environment
12
+ requirements and the GPU memory occupied by the operation are slightly different. Please refer to the actual operating
13
+ environment.**
14
+
15
+ Test hardware information:
16
+
17
+ + OS: Ubuntu 22.04
18
+ + Memory: 512GB
19
+ + Python: 3.10.12 (recommend) / 3.12.3 have been tested
20
+ + CUDA Version: 12.3
21
+ + GPU Driver: 535.104.05
22
+ + GPU: NVIDIA A100-SXM4-80GB * 8
23
+
24
+ The stress test data of relevant inference are as follows:
25
+
26
+ **All tests are performed on a single GPU, and all GPU memory consumption is calculated based on the peak value**
27
+
28
+ #
29
+
30
+ ### GLM-4-9B-Chat
31
+
32
+ | Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
33
+ |-------|------------|------------|---------------|------------------------|
34
+ | BF16 | 19 GB | 0.2s | 27.8 tokens/s | Input length is 1000 |
35
+ | BF16 | 21 GB | 0.8s | 31.8 tokens/s | Input length is 8000 |
36
+ | BF16 | 28 GB | 4.3s | 14.4 tokens/s | Input length is 32000 |
37
+ | BF16 | 58 GB | 38.1s | 3.4 tokens/s | Input length is 128000 |
38
+
39
+ | Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
40
+ |-------|------------|------------|---------------|-----------------------|
41
+ | INT4 | 8 GB | 0.2s | 23.3 tokens/s | Input length is 1000 |
42
+ | INT4 | 10 GB | 0.8s | 23.4 tokens/s | Input length is 8000 |
43
+ | INT4 | 17 GB | 4.3s | 14.6 tokens/s | Input length is 32000 |
44
+
45
+ ### GLM-4-9B-Chat-1M
46
+
47
+ | Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
48
+ |-------|------------|------------|------------------|------------------------|
49
+ | BF16 | 74497MiB | 98.4s | 2.3653 tokens/s | Input length is 200000 |
50
+
51
+ If your input exceeds 200K, we recommend that you use the vLLM backend with multi gpus for inference to get better
52
+ performance.
53
+
54
+ #### GLM-4V-9B
55
+
56
+ | Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
57
+ |-------|------------|------------|---------------|----------------------|
58
+ | BF16 | 28 GB | 0.1s | 33.4 tokens/s | Input length is 1000 |
59
+ | BF16 | 33 GB | 0.7s | 39.2 tokens/s | Input length is 8000 |
60
+
61
+ | Dtype | GPU Memory | Prefilling | Decode Speed | Remarks |
62
+ |-------|------------|------------|---------------|----------------------|
63
+ | INT4 | 10 GB | 0.1s | 28.7 tokens/s | Input length is 1000 |
64
+ | INT4 | 15 GB | 0.8s | 24.2 tokens/s | Input length is 8000 |
65
+
66
+ ### Minimum hardware requirements
67
+
68
+ If you want to run the most basic code provided by the official (transformers backend) you need:
69
+
70
+ + Python >= 3.10
71
+ + Memory of at least 32 GB
72
+
73
+ If you want to run all the codes in this folder provided by the official, you also need:
74
+
75
+ + Linux operating system (Debian series is best)
76
+ + GPU device with more than 8GB GPU memory, supporting CUDA or ROCM and supporting `BF16` reasoning (`FP16` precision
77
+ cannot be finetuned, and there is a small probability of problems in infering)
78
+
79
+ Install dependencies
80
+
81
+ ```shell
82
+ pip install -r requirements.txt
83
+ ```
84
+
85
+ ## Basic function calls
86
+
87
+ **Unless otherwise specified, all demos in this folder do not support advanced usage such as Function Call and All Tools
88
+ **
89
+
90
+ ### Use transformers backend code
91
+
92
+ + Use the command line to communicate with the GLM-4-9B model.
93
+
94
+ ```shell
95
+ python trans_cli_demo.py # GLM-4-9B-Chat
96
+ python trans_cli_vision_demo.py # GLM-4V-9B
97
+ ```
98
+
99
+ + Use the Gradio web client to communicate with the GLM-4-9B model.
100
+
101
+ ```shell
102
+ python trans_web_demo.py # GLM-4-9B-Chat
103
+ python trans_web_vision_demo.py # GLM-4V-9B
104
+ ```
105
+
106
+ + Use Batch inference.
107
+
108
+ ```shell
109
+ python trans_batch_demo.py
110
+ ```
111
+
112
+ ### Use vLLM backend code
113
+
114
+ + Use the command line to communicate with the GLM-4-9B-Chat model.
115
+
116
+ ```shell
117
+ python vllm_cli_demo.py
118
+ ```
119
+
120
+ + Build the server by yourself and use the request format of `OpenAI API` to communicate with the glm-4-9b model. This
121
+ demo supports Function Call and All Tools functions.
122
+
123
+ Start the server:
124
+
125
+ ```shell
126
+ python openai_api_server.py
127
+ ```
128
+
129
+ Client request:
130
+
131
+ ```shell
132
+ python openai_api_request.py
133
+ ```
134
+
135
+ ## Stress test
136
+
137
+ Users can use this code to test the generation speed of the model on the transformers backend on their own devices:
138
+
139
+ ```shell
140
+ python trans_stress_test.py
141
+ ```
openai_api_request.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script creates a OpenAI Request demo for the glm-4-9b model, just Use OpenAI API to interact with the model.
3
+ """
4
+
5
+ from openai import OpenAI
6
+
7
+ base_url = "http://127.0.0.1:8000/v1/"
8
+ client = OpenAI(api_key="EMPTY", base_url=base_url)
9
+
10
+
11
+ def function_chat(use_stream=False):
12
+ messages = [
13
+ {
14
+ "role": "user", "content": "What's the Celsius temperature in San Francisco?"
15
+ },
16
+
17
+ # Give Observations
18
+ # {
19
+ # "role": "assistant",
20
+ # "content": None,
21
+ # "function_call": None,
22
+ # "tool_calls": [
23
+ # {
24
+ # "id": "call_1717912616815",
25
+ # "function": {
26
+ # "name": "get_current_weather",
27
+ # "arguments": "{\"location\": \"San Francisco, CA\", \"format\": \"celsius\"}"
28
+ # },
29
+ # "type": "function"
30
+ # }
31
+ # ]
32
+ # },
33
+ # {
34
+ # "tool_call_id": "call_1717912616815",
35
+ # "role": "tool",
36
+ # "name": "get_current_weather",
37
+ # "content": "23°C",
38
+ # }
39
+ ]
40
+ tools = [
41
+ {
42
+ "type": "function",
43
+ "function": {
44
+ "name": "get_current_weather",
45
+ "description": "Get the current weather",
46
+ "parameters": {
47
+ "type": "object",
48
+ "properties": {
49
+ "location": {
50
+ "type": "string",
51
+ "description": "The city and state, e.g. San Francisco, CA",
52
+ },
53
+ "format": {
54
+ "type": "string",
55
+ "enum": ["celsius", "fahrenheit"],
56
+ "description": "The temperature unit to use. Infer this from the users location.",
57
+ },
58
+ },
59
+ "required": ["location", "format"],
60
+ },
61
+ }
62
+ },
63
+ ]
64
+
65
+ # All Tools: CogView
66
+ # messages = [{"role": "user", "content": "帮我画一张天空的画画吧"}]
67
+ # tools = [{"type": "cogview"}]
68
+
69
+ # All Tools: Searching
70
+ # messages = [{"role": "user", "content": "今天黄金的价格"}]
71
+ # tools = [{"type": "simple_browser"}]
72
+
73
+ response = client.chat.completions.create(
74
+ model="glm-4",
75
+ messages=messages,
76
+ tools=tools,
77
+ stream=use_stream,
78
+ max_tokens=256,
79
+ temperature=0.9,
80
+ presence_penalty=1.2,
81
+ top_p=0.1,
82
+ tool_choice="auto"
83
+ )
84
+ if response:
85
+ if use_stream:
86
+ for chunk in response:
87
+ print(chunk)
88
+ else:
89
+ print(response)
90
+ else:
91
+ print("Error:", response.status_code)
92
+
93
+
94
+ def simple_chat(use_stream=False):
95
+ messages = [
96
+ {
97
+ "role": "system",
98
+ "content": "请在你输出的时候都带上“喵喵喵”三个字,放在开头。",
99
+ },
100
+ {
101
+ "role": "user",
102
+ "content": "你是谁"
103
+ }
104
+ ]
105
+ response = client.chat.completions.create(
106
+ model="glm-4",
107
+ messages=messages,
108
+ stream=use_stream,
109
+ max_tokens=256,
110
+ temperature=0.4,
111
+ presence_penalty=1.2,
112
+ top_p=0.8,
113
+ )
114
+ if response:
115
+ if use_stream:
116
+ for chunk in response:
117
+ print(chunk)
118
+ else:
119
+ print(response)
120
+ else:
121
+ print("Error:", response.status_code)
122
+
123
+
124
+ if __name__ == "__main__":
125
+ # simple_chat(use_stream=False)
126
+ function_chat(use_stream=False)
127
+
openai_api_server.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from asyncio.log import logger
3
+ import re
4
+ import uvicorn
5
+ import gc
6
+ import json
7
+ import torch
8
+ import random
9
+ import string
10
+
11
+ from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
12
+ from fastapi import FastAPI, HTTPException, Response
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from contextlib import asynccontextmanager
15
+ from typing import List, Literal, Optional, Union
16
+ from pydantic import BaseModel, Field
17
+ from transformers import AutoTokenizer, LogitsProcessor
18
+ from sse_starlette.sse import EventSourceResponse
19
+
20
+ EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
21
+ import os
22
+
23
+ MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
24
+ MAX_MODEL_LENGTH = 8192
25
+
26
+
27
+ @asynccontextmanager
28
+ async def lifespan(app: FastAPI):
29
+ yield
30
+ if torch.cuda.is_available():
31
+ torch.cuda.empty_cache()
32
+ torch.cuda.ipc_collect()
33
+
34
+
35
+ app = FastAPI(lifespan=lifespan)
36
+
37
+ app.add_middleware(
38
+ CORSMiddleware,
39
+ allow_origins=["*"],
40
+ allow_credentials=True,
41
+ allow_methods=["*"],
42
+ allow_headers=["*"],
43
+ )
44
+
45
+
46
+ def generate_id(prefix: str, k=29) -> str:
47
+ suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=k))
48
+ return f"{prefix}{suffix}"
49
+
50
+
51
+ class ModelCard(BaseModel):
52
+ id: str = ""
53
+ object: str = "model"
54
+ created: int = Field(default_factory=lambda: int(time.time()))
55
+ owned_by: str = "owner"
56
+ root: Optional[str] = None
57
+ parent: Optional[str] = None
58
+ permission: Optional[list] = None
59
+
60
+
61
+ class ModelList(BaseModel):
62
+ object: str = "list"
63
+ data: List[ModelCard] = ["glm-4"]
64
+
65
+
66
+ class FunctionCall(BaseModel):
67
+ name: Optional[str] = None
68
+ arguments: Optional[str] = None
69
+
70
+
71
+ class ChoiceDeltaToolCallFunction(BaseModel):
72
+ name: Optional[str] = None
73
+ arguments: Optional[str] = None
74
+
75
+
76
+ class UsageInfo(BaseModel):
77
+ prompt_tokens: int = 0
78
+ total_tokens: int = 0
79
+ completion_tokens: Optional[int] = 0
80
+
81
+
82
+ class ChatCompletionMessageToolCall(BaseModel):
83
+ index: Optional[int] = 0
84
+ id: Optional[str] = None
85
+ function: FunctionCall
86
+ type: Optional[Literal["function"]] = 'function'
87
+
88
+
89
+ class ChatMessage(BaseModel):
90
+ # “function” 字段解释:
91
+ # 使用较老的OpenAI API版本需要注意在这里添加 function 字段并在 process_messages函数中添加相应角色转换逻辑为 observation
92
+
93
+ role: Literal["user", "assistant", "system", "tool"]
94
+ content: Optional[str] = None
95
+ function_call: Optional[ChoiceDeltaToolCallFunction] = None
96
+ tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
97
+
98
+
99
+ class DeltaMessage(BaseModel):
100
+ role: Optional[Literal["user", "assistant", "system"]] = None
101
+ content: Optional[str] = None
102
+ function_call: Optional[ChoiceDeltaToolCallFunction] = None
103
+ tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
104
+
105
+
106
+ class ChatCompletionResponseChoice(BaseModel):
107
+ index: int
108
+ message: ChatMessage
109
+ finish_reason: Literal["stop", "length", "tool_calls"]
110
+
111
+
112
+ class ChatCompletionResponseStreamChoice(BaseModel):
113
+ delta: DeltaMessage
114
+ finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
115
+ index: int
116
+
117
+
118
+ class ChatCompletionResponse(BaseModel):
119
+ model: str
120
+ id: Optional[str] = Field(default_factory=lambda: generate_id('chatcmpl-', 29))
121
+ object: Literal["chat.completion", "chat.completion.chunk"]
122
+ choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
123
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
124
+ system_fingerprint: Optional[str] = Field(default_factory=lambda: generate_id('fp_', 9))
125
+ usage: Optional[UsageInfo] = None
126
+
127
+
128
+ class ChatCompletionRequest(BaseModel):
129
+ model: str
130
+ messages: List[ChatMessage]
131
+ temperature: Optional[float] = 0.8
132
+ top_p: Optional[float] = 0.8
133
+ max_tokens: Optional[int] = None
134
+ stream: Optional[bool] = False
135
+ tools: Optional[Union[dict, List[dict]]] = None
136
+ tool_choice: Optional[Union[str, dict]] = None
137
+ repetition_penalty: Optional[float] = 1.1
138
+
139
+
140
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
141
+ def __call__(
142
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
143
+ ) -> torch.FloatTensor:
144
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
145
+ scores.zero_()
146
+ scores[..., 5] = 5e4
147
+ return scores
148
+
149
+
150
+ def process_response(output: str, tools: dict | List[dict] = None, use_tool: bool = False) -> Union[str, dict]:
151
+ lines = output.strip().split("\n")
152
+ arguments_json = None
153
+ special_tools = ["cogview", "simple_browser"]
154
+ tools = {tool['function']['name'] for tool in tools} if tools else {}
155
+
156
+ # 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
157
+ ##TODO 如果你希望做更多判断,可以在这里进行逻辑完善。
158
+
159
+ if len(lines) >= 2 and lines[1].startswith("{"):
160
+ function_name = lines[0].strip()
161
+ arguments = "\n".join(lines[1:]).strip()
162
+ if function_name in tools or function_name in special_tools:
163
+ try:
164
+ arguments_json = json.loads(arguments)
165
+ is_tool_call = True
166
+ except json.JSONDecodeError:
167
+ is_tool_call = function_name in special_tools
168
+
169
+ if is_tool_call and use_tool:
170
+ content = {
171
+ "name": function_name,
172
+ "arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
173
+ ensure_ascii=False)
174
+ }
175
+ if function_name == "simple_browser":
176
+ search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
177
+ match = search_pattern.match(arguments)
178
+ if match:
179
+ content["arguments"] = json.dumps({
180
+ "query": match.group(1),
181
+ "recency_days": int(match.group(2))
182
+ }, ensure_ascii=False)
183
+ elif function_name == "cogview":
184
+ content["arguments"] = json.dumps({
185
+ "prompt": arguments
186
+ }, ensure_ascii=False)
187
+
188
+ return content
189
+ return output.strip()
190
+
191
+
192
+ @torch.inference_mode()
193
+ async def generate_stream_glm4(params):
194
+ messages = params["messages"]
195
+ tools = params["tools"]
196
+ tool_choice = params["tool_choice"]
197
+ temperature = float(params.get("temperature", 1.0))
198
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
199
+ top_p = float(params.get("top_p", 1.0))
200
+ max_new_tokens = int(params.get("max_tokens", 8192))
201
+
202
+ messages = process_messages(messages, tools=tools, tool_choice=tool_choice)
203
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
204
+ params_dict = {
205
+ "n": 1,
206
+ "best_of": 1,
207
+ "presence_penalty": 1.0,
208
+ "frequency_penalty": 0.0,
209
+ "temperature": temperature,
210
+ "top_p": top_p,
211
+ "top_k": -1,
212
+ "repetition_penalty": repetition_penalty,
213
+ "use_beam_search": False,
214
+ "length_penalty": 1,
215
+ "early_stopping": False,
216
+ "stop_token_ids": [151329, 151336, 151338],
217
+ "ignore_eos": False,
218
+ "max_tokens": max_new_tokens,
219
+ "logprobs": None,
220
+ "prompt_logprobs": None,
221
+ "skip_special_tokens": True,
222
+ }
223
+ sampling_params = SamplingParams(**params_dict)
224
+ async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
225
+ output_len = len(output.outputs[0].token_ids)
226
+ input_len = len(output.prompt_token_ids)
227
+ ret = {
228
+ "text": output.outputs[0].text,
229
+ "usage": {
230
+ "prompt_tokens": input_len,
231
+ "completion_tokens": output_len,
232
+ "total_tokens": output_len + input_len
233
+ },
234
+ "finish_reason": output.outputs[0].finish_reason,
235
+ }
236
+ yield ret
237
+ gc.collect()
238
+ torch.cuda.empty_cache()
239
+
240
+
241
+ def process_messages(messages, tools=None, tool_choice="none"):
242
+ _messages = messages
243
+ processed_messages = []
244
+ msg_has_sys = False
245
+
246
+ def filter_tools(tool_choice, tools):
247
+ function_name = tool_choice.get('function', {}).get('name', None)
248
+ if not function_name:
249
+ return []
250
+ filtered_tools = [
251
+ tool for tool in tools
252
+ if tool.get('function', {}).get('name') == function_name
253
+ ]
254
+ return filtered_tools
255
+
256
+ if tool_choice != "none":
257
+ if isinstance(tool_choice, dict):
258
+ tools = filter_tools(tool_choice, tools)
259
+ if tools:
260
+ processed_messages.append(
261
+ {
262
+ "role": "system",
263
+ "content": None,
264
+ "tools": tools
265
+ }
266
+ )
267
+ msg_has_sys = True
268
+
269
+ if isinstance(tool_choice, dict) and tools:
270
+ processed_messages.append(
271
+ {
272
+ "role": "assistant",
273
+ "metadata": tool_choice["function"]["name"],
274
+ "content": ""
275
+ }
276
+ )
277
+
278
+ for m in _messages:
279
+ role, content, func_call = m.role, m.content, m.function_call
280
+ tool_calls = getattr(m, 'tool_calls', None)
281
+
282
+ if role == "function":
283
+ processed_messages.append(
284
+ {
285
+ "role": "observation",
286
+ "content": content
287
+ }
288
+ )
289
+ elif role == "tool":
290
+ processed_messages.append(
291
+ {
292
+ "role": "observation",
293
+ "content": content,
294
+ "function_call": True
295
+ }
296
+ )
297
+ elif role == "assistant":
298
+ if tool_calls:
299
+ for tool_call in tool_calls:
300
+ processed_messages.append(
301
+ {
302
+ "role": "assistant",
303
+ "metadata": tool_call.function.name,
304
+ "content": tool_call.function.arguments
305
+ }
306
+ )
307
+ else:
308
+ for response in content.split("\n"):
309
+ if "\n" in response:
310
+ metadata, sub_content = response.split("\n", maxsplit=1)
311
+ else:
312
+ metadata, sub_content = "", response
313
+ processed_messages.append(
314
+ {
315
+ "role": role,
316
+ "metadata": metadata,
317
+ "content": sub_content.strip()
318
+ }
319
+ )
320
+ else:
321
+ if role == "system" and msg_has_sys:
322
+ msg_has_sys = False
323
+ continue
324
+ processed_messages.append({"role": role, "content": content})
325
+
326
+ if not tools or tool_choice == "none":
327
+ for m in _messages:
328
+ if m.role == 'system':
329
+ processed_messages.insert(0, {"role": m.role, "content": m.content})
330
+ break
331
+ return processed_messages
332
+
333
+
334
+ @app.get("/health")
335
+ async def health() -> Response:
336
+ """Health check."""
337
+ return Response(status_code=200)
338
+
339
+
340
+ @app.get("/v1/models", response_model=ModelList)
341
+ async def list_models():
342
+ model_card = ModelCard(id="glm-4")
343
+ return ModelList(data=[model_card])
344
+
345
+
346
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
347
+ async def create_chat_completion(request: ChatCompletionRequest):
348
+ if len(request.messages) < 1 or request.messages[-1].role == "assistant":
349
+ raise HTTPException(status_code=400, detail="Invalid request")
350
+
351
+ gen_params = dict(
352
+ messages=request.messages,
353
+ temperature=request.temperature,
354
+ top_p=request.top_p,
355
+ max_tokens=request.max_tokens or 1024,
356
+ echo=False,
357
+ stream=request.stream,
358
+ repetition_penalty=request.repetition_penalty,
359
+ tools=request.tools,
360
+ tool_choice=request.tool_choice,
361
+ )
362
+ logger.debug(f"==== request ====\n{gen_params}")
363
+
364
+ if request.stream:
365
+ predict_stream_generator = predict_stream(request.model, gen_params)
366
+ output = await anext(predict_stream_generator)
367
+ if output:
368
+ return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
369
+ logger.debug(f"First result output:\n{output}")
370
+
371
+ function_call = None
372
+ if output and request.tools:
373
+ try:
374
+ function_call = process_response(output, request.tools, use_tool=True)
375
+ except:
376
+ logger.warning("Failed to parse tool call")
377
+
378
+ if isinstance(function_call, dict):
379
+ function_call = ChoiceDeltaToolCallFunction(**function_call)
380
+ generate = parse_output_text(request.model, output, function_call=function_call)
381
+ return EventSourceResponse(generate, media_type="text/event-stream")
382
+ else:
383
+ return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
384
+ response = ""
385
+ async for response in generate_stream_glm4(gen_params):
386
+ pass
387
+
388
+ if response["text"].startswith("\n"):
389
+ response["text"] = response["text"][1:]
390
+ response["text"] = response["text"].strip()
391
+
392
+ usage = UsageInfo()
393
+
394
+ function_call, finish_reason = None, "stop"
395
+ tool_calls = None
396
+ if request.tools:
397
+ try:
398
+ function_call = process_response(response["text"], request.tools, use_tool=True)
399
+ except Exception as e:
400
+ logger.warning(f"Failed to parse tool call: {e}")
401
+ if isinstance(function_call, dict):
402
+ finish_reason = "tool_calls"
403
+ function_call_response = ChoiceDeltaToolCallFunction(**function_call)
404
+ function_call_instance = FunctionCall(
405
+ name=function_call_response.name,
406
+ arguments=function_call_response.arguments
407
+ )
408
+ tool_calls = [
409
+ ChatCompletionMessageToolCall(
410
+ id=generate_id('call_', 24),
411
+ function=function_call_instance,
412
+ type="function")]
413
+
414
+ message = ChatMessage(
415
+ role="assistant",
416
+ content=None if tool_calls else response["text"],
417
+ function_call=None,
418
+ tool_calls=tool_calls,
419
+ )
420
+
421
+ logger.debug(f"==== message ====\n{message}")
422
+
423
+ choice_data = ChatCompletionResponseChoice(
424
+ index=0,
425
+ message=message,
426
+ finish_reason=finish_reason,
427
+ )
428
+ task_usage = UsageInfo.model_validate(response["usage"])
429
+ for usage_key, usage_value in task_usage.model_dump().items():
430
+ setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
431
+
432
+ return ChatCompletionResponse(
433
+ model=request.model,
434
+ choices=[choice_data],
435
+ object="chat.completion",
436
+ usage=usage
437
+ )
438
+
439
+
440
+ async def predict_stream(model_id, gen_params):
441
+ output = ""
442
+ is_function_call = False
443
+ has_send_first_chunk = False
444
+ created_time = int(time.time())
445
+ function_name = None
446
+ response_id = generate_id('chatcmpl-', 29)
447
+ system_fingerprint = generate_id('fp_', 9)
448
+ tools = {tool['function']['name'] for tool in gen_params['tools']} if gen_params['tools'] else {}
449
+ async for new_response in generate_stream_glm4(gen_params):
450
+ decoded_unicode = new_response["text"]
451
+ delta_text = decoded_unicode[len(output):]
452
+ output = decoded_unicode
453
+ lines = output.strip().split("\n")
454
+
455
+ # 检查是否为工具
456
+ # 这是一个简单的工具比较函数,不能保证拦截所有非工具输出的结果,比如参数未对齐等特殊情况。
457
+ ##TODO 如果你希望做更多处理,可以在这里进行逻辑完善。
458
+
459
+ if not is_function_call and len(lines) >= 2:
460
+ first_line = lines[0].strip()
461
+ if first_line in tools:
462
+ is_function_call = True
463
+ function_name = first_line
464
+
465
+ # 工具调用返回
466
+ if is_function_call:
467
+ if not has_send_first_chunk:
468
+ function_call = {"name": function_name, "arguments": ""}
469
+ tool_call = ChatCompletionMessageToolCall(
470
+ index=0,
471
+ id=generate_id('call_', 24),
472
+ function=FunctionCall(**function_call),
473
+ type="function"
474
+ )
475
+ message = DeltaMessage(
476
+ content=None,
477
+ role="assistant",
478
+ function_call=None,
479
+ tool_calls=[tool_call]
480
+ )
481
+ choice_data = ChatCompletionResponseStreamChoice(
482
+ index=0,
483
+ delta=message,
484
+ finish_reason=None
485
+ )
486
+ chunk = ChatCompletionResponse(
487
+ model=model_id,
488
+ id=response_id,
489
+ choices=[choice_data],
490
+ created=created_time,
491
+ system_fingerprint=system_fingerprint,
492
+ object="chat.completion.chunk"
493
+ )
494
+ yield ""
495
+ yield chunk.model_dump_json(exclude_unset=True)
496
+ has_send_first_chunk = True
497
+
498
+ function_call = {"name": None, "arguments": delta_text}
499
+ tool_call = ChatCompletionMessageToolCall(
500
+ index=0,
501
+ id=None,
502
+ function=FunctionCall(**function_call),
503
+ type="function"
504
+ )
505
+ message = DeltaMessage(
506
+ content=None,
507
+ role=None,
508
+ function_call=None,
509
+ tool_calls=[tool_call]
510
+ )
511
+ choice_data = ChatCompletionResponseStreamChoice(
512
+ index=0,
513
+ delta=message,
514
+ finish_reason=None
515
+ )
516
+ chunk = ChatCompletionResponse(
517
+ model=model_id,
518
+ id=response_id,
519
+ choices=[choice_data],
520
+ created=created_time,
521
+ system_fingerprint=system_fingerprint,
522
+ object="chat.completion.chunk"
523
+ )
524
+ yield chunk.model_dump_json(exclude_unset=True)
525
+
526
+ # 用户请求了 Function Call 但是框架还没确定是否为Function Call
527
+ elif (gen_params["tools"] and gen_params["tool_choice"] != "none") or is_function_call:
528
+ continue
529
+
530
+ # 常规返回
531
+ else:
532
+ finish_reason = new_response.get("finish_reason", None)
533
+ if not has_send_first_chunk:
534
+ message = DeltaMessage(
535
+ content="",
536
+ role="assistant",
537
+ function_call=None,
538
+ )
539
+ choice_data = ChatCompletionResponseStreamChoice(
540
+ index=0,
541
+ delta=message,
542
+ finish_reason=finish_reason
543
+ )
544
+ chunk = ChatCompletionResponse(
545
+ model=model_id,
546
+ id=response_id,
547
+ choices=[choice_data],
548
+ created=created_time,
549
+ system_fingerprint=system_fingerprint,
550
+ object="chat.completion.chunk"
551
+ )
552
+ yield chunk.model_dump_json(exclude_unset=True)
553
+ has_send_first_chunk = True
554
+
555
+ message = DeltaMessage(
556
+ content=delta_text,
557
+ role="assistant",
558
+ function_call=None,
559
+ )
560
+ choice_data = ChatCompletionResponseStreamChoice(
561
+ index=0,
562
+ delta=message,
563
+ finish_reason=finish_reason
564
+ )
565
+ chunk = ChatCompletionResponse(
566
+ model=model_id,
567
+ id=response_id,
568
+ choices=[choice_data],
569
+ created=created_time,
570
+ system_fingerprint=system_fingerprint,
571
+ object="chat.completion.chunk"
572
+ )
573
+ yield chunk.model_dump_json(exclude_unset=True)
574
+
575
+ # 工具调用需要额外返回一个字段以对齐 OpenAI 接口
576
+ if is_function_call:
577
+ yield ChatCompletionResponse(
578
+ model=model_id,
579
+ id=response_id,
580
+ system_fingerprint=system_fingerprint,
581
+ choices=[
582
+ ChatCompletionResponseStreamChoice(
583
+ index=0,
584
+ delta=DeltaMessage(
585
+ content=None,
586
+ role=None,
587
+ function_call=None,
588
+ ),
589
+ finish_reason="tool_calls"
590
+ )],
591
+ created=created_time,
592
+ object="chat.completion.chunk",
593
+ usage=None
594
+ ).model_dump_json(exclude_unset=True)
595
+ yield '[DONE]'
596
+
597
+
598
+ async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
599
+ delta = DeltaMessage(role="assistant", content=value)
600
+ if function_call is not None:
601
+ delta.function_call = function_call
602
+
603
+ choice_data = ChatCompletionResponseStreamChoice(
604
+ index=0,
605
+ delta=delta,
606
+ finish_reason=None
607
+ )
608
+ chunk = ChatCompletionResponse(
609
+ model=model_id,
610
+ choices=[choice_data],
611
+ object="chat.completion.chunk"
612
+ )
613
+ yield "{}".format(chunk.model_dump_json(exclude_unset=True))
614
+ yield '[DONE]'
615
+
616
+
617
+ if __name__ == "__main__":
618
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
619
+ engine_args = AsyncEngineArgs(
620
+ model=MODEL_PATH,
621
+ tokenizer=MODEL_PATH,
622
+ # 如果你有多张显卡,可以在这里设置成你的显卡数量
623
+ tensor_parallel_size=1,
624
+ dtype="bfloat16",
625
+ trust_remote_code=True,
626
+ # 占用显存的比例,请根据你的显卡显存大小设置合适的值,例如,如果你的显卡有80G,您只想使用24G,请按照24/80=0.3设置
627
+ gpu_memory_utilization=0.9,
628
+ enforce_eager=True,
629
+ worker_use_ray=False,
630
+ engine_use_ray=False,
631
+ disable_log_requests=True,
632
+ max_model_len=MAX_MODEL_LENGTH,
633
+ )
634
+ engine = AsyncLLMEngine.from_engine_args(engine_args)
635
+ uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # use vllm
2
+ # vllm>=0.5.0
3
+
4
+ torch>=2.3.0
5
+ torchvision>=0.18.0
6
+ transformers==4.40.0
7
+ huggingface-hub>=0.23.1
8
+ sentencepiece>=0.2.0
9
+ pydantic>=2.7.1
10
+ timm>=0.9.16
11
+ tiktoken>=0.7.0
12
+ accelerate>=0.30.1
13
+ sentence_transformers>=2.7.0
14
+
15
+ # web demo
16
+ gradio>=4.33.0
17
+
18
+ # openai demo
19
+ openai>=1.34.0
20
+ einops>=0.7.0
21
+ sse-starlette>=2.1.0
22
+
23
+ # INT4
24
+ # bitsandbytes>=0.43.1
25
+ bitsandbytes>=0.42.0
26
+
27
+
28
+ # PEFT model, not need if you don't use PEFT finetune model.
29
+ peft>=0.11.0
trans_batch_demo.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ Here is an example of using batch request glm-4-9b,
4
+ here you need to build the conversation format yourself and then call the batch function to make batch requests.
5
+ Please note that in this demo, the memory consumption is significantly higher.
6
+
7
+ """
8
+
9
+ from typing import Optional, Union
10
+ from transformers import AutoModel, AutoTokenizer, LogitsProcessorList
11
+
12
+ MODEL_PATH = 'THUDM/glm-4-9b-chat'
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained(
15
+ MODEL_PATH,
16
+ trust_remote_code=True,
17
+ encode_special_tokens=True)
18
+ model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
19
+
20
+
21
+ def process_model_outputs(inputs, outputs, tokenizer):
22
+ responses = []
23
+ for input_ids, output_ids in zip(inputs.input_ids, outputs):
24
+ response = tokenizer.decode(output_ids[len(input_ids):], skip_special_tokens=True).strip()
25
+ responses.append(response)
26
+ return responses
27
+
28
+
29
+ def batch(
30
+ model,
31
+ tokenizer,
32
+ messages: Union[str, list[str]],
33
+ max_input_tokens: int = 8192,
34
+ max_new_tokens: int = 8192,
35
+ num_beams: int = 1,
36
+ do_sample: bool = True,
37
+ top_p: float = 0.8,
38
+ temperature: float = 0.8,
39
+ logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
40
+ ):
41
+ messages = [messages] if isinstance(messages, str) else messages
42
+ batched_inputs = tokenizer(messages, return_tensors="pt", padding="max_length", truncation=True,
43
+ max_length=max_input_tokens).to(model.device)
44
+
45
+ gen_kwargs = {
46
+ "max_new_tokens": max_new_tokens,
47
+ "num_beams": num_beams,
48
+ "do_sample": do_sample,
49
+ "top_p": top_p,
50
+ "temperature": temperature,
51
+ "logits_processor": logits_processor,
52
+ "eos_token_id": model.config.eos_token_id
53
+ }
54
+ batched_outputs = model.generate(**batched_inputs, **gen_kwargs)
55
+ batched_response = process_model_outputs(batched_inputs, batched_outputs, tokenizer)
56
+ return batched_response
57
+
58
+
59
+ if __name__ == "__main__":
60
+
61
+ batch_message = [
62
+ [
63
+ {"role": "user", "content": "我的爸爸和妈妈结婚为什么不能带我去"},
64
+ {"role": "assistant", "content": "因为他们结婚时你还没有出生"},
65
+ {"role": "user", "content": "我刚才的提问是"}
66
+ ],
67
+ [
68
+ {"role": "user", "content": "你好,你是谁"}
69
+ ]
70
+ ]
71
+
72
+ batch_inputs = []
73
+ max_input_tokens = 1024
74
+ for i, messages in enumerate(batch_message):
75
+ new_batch_input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
76
+ max_input_tokens = max(max_input_tokens, len(new_batch_input))
77
+ batch_inputs.append(new_batch_input)
78
+ gen_kwargs = {
79
+ "max_input_tokens": max_input_tokens,
80
+ "max_new_tokens": 8192,
81
+ "do_sample": True,
82
+ "top_p": 0.8,
83
+ "temperature": 0.8,
84
+ "num_beams": 1,
85
+ }
86
+
87
+ batch_responses = batch(model, tokenizer, batch_inputs, **gen_kwargs)
88
+ for response in batch_responses:
89
+ print("=" * 10)
90
+ print(response)
trans_cli_demo.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script creates a CLI demo with transformers backend for the glm-4-9b model,
3
+ allowing users to interact with the model through a command-line interface.
4
+
5
+ Usage:
6
+ - Run the script to start the CLI demo.
7
+ - Interact with the model by typing questions and receiving responses.
8
+
9
+ Note: The script includes a modification to handle markdown to plain text conversion,
10
+ ensuring that the CLI interface displays formatted text correctly.
11
+
12
+ If you use flash attention, you should install the flash-attn and add attn_implementation="flash_attention_2" in model loading.
13
+ """
14
+
15
+ import os
16
+ import torch
17
+ from threading import Thread
18
+ from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, AutoModel
19
+
20
+ MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
21
+
22
+ print("MODEL_PATH: " + MODEL_PATH)
23
+
24
+ ## If use peft model.
25
+ # def load_model_and_tokenizer(model_dir, trust_remote_code: bool = True):
26
+ # if (model_dir / 'adapter_config.json').exists():
27
+ # model = AutoModel.from_pretrained(
28
+ # model_dir, trust_remote_code=trust_remote_code, device_map='auto'
29
+ # )
30
+ # tokenizer_dir = model.peft_config['default'].base_model_name_or_path
31
+ # else:
32
+ # model = AutoModel.from_pretrained(
33
+ # model_dir, trust_remote_code=trust_remote_code, device_map='auto'
34
+ # )
35
+ # tokenizer_dir = model_dir
36
+ # tokenizer = AutoTokenizer.from_pretrained(
37
+ # tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False
38
+ # )
39
+ # return model, tokenizer
40
+
41
+
42
+ tokenizer = AutoTokenizer.from_pretrained(
43
+ MODEL_PATH,
44
+ trust_remote_code=True,
45
+ encode_special_tokens=True
46
+ )
47
+
48
+ model = AutoModel.from_pretrained(
49
+ MODEL_PATH,
50
+ trust_remote_code=True,
51
+ # attn_implementation="flash_attention_2", # Use Flash Attention
52
+ # torch_dtype=torch.bfloat16, #using flash-attn must use bfloat16 or float16
53
+ device_map="auto").eval()
54
+
55
+
56
+ class StopOnTokens(StoppingCriteria):
57
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
58
+ stop_ids = model.config.eos_token_id
59
+ for stop_id in stop_ids:
60
+ if input_ids[0][-1] == stop_id:
61
+ return True
62
+ return False
63
+
64
+
65
+ if __name__ == "__main__":
66
+ history = []
67
+ max_length = 8192
68
+ top_p = 0.8
69
+ temperature = 0.6
70
+ stop = StopOnTokens()
71
+
72
+ print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
73
+ while True:
74
+ user_input = input("\nYou: ")
75
+ if user_input.lower() in ["exit", "quit"]:
76
+ break
77
+ history.append([user_input, ""])
78
+
79
+ messages = []
80
+ for idx, (user_msg, model_msg) in enumerate(history):
81
+ if idx == len(history) - 1 and not model_msg:
82
+ messages.append({"role": "user", "content": user_msg})
83
+ break
84
+ if user_msg:
85
+ messages.append({"role": "user", "content": user_msg})
86
+ if model_msg:
87
+ messages.append({"role": "assistant", "content": model_msg})
88
+ model_inputs = tokenizer.apply_chat_template(
89
+ messages,
90
+ add_generation_prompt=True,
91
+ tokenize=True,
92
+ return_tensors="pt"
93
+ ).to(model.device)
94
+ streamer = TextIteratorStreamer(
95
+ tokenizer=tokenizer,
96
+ timeout=60,
97
+ skip_prompt=True,
98
+ skip_special_tokens=True
99
+ )
100
+ generate_kwargs = {
101
+ "input_ids": model_inputs,
102
+ "streamer": streamer,
103
+ "max_new_tokens": max_length,
104
+ "do_sample": True,
105
+ "top_p": top_p,
106
+ "temperature": temperature,
107
+ "stopping_criteria": StoppingCriteriaList([stop]),
108
+ "repetition_penalty": 1.2,
109
+ "eos_token_id": model.config.eos_token_id,
110
+ }
111
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
112
+ t.start()
113
+ print("GLM-4:", end="", flush=True)
114
+ for new_token in streamer:
115
+ if new_token:
116
+ print(new_token, end="", flush=True)
117
+ history[-1][1] += new_token
118
+
119
+ history[-1][1] = history[-1][1].strip()
trans_cli_vision_demo.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script creates a CLI demo with transformers backend for the glm-4v-9b model,
3
+ allowing users to interact with the model through a command-line interface.
4
+
5
+ Usage:
6
+ - Run the script to start the CLI demo.
7
+ - Interact with the model by typing questions and receiving responses.
8
+
9
+ Note: The script includes a modification to handle markdown to plain text conversion,
10
+ ensuring that the CLI interface displays formatted text correctly.
11
+ """
12
+
13
+ import os
14
+ import torch
15
+ from threading import Thread
16
+ from transformers import (
17
+ AutoTokenizer,
18
+ StoppingCriteria,
19
+ StoppingCriteriaList,
20
+ TextIteratorStreamer, AutoModel, BitsAndBytesConfig
21
+ )
22
+
23
+ from PIL import Image
24
+
25
+ MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4v-9b')
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(
28
+ MODEL_PATH,
29
+ trust_remote_code=True,
30
+ encode_special_tokens=True
31
+ )
32
+ model = AutoModel.from_pretrained(
33
+ MODEL_PATH,
34
+ trust_remote_code=True,
35
+ # attn_implementation="flash_attention_2", # Use Flash Attention
36
+ # torch_dtype=torch.bfloat16, # using flash-attn must use bfloat16 or float16,
37
+ device_map="auto",
38
+ ).eval()
39
+
40
+
41
+ ## For INT4 inference
42
+ # model = AutoModel.from_pretrained(
43
+ # MODEL_PATH,
44
+ # trust_remote_code=True,
45
+ # quantization_config=BitsAndBytesConfig(load_in_4bit=True),
46
+ # torch_dtype=torch.bfloat16,
47
+ # low_cpu_mem_usage=True
48
+ # ).eval()
49
+
50
+ class StopOnTokens(StoppingCriteria):
51
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
52
+ stop_ids = model.config.eos_token_id
53
+ for stop_id in stop_ids:
54
+ if input_ids[0][-1] == stop_id:
55
+ return True
56
+ return False
57
+
58
+
59
+ if __name__ == "__main__":
60
+ history = []
61
+ max_length = 1024
62
+ top_p = 0.8
63
+ temperature = 0.6
64
+ stop = StopOnTokens()
65
+ uploaded = False
66
+ image = None
67
+ print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
68
+ image_path = input("Image Path:")
69
+ try:
70
+ image = Image.open(image_path).convert("RGB")
71
+ except:
72
+ print("Invalid image path. Continuing with text conversation.")
73
+ while True:
74
+ user_input = input("\nYou: ")
75
+ if user_input.lower() in ["exit", "quit"]:
76
+ break
77
+ history.append([user_input, ""])
78
+
79
+ messages = []
80
+ for idx, (user_msg, model_msg) in enumerate(history):
81
+ if idx == len(history) - 1 and not model_msg:
82
+ messages.append({"role": "user", "content": user_msg})
83
+ if image and not uploaded:
84
+ messages[-1].update({"image": image})
85
+ uploaded = True
86
+ break
87
+ if user_msg:
88
+ messages.append({"role": "user", "content": user_msg})
89
+ if model_msg:
90
+ messages.append({"role": "assistant", "content": model_msg})
91
+ model_inputs = tokenizer.apply_chat_template(
92
+ messages,
93
+ add_generation_prompt=True,
94
+ tokenize=True,
95
+ return_tensors="pt",
96
+ return_dict=True
97
+ ).to(next(model.parameters()).device)
98
+ streamer = TextIteratorStreamer(
99
+ tokenizer=tokenizer,
100
+ timeout=60,
101
+ skip_prompt=True,
102
+ skip_special_tokens=True
103
+ )
104
+ generate_kwargs = {
105
+ **model_inputs,
106
+ "streamer": streamer,
107
+ "max_new_tokens": max_length,
108
+ "do_sample": True,
109
+ "top_p": top_p,
110
+ "temperature": temperature,
111
+ "stopping_criteria": StoppingCriteriaList([stop]),
112
+ "repetition_penalty": 1.2,
113
+ "eos_token_id": [151329, 151336, 151338],
114
+ }
115
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
116
+ t.start()
117
+ print("GLM-4V:", end="", flush=True)
118
+ for new_token in streamer:
119
+ if new_token:
120
+ print(new_token, end="", flush=True)
121
+ history[-1][1] += new_token
122
+
123
+ history[-1][1] = history[-1][1].strip()
trans_stress_test.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
4
+ import torch
5
+ from threading import Thread
6
+
7
+ MODEL_PATH = 'THUDM/glm-4-9b-chat'
8
+
9
+
10
+ def stress_test(token_len, n, num_gpu):
11
+ device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu")
12
+ tokenizer = AutoTokenizer.from_pretrained(
13
+ MODEL_PATH,
14
+ trust_remote_code=True,
15
+ padding_side="left"
16
+ )
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ MODEL_PATH,
19
+ trust_remote_code=True,
20
+ torch_dtype=torch.bfloat16
21
+ ).to(device).eval()
22
+
23
+ # Use INT4 weight infer
24
+ # model = AutoModelForCausalLM.from_pretrained(
25
+ # MODEL_PATH,
26
+ # trust_remote_code=True,
27
+ # quantization_config=BitsAndBytesConfig(load_in_4bit=True),
28
+ # low_cpu_mem_usage=True,
29
+ # ).eval()
30
+
31
+ times = []
32
+ decode_times = []
33
+
34
+ print("Warming up...")
35
+ vocab_size = tokenizer.vocab_size
36
+ warmup_token_len = 20
37
+ random_token_ids = torch.randint(3, vocab_size - 200, (warmup_token_len - 5,), dtype=torch.long)
38
+ start_tokens = [151331, 151333, 151336, 198]
39
+ end_tokens = [151337]
40
+ input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(0).to(
41
+ device)
42
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
43
+ position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
44
+ warmup_inputs = {
45
+ 'input_ids': input_ids,
46
+ 'attention_mask': attention_mask,
47
+ 'position_ids': position_ids
48
+ }
49
+ with torch.no_grad():
50
+ _ = model.generate(
51
+ input_ids=warmup_inputs['input_ids'],
52
+ attention_mask=warmup_inputs['attention_mask'],
53
+ max_new_tokens=2048,
54
+ do_sample=False,
55
+ repetition_penalty=1.0,
56
+ eos_token_id=[151329, 151336, 151338]
57
+ )
58
+ print("Warming up complete. Starting stress test...")
59
+
60
+ for i in range(n):
61
+ random_token_ids = torch.randint(3, vocab_size - 200, (token_len - 5,), dtype=torch.long)
62
+ input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(
63
+ 0).to(device)
64
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
65
+ position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
66
+ test_inputs = {
67
+ 'input_ids': input_ids,
68
+ 'attention_mask': attention_mask,
69
+ 'position_ids': position_ids
70
+ }
71
+
72
+ streamer = TextIteratorStreamer(
73
+ tokenizer=tokenizer,
74
+ timeout=36000,
75
+ skip_prompt=True,
76
+ skip_special_tokens=True
77
+ )
78
+
79
+ generate_kwargs = {
80
+ "input_ids": test_inputs['input_ids'],
81
+ "attention_mask": test_inputs['attention_mask'],
82
+ "max_new_tokens": 512,
83
+ "do_sample": False,
84
+ "repetition_penalty": 1.0,
85
+ "eos_token_id": [151329, 151336, 151338],
86
+ "streamer": streamer
87
+ }
88
+
89
+ start_time = time.time()
90
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
91
+ t.start()
92
+
93
+ first_token_time = None
94
+ all_token_times = []
95
+
96
+ for token in streamer:
97
+ current_time = time.time()
98
+ if first_token_time is None:
99
+ first_token_time = current_time
100
+ times.append(first_token_time - start_time)
101
+ all_token_times.append(current_time)
102
+
103
+ t.join()
104
+ end_time = time.time()
105
+
106
+ avg_decode_time_per_token = len(all_token_times) / (end_time - first_token_time) if all_token_times else 0
107
+ decode_times.append(avg_decode_time_per_token)
108
+ print(
109
+ f"Iteration {i + 1}/{n} - Prefilling Time: {times[-1]:.4f} seconds - Average Decode Time: {avg_decode_time_per_token:.4f} tokens/second")
110
+
111
+ torch.cuda.empty_cache()
112
+
113
+ avg_first_token_time = sum(times) / n
114
+ avg_decode_time = sum(decode_times) / n
115
+ print(f"\nAverage First Token Time over {n} iterations: {avg_first_token_time:.4f} seconds")
116
+ print(f"Average Decode Time per Token over {n} iterations: {avg_decode_time:.4f} tokens/second")
117
+ return times, avg_first_token_time, decode_times, avg_decode_time
118
+
119
+
120
+ def main():
121
+ parser = argparse.ArgumentParser(description="Stress test for model inference")
122
+ parser.add_argument('--token_len', type=int, default=1000, help='Number of tokens for each test')
123
+ parser.add_argument('--n', type=int, default=3, help='Number of iterations for the stress test')
124
+ parser.add_argument('--num_gpu', type=int, default=1, help='Number of GPUs to use for inference')
125
+ args = parser.parse_args()
126
+
127
+ token_len = args.token_len
128
+ n = args.n
129
+ num_gpu = args.num_gpu
130
+
131
+ stress_test(token_len, n, num_gpu)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ main()
trans_web_demo.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script creates an interactive web demo for the GLM-4-9B model using Gradio,
3
+ a Python library for building quick and easy UI components for machine learning models.
4
+ It's designed to showcase the capabilities of the GLM-4-9B model in a user-friendly interface,
5
+ allowing users to interact with the model through a chat-like interface.
6
+ """
7
+
8
+ import os
9
+ from pathlib import Path
10
+ from threading import Thread
11
+ from typing import Union
12
+
13
+ import gradio as gr
14
+ import torch
15
+ from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ PreTrainedModel,
20
+ PreTrainedTokenizer,
21
+ PreTrainedTokenizerFast,
22
+ StoppingCriteria,
23
+ StoppingCriteriaList,
24
+ TextIteratorStreamer
25
+ )
26
+
27
+ ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
28
+ TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
29
+
30
+ MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
31
+ MODEL_PATH = "/Users/zmac/.cache/huggingface/hub/models--THUDM--glm-4-9b-chat/snapshots/04419001bc63e05e70991ade6da1f91c4aeec278"
32
+ TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
33
+
34
+
35
+ def _resolve_path(path: Union[str, Path]) -> Path:
36
+ return Path(path).expanduser().resolve()
37
+
38
+
39
+ def load_model_and_tokenizer(
40
+ model_dir: Union[str, Path], trust_remote_code: bool = True
41
+ ) -> tuple[ModelType, TokenizerType]:
42
+ model_dir = _resolve_path(model_dir)
43
+ if (model_dir / 'adapter_config.json').exists():
44
+ model = AutoPeftModelForCausalLM.from_pretrained(
45
+ model_dir, trust_remote_code=trust_remote_code, device_map='auto'
46
+ )
47
+ tokenizer_dir = model.peft_config['default'].base_model_name_or_path
48
+ else:
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ model_dir, trust_remote_code=trust_remote_code, device_map='auto'
51
+ )
52
+ tokenizer_dir = model_dir
53
+ tokenizer = AutoTokenizer.from_pretrained(
54
+ tokenizer_dir, trust_remote_code=trust_remote_code, use_fast=False
55
+ )
56
+ return model, tokenizer
57
+
58
+
59
+ model, tokenizer = load_model_and_tokenizer(MODEL_PATH, trust_remote_code=True)
60
+
61
+
62
+ class StopOnTokens(StoppingCriteria):
63
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
64
+ stop_ids = model.config.eos_token_id
65
+ for stop_id in stop_ids:
66
+ if input_ids[0][-1] == stop_id:
67
+ return True
68
+ return False
69
+
70
+
71
+ def parse_text(text):
72
+ lines = text.split("\n")
73
+ lines = [line for line in lines if line != ""]
74
+ count = 0
75
+ for i, line in enumerate(lines):
76
+ if "```" in line:
77
+ count += 1
78
+ items = line.split('`')
79
+ if count % 2 == 1:
80
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
81
+ else:
82
+ lines[i] = f'<br></code></pre>'
83
+ else:
84
+ if i > 0:
85
+ if count % 2 == 1:
86
+ line = line.replace("`", "\`")
87
+ line = line.replace("<", "&lt;")
88
+ line = line.replace(">", "&gt;")
89
+ line = line.replace(" ", "&nbsp;")
90
+ line = line.replace("*", "&ast;")
91
+ line = line.replace("_", "&lowbar;")
92
+ line = line.replace("-", "&#45;")
93
+ line = line.replace(".", "&#46;")
94
+ line = line.replace("!", "&#33;")
95
+ line = line.replace("(", "&#40;")
96
+ line = line.replace(")", "&#41;")
97
+ line = line.replace("$", "&#36;")
98
+ lines[i] = "<br>" + line
99
+ text = "".join(lines)
100
+ return text
101
+
102
+
103
+ def predict(history, prompt, max_length, top_p, temperature):
104
+ stop = StopOnTokens()
105
+ messages = []
106
+ if prompt:
107
+ messages.append({"role": "system", "content": prompt})
108
+ for idx, (user_msg, model_msg) in enumerate(history):
109
+ if prompt and idx == 0:
110
+ continue
111
+ if idx == len(history) - 1 and not model_msg:
112
+ messages.append({"role": "user", "content": user_msg})
113
+ break
114
+ if user_msg:
115
+ messages.append({"role": "user", "content": user_msg})
116
+ if model_msg:
117
+ messages.append({"role": "assistant", "content": model_msg})
118
+
119
+ model_inputs = tokenizer.apply_chat_template(messages,
120
+ add_generation_prompt=True,
121
+ tokenize=True,
122
+ return_tensors="pt").to(next(model.parameters()).device)
123
+ streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
124
+ generate_kwargs = {
125
+ "input_ids": model_inputs,
126
+ "streamer": streamer,
127
+ "max_new_tokens": max_length,
128
+ "do_sample": True,
129
+ "top_p": top_p,
130
+ "temperature": temperature,
131
+ "stopping_criteria": StoppingCriteriaList([stop]),
132
+ "repetition_penalty": 1.2,
133
+ "eos_token_id": model.config.eos_token_id,
134
+ }
135
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
136
+ t.start()
137
+ for new_token in streamer:
138
+ if new_token:
139
+ history[-1][1] += new_token
140
+ yield history
141
+
142
+
143
+ with gr.Blocks() as demo:
144
+ gr.HTML("""<h1 align="center">GLM-4-9B Gradio Simple Chat Demo</h1>""")
145
+ chatbot = gr.Chatbot()
146
+
147
+ with gr.Row():
148
+ with gr.Column(scale=3):
149
+ with gr.Column(scale=12):
150
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
151
+ with gr.Column(min_width=32, scale=1):
152
+ submitBtn = gr.Button("Submit")
153
+ with gr.Column(scale=1):
154
+ prompt_input = gr.Textbox(show_label=False, placeholder="Prompt", lines=10, container=False)
155
+ pBtn = gr.Button("Set Prompt")
156
+ with gr.Column(scale=1):
157
+ emptyBtn = gr.Button("Clear History")
158
+ max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
159
+ top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
160
+ temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
161
+
162
+
163
+ def user(query, history):
164
+ return "", history + [[parse_text(query), ""]]
165
+
166
+
167
+ def set_prompt(prompt_text):
168
+ return [[parse_text(prompt_text), "成功设置prompt"]]
169
+
170
+
171
+ pBtn.click(set_prompt, inputs=[prompt_input], outputs=chatbot)
172
+
173
+ submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
174
+ predict, [chatbot, prompt_input, max_length, top_p, temperature], chatbot
175
+ )
176
+ emptyBtn.click(lambda: (None, None), None, [chatbot, prompt_input], queue=False)
177
+
178
+ demo.queue()
179
+ demo.launch(server_name="127.0.0.1", server_port=8000, inbrowser=True, share=True)
trans_web_vision_demo.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script creates a Gradio demo with a Transformers backend for the glm-4v-9b model, allowing users to interact with the model through a Gradio web UI.
3
+
4
+ Usage:
5
+ - Run the script to start the Gradio server.
6
+ - Interact with the model via the web UI.
7
+
8
+ Requirements:
9
+ - Gradio package
10
+ - Type `pip install gradio` to install Gradio.
11
+ """
12
+
13
+ import os
14
+ import torch
15
+ import gradio as gr
16
+ from threading import Thread
17
+ from transformers import (
18
+ AutoTokenizer,
19
+ StoppingCriteria,
20
+ StoppingCriteriaList,
21
+ TextIteratorStreamer, AutoModel, BitsAndBytesConfig
22
+ )
23
+ from PIL import Image
24
+ import requests
25
+ from io import BytesIO
26
+
27
+ MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4v-9b')
28
+
29
+ tokenizer = AutoTokenizer.from_pretrained(
30
+ MODEL_PATH,
31
+ trust_remote_code=True,
32
+ encode_special_tokens=True
33
+ )
34
+ model = AutoModel.from_pretrained(
35
+ MODEL_PATH,
36
+ trust_remote_code=True,
37
+ device_map="auto",
38
+ torch_dtype=torch.bfloat16
39
+ ).eval()
40
+
41
+ class StopOnTokens(StoppingCriteria):
42
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
43
+ stop_ids = model.config.eos_token_id
44
+ for stop_id in stop_ids:
45
+ if input_ids[0][-1] == stop_id:
46
+ return True
47
+ return False
48
+
49
+ def get_image(image_path=None, image_url=None):
50
+ if image_path:
51
+ return Image.open(image_path).convert("RGB")
52
+ elif image_url:
53
+ response = requests.get(image_url)
54
+ return Image.open(BytesIO(response.content)).convert("RGB")
55
+ return None
56
+
57
+ def chatbot(image_path=None, image_url=None, assistant_prompt=""):
58
+ image = get_image(image_path, image_url)
59
+
60
+ messages = [
61
+ {"role": "assistant", "content": assistant_prompt},
62
+ {"role": "user", "content": "", "image": image}
63
+ ]
64
+
65
+ model_inputs = tokenizer.apply_chat_template(
66
+ messages,
67
+ add_generation_prompt=True,
68
+ tokenize=True,
69
+ return_tensors="pt",
70
+ return_dict=True
71
+ ).to(next(model.parameters()).device)
72
+
73
+ streamer = TextIteratorStreamer(
74
+ tokenizer=tokenizer,
75
+ timeout=60,
76
+ skip_prompt=True,
77
+ skip_special_tokens=True
78
+ )
79
+
80
+ generate_kwargs = {
81
+ **model_inputs,
82
+ "streamer": streamer,
83
+ "max_new_tokens": 1024,
84
+ "do_sample": True,
85
+ "top_p": 0.8,
86
+ "temperature": 0.6,
87
+ "stopping_criteria": StoppingCriteriaList([StopOnTokens()]),
88
+ "repetition_penalty": 1.2,
89
+ "eos_token_id": [151329, 151336, 151338],
90
+ }
91
+
92
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
93
+ t.start()
94
+
95
+ response = ""
96
+ for new_token in streamer:
97
+ if new_token:
98
+ response += new_token
99
+
100
+ return image, response.strip()
101
+
102
+ with gr.Blocks() as demo:
103
+ demo.title = "GLM-4V-9B Image Recognition Demo"
104
+ demo.description = """
105
+ This demo uses the GLM-4V-9B model to got image infomation.
106
+ """
107
+ with gr.Row():
108
+ with gr.Column():
109
+ image_path_input = gr.File(label="Upload Image (High-Priority)", type="filepath")
110
+ image_url_input = gr.Textbox(label="Image URL (Low-Priority)")
111
+ assistant_prompt_input = gr.Textbox(label="Assistant Prompt (You Can Change It)", value="这是什么?")
112
+ submit_button = gr.Button("Submit")
113
+ with gr.Column():
114
+ chatbot_output = gr.Textbox(label="GLM-4V-9B Model Response")
115
+ image_output = gr.Image(label="Image Preview")
116
+
117
+ submit_button.click(chatbot,
118
+ inputs=[image_path_input, image_url_input, assistant_prompt_input],
119
+ outputs=[image_output, chatbot_output])
120
+
121
+ demo.launch(server_name="127.0.0.1", server_port=8911, inbrowser=True, share=False)
vllm_cli_demo.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script creates a CLI demo with vllm backand for the glm-4-9b model,
3
+ allowing users to interact with the model through a command-line interface.
4
+
5
+ Usage:
6
+ - Run the script to start the CLI demo.
7
+ - Interact with the model by typing questions and receiving responses.
8
+
9
+ Note: The script includes a modification to handle markdown to plain text conversion,
10
+ ensuring that the CLI interface displays formatted text correctly.
11
+ """
12
+ import time
13
+ import asyncio
14
+ from transformers import AutoTokenizer
15
+ from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
16
+ from typing import List, Dict
17
+
18
+ MODEL_PATH = 'THUDM/glm-4-9b'
19
+
20
+
21
+ def load_model_and_tokenizer(model_dir: str):
22
+ engine_args = AsyncEngineArgs(
23
+ model=model_dir,
24
+ tokenizer=model_dir,
25
+ tensor_parallel_size=1,
26
+ dtype="bfloat16",
27
+ trust_remote_code=True,
28
+ gpu_memory_utilization=0.3,
29
+ enforce_eager=True,
30
+ worker_use_ray=True,
31
+ engine_use_ray=False,
32
+ disable_log_requests=True
33
+ # 如果遇见 OOM 现象,建议开启下述参数
34
+ # enable_chunked_prefill=True,
35
+ # max_num_batched_tokens=8192
36
+ )
37
+ tokenizer = AutoTokenizer.from_pretrained(
38
+ model_dir,
39
+ trust_remote_code=True,
40
+ encode_special_tokens=True
41
+ )
42
+ engine = AsyncLLMEngine.from_engine_args(engine_args)
43
+ return engine, tokenizer
44
+
45
+
46
+ engine, tokenizer = load_model_and_tokenizer(MODEL_PATH)
47
+
48
+
49
+ async def vllm_gen(messages: List[Dict[str, str]], top_p: float, temperature: float, max_dec_len: int):
50
+ inputs = tokenizer.apply_chat_template(
51
+ messages,
52
+ add_generation_prompt=True,
53
+ tokenize=False
54
+ )
55
+ params_dict = {
56
+ "n": 1,
57
+ "best_of": 1,
58
+ "presence_penalty": 1.0,
59
+ "frequency_penalty": 0.0,
60
+ "temperature": temperature,
61
+ "top_p": top_p,
62
+ "top_k": -1,
63
+ "use_beam_search": False,
64
+ "length_penalty": 1,
65
+ "early_stopping": False,
66
+ "stop_token_ids": [151329, 151336, 151338],
67
+ "ignore_eos": False,
68
+ "max_tokens": max_dec_len,
69
+ "logprobs": None,
70
+ "prompt_logprobs": None,
71
+ "skip_special_tokens": True,
72
+ }
73
+ sampling_params = SamplingParams(**params_dict)
74
+ async for output in engine.generate(inputs=inputs, sampling_params=sampling_params, request_id=f"{time.time()}"):
75
+ yield output.outputs[0].text
76
+
77
+
78
+ async def chat():
79
+ history = []
80
+ max_length = 8192
81
+ top_p = 0.8
82
+ temperature = 0.6
83
+
84
+ print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
85
+ while True:
86
+ user_input = input("\nYou: ")
87
+ if user_input.lower() in ["exit", "quit"]:
88
+ break
89
+ history.append([user_input, ""])
90
+
91
+ messages = []
92
+ for idx, (user_msg, model_msg) in enumerate(history):
93
+ if idx == len(history) - 1 and not model_msg:
94
+ messages.append({"role": "user", "content": user_msg})
95
+ break
96
+ if user_msg:
97
+ messages.append({"role": "user", "content": user_msg})
98
+ if model_msg:
99
+ messages.append({"role": "assistant", "content": model_msg})
100
+
101
+ print("\nGLM-4: ", end="")
102
+ current_length = 0
103
+ output = ""
104
+ async for output in vllm_gen(messages, top_p, temperature, max_length):
105
+ print(output[current_length:], end="", flush=True)
106
+ current_length = len(output)
107
+ history[-1][1] = output
108
+
109
+
110
+ if __name__ == "__main__":
111
+ asyncio.run(chat())