JoshuaChak
commited on
Commit
•
7c071a8
1
Parent(s):
ddb8425
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +45 -0
- Baichuan2/README.md +182 -0
- Baichuan2/compile/compile.sh +186 -0
- Baichuan2/compile/export_onnx.py +182 -0
- Baichuan2/compile/files/Baichuan2-7B/config.json +29 -0
- Baichuan2/compile/files/Baichuan2-7B/modeling_baichuan.py +792 -0
- Baichuan2/compile/torch_inference.py +16 -0
- Baichuan2/demo/CMakeLists.txt +38 -0
- Baichuan2/demo/demo.cpp +472 -0
- Baichuan2/model/tokenizer.model +3 -0
- Baichuan2/requirements.txt +7 -0
- Baichuan2/src/include/bmdef.h +129 -0
- Baichuan2/src/include/bmlib_runtime.h +2581 -0
- Baichuan2/src/include/bmruntime_interface.h +404 -0
- Baichuan2/src/include/sentencepiece/sentencepiece_processor.h +727 -0
- Baichuan2/src/lib_pcie/libbmlib.so +0 -0
- Baichuan2/src/lib_pcie/libbmrt.so +3 -0
- Baichuan2/src/lib_pcie/libbmrt.so.1.0 +3 -0
- Baichuan2/src/lib_pcie/libsentencepiece.a +3 -0
- Baichuan2/src/lib_soc/libbmlib.so +0 -0
- Baichuan2/src/lib_soc/libbmrt.so +3 -0
- Baichuan2/src/lib_soc/libbmrt.so.1.0 +3 -0
- Baichuan2/src/lib_soc/libsentencepiece.a +3 -0
- Baichuan2/web_demo/CMakeLists.txt +36 -0
- Baichuan2/web_demo/chat.cpp +419 -0
- Baichuan2/web_demo/chat.py +97 -0
- Baichuan2/web_demo/web_demo.py +108 -0
- BaseModel/base_model.py +184 -0
- ChatGLM2/README.md +160 -0
- ChatGLM2/compile/compile.sh +179 -0
- ChatGLM2/compile/export_onnx.py +176 -0
- ChatGLM2/compile/files/chatglm2-6b/config.json +42 -0
- ChatGLM2/compile/files/chatglm2-6b/modeling_chatglm.py +1285 -0
- ChatGLM2/demo/CMakeLists.txt +33 -0
- ChatGLM2/demo/demo.cpp +609 -0
- ChatGLM2/run_demo.sh +27 -0
- ChatGLM2/support/include/bmdef.h +129 -0
- ChatGLM2/support/include/bmlib_runtime.h +2581 -0
- ChatGLM2/support/include/bmruntime_interface.h +404 -0
- ChatGLM2/support/include/sentencepiece/sentencepiece_processor.h +727 -0
- ChatGLM2/support/lib_pcie/libbmlib.so +0 -0
- ChatGLM2/support/lib_pcie/libbmrt.so +3 -0
- ChatGLM2/support/lib_pcie/libbmrt.so.1.0 +3 -0
- ChatGLM2/support/lib_pcie/libsentencepiece.a +3 -0
- ChatGLM2/support/lib_soc/libbmlib.so +0 -0
- ChatGLM2/support/lib_soc/libbmrt.so +3 -0
- ChatGLM2/support/lib_soc/libbmrt.so.1.0 +3 -0
- ChatGLM2/support/lib_soc/libsentencepiece.a +3 -0
- ChatGLM2/support/tokenizer/tokenization_chatglm.py +257 -0
- ChatGLM2/support/tokenizer/tokenizer.model +3 -0
.gitattributes
CHANGED
@@ -34,3 +34,48 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
qwen1.5-1.8b_int4_1dev.bmodel filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
qwen1.5-1.8b_int4_1dev.bmodel filter=lfs diff=lfs merge=lfs -text
|
37 |
+
Baichuan2/src/lib_pcie/libbmrt.so filter=lfs diff=lfs merge=lfs -text
|
38 |
+
Baichuan2/src/lib_pcie/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
Baichuan2/src/lib_pcie/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
|
40 |
+
Baichuan2/src/lib_soc/libbmrt.so filter=lfs diff=lfs merge=lfs -text
|
41 |
+
Baichuan2/src/lib_soc/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
|
42 |
+
Baichuan2/src/lib_soc/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
|
43 |
+
ChatGLM2/support/lib_pcie/libbmrt.so filter=lfs diff=lfs merge=lfs -text
|
44 |
+
ChatGLM2/support/lib_pcie/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
|
45 |
+
ChatGLM2/support/lib_pcie/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
|
46 |
+
ChatGLM2/support/lib_soc/libbmrt.so filter=lfs diff=lfs merge=lfs -text
|
47 |
+
ChatGLM2/support/lib_soc/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
|
48 |
+
ChatGLM2/support/lib_soc/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
|
49 |
+
ChatGLM3/support/lib_pcie/libbmrt.so filter=lfs diff=lfs merge=lfs -text
|
50 |
+
ChatGLM3/support/lib_pcie/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
|
51 |
+
ChatGLM3/support/lib_pcie/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
|
52 |
+
ChatGLM3/support/lib_soc/libbmrt.so filter=lfs diff=lfs merge=lfs -text
|
53 |
+
ChatGLM3/support/lib_soc/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
|
54 |
+
ChatGLM3/support/lib_soc/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
|
55 |
+
DeepSeek/requirements/sophon-3.7.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
56 |
+
LWM/support/lib_pcie/libbmrt.so filter=lfs diff=lfs merge=lfs -text
|
57 |
+
LWM/support/lib_pcie/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
|
58 |
+
LWM/support/lib_pcie/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
|
59 |
+
LWM/support/lib_soc/libbmrt.so filter=lfs diff=lfs merge=lfs -text
|
60 |
+
LWM/support/lib_soc/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
|
61 |
+
LWM/support/lib_soc/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
|
62 |
+
Llama2/assets/llama2_pcie filter=lfs diff=lfs merge=lfs -text
|
63 |
+
Llama2/demo_parallel/lib/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
|
64 |
+
Llama2/support/lib_pcie/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
|
65 |
+
Llama2/support/lib_pcie/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
|
66 |
+
Llama2/support/lib_soc/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
|
67 |
+
Llama2/support/lib_soc/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
|
68 |
+
Llama3/python_demo/build/CMakeFiles/chat.dir/chat.cpp.o filter=lfs diff=lfs merge=lfs -text
|
69 |
+
Llama3/python_demo/build/chat.cpython-38-aarch64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
70 |
+
Llama3/python_demo/chat.cpython-38-aarch64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
71 |
+
Qwen1_5/python_demo/build/CMakeFiles/chat.dir/chat.cpp.o filter=lfs diff=lfs merge=lfs -text
|
72 |
+
Qwen1_5/python_demo/build/CMakeFiles/chat_parallel.dir/chat_parallel.cpp.o filter=lfs diff=lfs merge=lfs -text
|
73 |
+
Qwen1_5/python_demo/build/chat.cpython-38-aarch64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
74 |
+
Qwen1_5/python_demo/build/chat_parallel.cpython-38-aarch64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
75 |
+
Qwen1_5/python_demo/chat.cpython-38-aarch64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
76 |
+
Qwen1_5/python_demo/chat_parallel.cpython-38-aarch64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
77 |
+
WizardCoder/demo/lib_pcie/lib/libbmrt.so filter=lfs diff=lfs merge=lfs -text
|
78 |
+
WizardCoder/demo/lib_pcie/lib/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
|
79 |
+
WizardCoder/demo/lib_soc/lib/libbmrt.so filter=lfs diff=lfs merge=lfs -text
|
80 |
+
WizardCoder/demo/lib_soc/lib/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
|
81 |
+
Yi34B/demo_parallel/lib/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
|
Baichuan2/README.md
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
![image](../../assets/sophgo_chip.png)
|
2 |
+
|
3 |
+
# Baichuan2-TPU
|
4 |
+
|
5 |
+
本项目实现BM1684X部署语言大模型[Baichuan2-7B](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)。通过[TPU-MLIR](https://github.com/sophgo/tpu-mlir)编译器将模型转换成bmodel,并采用c++代码将其部署到BM1684X的PCIE环境,或者SoC环境。
|
6 |
+
|
7 |
+
下文中默认是PCIE环境;如果是SoC环境,按提示操作即可。
|
8 |
+
|
9 |
+
# 目录说明
|
10 |
+
```
|
11 |
+
.
|
12 |
+
├── README.md #使用说明
|
13 |
+
├── requirements.txt #需要使用的python wheel包
|
14 |
+
├── compile
|
15 |
+
│ ├── compile.sh #用来编译TPU模型的脚本
|
16 |
+
│ ├── export_onnx.py #用来导出onnx的脚本
|
17 |
+
│ ├── torch_inference.py #torch推理脚本
|
18 |
+
│ └── files
|
19 |
+
│ └── Baichuan2-7B #替换Baichuan2-7B-chat的对应文件的备份
|
20 |
+
│ ├── config.json
|
21 |
+
│ └── modeling_baichuan.py
|
22 |
+
├── demo #Baichuan2 c++代码文件
|
23 |
+
│ ├── CMakeLists.txt
|
24 |
+
│ └── demo.cpp #主程序
|
25 |
+
├── src #编译依赖库
|
26 |
+
│ ├── include
|
27 |
+
│ ├── lib_pcie
|
28 |
+
│ └── lib_soc
|
29 |
+
├── model #模型文件(bmodel需下载)
|
30 |
+
│ ├── baichuan2-7b-test_int8.bmodel
|
31 |
+
│ └── tokenizer.model
|
32 |
+
└── web_demo #web demo,提供网页对话示例
|
33 |
+
├── chat.cpp
|
34 |
+
├── chat.py
|
35 |
+
├── CMakeLists.txt
|
36 |
+
└── web_demo.py
|
37 |
+
```
|
38 |
+
----------------------------
|
39 |
+
|
40 |
+
# 【阶段一】模型编译
|
41 |
+
|
42 |
+
## 注意点
|
43 |
+
* 模型编译必须要在docker内完成,无法在docker外操作
|
44 |
+
|
45 |
+
### 步骤一:模型下载
|
46 |
+
Baichuan2模型在hugging face上完全开源,供用户下载使用。请根据官网下载步骤进行模型与权重的下载。
|
47 |
+
```bash
|
48 |
+
# Make sure you have git-lfs installed (https://git-lfs.com)
|
49 |
+
git lfs install
|
50 |
+
git clone https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
|
51 |
+
```
|
52 |
+
|
53 |
+
### 步骤二:下载docker
|
54 |
+
|
55 |
+
下载docker,启动容器,如下:
|
56 |
+
|
57 |
+
``` shell
|
58 |
+
docker pull sophgo/tpuc_dev:latest
|
59 |
+
|
60 |
+
# myname1234 is just an example, you can set your own name
|
61 |
+
docker run --privileged --name myname1234 -v $PWD:/workspace -it sophgo/tpuc_dev:latest
|
62 |
+
```
|
63 |
+
|
64 |
+
### 步骤三:下载TPU-MLIR代码并编译
|
65 |
+
|
66 |
+
``` shell
|
67 |
+
git clone git@github.com:sophgo/tpu-mlir.git
|
68 |
+
cd tpu-mlir
|
69 |
+
source ./envsetup.sh
|
70 |
+
./build.sh
|
71 |
+
```
|
72 |
+
* PS:重新进入docker环境并且需要编译模型时,必须在此路径下执行上述`source ./envsetup.sh` 和 `./build.sh`才能完成后续模型编译。
|
73 |
+
|
74 |
+
### 步骤四:下载本项目,安装requirements.txt
|
75 |
+
下载transfomers、sentencepiece、Baichuan2-TPU以及百度网盘里的.bin模型,并替换transformers里面的modeling_baichuan.py
|
76 |
+
|
77 |
+
``` shell
|
78 |
+
git clone https://github.com/sophgo/Baichuan2-TPU.git
|
79 |
+
cd Baichuan2
|
80 |
+
pip3 install -r requirements.txt
|
81 |
+
```
|
82 |
+
|
83 |
+
### 步骤五:替换modeling_baichuan.py, 修改config.json, 生成onnx文件
|
84 |
+
修改Baichuan2-7B-chat项目中config.json文件中max_position_embeddings与model_max_length,从4096变为512
|
85 |
+
|
86 |
+
``` shell
|
87 |
+
cd compile
|
88 |
+
cp files/Baichuan2-7B/modeling_baichuan.py $BAICHUAN2_PATH
|
89 |
+
cp files/Baichuan2-7B/config.json $BAICHUAN2_PATH
|
90 |
+
python3 export_onnx.py --model_path $BAICHUAN2_PATH
|
91 |
+
```
|
92 |
+
|
93 |
+
* PS1:your_model_path 指的是原模型下载后的地址, 如:"../../torch2onnx/Baichuan2-7B-Chat", 可以根据需要选择使用7b模型还是13b模型。
|
94 |
+
* PS2:如果你想要debug,而不是一下子生成完成全部的onnx模型,可以将240行的num_layers改成1, 并结合函数对比单个block情况下是否可以和
|
95 |
+
|
96 |
+
### 步骤六:生成bmodel文件
|
97 |
+
|
98 |
+
生成模型
|
99 |
+
|
100 |
+
``` shell
|
101 |
+
./compile.sh --mode int8
|
102 |
+
mv baichuan2-7b_int8_1dev.bmodel ../model
|
103 |
+
```
|
104 |
+
|
105 |
+
* PS1:编译完成后最终会在Baichuan2-TPU/compile路径下生成名为baichuan2-{X}b_{Y}_{Z}dev.bmodel,其中X为7或13,Y为`compile.sh`时选择的`mode`的数据类型,Z为推理的芯片数量(如果不指定num_device, 会省略{Z}dev的部分)
|
106 |
+
* PS2:生成bmodel耗时大概3小时以上,建议64G内存以及200G以上硬盘空间,不然很可能OOM或者no space left
|
107 |
+
* PS3:目前给定的lib_pcie和lib_soc部分仅包含单芯的动态库,多芯部分会在后续更新
|
108 |
+
|
109 |
+
----------------------------
|
110 |
+
|
111 |
+
# 阶段二:可执行文件生成(可以跳过)
|
112 |
+
|
113 |
+
## 准备
|
114 |
+
* bmodel模型准备:经过阶段一后将得到编译好的bmodel文件【也可以使用我们提供的现成编译好的bmodel文件】,下载方式为:
|
115 |
+
```shell
|
116 |
+
cd Baichuan2-TPU/model
|
117 |
+
pip3 install dfss
|
118 |
+
# baichuan2-7B
|
119 |
+
python3 -m dfss --url=open@sophgo.com:sophon-demo/baichuan2/baichuan2-7b-test_int8.bmodel
|
120 |
+
```
|
121 |
+
将得到编译好的int8单芯bmodel模型文件。
|
122 |
+
|
123 |
+
## 编译程序(C++版本)
|
124 |
+
|
125 |
+
执行如下编译,默认是PCIE版本:
|
126 |
+
|
127 |
+
```shell
|
128 |
+
cd Baichuan2-TPU/demo
|
129 |
+
mkdir build
|
130 |
+
cd build
|
131 |
+
cmake ..
|
132 |
+
make
|
133 |
+
```
|
134 |
+
|
135 |
+
如果是SoC版本,有两种编译方法:
|
136 |
+
|
137 |
+
方法1:直接将demo目录拷贝到SoC环境,按以上步骤编译(推荐)
|
138 |
+
|
139 |
+
方法2:docker中交叉编译,如下操作
|
140 |
+
|
141 |
+
```shell
|
142 |
+
wget https://releases.linaro.org/components/toolchain/binaries/7.5-2019.12/aarch64-linux-gnu/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
|
143 |
+
tar -xvf gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
|
144 |
+
mv gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu /opt/aarch64-linux-gnu-7.5.0
|
145 |
+
cd Baichuan2-TPU/demo
|
146 |
+
mkdir build
|
147 |
+
cd build
|
148 |
+
cmake .. -DTARGET_ARCH=soc # soc 只有一颗芯片,因此不支持多芯编译
|
149 |
+
make -j
|
150 |
+
```
|
151 |
+
|
152 |
+
编译生成Baichuan2可执行程序。
|
153 |
+
|
154 |
+
运行`baichuan2`:
|
155 |
+
```shell
|
156 |
+
./baichuan2 --model ../model/baichuan2-7b-test_int8.bmodel --dev dev_id
|
157 |
+
```
|
158 |
+
|
159 |
+
## 编译程序(Python Web版本)【单芯】
|
160 |
+
|
161 |
+
```shell
|
162 |
+
pip3 install gradio==3.39.0
|
163 |
+
cd Baichuan2-TPU/web_demo
|
164 |
+
mkdir build
|
165 |
+
cd build
|
166 |
+
cmake ..
|
167 |
+
make -j
|
168 |
+
```
|
169 |
+
|
170 |
+
编译成功会在`build`文件夹下生成`libtpuchat.so*`, 此时可以在web_demo.py中指定bmodel\_path token\_path device\_id, lib_path(编译生产的`libtpuchat.so*`文件, 默认路径是`./build`下), 以及dev_id。
|
171 |
+
```python
|
172 |
+
python3 web_demo.py
|
173 |
+
```
|
174 |
+
即可成功运行web的demo。
|
175 |
+
* PS:在用户不修改上述token\_path的lib\_path的存放路径前提下只需指定bmodel\_path即可运行程序。
|
176 |
+
|
177 |
+
如果是SoC环境,参考C++版本
|
178 |
+
|
179 |
+
* PS:尽量下载gradio==3.39.0版本,不然会出现各种问题!!
|
180 |
+
|
181 |
+
# 常见问题
|
182 |
+
* 请根据实际block数目调整`demo/chat`中或者`web_demo/chat.cpp`中的NUM_LAYERS,默认是使用Baichuan2-7B(NUM_LAYERS=32)
|
Baichuan2/compile/compile.sh
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
set -ex
|
3 |
+
models=
|
4 |
+
mode="f16"
|
5 |
+
folder="tmp"
|
6 |
+
num_device=1
|
7 |
+
mode_args=""
|
8 |
+
device_args=""
|
9 |
+
quantize_args="--quantize F16"
|
10 |
+
name=""
|
11 |
+
num_layers=
|
12 |
+
out_model=$name.bmodel
|
13 |
+
|
14 |
+
if [ -z "$name" ]; then
|
15 |
+
name="baichuan2-7b"
|
16 |
+
echo "Compile Baichuan2-7B"
|
17 |
+
else
|
18 |
+
name="baichuan2-13b"
|
19 |
+
echo "Compile Baichuan2-13B"
|
20 |
+
fi
|
21 |
+
|
22 |
+
while [[ $# -gt 0 ]]; do
|
23 |
+
key="$1"
|
24 |
+
|
25 |
+
case $key in
|
26 |
+
--mode)
|
27 |
+
mode="$2"
|
28 |
+
shift 2
|
29 |
+
;;
|
30 |
+
--num_device)
|
31 |
+
num_device="$2"
|
32 |
+
shift 2
|
33 |
+
;;
|
34 |
+
--name)
|
35 |
+
name="$2"
|
36 |
+
shift 2
|
37 |
+
;;
|
38 |
+
*)
|
39 |
+
echo "Invalid option: $key" >&2
|
40 |
+
exit 1
|
41 |
+
;;
|
42 |
+
:)
|
43 |
+
echo "Option -$OPTARG requires an argument." >&2
|
44 |
+
exit 1
|
45 |
+
;;
|
46 |
+
esac
|
47 |
+
done
|
48 |
+
|
49 |
+
if [ x$mode == x"int8" ] || [ x$mode == x"int4" ]; then
|
50 |
+
if [ x$mode == x"int8" ]; then
|
51 |
+
quantize_args="--quantize W8F16"
|
52 |
+
else
|
53 |
+
quantize_args="--quantize W4BF16 --q_group_size 64"
|
54 |
+
fi
|
55 |
+
out_model=$name'_'$mode'.bmodel'
|
56 |
+
fi
|
57 |
+
|
58 |
+
if [ x$name == x"baichuan2-7b" ] || [ x$name == x"baichuan2-13b" ]; then
|
59 |
+
if [ x$name == x"baichuan2-7b" ]; then
|
60 |
+
num_layers=32
|
61 |
+
else
|
62 |
+
num_layers=40
|
63 |
+
fi
|
64 |
+
fi
|
65 |
+
|
66 |
+
if [ x$num_device != x1 ]; then
|
67 |
+
device_args="--num_device $num_device"
|
68 |
+
out_model=$name'_'$mode'_'$num_device'dev.bmodel'
|
69 |
+
else
|
70 |
+
out_model=$name'_'$mode'_1dev.bmodel'
|
71 |
+
fi
|
72 |
+
|
73 |
+
outdir=${folder}/embedding
|
74 |
+
mkdir -p $outdir
|
75 |
+
pushd $outdir
|
76 |
+
|
77 |
+
seqlen=512
|
78 |
+
model_transform.py \
|
79 |
+
--model_name embedding \
|
80 |
+
--model_def ../embedding.onnx \
|
81 |
+
--input_shapes [[1,$seqlen]] \
|
82 |
+
--mlir embedding_${seqlen}.mlir
|
83 |
+
|
84 |
+
|
85 |
+
model_deploy.py \
|
86 |
+
--mlir embedding_$seqlen.mlir \
|
87 |
+
--quantize F16 \
|
88 |
+
--chip bm1684x \
|
89 |
+
$device_args \
|
90 |
+
--model embedding_${seqlen}_f16.bmodel
|
91 |
+
|
92 |
+
model_transform.py \
|
93 |
+
--model_name embedding_cache \
|
94 |
+
--model_def ../embedding.onnx \
|
95 |
+
--input_shapes [[1,1]] \
|
96 |
+
--mlir embedding_1.mlir
|
97 |
+
|
98 |
+
|
99 |
+
model_deploy.py \
|
100 |
+
--mlir embedding_1.mlir \
|
101 |
+
--quantize F16 \
|
102 |
+
--chip bm1684x \
|
103 |
+
$device_args \
|
104 |
+
--model embedding_1_f16.bmodel
|
105 |
+
|
106 |
+
rm *.npz
|
107 |
+
|
108 |
+
models=$models' '$outdir'/embedding_1_f16.bmodel '$outdir'/embedding_'$seqlen'_f16.bmodel '
|
109 |
+
|
110 |
+
popd
|
111 |
+
|
112 |
+
echo $models
|
113 |
+
|
114 |
+
outdir=${folder}/$mode"_"$num_device"dev"/lm_head
|
115 |
+
mkdir -p $outdir
|
116 |
+
pushd $outdir
|
117 |
+
|
118 |
+
model_transform.py \
|
119 |
+
--model_name lm_head \
|
120 |
+
--model_def ../../lm_head.onnx \
|
121 |
+
--mlir lm_head.mlir
|
122 |
+
|
123 |
+
|
124 |
+
model_deploy.py \
|
125 |
+
--mlir lm_head.mlir \
|
126 |
+
--quantize F16 \
|
127 |
+
--chip bm1684x \
|
128 |
+
--model lm_head.bmodel
|
129 |
+
|
130 |
+
rm *.npz
|
131 |
+
|
132 |
+
models=${models}${outdir}'/lm_head.bmodel '
|
133 |
+
popd
|
134 |
+
|
135 |
+
echo $models
|
136 |
+
|
137 |
+
outdir=${folder}/$mode"_"$num_device"dev"/block
|
138 |
+
mkdir -p $outdir
|
139 |
+
|
140 |
+
pushd $outdir
|
141 |
+
mkdir -p $outdir
|
142 |
+
|
143 |
+
for ((i=0; i<$num_layers; i++))
|
144 |
+
do
|
145 |
+
|
146 |
+
model_transform.py \
|
147 |
+
--model_name block_$i \
|
148 |
+
--model_def ../../block_$i.onnx \
|
149 |
+
--mlir block_$i.mlir
|
150 |
+
|
151 |
+
model_deploy.py \
|
152 |
+
--mlir block_$i.mlir \
|
153 |
+
$quantize_args \
|
154 |
+
--chip bm1684x \
|
155 |
+
--quant_output \
|
156 |
+
--quant_output_list 2,3 \
|
157 |
+
$device_args \
|
158 |
+
--model block_$i.bmodel
|
159 |
+
|
160 |
+
model_transform.py \
|
161 |
+
--model_name block_cache_$i \
|
162 |
+
--model_def ../../block_cache_${i}.onnx \
|
163 |
+
--mlir block_cache_$i.mlir
|
164 |
+
|
165 |
+
model_deploy.py \
|
166 |
+
--mlir block_cache_$i.mlir \
|
167 |
+
$quantize_args \
|
168 |
+
--chip bm1684x \
|
169 |
+
--quant_input \
|
170 |
+
--quant_output \
|
171 |
+
--quant_input_list 4,5 \
|
172 |
+
--quant_output_list 2,3 \
|
173 |
+
$device_args \
|
174 |
+
--model block_cache_$i.bmodel
|
175 |
+
|
176 |
+
rm *.npz
|
177 |
+
# rm ../../block_$i.onnx
|
178 |
+
# rm ../../block_cache_$i.onnx
|
179 |
+
|
180 |
+
models=${models}${outdir}'/block_'$i'.bmodel '$outdir'/block_cache_'$i'.bmodel '
|
181 |
+
|
182 |
+
done
|
183 |
+
popd
|
184 |
+
echo $models
|
185 |
+
|
186 |
+
model_tool --combine $models -o $out_model
|
Baichuan2/compile/export_onnx.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# ==============================================================================
|
3 |
+
#
|
4 |
+
# Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
|
5 |
+
#
|
6 |
+
# TPU-MLIR is licensed under the 2-Clause BSD License except for the
|
7 |
+
# third-party components.
|
8 |
+
#
|
9 |
+
# ==============================================================================
|
10 |
+
|
11 |
+
import os
|
12 |
+
import torch
|
13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
14 |
+
from transformers.generation.utils import GenerationConfig
|
15 |
+
import numpy as np
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
folder = f"./tmp/onnx"
|
19 |
+
parser = argparse.ArgumentParser(description='export onnx.')
|
20 |
+
parser.add_argument('--model_path', type=str, help='path to the torch model.')
|
21 |
+
parser.add_argument('--seq_length', type=int, default=512, help="sequence length")
|
22 |
+
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
model_path = args.model_path
|
26 |
+
folder = "./tmp"
|
27 |
+
|
28 |
+
origin_model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).eval()
|
29 |
+
origin_model.generation_config = GenerationConfig.from_pretrained(model_path)
|
30 |
+
config = origin_model.config
|
31 |
+
transformer = origin_model.model
|
32 |
+
layers = transformer.layers
|
33 |
+
|
34 |
+
SEQ_LENGTH = args.seq_length
|
35 |
+
NUM_LAYERS = config.num_hidden_layers
|
36 |
+
HIDDEN_SIZE = config.hidden_size
|
37 |
+
NUM_ATTENTION_HEADS = config.num_attention_heads
|
38 |
+
HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS
|
39 |
+
|
40 |
+
for param in origin_model.parameters():
|
41 |
+
param.requires_grad = False
|
42 |
+
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
|
44 |
+
|
45 |
+
class Embedding(torch.nn.Module):
|
46 |
+
def __init__(self):
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
def forward(self, input_ids):
|
50 |
+
return transformer.embed_tokens(input_ids)
|
51 |
+
|
52 |
+
|
53 |
+
class Block(torch.nn.Module):
|
54 |
+
|
55 |
+
def __init__(self, layer_id):
|
56 |
+
super().__init__()
|
57 |
+
# params
|
58 |
+
self.layer_id = layer_id
|
59 |
+
self.layer = layers[layer_id]
|
60 |
+
|
61 |
+
def forward(self, hidden_states, position_ids, attention_mask):
|
62 |
+
hidden_states, past_kv = self.layer(hidden_states,
|
63 |
+
attention_mask,
|
64 |
+
position_ids,
|
65 |
+
use_cache=True)
|
66 |
+
present_k, present_v = past_kv
|
67 |
+
return hidden_states, present_k, present_v
|
68 |
+
|
69 |
+
|
70 |
+
class BlockCache(torch.nn.Module):
|
71 |
+
|
72 |
+
def __init__(self, layer_id):
|
73 |
+
super().__init__()
|
74 |
+
# params
|
75 |
+
self.layer_id = layer_id
|
76 |
+
self.layer = layers[layer_id]
|
77 |
+
|
78 |
+
def forward(self, hidden_states, position_ids, attention_mask, past_k,
|
79 |
+
past_v):
|
80 |
+
hidden_states, past_kv = self.layer(hidden_states,
|
81 |
+
attention_mask,
|
82 |
+
position_ids=position_ids,
|
83 |
+
past_key_value=(past_k, past_v),
|
84 |
+
use_cache=True)
|
85 |
+
present_k, present_v = past_kv
|
86 |
+
return hidden_states, present_k, present_v
|
87 |
+
|
88 |
+
|
89 |
+
class LmHead(torch.nn.Module):
|
90 |
+
|
91 |
+
def __init__(self):
|
92 |
+
super().__init__()
|
93 |
+
|
94 |
+
def forward(self, hidden_states):
|
95 |
+
hidden_states = transformer.norm(hidden_states)
|
96 |
+
m_logits = origin_model.lm_head(hidden_states)
|
97 |
+
_, token = torch.topk(m_logits, 1)
|
98 |
+
return token
|
99 |
+
|
100 |
+
|
101 |
+
def convert_block(layer_id):
|
102 |
+
# input
|
103 |
+
hidden_states = torch.randn((1, SEQ_LENGTH, HIDDEN_SIZE))
|
104 |
+
position_ids = torch.tensor([range(SEQ_LENGTH)], dtype=torch.long)
|
105 |
+
attention_mask = torch.randn((1, 1, SEQ_LENGTH, SEQ_LENGTH))
|
106 |
+
model = Block(layer_id)
|
107 |
+
|
108 |
+
torch.onnx.export(
|
109 |
+
model, (hidden_states, position_ids, attention_mask),
|
110 |
+
f'{folder}/block_{layer_id}.onnx',
|
111 |
+
verbose=False,
|
112 |
+
input_names=['input_states', 'position_ids', 'attention_mask'],
|
113 |
+
output_names=['hidden_states', 'past_k', 'past_v'],
|
114 |
+
do_constant_folding=True,
|
115 |
+
opset_version=15)
|
116 |
+
|
117 |
+
|
118 |
+
def convert_block_cache(layer_id):
|
119 |
+
# input
|
120 |
+
np.random.seed(42)
|
121 |
+
hidden_states = torch.randn((1, 1, HIDDEN_SIZE))
|
122 |
+
position_ids = torch.tensor([range(1)], dtype=torch.long)
|
123 |
+
attention_mask = torch.randn((1, 1, 1, SEQ_LENGTH + 1))
|
124 |
+
past_k = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM))
|
125 |
+
past_v = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM))
|
126 |
+
model = BlockCache(layer_id)
|
127 |
+
|
128 |
+
torch.onnx.export(
|
129 |
+
model, (hidden_states, position_ids, attention_mask, past_k, past_v),
|
130 |
+
f'{folder}/block_cache_{layer_id}.onnx',
|
131 |
+
verbose=False,
|
132 |
+
input_names=[
|
133 |
+
'input_states', 'position_ids', 'attention_mask', 'history_k',
|
134 |
+
'history_v'
|
135 |
+
],
|
136 |
+
output_names=['hidden_states', 'past_k', 'past_v'],
|
137 |
+
do_constant_folding=True,
|
138 |
+
opset_version=15)
|
139 |
+
|
140 |
+
|
141 |
+
def convert_embedding():
|
142 |
+
model = Embedding()
|
143 |
+
input = torch.tensor([range(SEQ_LENGTH)])
|
144 |
+
torch.onnx.export(model, (input),
|
145 |
+
f'{folder}/embedding.onnx',
|
146 |
+
verbose=False,
|
147 |
+
input_names=['input_ids'],
|
148 |
+
output_names=['input_embed'],
|
149 |
+
dynamic_axes={"input_ids": {
|
150 |
+
0: "length"
|
151 |
+
}},
|
152 |
+
do_constant_folding=True,
|
153 |
+
opset_version=15)
|
154 |
+
|
155 |
+
|
156 |
+
def convert_lm_head():
|
157 |
+
model = LmHead()
|
158 |
+
input = torch.randn(1, HIDDEN_SIZE)
|
159 |
+
torch.onnx.export(model, (input),
|
160 |
+
f'{folder}/lm_head.onnx',
|
161 |
+
verbose=False,
|
162 |
+
input_names=['hidden_states'],
|
163 |
+
output_names=['token'],
|
164 |
+
do_constant_folding=True,
|
165 |
+
opset_version=15)
|
166 |
+
|
167 |
+
# create folder to store onnx
|
168 |
+
if not os.path.exists(folder):
|
169 |
+
os.makedirs(folder)
|
170 |
+
|
171 |
+
# export models
|
172 |
+
for i in range(NUM_LAYERS):
|
173 |
+
print("convert_block_{}".format(i))
|
174 |
+
convert_block_cache(i)
|
175 |
+
convert_block(i)
|
176 |
+
|
177 |
+
print("convert_embedding")
|
178 |
+
convert_embedding()
|
179 |
+
|
180 |
+
print("convert_lm_head")
|
181 |
+
convert_lm_head()
|
182 |
+
|
Baichuan2/compile/files/Baichuan2-7B/config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BaichuanForCausalLM"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_baichuan.BaichuanConfig",
|
7 |
+
"AutoModelForCausalLM": "modeling_baichuan.BaichuanForCausalLM"
|
8 |
+
},
|
9 |
+
"tokenizer_class": "BaichuanTokenizer",
|
10 |
+
"bos_token_id": 1,
|
11 |
+
"eos_token_id": 2,
|
12 |
+
"hidden_act": "silu",
|
13 |
+
"hidden_size": 4096,
|
14 |
+
"initializer_range": 0.02,
|
15 |
+
"intermediate_size": 11008,
|
16 |
+
"max_position_embeddings": 4096,
|
17 |
+
"model_max_length": 4096,
|
18 |
+
"model_type": "baichuan",
|
19 |
+
"num_attention_heads": 32,
|
20 |
+
"num_hidden_layers": 32,
|
21 |
+
"pad_token_id": 0,
|
22 |
+
"rms_norm_eps": 1e-06,
|
23 |
+
"_from_model_config": true,
|
24 |
+
"tie_word_embeddings": false,
|
25 |
+
"torch_dtype": "bfloat16",
|
26 |
+
"transformers_version": "4.29.2",
|
27 |
+
"use_cache": true,
|
28 |
+
"vocab_size": 125696
|
29 |
+
}
|
Baichuan2/compile/files/Baichuan2-7B/modeling_baichuan.py
ADDED
@@ -0,0 +1,792 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Baichuan Inc. All Rights Reserved.
|
2 |
+
|
3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
6 |
+
# and OPT implementations in this library. It has been modified from its
|
7 |
+
# original forms to accommodate minor architectural differences compared
|
8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
9 |
+
#
|
10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
11 |
+
# you may not use this file except in compliance with the License.
|
12 |
+
# You may obtain a copy of the License at
|
13 |
+
#
|
14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
15 |
+
#
|
16 |
+
# Unless required by applicable law or agreed to in writing, software
|
17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19 |
+
# See the License for the specific language governing permissions and
|
20 |
+
# limitations under the License.
|
21 |
+
|
22 |
+
|
23 |
+
from .configuration_baichuan import BaichuanConfig
|
24 |
+
from .generation_utils import build_chat_input, TextIterStreamer
|
25 |
+
|
26 |
+
import math
|
27 |
+
from typing import List, Optional, Tuple, Union
|
28 |
+
from threading import Thread
|
29 |
+
|
30 |
+
import torch
|
31 |
+
import torch.utils.checkpoint
|
32 |
+
from torch import nn
|
33 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
34 |
+
from torch.nn import functional as F
|
35 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
36 |
+
from transformers.activations import ACT2FN
|
37 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
38 |
+
from transformers.generation.utils import GenerationConfig
|
39 |
+
from transformers.utils import logging, ContextManagers
|
40 |
+
|
41 |
+
import os
|
42 |
+
from contextlib import contextmanager
|
43 |
+
logger = logging.get_logger(__name__)
|
44 |
+
|
45 |
+
try:
|
46 |
+
from xformers import ops as xops
|
47 |
+
except ImportError:
|
48 |
+
xops = None
|
49 |
+
logger.warning(
|
50 |
+
"Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
55 |
+
def _make_causal_mask(
|
56 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
57 |
+
):
|
58 |
+
"""
|
59 |
+
Make causal mask used for bi-directional self-attention.
|
60 |
+
"""
|
61 |
+
bsz, tgt_len = input_ids_shape
|
62 |
+
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
|
63 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
64 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
65 |
+
mask = mask.to(dtype)
|
66 |
+
|
67 |
+
if past_key_values_length > 0:
|
68 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
69 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
70 |
+
|
71 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
72 |
+
"""
|
73 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
74 |
+
"""
|
75 |
+
if len(mask.size()) == 3:
|
76 |
+
bsz, src_len, _ = mask.size()
|
77 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
78 |
+
expanded_mask = mask[:,None,:,:].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
79 |
+
else:
|
80 |
+
bsz, src_len = mask.size()
|
81 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
82 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
83 |
+
|
84 |
+
inverted_mask = 1.0 - expanded_mask
|
85 |
+
|
86 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
87 |
+
|
88 |
+
|
89 |
+
class RMSNorm(nn.Module):
|
90 |
+
def __init__(self, hidden_size, eps=1e-6):
|
91 |
+
"""
|
92 |
+
RMSNorm is equivalent to T5LayerNorm
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
96 |
+
self.variance_epsilon = eps
|
97 |
+
|
98 |
+
def forward(self, hidden_states):
|
99 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
100 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
101 |
+
|
102 |
+
# convert into half-precision if necessary
|
103 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
104 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
105 |
+
|
106 |
+
return self.weight * hidden_states
|
107 |
+
|
108 |
+
|
109 |
+
class RotaryEmbedding(torch.nn.Module):
|
110 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
111 |
+
super().__init__()
|
112 |
+
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
113 |
+
self.max_seq_len_cached = max_position_embeddings
|
114 |
+
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
115 |
+
freqs = torch.outer(t, self.inv_freq)
|
116 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
117 |
+
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
|
118 |
+
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
|
119 |
+
def forward(self, x, seq_len=None):
|
120 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
121 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
122 |
+
if seq_len > self.max_seq_len_cached:
|
123 |
+
self.max_seq_len_cached = seq_len
|
124 |
+
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
125 |
+
freqs = torch.outer(t, self.inv_freq)
|
126 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
127 |
+
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
|
128 |
+
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
|
129 |
+
elif self.cos_cached.device != x.device:
|
130 |
+
self.cos_cached = self.cos_cached.to(x.device)
|
131 |
+
self.sin_cached = self.sin_cached.to(x.device)
|
132 |
+
return (
|
133 |
+
self.cos_cached[:, :, :seq_len, ...],
|
134 |
+
self.sin_cached[:, :, :seq_len, ...],
|
135 |
+
)
|
136 |
+
|
137 |
+
|
138 |
+
def rotate_half(x):
|
139 |
+
"""Rotates half the hidden dims of the input."""
|
140 |
+
x1 = x[..., : x.shape[-1] // 2]
|
141 |
+
x2 = x[..., x.shape[-1] // 2:]
|
142 |
+
return torch.cat((-x2, x1), dim=-1)
|
143 |
+
|
144 |
+
|
145 |
+
def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
|
146 |
+
cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
|
147 |
+
sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
|
148 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
149 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
150 |
+
cos = cos.transpose(1, 2)
|
151 |
+
sin = sin.transpose(1, 2)
|
152 |
+
q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
|
153 |
+
k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
|
154 |
+
return q_embed.to(q.dtype), k_embed.to(k.dtype)
|
155 |
+
|
156 |
+
|
157 |
+
class MLP(nn.Module):
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
hidden_size: int,
|
161 |
+
intermediate_size: int,
|
162 |
+
hidden_act: str,
|
163 |
+
):
|
164 |
+
super().__init__()
|
165 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
166 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
167 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
168 |
+
self.act_fn = ACT2FN[hidden_act]
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
172 |
+
|
173 |
+
|
174 |
+
class Attention(nn.Module):
|
175 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
176 |
+
def __init__(self, config: BaichuanConfig):
|
177 |
+
super().__init__()
|
178 |
+
self.config = config
|
179 |
+
self.hidden_size = config.hidden_size
|
180 |
+
self.num_heads = config.num_attention_heads
|
181 |
+
self.head_dim = self.hidden_size // self.num_heads
|
182 |
+
self.max_position_embeddings = config.max_position_embeddings
|
183 |
+
|
184 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
185 |
+
raise ValueError(
|
186 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
187 |
+
f" and `num_heads`: {self.num_heads})."
|
188 |
+
)
|
189 |
+
self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
|
190 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
191 |
+
self.rotary_emb = RotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
192 |
+
|
193 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
194 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
195 |
+
|
196 |
+
def forward(
|
197 |
+
self,
|
198 |
+
hidden_states: torch.Tensor,
|
199 |
+
attention_mask: Optional[torch.Tensor] = None,
|
200 |
+
position_ids: Optional[torch.LongTensor] = None,
|
201 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
202 |
+
output_attentions: bool = False,
|
203 |
+
use_cache: bool = False,
|
204 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
205 |
+
bsz, q_len, _ = hidden_states.size()
|
206 |
+
|
207 |
+
proj = self.W_pack(hidden_states)
|
208 |
+
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
209 |
+
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim)
|
210 |
+
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim)
|
211 |
+
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim)
|
212 |
+
|
213 |
+
kv_seq_len = key_states.shape[-3]
|
214 |
+
if past_key_value is not None:
|
215 |
+
kv_seq_len = kv_seq_len + past_key_value[0].shape[-3]
|
216 |
+
if past_key_value is not None:
|
217 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len-1)
|
218 |
+
else:
|
219 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
220 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
221 |
+
# [bsz, nh, t, hd]
|
222 |
+
past_kv = (key_states, value_states) if use_cache else None
|
223 |
+
if past_key_value is not None:
|
224 |
+
# reuse k, v, self_attention
|
225 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=1)
|
226 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=1)
|
227 |
+
|
228 |
+
|
229 |
+
if xops is not None and self.training:
|
230 |
+
attn_weights = None
|
231 |
+
query_states = query_states.transpose(1, 2)
|
232 |
+
key_states = key_states.transpose(1, 2)
|
233 |
+
value_states = value_states.transpose(1, 2)
|
234 |
+
attn_output = xops.memory_efficient_attention(
|
235 |
+
query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
|
236 |
+
)
|
237 |
+
else:
|
238 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
239 |
+
query_states = query_states.transpose(1, 2)
|
240 |
+
key_states = key_states.transpose(1, 2)
|
241 |
+
value_states = value_states.transpose(1, 2)
|
242 |
+
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
|
243 |
+
attn_output = attn_output.transpose(1, 2)
|
244 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
245 |
+
attn_output = self.o_proj(attn_output)
|
246 |
+
|
247 |
+
if not output_attentions:
|
248 |
+
attn_weights = None
|
249 |
+
return attn_output, attn_weights, past_kv
|
250 |
+
|
251 |
+
|
252 |
+
class DecoderLayer(nn.Module):
|
253 |
+
def __init__(self, config: BaichuanConfig):
|
254 |
+
super().__init__()
|
255 |
+
self.hidden_size = config.hidden_size
|
256 |
+
self.self_attn = Attention(config=config)
|
257 |
+
self.mlp = MLP(
|
258 |
+
hidden_size=self.hidden_size,
|
259 |
+
intermediate_size=config.intermediate_size,
|
260 |
+
hidden_act=config.hidden_act,
|
261 |
+
)
|
262 |
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
263 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
264 |
+
|
265 |
+
def forward(
|
266 |
+
self,
|
267 |
+
hidden_states: torch.Tensor,
|
268 |
+
attention_mask: Optional[torch.Tensor] = None,
|
269 |
+
position_ids: Optional[torch.LongTensor] = None,
|
270 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
271 |
+
output_attentions: Optional[bool] = False,
|
272 |
+
use_cache: Optional[bool] = False,
|
273 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
274 |
+
|
275 |
+
residual = hidden_states
|
276 |
+
|
277 |
+
hidden_states = self.input_layernorm(hidden_states)
|
278 |
+
|
279 |
+
# Self Attention
|
280 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
281 |
+
hidden_states=hidden_states,
|
282 |
+
attention_mask=attention_mask,
|
283 |
+
position_ids=position_ids,
|
284 |
+
past_key_value=past_key_value,
|
285 |
+
output_attentions=output_attentions,
|
286 |
+
use_cache=use_cache,
|
287 |
+
)
|
288 |
+
hidden_states = residual + hidden_states
|
289 |
+
|
290 |
+
# Fully Connected
|
291 |
+
residual = hidden_states
|
292 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
293 |
+
hidden_states = self.mlp(hidden_states)
|
294 |
+
hidden_states = residual + hidden_states
|
295 |
+
|
296 |
+
outputs = (hidden_states,)
|
297 |
+
|
298 |
+
if output_attentions:
|
299 |
+
outputs += (self_attn_weights,)
|
300 |
+
|
301 |
+
if use_cache:
|
302 |
+
outputs += (present_key_value,)
|
303 |
+
|
304 |
+
return outputs
|
305 |
+
|
306 |
+
|
307 |
+
class BaichuanPreTrainedModel(PreTrainedModel):
|
308 |
+
config_class = BaichuanConfig
|
309 |
+
base_model_prefix = "model"
|
310 |
+
supports_gradient_checkpointing = True
|
311 |
+
_no_split_modules = ["DecoderLayer"]
|
312 |
+
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
313 |
+
|
314 |
+
def _init_weights(self, module):
|
315 |
+
std = self.config.initializer_range
|
316 |
+
if isinstance(module, nn.Linear):
|
317 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
318 |
+
if module.bias is not None:
|
319 |
+
module.bias.data.zero_()
|
320 |
+
elif isinstance(module, nn.Embedding):
|
321 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
322 |
+
if module.padding_idx is not None:
|
323 |
+
module.weight.data[module.padding_idx].zero_()
|
324 |
+
|
325 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
326 |
+
if isinstance(module, BaichuanModel):
|
327 |
+
module.gradient_checkpointing = value
|
328 |
+
|
329 |
+
|
330 |
+
class BaichuanModel(BaichuanPreTrainedModel):
|
331 |
+
def __init__(self, config: BaichuanConfig):
|
332 |
+
super().__init__(config)
|
333 |
+
self.padding_idx = config.pad_token_id
|
334 |
+
self.vocab_size = config.vocab_size
|
335 |
+
|
336 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
337 |
+
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
338 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
339 |
+
|
340 |
+
self.gradient_checkpointing = False
|
341 |
+
# Initialize weights and apply final processing
|
342 |
+
self.post_init()
|
343 |
+
|
344 |
+
def get_input_embeddings(self):
|
345 |
+
return self.embed_tokens
|
346 |
+
|
347 |
+
def set_input_embeddings(self, value):
|
348 |
+
self.embed_tokens = value
|
349 |
+
|
350 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
351 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
352 |
+
# create causal mask
|
353 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
354 |
+
combined_attention_mask = None
|
355 |
+
if input_shape[-1] > 1:
|
356 |
+
combined_attention_mask = _make_causal_mask(
|
357 |
+
input_shape,
|
358 |
+
inputs_embeds.dtype,
|
359 |
+
device=inputs_embeds.device,
|
360 |
+
past_key_values_length=past_key_values_length,
|
361 |
+
)
|
362 |
+
|
363 |
+
if attention_mask is not None:
|
364 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
365 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
366 |
+
inputs_embeds.device
|
367 |
+
)
|
368 |
+
combined_attention_mask = (
|
369 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
370 |
+
)
|
371 |
+
|
372 |
+
return combined_attention_mask
|
373 |
+
|
374 |
+
def forward(
|
375 |
+
self,
|
376 |
+
input_ids: torch.LongTensor = None,
|
377 |
+
attention_mask: Optional[torch.Tensor] = None,
|
378 |
+
position_ids: Optional[torch.LongTensor] = None,
|
379 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
380 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
381 |
+
use_cache: Optional[bool] = None,
|
382 |
+
output_attentions: Optional[bool] = None,
|
383 |
+
output_hidden_states: Optional[bool] = None,
|
384 |
+
return_dict: Optional[bool] = None,
|
385 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
386 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
387 |
+
output_hidden_states = (
|
388 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
389 |
+
)
|
390 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
391 |
+
|
392 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
393 |
+
|
394 |
+
# retrieve input_ids and inputs_embeds
|
395 |
+
if input_ids is not None and inputs_embeds is not None:
|
396 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
397 |
+
elif input_ids is not None:
|
398 |
+
batch_size, seq_length = input_ids.shape
|
399 |
+
elif inputs_embeds is not None:
|
400 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
401 |
+
else:
|
402 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
403 |
+
|
404 |
+
seq_length_with_past = seq_length
|
405 |
+
past_key_values_length = 0
|
406 |
+
|
407 |
+
if past_key_values is not None:
|
408 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
409 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
410 |
+
|
411 |
+
if position_ids is None:
|
412 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
413 |
+
position_ids = torch.arange(
|
414 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
415 |
+
)
|
416 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
417 |
+
else:
|
418 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
419 |
+
|
420 |
+
if inputs_embeds is None:
|
421 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
422 |
+
# embed positions
|
423 |
+
if attention_mask is None:
|
424 |
+
attention_mask = torch.ones(
|
425 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
426 |
+
)
|
427 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
428 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
429 |
+
)
|
430 |
+
|
431 |
+
hidden_states = inputs_embeds
|
432 |
+
|
433 |
+
if self.gradient_checkpointing and self.training:
|
434 |
+
if use_cache:
|
435 |
+
logger.warning_once(
|
436 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
437 |
+
)
|
438 |
+
use_cache = False
|
439 |
+
|
440 |
+
# decoder layers
|
441 |
+
all_hidden_states = () if output_hidden_states else None
|
442 |
+
all_self_attns = () if output_attentions else None
|
443 |
+
next_decoder_cache = () if use_cache else None
|
444 |
+
|
445 |
+
for idx, decoder_layer in enumerate(self.layers):
|
446 |
+
if output_hidden_states:
|
447 |
+
all_hidden_states += (hidden_states,)
|
448 |
+
|
449 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
450 |
+
|
451 |
+
if self.gradient_checkpointing and self.training:
|
452 |
+
|
453 |
+
def create_custom_forward(module):
|
454 |
+
def custom_forward(*inputs):
|
455 |
+
# None for past_key_value
|
456 |
+
return module(*inputs, output_attentions, None)
|
457 |
+
|
458 |
+
return custom_forward
|
459 |
+
|
460 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
461 |
+
create_custom_forward(decoder_layer),
|
462 |
+
hidden_states,
|
463 |
+
attention_mask,
|
464 |
+
position_ids,
|
465 |
+
None,
|
466 |
+
)
|
467 |
+
else:
|
468 |
+
layer_outputs = decoder_layer(
|
469 |
+
hidden_states,
|
470 |
+
attention_mask=attention_mask,
|
471 |
+
position_ids=position_ids,
|
472 |
+
past_key_value=past_key_value,
|
473 |
+
output_attentions=output_attentions,
|
474 |
+
use_cache=use_cache,
|
475 |
+
)
|
476 |
+
|
477 |
+
hidden_states = layer_outputs[0]
|
478 |
+
|
479 |
+
if use_cache:
|
480 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
481 |
+
|
482 |
+
if output_attentions:
|
483 |
+
all_self_attns += (layer_outputs[1],)
|
484 |
+
|
485 |
+
hidden_states = self.norm(hidden_states)
|
486 |
+
|
487 |
+
# add hidden states from the last decoder layer
|
488 |
+
if output_hidden_states:
|
489 |
+
all_hidden_states += (hidden_states,)
|
490 |
+
|
491 |
+
next_cache = next_decoder_cache if use_cache else None
|
492 |
+
if not return_dict:
|
493 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
494 |
+
return BaseModelOutputWithPast(
|
495 |
+
last_hidden_state=hidden_states,
|
496 |
+
past_key_values=next_cache,
|
497 |
+
hidden_states=all_hidden_states,
|
498 |
+
attentions=all_self_attns,
|
499 |
+
)
|
500 |
+
|
501 |
+
|
502 |
+
class NormHead(nn.Module):
|
503 |
+
def __init__(self, hidden_size, vocab_size, bias=False):
|
504 |
+
super().__init__()
|
505 |
+
self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size)))
|
506 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
507 |
+
self.first_flag = True
|
508 |
+
|
509 |
+
def forward(self, hidden_states):
|
510 |
+
if self.training:
|
511 |
+
norm_weight = nn.functional.normalize(self.weight)
|
512 |
+
self.first_flag = True
|
513 |
+
elif self.first_flag:
|
514 |
+
self.first_flag = False
|
515 |
+
self.weight.data = nn.functional.normalize(self.weight)
|
516 |
+
norm_weight = self.weight
|
517 |
+
else:
|
518 |
+
norm_weight = self.weight
|
519 |
+
return nn.functional.linear(hidden_states, norm_weight)
|
520 |
+
|
521 |
+
_init_weights = True
|
522 |
+
@contextmanager
|
523 |
+
def no_init_weights(_enable=True):
|
524 |
+
global _init_weights
|
525 |
+
old_init_weights = _init_weights
|
526 |
+
if _enable:
|
527 |
+
_init_weights = False
|
528 |
+
try:
|
529 |
+
yield
|
530 |
+
finally:
|
531 |
+
_init_weights = old_init_weights
|
532 |
+
|
533 |
+
class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
534 |
+
def __init__(self, config, *model_args, **model_kwargs):
|
535 |
+
super().__init__(config, *model_args, **model_kwargs)
|
536 |
+
self.model = BaichuanModel(config)
|
537 |
+
|
538 |
+
self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
|
539 |
+
if hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and config.quantization_config.get('load_in_4bit', False):
|
540 |
+
try:
|
541 |
+
from .quantizer import quantize_offline, init_model_weight_int4
|
542 |
+
except ImportError:
|
543 |
+
raise ImportError(f"Needs QLinear to run quantize.")
|
544 |
+
quantize_offline(self, 4)
|
545 |
+
# Initialize weights and apply final processing
|
546 |
+
self.post_init()
|
547 |
+
|
548 |
+
def get_input_embeddings(self):
|
549 |
+
return self.model.embed_tokens
|
550 |
+
|
551 |
+
def set_input_embeddings(self, value):
|
552 |
+
self.model.embed_tokens = value
|
553 |
+
|
554 |
+
def get_output_embeddings(self):
|
555 |
+
return self.lm_head
|
556 |
+
|
557 |
+
def set_output_embeddings(self, new_embeddings):
|
558 |
+
self.lm_head = new_embeddings
|
559 |
+
|
560 |
+
def set_decoder(self, decoder):
|
561 |
+
self.model = decoder
|
562 |
+
|
563 |
+
def get_decoder(self):
|
564 |
+
return self.model
|
565 |
+
|
566 |
+
@classmethod
|
567 |
+
def from_pretrained(
|
568 |
+
cls,
|
569 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
570 |
+
*model_args,
|
571 |
+
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
572 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
573 |
+
ignore_mismatched_sizes: bool = False,
|
574 |
+
force_download: bool = False,
|
575 |
+
local_files_only: bool = False,
|
576 |
+
token: Optional[Union[str, bool]] = None,
|
577 |
+
revision: str = "main",
|
578 |
+
use_safetensors: bool = None,
|
579 |
+
**kwargs,
|
580 |
+
):
|
581 |
+
# Load config if we don't provide a configuration
|
582 |
+
if not isinstance(config, PretrainedConfig):
|
583 |
+
config_path = config if config is not None else pretrained_model_name_or_path
|
584 |
+
config, model_kwargs = cls.config_class.from_pretrained(
|
585 |
+
config_path,
|
586 |
+
cache_dir=cache_dir,
|
587 |
+
return_unused_kwargs=True,
|
588 |
+
force_download=force_download,
|
589 |
+
resume_download=False,
|
590 |
+
proxies=None,
|
591 |
+
local_files_only=local_files_only,
|
592 |
+
token=token,
|
593 |
+
revision=revision,
|
594 |
+
subfolder="",
|
595 |
+
_from_auto=False,
|
596 |
+
_from_pipeline=None,
|
597 |
+
**kwargs,
|
598 |
+
)
|
599 |
+
else:
|
600 |
+
model_kwargs = kwargs
|
601 |
+
|
602 |
+
if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']:
|
603 |
+
try:
|
604 |
+
from .quantizer import init_model_weight_int4
|
605 |
+
from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map
|
606 |
+
from accelerate.utils import CustomDtype
|
607 |
+
from accelerate.utils import get_balanced_memory
|
608 |
+
except ImportError:
|
609 |
+
raise ImportError(f"Needs import model weight init func to run quantize.")
|
610 |
+
# Instantiate model.
|
611 |
+
init_contexts = [no_init_weights(_enable=True)]
|
612 |
+
init_contexts.append(init_empty_weights())
|
613 |
+
with ContextManagers(init_contexts):
|
614 |
+
model = cls(config)
|
615 |
+
|
616 |
+
model_file = os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin')
|
617 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
618 |
+
model.is_quantized = True
|
619 |
+
|
620 |
+
device_map = kwargs.pop("device_map", None)
|
621 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
622 |
+
|
623 |
+
if device_map is not None:
|
624 |
+
kwargs = {"no_split_module_classes": model._no_split_modules}
|
625 |
+
target_dtype = CustomDtype.INT4
|
626 |
+
max_memory = get_balanced_memory(
|
627 |
+
model,
|
628 |
+
dtype=target_dtype,
|
629 |
+
low_zero=(device_map == "balanced_low_0"),
|
630 |
+
max_memory=None,
|
631 |
+
**kwargs,
|
632 |
+
)
|
633 |
+
kwargs["max_memory"] = max_memory
|
634 |
+
device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs)
|
635 |
+
|
636 |
+
model = init_model_weight_int4(config, model, state_dict)
|
637 |
+
|
638 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
639 |
+
model.eval()
|
640 |
+
# If it is a model with generation capabilities, attempt to load the generation config
|
641 |
+
if model.can_generate():
|
642 |
+
try:
|
643 |
+
model.generation_config = GenerationConfig.from_pretrained(
|
644 |
+
pretrained_model_name_or_path,
|
645 |
+
cache_dir=cache_dir,
|
646 |
+
force_download=force_download,
|
647 |
+
resume_download=False,
|
648 |
+
proxies=None,
|
649 |
+
local_files_only=local_files_only,
|
650 |
+
token=token,
|
651 |
+
revision=revision,
|
652 |
+
subfolder="",
|
653 |
+
_from_auto=False,
|
654 |
+
_from_pipeline=None,
|
655 |
+
**kwargs,
|
656 |
+
)
|
657 |
+
except (OSError, TypeError):
|
658 |
+
logger.info(
|
659 |
+
"Generation config file not found, using a generation config created from the model config."
|
660 |
+
)
|
661 |
+
pass
|
662 |
+
|
663 |
+
if device_map is not None:
|
664 |
+
dispatch_model(model, device_map=device_map)
|
665 |
+
|
666 |
+
return model
|
667 |
+
return super(BaichuanForCausalLM, cls).from_pretrained(pretrained_model_name_or_path, *model_args,
|
668 |
+
config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes,
|
669 |
+
force_download=force_download, local_files_only=local_files_only, token=token, revision=revision,
|
670 |
+
use_safetensors=use_safetensors, **kwargs)
|
671 |
+
|
672 |
+
def forward(
|
673 |
+
self,
|
674 |
+
input_ids: torch.LongTensor = None,
|
675 |
+
attention_mask: Optional[torch.Tensor] = None,
|
676 |
+
position_ids: Optional[torch.LongTensor] = None,
|
677 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
678 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
679 |
+
labels: Optional[torch.LongTensor] = None,
|
680 |
+
use_cache: Optional[bool] = None,
|
681 |
+
output_attentions: Optional[bool] = None,
|
682 |
+
output_hidden_states: Optional[bool] = None,
|
683 |
+
return_dict: Optional[bool] = None,
|
684 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
685 |
+
|
686 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
687 |
+
output_hidden_states = (
|
688 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
689 |
+
)
|
690 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
691 |
+
|
692 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
693 |
+
outputs = self.model(
|
694 |
+
input_ids=input_ids,
|
695 |
+
attention_mask=attention_mask,
|
696 |
+
position_ids=position_ids,
|
697 |
+
past_key_values=past_key_values,
|
698 |
+
inputs_embeds=inputs_embeds,
|
699 |
+
use_cache=use_cache,
|
700 |
+
output_attentions=output_attentions,
|
701 |
+
output_hidden_states=output_hidden_states,
|
702 |
+
return_dict=return_dict,
|
703 |
+
)
|
704 |
+
|
705 |
+
hidden_states = outputs[0]
|
706 |
+
logits = self.lm_head(hidden_states)
|
707 |
+
loss = None
|
708 |
+
if labels is not None:
|
709 |
+
# Shift so that tokens < n predict n
|
710 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
711 |
+
shift_labels = labels[..., 1:].contiguous()
|
712 |
+
# Flatten the tokens
|
713 |
+
loss_fct = CrossEntropyLoss()
|
714 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
715 |
+
shift_labels = shift_labels.view(-1)
|
716 |
+
softmax_normalizer = shift_logits.max(-1).values ** 2
|
717 |
+
z_loss = self.config.z_loss_weight * softmax_normalizer.mean()
|
718 |
+
# Enable model parallelism
|
719 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
720 |
+
loss = loss_fct(shift_logits, shift_labels) + z_loss
|
721 |
+
|
722 |
+
if not return_dict:
|
723 |
+
output = (logits,) + outputs[1:]
|
724 |
+
return (loss,) + output if loss is not None else output
|
725 |
+
|
726 |
+
return CausalLMOutputWithPast(
|
727 |
+
loss=loss,
|
728 |
+
logits=logits,
|
729 |
+
past_key_values=outputs.past_key_values,
|
730 |
+
hidden_states=outputs.hidden_states,
|
731 |
+
attentions=outputs.attentions,
|
732 |
+
)
|
733 |
+
|
734 |
+
def prepare_inputs_for_generation(
|
735 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
736 |
+
):
|
737 |
+
if past_key_values:
|
738 |
+
input_ids = input_ids[:, -1:]
|
739 |
+
|
740 |
+
position_ids = kwargs.get("position_ids", None)
|
741 |
+
if attention_mask is not None and position_ids is None:
|
742 |
+
# create position_ids on the fly for batch generation
|
743 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
744 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
745 |
+
if past_key_values:
|
746 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
747 |
+
|
748 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
749 |
+
if inputs_embeds is not None and past_key_values is None:
|
750 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
751 |
+
else:
|
752 |
+
model_inputs = {"input_ids": input_ids}
|
753 |
+
|
754 |
+
model_inputs.update(
|
755 |
+
{
|
756 |
+
"position_ids": position_ids,
|
757 |
+
"past_key_values": past_key_values,
|
758 |
+
"use_cache": kwargs.get("use_cache"),
|
759 |
+
"attention_mask": attention_mask,
|
760 |
+
}
|
761 |
+
)
|
762 |
+
return model_inputs
|
763 |
+
|
764 |
+
@staticmethod
|
765 |
+
def _reorder_cache(past_key_values, beam_idx):
|
766 |
+
reordered_past = ()
|
767 |
+
for layer_past in past_key_values:
|
768 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
769 |
+
return reordered_past
|
770 |
+
|
771 |
+
def quantize(self, bits: int):
|
772 |
+
try:
|
773 |
+
from .quantizer import quantize_online
|
774 |
+
except ImportError:
|
775 |
+
raise ImportError(f"Needs QLinear to run quantize.")
|
776 |
+
return quantize_online(self, bits)
|
777 |
+
|
778 |
+
def chat(self, tokenizer, messages: List[dict], stream=False,
|
779 |
+
generation_config: Optional[GenerationConfig]=None):
|
780 |
+
generation_config = generation_config or self.generation_config
|
781 |
+
input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
|
782 |
+
if stream:
|
783 |
+
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
784 |
+
Thread(target=self.generate, kwargs=dict(
|
785 |
+
inputs=input_ids, streamer=streamer,
|
786 |
+
generation_config=generation_config,
|
787 |
+
)).start()
|
788 |
+
return streamer
|
789 |
+
else:
|
790 |
+
outputs = self.generate(input_ids, generation_config=generation_config)
|
791 |
+
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
792 |
+
return response
|
Baichuan2/compile/torch_inference.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
from transformers.generation.utils import GenerationConfig
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
parser = argparse.ArgumentParser()
|
7 |
+
parser.add_argument('model_path', help='下载模型的绝对路径')
|
8 |
+
args = parser.parse_args()
|
9 |
+
model_path = args.model_path
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
|
11 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.float32, trust_remote_code=True)
|
12 |
+
model.generation_config = GenerationConfig.from_pretrained(model_path)
|
13 |
+
messages = []
|
14 |
+
messages.append({"role": "user", "content": "解释一下“温故而知新”"})
|
15 |
+
response = model.chat(tokenizer, messages)
|
16 |
+
print(response)
|
Baichuan2/demo/CMakeLists.txt
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cmake_minimum_required(VERSION 2.8)
|
2 |
+
project(baichuan2)
|
3 |
+
|
4 |
+
if (NOT DEFINED TARGET_ARCH)
|
5 |
+
set(TARGET_ARCH pcie)
|
6 |
+
endif()
|
7 |
+
|
8 |
+
set(CMAKE_INSTALL_PREFIX install)
|
9 |
+
|
10 |
+
if (${CMAKE_HOST_SYSTEM_PROCESSOR} STREQUAL "aarch64")
|
11 |
+
add_definitions(-DSOC_TARGET)
|
12 |
+
link_directories(${PROJECT_SOURCE_DIR}/../src/lib_soc)
|
13 |
+
message("SoC mode, starting......")
|
14 |
+
elseif (${TARGET_ARCH} STREQUAL "pcie")
|
15 |
+
add_definitions(-DPCIE_TARGET)
|
16 |
+
link_directories(${PROJECT_SOURCE_DIR}/../src/lib_pcie)
|
17 |
+
message("Pcie mode, starting......")
|
18 |
+
elseif (${TARGET_ARCH} STREQUAL "soc")
|
19 |
+
add_definitions(-DSOC_TARGET)
|
20 |
+
set(CMAKE_C_COMPILER aarch64-linux-gnu-gcc)
|
21 |
+
set(CMAKE_ASM_COMPILER aarch64-linux-gnu-gcc)
|
22 |
+
set(CMAKE_CXX_COMPILER aarch64-linux-gnu-g++)
|
23 |
+
link_directories(${PROJECT_SOURCE_DIR}/lib_soc)
|
24 |
+
message("SoC mode, starting......")
|
25 |
+
endif()
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
include_directories(${PROJECT_SOURCE_DIR}/../src/include)
|
31 |
+
|
32 |
+
add_definitions(-DDEBUG --std=c++17 -fPIC -Wall -Werror)
|
33 |
+
set(CMAKE_BUILD_TYPE "Debug")
|
34 |
+
|
35 |
+
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
|
36 |
+
add_executable(baichuan2 demo.cpp)
|
37 |
+
target_link_libraries(baichuan2 bmrt bmlib sentencepiece)
|
38 |
+
|
Baichuan2/demo/demo.cpp
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
//===----------------------------------------------------------------------===//
|
2 |
+
//
|
3 |
+
// Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
|
4 |
+
//
|
5 |
+
// TPU-MLIR is licensed under the 2-Clause BSD License except for the
|
6 |
+
// third-party components.
|
7 |
+
//
|
8 |
+
//===----------------------------------------------------------------------===//
|
9 |
+
|
10 |
+
#include <iostream>
|
11 |
+
#include <cstdlib>
|
12 |
+
#include <vector>
|
13 |
+
#include <assert.h>
|
14 |
+
#include <chrono>
|
15 |
+
#include <algorithm>
|
16 |
+
#include "memory.h"
|
17 |
+
#include "sentencepiece/sentencepiece_processor.h"
|
18 |
+
#include "bmruntime_interface.h"
|
19 |
+
#include <getopt.h>
|
20 |
+
#include <numeric>
|
21 |
+
|
22 |
+
static const int NUM_LAYERS = 32;
|
23 |
+
static const int MAX_LEN = 512;
|
24 |
+
static const float ATTENTION_MASK = -1000.;
|
25 |
+
|
26 |
+
static const std::string TOKENIZER_MODEL = "../model/tokenizer.model";
|
27 |
+
|
28 |
+
// #define EXPORT_RESULTS
|
29 |
+
#ifdef EXPORT_RESULTS
|
30 |
+
#include "cnpy.h"
|
31 |
+
static cnpy::npz_t map;
|
32 |
+
|
33 |
+
template <typename T>
|
34 |
+
static void add_array(std::string name, bm_handle_t bm_handle,
|
35 |
+
const bm_device_mem_t &dst) {
|
36 |
+
std::vector<T> data(dst.size / sizeof(T));
|
37 |
+
bm_memcpy_d2s(bm_handle, data.data(), dst);
|
38 |
+
cnpy::npz_add_array(map, name, data);
|
39 |
+
}
|
40 |
+
|
41 |
+
static void save_array(std::string filename) {
|
42 |
+
cnpy::npz_save_all(filename, map);
|
43 |
+
}
|
44 |
+
#endif
|
45 |
+
|
46 |
+
class Baichuan2 {
|
47 |
+
public:
|
48 |
+
void init(const std::vector<int> &devid, std::string model);
|
49 |
+
void chat();
|
50 |
+
void deinit();
|
51 |
+
|
52 |
+
private:
|
53 |
+
void answer(const std::string &input_str);
|
54 |
+
int forward_first(std::vector<int> &tokens);
|
55 |
+
int forward_next();
|
56 |
+
void load_sentencepiece();
|
57 |
+
|
58 |
+
private:
|
59 |
+
std::vector<bm_handle_t> handles;
|
60 |
+
bm_handle_t bm_handle;
|
61 |
+
void *p_bmrt;
|
62 |
+
sentencepiece::SentencePieceProcessor sentencepiece;
|
63 |
+
const bm_net_info_t *net_blocks[NUM_LAYERS];
|
64 |
+
const bm_net_info_t *net_blocks_cache[NUM_LAYERS];
|
65 |
+
const bm_net_info_t *net_embed;
|
66 |
+
const bm_net_info_t *net_embed_cache;
|
67 |
+
const bm_net_info_t *net_lm;
|
68 |
+
bm_tensor_t inputs_embed_512, outputs_embed_512;
|
69 |
+
bm_tensor_t inputs_lm, outputs_lm;
|
70 |
+
bm_tensor_t inputs_pid, next_pid, inputs_attention, next_attention;
|
71 |
+
bm_tensor_t past_key[NUM_LAYERS], past_value[NUM_LAYERS];
|
72 |
+
bm_tensor_t present_key[NUM_LAYERS], present_value[NUM_LAYERS];
|
73 |
+
bm_tensor_t present_key_cache, present_value_cache;
|
74 |
+
std::string name_embed;
|
75 |
+
std::string name_embed_cache;
|
76 |
+
std::string name_lm;
|
77 |
+
std::string name_blocks[NUM_LAYERS];
|
78 |
+
std::string name_blocks_cache[NUM_LAYERS];
|
79 |
+
int round = 0;
|
80 |
+
int token_length;
|
81 |
+
int EOS;
|
82 |
+
std::vector<std::string> history;
|
83 |
+
};
|
84 |
+
|
85 |
+
void Baichuan2::load_sentencepiece() {
|
86 |
+
printf("Load %s ... ", TOKENIZER_MODEL.c_str());
|
87 |
+
auto status = sentencepiece.Load(TOKENIZER_MODEL);
|
88 |
+
if (!status.ok()) {
|
89 |
+
std::cout << status.ToString() << std::endl;
|
90 |
+
exit(-1);
|
91 |
+
}
|
92 |
+
EOS = sentencepiece.eos_id();
|
93 |
+
printf("Done!\n");
|
94 |
+
}
|
95 |
+
|
96 |
+
void Baichuan2::init(const std::vector<int> &devices, std::string model) {
|
97 |
+
load_sentencepiece();
|
98 |
+
// request bm_handle
|
99 |
+
std::cout << "Device [ ";
|
100 |
+
for (auto d : devices) {
|
101 |
+
std::cout << d << " ";
|
102 |
+
}
|
103 |
+
std::cout << "] loading ....\n";
|
104 |
+
// int device_num = devices.size();
|
105 |
+
for (auto d : devices) {
|
106 |
+
bm_handle_t h;
|
107 |
+
bm_status_t status = bm_dev_request(&h, d);
|
108 |
+
assert(BM_SUCCESS == status);
|
109 |
+
handles.push_back(h);
|
110 |
+
}
|
111 |
+
bm_handle = handles[0];
|
112 |
+
// create bmruntime
|
113 |
+
p_bmrt = bmrt_create(bm_handle);
|
114 |
+
assert(NULL != p_bmrt);
|
115 |
+
|
116 |
+
// load bmodel by file
|
117 |
+
printf("Model[%s] loading ....\n", model.c_str());
|
118 |
+
bool ret = bmrt_load_bmodel(p_bmrt, model.c_str());
|
119 |
+
assert(true == ret);
|
120 |
+
printf("Done!\n");
|
121 |
+
// net names
|
122 |
+
name_embed = "embedding";
|
123 |
+
name_embed_cache = "embedding_cache";
|
124 |
+
name_lm = "lm_head";
|
125 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
126 |
+
name_blocks[i] = "block_" + std::to_string(i);
|
127 |
+
name_blocks_cache[i] = "block_cache_" + std::to_string(i);
|
128 |
+
}
|
129 |
+
|
130 |
+
// net infos
|
131 |
+
net_embed = bmrt_get_network_info(p_bmrt, name_embed.c_str());
|
132 |
+
net_embed_cache = bmrt_get_network_info(p_bmrt, name_embed_cache.c_str());
|
133 |
+
net_lm = bmrt_get_network_info(p_bmrt, name_lm.c_str());
|
134 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
135 |
+
net_blocks[i] = bmrt_get_network_info(p_bmrt, name_blocks[i].c_str());
|
136 |
+
net_blocks_cache[i] =
|
137 |
+
bmrt_get_network_info(p_bmrt, name_blocks_cache[i].c_str());
|
138 |
+
}
|
139 |
+
|
140 |
+
// net device mem
|
141 |
+
ret = bmrt_tensor(&inputs_embed_512, p_bmrt, net_embed->input_dtypes[0],
|
142 |
+
net_embed->stages[0].input_shapes[0]);
|
143 |
+
assert(true == ret);
|
144 |
+
|
145 |
+
ret = bmrt_tensor(&outputs_embed_512, p_bmrt, net_embed->output_dtypes[0],
|
146 |
+
net_embed->stages[0].output_shapes[0]);
|
147 |
+
assert(true == ret);
|
148 |
+
|
149 |
+
ret = bmrt_tensor(&inputs_pid, p_bmrt, net_blocks[0]->input_dtypes[1],
|
150 |
+
net_blocks[0]->stages[0].input_shapes[1]);
|
151 |
+
assert(true == ret);
|
152 |
+
|
153 |
+
ret = bmrt_tensor(&inputs_attention, p_bmrt, net_blocks[0]->input_dtypes[2],
|
154 |
+
net_blocks[0]->stages[0].input_shapes[2]);
|
155 |
+
assert(true == ret);
|
156 |
+
|
157 |
+
ret = bmrt_tensor(&next_pid, p_bmrt, net_blocks_cache[0]->input_dtypes[1],
|
158 |
+
net_blocks_cache[0]->stages[0].input_shapes[1]);
|
159 |
+
assert(true == ret);
|
160 |
+
|
161 |
+
ret =
|
162 |
+
bmrt_tensor(&next_attention, p_bmrt, net_blocks_cache[0]->input_dtypes[2],
|
163 |
+
net_blocks_cache[0]->stages[0].input_shapes[2]);
|
164 |
+
assert(true == ret);
|
165 |
+
|
166 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
167 |
+
ret = bmrt_tensor(&past_key[i], p_bmrt, net_blocks[0]->output_dtypes[1],
|
168 |
+
net_blocks[0]->stages[0].output_shapes[1]);
|
169 |
+
assert(true == ret);
|
170 |
+
ret = bmrt_tensor(&past_value[i], p_bmrt, net_blocks[0]->output_dtypes[2],
|
171 |
+
net_blocks[0]->stages[0].output_shapes[2]);
|
172 |
+
assert(true == ret);
|
173 |
+
ret = bmrt_tensor(&present_key[i], p_bmrt, net_blocks[0]->output_dtypes[1],
|
174 |
+
net_blocks[0]->stages[0].output_shapes[1]);
|
175 |
+
assert(true == ret);
|
176 |
+
ret = bmrt_tensor(&present_value[i], p_bmrt, net_blocks[0]->output_dtypes[2],
|
177 |
+
net_blocks[0]->stages[0].output_shapes[2]);
|
178 |
+
assert(true == ret);
|
179 |
+
}
|
180 |
+
ret = bmrt_tensor(&present_key_cache, p_bmrt, net_blocks_cache[0]->output_dtypes[1],
|
181 |
+
net_blocks_cache[0]->stages[0].output_shapes[1]);
|
182 |
+
assert(true == ret);
|
183 |
+
ret = bmrt_tensor(&present_value_cache, p_bmrt, net_blocks_cache[0]->output_dtypes[2],
|
184 |
+
net_blocks_cache[0]->stages[0].output_shapes[2]);
|
185 |
+
assert(true == ret);
|
186 |
+
|
187 |
+
ret = bmrt_tensor(&inputs_lm, p_bmrt, net_lm->input_dtypes[0],
|
188 |
+
net_lm->stages[0].input_shapes[0]);
|
189 |
+
assert(true == ret);
|
190 |
+
ret = bmrt_tensor(&outputs_lm, p_bmrt, net_lm->output_dtypes[0],
|
191 |
+
net_lm->stages[0].output_shapes[0]);
|
192 |
+
assert(true == ret);
|
193 |
+
}
|
194 |
+
|
195 |
+
void Baichuan2::deinit() {
|
196 |
+
bm_free_device(bm_handle, inputs_embed_512.device_mem);
|
197 |
+
bm_free_device(bm_handle, outputs_embed_512.device_mem);
|
198 |
+
bm_free_device(bm_handle, inputs_lm.device_mem);
|
199 |
+
bm_free_device(bm_handle, outputs_lm.device_mem);
|
200 |
+
bm_free_device(bm_handle, inputs_pid.device_mem);
|
201 |
+
bm_free_device(bm_handle, next_pid.device_mem);
|
202 |
+
bm_free_device(bm_handle, inputs_attention.device_mem);
|
203 |
+
bm_free_device(bm_handle, next_attention.device_mem);
|
204 |
+
bm_free_device(bm_handle, present_key_cache.device_mem);
|
205 |
+
bm_free_device(bm_handle, present_value_cache.device_mem);
|
206 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
207 |
+
bm_free_device(bm_handle, past_key[i].device_mem);
|
208 |
+
bm_free_device(bm_handle, past_value[i].device_mem);
|
209 |
+
bm_free_device(bm_handle, present_key[i].device_mem);
|
210 |
+
bm_free_device(bm_handle, present_value[i].device_mem);
|
211 |
+
}
|
212 |
+
bmrt_destroy(p_bmrt);
|
213 |
+
for (auto h : handles) {
|
214 |
+
bm_dev_free(h);
|
215 |
+
}
|
216 |
+
}
|
217 |
+
|
218 |
+
int Baichuan2::forward_first(std::vector<int> &tokens) {
|
219 |
+
int input_ids[MAX_LEN] = {0}; // start token
|
220 |
+
int position_id[MAX_LEN] = {0};
|
221 |
+
float attention_mask[MAX_LEN * MAX_LEN] = {0};
|
222 |
+
token_length = tokens.size();
|
223 |
+
|
224 |
+
std::copy(tokens.begin(), tokens.end(), input_ids);
|
225 |
+
for (int i = 0; i < token_length; i++) {
|
226 |
+
position_id[i] = i;
|
227 |
+
}
|
228 |
+
|
229 |
+
for (int i = 0; i < MAX_LEN; i++) {
|
230 |
+
for (int j = 0; j < MAX_LEN; j++) {
|
231 |
+
if (j <= i && i < token_length) {
|
232 |
+
} else {
|
233 |
+
attention_mask[i * MAX_LEN + j] = ATTENTION_MASK;
|
234 |
+
}
|
235 |
+
}
|
236 |
+
}
|
237 |
+
|
238 |
+
// forward embeding
|
239 |
+
bm_memcpy_s2d(bm_handle, inputs_embed_512.device_mem, (void *)input_ids);
|
240 |
+
auto ret =
|
241 |
+
bmrt_launch_tensor_ex(p_bmrt, name_embed.c_str(), &inputs_embed_512, 1,
|
242 |
+
&outputs_embed_512, 1, true, false);
|
243 |
+
assert(ret);
|
244 |
+
// float test_embed[MAX_LEN] = {0};
|
245 |
+
// bm_memcpy_d2s(bm_handle, (void *)&test_embed, outputs_embed_512.device_mem);
|
246 |
+
bm_thread_sync(bm_handle);
|
247 |
+
|
248 |
+
// forward blocks
|
249 |
+
bm_memcpy_s2d(bm_handle, inputs_pid.device_mem, (void *)position_id);
|
250 |
+
bm_memcpy_s2d(bm_handle, inputs_attention.device_mem, (void *)attention_mask);
|
251 |
+
auto inputs_embed = outputs_embed_512;
|
252 |
+
inputs_embed.shape = net_blocks[0]->stages[0].input_shapes[0];
|
253 |
+
bm_tensor_t inputs_block[3] = {inputs_embed, inputs_pid, inputs_attention};
|
254 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
255 |
+
bm_tensor_t outputs_block[3] = {inputs_embed, past_key[i], past_value[i]};
|
256 |
+
ret = bmrt_launch_tensor_ex(p_bmrt, name_blocks[i].c_str(), inputs_block, 3,
|
257 |
+
outputs_block, 3, true, false);
|
258 |
+
assert(ret);
|
259 |
+
bm_thread_sync(bm_handle);
|
260 |
+
}
|
261 |
+
int bytes = inputs_embed.device_mem.size / MAX_LEN;
|
262 |
+
bm_memcpy_d2d_byte(bm_handle, inputs_lm.device_mem, 0,
|
263 |
+
inputs_embed.device_mem, (token_length - 1) * bytes,
|
264 |
+
bytes);
|
265 |
+
ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm, 1,
|
266 |
+
&outputs_lm, 1, true, false);
|
267 |
+
bm_thread_sync(bm_handle);
|
268 |
+
|
269 |
+
int token = 0;
|
270 |
+
bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm.device_mem);
|
271 |
+
return token;
|
272 |
+
}
|
273 |
+
|
274 |
+
int Baichuan2::forward_next() {
|
275 |
+
float attention_mask[MAX_LEN + 1] = {0};
|
276 |
+
for (int i = token_length - 1; i < MAX_LEN; i++) {
|
277 |
+
attention_mask[i] = ATTENTION_MASK;
|
278 |
+
}
|
279 |
+
int32_t position_id = token_length - 1;
|
280 |
+
// embedding
|
281 |
+
outputs_lm.shape = net_embed_cache->stages[0].input_shapes[0];
|
282 |
+
auto ret = bmrt_launch_tensor_ex(p_bmrt, name_embed_cache.c_str(), &outputs_lm, 1,
|
283 |
+
&inputs_lm, 1, true, false);
|
284 |
+
assert(ret);
|
285 |
+
bm_thread_sync(bm_handle);
|
286 |
+
|
287 |
+
// blocks
|
288 |
+
bm_memcpy_s2d(bm_handle, next_attention.device_mem, (void *)attention_mask);
|
289 |
+
bm_memcpy_s2d(bm_handle, next_pid.device_mem, (void *)&position_id);
|
290 |
+
auto inputs_embed = inputs_lm;
|
291 |
+
inputs_embed.shape = net_blocks_cache[0]->stages[0].input_shapes[0];
|
292 |
+
int bytes = bm_mem_get_device_size(present_key_cache.device_mem);
|
293 |
+
int token_offset = (token_length - 1) * bytes;
|
294 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
295 |
+
bm_tensor_t inputs_block[5] = {inputs_embed, next_pid, next_attention,
|
296 |
+
past_key[i], past_value[i]};
|
297 |
+
bm_tensor_t outputs_block[3] = {inputs_embed, present_key_cache, present_value_cache};
|
298 |
+
ret = bmrt_launch_tensor_ex(p_bmrt, name_blocks_cache[i].c_str(),
|
299 |
+
inputs_block, 5, outputs_block, 3, true, false);
|
300 |
+
assert(ret);
|
301 |
+
bm_thread_sync(bm_handle);
|
302 |
+
bm_memcpy_d2d_byte(bm_handle, past_key[i].device_mem, token_offset,
|
303 |
+
present_key_cache.device_mem, 0,
|
304 |
+
bytes);
|
305 |
+
bm_memcpy_d2d_byte(bm_handle, past_value[i].device_mem, token_offset,
|
306 |
+
present_value_cache.device_mem, 0,
|
307 |
+
bytes);
|
308 |
+
}
|
309 |
+
outputs_lm.shape = net_lm->stages[0].output_shapes[0];
|
310 |
+
ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm, 1,
|
311 |
+
&outputs_lm, 1, true, false);
|
312 |
+
bm_thread_sync(bm_handle);
|
313 |
+
|
314 |
+
int token = 0;
|
315 |
+
bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm.device_mem);
|
316 |
+
return token;
|
317 |
+
}
|
318 |
+
|
319 |
+
void Baichuan2::chat() {
|
320 |
+
while (true) {
|
321 |
+
std::cout << "\nQuestion: ";
|
322 |
+
std::string input_str;
|
323 |
+
std::getline(std::cin, input_str);
|
324 |
+
std::string user_token = "<reserved_106>"; //user token id 195
|
325 |
+
std::string assitant_token = "<reserved_107>"; //assistant token id 196
|
326 |
+
if (input_str == "exit") {
|
327 |
+
break;
|
328 |
+
}
|
329 |
+
if (input_str == "clear") {
|
330 |
+
history.clear();
|
331 |
+
continue;
|
332 |
+
}
|
333 |
+
|
334 |
+
input_str = user_token + input_str + assitant_token;
|
335 |
+
|
336 |
+
std::cout << "\nAnswer: " << std::flush;
|
337 |
+
answer(input_str);
|
338 |
+
std::cout << std::endl;
|
339 |
+
}
|
340 |
+
}
|
341 |
+
|
342 |
+
void Baichuan2::answer(const std::string &input_str) {
|
343 |
+
int tok_num = 0;
|
344 |
+
history.emplace_back(std::move(input_str));
|
345 |
+
|
346 |
+
std::vector<int> tokens;
|
347 |
+
|
348 |
+
std::string history_input = std::accumulate(history.begin(), history.end(), std::string());
|
349 |
+
sentencepiece.Encode(history_input, &tokens);
|
350 |
+
|
351 |
+
if (tokens.empty()) {
|
352 |
+
printf("Sorry: your question is too wierd!!\n");
|
353 |
+
history.clear();
|
354 |
+
round = 0;
|
355 |
+
return;
|
356 |
+
}
|
357 |
+
// make sure token not too large
|
358 |
+
if (tokens.size() > MAX_LEN - 10) {
|
359 |
+
// reset
|
360 |
+
if (round == 0) {
|
361 |
+
printf("Error: your question is too large!\n");
|
362 |
+
return;
|
363 |
+
}
|
364 |
+
round = 0;
|
365 |
+
history.clear();
|
366 |
+
answer(input_str);
|
367 |
+
return;
|
368 |
+
}
|
369 |
+
auto time_1 = std::chrono::system_clock::now();
|
370 |
+
int pre_token = 0;
|
371 |
+
int token = forward_first(tokens);
|
372 |
+
auto time_2 = std::chrono::system_clock::now();
|
373 |
+
std::string result;
|
374 |
+
while (token != EOS && token_length < MAX_LEN) {
|
375 |
+
std::string pre_word;
|
376 |
+
std::string word;
|
377 |
+
std::vector<int> pre_ids = {pre_token};
|
378 |
+
std::vector<int> ids = {pre_token, token};
|
379 |
+
sentencepiece.Decode(pre_ids, &pre_word);
|
380 |
+
sentencepiece.Decode(ids, &word);
|
381 |
+
std::string diff = word.substr(pre_word.size());
|
382 |
+
result += diff;
|
383 |
+
std::cout << diff << std::flush;
|
384 |
+
if (token_length < MAX_LEN) {
|
385 |
+
token_length++;
|
386 |
+
}
|
387 |
+
tok_num++;
|
388 |
+
token = forward_next();
|
389 |
+
}
|
390 |
+
auto time_3 = std::chrono::system_clock::now();
|
391 |
+
auto ftl_dur =
|
392 |
+
std::chrono::duration_cast<std::chrono::microseconds>(time_2 - time_1);
|
393 |
+
auto tps_dur =
|
394 |
+
std::chrono::duration_cast<std::chrono::microseconds>(time_3 - time_2);
|
395 |
+
double tps = tok_num / (tps_dur.count() * 1e-6);
|
396 |
+
if (token_length >= MAX_LEN) {
|
397 |
+
printf(" ......\nWarning: cleanup early history\n");
|
398 |
+
}
|
399 |
+
// double tht = tokens.size() / (tht_dur.count() * 1e-6);
|
400 |
+
printf("\nFTL:%f s, TPS: %f tokens/s\n", ftl_dur.count() * 1e-6, tps);
|
401 |
+
history.emplace_back(result);
|
402 |
+
if (token_length + 128 >= MAX_LEN) {
|
403 |
+
int num = (history.size() + 3) / 4 * 2;
|
404 |
+
history.erase(history.begin(), history.begin() + num);
|
405 |
+
}
|
406 |
+
}
|
407 |
+
|
408 |
+
static void split(const std::string &s, const std::string &delim,
|
409 |
+
std::vector<std::string> &ret) {
|
410 |
+
size_t last = 0;
|
411 |
+
size_t index = s.find_first_of(delim, last);
|
412 |
+
while (index != std::string::npos) {
|
413 |
+
ret.push_back(s.substr(last, index - last));
|
414 |
+
last = index + 1;
|
415 |
+
index = s.find_first_of(delim, last);
|
416 |
+
}
|
417 |
+
if (last < s.length()) {
|
418 |
+
ret.push_back(s.substr(last));
|
419 |
+
}
|
420 |
+
}
|
421 |
+
|
422 |
+
static std::vector<int> parseCascadeDevices(const std::string &str) {
|
423 |
+
std::vector<int> devices;
|
424 |
+
std::vector<std::string> sub_str;
|
425 |
+
split(str, ",", sub_str);
|
426 |
+
for (auto &s : sub_str) {
|
427 |
+
devices.push_back(std::atoi(s.c_str()));
|
428 |
+
}
|
429 |
+
return devices;
|
430 |
+
}
|
431 |
+
|
432 |
+
void processArguments(int argc, char *argv[], std::string &baichuan_model,
|
433 |
+
std::vector<int> &devices) {
|
434 |
+
struct option longOptions[] = {{"model", required_argument, nullptr, 'm'},
|
435 |
+
{"dev_id", required_argument, nullptr, 'd'},
|
436 |
+
{nullptr, 0, nullptr, 0}};
|
437 |
+
|
438 |
+
int optionIndex = 0;
|
439 |
+
int option;
|
440 |
+
|
441 |
+
while ((option = getopt_long(argc, argv, "m:d:", longOptions,
|
442 |
+
&optionIndex)) != -1) {
|
443 |
+
switch (option) {
|
444 |
+
case 'm':
|
445 |
+
baichuan_model = optarg;
|
446 |
+
break;
|
447 |
+
case 'd':
|
448 |
+
devices = parseCascadeDevices(optarg);
|
449 |
+
break;
|
450 |
+
case '?':
|
451 |
+
exit(EXIT_FAILURE);
|
452 |
+
default:
|
453 |
+
exit(EXIT_FAILURE);
|
454 |
+
}
|
455 |
+
}
|
456 |
+
}
|
457 |
+
|
458 |
+
int main(int argc, char **argv) {
|
459 |
+
// set your bmodel path here
|
460 |
+
printf("Demo for Baichuan2-7B in BM1684X\n");
|
461 |
+
std::string baichuan_model = "baichuan2-7b-test.bmodel";
|
462 |
+
std::vector<int> devices = {0};
|
463 |
+
processArguments(argc, argv, baichuan_model, devices);
|
464 |
+
|
465 |
+
Baichuan2 baichuan;
|
466 |
+
printf("Init Environment ...\n");
|
467 |
+
baichuan.init(devices, baichuan_model);
|
468 |
+
printf("==========================\n");
|
469 |
+
baichuan.chat();
|
470 |
+
baichuan.deinit();
|
471 |
+
return 0;
|
472 |
+
}
|
Baichuan2/model/tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:79452955be6b419a65984273a9f08af86042e1c2a75ee3ba989cbf620a133cc2
|
3 |
+
size 2001107
|
Baichuan2/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.1.2
|
2 |
+
transformers==4.36.2
|
3 |
+
sentencepiece==0.1.99
|
4 |
+
gradio==3.39.0
|
5 |
+
mdtex2html==1.2.0
|
6 |
+
accelerate
|
7 |
+
onnx
|
Baichuan2/src/include/bmdef.h
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*****************************************************************************
|
2 |
+
*
|
3 |
+
* Copyright (c) 2016-2026 by Sophgo Technologies Inc. All rights reserved.
|
4 |
+
*
|
5 |
+
* The material in this file is confidential and contains trade secrets
|
6 |
+
* of Sophgo Technologies Inc. This is proprietary information owned by
|
7 |
+
* Sophgo Technologies Inc. No part of this work may be disclosed,
|
8 |
+
* reproduced, copied, transmitted, or used in any way for any purpose,
|
9 |
+
* without the express written permission of Sophgo Technologies Inc.
|
10 |
+
*
|
11 |
+
*****************************************************************************/
|
12 |
+
|
13 |
+
#ifndef __BMRUNTIME_DEFINE_H__
|
14 |
+
#define __BMRUNTIME_DEFINE_H__
|
15 |
+
|
16 |
+
#include "bmlib_runtime.h"
|
17 |
+
#include <stddef.h>
|
18 |
+
#include <stdint.h>
|
19 |
+
|
20 |
+
#if defined(__cplusplus)
|
21 |
+
extern "C" {
|
22 |
+
#endif
|
23 |
+
|
24 |
+
/* --------------------------------------------------------------------------*/
|
25 |
+
/* basic definitions */
|
26 |
+
|
27 |
+
/* bm_data_type_t holds the type for a scalar value */
|
28 |
+
typedef enum bm_data_type_e {
|
29 |
+
BM_FLOAT32 = 0,
|
30 |
+
BM_FLOAT16 = 1,
|
31 |
+
BM_INT8 = 2,
|
32 |
+
BM_UINT8 = 3,
|
33 |
+
BM_INT16 = 4,
|
34 |
+
BM_UINT16 = 5,
|
35 |
+
BM_INT32 = 6,
|
36 |
+
BM_UINT32 = 7,
|
37 |
+
BM_BFLOAT16 = 8,
|
38 |
+
BM_INT4 = 9,
|
39 |
+
BM_UINT4 = 10,
|
40 |
+
} bm_data_type_t;
|
41 |
+
|
42 |
+
/* store mode definitions */
|
43 |
+
typedef enum bm_store_mode_e {
|
44 |
+
BM_STORE_1N = 0, /* default, if not sure, use 0 */
|
45 |
+
BM_STORE_2N = 1,
|
46 |
+
BM_STORE_4N = 2,
|
47 |
+
} bm_store_mode_t;
|
48 |
+
|
49 |
+
/* bm_shape_t holds the shape info */
|
50 |
+
#define BM_MAX_DIMS_NUM 8
|
51 |
+
typedef struct bm_shape_s {
|
52 |
+
int num_dims;
|
53 |
+
int dims[BM_MAX_DIMS_NUM];
|
54 |
+
} bm_shape_t;
|
55 |
+
|
56 |
+
typedef struct bm_shape_ex_s {
|
57 |
+
bm_shape_t shape;
|
58 |
+
int elem_num;
|
59 |
+
} bm_shape_ex_t;
|
60 |
+
|
61 |
+
/*
|
62 |
+
bm_tensor_t holds a multi-dimensional array of elements of a single data type
|
63 |
+
and tensor are in device memory */
|
64 |
+
typedef struct bm_tensor_s {
|
65 |
+
bm_data_type_t dtype;
|
66 |
+
bm_shape_t shape;
|
67 |
+
bm_device_mem_t device_mem;
|
68 |
+
bm_store_mode_t st_mode; /* user can set 0 as default store mode */
|
69 |
+
} bm_tensor_t;
|
70 |
+
|
71 |
+
/* --------------------------------------------------------------------------*/
|
72 |
+
/* network information structure */
|
73 |
+
|
74 |
+
/* bm_stage_info_t holds input/output shapes and device mems; every network can contain one or more
|
75 |
+
* stages */
|
76 |
+
typedef struct bm_stage_info_s {
|
77 |
+
bm_shape_t *input_shapes; /* input_shapes[0] / [1] / ... / [input_num-1] */
|
78 |
+
bm_shape_t *output_shapes; /* output_shapes[0] / [1] / ... / [output_num-1] */
|
79 |
+
bm_device_mem_t *input_mems; /* input_mems[0] / [1] / ... / [input_num-1] */
|
80 |
+
bm_device_mem_t *output_mems; /* output_mems[0] / [1] / ... / [output_num-1] */
|
81 |
+
} bm_stage_info_t;
|
82 |
+
|
83 |
+
/* bm_tensor_info_t holds all information of one net.
|
84 |
+
* scale for float type is 1.0 as default */
|
85 |
+
typedef struct bm_net_info_s {
|
86 |
+
const char* name; /* net name */
|
87 |
+
bool is_dynamic; /* dynamic or static */
|
88 |
+
int input_num; /* number of inputs */
|
89 |
+
char const** input_names; /* input_names[0] / [1] / .../ [input_num-1] */
|
90 |
+
bm_data_type_t* input_dtypes; /* input_dtypes[0] / [1] / .../ [input_num-1] */
|
91 |
+
float* input_scales; /* input_scales[0] / [1] / .../ [input_num-1] */
|
92 |
+
int output_num; /* number of outputs */
|
93 |
+
char const** output_names; /* output_names[0] / [1] / .../ [output_num-1] */
|
94 |
+
bm_data_type_t* output_dtypes; /* output_dtypes[0] / [1] / .../ [output_num-1] */
|
95 |
+
float* output_scales; /* output_scales[0] / [1] / .../ [output_num-1] */
|
96 |
+
int stage_num; /* number of stages */
|
97 |
+
bm_stage_info_t* stages; /* stages[0] / [1] / ... / [stage_num-1] */
|
98 |
+
size_t* max_input_bytes; /* max_input_bytes[0]/ [1] / ... / [input_num-1] */
|
99 |
+
size_t* max_output_bytes; /* max_output_bytes[0] / [1] / ... / [output_num-1] */
|
100 |
+
int* input_zero_point; /* input_zero_point[0] / [1] / .../ [input_num-1] */
|
101 |
+
int* output_zero_point; /* output_zero_point[0] / [1] / .../ [output_num-1] */
|
102 |
+
int *input_loc_devices; /* input_loc_device[0] / [1] / .../ [input_num-1] */
|
103 |
+
int *output_loc_devices; /* output_loc_device[0] / [1] / .../ [output_num-1] */
|
104 |
+
} bm_net_info_t;
|
105 |
+
|
106 |
+
typedef struct api_info_s {
|
107 |
+
/// @brief api_id to be sent to driver
|
108 |
+
int32_t api_id;
|
109 |
+
/// @brief api data to be sent to driver
|
110 |
+
uint8_t **api_data;
|
111 |
+
/// @brief size of the api data to be sent to driver
|
112 |
+
size_t api_data_size;
|
113 |
+
/// @brief subsize of the api data to be sent to driver
|
114 |
+
size_t *api_data_subsize;
|
115 |
+
/// @brief offset of input tensors' addr in api_data
|
116 |
+
uint32_t *input_addr_offset;
|
117 |
+
/// @brief number of the offset of input tensors' addr in api_data
|
118 |
+
size_t input_addr_offset_number;
|
119 |
+
/// @brief offset of output tensors' addr in api_data
|
120 |
+
uint32_t *output_addr_offset;
|
121 |
+
/// @brief number of the offset of output tensors' addr in api_data
|
122 |
+
size_t output_addr_offset_number;
|
123 |
+
} api_info_c;
|
124 |
+
|
125 |
+
#if defined(__cplusplus)
|
126 |
+
}
|
127 |
+
#endif
|
128 |
+
|
129 |
+
#endif /* __BM_NET_H__ */
|
Baichuan2/src/include/bmlib_runtime.h
ADDED
@@ -0,0 +1,2581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*****************************************************************************
|
2 |
+
*
|
3 |
+
* Copyright (c) 2016-2026 by Bitmain Technologies Inc. All rights reserved.
|
4 |
+
*
|
5 |
+
* The material in this file is confidential and contains trade secrets
|
6 |
+
* of Bitmain Technologies Inc. This is proprietary information owned by
|
7 |
+
* Bitmain Technologies Inc. No part of this work may be disclosed,
|
8 |
+
* reproduced, copied, transmitted, or used in any way for any purpose,
|
9 |
+
* without the express written permission of Bitmain Technologies Inc.
|
10 |
+
*
|
11 |
+
*****************************************************************************/
|
12 |
+
|
13 |
+
/**************************************************************************
|
14 |
+
* bmlib_runtime defines interfaces that operate TPU devices.
|
15 |
+
* The functions can be divided into serveral categories.
|
16 |
+
* 1) device handle creation and destroy
|
17 |
+
* 2) memory help functions
|
18 |
+
* 3) global memory allocation and free
|
19 |
+
* 4) data transfer between host and device
|
20 |
+
* 5) data transfer within device memory
|
21 |
+
* 6) api send and synchronization
|
22 |
+
* 7) global memory map and coherence
|
23 |
+
* 8) trace and profile
|
24 |
+
* 9) power management
|
25 |
+
* 10) miscellaneous functions
|
26 |
+
*************************************************************************/
|
27 |
+
|
28 |
+
#ifndef BMLIB_RUNTIME_H_
|
29 |
+
#define BMLIB_RUNTIME_H_
|
30 |
+
#if defined(_WIN32) && !defined(__MINGW32__)
|
31 |
+
#include <vadefs.h>
|
32 |
+
#define DECL_EXPORT __declspec(dllexport)
|
33 |
+
#define DECL_IMPORT __declspec(dllimport)
|
34 |
+
#else
|
35 |
+
#include <stdbool.h>
|
36 |
+
#include <stddef.h>
|
37 |
+
#include <stdarg.h>
|
38 |
+
#define DECL_EXPORT
|
39 |
+
#define DECL_IMPORT
|
40 |
+
#endif
|
41 |
+
|
42 |
+
#if defined(__cplusplus)
|
43 |
+
extern "C" {
|
44 |
+
#endif
|
45 |
+
|
46 |
+
typedef enum {
|
47 |
+
MODULE_CDMA = 0,
|
48 |
+
MODULE_GDMA = 1,
|
49 |
+
MODULE_TPU = 2,
|
50 |
+
MODULE_SMMU = 3,
|
51 |
+
MODULE_SRAM = 4,
|
52 |
+
MODULE_END = 5
|
53 |
+
} MODULE_ID;
|
54 |
+
|
55 |
+
#define BM_MEM_ADDR_NULL (0xfffffffff)
|
56 |
+
|
57 |
+
#ifndef BM_MEM_DESC_T_
|
58 |
+
#define BM_MEM_DESC_T_
|
59 |
+
/* BM function return code definitions */
|
60 |
+
typedef enum {
|
61 |
+
BM_SUCCESS = 0,
|
62 |
+
BM_ERR_DEVNOTREADY = 1, /* Device not ready yet */
|
63 |
+
BM_ERR_FAILURE = 2, /* General failure */
|
64 |
+
BM_ERR_TIMEOUT = 3, /* Timeout */
|
65 |
+
BM_ERR_PARAM = 4, /* Parameters invalid */
|
66 |
+
BM_ERR_NOMEM = 5, /* Not enough memory */
|
67 |
+
BM_ERR_DATA = 6, /* Data error */
|
68 |
+
BM_ERR_BUSY = 7, /* Busy */
|
69 |
+
BM_ERR_NOFEATURE = 8, /* Not supported yet */
|
70 |
+
BM_NOT_SUPPORTED = 9
|
71 |
+
} bm_status_t;
|
72 |
+
|
73 |
+
/* BM memory type definitions */
|
74 |
+
typedef enum {
|
75 |
+
BM_MEM_TYPE_DEVICE = 0,
|
76 |
+
BM_MEM_TYPE_HOST = 1,
|
77 |
+
BM_MEM_TYPE_SYSTEM = 2,
|
78 |
+
BM_MEM_TYPE_INT8_DEVICE = 3,
|
79 |
+
BM_MEM_TYPE_INVALID = 4
|
80 |
+
} bm_mem_type_t;
|
81 |
+
|
82 |
+
typedef enum {
|
83 |
+
PERF_MONITOR_GDMA = 0,
|
84 |
+
PERF_MONITOR_TPU = 1
|
85 |
+
} PERF_MONITOR_ID;
|
86 |
+
|
87 |
+
typedef enum {
|
88 |
+
BMCPU_IDLE = 0,
|
89 |
+
BMCPU_RUNNING = 1,
|
90 |
+
BMCPU_FAULT = 2
|
91 |
+
} bm_cpu_status_t;
|
92 |
+
|
93 |
+
/*
|
94 |
+
* bm performace monitor
|
95 |
+
*/
|
96 |
+
typedef struct bm_perf_monitor {
|
97 |
+
long long buffer_start_addr; /*buffer address to store perf data*/
|
98 |
+
int buffer_size; /*buffer size*/
|
99 |
+
PERF_MONITOR_ID monitor_id; /*PERF_MONITOR_GDMA or PERF_MONITOR_TPU*/
|
100 |
+
} bm_perf_monitor_t;
|
101 |
+
|
102 |
+
typedef union {
|
103 |
+
struct {
|
104 |
+
bm_mem_type_t mem_type : 3;
|
105 |
+
unsigned int gmem_heapid : 3;
|
106 |
+
unsigned int reserved : 26;
|
107 |
+
} u;
|
108 |
+
unsigned int rawflags;
|
109 |
+
} bm_mem_flags_t;
|
110 |
+
|
111 |
+
/* BM memory descriptor definition*/
|
112 |
+
typedef struct bm_mem_desc {
|
113 |
+
union {
|
114 |
+
struct {
|
115 |
+
#ifdef __linux__
|
116 |
+
unsigned long device_addr;
|
117 |
+
#else
|
118 |
+
unsigned long long device_addr;
|
119 |
+
#endif
|
120 |
+
unsigned int reserved;
|
121 |
+
int dmabuf_fd;
|
122 |
+
} device;
|
123 |
+
|
124 |
+
struct {
|
125 |
+
void *system_addr;
|
126 |
+
unsigned int reserved0;
|
127 |
+
int reserved1;
|
128 |
+
} system;
|
129 |
+
} u;
|
130 |
+
|
131 |
+
bm_mem_flags_t flags;
|
132 |
+
unsigned int size;
|
133 |
+
} bm_mem_desc_t;
|
134 |
+
|
135 |
+
typedef struct bm_mem_desc bm_device_mem_t;
|
136 |
+
typedef struct bm_mem_desc bm_system_mem_t;
|
137 |
+
|
138 |
+
typedef struct sg_mem_desc {
|
139 |
+
union {
|
140 |
+
struct {
|
141 |
+
#ifdef __linux__
|
142 |
+
unsigned long device_addr;
|
143 |
+
#else
|
144 |
+
unsigned long long device_addr;
|
145 |
+
#endif
|
146 |
+
unsigned int reserved;
|
147 |
+
int dmabuf_fd;
|
148 |
+
} device;
|
149 |
+
|
150 |
+
struct {
|
151 |
+
void *system_addr;
|
152 |
+
unsigned int reserved0;
|
153 |
+
int reserved1;
|
154 |
+
} system;
|
155 |
+
} u;
|
156 |
+
|
157 |
+
bm_mem_flags_t flags;
|
158 |
+
unsigned long long size;
|
159 |
+
} sg_mem_desc_t;
|
160 |
+
|
161 |
+
typedef struct sg_mem_desc sg_device_mem_t;
|
162 |
+
typedef struct sg_mem_desc sg_system_mem_t;
|
163 |
+
#endif
|
164 |
+
|
165 |
+
struct bm_context;
|
166 |
+
typedef struct bm_context *bm_handle_t;
|
167 |
+
|
168 |
+
#define MD5SUM_LEN 16
|
169 |
+
#define LIB_MAX_NAME_LEN 64
|
170 |
+
#define FUNC_MAX_NAME_LEN 64
|
171 |
+
|
172 |
+
typedef struct bm_module
|
173 |
+
{
|
174 |
+
// void *lib_handle;
|
175 |
+
char lib_name[LIB_MAX_NAME_LEN];
|
176 |
+
unsigned char md5[MD5SUM_LEN];
|
177 |
+
}bm_module;
|
178 |
+
|
179 |
+
typedef struct bm_module *tpu_kernel_module_t;
|
180 |
+
typedef int tpu_kernel_function_t;
|
181 |
+
|
182 |
+
/**
|
183 |
+
* @name tpu_kernel_load_module_file
|
184 |
+
* @brief To load dyn file
|
185 |
+
* @ingroup bmlib_runtime
|
186 |
+
*
|
187 |
+
* @param [in] handle The device handle
|
188 |
+
* @param [in] module_file dyn file
|
189 |
+
* @retval dyn lib ptr
|
190 |
+
*/
|
191 |
+
tpu_kernel_module_t tpu_kernel_load_module_file(bm_handle_t handle, const char *module_file);
|
192 |
+
|
193 |
+
/**
|
194 |
+
* @name tpu_kernel_load_module_file_key
|
195 |
+
* @brief To load dyn file with key
|
196 |
+
* @ingroup bmlib_runtime
|
197 |
+
*
|
198 |
+
* @param [in] handle The device handle
|
199 |
+
* @param [in] module_file dyn file
|
200 |
+
* @param [in] key identification str
|
201 |
+
* @param [in] size key size
|
202 |
+
* @retval dyn lib ptr
|
203 |
+
*/
|
204 |
+
tpu_kernel_module_t tpu_kernel_load_module_file_key(bm_handle_t handle, const char *module_file, const char *key, int size);
|
205 |
+
|
206 |
+
/**
|
207 |
+
* @name tpu_kernel_unload_module
|
208 |
+
* @brief To unload dyn file
|
209 |
+
* @ingroup bmlib_runtime
|
210 |
+
*
|
211 |
+
* @param [in] handle The device handle
|
212 |
+
* @param [in] p_module dyn lib ptr
|
213 |
+
* @retval BM_SUCCESS Succeeds.
|
214 |
+
* Other code Fails.
|
215 |
+
*/
|
216 |
+
bm_status_t tpu_kernel_unload_module(bm_handle_t handle, tpu_kernel_module_t p_module);
|
217 |
+
|
218 |
+
/**
|
219 |
+
* @name tpu_kernel_free_module
|
220 |
+
* @brief To free p_module when not use
|
221 |
+
* @ingroup bmlib_runtime
|
222 |
+
*
|
223 |
+
* @param [in] handle The device handle
|
224 |
+
* @param [in] p_module dyn lib ptr
|
225 |
+
* @retval BM_SUCCESS Succeeds.
|
226 |
+
* Other code Fails.
|
227 |
+
*/
|
228 |
+
bm_status_t tpu_kernel_free_module(bm_handle_t handle, tpu_kernel_module_t p_module);
|
229 |
+
|
230 |
+
/**
|
231 |
+
* @name tpu_kernel_load_module
|
232 |
+
* @brief To load dyn module
|
233 |
+
* @ingroup bmlib_runtime
|
234 |
+
*
|
235 |
+
* @param [in] handle The device handle
|
236 |
+
* @param [in] data dyn module
|
237 |
+
* @param [in] length dyn module size
|
238 |
+
* @retval dyn lib ptr
|
239 |
+
*/
|
240 |
+
tpu_kernel_module_t tpu_kernel_load_module(bm_handle_t handle, const char *data, size_t length);
|
241 |
+
|
242 |
+
/**
|
243 |
+
* @name tpu_kernel_get_function
|
244 |
+
* @brief To get function from lib
|
245 |
+
* @ingroup bmlib_runtime
|
246 |
+
*
|
247 |
+
* @param [in] handle The device handle
|
248 |
+
* @param [in] module dyn module
|
249 |
+
* @param [in] function funtion name
|
250 |
+
* @retval function id
|
251 |
+
*/
|
252 |
+
tpu_kernel_function_t tpu_kernel_get_function(bm_handle_t handle, tpu_kernel_module_t module, const char *function);
|
253 |
+
|
254 |
+
/**
|
255 |
+
* @name tpu_kernel_launch
|
256 |
+
* @brief To launch function with sync
|
257 |
+
* @ingroup bmlib_runtime
|
258 |
+
*
|
259 |
+
* @param [in] handle The device handle
|
260 |
+
* @param [in] function function id
|
261 |
+
* @param [in] args funtion args
|
262 |
+
* @param [in] size args size
|
263 |
+
* @retval BM_SUCCESS Succeeds.
|
264 |
+
* Other code Fails.
|
265 |
+
*/
|
266 |
+
bm_status_t tpu_kernel_launch(bm_handle_t handle, tpu_kernel_function_t function, void *args, size_t size);
|
267 |
+
|
268 |
+
/**
|
269 |
+
* @name tpu_kernel_launch_async
|
270 |
+
* @brief To launch function with async
|
271 |
+
* @ingroup bmlib_runtime
|
272 |
+
*
|
273 |
+
* @param [in] handle The device handle
|
274 |
+
* @param [in] function function id
|
275 |
+
* @param [in] args funtion args
|
276 |
+
* @param [in] size args size
|
277 |
+
* @retval BM_SUCCESS Succeeds.
|
278 |
+
* Other code Fails.
|
279 |
+
*/
|
280 |
+
bm_status_t tpu_kernel_launch_async(bm_handle_t handle, tpu_kernel_function_t function, void *args, size_t size);
|
281 |
+
|
282 |
+
/**
|
283 |
+
* @name tpu_kernel_launch_async_multi_cores
|
284 |
+
* @brief To launch function with async for multi cores
|
285 |
+
* @ingroup bmlib_runtime
|
286 |
+
*
|
287 |
+
* @param [in] handle The device handle
|
288 |
+
* @param [in] func_name function name
|
289 |
+
* @param [in] api_param funtion params
|
290 |
+
* @param [in] api_size params size
|
291 |
+
* @param [in] core_list list of core ids
|
292 |
+
* @param [in] core_num number of cores
|
293 |
+
* @retval BM_SUCCESS Succeeds.
|
294 |
+
* Other code Fails.
|
295 |
+
*/
|
296 |
+
bm_status_t tpu_kernel_launch_async_multi_cores(bm_handle_t handle, const char *func_name, const void *api_param,
|
297 |
+
size_t api_size, const int* core_list, const int core_num);
|
298 |
+
|
299 |
+
/**
|
300 |
+
* @name tpu_kernel_launch_sync_multi_cores
|
301 |
+
* @brief To launch function with sync for multi cores
|
302 |
+
* @ingroup bmlib_runtime
|
303 |
+
*
|
304 |
+
* @param [in] handle The device handle
|
305 |
+
* @param [in] func_name function name
|
306 |
+
* @param [in] api_param funtion params
|
307 |
+
* @param [in] api_size params size
|
308 |
+
* @param [in] core_list list of core ids
|
309 |
+
* @param [in] core_num number of cores
|
310 |
+
* @retval BM_SUCCESS Succeeds.
|
311 |
+
* Other code Fails.
|
312 |
+
*/
|
313 |
+
bm_status_t tpu_kernel_launch_sync_multi_cores(bm_handle_t handle, const char *func_name, const void *api_param,
|
314 |
+
size_t api_size, const int* core_list, const int core_num);
|
315 |
+
|
316 |
+
/**
|
317 |
+
* @name tpu_kernel_sync
|
318 |
+
* @brief To sync
|
319 |
+
* @ingroup bmlib_runtime
|
320 |
+
*
|
321 |
+
* @param [in] handle The device handle
|
322 |
+
* @retval BM_SUCCESS Succeeds.
|
323 |
+
* Other code Fails.
|
324 |
+
*/
|
325 |
+
bm_status_t tpu_kernel_sync(bm_handle_t handle);
|
326 |
+
void show_md5(unsigned char md5[]);
|
327 |
+
|
328 |
+
DECL_EXPORT void bmlib_log(const char *tag, int level, const char *fmt, ...);
|
329 |
+
|
330 |
+
#ifndef USING_CMODEL
|
331 |
+
#define BM_CHECK_RET(call) \
|
332 |
+
do { \
|
333 |
+
bm_status_t ret = (bm_status_t)call; \
|
334 |
+
if (ret != BM_SUCCESS) { \
|
335 |
+
bmlib_log("BM_CHECK",16,"BM_CHECK_RET fail %s: %s: %d\n", __FILE__, __func__, __LINE__); \
|
336 |
+
return ret; \
|
337 |
+
} \
|
338 |
+
} while (0)
|
339 |
+
#else
|
340 |
+
#define BM_CHECK_RET(call) \
|
341 |
+
do { \
|
342 |
+
bm_status_t ret = call; \
|
343 |
+
if (ret != BM_SUCCESS) { \
|
344 |
+
bmlib_log("BM_CHECK",16,"BM_CHECK_RET failed %d\n", ret);\
|
345 |
+
ASSERT(0); \
|
346 |
+
exit(-ret); \
|
347 |
+
} \
|
348 |
+
} while (0)
|
349 |
+
#endif
|
350 |
+
|
351 |
+
/*******************handle releated functions *********************************/
|
352 |
+
/**
|
353 |
+
* @name bm_dev_getcount
|
354 |
+
* @brief To get the number of sophon devices in system.
|
355 |
+
* If N is got, valid devid is [0, N-1]
|
356 |
+
* @ingroup bmlib_runtime
|
357 |
+
*
|
358 |
+
* @param [out] count The result number of sophon devices
|
359 |
+
* @retval BM_SUCCESS Succeeds.
|
360 |
+
* Other code Fails.
|
361 |
+
*/
|
362 |
+
DECL_EXPORT bm_status_t bm_dev_getcount(int *count);
|
363 |
+
|
364 |
+
/**
|
365 |
+
* @name bm_dev_query
|
366 |
+
* @brief To query if a device is present
|
367 |
+
* @ingroup bmlib_runtime
|
368 |
+
*
|
369 |
+
* @param [in] devid The id of the device to query
|
370 |
+
* @retval BM_SUCCESS Device is present
|
371 |
+
* Other code Devcie is not present
|
372 |
+
*/
|
373 |
+
DECL_EXPORT bm_status_t bm_dev_query(int devid);
|
374 |
+
|
375 |
+
/**
|
376 |
+
* @name bm_dev_request
|
377 |
+
* @brief To create a handle for the given device
|
378 |
+
* @ingroup bmlib_runtime
|
379 |
+
*
|
380 |
+
* @param [out] handle The created handle
|
381 |
+
* @param [in] devid Specify on which device to create handle
|
382 |
+
* @retval BM_SUCCESS Succeeds.
|
383 |
+
* Other code Fails.
|
384 |
+
*/
|
385 |
+
DECL_EXPORT bm_status_t bm_dev_request(bm_handle_t *handle, int devid);
|
386 |
+
|
387 |
+
/**
|
388 |
+
* @name bm_get_devid
|
389 |
+
* @brief To get device index for the given handle
|
390 |
+
* @ingroup bmlib_runtime
|
391 |
+
*
|
392 |
+
* @param [in] handle The given handle
|
393 |
+
* @retval int device index that the handle points to.
|
394 |
+
*/
|
395 |
+
DECL_EXPORT int bm_get_devid(bm_handle_t handle);
|
396 |
+
|
397 |
+
/**
|
398 |
+
* @name bm_dev_free
|
399 |
+
* @brief To free a handle
|
400 |
+
* @ingroup bmlib_runtime
|
401 |
+
*
|
402 |
+
* @param [in] handle The handle to free
|
403 |
+
*/
|
404 |
+
DECL_EXPORT void bm_dev_free(bm_handle_t handle);
|
405 |
+
|
406 |
+
/*******************memory help functions ************************************/
|
407 |
+
/**
|
408 |
+
* @name bm_mem_get_type
|
409 |
+
* @brief To get a memory descriptor's type
|
410 |
+
* @ingroup bmlib_runtime
|
411 |
+
*
|
412 |
+
* @param [in] mem The memory descriptor queried
|
413 |
+
* @retval BM_MEM_TYPE_DEVICE Device global memory
|
414 |
+
* @retval BM_MEM_TYPE_SYSTEM Host user memory
|
415 |
+
*/
|
416 |
+
DECL_EXPORT bm_mem_type_t bm_mem_get_type(struct bm_mem_desc mem);
|
417 |
+
|
418 |
+
/**
|
419 |
+
* @name sg_mem_get_type
|
420 |
+
* @brief To get a memory descriptor's type
|
421 |
+
* @ingroup bmlib_runtime
|
422 |
+
*
|
423 |
+
* @param [in] mem The memory descriptor queried
|
424 |
+
* @retval BM_MEM_TYPE_DEVICE Device global memory
|
425 |
+
* @retval BM_MEM_TYPE_SYSTEM Host user memory
|
426 |
+
*/
|
427 |
+
DECL_EXPORT bm_mem_type_t sg_mem_get_type(struct sg_mem_desc mem);
|
428 |
+
|
429 |
+
/**
|
430 |
+
* @name bm_mem_get_device_addr
|
431 |
+
* @brief To get a device memory descriptor's address
|
432 |
+
* @ingroup bmlib_runtime
|
433 |
+
*
|
434 |
+
* @param [in] mem The device memory descriptor queried
|
435 |
+
* @retval unsigned long long The device memory address
|
436 |
+
*/
|
437 |
+
DECL_EXPORT unsigned long long bm_mem_get_device_addr(struct bm_mem_desc mem);
|
438 |
+
|
439 |
+
/**
|
440 |
+
* @name sg_mem_get_device_addr
|
441 |
+
* @brief To get a device memory descriptor's address
|
442 |
+
* @ingroup bmlib_runtime
|
443 |
+
*
|
444 |
+
* @param [in] mem The device memory descriptor queried
|
445 |
+
* @retval unsigned long long The device memory address
|
446 |
+
*/
|
447 |
+
DECL_EXPORT unsigned long long sg_mem_get_device_addr(struct sg_mem_desc mem);
|
448 |
+
|
449 |
+
/**
|
450 |
+
* @name bm_mem_set_device_addr
|
451 |
+
* @brief To set a device memory descriptor's address
|
452 |
+
* @ingroup bmlib_runtime
|
453 |
+
*
|
454 |
+
* @param [in] pmem The device memory descriptor pointer
|
455 |
+
* @param ]in] addr The new device address of the device memory
|
456 |
+
*/
|
457 |
+
DECL_EXPORT void bm_mem_set_device_addr(struct bm_mem_desc* pmem, unsigned long long addr);
|
458 |
+
|
459 |
+
/**
|
460 |
+
* @name sg_mem_set_device_addr
|
461 |
+
* @brief To set a device memory descriptor's address
|
462 |
+
* @ingroup bmlib_runtime
|
463 |
+
*
|
464 |
+
* @param [in] pmem The device memory descriptor pointer
|
465 |
+
* @param ]in] addr The new device address of the device memory
|
466 |
+
*/
|
467 |
+
DECL_EXPORT void sg_mem_set_device_addr(struct sg_mem_desc* pmem, unsigned long long addr);
|
468 |
+
|
469 |
+
/**
|
470 |
+
* @name bm_mem_get_device_size
|
471 |
+
* @brief To get a device memory descriptor's size
|
472 |
+
* @ingroup bmlib_runtime
|
473 |
+
*
|
474 |
+
* @param [in] mem The device memory descriptor queried
|
475 |
+
* @retval unsigned int The device memory's size in bytes
|
476 |
+
*/
|
477 |
+
DECL_EXPORT unsigned int bm_mem_get_device_size(struct bm_mem_desc mem);
|
478 |
+
|
479 |
+
/**
|
480 |
+
* @name sg_mem_get_device_size
|
481 |
+
* @brief To get a device memory descriptor's size
|
482 |
+
* @ingroup bmlib_runtime
|
483 |
+
*
|
484 |
+
* @param [in] mem The device memory descriptor queried
|
485 |
+
* @retval unsigned int The device memory's size in bytes
|
486 |
+
*/
|
487 |
+
DECL_EXPORT unsigned long long sg_mem_get_device_size(struct sg_mem_desc mem);
|
488 |
+
|
489 |
+
/**
|
490 |
+
* @name bm_mem_set_device_size
|
491 |
+
* @brief To set a device memory descriptor's size
|
492 |
+
* @ingroup bmlib_runtime
|
493 |
+
*
|
494 |
+
* @param [out] pmem The device memory descriptor pointer
|
495 |
+
* @param [in] size The new device memory size (in bytes) of the device memory
|
496 |
+
*/
|
497 |
+
DECL_EXPORT void bm_mem_set_device_size(struct bm_mem_desc* pmem, unsigned int size);
|
498 |
+
|
499 |
+
/**
|
500 |
+
* @name sg_mem_set_device_size
|
501 |
+
* @brief To set a device memory descriptor's size
|
502 |
+
* @ingroup bmlib_runtime
|
503 |
+
*
|
504 |
+
* @param [out] pmem The device memory descriptor pointer
|
505 |
+
* @param [in] size The new device memory size (in bytes) of the device memory
|
506 |
+
*/
|
507 |
+
DECL_EXPORT void sg_mem_set_device_size(struct sg_mem_desc* pmem, unsigned long long size);
|
508 |
+
|
509 |
+
/**
|
510 |
+
* @name bm_set_device_mem
|
511 |
+
* @brief To fill in a device memory descriptor with size and address
|
512 |
+
* @ingroup bmlib_runtime
|
513 |
+
*
|
514 |
+
* @param [in] pmem The device memory descriptor pointer
|
515 |
+
* @param [in] size The device memory descriptor's size
|
516 |
+
* @param [in] addr The device memory descriptor's address
|
517 |
+
*/
|
518 |
+
DECL_EXPORT void bm_set_device_mem(bm_device_mem_t* pmem, unsigned int size,
|
519 |
+
unsigned long long addr);
|
520 |
+
|
521 |
+
/**
|
522 |
+
* @name sg_set_device_mem
|
523 |
+
* @brief To fill in a device memory descriptor with size and address
|
524 |
+
* @ingroup bmlib_runtime
|
525 |
+
*
|
526 |
+
* @param [in] pmem The device memory descriptor pointer
|
527 |
+
* @param [in] size The device memory descriptor's size
|
528 |
+
* @param [in] addr The device memory descriptor's address
|
529 |
+
*/
|
530 |
+
DECL_EXPORT void sg_set_device_mem(sg_device_mem_t* pmem, unsigned long long size,
|
531 |
+
unsigned long long addr);
|
532 |
+
|
533 |
+
/**
|
534 |
+
* @name bm_mem_from_device
|
535 |
+
* @brief To create a device memory descriptor from address and size
|
536 |
+
* @ingroup bmlib_runtime
|
537 |
+
*
|
538 |
+
* @param [in] device_addr The device memory address
|
539 |
+
* @param [in] len The device memory size
|
540 |
+
* @retval bm_device_mem_t The device memory descriptor created
|
541 |
+
*/
|
542 |
+
DECL_EXPORT bm_device_mem_t bm_mem_from_device(unsigned long long device_addr,
|
543 |
+
unsigned int len);
|
544 |
+
|
545 |
+
/**
|
546 |
+
* @name sg_mem_from_device
|
547 |
+
* @brief To create a device memory descriptor from address and size
|
548 |
+
* @ingroup bmlib_runtime
|
549 |
+
*
|
550 |
+
* @param [in] device_addr The device memory address
|
551 |
+
* @param [in] len The device memory size
|
552 |
+
* @retval bm_device_mem_t The device memory descriptor created
|
553 |
+
*/
|
554 |
+
DECL_EXPORT sg_device_mem_t sg_mem_from_device(unsigned long long device_addr,
|
555 |
+
unsigned long long len);
|
556 |
+
|
557 |
+
/**
|
558 |
+
* @name bm_mem_get_system_addr
|
559 |
+
* @brief To get a system memory descriptor's address
|
560 |
+
* @ingroup bmlib_runtime
|
561 |
+
*
|
562 |
+
* @param [in] mem The system memory descriptor
|
563 |
+
* @retval void * The system memory descriptor's address
|
564 |
+
*/
|
565 |
+
DECL_EXPORT void *bm_mem_get_system_addr(struct bm_mem_desc mem);
|
566 |
+
|
567 |
+
/**
|
568 |
+
* @name sg_mem_get_system_addr
|
569 |
+
* @brief To get a system memory descriptor's address
|
570 |
+
* @ingroup bmlib_runtime
|
571 |
+
*
|
572 |
+
* @param [in] mem The system memory descriptor
|
573 |
+
* @retval void * The system memory descriptor's address
|
574 |
+
*/
|
575 |
+
DECL_EXPORT void *sg_mem_get_system_addr(struct sg_mem_desc mem);
|
576 |
+
|
577 |
+
/**
|
578 |
+
* @name bm_mem_set_system_addr
|
579 |
+
* @brief To set a system memory descriptor's address
|
580 |
+
* @ingroup bmlib_runtime
|
581 |
+
*
|
582 |
+
* @param [in] pmem The system memory descriptor pointer
|
583 |
+
* @param [in] addr The system memory address
|
584 |
+
*/
|
585 |
+
DECL_EXPORT void bm_mem_set_system_addr(struct bm_mem_desc* pmem, void *addr);
|
586 |
+
|
587 |
+
/**
|
588 |
+
* @name sg_mem_set_system_addr
|
589 |
+
* @brief To set a system memory descriptor's address
|
590 |
+
* @ingroup bmlib_runtime
|
591 |
+
*
|
592 |
+
* @param [in] pmem The system memory descriptor pointer
|
593 |
+
* @param [in] addr The system memory address
|
594 |
+
*/
|
595 |
+
DECL_EXPORT void sg_mem_set_system_addr(struct sg_mem_desc* pmem, void *addr);
|
596 |
+
|
597 |
+
/**
|
598 |
+
* @name bm_mem_from_system
|
599 |
+
* @brief To create a system memory descriptor with the given system address
|
600 |
+
* @ingroup bmlib_runtime
|
601 |
+
*
|
602 |
+
* @param [in] system_addr The system address in the descriptor
|
603 |
+
* @retval bm_system_mem_t The system memory descriptor created
|
604 |
+
*/
|
605 |
+
DECL_EXPORT bm_system_mem_t bm_mem_from_system(void *system_addr);
|
606 |
+
|
607 |
+
/*******************memory alloc and free functions ***************************/
|
608 |
+
/**
|
609 |
+
* @name bm_mem_null
|
610 |
+
* @brief Return an illegal device memory descriptor
|
611 |
+
* @ingroup bmlib_runtime
|
612 |
+
*
|
613 |
+
* @retval bm_device_mem_t An invalid device memory descriptor
|
614 |
+
*/
|
615 |
+
DECL_EXPORT bm_device_mem_t bm_mem_null(void);
|
616 |
+
#define BM_MEM_NULL (bm_mem_null())
|
617 |
+
|
618 |
+
/**
|
619 |
+
* @name bm_malloc_neuron_device
|
620 |
+
* @brief To malloc device memory according to a tensor shape
|
621 |
+
* (each neuron is 32 bits)
|
622 |
+
* @ingroup bmlib_runtime
|
623 |
+
*
|
624 |
+
* @param [in] handle The device handle
|
625 |
+
* @param [out] pmem The result devcie memory descriptor
|
626 |
+
* @param [in] n, c, h, w The shape of the input tensor
|
627 |
+
* @retval BM_SUCCESS Succeeds.
|
628 |
+
* Other code Fails.
|
629 |
+
*/
|
630 |
+
DECL_EXPORT bm_status_t bm_malloc_neuron_device(bm_handle_t handle, bm_device_mem_t *pmem,
|
631 |
+
int n, int c, int h, int w);
|
632 |
+
|
633 |
+
/**
|
634 |
+
* @name sg_malloc_neuron_device
|
635 |
+
* @brief To malloc device memory according to a tensor shape
|
636 |
+
* (each neuron is 32 bits)
|
637 |
+
* @ingroup bmlib_runtime
|
638 |
+
*
|
639 |
+
* @param [in] handle The device handle
|
640 |
+
* @param [out] pmem The result devcie memory descriptor
|
641 |
+
* @param [in] n, c, h, w The shape of the input tensor
|
642 |
+
* @retval BM_SUCCESS Succeeds.
|
643 |
+
* Other code Fails.
|
644 |
+
*/
|
645 |
+
DECL_EXPORT bm_status_t sg_malloc_neuron_device(bm_handle_t handle, sg_device_mem_t *pmem,
|
646 |
+
unsigned long long n, unsigned long long c,
|
647 |
+
unsigned long long h, unsigned long long w);
|
648 |
+
|
649 |
+
/**
|
650 |
+
* @name bm_malloc_device_dword
|
651 |
+
* @brief To malloc device memory in size of dword (32 bits)
|
652 |
+
* @ingroup bmlib_runtime
|
653 |
+
*
|
654 |
+
* @param [in] handle The device handle
|
655 |
+
* @param [out] pmem The result device memory descriptor
|
656 |
+
* @param [in] count The number of dwords(32bits) to allocate
|
657 |
+
* @retval BM_SUCCESS Succeeds.
|
658 |
+
* Other code Fails.
|
659 |
+
*/
|
660 |
+
DECL_EXPORT bm_status_t bm_malloc_device_dword(bm_handle_t handle, bm_device_mem_t *pmem,
|
661 |
+
int count);
|
662 |
+
|
663 |
+
/**
|
664 |
+
* @name sg_malloc_device_dword
|
665 |
+
* @brief To malloc device memory in size of dword (32 bits)
|
666 |
+
* @ingroup bmlib_runtime
|
667 |
+
*
|
668 |
+
* @param [in] handle The device handle
|
669 |
+
* @param [out] pmem The result device memory descriptor
|
670 |
+
* @param [in] count The number of dwords(32bits) to allocate
|
671 |
+
* @retval BM_SUCCESS Succeeds.
|
672 |
+
* Other code Fails.
|
673 |
+
*/
|
674 |
+
DECL_EXPORT bm_status_t sg_malloc_device_dword(bm_handle_t handle, sg_device_mem_t *pmem,
|
675 |
+
unsigned long long count);
|
676 |
+
|
677 |
+
/**
|
678 |
+
* @name bm_malloc_device_byte
|
679 |
+
* @brief To malloc device memory in size of byte
|
680 |
+
* @ingroup bmlib_runtime
|
681 |
+
*
|
682 |
+
* @param [in] handle The device handle
|
683 |
+
* @param [out] pmem The result device memory descriptor
|
684 |
+
* @param [in] size The number of bytes to allocate
|
685 |
+
* @retval BM_SUCCESS Succeeds.
|
686 |
+
* Other code Fails.
|
687 |
+
*/
|
688 |
+
DECL_EXPORT bm_status_t bm_malloc_device_byte(bm_handle_t handle, bm_device_mem_t *pmem,
|
689 |
+
unsigned int size);
|
690 |
+
|
691 |
+
/**
|
692 |
+
* @name sg_malloc_device_byte
|
693 |
+
* @brief To malloc device memory in size of byte
|
694 |
+
* @ingroup bmlib_runtime
|
695 |
+
*
|
696 |
+
* @param [in] handle The device handle
|
697 |
+
* @param [out] pmem The result device memory descriptor
|
698 |
+
* @param [in] size The number of bytes to allocate
|
699 |
+
* @retval BM_SUCCESS Succeeds.
|
700 |
+
* Other code Fails.
|
701 |
+
*/
|
702 |
+
DECL_EXPORT bm_status_t sg_malloc_device_byte(bm_handle_t handle, sg_device_mem_t *pmem,
|
703 |
+
unsigned long long size);
|
704 |
+
|
705 |
+
/**
|
706 |
+
* @name bm_malloc_device_byte_heap
|
707 |
+
* @brief To malloc device memory in size of byte within the specified heap
|
708 |
+
* @ingroup bmlib_runtime
|
709 |
+
*
|
710 |
+
* @param [in] handle The device handle
|
711 |
+
* @param [out] pmem The result device memory descriptor
|
712 |
+
* @param [in] heap_id The heap where to allocate 0/1/2
|
713 |
+
* @param [in] size The number of bytes to allocate
|
714 |
+
* @retval BM_SUCCESS Succeeds.
|
715 |
+
* Other code Fails.
|
716 |
+
*/
|
717 |
+
DECL_EXPORT bm_status_t bm_malloc_device_byte_heap(bm_handle_t handle, bm_device_mem_t *pmem,
|
718 |
+
int heap_id, unsigned int size);
|
719 |
+
|
720 |
+
/**
|
721 |
+
* @name sg_malloc_device_byte_heap
|
722 |
+
* @brief To malloc device memory in size of byte within the specified heap
|
723 |
+
* @ingroup bmlib_runtime
|
724 |
+
*
|
725 |
+
* @param [in] handle The device handle
|
726 |
+
* @param [out] pmem The result device memory descriptor
|
727 |
+
* @param [in] heap_id The heap where to allocate 0/1/2
|
728 |
+
* @param [in] size The number of bytes to allocate
|
729 |
+
* @retval BM_SUCCESS Succeeds.
|
730 |
+
* Other code Fails.
|
731 |
+
*/
|
732 |
+
DECL_EXPORT bm_status_t sg_malloc_device_byte_heap(bm_handle_t handle, sg_device_mem_t *pmem,
|
733 |
+
int heap_id, unsigned long long size);
|
734 |
+
|
735 |
+
/**
|
736 |
+
* @name bm_malloc_device_byte_heap_mask
|
737 |
+
* @brief To malloc device memory in size of byte within the specified heaps
|
738 |
+
* @ingroup bmlib_runtime
|
739 |
+
*
|
740 |
+
* @param [in] handle The device handle
|
741 |
+
* @param [out] pmem The result device memory descriptor
|
742 |
+
* @param [in] heap_id_mask The mask which heaps allocate from. each bit indicate one heap
|
743 |
+
* @param [in] size The number of bytes to allocate
|
744 |
+
* @retval BM_SUCCESS Succeeds.
|
745 |
+
* Other code Fails.
|
746 |
+
*/
|
747 |
+
DECL_EXPORT bm_status_t bm_malloc_device_byte_heap_mask(bm_handle_t handle, bm_device_mem_t *pmem,
|
748 |
+
int heap_id_mask, unsigned int size);
|
749 |
+
|
750 |
+
/**
|
751 |
+
* @name sg_malloc_device_byte_heap_mask
|
752 |
+
* @brief To malloc device memory in size of byte within the specified heaps
|
753 |
+
* @ingroup bmlib_runtime
|
754 |
+
*
|
755 |
+
* @param [in] handle The device handle
|
756 |
+
* @param [out] pmem The result device memory descriptor
|
757 |
+
* @param [in] heap_id_mask The mask which heaps allocate from. each bit indicate one heap
|
758 |
+
* @param [in] size The number of bytes to allocate
|
759 |
+
* @retval BM_SUCCESS Succeeds.
|
760 |
+
* Other code Fails.
|
761 |
+
*/
|
762 |
+
DECL_EXPORT bm_status_t sg_malloc_device_byte_heap_mask(bm_handle_t handle, sg_device_mem_t *pmem,
|
763 |
+
int heap_id_mask, unsigned long long size);
|
764 |
+
|
765 |
+
/**
|
766 |
+
* @name bm_free_device
|
767 |
+
* @brief To free device memory
|
768 |
+
* @ingroup bmlib_runtime
|
769 |
+
*
|
770 |
+
* @param [in] handle The device handle
|
771 |
+
* @param [in] mem The device memory descriptor to free
|
772 |
+
*/
|
773 |
+
DECL_EXPORT void bm_free_device(bm_handle_t handle, bm_device_mem_t mem);
|
774 |
+
|
775 |
+
/**
|
776 |
+
* @name sg_free_device
|
777 |
+
* @brief To free device memory
|
778 |
+
* @ingroup bmlib_runtime
|
779 |
+
*
|
780 |
+
* @param [in] handle The device handle
|
781 |
+
* @param [in] mem The device memory descriptor to free
|
782 |
+
*/
|
783 |
+
DECL_EXPORT void sg_free_device(bm_handle_t handle, sg_device_mem_t mem);
|
784 |
+
|
785 |
+
/**
|
786 |
+
* @name bm_gmem_arm_reserved_request
|
787 |
+
* @brief To obtain the address of global memory reserved for arm926
|
788 |
+
* @param [in] handle The device handle
|
789 |
+
*
|
790 |
+
* @retval unsigned long long The absolute address of gmem reserved for arm926
|
791 |
+
*/
|
792 |
+
DECL_EXPORT unsigned long long bm_gmem_arm_reserved_request(bm_handle_t handle);
|
793 |
+
|
794 |
+
/**
|
795 |
+
* @name bm_gmem_arm_reserved_release
|
796 |
+
* @brief To release the global memory reserved for arm926
|
797 |
+
* @ingroup bmlib_runtime
|
798 |
+
*
|
799 |
+
* @param [in] handle The device handle
|
800 |
+
*/
|
801 |
+
DECL_EXPORT void bm_gmem_arm_reserved_release(bm_handle_t handle);
|
802 |
+
|
803 |
+
/*******************memory copy functions *************************************/
|
804 |
+
/**
|
805 |
+
* @name bm_memcpy_s2d
|
806 |
+
* @brief To copy data from system memory to device memory
|
807 |
+
* @ingroup bmlib_runtime
|
808 |
+
*
|
809 |
+
* @param [in] handle The device handle
|
810 |
+
* @param [in] dst The destination memory (device memory descriptor )
|
811 |
+
* @param [in] src The source memory (system memory, a void* pointer)
|
812 |
+
*
|
813 |
+
* @retval BM_SUCCESS Succeeds.
|
814 |
+
* Other code Fails.
|
815 |
+
*/
|
816 |
+
DECL_EXPORT bm_status_t bm_memcpy_s2d(bm_handle_t handle, bm_device_mem_t dst, void *src);
|
817 |
+
|
818 |
+
/**
|
819 |
+
* @name bm_memcpy_p2p
|
820 |
+
* @brief To copy data from one chip to another chip
|
821 |
+
* @ingroup bmlib_runtime
|
822 |
+
*
|
823 |
+
* @param [in] handle_src The source device handle
|
824 |
+
* @param [in] src The source memory (device memory descriptor )
|
825 |
+
* @param [in] handle_dst The destination device handle
|
826 |
+
* @param [in] dst The destination memory (device memory descriptor )
|
827 |
+
*
|
828 |
+
* @retval BM_SUCCESS Succeeds.
|
829 |
+
* Other code Fails.
|
830 |
+
*/
|
831 |
+
DECL_EXPORT bm_status_t bm_memcpy_p2p(bm_handle_t handle_src, bm_device_mem_t src, bm_handle_t handle_dst,bm_device_mem_t dst);
|
832 |
+
|
833 |
+
/**
|
834 |
+
* @name sg_memcpy_s2d
|
835 |
+
* @brief To copy data from system memory to device memory
|
836 |
+
* @ingroup bmlib_runtime
|
837 |
+
*
|
838 |
+
* @param [in] handle The device handle
|
839 |
+
* @param [in] dst The destination memory (device memory descriptor )
|
840 |
+
* @param [in] src The source memory (system memory, a void* pointer)
|
841 |
+
*
|
842 |
+
* @retval BM_SUCCESS Succeeds.
|
843 |
+
* Other code Fails.
|
844 |
+
*/
|
845 |
+
DECL_EXPORT bm_status_t sg_memcpy_s2d(bm_handle_t handle, sg_device_mem_t dst, void *src);
|
846 |
+
|
847 |
+
/**
|
848 |
+
* @name bm_memcpy_s2d_partial_offset
|
849 |
+
* @brief To copy specified bytes of data from system memory to device memory
|
850 |
+
* with an offset in device memory address.
|
851 |
+
* @ingroup bmlib_runtime
|
852 |
+
*
|
853 |
+
* @param [in] handle The device handle
|
854 |
+
* @param [in] dst The destination memory (device memory descriptor)
|
855 |
+
* @param [in] src The source memory (system memory, a void* pointer)
|
856 |
+
* @param [in] size The size of data to copy (in bytes)
|
857 |
+
* @param [in] offset The offset of the device memory address
|
858 |
+
*
|
859 |
+
* @retval BM_SUCCESS Succeeds.
|
860 |
+
* Other code Fails.
|
861 |
+
*/
|
862 |
+
DECL_EXPORT bm_status_t bm_memcpy_s2d_partial_offset(bm_handle_t handle,
|
863 |
+
bm_device_mem_t dst, void *src,
|
864 |
+
unsigned int size,
|
865 |
+
unsigned int offset);
|
866 |
+
|
867 |
+
/**
|
868 |
+
* @name sg_memcpy_s2d_partial_offset
|
869 |
+
* @brief To copy specified bytes of data from system memory to device memory
|
870 |
+
* with an offset in device memory address.
|
871 |
+
* @ingroup bmlib_runtime
|
872 |
+
*
|
873 |
+
* @param [in] handle The device handle
|
874 |
+
* @param [in] dst The destination memory (device memory descriptor)
|
875 |
+
* @param [in] src The source memory (system memory, a void* pointer)
|
876 |
+
* @param [in] size The size of data to copy (in bytes)
|
877 |
+
* @param [in] offset The offset of the device memory address
|
878 |
+
*
|
879 |
+
* @retval BM_SUCCESS Succeeds.
|
880 |
+
* Other code Fails.
|
881 |
+
*/
|
882 |
+
DECL_EXPORT bm_status_t sg_memcpy_s2d_partial_offset(bm_handle_t handle,
|
883 |
+
sg_device_mem_t dst, void *src,
|
884 |
+
unsigned long long size,
|
885 |
+
unsigned long long offset);
|
886 |
+
|
887 |
+
/**
|
888 |
+
* @name bm_memcpy_s2d_partial
|
889 |
+
* @brief To copy specified bytes of data from system memory to device memory
|
890 |
+
* @ingroup bmlib_runtime
|
891 |
+
*
|
892 |
+
* @param [in] handle The device handle
|
893 |
+
* @param [in] dst The destination memory (device memory descriptor)
|
894 |
+
* @param [in] src The source memory (system memory, a void* pointer)
|
895 |
+
* @param [in] size The size of data to copy (in bytes)
|
896 |
+
*
|
897 |
+
* @retval BM_SUCCESS Succeeds.
|
898 |
+
* Other code Fails.
|
899 |
+
*/
|
900 |
+
DECL_EXPORT bm_status_t bm_memcpy_s2d_partial(bm_handle_t handle, bm_device_mem_t dst,
|
901 |
+
void *src, unsigned int size);
|
902 |
+
|
903 |
+
/**
|
904 |
+
* @name sg_memcpy_s2d_partial
|
905 |
+
* @brief To copy specified bytes of data from system memory to device memory
|
906 |
+
* @ingroup bmlib_runtime
|
907 |
+
*
|
908 |
+
* @param [in] handle The device handle
|
909 |
+
* @param [in] dst The destination memory (device memory descriptor)
|
910 |
+
* @param [in] src The source memory (system memory, a void* pointer)
|
911 |
+
* @param [in] size The size of data to copy (in bytes)
|
912 |
+
*
|
913 |
+
* @retval BM_SUCCESS Succeeds.
|
914 |
+
* Other code Fails.
|
915 |
+
*/
|
916 |
+
DECL_EXPORT bm_status_t sg_memcpy_s2d_partial(bm_handle_t handle, sg_device_mem_t dst,
|
917 |
+
void *src, unsigned long long size);
|
918 |
+
|
919 |
+
/**
|
920 |
+
* @name bm_memcpy_d2s
|
921 |
+
* @brief To copy data from device memory to system memory
|
922 |
+
* @ingroup bmlib_runtime
|
923 |
+
*
|
924 |
+
* @param [in] handle The device handle
|
925 |
+
* @param [in] dst The destination memory (system memory, a void* pointer)
|
926 |
+
* @param [in] src The source memory (device memory descriptor)
|
927 |
+
*
|
928 |
+
* @retval BM_SUCCESS Succeeds.
|
929 |
+
* Other code Fails.
|
930 |
+
*/
|
931 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2s(bm_handle_t handle, void *dst, bm_device_mem_t src);
|
932 |
+
|
933 |
+
/**
|
934 |
+
* @name sg_memcpy_d2s
|
935 |
+
* @brief To copy data from device memory to system memory
|
936 |
+
* @ingroup bmlib_runtime
|
937 |
+
*
|
938 |
+
* @param [in] handle The device handle
|
939 |
+
* @param [in] dst The destination memory (system memory, a void* pointer)
|
940 |
+
* @param [in] src The source memory (device memory descriptor)
|
941 |
+
*
|
942 |
+
* @retval BM_SUCCESS Succeeds.
|
943 |
+
* Other code Fails.
|
944 |
+
*/
|
945 |
+
DECL_EXPORT bm_status_t sg_memcpy_d2s(bm_handle_t handle, void *dst, sg_device_mem_t src);
|
946 |
+
|
947 |
+
/**
|
948 |
+
* @name bm_memcpy_d2s_partial_offset
|
949 |
+
* @brief To copy specified bytes of data from device memory to system memory
|
950 |
+
* with an offset in device memory address.
|
951 |
+
* @ingroup bmlib_runtime
|
952 |
+
*
|
953 |
+
* @param [in] handle The device handle
|
954 |
+
* @param [in] dst The destination memory (system memory, a void* pointer)
|
955 |
+
* @param [in] src The source memory (device memory descriptor)
|
956 |
+
* @param [in] size The size of data to copy (in bytes)
|
957 |
+
* @param [in] offset The offset of the device memory address
|
958 |
+
*
|
959 |
+
* @retval BM_SUCCESS Succeeds.
|
960 |
+
* Other code Fails.
|
961 |
+
*/
|
962 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2s_partial_offset(bm_handle_t handle, void *dst,
|
963 |
+
bm_device_mem_t src, unsigned int size,
|
964 |
+
unsigned int offset);
|
965 |
+
|
966 |
+
/**
|
967 |
+
* @name sg_memcpy_d2s_partial_offset
|
968 |
+
* @brief To copy specified bytes of data from device memory to system memory
|
969 |
+
* with an offset in device memory address.
|
970 |
+
* @ingroup bmlib_runtime
|
971 |
+
*
|
972 |
+
* @param [in] handle The device handle
|
973 |
+
* @param [in] dst The destination memory (system memory, a void* pointer)
|
974 |
+
* @param [in] src The source memory (device memory descriptor)
|
975 |
+
* @param [in] size The size of data to copy (in bytes)
|
976 |
+
* @param [in] offset The offset of the device memory address
|
977 |
+
*
|
978 |
+
* @retval BM_SUCCESS Succeeds.
|
979 |
+
* Other code Fails.
|
980 |
+
*/
|
981 |
+
DECL_EXPORT bm_status_t sg_memcpy_d2s_partial_offset(bm_handle_t handle, void *dst,
|
982 |
+
sg_device_mem_t src, unsigned long long size,
|
983 |
+
unsigned long long offset);
|
984 |
+
|
985 |
+
/**
|
986 |
+
* @name bm_memcpy_d2s_partial
|
987 |
+
* @brief To copy specified bytes of data from device memory to system memory
|
988 |
+
* @ingroup bmlib_runtime
|
989 |
+
*
|
990 |
+
* @param [in] handle The device handle
|
991 |
+
* @param [in] dst The destination memory (system memory, a void* pointer)
|
992 |
+
* @param [in] src The source memory (device memory descriptor)
|
993 |
+
* @param [in] size The size of data to copy (in bytes)
|
994 |
+
*
|
995 |
+
* @retval BM_SUCCESS Data transfer succeeds.
|
996 |
+
* Other code Data transfer fails.
|
997 |
+
*/
|
998 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2s_partial(bm_handle_t handle, void *dst,
|
999 |
+
bm_device_mem_t src, unsigned int size);
|
1000 |
+
|
1001 |
+
/**
|
1002 |
+
* @name sg_memcpy_d2s_partial
|
1003 |
+
* @brief To copy specified bytes of data from device memory to system memory
|
1004 |
+
* @ingroup bmlib_runtime
|
1005 |
+
*
|
1006 |
+
* @param [in] handle The device handle
|
1007 |
+
* @param [in] dst The destination memory (system memory, a void* pointer)
|
1008 |
+
* @param [in] src The source memory (device memory descriptor)
|
1009 |
+
* @param [in] size The size of data to copy (in bytes)
|
1010 |
+
*
|
1011 |
+
* @retval BM_SUCCESS Data transfer succeeds.
|
1012 |
+
* Other code Data transfer fails.
|
1013 |
+
*/
|
1014 |
+
DECL_EXPORT bm_status_t sg_memcpy_d2s_partial(bm_handle_t handle, void *dst,
|
1015 |
+
sg_device_mem_t src, unsigned long long size);
|
1016 |
+
|
1017 |
+
/**
|
1018 |
+
* @name bm_memcpy_d2d
|
1019 |
+
* @brief To copy specified dwords of data from one piece of device memory
|
1020 |
+
* to another piece of device memory within one device. Both source
|
1021 |
+
* and destination offsets can be specified.
|
1022 |
+
* @ingroup bmlib_runtime
|
1023 |
+
*
|
1024 |
+
* @param [in] handle The device handle
|
1025 |
+
* @param [in] dst The destination device memory
|
1026 |
+
* @param [in] dst_offset The offset of destination device memory address
|
1027 |
+
* @param [in] src The source device memory
|
1028 |
+
* @param [in] src_offset The offset of source device memory address
|
1029 |
+
* @param [in] len Length of data to copy (in DWORD 4 bytes)
|
1030 |
+
*
|
1031 |
+
* @retval BM_SUCCESS Succeeds.
|
1032 |
+
* Other code Fails.
|
1033 |
+
*/
|
1034 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2d(bm_handle_t handle, bm_device_mem_t dst,
|
1035 |
+
int dst_offset, bm_device_mem_t src, int src_offset,
|
1036 |
+
int len);
|
1037 |
+
|
1038 |
+
/**
|
1039 |
+
* @name bm_memcpy_d2d_with_core
|
1040 |
+
* @brief To copy specified dwords of data from one piece of device memory
|
1041 |
+
* to another piece of device memory within one device. Both source
|
1042 |
+
* and destination offsets can be specified.
|
1043 |
+
* @ingroup bmlib_runtime
|
1044 |
+
*
|
1045 |
+
* @param [in] handle The device handle
|
1046 |
+
* @param [in] dst The destination device memory
|
1047 |
+
* @param [in] dst_offset The offset of destination device memory address
|
1048 |
+
* @param [in] src The source device memory
|
1049 |
+
* @param [in] src_offset The offset of source device memory address
|
1050 |
+
* @param [in] len Length of data to copy (in DWORD 4 bytes)
|
1051 |
+
* @param [in] core_id The core id to copy
|
1052 |
+
*
|
1053 |
+
* @retval BM_SUCCESS Succeeds.
|
1054 |
+
* Other code Fails.
|
1055 |
+
*/
|
1056 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2d_with_core(bm_handle_t handle, bm_device_mem_t dst,
|
1057 |
+
int dst_offset, bm_device_mem_t src, int src_offset,
|
1058 |
+
int len, int core_id);
|
1059 |
+
|
1060 |
+
/**
|
1061 |
+
* @name bm_memcpy_d2d_byte
|
1062 |
+
* @brief To copy specified bytes of data from one piece of device memory
|
1063 |
+
* to another piece of device memory within one device. Both source
|
1064 |
+
* and destination offsets can be specified.
|
1065 |
+
* @ingroup bmlib_runtime
|
1066 |
+
*
|
1067 |
+
* @param [in] handle The device handle
|
1068 |
+
* @param [in] dst The destination device memory
|
1069 |
+
* @param [in] dst_offset The offset of destination device memory address (in bytes)
|
1070 |
+
* @param [in] src The source device memory
|
1071 |
+
* @param [in] src_offset The offset of source device memory address (in bytes)
|
1072 |
+
* @param [in] size Size of data to copy (in bytes)
|
1073 |
+
*
|
1074 |
+
* @retval BM_SUCCESS Succeeds.
|
1075 |
+
* Other code Fails.
|
1076 |
+
*/
|
1077 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2d_byte(bm_handle_t handle, bm_device_mem_t dst,
|
1078 |
+
size_t dst_offset, bm_device_mem_t src,
|
1079 |
+
size_t src_offset, size_t size);
|
1080 |
+
|
1081 |
+
/**
|
1082 |
+
* @name bm_memcpy_d2d_byte_with_core
|
1083 |
+
* @brief To copy specified bytes of data from one piece of device memory
|
1084 |
+
* to another piece of device memory within one device. Both source
|
1085 |
+
* and destination offsets can be specified.
|
1086 |
+
* @ingroup bmlib_runtime
|
1087 |
+
*
|
1088 |
+
* @param [in] handle The device handle
|
1089 |
+
* @param [in] dst The destination device memory
|
1090 |
+
* @param [in] dst_offset The offset of destination device memory address (in bytes)
|
1091 |
+
* @param [in] src The source device memory
|
1092 |
+
* @param [in] src_offset The offset of source device memory address (in bytes)
|
1093 |
+
* @param [in] size Size of data to copy (in bytes)
|
1094 |
+
* @param [in] core_id The core id to copy
|
1095 |
+
*
|
1096 |
+
* @retval BM_SUCCESS Succeeds.
|
1097 |
+
* Other code Fails.
|
1098 |
+
*/
|
1099 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2d_byte_with_core(bm_handle_t handle, bm_device_mem_t dst,
|
1100 |
+
size_t dst_offset, bm_device_mem_t src,
|
1101 |
+
size_t src_offset, size_t size, int core_id);
|
1102 |
+
|
1103 |
+
/**
|
1104 |
+
* @name bm_memcpy_d2d_stride
|
1105 |
+
* @brief To copy specified data from one piece of device memory
|
1106 |
+
* to another piece of device memory within one device. Both source
|
1107 |
+
* and destination offsets can be specified.
|
1108 |
+
* @ingroup bmlib_runtime
|
1109 |
+
*
|
1110 |
+
* @param [in] handle The device handle
|
1111 |
+
* @param [in] dst The destination device memory
|
1112 |
+
* @param [in] dst_stride The data stride of destination data
|
1113 |
+
* @param [in] src The source device memory
|
1114 |
+
* @param [in] src_stride The data stride of source data
|
1115 |
+
* @param [in] count Count of data to copy
|
1116 |
+
* @param [in] format_size Data format byte size, such as sizeof(uint8_t), sizeof(float), etc.
|
1117 |
+
* format_size only support 1/2/4.
|
1118 |
+
*
|
1119 |
+
* dst_stride MUST be 1, EXCEPT: dst_stride == 4 && src_stride == 1 && format_size ==1
|
1120 |
+
*
|
1121 |
+
* @retval BM_SUCCESS Succeeds.
|
1122 |
+
* Other code Fails.
|
1123 |
+
*/
|
1124 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2d_stride(bm_handle_t handle,
|
1125 |
+
bm_device_mem_t dst,
|
1126 |
+
int dst_stride,
|
1127 |
+
bm_device_mem_t src,
|
1128 |
+
int src_stride,
|
1129 |
+
int count,
|
1130 |
+
int format_size);
|
1131 |
+
|
1132 |
+
/**
|
1133 |
+
* @name bm_memcpy_d2d_stride
|
1134 |
+
* @brief To copy specified data from one piece of device memory
|
1135 |
+
* to another piece of device memory within one device. Both source
|
1136 |
+
* and destination offsets can be specified.
|
1137 |
+
* @ingroup bmlib_runtime
|
1138 |
+
*
|
1139 |
+
* @param [in] handle The device handle
|
1140 |
+
* @param [in] dst The destination device memory
|
1141 |
+
* @param [in] dst_stride The data stride of destination data
|
1142 |
+
* @param [in] src The source device memory
|
1143 |
+
* @param [in] src_stride The data stride of source data
|
1144 |
+
* @param [in] count Count of data to copy
|
1145 |
+
* @param [in] format_size Data format byte size, such as sizeof(uint8_t), sizeof(float), etc.
|
1146 |
+
* format_size only support 1/2/4.
|
1147 |
+
* @param [in] core_id The core id to copy.
|
1148 |
+
*
|
1149 |
+
* dst_stride MUST be 1, EXCEPT: dst_stride == 4 && src_stride == 1 && format_size ==1
|
1150 |
+
*
|
1151 |
+
* @retval BM_SUCCESS Succeeds.
|
1152 |
+
* Other code Fails.
|
1153 |
+
*/
|
1154 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2d_stride_with_core(bm_handle_t handle,
|
1155 |
+
bm_device_mem_t dst,
|
1156 |
+
int dst_stride,
|
1157 |
+
bm_device_mem_t src,
|
1158 |
+
int src_stride,
|
1159 |
+
int count,
|
1160 |
+
int format_size,
|
1161 |
+
int core_id);
|
1162 |
+
|
1163 |
+
/**
|
1164 |
+
* @name bm_memcpy_c2c
|
1165 |
+
* @brief To copy data from one chip to another chip.
|
1166 |
+
* (Used in multi-chip card scenario)
|
1167 |
+
* @ingroup bmlib_runtime
|
1168 |
+
*
|
1169 |
+
* @param [in] src_handle The source device handle
|
1170 |
+
* @param [in] dst_handle The destination device handle
|
1171 |
+
* @param [in] src The source device memory descriptor
|
1172 |
+
* @param [in] dst The destination device memory descriptor
|
1173 |
+
* @param [in] force_dst_cdma If use the CDMA engine of the destination device
|
1174 |
+
* @retval BM_SUCCESS Succeeds.
|
1175 |
+
* Other code Fails.
|
1176 |
+
*/
|
1177 |
+
DECL_EXPORT bm_status_t bm_memcpy_c2c(bm_handle_t src_handle, bm_handle_t dst_handle,
|
1178 |
+
bm_device_mem_t src, bm_device_mem_t dst,
|
1179 |
+
bool force_dst_cdma);
|
1180 |
+
|
1181 |
+
/**
|
1182 |
+
* @name bm_memset_device
|
1183 |
+
* @brief To fill in specified device memory with the given value
|
1184 |
+
* @ingroup bmlib_runtime
|
1185 |
+
*
|
1186 |
+
* @param [in] handle The device handle
|
1187 |
+
* @param [in] value The value used to fill. (int type)
|
1188 |
+
* @param [in] mem The device memory which will be filled in
|
1189 |
+
* @retval BM_SUCCESS Succeeds.
|
1190 |
+
* Other code Fails.
|
1191 |
+
*/
|
1192 |
+
DECL_EXPORT bm_status_t bm_memset_device(bm_handle_t handle, const int value,
|
1193 |
+
bm_device_mem_t mem);
|
1194 |
+
|
1195 |
+
/**
|
1196 |
+
* @name bm_memset_device_ext
|
1197 |
+
* @brief To fill in specified device memory with the given value and mode
|
1198 |
+
* @ingroup bmlib_runtime
|
1199 |
+
*
|
1200 |
+
* @param [in] handle The device handle
|
1201 |
+
* @param [in] value The pointer of value used to fill
|
1202 |
+
* @param [in] mode The valid bytes of *value
|
1203 |
+
* @param [in] mem The device memory which will be filled in
|
1204 |
+
* @retval BM_SUCCESS Succeeds.
|
1205 |
+
* Other code Fails.
|
1206 |
+
*/
|
1207 |
+
DECL_EXPORT bm_status_t bm_memset_device_ext(bm_handle_t handle, void* value, int mode,
|
1208 |
+
bm_device_mem_t mem);
|
1209 |
+
|
1210 |
+
/**
|
1211 |
+
* @name bm_mem_convert_system_to_device_neuron
|
1212 |
+
* @brief To malloc a piece of device memory according to the shape of
|
1213 |
+
* neuron(in DWORD 4 bytes); copy neuron from system memory to
|
1214 |
+
* device memory if need_copy is true.
|
1215 |
+
* @ingroup bmlib_runtime
|
1216 |
+
*
|
1217 |
+
* @param [in] handle The device handle
|
1218 |
+
* @param [in] dev_mem The device memory descriptor
|
1219 |
+
* @param [in] sys_mem The system memory descriptor
|
1220 |
+
* @param [in] need_copy If copy from system to device is needed
|
1221 |
+
* @param [in] n,c,h,w Neuron shape size
|
1222 |
+
*
|
1223 |
+
* @retval BM_SUCCESS Succeeds.
|
1224 |
+
* Other code Fails.
|
1225 |
+
*/
|
1226 |
+
DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_neuron(bm_handle_t handle,
|
1227 |
+
struct bm_mem_desc *dev_mem,
|
1228 |
+
struct bm_mem_desc sys_mem,
|
1229 |
+
bool need_copy, int n, int c,
|
1230 |
+
int h, int w);
|
1231 |
+
|
1232 |
+
/**
|
1233 |
+
* @name bm_mem_convert_system_to_device_neuron_byte
|
1234 |
+
* @brief To malloc a piece of device memory according to the shape of
|
1235 |
+
* neuron(in bytes); copy neuron from system memory to
|
1236 |
+
* device memory if need_copy is true.
|
1237 |
+
* @ingroup bmlib_runtime
|
1238 |
+
*
|
1239 |
+
* @param [in] handle The device handle
|
1240 |
+
* @param [in] dev_mem The device memory descriptor
|
1241 |
+
* @param [in] sys_mem The system memory descriptor
|
1242 |
+
* @param [in] need_copy If copy from system to device is needed
|
1243 |
+
* @param [in] n,c,h,w Neuron shape size
|
1244 |
+
*
|
1245 |
+
* @retval BM_SUCCESS Succeeds.
|
1246 |
+
* Other code Fails.
|
1247 |
+
*/
|
1248 |
+
DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_neuron_byte(
|
1249 |
+
bm_handle_t handle, struct bm_mem_desc *dev_mem, struct bm_mem_desc sys_mem,
|
1250 |
+
bool need_copy, int n, int c, int h, int w);
|
1251 |
+
|
1252 |
+
/**
|
1253 |
+
* @name bm_mem_convert_system_to_device_coeff
|
1254 |
+
* @brief To malloc a piece of device memory according to the size of
|
1255 |
+
* coefficient (in DWORD 4 bytes); copy coefficient from system
|
1256 |
+
* memory to device memory if need_copy is true.
|
1257 |
+
* @ingroup bmlib_runtime
|
1258 |
+
*
|
1259 |
+
* @param [in] handle The device handle
|
1260 |
+
* @param [in] dev_mem The device memory descriptor
|
1261 |
+
* @param [in] sys_mem The system memory descriptor
|
1262 |
+
* @param [in] need_copy If copy from system to device is needed
|
1263 |
+
* @param [in] coeff_count Coefficient size
|
1264 |
+
*
|
1265 |
+
* @retval BM_SUCCESS Succeeds.
|
1266 |
+
* Other code Fails.
|
1267 |
+
*/
|
1268 |
+
DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_coeff(bm_handle_t handle,
|
1269 |
+
struct bm_mem_desc *dev_mem,
|
1270 |
+
struct bm_mem_desc sys_mem,
|
1271 |
+
bool need_copy,
|
1272 |
+
int coeff_count);
|
1273 |
+
/**
|
1274 |
+
* @name bm_mem_convert_system_to_device_coeff_byte
|
1275 |
+
* @brief To malloc a piece of device memory according to the size of
|
1276 |
+
* coefficient (in bytes); copy coefficient from system
|
1277 |
+
* memory to device memory if need_copy is true.
|
1278 |
+
* @ingroup bmlib_runtime
|
1279 |
+
*
|
1280 |
+
* @param [in] handle The device handle
|
1281 |
+
* @param [in] dev_mem The device memory descriptor
|
1282 |
+
* @param [in] sys_mem The system memory descriptor
|
1283 |
+
* @param [in] need_copy If copy from system to device is needed
|
1284 |
+
* @param [in] coeff_count Coefficient size
|
1285 |
+
*
|
1286 |
+
* @retval BM_SUCCESS Succeeds.
|
1287 |
+
* Other code Fails.
|
1288 |
+
*/
|
1289 |
+
DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_coeff_byte(
|
1290 |
+
bm_handle_t handle, struct bm_mem_desc *dev_mem, struct bm_mem_desc sys_mem,
|
1291 |
+
bool need_copy, int coeff_count);
|
1292 |
+
|
1293 |
+
/*******************memory map functions *************************************/
|
1294 |
+
/**
|
1295 |
+
* @name bm_mem_mmap_device_mem
|
1296 |
+
* @brief To map a piece of device memory to user space with cache enabled.
|
1297 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1298 |
+
* @ingroup bmlib_runtime
|
1299 |
+
*
|
1300 |
+
* @param [in] handle The device handle
|
1301 |
+
* @param [in] dev_mem The device memory to map
|
1302 |
+
* @param [out] vmem The virtual address of the mapped device memory
|
1303 |
+
*
|
1304 |
+
* @retval BM_SUCCESS Succeeds.
|
1305 |
+
* Other code Fails.
|
1306 |
+
*/
|
1307 |
+
DECL_EXPORT bm_status_t bm_mem_mmap_device_mem(bm_handle_t handle, bm_device_mem_t *dmem,
|
1308 |
+
|
1309 |
+
unsigned long long *vmem);
|
1310 |
+
|
1311 |
+
/**
|
1312 |
+
* @name sg_mem_mmap_device_mem
|
1313 |
+
* @brief To map a piece of device memory to user space with cache enabled.
|
1314 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1315 |
+
* @ingroup bmlib_runtime
|
1316 |
+
*
|
1317 |
+
* @param [in] handle The device handle
|
1318 |
+
* @param [in] dev_mem The device memory to map
|
1319 |
+
* @param [out] vmem The virtual address of the mapped device memory
|
1320 |
+
*
|
1321 |
+
* @retval BM_SUCCESS Succeeds.
|
1322 |
+
* Other code Fails.
|
1323 |
+
*/
|
1324 |
+
DECL_EXPORT bm_status_t sg_mem_mmap_device_mem(bm_handle_t handle, sg_device_mem_t *dmem,
|
1325 |
+
unsigned long long *vmem);
|
1326 |
+
|
1327 |
+
/*******************memory map functions *************************************/
|
1328 |
+
/**
|
1329 |
+
* @name bm_mem_mmap_device_mem_no_cache
|
1330 |
+
* @brief To map a piece of device memory to user space with cache disabled.
|
1331 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1332 |
+
* @ingroup bmlib_runtime
|
1333 |
+
*
|
1334 |
+
* @param [in] handle The device handle
|
1335 |
+
* @param [in] dev_mem The device memory to map
|
1336 |
+
* @param [out] vmem The virtual address of the mapped device memory
|
1337 |
+
*
|
1338 |
+
* @retval BM_SUCCESS Succeeds.
|
1339 |
+
* Other code Fails.
|
1340 |
+
*/
|
1341 |
+
DECL_EXPORT bm_status_t bm_mem_mmap_device_mem_no_cache(bm_handle_t handle, bm_device_mem_t *dmem,
|
1342 |
+
|
1343 |
+
unsigned long long *vmem);
|
1344 |
+
|
1345 |
+
/**
|
1346 |
+
* @name sg_mem_mmap_device_mem_no_cache
|
1347 |
+
* @brief To map a piece of device memory to user space with cache disabled.
|
1348 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1349 |
+
* @ingroup bmlib_runtime
|
1350 |
+
*
|
1351 |
+
* @param [in] handle The device handle
|
1352 |
+
* @param [in] dev_mem The device memory to map
|
1353 |
+
* @param [out] vmem The virtual address of the mapped device memory
|
1354 |
+
*
|
1355 |
+
* @retval BM_SUCCESS Succeeds.
|
1356 |
+
* Other code Fails.
|
1357 |
+
*/
|
1358 |
+
DECL_EXPORT bm_status_t sg_mem_mmap_device_mem_no_cache(bm_handle_t handle, sg_device_mem_t *dmem,
|
1359 |
+
unsigned long long *vmem);
|
1360 |
+
|
1361 |
+
/**
|
1362 |
+
* @name bm_mem_vir_to_phy
|
1363 |
+
* @brief To get device mem address through the mapped virtual address .
|
1364 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1365 |
+
* @ingroup bmlib_runtime
|
1366 |
+
*
|
1367 |
+
* @param [in] handle The device handle
|
1368 |
+
* @param [in] vmem The virtual address of the mapped device memory
|
1369 |
+
* @param [out] dev_mem The device memory address
|
1370 |
+
*
|
1371 |
+
* @retval BM_SUCCESS Succeeds.
|
1372 |
+
* Other code Fails.
|
1373 |
+
*/
|
1374 |
+
DECL_EXPORT bm_status_t bm_mem_vir_to_phy(bm_handle_t handle, unsigned long long vmem,
|
1375 |
+
unsigned long long *device_mem);
|
1376 |
+
/**
|
1377 |
+
* @name bm_mem_invalidate_device_mem
|
1378 |
+
* @brief To invalidate a piece of mapped device memory to maintain
|
1379 |
+
* cache coherence
|
1380 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1381 |
+
* @ingroup bmlib_runtime
|
1382 |
+
*
|
1383 |
+
* @param [in] handle The device handle
|
1384 |
+
* @param [in] dmem The device memory to invalidate
|
1385 |
+
*
|
1386 |
+
* @retval BM_SUCCESS Succeeds.
|
1387 |
+
* Other code Fails.
|
1388 |
+
*/
|
1389 |
+
|
1390 |
+
DECL_EXPORT bm_status_t bm_mem_invalidate_device_mem(bm_handle_t handle,
|
1391 |
+
bm_device_mem_t *dmem);
|
1392 |
+
|
1393 |
+
/**
|
1394 |
+
* @name sg_mem_invalidate_device_mem
|
1395 |
+
* @brief To invalidate a piece of mapped device memory to maintain
|
1396 |
+
* cache coherence
|
1397 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1398 |
+
* @ingroup bmlib_runtime
|
1399 |
+
*
|
1400 |
+
* @param [in] handle The device handle
|
1401 |
+
* @param [in] dmem The device memory to invalidate
|
1402 |
+
*
|
1403 |
+
* @retval BM_SUCCESS Succeeds.
|
1404 |
+
* Other code Fails.
|
1405 |
+
*/
|
1406 |
+
|
1407 |
+
DECL_EXPORT bm_status_t sg_mem_invalidate_device_mem(bm_handle_t handle,
|
1408 |
+
sg_device_mem_t *dmem);
|
1409 |
+
|
1410 |
+
/**
|
1411 |
+
* @name bm_mem_invalidate_partial_device_mem
|
1412 |
+
* @brief To invalidate part of mapped device memory to maintain
|
1413 |
+
* cache coherence
|
1414 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1415 |
+
* @ingroup bmlib_runtime
|
1416 |
+
*
|
1417 |
+
* @param [in] handle The device handle
|
1418 |
+
* @param [in] dmem The device memory to invalidate
|
1419 |
+
* @param [in] offset The offset of device memory address
|
1420 |
+
* @param [in] len The length of memory to invalidate in bytes
|
1421 |
+
*
|
1422 |
+
* @retval BM_SUCCESS Succeeds.
|
1423 |
+
* Other code Fails.
|
1424 |
+
*/
|
1425 |
+
DECL_EXPORT bm_status_t bm_mem_invalidate_partial_device_mem(bm_handle_t handle,
|
1426 |
+
bm_device_mem_t *dmem,
|
1427 |
+
unsigned int offset,
|
1428 |
+
unsigned int len);
|
1429 |
+
|
1430 |
+
/**
|
1431 |
+
* @name sg_mem_invalidate_partial_device_mem
|
1432 |
+
* @brief To invalidate part of mapped device memory to maintain
|
1433 |
+
* cache coherence
|
1434 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1435 |
+
* @ingroup bmlib_runtime
|
1436 |
+
*
|
1437 |
+
* @param [in] handle The device handle
|
1438 |
+
* @param [in] dmem The device memory to invalidate
|
1439 |
+
* @param [in] offset The offset of device memory address
|
1440 |
+
* @param [in] len The length of memory to invalidate in bytes
|
1441 |
+
*
|
1442 |
+
* @retval BM_SUCCESS Succeeds.
|
1443 |
+
* Other code Fails.
|
1444 |
+
*/
|
1445 |
+
DECL_EXPORT bm_status_t sg_mem_invalidate_partial_device_mem(bm_handle_t handle,
|
1446 |
+
sg_device_mem_t *dmem,
|
1447 |
+
unsigned long long offset,
|
1448 |
+
unsigned long long len);
|
1449 |
+
|
1450 |
+
/**
|
1451 |
+
* @name bm_mem_flush_device_mem
|
1452 |
+
* @brief To flush a piece of mapped device memory to maintain
|
1453 |
+
* cache coherence
|
1454 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1455 |
+
* @ingroup bmlib_runtime
|
1456 |
+
*
|
1457 |
+
* @param [in] handle The device handle
|
1458 |
+
* @param [in] dmem The device memory to flush
|
1459 |
+
*
|
1460 |
+
* @retval BM_SUCCESS Succeeds.
|
1461 |
+
* Other code Fails.
|
1462 |
+
*/
|
1463 |
+
DECL_EXPORT bm_status_t bm_mem_flush_device_mem(bm_handle_t handle, bm_device_mem_t *dmem);
|
1464 |
+
|
1465 |
+
/**
|
1466 |
+
* @name sg_mem_flush_device_mem
|
1467 |
+
* @brief To flush a piece of mapped device memory to maintain
|
1468 |
+
* cache coherence
|
1469 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1470 |
+
* @ingroup bmlib_runtime
|
1471 |
+
*
|
1472 |
+
* @param [in] handle The device handle
|
1473 |
+
* @param [in] dmem The device memory to flush
|
1474 |
+
*
|
1475 |
+
* @retval BM_SUCCESS Succeeds.
|
1476 |
+
* Other code Fails.
|
1477 |
+
*/
|
1478 |
+
DECL_EXPORT bm_status_t sg_mem_flush_device_mem(bm_handle_t handle, sg_device_mem_t *dmem);
|
1479 |
+
|
1480 |
+
/**
|
1481 |
+
* @name bm_mem_flush_partial_device_mem
|
1482 |
+
* @brief To flush part of mapped device memory to maintain
|
1483 |
+
* cache coherence
|
1484 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1485 |
+
* @ingroup bmlib_runtime
|
1486 |
+
*
|
1487 |
+
* @param [in] handle The device handle
|
1488 |
+
* @param [in] dmem The device memory to flush
|
1489 |
+
* @param [in] offset The offset of device memory address
|
1490 |
+
* @param [in] len The length of memory to flush in bytes
|
1491 |
+
*
|
1492 |
+
* @retval BM_SUCCESS Succeeds.
|
1493 |
+
* Other code Fails.
|
1494 |
+
*/
|
1495 |
+
DECL_EXPORT bm_status_t bm_mem_flush_partial_device_mem(bm_handle_t handle,
|
1496 |
+
bm_device_mem_t *dmem,
|
1497 |
+
unsigned int offset,
|
1498 |
+
unsigned int len);
|
1499 |
+
|
1500 |
+
/**
|
1501 |
+
* @name sg_mem_flush_partial_device_mem
|
1502 |
+
* @brief To flush part of mapped device memory to maintain
|
1503 |
+
* cache coherence
|
1504 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1505 |
+
* @ingroup bmlib_runtime
|
1506 |
+
*
|
1507 |
+
* @param [in] handle The device handle
|
1508 |
+
* @param [in] dmem The device memory to flush
|
1509 |
+
* @param [in] offset The offset of device memory address
|
1510 |
+
* @param [in] len The length of memory to flush in bytes
|
1511 |
+
*
|
1512 |
+
* @retval BM_SUCCESS Succeeds.
|
1513 |
+
* Other code Fails.
|
1514 |
+
*/
|
1515 |
+
DECL_EXPORT bm_status_t sg_mem_flush_partial_device_mem(bm_handle_t handle,
|
1516 |
+
sg_device_mem_t *dmem,
|
1517 |
+
unsigned long long offset,
|
1518 |
+
unsigned long long len);
|
1519 |
+
|
1520 |
+
/**
|
1521 |
+
* @name bm_mem_unmap_device_mem
|
1522 |
+
* @brief To unmap a piece of mapped device memory
|
1523 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1524 |
+
* @ingroup bmlib_runtime
|
1525 |
+
*
|
1526 |
+
* @param [in] handle The device handle
|
1527 |
+
* @param [in] vmem The virtual address of the mapped device memory
|
1528 |
+
* @param [in] size The size of unmapped memory
|
1529 |
+
*
|
1530 |
+
* @retval BM_SUCCESS Succeeds.
|
1531 |
+
* Other code Fails.
|
1532 |
+
*/
|
1533 |
+
DECL_EXPORT bm_status_t bm_mem_unmap_device_mem(bm_handle_t handle, void *vmem, int size);
|
1534 |
+
|
1535 |
+
/**
|
1536 |
+
* @name sg_mem_unmap_device_mem
|
1537 |
+
* @brief To unmap a piece of mapped device memory
|
1538 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1539 |
+
* @ingroup bmlib_runtime
|
1540 |
+
*
|
1541 |
+
* @param [in] handle The device handle
|
1542 |
+
* @param [in] vmem The virtual address of the mapped device memory
|
1543 |
+
* @param [in] size The size of unmapped memory
|
1544 |
+
*
|
1545 |
+
* @retval BM_SUCCESS Succeeds.
|
1546 |
+
* Other code Fails.
|
1547 |
+
*/
|
1548 |
+
DECL_EXPORT bm_status_t sg_mem_unmap_device_mem(bm_handle_t handle, void *vmem, unsigned long long size);
|
1549 |
+
|
1550 |
+
/*******************api(kernel) functions *************************************/
|
1551 |
+
/**
|
1552 |
+
* @name bm_flush
|
1553 |
+
* @brief To synchronize APIs of the current thread. The thread will block
|
1554 |
+
* until all the outstanding APIs of the current thread are finished.
|
1555 |
+
* @ingroup bmlib_runtime
|
1556 |
+
*
|
1557 |
+
* @param [in] handle The device handle
|
1558 |
+
*/
|
1559 |
+
DECL_EXPORT void bm_flush(bm_handle_t handle);
|
1560 |
+
|
1561 |
+
/**
|
1562 |
+
* @name bm_device_sync
|
1563 |
+
* @brief To synchronize APIs of the device. The thread will block
|
1564 |
+
* until all the outstanding APIs of the device are finished.
|
1565 |
+
* @ingroup bmlib_runtime
|
1566 |
+
*
|
1567 |
+
* @param [in] handle The device handle
|
1568 |
+
* @retval BM_SUCCESS Succeeds.
|
1569 |
+
* Other code Fails.
|
1570 |
+
*/
|
1571 |
+
DECL_EXPORT bm_status_t bm_device_sync(bm_handle_t handle);
|
1572 |
+
|
1573 |
+
/**
|
1574 |
+
* @name bm_handle_sync
|
1575 |
+
* @brief To synchronize APIs of the handle. The thread will block
|
1576 |
+
* until all the outstanding APIs of the handle are finished.
|
1577 |
+
* @ingroup bmlib_runtime
|
1578 |
+
*
|
1579 |
+
* @param [in] handle The device handle
|
1580 |
+
* @retval BM_SUCCESS Succeeds.
|
1581 |
+
* Other code Fails.
|
1582 |
+
*/
|
1583 |
+
DECL_EXPORT bm_status_t bm_handle_sync(bm_handle_t handle);
|
1584 |
+
|
1585 |
+
/**
|
1586 |
+
* @name bm_handle_sync_from_core
|
1587 |
+
* @brief To synchronize APIs of the handle. The thread will block
|
1588 |
+
* until all the outstanding APIs of the handle are finished.
|
1589 |
+
* @ingroup bmlib_runtime
|
1590 |
+
*
|
1591 |
+
* @param [in] handle The device handle
|
1592 |
+
* @param [in] core_id The core id
|
1593 |
+
* @retval BM_SUCCESS Succeeds.
|
1594 |
+
* Other code Fails.
|
1595 |
+
*/
|
1596 |
+
DECL_EXPORT bm_status_t bm_handle_sync_from_core(bm_handle_t handle, int core_id);
|
1597 |
+
|
1598 |
+
/**
|
1599 |
+
* @name bm_thread_sync
|
1600 |
+
* @brief To synchronize APIs of the current thread. The thread will block
|
1601 |
+
* until all the outstanding APIs of the current thread are finished.
|
1602 |
+
* @ingroup bmlib_runtime
|
1603 |
+
*
|
1604 |
+
* @param [in] handle The device handle
|
1605 |
+
* @retval BM_SUCCESS Succeeds.
|
1606 |
+
* Other code Fails.
|
1607 |
+
*/
|
1608 |
+
DECL_EXPORT bm_status_t bm_thread_sync(bm_handle_t handle);
|
1609 |
+
|
1610 |
+
/**
|
1611 |
+
* @name bm_thread_sync_from_core
|
1612 |
+
* @brief To synchronize APIs of the current thread. The thread will block
|
1613 |
+
* until all the outstanding APIs of the current thread are finished.
|
1614 |
+
* @ingroup bmlib_runtime
|
1615 |
+
*
|
1616 |
+
* @param [in] handle The device handle
|
1617 |
+
* @param [in] core_id The core id
|
1618 |
+
* @retval BM_SUCCESS Succeeds.
|
1619 |
+
* Other code Fails.
|
1620 |
+
*/
|
1621 |
+
DECL_EXPORT bm_status_t bm_thread_sync_from_core(bm_handle_t handle, int core_id);
|
1622 |
+
|
1623 |
+
/*******************trace and profile releated functions **********************/
|
1624 |
+
typedef struct bm_profile {
|
1625 |
+
#ifdef __linux__
|
1626 |
+
unsigned long cdma_in_time;
|
1627 |
+
unsigned long cdma_in_counter;
|
1628 |
+
unsigned long cdma_out_time;
|
1629 |
+
unsigned long cdma_out_counter;
|
1630 |
+
unsigned long tpu_process_time;
|
1631 |
+
unsigned long tpu1_process_time;
|
1632 |
+
unsigned long sent_api_counter;
|
1633 |
+
unsigned long completed_api_counter;
|
1634 |
+
#else
|
1635 |
+
unsigned long long cdma_in_time;
|
1636 |
+
unsigned long long cdma_in_counter;
|
1637 |
+
unsigned long long cdma_out_time;
|
1638 |
+
unsigned long long cdma_out_counter;
|
1639 |
+
unsigned long long tpu_process_time;
|
1640 |
+
unsigned long long tpu1_process_time;
|
1641 |
+
unsigned long long sent_api_counter;
|
1642 |
+
unsigned long long completed_api_counter;
|
1643 |
+
#endif
|
1644 |
+
} bm_profile_t;
|
1645 |
+
/**
|
1646 |
+
* @name bm_get_profile
|
1647 |
+
* @brief To get the profile data at the moment
|
1648 |
+
* @ingroup bmlib_runtime
|
1649 |
+
*
|
1650 |
+
* @param [in] handle The device handle
|
1651 |
+
* @param [out] profile The result profile data
|
1652 |
+
* @retval BM_SUCCESS Succeeds.
|
1653 |
+
* Other code Fails.
|
1654 |
+
*/
|
1655 |
+
DECL_EXPORT bm_status_t bm_get_profile(bm_handle_t handle, bm_profile_t *profile);
|
1656 |
+
|
1657 |
+
typedef struct bootloader_version{
|
1658 |
+
char *bl1_version;
|
1659 |
+
char *bl2_version;
|
1660 |
+
char *bl31_version;
|
1661 |
+
char *uboot_version;
|
1662 |
+
} boot_loader_version;
|
1663 |
+
|
1664 |
+
/**
|
1665 |
+
* @name bm_get_boot_loader_version
|
1666 |
+
* @brief To get the boot_loader_version
|
1667 |
+
* @ingroup bmlib_runtime
|
1668 |
+
*
|
1669 |
+
* @param [in] handle The device handle
|
1670 |
+
* @param [out] version The result version data
|
1671 |
+
* @retval BM_SUCCESS Succeeds.
|
1672 |
+
* Other code Fails.
|
1673 |
+
*/
|
1674 |
+
DECL_EXPORT bm_status_t bm_get_boot_loader_version(bm_handle_t handle, boot_loader_version *version);
|
1675 |
+
|
1676 |
+
/**
|
1677 |
+
* @name bm_get_vpu_instant_usage
|
1678 |
+
* @brief To get vpu usage
|
1679 |
+
* @ingroup bmlib_runtime
|
1680 |
+
*
|
1681 |
+
* @param [in] handle The device handle
|
1682 |
+
* @param [out] smi_attr The result vpu usage
|
1683 |
+
* @retval BM_SUCCESS Succeeds.
|
1684 |
+
* Other code Fails.
|
1685 |
+
*/
|
1686 |
+
DECL_EXPORT bm_status_t bm_get_vpu_instant_usage(bm_handle_t handle, int *vpu_usage);
|
1687 |
+
|
1688 |
+
/**
|
1689 |
+
* @name bm_get_jpu_core_usage
|
1690 |
+
* @brief To get the jpu usage
|
1691 |
+
* @ingroup bmlib_runtime
|
1692 |
+
*
|
1693 |
+
* @param [in] handle The device handle
|
1694 |
+
* @param [out] smi_attr The result jpu usage
|
1695 |
+
* @retval BM_SUCCESS Succeeds.
|
1696 |
+
* Other code Fails.
|
1697 |
+
*/
|
1698 |
+
DECL_EXPORT bm_status_t bm_get_jpu_core_usage(bm_handle_t handle, int *jpu_usage);
|
1699 |
+
|
1700 |
+
/**
|
1701 |
+
* @name bm_get_vpp_instant_usage
|
1702 |
+
* @brief To get the vpp usage
|
1703 |
+
* @ingroup bmlib_runtime
|
1704 |
+
*
|
1705 |
+
* @param [in] handle The device handle
|
1706 |
+
* @param [out] smi_attr The result vpp usage
|
1707 |
+
* @retval BM_SUCCESS Succeeds.
|
1708 |
+
* Other code Fails.
|
1709 |
+
*/
|
1710 |
+
DECL_EXPORT bm_status_t bm_get_vpp_instant_usage(bm_handle_t handle, int *vpp_usage);
|
1711 |
+
/**
|
1712 |
+
* @name bm_get_last_api_process_time_us
|
1713 |
+
* @brief This function is abandoned.
|
1714 |
+
*/
|
1715 |
+
#ifdef __linux__
|
1716 |
+
DECL_EXPORT bm_status_t bm_get_last_api_process_time_us(bm_handle_t handle,
|
1717 |
+
unsigned long *time_us);
|
1718 |
+
#else
|
1719 |
+
DECL_EXPORT bm_status_t bm_get_last_api_process_time_us(bm_handle_t handle,
|
1720 |
+
unsigned long long *time_us);
|
1721 |
+
#endif
|
1722 |
+
/*******************tpu clock and module reset releated functions *************/
|
1723 |
+
|
1724 |
+
/**
|
1725 |
+
* @name bm_set_clk_tpu_freq
|
1726 |
+
* @brief To set the clock frequency of TPU (only valid in PCIE mode).
|
1727 |
+
* @ingroup bmlib_runtime
|
1728 |
+
*
|
1729 |
+
* @param [in] handle The device handle
|
1730 |
+
* @param [in] freq The TPU target frequency
|
1731 |
+
* @retval BM_SUCCESS Succeeds.
|
1732 |
+
* Other code Fails.
|
1733 |
+
*/
|
1734 |
+
DECL_EXPORT bm_status_t bm_set_clk_tpu_freq(bm_handle_t handle, int freq);
|
1735 |
+
|
1736 |
+
/**
|
1737 |
+
* @name bm_get_clk_tpu_freq
|
1738 |
+
* @brief To get the clock frequency of TPU
|
1739 |
+
* @ingroup bmlib_runtime
|
1740 |
+
*
|
1741 |
+
* @param [in] handle The device handle
|
1742 |
+
* @param [out] freq The current TPU frequency
|
1743 |
+
* @retval BM_SUCCESS Succeeds.
|
1744 |
+
* Other code Fails.
|
1745 |
+
*/
|
1746 |
+
DECL_EXPORT bm_status_t bm_get_clk_tpu_freq(bm_handle_t handle, int *freq);
|
1747 |
+
|
1748 |
+
/*******************misc functions ********************************************/
|
1749 |
+
struct bm_misc_info {
|
1750 |
+
int pcie_soc_mode; /*0---pcie; 1---soc*/
|
1751 |
+
int ddr_ecc_enable; /*0---disable; 1---enable*/
|
1752 |
+
long long ddr0a_size;
|
1753 |
+
long long ddr0b_size;
|
1754 |
+
long long ddr1_size;
|
1755 |
+
long long ddr2_size;
|
1756 |
+
unsigned int chipid;
|
1757 |
+
#define BM1682_CHIPID_BIT_MASK (0X1 << 0)
|
1758 |
+
#define BM1684_CHIPID_BIT_MASK (0X1 << 1)
|
1759 |
+
#define BM1686_CHIPID_BIT_MASK (0X1 << 2)
|
1760 |
+
#ifdef __linux__
|
1761 |
+
unsigned long chipid_bit_mask;
|
1762 |
+
#else
|
1763 |
+
unsigned long long chipid_bit_mask;
|
1764 |
+
#endif
|
1765 |
+
unsigned int driver_version;
|
1766 |
+
int domain_bdf;
|
1767 |
+
int board_version; /*hardware board version [23:16]-mcu sw version, [15:8]-board type, [7:0]-hw version*/
|
1768 |
+
int a53_enable;
|
1769 |
+
int dyn_enable;
|
1770 |
+
};
|
1771 |
+
|
1772 |
+
/**
|
1773 |
+
* @name bm_get_misc_info
|
1774 |
+
* @brief To get miscellaneous information of the device
|
1775 |
+
* @ingroup bmlib_runtime
|
1776 |
+
*
|
1777 |
+
* @param [in] handle The device handle
|
1778 |
+
* @param [out] pmisc_info The fetched misc info
|
1779 |
+
* @retval BM_SUCCESS Succeeds.
|
1780 |
+
* Other code Fails.
|
1781 |
+
*/
|
1782 |
+
DECL_EXPORT bm_status_t bm_get_misc_info(bm_handle_t handle, struct bm_misc_info *pmisc_info);
|
1783 |
+
|
1784 |
+
/**
|
1785 |
+
* @name bm_get_chipid
|
1786 |
+
* @brief To get the chipid of the device. (0x1682 / 0x1684 / 0x168?)
|
1787 |
+
* @ingroup bmlib_runtime
|
1788 |
+
*
|
1789 |
+
* @param [in] handle The device handle
|
1790 |
+
* @param [out] p_chipid The chip id of the device
|
1791 |
+
* @retval BM_SUCCESS Succeeds.
|
1792 |
+
* Other code Fails.
|
1793 |
+
*/
|
1794 |
+
DECL_EXPORT bm_status_t bm_get_chipid(bm_handle_t handle, unsigned int *p_chipid);
|
1795 |
+
|
1796 |
+
#define BMLIB_LOG_QUIET -8
|
1797 |
+
#define BMLIB_LOG_PANIC 0
|
1798 |
+
#define BMLIB_LOG_FATAL 8
|
1799 |
+
#define BMLIB_LOG_ERROR 16
|
1800 |
+
#define BMLIB_LOG_WARNING 24
|
1801 |
+
#define BMLIB_LOG_INFO 32
|
1802 |
+
#define BMLIB_LOG_VERBOSE 40
|
1803 |
+
#define BMLIB_LOG_DEBUG 48
|
1804 |
+
#define BMLIB_LOG_TRACE 56
|
1805 |
+
|
1806 |
+
/**
|
1807 |
+
* @name bmlib_log_get_level
|
1808 |
+
* @brief To get the bmlib log level
|
1809 |
+
* @ingroup bmlib_log
|
1810 |
+
*
|
1811 |
+
* @param void
|
1812 |
+
* @retval The level of bmlib log level
|
1813 |
+
*/
|
1814 |
+
DECL_EXPORT int bmlib_log_get_level(void);
|
1815 |
+
|
1816 |
+
/**
|
1817 |
+
* @name bmlib_log_set_level
|
1818 |
+
* @brief To set the bmlib log level
|
1819 |
+
* @ingroup bmlib_log
|
1820 |
+
*
|
1821 |
+
* @param [in] level The level of bmlib log level
|
1822 |
+
* @retval void
|
1823 |
+
*/
|
1824 |
+
DECL_EXPORT void bmlib_log_set_level(int level);
|
1825 |
+
|
1826 |
+
/**
|
1827 |
+
* @name bmlib_log_set_callback
|
1828 |
+
* @brief To set callback to get bmlib log
|
1829 |
+
* @ingroup bmlib_log
|
1830 |
+
*
|
1831 |
+
* @param [in] callback The callback function to get bmlib log
|
1832 |
+
* @retval void
|
1833 |
+
*/
|
1834 |
+
DECL_EXPORT void bmlib_log_set_callback(void (*callback)(const char*, int, const char*, va_list args));
|
1835 |
+
|
1836 |
+
/**
|
1837 |
+
* @name bm_set_debug_mode
|
1838 |
+
* @brief To set the debug mode for firmware log for tpu
|
1839 |
+
* @ingroup bmlib_log
|
1840 |
+
*
|
1841 |
+
* @param [in] handle The device handle
|
1842 |
+
* @param [in] mode The debug mode of fw log, 0/1 for disable/enable log
|
1843 |
+
* @retval void
|
1844 |
+
*/
|
1845 |
+
DECL_EXPORT void bm_set_debug_mode(bm_handle_t handle, int mode);
|
1846 |
+
|
1847 |
+
/**
|
1848 |
+
* @name bmlib_api_dbg_callback
|
1849 |
+
* @brief To set debug callback to get firmware log
|
1850 |
+
* @ingroup bmlib_log
|
1851 |
+
*
|
1852 |
+
* @param [in] bmlib_api_dbg_callback callback to get firmware log
|
1853 |
+
* @retval void
|
1854 |
+
*/
|
1855 |
+
typedef void (*bmlib_api_dbg_callback)(int, int, int, const char*);
|
1856 |
+
// api, result, duratioin, log, third int for api duration for future
|
1857 |
+
DECL_EXPORT void bmlib_set_api_dbg_callback(bmlib_api_dbg_callback callback);
|
1858 |
+
|
1859 |
+
/**
|
1860 |
+
* @name bmcpu_get_cpu_status
|
1861 |
+
* @brief Get bmcpu status
|
1862 |
+
* @ingroup bmlib_log
|
1863 |
+
*
|
1864 |
+
* @param [in] handle The device handle
|
1865 |
+
* @retval BMCPU_RUNNING bmcpu is running.
|
1866 |
+
* Other code Fails.
|
1867 |
+
*/
|
1868 |
+
DECL_EXPORT bm_cpu_status_t bmcpu_get_cpu_status(bm_handle_t handle);
|
1869 |
+
|
1870 |
+
/**
|
1871 |
+
* @name bmcpu_start_cpu
|
1872 |
+
* @brief Start cpu in pcie mode
|
1873 |
+
* @ingroup bmlib_log
|
1874 |
+
*
|
1875 |
+
* @param [in] handle The device handle
|
1876 |
+
* @param [in] boot_file Fip file
|
1877 |
+
* @param [in] core_file Itb file
|
1878 |
+
* @retval BM_SUCCESS Succeeds.
|
1879 |
+
* Other code Fails.
|
1880 |
+
*/
|
1881 |
+
DECL_EXPORT bm_status_t bmcpu_start_cpu(bm_handle_t handle, char *boot_file, char *core_file);
|
1882 |
+
|
1883 |
+
/**
|
1884 |
+
* @name bmcpu_open_process
|
1885 |
+
* @brief Open a process to do some work
|
1886 |
+
* @ingroup bmlib_log
|
1887 |
+
*
|
1888 |
+
* @param [in] handle The device handle
|
1889 |
+
* @param [in] flags Process flags
|
1890 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
1891 |
+
* @retval >= 0 process handle
|
1892 |
+
* < 0 Other code Fails.
|
1893 |
+
*/
|
1894 |
+
DECL_EXPORT int bmcpu_open_process(bm_handle_t handle, unsigned int flags, int timeout);
|
1895 |
+
|
1896 |
+
/**
|
1897 |
+
* @name bmcpu_load_library
|
1898 |
+
* @brief Load a share library(so) to specific process
|
1899 |
+
* @ingroup bmlib_log
|
1900 |
+
*
|
1901 |
+
* @param [in] handle The device handle
|
1902 |
+
* @param [in] process_handle Process handle
|
1903 |
+
* @param [in] library_file Library file path
|
1904 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
1905 |
+
* @retval BM_SUCCESS Succeeds.
|
1906 |
+
* Other code Fails.
|
1907 |
+
*/
|
1908 |
+
DECL_EXPORT bm_status_t bmcpu_load_library(bm_handle_t handle, int process_handle, char *library_file, int timeout);
|
1909 |
+
|
1910 |
+
/**
|
1911 |
+
* @name bmcpu_unload_library
|
1912 |
+
* @brief Load a share library(so) to specific process
|
1913 |
+
* @ingroup bmlib_log
|
1914 |
+
*
|
1915 |
+
* @param [in] handle The device handle
|
1916 |
+
* @param [in] process_handle Process handle
|
1917 |
+
* @param [in] library_file Library file path
|
1918 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
1919 |
+
* @retval BM_SUCCESS Succeeds.
|
1920 |
+
* Other code Fails.
|
1921 |
+
*/
|
1922 |
+
DECL_EXPORT bm_status_t bmcpu_unload_library(bm_handle_t handle, int process_handle, char *library_file, int timeout);
|
1923 |
+
|
1924 |
+
/**
|
1925 |
+
* @name bmcpu_exec_function
|
1926 |
+
* @brief Execute specific function in specific process
|
1927 |
+
* @ingroup bmlib_log
|
1928 |
+
*
|
1929 |
+
* @param [in] handle The device handle
|
1930 |
+
* @param [in] process_handle Process handle
|
1931 |
+
* @param [in] function_name Function name
|
1932 |
+
* @param [in] function_param Function parameters
|
1933 |
+
* @param [in] param_size Parameters size in bytes
|
1934 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
1935 |
+
* @retval 0 success.
|
1936 |
+
* >0 code fails from bmlib
|
1937 |
+
* <0 code fails from function
|
1938 |
+
*/
|
1939 |
+
DECL_EXPORT int bmcpu_exec_function(bm_handle_t handle,
|
1940 |
+
int process_handle,
|
1941 |
+
char *function_name,
|
1942 |
+
void *function_param,
|
1943 |
+
unsigned int param_size,
|
1944 |
+
int timeout);
|
1945 |
+
|
1946 |
+
#define BMCPU_EXEC_OPT_NO_FLUSH_CACHE 1
|
1947 |
+
/**
|
1948 |
+
* @name bmcpu_exec_function_ext
|
1949 |
+
* @brief Execute specific function in specific process
|
1950 |
+
* @ingroup bmlib_log
|
1951 |
+
*
|
1952 |
+
* @param [in] handle The device handle
|
1953 |
+
* @param [in] process_handle Process handle
|
1954 |
+
* @param [in] function_name Function name
|
1955 |
+
* @param [in] function_param Function parameters
|
1956 |
+
* @param [in] param_size Parameters size in bytes
|
1957 |
+
* @param [in] opt exec options
|
1958 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
1959 |
+
* @retval 0 success.
|
1960 |
+
* >0 code fails from bmlib
|
1961 |
+
* <0 code fails from function
|
1962 |
+
*/
|
1963 |
+
DECL_EXPORT int bmcpu_exec_function_ext(bm_handle_t handle,
|
1964 |
+
int process_handle,
|
1965 |
+
char *function_name,
|
1966 |
+
void *function_param,
|
1967 |
+
unsigned int param_size,
|
1968 |
+
unsigned int opt,
|
1969 |
+
int timeout);
|
1970 |
+
|
1971 |
+
/**
|
1972 |
+
* @name bmcpu_exec_function_async
|
1973 |
+
* @brief Execute specific function in specific process asynchronous
|
1974 |
+
* user should use bm_query_exec_function_result to query result
|
1975 |
+
* @ingroup bmlib_log
|
1976 |
+
*
|
1977 |
+
* @param [in] handle The device handle
|
1978 |
+
* @param [in] process_handle Process handle
|
1979 |
+
* @param [in] function_name Function name
|
1980 |
+
* @param [in] function_param Function param
|
1981 |
+
* @param [in] param_size Param size in bytes
|
1982 |
+
* @retval BM_SUCCESS Succeeds.
|
1983 |
+
* Other code Fails.
|
1984 |
+
*/
|
1985 |
+
DECL_EXPORT bm_status_t bmcpu_exec_function_async(bm_handle_t handle,
|
1986 |
+
int process_handle,
|
1987 |
+
char *function_name,
|
1988 |
+
void *function_param,
|
1989 |
+
unsigned int param_size,
|
1990 |
+
unsigned long long *api_handle);
|
1991 |
+
|
1992 |
+
/**
|
1993 |
+
* @name bmcpu_exec_function_async_ext
|
1994 |
+
* @brief Execute specific function in specific process asynchronous
|
1995 |
+
* user should use bm_query_exec_function_result to query result
|
1996 |
+
* @ingroup bmlib_log
|
1997 |
+
*
|
1998 |
+
* @param [in] handle The device handle
|
1999 |
+
* @param [in] process_handle Process handle
|
2000 |
+
* @param [in] function_name Function name
|
2001 |
+
* @param [in] function_param Function param
|
2002 |
+
* @param [in] param_size Param size in bytes
|
2003 |
+
* @param [in] opt exec options
|
2004 |
+
* @retval BM_SUCCESS Succeeds.
|
2005 |
+
* Other code Fails.
|
2006 |
+
*/
|
2007 |
+
DECL_EXPORT bm_status_t bmcpu_exec_function_async_ext(bm_handle_t handle,
|
2008 |
+
int process_handle,
|
2009 |
+
char *function_name,
|
2010 |
+
void *function_param,
|
2011 |
+
unsigned int param_size,
|
2012 |
+
unsigned int opt,
|
2013 |
+
unsigned long long *api_handle);
|
2014 |
+
|
2015 |
+
/**
|
2016 |
+
* @name bmcpu_query_exec_function_result
|
2017 |
+
* @brief Query result from function called by bm_exec_function
|
2018 |
+
* @ingroup bmlib_log
|
2019 |
+
*
|
2020 |
+
* @param [in] handle The device handle
|
2021 |
+
* @param [in] api_handle Api handle return by bm_exec_function_async
|
2022 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
2023 |
+
* @retval 0 success.
|
2024 |
+
* >0 code fails from bmlib
|
2025 |
+
* <0 code fails from function
|
2026 |
+
*/
|
2027 |
+
DECL_EXPORT int bmcpu_query_exec_function_result(bm_handle_t handle, unsigned long long api_handle, int timeout);
|
2028 |
+
|
2029 |
+
/**
|
2030 |
+
* @name bmcpu_map_phys_addr
|
2031 |
+
* @brief Map physical address in specific process
|
2032 |
+
* @ingroup bmlib_log
|
2033 |
+
*
|
2034 |
+
* @param [in] handle The device handle
|
2035 |
+
* @param [in] process_handle Process handle
|
2036 |
+
* @param [in] phys_addr Physical address
|
2037 |
+
* @param [in] size Map size in bytes
|
2038 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
2039 |
+
* @retval >0 virtual address
|
2040 |
+
* 0 fails
|
2041 |
+
*/
|
2042 |
+
DECL_EXPORT void *bmcpu_map_phys_addr(bm_handle_t handle, int process_handle, void *phys_addr, unsigned int size, int timeout);
|
2043 |
+
|
2044 |
+
/**
|
2045 |
+
* @name bmcpu_unmap_phys_addr
|
2046 |
+
* @brief Unmap physical address in specific process
|
2047 |
+
* @ingroup bmlib_log
|
2048 |
+
*
|
2049 |
+
* @param [in] handle The device handle
|
2050 |
+
* @param [in] process_handle Process handle
|
2051 |
+
* @param [in] phys_addr Physical address
|
2052 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
2053 |
+
* @retval <0 fail
|
2054 |
+
* 0 success
|
2055 |
+
*/
|
2056 |
+
DECL_EXPORT bm_status_t bmcpu_unmap_phys_addr(bm_handle_t handle, int process_handle, void *phys_addr, int timeout);
|
2057 |
+
|
2058 |
+
/**
|
2059 |
+
* @name bmcpu_close_process
|
2060 |
+
* @brief Close process
|
2061 |
+
* @ingroup bmlib_log
|
2062 |
+
*
|
2063 |
+
* @param [in] handle The device handle
|
2064 |
+
* @param [in] process_handle Process handle
|
2065 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
2066 |
+
* @retval BM_SUCCESS Succeeds.
|
2067 |
+
* Other code Fails.
|
2068 |
+
*/
|
2069 |
+
DECL_EXPORT bm_status_t bmcpu_close_process(bm_handle_t handle, int process_handle, int timeout);
|
2070 |
+
|
2071 |
+
/**
|
2072 |
+
* @name bmcpu_reset_cpu
|
2073 |
+
* @brief Reset cpu in pcie mode
|
2074 |
+
* @ingroup bmlib_log
|
2075 |
+
*
|
2076 |
+
* @param [in] handle The device handle
|
2077 |
+
* @retval BM_SUCCESS Succeeds.
|
2078 |
+
* Other code Fails.
|
2079 |
+
*/
|
2080 |
+
DECL_EXPORT bm_status_t bmcpu_reset_cpu(bm_handle_t handle);
|
2081 |
+
|
2082 |
+
/**
|
2083 |
+
* @name bm_enable_perf_monitor
|
2084 |
+
* @brief enable perf monitor to get gdma and tpu performance data
|
2085 |
+
* @ingroup bmlib_perf
|
2086 |
+
*
|
2087 |
+
* @param [in] handle The device handle
|
2088 |
+
* @param [in] perf_monitor The monitor to perf
|
2089 |
+
* @retval BM_SUCCESS Succeeds.
|
2090 |
+
* Other code Fails.
|
2091 |
+
*/
|
2092 |
+
DECL_EXPORT bm_status_t bm_enable_perf_monitor(bm_handle_t handle, bm_perf_monitor_t *perf_monitor);
|
2093 |
+
|
2094 |
+
/**
|
2095 |
+
* @name bm_disable_perf_monitor
|
2096 |
+
* @brief disable perf monitor to get gdma and tpu performance data
|
2097 |
+
* @ingroup bmlib_perf
|
2098 |
+
*
|
2099 |
+
* @param [in] handle The device handle
|
2100 |
+
* @param [in] perf_monitor The monitor to perf
|
2101 |
+
* @retval BM_SUCCESS Succeeds.
|
2102 |
+
* Other code Fails.
|
2103 |
+
*/
|
2104 |
+
DECL_EXPORT bm_status_t bm_disable_perf_monitor(bm_handle_t handle, bm_perf_monitor_t *perf_monitor);
|
2105 |
+
|
2106 |
+
/**
|
2107 |
+
* @name bmcpu_set_log
|
2108 |
+
* @brief Set cpu log options
|
2109 |
+
* @ingroup bmlib_log
|
2110 |
+
*
|
2111 |
+
* @param [in] handle The device handle
|
2112 |
+
* @param [in] log_level 0: DEBUG 1:INFO 2:WARN 3:ERROR 4:FATAL
|
2113 |
+
* @param [in] log_to_console 1: YES 0: No
|
2114 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
2115 |
+
* @retval BM_SUCCESS Succeeds.
|
2116 |
+
* Other code Fails.
|
2117 |
+
*/
|
2118 |
+
DECL_EXPORT bm_status_t bmcpu_set_log(bm_handle_t handle, unsigned int log_level, unsigned int log_to_console, int timeout);
|
2119 |
+
|
2120 |
+
/**
|
2121 |
+
* @name bmcpu_get_log
|
2122 |
+
* @brief Get cpu log file
|
2123 |
+
* @ingroup bmlib_log
|
2124 |
+
*
|
2125 |
+
* @param [in] handle The device handle
|
2126 |
+
* @param [in] process_handle Process handle
|
2127 |
+
* @param [in] log_file save log as file
|
2128 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
2129 |
+
* @retval BM_SUCCESS Succeeds.
|
2130 |
+
* Other code Fails.
|
2131 |
+
*/
|
2132 |
+
DECL_EXPORT bm_status_t bmcpu_get_log(bm_handle_t handle, int process_handle, char *log_file, int timeout);
|
2133 |
+
|
2134 |
+
/**
|
2135 |
+
* @name bmcpu_sync_time
|
2136 |
+
* @brief Sync device cpu time with host
|
2137 |
+
* @ingroup bmlib_log
|
2138 |
+
*
|
2139 |
+
* @param [in] handle The device handle
|
2140 |
+
* @retval BM_SUCCESS Succeeds.
|
2141 |
+
* Other code Fails.
|
2142 |
+
*/
|
2143 |
+
DECL_EXPORT bm_status_t bmcpu_sync_time(bm_handle_t handle);
|
2144 |
+
|
2145 |
+
/*******************trace and profile releated functions **********************/
|
2146 |
+
struct bm_heap_stat {
|
2147 |
+
unsigned int mem_total;
|
2148 |
+
unsigned int mem_avail;
|
2149 |
+
unsigned int mem_used;
|
2150 |
+
};
|
2151 |
+
|
2152 |
+
typedef struct bm_heap_stat_byte {
|
2153 |
+
unsigned int heap_id;
|
2154 |
+
unsigned long long mem_total;
|
2155 |
+
unsigned long long mem_avail;
|
2156 |
+
unsigned long long mem_used;
|
2157 |
+
unsigned long long mem_start_addr;
|
2158 |
+
} bm_heap_stat_byte_t;
|
2159 |
+
|
2160 |
+
typedef struct bm_dev_stat {
|
2161 |
+
int mem_total;
|
2162 |
+
int mem_used;
|
2163 |
+
int tpu_util;
|
2164 |
+
int heap_num;
|
2165 |
+
struct bm_heap_stat heap_stat[4];
|
2166 |
+
} bm_dev_stat_t;
|
2167 |
+
|
2168 |
+
/**
|
2169 |
+
* @name bm_get_stat
|
2170 |
+
* @brief To get the stat data at the moment
|
2171 |
+
* @ingroup bmlib_runtime
|
2172 |
+
*
|
2173 |
+
* @param [in] handle The device handle
|
2174 |
+
* @param [out] profile The result stat data
|
2175 |
+
* @retval BM_SUCCESS Succeeds.
|
2176 |
+
* Other code Fails.
|
2177 |
+
*/
|
2178 |
+
DECL_EXPORT bm_status_t bm_get_stat(bm_handle_t handle, bm_dev_stat_t *stat);
|
2179 |
+
|
2180 |
+
/**
|
2181 |
+
* @name bm_get_gmem_heap_id
|
2182 |
+
* @brief To get the heap id of allocated global memory
|
2183 |
+
* @ingroup bmlib_runtime
|
2184 |
+
*
|
2185 |
+
* @param [in] handle The device handle
|
2186 |
+
* @param [in] pmem The allocted global memory
|
2187 |
+
* @param [out] heapid The result of get heap id
|
2188 |
+
* @retval BM_SUCCESS Succeeds.
|
2189 |
+
* Other code Fails.
|
2190 |
+
*/
|
2191 |
+
|
2192 |
+
DECL_EXPORT bm_status_t bm_get_gmem_heap_id(bm_handle_t handle, bm_device_mem_t *pmem, unsigned int *heapid);
|
2193 |
+
|
2194 |
+
/**
|
2195 |
+
* @name sg_get_gmem_heap_id
|
2196 |
+
* @brief To get the heap id of allocated global memory
|
2197 |
+
* @ingroup bmlib_runtime
|
2198 |
+
*
|
2199 |
+
* @param [in] handle The device handle
|
2200 |
+
* @param [in] pmem The allocted global memory
|
2201 |
+
* @param [out] heapid The result of get heap id
|
2202 |
+
* @retval BM_SUCCESS Succeeds.
|
2203 |
+
* Other code Fails.
|
2204 |
+
*/
|
2205 |
+
|
2206 |
+
DECL_EXPORT bm_status_t sg_get_gmem_heap_id(bm_handle_t handle, sg_device_mem_t *pmem, unsigned int *heapid);
|
2207 |
+
|
2208 |
+
/**
|
2209 |
+
* @name bm_get_gmem_total_heap_num
|
2210 |
+
* @brief To get the total heap num of global memory
|
2211 |
+
* @ingroup bmlib_runtime
|
2212 |
+
*
|
2213 |
+
* @param [in] handle The device handle
|
2214 |
+
* @param [in] heap_num The result of get total num
|
2215 |
+
* @retval BM_SUCCESS Succeeds.
|
2216 |
+
* Other code Fails.
|
2217 |
+
*/
|
2218 |
+
DECL_EXPORT bm_status_t bm_get_gmem_total_heap_num(bm_handle_t handle, unsigned int *heap_num);
|
2219 |
+
|
2220 |
+
/**
|
2221 |
+
* @name bm_get_gmem_heap_stat_byte_by_id
|
2222 |
+
* @brief To get the heap stat by heap id
|
2223 |
+
* @ingroup bmlib_runtime
|
2224 |
+
*
|
2225 |
+
* @param [in] handle The device handle
|
2226 |
+
* @param [in] heap_id The heap index to get heap status
|
2227 |
+
* @param [out] pheap_byte The result of get heap status
|
2228 |
+
* @retval BM_SUCCESS Succeeds.
|
2229 |
+
* Other code Fails.
|
2230 |
+
*/
|
2231 |
+
DECL_EXPORT bm_status_t bm_get_gmem_heap_stat_byte_by_id(bm_handle_t handle, bm_heap_stat_byte_t *pheap_byte, unsigned int heap_id);
|
2232 |
+
|
2233 |
+
DECL_EXPORT bm_status_t bm_load_firmware(
|
2234 |
+
bm_handle_t handle,
|
2235 |
+
const char *firmware_tcm,
|
2236 |
+
const char *firmware_ddr);
|
2237 |
+
|
2238 |
+
#define bmkernel_load_firmware okkernel_load_firmware
|
2239 |
+
DECL_EXPORT bm_status_t okkernel_load_firmware(
|
2240 |
+
bm_handle_t handle,
|
2241 |
+
const char *firmware_tcm,
|
2242 |
+
const char *firmware_ddr);
|
2243 |
+
|
2244 |
+
DECL_EXPORT bm_status_t okkernel_launch_async(
|
2245 |
+
bm_handle_t handle,
|
2246 |
+
const char *func_name,
|
2247 |
+
const void *args,
|
2248 |
+
unsigned int size);
|
2249 |
+
|
2250 |
+
DECL_EXPORT bm_status_t okkernel_launch_sync(
|
2251 |
+
bm_handle_t handle,
|
2252 |
+
const char *func_name,
|
2253 |
+
const void *args,
|
2254 |
+
unsigned int size);
|
2255 |
+
|
2256 |
+
DECL_EXPORT bm_status_t tpu_kernel_launch_sync(
|
2257 |
+
bm_handle_t handle,
|
2258 |
+
const char *func_name,
|
2259 |
+
const void *args,
|
2260 |
+
unsigned int size);
|
2261 |
+
|
2262 |
+
DECL_EXPORT bm_status_t okkernel_sync(bm_handle_t handle);
|
2263 |
+
|
2264 |
+
/**
|
2265 |
+
* @name bmkernel_launch
|
2266 |
+
* @brief send api to device and launch function
|
2267 |
+
* @ingroup bmlib_runtime
|
2268 |
+
*
|
2269 |
+
* @param [in] handle The device handle
|
2270 |
+
* @param [in] api cmd struct pointer
|
2271 |
+
* @param [in] api cmd length
|
2272 |
+
* @retval BM_SUCCESS Succeeds.
|
2273 |
+
* Other code Fails.
|
2274 |
+
*/
|
2275 |
+
DECL_EXPORT bm_status_t bmkernel_launch(bm_handle_t handle, const void *args,
|
2276 |
+
unsigned int size);
|
2277 |
+
|
2278 |
+
/**
|
2279 |
+
* @name bmkernel_load_lookup_table
|
2280 |
+
* @brief load lookup table to l2-sram
|
2281 |
+
* @ingroup bmlib_runtime
|
2282 |
+
*
|
2283 |
+
* @param [in] handle The device handle
|
2284 |
+
* @param [in] table which loaded to l2-sram
|
2285 |
+
* @param [in] table size
|
2286 |
+
* @retval BM_SUCCESS Succeeds.
|
2287 |
+
* Other code Fails.
|
2288 |
+
*/
|
2289 |
+
DECL_EXPORT bm_status_t bmkernel_load_lookup_table(bm_handle_t handle, const void* table, unsigned int size);
|
2290 |
+
|
2291 |
+
/*******************device management api functions ********************************************/
|
2292 |
+
/**
|
2293 |
+
* @name bm_get_tpu_current
|
2294 |
+
* @brief get tpu current
|
2295 |
+
* @ingroup bmlib_runtime
|
2296 |
+
*
|
2297 |
+
* @param [in] handle The device handle
|
2298 |
+
* @param [out] tpuc(mA) The pointer for tpu current
|
2299 |
+
* @retval BM_SUCCESS Succeeds.
|
2300 |
+
* Other code Fails.
|
2301 |
+
*/
|
2302 |
+
DECL_EXPORT bm_status_t bm_get_tpu_current(bm_handle_t handle, unsigned int *tpuc);
|
2303 |
+
|
2304 |
+
/**
|
2305 |
+
* @name bm_get_board_max_power
|
2306 |
+
* @brief get board support max power
|
2307 |
+
* @ingroup bmlib_runtime
|
2308 |
+
*
|
2309 |
+
* @param [in] handle The device handle
|
2310 |
+
* @param [out] maxp The pointer for maxp
|
2311 |
+
* @retval BM_SUCCESS Succeeds.
|
2312 |
+
* Other code Fails.
|
2313 |
+
*/
|
2314 |
+
DECL_EXPORT bm_status_t bm_get_board_max_power(bm_handle_t handle, unsigned int *maxp);
|
2315 |
+
|
2316 |
+
/**
|
2317 |
+
* @name bm_get_board_power
|
2318 |
+
* @brief get board power
|
2319 |
+
* @ingroup bmlib_runtime
|
2320 |
+
*
|
2321 |
+
* @param [in] handle The device handle
|
2322 |
+
* @param [out] boardp The pointer for boardp
|
2323 |
+
* @retval BM_SUCCESS Succeeds.
|
2324 |
+
* Other code Fails.
|
2325 |
+
*/
|
2326 |
+
DECL_EXPORT bm_status_t bm_get_board_power(bm_handle_t handle, unsigned int *boardp);
|
2327 |
+
|
2328 |
+
/**
|
2329 |
+
* @name bm_get_fan_speed
|
2330 |
+
* @brief get board fan speed
|
2331 |
+
* @ingroup bmlib_runtime
|
2332 |
+
*
|
2333 |
+
* @param [in] handle The device handle
|
2334 |
+
* @param [out] fan The pointer for fan speed
|
2335 |
+
* @retval BM_SUCCESS Succeeds.
|
2336 |
+
* Other code Fails.
|
2337 |
+
*/
|
2338 |
+
DECL_EXPORT bm_status_t bm_get_fan_speed(bm_handle_t handle, unsigned int *fan);
|
2339 |
+
|
2340 |
+
/**
|
2341 |
+
* @name bm_get_ecc_correct_num
|
2342 |
+
* @brief get ecc_correct_num
|
2343 |
+
* @ingroup device management api
|
2344 |
+
*
|
2345 |
+
* @param [in] handle The device handle
|
2346 |
+
* @param [out] ecc_correct_num
|
2347 |
+
* @retval BM_SUCCESS Succeeds.
|
2348 |
+
* Other code Fails.
|
2349 |
+
*/
|
2350 |
+
#ifdef __linux__
|
2351 |
+
DECL_EXPORT bm_status_t bm_get_ecc_correct_num(bm_handle_t handle, unsigned long *ecc_correct_num);
|
2352 |
+
#else
|
2353 |
+
DECL_EXPORT bm_status_t bm_get_ecc_correct_num(bm_handle_t handle, unsigned long long *ecc_correct_num);
|
2354 |
+
#endif
|
2355 |
+
/**
|
2356 |
+
* @name bm_get_12v_atx
|
2357 |
+
* @brief get atx_12v
|
2358 |
+
* @ingroup device management api
|
2359 |
+
*
|
2360 |
+
* @param [in] handle The device handle
|
2361 |
+
* @param [out] atx_12v
|
2362 |
+
* @retval BM_SUCCESS Succeeds.
|
2363 |
+
* Other code Fails.
|
2364 |
+
*/
|
2365 |
+
DECL_EXPORT bm_status_t bm_get_12v_atx(bm_handle_t handle, int *atx_12v);
|
2366 |
+
|
2367 |
+
/**
|
2368 |
+
* @name bm_get_product_sn
|
2369 |
+
* @brief get SE5 sn
|
2370 |
+
* @ingroup device management api
|
2371 |
+
*
|
2372 |
+
* @param [out] product_sn
|
2373 |
+
* @retval BM_SUCCESS Succeeds.
|
2374 |
+
* Other code Fails.
|
2375 |
+
*/
|
2376 |
+
DECL_EXPORT bm_status_t bm_get_product_sn(char *product_sn);
|
2377 |
+
|
2378 |
+
/**
|
2379 |
+
* @name bm_get_sn
|
2380 |
+
* @brief get sn
|
2381 |
+
* @ingroup device management api
|
2382 |
+
*
|
2383 |
+
* @param [in] handle The device handle
|
2384 |
+
* @param [out] sn
|
2385 |
+
* @retval BM_SUCCESS Succeeds.
|
2386 |
+
* Other code Fails.
|
2387 |
+
*/
|
2388 |
+
DECL_EXPORT bm_status_t bm_get_sn(bm_handle_t handle, char *sn);
|
2389 |
+
|
2390 |
+
/**
|
2391 |
+
* @name bm_get_status
|
2392 |
+
* @brief get chip status
|
2393 |
+
* @ingroup device management api
|
2394 |
+
*
|
2395 |
+
* @param [in] handle The device handle
|
2396 |
+
* @param [out] status The board error status, each bit represents an error state
|
2397 |
+
* status == 0x0, borad is nornal, staus > 0, borad is abnormal;
|
2398 |
+
* bit0 == 1, tpu is hang
|
2399 |
+
* bit1 == 1, pcie link abnormal
|
2400 |
+
* bit2 == 1, board temperature is too high
|
2401 |
+
* @retval BM_SUCCESS Succeeds.
|
2402 |
+
* Other code Fails.
|
2403 |
+
*/
|
2404 |
+
DECL_EXPORT bm_status_t bm_get_status(bm_handle_t handle, int *status);
|
2405 |
+
|
2406 |
+
/**
|
2407 |
+
* @name bm_get_tpu_maxclk
|
2408 |
+
* @brief get tpu_maxclk
|
2409 |
+
* @ingroup device management api
|
2410 |
+
*
|
2411 |
+
* @param [in] handle The device handle
|
2412 |
+
* @param [out] tpu_maxclk
|
2413 |
+
* @retval BM_SUCCESS Succeeds.
|
2414 |
+
* Other code Fails.
|
2415 |
+
*/
|
2416 |
+
DECL_EXPORT bm_status_t bm_get_tpu_maxclk(bm_handle_t handle, unsigned int *tpu_maxclk);
|
2417 |
+
|
2418 |
+
/**
|
2419 |
+
* @name bm_get_tpu_minclk
|
2420 |
+
* @brief get tpu_minclk
|
2421 |
+
* @ingroup device management api
|
2422 |
+
*
|
2423 |
+
* @param [in] handle The device handle
|
2424 |
+
* @param [out] tpu_minclk
|
2425 |
+
* @retval BM_SUCCESS Succeeds.
|
2426 |
+
* Other code Fails.
|
2427 |
+
*/
|
2428 |
+
DECL_EXPORT bm_status_t bm_get_tpu_minclk(bm_handle_t handle, unsigned int *tpu_minclk);
|
2429 |
+
|
2430 |
+
/**
|
2431 |
+
* @name bm_get_driver_version
|
2432 |
+
* @brief get driver version
|
2433 |
+
* @ingroup device management api
|
2434 |
+
*
|
2435 |
+
* @param [in] handle The device handle
|
2436 |
+
* @param [out] driver_version
|
2437 |
+
* @retval BM_SUCCESS Succeeds.
|
2438 |
+
* Other code Fails.
|
2439 |
+
*/
|
2440 |
+
DECL_EXPORT bm_status_t bm_get_driver_version(bm_handle_t handle, int *driver_version);
|
2441 |
+
|
2442 |
+
/**
|
2443 |
+
* @name bm_get_board_name
|
2444 |
+
* @brief get device board name
|
2445 |
+
* @ingroup device management api
|
2446 |
+
*
|
2447 |
+
* @param [in] handle The device handle
|
2448 |
+
* @param [out] board_name
|
2449 |
+
* @retval BM_SUCCESS Succeeds.
|
2450 |
+
* Other code Fails.
|
2451 |
+
*/
|
2452 |
+
DECL_EXPORT bm_status_t bm_get_board_name(bm_handle_t handle, char *name);
|
2453 |
+
|
2454 |
+
/**
|
2455 |
+
* @name bm_get_board_temp
|
2456 |
+
* @brief get board temperature
|
2457 |
+
* @ingroup device management api
|
2458 |
+
*
|
2459 |
+
* @param [in] handle The device handle
|
2460 |
+
* @param [out] board_temp
|
2461 |
+
* @retval BM_SUCCESS Succeeds.
|
2462 |
+
* Other code Fails.
|
2463 |
+
*/
|
2464 |
+
DECL_EXPORT bm_status_t bm_get_board_temp(bm_handle_t handle, unsigned int *board_temp);
|
2465 |
+
|
2466 |
+
/**
|
2467 |
+
* @name bm_get_chip_temp
|
2468 |
+
* @brief get chip temperature
|
2469 |
+
* @ingroup device management api
|
2470 |
+
*
|
2471 |
+
* @param [in] handle The device handle
|
2472 |
+
* @param [out] chip_temp
|
2473 |
+
* @retval BM_SUCCESS Succeeds.
|
2474 |
+
* Other code Fails.
|
2475 |
+
*/
|
2476 |
+
DECL_EXPORT bm_status_t bm_get_chip_temp(bm_handle_t handle, unsigned int *chip_temp);
|
2477 |
+
|
2478 |
+
/**
|
2479 |
+
* @name bm_get_tpu_power
|
2480 |
+
* @brief get TPU power
|
2481 |
+
* @ingroup device management api
|
2482 |
+
*
|
2483 |
+
* @param [in] handle The device handle
|
2484 |
+
* @param [out] tpu_power
|
2485 |
+
* @retval BM_SUCCESS Succeeds.
|
2486 |
+
* Other code Fails.
|
2487 |
+
*/
|
2488 |
+
DECL_EXPORT bm_status_t bm_get_tpu_power(bm_handle_t handle, float *tpu_power);
|
2489 |
+
|
2490 |
+
/**
|
2491 |
+
* @name bm_get_tpu_volt
|
2492 |
+
* @brief get TPU voltage
|
2493 |
+
* @ingroup device management api
|
2494 |
+
*
|
2495 |
+
* @param [in] handle The device handle
|
2496 |
+
* @param [out] tpu_volt
|
2497 |
+
* @retval BM_SUCCESS Succeeds.
|
2498 |
+
* Other code Fails.
|
2499 |
+
*/
|
2500 |
+
DECL_EXPORT bm_status_t bm_get_tpu_volt(bm_handle_t handle, unsigned int *tpu_volt);
|
2501 |
+
|
2502 |
+
/**
|
2503 |
+
* @name bm_get_card_id
|
2504 |
+
* @brief get card id
|
2505 |
+
* @ingroup device management api
|
2506 |
+
*
|
2507 |
+
* @param [in] handle The device handle
|
2508 |
+
* @param [out] card_id
|
2509 |
+
* @retval BM_SUCCESS Succeeds.
|
2510 |
+
* Other code Fails.
|
2511 |
+
*/
|
2512 |
+
DECL_EXPORT bm_status_t bm_get_card_id(bm_handle_t handle, unsigned int *card_id);
|
2513 |
+
|
2514 |
+
/**
|
2515 |
+
* @name bm_get_card_num
|
2516 |
+
* @brief get card number
|
2517 |
+
* @ingroup device management api
|
2518 |
+
*
|
2519 |
+
* @param [in] handle The device handle
|
2520 |
+
* @param [out] card_id
|
2521 |
+
* @retval BM_SUCCESS Succeeds.
|
2522 |
+
* Other code Fails.
|
2523 |
+
*/
|
2524 |
+
DECL_EXPORT bm_status_t bm_get_card_num(unsigned int *card_num);
|
2525 |
+
|
2526 |
+
/**
|
2527 |
+
* @name bm_get_chip_num_from_card
|
2528 |
+
* @brief get chip number and start chip id from card
|
2529 |
+
* @ingroup device management api
|
2530 |
+
*
|
2531 |
+
* @param [in] handle The device handle
|
2532 |
+
* @param [out] chip_num
|
2533 |
+
* @param [out] dev_start_index
|
2534 |
+
* @retval BM_SUCCESS Succeeds.
|
2535 |
+
* Other code Fails.
|
2536 |
+
*/
|
2537 |
+
DECL_EXPORT bm_status_t bm_get_chip_num_from_card(unsigned int card_id, unsigned int *chip_num, unsigned int *dev_start_index);
|
2538 |
+
|
2539 |
+
/**
|
2540 |
+
* @name bm_get_dynfreq_status
|
2541 |
+
* @brief get chip dynamic freq status
|
2542 |
+
* @ingroup device management api
|
2543 |
+
*
|
2544 |
+
* @param [in] handle The device handle
|
2545 |
+
* @param [out] dynfreq_status
|
2546 |
+
* @retval BM_SUCCESS Succeeds.
|
2547 |
+
* Other code Fails.
|
2548 |
+
*/
|
2549 |
+
DECL_EXPORT bm_status_t bm_get_dynfreq_status(bm_handle_t handle, int *dynfreq_status);
|
2550 |
+
|
2551 |
+
/**
|
2552 |
+
* @name bm_change_dynfreq_status
|
2553 |
+
* @brief change(enable/disable) chip dynamic freq status
|
2554 |
+
* @ingroup device management api
|
2555 |
+
*
|
2556 |
+
* @param [in] handle The device handle
|
2557 |
+
* @param [in] new_status
|
2558 |
+
* @retval BM_SUCCESS Succeeds.
|
2559 |
+
* Other code Fails.
|
2560 |
+
*/
|
2561 |
+
DECL_EXPORT bm_status_t bm_change_dynfreq_status(bm_handle_t handle, int new_status);
|
2562 |
+
|
2563 |
+
/**
|
2564 |
+
* @name bm_get_tpu_scalar_num
|
2565 |
+
* @brief To get the core number of TPU scalar
|
2566 |
+
* @ingroup bmlib_runtime
|
2567 |
+
*
|
2568 |
+
* @param [in] handle The device handle
|
2569 |
+
* @param [out] core_num The core number of TPU scalar
|
2570 |
+
* @retval BM_SUCCESS Succeeds.
|
2571 |
+
* Other code Fails.
|
2572 |
+
*/
|
2573 |
+
DECL_EXPORT bm_status_t bm_get_tpu_scalar_num(bm_handle_t handle, unsigned int *core_num);
|
2574 |
+
|
2575 |
+
#define bm_get_tpu_core_num bm_get_tpu_scalar_num
|
2576 |
+
|
2577 |
+
#if defined(__cplusplus)
|
2578 |
+
}
|
2579 |
+
#endif
|
2580 |
+
|
2581 |
+
#endif /* BM_RUNTIME_H_ */
|
Baichuan2/src/include/bmruntime_interface.h
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*****************************************************************************
|
2 |
+
*
|
3 |
+
* Copyright (c) 2016-2026 by Sophgo Technologies Inc. All rights reserved.
|
4 |
+
*
|
5 |
+
* The material in this file is confidential and contains trade secrets
|
6 |
+
* of Sophgo Technologies Inc. This is proprietary information owned by
|
7 |
+
* Sophgo Technologies Inc. No part of this work may be disclosed,
|
8 |
+
* reproduced, copied, transmitted, or used in any way for any purpose,
|
9 |
+
* without the express written permission of Sophgo Technologies Inc.
|
10 |
+
*
|
11 |
+
*****************************************************************************/
|
12 |
+
|
13 |
+
/*****************************************************************************
|
14 |
+
* BMRuntime Interface is mainly for inference.
|
15 |
+
* Also we can use it for device computation from BMLang programming.
|
16 |
+
* Note: please use interface from bmlib_runtime.h for device memory operation.
|
17 |
+
****************************************************************************/
|
18 |
+
|
19 |
+
#ifndef BMRUNTIME_INTERFACE_H_
|
20 |
+
#define BMRUNTIME_INTERFACE_H_
|
21 |
+
|
22 |
+
#include "bmdef.h"
|
23 |
+
|
24 |
+
#ifdef _WIN32
|
25 |
+
#define DECL_EXPORT _declspec(dllexport)
|
26 |
+
#define DECL_IMPORT _declspec(dllimport)
|
27 |
+
#else
|
28 |
+
#define DECL_EXPORT
|
29 |
+
#define DECL_IMPORT
|
30 |
+
#endif
|
31 |
+
|
32 |
+
#if defined(__cplusplus)
|
33 |
+
extern "C" {
|
34 |
+
#endif
|
35 |
+
|
36 |
+
/* --------------------------------------------------------------------------*/
|
37 |
+
/* interface for basic data type */
|
38 |
+
|
39 |
+
/* get data type byte size */
|
40 |
+
DECL_EXPORT size_t bmrt_data_type_size(bm_data_type_t dtype);
|
41 |
+
|
42 |
+
/*
|
43 |
+
dims array to bm_shape_t,
|
44 |
+
shape and dims should not be NULL, num_dims should not be larger than BM_MAX_DIMS_NUM */
|
45 |
+
DECL_EXPORT void bmrt_shape(bm_shape_t* shape, const int* dims, int num_dims);
|
46 |
+
|
47 |
+
/*
|
48 |
+
number of shape elements, shape should not be NULL and num_dims should not large than
|
49 |
+
BM_MAX_DIMS_NUM */
|
50 |
+
DECL_EXPORT uint64_t bmrt_shape_count(const bm_shape_t* shape);
|
51 |
+
|
52 |
+
/* compare whether two shape is same */
|
53 |
+
DECL_EXPORT bool bmrt_shape_is_same(const bm_shape_t* left, const bm_shape_t* right);
|
54 |
+
|
55 |
+
/*
|
56 |
+
fill a tensor with data type and shape, and st_mode = 0 as default.
|
57 |
+
tensor and p_bmrt should not be NULL, shape count should not be 0.
|
58 |
+
it will alloc device mem to tensor->device_mem, so user should bmrt_free_device(p_bmrt,
|
59 |
+
tensor->device_mem) to free it.*/
|
60 |
+
DECL_EXPORT bool bmrt_tensor(bm_tensor_t* tensor, void* p_bmrt, bm_data_type_t dtype, bm_shape_t shape);
|
61 |
+
|
62 |
+
/*
|
63 |
+
fill a tensor with data type and shape, and st_mode = 0 as default.
|
64 |
+
tensor and p_bmrt should not be NULL, shape count should not be 0.
|
65 |
+
it will alloc device mem to tensor->device_mem on devid-th device.*/
|
66 |
+
DECL_EXPORT bool bmrt_tensor_ex(bm_tensor_t* tensor, void* p_bmrt, int devid, bm_data_type_t dtype, bm_shape_t shape);
|
67 |
+
|
68 |
+
/* fill a tensor with device mem existed, tensor byte size should not large than device mem size */
|
69 |
+
DECL_EXPORT void bmrt_tensor_with_device(bm_tensor_t* tensor, bm_device_mem_t device_mem,
|
70 |
+
bm_data_type_t dtype, bm_shape_t shape);
|
71 |
+
|
72 |
+
/* get tensor bytes size, tensor should not be NULL */
|
73 |
+
DECL_EXPORT size_t bmrt_tensor_bytesize(const bm_tensor_t* tensor);
|
74 |
+
|
75 |
+
/* get tensor mem size allocated in device mem, tensor should not be NULL */
|
76 |
+
DECL_EXPORT size_t bmrt_tensor_device_size(const bm_tensor_t* tensor);
|
77 |
+
|
78 |
+
/* print net info for debug */
|
79 |
+
DECL_EXPORT void bmrt_print_network_info(const bm_net_info_t* net_info);
|
80 |
+
|
81 |
+
/* --------------------------------------------------------------------------*/
|
82 |
+
/**
|
83 |
+
* @name bmrt_create
|
84 |
+
* @brief To create the bmruntime with bm_handle.
|
85 |
+
* @ingroup bmruntime
|
86 |
+
*
|
87 |
+
* This API creates the bmruntime. It returns a void* pointer which is the pointer
|
88 |
+
* of bmruntime. Device id is set when get bm_handle;
|
89 |
+
*
|
90 |
+
* @param [in] bm_handle bm handle. It must be initialized by using bmlib.
|
91 |
+
*
|
92 |
+
* @retval void* the pointer of bmruntime
|
93 |
+
*/
|
94 |
+
DECL_EXPORT void* bmrt_create(bm_handle_t bm_handle);
|
95 |
+
|
96 |
+
/* --------------------------------------------------------------------------*/
|
97 |
+
/**
|
98 |
+
* @name bmrt_create_ex
|
99 |
+
* @brief To create the bmruntime with one or more bm_handle.
|
100 |
+
* @ingroup bmruntime
|
101 |
+
*
|
102 |
+
* This API creates the bmruntime. It returns a void* pointer which is the pointer
|
103 |
+
* of bmruntime.
|
104 |
+
*
|
105 |
+
* @param [in] bm_handles bm handles. They must be initialized by using bmlib.
|
106 |
+
* @param [in] num_handles number of bm_handles.
|
107 |
+
*
|
108 |
+
* @retval void* the pointer of bmruntime
|
109 |
+
*/
|
110 |
+
DECL_EXPORT void *bmrt_create_ex(bm_handle_t *bm_handles, int num_handles);
|
111 |
+
|
112 |
+
/**
|
113 |
+
* @name bmrt_destroy
|
114 |
+
* @brief To destroy the bmruntime pointer
|
115 |
+
* @ingroup bmruntime
|
116 |
+
*
|
117 |
+
* This API destroy the bmruntime.
|
118 |
+
*
|
119 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
120 |
+
*/
|
121 |
+
DECL_EXPORT void bmrt_destroy(void* p_bmrt);
|
122 |
+
|
123 |
+
/**
|
124 |
+
* @name bmrt_get_bm_handle
|
125 |
+
* @brief To get the BM runtime context.
|
126 |
+
* @ingroup bmruntime
|
127 |
+
*
|
128 |
+
* This API get the BM runtime context for using BMDNN, BMCV or BMLIB
|
129 |
+
*
|
130 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
131 |
+
*/
|
132 |
+
DECL_EXPORT void * bmrt_get_bm_handle(void* p_bmrt);
|
133 |
+
|
134 |
+
/**
|
135 |
+
* @name bmrt_load_bmodel
|
136 |
+
* @brief To load the bmodel which is created by BM compiler
|
137 |
+
* @ingroup bmruntime
|
138 |
+
*
|
139 |
+
* This API is to load bmodel created by BM compiler.
|
140 |
+
* After loading bmodel, we can run the inference of neuron network.
|
141 |
+
*
|
142 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
143 |
+
* @param [in] bmodel_path Bmodel file directory.
|
144 |
+
*
|
145 |
+
* @retval true Load context sucess.
|
146 |
+
* @retval false Load context failed.
|
147 |
+
*/
|
148 |
+
DECL_EXPORT bool bmrt_load_bmodel(void* p_bmrt, const char *bmodel_path);
|
149 |
+
|
150 |
+
/**
|
151 |
+
* @name bmrt_load_bmodel_data
|
152 |
+
* @brief To load the bmodel which is created by BM compiler from buffer
|
153 |
+
* @ingroup bmruntime
|
154 |
+
*
|
155 |
+
* This API is to load bmodel created by BM compiler.
|
156 |
+
* After loading bmodel, we can run the inference of neuron network.
|
157 |
+
* Different with bmrt_load_bmodel, bmodel is the data in host memory.
|
158 |
+
*
|
159 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
160 |
+
* @param [in] bmodel_data Bmodel data pointer to buffer
|
161 |
+
* @param [in] size Bmodel data size
|
162 |
+
*
|
163 |
+
* @retval true Load context sucess.
|
164 |
+
* @retval false Load context failed.
|
165 |
+
*/
|
166 |
+
DECL_EXPORT bool bmrt_load_bmodel_data(void* p_bmrt, const void * bmodel_data, size_t size);
|
167 |
+
|
168 |
+
/**
|
169 |
+
* @name bmrt_show_neuron_network
|
170 |
+
* @brief To print the name of all neuron network
|
171 |
+
* @ingroup bmruntime
|
172 |
+
*
|
173 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
174 |
+
*/
|
175 |
+
DECL_EXPORT void bmrt_show_neuron_network(void* p_bmrt);
|
176 |
+
|
177 |
+
/**
|
178 |
+
* @name bmrt_get_network_number
|
179 |
+
* @brief To get the number of neuron network in the bmruntime
|
180 |
+
* @ingroup bmruntime
|
181 |
+
*
|
182 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
183 |
+
*
|
184 |
+
* @retval int value The number of neuron networks.
|
185 |
+
*/
|
186 |
+
DECL_EXPORT int bmrt_get_network_number(void* p_bmrt);
|
187 |
+
|
188 |
+
/**
|
189 |
+
* @name bmrt_get_network_names
|
190 |
+
* @brief To get the names of all neuron network in the bmruntime
|
191 |
+
* @ingroup bmruntime
|
192 |
+
*
|
193 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
194 |
+
* @param [out] network_names The names of all neuron networks. It should be declare as (const char** networks_ = NULL),
|
195 |
+
* and use as the param &networks_. After this API, user need to free(networks_) if user
|
196 |
+
* do not need it.
|
197 |
+
*/
|
198 |
+
DECL_EXPORT void bmrt_get_network_names(void* p_bmrt, const char*** network_names);
|
199 |
+
|
200 |
+
/**
|
201 |
+
* @name bmrt_get_network_info
|
202 |
+
* @brief To get network info by net name
|
203 |
+
* @ingroup bmruntime
|
204 |
+
*
|
205 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
206 |
+
* @param [in] net_name Network name
|
207 |
+
*
|
208 |
+
* @retval bm_net_info_t* Pointer to net info, needn't free by user; if net name not found, will return NULL.
|
209 |
+
*/
|
210 |
+
DECL_EXPORT const bm_net_info_t* bmrt_get_network_info(void* p_bmrt, const char* net_name);
|
211 |
+
|
212 |
+
/**
|
213 |
+
* @name bmrt_launch_tensor
|
214 |
+
* @brief To launch the inference of the neuron network with setting input tensors
|
215 |
+
* @ingroup bmruntime
|
216 |
+
*
|
217 |
+
* This API supports the neuron nework that is static-compiled or dynamic-compiled
|
218 |
+
* After calling this API, inference on TPU is launched. And the CPU program will not
|
219 |
+
* be blocked. bm_thread_sync should be called to make sure inference finished.
|
220 |
+
* This API support multiple inputs, and multi thread safety
|
221 |
+
*
|
222 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
223 |
+
* @param [in] net_name The name of the neuron network
|
224 |
+
* @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num].
|
225 |
+
* User should initialize each input tensor.
|
226 |
+
* @param [in] input_num Input number
|
227 |
+
* @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num].
|
228 |
+
* This interface will alloc devcie mem to store output data. User should free each
|
229 |
+
* device mem by bm_free_device after the result data not used.
|
230 |
+
* @param [in] output_num Output number
|
231 |
+
*
|
232 |
+
* @retval true Launch success.
|
233 |
+
* @retval false Launch failed.
|
234 |
+
*/
|
235 |
+
DECL_EXPORT bool bmrt_launch_tensor(void* p_bmrt, const char * net_name, const bm_tensor_t input_tensors[], int input_num,
|
236 |
+
bm_tensor_t output_tensors[], int output_num);
|
237 |
+
|
238 |
+
/**
|
239 |
+
* @name bmrt_launch_tensor_ex
|
240 |
+
* @brief To launch the inference of the neuron network with setting input tensors
|
241 |
+
* @ingroup bmruntime
|
242 |
+
*
|
243 |
+
* This API supports the neuron nework that is static-compiled or dynamic-compiled
|
244 |
+
* After calling this API, inference on TPU is launched. And the CPU program will not
|
245 |
+
* be blocked. bm_thread_sync should be called to make sure inference finished.
|
246 |
+
* This API support multiple inputs, and multi thread safety
|
247 |
+
*
|
248 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
249 |
+
* @param [in] net_name The name of the neuron network
|
250 |
+
* @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num],
|
251 |
+
* User should initialize each input tensor.
|
252 |
+
* @param [in] input_num Input number
|
253 |
+
* @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num].
|
254 |
+
* User can set device_mem or stmode of output tensors. If user_mem is true, this interface
|
255 |
+
* will use device mem of output_tensors to store output data, and not alloc device mem;
|
256 |
+
* Or it will alloc device mem to store output. If user_stmode is true, it will use stmode in
|
257 |
+
* each output tensor; Or stmode will be BM_STORE_1N as default.
|
258 |
+
* @param [in] output_num Output number
|
259 |
+
* @param [in] user_mem whether device_mem of output tensors are set
|
260 |
+
* @param [in] user_stmode whether stmode of output tensors are set
|
261 |
+
*
|
262 |
+
* @retval true Launch success.
|
263 |
+
* @retval false Launch failed.
|
264 |
+
*/
|
265 |
+
DECL_EXPORT bool bmrt_launch_tensor_ex(void* p_bmrt, const char * net_name, const bm_tensor_t input_tensors[], int input_num,
|
266 |
+
bm_tensor_t output_tensors[], int output_num, bool user_mem, bool user_stmode);
|
267 |
+
|
268 |
+
/**
|
269 |
+
* @name bmrt_launch_data
|
270 |
+
* @brief To launch the inference of the neuron network with setting input datas in system memory
|
271 |
+
* @ingroup bmruntime
|
272 |
+
*
|
273 |
+
* This API supports the neuron nework that is static-compiled or dynamic-compiled
|
274 |
+
* After calling this API, inference on TPU is launched. And the CPU
|
275 |
+
* program will be blocked.
|
276 |
+
* This API support multiple inputs, and multi thread safety
|
277 |
+
*
|
278 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
279 |
+
* @param [in] net_name The name of the neuron network
|
280 |
+
* @param [in] input_datas Array of input data, defined like void * input_datas[input_num]. User should
|
281 |
+
* initialize each data pointer as input.
|
282 |
+
* @param [in] input_shapes Array of input shape, defined like bm_shape_t input_shapes[input_num].
|
283 |
+
* User should set each input shape
|
284 |
+
* @param [in] input_num Input number
|
285 |
+
* @param [out] output_datas Array of output data, defined like void * output_datas[output_num].
|
286 |
+
* If user don't alloc each output data, set user_mem to false, and this api will alloc
|
287 |
+
* output mem, user should free each output mem when output data not used. Also
|
288 |
+
* user can alloc system memory for each output data by self and set user_mem = true.
|
289 |
+
* @param [out] output_shapes Array of output shape, defined like bm_shape_t output_shapes[output_num].
|
290 |
+
* It will store each output shape.
|
291 |
+
* @param [in] output_num Output number
|
292 |
+
* @param [in] user_mem whether output_datas[i] have allocated memory
|
293 |
+
*
|
294 |
+
* @retval true Launch success.
|
295 |
+
* @retval false Launch failed.
|
296 |
+
*/
|
297 |
+
DECL_EXPORT bool bmrt_launch_data(void* p_bmrt, const char* net_name, void* const input_datas[],
|
298 |
+
const bm_shape_t input_shapes[], int input_num, void * output_datas[],
|
299 |
+
bm_shape_t output_shapes[], int output_num, bool user_mem);
|
300 |
+
|
301 |
+
/**
|
302 |
+
* @name bmrt_trace
|
303 |
+
* @brief To check runtime environment, and collect info for DEBUG
|
304 |
+
* @ingroup bmruntime
|
305 |
+
*
|
306 |
+
* This API is to collect runtime info for DEBUG. Expecially when launch result sudden mistake, call bmrt_trace
|
307 |
+
* will show whether device mems are broken, and other check info.
|
308 |
+
*
|
309 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
310 |
+
*/
|
311 |
+
DECL_EXPORT void bmrt_trace(void* p_bmrt);
|
312 |
+
|
313 |
+
/**
|
314 |
+
* @name bmrt_launch_tensor_multi_cores
|
315 |
+
* @brief To launch the inference of the neuron network with setting input tensors, and support multi core inference.
|
316 |
+
* @ingroup bmruntime
|
317 |
+
*
|
318 |
+
* This API supports the neuron nework that is static-compiled or dynamic-compiled
|
319 |
+
* After calling this API, inference on TPU is launched. And the CPU program will not
|
320 |
+
* be blocked. bm_thread_sync_from_core should be called to make sure inference is finished.
|
321 |
+
* This API support multiple inputs, and multi thread safety
|
322 |
+
*
|
323 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
324 |
+
* @param [in] net_name The name of the neuron network
|
325 |
+
* @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num],
|
326 |
+
* User should initialize each input tensor.
|
327 |
+
* @param [in] input_num Input number
|
328 |
+
* @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num].
|
329 |
+
* User can set device_mem or stmode of output tensors. If user_mem is true, this interface
|
330 |
+
* will use device mem of output_tensors to store output data, and not alloc device mem;
|
331 |
+
* Or it will alloc device mem to store output. If user_stmode is true, it will use stmode in
|
332 |
+
* each output tensor; Or stmode will be BM_STORE_1N as default.
|
333 |
+
* @param [in] output_num Output number
|
334 |
+
* @param [in] user_mem whether device_mem of output tensors are set
|
335 |
+
* @param [in] user_stmode whether stmode of output tensors are set
|
336 |
+
* @param [in] core_list core id list those will be used to inference
|
337 |
+
* @param [in] core_num number of the core list
|
338 |
+
*
|
339 |
+
* @retval true Launch success.
|
340 |
+
* @retval false Launch failed.
|
341 |
+
*/
|
342 |
+
DECL_EXPORT bool bmrt_launch_tensor_multi_cores(
|
343 |
+
void *p_bmrt,
|
344 |
+
const char *net_name,
|
345 |
+
const bm_tensor_t input_tensors[],
|
346 |
+
int input_num,
|
347 |
+
bm_tensor_t output_tensors[],
|
348 |
+
int output_num,
|
349 |
+
bool user_mem,
|
350 |
+
bool user_stmode,
|
351 |
+
const int *core_list,
|
352 |
+
int core_num);
|
353 |
+
|
354 |
+
/**
|
355 |
+
* @name bmrt_memcpy_s2d_parallel
|
356 |
+
* @brief To copy data from system memory to muti-devices memory in parallel
|
357 |
+
* @ingroup bmruntime
|
358 |
+
*
|
359 |
+
* This API only could be used when the p_bmrt is created with bmrt_create_ex on multi devices.
|
360 |
+
* After calling this API, datas[:tensor_num[0]] will be copied to the first device, and
|
361 |
+
* datas[tensor_num[0]:tensor_num[0]+tensor_num[1]] will be copied to the second device and so on.
|
362 |
+
* The process of copying data to different devices is done in parallel and to the same device is in sequence.
|
363 |
+
*
|
364 |
+
* @param [in] p_bmrt Bmruntime that had been created with multi bm_handles
|
365 |
+
* @param [in] tensors Array of tensors that will be copied to devices
|
366 |
+
* @param [in] datas Array of satas allocated in system memory
|
367 |
+
* @param [in] tensor_num Array of tensor_num that will be copied to each device
|
368 |
+
* @param [in] device_num Device number
|
369 |
+
*/
|
370 |
+
DECL_EXPORT bool bmrt_memcpy_s2d_parallel(
|
371 |
+
void *p_bmrt,
|
372 |
+
bm_tensor_t tensors[],
|
373 |
+
void *datas[],
|
374 |
+
int tensor_num[],
|
375 |
+
int device_num);
|
376 |
+
|
377 |
+
/**
|
378 |
+
* @name bmrt_memcpy_d2s_parallel
|
379 |
+
* @brief To copy data from muti-devices memory to system memory in parallel
|
380 |
+
* @ingroup bmruntime
|
381 |
+
*
|
382 |
+
* This API only could be used when the p_bmrt is created with bmrt_create_ex on multi devices.
|
383 |
+
* After calling this API, tensors on the first device will be copied to datas[:tensor_num[0]] , and
|
384 |
+
* tensors on the second device will be copied to datas[tensor_num[0]:tensor_num[0]+tensor_num[1]] and so on.
|
385 |
+
* The process of copying data from different devices is done in parallel and from the same device is in sequence.
|
386 |
+
*
|
387 |
+
* @param [in] p_bmrt Bmruntime that had been created with multi bm_handles
|
388 |
+
* @param [in] datas Array of satas allocated in system memory
|
389 |
+
* @param [in] tensors Array of tensors that will be copied from devices
|
390 |
+
* @param [in] tensor_num Array of tensor_num that will be copied from each device
|
391 |
+
* @param [in] device_num Device number
|
392 |
+
*/
|
393 |
+
DECL_EXPORT bool bmrt_memcpy_d2s_parallel(
|
394 |
+
void *p_bmrt,
|
395 |
+
void *datas[],
|
396 |
+
bm_tensor_t tensors[],
|
397 |
+
int tensor_num[],
|
398 |
+
int device_num);
|
399 |
+
|
400 |
+
#if defined (__cplusplus)
|
401 |
+
}
|
402 |
+
#endif
|
403 |
+
|
404 |
+
#endif
|
Baichuan2/src/include/sentencepiece/sentencepiece_processor.h
ADDED
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2016 Google Inc.
|
2 |
+
//
|
3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
// you may not use this file except in compliance with the License.
|
5 |
+
// You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software
|
10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
// See the License for the specific language governing permissions and
|
13 |
+
// limitations under the License.!
|
14 |
+
|
15 |
+
#ifndef SENTENCEPIECE_PROCESSOR_H_
|
16 |
+
#define SENTENCEPIECE_PROCESSOR_H_
|
17 |
+
|
18 |
+
#include <cstring>
|
19 |
+
#include <memory>
|
20 |
+
#include <string>
|
21 |
+
#include <string_view>
|
22 |
+
#include <utility>
|
23 |
+
#include <vector>
|
24 |
+
|
25 |
+
#ifndef SWIG
|
26 |
+
namespace absl {
|
27 |
+
using std::string_view;
|
28 |
+
} // namespace absl
|
29 |
+
#endif // SWIG
|
30 |
+
|
31 |
+
namespace sentencepiece {
|
32 |
+
namespace util {
|
33 |
+
|
34 |
+
enum class StatusCode : int {
|
35 |
+
kOk = 0,
|
36 |
+
kCancelled = 1,
|
37 |
+
kUnknown = 2,
|
38 |
+
kInvalidArgument = 3,
|
39 |
+
kDeadlineExceeded = 4,
|
40 |
+
kNotFound = 5,
|
41 |
+
kAlreadyExists = 6,
|
42 |
+
kPermissionDenied = 7,
|
43 |
+
kResourceExhausted = 8,
|
44 |
+
kFailedPrecondition = 9,
|
45 |
+
kAborted = 10,
|
46 |
+
kOutOfRange = 11,
|
47 |
+
kUnimplemented = 12,
|
48 |
+
kInternal = 13,
|
49 |
+
kUnavailable = 14,
|
50 |
+
kDataLoss = 15,
|
51 |
+
kUnauthenticated = 16,
|
52 |
+
};
|
53 |
+
|
54 |
+
class Status {
|
55 |
+
public:
|
56 |
+
Status();
|
57 |
+
~Status();
|
58 |
+
Status(StatusCode code, absl::string_view error_message);
|
59 |
+
Status(const Status &s);
|
60 |
+
void operator=(const Status &s);
|
61 |
+
bool operator==(const Status &s) const;
|
62 |
+
bool operator!=(const Status &s) const;
|
63 |
+
inline bool ok() const { return rep_ == nullptr; }
|
64 |
+
|
65 |
+
void set_error_message(const char *str);
|
66 |
+
const char *error_message() const;
|
67 |
+
const char *message() const { return error_message(); }
|
68 |
+
StatusCode code() const;
|
69 |
+
std::string ToString() const;
|
70 |
+
|
71 |
+
void IgnoreError();
|
72 |
+
|
73 |
+
private:
|
74 |
+
struct Rep;
|
75 |
+
std::unique_ptr<Rep> rep_;
|
76 |
+
};
|
77 |
+
} // namespace util
|
78 |
+
|
79 |
+
// SentencePieceProcessor:
|
80 |
+
// Simple and language independent tokenizer and de-tokenizer for
|
81 |
+
// Neural Network Machine Translation.
|
82 |
+
//
|
83 |
+
// SentencePieceProcessor provides Encode() and Decode() methods,
|
84 |
+
// which correspond to tokenization and de-tokenization respectively.
|
85 |
+
//
|
86 |
+
// - Encode:
|
87 |
+
// Given a raw source sentence, encode it into a sequence
|
88 |
+
// of pieces or vocabulary ids.
|
89 |
+
//
|
90 |
+
// - Decode:
|
91 |
+
// Given a sequence of pieces or vocabulary ids, decode it
|
92 |
+
// into a de-tokenized raw sentence.
|
93 |
+
//
|
94 |
+
// SentencePieceProcessor provides a lossless data conversion
|
95 |
+
// that allows the original raw sentence to be perfectly reconstructed
|
96 |
+
// from the encoded data, i.e., Decode(Encode(input)) == input.
|
97 |
+
// This characteristics is useful, as we can make the de-tokenization
|
98 |
+
// completely language independent.
|
99 |
+
//
|
100 |
+
// Usage:
|
101 |
+
// SentencePieceProcessor sp;
|
102 |
+
// sp.Load("//path/to/model");
|
103 |
+
//
|
104 |
+
// vector<string> sps;
|
105 |
+
// sp.Encode("hello world.", &sps).IgnoreError();
|
106 |
+
//
|
107 |
+
// vector<int> ids;
|
108 |
+
// sp.Encode("hello world.", &ids).IgnoreError();
|
109 |
+
//
|
110 |
+
// string detok;
|
111 |
+
// sp.Decode(sps, &detok);
|
112 |
+
// CHECK_EQ("hello world.", detok).IgnoreError();
|
113 |
+
//
|
114 |
+
// sp.Decode(ids, &detok);
|
115 |
+
// CHECK_EQ("hello world.", detok).IgnoreError();
|
116 |
+
//
|
117 |
+
// We can also use SentencePieceText which manages the byte-offsets
|
118 |
+
// between user input (output) and internal sentence pieces.
|
119 |
+
//
|
120 |
+
// SentencePieceText spt;
|
121 |
+
// sp.Encode("hello world.", &spt);
|
122 |
+
// // Emits the byte range of each piece.
|
123 |
+
// for (const auto &piece : spt.pieces()) {
|
124 |
+
// LOG(INFO) << piece.begin() << " " << piece.end();
|
125 |
+
// }
|
126 |
+
//
|
127 |
+
// sp.Decode({0, 1, 2, 3..}, &spt);
|
128 |
+
// for (const auto &piece : spt.pieces()) {
|
129 |
+
// LOG(INFO) << piece.begin() << " " << piece.end();
|
130 |
+
// }
|
131 |
+
//
|
132 |
+
|
133 |
+
class NBestSentencePieceText;
|
134 |
+
class ModelInterface;
|
135 |
+
class SentencePieceText;
|
136 |
+
class ModelProto;
|
137 |
+
|
138 |
+
namespace normalizer {
|
139 |
+
class Normalizer;
|
140 |
+
} // namespace normalizer
|
141 |
+
|
142 |
+
#ifndef SWIGGO
|
143 |
+
namespace util {
|
144 |
+
// Redefine std::string for serialized_proto interface as Python's string is
|
145 |
+
// a Unicode string. We can enforce the return value to be raw byte sequence
|
146 |
+
// with SWIG's typemap.
|
147 |
+
using bytes = std::string;
|
148 |
+
} // namespace util
|
149 |
+
#endif // SWIGGO
|
150 |
+
|
151 |
+
class NBestSentencePieceText;
|
152 |
+
class ModelInterface;
|
153 |
+
class SentencePieceText;
|
154 |
+
class SentencePieceText_SentencePiece;
|
155 |
+
|
156 |
+
// Wrapper class of SentencePieceText
|
157 |
+
// This wrapper only allows an immutable access to the proto and
|
158 |
+
// hides the actual implementation of protobuf.
|
159 |
+
// See sentencepiece.proto for the details of this class.
|
160 |
+
class ImmutableSentencePieceText_ImmutableSentencePiece {
|
161 |
+
public:
|
162 |
+
ImmutableSentencePieceText_ImmutableSentencePiece();
|
163 |
+
~ImmutableSentencePieceText_ImmutableSentencePiece() = default;
|
164 |
+
|
165 |
+
const std::string &piece() const;
|
166 |
+
const std::string &surface() const;
|
167 |
+
uint32_t id() const;
|
168 |
+
uint32_t begin() const;
|
169 |
+
uint32_t end() const;
|
170 |
+
|
171 |
+
friend class ImmutableSentencePieceText;
|
172 |
+
|
173 |
+
private:
|
174 |
+
explicit ImmutableSentencePieceText_ImmutableSentencePiece(
|
175 |
+
const SentencePieceText_SentencePiece &sp);
|
176 |
+
const SentencePieceText_SentencePiece *sp_ = nullptr;
|
177 |
+
};
|
178 |
+
|
179 |
+
class ImmutableSentencePieceText {
|
180 |
+
public:
|
181 |
+
ImmutableSentencePieceText();
|
182 |
+
virtual ~ImmutableSentencePieceText();
|
183 |
+
|
184 |
+
std::vector<ImmutableSentencePieceText_ImmutableSentencePiece> pieces() const;
|
185 |
+
|
186 |
+
size_t pieces_size() const;
|
187 |
+
ImmutableSentencePieceText_ImmutableSentencePiece pieces(int index) const;
|
188 |
+
|
189 |
+
const std::string &text() const;
|
190 |
+
float score() const;
|
191 |
+
|
192 |
+
util::bytes SerializeAsString() const;
|
193 |
+
|
194 |
+
// Returns the actual mutable proto.
|
195 |
+
// Do not use this outside of SentencePieceProcessor, as
|
196 |
+
// it returns the raw pointer managed by the shared_ptr.
|
197 |
+
SentencePieceText *mutable_proto();
|
198 |
+
|
199 |
+
// Converts the utf8 byte spans into Unicode char span.
|
200 |
+
void ConvertToUnicodeSpans();
|
201 |
+
|
202 |
+
friend class ImmutableNBestSentencePieceText;
|
203 |
+
|
204 |
+
private:
|
205 |
+
explicit ImmutableSentencePieceText(const SentencePieceText &spt);
|
206 |
+
const SentencePieceText *spt_ = nullptr;
|
207 |
+
std::shared_ptr<SentencePieceText> rep_;
|
208 |
+
};
|
209 |
+
|
210 |
+
// Wrapper class of SentencePieceText
|
211 |
+
// This wrapper only allows an immutable access to the proto and
|
212 |
+
// hides the actual implementation of protobuf.
|
213 |
+
// See sentencepiece.proto for the details of this class.
|
214 |
+
class ImmutableNBestSentencePieceText {
|
215 |
+
public:
|
216 |
+
ImmutableNBestSentencePieceText();
|
217 |
+
virtual ~ImmutableNBestSentencePieceText();
|
218 |
+
|
219 |
+
std::vector<ImmutableSentencePieceText> nbests() const;
|
220 |
+
|
221 |
+
size_t nbests_size() const;
|
222 |
+
ImmutableSentencePieceText nbests(int index) const;
|
223 |
+
|
224 |
+
util::bytes SerializeAsString() const;
|
225 |
+
|
226 |
+
// Returns the actual mutable proto.
|
227 |
+
// Do not use this outside of SentencePieceProcessor, as
|
228 |
+
// it returns the raw pointer managed by the shared_ptr.
|
229 |
+
NBestSentencePieceText *mutable_proto();
|
230 |
+
|
231 |
+
void ConvertToUnicodeSpans();
|
232 |
+
|
233 |
+
private:
|
234 |
+
std::shared_ptr<NBestSentencePieceText> rep_;
|
235 |
+
};
|
236 |
+
|
237 |
+
class SentencePieceProcessor {
|
238 |
+
public:
|
239 |
+
SentencePieceProcessor();
|
240 |
+
virtual ~SentencePieceProcessor();
|
241 |
+
|
242 |
+
// Loads model from `filename`.
|
243 |
+
// Returns false if `filename` cannot be loaded.
|
244 |
+
virtual util::Status Load(absl::string_view filename);
|
245 |
+
|
246 |
+
// Loads model from `filename`.
|
247 |
+
// Crash if `filename` cannot be loaded.
|
248 |
+
virtual void LoadOrDie(absl::string_view filename);
|
249 |
+
|
250 |
+
// Loads model from `model_proto`.
|
251 |
+
// `model_proto` is copied.
|
252 |
+
virtual util::Status Load(const ModelProto &model_proto);
|
253 |
+
|
254 |
+
// Loads model from `model_proto`.
|
255 |
+
// `model_proto` is moved.
|
256 |
+
virtual util::Status Load(std::unique_ptr<ModelProto> model_proto);
|
257 |
+
|
258 |
+
// Loads model from `serialized`, which is a string-serialized model proto.
|
259 |
+
// Useful to load the model from a platform independent blob object.
|
260 |
+
virtual util::Status LoadFromSerializedProto(absl::string_view serialized);
|
261 |
+
|
262 |
+
// Returns the status. Encode/Decode methods are valid when status is OK.
|
263 |
+
virtual util::Status status() const;
|
264 |
+
|
265 |
+
// Sets encode extra_option sequence.
|
266 |
+
virtual util::Status SetEncodeExtraOptions(absl::string_view extra_option);
|
267 |
+
|
268 |
+
// Sets decode extra_option sequence.
|
269 |
+
virtual util::Status SetDecodeExtraOptions(absl::string_view extra_option);
|
270 |
+
|
271 |
+
//////////////////////////////////////////////////////////////
|
272 |
+
// Vocabulary restriction.
|
273 |
+
// Background:
|
274 |
+
// https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt
|
275 |
+
|
276 |
+
// Restricts the vocabulary set.
|
277 |
+
// The input sentences are encoded into the tokens in `valid_vocab`.
|
278 |
+
virtual util::Status SetVocabulary(
|
279 |
+
const std::vector<absl::string_view> &valid_vocab);
|
280 |
+
|
281 |
+
// Reverts the vocabulary restriction.
|
282 |
+
virtual util::Status ResetVocabulary();
|
283 |
+
|
284 |
+
// Loads the valid vocabulary set from `filename` in TSV format.
|
285 |
+
// Format: <token> <tab> <freq>.
|
286 |
+
// Any token with frequency < threshold will be treated as OOV.
|
287 |
+
virtual util::Status LoadVocabulary(absl::string_view filename,
|
288 |
+
int threshold);
|
289 |
+
|
290 |
+
//////////////////////////////////////////////////////////////
|
291 |
+
// Simple Encode and Decode API.
|
292 |
+
//
|
293 |
+
// Given a UTF8 input, encodes it into a sequence of sentence pieces.
|
294 |
+
virtual util::Status Encode(absl::string_view input,
|
295 |
+
std::vector<std::string> *pieces) const;
|
296 |
+
|
297 |
+
// Given a UTF8 input, encodes it into a sequence of ids.
|
298 |
+
virtual util::Status Encode(absl::string_view input,
|
299 |
+
std::vector<int> *ids) const;
|
300 |
+
|
301 |
+
// Given a sequence of pieces, decodes it into a detokenized output.
|
302 |
+
virtual util::Status Decode(const std::vector<std::string> &pieces,
|
303 |
+
std::string *detokenized) const;
|
304 |
+
|
305 |
+
// Given a sequence of pieces, decodes it into a detokenized output.
|
306 |
+
virtual util::Status Decode(const std::vector<absl::string_view> &pieces,
|
307 |
+
std::string *detokenized) const;
|
308 |
+
|
309 |
+
// Given a sequence of ids, decodes it into a detokenized output.
|
310 |
+
virtual util::Status Decode(const std::vector<int> &ids,
|
311 |
+
std::string *detokenized) const;
|
312 |
+
|
313 |
+
//////////////////////////////////////////////////////////////
|
314 |
+
// NBest API.
|
315 |
+
//
|
316 |
+
// Same as Encode, but returns nbest results.
|
317 |
+
virtual util::Status NBestEncode(
|
318 |
+
absl::string_view input, int nbest_size,
|
319 |
+
std::vector<std::vector<std::string>> *pieces) const;
|
320 |
+
|
321 |
+
// Same as Encode, but returns nbest results.
|
322 |
+
virtual util::Status NBestEncode(absl::string_view input, int nbest_size,
|
323 |
+
std::vector<std::vector<int>> *ids) const;
|
324 |
+
|
325 |
+
//////////////////////////////////////////////////////////////
|
326 |
+
// Sampling API.
|
327 |
+
//
|
328 |
+
// Unigram and BPE support sampling mode.
|
329 |
+
// - Unigram (--model_type=unigram):
|
330 |
+
// `nbest_size`: When `nbest_size` is positive value, approximately samples
|
331 |
+
// one segmentation from nbest candidates. When `nbest_size` is negative
|
332 |
+
// value, samples one segmentation from the hypotheses (Lattice) according to
|
333 |
+
// the generation probabilities using forward-filtering and backward-sampling
|
334 |
+
// algorithm.
|
335 |
+
// `alpha`: Smoothing parameter (inverse temperature). The best segmentation
|
336 |
+
// (Viterbi segmentation) is more likely sampled when setting larger alpha.
|
337 |
+
// When alpha is 0.0, one segmentation is uniformly sampled from the nbest or
|
338 |
+
// lattice. `nbest_size` and `alpha` correspond to parameters `l` and `alpha`
|
339 |
+
// in https://arxiv.org/abs/1804.10959 (nbest_size < 0 means l = infinity)
|
340 |
+
//
|
341 |
+
// - BPE (--model_type=bpe):
|
342 |
+
// `alpha`: The dropout probability `p` of bpe merge operations in
|
343 |
+
// https://arxiv.org/abs/1910.13267 Nbest-based sampling is not supported so
|
344 |
+
// nbest_size parameter is ignored in BPE.
|
345 |
+
virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
|
346 |
+
float alpha,
|
347 |
+
std::vector<std::string> *pieces) const;
|
348 |
+
|
349 |
+
// Same as above, but returns a sequence of ids.
|
350 |
+
virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
|
351 |
+
float alpha, std::vector<int> *ids) const;
|
352 |
+
|
353 |
+
//////////////////////////////////////////////////////////////
|
354 |
+
// SampleEncodeAndScore API.
|
355 |
+
//
|
356 |
+
// Sample `samples` many tokenisations from the segmentation lattice.
|
357 |
+
// These methods are only available in model_type=unigram.
|
358 |
+
//
|
359 |
+
// `alpha`: smoothing parameter (inverse temperature). The same as `alpha` in
|
360 |
+
// `Sample` method.
|
361 |
+
// 'wor`: If `wor` is true, the samples are taken without replacement, and the
|
362 |
+
// scores are the inclusion probabilities of the elements in the sample;
|
363 |
+
// otherwise the samples are taken with replacement and the scores are the
|
364 |
+
// log-probs of sample elements
|
365 |
+
// `include_best`: If `include_best` is true, the best tokenisation is always
|
366 |
+
// included in the sample, and the remaining elements are sampled excluding
|
367 |
+
// the best.
|
368 |
+
virtual util::Status SampleEncodeAndScore(
|
369 |
+
absl::string_view input, int num_samples, float alpha, bool wor,
|
370 |
+
bool include_best,
|
371 |
+
std::vector<std::pair<std::vector<std::string>, float>> *pieces) const;
|
372 |
+
|
373 |
+
// Same as above, but returns a sequence of ids.
|
374 |
+
virtual util::Status SampleEncodeAndScore(
|
375 |
+
absl::string_view input, int num_samples, float alpha, bool wor,
|
376 |
+
bool include_best,
|
377 |
+
std::vector<std::pair<std::vector<int>, float>> *ids) const;
|
378 |
+
|
379 |
+
//////////////////////////////////////////////////////////////
|
380 |
+
// Entropy API.
|
381 |
+
//
|
382 |
+
// This only available in model_type=unigram.
|
383 |
+
// Calculate entropy of possible tokenisations
|
384 |
+
virtual util::Status CalculateEntropy(absl::string_view input, float alpha,
|
385 |
+
float *entropy) const;
|
386 |
+
|
387 |
+
//////////////////////////////////////////////////////////////
|
388 |
+
// Advanced API returning SentencePieceText, which manages
|
389 |
+
// utf8-byte alignments between user-input/detokenized text
|
390 |
+
// and internal sentencepiece sequence.
|
391 |
+
//
|
392 |
+
// Given a UTF8 input, encodes it into SentencePieceText.
|
393 |
+
//
|
394 |
+
// When using these APIs, sentencepiece.pb.h header files must be included.
|
395 |
+
// We can also use ImutableSentencePieceText as follows.
|
396 |
+
//
|
397 |
+
// ImmutableSentencePieceText spt;
|
398 |
+
// Encode("hello", spt.mutable_proto()).IgnoreError();
|
399 |
+
// std::cout << spt.pieces_size() << std::endl;
|
400 |
+
virtual util::Status Encode(absl::string_view input,
|
401 |
+
SentencePieceText *spt) const;
|
402 |
+
|
403 |
+
virtual util::Status NBestEncode(absl::string_view input, int nbest_size,
|
404 |
+
NBestSentencePieceText *nbest_spt) const;
|
405 |
+
|
406 |
+
virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
|
407 |
+
float alpha, SentencePieceText *spt) const;
|
408 |
+
|
409 |
+
virtual util::Status SampleEncodeAndScore(
|
410 |
+
absl::string_view input, int num_samples, float alpha, bool wor,
|
411 |
+
bool include_best, NBestSentencePieceText *samples_spt) const;
|
412 |
+
|
413 |
+
// DEPRECATED: Remove this API and use std::vector<std::string_view>
|
414 |
+
virtual util::Status Decode(const std::vector<std::string> &pieces,
|
415 |
+
SentencePieceText *spt) const;
|
416 |
+
|
417 |
+
virtual util::Status Decode(const std::vector<absl::string_view> &pieces,
|
418 |
+
SentencePieceText *spt) const;
|
419 |
+
|
420 |
+
virtual util::Status Decode(const std::vector<int> &ids,
|
421 |
+
SentencePieceText *spt) const;
|
422 |
+
#ifdef SWIG
|
423 |
+
#define SPP_SWIG_CHECK_AND_THROW \
|
424 |
+
if (!status.ok()) throw status;
|
425 |
+
#else
|
426 |
+
#define SPP_SWIG_CHECK_AND_THROW \
|
427 |
+
if (!status.ok()) { \
|
428 |
+
}
|
429 |
+
#endif // SWIG
|
430 |
+
|
431 |
+
#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
|
432 |
+
OutType output; \
|
433 |
+
const auto status = FuncName(__VA_ARGS__, &output); \
|
434 |
+
SPP_SWIG_CHECK_AND_THROW; \
|
435 |
+
return output;
|
436 |
+
|
437 |
+
#define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...) \
|
438 |
+
OutType output; \
|
439 |
+
const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
|
440 |
+
SPP_SWIG_CHECK_AND_THROW; \
|
441 |
+
return output.SerializeAsString();
|
442 |
+
|
443 |
+
#define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...) \
|
444 |
+
OutType output; \
|
445 |
+
const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
|
446 |
+
SPP_SWIG_CHECK_AND_THROW; \
|
447 |
+
return output;
|
448 |
+
|
449 |
+
//////////////////////////////////////////////////////////////
|
450 |
+
// Handy methods that return the result directly.
|
451 |
+
// These functions ignore internal errors.
|
452 |
+
virtual std::vector<std::string> EncodeAsPieces(
|
453 |
+
absl::string_view input) const {
|
454 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<std::string>, input);
|
455 |
+
}
|
456 |
+
|
457 |
+
virtual std::vector<int> EncodeAsIds(absl::string_view input) const {
|
458 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<int>, input);
|
459 |
+
}
|
460 |
+
|
461 |
+
virtual std::vector<std::vector<std::string>> NBestEncodeAsPieces(
|
462 |
+
absl::string_view input, int nbest_size) const {
|
463 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(
|
464 |
+
NBestEncode, std::vector<std::vector<std::string>>, input, nbest_size);
|
465 |
+
}
|
466 |
+
|
467 |
+
virtual std::vector<std::vector<int>> NBestEncodeAsIds(
|
468 |
+
absl::string_view input, int nbest_size) const {
|
469 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(NBestEncode, std::vector<std::vector<int>>,
|
470 |
+
input, nbest_size);
|
471 |
+
}
|
472 |
+
|
473 |
+
virtual std::vector<std::string> SampleEncodeAsPieces(absl::string_view input,
|
474 |
+
int nbest_size,
|
475 |
+
float alpha) const {
|
476 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<std::string>, input,
|
477 |
+
nbest_size, alpha);
|
478 |
+
}
|
479 |
+
|
480 |
+
virtual std::vector<int> SampleEncodeAsIds(absl::string_view input,
|
481 |
+
int nbest_size,
|
482 |
+
float alpha) const {
|
483 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<int>, input,
|
484 |
+
nbest_size, alpha);
|
485 |
+
}
|
486 |
+
|
487 |
+
virtual std::vector<std::pair<std::vector<std::string>, float>>
|
488 |
+
SampleEncodeAndScoreAsPieces(absl::string_view input, int num_samples,
|
489 |
+
float alpha, bool wor, bool include_best) const {
|
490 |
+
using _T = std::vector<std::pair<std::vector<std::string>, float>>;
|
491 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples,
|
492 |
+
alpha, wor, include_best);
|
493 |
+
}
|
494 |
+
|
495 |
+
virtual std::vector<std::pair<std::vector<int>, float>>
|
496 |
+
SampleEncodeAndScoreAsIds(absl::string_view input, int num_samples,
|
497 |
+
float alpha, bool wor, bool include_best) const {
|
498 |
+
using _T = std::vector<std::pair<std::vector<int>, float>>;
|
499 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples,
|
500 |
+
alpha, wor, include_best);
|
501 |
+
}
|
502 |
+
|
503 |
+
// DEPRECATED: Remove this API and use std::vector<std::string_view>
|
504 |
+
virtual std::string DecodePieces(
|
505 |
+
const std::vector<std::string> &pieces) const {
|
506 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces);
|
507 |
+
}
|
508 |
+
|
509 |
+
virtual std::string DecodePieces(
|
510 |
+
const std::vector<absl::string_view> &pieces) const {
|
511 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces);
|
512 |
+
}
|
513 |
+
|
514 |
+
virtual std::string DecodeIds(const std::vector<int> &ids) const {
|
515 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, ids);
|
516 |
+
}
|
517 |
+
|
518 |
+
virtual float CalculateEntropy(absl::string_view text, float alpha) const {
|
519 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(CalculateEntropy, float, text, alpha);
|
520 |
+
}
|
521 |
+
|
522 |
+
//////////////////////////////////////////////////////////////
|
523 |
+
// SerializedProto API. (DEPRECATED). Use ImmutableProto API.
|
524 |
+
// They are used in Python interface. Returns serialized proto.
|
525 |
+
// In python module, we can get access to the full Proto after
|
526 |
+
// deserialzing the returned byte sequence.
|
527 |
+
virtual util::bytes EncodeAsSerializedProto(absl::string_view input) const {
|
528 |
+
DEFINE_SPP_SERIALIZED_PROTO_IMPL(Encode, ImmutableSentencePieceText, input);
|
529 |
+
}
|
530 |
+
|
531 |
+
virtual util::bytes SampleEncodeAsSerializedProto(absl::string_view input,
|
532 |
+
int nbest_size,
|
533 |
+
float alpha) const {
|
534 |
+
DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText,
|
535 |
+
input, nbest_size, alpha);
|
536 |
+
}
|
537 |
+
|
538 |
+
virtual util::bytes NBestEncodeAsSerializedProto(absl::string_view input,
|
539 |
+
int nbest_size) const {
|
540 |
+
DEFINE_SPP_SERIALIZED_PROTO_IMPL(
|
541 |
+
NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size);
|
542 |
+
}
|
543 |
+
|
544 |
+
virtual util::bytes SampleEncodeAndScoreAsSerializedProto(
|
545 |
+
absl::string_view input, int num_samples, float alpha, bool wor,
|
546 |
+
bool include_best) const {
|
547 |
+
DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncodeAndScore,
|
548 |
+
ImmutableNBestSentencePieceText, input,
|
549 |
+
num_samples, alpha, wor, include_best);
|
550 |
+
}
|
551 |
+
|
552 |
+
// TODO(taku): Remove this API and use std::vector<std::string_view>
|
553 |
+
virtual util::bytes DecodePiecesAsSerializedProto(
|
554 |
+
const std::vector<std::string> &pieces) const {
|
555 |
+
DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText,
|
556 |
+
pieces);
|
557 |
+
}
|
558 |
+
|
559 |
+
virtual util::bytes DecodePiecesAsSerializedProto(
|
560 |
+
const std::vector<absl::string_view> &pieces) const {
|
561 |
+
DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText,
|
562 |
+
pieces);
|
563 |
+
}
|
564 |
+
|
565 |
+
virtual util::bytes DecodeIdsAsSerializedProto(
|
566 |
+
const std::vector<int> &ids) const {
|
567 |
+
DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids);
|
568 |
+
}
|
569 |
+
|
570 |
+
//////////////////////////////////////////////////////////////
|
571 |
+
// ImmutableProto API.
|
572 |
+
virtual ImmutableSentencePieceText EncodeAsImmutableProto(
|
573 |
+
absl::string_view input) const {
|
574 |
+
DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Encode, ImmutableSentencePieceText, input);
|
575 |
+
}
|
576 |
+
|
577 |
+
virtual ImmutableSentencePieceText SampleEncodeAsImmutableProto(
|
578 |
+
absl::string_view input, int nbest_size, float alpha) const {
|
579 |
+
DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText,
|
580 |
+
input, nbest_size, alpha);
|
581 |
+
}
|
582 |
+
|
583 |
+
virtual ImmutableNBestSentencePieceText NBestEncodeAsImmutableProto(
|
584 |
+
absl::string_view input, int nbest_size) const {
|
585 |
+
DEFINE_SPP_IMMUTABLE_PROTO_IMPL(
|
586 |
+
NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size);
|
587 |
+
}
|
588 |
+
|
589 |
+
virtual ImmutableNBestSentencePieceText SampleEncodeAndScoreAsImmutableProto(
|
590 |
+
absl::string_view input, int num_samples, float alpha, bool wor,
|
591 |
+
bool include_best) const {
|
592 |
+
DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncodeAndScore,
|
593 |
+
ImmutableNBestSentencePieceText, input,
|
594 |
+
num_samples, alpha, wor, include_best);
|
595 |
+
}
|
596 |
+
|
597 |
+
// TODO(taku): Remove this API and use std::vector<std::string_view>
|
598 |
+
virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto(
|
599 |
+
const std::vector<std::string> &pieces) const {
|
600 |
+
DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces);
|
601 |
+
}
|
602 |
+
|
603 |
+
virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto(
|
604 |
+
const std::vector<absl::string_view> &pieces) const {
|
605 |
+
DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces);
|
606 |
+
}
|
607 |
+
|
608 |
+
virtual ImmutableSentencePieceText DecodeIdsAsImmutableProto(
|
609 |
+
const std::vector<int> &ids) const {
|
610 |
+
DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids);
|
611 |
+
}
|
612 |
+
|
613 |
+
#undef DEFINE_SPP_DIRECT_FUNC_IMPL
|
614 |
+
#undef DEFINE_SPP_SERIALIZED_PROTO_IMPL
|
615 |
+
#undef DEFINE_SPP_IMMUTABLE_PROTO_IMPL
|
616 |
+
|
617 |
+
//////////////////////////////////////////////////////////////
|
618 |
+
// Vocabulary management methods.
|
619 |
+
//
|
620 |
+
// Returns the size of sentence pieces, which is the same as
|
621 |
+
// the size of vocabulary for NMT.
|
622 |
+
virtual int GetPieceSize() const;
|
623 |
+
|
624 |
+
// Returns the vocab id of `piece`.
|
625 |
+
// Returns UNK(0) if `piece` is unknown.
|
626 |
+
virtual int PieceToId(absl::string_view piece) const;
|
627 |
+
|
628 |
+
// Returns the string representation of vocab with `id`.
|
629 |
+
virtual const std::string &IdToPiece(int id) const;
|
630 |
+
|
631 |
+
// Returns the score of `id`.
|
632 |
+
// Usually score is an emission log probability of unigram language
|
633 |
+
// model.
|
634 |
+
virtual float GetScore(int id) const;
|
635 |
+
|
636 |
+
// Returns true if `id` is unknown symbol.
|
637 |
+
virtual bool IsUnknown(int id) const;
|
638 |
+
|
639 |
+
// Returns true if `id` is control symbol.
|
640 |
+
virtual bool IsControl(int id) const;
|
641 |
+
|
642 |
+
// Returns true if `id` is unused symbol.
|
643 |
+
virtual bool IsUnused(int id) const;
|
644 |
+
|
645 |
+
// Returns true if `id` is byte symbol.
|
646 |
+
virtual bool IsByte(int id) const;
|
647 |
+
|
648 |
+
// Returns the reserved id.
|
649 |
+
// Returns -1 if not defined.
|
650 |
+
|
651 |
+
// Returns unknown (<unk>) id.
|
652 |
+
virtual int unk_id() const;
|
653 |
+
|
654 |
+
// Returns BOS (<s>) id.
|
655 |
+
virtual int bos_id() const;
|
656 |
+
|
657 |
+
// Returns EOS (</s>) id.
|
658 |
+
virtual int eos_id() const;
|
659 |
+
|
660 |
+
// Returns PAD (<pad>) id.
|
661 |
+
virtual int pad_id() const;
|
662 |
+
|
663 |
+
//////////////////////////////////////////////////////////////
|
664 |
+
// Model management.
|
665 |
+
//
|
666 |
+
// Allows injection of a mock model instance. `model` is moved.
|
667 |
+
void SetModel(std::unique_ptr<ModelInterface> &&model);
|
668 |
+
|
669 |
+
// Allows injection of a normalizer instance. `normalizer` is moved.
|
670 |
+
void SetNormalizer(std::unique_ptr<normalizer::Normalizer> &&normalizer);
|
671 |
+
|
672 |
+
// Returns immutable model proto. Useful to obtain extended
|
673 |
+
// or experimental parameters encoded in model_proto.
|
674 |
+
const ModelProto &model_proto() const;
|
675 |
+
|
676 |
+
// returns immutable model proto as std::string.
|
677 |
+
// Useful to save the state of this instance via Python's pickle object.
|
678 |
+
util::bytes serialized_model_proto() const;
|
679 |
+
|
680 |
+
private:
|
681 |
+
enum ExtraOption { REVERSE, BOS, EOS, UNK_PIECE };
|
682 |
+
|
683 |
+
util::Status ParseExtraOptions(absl::string_view extra_option,
|
684 |
+
std::vector<ExtraOption> *extra_options) const;
|
685 |
+
|
686 |
+
util::Status ApplyExtraOptions(const std::vector<ExtraOption> &extra_options,
|
687 |
+
SentencePieceText *spt) const;
|
688 |
+
|
689 |
+
util::Status PopulateSentencePieceText(
|
690 |
+
absl::string_view input, absl::string_view normalized,
|
691 |
+
const std::vector<size_t> &norm_to_orig,
|
692 |
+
const std::vector<std::pair<absl::string_view, int>> &result,
|
693 |
+
SentencePieceText *spt) const;
|
694 |
+
|
695 |
+
std::unique_ptr<ModelInterface> model_;
|
696 |
+
std::unique_ptr<normalizer::Normalizer> normalizer_;
|
697 |
+
std::unique_ptr<normalizer::Normalizer> denormalizer_;
|
698 |
+
|
699 |
+
// Underlying model protocol buffer. The same lifetime as model_.
|
700 |
+
std::unique_ptr<ModelProto> model_proto_;
|
701 |
+
|
702 |
+
std::vector<ExtraOption> encode_extra_options_;
|
703 |
+
std::vector<ExtraOption> decode_extra_options_;
|
704 |
+
};
|
705 |
+
|
706 |
+
// Set seed value of random generator.
|
707 |
+
// Do not set static_cast<unique_int>(-1),
|
708 |
+
// as this seed is reserved for initializing from
|
709 |
+
// std::random_device.
|
710 |
+
void SetRandomGeneratorSeed(unsigned int seed);
|
711 |
+
|
712 |
+
// IO related functions to absorb model formats.
|
713 |
+
namespace io {
|
714 |
+
// Loads `model_proto` from `filename`.
|
715 |
+
// We can instantiate SentencePieceProcessor as follows:
|
716 |
+
//
|
717 |
+
// auto model_proto = absl::make_unique<ModelProto>();
|
718 |
+
// io::LoadModelProto("//path/spm.model", model_proto.get());
|
719 |
+
// SentencePieceProcessor sp;
|
720 |
+
// CHECK_OK(sp.Load(std::move(model_proto)));
|
721 |
+
util::Status LoadModelProto(absl::string_view, ModelProto *model_proto);
|
722 |
+
|
723 |
+
// Saves `model_proto` as `filename`.
|
724 |
+
util::Status SaveModelProto(absl::string_view, const ModelProto &model_proto);
|
725 |
+
} // namespace io
|
726 |
+
} // namespace sentencepiece
|
727 |
+
#endif // SENTENCEPIECE_PROCESSOR_H_
|
Baichuan2/src/lib_pcie/libbmlib.so
ADDED
Binary file (195 kB). View file
|
|
Baichuan2/src/lib_pcie/libbmrt.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:621e33823dca470275e09570324a567ce4a30fa6100ac9e52742bb9e1ee02f45
|
3 |
+
size 2966400
|
Baichuan2/src/lib_pcie/libbmrt.so.1.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:621e33823dca470275e09570324a567ce4a30fa6100ac9e52742bb9e1ee02f45
|
3 |
+
size 2966400
|
Baichuan2/src/lib_pcie/libsentencepiece.a
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:68811cd99e6e1a58572372f14f3b7a02cf98bc98f5d46d24c406be65a94b53e8
|
3 |
+
size 2858304
|
Baichuan2/src/lib_soc/libbmlib.so
ADDED
Binary file (191 kB). View file
|
|
Baichuan2/src/lib_soc/libbmrt.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cff807807fcc8c6a9d16353e389422d434ae2b79c8bc191266d0eb5a69b3d97d
|
3 |
+
size 2915352
|
Baichuan2/src/lib_soc/libbmrt.so.1.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cff807807fcc8c6a9d16353e389422d434ae2b79c8bc191266d0eb5a69b3d97d
|
3 |
+
size 2915352
|
Baichuan2/src/lib_soc/libsentencepiece.a
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b1c1ece6c62265ee879cf5876d31e82580c3ee88c2cb627b8ac3eaf35695bde
|
3 |
+
size 3032062
|
Baichuan2/web_demo/CMakeLists.txt
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cmake_minimum_required(VERSION 2.8)
|
2 |
+
project(baichuan2)
|
3 |
+
|
4 |
+
if (NOT DEFINED TARGET_ARCH)
|
5 |
+
set(TARGET_ARCH pcie)
|
6 |
+
endif()
|
7 |
+
|
8 |
+
set(CMAKE_INSTALL_PREFIX install)
|
9 |
+
|
10 |
+
if (${CMAKE_HOST_SYSTEM_PROCESSOR} STREQUAL "aarch64")
|
11 |
+
add_definitions(-DSOC_TARGET)
|
12 |
+
link_directories(${PROJECT_SOURCE_DIR}/../src/lib_soc)
|
13 |
+
message("SoC mode, starting......")
|
14 |
+
elseif (${TARGET_ARCH} STREQUAL "pcie")
|
15 |
+
add_definitions(-DPCIE_TARGET)
|
16 |
+
link_directories(${PROJECT_SOURCE_DIR}/../src/lib_pcie)
|
17 |
+
message("Pcie mode, starting......")
|
18 |
+
elseif (${TARGET_ARCH} STREQUAL "soc")
|
19 |
+
add_definitions(-DSOC_TARGET)
|
20 |
+
set(CMAKE_C_COMPILER aarch64-linux-gnu-gcc)
|
21 |
+
set(CMAKE_ASM_COMPILER aarch64-linux-gnu-gcc)
|
22 |
+
set(CMAKE_CXX_COMPILER aarch64-linux-gnu-g++)
|
23 |
+
link_directories(${PROJECT_SOURCE_DIR}/lib_soc)
|
24 |
+
message("SoC mode, starting......")
|
25 |
+
endif()
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
include_directories(${PROJECT_SOURCE_DIR}/../src/include)
|
31 |
+
|
32 |
+
add_definitions(-DDEBUG --std=c++17 -fPIC -Wall -Werror)
|
33 |
+
set(CMAKE_BUILD_TYPE "Debug")
|
34 |
+
|
35 |
+
add_library(tpuchat SHARED chat.cpp)
|
36 |
+
target_link_libraries(tpuchat bmrt bmlib sentencepiece)
|
Baichuan2/web_demo/chat.cpp
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
//===----------------------------------------------------------------------===//
|
2 |
+
//
|
3 |
+
// Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
|
4 |
+
//
|
5 |
+
// TPU-MLIR is licensed under the 2-Clause BSD License except for the
|
6 |
+
// third-party components.
|
7 |
+
//
|
8 |
+
//===----------------------------------------------------------------------===//
|
9 |
+
|
10 |
+
#include <iostream>
|
11 |
+
#include <cstdlib>
|
12 |
+
#include <vector>
|
13 |
+
#include <assert.h>
|
14 |
+
#include <chrono>
|
15 |
+
#include <algorithm>
|
16 |
+
#include "memory.h"
|
17 |
+
#include "sentencepiece/sentencepiece_processor.h"
|
18 |
+
#include "bmruntime_interface.h"
|
19 |
+
#include <getopt.h>
|
20 |
+
|
21 |
+
static const int NUM_LAYERS = 32;
|
22 |
+
static const int MAX_LEN = 512;
|
23 |
+
static const float ATTENTION_MASK = -1000.;
|
24 |
+
|
25 |
+
static const std::string TOKENIZER_MODEL = "tokenizer.model";
|
26 |
+
|
27 |
+
// #define EXPORT_RESULTS
|
28 |
+
#ifdef EXPORT_RESULTS
|
29 |
+
#include "cnpy.h"
|
30 |
+
static cnpy::npz_t map;
|
31 |
+
|
32 |
+
template <typename T>
|
33 |
+
static void add_array(std::string name, bm_handle_t bm_handle,
|
34 |
+
const bm_device_mem_t &dst) {
|
35 |
+
std::vector<T> data(dst.size / sizeof(T));
|
36 |
+
bm_memcpy_d2s(bm_handle, data.data(), dst);
|
37 |
+
cnpy::npz_add_array(map, name, data);
|
38 |
+
}
|
39 |
+
|
40 |
+
static void save_array(std::string filename) {
|
41 |
+
cnpy::npz_save_all(filename, map);
|
42 |
+
}
|
43 |
+
#endif
|
44 |
+
|
45 |
+
class Baichuan2 {
|
46 |
+
public:
|
47 |
+
void init(int devid, const std::string model, const std::string tokenizer_path);
|
48 |
+
void chat();
|
49 |
+
void deinit();
|
50 |
+
std::string name;
|
51 |
+
std::string history = "";
|
52 |
+
int round = 0;
|
53 |
+
int token_length;
|
54 |
+
int EOS;
|
55 |
+
std::string predict_next_token();
|
56 |
+
std::string predict_first_token(const std::string &input_str);
|
57 |
+
|
58 |
+
private:
|
59 |
+
int forward_first(std::vector<int> &tokens);
|
60 |
+
int forward_next();
|
61 |
+
void load_sentencepiece(const std::string &tokenizer_path);
|
62 |
+
|
63 |
+
private:
|
64 |
+
std::vector<bm_handle_t> handles;
|
65 |
+
bm_handle_t bm_handle;
|
66 |
+
void *p_bmrt;
|
67 |
+
sentencepiece::SentencePieceProcessor sentencepiece;
|
68 |
+
const bm_net_info_t *net_blocks[NUM_LAYERS];
|
69 |
+
const bm_net_info_t *net_blocks_cache[NUM_LAYERS];
|
70 |
+
const bm_net_info_t *net_embed;
|
71 |
+
const bm_net_info_t *net_lm;
|
72 |
+
bm_tensor_t inputs_embed_512, outputs_embed_512;
|
73 |
+
bm_tensor_t inputs_lm, outputs_lm;
|
74 |
+
bm_tensor_t inputs_pid, next_pid, inputs_attention, next_attention;
|
75 |
+
bm_tensor_t past_key[NUM_LAYERS], past_value[NUM_LAYERS];
|
76 |
+
bm_tensor_t present_key[NUM_LAYERS], present_value[NUM_LAYERS];
|
77 |
+
bm_tensor_t present_key_cache, present_value_cache;
|
78 |
+
std::string name_embed;
|
79 |
+
std::string name_lm;
|
80 |
+
std::string name_blocks[NUM_LAYERS];
|
81 |
+
std::string name_blocks_cache[NUM_LAYERS];
|
82 |
+
};
|
83 |
+
|
84 |
+
void Baichuan2::load_sentencepiece(const std::string &model) {
|
85 |
+
printf("Load %s ... ", model.c_str());
|
86 |
+
auto status = sentencepiece.Load(model);
|
87 |
+
if (!status.ok()) {
|
88 |
+
std::cout << status.ToString() << std::endl;
|
89 |
+
exit(-1);
|
90 |
+
}
|
91 |
+
EOS = sentencepiece.eos_id();
|
92 |
+
printf("Done!\n");
|
93 |
+
}
|
94 |
+
|
95 |
+
void Baichuan2::init(int devid, const std::string model, const std::string tokenizer_path) {
|
96 |
+
load_sentencepiece(tokenizer_path);
|
97 |
+
// request bm_handle
|
98 |
+
bm_status_t status = bm_dev_request(&bm_handle, devid);
|
99 |
+
assert(BM_SUCCESS == status);
|
100 |
+
|
101 |
+
// create bmruntime
|
102 |
+
p_bmrt = bmrt_create(bm_handle);
|
103 |
+
assert(NULL != p_bmrt);
|
104 |
+
|
105 |
+
// load bmodel by file
|
106 |
+
printf("Model[%s] loading ....\n", model.c_str());
|
107 |
+
bool ret = bmrt_load_bmodel(p_bmrt, model.c_str());
|
108 |
+
assert(true == ret);
|
109 |
+
printf("Done!\n");
|
110 |
+
// net names
|
111 |
+
name_embed = "embedding";
|
112 |
+
name_lm = "lm_head";
|
113 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
114 |
+
name_blocks[i] = "block_" + std::to_string(i);
|
115 |
+
name_blocks_cache[i] = "block_cache_" + std::to_string(i);
|
116 |
+
}
|
117 |
+
|
118 |
+
// net infos
|
119 |
+
net_embed = bmrt_get_network_info(p_bmrt, name_embed.c_str());
|
120 |
+
net_lm = bmrt_get_network_info(p_bmrt, name_lm.c_str());
|
121 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
122 |
+
net_blocks[i] = bmrt_get_network_info(p_bmrt, name_blocks[i].c_str());
|
123 |
+
net_blocks_cache[i] =
|
124 |
+
bmrt_get_network_info(p_bmrt, name_blocks_cache[i].c_str());
|
125 |
+
}
|
126 |
+
|
127 |
+
// net device mem
|
128 |
+
ret = bmrt_tensor(&inputs_embed_512, p_bmrt, net_embed->input_dtypes[0],
|
129 |
+
net_embed->stages[1].input_shapes[0]);
|
130 |
+
assert(true == ret);
|
131 |
+
|
132 |
+
ret = bmrt_tensor(&outputs_embed_512, p_bmrt, net_embed->output_dtypes[0],
|
133 |
+
net_embed->stages[1].output_shapes[0]);
|
134 |
+
assert(true == ret);
|
135 |
+
|
136 |
+
ret = bmrt_tensor(&inputs_pid, p_bmrt, net_blocks[0]->input_dtypes[1],
|
137 |
+
net_blocks[0]->stages[0].input_shapes[1]);
|
138 |
+
assert(true == ret);
|
139 |
+
|
140 |
+
ret = bmrt_tensor(&inputs_attention, p_bmrt, net_blocks[0]->input_dtypes[2],
|
141 |
+
net_blocks[0]->stages[0].input_shapes[2]);
|
142 |
+
assert(true == ret);
|
143 |
+
|
144 |
+
ret = bmrt_tensor(&next_pid, p_bmrt, net_blocks_cache[0]->input_dtypes[1],
|
145 |
+
net_blocks_cache[0]->stages[0].input_shapes[1]);
|
146 |
+
assert(true == ret);
|
147 |
+
|
148 |
+
ret =
|
149 |
+
bmrt_tensor(&next_attention, p_bmrt, net_blocks_cache[0]->input_dtypes[2],
|
150 |
+
net_blocks_cache[0]->stages[0].input_shapes[2]);
|
151 |
+
assert(true == ret);
|
152 |
+
|
153 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
154 |
+
ret = bmrt_tensor(&past_key[i], p_bmrt, net_blocks[0]->output_dtypes[1],
|
155 |
+
net_blocks[0]->stages[0].output_shapes[1]);
|
156 |
+
assert(true == ret);
|
157 |
+
ret = bmrt_tensor(&past_value[i], p_bmrt, net_blocks[0]->output_dtypes[2],
|
158 |
+
net_blocks[0]->stages[0].output_shapes[2]);
|
159 |
+
assert(true == ret);
|
160 |
+
ret = bmrt_tensor(&present_key[i], p_bmrt, net_blocks[0]->output_dtypes[1],
|
161 |
+
net_blocks[0]->stages[0].output_shapes[1]);
|
162 |
+
assert(true == ret);
|
163 |
+
ret = bmrt_tensor(&present_value[i], p_bmrt, net_blocks[0]->output_dtypes[2],
|
164 |
+
net_blocks[0]->stages[0].output_shapes[2]);
|
165 |
+
assert(true == ret);
|
166 |
+
}
|
167 |
+
ret = bmrt_tensor(&present_key_cache, p_bmrt, net_blocks_cache[0]->output_dtypes[1],
|
168 |
+
net_blocks_cache[0]->stages[0].output_shapes[1]);
|
169 |
+
assert(true == ret);
|
170 |
+
ret = bmrt_tensor(&present_value_cache, p_bmrt, net_blocks_cache[0]->output_dtypes[2],
|
171 |
+
net_blocks_cache[0]->stages[0].output_shapes[2]);
|
172 |
+
assert(true == ret);
|
173 |
+
|
174 |
+
ret = bmrt_tensor(&inputs_lm, p_bmrt, net_lm->input_dtypes[0],
|
175 |
+
net_lm->stages[0].input_shapes[0]);
|
176 |
+
assert(true == ret);
|
177 |
+
ret = bmrt_tensor(&outputs_lm, p_bmrt, net_lm->output_dtypes[0],
|
178 |
+
net_lm->stages[0].output_shapes[0]);
|
179 |
+
assert(true == ret);
|
180 |
+
}
|
181 |
+
|
182 |
+
void Baichuan2::deinit() {
|
183 |
+
bm_free_device(bm_handle, inputs_embed_512.device_mem);
|
184 |
+
bm_free_device(bm_handle, outputs_embed_512.device_mem);
|
185 |
+
bm_free_device(bm_handle, inputs_lm.device_mem);
|
186 |
+
bm_free_device(bm_handle, outputs_lm.device_mem);
|
187 |
+
bm_free_device(bm_handle, inputs_pid.device_mem);
|
188 |
+
bm_free_device(bm_handle, next_pid.device_mem);
|
189 |
+
bm_free_device(bm_handle, inputs_attention.device_mem);
|
190 |
+
bm_free_device(bm_handle, next_attention.device_mem);
|
191 |
+
bm_free_device(bm_handle, present_key_cache.device_mem);
|
192 |
+
bm_free_device(bm_handle, present_value_cache.device_mem);
|
193 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
194 |
+
bm_free_device(bm_handle, past_key[i].device_mem);
|
195 |
+
bm_free_device(bm_handle, past_value[i].device_mem);
|
196 |
+
bm_free_device(bm_handle, present_key[i].device_mem);
|
197 |
+
bm_free_device(bm_handle, present_value[i].device_mem);
|
198 |
+
}
|
199 |
+
bmrt_destroy(p_bmrt);
|
200 |
+
for (auto h : handles) {
|
201 |
+
bm_dev_free(h);
|
202 |
+
}
|
203 |
+
}
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
int Baichuan2::forward_first(std::vector<int> &tokens) {
|
208 |
+
int input_ids[MAX_LEN] = {0}; // start token
|
209 |
+
int position_id[MAX_LEN] = {0};
|
210 |
+
float attention_mask[MAX_LEN * MAX_LEN] = {0};
|
211 |
+
token_length = tokens.size();
|
212 |
+
|
213 |
+
std::copy(tokens.begin(), tokens.end(), input_ids);
|
214 |
+
for (int i = 0; i < token_length; i++) {
|
215 |
+
position_id[i] = i;
|
216 |
+
}
|
217 |
+
|
218 |
+
for (int i = 0; i < MAX_LEN; i++) {
|
219 |
+
for (int j = 0; j < MAX_LEN; j++) {
|
220 |
+
if (j <= i && i < token_length) {
|
221 |
+
} else {
|
222 |
+
attention_mask[i * MAX_LEN + j] = ATTENTION_MASK;
|
223 |
+
}
|
224 |
+
}
|
225 |
+
}
|
226 |
+
|
227 |
+
// forward embeding
|
228 |
+
bm_memcpy_s2d(bm_handle, inputs_embed_512.device_mem, (void *)input_ids);
|
229 |
+
auto ret =
|
230 |
+
bmrt_launch_tensor_ex(p_bmrt, name_embed.c_str(), &inputs_embed_512, 1,
|
231 |
+
&outputs_embed_512, 1, true, false);
|
232 |
+
assert(ret);
|
233 |
+
bm_thread_sync(bm_handle);
|
234 |
+
|
235 |
+
// forward blocks
|
236 |
+
bm_memcpy_s2d(bm_handle, inputs_pid.device_mem, (void *)position_id);
|
237 |
+
bm_memcpy_s2d(bm_handle, inputs_attention.device_mem, (void *)attention_mask);
|
238 |
+
auto inputs_embed = outputs_embed_512;
|
239 |
+
inputs_embed.shape = net_blocks[0]->stages[0].input_shapes[0];
|
240 |
+
bm_tensor_t inputs_block[3] = {inputs_embed, inputs_pid, inputs_attention};
|
241 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
242 |
+
bm_tensor_t outputs_block[3] = {inputs_embed, past_key[i], past_value[i]};
|
243 |
+
ret = bmrt_launch_tensor_ex(p_bmrt, name_blocks[i].c_str(), inputs_block, 3,
|
244 |
+
outputs_block, 3, true, false);
|
245 |
+
assert(ret);
|
246 |
+
bm_thread_sync(bm_handle);
|
247 |
+
}
|
248 |
+
int bytes = inputs_embed.device_mem.size / MAX_LEN;
|
249 |
+
bm_memcpy_d2d_byte(bm_handle, inputs_lm.device_mem, 0,
|
250 |
+
inputs_embed.device_mem, (token_length - 1) * bytes,
|
251 |
+
bytes);
|
252 |
+
ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm, 1,
|
253 |
+
&outputs_lm, 1, true, false);
|
254 |
+
bm_thread_sync(bm_handle);
|
255 |
+
|
256 |
+
int token = 0;
|
257 |
+
bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm.device_mem);
|
258 |
+
return token;
|
259 |
+
}
|
260 |
+
|
261 |
+
int Baichuan2::forward_next() {
|
262 |
+
float attention_mask[MAX_LEN + 1] = {0};
|
263 |
+
for (int i = token_length - 1; i < MAX_LEN; i++) {
|
264 |
+
attention_mask[i] = ATTENTION_MASK;
|
265 |
+
}
|
266 |
+
int32_t position_id = token_length - 1;
|
267 |
+
// embedding
|
268 |
+
outputs_lm.shape = net_embed->stages[0].input_shapes[0];
|
269 |
+
auto ret = bmrt_launch_tensor_ex(p_bmrt, name_embed.c_str(), &outputs_lm, 1,
|
270 |
+
&inputs_lm, 1, true, false);
|
271 |
+
assert(ret);
|
272 |
+
bm_thread_sync(bm_handle);
|
273 |
+
|
274 |
+
// blocks
|
275 |
+
bm_memcpy_s2d(bm_handle, next_attention.device_mem, (void *)attention_mask);
|
276 |
+
bm_memcpy_s2d(bm_handle, next_pid.device_mem, (void *)&position_id);
|
277 |
+
auto inputs_embed = inputs_lm;
|
278 |
+
inputs_embed.shape = net_blocks_cache[0]->stages[0].input_shapes[0];
|
279 |
+
int bytes = bm_mem_get_device_size(present_key_cache.device_mem);
|
280 |
+
int token_offset = (token_length - 1) * bytes;
|
281 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
282 |
+
bm_tensor_t inputs_block[5] = {inputs_embed, next_pid, next_attention,
|
283 |
+
past_key[i], past_value[i]};
|
284 |
+
bm_tensor_t outputs_block[3] = {inputs_embed, present_key_cache, present_value_cache};
|
285 |
+
ret = bmrt_launch_tensor_ex(p_bmrt, name_blocks_cache[i].c_str(),
|
286 |
+
inputs_block, 5, outputs_block, 3, true, false);
|
287 |
+
assert(ret);
|
288 |
+
bm_thread_sync(bm_handle);
|
289 |
+
bm_memcpy_d2d_byte(bm_handle, past_key[i].device_mem, token_offset,
|
290 |
+
present_key_cache.device_mem, 0,
|
291 |
+
bytes);
|
292 |
+
bm_memcpy_d2d_byte(bm_handle, past_value[i].device_mem, token_offset,
|
293 |
+
present_value_cache.device_mem, 0,
|
294 |
+
bytes);
|
295 |
+
}
|
296 |
+
outputs_lm.shape = net_lm->stages[0].output_shapes[0];
|
297 |
+
ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm, 1,
|
298 |
+
&outputs_lm, 1, true, false);
|
299 |
+
bm_thread_sync(bm_handle);
|
300 |
+
|
301 |
+
int token = 0;
|
302 |
+
bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm.device_mem);
|
303 |
+
return token;
|
304 |
+
}
|
305 |
+
|
306 |
+
|
307 |
+
std::string Baichuan2::predict_first_token(const std::string &input_str) {
|
308 |
+
history = input_str;
|
309 |
+
//int tok_num = 1;
|
310 |
+
std::vector<int> tokens;
|
311 |
+
sentencepiece.Encode(history, &tokens);
|
312 |
+
tokens.insert(tokens.begin(), 1);
|
313 |
+
if (tokens.empty()) {
|
314 |
+
round = 0;
|
315 |
+
history = "Sorry: your question is too wierd!!\n";
|
316 |
+
return history;
|
317 |
+
}
|
318 |
+
// make sure token not too large
|
319 |
+
if (tokens.size() > MAX_LEN - 10) {
|
320 |
+
// reset
|
321 |
+
if (round == 0) {
|
322 |
+
history = "Error: your question is too large!\n";
|
323 |
+
return history;
|
324 |
+
}
|
325 |
+
round = 0;
|
326 |
+
history = "";
|
327 |
+
return predict_first_token(input_str);
|
328 |
+
}
|
329 |
+
int token = forward_first(tokens);
|
330 |
+
int pre_token = 0;
|
331 |
+
std::string pre_word;
|
332 |
+
std::string word;
|
333 |
+
std::vector<int> pre_ids = {pre_token};
|
334 |
+
std::vector<int> ids = {pre_token,token};
|
335 |
+
sentencepiece.Decode(pre_ids, &pre_word);
|
336 |
+
sentencepiece.Decode(ids, &word);
|
337 |
+
std::string diff = word.substr(pre_word.size());
|
338 |
+
#ifdef PRINT
|
339 |
+
printf("token %d",token);
|
340 |
+
printf("diff %s",diff.c_str());
|
341 |
+
#endif
|
342 |
+
history += diff;
|
343 |
+
if (token_length < MAX_LEN) {
|
344 |
+
token_length++;
|
345 |
+
}
|
346 |
+
return diff;
|
347 |
+
}
|
348 |
+
|
349 |
+
std::string Baichuan2::predict_next_token() {
|
350 |
+
int pre_token;
|
351 |
+
pre_token = 0;
|
352 |
+
int token = forward_next();
|
353 |
+
if(token == EOS){
|
354 |
+
round = 0;
|
355 |
+
history = history.substr(history.size()/2);
|
356 |
+
return "_GETEOS_";
|
357 |
+
}
|
358 |
+
std::string pre_word;
|
359 |
+
std::string word;
|
360 |
+
std::vector<int> pre_ids = {pre_token};
|
361 |
+
std::vector<int> ids = {pre_token, token};
|
362 |
+
sentencepiece.Decode(pre_ids, &pre_word);
|
363 |
+
sentencepiece.Decode(ids, &word);
|
364 |
+
std::string diff = word.substr(pre_word.size());
|
365 |
+
#ifdef PRINT
|
366 |
+
printf("token %d",token);
|
367 |
+
printf("diff %s",diff.c_str());
|
368 |
+
#endif
|
369 |
+
history += diff;
|
370 |
+
if (token_length < MAX_LEN) {
|
371 |
+
token_length++;
|
372 |
+
}else{
|
373 |
+
round = 0;
|
374 |
+
return "_GETMAX_";
|
375 |
+
}
|
376 |
+
return diff;
|
377 |
+
}
|
378 |
+
|
379 |
+
|
380 |
+
extern "C" {
|
381 |
+
|
382 |
+
|
383 |
+
Baichuan2 *Baichuan2_with_devid_and_model(int devid, const char *bmodel_path, const char *tokenizer_path) {
|
384 |
+
Baichuan2 *chat = new Baichuan2();
|
385 |
+
chat->init(devid, bmodel_path, tokenizer_path);
|
386 |
+
return chat;
|
387 |
+
}
|
388 |
+
|
389 |
+
void Baichuan2_delete(Baichuan2 *chat) { delete chat; }
|
390 |
+
|
391 |
+
void Baichuan2_deinit(Baichuan2 *chat) {
|
392 |
+
chat->deinit();
|
393 |
+
}
|
394 |
+
|
395 |
+
const char *get_history(Baichuan2 *chat) {
|
396 |
+
std::string str = chat->history;
|
397 |
+
return strdup(str.c_str());
|
398 |
+
}
|
399 |
+
|
400 |
+
const char *set_history(Baichuan2 *chat, const char *history) {
|
401 |
+
chat->history = history;
|
402 |
+
return strdup(history);
|
403 |
+
}
|
404 |
+
|
405 |
+
const char *Baichuan2_predict_first_token(Baichuan2 *chat, const char *input_str) {
|
406 |
+
std::string str = chat->predict_first_token(input_str);
|
407 |
+
return strdup(str.c_str());
|
408 |
+
}
|
409 |
+
|
410 |
+
const char *Baichuan2_predict_next_token(Baichuan2 *chat) {
|
411 |
+
std::string str = chat->predict_next_token();
|
412 |
+
return strdup(str.c_str());
|
413 |
+
}
|
414 |
+
|
415 |
+
const int get_eos(Baichuan2 *chat){
|
416 |
+
const int res = chat->EOS;
|
417 |
+
return res;
|
418 |
+
}
|
419 |
+
}
|
Baichuan2/web_demo/chat.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
|
3 |
+
import ctypes
|
4 |
+
|
5 |
+
|
6 |
+
class TokenWord(ctypes.Structure):
|
7 |
+
_fields_ = [
|
8 |
+
("token", ctypes.c_int),
|
9 |
+
("word", ctypes.c_char * 2048) # 假设最大长度为 100,你可以根据实际情况调整
|
10 |
+
]
|
11 |
+
|
12 |
+
|
13 |
+
class TPUChatglm:
|
14 |
+
def __init__(self):
|
15 |
+
self.lib = ctypes.cdll.LoadLibrary('./build/libtpuchat.so')
|
16 |
+
device_id = 3
|
17 |
+
bmodel_path = "../model/baichuan2-7b-test_int8.bmodel"
|
18 |
+
token_path = "../model/tokenizer.model"
|
19 |
+
self.device_id = device_id
|
20 |
+
self.bmodel_path = bmodel_path
|
21 |
+
self.token_path = token_path
|
22 |
+
self.libset()
|
23 |
+
self.init()
|
24 |
+
|
25 |
+
def libset(self):
|
26 |
+
self.lib.Baichuan2_with_devid_and_model.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p]
|
27 |
+
self.lib.Baichuan2_with_devid_and_model.restype = ctypes.c_void_p
|
28 |
+
|
29 |
+
self.lib.Baichuan2_delete.argtypes = [ctypes.c_void_p]
|
30 |
+
|
31 |
+
# deinit
|
32 |
+
self.lib.Baichuan2_deinit.argtypes = [ctypes.c_void_p]
|
33 |
+
|
34 |
+
# Baichuan2_predict_first_token
|
35 |
+
self.lib.Baichuan2_predict_first_token.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
36 |
+
self.lib.Baichuan2_predict_first_token.restype = ctypes.c_char_p
|
37 |
+
|
38 |
+
# Baichuan2_predict_next_token
|
39 |
+
self.lib.Baichuan2_predict_next_token.argtypes = [ctypes.c_void_p]
|
40 |
+
self.lib.Baichuan2_predict_next_token.restype = ctypes.c_char_p
|
41 |
+
|
42 |
+
# get_eos
|
43 |
+
self.lib.get_eos.argtypes = [ctypes.c_void_p]
|
44 |
+
self.lib.get_eos.restype = ctypes.c_int
|
45 |
+
# get_history
|
46 |
+
self.lib.get_history.argtypes = [ctypes.c_void_p]
|
47 |
+
self.lib.get_history.restype = ctypes.c_char_p
|
48 |
+
# set history
|
49 |
+
self.lib.set_history.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
50 |
+
|
51 |
+
def init(self):
|
52 |
+
self.obj = self.lib.Baichuan2_with_devid_and_model(self.device_id, self.bmodel_path.encode('utf-8'),
|
53 |
+
self.token_path.encode('utf-8'))
|
54 |
+
|
55 |
+
def predict_first_token(self, context):
|
56 |
+
return self.lib.Baichuan2_predict_first_token(self.obj, context.encode('utf-8')).decode('utf-8')
|
57 |
+
|
58 |
+
def predict_next_token(self):
|
59 |
+
return self.lib.Baichuan2_predict_next_token(self.obj).decode('utf-8')
|
60 |
+
|
61 |
+
def predict(self, context):
|
62 |
+
|
63 |
+
first_token = self.predict_first_token(context)
|
64 |
+
# print(first_token, end='')
|
65 |
+
res = ''
|
66 |
+
while True:
|
67 |
+
next_token = self.predict_next_token()
|
68 |
+
if next_token == '_GETMAX_' or next_token == '_GETEOS_':
|
69 |
+
# print(next_token)
|
70 |
+
break
|
71 |
+
# print(next_token, end='')
|
72 |
+
res += next_token
|
73 |
+
return res
|
74 |
+
|
75 |
+
def stream_predict(self, query, history):
|
76 |
+
history.append((query, ''))
|
77 |
+
|
78 |
+
prompt = ''
|
79 |
+
# for i, (old_query, response) in enumerate(history):
|
80 |
+
# prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
|
81 |
+
# prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
|
82 |
+
prompt = "<reserved_106>" + query + "<reserved_107>"
|
83 |
+
|
84 |
+
res = ''
|
85 |
+
first_token = self.predict_first_token(prompt)
|
86 |
+
res += first_token
|
87 |
+
|
88 |
+
while True:
|
89 |
+
next_token = self.predict_next_token()
|
90 |
+
if next_token == '_GETMAX_' or next_token == '_GETEOS_':
|
91 |
+
break
|
92 |
+
res += next_token
|
93 |
+
history[-1] = (query, res)
|
94 |
+
yield res, history
|
95 |
+
|
96 |
+
def get_config(self):
|
97 |
+
pass
|
Baichuan2/web_demo/web_demo.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import gradio as gr
|
3 |
+
import mdtex2html
|
4 |
+
from chat import TPUChatglm
|
5 |
+
|
6 |
+
|
7 |
+
def postprocess(self, y):
|
8 |
+
if y is None:
|
9 |
+
return []
|
10 |
+
for i, (message, response) in enumerate(y):
|
11 |
+
y[i] = (
|
12 |
+
None if message is None else mdtex2html.convert((message)),
|
13 |
+
None if response is None else mdtex2html.convert(response),
|
14 |
+
)
|
15 |
+
return y
|
16 |
+
|
17 |
+
|
18 |
+
gr.Chatbot.postprocess = postprocess
|
19 |
+
|
20 |
+
glm = TPUChatglm()
|
21 |
+
|
22 |
+
def parse_text(text):
|
23 |
+
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
|
24 |
+
lines = text.split("\n")
|
25 |
+
lines = [line for line in lines if line != ""]
|
26 |
+
count = 0
|
27 |
+
for i, line in enumerate(lines):
|
28 |
+
if "```" in line:
|
29 |
+
count += 1
|
30 |
+
items = line.split('`')
|
31 |
+
if count % 2 == 1:
|
32 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
33 |
+
else:
|
34 |
+
lines[i] = f'<br></code></pre>'
|
35 |
+
else:
|
36 |
+
if i > 0:
|
37 |
+
if count % 2 == 1:
|
38 |
+
line = line.replace("`", "\`")
|
39 |
+
line = line.replace("<", "<")
|
40 |
+
line = line.replace(">", ">")
|
41 |
+
line = line.replace(" ", " ")
|
42 |
+
line = line.replace("*", "*")
|
43 |
+
line = line.replace("_", "_")
|
44 |
+
line = line.replace("-", "-")
|
45 |
+
line = line.replace(".", ".")
|
46 |
+
line = line.replace("!", "!")
|
47 |
+
line = line.replace("(", "(")
|
48 |
+
line = line.replace(")", ")")
|
49 |
+
line = line.replace("$", "$")
|
50 |
+
lines[i] = "<br>" + line
|
51 |
+
text = "".join(lines)
|
52 |
+
return text
|
53 |
+
|
54 |
+
|
55 |
+
def gen(input, history):
|
56 |
+
i = 0
|
57 |
+
history.append((input, ''))
|
58 |
+
res = ''
|
59 |
+
while i < 10:
|
60 |
+
i += 1
|
61 |
+
res += str(i)
|
62 |
+
time.sleep(0.05)
|
63 |
+
history[-1] = (input, res)
|
64 |
+
yield res, history
|
65 |
+
|
66 |
+
|
67 |
+
def predict(input, chatbot, max_length, top_p, temperature, history):
|
68 |
+
|
69 |
+
chatbot.append((parse_text(input), ""))
|
70 |
+
for response, history in glm.stream_predict(input, history):
|
71 |
+
chatbot[-1] = (parse_text(input), parse_text(response))
|
72 |
+
yield chatbot, history
|
73 |
+
|
74 |
+
|
75 |
+
def reset_user_input():
|
76 |
+
return gr.update(value='')
|
77 |
+
|
78 |
+
|
79 |
+
def reset_state():
|
80 |
+
return [], [], None
|
81 |
+
|
82 |
+
|
83 |
+
with gr.Blocks() as demo:
|
84 |
+
gr.HTML("""<h1 align="center">Baichuan2-7B TPU</h1>""")
|
85 |
+
|
86 |
+
chatbot = gr.Chatbot()
|
87 |
+
with gr.Row():
|
88 |
+
with gr.Column(scale=4):
|
89 |
+
with gr.Column(scale=12):
|
90 |
+
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
|
91 |
+
container=False)
|
92 |
+
with gr.Column(min_width=32, scale=1):
|
93 |
+
submitBtn = gr.Button("Submit", variant="primary")
|
94 |
+
with gr.Column(scale=1):
|
95 |
+
emptyBtn = gr.Button("Clear History")
|
96 |
+
max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
|
97 |
+
top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
|
98 |
+
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
99 |
+
|
100 |
+
history = gr.State([])
|
101 |
+
|
102 |
+
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history],
|
103 |
+
[chatbot, history], show_progress=True)
|
104 |
+
submitBtn.click(reset_user_input, [], [user_input])
|
105 |
+
|
106 |
+
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
|
107 |
+
|
108 |
+
demo.queue().launch(share=True, server_name="0.0.0.0", inbrowser=True)
|
BaseModel/base_model.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
|
4 |
+
|
5 |
+
class BaseModel:
|
6 |
+
def __init__(self, args):
|
7 |
+
# parameters
|
8 |
+
self.EOS = None
|
9 |
+
self.SEQLEN = None
|
10 |
+
self.input_str = ""
|
11 |
+
self.system_prompt = ""
|
12 |
+
self.history = []
|
13 |
+
|
14 |
+
# devid
|
15 |
+
self.devices = [int(d) for d in args.devid.split(",")]
|
16 |
+
|
17 |
+
# load tokenizer
|
18 |
+
print("Load " + args.tokenizer_path + " ...")
|
19 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
20 |
+
args.tokenizer_path, trust_remote_code=True
|
21 |
+
)
|
22 |
+
|
23 |
+
# warm up
|
24 |
+
self.tokenizer.decode([0])
|
25 |
+
print("Done!")
|
26 |
+
|
27 |
+
def chat(self):
|
28 |
+
"""
|
29 |
+
Start a chat session.
|
30 |
+
"""
|
31 |
+
# check
|
32 |
+
if not self.EOS:
|
33 |
+
raise NotImplementedError("Forget to set End of Sentence Token Id(EOS)")
|
34 |
+
if not self.SEQLEN:
|
35 |
+
raise NotImplementedError("Forget to set End of Sentence Token Id")
|
36 |
+
|
37 |
+
# Instruct
|
38 |
+
print(
|
39 |
+
"""\n===========================================================
|
40 |
+
1. If you want to quit, please enter one of [q, quit, exit]
|
41 |
+
2. To create a new chat session, please enter one of [clear, new]
|
42 |
+
==========================================================="""
|
43 |
+
)
|
44 |
+
# Stop Chatting with "exit" input
|
45 |
+
while True:
|
46 |
+
self.input_str = input("\nQuestion: ")
|
47 |
+
# Quit
|
48 |
+
if self.input_str in ["exit", "q", "quit"]:
|
49 |
+
break
|
50 |
+
# New Chat
|
51 |
+
elif self.input_str in ["clear", "new"]:
|
52 |
+
self.clear()
|
53 |
+
# Chat
|
54 |
+
else:
|
55 |
+
tokens = self.encode_tokens()
|
56 |
+
|
57 |
+
# check tokens
|
58 |
+
if not tokens:
|
59 |
+
print("Sorry: your question is empty!!")
|
60 |
+
return
|
61 |
+
if len(tokens) > self.SEQLEN:
|
62 |
+
print(
|
63 |
+
"The maximum question length should be shorter than {} but we get {} instead.".format(
|
64 |
+
self.SEQLEN, len(tokens)
|
65 |
+
)
|
66 |
+
)
|
67 |
+
return
|
68 |
+
|
69 |
+
print("\nAnswer: ", end="")
|
70 |
+
self.stream_answer(tokens)
|
71 |
+
|
72 |
+
def stream_answer(self, tokens):
|
73 |
+
"""
|
74 |
+
Stream the answer for the given tokens.
|
75 |
+
"""
|
76 |
+
tok_num = 0
|
77 |
+
self.answer_cur = ""
|
78 |
+
self.answer_token = []
|
79 |
+
|
80 |
+
# First token
|
81 |
+
first_start = time.time()
|
82 |
+
token = self.forward_first(tokens)
|
83 |
+
first_end = time.time()
|
84 |
+
# Following tokens
|
85 |
+
while token != self.EOS and self.model.token_length < self.SEQLEN:
|
86 |
+
pre_word = self.decode_tokens([token])
|
87 |
+
word = self.decode_tokens([token, token])[len(pre_word):]
|
88 |
+
self.answer_token += [token]
|
89 |
+
print(word, flush=True, end="")
|
90 |
+
tok_num += 1
|
91 |
+
token = self.forward_next()
|
92 |
+
self.answer_cur = self.tokenizer.decode(self.answer_token)
|
93 |
+
|
94 |
+
# counting time
|
95 |
+
next_end = time.time()
|
96 |
+
first_duration = first_end - first_start
|
97 |
+
next_duration = next_end - first_end
|
98 |
+
tps = tok_num / next_duration
|
99 |
+
|
100 |
+
self.update_history()
|
101 |
+
|
102 |
+
print()
|
103 |
+
print(f"FTL: {first_duration:.3f} s")
|
104 |
+
print(f"TPS: {tps:.3f} token/s")
|
105 |
+
|
106 |
+
def stream_predict(self, query):
|
107 |
+
"""
|
108 |
+
Stream the prediction for the given query.
|
109 |
+
"""
|
110 |
+
self.answer_cur = ""
|
111 |
+
self.input_str = query
|
112 |
+
tokens = self.encode_tokens()
|
113 |
+
|
114 |
+
for answer_cur, history in self._generate_predictions(tokens):
|
115 |
+
yield answer_cur, history
|
116 |
+
|
117 |
+
def _generate_predictions(self, tokens):
|
118 |
+
"""
|
119 |
+
Generate predictions for the given tokens.
|
120 |
+
"""
|
121 |
+
# First token
|
122 |
+
next_token = self.forward_first(tokens)
|
123 |
+
output_tokens = [next_token]
|
124 |
+
|
125 |
+
# Following tokens
|
126 |
+
while True:
|
127 |
+
next_token = self.forward_next()
|
128 |
+
if next_token == self.EOS:
|
129 |
+
break
|
130 |
+
output_tokens += [next_token]
|
131 |
+
self.answer_cur = self.tokenizer.decode(output_tokens)
|
132 |
+
if self.model.token_length >= self.SEQLEN:
|
133 |
+
self.update_history()
|
134 |
+
yield self.answer_cur + "\n\n\nReached the maximum length; The history context has been cleared.", self.history
|
135 |
+
break
|
136 |
+
else:
|
137 |
+
yield self.answer_cur, self.history
|
138 |
+
|
139 |
+
self.update_history()
|
140 |
+
|
141 |
+
def forward_first(self, tokens):
|
142 |
+
"""
|
143 |
+
Forward the first token.
|
144 |
+
"""
|
145 |
+
token = self.model.forward_first(tokens)
|
146 |
+
return token
|
147 |
+
|
148 |
+
def forward_next(self):
|
149 |
+
"""
|
150 |
+
Forward the next token.
|
151 |
+
"""
|
152 |
+
token = self.model.forward_next()
|
153 |
+
return token
|
154 |
+
|
155 |
+
def decode_tokens(self, token):
|
156 |
+
"""
|
157 |
+
Decode the given token.
|
158 |
+
"""
|
159 |
+
word = self.tokenizer.decode(token, skip_special_tokens=True)
|
160 |
+
return word
|
161 |
+
|
162 |
+
def encode_tokens(self):
|
163 |
+
"""
|
164 |
+
Encode the input string to tokens.
|
165 |
+
"""
|
166 |
+
raise NotImplementedError
|
167 |
+
|
168 |
+
def load_model(self):
|
169 |
+
"""
|
170 |
+
Load the model.
|
171 |
+
"""
|
172 |
+
raise NotImplementedError
|
173 |
+
|
174 |
+
def clear(self):
|
175 |
+
"""
|
176 |
+
Clear the chat session.
|
177 |
+
"""
|
178 |
+
raise NotImplementedError
|
179 |
+
|
180 |
+
def update_history(self):
|
181 |
+
"""
|
182 |
+
Update chat history.
|
183 |
+
"""
|
184 |
+
raise NotImplementedError
|
ChatGLM2/README.md
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
![](./assets/sophgo_chip.png)
|
2 |
+
|
3 |
+
# ChatGLM2
|
4 |
+
|
5 |
+
本项目实现BM1684X部署语言大模型[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b)。通过[TPU-MLIR](https://github.com/sophgo/tpu-mlir)编译器将模型转换成bmodel,并采用c++代码将其部署到BM1684X的PCIE环境,或者SoC环境。
|
6 |
+
|
7 |
+
|
8 |
+
在知乎上写了关于`ChatGLM`的解读,方便大家理解源码:
|
9 |
+
|
10 |
+
[ChatGLM2流程解析与TPU-MLIR部署](https://zhuanlan.zhihu.com/p/641975976)
|
11 |
+
|
12 |
+
|
13 |
+
## 开发环境
|
14 |
+
|
15 |
+
|
16 |
+
1. 下载docker,启动容器,如下:
|
17 |
+
|
18 |
+
``` shell
|
19 |
+
docker pull sophgo/tpuc_dev:latest
|
20 |
+
|
21 |
+
# myname1234 is just an example, you can set your own name
|
22 |
+
docker run --privileged --name myname1234 -v $PWD:/workspace -it sophgo/tpuc_dev:latest
|
23 |
+
```
|
24 |
+
后文假定环境都在docker的`/workspace`目录。
|
25 |
+
|
26 |
+
|
27 |
+
2. 从Huggingface下载`ChatGLM2-6B`,比较大,会花较长时间
|
28 |
+
|
29 |
+
``` shell
|
30 |
+
git lfs install
|
31 |
+
git clone git@hf.co:THUDM/chatglm2-6b
|
32 |
+
```
|
33 |
+
并将本项目中./models/ChatGLM2/compile/files/chatglm2-6b中config.json与modeling_chatglm.py替换至上述下载后的文件夹中,并替换同名文件(其中需要采用其它sequence length的用户请参考[常见问题](#常见问题),默认sequence length = 512)
|
34 |
+
|
35 |
+
3. 下载`TPU-MLIR`代码并编译,(也可以直接下载编译好的release包解压)
|
36 |
+
|
37 |
+
目前由于mlir还在维护中,编译GLM系列模型的用户请下载
|
38 |
+
``` shell
|
39 |
+
pip3 install dfss
|
40 |
+
python3 -m dfss --url=open@sophgo.com:/ext_model_information/LLM/mlir_club/glm_mlir.tar.gz
|
41 |
+
tar -xf glm_mlir.tar.gz
|
42 |
+
source source tpu-mlir_v1.6.45-gdc3e9f6b-20231220/envsetup.sh
|
43 |
+
```
|
44 |
+
|
45 |
+
后续mlir维护完成后可以使用如下方式
|
46 |
+
``` shell
|
47 |
+
git clone git@github.com:sophgo/tpu-mlir.git
|
48 |
+
cd tpu-mlir
|
49 |
+
source ./envsetup.sh
|
50 |
+
./build.sh
|
51 |
+
```
|
52 |
+
|
53 |
+
## 编译模型
|
54 |
+
|
55 |
+
1. 导出所有onnx模型,如果过程中提示缺少某些组件,直接`pip3 install 组件`即可
|
56 |
+
|
57 |
+
``` shell
|
58 |
+
cd compile
|
59 |
+
python3 export_onnx.py --model_path your_chatglm2-6b_path
|
60 |
+
```
|
61 |
+
此时有大量onnx模型被导出到tmp目录。
|
62 |
+
|
63 |
+
2. 对onnx模型进行编译
|
64 |
+
|
65 |
+
目前TPU-MLIR支持对ChatGLM2进行F16、INT8和INT4量化,且支持多芯分布式推理,默认情况下会进行F16量化和单芯推理,最终生成`chatglm2-6b_f16_1dev.bmodel`文件
|
66 |
+
|
67 |
+
```shell
|
68 |
+
./compile.sh --name chatglm2-6b --mode inference_mode --num_device device_number
|
69 |
+
```
|
70 |
+
|
71 |
+
其中:
|
72 |
+
`--name` 为模型名称,在此指定为`chatglm2-6b`;
|
73 |
+
`--mode` 为推理所使用的数据类型,可以选择`f16, int8, int4`中任意一种,默认为`f16`;
|
74 |
+
`--num_device` 为推理所使用的芯片数量,请根据实际所使用的设备指定,默认`--num_device 1`。
|
75 |
+
|
76 |
+
## 编译程序(C++版本)
|
77 |
+
|
78 |
+
执行如下编译,(PCIE与SOC相同):
|
79 |
+
|
80 |
+
```shell
|
81 |
+
cd demo
|
82 |
+
mkdir build
|
83 |
+
cd build
|
84 |
+
cmake ..
|
85 |
+
make
|
86 |
+
```
|
87 |
+
|
88 |
+
编译生成chatglm可执行程序,将`chatglm`放到demo目录下,同时按照下列方式指定芯片数量和bmodel路径。
|
89 |
+
运行`chatglm`,默认单芯运行`chatglm2-6b_f16_1dev.bmodel`:
|
90 |
+
```shell
|
91 |
+
./chatglm --model chatglm2-6b_f16_1dev.bmodel --tokenizer ../support/tokenizer/tokenizer.model --devid your_devid
|
92 |
+
```
|
93 |
+
其中`--devid`为用来推理的TPU编号,默认为0,如果使用多芯推理(需要保证编译的bmodel也是多芯)可以使用`,`来增加芯片,如`--devid 2,3` 表示使用TPU2 和 TPU3来进行推理。
|
94 |
+
|
95 |
+
## 运行效果
|
96 |
+
|
97 |
+
以下为单芯片下INT8量化模式的运行效果:
|
98 |
+
|
99 |
+
![](./assets/chatglm.jpg)
|
100 |
+
|
101 |
+
## 常见问题
|
102 |
+
|
103 |
+
#### sentencepiece是怎么来的
|
104 |
+
|
105 |
+
工程中已经有编译好的,所以不需要编译,如果好奇的话,参考如下步骤。
|
106 |
+
|
107 |
+
下载[sentencepiece](https://github.com/google/sentencepiece),并编译得到`libsentencepiece.a`
|
108 |
+
|
109 |
+
```shell
|
110 |
+
git clone git@github.com:google/sentencepiece.git
|
111 |
+
cd sentencepiece
|
112 |
+
mkdir build
|
113 |
+
cd build
|
114 |
+
cmake ..
|
115 |
+
make -j
|
116 |
+
```
|
117 |
+
|
118 |
+
如果要编译SoC环境,则参考demo的编译方式,在makefile中指定交叉编译器
|
119 |
+
|
120 |
+
#### demo程序无法正常运行
|
121 |
+
|
122 |
+
如果demo程序拷贝到运行环境提示无法运行,比如接口找不到等等错误。
|
123 |
+
原因是运行环境的库有所不同,将demo中的`./support/lib_pcie`(PCIE)或者 `./support/lib_soc`(SoC)里面的so文件拷贝到运行环境,链接到里面的so即可。
|
124 |
+
|
125 |
+
|
126 |
+
#### 对源码做了哪些修改:
|
127 |
+
|
128 |
+
一共做了三点修改:
|
129 |
+
- 将`config.json`文件中`seq_length`配置为512;
|
130 |
+
- 将`modeling_chatglm.py`文件中的如下代码:
|
131 |
+
|
132 |
+
```python
|
133 |
+
if attention_mask is not None:
|
134 |
+
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
|
135 |
+
```
|
136 |
+
|
137 |
+
修改为:
|
138 |
+
|
139 |
+
```python
|
140 |
+
if attention_mask is not None:
|
141 |
+
attention_scores = attention_scores + (attention_mask * -10000.0)
|
142 |
+
```
|
143 |
+
|
144 |
+
这样修改可以提升效率,使用`masked_fill`效率低下;另一方面`masked_fill`转ONNX存在些bug。
|
145 |
+
|
146 |
+
- 将`modeling_chatglm.py`文件中的如下代码:
|
147 |
+
|
148 |
+
```python
|
149 |
+
pytorch_major_version = int(torch.__version__.split('.')[0])
|
150 |
+
if pytorch_major_version >= 2:
|
151 |
+
```
|
152 |
+
|
153 |
+
修改为:
|
154 |
+
|
155 |
+
```python
|
156 |
+
pytorch_major_version = int(torch.__version__.split('.')[0])
|
157 |
+
if False:
|
158 |
+
```
|
159 |
+
|
160 |
+
这是因为ONNX无法支持`torch.nn.functional.scaled_dot_product_attention`算子的转换。
|
ChatGLM2/compile/compile.sh
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
set -ex
|
3 |
+
models=
|
4 |
+
mode="f16"
|
5 |
+
folder="tmp"
|
6 |
+
num_device=1
|
7 |
+
mode_args=""
|
8 |
+
device_args=""
|
9 |
+
quantize_args="--quantize F16"
|
10 |
+
name=""
|
11 |
+
num_layers=
|
12 |
+
out_model=$name.bmodel
|
13 |
+
|
14 |
+
while [[ $# -gt 0 ]]; do
|
15 |
+
key="$1"
|
16 |
+
|
17 |
+
case $key in
|
18 |
+
--mode)
|
19 |
+
mode="$2"
|
20 |
+
shift 2
|
21 |
+
;;
|
22 |
+
--num_device)
|
23 |
+
num_device="$2"
|
24 |
+
shift 2
|
25 |
+
;;
|
26 |
+
--name)
|
27 |
+
name="$2"
|
28 |
+
shift 2
|
29 |
+
;;
|
30 |
+
*)
|
31 |
+
echo "Invalid option: $key" >&2
|
32 |
+
exit 1
|
33 |
+
;;
|
34 |
+
:)
|
35 |
+
echo "Option -$OPTARG requires an argument." >&2
|
36 |
+
exit 1
|
37 |
+
;;
|
38 |
+
esac
|
39 |
+
done
|
40 |
+
|
41 |
+
if [ "$name" = "chatglm2-6b" ]; then
|
42 |
+
num_layers=27
|
43 |
+
echo "Compile ChatGLM2-6B"
|
44 |
+
else
|
45 |
+
>&2 echo -e "Error: Invalid name $name, the input name must be \033[31mchatglm2-6b\033[0m"
|
46 |
+
exit 1
|
47 |
+
fi
|
48 |
+
|
49 |
+
if [ x$mode == x"int8" ]; then
|
50 |
+
quantize_args="--quantize W8F16"
|
51 |
+
elif [ x$mode == x"f16" ]; then
|
52 |
+
quantize_args="--quantize F16"
|
53 |
+
elif [ x$mode == x"int4" ]; then
|
54 |
+
quantize_args="--quantize W4F16 --q_group_size 64"
|
55 |
+
else
|
56 |
+
echo "Error, unknown quantize mode"
|
57 |
+
exit 1
|
58 |
+
fi
|
59 |
+
|
60 |
+
if [ x$num_device != x1 ]; then
|
61 |
+
device_args="--num_device $num_device"
|
62 |
+
out_model=$name'_'$mode'_'$num_device'dev.bmodel'
|
63 |
+
else
|
64 |
+
out_model=$name'_'$mode'_1dev.bmodel'
|
65 |
+
fi
|
66 |
+
|
67 |
+
outdir=${folder}/embedding
|
68 |
+
mkdir -p $outdir
|
69 |
+
pushd $outdir
|
70 |
+
|
71 |
+
model_transform.py \
|
72 |
+
--model_name embedding \
|
73 |
+
--model_def ../onnx/embedding.onnx \
|
74 |
+
--mlir embedding.mlir
|
75 |
+
|
76 |
+
|
77 |
+
model_deploy.py \
|
78 |
+
--mlir embedding.mlir \
|
79 |
+
--quantize F16 \
|
80 |
+
--quant_input \
|
81 |
+
--quant_output \
|
82 |
+
--chip bm1684x \
|
83 |
+
$device_args \
|
84 |
+
--model embedding.bmodel
|
85 |
+
|
86 |
+
model_transform.py \
|
87 |
+
--model_name embedding_cache \
|
88 |
+
--model_def ../onnx/embedding.onnx \
|
89 |
+
--input_shapes [[1,1]] \
|
90 |
+
--mlir embedding_cache.mlir
|
91 |
+
|
92 |
+
|
93 |
+
model_deploy.py \
|
94 |
+
--mlir embedding_cache.mlir \
|
95 |
+
--quantize F16 \
|
96 |
+
--quant_input \
|
97 |
+
--quant_output \
|
98 |
+
--chip bm1684x \
|
99 |
+
$device_args \
|
100 |
+
--model embedding_cache.bmodel
|
101 |
+
|
102 |
+
rm *.npz
|
103 |
+
|
104 |
+
models=$models' '$outdir'/embedding.bmodel '$outdir'/embedding_cache.bmodel '
|
105 |
+
|
106 |
+
popd
|
107 |
+
|
108 |
+
echo $models
|
109 |
+
|
110 |
+
outdir=tmp/$mode"_"$num_device"dev"/lm_head
|
111 |
+
mkdir -p $outdir
|
112 |
+
pushd $outdir
|
113 |
+
|
114 |
+
model_transform.py \
|
115 |
+
--model_name lm_head \
|
116 |
+
--model_def ../../onnx/lm_head.onnx \
|
117 |
+
--mlir lm_head.mlir
|
118 |
+
|
119 |
+
model_deploy.py \
|
120 |
+
--mlir lm_head.mlir \
|
121 |
+
$quantize_args \
|
122 |
+
--quant_input \
|
123 |
+
--quant_output \
|
124 |
+
--chip bm1684x \
|
125 |
+
$device_args \
|
126 |
+
--model lm_head.bmodel
|
127 |
+
|
128 |
+
rm *.npz
|
129 |
+
|
130 |
+
models=${models}${outdir}'/lm_head.bmodel '
|
131 |
+
popd
|
132 |
+
|
133 |
+
echo $models
|
134 |
+
|
135 |
+
outdir=tmp/$mode"_"$num_device"dev"/block
|
136 |
+
mkdir -p $outdir
|
137 |
+
|
138 |
+
pushd $outdir
|
139 |
+
mkdir -p $outdir
|
140 |
+
|
141 |
+
for ((i=0; i<=$num_layers; i++)); do
|
142 |
+
|
143 |
+
model_transform.py \
|
144 |
+
--model_name block_$i \
|
145 |
+
--model_def ../../onnx/block_$i.onnx \
|
146 |
+
--mlir block_$i.mlir
|
147 |
+
|
148 |
+
model_deploy.py \
|
149 |
+
--mlir block_$i.mlir \
|
150 |
+
$quantize_args \
|
151 |
+
--quant_input \
|
152 |
+
--quant_output \
|
153 |
+
--chip bm1684x \
|
154 |
+
$device_args \
|
155 |
+
--model block_$i.bmodel
|
156 |
+
|
157 |
+
model_transform.py \
|
158 |
+
--model_name block_cache_$i \
|
159 |
+
--model_def ../../onnx/block_cache_$i.onnx \
|
160 |
+
--mlir block_cache_$i.mlir
|
161 |
+
|
162 |
+
model_deploy.py \
|
163 |
+
--mlir block_cache_$i.mlir \
|
164 |
+
$quantize_args \
|
165 |
+
--quant_input \
|
166 |
+
--quant_output \
|
167 |
+
--chip bm1684x \
|
168 |
+
$device_args \
|
169 |
+
--model block_cache_$i.bmodel
|
170 |
+
|
171 |
+
rm *.npz
|
172 |
+
|
173 |
+
models=${models}${outdir}'/block_'$i'.bmodel '$outdir'/block_cache_'$i'.bmodel '
|
174 |
+
|
175 |
+
done
|
176 |
+
popd
|
177 |
+
echo $models
|
178 |
+
|
179 |
+
model_tool --combine $models -o $out_model
|
ChatGLM2/compile/export_onnx.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# ==============================================================================
|
3 |
+
#
|
4 |
+
# Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
|
5 |
+
#
|
6 |
+
# TPU-MLIR is licensed under the 2-Clause BSD License except for the
|
7 |
+
# third-party components.
|
8 |
+
#
|
9 |
+
# ==============================================================================
|
10 |
+
|
11 |
+
import os
|
12 |
+
import torch
|
13 |
+
import argparse
|
14 |
+
from tqdm import tqdm
|
15 |
+
from transformers import AutoModel, AutoTokenizer
|
16 |
+
|
17 |
+
parser = argparse.ArgumentParser(description='export onnx.')
|
18 |
+
parser.add_argument('--model_path', type=str, help='path to the torch model.')
|
19 |
+
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
model_path = args.model_path
|
23 |
+
folder = f"./tmp/onnx"
|
24 |
+
|
25 |
+
origin_model = AutoModel.from_pretrained(
|
26 |
+
model_path, trust_remote_code=True).float().eval()
|
27 |
+
|
28 |
+
for param in origin_model.parameters():
|
29 |
+
param.requires_grad = False
|
30 |
+
|
31 |
+
config = origin_model.config
|
32 |
+
transformer = origin_model.transformer
|
33 |
+
layers = transformer.encoder.layers
|
34 |
+
|
35 |
+
SEQ_LENGTH = transformer.seq_length
|
36 |
+
NUM_LAYERS = config.num_layers
|
37 |
+
HIDDEN_SIZE = config.hidden_size
|
38 |
+
NUM_ATTENTION_HEADS = config.num_attention_heads
|
39 |
+
HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS
|
40 |
+
|
41 |
+
print(f'Layers: {NUM_LAYERS}\nHidden size: {HIDDEN_SIZE}\n')
|
42 |
+
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
44 |
+
|
45 |
+
class Embedding(torch.nn.Module):
|
46 |
+
|
47 |
+
def __init__(self):
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
def forward(self, input_ids):
|
51 |
+
return transformer.embedding.word_embeddings(input_ids)
|
52 |
+
|
53 |
+
|
54 |
+
class Block(torch.nn.Module):
|
55 |
+
|
56 |
+
def __init__(self, layer_id):
|
57 |
+
super().__init__()
|
58 |
+
self.layer_id = layer_id
|
59 |
+
self.layer = layers[layer_id]
|
60 |
+
|
61 |
+
def forward(self, hidden_states, position_ids, attention_mask):
|
62 |
+
rotary_pos_emb = transformer.rotary_pos_emb(SEQ_LENGTH)[position_ids]
|
63 |
+
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
64 |
+
hidden_states, past_kv = self.layer(hidden_states,
|
65 |
+
attention_mask,
|
66 |
+
rotary_pos_emb=rotary_pos_emb)
|
67 |
+
return hidden_states, past_kv
|
68 |
+
|
69 |
+
|
70 |
+
class BlockCache(torch.nn.Module):
|
71 |
+
|
72 |
+
def __init__(self, layer_id):
|
73 |
+
super().__init__()
|
74 |
+
self.layer_id = layer_id
|
75 |
+
self.layer = layers[layer_id]
|
76 |
+
|
77 |
+
def forward(self, hidden_states, position_ids, attention_mask, past_k,
|
78 |
+
past_v):
|
79 |
+
rotary_pos_emb = transformer.rotary_pos_emb(SEQ_LENGTH)[position_ids]
|
80 |
+
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
81 |
+
hidden_states, past_kv = self.layer(hidden_states,
|
82 |
+
attention_mask,
|
83 |
+
kv_cache=(past_k, past_v),
|
84 |
+
rotary_pos_emb=rotary_pos_emb)
|
85 |
+
present_k, present_v = past_kv
|
86 |
+
return hidden_states, present_k[1:], present_v[1:]
|
87 |
+
|
88 |
+
|
89 |
+
class LmHead(torch.nn.Module):
|
90 |
+
|
91 |
+
def __init__(self):
|
92 |
+
super().__init__()
|
93 |
+
|
94 |
+
def forward(self, hidden_states):
|
95 |
+
hidden_states = transformer.encoder.final_layernorm(hidden_states)
|
96 |
+
m_logits = transformer.output_layer(hidden_states)
|
97 |
+
_, token = torch.topk(m_logits, 1)
|
98 |
+
return token
|
99 |
+
|
100 |
+
|
101 |
+
def convert_block(layer_id):
|
102 |
+
model = Block(layer_id)
|
103 |
+
hidden_states = torch.randn((SEQ_LENGTH, 1, HIDDEN_SIZE))
|
104 |
+
position_ids = torch.tensor([range(SEQ_LENGTH)], dtype=torch.long)
|
105 |
+
attention_mask = -1000 * torch.ones((1, 1, SEQ_LENGTH, SEQ_LENGTH), dtype=torch.float32).triu(diagonal=1)
|
106 |
+
torch.onnx.export(
|
107 |
+
model, (hidden_states, position_ids, attention_mask),
|
108 |
+
f'{folder}/block_{layer_id}.onnx',
|
109 |
+
verbose=False,
|
110 |
+
input_names=['input_states', 'position_ids', 'attention_mask'],
|
111 |
+
output_names=['hidden_states', 'past_k', 'past_v'],
|
112 |
+
do_constant_folding=True,
|
113 |
+
opset_version=15)
|
114 |
+
|
115 |
+
|
116 |
+
def convert_block_cache(layer_id):
|
117 |
+
model = BlockCache(layer_id)
|
118 |
+
hidden_states = torch.randn((1, 1, HIDDEN_SIZE))
|
119 |
+
position_ids = torch.tensor([range(1)], dtype=torch.long)
|
120 |
+
attention_mask = -1000 * torch.ones((1, 1, 1, SEQ_LENGTH + 1), dtype=torch.float32).triu(diagonal=1)
|
121 |
+
past_k = torch.randn((SEQ_LENGTH, 1, 2, HEAD_DIM))
|
122 |
+
past_v = torch.randn((SEQ_LENGTH, 1, 2, HEAD_DIM))
|
123 |
+
|
124 |
+
torch.onnx.export(
|
125 |
+
model, (hidden_states, position_ids, attention_mask, past_k, past_v),
|
126 |
+
f'{folder}/block_cache_{layer_id}.onnx',
|
127 |
+
verbose=False,
|
128 |
+
input_names=[
|
129 |
+
'input_states', 'position_ids', 'attention_mask', 'history_k',
|
130 |
+
'history_v'
|
131 |
+
],
|
132 |
+
output_names=['hidden_states', 'past_k', 'past_v'],
|
133 |
+
do_constant_folding=True,
|
134 |
+
opset_version=15)
|
135 |
+
|
136 |
+
|
137 |
+
def convert_embedding():
|
138 |
+
model = Embedding()
|
139 |
+
input_ids = torch.tensor([range(SEQ_LENGTH)])
|
140 |
+
|
141 |
+
torch.onnx.export(model, (input_ids),
|
142 |
+
f'{folder}/embedding.onnx',
|
143 |
+
verbose=False,
|
144 |
+
input_names=['input_ids'],
|
145 |
+
output_names=['input_embed'],
|
146 |
+
do_constant_folding=True,
|
147 |
+
opset_version=15)
|
148 |
+
|
149 |
+
|
150 |
+
def convert_lm_head():
|
151 |
+
model = LmHead()
|
152 |
+
input = torch.randn(1, HIDDEN_SIZE)
|
153 |
+
|
154 |
+
torch.onnx.export(model, (input),
|
155 |
+
f'{folder}/lm_head.onnx',
|
156 |
+
verbose=False,
|
157 |
+
input_names=['hidden_states'],
|
158 |
+
output_names=['token'],
|
159 |
+
do_constant_folding=True,
|
160 |
+
opset_version=15)
|
161 |
+
|
162 |
+
# create folder to store onnx
|
163 |
+
if not os.path.exists(folder):
|
164 |
+
os.makedirs(folder)
|
165 |
+
|
166 |
+
# export models
|
167 |
+
print(f'Convert block & block_cache')
|
168 |
+
for i in tqdm(range(NUM_LAYERS)):
|
169 |
+
convert_block(i)
|
170 |
+
convert_block_cache(i)
|
171 |
+
|
172 |
+
print(f'Convert embedding')
|
173 |
+
convert_embedding()
|
174 |
+
|
175 |
+
print(f'Convert lm_head')
|
176 |
+
convert_lm_head()
|
ChatGLM2/compile/files/chatglm2-6b/config.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "THUDM/chatglm2-6b",
|
3 |
+
"model_type": "chatglm",
|
4 |
+
"architectures": [
|
5 |
+
"ChatGLMModel"
|
6 |
+
],
|
7 |
+
"auto_map": {
|
8 |
+
"AutoConfig": "configuration_chatglm.ChatGLMConfig",
|
9 |
+
"AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
10 |
+
"AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
11 |
+
"AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
12 |
+
"AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
|
13 |
+
},
|
14 |
+
"add_bias_linear": false,
|
15 |
+
"add_qkv_bias": true,
|
16 |
+
"apply_query_key_layer_scaling": true,
|
17 |
+
"apply_residual_connection_post_layernorm": false,
|
18 |
+
"attention_dropout": 0.0,
|
19 |
+
"attention_softmax_in_fp32": true,
|
20 |
+
"bias_dropout_fusion": true,
|
21 |
+
"ffn_hidden_size": 13696,
|
22 |
+
"fp32_residual_connection": false,
|
23 |
+
"hidden_dropout": 0.0,
|
24 |
+
"hidden_size": 4096,
|
25 |
+
"kv_channels": 128,
|
26 |
+
"layernorm_epsilon": 1e-05,
|
27 |
+
"multi_query_attention": true,
|
28 |
+
"multi_query_group_num": 2,
|
29 |
+
"num_attention_heads": 32,
|
30 |
+
"num_layers": 28,
|
31 |
+
"original_rope": true,
|
32 |
+
"padded_vocab_size": 65024,
|
33 |
+
"post_layer_norm": true,
|
34 |
+
"rmsnorm": true,
|
35 |
+
"seq_length": 512,
|
36 |
+
"use_cache": true,
|
37 |
+
"torch_dtype": "float16",
|
38 |
+
"transformers_version": "4.27.1",
|
39 |
+
"tie_word_embeddings": false,
|
40 |
+
"eos_token_id": 2,
|
41 |
+
"pad_token_id": 0
|
42 |
+
}
|
ChatGLM2/compile/files/chatglm2-6b/modeling_chatglm.py
ADDED
@@ -0,0 +1,1285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" PyTorch ChatGLM model. """
|
2 |
+
|
3 |
+
import math
|
4 |
+
import copy
|
5 |
+
import warnings
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.utils.checkpoint
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch import nn
|
13 |
+
from torch.nn import CrossEntropyLoss, LayerNorm
|
14 |
+
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
15 |
+
from torch.nn.utils import skip_init
|
16 |
+
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
17 |
+
|
18 |
+
from transformers.modeling_outputs import (
|
19 |
+
BaseModelOutputWithPast,
|
20 |
+
CausalLMOutputWithPast,
|
21 |
+
SequenceClassifierOutputWithPast,
|
22 |
+
)
|
23 |
+
from transformers.modeling_utils import PreTrainedModel
|
24 |
+
from transformers.utils import logging
|
25 |
+
from transformers.generation.logits_process import LogitsProcessor
|
26 |
+
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
27 |
+
|
28 |
+
from .configuration_chatglm import ChatGLMConfig
|
29 |
+
|
30 |
+
# flags required to enable jit fusion kernels
|
31 |
+
|
32 |
+
if sys.platform != 'darwin':
|
33 |
+
torch._C._jit_set_profiling_mode(False)
|
34 |
+
torch._C._jit_set_profiling_executor(False)
|
35 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
36 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__)
|
39 |
+
|
40 |
+
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B"
|
41 |
+
_CONFIG_FOR_DOC = "ChatGLM6BConfig"
|
42 |
+
|
43 |
+
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
44 |
+
"THUDM/chatglm2-6b",
|
45 |
+
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
|
46 |
+
]
|
47 |
+
|
48 |
+
|
49 |
+
def default_init(cls, *args, **kwargs):
|
50 |
+
return cls(*args, **kwargs)
|
51 |
+
|
52 |
+
|
53 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
54 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
55 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
56 |
+
scores.zero_()
|
57 |
+
scores[..., 5] = 5e4
|
58 |
+
return scores
|
59 |
+
|
60 |
+
|
61 |
+
class PrefixEncoder(torch.nn.Module):
|
62 |
+
"""
|
63 |
+
The torch.nn model to encode the prefix
|
64 |
+
Input shape: (batch-size, prefix-length)
|
65 |
+
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, config: ChatGLMConfig):
|
69 |
+
super().__init__()
|
70 |
+
self.prefix_projection = config.prefix_projection
|
71 |
+
if self.prefix_projection:
|
72 |
+
# Use a two-layer MLP to encode the prefix
|
73 |
+
kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
|
74 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
|
75 |
+
self.trans = torch.nn.Sequential(
|
76 |
+
torch.nn.Linear(kv_size, config.hidden_size),
|
77 |
+
torch.nn.Tanh(),
|
78 |
+
torch.nn.Linear(config.hidden_size, kv_size)
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len,
|
82 |
+
config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
|
83 |
+
|
84 |
+
def forward(self, prefix: torch.Tensor):
|
85 |
+
if self.prefix_projection:
|
86 |
+
prefix_tokens = self.embedding(prefix)
|
87 |
+
past_key_values = self.trans(prefix_tokens)
|
88 |
+
else:
|
89 |
+
past_key_values = self.embedding(prefix)
|
90 |
+
return past_key_values
|
91 |
+
|
92 |
+
|
93 |
+
def split_tensor_along_last_dim(
|
94 |
+
tensor: torch.Tensor,
|
95 |
+
num_partitions: int,
|
96 |
+
contiguous_split_chunks: bool = False,
|
97 |
+
) -> List[torch.Tensor]:
|
98 |
+
"""Split a tensor along its last dimension.
|
99 |
+
|
100 |
+
Arguments:
|
101 |
+
tensor: input tensor.
|
102 |
+
num_partitions: number of partitions to split the tensor
|
103 |
+
contiguous_split_chunks: If True, make each chunk contiguous
|
104 |
+
in memory.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
A list of Tensors
|
108 |
+
"""
|
109 |
+
# Get the size and dimension.
|
110 |
+
last_dim = tensor.dim() - 1
|
111 |
+
last_dim_size = tensor.size()[last_dim] // num_partitions
|
112 |
+
# Split.
|
113 |
+
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
114 |
+
# Note: torch.split does not create contiguous tensors by default.
|
115 |
+
if contiguous_split_chunks:
|
116 |
+
return tuple(chunk.contiguous() for chunk in tensor_list)
|
117 |
+
|
118 |
+
return tensor_list
|
119 |
+
|
120 |
+
|
121 |
+
class RotaryEmbedding(nn.Module):
|
122 |
+
def __init__(self, dim, original_impl=False, device=None, dtype=None):
|
123 |
+
super().__init__()
|
124 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
|
125 |
+
self.register_buffer("inv_freq", inv_freq)
|
126 |
+
self.dim = dim
|
127 |
+
self.original_impl = original_impl
|
128 |
+
|
129 |
+
def forward_impl(
|
130 |
+
self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
|
131 |
+
):
|
132 |
+
"""Enhanced Transformer with Rotary Position Embedding.
|
133 |
+
|
134 |
+
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
|
135 |
+
transformers/rope/__init__.py. MIT License:
|
136 |
+
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
137 |
+
"""
|
138 |
+
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
139 |
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
|
140 |
+
|
141 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
142 |
+
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
|
143 |
+
|
144 |
+
# Calculate the product of position index and $\theta_i$
|
145 |
+
idx_theta = torch.outer(seq_idx, theta).float()
|
146 |
+
|
147 |
+
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
|
148 |
+
|
149 |
+
# this is to mimic the behaviour of complex32, else we will get different results
|
150 |
+
if dtype in (torch.float16, torch.bfloat16, torch.int8):
|
151 |
+
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
|
152 |
+
return cache
|
153 |
+
|
154 |
+
def forward(self, max_seq_len, offset=0):
|
155 |
+
return self.forward_impl(
|
156 |
+
max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
@torch.jit.script
|
161 |
+
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
162 |
+
# x: [sq, b, np, hn]
|
163 |
+
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
164 |
+
rot_dim = rope_cache.shape[-2] * 2
|
165 |
+
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
166 |
+
# truncate to support variable sizes
|
167 |
+
rope_cache = rope_cache[:sq]
|
168 |
+
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
|
169 |
+
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
|
170 |
+
x_out2 = torch.stack(
|
171 |
+
[
|
172 |
+
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
|
173 |
+
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
|
174 |
+
],
|
175 |
+
-1,
|
176 |
+
)
|
177 |
+
x_out2 = x_out2.flatten(3)
|
178 |
+
return torch.cat((x_out2, x_pass), dim=-1)
|
179 |
+
|
180 |
+
|
181 |
+
class RMSNorm(torch.nn.Module):
|
182 |
+
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
183 |
+
super().__init__()
|
184 |
+
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
185 |
+
self.eps = eps
|
186 |
+
|
187 |
+
def forward(self, hidden_states: torch.Tensor):
|
188 |
+
input_dtype = hidden_states.dtype
|
189 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
190 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
191 |
+
|
192 |
+
return (self.weight * hidden_states).to(input_dtype)
|
193 |
+
|
194 |
+
|
195 |
+
class CoreAttention(torch.nn.Module):
|
196 |
+
def __init__(self, config: ChatGLMConfig, layer_number):
|
197 |
+
super(CoreAttention, self).__init__()
|
198 |
+
|
199 |
+
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
|
200 |
+
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
201 |
+
if self.apply_query_key_layer_scaling:
|
202 |
+
self.attention_softmax_in_fp32 = True
|
203 |
+
self.layer_number = max(1, layer_number)
|
204 |
+
|
205 |
+
projection_size = config.kv_channels * config.num_attention_heads
|
206 |
+
|
207 |
+
# Per attention head and per partition values.
|
208 |
+
self.hidden_size_per_partition = projection_size
|
209 |
+
self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
|
210 |
+
self.num_attention_heads_per_partition = config.num_attention_heads
|
211 |
+
|
212 |
+
coeff = None
|
213 |
+
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
214 |
+
if self.apply_query_key_layer_scaling:
|
215 |
+
coeff = self.layer_number
|
216 |
+
self.norm_factor *= coeff
|
217 |
+
self.coeff = coeff
|
218 |
+
|
219 |
+
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
|
220 |
+
|
221 |
+
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
222 |
+
pytorch_major_version = int(torch.__version__.split('.')[0])
|
223 |
+
if False:
|
224 |
+
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
225 |
+
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
226 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
227 |
+
is_causal=True)
|
228 |
+
else:
|
229 |
+
if attention_mask is not None:
|
230 |
+
attention_mask = ~attention_mask
|
231 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
232 |
+
attention_mask)
|
233 |
+
context_layer = context_layer.permute(2, 0, 1, 3)
|
234 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
235 |
+
context_layer = context_layer.reshape(*new_context_layer_shape)
|
236 |
+
else:
|
237 |
+
# Raw attention scores
|
238 |
+
|
239 |
+
# [b, np, sq, sk]
|
240 |
+
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
|
241 |
+
|
242 |
+
# [sq, b, np, hn] -> [sq, b * np, hn]
|
243 |
+
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
244 |
+
# [sk, b, np, hn] -> [sk, b * np, hn]
|
245 |
+
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
246 |
+
|
247 |
+
# preallocting input tensor: [b * np, sq, sk]
|
248 |
+
matmul_input_buffer = torch.empty(
|
249 |
+
output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
|
250 |
+
device=query_layer.device
|
251 |
+
)
|
252 |
+
|
253 |
+
# Raw attention scores. [b * np, sq, sk]
|
254 |
+
matmul_result = torch.baddbmm(
|
255 |
+
matmul_input_buffer,
|
256 |
+
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
257 |
+
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
|
258 |
+
beta=0.0,
|
259 |
+
alpha=(1.0 / self.norm_factor),
|
260 |
+
)
|
261 |
+
|
262 |
+
# change view to [b, np, sq, sk]
|
263 |
+
attention_scores = matmul_result.view(*output_size)
|
264 |
+
|
265 |
+
# ===========================
|
266 |
+
# Attention probs and dropout
|
267 |
+
# ===========================
|
268 |
+
|
269 |
+
# attention scores and attention mask [b, np, sq, sk]
|
270 |
+
if self.attention_softmax_in_fp32:
|
271 |
+
attention_scores = attention_scores.float()
|
272 |
+
if self.coeff is not None:
|
273 |
+
attention_scores = attention_scores * self.coeff
|
274 |
+
if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
|
275 |
+
attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
|
276 |
+
device=attention_scores.device, dtype=torch.bool)
|
277 |
+
attention_mask.tril_()
|
278 |
+
attention_mask = ~attention_mask
|
279 |
+
if attention_mask is not None:
|
280 |
+
attention_scores = attention_scores + attention_mask
|
281 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
282 |
+
attention_probs = attention_probs.type_as(value_layer)
|
283 |
+
|
284 |
+
# This is actually dropping out entire tokens to attend to, which might
|
285 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
286 |
+
attention_probs = self.attention_dropout(attention_probs)
|
287 |
+
# =========================
|
288 |
+
# Context layer. [sq, b, hp]
|
289 |
+
# =========================
|
290 |
+
|
291 |
+
# value_layer -> context layer.
|
292 |
+
# [sk, b, np, hn] --> [b, np, sq, hn]
|
293 |
+
|
294 |
+
# context layer shape: [b, np, sq, hn]
|
295 |
+
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
296 |
+
# change view [sk, b * np, hn]
|
297 |
+
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
|
298 |
+
# change view [b * np, sq, sk]
|
299 |
+
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
300 |
+
# matmul: [b * np, sq, hn]
|
301 |
+
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
|
302 |
+
# change view [b, np, sq, hn]
|
303 |
+
context_layer = context_layer.view(*output_size)
|
304 |
+
# [b, np, sq, hn] --> [sq, b, np, hn]
|
305 |
+
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
306 |
+
# [sq, b, np, hn] --> [sq, b, hp]
|
307 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
308 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
309 |
+
|
310 |
+
return context_layer
|
311 |
+
|
312 |
+
|
313 |
+
class SelfAttention(torch.nn.Module):
|
314 |
+
"""Parallel self-attention layer abstract class.
|
315 |
+
|
316 |
+
Self-attention layer takes input with size [s, b, h]
|
317 |
+
and returns output of the same size.
|
318 |
+
"""
|
319 |
+
|
320 |
+
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
321 |
+
super(SelfAttention, self).__init__()
|
322 |
+
self.layer_number = max(1, layer_number)
|
323 |
+
|
324 |
+
self.projection_size = config.kv_channels * config.num_attention_heads
|
325 |
+
|
326 |
+
# Per attention head and per partition values.
|
327 |
+
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
328 |
+
self.num_attention_heads_per_partition = config.num_attention_heads
|
329 |
+
|
330 |
+
self.multi_query_attention = config.multi_query_attention
|
331 |
+
self.qkv_hidden_size = 3 * self.projection_size
|
332 |
+
if self.multi_query_attention:
|
333 |
+
self.num_multi_query_groups_per_partition = config.multi_query_group_num
|
334 |
+
self.qkv_hidden_size = (
|
335 |
+
self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
|
336 |
+
)
|
337 |
+
self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
|
338 |
+
bias=config.add_bias_linear or config.add_qkv_bias,
|
339 |
+
device=device, **_config_to_kwargs(config)
|
340 |
+
)
|
341 |
+
|
342 |
+
self.core_attention = CoreAttention(config, self.layer_number)
|
343 |
+
|
344 |
+
# Output.
|
345 |
+
self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
|
346 |
+
device=device, **_config_to_kwargs(config)
|
347 |
+
)
|
348 |
+
|
349 |
+
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
350 |
+
if self.multi_query_attention:
|
351 |
+
num_attention_heads = self.num_multi_query_groups_per_partition
|
352 |
+
else:
|
353 |
+
num_attention_heads = self.num_attention_heads_per_partition
|
354 |
+
return torch.empty(
|
355 |
+
inference_max_sequence_len,
|
356 |
+
batch_size,
|
357 |
+
num_attention_heads,
|
358 |
+
self.hidden_size_per_attention_head,
|
359 |
+
dtype=dtype,
|
360 |
+
device=device,
|
361 |
+
)
|
362 |
+
|
363 |
+
def forward(
|
364 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
|
365 |
+
):
|
366 |
+
# hidden_states: [sq, b, h]
|
367 |
+
|
368 |
+
# =================================================
|
369 |
+
# Pre-allocate memory for key-values for inference.
|
370 |
+
# =================================================
|
371 |
+
# =====================
|
372 |
+
# Query, Key, and Value
|
373 |
+
# =====================
|
374 |
+
|
375 |
+
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
376 |
+
mixed_x_layer = self.query_key_value(hidden_states)
|
377 |
+
|
378 |
+
if self.multi_query_attention:
|
379 |
+
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
380 |
+
[
|
381 |
+
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
|
382 |
+
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
383 |
+
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
384 |
+
],
|
385 |
+
dim=-1,
|
386 |
+
)
|
387 |
+
query_layer = query_layer.view(
|
388 |
+
query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
389 |
+
)
|
390 |
+
key_layer = key_layer.view(
|
391 |
+
key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
392 |
+
)
|
393 |
+
value_layer = value_layer.view(
|
394 |
+
value_layer.size()[:-1]
|
395 |
+
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
396 |
+
)
|
397 |
+
else:
|
398 |
+
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
399 |
+
(self.num_attention_heads_per_partition,
|
400 |
+
3 * self.hidden_size_per_attention_head)
|
401 |
+
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
402 |
+
|
403 |
+
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
404 |
+
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
405 |
+
|
406 |
+
# apply relative positional encoding (rotary embedding)
|
407 |
+
if rotary_pos_emb is not None:
|
408 |
+
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
409 |
+
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
410 |
+
|
411 |
+
# adjust key and value for inference
|
412 |
+
if kv_cache is not None:
|
413 |
+
cache_k, cache_v = kv_cache
|
414 |
+
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
415 |
+
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
416 |
+
if use_cache:
|
417 |
+
kv_cache = (key_layer, value_layer)
|
418 |
+
else:
|
419 |
+
kv_cache = None
|
420 |
+
|
421 |
+
if self.multi_query_attention:
|
422 |
+
key_layer = key_layer.unsqueeze(-2)
|
423 |
+
key_layer = key_layer.expand(
|
424 |
+
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
|
425 |
+
)
|
426 |
+
key_layer = key_layer.contiguous().view(
|
427 |
+
key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
428 |
+
)
|
429 |
+
value_layer = value_layer.unsqueeze(-2)
|
430 |
+
value_layer = value_layer.expand(
|
431 |
+
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
|
432 |
+
)
|
433 |
+
value_layer = value_layer.contiguous().view(
|
434 |
+
value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
435 |
+
)
|
436 |
+
|
437 |
+
# ==================================
|
438 |
+
# core attention computation
|
439 |
+
# ==================================
|
440 |
+
|
441 |
+
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
442 |
+
|
443 |
+
# =================
|
444 |
+
# Output. [sq, b, h]
|
445 |
+
# =================
|
446 |
+
|
447 |
+
output = self.dense(context_layer)
|
448 |
+
|
449 |
+
return output, kv_cache
|
450 |
+
|
451 |
+
|
452 |
+
def _config_to_kwargs(args):
|
453 |
+
common_kwargs = {
|
454 |
+
"dtype": args.torch_dtype,
|
455 |
+
}
|
456 |
+
return common_kwargs
|
457 |
+
|
458 |
+
|
459 |
+
class MLP(torch.nn.Module):
|
460 |
+
"""MLP.
|
461 |
+
|
462 |
+
MLP will take the input with h hidden state, project it to 4*h
|
463 |
+
hidden dimension, perform nonlinear transformation, and project the
|
464 |
+
state back into h hidden dimension.
|
465 |
+
"""
|
466 |
+
|
467 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
468 |
+
super(MLP, self).__init__()
|
469 |
+
|
470 |
+
self.add_bias = config.add_bias_linear
|
471 |
+
|
472 |
+
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
473 |
+
self.dense_h_to_4h = nn.Linear(
|
474 |
+
config.hidden_size,
|
475 |
+
config.ffn_hidden_size * 2,
|
476 |
+
bias=self.add_bias,
|
477 |
+
device=device,
|
478 |
+
**_config_to_kwargs(config)
|
479 |
+
)
|
480 |
+
|
481 |
+
def swiglu(x):
|
482 |
+
x = torch.chunk(x, 2, dim=-1)
|
483 |
+
return F.silu(x[0]) * x[1]
|
484 |
+
|
485 |
+
self.activation_func = swiglu
|
486 |
+
|
487 |
+
# Project back to h.
|
488 |
+
self.dense_4h_to_h = nn.Linear(
|
489 |
+
config.ffn_hidden_size,
|
490 |
+
config.hidden_size,
|
491 |
+
bias=self.add_bias,
|
492 |
+
device=device,
|
493 |
+
**_config_to_kwargs(config)
|
494 |
+
)
|
495 |
+
|
496 |
+
def forward(self, hidden_states):
|
497 |
+
# [s, b, 4hp]
|
498 |
+
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
499 |
+
intermediate_parallel = self.activation_func(intermediate_parallel)
|
500 |
+
# [s, b, h]
|
501 |
+
output = self.dense_4h_to_h(intermediate_parallel)
|
502 |
+
return output
|
503 |
+
|
504 |
+
|
505 |
+
class GLMBlock(torch.nn.Module):
|
506 |
+
"""A single transformer layer.
|
507 |
+
|
508 |
+
Transformer layer takes input with size [s, b, h] and returns an
|
509 |
+
output of the same size.
|
510 |
+
"""
|
511 |
+
|
512 |
+
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
513 |
+
super(GLMBlock, self).__init__()
|
514 |
+
self.layer_number = layer_number
|
515 |
+
|
516 |
+
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
517 |
+
|
518 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
519 |
+
|
520 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
521 |
+
# Layernorm on the input data.
|
522 |
+
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
523 |
+
dtype=config.torch_dtype)
|
524 |
+
|
525 |
+
# Self attention.
|
526 |
+
self.self_attention = SelfAttention(config, layer_number, device=device)
|
527 |
+
self.hidden_dropout = config.hidden_dropout
|
528 |
+
|
529 |
+
# Layernorm on the attention output
|
530 |
+
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
531 |
+
dtype=config.torch_dtype)
|
532 |
+
|
533 |
+
# MLP
|
534 |
+
self.mlp = MLP(config, device=device)
|
535 |
+
|
536 |
+
def forward(
|
537 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
|
538 |
+
):
|
539 |
+
# hidden_states: [s, b, h]
|
540 |
+
|
541 |
+
# Layer norm at the beginning of the transformer layer.
|
542 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
543 |
+
# Self attention.
|
544 |
+
attention_output, kv_cache = self.self_attention(
|
545 |
+
layernorm_output,
|
546 |
+
attention_mask,
|
547 |
+
rotary_pos_emb,
|
548 |
+
kv_cache=kv_cache,
|
549 |
+
use_cache=use_cache
|
550 |
+
)
|
551 |
+
|
552 |
+
# Residual connection.
|
553 |
+
if self.apply_residual_connection_post_layernorm:
|
554 |
+
residual = layernorm_output
|
555 |
+
else:
|
556 |
+
residual = hidden_states
|
557 |
+
|
558 |
+
layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
|
559 |
+
layernorm_input = residual + layernorm_input
|
560 |
+
|
561 |
+
# Layer norm post the self attention.
|
562 |
+
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
563 |
+
|
564 |
+
# MLP.
|
565 |
+
mlp_output = self.mlp(layernorm_output)
|
566 |
+
|
567 |
+
# Second residual connection.
|
568 |
+
if self.apply_residual_connection_post_layernorm:
|
569 |
+
residual = layernorm_output
|
570 |
+
else:
|
571 |
+
residual = layernorm_input
|
572 |
+
|
573 |
+
output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
|
574 |
+
output = residual + output
|
575 |
+
|
576 |
+
return output, kv_cache
|
577 |
+
|
578 |
+
|
579 |
+
class GLMTransformer(torch.nn.Module):
|
580 |
+
"""Transformer class."""
|
581 |
+
|
582 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
583 |
+
super(GLMTransformer, self).__init__()
|
584 |
+
|
585 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
586 |
+
self.post_layer_norm = config.post_layer_norm
|
587 |
+
|
588 |
+
# Number of layers.
|
589 |
+
self.num_layers = config.num_layers
|
590 |
+
|
591 |
+
# Transformer layers.
|
592 |
+
def build_layer(layer_number):
|
593 |
+
return GLMBlock(config, layer_number, device=device)
|
594 |
+
|
595 |
+
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
596 |
+
|
597 |
+
if self.post_layer_norm:
|
598 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
599 |
+
# Final layer norm before output.
|
600 |
+
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
601 |
+
dtype=config.torch_dtype)
|
602 |
+
|
603 |
+
self.gradient_checkpointing = False
|
604 |
+
|
605 |
+
def _get_layer(self, layer_number):
|
606 |
+
return self.layers[layer_number]
|
607 |
+
|
608 |
+
def forward(
|
609 |
+
self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
|
610 |
+
use_cache: Optional[bool] = True,
|
611 |
+
output_hidden_states: Optional[bool] = False,
|
612 |
+
):
|
613 |
+
if not kv_caches:
|
614 |
+
kv_caches = [None for _ in range(self.num_layers)]
|
615 |
+
presents = () if use_cache else None
|
616 |
+
if self.gradient_checkpointing and self.training:
|
617 |
+
if use_cache:
|
618 |
+
logger.warning_once(
|
619 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
620 |
+
)
|
621 |
+
use_cache = False
|
622 |
+
|
623 |
+
all_self_attentions = None
|
624 |
+
all_hidden_states = () if output_hidden_states else None
|
625 |
+
for index in range(self.num_layers):
|
626 |
+
if output_hidden_states:
|
627 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
628 |
+
|
629 |
+
layer = self._get_layer(index)
|
630 |
+
if self.gradient_checkpointing and self.training:
|
631 |
+
layer_ret = torch.utils.checkpoint.checkpoint(
|
632 |
+
layer,
|
633 |
+
hidden_states,
|
634 |
+
attention_mask,
|
635 |
+
rotary_pos_emb,
|
636 |
+
kv_caches[index],
|
637 |
+
use_cache
|
638 |
+
)
|
639 |
+
else:
|
640 |
+
layer_ret = layer(
|
641 |
+
hidden_states,
|
642 |
+
attention_mask,
|
643 |
+
rotary_pos_emb,
|
644 |
+
kv_cache=kv_caches[index],
|
645 |
+
use_cache=use_cache
|
646 |
+
)
|
647 |
+
hidden_states, kv_cache = layer_ret
|
648 |
+
if use_cache:
|
649 |
+
presents = presents + (kv_cache,)
|
650 |
+
|
651 |
+
if output_hidden_states:
|
652 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
653 |
+
|
654 |
+
# Final layer norm.
|
655 |
+
if self.post_layer_norm:
|
656 |
+
hidden_states = self.final_layernorm(hidden_states)
|
657 |
+
|
658 |
+
return hidden_states, presents, all_hidden_states, all_self_attentions
|
659 |
+
|
660 |
+
|
661 |
+
class ChatGLMPreTrainedModel(PreTrainedModel):
|
662 |
+
"""
|
663 |
+
An abstract class to handle weights initialization and
|
664 |
+
a simple interface for downloading and loading pretrained models.
|
665 |
+
"""
|
666 |
+
|
667 |
+
is_parallelizable = False
|
668 |
+
supports_gradient_checkpointing = True
|
669 |
+
config_class = ChatGLMConfig
|
670 |
+
base_model_prefix = "transformer"
|
671 |
+
_no_split_modules = ["GLMBlock"]
|
672 |
+
|
673 |
+
def _init_weights(self, module: nn.Module):
|
674 |
+
"""Initialize the weights."""
|
675 |
+
return
|
676 |
+
|
677 |
+
def get_masks(self, input_ids, past_key_values, padding_mask=None):
|
678 |
+
batch_size, seq_length = input_ids.shape
|
679 |
+
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
|
680 |
+
full_attention_mask.tril_()
|
681 |
+
past_length = 0
|
682 |
+
if past_key_values:
|
683 |
+
past_length = past_key_values[0][0].shape[0]
|
684 |
+
if past_length:
|
685 |
+
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
686 |
+
device=input_ids.device), full_attention_mask), dim=-1)
|
687 |
+
if padding_mask is not None:
|
688 |
+
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
689 |
+
if not past_length and padding_mask is not None:
|
690 |
+
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
|
691 |
+
full_attention_mask = (full_attention_mask < 0.5).bool()
|
692 |
+
full_attention_mask.unsqueeze_(1)
|
693 |
+
return full_attention_mask
|
694 |
+
|
695 |
+
def get_position_ids(self, input_ids, device):
|
696 |
+
batch_size, seq_length = input_ids.shape
|
697 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
698 |
+
return position_ids
|
699 |
+
|
700 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
701 |
+
if isinstance(module, GLMTransformer):
|
702 |
+
module.gradient_checkpointing = value
|
703 |
+
|
704 |
+
|
705 |
+
class Embedding(torch.nn.Module):
|
706 |
+
"""Language model embeddings."""
|
707 |
+
|
708 |
+
def __init__(self, config: ChatGLMConfig, device=None):
|
709 |
+
super(Embedding, self).__init__()
|
710 |
+
|
711 |
+
self.hidden_size = config.hidden_size
|
712 |
+
# Word embeddings (parallel).
|
713 |
+
self.word_embeddings = nn.Embedding(
|
714 |
+
config.padded_vocab_size,
|
715 |
+
self.hidden_size,
|
716 |
+
dtype=config.torch_dtype,
|
717 |
+
device=device
|
718 |
+
)
|
719 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
720 |
+
|
721 |
+
def forward(self, input_ids):
|
722 |
+
# Embeddings.
|
723 |
+
words_embeddings = self.word_embeddings(input_ids)
|
724 |
+
embeddings = words_embeddings
|
725 |
+
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
726 |
+
embeddings = embeddings.transpose(0, 1).contiguous()
|
727 |
+
# If the input flag for fp32 residual connection is set, convert for float.
|
728 |
+
if self.fp32_residual_connection:
|
729 |
+
embeddings = embeddings.float()
|
730 |
+
return embeddings
|
731 |
+
|
732 |
+
|
733 |
+
class ChatGLMModel(ChatGLMPreTrainedModel):
|
734 |
+
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
|
735 |
+
super().__init__(config)
|
736 |
+
if empty_init:
|
737 |
+
init_method = skip_init
|
738 |
+
else:
|
739 |
+
init_method = default_init
|
740 |
+
init_kwargs = {}
|
741 |
+
if device is not None:
|
742 |
+
init_kwargs["device"] = device
|
743 |
+
self.embedding = init_method(Embedding, config, **init_kwargs)
|
744 |
+
self.num_layers = config.num_layers
|
745 |
+
self.multi_query_group_num = config.multi_query_group_num
|
746 |
+
self.kv_channels = config.kv_channels
|
747 |
+
|
748 |
+
# Rotary positional embeddings
|
749 |
+
self.seq_length = config.seq_length
|
750 |
+
rotary_dim = (
|
751 |
+
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
752 |
+
)
|
753 |
+
|
754 |
+
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
755 |
+
dtype=config.torch_dtype)
|
756 |
+
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
757 |
+
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
758 |
+
dtype=config.torch_dtype, **init_kwargs)
|
759 |
+
self.pre_seq_len = config.pre_seq_len
|
760 |
+
self.prefix_projection = config.prefix_projection
|
761 |
+
if self.pre_seq_len is not None:
|
762 |
+
for param in self.parameters():
|
763 |
+
param.requires_grad = False
|
764 |
+
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
765 |
+
self.prefix_encoder = PrefixEncoder(config)
|
766 |
+
self.dropout = torch.nn.Dropout(0.1)
|
767 |
+
|
768 |
+
def get_input_embeddings(self):
|
769 |
+
return self.embedding.word_embeddings
|
770 |
+
|
771 |
+
def get_prompt(self, batch_size, device, dtype=torch.half):
|
772 |
+
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
773 |
+
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
774 |
+
past_key_values = past_key_values.view(
|
775 |
+
batch_size,
|
776 |
+
self.pre_seq_len,
|
777 |
+
self.num_layers * 2,
|
778 |
+
self.multi_query_group_num,
|
779 |
+
self.kv_channels
|
780 |
+
)
|
781 |
+
# seq_len, b, nh, hidden_size
|
782 |
+
past_key_values = self.dropout(past_key_values)
|
783 |
+
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
|
784 |
+
return past_key_values
|
785 |
+
|
786 |
+
def forward(
|
787 |
+
self,
|
788 |
+
input_ids,
|
789 |
+
position_ids: Optional[torch.Tensor] = None,
|
790 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
791 |
+
full_attention_mask: Optional[torch.BoolTensor] = None,
|
792 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
793 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
794 |
+
use_cache: Optional[bool] = None,
|
795 |
+
output_hidden_states: Optional[bool] = None,
|
796 |
+
return_dict: Optional[bool] = None,
|
797 |
+
):
|
798 |
+
output_hidden_states = (
|
799 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
800 |
+
)
|
801 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
802 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
803 |
+
|
804 |
+
batch_size, seq_length = input_ids.shape
|
805 |
+
|
806 |
+
if inputs_embeds is None:
|
807 |
+
inputs_embeds = self.embedding(input_ids)
|
808 |
+
|
809 |
+
if self.pre_seq_len is not None:
|
810 |
+
if past_key_values is None:
|
811 |
+
past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
|
812 |
+
dtype=inputs_embeds.dtype)
|
813 |
+
if attention_mask is not None:
|
814 |
+
attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
|
815 |
+
attention_mask], dim=-1)
|
816 |
+
|
817 |
+
if full_attention_mask is None:
|
818 |
+
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
819 |
+
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
820 |
+
|
821 |
+
# Rotary positional embeddings
|
822 |
+
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
823 |
+
if position_ids is not None:
|
824 |
+
rotary_pos_emb = rotary_pos_emb[position_ids]
|
825 |
+
else:
|
826 |
+
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
827 |
+
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
828 |
+
|
829 |
+
# Run encoder.
|
830 |
+
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
831 |
+
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
832 |
+
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
833 |
+
)
|
834 |
+
|
835 |
+
if not return_dict:
|
836 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
837 |
+
|
838 |
+
return BaseModelOutputWithPast(
|
839 |
+
last_hidden_state=hidden_states,
|
840 |
+
past_key_values=presents,
|
841 |
+
hidden_states=all_hidden_states,
|
842 |
+
attentions=all_self_attentions,
|
843 |
+
)
|
844 |
+
|
845 |
+
def quantize(self, weight_bit_width: int):
|
846 |
+
from .quantization import quantize
|
847 |
+
quantize(self.encoder, weight_bit_width)
|
848 |
+
return self
|
849 |
+
|
850 |
+
|
851 |
+
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
852 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
853 |
+
super().__init__(config)
|
854 |
+
|
855 |
+
self.max_sequence_length = config.max_length
|
856 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
857 |
+
self.config = config
|
858 |
+
self.quantized = False
|
859 |
+
|
860 |
+
if self.config.quantization_bit:
|
861 |
+
self.quantize(self.config.quantization_bit, empty_init=True)
|
862 |
+
|
863 |
+
def _update_model_kwargs_for_generation(
|
864 |
+
self,
|
865 |
+
outputs: ModelOutput,
|
866 |
+
model_kwargs: Dict[str, Any],
|
867 |
+
is_encoder_decoder: bool = False,
|
868 |
+
standardize_cache_format: bool = False,
|
869 |
+
) -> Dict[str, Any]:
|
870 |
+
# update past_key_values
|
871 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
872 |
+
outputs, standardize_cache_format=standardize_cache_format
|
873 |
+
)
|
874 |
+
|
875 |
+
# update attention mask
|
876 |
+
if "attention_mask" in model_kwargs:
|
877 |
+
attention_mask = model_kwargs["attention_mask"]
|
878 |
+
model_kwargs["attention_mask"] = torch.cat(
|
879 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
880 |
+
)
|
881 |
+
|
882 |
+
# update position ids
|
883 |
+
if "position_ids" in model_kwargs:
|
884 |
+
position_ids = model_kwargs["position_ids"]
|
885 |
+
new_position_id = position_ids[..., -1:].clone()
|
886 |
+
new_position_id += 1
|
887 |
+
model_kwargs["position_ids"] = torch.cat(
|
888 |
+
[position_ids, new_position_id], dim=-1
|
889 |
+
)
|
890 |
+
|
891 |
+
model_kwargs["is_first_forward"] = False
|
892 |
+
return model_kwargs
|
893 |
+
|
894 |
+
def prepare_inputs_for_generation(
|
895 |
+
self,
|
896 |
+
input_ids: torch.LongTensor,
|
897 |
+
past_key_values: Optional[torch.Tensor] = None,
|
898 |
+
attention_mask: Optional[torch.Tensor] = None,
|
899 |
+
position_ids: Optional[torch.Tensor] = None,
|
900 |
+
use_cache: Optional[bool] = None,
|
901 |
+
is_first_forward: bool = True,
|
902 |
+
**kwargs
|
903 |
+
) -> dict:
|
904 |
+
# only last token for input_ids if past is not None
|
905 |
+
if position_ids is None:
|
906 |
+
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
907 |
+
if not is_first_forward:
|
908 |
+
if past_key_values is not None:
|
909 |
+
position_ids = position_ids[..., -1:]
|
910 |
+
input_ids = input_ids[:, -1:]
|
911 |
+
return {
|
912 |
+
"input_ids": input_ids,
|
913 |
+
"past_key_values": past_key_values,
|
914 |
+
"position_ids": position_ids,
|
915 |
+
"attention_mask": attention_mask,
|
916 |
+
"return_last_logit": True,
|
917 |
+
"use_cache": use_cache
|
918 |
+
}
|
919 |
+
|
920 |
+
def forward(
|
921 |
+
self,
|
922 |
+
input_ids: Optional[torch.Tensor] = None,
|
923 |
+
position_ids: Optional[torch.Tensor] = None,
|
924 |
+
attention_mask: Optional[torch.Tensor] = None,
|
925 |
+
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
926 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
927 |
+
labels: Optional[torch.Tensor] = None,
|
928 |
+
use_cache: Optional[bool] = None,
|
929 |
+
output_attentions: Optional[bool] = None,
|
930 |
+
output_hidden_states: Optional[bool] = None,
|
931 |
+
return_dict: Optional[bool] = None,
|
932 |
+
return_last_logit: Optional[bool] = False,
|
933 |
+
):
|
934 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
935 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
936 |
+
|
937 |
+
transformer_outputs = self.transformer(
|
938 |
+
input_ids=input_ids,
|
939 |
+
position_ids=position_ids,
|
940 |
+
attention_mask=attention_mask,
|
941 |
+
past_key_values=past_key_values,
|
942 |
+
inputs_embeds=inputs_embeds,
|
943 |
+
use_cache=use_cache,
|
944 |
+
output_hidden_states=output_hidden_states,
|
945 |
+
return_dict=return_dict,
|
946 |
+
)
|
947 |
+
|
948 |
+
hidden_states = transformer_outputs[0]
|
949 |
+
if return_last_logit:
|
950 |
+
hidden_states = hidden_states[-1:]
|
951 |
+
lm_logits = self.transformer.output_layer(hidden_states)
|
952 |
+
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
953 |
+
|
954 |
+
loss = None
|
955 |
+
if labels is not None:
|
956 |
+
lm_logits = lm_logits.to(torch.float32)
|
957 |
+
|
958 |
+
# Shift so that tokens < n predict n
|
959 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
960 |
+
shift_labels = labels[..., 1:].contiguous()
|
961 |
+
# Flatten the tokens
|
962 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
963 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
964 |
+
|
965 |
+
lm_logits = lm_logits.to(hidden_states.dtype)
|
966 |
+
loss = loss.to(hidden_states.dtype)
|
967 |
+
|
968 |
+
if not return_dict:
|
969 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
970 |
+
return ((loss,) + output) if loss is not None else output
|
971 |
+
|
972 |
+
return CausalLMOutputWithPast(
|
973 |
+
loss=loss,
|
974 |
+
logits=lm_logits,
|
975 |
+
past_key_values=transformer_outputs.past_key_values,
|
976 |
+
hidden_states=transformer_outputs.hidden_states,
|
977 |
+
attentions=transformer_outputs.attentions,
|
978 |
+
)
|
979 |
+
|
980 |
+
@staticmethod
|
981 |
+
def _reorder_cache(
|
982 |
+
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
983 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
984 |
+
"""
|
985 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
986 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
987 |
+
beam_idx at every generation step.
|
988 |
+
|
989 |
+
Output shares the same memory storage as `past`.
|
990 |
+
"""
|
991 |
+
return tuple(
|
992 |
+
(
|
993 |
+
layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
|
994 |
+
layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
|
995 |
+
)
|
996 |
+
for layer_past in past
|
997 |
+
)
|
998 |
+
|
999 |
+
def process_response(self, response):
|
1000 |
+
response = response.strip()
|
1001 |
+
response = response.replace("[[训练时间]]", "2023年")
|
1002 |
+
return response
|
1003 |
+
|
1004 |
+
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
|
1005 |
+
prompt = tokenizer.build_prompt(query, history=history)
|
1006 |
+
inputs = tokenizer([prompt], return_tensors="pt")
|
1007 |
+
inputs = inputs.to(self.device)
|
1008 |
+
return inputs
|
1009 |
+
|
1010 |
+
def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
|
1011 |
+
if history:
|
1012 |
+
prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
|
1013 |
+
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
1014 |
+
input_ids = input_ids[1:]
|
1015 |
+
inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False)
|
1016 |
+
else:
|
1017 |
+
prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
|
1018 |
+
inputs = tokenizer([prompt], return_tensors="pt")
|
1019 |
+
inputs = inputs.to(self.device)
|
1020 |
+
return inputs
|
1021 |
+
|
1022 |
+
@torch.inference_mode()
|
1023 |
+
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
|
1024 |
+
do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
|
1025 |
+
if history is None:
|
1026 |
+
history = []
|
1027 |
+
if logits_processor is None:
|
1028 |
+
logits_processor = LogitsProcessorList()
|
1029 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1030 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1031 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1032 |
+
inputs = self.build_inputs(tokenizer, query, history=history)
|
1033 |
+
outputs = self.generate(**inputs, **gen_kwargs)
|
1034 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
1035 |
+
response = tokenizer.decode(outputs)
|
1036 |
+
response = self.process_response(response)
|
1037 |
+
history = history + [(query, response)]
|
1038 |
+
return response, history
|
1039 |
+
|
1040 |
+
@torch.inference_mode()
|
1041 |
+
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
|
1042 |
+
max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
1043 |
+
return_past_key_values=False, **kwargs):
|
1044 |
+
if history is None:
|
1045 |
+
history = []
|
1046 |
+
if logits_processor is None:
|
1047 |
+
logits_processor = LogitsProcessorList()
|
1048 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1049 |
+
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
1050 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1051 |
+
if past_key_values is None and not return_past_key_values:
|
1052 |
+
inputs = self.build_inputs(tokenizer, query, history=history)
|
1053 |
+
else:
|
1054 |
+
inputs = self.build_stream_inputs(tokenizer, query, history=history)
|
1055 |
+
if past_key_values is not None:
|
1056 |
+
past_length = past_key_values[0][0].shape[0]
|
1057 |
+
if self.transformer.pre_seq_len is not None:
|
1058 |
+
past_length -= self.transformer.pre_seq_len
|
1059 |
+
inputs.position_ids += past_length
|
1060 |
+
attention_mask = inputs.attention_mask
|
1061 |
+
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
1062 |
+
inputs['attention_mask'] = attention_mask
|
1063 |
+
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
|
1064 |
+
return_past_key_values=return_past_key_values, **gen_kwargs):
|
1065 |
+
if return_past_key_values:
|
1066 |
+
outputs, past_key_values = outputs
|
1067 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
1068 |
+
response = tokenizer.decode(outputs)
|
1069 |
+
if response and response[-1] != "�":
|
1070 |
+
response = self.process_response(response)
|
1071 |
+
new_history = history + [(query, response)]
|
1072 |
+
if return_past_key_values:
|
1073 |
+
yield response, new_history, past_key_values
|
1074 |
+
else:
|
1075 |
+
yield response, new_history
|
1076 |
+
|
1077 |
+
@torch.inference_mode()
|
1078 |
+
def stream_generate(
|
1079 |
+
self,
|
1080 |
+
input_ids,
|
1081 |
+
generation_config: Optional[GenerationConfig] = None,
|
1082 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
1083 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
1084 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
1085 |
+
return_past_key_values=False,
|
1086 |
+
**kwargs,
|
1087 |
+
):
|
1088 |
+
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
1089 |
+
|
1090 |
+
if generation_config is None:
|
1091 |
+
generation_config = self.generation_config
|
1092 |
+
generation_config = copy.deepcopy(generation_config)
|
1093 |
+
model_kwargs = generation_config.update(**kwargs)
|
1094 |
+
model_kwargs["use_cache"] = generation_config.use_cache
|
1095 |
+
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
1096 |
+
|
1097 |
+
if isinstance(eos_token_id, int):
|
1098 |
+
eos_token_id = [eos_token_id]
|
1099 |
+
|
1100 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
1101 |
+
if has_default_max_length and generation_config.max_new_tokens is None:
|
1102 |
+
warnings.warn(
|
1103 |
+
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
1104 |
+
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
1105 |
+
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
1106 |
+
UserWarning,
|
1107 |
+
)
|
1108 |
+
elif generation_config.max_new_tokens is not None:
|
1109 |
+
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
1110 |
+
if not has_default_max_length:
|
1111 |
+
logger.warn(
|
1112 |
+
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
1113 |
+
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
1114 |
+
"Please refer to the documentation for more information. "
|
1115 |
+
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
1116 |
+
UserWarning,
|
1117 |
+
)
|
1118 |
+
|
1119 |
+
if input_ids_seq_length >= generation_config.max_length:
|
1120 |
+
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
1121 |
+
logger.warning(
|
1122 |
+
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
1123 |
+
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
1124 |
+
" increasing `max_new_tokens`."
|
1125 |
+
)
|
1126 |
+
|
1127 |
+
# 2. Set generation parameters if not already defined
|
1128 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
1129 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
1130 |
+
|
1131 |
+
logits_processor = self._get_logits_processor(
|
1132 |
+
generation_config=generation_config,
|
1133 |
+
input_ids_seq_length=input_ids_seq_length,
|
1134 |
+
encoder_input_ids=input_ids,
|
1135 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
1136 |
+
logits_processor=logits_processor,
|
1137 |
+
)
|
1138 |
+
|
1139 |
+
stopping_criteria = self._get_stopping_criteria(
|
1140 |
+
generation_config=generation_config, stopping_criteria=stopping_criteria
|
1141 |
+
)
|
1142 |
+
logits_warper = self._get_logits_warper(generation_config)
|
1143 |
+
|
1144 |
+
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
1145 |
+
scores = None
|
1146 |
+
while True:
|
1147 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
1148 |
+
# forward pass to get next token
|
1149 |
+
outputs = self(
|
1150 |
+
**model_inputs,
|
1151 |
+
return_dict=True,
|
1152 |
+
output_attentions=False,
|
1153 |
+
output_hidden_states=False,
|
1154 |
+
)
|
1155 |
+
|
1156 |
+
next_token_logits = outputs.logits[:, -1, :]
|
1157 |
+
|
1158 |
+
# pre-process distribution
|
1159 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
1160 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
1161 |
+
|
1162 |
+
# sample
|
1163 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
1164 |
+
if generation_config.do_sample:
|
1165 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
1166 |
+
else:
|
1167 |
+
next_tokens = torch.argmax(probs, dim=-1)
|
1168 |
+
|
1169 |
+
# update generated ids, model inputs, and length for next step
|
1170 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
1171 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
1172 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
1173 |
+
)
|
1174 |
+
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
1175 |
+
if return_past_key_values:
|
1176 |
+
yield input_ids, outputs.past_key_values
|
1177 |
+
else:
|
1178 |
+
yield input_ids
|
1179 |
+
# stop when each sentence is finished, or if we exceed the maximum length
|
1180 |
+
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
1181 |
+
break
|
1182 |
+
|
1183 |
+
def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
|
1184 |
+
if bits == 0:
|
1185 |
+
return
|
1186 |
+
|
1187 |
+
from .quantization import quantize
|
1188 |
+
|
1189 |
+
if self.quantized:
|
1190 |
+
logger.info("Already quantized.")
|
1191 |
+
return self
|
1192 |
+
|
1193 |
+
self.quantized = True
|
1194 |
+
|
1195 |
+
self.config.quantization_bit = bits
|
1196 |
+
|
1197 |
+
self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
|
1198 |
+
**kwargs)
|
1199 |
+
return self
|
1200 |
+
|
1201 |
+
|
1202 |
+
class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
1203 |
+
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
1204 |
+
super().__init__(config)
|
1205 |
+
|
1206 |
+
self.num_labels = config.num_labels
|
1207 |
+
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
1208 |
+
|
1209 |
+
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
|
1210 |
+
if config.classifier_dropout is not None:
|
1211 |
+
self.dropout = nn.Dropout(config.classifier_dropout)
|
1212 |
+
else:
|
1213 |
+
self.dropout = None
|
1214 |
+
self.config = config
|
1215 |
+
|
1216 |
+
if self.config.quantization_bit:
|
1217 |
+
self.quantize(self.config.quantization_bit, empty_init=True)
|
1218 |
+
|
1219 |
+
def forward(
|
1220 |
+
self,
|
1221 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1222 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1223 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1224 |
+
full_attention_mask: Optional[torch.Tensor] = None,
|
1225 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
1226 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
1227 |
+
labels: Optional[torch.LongTensor] = None,
|
1228 |
+
use_cache: Optional[bool] = None,
|
1229 |
+
output_hidden_states: Optional[bool] = None,
|
1230 |
+
return_dict: Optional[bool] = None,
|
1231 |
+
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
1232 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1233 |
+
|
1234 |
+
transformer_outputs = self.transformer(
|
1235 |
+
input_ids=input_ids,
|
1236 |
+
position_ids=position_ids,
|
1237 |
+
attention_mask=attention_mask,
|
1238 |
+
full_attention_mask=full_attention_mask,
|
1239 |
+
past_key_values=past_key_values,
|
1240 |
+
inputs_embeds=inputs_embeds,
|
1241 |
+
use_cache=use_cache,
|
1242 |
+
output_hidden_states=output_hidden_states,
|
1243 |
+
return_dict=return_dict,
|
1244 |
+
)
|
1245 |
+
|
1246 |
+
hidden_states = transformer_outputs[0]
|
1247 |
+
pooled_hidden_states = hidden_states[-1]
|
1248 |
+
if self.dropout is not None:
|
1249 |
+
pooled_hidden_states = self.dropout(pooled_hidden_states)
|
1250 |
+
logits = self.classifier_head(pooled_hidden_states)
|
1251 |
+
|
1252 |
+
loss = None
|
1253 |
+
if labels is not None:
|
1254 |
+
if self.config.problem_type is None:
|
1255 |
+
if self.num_labels == 1:
|
1256 |
+
self.config.problem_type = "regression"
|
1257 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1258 |
+
self.config.problem_type = "single_label_classification"
|
1259 |
+
else:
|
1260 |
+
self.config.problem_type = "multi_label_classification"
|
1261 |
+
|
1262 |
+
if self.config.problem_type == "regression":
|
1263 |
+
loss_fct = MSELoss()
|
1264 |
+
if self.num_labels == 1:
|
1265 |
+
loss = loss_fct(logits.squeeze().float(), labels.squeeze())
|
1266 |
+
else:
|
1267 |
+
loss = loss_fct(logits.float(), labels)
|
1268 |
+
elif self.config.problem_type == "single_label_classification":
|
1269 |
+
loss_fct = CrossEntropyLoss()
|
1270 |
+
loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
|
1271 |
+
elif self.config.problem_type == "multi_label_classification":
|
1272 |
+
loss_fct = BCEWithLogitsLoss()
|
1273 |
+
loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
|
1274 |
+
|
1275 |
+
if not return_dict:
|
1276 |
+
output = (logits,) + transformer_outputs[1:]
|
1277 |
+
return ((loss,) + output) if loss is not None else output
|
1278 |
+
|
1279 |
+
return SequenceClassifierOutputWithPast(
|
1280 |
+
loss=loss,
|
1281 |
+
logits=logits,
|
1282 |
+
past_key_values=transformer_outputs.past_key_values,
|
1283 |
+
hidden_states=transformer_outputs.hidden_states,
|
1284 |
+
attentions=transformer_outputs.attentions,
|
1285 |
+
)
|
ChatGLM2/demo/CMakeLists.txt
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cmake_minimum_required(VERSION 2.8)
|
2 |
+
project(chatglm)
|
3 |
+
|
4 |
+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON CACHE INTERNAL "")
|
5 |
+
|
6 |
+
if (NOT DEFINED TARGET_ARCH)
|
7 |
+
set(TARGET_ARCH pcie)
|
8 |
+
endif()
|
9 |
+
|
10 |
+
include_directories(${PROJECT_SOURCE_DIR}/../support/include)
|
11 |
+
|
12 |
+
if (${CMAKE_HOST_SYSTEM_PROCESSOR} STREQUAL "aarch64")
|
13 |
+
add_definitions(-DSOC_TARGET)
|
14 |
+
link_directories(${PROJECT_SOURCE_DIR}/../support/lib_soc)
|
15 |
+
message("SoC mode, starting......")
|
16 |
+
elseif (${TARGET_ARCH} STREQUAL "pcie")
|
17 |
+
add_definitions(-DPCIE_TARGET)
|
18 |
+
link_directories(${PROJECT_SOURCE_DIR}/../support/lib_pcie)
|
19 |
+
message("PCIE mode, starting......")
|
20 |
+
elseif (${TARGET_ARCH} STREQUAL "soc")
|
21 |
+
add_definitions(-DSOC_TARGET)
|
22 |
+
set(CMAKE_C_COMPILER /opt/aarch64-linux-gnu-7.5.0/bin/aarch64-linux-gnu-gcc)
|
23 |
+
set(CMAKE_ASM_COMPILER /opt/aarch64-linux-gnu-7.5.0/bin/aarch64-linux-gnu-gcc)
|
24 |
+
set(CMAKE_CXX_COMPILER /opt/aarch64-linux-gnu-7.5.0/bin/aarch64-linux-gnu-g++)
|
25 |
+
link_directories(${PROJECT_SOURCE_DIR}/../support/lib_soc)
|
26 |
+
message("SoC mode, starting......")
|
27 |
+
endif()
|
28 |
+
|
29 |
+
add_definitions(-DDEBUG --std=c++17 -fPIC -Wall -Werror)
|
30 |
+
set(CMAKE_BUILD_TYPE "Debug")
|
31 |
+
|
32 |
+
add_executable(chatglm demo.cpp)
|
33 |
+
target_link_libraries(chatglm bmrt bmlib sentencepiece)
|
ChatGLM2/demo/demo.cpp
ADDED
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
//===----------------------------------------------------------------------===//
|
2 |
+
//
|
3 |
+
// Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
|
4 |
+
//
|
5 |
+
// TPU-MLIR is licensed under the 2-Clause BSD License except for the
|
6 |
+
// third-party components.
|
7 |
+
//
|
8 |
+
//===----------------------------------------------------------------------===//
|
9 |
+
|
10 |
+
#include <iostream>
|
11 |
+
#include <cstdlib>
|
12 |
+
#include <vector>
|
13 |
+
#include <assert.h>
|
14 |
+
#include <chrono>
|
15 |
+
#include <algorithm>
|
16 |
+
#include "memory.h"
|
17 |
+
#include "sentencepiece/sentencepiece_processor.h"
|
18 |
+
#include "bmruntime_interface.h"
|
19 |
+
#include <getopt.h>
|
20 |
+
#include <stdio.h>
|
21 |
+
#include <inttypes.h>
|
22 |
+
|
23 |
+
static const uint16_t ATTENTION_MASK = 0xF0E2;
|
24 |
+
|
25 |
+
class ChatGLM {
|
26 |
+
public:
|
27 |
+
void init(const std::vector<int> &devid, std::string model_path, std::string tokenizer_path);
|
28 |
+
void chat();
|
29 |
+
void deinit();
|
30 |
+
|
31 |
+
private:
|
32 |
+
void answer(const std::string &input_str);
|
33 |
+
void tokenizer_encode(const std::string &input_str, std::vector<int> &tokens);
|
34 |
+
int forward_first(std::vector<int> &tokens);
|
35 |
+
int forward_next(int cur_token);
|
36 |
+
void move2end(const bm_tensor_t &kv);
|
37 |
+
void load_sentencepiece(std::string tokenizer_path);
|
38 |
+
|
39 |
+
private:
|
40 |
+
std::vector<bm_handle_t> handles;
|
41 |
+
bm_handle_t bm_handle;
|
42 |
+
void *p_bmrt;
|
43 |
+
sentencepiece::SentencePieceProcessor sentencepiece;
|
44 |
+
const bm_net_info_t *net_embed;
|
45 |
+
const bm_net_info_t *net_embed_cache;
|
46 |
+
const bm_net_info_t *net_lm;
|
47 |
+
std::vector<const bm_net_info_t *> net_blocks;
|
48 |
+
std::vector<const bm_net_info_t *> net_blocks_cache;
|
49 |
+
std::vector<bm_tensor_t> inputs_embed_512, outputs_embed_512;
|
50 |
+
std::vector<bm_tensor_t> inputs_pid, next_pid, inputs_attention, next_attention;
|
51 |
+
std::vector<std::vector<bm_tensor_t>> past_key, past_value;
|
52 |
+
std::vector<bm_tensor_t> inputs_lm, outputs_lm;
|
53 |
+
std::string name_embed;
|
54 |
+
std::string name_embed_cache;
|
55 |
+
std::string name_lm;
|
56 |
+
std::vector<std::string> name_blocks;
|
57 |
+
std::vector<std::string> name_blocks_cache;
|
58 |
+
std::vector<std::pair<std::string, std::string>> history_vector;
|
59 |
+
std::vector<int> history_tokens;
|
60 |
+
std::string cur_answer = "";
|
61 |
+
|
62 |
+
int device_num;
|
63 |
+
int round = 0;
|
64 |
+
int token_length;
|
65 |
+
int EOS;
|
66 |
+
int SEQLEN;
|
67 |
+
int NUM_LAYERS;
|
68 |
+
};
|
69 |
+
|
70 |
+
void ChatGLM::load_sentencepiece(std::string tokenizer_path) {
|
71 |
+
printf("Load %s ... ", tokenizer_path.c_str());
|
72 |
+
auto status = sentencepiece.Load(tokenizer_path);
|
73 |
+
if (!status.ok()) {
|
74 |
+
std::cout << status.ToString() << std::endl;
|
75 |
+
exit(-1);
|
76 |
+
}
|
77 |
+
EOS = sentencepiece.eos_id();
|
78 |
+
printf("Done!\n");
|
79 |
+
}
|
80 |
+
|
81 |
+
void ChatGLM::init(const std::vector<int> &devices, std::string model_path, std::string tokenizer_path) {
|
82 |
+
device_num = devices.size();
|
83 |
+
load_sentencepiece(tokenizer_path);
|
84 |
+
// request bm_handle
|
85 |
+
std::cout << "Device [ ";
|
86 |
+
for (auto d : devices) {
|
87 |
+
std::cout << d << " ";
|
88 |
+
}
|
89 |
+
std::cout << "] loading ....\n";
|
90 |
+
for (auto d : devices) {
|
91 |
+
bm_handle_t h;
|
92 |
+
bm_status_t status = bm_dev_request(&h, d);
|
93 |
+
assert(BM_SUCCESS == status);
|
94 |
+
handles.push_back(h);
|
95 |
+
}
|
96 |
+
bm_handle = handles[0];
|
97 |
+
|
98 |
+
// create bmruntime
|
99 |
+
#ifdef SOC_TARGET
|
100 |
+
p_bmrt = bmrt_create(handles[0]);
|
101 |
+
#else
|
102 |
+
p_bmrt = bmrt_create_ex(handles.data(), handles.size());
|
103 |
+
#endif
|
104 |
+
assert(NULL != p_bmrt);
|
105 |
+
|
106 |
+
// load bmodel by file
|
107 |
+
printf("Model[%s] loading ....\n", model_path.c_str());
|
108 |
+
bool ret = bmrt_load_bmodel(p_bmrt, model_path.c_str());
|
109 |
+
assert(true == ret);
|
110 |
+
printf("Done!\n");
|
111 |
+
|
112 |
+
// set NUM_LAYERS
|
113 |
+
auto num_nets = bmrt_get_network_number(p_bmrt);
|
114 |
+
NUM_LAYERS = (num_nets - 2) / 2;
|
115 |
+
|
116 |
+
// net names
|
117 |
+
name_embed = "embedding";
|
118 |
+
name_embed_cache = "embedding_cache";
|
119 |
+
name_lm = "lm_head";
|
120 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
121 |
+
name_blocks.emplace_back("block_" + std::to_string(i));
|
122 |
+
name_blocks_cache.emplace_back("block_cache_" + std::to_string(i));
|
123 |
+
}
|
124 |
+
|
125 |
+
// net infos
|
126 |
+
net_embed = bmrt_get_network_info(p_bmrt, name_embed.c_str());
|
127 |
+
net_embed_cache = bmrt_get_network_info(p_bmrt, name_embed_cache.c_str());
|
128 |
+
net_lm = bmrt_get_network_info(p_bmrt, name_lm.c_str());
|
129 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
130 |
+
net_blocks.emplace_back(
|
131 |
+
bmrt_get_network_info(p_bmrt, name_blocks[i].c_str()));
|
132 |
+
net_blocks_cache.emplace_back(
|
133 |
+
bmrt_get_network_info(p_bmrt, name_blocks_cache[i].c_str()));
|
134 |
+
}
|
135 |
+
|
136 |
+
// set SEQLEN
|
137 |
+
SEQLEN = net_embed->stages[0].input_shapes[0].dims[1];
|
138 |
+
|
139 |
+
// resize
|
140 |
+
net_blocks.resize(NUM_LAYERS);
|
141 |
+
net_blocks_cache.resize(NUM_LAYERS);
|
142 |
+
past_key.resize(NUM_LAYERS);
|
143 |
+
past_value.resize(NUM_LAYERS);
|
144 |
+
|
145 |
+
// net device mem
|
146 |
+
inputs_embed_512.resize(net_embed->input_num);
|
147 |
+
for (int i = 0; i < device_num; ++i) {
|
148 |
+
ret = bmrt_tensor_ex(&inputs_embed_512[i], p_bmrt,
|
149 |
+
net_embed->input_loc_devices[i],
|
150 |
+
net_embed->input_dtypes[i],
|
151 |
+
net_embed->stages[0].input_shapes[i]);
|
152 |
+
assert(true == ret);
|
153 |
+
}
|
154 |
+
|
155 |
+
outputs_embed_512.resize(net_embed->output_num);
|
156 |
+
for (int i = 0; i < device_num; ++i) {
|
157 |
+
ret = bmrt_tensor_ex(&outputs_embed_512[i], p_bmrt,
|
158 |
+
net_embed->output_loc_devices[i],
|
159 |
+
net_embed->output_dtypes[i],
|
160 |
+
net_embed->stages[0].output_shapes[i]);
|
161 |
+
assert(true == ret);
|
162 |
+
}
|
163 |
+
|
164 |
+
inputs_pid.resize(device_num);
|
165 |
+
inputs_attention.resize(device_num);
|
166 |
+
int in_num = net_blocks[0]->input_num / device_num;
|
167 |
+
for (int i = 0; i < device_num; ++i) {
|
168 |
+
ret = bmrt_tensor_ex(&inputs_pid[i], p_bmrt,
|
169 |
+
net_blocks[0]->input_loc_devices[1 + i * in_num],
|
170 |
+
net_blocks[0]->input_dtypes[1 + i * in_num],
|
171 |
+
net_blocks[0]->stages[0].input_shapes[1 + i * in_num]);
|
172 |
+
assert(true == ret);
|
173 |
+
|
174 |
+
ret = bmrt_tensor_ex(&inputs_attention[i], p_bmrt,
|
175 |
+
net_blocks[0]->input_loc_devices[2 + i * in_num],
|
176 |
+
net_blocks[0]->input_dtypes[2 + i * in_num],
|
177 |
+
net_blocks[0]->stages[0].input_shapes[2 + i * in_num]);
|
178 |
+
assert(true == ret);
|
179 |
+
}
|
180 |
+
|
181 |
+
|
182 |
+
next_pid.resize(device_num);
|
183 |
+
next_attention.resize(device_num);
|
184 |
+
int in_num_cache = net_blocks_cache[0]->input_num / device_num;
|
185 |
+
for (int i = 0; i < device_num; ++i) {
|
186 |
+
ret = bmrt_tensor_ex(&next_pid[i], p_bmrt,
|
187 |
+
net_blocks_cache[0]->input_loc_devices[1 + i * in_num_cache],
|
188 |
+
net_blocks_cache[0]->input_dtypes[1 + i * in_num_cache],
|
189 |
+
net_blocks_cache[0]->stages[0].input_shapes[1 + i * in_num_cache]);
|
190 |
+
assert(true == ret);
|
191 |
+
|
192 |
+
ret = bmrt_tensor_ex(&next_attention[i], p_bmrt,
|
193 |
+
net_blocks_cache[0]->input_loc_devices[2 + i * in_num_cache],
|
194 |
+
net_blocks_cache[0]->input_dtypes[2 + i * in_num_cache],
|
195 |
+
net_blocks_cache[0]->stages[0].input_shapes[2 + i * in_num_cache]);
|
196 |
+
assert(true == ret);
|
197 |
+
}
|
198 |
+
|
199 |
+
int out_num = net_blocks[0]->output_num / device_num;
|
200 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
201 |
+
past_key[i].resize(device_num);
|
202 |
+
past_value[i].resize(device_num);
|
203 |
+
for (int j = 0; j < device_num; j++) {
|
204 |
+
ret = bmrt_tensor_ex(&past_key[i][j], p_bmrt,
|
205 |
+
net_blocks[0]->output_loc_devices[1 + j * out_num],
|
206 |
+
net_blocks[0]->output_dtypes[1 + j * out_num],
|
207 |
+
net_blocks[0]->stages[0].output_shapes[1 + j * out_num]);
|
208 |
+
assert(true == ret);
|
209 |
+
ret = bmrt_tensor_ex(&past_value[i][j], p_bmrt,
|
210 |
+
net_blocks[0]->output_loc_devices[2 + j * out_num],
|
211 |
+
net_blocks[0]->output_dtypes[2 + j * out_num],
|
212 |
+
net_blocks[0]->stages[0].output_shapes[2 + j * out_num]);
|
213 |
+
assert(true == ret);
|
214 |
+
}
|
215 |
+
}
|
216 |
+
|
217 |
+
inputs_lm.resize(device_num);
|
218 |
+
outputs_lm.resize(device_num);
|
219 |
+
for (int i = 0; i < device_num; ++i) {
|
220 |
+
ret = bmrt_tensor_ex(&inputs_lm[i], p_bmrt, i, net_lm->input_dtypes[0],
|
221 |
+
net_lm->stages[0].input_shapes[0]);
|
222 |
+
assert(true == ret);
|
223 |
+
ret = bmrt_tensor_ex(&outputs_lm[i], p_bmrt, i, net_lm->output_dtypes[0],
|
224 |
+
net_lm->stages[0].output_shapes[0]);
|
225 |
+
assert(true == ret);
|
226 |
+
}
|
227 |
+
}
|
228 |
+
|
229 |
+
void ChatGLM::deinit() {
|
230 |
+
for (int i = 0; i < device_num; ++i) {
|
231 |
+
bm_free_device(handles[i], inputs_embed_512[i].device_mem);
|
232 |
+
bm_free_device(handles[i], outputs_embed_512[i].device_mem);
|
233 |
+
bm_free_device(handles[i], inputs_pid[i].device_mem);
|
234 |
+
bm_free_device(handles[i], next_pid[i].device_mem);
|
235 |
+
bm_free_device(handles[i], inputs_attention[i].device_mem);
|
236 |
+
bm_free_device(handles[i], next_attention[i].device_mem);
|
237 |
+
bm_free_device(handles[i], inputs_lm[i].device_mem);
|
238 |
+
bm_free_device(handles[i], outputs_lm[i].device_mem);
|
239 |
+
}
|
240 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
241 |
+
for (int j = 0; j < device_num; j++) {
|
242 |
+
bm_free_device(handles[j], past_key[i][j].device_mem);
|
243 |
+
bm_free_device(handles[j], past_value[i][j].device_mem);
|
244 |
+
}
|
245 |
+
}
|
246 |
+
bmrt_destroy(p_bmrt);
|
247 |
+
for (auto h : handles) {
|
248 |
+
bm_dev_free(h);
|
249 |
+
}
|
250 |
+
}
|
251 |
+
|
252 |
+
// after first block, move real result to end of mem
|
253 |
+
void ChatGLM::move2end(const bm_tensor_t &kv) {
|
254 |
+
if (token_length >= SEQLEN) {
|
255 |
+
return;
|
256 |
+
}
|
257 |
+
auto total_size = bm_mem_get_device_size(kv.device_mem);
|
258 |
+
auto bytes = total_size / SEQLEN;
|
259 |
+
auto real_size = token_length * bytes;
|
260 |
+
auto mem =
|
261 |
+
bm_mem_from_device(bm_mem_get_device_addr(kv.device_mem), real_size);
|
262 |
+
auto buffer = new uint8_t[real_size];
|
263 |
+
auto dst = new uint8_t[total_size];
|
264 |
+
bm_memcpy_d2s(bm_handle, (void *)buffer, mem);
|
265 |
+
memset(dst, 0, total_size - real_size);
|
266 |
+
memcpy(dst + total_size - real_size, buffer, real_size);
|
267 |
+
bm_memcpy_s2d(bm_handle, kv.device_mem, (void *)dst);
|
268 |
+
delete[] buffer;
|
269 |
+
delete[] dst;
|
270 |
+
}
|
271 |
+
|
272 |
+
int ChatGLM::forward_first(std::vector<int> &tokens) {
|
273 |
+
std::vector<int> input_ids(SEQLEN, 0);
|
274 |
+
std::vector<int> position_id(SEQLEN, 0);
|
275 |
+
std::vector<uint16_t> attention_mask(SEQLEN * SEQLEN, 0);
|
276 |
+
|
277 |
+
std::copy(tokens.begin(), tokens.end(), input_ids.data());
|
278 |
+
|
279 |
+
token_length = tokens.size();
|
280 |
+
for (int i = 0; i < token_length; i++) {
|
281 |
+
position_id[i] = i;
|
282 |
+
}
|
283 |
+
for (int i = 0; i < SEQLEN; i++) {
|
284 |
+
for (int j = 0; j < SEQLEN; j++) {
|
285 |
+
if (j <= i && i < token_length) {
|
286 |
+
} else {
|
287 |
+
attention_mask[i * SEQLEN + j] = ATTENTION_MASK;
|
288 |
+
}
|
289 |
+
}
|
290 |
+
}
|
291 |
+
|
292 |
+
// forward embeding
|
293 |
+
std::vector<int> input_nums(device_num, 1);
|
294 |
+
std::vector<void*> datas(device_num, (void*)input_ids.data());
|
295 |
+
bmrt_memcpy_s2d_parallel(p_bmrt, inputs_embed_512.data(), datas.data(),
|
296 |
+
input_nums.data(), device_num);
|
297 |
+
auto ret =
|
298 |
+
bmrt_launch_tensor_ex(p_bmrt, name_embed.c_str(),
|
299 |
+
inputs_embed_512.data(), inputs_embed_512.size(),
|
300 |
+
outputs_embed_512.data(), outputs_embed_512.size(),
|
301 |
+
true, false);
|
302 |
+
assert(ret);
|
303 |
+
bm_thread_sync(bm_handle);
|
304 |
+
|
305 |
+
// forward blocks
|
306 |
+
std::vector<void*> pos_id_datas(device_num, position_id.data());
|
307 |
+
std::vector<void*> in_attn_datas(device_num, attention_mask.data());
|
308 |
+
bmrt_memcpy_s2d_parallel(p_bmrt, inputs_pid.data(), pos_id_datas.data(),
|
309 |
+
input_nums.data(), device_num);
|
310 |
+
bmrt_memcpy_s2d_parallel(p_bmrt, inputs_attention.data(),in_attn_datas.data(),
|
311 |
+
input_nums.data(), device_num);
|
312 |
+
auto embed_512 = outputs_embed_512;
|
313 |
+
std::vector<bm_tensor_t> inputs_block;
|
314 |
+
std::vector<bm_tensor_t> outputs_block;
|
315 |
+
for (int i = 0; i < device_num; ++i) {
|
316 |
+
embed_512[i].shape = net_blocks[0]->stages[0].input_shapes[0];
|
317 |
+
inputs_block.push_back(embed_512[i]);
|
318 |
+
inputs_block.push_back(inputs_pid[i]);
|
319 |
+
inputs_block.push_back(inputs_attention[i]);
|
320 |
+
outputs_block.push_back(embed_512[i]);
|
321 |
+
outputs_block.push_back(past_key[0][i]);
|
322 |
+
outputs_block.push_back(past_value[0][i]);
|
323 |
+
}
|
324 |
+
|
325 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
326 |
+
for (int j = 0; j < device_num; ++j) {
|
327 |
+
outputs_block[1 + j * 3] = past_key[i][j];
|
328 |
+
outputs_block[2 + j * 3] = past_value[i][j];
|
329 |
+
}
|
330 |
+
ret = bmrt_launch_tensor_ex(p_bmrt, name_blocks[i].c_str(),
|
331 |
+
inputs_block.data(), inputs_block.size(),
|
332 |
+
outputs_block.data(), outputs_block.size(),
|
333 |
+
true, false);
|
334 |
+
assert(ret);
|
335 |
+
bm_thread_sync(bm_handle);
|
336 |
+
for (int j = 0; j < device_num; ++j) {
|
337 |
+
move2end(past_key[i][j]);
|
338 |
+
move2end(past_value[i][j]);
|
339 |
+
}
|
340 |
+
}
|
341 |
+
|
342 |
+
// forward lmhead
|
343 |
+
int bytes = embed_512[0].device_mem.size / SEQLEN;
|
344 |
+
bm_memcpy_d2d_byte(bm_handle, inputs_lm[0].device_mem, 0,
|
345 |
+
embed_512[0].device_mem, (token_length - 1) * bytes,
|
346 |
+
bytes);
|
347 |
+
ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm[0], 1,
|
348 |
+
&outputs_lm[0], 1, true, false);
|
349 |
+
bm_thread_sync(bm_handle);
|
350 |
+
|
351 |
+
int token = 0;
|
352 |
+
bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm[0].device_mem);
|
353 |
+
return token;
|
354 |
+
}
|
355 |
+
|
356 |
+
int ChatGLM::forward_next(int cur_token) {
|
357 |
+
std::vector<uint16_t> attention_mask(SEQLEN + 1, 0);
|
358 |
+
for (int i = 0; i <= SEQLEN - token_length; i++) {
|
359 |
+
attention_mask[i] = ATTENTION_MASK;
|
360 |
+
}
|
361 |
+
int32_t position_id = token_length - 1;
|
362 |
+
|
363 |
+
// forward embedding
|
364 |
+
std::vector<bm_tensor_t> inputs_embed;
|
365 |
+
std::vector<void*> input_datas;
|
366 |
+
std::vector<int> input_nums(device_num, 1);
|
367 |
+
for (int i = 0; i < device_num; ++i) {
|
368 |
+
inputs_embed.push_back(outputs_lm[i]); // token_id
|
369 |
+
inputs_embed[i].shape = net_embed_cache->stages[0].input_shapes[0];
|
370 |
+
input_datas.push_back((void*)(&cur_token));
|
371 |
+
}
|
372 |
+
bmrt_memcpy_s2d_parallel(p_bmrt, inputs_embed.data(), input_datas.data(),
|
373 |
+
input_nums.data(), device_num);
|
374 |
+
auto ret = bmrt_launch_tensor_ex(p_bmrt, name_embed_cache.c_str(),
|
375 |
+
inputs_embed.data(), inputs_embed.size(),
|
376 |
+
inputs_lm.data(), inputs_lm.size(), true, false);
|
377 |
+
assert(ret);
|
378 |
+
bm_thread_sync(bm_handle);
|
379 |
+
|
380 |
+
// forward blocks
|
381 |
+
std::vector<void*> attn_datas(device_num, attention_mask.data());
|
382 |
+
std::vector<void*> pid_datas(device_num, &position_id);
|
383 |
+
bmrt_memcpy_s2d_parallel(p_bmrt, next_attention.data(), attn_datas.data(),
|
384 |
+
input_nums.data(), device_num);
|
385 |
+
bmrt_memcpy_s2d_parallel(p_bmrt, next_pid.data(), pid_datas.data(),
|
386 |
+
input_nums.data(), device_num);
|
387 |
+
|
388 |
+
// WARNING: make inputs_lm device_num
|
389 |
+
std::vector<bm_tensor_t> embed_1 = inputs_lm;
|
390 |
+
for (int i = 0; i < device_num; ++i) {
|
391 |
+
embed_1[i].shape = net_blocks_cache[0]->stages[0].input_shapes[0];
|
392 |
+
}
|
393 |
+
std::vector<bm_tensor_t> inputs_block;
|
394 |
+
std::vector<bm_tensor_t> outputs_block;
|
395 |
+
for (int i = 0; i < device_num; ++i) {
|
396 |
+
inputs_block.push_back(embed_1[i]);
|
397 |
+
inputs_block.push_back(next_pid[i]);
|
398 |
+
inputs_block.push_back(next_attention[i]);
|
399 |
+
inputs_block.push_back(past_key[0][i]);
|
400 |
+
inputs_block.push_back(past_value[0][i]);
|
401 |
+
outputs_block.push_back(embed_1[i]);
|
402 |
+
outputs_block.push_back(past_key[0][i]);
|
403 |
+
outputs_block.push_back(past_value[0][i]);
|
404 |
+
}
|
405 |
+
|
406 |
+
for (int i = 0; i < NUM_LAYERS; i++) {
|
407 |
+
for (int j = 0; j < device_num; ++j) {
|
408 |
+
inputs_block[3 + j * 5] = past_key[i][j];
|
409 |
+
inputs_block[4 + j * 5] = past_value[i][j];
|
410 |
+
outputs_block[1 + j * 3] = past_key[i][j];
|
411 |
+
outputs_block[2 + j * 3] = past_value[i][j];
|
412 |
+
}
|
413 |
+
ret = bmrt_launch_tensor_ex(p_bmrt, name_blocks_cache[i].c_str(),
|
414 |
+
inputs_block.data(), inputs_block.size(),
|
415 |
+
outputs_block.data(), outputs_block.size(),
|
416 |
+
true, false);
|
417 |
+
assert(ret);
|
418 |
+
bm_thread_sync(bm_handle);
|
419 |
+
}
|
420 |
+
|
421 |
+
// forward lmhead
|
422 |
+
ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm[0], 1,
|
423 |
+
&outputs_lm[0], 1, true, false);
|
424 |
+
assert(ret);
|
425 |
+
bm_thread_sync(bm_handle);
|
426 |
+
|
427 |
+
int token = 0;
|
428 |
+
bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm[0].device_mem);
|
429 |
+
return token;
|
430 |
+
}
|
431 |
+
|
432 |
+
std::string build_prompt(std::string query, std::vector<std::pair<std::string, std::string>> history = {}) {
|
433 |
+
std::string prompt = "";
|
434 |
+
int round_number = 1;
|
435 |
+
for (const auto& item : history) {
|
436 |
+
prompt += "[Round " + std::to_string(round_number) + "]\n\n问:" + item.first + "\n\n答:" + item.second + "\n\n";
|
437 |
+
round_number ++;
|
438 |
+
}
|
439 |
+
prompt += "[Round " + std::to_string(history.size() + 1) + "]\n\n问:" + query + "\n\n答:";
|
440 |
+
return prompt;
|
441 |
+
}
|
442 |
+
|
443 |
+
void ChatGLM::chat() {
|
444 |
+
while (true) {
|
445 |
+
std::cout << "\nQuestion: ";
|
446 |
+
std::string input_str;
|
447 |
+
std::getline(std::cin, input_str);
|
448 |
+
if (input_str == "exit") {
|
449 |
+
break;
|
450 |
+
}
|
451 |
+
std::cout << "\nAnswer: " << std::flush;
|
452 |
+
answer(input_str);
|
453 |
+
std::cout << std::endl;
|
454 |
+
}
|
455 |
+
}
|
456 |
+
|
457 |
+
void ChatGLM::answer(const std::string &input_str) {
|
458 |
+
// auto time_0 = std::chrono::system_clock::now();
|
459 |
+
std::string query = build_prompt(input_str, history_vector);
|
460 |
+
int tok_num = 0;
|
461 |
+
std::vector<int> tokens;
|
462 |
+
std::vector<int> prompt{64790, 64792};
|
463 |
+
sentencepiece.Encode(query, &tokens);
|
464 |
+
|
465 |
+
if (tokens.empty()) {
|
466 |
+
printf("Sorry: your question is too wierd!!\n");
|
467 |
+
return;
|
468 |
+
}
|
469 |
+
// tokens is not empty
|
470 |
+
tokens.insert(tokens.begin(), prompt.begin(), prompt.end());
|
471 |
+
|
472 |
+
// make sure token not too large
|
473 |
+
if ((int)tokens.size() > SEQLEN - 10) {
|
474 |
+
// reset
|
475 |
+
tokens.clear();
|
476 |
+
cur_answer.clear();
|
477 |
+
printf("Error: your question is too large!\n");
|
478 |
+
return;
|
479 |
+
}
|
480 |
+
|
481 |
+
int pre_token = 0;
|
482 |
+
auto t0 = std::chrono::system_clock::now();
|
483 |
+
int token = forward_first(tokens);
|
484 |
+
auto t1 = std::chrono::system_clock::now();
|
485 |
+
while (token != EOS && token_length < SEQLEN) {
|
486 |
+
std::string pre_word;
|
487 |
+
std::string word;
|
488 |
+
std::vector<int> pre_ids = {pre_token};
|
489 |
+
std::vector<int> ids = {pre_token, token};
|
490 |
+
sentencepiece.Decode(pre_ids, &pre_word);
|
491 |
+
sentencepiece.Decode(ids, &word);
|
492 |
+
std::string diff = word.substr(pre_word.size());
|
493 |
+
cur_answer += diff;
|
494 |
+
tokens.emplace_back(token);
|
495 |
+
std::cout << diff << std::flush;
|
496 |
+
if (token_length < SEQLEN) {
|
497 |
+
token_length++;
|
498 |
+
}
|
499 |
+
tok_num++;
|
500 |
+
token = forward_next(token);
|
501 |
+
}
|
502 |
+
auto t2 = std::chrono::system_clock::now();
|
503 |
+
auto use0 = std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0);
|
504 |
+
auto use1 = std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1);
|
505 |
+
printf("\n\nfirst token latency: %f s", (use0.count() * 1e-6));
|
506 |
+
printf("\nspeed: %f token/s\n", tok_num / (use1.count() * 1e-6));
|
507 |
+
|
508 |
+
if (token_length >= SEQLEN) {
|
509 |
+
printf("Warning: Reach to the max sequence length!\n");
|
510 |
+
history_vector.push_back({input_str, cur_answer});
|
511 |
+
cur_answer.clear();
|
512 |
+
|
513 |
+
// Delete the first half data
|
514 |
+
size_t half_size = history_vector.size() / 2;
|
515 |
+
history_vector.erase(history_vector.begin(), history_vector.begin() + half_size);
|
516 |
+
} else {
|
517 |
+
history_vector.push_back({input_str, cur_answer});
|
518 |
+
cur_answer.clear();
|
519 |
+
}
|
520 |
+
}
|
521 |
+
|
522 |
+
static void split(const std::string &s, const std::string &delim,
|
523 |
+
std::vector<std::string> &ret) {
|
524 |
+
size_t last = 0;
|
525 |
+
size_t index = s.find_first_of(delim, last);
|
526 |
+
while (index != std::string::npos) {
|
527 |
+
ret.push_back(s.substr(last, index - last));
|
528 |
+
last = index + 1;
|
529 |
+
index = s.find_first_of(delim, last);
|
530 |
+
}
|
531 |
+
if (last < s.length()) {
|
532 |
+
ret.push_back(s.substr(last));
|
533 |
+
}
|
534 |
+
}
|
535 |
+
|
536 |
+
static std::vector<int> parseCascadeDevices(const std::string &str) {
|
537 |
+
std::vector<int> devices;
|
538 |
+
std::vector<std::string> sub_str;
|
539 |
+
split(str, ",", sub_str);
|
540 |
+
for (auto &s : sub_str) {
|
541 |
+
devices.push_back(std::atoi(s.c_str()));
|
542 |
+
}
|
543 |
+
return devices;
|
544 |
+
}
|
545 |
+
|
546 |
+
void Usage() {
|
547 |
+
printf("Usage:\n"
|
548 |
+
" --help : Show help info.\n"
|
549 |
+
" --model : Set model path \n"
|
550 |
+
" --tokenizer : Set tokenizer path \n"
|
551 |
+
" --devid : Set devices to run for model, e.g. 1,2. if not "
|
552 |
+
"set, use 0\n");
|
553 |
+
}
|
554 |
+
|
555 |
+
void processArguments(int argc, char *argv[], std::string &model_path, std::string &tokenizer_path,
|
556 |
+
std::vector<int> &devices) {
|
557 |
+
struct option longOptions[] = {{"model", required_argument, nullptr, 'm'},
|
558 |
+
{"tokenizer", required_argument, nullptr, 't'},
|
559 |
+
{"devid", required_argument, nullptr, 'd'},
|
560 |
+
{"help", no_argument, nullptr, 'h'},
|
561 |
+
{nullptr, 0, nullptr, 0}};
|
562 |
+
|
563 |
+
int optionIndex = 0;
|
564 |
+
int option;
|
565 |
+
|
566 |
+
while ((option = getopt_long(argc, argv, "m:t:d:h:", longOptions,
|
567 |
+
&optionIndex)) != -1) {
|
568 |
+
switch (option) {
|
569 |
+
case 'm':
|
570 |
+
model_path = optarg;
|
571 |
+
break;
|
572 |
+
case 't':
|
573 |
+
tokenizer_path = optarg;
|
574 |
+
break;
|
575 |
+
case 'd':
|
576 |
+
devices = parseCascadeDevices(optarg);
|
577 |
+
break;
|
578 |
+
case 'h':
|
579 |
+
Usage();
|
580 |
+
exit(EXIT_FAILURE);
|
581 |
+
case '?':
|
582 |
+
Usage();
|
583 |
+
exit(EXIT_FAILURE);
|
584 |
+
default:
|
585 |
+
exit(EXIT_FAILURE);
|
586 |
+
}
|
587 |
+
}
|
588 |
+
}
|
589 |
+
|
590 |
+
int main(int argc, char **argv) {
|
591 |
+
// set your bmodel path here
|
592 |
+
printf("Demo for ChatGLM in BM1684X, support ChatGLM1/2/3\n");
|
593 |
+
std::string model_path;
|
594 |
+
std::string tokenizer_path;
|
595 |
+
std::vector<int> devices = {0};
|
596 |
+
processArguments(argc, argv, model_path, tokenizer_path, devices);
|
597 |
+
if (model_path.empty()) {
|
598 |
+
Usage();
|
599 |
+
exit(EXIT_FAILURE);
|
600 |
+
}
|
601 |
+
|
602 |
+
ChatGLM glm;
|
603 |
+
printf("Init Environment ...\n");
|
604 |
+
glm.init(devices, model_path, tokenizer_path);
|
605 |
+
printf("==========================\n");
|
606 |
+
glm.chat();
|
607 |
+
glm.deinit();
|
608 |
+
return 0;
|
609 |
+
}
|
ChatGLM2/run_demo.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
set -ex
|
3 |
+
|
4 |
+
#!/bin/bash
|
5 |
+
# download bmodel
|
6 |
+
if [ ! -d "../../bmodels" ]; then
|
7 |
+
mkdir ../../bmodels
|
8 |
+
fi
|
9 |
+
|
10 |
+
if [ ! -f "../../bmodels/chatglm2-6b_int4_1dev.bmodel" ]; then
|
11 |
+
pip3 install dfss
|
12 |
+
python3 -m dfss --url=open@sophgo.com:/ext_model_information/LLM/LLM-TPU/chatglm2-6b_int4_1dev.bmodel
|
13 |
+
mv chatglm2-6b_int4_1dev.bmodel ../../bmodels
|
14 |
+
else
|
15 |
+
echo "Bmodel Exists!"
|
16 |
+
fi
|
17 |
+
|
18 |
+
if [ ! -f "./demo/chatglm" ]; then
|
19 |
+
cd demo && rm -rf build && mkdir build && cd build
|
20 |
+
cmake .. && make -j
|
21 |
+
cp chatglm .. && cd ../..
|
22 |
+
else
|
23 |
+
echo "chatglm file Exists!"
|
24 |
+
fi
|
25 |
+
|
26 |
+
# run demo
|
27 |
+
./demo/chatglm --model ../../bmodels/chatglm2-6b_int4_1dev.bmodel --tokenizer ./support/tokenizer/tokenizer.model --devid 0
|
ChatGLM2/support/include/bmdef.h
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*****************************************************************************
|
2 |
+
*
|
3 |
+
* Copyright (c) 2016-2026 by Sophgo Technologies Inc. All rights reserved.
|
4 |
+
*
|
5 |
+
* The material in this file is confidential and contains trade secrets
|
6 |
+
* of Sophgo Technologies Inc. This is proprietary information owned by
|
7 |
+
* Sophgo Technologies Inc. No part of this work may be disclosed,
|
8 |
+
* reproduced, copied, transmitted, or used in any way for any purpose,
|
9 |
+
* without the express written permission of Sophgo Technologies Inc.
|
10 |
+
*
|
11 |
+
*****************************************************************************/
|
12 |
+
|
13 |
+
#ifndef __BMRUNTIME_DEFINE_H__
|
14 |
+
#define __BMRUNTIME_DEFINE_H__
|
15 |
+
|
16 |
+
#include "bmlib_runtime.h"
|
17 |
+
#include <stddef.h>
|
18 |
+
#include <stdint.h>
|
19 |
+
|
20 |
+
#if defined(__cplusplus)
|
21 |
+
extern "C" {
|
22 |
+
#endif
|
23 |
+
|
24 |
+
/* --------------------------------------------------------------------------*/
|
25 |
+
/* basic definitions */
|
26 |
+
|
27 |
+
/* bm_data_type_t holds the type for a scalar value */
|
28 |
+
typedef enum bm_data_type_e {
|
29 |
+
BM_FLOAT32 = 0,
|
30 |
+
BM_FLOAT16 = 1,
|
31 |
+
BM_INT8 = 2,
|
32 |
+
BM_UINT8 = 3,
|
33 |
+
BM_INT16 = 4,
|
34 |
+
BM_UINT16 = 5,
|
35 |
+
BM_INT32 = 6,
|
36 |
+
BM_UINT32 = 7,
|
37 |
+
BM_BFLOAT16 = 8,
|
38 |
+
BM_INT4 = 9,
|
39 |
+
BM_UINT4 = 10,
|
40 |
+
} bm_data_type_t;
|
41 |
+
|
42 |
+
/* store mode definitions */
|
43 |
+
typedef enum bm_store_mode_e {
|
44 |
+
BM_STORE_1N = 0, /* default, if not sure, use 0 */
|
45 |
+
BM_STORE_2N = 1,
|
46 |
+
BM_STORE_4N = 2,
|
47 |
+
} bm_store_mode_t;
|
48 |
+
|
49 |
+
/* bm_shape_t holds the shape info */
|
50 |
+
#define BM_MAX_DIMS_NUM 8
|
51 |
+
typedef struct bm_shape_s {
|
52 |
+
int num_dims;
|
53 |
+
int dims[BM_MAX_DIMS_NUM];
|
54 |
+
} bm_shape_t;
|
55 |
+
|
56 |
+
typedef struct bm_shape_ex_s {
|
57 |
+
bm_shape_t shape;
|
58 |
+
int elem_num;
|
59 |
+
} bm_shape_ex_t;
|
60 |
+
|
61 |
+
/*
|
62 |
+
bm_tensor_t holds a multi-dimensional array of elements of a single data type
|
63 |
+
and tensor are in device memory */
|
64 |
+
typedef struct bm_tensor_s {
|
65 |
+
bm_data_type_t dtype;
|
66 |
+
bm_shape_t shape;
|
67 |
+
bm_device_mem_t device_mem;
|
68 |
+
bm_store_mode_t st_mode; /* user can set 0 as default store mode */
|
69 |
+
} bm_tensor_t;
|
70 |
+
|
71 |
+
/* --------------------------------------------------------------------------*/
|
72 |
+
/* network information structure */
|
73 |
+
|
74 |
+
/* bm_stage_info_t holds input/output shapes and device mems; every network can contain one or more
|
75 |
+
* stages */
|
76 |
+
typedef struct bm_stage_info_s {
|
77 |
+
bm_shape_t *input_shapes; /* input_shapes[0] / [1] / ... / [input_num-1] */
|
78 |
+
bm_shape_t *output_shapes; /* output_shapes[0] / [1] / ... / [output_num-1] */
|
79 |
+
bm_device_mem_t *input_mems; /* input_mems[0] / [1] / ... / [input_num-1] */
|
80 |
+
bm_device_mem_t *output_mems; /* output_mems[0] / [1] / ... / [output_num-1] */
|
81 |
+
} bm_stage_info_t;
|
82 |
+
|
83 |
+
/* bm_tensor_info_t holds all information of one net.
|
84 |
+
* scale for float type is 1.0 as default */
|
85 |
+
typedef struct bm_net_info_s {
|
86 |
+
const char* name; /* net name */
|
87 |
+
bool is_dynamic; /* dynamic or static */
|
88 |
+
int input_num; /* number of inputs */
|
89 |
+
char const** input_names; /* input_names[0] / [1] / .../ [input_num-1] */
|
90 |
+
bm_data_type_t* input_dtypes; /* input_dtypes[0] / [1] / .../ [input_num-1] */
|
91 |
+
float* input_scales; /* input_scales[0] / [1] / .../ [input_num-1] */
|
92 |
+
int output_num; /* number of outputs */
|
93 |
+
char const** output_names; /* output_names[0] / [1] / .../ [output_num-1] */
|
94 |
+
bm_data_type_t* output_dtypes; /* output_dtypes[0] / [1] / .../ [output_num-1] */
|
95 |
+
float* output_scales; /* output_scales[0] / [1] / .../ [output_num-1] */
|
96 |
+
int stage_num; /* number of stages */
|
97 |
+
bm_stage_info_t* stages; /* stages[0] / [1] / ... / [stage_num-1] */
|
98 |
+
size_t* max_input_bytes; /* max_input_bytes[0]/ [1] / ... / [input_num-1] */
|
99 |
+
size_t* max_output_bytes; /* max_output_bytes[0] / [1] / ... / [output_num-1] */
|
100 |
+
int* input_zero_point; /* input_zero_point[0] / [1] / .../ [input_num-1] */
|
101 |
+
int* output_zero_point; /* output_zero_point[0] / [1] / .../ [output_num-1] */
|
102 |
+
int *input_loc_devices; /* input_loc_device[0] / [1] / .../ [input_num-1] */
|
103 |
+
int *output_loc_devices; /* output_loc_device[0] / [1] / .../ [output_num-1] */
|
104 |
+
} bm_net_info_t;
|
105 |
+
|
106 |
+
typedef struct api_info_s {
|
107 |
+
/// @brief api_id to be sent to driver
|
108 |
+
int32_t api_id;
|
109 |
+
/// @brief api data to be sent to driver
|
110 |
+
uint8_t **api_data;
|
111 |
+
/// @brief size of the api data to be sent to driver
|
112 |
+
size_t api_data_size;
|
113 |
+
/// @brief subsize of the api data to be sent to driver
|
114 |
+
size_t *api_data_subsize;
|
115 |
+
/// @brief offset of input tensors' addr in api_data
|
116 |
+
uint32_t *input_addr_offset;
|
117 |
+
/// @brief number of the offset of input tensors' addr in api_data
|
118 |
+
size_t input_addr_offset_number;
|
119 |
+
/// @brief offset of output tensors' addr in api_data
|
120 |
+
uint32_t *output_addr_offset;
|
121 |
+
/// @brief number of the offset of output tensors' addr in api_data
|
122 |
+
size_t output_addr_offset_number;
|
123 |
+
} api_info_c;
|
124 |
+
|
125 |
+
#if defined(__cplusplus)
|
126 |
+
}
|
127 |
+
#endif
|
128 |
+
|
129 |
+
#endif /* __BM_NET_H__ */
|
ChatGLM2/support/include/bmlib_runtime.h
ADDED
@@ -0,0 +1,2581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*****************************************************************************
|
2 |
+
*
|
3 |
+
* Copyright (c) 2016-2026 by Bitmain Technologies Inc. All rights reserved.
|
4 |
+
*
|
5 |
+
* The material in this file is confidential and contains trade secrets
|
6 |
+
* of Bitmain Technologies Inc. This is proprietary information owned by
|
7 |
+
* Bitmain Technologies Inc. No part of this work may be disclosed,
|
8 |
+
* reproduced, copied, transmitted, or used in any way for any purpose,
|
9 |
+
* without the express written permission of Bitmain Technologies Inc.
|
10 |
+
*
|
11 |
+
*****************************************************************************/
|
12 |
+
|
13 |
+
/**************************************************************************
|
14 |
+
* bmlib_runtime defines interfaces that operate TPU devices.
|
15 |
+
* The functions can be divided into serveral categories.
|
16 |
+
* 1) device handle creation and destroy
|
17 |
+
* 2) memory help functions
|
18 |
+
* 3) global memory allocation and free
|
19 |
+
* 4) data transfer between host and device
|
20 |
+
* 5) data transfer within device memory
|
21 |
+
* 6) api send and synchronization
|
22 |
+
* 7) global memory map and coherence
|
23 |
+
* 8) trace and profile
|
24 |
+
* 9) power management
|
25 |
+
* 10) miscellaneous functions
|
26 |
+
*************************************************************************/
|
27 |
+
|
28 |
+
#ifndef BMLIB_RUNTIME_H_
|
29 |
+
#define BMLIB_RUNTIME_H_
|
30 |
+
#if defined(_WIN32) && !defined(__MINGW32__)
|
31 |
+
#include <vadefs.h>
|
32 |
+
#define DECL_EXPORT __declspec(dllexport)
|
33 |
+
#define DECL_IMPORT __declspec(dllimport)
|
34 |
+
#else
|
35 |
+
#include <stdbool.h>
|
36 |
+
#include <stddef.h>
|
37 |
+
#include <stdarg.h>
|
38 |
+
#define DECL_EXPORT
|
39 |
+
#define DECL_IMPORT
|
40 |
+
#endif
|
41 |
+
|
42 |
+
#if defined(__cplusplus)
|
43 |
+
extern "C" {
|
44 |
+
#endif
|
45 |
+
|
46 |
+
typedef enum {
|
47 |
+
MODULE_CDMA = 0,
|
48 |
+
MODULE_GDMA = 1,
|
49 |
+
MODULE_TPU = 2,
|
50 |
+
MODULE_SMMU = 3,
|
51 |
+
MODULE_SRAM = 4,
|
52 |
+
MODULE_END = 5
|
53 |
+
} MODULE_ID;
|
54 |
+
|
55 |
+
#define BM_MEM_ADDR_NULL (0xfffffffff)
|
56 |
+
|
57 |
+
#ifndef BM_MEM_DESC_T_
|
58 |
+
#define BM_MEM_DESC_T_
|
59 |
+
/* BM function return code definitions */
|
60 |
+
typedef enum {
|
61 |
+
BM_SUCCESS = 0,
|
62 |
+
BM_ERR_DEVNOTREADY = 1, /* Device not ready yet */
|
63 |
+
BM_ERR_FAILURE = 2, /* General failure */
|
64 |
+
BM_ERR_TIMEOUT = 3, /* Timeout */
|
65 |
+
BM_ERR_PARAM = 4, /* Parameters invalid */
|
66 |
+
BM_ERR_NOMEM = 5, /* Not enough memory */
|
67 |
+
BM_ERR_DATA = 6, /* Data error */
|
68 |
+
BM_ERR_BUSY = 7, /* Busy */
|
69 |
+
BM_ERR_NOFEATURE = 8, /* Not supported yet */
|
70 |
+
BM_NOT_SUPPORTED = 9
|
71 |
+
} bm_status_t;
|
72 |
+
|
73 |
+
/* BM memory type definitions */
|
74 |
+
typedef enum {
|
75 |
+
BM_MEM_TYPE_DEVICE = 0,
|
76 |
+
BM_MEM_TYPE_HOST = 1,
|
77 |
+
BM_MEM_TYPE_SYSTEM = 2,
|
78 |
+
BM_MEM_TYPE_INT8_DEVICE = 3,
|
79 |
+
BM_MEM_TYPE_INVALID = 4
|
80 |
+
} bm_mem_type_t;
|
81 |
+
|
82 |
+
typedef enum {
|
83 |
+
PERF_MONITOR_GDMA = 0,
|
84 |
+
PERF_MONITOR_TPU = 1
|
85 |
+
} PERF_MONITOR_ID;
|
86 |
+
|
87 |
+
typedef enum {
|
88 |
+
BMCPU_IDLE = 0,
|
89 |
+
BMCPU_RUNNING = 1,
|
90 |
+
BMCPU_FAULT = 2
|
91 |
+
} bm_cpu_status_t;
|
92 |
+
|
93 |
+
/*
|
94 |
+
* bm performace monitor
|
95 |
+
*/
|
96 |
+
typedef struct bm_perf_monitor {
|
97 |
+
long long buffer_start_addr; /*buffer address to store perf data*/
|
98 |
+
int buffer_size; /*buffer size*/
|
99 |
+
PERF_MONITOR_ID monitor_id; /*PERF_MONITOR_GDMA or PERF_MONITOR_TPU*/
|
100 |
+
} bm_perf_monitor_t;
|
101 |
+
|
102 |
+
typedef union {
|
103 |
+
struct {
|
104 |
+
bm_mem_type_t mem_type : 3;
|
105 |
+
unsigned int gmem_heapid : 3;
|
106 |
+
unsigned int reserved : 26;
|
107 |
+
} u;
|
108 |
+
unsigned int rawflags;
|
109 |
+
} bm_mem_flags_t;
|
110 |
+
|
111 |
+
/* BM memory descriptor definition*/
|
112 |
+
typedef struct bm_mem_desc {
|
113 |
+
union {
|
114 |
+
struct {
|
115 |
+
#ifdef __linux__
|
116 |
+
unsigned long device_addr;
|
117 |
+
#else
|
118 |
+
unsigned long long device_addr;
|
119 |
+
#endif
|
120 |
+
unsigned int reserved;
|
121 |
+
int dmabuf_fd;
|
122 |
+
} device;
|
123 |
+
|
124 |
+
struct {
|
125 |
+
void *system_addr;
|
126 |
+
unsigned int reserved0;
|
127 |
+
int reserved1;
|
128 |
+
} system;
|
129 |
+
} u;
|
130 |
+
|
131 |
+
bm_mem_flags_t flags;
|
132 |
+
unsigned int size;
|
133 |
+
} bm_mem_desc_t;
|
134 |
+
|
135 |
+
typedef struct bm_mem_desc bm_device_mem_t;
|
136 |
+
typedef struct bm_mem_desc bm_system_mem_t;
|
137 |
+
|
138 |
+
typedef struct sg_mem_desc {
|
139 |
+
union {
|
140 |
+
struct {
|
141 |
+
#ifdef __linux__
|
142 |
+
unsigned long device_addr;
|
143 |
+
#else
|
144 |
+
unsigned long long device_addr;
|
145 |
+
#endif
|
146 |
+
unsigned int reserved;
|
147 |
+
int dmabuf_fd;
|
148 |
+
} device;
|
149 |
+
|
150 |
+
struct {
|
151 |
+
void *system_addr;
|
152 |
+
unsigned int reserved0;
|
153 |
+
int reserved1;
|
154 |
+
} system;
|
155 |
+
} u;
|
156 |
+
|
157 |
+
bm_mem_flags_t flags;
|
158 |
+
unsigned long long size;
|
159 |
+
} sg_mem_desc_t;
|
160 |
+
|
161 |
+
typedef struct sg_mem_desc sg_device_mem_t;
|
162 |
+
typedef struct sg_mem_desc sg_system_mem_t;
|
163 |
+
#endif
|
164 |
+
|
165 |
+
struct bm_context;
|
166 |
+
typedef struct bm_context *bm_handle_t;
|
167 |
+
|
168 |
+
#define MD5SUM_LEN 16
|
169 |
+
#define LIB_MAX_NAME_LEN 64
|
170 |
+
#define FUNC_MAX_NAME_LEN 64
|
171 |
+
|
172 |
+
typedef struct bm_module
|
173 |
+
{
|
174 |
+
// void *lib_handle;
|
175 |
+
char lib_name[LIB_MAX_NAME_LEN];
|
176 |
+
unsigned char md5[MD5SUM_LEN];
|
177 |
+
}bm_module;
|
178 |
+
|
179 |
+
typedef struct bm_module *tpu_kernel_module_t;
|
180 |
+
typedef int tpu_kernel_function_t;
|
181 |
+
|
182 |
+
/**
|
183 |
+
* @name tpu_kernel_load_module_file
|
184 |
+
* @brief To load dyn file
|
185 |
+
* @ingroup bmlib_runtime
|
186 |
+
*
|
187 |
+
* @param [in] handle The device handle
|
188 |
+
* @param [in] module_file dyn file
|
189 |
+
* @retval dyn lib ptr
|
190 |
+
*/
|
191 |
+
tpu_kernel_module_t tpu_kernel_load_module_file(bm_handle_t handle, const char *module_file);
|
192 |
+
|
193 |
+
/**
|
194 |
+
* @name tpu_kernel_load_module_file_key
|
195 |
+
* @brief To load dyn file with key
|
196 |
+
* @ingroup bmlib_runtime
|
197 |
+
*
|
198 |
+
* @param [in] handle The device handle
|
199 |
+
* @param [in] module_file dyn file
|
200 |
+
* @param [in] key identification str
|
201 |
+
* @param [in] size key size
|
202 |
+
* @retval dyn lib ptr
|
203 |
+
*/
|
204 |
+
tpu_kernel_module_t tpu_kernel_load_module_file_key(bm_handle_t handle, const char *module_file, const char *key, int size);
|
205 |
+
|
206 |
+
/**
|
207 |
+
* @name tpu_kernel_unload_module
|
208 |
+
* @brief To unload dyn file
|
209 |
+
* @ingroup bmlib_runtime
|
210 |
+
*
|
211 |
+
* @param [in] handle The device handle
|
212 |
+
* @param [in] p_module dyn lib ptr
|
213 |
+
* @retval BM_SUCCESS Succeeds.
|
214 |
+
* Other code Fails.
|
215 |
+
*/
|
216 |
+
bm_status_t tpu_kernel_unload_module(bm_handle_t handle, tpu_kernel_module_t p_module);
|
217 |
+
|
218 |
+
/**
|
219 |
+
* @name tpu_kernel_free_module
|
220 |
+
* @brief To free p_module when not use
|
221 |
+
* @ingroup bmlib_runtime
|
222 |
+
*
|
223 |
+
* @param [in] handle The device handle
|
224 |
+
* @param [in] p_module dyn lib ptr
|
225 |
+
* @retval BM_SUCCESS Succeeds.
|
226 |
+
* Other code Fails.
|
227 |
+
*/
|
228 |
+
bm_status_t tpu_kernel_free_module(bm_handle_t handle, tpu_kernel_module_t p_module);
|
229 |
+
|
230 |
+
/**
|
231 |
+
* @name tpu_kernel_load_module
|
232 |
+
* @brief To load dyn module
|
233 |
+
* @ingroup bmlib_runtime
|
234 |
+
*
|
235 |
+
* @param [in] handle The device handle
|
236 |
+
* @param [in] data dyn module
|
237 |
+
* @param [in] length dyn module size
|
238 |
+
* @retval dyn lib ptr
|
239 |
+
*/
|
240 |
+
tpu_kernel_module_t tpu_kernel_load_module(bm_handle_t handle, const char *data, size_t length);
|
241 |
+
|
242 |
+
/**
|
243 |
+
* @name tpu_kernel_get_function
|
244 |
+
* @brief To get function from lib
|
245 |
+
* @ingroup bmlib_runtime
|
246 |
+
*
|
247 |
+
* @param [in] handle The device handle
|
248 |
+
* @param [in] module dyn module
|
249 |
+
* @param [in] function funtion name
|
250 |
+
* @retval function id
|
251 |
+
*/
|
252 |
+
tpu_kernel_function_t tpu_kernel_get_function(bm_handle_t handle, tpu_kernel_module_t module, const char *function);
|
253 |
+
|
254 |
+
/**
|
255 |
+
* @name tpu_kernel_launch
|
256 |
+
* @brief To launch function with sync
|
257 |
+
* @ingroup bmlib_runtime
|
258 |
+
*
|
259 |
+
* @param [in] handle The device handle
|
260 |
+
* @param [in] function function id
|
261 |
+
* @param [in] args funtion args
|
262 |
+
* @param [in] size args size
|
263 |
+
* @retval BM_SUCCESS Succeeds.
|
264 |
+
* Other code Fails.
|
265 |
+
*/
|
266 |
+
bm_status_t tpu_kernel_launch(bm_handle_t handle, tpu_kernel_function_t function, void *args, size_t size);
|
267 |
+
|
268 |
+
/**
|
269 |
+
* @name tpu_kernel_launch_async
|
270 |
+
* @brief To launch function with async
|
271 |
+
* @ingroup bmlib_runtime
|
272 |
+
*
|
273 |
+
* @param [in] handle The device handle
|
274 |
+
* @param [in] function function id
|
275 |
+
* @param [in] args funtion args
|
276 |
+
* @param [in] size args size
|
277 |
+
* @retval BM_SUCCESS Succeeds.
|
278 |
+
* Other code Fails.
|
279 |
+
*/
|
280 |
+
bm_status_t tpu_kernel_launch_async(bm_handle_t handle, tpu_kernel_function_t function, void *args, size_t size);
|
281 |
+
|
282 |
+
/**
|
283 |
+
* @name tpu_kernel_launch_async_multi_cores
|
284 |
+
* @brief To launch function with async for multi cores
|
285 |
+
* @ingroup bmlib_runtime
|
286 |
+
*
|
287 |
+
* @param [in] handle The device handle
|
288 |
+
* @param [in] func_name function name
|
289 |
+
* @param [in] api_param funtion params
|
290 |
+
* @param [in] api_size params size
|
291 |
+
* @param [in] core_list list of core ids
|
292 |
+
* @param [in] core_num number of cores
|
293 |
+
* @retval BM_SUCCESS Succeeds.
|
294 |
+
* Other code Fails.
|
295 |
+
*/
|
296 |
+
bm_status_t tpu_kernel_launch_async_multi_cores(bm_handle_t handle, const char *func_name, const void *api_param,
|
297 |
+
size_t api_size, const int* core_list, const int core_num);
|
298 |
+
|
299 |
+
/**
|
300 |
+
* @name tpu_kernel_launch_sync_multi_cores
|
301 |
+
* @brief To launch function with sync for multi cores
|
302 |
+
* @ingroup bmlib_runtime
|
303 |
+
*
|
304 |
+
* @param [in] handle The device handle
|
305 |
+
* @param [in] func_name function name
|
306 |
+
* @param [in] api_param funtion params
|
307 |
+
* @param [in] api_size params size
|
308 |
+
* @param [in] core_list list of core ids
|
309 |
+
* @param [in] core_num number of cores
|
310 |
+
* @retval BM_SUCCESS Succeeds.
|
311 |
+
* Other code Fails.
|
312 |
+
*/
|
313 |
+
bm_status_t tpu_kernel_launch_sync_multi_cores(bm_handle_t handle, const char *func_name, const void *api_param,
|
314 |
+
size_t api_size, const int* core_list, const int core_num);
|
315 |
+
|
316 |
+
/**
|
317 |
+
* @name tpu_kernel_sync
|
318 |
+
* @brief To sync
|
319 |
+
* @ingroup bmlib_runtime
|
320 |
+
*
|
321 |
+
* @param [in] handle The device handle
|
322 |
+
* @retval BM_SUCCESS Succeeds.
|
323 |
+
* Other code Fails.
|
324 |
+
*/
|
325 |
+
bm_status_t tpu_kernel_sync(bm_handle_t handle);
|
326 |
+
void show_md5(unsigned char md5[]);
|
327 |
+
|
328 |
+
DECL_EXPORT void bmlib_log(const char *tag, int level, const char *fmt, ...);
|
329 |
+
|
330 |
+
#ifndef USING_CMODEL
|
331 |
+
#define BM_CHECK_RET(call) \
|
332 |
+
do { \
|
333 |
+
bm_status_t ret = (bm_status_t)call; \
|
334 |
+
if (ret != BM_SUCCESS) { \
|
335 |
+
bmlib_log("BM_CHECK",16,"BM_CHECK_RET fail %s: %s: %d\n", __FILE__, __func__, __LINE__); \
|
336 |
+
return ret; \
|
337 |
+
} \
|
338 |
+
} while (0)
|
339 |
+
#else
|
340 |
+
#define BM_CHECK_RET(call) \
|
341 |
+
do { \
|
342 |
+
bm_status_t ret = call; \
|
343 |
+
if (ret != BM_SUCCESS) { \
|
344 |
+
bmlib_log("BM_CHECK",16,"BM_CHECK_RET failed %d\n", ret);\
|
345 |
+
ASSERT(0); \
|
346 |
+
exit(-ret); \
|
347 |
+
} \
|
348 |
+
} while (0)
|
349 |
+
#endif
|
350 |
+
|
351 |
+
/*******************handle releated functions *********************************/
|
352 |
+
/**
|
353 |
+
* @name bm_dev_getcount
|
354 |
+
* @brief To get the number of sophon devices in system.
|
355 |
+
* If N is got, valid devid is [0, N-1]
|
356 |
+
* @ingroup bmlib_runtime
|
357 |
+
*
|
358 |
+
* @param [out] count The result number of sophon devices
|
359 |
+
* @retval BM_SUCCESS Succeeds.
|
360 |
+
* Other code Fails.
|
361 |
+
*/
|
362 |
+
DECL_EXPORT bm_status_t bm_dev_getcount(int *count);
|
363 |
+
|
364 |
+
/**
|
365 |
+
* @name bm_dev_query
|
366 |
+
* @brief To query if a device is present
|
367 |
+
* @ingroup bmlib_runtime
|
368 |
+
*
|
369 |
+
* @param [in] devid The id of the device to query
|
370 |
+
* @retval BM_SUCCESS Device is present
|
371 |
+
* Other code Devcie is not present
|
372 |
+
*/
|
373 |
+
DECL_EXPORT bm_status_t bm_dev_query(int devid);
|
374 |
+
|
375 |
+
/**
|
376 |
+
* @name bm_dev_request
|
377 |
+
* @brief To create a handle for the given device
|
378 |
+
* @ingroup bmlib_runtime
|
379 |
+
*
|
380 |
+
* @param [out] handle The created handle
|
381 |
+
* @param [in] devid Specify on which device to create handle
|
382 |
+
* @retval BM_SUCCESS Succeeds.
|
383 |
+
* Other code Fails.
|
384 |
+
*/
|
385 |
+
DECL_EXPORT bm_status_t bm_dev_request(bm_handle_t *handle, int devid);
|
386 |
+
|
387 |
+
/**
|
388 |
+
* @name bm_get_devid
|
389 |
+
* @brief To get device index for the given handle
|
390 |
+
* @ingroup bmlib_runtime
|
391 |
+
*
|
392 |
+
* @param [in] handle The given handle
|
393 |
+
* @retval int device index that the handle points to.
|
394 |
+
*/
|
395 |
+
DECL_EXPORT int bm_get_devid(bm_handle_t handle);
|
396 |
+
|
397 |
+
/**
|
398 |
+
* @name bm_dev_free
|
399 |
+
* @brief To free a handle
|
400 |
+
* @ingroup bmlib_runtime
|
401 |
+
*
|
402 |
+
* @param [in] handle The handle to free
|
403 |
+
*/
|
404 |
+
DECL_EXPORT void bm_dev_free(bm_handle_t handle);
|
405 |
+
|
406 |
+
/*******************memory help functions ************************************/
|
407 |
+
/**
|
408 |
+
* @name bm_mem_get_type
|
409 |
+
* @brief To get a memory descriptor's type
|
410 |
+
* @ingroup bmlib_runtime
|
411 |
+
*
|
412 |
+
* @param [in] mem The memory descriptor queried
|
413 |
+
* @retval BM_MEM_TYPE_DEVICE Device global memory
|
414 |
+
* @retval BM_MEM_TYPE_SYSTEM Host user memory
|
415 |
+
*/
|
416 |
+
DECL_EXPORT bm_mem_type_t bm_mem_get_type(struct bm_mem_desc mem);
|
417 |
+
|
418 |
+
/**
|
419 |
+
* @name sg_mem_get_type
|
420 |
+
* @brief To get a memory descriptor's type
|
421 |
+
* @ingroup bmlib_runtime
|
422 |
+
*
|
423 |
+
* @param [in] mem The memory descriptor queried
|
424 |
+
* @retval BM_MEM_TYPE_DEVICE Device global memory
|
425 |
+
* @retval BM_MEM_TYPE_SYSTEM Host user memory
|
426 |
+
*/
|
427 |
+
DECL_EXPORT bm_mem_type_t sg_mem_get_type(struct sg_mem_desc mem);
|
428 |
+
|
429 |
+
/**
|
430 |
+
* @name bm_mem_get_device_addr
|
431 |
+
* @brief To get a device memory descriptor's address
|
432 |
+
* @ingroup bmlib_runtime
|
433 |
+
*
|
434 |
+
* @param [in] mem The device memory descriptor queried
|
435 |
+
* @retval unsigned long long The device memory address
|
436 |
+
*/
|
437 |
+
DECL_EXPORT unsigned long long bm_mem_get_device_addr(struct bm_mem_desc mem);
|
438 |
+
|
439 |
+
/**
|
440 |
+
* @name sg_mem_get_device_addr
|
441 |
+
* @brief To get a device memory descriptor's address
|
442 |
+
* @ingroup bmlib_runtime
|
443 |
+
*
|
444 |
+
* @param [in] mem The device memory descriptor queried
|
445 |
+
* @retval unsigned long long The device memory address
|
446 |
+
*/
|
447 |
+
DECL_EXPORT unsigned long long sg_mem_get_device_addr(struct sg_mem_desc mem);
|
448 |
+
|
449 |
+
/**
|
450 |
+
* @name bm_mem_set_device_addr
|
451 |
+
* @brief To set a device memory descriptor's address
|
452 |
+
* @ingroup bmlib_runtime
|
453 |
+
*
|
454 |
+
* @param [in] pmem The device memory descriptor pointer
|
455 |
+
* @param ]in] addr The new device address of the device memory
|
456 |
+
*/
|
457 |
+
DECL_EXPORT void bm_mem_set_device_addr(struct bm_mem_desc* pmem, unsigned long long addr);
|
458 |
+
|
459 |
+
/**
|
460 |
+
* @name sg_mem_set_device_addr
|
461 |
+
* @brief To set a device memory descriptor's address
|
462 |
+
* @ingroup bmlib_runtime
|
463 |
+
*
|
464 |
+
* @param [in] pmem The device memory descriptor pointer
|
465 |
+
* @param ]in] addr The new device address of the device memory
|
466 |
+
*/
|
467 |
+
DECL_EXPORT void sg_mem_set_device_addr(struct sg_mem_desc* pmem, unsigned long long addr);
|
468 |
+
|
469 |
+
/**
|
470 |
+
* @name bm_mem_get_device_size
|
471 |
+
* @brief To get a device memory descriptor's size
|
472 |
+
* @ingroup bmlib_runtime
|
473 |
+
*
|
474 |
+
* @param [in] mem The device memory descriptor queried
|
475 |
+
* @retval unsigned int The device memory's size in bytes
|
476 |
+
*/
|
477 |
+
DECL_EXPORT unsigned int bm_mem_get_device_size(struct bm_mem_desc mem);
|
478 |
+
|
479 |
+
/**
|
480 |
+
* @name sg_mem_get_device_size
|
481 |
+
* @brief To get a device memory descriptor's size
|
482 |
+
* @ingroup bmlib_runtime
|
483 |
+
*
|
484 |
+
* @param [in] mem The device memory descriptor queried
|
485 |
+
* @retval unsigned int The device memory's size in bytes
|
486 |
+
*/
|
487 |
+
DECL_EXPORT unsigned long long sg_mem_get_device_size(struct sg_mem_desc mem);
|
488 |
+
|
489 |
+
/**
|
490 |
+
* @name bm_mem_set_device_size
|
491 |
+
* @brief To set a device memory descriptor's size
|
492 |
+
* @ingroup bmlib_runtime
|
493 |
+
*
|
494 |
+
* @param [out] pmem The device memory descriptor pointer
|
495 |
+
* @param [in] size The new device memory size (in bytes) of the device memory
|
496 |
+
*/
|
497 |
+
DECL_EXPORT void bm_mem_set_device_size(struct bm_mem_desc* pmem, unsigned int size);
|
498 |
+
|
499 |
+
/**
|
500 |
+
* @name sg_mem_set_device_size
|
501 |
+
* @brief To set a device memory descriptor's size
|
502 |
+
* @ingroup bmlib_runtime
|
503 |
+
*
|
504 |
+
* @param [out] pmem The device memory descriptor pointer
|
505 |
+
* @param [in] size The new device memory size (in bytes) of the device memory
|
506 |
+
*/
|
507 |
+
DECL_EXPORT void sg_mem_set_device_size(struct sg_mem_desc* pmem, unsigned long long size);
|
508 |
+
|
509 |
+
/**
|
510 |
+
* @name bm_set_device_mem
|
511 |
+
* @brief To fill in a device memory descriptor with size and address
|
512 |
+
* @ingroup bmlib_runtime
|
513 |
+
*
|
514 |
+
* @param [in] pmem The device memory descriptor pointer
|
515 |
+
* @param [in] size The device memory descriptor's size
|
516 |
+
* @param [in] addr The device memory descriptor's address
|
517 |
+
*/
|
518 |
+
DECL_EXPORT void bm_set_device_mem(bm_device_mem_t* pmem, unsigned int size,
|
519 |
+
unsigned long long addr);
|
520 |
+
|
521 |
+
/**
|
522 |
+
* @name sg_set_device_mem
|
523 |
+
* @brief To fill in a device memory descriptor with size and address
|
524 |
+
* @ingroup bmlib_runtime
|
525 |
+
*
|
526 |
+
* @param [in] pmem The device memory descriptor pointer
|
527 |
+
* @param [in] size The device memory descriptor's size
|
528 |
+
* @param [in] addr The device memory descriptor's address
|
529 |
+
*/
|
530 |
+
DECL_EXPORT void sg_set_device_mem(sg_device_mem_t* pmem, unsigned long long size,
|
531 |
+
unsigned long long addr);
|
532 |
+
|
533 |
+
/**
|
534 |
+
* @name bm_mem_from_device
|
535 |
+
* @brief To create a device memory descriptor from address and size
|
536 |
+
* @ingroup bmlib_runtime
|
537 |
+
*
|
538 |
+
* @param [in] device_addr The device memory address
|
539 |
+
* @param [in] len The device memory size
|
540 |
+
* @retval bm_device_mem_t The device memory descriptor created
|
541 |
+
*/
|
542 |
+
DECL_EXPORT bm_device_mem_t bm_mem_from_device(unsigned long long device_addr,
|
543 |
+
unsigned int len);
|
544 |
+
|
545 |
+
/**
|
546 |
+
* @name sg_mem_from_device
|
547 |
+
* @brief To create a device memory descriptor from address and size
|
548 |
+
* @ingroup bmlib_runtime
|
549 |
+
*
|
550 |
+
* @param [in] device_addr The device memory address
|
551 |
+
* @param [in] len The device memory size
|
552 |
+
* @retval bm_device_mem_t The device memory descriptor created
|
553 |
+
*/
|
554 |
+
DECL_EXPORT sg_device_mem_t sg_mem_from_device(unsigned long long device_addr,
|
555 |
+
unsigned long long len);
|
556 |
+
|
557 |
+
/**
|
558 |
+
* @name bm_mem_get_system_addr
|
559 |
+
* @brief To get a system memory descriptor's address
|
560 |
+
* @ingroup bmlib_runtime
|
561 |
+
*
|
562 |
+
* @param [in] mem The system memory descriptor
|
563 |
+
* @retval void * The system memory descriptor's address
|
564 |
+
*/
|
565 |
+
DECL_EXPORT void *bm_mem_get_system_addr(struct bm_mem_desc mem);
|
566 |
+
|
567 |
+
/**
|
568 |
+
* @name sg_mem_get_system_addr
|
569 |
+
* @brief To get a system memory descriptor's address
|
570 |
+
* @ingroup bmlib_runtime
|
571 |
+
*
|
572 |
+
* @param [in] mem The system memory descriptor
|
573 |
+
* @retval void * The system memory descriptor's address
|
574 |
+
*/
|
575 |
+
DECL_EXPORT void *sg_mem_get_system_addr(struct sg_mem_desc mem);
|
576 |
+
|
577 |
+
/**
|
578 |
+
* @name bm_mem_set_system_addr
|
579 |
+
* @brief To set a system memory descriptor's address
|
580 |
+
* @ingroup bmlib_runtime
|
581 |
+
*
|
582 |
+
* @param [in] pmem The system memory descriptor pointer
|
583 |
+
* @param [in] addr The system memory address
|
584 |
+
*/
|
585 |
+
DECL_EXPORT void bm_mem_set_system_addr(struct bm_mem_desc* pmem, void *addr);
|
586 |
+
|
587 |
+
/**
|
588 |
+
* @name sg_mem_set_system_addr
|
589 |
+
* @brief To set a system memory descriptor's address
|
590 |
+
* @ingroup bmlib_runtime
|
591 |
+
*
|
592 |
+
* @param [in] pmem The system memory descriptor pointer
|
593 |
+
* @param [in] addr The system memory address
|
594 |
+
*/
|
595 |
+
DECL_EXPORT void sg_mem_set_system_addr(struct sg_mem_desc* pmem, void *addr);
|
596 |
+
|
597 |
+
/**
|
598 |
+
* @name bm_mem_from_system
|
599 |
+
* @brief To create a system memory descriptor with the given system address
|
600 |
+
* @ingroup bmlib_runtime
|
601 |
+
*
|
602 |
+
* @param [in] system_addr The system address in the descriptor
|
603 |
+
* @retval bm_system_mem_t The system memory descriptor created
|
604 |
+
*/
|
605 |
+
DECL_EXPORT bm_system_mem_t bm_mem_from_system(void *system_addr);
|
606 |
+
|
607 |
+
/*******************memory alloc and free functions ***************************/
|
608 |
+
/**
|
609 |
+
* @name bm_mem_null
|
610 |
+
* @brief Return an illegal device memory descriptor
|
611 |
+
* @ingroup bmlib_runtime
|
612 |
+
*
|
613 |
+
* @retval bm_device_mem_t An invalid device memory descriptor
|
614 |
+
*/
|
615 |
+
DECL_EXPORT bm_device_mem_t bm_mem_null(void);
|
616 |
+
#define BM_MEM_NULL (bm_mem_null())
|
617 |
+
|
618 |
+
/**
|
619 |
+
* @name bm_malloc_neuron_device
|
620 |
+
* @brief To malloc device memory according to a tensor shape
|
621 |
+
* (each neuron is 32 bits)
|
622 |
+
* @ingroup bmlib_runtime
|
623 |
+
*
|
624 |
+
* @param [in] handle The device handle
|
625 |
+
* @param [out] pmem The result devcie memory descriptor
|
626 |
+
* @param [in] n, c, h, w The shape of the input tensor
|
627 |
+
* @retval BM_SUCCESS Succeeds.
|
628 |
+
* Other code Fails.
|
629 |
+
*/
|
630 |
+
DECL_EXPORT bm_status_t bm_malloc_neuron_device(bm_handle_t handle, bm_device_mem_t *pmem,
|
631 |
+
int n, int c, int h, int w);
|
632 |
+
|
633 |
+
/**
|
634 |
+
* @name sg_malloc_neuron_device
|
635 |
+
* @brief To malloc device memory according to a tensor shape
|
636 |
+
* (each neuron is 32 bits)
|
637 |
+
* @ingroup bmlib_runtime
|
638 |
+
*
|
639 |
+
* @param [in] handle The device handle
|
640 |
+
* @param [out] pmem The result devcie memory descriptor
|
641 |
+
* @param [in] n, c, h, w The shape of the input tensor
|
642 |
+
* @retval BM_SUCCESS Succeeds.
|
643 |
+
* Other code Fails.
|
644 |
+
*/
|
645 |
+
DECL_EXPORT bm_status_t sg_malloc_neuron_device(bm_handle_t handle, sg_device_mem_t *pmem,
|
646 |
+
unsigned long long n, unsigned long long c,
|
647 |
+
unsigned long long h, unsigned long long w);
|
648 |
+
|
649 |
+
/**
|
650 |
+
* @name bm_malloc_device_dword
|
651 |
+
* @brief To malloc device memory in size of dword (32 bits)
|
652 |
+
* @ingroup bmlib_runtime
|
653 |
+
*
|
654 |
+
* @param [in] handle The device handle
|
655 |
+
* @param [out] pmem The result device memory descriptor
|
656 |
+
* @param [in] count The number of dwords(32bits) to allocate
|
657 |
+
* @retval BM_SUCCESS Succeeds.
|
658 |
+
* Other code Fails.
|
659 |
+
*/
|
660 |
+
DECL_EXPORT bm_status_t bm_malloc_device_dword(bm_handle_t handle, bm_device_mem_t *pmem,
|
661 |
+
int count);
|
662 |
+
|
663 |
+
/**
|
664 |
+
* @name sg_malloc_device_dword
|
665 |
+
* @brief To malloc device memory in size of dword (32 bits)
|
666 |
+
* @ingroup bmlib_runtime
|
667 |
+
*
|
668 |
+
* @param [in] handle The device handle
|
669 |
+
* @param [out] pmem The result device memory descriptor
|
670 |
+
* @param [in] count The number of dwords(32bits) to allocate
|
671 |
+
* @retval BM_SUCCESS Succeeds.
|
672 |
+
* Other code Fails.
|
673 |
+
*/
|
674 |
+
DECL_EXPORT bm_status_t sg_malloc_device_dword(bm_handle_t handle, sg_device_mem_t *pmem,
|
675 |
+
unsigned long long count);
|
676 |
+
|
677 |
+
/**
|
678 |
+
* @name bm_malloc_device_byte
|
679 |
+
* @brief To malloc device memory in size of byte
|
680 |
+
* @ingroup bmlib_runtime
|
681 |
+
*
|
682 |
+
* @param [in] handle The device handle
|
683 |
+
* @param [out] pmem The result device memory descriptor
|
684 |
+
* @param [in] size The number of bytes to allocate
|
685 |
+
* @retval BM_SUCCESS Succeeds.
|
686 |
+
* Other code Fails.
|
687 |
+
*/
|
688 |
+
DECL_EXPORT bm_status_t bm_malloc_device_byte(bm_handle_t handle, bm_device_mem_t *pmem,
|
689 |
+
unsigned int size);
|
690 |
+
|
691 |
+
/**
|
692 |
+
* @name sg_malloc_device_byte
|
693 |
+
* @brief To malloc device memory in size of byte
|
694 |
+
* @ingroup bmlib_runtime
|
695 |
+
*
|
696 |
+
* @param [in] handle The device handle
|
697 |
+
* @param [out] pmem The result device memory descriptor
|
698 |
+
* @param [in] size The number of bytes to allocate
|
699 |
+
* @retval BM_SUCCESS Succeeds.
|
700 |
+
* Other code Fails.
|
701 |
+
*/
|
702 |
+
DECL_EXPORT bm_status_t sg_malloc_device_byte(bm_handle_t handle, sg_device_mem_t *pmem,
|
703 |
+
unsigned long long size);
|
704 |
+
|
705 |
+
/**
|
706 |
+
* @name bm_malloc_device_byte_heap
|
707 |
+
* @brief To malloc device memory in size of byte within the specified heap
|
708 |
+
* @ingroup bmlib_runtime
|
709 |
+
*
|
710 |
+
* @param [in] handle The device handle
|
711 |
+
* @param [out] pmem The result device memory descriptor
|
712 |
+
* @param [in] heap_id The heap where to allocate 0/1/2
|
713 |
+
* @param [in] size The number of bytes to allocate
|
714 |
+
* @retval BM_SUCCESS Succeeds.
|
715 |
+
* Other code Fails.
|
716 |
+
*/
|
717 |
+
DECL_EXPORT bm_status_t bm_malloc_device_byte_heap(bm_handle_t handle, bm_device_mem_t *pmem,
|
718 |
+
int heap_id, unsigned int size);
|
719 |
+
|
720 |
+
/**
|
721 |
+
* @name sg_malloc_device_byte_heap
|
722 |
+
* @brief To malloc device memory in size of byte within the specified heap
|
723 |
+
* @ingroup bmlib_runtime
|
724 |
+
*
|
725 |
+
* @param [in] handle The device handle
|
726 |
+
* @param [out] pmem The result device memory descriptor
|
727 |
+
* @param [in] heap_id The heap where to allocate 0/1/2
|
728 |
+
* @param [in] size The number of bytes to allocate
|
729 |
+
* @retval BM_SUCCESS Succeeds.
|
730 |
+
* Other code Fails.
|
731 |
+
*/
|
732 |
+
DECL_EXPORT bm_status_t sg_malloc_device_byte_heap(bm_handle_t handle, sg_device_mem_t *pmem,
|
733 |
+
int heap_id, unsigned long long size);
|
734 |
+
|
735 |
+
/**
|
736 |
+
* @name bm_malloc_device_byte_heap_mask
|
737 |
+
* @brief To malloc device memory in size of byte within the specified heaps
|
738 |
+
* @ingroup bmlib_runtime
|
739 |
+
*
|
740 |
+
* @param [in] handle The device handle
|
741 |
+
* @param [out] pmem The result device memory descriptor
|
742 |
+
* @param [in] heap_id_mask The mask which heaps allocate from. each bit indicate one heap
|
743 |
+
* @param [in] size The number of bytes to allocate
|
744 |
+
* @retval BM_SUCCESS Succeeds.
|
745 |
+
* Other code Fails.
|
746 |
+
*/
|
747 |
+
DECL_EXPORT bm_status_t bm_malloc_device_byte_heap_mask(bm_handle_t handle, bm_device_mem_t *pmem,
|
748 |
+
int heap_id_mask, unsigned int size);
|
749 |
+
|
750 |
+
/**
|
751 |
+
* @name sg_malloc_device_byte_heap_mask
|
752 |
+
* @brief To malloc device memory in size of byte within the specified heaps
|
753 |
+
* @ingroup bmlib_runtime
|
754 |
+
*
|
755 |
+
* @param [in] handle The device handle
|
756 |
+
* @param [out] pmem The result device memory descriptor
|
757 |
+
* @param [in] heap_id_mask The mask which heaps allocate from. each bit indicate one heap
|
758 |
+
* @param [in] size The number of bytes to allocate
|
759 |
+
* @retval BM_SUCCESS Succeeds.
|
760 |
+
* Other code Fails.
|
761 |
+
*/
|
762 |
+
DECL_EXPORT bm_status_t sg_malloc_device_byte_heap_mask(bm_handle_t handle, sg_device_mem_t *pmem,
|
763 |
+
int heap_id_mask, unsigned long long size);
|
764 |
+
|
765 |
+
/**
|
766 |
+
* @name bm_free_device
|
767 |
+
* @brief To free device memory
|
768 |
+
* @ingroup bmlib_runtime
|
769 |
+
*
|
770 |
+
* @param [in] handle The device handle
|
771 |
+
* @param [in] mem The device memory descriptor to free
|
772 |
+
*/
|
773 |
+
DECL_EXPORT void bm_free_device(bm_handle_t handle, bm_device_mem_t mem);
|
774 |
+
|
775 |
+
/**
|
776 |
+
* @name sg_free_device
|
777 |
+
* @brief To free device memory
|
778 |
+
* @ingroup bmlib_runtime
|
779 |
+
*
|
780 |
+
* @param [in] handle The device handle
|
781 |
+
* @param [in] mem The device memory descriptor to free
|
782 |
+
*/
|
783 |
+
DECL_EXPORT void sg_free_device(bm_handle_t handle, sg_device_mem_t mem);
|
784 |
+
|
785 |
+
/**
|
786 |
+
* @name bm_gmem_arm_reserved_request
|
787 |
+
* @brief To obtain the address of global memory reserved for arm926
|
788 |
+
* @param [in] handle The device handle
|
789 |
+
*
|
790 |
+
* @retval unsigned long long The absolute address of gmem reserved for arm926
|
791 |
+
*/
|
792 |
+
DECL_EXPORT unsigned long long bm_gmem_arm_reserved_request(bm_handle_t handle);
|
793 |
+
|
794 |
+
/**
|
795 |
+
* @name bm_gmem_arm_reserved_release
|
796 |
+
* @brief To release the global memory reserved for arm926
|
797 |
+
* @ingroup bmlib_runtime
|
798 |
+
*
|
799 |
+
* @param [in] handle The device handle
|
800 |
+
*/
|
801 |
+
DECL_EXPORT void bm_gmem_arm_reserved_release(bm_handle_t handle);
|
802 |
+
|
803 |
+
/*******************memory copy functions *************************************/
|
804 |
+
/**
|
805 |
+
* @name bm_memcpy_s2d
|
806 |
+
* @brief To copy data from system memory to device memory
|
807 |
+
* @ingroup bmlib_runtime
|
808 |
+
*
|
809 |
+
* @param [in] handle The device handle
|
810 |
+
* @param [in] dst The destination memory (device memory descriptor )
|
811 |
+
* @param [in] src The source memory (system memory, a void* pointer)
|
812 |
+
*
|
813 |
+
* @retval BM_SUCCESS Succeeds.
|
814 |
+
* Other code Fails.
|
815 |
+
*/
|
816 |
+
DECL_EXPORT bm_status_t bm_memcpy_s2d(bm_handle_t handle, bm_device_mem_t dst, void *src);
|
817 |
+
|
818 |
+
/**
|
819 |
+
* @name bm_memcpy_p2p
|
820 |
+
* @brief To copy data from one chip to another chip
|
821 |
+
* @ingroup bmlib_runtime
|
822 |
+
*
|
823 |
+
* @param [in] handle_src The source device handle
|
824 |
+
* @param [in] src The source memory (device memory descriptor )
|
825 |
+
* @param [in] handle_dst The destination device handle
|
826 |
+
* @param [in] dst The destination memory (device memory descriptor )
|
827 |
+
*
|
828 |
+
* @retval BM_SUCCESS Succeeds.
|
829 |
+
* Other code Fails.
|
830 |
+
*/
|
831 |
+
DECL_EXPORT bm_status_t bm_memcpy_p2p(bm_handle_t handle_src, bm_device_mem_t src, bm_handle_t handle_dst,bm_device_mem_t dst);
|
832 |
+
|
833 |
+
/**
|
834 |
+
* @name sg_memcpy_s2d
|
835 |
+
* @brief To copy data from system memory to device memory
|
836 |
+
* @ingroup bmlib_runtime
|
837 |
+
*
|
838 |
+
* @param [in] handle The device handle
|
839 |
+
* @param [in] dst The destination memory (device memory descriptor )
|
840 |
+
* @param [in] src The source memory (system memory, a void* pointer)
|
841 |
+
*
|
842 |
+
* @retval BM_SUCCESS Succeeds.
|
843 |
+
* Other code Fails.
|
844 |
+
*/
|
845 |
+
DECL_EXPORT bm_status_t sg_memcpy_s2d(bm_handle_t handle, sg_device_mem_t dst, void *src);
|
846 |
+
|
847 |
+
/**
|
848 |
+
* @name bm_memcpy_s2d_partial_offset
|
849 |
+
* @brief To copy specified bytes of data from system memory to device memory
|
850 |
+
* with an offset in device memory address.
|
851 |
+
* @ingroup bmlib_runtime
|
852 |
+
*
|
853 |
+
* @param [in] handle The device handle
|
854 |
+
* @param [in] dst The destination memory (device memory descriptor)
|
855 |
+
* @param [in] src The source memory (system memory, a void* pointer)
|
856 |
+
* @param [in] size The size of data to copy (in bytes)
|
857 |
+
* @param [in] offset The offset of the device memory address
|
858 |
+
*
|
859 |
+
* @retval BM_SUCCESS Succeeds.
|
860 |
+
* Other code Fails.
|
861 |
+
*/
|
862 |
+
DECL_EXPORT bm_status_t bm_memcpy_s2d_partial_offset(bm_handle_t handle,
|
863 |
+
bm_device_mem_t dst, void *src,
|
864 |
+
unsigned int size,
|
865 |
+
unsigned int offset);
|
866 |
+
|
867 |
+
/**
|
868 |
+
* @name sg_memcpy_s2d_partial_offset
|
869 |
+
* @brief To copy specified bytes of data from system memory to device memory
|
870 |
+
* with an offset in device memory address.
|
871 |
+
* @ingroup bmlib_runtime
|
872 |
+
*
|
873 |
+
* @param [in] handle The device handle
|
874 |
+
* @param [in] dst The destination memory (device memory descriptor)
|
875 |
+
* @param [in] src The source memory (system memory, a void* pointer)
|
876 |
+
* @param [in] size The size of data to copy (in bytes)
|
877 |
+
* @param [in] offset The offset of the device memory address
|
878 |
+
*
|
879 |
+
* @retval BM_SUCCESS Succeeds.
|
880 |
+
* Other code Fails.
|
881 |
+
*/
|
882 |
+
DECL_EXPORT bm_status_t sg_memcpy_s2d_partial_offset(bm_handle_t handle,
|
883 |
+
sg_device_mem_t dst, void *src,
|
884 |
+
unsigned long long size,
|
885 |
+
unsigned long long offset);
|
886 |
+
|
887 |
+
/**
|
888 |
+
* @name bm_memcpy_s2d_partial
|
889 |
+
* @brief To copy specified bytes of data from system memory to device memory
|
890 |
+
* @ingroup bmlib_runtime
|
891 |
+
*
|
892 |
+
* @param [in] handle The device handle
|
893 |
+
* @param [in] dst The destination memory (device memory descriptor)
|
894 |
+
* @param [in] src The source memory (system memory, a void* pointer)
|
895 |
+
* @param [in] size The size of data to copy (in bytes)
|
896 |
+
*
|
897 |
+
* @retval BM_SUCCESS Succeeds.
|
898 |
+
* Other code Fails.
|
899 |
+
*/
|
900 |
+
DECL_EXPORT bm_status_t bm_memcpy_s2d_partial(bm_handle_t handle, bm_device_mem_t dst,
|
901 |
+
void *src, unsigned int size);
|
902 |
+
|
903 |
+
/**
|
904 |
+
* @name sg_memcpy_s2d_partial
|
905 |
+
* @brief To copy specified bytes of data from system memory to device memory
|
906 |
+
* @ingroup bmlib_runtime
|
907 |
+
*
|
908 |
+
* @param [in] handle The device handle
|
909 |
+
* @param [in] dst The destination memory (device memory descriptor)
|
910 |
+
* @param [in] src The source memory (system memory, a void* pointer)
|
911 |
+
* @param [in] size The size of data to copy (in bytes)
|
912 |
+
*
|
913 |
+
* @retval BM_SUCCESS Succeeds.
|
914 |
+
* Other code Fails.
|
915 |
+
*/
|
916 |
+
DECL_EXPORT bm_status_t sg_memcpy_s2d_partial(bm_handle_t handle, sg_device_mem_t dst,
|
917 |
+
void *src, unsigned long long size);
|
918 |
+
|
919 |
+
/**
|
920 |
+
* @name bm_memcpy_d2s
|
921 |
+
* @brief To copy data from device memory to system memory
|
922 |
+
* @ingroup bmlib_runtime
|
923 |
+
*
|
924 |
+
* @param [in] handle The device handle
|
925 |
+
* @param [in] dst The destination memory (system memory, a void* pointer)
|
926 |
+
* @param [in] src The source memory (device memory descriptor)
|
927 |
+
*
|
928 |
+
* @retval BM_SUCCESS Succeeds.
|
929 |
+
* Other code Fails.
|
930 |
+
*/
|
931 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2s(bm_handle_t handle, void *dst, bm_device_mem_t src);
|
932 |
+
|
933 |
+
/**
|
934 |
+
* @name sg_memcpy_d2s
|
935 |
+
* @brief To copy data from device memory to system memory
|
936 |
+
* @ingroup bmlib_runtime
|
937 |
+
*
|
938 |
+
* @param [in] handle The device handle
|
939 |
+
* @param [in] dst The destination memory (system memory, a void* pointer)
|
940 |
+
* @param [in] src The source memory (device memory descriptor)
|
941 |
+
*
|
942 |
+
* @retval BM_SUCCESS Succeeds.
|
943 |
+
* Other code Fails.
|
944 |
+
*/
|
945 |
+
DECL_EXPORT bm_status_t sg_memcpy_d2s(bm_handle_t handle, void *dst, sg_device_mem_t src);
|
946 |
+
|
947 |
+
/**
|
948 |
+
* @name bm_memcpy_d2s_partial_offset
|
949 |
+
* @brief To copy specified bytes of data from device memory to system memory
|
950 |
+
* with an offset in device memory address.
|
951 |
+
* @ingroup bmlib_runtime
|
952 |
+
*
|
953 |
+
* @param [in] handle The device handle
|
954 |
+
* @param [in] dst The destination memory (system memory, a void* pointer)
|
955 |
+
* @param [in] src The source memory (device memory descriptor)
|
956 |
+
* @param [in] size The size of data to copy (in bytes)
|
957 |
+
* @param [in] offset The offset of the device memory address
|
958 |
+
*
|
959 |
+
* @retval BM_SUCCESS Succeeds.
|
960 |
+
* Other code Fails.
|
961 |
+
*/
|
962 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2s_partial_offset(bm_handle_t handle, void *dst,
|
963 |
+
bm_device_mem_t src, unsigned int size,
|
964 |
+
unsigned int offset);
|
965 |
+
|
966 |
+
/**
|
967 |
+
* @name sg_memcpy_d2s_partial_offset
|
968 |
+
* @brief To copy specified bytes of data from device memory to system memory
|
969 |
+
* with an offset in device memory address.
|
970 |
+
* @ingroup bmlib_runtime
|
971 |
+
*
|
972 |
+
* @param [in] handle The device handle
|
973 |
+
* @param [in] dst The destination memory (system memory, a void* pointer)
|
974 |
+
* @param [in] src The source memory (device memory descriptor)
|
975 |
+
* @param [in] size The size of data to copy (in bytes)
|
976 |
+
* @param [in] offset The offset of the device memory address
|
977 |
+
*
|
978 |
+
* @retval BM_SUCCESS Succeeds.
|
979 |
+
* Other code Fails.
|
980 |
+
*/
|
981 |
+
DECL_EXPORT bm_status_t sg_memcpy_d2s_partial_offset(bm_handle_t handle, void *dst,
|
982 |
+
sg_device_mem_t src, unsigned long long size,
|
983 |
+
unsigned long long offset);
|
984 |
+
|
985 |
+
/**
|
986 |
+
* @name bm_memcpy_d2s_partial
|
987 |
+
* @brief To copy specified bytes of data from device memory to system memory
|
988 |
+
* @ingroup bmlib_runtime
|
989 |
+
*
|
990 |
+
* @param [in] handle The device handle
|
991 |
+
* @param [in] dst The destination memory (system memory, a void* pointer)
|
992 |
+
* @param [in] src The source memory (device memory descriptor)
|
993 |
+
* @param [in] size The size of data to copy (in bytes)
|
994 |
+
*
|
995 |
+
* @retval BM_SUCCESS Data transfer succeeds.
|
996 |
+
* Other code Data transfer fails.
|
997 |
+
*/
|
998 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2s_partial(bm_handle_t handle, void *dst,
|
999 |
+
bm_device_mem_t src, unsigned int size);
|
1000 |
+
|
1001 |
+
/**
|
1002 |
+
* @name sg_memcpy_d2s_partial
|
1003 |
+
* @brief To copy specified bytes of data from device memory to system memory
|
1004 |
+
* @ingroup bmlib_runtime
|
1005 |
+
*
|
1006 |
+
* @param [in] handle The device handle
|
1007 |
+
* @param [in] dst The destination memory (system memory, a void* pointer)
|
1008 |
+
* @param [in] src The source memory (device memory descriptor)
|
1009 |
+
* @param [in] size The size of data to copy (in bytes)
|
1010 |
+
*
|
1011 |
+
* @retval BM_SUCCESS Data transfer succeeds.
|
1012 |
+
* Other code Data transfer fails.
|
1013 |
+
*/
|
1014 |
+
DECL_EXPORT bm_status_t sg_memcpy_d2s_partial(bm_handle_t handle, void *dst,
|
1015 |
+
sg_device_mem_t src, unsigned long long size);
|
1016 |
+
|
1017 |
+
/**
|
1018 |
+
* @name bm_memcpy_d2d
|
1019 |
+
* @brief To copy specified dwords of data from one piece of device memory
|
1020 |
+
* to another piece of device memory within one device. Both source
|
1021 |
+
* and destination offsets can be specified.
|
1022 |
+
* @ingroup bmlib_runtime
|
1023 |
+
*
|
1024 |
+
* @param [in] handle The device handle
|
1025 |
+
* @param [in] dst The destination device memory
|
1026 |
+
* @param [in] dst_offset The offset of destination device memory address
|
1027 |
+
* @param [in] src The source device memory
|
1028 |
+
* @param [in] src_offset The offset of source device memory address
|
1029 |
+
* @param [in] len Length of data to copy (in DWORD 4 bytes)
|
1030 |
+
*
|
1031 |
+
* @retval BM_SUCCESS Succeeds.
|
1032 |
+
* Other code Fails.
|
1033 |
+
*/
|
1034 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2d(bm_handle_t handle, bm_device_mem_t dst,
|
1035 |
+
int dst_offset, bm_device_mem_t src, int src_offset,
|
1036 |
+
int len);
|
1037 |
+
|
1038 |
+
/**
|
1039 |
+
* @name bm_memcpy_d2d_with_core
|
1040 |
+
* @brief To copy specified dwords of data from one piece of device memory
|
1041 |
+
* to another piece of device memory within one device. Both source
|
1042 |
+
* and destination offsets can be specified.
|
1043 |
+
* @ingroup bmlib_runtime
|
1044 |
+
*
|
1045 |
+
* @param [in] handle The device handle
|
1046 |
+
* @param [in] dst The destination device memory
|
1047 |
+
* @param [in] dst_offset The offset of destination device memory address
|
1048 |
+
* @param [in] src The source device memory
|
1049 |
+
* @param [in] src_offset The offset of source device memory address
|
1050 |
+
* @param [in] len Length of data to copy (in DWORD 4 bytes)
|
1051 |
+
* @param [in] core_id The core id to copy
|
1052 |
+
*
|
1053 |
+
* @retval BM_SUCCESS Succeeds.
|
1054 |
+
* Other code Fails.
|
1055 |
+
*/
|
1056 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2d_with_core(bm_handle_t handle, bm_device_mem_t dst,
|
1057 |
+
int dst_offset, bm_device_mem_t src, int src_offset,
|
1058 |
+
int len, int core_id);
|
1059 |
+
|
1060 |
+
/**
|
1061 |
+
* @name bm_memcpy_d2d_byte
|
1062 |
+
* @brief To copy specified bytes of data from one piece of device memory
|
1063 |
+
* to another piece of device memory within one device. Both source
|
1064 |
+
* and destination offsets can be specified.
|
1065 |
+
* @ingroup bmlib_runtime
|
1066 |
+
*
|
1067 |
+
* @param [in] handle The device handle
|
1068 |
+
* @param [in] dst The destination device memory
|
1069 |
+
* @param [in] dst_offset The offset of destination device memory address (in bytes)
|
1070 |
+
* @param [in] src The source device memory
|
1071 |
+
* @param [in] src_offset The offset of source device memory address (in bytes)
|
1072 |
+
* @param [in] size Size of data to copy (in bytes)
|
1073 |
+
*
|
1074 |
+
* @retval BM_SUCCESS Succeeds.
|
1075 |
+
* Other code Fails.
|
1076 |
+
*/
|
1077 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2d_byte(bm_handle_t handle, bm_device_mem_t dst,
|
1078 |
+
size_t dst_offset, bm_device_mem_t src,
|
1079 |
+
size_t src_offset, size_t size);
|
1080 |
+
|
1081 |
+
/**
|
1082 |
+
* @name bm_memcpy_d2d_byte_with_core
|
1083 |
+
* @brief To copy specified bytes of data from one piece of device memory
|
1084 |
+
* to another piece of device memory within one device. Both source
|
1085 |
+
* and destination offsets can be specified.
|
1086 |
+
* @ingroup bmlib_runtime
|
1087 |
+
*
|
1088 |
+
* @param [in] handle The device handle
|
1089 |
+
* @param [in] dst The destination device memory
|
1090 |
+
* @param [in] dst_offset The offset of destination device memory address (in bytes)
|
1091 |
+
* @param [in] src The source device memory
|
1092 |
+
* @param [in] src_offset The offset of source device memory address (in bytes)
|
1093 |
+
* @param [in] size Size of data to copy (in bytes)
|
1094 |
+
* @param [in] core_id The core id to copy
|
1095 |
+
*
|
1096 |
+
* @retval BM_SUCCESS Succeeds.
|
1097 |
+
* Other code Fails.
|
1098 |
+
*/
|
1099 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2d_byte_with_core(bm_handle_t handle, bm_device_mem_t dst,
|
1100 |
+
size_t dst_offset, bm_device_mem_t src,
|
1101 |
+
size_t src_offset, size_t size, int core_id);
|
1102 |
+
|
1103 |
+
/**
|
1104 |
+
* @name bm_memcpy_d2d_stride
|
1105 |
+
* @brief To copy specified data from one piece of device memory
|
1106 |
+
* to another piece of device memory within one device. Both source
|
1107 |
+
* and destination offsets can be specified.
|
1108 |
+
* @ingroup bmlib_runtime
|
1109 |
+
*
|
1110 |
+
* @param [in] handle The device handle
|
1111 |
+
* @param [in] dst The destination device memory
|
1112 |
+
* @param [in] dst_stride The data stride of destination data
|
1113 |
+
* @param [in] src The source device memory
|
1114 |
+
* @param [in] src_stride The data stride of source data
|
1115 |
+
* @param [in] count Count of data to copy
|
1116 |
+
* @param [in] format_size Data format byte size, such as sizeof(uint8_t), sizeof(float), etc.
|
1117 |
+
* format_size only support 1/2/4.
|
1118 |
+
*
|
1119 |
+
* dst_stride MUST be 1, EXCEPT: dst_stride == 4 && src_stride == 1 && format_size ==1
|
1120 |
+
*
|
1121 |
+
* @retval BM_SUCCESS Succeeds.
|
1122 |
+
* Other code Fails.
|
1123 |
+
*/
|
1124 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2d_stride(bm_handle_t handle,
|
1125 |
+
bm_device_mem_t dst,
|
1126 |
+
int dst_stride,
|
1127 |
+
bm_device_mem_t src,
|
1128 |
+
int src_stride,
|
1129 |
+
int count,
|
1130 |
+
int format_size);
|
1131 |
+
|
1132 |
+
/**
|
1133 |
+
* @name bm_memcpy_d2d_stride
|
1134 |
+
* @brief To copy specified data from one piece of device memory
|
1135 |
+
* to another piece of device memory within one device. Both source
|
1136 |
+
* and destination offsets can be specified.
|
1137 |
+
* @ingroup bmlib_runtime
|
1138 |
+
*
|
1139 |
+
* @param [in] handle The device handle
|
1140 |
+
* @param [in] dst The destination device memory
|
1141 |
+
* @param [in] dst_stride The data stride of destination data
|
1142 |
+
* @param [in] src The source device memory
|
1143 |
+
* @param [in] src_stride The data stride of source data
|
1144 |
+
* @param [in] count Count of data to copy
|
1145 |
+
* @param [in] format_size Data format byte size, such as sizeof(uint8_t), sizeof(float), etc.
|
1146 |
+
* format_size only support 1/2/4.
|
1147 |
+
* @param [in] core_id The core id to copy.
|
1148 |
+
*
|
1149 |
+
* dst_stride MUST be 1, EXCEPT: dst_stride == 4 && src_stride == 1 && format_size ==1
|
1150 |
+
*
|
1151 |
+
* @retval BM_SUCCESS Succeeds.
|
1152 |
+
* Other code Fails.
|
1153 |
+
*/
|
1154 |
+
DECL_EXPORT bm_status_t bm_memcpy_d2d_stride_with_core(bm_handle_t handle,
|
1155 |
+
bm_device_mem_t dst,
|
1156 |
+
int dst_stride,
|
1157 |
+
bm_device_mem_t src,
|
1158 |
+
int src_stride,
|
1159 |
+
int count,
|
1160 |
+
int format_size,
|
1161 |
+
int core_id);
|
1162 |
+
|
1163 |
+
/**
|
1164 |
+
* @name bm_memcpy_c2c
|
1165 |
+
* @brief To copy data from one chip to another chip.
|
1166 |
+
* (Used in multi-chip card scenario)
|
1167 |
+
* @ingroup bmlib_runtime
|
1168 |
+
*
|
1169 |
+
* @param [in] src_handle The source device handle
|
1170 |
+
* @param [in] dst_handle The destination device handle
|
1171 |
+
* @param [in] src The source device memory descriptor
|
1172 |
+
* @param [in] dst The destination device memory descriptor
|
1173 |
+
* @param [in] force_dst_cdma If use the CDMA engine of the destination device
|
1174 |
+
* @retval BM_SUCCESS Succeeds.
|
1175 |
+
* Other code Fails.
|
1176 |
+
*/
|
1177 |
+
DECL_EXPORT bm_status_t bm_memcpy_c2c(bm_handle_t src_handle, bm_handle_t dst_handle,
|
1178 |
+
bm_device_mem_t src, bm_device_mem_t dst,
|
1179 |
+
bool force_dst_cdma);
|
1180 |
+
|
1181 |
+
/**
|
1182 |
+
* @name bm_memset_device
|
1183 |
+
* @brief To fill in specified device memory with the given value
|
1184 |
+
* @ingroup bmlib_runtime
|
1185 |
+
*
|
1186 |
+
* @param [in] handle The device handle
|
1187 |
+
* @param [in] value The value used to fill. (int type)
|
1188 |
+
* @param [in] mem The device memory which will be filled in
|
1189 |
+
* @retval BM_SUCCESS Succeeds.
|
1190 |
+
* Other code Fails.
|
1191 |
+
*/
|
1192 |
+
DECL_EXPORT bm_status_t bm_memset_device(bm_handle_t handle, const int value,
|
1193 |
+
bm_device_mem_t mem);
|
1194 |
+
|
1195 |
+
/**
|
1196 |
+
* @name bm_memset_device_ext
|
1197 |
+
* @brief To fill in specified device memory with the given value and mode
|
1198 |
+
* @ingroup bmlib_runtime
|
1199 |
+
*
|
1200 |
+
* @param [in] handle The device handle
|
1201 |
+
* @param [in] value The pointer of value used to fill
|
1202 |
+
* @param [in] mode The valid bytes of *value
|
1203 |
+
* @param [in] mem The device memory which will be filled in
|
1204 |
+
* @retval BM_SUCCESS Succeeds.
|
1205 |
+
* Other code Fails.
|
1206 |
+
*/
|
1207 |
+
DECL_EXPORT bm_status_t bm_memset_device_ext(bm_handle_t handle, void* value, int mode,
|
1208 |
+
bm_device_mem_t mem);
|
1209 |
+
|
1210 |
+
/**
|
1211 |
+
* @name bm_mem_convert_system_to_device_neuron
|
1212 |
+
* @brief To malloc a piece of device memory according to the shape of
|
1213 |
+
* neuron(in DWORD 4 bytes); copy neuron from system memory to
|
1214 |
+
* device memory if need_copy is true.
|
1215 |
+
* @ingroup bmlib_runtime
|
1216 |
+
*
|
1217 |
+
* @param [in] handle The device handle
|
1218 |
+
* @param [in] dev_mem The device memory descriptor
|
1219 |
+
* @param [in] sys_mem The system memory descriptor
|
1220 |
+
* @param [in] need_copy If copy from system to device is needed
|
1221 |
+
* @param [in] n,c,h,w Neuron shape size
|
1222 |
+
*
|
1223 |
+
* @retval BM_SUCCESS Succeeds.
|
1224 |
+
* Other code Fails.
|
1225 |
+
*/
|
1226 |
+
DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_neuron(bm_handle_t handle,
|
1227 |
+
struct bm_mem_desc *dev_mem,
|
1228 |
+
struct bm_mem_desc sys_mem,
|
1229 |
+
bool need_copy, int n, int c,
|
1230 |
+
int h, int w);
|
1231 |
+
|
1232 |
+
/**
|
1233 |
+
* @name bm_mem_convert_system_to_device_neuron_byte
|
1234 |
+
* @brief To malloc a piece of device memory according to the shape of
|
1235 |
+
* neuron(in bytes); copy neuron from system memory to
|
1236 |
+
* device memory if need_copy is true.
|
1237 |
+
* @ingroup bmlib_runtime
|
1238 |
+
*
|
1239 |
+
* @param [in] handle The device handle
|
1240 |
+
* @param [in] dev_mem The device memory descriptor
|
1241 |
+
* @param [in] sys_mem The system memory descriptor
|
1242 |
+
* @param [in] need_copy If copy from system to device is needed
|
1243 |
+
* @param [in] n,c,h,w Neuron shape size
|
1244 |
+
*
|
1245 |
+
* @retval BM_SUCCESS Succeeds.
|
1246 |
+
* Other code Fails.
|
1247 |
+
*/
|
1248 |
+
DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_neuron_byte(
|
1249 |
+
bm_handle_t handle, struct bm_mem_desc *dev_mem, struct bm_mem_desc sys_mem,
|
1250 |
+
bool need_copy, int n, int c, int h, int w);
|
1251 |
+
|
1252 |
+
/**
|
1253 |
+
* @name bm_mem_convert_system_to_device_coeff
|
1254 |
+
* @brief To malloc a piece of device memory according to the size of
|
1255 |
+
* coefficient (in DWORD 4 bytes); copy coefficient from system
|
1256 |
+
* memory to device memory if need_copy is true.
|
1257 |
+
* @ingroup bmlib_runtime
|
1258 |
+
*
|
1259 |
+
* @param [in] handle The device handle
|
1260 |
+
* @param [in] dev_mem The device memory descriptor
|
1261 |
+
* @param [in] sys_mem The system memory descriptor
|
1262 |
+
* @param [in] need_copy If copy from system to device is needed
|
1263 |
+
* @param [in] coeff_count Coefficient size
|
1264 |
+
*
|
1265 |
+
* @retval BM_SUCCESS Succeeds.
|
1266 |
+
* Other code Fails.
|
1267 |
+
*/
|
1268 |
+
DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_coeff(bm_handle_t handle,
|
1269 |
+
struct bm_mem_desc *dev_mem,
|
1270 |
+
struct bm_mem_desc sys_mem,
|
1271 |
+
bool need_copy,
|
1272 |
+
int coeff_count);
|
1273 |
+
/**
|
1274 |
+
* @name bm_mem_convert_system_to_device_coeff_byte
|
1275 |
+
* @brief To malloc a piece of device memory according to the size of
|
1276 |
+
* coefficient (in bytes); copy coefficient from system
|
1277 |
+
* memory to device memory if need_copy is true.
|
1278 |
+
* @ingroup bmlib_runtime
|
1279 |
+
*
|
1280 |
+
* @param [in] handle The device handle
|
1281 |
+
* @param [in] dev_mem The device memory descriptor
|
1282 |
+
* @param [in] sys_mem The system memory descriptor
|
1283 |
+
* @param [in] need_copy If copy from system to device is needed
|
1284 |
+
* @param [in] coeff_count Coefficient size
|
1285 |
+
*
|
1286 |
+
* @retval BM_SUCCESS Succeeds.
|
1287 |
+
* Other code Fails.
|
1288 |
+
*/
|
1289 |
+
DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_coeff_byte(
|
1290 |
+
bm_handle_t handle, struct bm_mem_desc *dev_mem, struct bm_mem_desc sys_mem,
|
1291 |
+
bool need_copy, int coeff_count);
|
1292 |
+
|
1293 |
+
/*******************memory map functions *************************************/
|
1294 |
+
/**
|
1295 |
+
* @name bm_mem_mmap_device_mem
|
1296 |
+
* @brief To map a piece of device memory to user space with cache enabled.
|
1297 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1298 |
+
* @ingroup bmlib_runtime
|
1299 |
+
*
|
1300 |
+
* @param [in] handle The device handle
|
1301 |
+
* @param [in] dev_mem The device memory to map
|
1302 |
+
* @param [out] vmem The virtual address of the mapped device memory
|
1303 |
+
*
|
1304 |
+
* @retval BM_SUCCESS Succeeds.
|
1305 |
+
* Other code Fails.
|
1306 |
+
*/
|
1307 |
+
DECL_EXPORT bm_status_t bm_mem_mmap_device_mem(bm_handle_t handle, bm_device_mem_t *dmem,
|
1308 |
+
|
1309 |
+
unsigned long long *vmem);
|
1310 |
+
|
1311 |
+
/**
|
1312 |
+
* @name sg_mem_mmap_device_mem
|
1313 |
+
* @brief To map a piece of device memory to user space with cache enabled.
|
1314 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1315 |
+
* @ingroup bmlib_runtime
|
1316 |
+
*
|
1317 |
+
* @param [in] handle The device handle
|
1318 |
+
* @param [in] dev_mem The device memory to map
|
1319 |
+
* @param [out] vmem The virtual address of the mapped device memory
|
1320 |
+
*
|
1321 |
+
* @retval BM_SUCCESS Succeeds.
|
1322 |
+
* Other code Fails.
|
1323 |
+
*/
|
1324 |
+
DECL_EXPORT bm_status_t sg_mem_mmap_device_mem(bm_handle_t handle, sg_device_mem_t *dmem,
|
1325 |
+
unsigned long long *vmem);
|
1326 |
+
|
1327 |
+
/*******************memory map functions *************************************/
|
1328 |
+
/**
|
1329 |
+
* @name bm_mem_mmap_device_mem_no_cache
|
1330 |
+
* @brief To map a piece of device memory to user space with cache disabled.
|
1331 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1332 |
+
* @ingroup bmlib_runtime
|
1333 |
+
*
|
1334 |
+
* @param [in] handle The device handle
|
1335 |
+
* @param [in] dev_mem The device memory to map
|
1336 |
+
* @param [out] vmem The virtual address of the mapped device memory
|
1337 |
+
*
|
1338 |
+
* @retval BM_SUCCESS Succeeds.
|
1339 |
+
* Other code Fails.
|
1340 |
+
*/
|
1341 |
+
DECL_EXPORT bm_status_t bm_mem_mmap_device_mem_no_cache(bm_handle_t handle, bm_device_mem_t *dmem,
|
1342 |
+
|
1343 |
+
unsigned long long *vmem);
|
1344 |
+
|
1345 |
+
/**
|
1346 |
+
* @name sg_mem_mmap_device_mem_no_cache
|
1347 |
+
* @brief To map a piece of device memory to user space with cache disabled.
|
1348 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1349 |
+
* @ingroup bmlib_runtime
|
1350 |
+
*
|
1351 |
+
* @param [in] handle The device handle
|
1352 |
+
* @param [in] dev_mem The device memory to map
|
1353 |
+
* @param [out] vmem The virtual address of the mapped device memory
|
1354 |
+
*
|
1355 |
+
* @retval BM_SUCCESS Succeeds.
|
1356 |
+
* Other code Fails.
|
1357 |
+
*/
|
1358 |
+
DECL_EXPORT bm_status_t sg_mem_mmap_device_mem_no_cache(bm_handle_t handle, sg_device_mem_t *dmem,
|
1359 |
+
unsigned long long *vmem);
|
1360 |
+
|
1361 |
+
/**
|
1362 |
+
* @name bm_mem_vir_to_phy
|
1363 |
+
* @brief To get device mem address through the mapped virtual address .
|
1364 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1365 |
+
* @ingroup bmlib_runtime
|
1366 |
+
*
|
1367 |
+
* @param [in] handle The device handle
|
1368 |
+
* @param [in] vmem The virtual address of the mapped device memory
|
1369 |
+
* @param [out] dev_mem The device memory address
|
1370 |
+
*
|
1371 |
+
* @retval BM_SUCCESS Succeeds.
|
1372 |
+
* Other code Fails.
|
1373 |
+
*/
|
1374 |
+
DECL_EXPORT bm_status_t bm_mem_vir_to_phy(bm_handle_t handle, unsigned long long vmem,
|
1375 |
+
unsigned long long *device_mem);
|
1376 |
+
/**
|
1377 |
+
* @name bm_mem_invalidate_device_mem
|
1378 |
+
* @brief To invalidate a piece of mapped device memory to maintain
|
1379 |
+
* cache coherence
|
1380 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1381 |
+
* @ingroup bmlib_runtime
|
1382 |
+
*
|
1383 |
+
* @param [in] handle The device handle
|
1384 |
+
* @param [in] dmem The device memory to invalidate
|
1385 |
+
*
|
1386 |
+
* @retval BM_SUCCESS Succeeds.
|
1387 |
+
* Other code Fails.
|
1388 |
+
*/
|
1389 |
+
|
1390 |
+
DECL_EXPORT bm_status_t bm_mem_invalidate_device_mem(bm_handle_t handle,
|
1391 |
+
bm_device_mem_t *dmem);
|
1392 |
+
|
1393 |
+
/**
|
1394 |
+
* @name sg_mem_invalidate_device_mem
|
1395 |
+
* @brief To invalidate a piece of mapped device memory to maintain
|
1396 |
+
* cache coherence
|
1397 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1398 |
+
* @ingroup bmlib_runtime
|
1399 |
+
*
|
1400 |
+
* @param [in] handle The device handle
|
1401 |
+
* @param [in] dmem The device memory to invalidate
|
1402 |
+
*
|
1403 |
+
* @retval BM_SUCCESS Succeeds.
|
1404 |
+
* Other code Fails.
|
1405 |
+
*/
|
1406 |
+
|
1407 |
+
DECL_EXPORT bm_status_t sg_mem_invalidate_device_mem(bm_handle_t handle,
|
1408 |
+
sg_device_mem_t *dmem);
|
1409 |
+
|
1410 |
+
/**
|
1411 |
+
* @name bm_mem_invalidate_partial_device_mem
|
1412 |
+
* @brief To invalidate part of mapped device memory to maintain
|
1413 |
+
* cache coherence
|
1414 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1415 |
+
* @ingroup bmlib_runtime
|
1416 |
+
*
|
1417 |
+
* @param [in] handle The device handle
|
1418 |
+
* @param [in] dmem The device memory to invalidate
|
1419 |
+
* @param [in] offset The offset of device memory address
|
1420 |
+
* @param [in] len The length of memory to invalidate in bytes
|
1421 |
+
*
|
1422 |
+
* @retval BM_SUCCESS Succeeds.
|
1423 |
+
* Other code Fails.
|
1424 |
+
*/
|
1425 |
+
DECL_EXPORT bm_status_t bm_mem_invalidate_partial_device_mem(bm_handle_t handle,
|
1426 |
+
bm_device_mem_t *dmem,
|
1427 |
+
unsigned int offset,
|
1428 |
+
unsigned int len);
|
1429 |
+
|
1430 |
+
/**
|
1431 |
+
* @name sg_mem_invalidate_partial_device_mem
|
1432 |
+
* @brief To invalidate part of mapped device memory to maintain
|
1433 |
+
* cache coherence
|
1434 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1435 |
+
* @ingroup bmlib_runtime
|
1436 |
+
*
|
1437 |
+
* @param [in] handle The device handle
|
1438 |
+
* @param [in] dmem The device memory to invalidate
|
1439 |
+
* @param [in] offset The offset of device memory address
|
1440 |
+
* @param [in] len The length of memory to invalidate in bytes
|
1441 |
+
*
|
1442 |
+
* @retval BM_SUCCESS Succeeds.
|
1443 |
+
* Other code Fails.
|
1444 |
+
*/
|
1445 |
+
DECL_EXPORT bm_status_t sg_mem_invalidate_partial_device_mem(bm_handle_t handle,
|
1446 |
+
sg_device_mem_t *dmem,
|
1447 |
+
unsigned long long offset,
|
1448 |
+
unsigned long long len);
|
1449 |
+
|
1450 |
+
/**
|
1451 |
+
* @name bm_mem_flush_device_mem
|
1452 |
+
* @brief To flush a piece of mapped device memory to maintain
|
1453 |
+
* cache coherence
|
1454 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1455 |
+
* @ingroup bmlib_runtime
|
1456 |
+
*
|
1457 |
+
* @param [in] handle The device handle
|
1458 |
+
* @param [in] dmem The device memory to flush
|
1459 |
+
*
|
1460 |
+
* @retval BM_SUCCESS Succeeds.
|
1461 |
+
* Other code Fails.
|
1462 |
+
*/
|
1463 |
+
DECL_EXPORT bm_status_t bm_mem_flush_device_mem(bm_handle_t handle, bm_device_mem_t *dmem);
|
1464 |
+
|
1465 |
+
/**
|
1466 |
+
* @name sg_mem_flush_device_mem
|
1467 |
+
* @brief To flush a piece of mapped device memory to maintain
|
1468 |
+
* cache coherence
|
1469 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1470 |
+
* @ingroup bmlib_runtime
|
1471 |
+
*
|
1472 |
+
* @param [in] handle The device handle
|
1473 |
+
* @param [in] dmem The device memory to flush
|
1474 |
+
*
|
1475 |
+
* @retval BM_SUCCESS Succeeds.
|
1476 |
+
* Other code Fails.
|
1477 |
+
*/
|
1478 |
+
DECL_EXPORT bm_status_t sg_mem_flush_device_mem(bm_handle_t handle, sg_device_mem_t *dmem);
|
1479 |
+
|
1480 |
+
/**
|
1481 |
+
* @name bm_mem_flush_partial_device_mem
|
1482 |
+
* @brief To flush part of mapped device memory to maintain
|
1483 |
+
* cache coherence
|
1484 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1485 |
+
* @ingroup bmlib_runtime
|
1486 |
+
*
|
1487 |
+
* @param [in] handle The device handle
|
1488 |
+
* @param [in] dmem The device memory to flush
|
1489 |
+
* @param [in] offset The offset of device memory address
|
1490 |
+
* @param [in] len The length of memory to flush in bytes
|
1491 |
+
*
|
1492 |
+
* @retval BM_SUCCESS Succeeds.
|
1493 |
+
* Other code Fails.
|
1494 |
+
*/
|
1495 |
+
DECL_EXPORT bm_status_t bm_mem_flush_partial_device_mem(bm_handle_t handle,
|
1496 |
+
bm_device_mem_t *dmem,
|
1497 |
+
unsigned int offset,
|
1498 |
+
unsigned int len);
|
1499 |
+
|
1500 |
+
/**
|
1501 |
+
* @name sg_mem_flush_partial_device_mem
|
1502 |
+
* @brief To flush part of mapped device memory to maintain
|
1503 |
+
* cache coherence
|
1504 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1505 |
+
* @ingroup bmlib_runtime
|
1506 |
+
*
|
1507 |
+
* @param [in] handle The device handle
|
1508 |
+
* @param [in] dmem The device memory to flush
|
1509 |
+
* @param [in] offset The offset of device memory address
|
1510 |
+
* @param [in] len The length of memory to flush in bytes
|
1511 |
+
*
|
1512 |
+
* @retval BM_SUCCESS Succeeds.
|
1513 |
+
* Other code Fails.
|
1514 |
+
*/
|
1515 |
+
DECL_EXPORT bm_status_t sg_mem_flush_partial_device_mem(bm_handle_t handle,
|
1516 |
+
sg_device_mem_t *dmem,
|
1517 |
+
unsigned long long offset,
|
1518 |
+
unsigned long long len);
|
1519 |
+
|
1520 |
+
/**
|
1521 |
+
* @name bm_mem_unmap_device_mem
|
1522 |
+
* @brief To unmap a piece of mapped device memory
|
1523 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1524 |
+
* @ingroup bmlib_runtime
|
1525 |
+
*
|
1526 |
+
* @param [in] handle The device handle
|
1527 |
+
* @param [in] vmem The virtual address of the mapped device memory
|
1528 |
+
* @param [in] size The size of unmapped memory
|
1529 |
+
*
|
1530 |
+
* @retval BM_SUCCESS Succeeds.
|
1531 |
+
* Other code Fails.
|
1532 |
+
*/
|
1533 |
+
DECL_EXPORT bm_status_t bm_mem_unmap_device_mem(bm_handle_t handle, void *vmem, int size);
|
1534 |
+
|
1535 |
+
/**
|
1536 |
+
* @name sg_mem_unmap_device_mem
|
1537 |
+
* @brief To unmap a piece of mapped device memory
|
1538 |
+
* (only valid in SoC mode; Not supported in PCIE mode).
|
1539 |
+
* @ingroup bmlib_runtime
|
1540 |
+
*
|
1541 |
+
* @param [in] handle The device handle
|
1542 |
+
* @param [in] vmem The virtual address of the mapped device memory
|
1543 |
+
* @param [in] size The size of unmapped memory
|
1544 |
+
*
|
1545 |
+
* @retval BM_SUCCESS Succeeds.
|
1546 |
+
* Other code Fails.
|
1547 |
+
*/
|
1548 |
+
DECL_EXPORT bm_status_t sg_mem_unmap_device_mem(bm_handle_t handle, void *vmem, unsigned long long size);
|
1549 |
+
|
1550 |
+
/*******************api(kernel) functions *************************************/
|
1551 |
+
/**
|
1552 |
+
* @name bm_flush
|
1553 |
+
* @brief To synchronize APIs of the current thread. The thread will block
|
1554 |
+
* until all the outstanding APIs of the current thread are finished.
|
1555 |
+
* @ingroup bmlib_runtime
|
1556 |
+
*
|
1557 |
+
* @param [in] handle The device handle
|
1558 |
+
*/
|
1559 |
+
DECL_EXPORT void bm_flush(bm_handle_t handle);
|
1560 |
+
|
1561 |
+
/**
|
1562 |
+
* @name bm_device_sync
|
1563 |
+
* @brief To synchronize APIs of the device. The thread will block
|
1564 |
+
* until all the outstanding APIs of the device are finished.
|
1565 |
+
* @ingroup bmlib_runtime
|
1566 |
+
*
|
1567 |
+
* @param [in] handle The device handle
|
1568 |
+
* @retval BM_SUCCESS Succeeds.
|
1569 |
+
* Other code Fails.
|
1570 |
+
*/
|
1571 |
+
DECL_EXPORT bm_status_t bm_device_sync(bm_handle_t handle);
|
1572 |
+
|
1573 |
+
/**
|
1574 |
+
* @name bm_handle_sync
|
1575 |
+
* @brief To synchronize APIs of the handle. The thread will block
|
1576 |
+
* until all the outstanding APIs of the handle are finished.
|
1577 |
+
* @ingroup bmlib_runtime
|
1578 |
+
*
|
1579 |
+
* @param [in] handle The device handle
|
1580 |
+
* @retval BM_SUCCESS Succeeds.
|
1581 |
+
* Other code Fails.
|
1582 |
+
*/
|
1583 |
+
DECL_EXPORT bm_status_t bm_handle_sync(bm_handle_t handle);
|
1584 |
+
|
1585 |
+
/**
|
1586 |
+
* @name bm_handle_sync_from_core
|
1587 |
+
* @brief To synchronize APIs of the handle. The thread will block
|
1588 |
+
* until all the outstanding APIs of the handle are finished.
|
1589 |
+
* @ingroup bmlib_runtime
|
1590 |
+
*
|
1591 |
+
* @param [in] handle The device handle
|
1592 |
+
* @param [in] core_id The core id
|
1593 |
+
* @retval BM_SUCCESS Succeeds.
|
1594 |
+
* Other code Fails.
|
1595 |
+
*/
|
1596 |
+
DECL_EXPORT bm_status_t bm_handle_sync_from_core(bm_handle_t handle, int core_id);
|
1597 |
+
|
1598 |
+
/**
|
1599 |
+
* @name bm_thread_sync
|
1600 |
+
* @brief To synchronize APIs of the current thread. The thread will block
|
1601 |
+
* until all the outstanding APIs of the current thread are finished.
|
1602 |
+
* @ingroup bmlib_runtime
|
1603 |
+
*
|
1604 |
+
* @param [in] handle The device handle
|
1605 |
+
* @retval BM_SUCCESS Succeeds.
|
1606 |
+
* Other code Fails.
|
1607 |
+
*/
|
1608 |
+
DECL_EXPORT bm_status_t bm_thread_sync(bm_handle_t handle);
|
1609 |
+
|
1610 |
+
/**
|
1611 |
+
* @name bm_thread_sync_from_core
|
1612 |
+
* @brief To synchronize APIs of the current thread. The thread will block
|
1613 |
+
* until all the outstanding APIs of the current thread are finished.
|
1614 |
+
* @ingroup bmlib_runtime
|
1615 |
+
*
|
1616 |
+
* @param [in] handle The device handle
|
1617 |
+
* @param [in] core_id The core id
|
1618 |
+
* @retval BM_SUCCESS Succeeds.
|
1619 |
+
* Other code Fails.
|
1620 |
+
*/
|
1621 |
+
DECL_EXPORT bm_status_t bm_thread_sync_from_core(bm_handle_t handle, int core_id);
|
1622 |
+
|
1623 |
+
/*******************trace and profile releated functions **********************/
|
1624 |
+
typedef struct bm_profile {
|
1625 |
+
#ifdef __linux__
|
1626 |
+
unsigned long cdma_in_time;
|
1627 |
+
unsigned long cdma_in_counter;
|
1628 |
+
unsigned long cdma_out_time;
|
1629 |
+
unsigned long cdma_out_counter;
|
1630 |
+
unsigned long tpu_process_time;
|
1631 |
+
unsigned long tpu1_process_time;
|
1632 |
+
unsigned long sent_api_counter;
|
1633 |
+
unsigned long completed_api_counter;
|
1634 |
+
#else
|
1635 |
+
unsigned long long cdma_in_time;
|
1636 |
+
unsigned long long cdma_in_counter;
|
1637 |
+
unsigned long long cdma_out_time;
|
1638 |
+
unsigned long long cdma_out_counter;
|
1639 |
+
unsigned long long tpu_process_time;
|
1640 |
+
unsigned long long tpu1_process_time;
|
1641 |
+
unsigned long long sent_api_counter;
|
1642 |
+
unsigned long long completed_api_counter;
|
1643 |
+
#endif
|
1644 |
+
} bm_profile_t;
|
1645 |
+
/**
|
1646 |
+
* @name bm_get_profile
|
1647 |
+
* @brief To get the profile data at the moment
|
1648 |
+
* @ingroup bmlib_runtime
|
1649 |
+
*
|
1650 |
+
* @param [in] handle The device handle
|
1651 |
+
* @param [out] profile The result profile data
|
1652 |
+
* @retval BM_SUCCESS Succeeds.
|
1653 |
+
* Other code Fails.
|
1654 |
+
*/
|
1655 |
+
DECL_EXPORT bm_status_t bm_get_profile(bm_handle_t handle, bm_profile_t *profile);
|
1656 |
+
|
1657 |
+
typedef struct bootloader_version{
|
1658 |
+
char *bl1_version;
|
1659 |
+
char *bl2_version;
|
1660 |
+
char *bl31_version;
|
1661 |
+
char *uboot_version;
|
1662 |
+
} boot_loader_version;
|
1663 |
+
|
1664 |
+
/**
|
1665 |
+
* @name bm_get_boot_loader_version
|
1666 |
+
* @brief To get the boot_loader_version
|
1667 |
+
* @ingroup bmlib_runtime
|
1668 |
+
*
|
1669 |
+
* @param [in] handle The device handle
|
1670 |
+
* @param [out] version The result version data
|
1671 |
+
* @retval BM_SUCCESS Succeeds.
|
1672 |
+
* Other code Fails.
|
1673 |
+
*/
|
1674 |
+
DECL_EXPORT bm_status_t bm_get_boot_loader_version(bm_handle_t handle, boot_loader_version *version);
|
1675 |
+
|
1676 |
+
/**
|
1677 |
+
* @name bm_get_vpu_instant_usage
|
1678 |
+
* @brief To get vpu usage
|
1679 |
+
* @ingroup bmlib_runtime
|
1680 |
+
*
|
1681 |
+
* @param [in] handle The device handle
|
1682 |
+
* @param [out] smi_attr The result vpu usage
|
1683 |
+
* @retval BM_SUCCESS Succeeds.
|
1684 |
+
* Other code Fails.
|
1685 |
+
*/
|
1686 |
+
DECL_EXPORT bm_status_t bm_get_vpu_instant_usage(bm_handle_t handle, int *vpu_usage);
|
1687 |
+
|
1688 |
+
/**
|
1689 |
+
* @name bm_get_jpu_core_usage
|
1690 |
+
* @brief To get the jpu usage
|
1691 |
+
* @ingroup bmlib_runtime
|
1692 |
+
*
|
1693 |
+
* @param [in] handle The device handle
|
1694 |
+
* @param [out] smi_attr The result jpu usage
|
1695 |
+
* @retval BM_SUCCESS Succeeds.
|
1696 |
+
* Other code Fails.
|
1697 |
+
*/
|
1698 |
+
DECL_EXPORT bm_status_t bm_get_jpu_core_usage(bm_handle_t handle, int *jpu_usage);
|
1699 |
+
|
1700 |
+
/**
|
1701 |
+
* @name bm_get_vpp_instant_usage
|
1702 |
+
* @brief To get the vpp usage
|
1703 |
+
* @ingroup bmlib_runtime
|
1704 |
+
*
|
1705 |
+
* @param [in] handle The device handle
|
1706 |
+
* @param [out] smi_attr The result vpp usage
|
1707 |
+
* @retval BM_SUCCESS Succeeds.
|
1708 |
+
* Other code Fails.
|
1709 |
+
*/
|
1710 |
+
DECL_EXPORT bm_status_t bm_get_vpp_instant_usage(bm_handle_t handle, int *vpp_usage);
|
1711 |
+
/**
|
1712 |
+
* @name bm_get_last_api_process_time_us
|
1713 |
+
* @brief This function is abandoned.
|
1714 |
+
*/
|
1715 |
+
#ifdef __linux__
|
1716 |
+
DECL_EXPORT bm_status_t bm_get_last_api_process_time_us(bm_handle_t handle,
|
1717 |
+
unsigned long *time_us);
|
1718 |
+
#else
|
1719 |
+
DECL_EXPORT bm_status_t bm_get_last_api_process_time_us(bm_handle_t handle,
|
1720 |
+
unsigned long long *time_us);
|
1721 |
+
#endif
|
1722 |
+
/*******************tpu clock and module reset releated functions *************/
|
1723 |
+
|
1724 |
+
/**
|
1725 |
+
* @name bm_set_clk_tpu_freq
|
1726 |
+
* @brief To set the clock frequency of TPU (only valid in PCIE mode).
|
1727 |
+
* @ingroup bmlib_runtime
|
1728 |
+
*
|
1729 |
+
* @param [in] handle The device handle
|
1730 |
+
* @param [in] freq The TPU target frequency
|
1731 |
+
* @retval BM_SUCCESS Succeeds.
|
1732 |
+
* Other code Fails.
|
1733 |
+
*/
|
1734 |
+
DECL_EXPORT bm_status_t bm_set_clk_tpu_freq(bm_handle_t handle, int freq);
|
1735 |
+
|
1736 |
+
/**
|
1737 |
+
* @name bm_get_clk_tpu_freq
|
1738 |
+
* @brief To get the clock frequency of TPU
|
1739 |
+
* @ingroup bmlib_runtime
|
1740 |
+
*
|
1741 |
+
* @param [in] handle The device handle
|
1742 |
+
* @param [out] freq The current TPU frequency
|
1743 |
+
* @retval BM_SUCCESS Succeeds.
|
1744 |
+
* Other code Fails.
|
1745 |
+
*/
|
1746 |
+
DECL_EXPORT bm_status_t bm_get_clk_tpu_freq(bm_handle_t handle, int *freq);
|
1747 |
+
|
1748 |
+
/*******************misc functions ********************************************/
|
1749 |
+
struct bm_misc_info {
|
1750 |
+
int pcie_soc_mode; /*0---pcie; 1---soc*/
|
1751 |
+
int ddr_ecc_enable; /*0---disable; 1---enable*/
|
1752 |
+
long long ddr0a_size;
|
1753 |
+
long long ddr0b_size;
|
1754 |
+
long long ddr1_size;
|
1755 |
+
long long ddr2_size;
|
1756 |
+
unsigned int chipid;
|
1757 |
+
#define BM1682_CHIPID_BIT_MASK (0X1 << 0)
|
1758 |
+
#define BM1684_CHIPID_BIT_MASK (0X1 << 1)
|
1759 |
+
#define BM1686_CHIPID_BIT_MASK (0X1 << 2)
|
1760 |
+
#ifdef __linux__
|
1761 |
+
unsigned long chipid_bit_mask;
|
1762 |
+
#else
|
1763 |
+
unsigned long long chipid_bit_mask;
|
1764 |
+
#endif
|
1765 |
+
unsigned int driver_version;
|
1766 |
+
int domain_bdf;
|
1767 |
+
int board_version; /*hardware board version [23:16]-mcu sw version, [15:8]-board type, [7:0]-hw version*/
|
1768 |
+
int a53_enable;
|
1769 |
+
int dyn_enable;
|
1770 |
+
};
|
1771 |
+
|
1772 |
+
/**
|
1773 |
+
* @name bm_get_misc_info
|
1774 |
+
* @brief To get miscellaneous information of the device
|
1775 |
+
* @ingroup bmlib_runtime
|
1776 |
+
*
|
1777 |
+
* @param [in] handle The device handle
|
1778 |
+
* @param [out] pmisc_info The fetched misc info
|
1779 |
+
* @retval BM_SUCCESS Succeeds.
|
1780 |
+
* Other code Fails.
|
1781 |
+
*/
|
1782 |
+
DECL_EXPORT bm_status_t bm_get_misc_info(bm_handle_t handle, struct bm_misc_info *pmisc_info);
|
1783 |
+
|
1784 |
+
/**
|
1785 |
+
* @name bm_get_chipid
|
1786 |
+
* @brief To get the chipid of the device. (0x1682 / 0x1684 / 0x168?)
|
1787 |
+
* @ingroup bmlib_runtime
|
1788 |
+
*
|
1789 |
+
* @param [in] handle The device handle
|
1790 |
+
* @param [out] p_chipid The chip id of the device
|
1791 |
+
* @retval BM_SUCCESS Succeeds.
|
1792 |
+
* Other code Fails.
|
1793 |
+
*/
|
1794 |
+
DECL_EXPORT bm_status_t bm_get_chipid(bm_handle_t handle, unsigned int *p_chipid);
|
1795 |
+
|
1796 |
+
#define BMLIB_LOG_QUIET -8
|
1797 |
+
#define BMLIB_LOG_PANIC 0
|
1798 |
+
#define BMLIB_LOG_FATAL 8
|
1799 |
+
#define BMLIB_LOG_ERROR 16
|
1800 |
+
#define BMLIB_LOG_WARNING 24
|
1801 |
+
#define BMLIB_LOG_INFO 32
|
1802 |
+
#define BMLIB_LOG_VERBOSE 40
|
1803 |
+
#define BMLIB_LOG_DEBUG 48
|
1804 |
+
#define BMLIB_LOG_TRACE 56
|
1805 |
+
|
1806 |
+
/**
|
1807 |
+
* @name bmlib_log_get_level
|
1808 |
+
* @brief To get the bmlib log level
|
1809 |
+
* @ingroup bmlib_log
|
1810 |
+
*
|
1811 |
+
* @param void
|
1812 |
+
* @retval The level of bmlib log level
|
1813 |
+
*/
|
1814 |
+
DECL_EXPORT int bmlib_log_get_level(void);
|
1815 |
+
|
1816 |
+
/**
|
1817 |
+
* @name bmlib_log_set_level
|
1818 |
+
* @brief To set the bmlib log level
|
1819 |
+
* @ingroup bmlib_log
|
1820 |
+
*
|
1821 |
+
* @param [in] level The level of bmlib log level
|
1822 |
+
* @retval void
|
1823 |
+
*/
|
1824 |
+
DECL_EXPORT void bmlib_log_set_level(int level);
|
1825 |
+
|
1826 |
+
/**
|
1827 |
+
* @name bmlib_log_set_callback
|
1828 |
+
* @brief To set callback to get bmlib log
|
1829 |
+
* @ingroup bmlib_log
|
1830 |
+
*
|
1831 |
+
* @param [in] callback The callback function to get bmlib log
|
1832 |
+
* @retval void
|
1833 |
+
*/
|
1834 |
+
DECL_EXPORT void bmlib_log_set_callback(void (*callback)(const char*, int, const char*, va_list args));
|
1835 |
+
|
1836 |
+
/**
|
1837 |
+
* @name bm_set_debug_mode
|
1838 |
+
* @brief To set the debug mode for firmware log for tpu
|
1839 |
+
* @ingroup bmlib_log
|
1840 |
+
*
|
1841 |
+
* @param [in] handle The device handle
|
1842 |
+
* @param [in] mode The debug mode of fw log, 0/1 for disable/enable log
|
1843 |
+
* @retval void
|
1844 |
+
*/
|
1845 |
+
DECL_EXPORT void bm_set_debug_mode(bm_handle_t handle, int mode);
|
1846 |
+
|
1847 |
+
/**
|
1848 |
+
* @name bmlib_api_dbg_callback
|
1849 |
+
* @brief To set debug callback to get firmware log
|
1850 |
+
* @ingroup bmlib_log
|
1851 |
+
*
|
1852 |
+
* @param [in] bmlib_api_dbg_callback callback to get firmware log
|
1853 |
+
* @retval void
|
1854 |
+
*/
|
1855 |
+
typedef void (*bmlib_api_dbg_callback)(int, int, int, const char*);
|
1856 |
+
// api, result, duratioin, log, third int for api duration for future
|
1857 |
+
DECL_EXPORT void bmlib_set_api_dbg_callback(bmlib_api_dbg_callback callback);
|
1858 |
+
|
1859 |
+
/**
|
1860 |
+
* @name bmcpu_get_cpu_status
|
1861 |
+
* @brief Get bmcpu status
|
1862 |
+
* @ingroup bmlib_log
|
1863 |
+
*
|
1864 |
+
* @param [in] handle The device handle
|
1865 |
+
* @retval BMCPU_RUNNING bmcpu is running.
|
1866 |
+
* Other code Fails.
|
1867 |
+
*/
|
1868 |
+
DECL_EXPORT bm_cpu_status_t bmcpu_get_cpu_status(bm_handle_t handle);
|
1869 |
+
|
1870 |
+
/**
|
1871 |
+
* @name bmcpu_start_cpu
|
1872 |
+
* @brief Start cpu in pcie mode
|
1873 |
+
* @ingroup bmlib_log
|
1874 |
+
*
|
1875 |
+
* @param [in] handle The device handle
|
1876 |
+
* @param [in] boot_file Fip file
|
1877 |
+
* @param [in] core_file Itb file
|
1878 |
+
* @retval BM_SUCCESS Succeeds.
|
1879 |
+
* Other code Fails.
|
1880 |
+
*/
|
1881 |
+
DECL_EXPORT bm_status_t bmcpu_start_cpu(bm_handle_t handle, char *boot_file, char *core_file);
|
1882 |
+
|
1883 |
+
/**
|
1884 |
+
* @name bmcpu_open_process
|
1885 |
+
* @brief Open a process to do some work
|
1886 |
+
* @ingroup bmlib_log
|
1887 |
+
*
|
1888 |
+
* @param [in] handle The device handle
|
1889 |
+
* @param [in] flags Process flags
|
1890 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
1891 |
+
* @retval >= 0 process handle
|
1892 |
+
* < 0 Other code Fails.
|
1893 |
+
*/
|
1894 |
+
DECL_EXPORT int bmcpu_open_process(bm_handle_t handle, unsigned int flags, int timeout);
|
1895 |
+
|
1896 |
+
/**
|
1897 |
+
* @name bmcpu_load_library
|
1898 |
+
* @brief Load a share library(so) to specific process
|
1899 |
+
* @ingroup bmlib_log
|
1900 |
+
*
|
1901 |
+
* @param [in] handle The device handle
|
1902 |
+
* @param [in] process_handle Process handle
|
1903 |
+
* @param [in] library_file Library file path
|
1904 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
1905 |
+
* @retval BM_SUCCESS Succeeds.
|
1906 |
+
* Other code Fails.
|
1907 |
+
*/
|
1908 |
+
DECL_EXPORT bm_status_t bmcpu_load_library(bm_handle_t handle, int process_handle, char *library_file, int timeout);
|
1909 |
+
|
1910 |
+
/**
|
1911 |
+
* @name bmcpu_unload_library
|
1912 |
+
* @brief Load a share library(so) to specific process
|
1913 |
+
* @ingroup bmlib_log
|
1914 |
+
*
|
1915 |
+
* @param [in] handle The device handle
|
1916 |
+
* @param [in] process_handle Process handle
|
1917 |
+
* @param [in] library_file Library file path
|
1918 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
1919 |
+
* @retval BM_SUCCESS Succeeds.
|
1920 |
+
* Other code Fails.
|
1921 |
+
*/
|
1922 |
+
DECL_EXPORT bm_status_t bmcpu_unload_library(bm_handle_t handle, int process_handle, char *library_file, int timeout);
|
1923 |
+
|
1924 |
+
/**
|
1925 |
+
* @name bmcpu_exec_function
|
1926 |
+
* @brief Execute specific function in specific process
|
1927 |
+
* @ingroup bmlib_log
|
1928 |
+
*
|
1929 |
+
* @param [in] handle The device handle
|
1930 |
+
* @param [in] process_handle Process handle
|
1931 |
+
* @param [in] function_name Function name
|
1932 |
+
* @param [in] function_param Function parameters
|
1933 |
+
* @param [in] param_size Parameters size in bytes
|
1934 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
1935 |
+
* @retval 0 success.
|
1936 |
+
* >0 code fails from bmlib
|
1937 |
+
* <0 code fails from function
|
1938 |
+
*/
|
1939 |
+
DECL_EXPORT int bmcpu_exec_function(bm_handle_t handle,
|
1940 |
+
int process_handle,
|
1941 |
+
char *function_name,
|
1942 |
+
void *function_param,
|
1943 |
+
unsigned int param_size,
|
1944 |
+
int timeout);
|
1945 |
+
|
1946 |
+
#define BMCPU_EXEC_OPT_NO_FLUSH_CACHE 1
|
1947 |
+
/**
|
1948 |
+
* @name bmcpu_exec_function_ext
|
1949 |
+
* @brief Execute specific function in specific process
|
1950 |
+
* @ingroup bmlib_log
|
1951 |
+
*
|
1952 |
+
* @param [in] handle The device handle
|
1953 |
+
* @param [in] process_handle Process handle
|
1954 |
+
* @param [in] function_name Function name
|
1955 |
+
* @param [in] function_param Function parameters
|
1956 |
+
* @param [in] param_size Parameters size in bytes
|
1957 |
+
* @param [in] opt exec options
|
1958 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
1959 |
+
* @retval 0 success.
|
1960 |
+
* >0 code fails from bmlib
|
1961 |
+
* <0 code fails from function
|
1962 |
+
*/
|
1963 |
+
DECL_EXPORT int bmcpu_exec_function_ext(bm_handle_t handle,
|
1964 |
+
int process_handle,
|
1965 |
+
char *function_name,
|
1966 |
+
void *function_param,
|
1967 |
+
unsigned int param_size,
|
1968 |
+
unsigned int opt,
|
1969 |
+
int timeout);
|
1970 |
+
|
1971 |
+
/**
|
1972 |
+
* @name bmcpu_exec_function_async
|
1973 |
+
* @brief Execute specific function in specific process asynchronous
|
1974 |
+
* user should use bm_query_exec_function_result to query result
|
1975 |
+
* @ingroup bmlib_log
|
1976 |
+
*
|
1977 |
+
* @param [in] handle The device handle
|
1978 |
+
* @param [in] process_handle Process handle
|
1979 |
+
* @param [in] function_name Function name
|
1980 |
+
* @param [in] function_param Function param
|
1981 |
+
* @param [in] param_size Param size in bytes
|
1982 |
+
* @retval BM_SUCCESS Succeeds.
|
1983 |
+
* Other code Fails.
|
1984 |
+
*/
|
1985 |
+
DECL_EXPORT bm_status_t bmcpu_exec_function_async(bm_handle_t handle,
|
1986 |
+
int process_handle,
|
1987 |
+
char *function_name,
|
1988 |
+
void *function_param,
|
1989 |
+
unsigned int param_size,
|
1990 |
+
unsigned long long *api_handle);
|
1991 |
+
|
1992 |
+
/**
|
1993 |
+
* @name bmcpu_exec_function_async_ext
|
1994 |
+
* @brief Execute specific function in specific process asynchronous
|
1995 |
+
* user should use bm_query_exec_function_result to query result
|
1996 |
+
* @ingroup bmlib_log
|
1997 |
+
*
|
1998 |
+
* @param [in] handle The device handle
|
1999 |
+
* @param [in] process_handle Process handle
|
2000 |
+
* @param [in] function_name Function name
|
2001 |
+
* @param [in] function_param Function param
|
2002 |
+
* @param [in] param_size Param size in bytes
|
2003 |
+
* @param [in] opt exec options
|
2004 |
+
* @retval BM_SUCCESS Succeeds.
|
2005 |
+
* Other code Fails.
|
2006 |
+
*/
|
2007 |
+
DECL_EXPORT bm_status_t bmcpu_exec_function_async_ext(bm_handle_t handle,
|
2008 |
+
int process_handle,
|
2009 |
+
char *function_name,
|
2010 |
+
void *function_param,
|
2011 |
+
unsigned int param_size,
|
2012 |
+
unsigned int opt,
|
2013 |
+
unsigned long long *api_handle);
|
2014 |
+
|
2015 |
+
/**
|
2016 |
+
* @name bmcpu_query_exec_function_result
|
2017 |
+
* @brief Query result from function called by bm_exec_function
|
2018 |
+
* @ingroup bmlib_log
|
2019 |
+
*
|
2020 |
+
* @param [in] handle The device handle
|
2021 |
+
* @param [in] api_handle Api handle return by bm_exec_function_async
|
2022 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
2023 |
+
* @retval 0 success.
|
2024 |
+
* >0 code fails from bmlib
|
2025 |
+
* <0 code fails from function
|
2026 |
+
*/
|
2027 |
+
DECL_EXPORT int bmcpu_query_exec_function_result(bm_handle_t handle, unsigned long long api_handle, int timeout);
|
2028 |
+
|
2029 |
+
/**
|
2030 |
+
* @name bmcpu_map_phys_addr
|
2031 |
+
* @brief Map physical address in specific process
|
2032 |
+
* @ingroup bmlib_log
|
2033 |
+
*
|
2034 |
+
* @param [in] handle The device handle
|
2035 |
+
* @param [in] process_handle Process handle
|
2036 |
+
* @param [in] phys_addr Physical address
|
2037 |
+
* @param [in] size Map size in bytes
|
2038 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
2039 |
+
* @retval >0 virtual address
|
2040 |
+
* 0 fails
|
2041 |
+
*/
|
2042 |
+
DECL_EXPORT void *bmcpu_map_phys_addr(bm_handle_t handle, int process_handle, void *phys_addr, unsigned int size, int timeout);
|
2043 |
+
|
2044 |
+
/**
|
2045 |
+
* @name bmcpu_unmap_phys_addr
|
2046 |
+
* @brief Unmap physical address in specific process
|
2047 |
+
* @ingroup bmlib_log
|
2048 |
+
*
|
2049 |
+
* @param [in] handle The device handle
|
2050 |
+
* @param [in] process_handle Process handle
|
2051 |
+
* @param [in] phys_addr Physical address
|
2052 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
2053 |
+
* @retval <0 fail
|
2054 |
+
* 0 success
|
2055 |
+
*/
|
2056 |
+
DECL_EXPORT bm_status_t bmcpu_unmap_phys_addr(bm_handle_t handle, int process_handle, void *phys_addr, int timeout);
|
2057 |
+
|
2058 |
+
/**
|
2059 |
+
* @name bmcpu_close_process
|
2060 |
+
* @brief Close process
|
2061 |
+
* @ingroup bmlib_log
|
2062 |
+
*
|
2063 |
+
* @param [in] handle The device handle
|
2064 |
+
* @param [in] process_handle Process handle
|
2065 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
2066 |
+
* @retval BM_SUCCESS Succeeds.
|
2067 |
+
* Other code Fails.
|
2068 |
+
*/
|
2069 |
+
DECL_EXPORT bm_status_t bmcpu_close_process(bm_handle_t handle, int process_handle, int timeout);
|
2070 |
+
|
2071 |
+
/**
|
2072 |
+
* @name bmcpu_reset_cpu
|
2073 |
+
* @brief Reset cpu in pcie mode
|
2074 |
+
* @ingroup bmlib_log
|
2075 |
+
*
|
2076 |
+
* @param [in] handle The device handle
|
2077 |
+
* @retval BM_SUCCESS Succeeds.
|
2078 |
+
* Other code Fails.
|
2079 |
+
*/
|
2080 |
+
DECL_EXPORT bm_status_t bmcpu_reset_cpu(bm_handle_t handle);
|
2081 |
+
|
2082 |
+
/**
|
2083 |
+
* @name bm_enable_perf_monitor
|
2084 |
+
* @brief enable perf monitor to get gdma and tpu performance data
|
2085 |
+
* @ingroup bmlib_perf
|
2086 |
+
*
|
2087 |
+
* @param [in] handle The device handle
|
2088 |
+
* @param [in] perf_monitor The monitor to perf
|
2089 |
+
* @retval BM_SUCCESS Succeeds.
|
2090 |
+
* Other code Fails.
|
2091 |
+
*/
|
2092 |
+
DECL_EXPORT bm_status_t bm_enable_perf_monitor(bm_handle_t handle, bm_perf_monitor_t *perf_monitor);
|
2093 |
+
|
2094 |
+
/**
|
2095 |
+
* @name bm_disable_perf_monitor
|
2096 |
+
* @brief disable perf monitor to get gdma and tpu performance data
|
2097 |
+
* @ingroup bmlib_perf
|
2098 |
+
*
|
2099 |
+
* @param [in] handle The device handle
|
2100 |
+
* @param [in] perf_monitor The monitor to perf
|
2101 |
+
* @retval BM_SUCCESS Succeeds.
|
2102 |
+
* Other code Fails.
|
2103 |
+
*/
|
2104 |
+
DECL_EXPORT bm_status_t bm_disable_perf_monitor(bm_handle_t handle, bm_perf_monitor_t *perf_monitor);
|
2105 |
+
|
2106 |
+
/**
|
2107 |
+
* @name bmcpu_set_log
|
2108 |
+
* @brief Set cpu log options
|
2109 |
+
* @ingroup bmlib_log
|
2110 |
+
*
|
2111 |
+
* @param [in] handle The device handle
|
2112 |
+
* @param [in] log_level 0: DEBUG 1:INFO 2:WARN 3:ERROR 4:FATAL
|
2113 |
+
* @param [in] log_to_console 1: YES 0: No
|
2114 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
2115 |
+
* @retval BM_SUCCESS Succeeds.
|
2116 |
+
* Other code Fails.
|
2117 |
+
*/
|
2118 |
+
DECL_EXPORT bm_status_t bmcpu_set_log(bm_handle_t handle, unsigned int log_level, unsigned int log_to_console, int timeout);
|
2119 |
+
|
2120 |
+
/**
|
2121 |
+
* @name bmcpu_get_log
|
2122 |
+
* @brief Get cpu log file
|
2123 |
+
* @ingroup bmlib_log
|
2124 |
+
*
|
2125 |
+
* @param [in] handle The device handle
|
2126 |
+
* @param [in] process_handle Process handle
|
2127 |
+
* @param [in] log_file save log as file
|
2128 |
+
* @param [in] timeout Timeout value in millisecond, -1 means default value of this device
|
2129 |
+
* @retval BM_SUCCESS Succeeds.
|
2130 |
+
* Other code Fails.
|
2131 |
+
*/
|
2132 |
+
DECL_EXPORT bm_status_t bmcpu_get_log(bm_handle_t handle, int process_handle, char *log_file, int timeout);
|
2133 |
+
|
2134 |
+
/**
|
2135 |
+
* @name bmcpu_sync_time
|
2136 |
+
* @brief Sync device cpu time with host
|
2137 |
+
* @ingroup bmlib_log
|
2138 |
+
*
|
2139 |
+
* @param [in] handle The device handle
|
2140 |
+
* @retval BM_SUCCESS Succeeds.
|
2141 |
+
* Other code Fails.
|
2142 |
+
*/
|
2143 |
+
DECL_EXPORT bm_status_t bmcpu_sync_time(bm_handle_t handle);
|
2144 |
+
|
2145 |
+
/*******************trace and profile releated functions **********************/
|
2146 |
+
struct bm_heap_stat {
|
2147 |
+
unsigned int mem_total;
|
2148 |
+
unsigned int mem_avail;
|
2149 |
+
unsigned int mem_used;
|
2150 |
+
};
|
2151 |
+
|
2152 |
+
typedef struct bm_heap_stat_byte {
|
2153 |
+
unsigned int heap_id;
|
2154 |
+
unsigned long long mem_total;
|
2155 |
+
unsigned long long mem_avail;
|
2156 |
+
unsigned long long mem_used;
|
2157 |
+
unsigned long long mem_start_addr;
|
2158 |
+
} bm_heap_stat_byte_t;
|
2159 |
+
|
2160 |
+
typedef struct bm_dev_stat {
|
2161 |
+
int mem_total;
|
2162 |
+
int mem_used;
|
2163 |
+
int tpu_util;
|
2164 |
+
int heap_num;
|
2165 |
+
struct bm_heap_stat heap_stat[4];
|
2166 |
+
} bm_dev_stat_t;
|
2167 |
+
|
2168 |
+
/**
|
2169 |
+
* @name bm_get_stat
|
2170 |
+
* @brief To get the stat data at the moment
|
2171 |
+
* @ingroup bmlib_runtime
|
2172 |
+
*
|
2173 |
+
* @param [in] handle The device handle
|
2174 |
+
* @param [out] profile The result stat data
|
2175 |
+
* @retval BM_SUCCESS Succeeds.
|
2176 |
+
* Other code Fails.
|
2177 |
+
*/
|
2178 |
+
DECL_EXPORT bm_status_t bm_get_stat(bm_handle_t handle, bm_dev_stat_t *stat);
|
2179 |
+
|
2180 |
+
/**
|
2181 |
+
* @name bm_get_gmem_heap_id
|
2182 |
+
* @brief To get the heap id of allocated global memory
|
2183 |
+
* @ingroup bmlib_runtime
|
2184 |
+
*
|
2185 |
+
* @param [in] handle The device handle
|
2186 |
+
* @param [in] pmem The allocted global memory
|
2187 |
+
* @param [out] heapid The result of get heap id
|
2188 |
+
* @retval BM_SUCCESS Succeeds.
|
2189 |
+
* Other code Fails.
|
2190 |
+
*/
|
2191 |
+
|
2192 |
+
DECL_EXPORT bm_status_t bm_get_gmem_heap_id(bm_handle_t handle, bm_device_mem_t *pmem, unsigned int *heapid);
|
2193 |
+
|
2194 |
+
/**
|
2195 |
+
* @name sg_get_gmem_heap_id
|
2196 |
+
* @brief To get the heap id of allocated global memory
|
2197 |
+
* @ingroup bmlib_runtime
|
2198 |
+
*
|
2199 |
+
* @param [in] handle The device handle
|
2200 |
+
* @param [in] pmem The allocted global memory
|
2201 |
+
* @param [out] heapid The result of get heap id
|
2202 |
+
* @retval BM_SUCCESS Succeeds.
|
2203 |
+
* Other code Fails.
|
2204 |
+
*/
|
2205 |
+
|
2206 |
+
DECL_EXPORT bm_status_t sg_get_gmem_heap_id(bm_handle_t handle, sg_device_mem_t *pmem, unsigned int *heapid);
|
2207 |
+
|
2208 |
+
/**
|
2209 |
+
* @name bm_get_gmem_total_heap_num
|
2210 |
+
* @brief To get the total heap num of global memory
|
2211 |
+
* @ingroup bmlib_runtime
|
2212 |
+
*
|
2213 |
+
* @param [in] handle The device handle
|
2214 |
+
* @param [in] heap_num The result of get total num
|
2215 |
+
* @retval BM_SUCCESS Succeeds.
|
2216 |
+
* Other code Fails.
|
2217 |
+
*/
|
2218 |
+
DECL_EXPORT bm_status_t bm_get_gmem_total_heap_num(bm_handle_t handle, unsigned int *heap_num);
|
2219 |
+
|
2220 |
+
/**
|
2221 |
+
* @name bm_get_gmem_heap_stat_byte_by_id
|
2222 |
+
* @brief To get the heap stat by heap id
|
2223 |
+
* @ingroup bmlib_runtime
|
2224 |
+
*
|
2225 |
+
* @param [in] handle The device handle
|
2226 |
+
* @param [in] heap_id The heap index to get heap status
|
2227 |
+
* @param [out] pheap_byte The result of get heap status
|
2228 |
+
* @retval BM_SUCCESS Succeeds.
|
2229 |
+
* Other code Fails.
|
2230 |
+
*/
|
2231 |
+
DECL_EXPORT bm_status_t bm_get_gmem_heap_stat_byte_by_id(bm_handle_t handle, bm_heap_stat_byte_t *pheap_byte, unsigned int heap_id);
|
2232 |
+
|
2233 |
+
DECL_EXPORT bm_status_t bm_load_firmware(
|
2234 |
+
bm_handle_t handle,
|
2235 |
+
const char *firmware_tcm,
|
2236 |
+
const char *firmware_ddr);
|
2237 |
+
|
2238 |
+
#define bmkernel_load_firmware okkernel_load_firmware
|
2239 |
+
DECL_EXPORT bm_status_t okkernel_load_firmware(
|
2240 |
+
bm_handle_t handle,
|
2241 |
+
const char *firmware_tcm,
|
2242 |
+
const char *firmware_ddr);
|
2243 |
+
|
2244 |
+
DECL_EXPORT bm_status_t okkernel_launch_async(
|
2245 |
+
bm_handle_t handle,
|
2246 |
+
const char *func_name,
|
2247 |
+
const void *args,
|
2248 |
+
unsigned int size);
|
2249 |
+
|
2250 |
+
DECL_EXPORT bm_status_t okkernel_launch_sync(
|
2251 |
+
bm_handle_t handle,
|
2252 |
+
const char *func_name,
|
2253 |
+
const void *args,
|
2254 |
+
unsigned int size);
|
2255 |
+
|
2256 |
+
DECL_EXPORT bm_status_t tpu_kernel_launch_sync(
|
2257 |
+
bm_handle_t handle,
|
2258 |
+
const char *func_name,
|
2259 |
+
const void *args,
|
2260 |
+
unsigned int size);
|
2261 |
+
|
2262 |
+
DECL_EXPORT bm_status_t okkernel_sync(bm_handle_t handle);
|
2263 |
+
|
2264 |
+
/**
|
2265 |
+
* @name bmkernel_launch
|
2266 |
+
* @brief send api to device and launch function
|
2267 |
+
* @ingroup bmlib_runtime
|
2268 |
+
*
|
2269 |
+
* @param [in] handle The device handle
|
2270 |
+
* @param [in] api cmd struct pointer
|
2271 |
+
* @param [in] api cmd length
|
2272 |
+
* @retval BM_SUCCESS Succeeds.
|
2273 |
+
* Other code Fails.
|
2274 |
+
*/
|
2275 |
+
DECL_EXPORT bm_status_t bmkernel_launch(bm_handle_t handle, const void *args,
|
2276 |
+
unsigned int size);
|
2277 |
+
|
2278 |
+
/**
|
2279 |
+
* @name bmkernel_load_lookup_table
|
2280 |
+
* @brief load lookup table to l2-sram
|
2281 |
+
* @ingroup bmlib_runtime
|
2282 |
+
*
|
2283 |
+
* @param [in] handle The device handle
|
2284 |
+
* @param [in] table which loaded to l2-sram
|
2285 |
+
* @param [in] table size
|
2286 |
+
* @retval BM_SUCCESS Succeeds.
|
2287 |
+
* Other code Fails.
|
2288 |
+
*/
|
2289 |
+
DECL_EXPORT bm_status_t bmkernel_load_lookup_table(bm_handle_t handle, const void* table, unsigned int size);
|
2290 |
+
|
2291 |
+
/*******************device management api functions ********************************************/
|
2292 |
+
/**
|
2293 |
+
* @name bm_get_tpu_current
|
2294 |
+
* @brief get tpu current
|
2295 |
+
* @ingroup bmlib_runtime
|
2296 |
+
*
|
2297 |
+
* @param [in] handle The device handle
|
2298 |
+
* @param [out] tpuc(mA) The pointer for tpu current
|
2299 |
+
* @retval BM_SUCCESS Succeeds.
|
2300 |
+
* Other code Fails.
|
2301 |
+
*/
|
2302 |
+
DECL_EXPORT bm_status_t bm_get_tpu_current(bm_handle_t handle, unsigned int *tpuc);
|
2303 |
+
|
2304 |
+
/**
|
2305 |
+
* @name bm_get_board_max_power
|
2306 |
+
* @brief get board support max power
|
2307 |
+
* @ingroup bmlib_runtime
|
2308 |
+
*
|
2309 |
+
* @param [in] handle The device handle
|
2310 |
+
* @param [out] maxp The pointer for maxp
|
2311 |
+
* @retval BM_SUCCESS Succeeds.
|
2312 |
+
* Other code Fails.
|
2313 |
+
*/
|
2314 |
+
DECL_EXPORT bm_status_t bm_get_board_max_power(bm_handle_t handle, unsigned int *maxp);
|
2315 |
+
|
2316 |
+
/**
|
2317 |
+
* @name bm_get_board_power
|
2318 |
+
* @brief get board power
|
2319 |
+
* @ingroup bmlib_runtime
|
2320 |
+
*
|
2321 |
+
* @param [in] handle The device handle
|
2322 |
+
* @param [out] boardp The pointer for boardp
|
2323 |
+
* @retval BM_SUCCESS Succeeds.
|
2324 |
+
* Other code Fails.
|
2325 |
+
*/
|
2326 |
+
DECL_EXPORT bm_status_t bm_get_board_power(bm_handle_t handle, unsigned int *boardp);
|
2327 |
+
|
2328 |
+
/**
|
2329 |
+
* @name bm_get_fan_speed
|
2330 |
+
* @brief get board fan speed
|
2331 |
+
* @ingroup bmlib_runtime
|
2332 |
+
*
|
2333 |
+
* @param [in] handle The device handle
|
2334 |
+
* @param [out] fan The pointer for fan speed
|
2335 |
+
* @retval BM_SUCCESS Succeeds.
|
2336 |
+
* Other code Fails.
|
2337 |
+
*/
|
2338 |
+
DECL_EXPORT bm_status_t bm_get_fan_speed(bm_handle_t handle, unsigned int *fan);
|
2339 |
+
|
2340 |
+
/**
|
2341 |
+
* @name bm_get_ecc_correct_num
|
2342 |
+
* @brief get ecc_correct_num
|
2343 |
+
* @ingroup device management api
|
2344 |
+
*
|
2345 |
+
* @param [in] handle The device handle
|
2346 |
+
* @param [out] ecc_correct_num
|
2347 |
+
* @retval BM_SUCCESS Succeeds.
|
2348 |
+
* Other code Fails.
|
2349 |
+
*/
|
2350 |
+
#ifdef __linux__
|
2351 |
+
DECL_EXPORT bm_status_t bm_get_ecc_correct_num(bm_handle_t handle, unsigned long *ecc_correct_num);
|
2352 |
+
#else
|
2353 |
+
DECL_EXPORT bm_status_t bm_get_ecc_correct_num(bm_handle_t handle, unsigned long long *ecc_correct_num);
|
2354 |
+
#endif
|
2355 |
+
/**
|
2356 |
+
* @name bm_get_12v_atx
|
2357 |
+
* @brief get atx_12v
|
2358 |
+
* @ingroup device management api
|
2359 |
+
*
|
2360 |
+
* @param [in] handle The device handle
|
2361 |
+
* @param [out] atx_12v
|
2362 |
+
* @retval BM_SUCCESS Succeeds.
|
2363 |
+
* Other code Fails.
|
2364 |
+
*/
|
2365 |
+
DECL_EXPORT bm_status_t bm_get_12v_atx(bm_handle_t handle, int *atx_12v);
|
2366 |
+
|
2367 |
+
/**
|
2368 |
+
* @name bm_get_product_sn
|
2369 |
+
* @brief get SE5 sn
|
2370 |
+
* @ingroup device management api
|
2371 |
+
*
|
2372 |
+
* @param [out] product_sn
|
2373 |
+
* @retval BM_SUCCESS Succeeds.
|
2374 |
+
* Other code Fails.
|
2375 |
+
*/
|
2376 |
+
DECL_EXPORT bm_status_t bm_get_product_sn(char *product_sn);
|
2377 |
+
|
2378 |
+
/**
|
2379 |
+
* @name bm_get_sn
|
2380 |
+
* @brief get sn
|
2381 |
+
* @ingroup device management api
|
2382 |
+
*
|
2383 |
+
* @param [in] handle The device handle
|
2384 |
+
* @param [out] sn
|
2385 |
+
* @retval BM_SUCCESS Succeeds.
|
2386 |
+
* Other code Fails.
|
2387 |
+
*/
|
2388 |
+
DECL_EXPORT bm_status_t bm_get_sn(bm_handle_t handle, char *sn);
|
2389 |
+
|
2390 |
+
/**
|
2391 |
+
* @name bm_get_status
|
2392 |
+
* @brief get chip status
|
2393 |
+
* @ingroup device management api
|
2394 |
+
*
|
2395 |
+
* @param [in] handle The device handle
|
2396 |
+
* @param [out] status The board error status, each bit represents an error state
|
2397 |
+
* status == 0x0, borad is nornal, staus > 0, borad is abnormal;
|
2398 |
+
* bit0 == 1, tpu is hang
|
2399 |
+
* bit1 == 1, pcie link abnormal
|
2400 |
+
* bit2 == 1, board temperature is too high
|
2401 |
+
* @retval BM_SUCCESS Succeeds.
|
2402 |
+
* Other code Fails.
|
2403 |
+
*/
|
2404 |
+
DECL_EXPORT bm_status_t bm_get_status(bm_handle_t handle, int *status);
|
2405 |
+
|
2406 |
+
/**
|
2407 |
+
* @name bm_get_tpu_maxclk
|
2408 |
+
* @brief get tpu_maxclk
|
2409 |
+
* @ingroup device management api
|
2410 |
+
*
|
2411 |
+
* @param [in] handle The device handle
|
2412 |
+
* @param [out] tpu_maxclk
|
2413 |
+
* @retval BM_SUCCESS Succeeds.
|
2414 |
+
* Other code Fails.
|
2415 |
+
*/
|
2416 |
+
DECL_EXPORT bm_status_t bm_get_tpu_maxclk(bm_handle_t handle, unsigned int *tpu_maxclk);
|
2417 |
+
|
2418 |
+
/**
|
2419 |
+
* @name bm_get_tpu_minclk
|
2420 |
+
* @brief get tpu_minclk
|
2421 |
+
* @ingroup device management api
|
2422 |
+
*
|
2423 |
+
* @param [in] handle The device handle
|
2424 |
+
* @param [out] tpu_minclk
|
2425 |
+
* @retval BM_SUCCESS Succeeds.
|
2426 |
+
* Other code Fails.
|
2427 |
+
*/
|
2428 |
+
DECL_EXPORT bm_status_t bm_get_tpu_minclk(bm_handle_t handle, unsigned int *tpu_minclk);
|
2429 |
+
|
2430 |
+
/**
|
2431 |
+
* @name bm_get_driver_version
|
2432 |
+
* @brief get driver version
|
2433 |
+
* @ingroup device management api
|
2434 |
+
*
|
2435 |
+
* @param [in] handle The device handle
|
2436 |
+
* @param [out] driver_version
|
2437 |
+
* @retval BM_SUCCESS Succeeds.
|
2438 |
+
* Other code Fails.
|
2439 |
+
*/
|
2440 |
+
DECL_EXPORT bm_status_t bm_get_driver_version(bm_handle_t handle, int *driver_version);
|
2441 |
+
|
2442 |
+
/**
|
2443 |
+
* @name bm_get_board_name
|
2444 |
+
* @brief get device board name
|
2445 |
+
* @ingroup device management api
|
2446 |
+
*
|
2447 |
+
* @param [in] handle The device handle
|
2448 |
+
* @param [out] board_name
|
2449 |
+
* @retval BM_SUCCESS Succeeds.
|
2450 |
+
* Other code Fails.
|
2451 |
+
*/
|
2452 |
+
DECL_EXPORT bm_status_t bm_get_board_name(bm_handle_t handle, char *name);
|
2453 |
+
|
2454 |
+
/**
|
2455 |
+
* @name bm_get_board_temp
|
2456 |
+
* @brief get board temperature
|
2457 |
+
* @ingroup device management api
|
2458 |
+
*
|
2459 |
+
* @param [in] handle The device handle
|
2460 |
+
* @param [out] board_temp
|
2461 |
+
* @retval BM_SUCCESS Succeeds.
|
2462 |
+
* Other code Fails.
|
2463 |
+
*/
|
2464 |
+
DECL_EXPORT bm_status_t bm_get_board_temp(bm_handle_t handle, unsigned int *board_temp);
|
2465 |
+
|
2466 |
+
/**
|
2467 |
+
* @name bm_get_chip_temp
|
2468 |
+
* @brief get chip temperature
|
2469 |
+
* @ingroup device management api
|
2470 |
+
*
|
2471 |
+
* @param [in] handle The device handle
|
2472 |
+
* @param [out] chip_temp
|
2473 |
+
* @retval BM_SUCCESS Succeeds.
|
2474 |
+
* Other code Fails.
|
2475 |
+
*/
|
2476 |
+
DECL_EXPORT bm_status_t bm_get_chip_temp(bm_handle_t handle, unsigned int *chip_temp);
|
2477 |
+
|
2478 |
+
/**
|
2479 |
+
* @name bm_get_tpu_power
|
2480 |
+
* @brief get TPU power
|
2481 |
+
* @ingroup device management api
|
2482 |
+
*
|
2483 |
+
* @param [in] handle The device handle
|
2484 |
+
* @param [out] tpu_power
|
2485 |
+
* @retval BM_SUCCESS Succeeds.
|
2486 |
+
* Other code Fails.
|
2487 |
+
*/
|
2488 |
+
DECL_EXPORT bm_status_t bm_get_tpu_power(bm_handle_t handle, float *tpu_power);
|
2489 |
+
|
2490 |
+
/**
|
2491 |
+
* @name bm_get_tpu_volt
|
2492 |
+
* @brief get TPU voltage
|
2493 |
+
* @ingroup device management api
|
2494 |
+
*
|
2495 |
+
* @param [in] handle The device handle
|
2496 |
+
* @param [out] tpu_volt
|
2497 |
+
* @retval BM_SUCCESS Succeeds.
|
2498 |
+
* Other code Fails.
|
2499 |
+
*/
|
2500 |
+
DECL_EXPORT bm_status_t bm_get_tpu_volt(bm_handle_t handle, unsigned int *tpu_volt);
|
2501 |
+
|
2502 |
+
/**
|
2503 |
+
* @name bm_get_card_id
|
2504 |
+
* @brief get card id
|
2505 |
+
* @ingroup device management api
|
2506 |
+
*
|
2507 |
+
* @param [in] handle The device handle
|
2508 |
+
* @param [out] card_id
|
2509 |
+
* @retval BM_SUCCESS Succeeds.
|
2510 |
+
* Other code Fails.
|
2511 |
+
*/
|
2512 |
+
DECL_EXPORT bm_status_t bm_get_card_id(bm_handle_t handle, unsigned int *card_id);
|
2513 |
+
|
2514 |
+
/**
|
2515 |
+
* @name bm_get_card_num
|
2516 |
+
* @brief get card number
|
2517 |
+
* @ingroup device management api
|
2518 |
+
*
|
2519 |
+
* @param [in] handle The device handle
|
2520 |
+
* @param [out] card_id
|
2521 |
+
* @retval BM_SUCCESS Succeeds.
|
2522 |
+
* Other code Fails.
|
2523 |
+
*/
|
2524 |
+
DECL_EXPORT bm_status_t bm_get_card_num(unsigned int *card_num);
|
2525 |
+
|
2526 |
+
/**
|
2527 |
+
* @name bm_get_chip_num_from_card
|
2528 |
+
* @brief get chip number and start chip id from card
|
2529 |
+
* @ingroup device management api
|
2530 |
+
*
|
2531 |
+
* @param [in] handle The device handle
|
2532 |
+
* @param [out] chip_num
|
2533 |
+
* @param [out] dev_start_index
|
2534 |
+
* @retval BM_SUCCESS Succeeds.
|
2535 |
+
* Other code Fails.
|
2536 |
+
*/
|
2537 |
+
DECL_EXPORT bm_status_t bm_get_chip_num_from_card(unsigned int card_id, unsigned int *chip_num, unsigned int *dev_start_index);
|
2538 |
+
|
2539 |
+
/**
|
2540 |
+
* @name bm_get_dynfreq_status
|
2541 |
+
* @brief get chip dynamic freq status
|
2542 |
+
* @ingroup device management api
|
2543 |
+
*
|
2544 |
+
* @param [in] handle The device handle
|
2545 |
+
* @param [out] dynfreq_status
|
2546 |
+
* @retval BM_SUCCESS Succeeds.
|
2547 |
+
* Other code Fails.
|
2548 |
+
*/
|
2549 |
+
DECL_EXPORT bm_status_t bm_get_dynfreq_status(bm_handle_t handle, int *dynfreq_status);
|
2550 |
+
|
2551 |
+
/**
|
2552 |
+
* @name bm_change_dynfreq_status
|
2553 |
+
* @brief change(enable/disable) chip dynamic freq status
|
2554 |
+
* @ingroup device management api
|
2555 |
+
*
|
2556 |
+
* @param [in] handle The device handle
|
2557 |
+
* @param [in] new_status
|
2558 |
+
* @retval BM_SUCCESS Succeeds.
|
2559 |
+
* Other code Fails.
|
2560 |
+
*/
|
2561 |
+
DECL_EXPORT bm_status_t bm_change_dynfreq_status(bm_handle_t handle, int new_status);
|
2562 |
+
|
2563 |
+
/**
|
2564 |
+
* @name bm_get_tpu_scalar_num
|
2565 |
+
* @brief To get the core number of TPU scalar
|
2566 |
+
* @ingroup bmlib_runtime
|
2567 |
+
*
|
2568 |
+
* @param [in] handle The device handle
|
2569 |
+
* @param [out] core_num The core number of TPU scalar
|
2570 |
+
* @retval BM_SUCCESS Succeeds.
|
2571 |
+
* Other code Fails.
|
2572 |
+
*/
|
2573 |
+
DECL_EXPORT bm_status_t bm_get_tpu_scalar_num(bm_handle_t handle, unsigned int *core_num);
|
2574 |
+
|
2575 |
+
#define bm_get_tpu_core_num bm_get_tpu_scalar_num
|
2576 |
+
|
2577 |
+
#if defined(__cplusplus)
|
2578 |
+
}
|
2579 |
+
#endif
|
2580 |
+
|
2581 |
+
#endif /* BM_RUNTIME_H_ */
|
ChatGLM2/support/include/bmruntime_interface.h
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*****************************************************************************
|
2 |
+
*
|
3 |
+
* Copyright (c) 2016-2026 by Sophgo Technologies Inc. All rights reserved.
|
4 |
+
*
|
5 |
+
* The material in this file is confidential and contains trade secrets
|
6 |
+
* of Sophgo Technologies Inc. This is proprietary information owned by
|
7 |
+
* Sophgo Technologies Inc. No part of this work may be disclosed,
|
8 |
+
* reproduced, copied, transmitted, or used in any way for any purpose,
|
9 |
+
* without the express written permission of Sophgo Technologies Inc.
|
10 |
+
*
|
11 |
+
*****************************************************************************/
|
12 |
+
|
13 |
+
/*****************************************************************************
|
14 |
+
* BMRuntime Interface is mainly for inference.
|
15 |
+
* Also we can use it for device computation from BMLang programming.
|
16 |
+
* Note: please use interface from bmlib_runtime.h for device memory operation.
|
17 |
+
****************************************************************************/
|
18 |
+
|
19 |
+
#ifndef BMRUNTIME_INTERFACE_H_
|
20 |
+
#define BMRUNTIME_INTERFACE_H_
|
21 |
+
|
22 |
+
#include "bmdef.h"
|
23 |
+
|
24 |
+
#ifdef _WIN32
|
25 |
+
#define DECL_EXPORT _declspec(dllexport)
|
26 |
+
#define DECL_IMPORT _declspec(dllimport)
|
27 |
+
#else
|
28 |
+
#define DECL_EXPORT
|
29 |
+
#define DECL_IMPORT
|
30 |
+
#endif
|
31 |
+
|
32 |
+
#if defined(__cplusplus)
|
33 |
+
extern "C" {
|
34 |
+
#endif
|
35 |
+
|
36 |
+
/* --------------------------------------------------------------------------*/
|
37 |
+
/* interface for basic data type */
|
38 |
+
|
39 |
+
/* get data type byte size */
|
40 |
+
DECL_EXPORT size_t bmrt_data_type_size(bm_data_type_t dtype);
|
41 |
+
|
42 |
+
/*
|
43 |
+
dims array to bm_shape_t,
|
44 |
+
shape and dims should not be NULL, num_dims should not be larger than BM_MAX_DIMS_NUM */
|
45 |
+
DECL_EXPORT void bmrt_shape(bm_shape_t* shape, const int* dims, int num_dims);
|
46 |
+
|
47 |
+
/*
|
48 |
+
number of shape elements, shape should not be NULL and num_dims should not large than
|
49 |
+
BM_MAX_DIMS_NUM */
|
50 |
+
DECL_EXPORT uint64_t bmrt_shape_count(const bm_shape_t* shape);
|
51 |
+
|
52 |
+
/* compare whether two shape is same */
|
53 |
+
DECL_EXPORT bool bmrt_shape_is_same(const bm_shape_t* left, const bm_shape_t* right);
|
54 |
+
|
55 |
+
/*
|
56 |
+
fill a tensor with data type and shape, and st_mode = 0 as default.
|
57 |
+
tensor and p_bmrt should not be NULL, shape count should not be 0.
|
58 |
+
it will alloc device mem to tensor->device_mem, so user should bmrt_free_device(p_bmrt,
|
59 |
+
tensor->device_mem) to free it.*/
|
60 |
+
DECL_EXPORT bool bmrt_tensor(bm_tensor_t* tensor, void* p_bmrt, bm_data_type_t dtype, bm_shape_t shape);
|
61 |
+
|
62 |
+
/*
|
63 |
+
fill a tensor with data type and shape, and st_mode = 0 as default.
|
64 |
+
tensor and p_bmrt should not be NULL, shape count should not be 0.
|
65 |
+
it will alloc device mem to tensor->device_mem on devid-th device.*/
|
66 |
+
DECL_EXPORT bool bmrt_tensor_ex(bm_tensor_t* tensor, void* p_bmrt, int devid, bm_data_type_t dtype, bm_shape_t shape);
|
67 |
+
|
68 |
+
/* fill a tensor with device mem existed, tensor byte size should not large than device mem size */
|
69 |
+
DECL_EXPORT void bmrt_tensor_with_device(bm_tensor_t* tensor, bm_device_mem_t device_mem,
|
70 |
+
bm_data_type_t dtype, bm_shape_t shape);
|
71 |
+
|
72 |
+
/* get tensor bytes size, tensor should not be NULL */
|
73 |
+
DECL_EXPORT size_t bmrt_tensor_bytesize(const bm_tensor_t* tensor);
|
74 |
+
|
75 |
+
/* get tensor mem size allocated in device mem, tensor should not be NULL */
|
76 |
+
DECL_EXPORT size_t bmrt_tensor_device_size(const bm_tensor_t* tensor);
|
77 |
+
|
78 |
+
/* print net info for debug */
|
79 |
+
DECL_EXPORT void bmrt_print_network_info(const bm_net_info_t* net_info);
|
80 |
+
|
81 |
+
/* --------------------------------------------------------------------------*/
|
82 |
+
/**
|
83 |
+
* @name bmrt_create
|
84 |
+
* @brief To create the bmruntime with bm_handle.
|
85 |
+
* @ingroup bmruntime
|
86 |
+
*
|
87 |
+
* This API creates the bmruntime. It returns a void* pointer which is the pointer
|
88 |
+
* of bmruntime. Device id is set when get bm_handle;
|
89 |
+
*
|
90 |
+
* @param [in] bm_handle bm handle. It must be initialized by using bmlib.
|
91 |
+
*
|
92 |
+
* @retval void* the pointer of bmruntime
|
93 |
+
*/
|
94 |
+
DECL_EXPORT void* bmrt_create(bm_handle_t bm_handle);
|
95 |
+
|
96 |
+
/* --------------------------------------------------------------------------*/
|
97 |
+
/**
|
98 |
+
* @name bmrt_create_ex
|
99 |
+
* @brief To create the bmruntime with one or more bm_handle.
|
100 |
+
* @ingroup bmruntime
|
101 |
+
*
|
102 |
+
* This API creates the bmruntime. It returns a void* pointer which is the pointer
|
103 |
+
* of bmruntime.
|
104 |
+
*
|
105 |
+
* @param [in] bm_handles bm handles. They must be initialized by using bmlib.
|
106 |
+
* @param [in] num_handles number of bm_handles.
|
107 |
+
*
|
108 |
+
* @retval void* the pointer of bmruntime
|
109 |
+
*/
|
110 |
+
DECL_EXPORT void *bmrt_create_ex(bm_handle_t *bm_handles, int num_handles);
|
111 |
+
|
112 |
+
/**
|
113 |
+
* @name bmrt_destroy
|
114 |
+
* @brief To destroy the bmruntime pointer
|
115 |
+
* @ingroup bmruntime
|
116 |
+
*
|
117 |
+
* This API destroy the bmruntime.
|
118 |
+
*
|
119 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
120 |
+
*/
|
121 |
+
DECL_EXPORT void bmrt_destroy(void* p_bmrt);
|
122 |
+
|
123 |
+
/**
|
124 |
+
* @name bmrt_get_bm_handle
|
125 |
+
* @brief To get the BM runtime context.
|
126 |
+
* @ingroup bmruntime
|
127 |
+
*
|
128 |
+
* This API get the BM runtime context for using BMDNN, BMCV or BMLIB
|
129 |
+
*
|
130 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
131 |
+
*/
|
132 |
+
DECL_EXPORT void * bmrt_get_bm_handle(void* p_bmrt);
|
133 |
+
|
134 |
+
/**
|
135 |
+
* @name bmrt_load_bmodel
|
136 |
+
* @brief To load the bmodel which is created by BM compiler
|
137 |
+
* @ingroup bmruntime
|
138 |
+
*
|
139 |
+
* This API is to load bmodel created by BM compiler.
|
140 |
+
* After loading bmodel, we can run the inference of neuron network.
|
141 |
+
*
|
142 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
143 |
+
* @param [in] bmodel_path Bmodel file directory.
|
144 |
+
*
|
145 |
+
* @retval true Load context sucess.
|
146 |
+
* @retval false Load context failed.
|
147 |
+
*/
|
148 |
+
DECL_EXPORT bool bmrt_load_bmodel(void* p_bmrt, const char *bmodel_path);
|
149 |
+
|
150 |
+
/**
|
151 |
+
* @name bmrt_load_bmodel_data
|
152 |
+
* @brief To load the bmodel which is created by BM compiler from buffer
|
153 |
+
* @ingroup bmruntime
|
154 |
+
*
|
155 |
+
* This API is to load bmodel created by BM compiler.
|
156 |
+
* After loading bmodel, we can run the inference of neuron network.
|
157 |
+
* Different with bmrt_load_bmodel, bmodel is the data in host memory.
|
158 |
+
*
|
159 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
160 |
+
* @param [in] bmodel_data Bmodel data pointer to buffer
|
161 |
+
* @param [in] size Bmodel data size
|
162 |
+
*
|
163 |
+
* @retval true Load context sucess.
|
164 |
+
* @retval false Load context failed.
|
165 |
+
*/
|
166 |
+
DECL_EXPORT bool bmrt_load_bmodel_data(void* p_bmrt, const void * bmodel_data, size_t size);
|
167 |
+
|
168 |
+
/**
|
169 |
+
* @name bmrt_show_neuron_network
|
170 |
+
* @brief To print the name of all neuron network
|
171 |
+
* @ingroup bmruntime
|
172 |
+
*
|
173 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
174 |
+
*/
|
175 |
+
DECL_EXPORT void bmrt_show_neuron_network(void* p_bmrt);
|
176 |
+
|
177 |
+
/**
|
178 |
+
* @name bmrt_get_network_number
|
179 |
+
* @brief To get the number of neuron network in the bmruntime
|
180 |
+
* @ingroup bmruntime
|
181 |
+
*
|
182 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
183 |
+
*
|
184 |
+
* @retval int value The number of neuron networks.
|
185 |
+
*/
|
186 |
+
DECL_EXPORT int bmrt_get_network_number(void* p_bmrt);
|
187 |
+
|
188 |
+
/**
|
189 |
+
* @name bmrt_get_network_names
|
190 |
+
* @brief To get the names of all neuron network in the bmruntime
|
191 |
+
* @ingroup bmruntime
|
192 |
+
*
|
193 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
194 |
+
* @param [out] network_names The names of all neuron networks. It should be declare as (const char** networks_ = NULL),
|
195 |
+
* and use as the param &networks_. After this API, user need to free(networks_) if user
|
196 |
+
* do not need it.
|
197 |
+
*/
|
198 |
+
DECL_EXPORT void bmrt_get_network_names(void* p_bmrt, const char*** network_names);
|
199 |
+
|
200 |
+
/**
|
201 |
+
* @name bmrt_get_network_info
|
202 |
+
* @brief To get network info by net name
|
203 |
+
* @ingroup bmruntime
|
204 |
+
*
|
205 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
206 |
+
* @param [in] net_name Network name
|
207 |
+
*
|
208 |
+
* @retval bm_net_info_t* Pointer to net info, needn't free by user; if net name not found, will return NULL.
|
209 |
+
*/
|
210 |
+
DECL_EXPORT const bm_net_info_t* bmrt_get_network_info(void* p_bmrt, const char* net_name);
|
211 |
+
|
212 |
+
/**
|
213 |
+
* @name bmrt_launch_tensor
|
214 |
+
* @brief To launch the inference of the neuron network with setting input tensors
|
215 |
+
* @ingroup bmruntime
|
216 |
+
*
|
217 |
+
* This API supports the neuron nework that is static-compiled or dynamic-compiled
|
218 |
+
* After calling this API, inference on TPU is launched. And the CPU program will not
|
219 |
+
* be blocked. bm_thread_sync should be called to make sure inference finished.
|
220 |
+
* This API support multiple inputs, and multi thread safety
|
221 |
+
*
|
222 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
223 |
+
* @param [in] net_name The name of the neuron network
|
224 |
+
* @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num].
|
225 |
+
* User should initialize each input tensor.
|
226 |
+
* @param [in] input_num Input number
|
227 |
+
* @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num].
|
228 |
+
* This interface will alloc devcie mem to store output data. User should free each
|
229 |
+
* device mem by bm_free_device after the result data not used.
|
230 |
+
* @param [in] output_num Output number
|
231 |
+
*
|
232 |
+
* @retval true Launch success.
|
233 |
+
* @retval false Launch failed.
|
234 |
+
*/
|
235 |
+
DECL_EXPORT bool bmrt_launch_tensor(void* p_bmrt, const char * net_name, const bm_tensor_t input_tensors[], int input_num,
|
236 |
+
bm_tensor_t output_tensors[], int output_num);
|
237 |
+
|
238 |
+
/**
|
239 |
+
* @name bmrt_launch_tensor_ex
|
240 |
+
* @brief To launch the inference of the neuron network with setting input tensors
|
241 |
+
* @ingroup bmruntime
|
242 |
+
*
|
243 |
+
* This API supports the neuron nework that is static-compiled or dynamic-compiled
|
244 |
+
* After calling this API, inference on TPU is launched. And the CPU program will not
|
245 |
+
* be blocked. bm_thread_sync should be called to make sure inference finished.
|
246 |
+
* This API support multiple inputs, and multi thread safety
|
247 |
+
*
|
248 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
249 |
+
* @param [in] net_name The name of the neuron network
|
250 |
+
* @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num],
|
251 |
+
* User should initialize each input tensor.
|
252 |
+
* @param [in] input_num Input number
|
253 |
+
* @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num].
|
254 |
+
* User can set device_mem or stmode of output tensors. If user_mem is true, this interface
|
255 |
+
* will use device mem of output_tensors to store output data, and not alloc device mem;
|
256 |
+
* Or it will alloc device mem to store output. If user_stmode is true, it will use stmode in
|
257 |
+
* each output tensor; Or stmode will be BM_STORE_1N as default.
|
258 |
+
* @param [in] output_num Output number
|
259 |
+
* @param [in] user_mem whether device_mem of output tensors are set
|
260 |
+
* @param [in] user_stmode whether stmode of output tensors are set
|
261 |
+
*
|
262 |
+
* @retval true Launch success.
|
263 |
+
* @retval false Launch failed.
|
264 |
+
*/
|
265 |
+
DECL_EXPORT bool bmrt_launch_tensor_ex(void* p_bmrt, const char * net_name, const bm_tensor_t input_tensors[], int input_num,
|
266 |
+
bm_tensor_t output_tensors[], int output_num, bool user_mem, bool user_stmode);
|
267 |
+
|
268 |
+
/**
|
269 |
+
* @name bmrt_launch_data
|
270 |
+
* @brief To launch the inference of the neuron network with setting input datas in system memory
|
271 |
+
* @ingroup bmruntime
|
272 |
+
*
|
273 |
+
* This API supports the neuron nework that is static-compiled or dynamic-compiled
|
274 |
+
* After calling this API, inference on TPU is launched. And the CPU
|
275 |
+
* program will be blocked.
|
276 |
+
* This API support multiple inputs, and multi thread safety
|
277 |
+
*
|
278 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
279 |
+
* @param [in] net_name The name of the neuron network
|
280 |
+
* @param [in] input_datas Array of input data, defined like void * input_datas[input_num]. User should
|
281 |
+
* initialize each data pointer as input.
|
282 |
+
* @param [in] input_shapes Array of input shape, defined like bm_shape_t input_shapes[input_num].
|
283 |
+
* User should set each input shape
|
284 |
+
* @param [in] input_num Input number
|
285 |
+
* @param [out] output_datas Array of output data, defined like void * output_datas[output_num].
|
286 |
+
* If user don't alloc each output data, set user_mem to false, and this api will alloc
|
287 |
+
* output mem, user should free each output mem when output data not used. Also
|
288 |
+
* user can alloc system memory for each output data by self and set user_mem = true.
|
289 |
+
* @param [out] output_shapes Array of output shape, defined like bm_shape_t output_shapes[output_num].
|
290 |
+
* It will store each output shape.
|
291 |
+
* @param [in] output_num Output number
|
292 |
+
* @param [in] user_mem whether output_datas[i] have allocated memory
|
293 |
+
*
|
294 |
+
* @retval true Launch success.
|
295 |
+
* @retval false Launch failed.
|
296 |
+
*/
|
297 |
+
DECL_EXPORT bool bmrt_launch_data(void* p_bmrt, const char* net_name, void* const input_datas[],
|
298 |
+
const bm_shape_t input_shapes[], int input_num, void * output_datas[],
|
299 |
+
bm_shape_t output_shapes[], int output_num, bool user_mem);
|
300 |
+
|
301 |
+
/**
|
302 |
+
* @name bmrt_trace
|
303 |
+
* @brief To check runtime environment, and collect info for DEBUG
|
304 |
+
* @ingroup bmruntime
|
305 |
+
*
|
306 |
+
* This API is to collect runtime info for DEBUG. Expecially when launch result sudden mistake, call bmrt_trace
|
307 |
+
* will show whether device mems are broken, and other check info.
|
308 |
+
*
|
309 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
310 |
+
*/
|
311 |
+
DECL_EXPORT void bmrt_trace(void* p_bmrt);
|
312 |
+
|
313 |
+
/**
|
314 |
+
* @name bmrt_launch_tensor_multi_cores
|
315 |
+
* @brief To launch the inference of the neuron network with setting input tensors, and support multi core inference.
|
316 |
+
* @ingroup bmruntime
|
317 |
+
*
|
318 |
+
* This API supports the neuron nework that is static-compiled or dynamic-compiled
|
319 |
+
* After calling this API, inference on TPU is launched. And the CPU program will not
|
320 |
+
* be blocked. bm_thread_sync_from_core should be called to make sure inference is finished.
|
321 |
+
* This API support multiple inputs, and multi thread safety
|
322 |
+
*
|
323 |
+
* @param [in] p_bmrt Bmruntime that had been created
|
324 |
+
* @param [in] net_name The name of the neuron network
|
325 |
+
* @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num],
|
326 |
+
* User should initialize each input tensor.
|
327 |
+
* @param [in] input_num Input number
|
328 |
+
* @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num].
|
329 |
+
* User can set device_mem or stmode of output tensors. If user_mem is true, this interface
|
330 |
+
* will use device mem of output_tensors to store output data, and not alloc device mem;
|
331 |
+
* Or it will alloc device mem to store output. If user_stmode is true, it will use stmode in
|
332 |
+
* each output tensor; Or stmode will be BM_STORE_1N as default.
|
333 |
+
* @param [in] output_num Output number
|
334 |
+
* @param [in] user_mem whether device_mem of output tensors are set
|
335 |
+
* @param [in] user_stmode whether stmode of output tensors are set
|
336 |
+
* @param [in] core_list core id list those will be used to inference
|
337 |
+
* @param [in] core_num number of the core list
|
338 |
+
*
|
339 |
+
* @retval true Launch success.
|
340 |
+
* @retval false Launch failed.
|
341 |
+
*/
|
342 |
+
DECL_EXPORT bool bmrt_launch_tensor_multi_cores(
|
343 |
+
void *p_bmrt,
|
344 |
+
const char *net_name,
|
345 |
+
const bm_tensor_t input_tensors[],
|
346 |
+
int input_num,
|
347 |
+
bm_tensor_t output_tensors[],
|
348 |
+
int output_num,
|
349 |
+
bool user_mem,
|
350 |
+
bool user_stmode,
|
351 |
+
const int *core_list,
|
352 |
+
int core_num);
|
353 |
+
|
354 |
+
/**
|
355 |
+
* @name bmrt_memcpy_s2d_parallel
|
356 |
+
* @brief To copy data from system memory to muti-devices memory in parallel
|
357 |
+
* @ingroup bmruntime
|
358 |
+
*
|
359 |
+
* This API only could be used when the p_bmrt is created with bmrt_create_ex on multi devices.
|
360 |
+
* After calling this API, datas[:tensor_num[0]] will be copied to the first device, and
|
361 |
+
* datas[tensor_num[0]:tensor_num[0]+tensor_num[1]] will be copied to the second device and so on.
|
362 |
+
* The process of copying data to different devices is done in parallel and to the same device is in sequence.
|
363 |
+
*
|
364 |
+
* @param [in] p_bmrt Bmruntime that had been created with multi bm_handles
|
365 |
+
* @param [in] tensors Array of tensors that will be copied to devices
|
366 |
+
* @param [in] datas Array of satas allocated in system memory
|
367 |
+
* @param [in] tensor_num Array of tensor_num that will be copied to each device
|
368 |
+
* @param [in] device_num Device number
|
369 |
+
*/
|
370 |
+
DECL_EXPORT bool bmrt_memcpy_s2d_parallel(
|
371 |
+
void *p_bmrt,
|
372 |
+
bm_tensor_t tensors[],
|
373 |
+
void *datas[],
|
374 |
+
int tensor_num[],
|
375 |
+
int device_num);
|
376 |
+
|
377 |
+
/**
|
378 |
+
* @name bmrt_memcpy_d2s_parallel
|
379 |
+
* @brief To copy data from muti-devices memory to system memory in parallel
|
380 |
+
* @ingroup bmruntime
|
381 |
+
*
|
382 |
+
* This API only could be used when the p_bmrt is created with bmrt_create_ex on multi devices.
|
383 |
+
* After calling this API, tensors on the first device will be copied to datas[:tensor_num[0]] , and
|
384 |
+
* tensors on the second device will be copied to datas[tensor_num[0]:tensor_num[0]+tensor_num[1]] and so on.
|
385 |
+
* The process of copying data from different devices is done in parallel and from the same device is in sequence.
|
386 |
+
*
|
387 |
+
* @param [in] p_bmrt Bmruntime that had been created with multi bm_handles
|
388 |
+
* @param [in] datas Array of satas allocated in system memory
|
389 |
+
* @param [in] tensors Array of tensors that will be copied from devices
|
390 |
+
* @param [in] tensor_num Array of tensor_num that will be copied from each device
|
391 |
+
* @param [in] device_num Device number
|
392 |
+
*/
|
393 |
+
DECL_EXPORT bool bmrt_memcpy_d2s_parallel(
|
394 |
+
void *p_bmrt,
|
395 |
+
void *datas[],
|
396 |
+
bm_tensor_t tensors[],
|
397 |
+
int tensor_num[],
|
398 |
+
int device_num);
|
399 |
+
|
400 |
+
#if defined (__cplusplus)
|
401 |
+
}
|
402 |
+
#endif
|
403 |
+
|
404 |
+
#endif
|
ChatGLM2/support/include/sentencepiece/sentencepiece_processor.h
ADDED
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2016 Google Inc.
|
2 |
+
//
|
3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
// you may not use this file except in compliance with the License.
|
5 |
+
// You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software
|
10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
// See the License for the specific language governing permissions and
|
13 |
+
// limitations under the License.!
|
14 |
+
|
15 |
+
#ifndef SENTENCEPIECE_PROCESSOR_H_
|
16 |
+
#define SENTENCEPIECE_PROCESSOR_H_
|
17 |
+
|
18 |
+
#include <cstring>
|
19 |
+
#include <memory>
|
20 |
+
#include <string>
|
21 |
+
#include <string_view>
|
22 |
+
#include <utility>
|
23 |
+
#include <vector>
|
24 |
+
|
25 |
+
#ifndef SWIG
|
26 |
+
namespace absl {
|
27 |
+
using std::string_view;
|
28 |
+
} // namespace absl
|
29 |
+
#endif // SWIG
|
30 |
+
|
31 |
+
namespace sentencepiece {
|
32 |
+
namespace util {
|
33 |
+
|
34 |
+
enum class StatusCode : int {
|
35 |
+
kOk = 0,
|
36 |
+
kCancelled = 1,
|
37 |
+
kUnknown = 2,
|
38 |
+
kInvalidArgument = 3,
|
39 |
+
kDeadlineExceeded = 4,
|
40 |
+
kNotFound = 5,
|
41 |
+
kAlreadyExists = 6,
|
42 |
+
kPermissionDenied = 7,
|
43 |
+
kResourceExhausted = 8,
|
44 |
+
kFailedPrecondition = 9,
|
45 |
+
kAborted = 10,
|
46 |
+
kOutOfRange = 11,
|
47 |
+
kUnimplemented = 12,
|
48 |
+
kInternal = 13,
|
49 |
+
kUnavailable = 14,
|
50 |
+
kDataLoss = 15,
|
51 |
+
kUnauthenticated = 16,
|
52 |
+
};
|
53 |
+
|
54 |
+
class Status {
|
55 |
+
public:
|
56 |
+
Status();
|
57 |
+
~Status();
|
58 |
+
Status(StatusCode code, absl::string_view error_message);
|
59 |
+
Status(const Status &s);
|
60 |
+
void operator=(const Status &s);
|
61 |
+
bool operator==(const Status &s) const;
|
62 |
+
bool operator!=(const Status &s) const;
|
63 |
+
inline bool ok() const { return rep_ == nullptr; }
|
64 |
+
|
65 |
+
void set_error_message(const char *str);
|
66 |
+
const char *error_message() const;
|
67 |
+
const char *message() const { return error_message(); }
|
68 |
+
StatusCode code() const;
|
69 |
+
std::string ToString() const;
|
70 |
+
|
71 |
+
void IgnoreError();
|
72 |
+
|
73 |
+
private:
|
74 |
+
struct Rep;
|
75 |
+
std::unique_ptr<Rep> rep_;
|
76 |
+
};
|
77 |
+
} // namespace util
|
78 |
+
|
79 |
+
// SentencePieceProcessor:
|
80 |
+
// Simple and language independent tokenizer and de-tokenizer for
|
81 |
+
// Neural Network Machine Translation.
|
82 |
+
//
|
83 |
+
// SentencePieceProcessor provides Encode() and Decode() methods,
|
84 |
+
// which correspond to tokenization and de-tokenization respectively.
|
85 |
+
//
|
86 |
+
// - Encode:
|
87 |
+
// Given a raw source sentence, encode it into a sequence
|
88 |
+
// of pieces or vocabulary ids.
|
89 |
+
//
|
90 |
+
// - Decode:
|
91 |
+
// Given a sequence of pieces or vocabulary ids, decode it
|
92 |
+
// into a de-tokenized raw sentence.
|
93 |
+
//
|
94 |
+
// SentencePieceProcessor provides a lossless data conversion
|
95 |
+
// that allows the original raw sentence to be perfectly reconstructed
|
96 |
+
// from the encoded data, i.e., Decode(Encode(input)) == input.
|
97 |
+
// This characteristics is useful, as we can make the de-tokenization
|
98 |
+
// completely language independent.
|
99 |
+
//
|
100 |
+
// Usage:
|
101 |
+
// SentencePieceProcessor sp;
|
102 |
+
// sp.Load("//path/to/model");
|
103 |
+
//
|
104 |
+
// vector<string> sps;
|
105 |
+
// sp.Encode("hello world.", &sps).IgnoreError();
|
106 |
+
//
|
107 |
+
// vector<int> ids;
|
108 |
+
// sp.Encode("hello world.", &ids).IgnoreError();
|
109 |
+
//
|
110 |
+
// string detok;
|
111 |
+
// sp.Decode(sps, &detok);
|
112 |
+
// CHECK_EQ("hello world.", detok).IgnoreError();
|
113 |
+
//
|
114 |
+
// sp.Decode(ids, &detok);
|
115 |
+
// CHECK_EQ("hello world.", detok).IgnoreError();
|
116 |
+
//
|
117 |
+
// We can also use SentencePieceText which manages the byte-offsets
|
118 |
+
// between user input (output) and internal sentence pieces.
|
119 |
+
//
|
120 |
+
// SentencePieceText spt;
|
121 |
+
// sp.Encode("hello world.", &spt);
|
122 |
+
// // Emits the byte range of each piece.
|
123 |
+
// for (const auto &piece : spt.pieces()) {
|
124 |
+
// LOG(INFO) << piece.begin() << " " << piece.end();
|
125 |
+
// }
|
126 |
+
//
|
127 |
+
// sp.Decode({0, 1, 2, 3..}, &spt);
|
128 |
+
// for (const auto &piece : spt.pieces()) {
|
129 |
+
// LOG(INFO) << piece.begin() << " " << piece.end();
|
130 |
+
// }
|
131 |
+
//
|
132 |
+
|
133 |
+
class NBestSentencePieceText;
|
134 |
+
class ModelInterface;
|
135 |
+
class SentencePieceText;
|
136 |
+
class ModelProto;
|
137 |
+
|
138 |
+
namespace normalizer {
|
139 |
+
class Normalizer;
|
140 |
+
} // namespace normalizer
|
141 |
+
|
142 |
+
#ifndef SWIGGO
|
143 |
+
namespace util {
|
144 |
+
// Redefine std::string for serialized_proto interface as Python's string is
|
145 |
+
// a Unicode string. We can enforce the return value to be raw byte sequence
|
146 |
+
// with SWIG's typemap.
|
147 |
+
using bytes = std::string;
|
148 |
+
} // namespace util
|
149 |
+
#endif // SWIGGO
|
150 |
+
|
151 |
+
class NBestSentencePieceText;
|
152 |
+
class ModelInterface;
|
153 |
+
class SentencePieceText;
|
154 |
+
class SentencePieceText_SentencePiece;
|
155 |
+
|
156 |
+
// Wrapper class of SentencePieceText
|
157 |
+
// This wrapper only allows an immutable access to the proto and
|
158 |
+
// hides the actual implementation of protobuf.
|
159 |
+
// See sentencepiece.proto for the details of this class.
|
160 |
+
class ImmutableSentencePieceText_ImmutableSentencePiece {
|
161 |
+
public:
|
162 |
+
ImmutableSentencePieceText_ImmutableSentencePiece();
|
163 |
+
~ImmutableSentencePieceText_ImmutableSentencePiece() = default;
|
164 |
+
|
165 |
+
const std::string &piece() const;
|
166 |
+
const std::string &surface() const;
|
167 |
+
uint32_t id() const;
|
168 |
+
uint32_t begin() const;
|
169 |
+
uint32_t end() const;
|
170 |
+
|
171 |
+
friend class ImmutableSentencePieceText;
|
172 |
+
|
173 |
+
private:
|
174 |
+
explicit ImmutableSentencePieceText_ImmutableSentencePiece(
|
175 |
+
const SentencePieceText_SentencePiece &sp);
|
176 |
+
const SentencePieceText_SentencePiece *sp_ = nullptr;
|
177 |
+
};
|
178 |
+
|
179 |
+
class ImmutableSentencePieceText {
|
180 |
+
public:
|
181 |
+
ImmutableSentencePieceText();
|
182 |
+
virtual ~ImmutableSentencePieceText();
|
183 |
+
|
184 |
+
std::vector<ImmutableSentencePieceText_ImmutableSentencePiece> pieces() const;
|
185 |
+
|
186 |
+
size_t pieces_size() const;
|
187 |
+
ImmutableSentencePieceText_ImmutableSentencePiece pieces(int index) const;
|
188 |
+
|
189 |
+
const std::string &text() const;
|
190 |
+
float score() const;
|
191 |
+
|
192 |
+
util::bytes SerializeAsString() const;
|
193 |
+
|
194 |
+
// Returns the actual mutable proto.
|
195 |
+
// Do not use this outside of SentencePieceProcessor, as
|
196 |
+
// it returns the raw pointer managed by the shared_ptr.
|
197 |
+
SentencePieceText *mutable_proto();
|
198 |
+
|
199 |
+
// Converts the utf8 byte spans into Unicode char span.
|
200 |
+
void ConvertToUnicodeSpans();
|
201 |
+
|
202 |
+
friend class ImmutableNBestSentencePieceText;
|
203 |
+
|
204 |
+
private:
|
205 |
+
explicit ImmutableSentencePieceText(const SentencePieceText &spt);
|
206 |
+
const SentencePieceText *spt_ = nullptr;
|
207 |
+
std::shared_ptr<SentencePieceText> rep_;
|
208 |
+
};
|
209 |
+
|
210 |
+
// Wrapper class of SentencePieceText
|
211 |
+
// This wrapper only allows an immutable access to the proto and
|
212 |
+
// hides the actual implementation of protobuf.
|
213 |
+
// See sentencepiece.proto for the details of this class.
|
214 |
+
class ImmutableNBestSentencePieceText {
|
215 |
+
public:
|
216 |
+
ImmutableNBestSentencePieceText();
|
217 |
+
virtual ~ImmutableNBestSentencePieceText();
|
218 |
+
|
219 |
+
std::vector<ImmutableSentencePieceText> nbests() const;
|
220 |
+
|
221 |
+
size_t nbests_size() const;
|
222 |
+
ImmutableSentencePieceText nbests(int index) const;
|
223 |
+
|
224 |
+
util::bytes SerializeAsString() const;
|
225 |
+
|
226 |
+
// Returns the actual mutable proto.
|
227 |
+
// Do not use this outside of SentencePieceProcessor, as
|
228 |
+
// it returns the raw pointer managed by the shared_ptr.
|
229 |
+
NBestSentencePieceText *mutable_proto();
|
230 |
+
|
231 |
+
void ConvertToUnicodeSpans();
|
232 |
+
|
233 |
+
private:
|
234 |
+
std::shared_ptr<NBestSentencePieceText> rep_;
|
235 |
+
};
|
236 |
+
|
237 |
+
class SentencePieceProcessor {
|
238 |
+
public:
|
239 |
+
SentencePieceProcessor();
|
240 |
+
virtual ~SentencePieceProcessor();
|
241 |
+
|
242 |
+
// Loads model from `filename`.
|
243 |
+
// Returns false if `filename` cannot be loaded.
|
244 |
+
virtual util::Status Load(absl::string_view filename);
|
245 |
+
|
246 |
+
// Loads model from `filename`.
|
247 |
+
// Crash if `filename` cannot be loaded.
|
248 |
+
virtual void LoadOrDie(absl::string_view filename);
|
249 |
+
|
250 |
+
// Loads model from `model_proto`.
|
251 |
+
// `model_proto` is copied.
|
252 |
+
virtual util::Status Load(const ModelProto &model_proto);
|
253 |
+
|
254 |
+
// Loads model from `model_proto`.
|
255 |
+
// `model_proto` is moved.
|
256 |
+
virtual util::Status Load(std::unique_ptr<ModelProto> model_proto);
|
257 |
+
|
258 |
+
// Loads model from `serialized`, which is a string-serialized model proto.
|
259 |
+
// Useful to load the model from a platform independent blob object.
|
260 |
+
virtual util::Status LoadFromSerializedProto(absl::string_view serialized);
|
261 |
+
|
262 |
+
// Returns the status. Encode/Decode methods are valid when status is OK.
|
263 |
+
virtual util::Status status() const;
|
264 |
+
|
265 |
+
// Sets encode extra_option sequence.
|
266 |
+
virtual util::Status SetEncodeExtraOptions(absl::string_view extra_option);
|
267 |
+
|
268 |
+
// Sets decode extra_option sequence.
|
269 |
+
virtual util::Status SetDecodeExtraOptions(absl::string_view extra_option);
|
270 |
+
|
271 |
+
//////////////////////////////////////////////////////////////
|
272 |
+
// Vocabulary restriction.
|
273 |
+
// Background:
|
274 |
+
// https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt
|
275 |
+
|
276 |
+
// Restricts the vocabulary set.
|
277 |
+
// The input sentences are encoded into the tokens in `valid_vocab`.
|
278 |
+
virtual util::Status SetVocabulary(
|
279 |
+
const std::vector<absl::string_view> &valid_vocab);
|
280 |
+
|
281 |
+
// Reverts the vocabulary restriction.
|
282 |
+
virtual util::Status ResetVocabulary();
|
283 |
+
|
284 |
+
// Loads the valid vocabulary set from `filename` in TSV format.
|
285 |
+
// Format: <token> <tab> <freq>.
|
286 |
+
// Any token with frequency < threshold will be treated as OOV.
|
287 |
+
virtual util::Status LoadVocabulary(absl::string_view filename,
|
288 |
+
int threshold);
|
289 |
+
|
290 |
+
//////////////////////////////////////////////////////////////
|
291 |
+
// Simple Encode and Decode API.
|
292 |
+
//
|
293 |
+
// Given a UTF8 input, encodes it into a sequence of sentence pieces.
|
294 |
+
virtual util::Status Encode(absl::string_view input,
|
295 |
+
std::vector<std::string> *pieces) const;
|
296 |
+
|
297 |
+
// Given a UTF8 input, encodes it into a sequence of ids.
|
298 |
+
virtual util::Status Encode(absl::string_view input,
|
299 |
+
std::vector<int> *ids) const;
|
300 |
+
|
301 |
+
// Given a sequence of pieces, decodes it into a detokenized output.
|
302 |
+
virtual util::Status Decode(const std::vector<std::string> &pieces,
|
303 |
+
std::string *detokenized) const;
|
304 |
+
|
305 |
+
// Given a sequence of pieces, decodes it into a detokenized output.
|
306 |
+
virtual util::Status Decode(const std::vector<absl::string_view> &pieces,
|
307 |
+
std::string *detokenized) const;
|
308 |
+
|
309 |
+
// Given a sequence of ids, decodes it into a detokenized output.
|
310 |
+
virtual util::Status Decode(const std::vector<int> &ids,
|
311 |
+
std::string *detokenized) const;
|
312 |
+
|
313 |
+
//////////////////////////////////////////////////////////////
|
314 |
+
// NBest API.
|
315 |
+
//
|
316 |
+
// Same as Encode, but returns nbest results.
|
317 |
+
virtual util::Status NBestEncode(
|
318 |
+
absl::string_view input, int nbest_size,
|
319 |
+
std::vector<std::vector<std::string>> *pieces) const;
|
320 |
+
|
321 |
+
// Same as Encode, but returns nbest results.
|
322 |
+
virtual util::Status NBestEncode(absl::string_view input, int nbest_size,
|
323 |
+
std::vector<std::vector<int>> *ids) const;
|
324 |
+
|
325 |
+
//////////////////////////////////////////////////////////////
|
326 |
+
// Sampling API.
|
327 |
+
//
|
328 |
+
// Unigram and BPE support sampling mode.
|
329 |
+
// - Unigram (--model_type=unigram):
|
330 |
+
// `nbest_size`: When `nbest_size` is positive value, approximately samples
|
331 |
+
// one segmentation from nbest candidates. When `nbest_size` is negative
|
332 |
+
// value, samples one segmentation from the hypotheses (Lattice) according to
|
333 |
+
// the generation probabilities using forward-filtering and backward-sampling
|
334 |
+
// algorithm.
|
335 |
+
// `alpha`: Smoothing parameter (inverse temperature). The best segmentation
|
336 |
+
// (Viterbi segmentation) is more likely sampled when setting larger alpha.
|
337 |
+
// When alpha is 0.0, one segmentation is uniformly sampled from the nbest or
|
338 |
+
// lattice. `nbest_size` and `alpha` correspond to parameters `l` and `alpha`
|
339 |
+
// in https://arxiv.org/abs/1804.10959 (nbest_size < 0 means l = infinity)
|
340 |
+
//
|
341 |
+
// - BPE (--model_type=bpe):
|
342 |
+
// `alpha`: The dropout probability `p` of bpe merge operations in
|
343 |
+
// https://arxiv.org/abs/1910.13267 Nbest-based sampling is not supported so
|
344 |
+
// nbest_size parameter is ignored in BPE.
|
345 |
+
virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
|
346 |
+
float alpha,
|
347 |
+
std::vector<std::string> *pieces) const;
|
348 |
+
|
349 |
+
// Same as above, but returns a sequence of ids.
|
350 |
+
virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
|
351 |
+
float alpha, std::vector<int> *ids) const;
|
352 |
+
|
353 |
+
//////////////////////////////////////////////////////////////
|
354 |
+
// SampleEncodeAndScore API.
|
355 |
+
//
|
356 |
+
// Sample `samples` many tokenisations from the segmentation lattice.
|
357 |
+
// These methods are only available in model_type=unigram.
|
358 |
+
//
|
359 |
+
// `alpha`: smoothing parameter (inverse temperature). The same as `alpha` in
|
360 |
+
// `Sample` method.
|
361 |
+
// 'wor`: If `wor` is true, the samples are taken without replacement, and the
|
362 |
+
// scores are the inclusion probabilities of the elements in the sample;
|
363 |
+
// otherwise the samples are taken with replacement and the scores are the
|
364 |
+
// log-probs of sample elements
|
365 |
+
// `include_best`: If `include_best` is true, the best tokenisation is always
|
366 |
+
// included in the sample, and the remaining elements are sampled excluding
|
367 |
+
// the best.
|
368 |
+
virtual util::Status SampleEncodeAndScore(
|
369 |
+
absl::string_view input, int num_samples, float alpha, bool wor,
|
370 |
+
bool include_best,
|
371 |
+
std::vector<std::pair<std::vector<std::string>, float>> *pieces) const;
|
372 |
+
|
373 |
+
// Same as above, but returns a sequence of ids.
|
374 |
+
virtual util::Status SampleEncodeAndScore(
|
375 |
+
absl::string_view input, int num_samples, float alpha, bool wor,
|
376 |
+
bool include_best,
|
377 |
+
std::vector<std::pair<std::vector<int>, float>> *ids) const;
|
378 |
+
|
379 |
+
//////////////////////////////////////////////////////////////
|
380 |
+
// Entropy API.
|
381 |
+
//
|
382 |
+
// This only available in model_type=unigram.
|
383 |
+
// Calculate entropy of possible tokenisations
|
384 |
+
virtual util::Status CalculateEntropy(absl::string_view input, float alpha,
|
385 |
+
float *entropy) const;
|
386 |
+
|
387 |
+
//////////////////////////////////////////////////////////////
|
388 |
+
// Advanced API returning SentencePieceText, which manages
|
389 |
+
// utf8-byte alignments between user-input/detokenized text
|
390 |
+
// and internal sentencepiece sequence.
|
391 |
+
//
|
392 |
+
// Given a UTF8 input, encodes it into SentencePieceText.
|
393 |
+
//
|
394 |
+
// When using these APIs, sentencepiece.pb.h header files must be included.
|
395 |
+
// We can also use ImutableSentencePieceText as follows.
|
396 |
+
//
|
397 |
+
// ImmutableSentencePieceText spt;
|
398 |
+
// Encode("hello", spt.mutable_proto()).IgnoreError();
|
399 |
+
// std::cout << spt.pieces_size() << std::endl;
|
400 |
+
virtual util::Status Encode(absl::string_view input,
|
401 |
+
SentencePieceText *spt) const;
|
402 |
+
|
403 |
+
virtual util::Status NBestEncode(absl::string_view input, int nbest_size,
|
404 |
+
NBestSentencePieceText *nbest_spt) const;
|
405 |
+
|
406 |
+
virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
|
407 |
+
float alpha, SentencePieceText *spt) const;
|
408 |
+
|
409 |
+
virtual util::Status SampleEncodeAndScore(
|
410 |
+
absl::string_view input, int num_samples, float alpha, bool wor,
|
411 |
+
bool include_best, NBestSentencePieceText *samples_spt) const;
|
412 |
+
|
413 |
+
// DEPRECATED: Remove this API and use std::vector<std::string_view>
|
414 |
+
virtual util::Status Decode(const std::vector<std::string> &pieces,
|
415 |
+
SentencePieceText *spt) const;
|
416 |
+
|
417 |
+
virtual util::Status Decode(const std::vector<absl::string_view> &pieces,
|
418 |
+
SentencePieceText *spt) const;
|
419 |
+
|
420 |
+
virtual util::Status Decode(const std::vector<int> &ids,
|
421 |
+
SentencePieceText *spt) const;
|
422 |
+
#ifdef SWIG
|
423 |
+
#define SPP_SWIG_CHECK_AND_THROW \
|
424 |
+
if (!status.ok()) throw status;
|
425 |
+
#else
|
426 |
+
#define SPP_SWIG_CHECK_AND_THROW \
|
427 |
+
if (!status.ok()) { \
|
428 |
+
}
|
429 |
+
#endif // SWIG
|
430 |
+
|
431 |
+
#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
|
432 |
+
OutType output; \
|
433 |
+
const auto status = FuncName(__VA_ARGS__, &output); \
|
434 |
+
SPP_SWIG_CHECK_AND_THROW; \
|
435 |
+
return output;
|
436 |
+
|
437 |
+
#define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...) \
|
438 |
+
OutType output; \
|
439 |
+
const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
|
440 |
+
SPP_SWIG_CHECK_AND_THROW; \
|
441 |
+
return output.SerializeAsString();
|
442 |
+
|
443 |
+
#define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...) \
|
444 |
+
OutType output; \
|
445 |
+
const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
|
446 |
+
SPP_SWIG_CHECK_AND_THROW; \
|
447 |
+
return output;
|
448 |
+
|
449 |
+
//////////////////////////////////////////////////////////////
|
450 |
+
// Handy methods that return the result directly.
|
451 |
+
// These functions ignore internal errors.
|
452 |
+
virtual std::vector<std::string> EncodeAsPieces(
|
453 |
+
absl::string_view input) const {
|
454 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<std::string>, input);
|
455 |
+
}
|
456 |
+
|
457 |
+
virtual std::vector<int> EncodeAsIds(absl::string_view input) const {
|
458 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<int>, input);
|
459 |
+
}
|
460 |
+
|
461 |
+
virtual std::vector<std::vector<std::string>> NBestEncodeAsPieces(
|
462 |
+
absl::string_view input, int nbest_size) const {
|
463 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(
|
464 |
+
NBestEncode, std::vector<std::vector<std::string>>, input, nbest_size);
|
465 |
+
}
|
466 |
+
|
467 |
+
virtual std::vector<std::vector<int>> NBestEncodeAsIds(
|
468 |
+
absl::string_view input, int nbest_size) const {
|
469 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(NBestEncode, std::vector<std::vector<int>>,
|
470 |
+
input, nbest_size);
|
471 |
+
}
|
472 |
+
|
473 |
+
virtual std::vector<std::string> SampleEncodeAsPieces(absl::string_view input,
|
474 |
+
int nbest_size,
|
475 |
+
float alpha) const {
|
476 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<std::string>, input,
|
477 |
+
nbest_size, alpha);
|
478 |
+
}
|
479 |
+
|
480 |
+
virtual std::vector<int> SampleEncodeAsIds(absl::string_view input,
|
481 |
+
int nbest_size,
|
482 |
+
float alpha) const {
|
483 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<int>, input,
|
484 |
+
nbest_size, alpha);
|
485 |
+
}
|
486 |
+
|
487 |
+
virtual std::vector<std::pair<std::vector<std::string>, float>>
|
488 |
+
SampleEncodeAndScoreAsPieces(absl::string_view input, int num_samples,
|
489 |
+
float alpha, bool wor, bool include_best) const {
|
490 |
+
using _T = std::vector<std::pair<std::vector<std::string>, float>>;
|
491 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples,
|
492 |
+
alpha, wor, include_best);
|
493 |
+
}
|
494 |
+
|
495 |
+
virtual std::vector<std::pair<std::vector<int>, float>>
|
496 |
+
SampleEncodeAndScoreAsIds(absl::string_view input, int num_samples,
|
497 |
+
float alpha, bool wor, bool include_best) const {
|
498 |
+
using _T = std::vector<std::pair<std::vector<int>, float>>;
|
499 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples,
|
500 |
+
alpha, wor, include_best);
|
501 |
+
}
|
502 |
+
|
503 |
+
// DEPRECATED: Remove this API and use std::vector<std::string_view>
|
504 |
+
virtual std::string DecodePieces(
|
505 |
+
const std::vector<std::string> &pieces) const {
|
506 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces);
|
507 |
+
}
|
508 |
+
|
509 |
+
virtual std::string DecodePieces(
|
510 |
+
const std::vector<absl::string_view> &pieces) const {
|
511 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces);
|
512 |
+
}
|
513 |
+
|
514 |
+
virtual std::string DecodeIds(const std::vector<int> &ids) const {
|
515 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, ids);
|
516 |
+
}
|
517 |
+
|
518 |
+
virtual float CalculateEntropy(absl::string_view text, float alpha) const {
|
519 |
+
DEFINE_SPP_DIRECT_FUNC_IMPL(CalculateEntropy, float, text, alpha);
|
520 |
+
}
|
521 |
+
|
522 |
+
//////////////////////////////////////////////////////////////
|
523 |
+
// SerializedProto API. (DEPRECATED). Use ImmutableProto API.
|
524 |
+
// They are used in Python interface. Returns serialized proto.
|
525 |
+
// In python module, we can get access to the full Proto after
|
526 |
+
// deserialzing the returned byte sequence.
|
527 |
+
virtual util::bytes EncodeAsSerializedProto(absl::string_view input) const {
|
528 |
+
DEFINE_SPP_SERIALIZED_PROTO_IMPL(Encode, ImmutableSentencePieceText, input);
|
529 |
+
}
|
530 |
+
|
531 |
+
virtual util::bytes SampleEncodeAsSerializedProto(absl::string_view input,
|
532 |
+
int nbest_size,
|
533 |
+
float alpha) const {
|
534 |
+
DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText,
|
535 |
+
input, nbest_size, alpha);
|
536 |
+
}
|
537 |
+
|
538 |
+
virtual util::bytes NBestEncodeAsSerializedProto(absl::string_view input,
|
539 |
+
int nbest_size) const {
|
540 |
+
DEFINE_SPP_SERIALIZED_PROTO_IMPL(
|
541 |
+
NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size);
|
542 |
+
}
|
543 |
+
|
544 |
+
virtual util::bytes SampleEncodeAndScoreAsSerializedProto(
|
545 |
+
absl::string_view input, int num_samples, float alpha, bool wor,
|
546 |
+
bool include_best) const {
|
547 |
+
DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncodeAndScore,
|
548 |
+
ImmutableNBestSentencePieceText, input,
|
549 |
+
num_samples, alpha, wor, include_best);
|
550 |
+
}
|
551 |
+
|
552 |
+
// TODO(taku): Remove this API and use std::vector<std::string_view>
|
553 |
+
virtual util::bytes DecodePiecesAsSerializedProto(
|
554 |
+
const std::vector<std::string> &pieces) const {
|
555 |
+
DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText,
|
556 |
+
pieces);
|
557 |
+
}
|
558 |
+
|
559 |
+
virtual util::bytes DecodePiecesAsSerializedProto(
|
560 |
+
const std::vector<absl::string_view> &pieces) const {
|
561 |
+
DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText,
|
562 |
+
pieces);
|
563 |
+
}
|
564 |
+
|
565 |
+
virtual util::bytes DecodeIdsAsSerializedProto(
|
566 |
+
const std::vector<int> &ids) const {
|
567 |
+
DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids);
|
568 |
+
}
|
569 |
+
|
570 |
+
//////////////////////////////////////////////////////////////
|
571 |
+
// ImmutableProto API.
|
572 |
+
virtual ImmutableSentencePieceText EncodeAsImmutableProto(
|
573 |
+
absl::string_view input) const {
|
574 |
+
DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Encode, ImmutableSentencePieceText, input);
|
575 |
+
}
|
576 |
+
|
577 |
+
virtual ImmutableSentencePieceText SampleEncodeAsImmutableProto(
|
578 |
+
absl::string_view input, int nbest_size, float alpha) const {
|
579 |
+
DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText,
|
580 |
+
input, nbest_size, alpha);
|
581 |
+
}
|
582 |
+
|
583 |
+
virtual ImmutableNBestSentencePieceText NBestEncodeAsImmutableProto(
|
584 |
+
absl::string_view input, int nbest_size) const {
|
585 |
+
DEFINE_SPP_IMMUTABLE_PROTO_IMPL(
|
586 |
+
NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size);
|
587 |
+
}
|
588 |
+
|
589 |
+
virtual ImmutableNBestSentencePieceText SampleEncodeAndScoreAsImmutableProto(
|
590 |
+
absl::string_view input, int num_samples, float alpha, bool wor,
|
591 |
+
bool include_best) const {
|
592 |
+
DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncodeAndScore,
|
593 |
+
ImmutableNBestSentencePieceText, input,
|
594 |
+
num_samples, alpha, wor, include_best);
|
595 |
+
}
|
596 |
+
|
597 |
+
// TODO(taku): Remove this API and use std::vector<std::string_view>
|
598 |
+
virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto(
|
599 |
+
const std::vector<std::string> &pieces) const {
|
600 |
+
DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces);
|
601 |
+
}
|
602 |
+
|
603 |
+
virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto(
|
604 |
+
const std::vector<absl::string_view> &pieces) const {
|
605 |
+
DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces);
|
606 |
+
}
|
607 |
+
|
608 |
+
virtual ImmutableSentencePieceText DecodeIdsAsImmutableProto(
|
609 |
+
const std::vector<int> &ids) const {
|
610 |
+
DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids);
|
611 |
+
}
|
612 |
+
|
613 |
+
#undef DEFINE_SPP_DIRECT_FUNC_IMPL
|
614 |
+
#undef DEFINE_SPP_SERIALIZED_PROTO_IMPL
|
615 |
+
#undef DEFINE_SPP_IMMUTABLE_PROTO_IMPL
|
616 |
+
|
617 |
+
//////////////////////////////////////////////////////////////
|
618 |
+
// Vocabulary management methods.
|
619 |
+
//
|
620 |
+
// Returns the size of sentence pieces, which is the same as
|
621 |
+
// the size of vocabulary for NMT.
|
622 |
+
virtual int GetPieceSize() const;
|
623 |
+
|
624 |
+
// Returns the vocab id of `piece`.
|
625 |
+
// Returns UNK(0) if `piece` is unknown.
|
626 |
+
virtual int PieceToId(absl::string_view piece) const;
|
627 |
+
|
628 |
+
// Returns the string representation of vocab with `id`.
|
629 |
+
virtual const std::string &IdToPiece(int id) const;
|
630 |
+
|
631 |
+
// Returns the score of `id`.
|
632 |
+
// Usually score is an emission log probability of unigram language
|
633 |
+
// model.
|
634 |
+
virtual float GetScore(int id) const;
|
635 |
+
|
636 |
+
// Returns true if `id` is unknown symbol.
|
637 |
+
virtual bool IsUnknown(int id) const;
|
638 |
+
|
639 |
+
// Returns true if `id` is control symbol.
|
640 |
+
virtual bool IsControl(int id) const;
|
641 |
+
|
642 |
+
// Returns true if `id` is unused symbol.
|
643 |
+
virtual bool IsUnused(int id) const;
|
644 |
+
|
645 |
+
// Returns true if `id` is byte symbol.
|
646 |
+
virtual bool IsByte(int id) const;
|
647 |
+
|
648 |
+
// Returns the reserved id.
|
649 |
+
// Returns -1 if not defined.
|
650 |
+
|
651 |
+
// Returns unknown (<unk>) id.
|
652 |
+
virtual int unk_id() const;
|
653 |
+
|
654 |
+
// Returns BOS (<s>) id.
|
655 |
+
virtual int bos_id() const;
|
656 |
+
|
657 |
+
// Returns EOS (</s>) id.
|
658 |
+
virtual int eos_id() const;
|
659 |
+
|
660 |
+
// Returns PAD (<pad>) id.
|
661 |
+
virtual int pad_id() const;
|
662 |
+
|
663 |
+
//////////////////////////////////////////////////////////////
|
664 |
+
// Model management.
|
665 |
+
//
|
666 |
+
// Allows injection of a mock model instance. `model` is moved.
|
667 |
+
void SetModel(std::unique_ptr<ModelInterface> &&model);
|
668 |
+
|
669 |
+
// Allows injection of a normalizer instance. `normalizer` is moved.
|
670 |
+
void SetNormalizer(std::unique_ptr<normalizer::Normalizer> &&normalizer);
|
671 |
+
|
672 |
+
// Returns immutable model proto. Useful to obtain extended
|
673 |
+
// or experimental parameters encoded in model_proto.
|
674 |
+
const ModelProto &model_proto() const;
|
675 |
+
|
676 |
+
// returns immutable model proto as std::string.
|
677 |
+
// Useful to save the state of this instance via Python's pickle object.
|
678 |
+
util::bytes serialized_model_proto() const;
|
679 |
+
|
680 |
+
private:
|
681 |
+
enum ExtraOption { REVERSE, BOS, EOS, UNK_PIECE };
|
682 |
+
|
683 |
+
util::Status ParseExtraOptions(absl::string_view extra_option,
|
684 |
+
std::vector<ExtraOption> *extra_options) const;
|
685 |
+
|
686 |
+
util::Status ApplyExtraOptions(const std::vector<ExtraOption> &extra_options,
|
687 |
+
SentencePieceText *spt) const;
|
688 |
+
|
689 |
+
util::Status PopulateSentencePieceText(
|
690 |
+
absl::string_view input, absl::string_view normalized,
|
691 |
+
const std::vector<size_t> &norm_to_orig,
|
692 |
+
const std::vector<std::pair<absl::string_view, int>> &result,
|
693 |
+
SentencePieceText *spt) const;
|
694 |
+
|
695 |
+
std::unique_ptr<ModelInterface> model_;
|
696 |
+
std::unique_ptr<normalizer::Normalizer> normalizer_;
|
697 |
+
std::unique_ptr<normalizer::Normalizer> denormalizer_;
|
698 |
+
|
699 |
+
// Underlying model protocol buffer. The same lifetime as model_.
|
700 |
+
std::unique_ptr<ModelProto> model_proto_;
|
701 |
+
|
702 |
+
std::vector<ExtraOption> encode_extra_options_;
|
703 |
+
std::vector<ExtraOption> decode_extra_options_;
|
704 |
+
};
|
705 |
+
|
706 |
+
// Set seed value of random generator.
|
707 |
+
// Do not set static_cast<unique_int>(-1),
|
708 |
+
// as this seed is reserved for initializing from
|
709 |
+
// std::random_device.
|
710 |
+
void SetRandomGeneratorSeed(unsigned int seed);
|
711 |
+
|
712 |
+
// IO related functions to absorb model formats.
|
713 |
+
namespace io {
|
714 |
+
// Loads `model_proto` from `filename`.
|
715 |
+
// We can instantiate SentencePieceProcessor as follows:
|
716 |
+
//
|
717 |
+
// auto model_proto = absl::make_unique<ModelProto>();
|
718 |
+
// io::LoadModelProto("//path/spm.model", model_proto.get());
|
719 |
+
// SentencePieceProcessor sp;
|
720 |
+
// CHECK_OK(sp.Load(std::move(model_proto)));
|
721 |
+
util::Status LoadModelProto(absl::string_view, ModelProto *model_proto);
|
722 |
+
|
723 |
+
// Saves `model_proto` as `filename`.
|
724 |
+
util::Status SaveModelProto(absl::string_view, const ModelProto &model_proto);
|
725 |
+
} // namespace io
|
726 |
+
} // namespace sentencepiece
|
727 |
+
#endif // SENTENCEPIECE_PROCESSOR_H_
|
ChatGLM2/support/lib_pcie/libbmlib.so
ADDED
Binary file (195 kB). View file
|
|
ChatGLM2/support/lib_pcie/libbmrt.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:621e33823dca470275e09570324a567ce4a30fa6100ac9e52742bb9e1ee02f45
|
3 |
+
size 2966400
|
ChatGLM2/support/lib_pcie/libbmrt.so.1.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:621e33823dca470275e09570324a567ce4a30fa6100ac9e52742bb9e1ee02f45
|
3 |
+
size 2966400
|
ChatGLM2/support/lib_pcie/libsentencepiece.a
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:68811cd99e6e1a58572372f14f3b7a02cf98bc98f5d46d24c406be65a94b53e8
|
3 |
+
size 2858304
|
ChatGLM2/support/lib_soc/libbmlib.so
ADDED
Binary file (191 kB). View file
|
|
ChatGLM2/support/lib_soc/libbmrt.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cff807807fcc8c6a9d16353e389422d434ae2b79c8bc191266d0eb5a69b3d97d
|
3 |
+
size 2915352
|
ChatGLM2/support/lib_soc/libbmrt.so.1.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cff807807fcc8c6a9d16353e389422d434ae2b79c8bc191266d0eb5a69b3d97d
|
3 |
+
size 2915352
|
ChatGLM2/support/lib_soc/libsentencepiece.a
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b1c1ece6c62265ee879cf5876d31e82580c3ee88c2cb627b8ac3eaf35695bde
|
3 |
+
size 3032062
|
ChatGLM2/support/tokenizer/tokenization_chatglm.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from typing import List, Optional, Union, Dict
|
4 |
+
from sentencepiece import SentencePieceProcessor
|
5 |
+
from transformers import PreTrainedTokenizer
|
6 |
+
from transformers.utils import logging, PaddingStrategy
|
7 |
+
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
8 |
+
|
9 |
+
|
10 |
+
class SPTokenizer:
|
11 |
+
def __init__(self, model_path: str):
|
12 |
+
# reload tokenizer
|
13 |
+
assert os.path.isfile(model_path), model_path
|
14 |
+
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
15 |
+
|
16 |
+
# BOS / EOS token IDs
|
17 |
+
self.n_words: int = self.sp_model.vocab_size()
|
18 |
+
self.bos_id: int = self.sp_model.bos_id()
|
19 |
+
self.eos_id: int = self.sp_model.eos_id()
|
20 |
+
self.pad_id: int = self.sp_model.unk_id()
|
21 |
+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
22 |
+
|
23 |
+
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
|
24 |
+
self.special_tokens = {}
|
25 |
+
self.index_special_tokens = {}
|
26 |
+
for token in special_tokens:
|
27 |
+
self.special_tokens[token] = self.n_words
|
28 |
+
self.index_special_tokens[self.n_words] = token
|
29 |
+
self.n_words += 1
|
30 |
+
|
31 |
+
def tokenize(self, s: str):
|
32 |
+
return self.sp_model.EncodeAsPieces(s)
|
33 |
+
|
34 |
+
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
35 |
+
assert type(s) is str
|
36 |
+
t = self.sp_model.encode(s)
|
37 |
+
if bos:
|
38 |
+
t = [self.bos_id] + t
|
39 |
+
if eos:
|
40 |
+
t = t + [self.eos_id]
|
41 |
+
return t
|
42 |
+
|
43 |
+
def decode(self, t: List[int]) -> str:
|
44 |
+
return self.sp_model.decode(t)
|
45 |
+
|
46 |
+
def decode_tokens(self, tokens: List[str]) -> str:
|
47 |
+
text = self.sp_model.DecodePieces(tokens)
|
48 |
+
return text
|
49 |
+
|
50 |
+
def convert_token_to_id(self, token):
|
51 |
+
""" Converts a token (str) in an id using the vocab. """
|
52 |
+
if token in self.special_tokens:
|
53 |
+
return self.special_tokens[token]
|
54 |
+
return self.sp_model.PieceToId(token)
|
55 |
+
|
56 |
+
def convert_id_to_token(self, index):
|
57 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
58 |
+
if index in self.index_special_tokens or index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
59 |
+
return ""
|
60 |
+
return self.sp_model.IdToPiece(index)
|
61 |
+
|
62 |
+
|
63 |
+
class ChatGLMTokenizer(PreTrainedTokenizer):
|
64 |
+
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
65 |
+
|
66 |
+
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
67 |
+
|
68 |
+
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs):
|
69 |
+
self.name = "GLMTokenizer"
|
70 |
+
|
71 |
+
self.vocab_file = vocab_file
|
72 |
+
self.tokenizer = SPTokenizer(vocab_file)
|
73 |
+
self.special_tokens = {
|
74 |
+
"<bos>": self.tokenizer.bos_id,
|
75 |
+
"<eos>": self.tokenizer.eos_id,
|
76 |
+
"<pad>": self.tokenizer.pad_id
|
77 |
+
}
|
78 |
+
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
|
79 |
+
|
80 |
+
def get_command(self, token):
|
81 |
+
if token in self.special_tokens:
|
82 |
+
return self.special_tokens[token]
|
83 |
+
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
84 |
+
return self.tokenizer.special_tokens[token]
|
85 |
+
|
86 |
+
@property
|
87 |
+
def unk_token(self) -> str:
|
88 |
+
return "<unk>"
|
89 |
+
|
90 |
+
@property
|
91 |
+
def pad_token(self) -> str:
|
92 |
+
return "<unk>"
|
93 |
+
|
94 |
+
@property
|
95 |
+
def pad_token_id(self):
|
96 |
+
return self.get_command("<pad>")
|
97 |
+
|
98 |
+
@property
|
99 |
+
def eos_token(self) -> str:
|
100 |
+
return "</s>"
|
101 |
+
|
102 |
+
@property
|
103 |
+
def eos_token_id(self):
|
104 |
+
return self.get_command("<eos>")
|
105 |
+
|
106 |
+
@property
|
107 |
+
def vocab_size(self):
|
108 |
+
return self.tokenizer.n_words
|
109 |
+
|
110 |
+
def get_vocab(self):
|
111 |
+
""" Returns vocab as a dict """
|
112 |
+
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
113 |
+
vocab.update(self.added_tokens_encoder)
|
114 |
+
return vocab
|
115 |
+
|
116 |
+
def _tokenize(self, text, **kwargs):
|
117 |
+
return self.tokenizer.tokenize(text)
|
118 |
+
|
119 |
+
def _convert_token_to_id(self, token):
|
120 |
+
""" Converts a token (str) in an id using the vocab. """
|
121 |
+
return self.tokenizer.convert_token_to_id(token)
|
122 |
+
|
123 |
+
def _convert_id_to_token(self, index):
|
124 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
125 |
+
return self.tokenizer.convert_id_to_token(index)
|
126 |
+
|
127 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
128 |
+
return self.tokenizer.decode_tokens(tokens)
|
129 |
+
|
130 |
+
def save_vocabulary(self, save_directory, filename_prefix=None):
|
131 |
+
"""
|
132 |
+
Save the vocabulary and special tokens file to a directory.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
save_directory (`str`):
|
136 |
+
The directory in which to save the vocabulary.
|
137 |
+
filename_prefix (`str`, *optional*):
|
138 |
+
An optional prefix to add to the named of the saved files.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
`Tuple(str)`: Paths to the files saved.
|
142 |
+
"""
|
143 |
+
if os.path.isdir(save_directory):
|
144 |
+
vocab_file = os.path.join(
|
145 |
+
save_directory, self.vocab_files_names["vocab_file"]
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
vocab_file = save_directory
|
149 |
+
|
150 |
+
with open(self.vocab_file, 'rb') as fin:
|
151 |
+
proto_str = fin.read()
|
152 |
+
|
153 |
+
with open(vocab_file, "wb") as writer:
|
154 |
+
writer.write(proto_str)
|
155 |
+
|
156 |
+
return (vocab_file,)
|
157 |
+
|
158 |
+
def get_prefix_tokens(self):
|
159 |
+
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
160 |
+
return prefix_tokens
|
161 |
+
|
162 |
+
def build_prompt(self, query, history=None):
|
163 |
+
if history is None:
|
164 |
+
history = []
|
165 |
+
prompt = ""
|
166 |
+
for i, (old_query, response) in enumerate(history):
|
167 |
+
prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
|
168 |
+
prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
|
169 |
+
return prompt
|
170 |
+
|
171 |
+
def build_inputs_with_special_tokens(
|
172 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
173 |
+
) -> List[int]:
|
174 |
+
"""
|
175 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
176 |
+
adding special tokens. A BERT sequence has the following format:
|
177 |
+
|
178 |
+
- single sequence: `[CLS] X [SEP]`
|
179 |
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
180 |
+
|
181 |
+
Args:
|
182 |
+
token_ids_0 (`List[int]`):
|
183 |
+
List of IDs to which the special tokens will be added.
|
184 |
+
token_ids_1 (`List[int]`, *optional*):
|
185 |
+
Optional second list of IDs for sequence pairs.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
189 |
+
"""
|
190 |
+
prefix_tokens = self.get_prefix_tokens()
|
191 |
+
token_ids_0 = prefix_tokens + token_ids_0
|
192 |
+
if token_ids_1 is not None:
|
193 |
+
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
|
194 |
+
return token_ids_0
|
195 |
+
|
196 |
+
def _pad(
|
197 |
+
self,
|
198 |
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
199 |
+
max_length: Optional[int] = None,
|
200 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
201 |
+
pad_to_multiple_of: Optional[int] = None,
|
202 |
+
return_attention_mask: Optional[bool] = None,
|
203 |
+
) -> dict:
|
204 |
+
"""
|
205 |
+
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
206 |
+
|
207 |
+
Args:
|
208 |
+
encoded_inputs:
|
209 |
+
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
210 |
+
max_length: maximum length of the returned list and optionally padding length (see below).
|
211 |
+
Will truncate by taking into account the special tokens.
|
212 |
+
padding_strategy: PaddingStrategy to use for padding.
|
213 |
+
|
214 |
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
215 |
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
216 |
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
217 |
+
The tokenizer padding sides are defined in self.padding_side:
|
218 |
+
|
219 |
+
- 'left': pads on the left of the sequences
|
220 |
+
- 'right': pads on the right of the sequences
|
221 |
+
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
222 |
+
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
223 |
+
`>= 7.5` (Volta).
|
224 |
+
return_attention_mask:
|
225 |
+
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
226 |
+
"""
|
227 |
+
# Load from model defaults
|
228 |
+
assert self.padding_side == "left"
|
229 |
+
|
230 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
231 |
+
seq_length = len(required_input)
|
232 |
+
|
233 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
234 |
+
max_length = len(required_input)
|
235 |
+
|
236 |
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
237 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
238 |
+
|
239 |
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
240 |
+
|
241 |
+
# Initialize attention mask if not present.
|
242 |
+
if "attention_mask" not in encoded_inputs:
|
243 |
+
encoded_inputs["attention_mask"] = [1] * seq_length
|
244 |
+
|
245 |
+
if "position_ids" not in encoded_inputs:
|
246 |
+
encoded_inputs["position_ids"] = list(range(seq_length))
|
247 |
+
|
248 |
+
if needs_to_be_padded:
|
249 |
+
difference = max_length - len(required_input)
|
250 |
+
|
251 |
+
if "attention_mask" in encoded_inputs:
|
252 |
+
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
253 |
+
if "position_ids" in encoded_inputs:
|
254 |
+
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
255 |
+
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
256 |
+
|
257 |
+
return encoded_inputs
|
ChatGLM2/support/tokenizer/tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7dc4c393423b76e4373e5157ddc34803a0189ba96b21ddbb40269d31468a6f2
|
3 |
+
size 1018370
|