JoshuaChak commited on
Commit
7c071a8
1 Parent(s): ddb8425

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +45 -0
  2. Baichuan2/README.md +182 -0
  3. Baichuan2/compile/compile.sh +186 -0
  4. Baichuan2/compile/export_onnx.py +182 -0
  5. Baichuan2/compile/files/Baichuan2-7B/config.json +29 -0
  6. Baichuan2/compile/files/Baichuan2-7B/modeling_baichuan.py +792 -0
  7. Baichuan2/compile/torch_inference.py +16 -0
  8. Baichuan2/demo/CMakeLists.txt +38 -0
  9. Baichuan2/demo/demo.cpp +472 -0
  10. Baichuan2/model/tokenizer.model +3 -0
  11. Baichuan2/requirements.txt +7 -0
  12. Baichuan2/src/include/bmdef.h +129 -0
  13. Baichuan2/src/include/bmlib_runtime.h +2581 -0
  14. Baichuan2/src/include/bmruntime_interface.h +404 -0
  15. Baichuan2/src/include/sentencepiece/sentencepiece_processor.h +727 -0
  16. Baichuan2/src/lib_pcie/libbmlib.so +0 -0
  17. Baichuan2/src/lib_pcie/libbmrt.so +3 -0
  18. Baichuan2/src/lib_pcie/libbmrt.so.1.0 +3 -0
  19. Baichuan2/src/lib_pcie/libsentencepiece.a +3 -0
  20. Baichuan2/src/lib_soc/libbmlib.so +0 -0
  21. Baichuan2/src/lib_soc/libbmrt.so +3 -0
  22. Baichuan2/src/lib_soc/libbmrt.so.1.0 +3 -0
  23. Baichuan2/src/lib_soc/libsentencepiece.a +3 -0
  24. Baichuan2/web_demo/CMakeLists.txt +36 -0
  25. Baichuan2/web_demo/chat.cpp +419 -0
  26. Baichuan2/web_demo/chat.py +97 -0
  27. Baichuan2/web_demo/web_demo.py +108 -0
  28. BaseModel/base_model.py +184 -0
  29. ChatGLM2/README.md +160 -0
  30. ChatGLM2/compile/compile.sh +179 -0
  31. ChatGLM2/compile/export_onnx.py +176 -0
  32. ChatGLM2/compile/files/chatglm2-6b/config.json +42 -0
  33. ChatGLM2/compile/files/chatglm2-6b/modeling_chatglm.py +1285 -0
  34. ChatGLM2/demo/CMakeLists.txt +33 -0
  35. ChatGLM2/demo/demo.cpp +609 -0
  36. ChatGLM2/run_demo.sh +27 -0
  37. ChatGLM2/support/include/bmdef.h +129 -0
  38. ChatGLM2/support/include/bmlib_runtime.h +2581 -0
  39. ChatGLM2/support/include/bmruntime_interface.h +404 -0
  40. ChatGLM2/support/include/sentencepiece/sentencepiece_processor.h +727 -0
  41. ChatGLM2/support/lib_pcie/libbmlib.so +0 -0
  42. ChatGLM2/support/lib_pcie/libbmrt.so +3 -0
  43. ChatGLM2/support/lib_pcie/libbmrt.so.1.0 +3 -0
  44. ChatGLM2/support/lib_pcie/libsentencepiece.a +3 -0
  45. ChatGLM2/support/lib_soc/libbmlib.so +0 -0
  46. ChatGLM2/support/lib_soc/libbmrt.so +3 -0
  47. ChatGLM2/support/lib_soc/libbmrt.so.1.0 +3 -0
  48. ChatGLM2/support/lib_soc/libsentencepiece.a +3 -0
  49. ChatGLM2/support/tokenizer/tokenization_chatglm.py +257 -0
  50. ChatGLM2/support/tokenizer/tokenizer.model +3 -0
.gitattributes CHANGED
@@ -34,3 +34,48 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  qwen1.5-1.8b_int4_1dev.bmodel filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  qwen1.5-1.8b_int4_1dev.bmodel filter=lfs diff=lfs merge=lfs -text
37
+ Baichuan2/src/lib_pcie/libbmrt.so filter=lfs diff=lfs merge=lfs -text
38
+ Baichuan2/src/lib_pcie/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
39
+ Baichuan2/src/lib_pcie/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
40
+ Baichuan2/src/lib_soc/libbmrt.so filter=lfs diff=lfs merge=lfs -text
41
+ Baichuan2/src/lib_soc/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
42
+ Baichuan2/src/lib_soc/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
43
+ ChatGLM2/support/lib_pcie/libbmrt.so filter=lfs diff=lfs merge=lfs -text
44
+ ChatGLM2/support/lib_pcie/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
45
+ ChatGLM2/support/lib_pcie/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
46
+ ChatGLM2/support/lib_soc/libbmrt.so filter=lfs diff=lfs merge=lfs -text
47
+ ChatGLM2/support/lib_soc/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
48
+ ChatGLM2/support/lib_soc/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
49
+ ChatGLM3/support/lib_pcie/libbmrt.so filter=lfs diff=lfs merge=lfs -text
50
+ ChatGLM3/support/lib_pcie/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
51
+ ChatGLM3/support/lib_pcie/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
52
+ ChatGLM3/support/lib_soc/libbmrt.so filter=lfs diff=lfs merge=lfs -text
53
+ ChatGLM3/support/lib_soc/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
54
+ ChatGLM3/support/lib_soc/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
55
+ DeepSeek/requirements/sophon-3.7.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
56
+ LWM/support/lib_pcie/libbmrt.so filter=lfs diff=lfs merge=lfs -text
57
+ LWM/support/lib_pcie/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
58
+ LWM/support/lib_pcie/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
59
+ LWM/support/lib_soc/libbmrt.so filter=lfs diff=lfs merge=lfs -text
60
+ LWM/support/lib_soc/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
61
+ LWM/support/lib_soc/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
62
+ Llama2/assets/llama2_pcie filter=lfs diff=lfs merge=lfs -text
63
+ Llama2/demo_parallel/lib/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
64
+ Llama2/support/lib_pcie/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
65
+ Llama2/support/lib_pcie/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
66
+ Llama2/support/lib_soc/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
67
+ Llama2/support/lib_soc/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
68
+ Llama3/python_demo/build/CMakeFiles/chat.dir/chat.cpp.o filter=lfs diff=lfs merge=lfs -text
69
+ Llama3/python_demo/build/chat.cpython-38-aarch64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
70
+ Llama3/python_demo/chat.cpython-38-aarch64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
71
+ Qwen1_5/python_demo/build/CMakeFiles/chat.dir/chat.cpp.o filter=lfs diff=lfs merge=lfs -text
72
+ Qwen1_5/python_demo/build/CMakeFiles/chat_parallel.dir/chat_parallel.cpp.o filter=lfs diff=lfs merge=lfs -text
73
+ Qwen1_5/python_demo/build/chat.cpython-38-aarch64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
74
+ Qwen1_5/python_demo/build/chat_parallel.cpython-38-aarch64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
75
+ Qwen1_5/python_demo/chat.cpython-38-aarch64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
76
+ Qwen1_5/python_demo/chat_parallel.cpython-38-aarch64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
77
+ WizardCoder/demo/lib_pcie/lib/libbmrt.so filter=lfs diff=lfs merge=lfs -text
78
+ WizardCoder/demo/lib_pcie/lib/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
79
+ WizardCoder/demo/lib_soc/lib/libbmrt.so filter=lfs diff=lfs merge=lfs -text
80
+ WizardCoder/demo/lib_soc/lib/libbmrt.so.1.0 filter=lfs diff=lfs merge=lfs -text
81
+ Yi34B/demo_parallel/lib/libsentencepiece.a filter=lfs diff=lfs merge=lfs -text
Baichuan2/README.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![image](../../assets/sophgo_chip.png)
2
+
3
+ # Baichuan2-TPU
4
+
5
+ 本项目实现BM1684X部署语言大模型[Baichuan2-7B](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)。通过[TPU-MLIR](https://github.com/sophgo/tpu-mlir)编译器将模型转换成bmodel,并采用c++代码将其部署到BM1684X的PCIE环境,或者SoC环境。
6
+
7
+ 下文中默认是PCIE环境;如果是SoC环境,按提示操作即可。
8
+
9
+ # 目录说明
10
+ ```
11
+ .
12
+ ├── README.md #使用说明
13
+ ├── requirements.txt #需要使用的python wheel包
14
+ ├── compile
15
+ │   ├── compile.sh #用来编译TPU模型的脚本
16
+ │   ├── export_onnx.py #用来导出onnx的脚本
17
+ │   ├── torch_inference.py #torch推理脚本
18
+ │   └── files
19
+ │      └── Baichuan2-7B #替换Baichuan2-7B-chat的对应文件的备份
20
+ │      ├── config.json
21
+ │      └── modeling_baichuan.py
22
+ ├── demo #Baichuan2 c++代码文件
23
+ │   ├── CMakeLists.txt
24
+ │   └── demo.cpp #主程序
25
+ ├── src #编译依赖库
26
+ │   ├── include
27
+ │   ├── lib_pcie
28
+ │   └── lib_soc
29
+ ├── model #模型文件(bmodel需下载)
30
+ │   ├── baichuan2-7b-test_int8.bmodel
31
+ │   └── tokenizer.model
32
+ └── web_demo #web demo,提供网页对话示例
33
+ ├── chat.cpp
34
+ ├── chat.py
35
+ ├── CMakeLists.txt
36
+ └── web_demo.py
37
+ ```
38
+ ----------------------------
39
+
40
+ # 【阶段一】模型编译
41
+
42
+ ## 注意点
43
+ * 模型编译必须要在docker内完成,无法在docker外操作
44
+
45
+ ### 步骤一:模型下载
46
+ Baichuan2模型在hugging face上完全开源,供用户下载使用。请根据官网下载步骤进行模型与权重的下载。
47
+ ```bash
48
+ # Make sure you have git-lfs installed (https://git-lfs.com)
49
+ git lfs install
50
+ git clone https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
51
+ ```
52
+
53
+ ### 步骤二:下载docker
54
+
55
+ 下载docker,启动容器,如下:
56
+
57
+ ``` shell
58
+ docker pull sophgo/tpuc_dev:latest
59
+
60
+ # myname1234 is just an example, you can set your own name
61
+ docker run --privileged --name myname1234 -v $PWD:/workspace -it sophgo/tpuc_dev:latest
62
+ ```
63
+
64
+ ### 步骤三:下载TPU-MLIR代码并编译
65
+
66
+ ``` shell
67
+ git clone git@github.com:sophgo/tpu-mlir.git
68
+ cd tpu-mlir
69
+ source ./envsetup.sh
70
+ ./build.sh
71
+ ```
72
+ * PS:重新进入docker环境并且需要编译模型时,必须在此路径下执行上述`source ./envsetup.sh` 和 `./build.sh`才能完成后续模型编译。
73
+
74
+ ### 步骤四:下载本项目,安装requirements.txt
75
+ 下载transfomers、sentencepiece、Baichuan2-TPU以及百度网盘里的.bin模型,并替换transformers里面的modeling_baichuan.py
76
+
77
+ ``` shell
78
+ git clone https://github.com/sophgo/Baichuan2-TPU.git
79
+ cd Baichuan2
80
+ pip3 install -r requirements.txt
81
+ ```
82
+
83
+ ### 步骤五:替换modeling_baichuan.py, 修改config.json, 生成onnx文件
84
+ 修改Baichuan2-7B-chat项目中config.json文件中max_position_embeddings与model_max_length,从4096变为512
85
+
86
+ ``` shell
87
+ cd compile
88
+ cp files/Baichuan2-7B/modeling_baichuan.py $BAICHUAN2_PATH
89
+ cp files/Baichuan2-7B/config.json $BAICHUAN2_PATH
90
+ python3 export_onnx.py --model_path $BAICHUAN2_PATH
91
+ ```
92
+
93
+ * PS1:your_model_path 指的是原模型下载后的地址, 如:"../../torch2onnx/Baichuan2-7B-Chat", 可以根据需要选择使用7b模型还是13b模型。
94
+ * PS2:如果你想要debug,而不是一下子生成完成全部的onnx模型,可以将240行的num_layers改成1, 并结合函数对比单个block情况下是否可以和
95
+
96
+ ### 步骤六:生成bmodel文件
97
+
98
+ 生成模型
99
+
100
+ ``` shell
101
+ ./compile.sh --mode int8
102
+ mv baichuan2-7b_int8_1dev.bmodel ../model
103
+ ```
104
+
105
+ * PS1:编译完成后最终会在Baichuan2-TPU/compile路径下生成名为baichuan2-{X}b_{Y}_{Z}dev.bmodel,其中X为7或13,Y为`compile.sh`时选择的`mode`的数据类型,Z为推理的芯片数量(如果不指定num_device, 会省略{Z}dev的部分)
106
+ * PS2:生成bmodel耗时大概3小时以上,建议64G内存以及200G以上硬盘空间,不然很可能OOM或者no space left
107
+ * PS3:目前给定的lib_pcie和lib_soc部分仅包含单芯的动态库,多芯部分会在后续更新
108
+
109
+ ----------------------------
110
+
111
+ # 阶段二:可执行文件生成(可以跳过)
112
+
113
+ ## 准备
114
+ * bmodel模型准备:经过阶段一后将得到编译好的bmodel文件【也可以使用我们提供的现成编译好的bmodel文件】,下载方式为:
115
+ ```shell
116
+ cd Baichuan2-TPU/model
117
+ pip3 install dfss
118
+ # baichuan2-7B
119
+ python3 -m dfss --url=open@sophgo.com:sophon-demo/baichuan2/baichuan2-7b-test_int8.bmodel
120
+ ```
121
+ 将得到编译好的int8单芯bmodel模型文件。
122
+
123
+ ## 编译程序(C++版本)
124
+
125
+ 执行如下编译,默认是PCIE版本:
126
+
127
+ ```shell
128
+ cd Baichuan2-TPU/demo
129
+ mkdir build
130
+ cd build
131
+ cmake ..
132
+ make
133
+ ```
134
+
135
+ 如果是SoC版本,有两种编译方法:
136
+
137
+ 方法1:直接将demo目录拷贝到SoC环境,按以上步骤编译(推荐)
138
+
139
+ 方法2:docker中交叉编译,如下操作
140
+
141
+ ```shell
142
+ wget https://releases.linaro.org/components/toolchain/binaries/7.5-2019.12/aarch64-linux-gnu/gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
143
+ tar -xvf gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu.tar.xz
144
+ mv gcc-linaro-7.5.0-2019.12-x86_64_aarch64-linux-gnu /opt/aarch64-linux-gnu-7.5.0
145
+ cd Baichuan2-TPU/demo
146
+ mkdir build
147
+ cd build
148
+ cmake .. -DTARGET_ARCH=soc # soc 只有一颗芯片,因此不支持多芯编译
149
+ make -j
150
+ ```
151
+
152
+ 编译生成Baichuan2可执行程序。
153
+
154
+ 运行`baichuan2`:
155
+ ```shell
156
+ ./baichuan2 --model ../model/baichuan2-7b-test_int8.bmodel --dev dev_id
157
+ ```
158
+
159
+ ## 编译程序(Python Web版本)【单芯】
160
+
161
+ ```shell
162
+ pip3 install gradio==3.39.0
163
+ cd Baichuan2-TPU/web_demo
164
+ mkdir build
165
+ cd build
166
+ cmake ..
167
+ make -j
168
+ ```
169
+
170
+ 编译成功会在`build`文件夹下生成`libtpuchat.so*`, 此时可以在web_demo.py中指定bmodel\_path token\_path device\_id, lib_path(编译生产的`libtpuchat.so*`文件, 默认路径是`./build`下), 以及dev_id。
171
+ ```python
172
+ python3 web_demo.py
173
+ ```
174
+ 即可成功运行web的demo。
175
+ * PS:在用户不修改上述token\_path的lib\_path的存放路径前提下只需指定bmodel\_path即可运行程序。
176
+
177
+ 如果是SoC环境,参考C++版本
178
+
179
+ * PS:尽量下载gradio==3.39.0版本,不然会出现各种问题!!
180
+
181
+ # 常见问题
182
+ * 请根据实际block数目调整`demo/chat`中或者`web_demo/chat.cpp`中的NUM_LAYERS,默认是使用Baichuan2-7B(NUM_LAYERS=32)
Baichuan2/compile/compile.sh ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -ex
3
+ models=
4
+ mode="f16"
5
+ folder="tmp"
6
+ num_device=1
7
+ mode_args=""
8
+ device_args=""
9
+ quantize_args="--quantize F16"
10
+ name=""
11
+ num_layers=
12
+ out_model=$name.bmodel
13
+
14
+ if [ -z "$name" ]; then
15
+ name="baichuan2-7b"
16
+ echo "Compile Baichuan2-7B"
17
+ else
18
+ name="baichuan2-13b"
19
+ echo "Compile Baichuan2-13B"
20
+ fi
21
+
22
+ while [[ $# -gt 0 ]]; do
23
+ key="$1"
24
+
25
+ case $key in
26
+ --mode)
27
+ mode="$2"
28
+ shift 2
29
+ ;;
30
+ --num_device)
31
+ num_device="$2"
32
+ shift 2
33
+ ;;
34
+ --name)
35
+ name="$2"
36
+ shift 2
37
+ ;;
38
+ *)
39
+ echo "Invalid option: $key" >&2
40
+ exit 1
41
+ ;;
42
+ :)
43
+ echo "Option -$OPTARG requires an argument." >&2
44
+ exit 1
45
+ ;;
46
+ esac
47
+ done
48
+
49
+ if [ x$mode == x"int8" ] || [ x$mode == x"int4" ]; then
50
+ if [ x$mode == x"int8" ]; then
51
+ quantize_args="--quantize W8F16"
52
+ else
53
+ quantize_args="--quantize W4BF16 --q_group_size 64"
54
+ fi
55
+ out_model=$name'_'$mode'.bmodel'
56
+ fi
57
+
58
+ if [ x$name == x"baichuan2-7b" ] || [ x$name == x"baichuan2-13b" ]; then
59
+ if [ x$name == x"baichuan2-7b" ]; then
60
+ num_layers=32
61
+ else
62
+ num_layers=40
63
+ fi
64
+ fi
65
+
66
+ if [ x$num_device != x1 ]; then
67
+ device_args="--num_device $num_device"
68
+ out_model=$name'_'$mode'_'$num_device'dev.bmodel'
69
+ else
70
+ out_model=$name'_'$mode'_1dev.bmodel'
71
+ fi
72
+
73
+ outdir=${folder}/embedding
74
+ mkdir -p $outdir
75
+ pushd $outdir
76
+
77
+ seqlen=512
78
+ model_transform.py \
79
+ --model_name embedding \
80
+ --model_def ../embedding.onnx \
81
+ --input_shapes [[1,$seqlen]] \
82
+ --mlir embedding_${seqlen}.mlir
83
+
84
+
85
+ model_deploy.py \
86
+ --mlir embedding_$seqlen.mlir \
87
+ --quantize F16 \
88
+ --chip bm1684x \
89
+ $device_args \
90
+ --model embedding_${seqlen}_f16.bmodel
91
+
92
+ model_transform.py \
93
+ --model_name embedding_cache \
94
+ --model_def ../embedding.onnx \
95
+ --input_shapes [[1,1]] \
96
+ --mlir embedding_1.mlir
97
+
98
+
99
+ model_deploy.py \
100
+ --mlir embedding_1.mlir \
101
+ --quantize F16 \
102
+ --chip bm1684x \
103
+ $device_args \
104
+ --model embedding_1_f16.bmodel
105
+
106
+ rm *.npz
107
+
108
+ models=$models' '$outdir'/embedding_1_f16.bmodel '$outdir'/embedding_'$seqlen'_f16.bmodel '
109
+
110
+ popd
111
+
112
+ echo $models
113
+
114
+ outdir=${folder}/$mode"_"$num_device"dev"/lm_head
115
+ mkdir -p $outdir
116
+ pushd $outdir
117
+
118
+ model_transform.py \
119
+ --model_name lm_head \
120
+ --model_def ../../lm_head.onnx \
121
+ --mlir lm_head.mlir
122
+
123
+
124
+ model_deploy.py \
125
+ --mlir lm_head.mlir \
126
+ --quantize F16 \
127
+ --chip bm1684x \
128
+ --model lm_head.bmodel
129
+
130
+ rm *.npz
131
+
132
+ models=${models}${outdir}'/lm_head.bmodel '
133
+ popd
134
+
135
+ echo $models
136
+
137
+ outdir=${folder}/$mode"_"$num_device"dev"/block
138
+ mkdir -p $outdir
139
+
140
+ pushd $outdir
141
+ mkdir -p $outdir
142
+
143
+ for ((i=0; i<$num_layers; i++))
144
+ do
145
+
146
+ model_transform.py \
147
+ --model_name block_$i \
148
+ --model_def ../../block_$i.onnx \
149
+ --mlir block_$i.mlir
150
+
151
+ model_deploy.py \
152
+ --mlir block_$i.mlir \
153
+ $quantize_args \
154
+ --chip bm1684x \
155
+ --quant_output \
156
+ --quant_output_list 2,3 \
157
+ $device_args \
158
+ --model block_$i.bmodel
159
+
160
+ model_transform.py \
161
+ --model_name block_cache_$i \
162
+ --model_def ../../block_cache_${i}.onnx \
163
+ --mlir block_cache_$i.mlir
164
+
165
+ model_deploy.py \
166
+ --mlir block_cache_$i.mlir \
167
+ $quantize_args \
168
+ --chip bm1684x \
169
+ --quant_input \
170
+ --quant_output \
171
+ --quant_input_list 4,5 \
172
+ --quant_output_list 2,3 \
173
+ $device_args \
174
+ --model block_cache_$i.bmodel
175
+
176
+ rm *.npz
177
+ # rm ../../block_$i.onnx
178
+ # rm ../../block_cache_$i.onnx
179
+
180
+ models=${models}${outdir}'/block_'$i'.bmodel '$outdir'/block_cache_'$i'.bmodel '
181
+
182
+ done
183
+ popd
184
+ echo $models
185
+
186
+ model_tool --combine $models -o $out_model
Baichuan2/compile/export_onnx.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ==============================================================================
3
+ #
4
+ # Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
5
+ #
6
+ # TPU-MLIR is licensed under the 2-Clause BSD License except for the
7
+ # third-party components.
8
+ #
9
+ # ==============================================================================
10
+
11
+ import os
12
+ import torch
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ from transformers.generation.utils import GenerationConfig
15
+ import numpy as np
16
+ import argparse
17
+
18
+ folder = f"./tmp/onnx"
19
+ parser = argparse.ArgumentParser(description='export onnx.')
20
+ parser.add_argument('--model_path', type=str, help='path to the torch model.')
21
+ parser.add_argument('--seq_length', type=int, default=512, help="sequence length")
22
+
23
+ args = parser.parse_args()
24
+
25
+ model_path = args.model_path
26
+ folder = "./tmp"
27
+
28
+ origin_model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).eval()
29
+ origin_model.generation_config = GenerationConfig.from_pretrained(model_path)
30
+ config = origin_model.config
31
+ transformer = origin_model.model
32
+ layers = transformer.layers
33
+
34
+ SEQ_LENGTH = args.seq_length
35
+ NUM_LAYERS = config.num_hidden_layers
36
+ HIDDEN_SIZE = config.hidden_size
37
+ NUM_ATTENTION_HEADS = config.num_attention_heads
38
+ HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS
39
+
40
+ for param in origin_model.parameters():
41
+ param.requires_grad = False
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
44
+
45
+ class Embedding(torch.nn.Module):
46
+ def __init__(self):
47
+ super().__init__()
48
+
49
+ def forward(self, input_ids):
50
+ return transformer.embed_tokens(input_ids)
51
+
52
+
53
+ class Block(torch.nn.Module):
54
+
55
+ def __init__(self, layer_id):
56
+ super().__init__()
57
+ # params
58
+ self.layer_id = layer_id
59
+ self.layer = layers[layer_id]
60
+
61
+ def forward(self, hidden_states, position_ids, attention_mask):
62
+ hidden_states, past_kv = self.layer(hidden_states,
63
+ attention_mask,
64
+ position_ids,
65
+ use_cache=True)
66
+ present_k, present_v = past_kv
67
+ return hidden_states, present_k, present_v
68
+
69
+
70
+ class BlockCache(torch.nn.Module):
71
+
72
+ def __init__(self, layer_id):
73
+ super().__init__()
74
+ # params
75
+ self.layer_id = layer_id
76
+ self.layer = layers[layer_id]
77
+
78
+ def forward(self, hidden_states, position_ids, attention_mask, past_k,
79
+ past_v):
80
+ hidden_states, past_kv = self.layer(hidden_states,
81
+ attention_mask,
82
+ position_ids=position_ids,
83
+ past_key_value=(past_k, past_v),
84
+ use_cache=True)
85
+ present_k, present_v = past_kv
86
+ return hidden_states, present_k, present_v
87
+
88
+
89
+ class LmHead(torch.nn.Module):
90
+
91
+ def __init__(self):
92
+ super().__init__()
93
+
94
+ def forward(self, hidden_states):
95
+ hidden_states = transformer.norm(hidden_states)
96
+ m_logits = origin_model.lm_head(hidden_states)
97
+ _, token = torch.topk(m_logits, 1)
98
+ return token
99
+
100
+
101
+ def convert_block(layer_id):
102
+ # input
103
+ hidden_states = torch.randn((1, SEQ_LENGTH, HIDDEN_SIZE))
104
+ position_ids = torch.tensor([range(SEQ_LENGTH)], dtype=torch.long)
105
+ attention_mask = torch.randn((1, 1, SEQ_LENGTH, SEQ_LENGTH))
106
+ model = Block(layer_id)
107
+
108
+ torch.onnx.export(
109
+ model, (hidden_states, position_ids, attention_mask),
110
+ f'{folder}/block_{layer_id}.onnx',
111
+ verbose=False,
112
+ input_names=['input_states', 'position_ids', 'attention_mask'],
113
+ output_names=['hidden_states', 'past_k', 'past_v'],
114
+ do_constant_folding=True,
115
+ opset_version=15)
116
+
117
+
118
+ def convert_block_cache(layer_id):
119
+ # input
120
+ np.random.seed(42)
121
+ hidden_states = torch.randn((1, 1, HIDDEN_SIZE))
122
+ position_ids = torch.tensor([range(1)], dtype=torch.long)
123
+ attention_mask = torch.randn((1, 1, 1, SEQ_LENGTH + 1))
124
+ past_k = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM))
125
+ past_v = torch.randn((1, SEQ_LENGTH, NUM_ATTENTION_HEADS, HEAD_DIM))
126
+ model = BlockCache(layer_id)
127
+
128
+ torch.onnx.export(
129
+ model, (hidden_states, position_ids, attention_mask, past_k, past_v),
130
+ f'{folder}/block_cache_{layer_id}.onnx',
131
+ verbose=False,
132
+ input_names=[
133
+ 'input_states', 'position_ids', 'attention_mask', 'history_k',
134
+ 'history_v'
135
+ ],
136
+ output_names=['hidden_states', 'past_k', 'past_v'],
137
+ do_constant_folding=True,
138
+ opset_version=15)
139
+
140
+
141
+ def convert_embedding():
142
+ model = Embedding()
143
+ input = torch.tensor([range(SEQ_LENGTH)])
144
+ torch.onnx.export(model, (input),
145
+ f'{folder}/embedding.onnx',
146
+ verbose=False,
147
+ input_names=['input_ids'],
148
+ output_names=['input_embed'],
149
+ dynamic_axes={"input_ids": {
150
+ 0: "length"
151
+ }},
152
+ do_constant_folding=True,
153
+ opset_version=15)
154
+
155
+
156
+ def convert_lm_head():
157
+ model = LmHead()
158
+ input = torch.randn(1, HIDDEN_SIZE)
159
+ torch.onnx.export(model, (input),
160
+ f'{folder}/lm_head.onnx',
161
+ verbose=False,
162
+ input_names=['hidden_states'],
163
+ output_names=['token'],
164
+ do_constant_folding=True,
165
+ opset_version=15)
166
+
167
+ # create folder to store onnx
168
+ if not os.path.exists(folder):
169
+ os.makedirs(folder)
170
+
171
+ # export models
172
+ for i in range(NUM_LAYERS):
173
+ print("convert_block_{}".format(i))
174
+ convert_block_cache(i)
175
+ convert_block(i)
176
+
177
+ print("convert_embedding")
178
+ convert_embedding()
179
+
180
+ print("convert_lm_head")
181
+ convert_lm_head()
182
+
Baichuan2/compile/files/Baichuan2-7B/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BaichuanForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_baichuan.BaichuanConfig",
7
+ "AutoModelForCausalLM": "modeling_baichuan.BaichuanForCausalLM"
8
+ },
9
+ "tokenizer_class": "BaichuanTokenizer",
10
+ "bos_token_id": 1,
11
+ "eos_token_id": 2,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 4096,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 11008,
16
+ "max_position_embeddings": 4096,
17
+ "model_max_length": 4096,
18
+ "model_type": "baichuan",
19
+ "num_attention_heads": 32,
20
+ "num_hidden_layers": 32,
21
+ "pad_token_id": 0,
22
+ "rms_norm_eps": 1e-06,
23
+ "_from_model_config": true,
24
+ "tie_word_embeddings": false,
25
+ "torch_dtype": "bfloat16",
26
+ "transformers_version": "4.29.2",
27
+ "use_cache": true,
28
+ "vocab_size": 125696
29
+ }
Baichuan2/compile/files/Baichuan2-7B/modeling_baichuan.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Baichuan Inc. All Rights Reserved.
2
+
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ from .configuration_baichuan import BaichuanConfig
24
+ from .generation_utils import build_chat_input, TextIterStreamer
25
+
26
+ import math
27
+ from typing import List, Optional, Tuple, Union
28
+ from threading import Thread
29
+
30
+ import torch
31
+ import torch.utils.checkpoint
32
+ from torch import nn
33
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
34
+ from torch.nn import functional as F
35
+ from transformers import PreTrainedModel, PretrainedConfig
36
+ from transformers.activations import ACT2FN
37
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
38
+ from transformers.generation.utils import GenerationConfig
39
+ from transformers.utils import logging, ContextManagers
40
+
41
+ import os
42
+ from contextlib import contextmanager
43
+ logger = logging.get_logger(__name__)
44
+
45
+ try:
46
+ from xformers import ops as xops
47
+ except ImportError:
48
+ xops = None
49
+ logger.warning(
50
+ "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
51
+ )
52
+
53
+
54
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
55
+ def _make_causal_mask(
56
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
57
+ ):
58
+ """
59
+ Make causal mask used for bi-directional self-attention.
60
+ """
61
+ bsz, tgt_len = input_ids_shape
62
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
63
+ mask_cond = torch.arange(mask.size(-1), device=device)
64
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
65
+ mask = mask.to(dtype)
66
+
67
+ if past_key_values_length > 0:
68
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
69
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
70
+
71
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
72
+ """
73
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
74
+ """
75
+ if len(mask.size()) == 3:
76
+ bsz, src_len, _ = mask.size()
77
+ tgt_len = tgt_len if tgt_len is not None else src_len
78
+ expanded_mask = mask[:,None,:,:].expand(bsz, 1, tgt_len, src_len).to(dtype)
79
+ else:
80
+ bsz, src_len = mask.size()
81
+ tgt_len = tgt_len if tgt_len is not None else src_len
82
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
83
+
84
+ inverted_mask = 1.0 - expanded_mask
85
+
86
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
87
+
88
+
89
+ class RMSNorm(nn.Module):
90
+ def __init__(self, hidden_size, eps=1e-6):
91
+ """
92
+ RMSNorm is equivalent to T5LayerNorm
93
+ """
94
+ super().__init__()
95
+ self.weight = nn.Parameter(torch.ones(hidden_size))
96
+ self.variance_epsilon = eps
97
+
98
+ def forward(self, hidden_states):
99
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
100
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
101
+
102
+ # convert into half-precision if necessary
103
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
104
+ hidden_states = hidden_states.to(self.weight.dtype)
105
+
106
+ return self.weight * hidden_states
107
+
108
+
109
+ class RotaryEmbedding(torch.nn.Module):
110
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
111
+ super().__init__()
112
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
113
+ self.max_seq_len_cached = max_position_embeddings
114
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
115
+ freqs = torch.outer(t, self.inv_freq)
116
+ emb = torch.cat((freqs, freqs), dim=-1)
117
+ self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
118
+ self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
119
+ def forward(self, x, seq_len=None):
120
+ # x: [bs, num_attention_heads, seq_len, head_size]
121
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
122
+ if seq_len > self.max_seq_len_cached:
123
+ self.max_seq_len_cached = seq_len
124
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
125
+ freqs = torch.outer(t, self.inv_freq)
126
+ emb = torch.cat((freqs, freqs), dim=-1)
127
+ self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
128
+ self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
129
+ elif self.cos_cached.device != x.device:
130
+ self.cos_cached = self.cos_cached.to(x.device)
131
+ self.sin_cached = self.sin_cached.to(x.device)
132
+ return (
133
+ self.cos_cached[:, :, :seq_len, ...],
134
+ self.sin_cached[:, :, :seq_len, ...],
135
+ )
136
+
137
+
138
+ def rotate_half(x):
139
+ """Rotates half the hidden dims of the input."""
140
+ x1 = x[..., : x.shape[-1] // 2]
141
+ x2 = x[..., x.shape[-1] // 2:]
142
+ return torch.cat((-x2, x1), dim=-1)
143
+
144
+
145
+ def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
146
+ cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
147
+ sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
148
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
149
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
150
+ cos = cos.transpose(1, 2)
151
+ sin = sin.transpose(1, 2)
152
+ q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
153
+ k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
154
+ return q_embed.to(q.dtype), k_embed.to(k.dtype)
155
+
156
+
157
+ class MLP(nn.Module):
158
+ def __init__(
159
+ self,
160
+ hidden_size: int,
161
+ intermediate_size: int,
162
+ hidden_act: str,
163
+ ):
164
+ super().__init__()
165
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
166
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
167
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
168
+ self.act_fn = ACT2FN[hidden_act]
169
+
170
+ def forward(self, x):
171
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
172
+
173
+
174
+ class Attention(nn.Module):
175
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
176
+ def __init__(self, config: BaichuanConfig):
177
+ super().__init__()
178
+ self.config = config
179
+ self.hidden_size = config.hidden_size
180
+ self.num_heads = config.num_attention_heads
181
+ self.head_dim = self.hidden_size // self.num_heads
182
+ self.max_position_embeddings = config.max_position_embeddings
183
+
184
+ if (self.head_dim * self.num_heads) != self.hidden_size:
185
+ raise ValueError(
186
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
187
+ f" and `num_heads`: {self.num_heads})."
188
+ )
189
+ self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
190
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
191
+ self.rotary_emb = RotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
192
+
193
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
194
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
195
+
196
+ def forward(
197
+ self,
198
+ hidden_states: torch.Tensor,
199
+ attention_mask: Optional[torch.Tensor] = None,
200
+ position_ids: Optional[torch.LongTensor] = None,
201
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
202
+ output_attentions: bool = False,
203
+ use_cache: bool = False,
204
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
205
+ bsz, q_len, _ = hidden_states.size()
206
+
207
+ proj = self.W_pack(hidden_states)
208
+ proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
209
+ query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim)
210
+ key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim)
211
+ value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim)
212
+
213
+ kv_seq_len = key_states.shape[-3]
214
+ if past_key_value is not None:
215
+ kv_seq_len = kv_seq_len + past_key_value[0].shape[-3]
216
+ if past_key_value is not None:
217
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len-1)
218
+ else:
219
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
220
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
221
+ # [bsz, nh, t, hd]
222
+ past_kv = (key_states, value_states) if use_cache else None
223
+ if past_key_value is not None:
224
+ # reuse k, v, self_attention
225
+ key_states = torch.cat([past_key_value[0], key_states], dim=1)
226
+ value_states = torch.cat([past_key_value[1], value_states], dim=1)
227
+
228
+
229
+ if xops is not None and self.training:
230
+ attn_weights = None
231
+ query_states = query_states.transpose(1, 2)
232
+ key_states = key_states.transpose(1, 2)
233
+ value_states = value_states.transpose(1, 2)
234
+ attn_output = xops.memory_efficient_attention(
235
+ query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
236
+ )
237
+ else:
238
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
239
+ query_states = query_states.transpose(1, 2)
240
+ key_states = key_states.transpose(1, 2)
241
+ value_states = value_states.transpose(1, 2)
242
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
243
+ attn_output = attn_output.transpose(1, 2)
244
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
245
+ attn_output = self.o_proj(attn_output)
246
+
247
+ if not output_attentions:
248
+ attn_weights = None
249
+ return attn_output, attn_weights, past_kv
250
+
251
+
252
+ class DecoderLayer(nn.Module):
253
+ def __init__(self, config: BaichuanConfig):
254
+ super().__init__()
255
+ self.hidden_size = config.hidden_size
256
+ self.self_attn = Attention(config=config)
257
+ self.mlp = MLP(
258
+ hidden_size=self.hidden_size,
259
+ intermediate_size=config.intermediate_size,
260
+ hidden_act=config.hidden_act,
261
+ )
262
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
263
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
264
+
265
+ def forward(
266
+ self,
267
+ hidden_states: torch.Tensor,
268
+ attention_mask: Optional[torch.Tensor] = None,
269
+ position_ids: Optional[torch.LongTensor] = None,
270
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
271
+ output_attentions: Optional[bool] = False,
272
+ use_cache: Optional[bool] = False,
273
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
274
+
275
+ residual = hidden_states
276
+
277
+ hidden_states = self.input_layernorm(hidden_states)
278
+
279
+ # Self Attention
280
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
281
+ hidden_states=hidden_states,
282
+ attention_mask=attention_mask,
283
+ position_ids=position_ids,
284
+ past_key_value=past_key_value,
285
+ output_attentions=output_attentions,
286
+ use_cache=use_cache,
287
+ )
288
+ hidden_states = residual + hidden_states
289
+
290
+ # Fully Connected
291
+ residual = hidden_states
292
+ hidden_states = self.post_attention_layernorm(hidden_states)
293
+ hidden_states = self.mlp(hidden_states)
294
+ hidden_states = residual + hidden_states
295
+
296
+ outputs = (hidden_states,)
297
+
298
+ if output_attentions:
299
+ outputs += (self_attn_weights,)
300
+
301
+ if use_cache:
302
+ outputs += (present_key_value,)
303
+
304
+ return outputs
305
+
306
+
307
+ class BaichuanPreTrainedModel(PreTrainedModel):
308
+ config_class = BaichuanConfig
309
+ base_model_prefix = "model"
310
+ supports_gradient_checkpointing = True
311
+ _no_split_modules = ["DecoderLayer"]
312
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
313
+
314
+ def _init_weights(self, module):
315
+ std = self.config.initializer_range
316
+ if isinstance(module, nn.Linear):
317
+ module.weight.data.normal_(mean=0.0, std=std)
318
+ if module.bias is not None:
319
+ module.bias.data.zero_()
320
+ elif isinstance(module, nn.Embedding):
321
+ module.weight.data.normal_(mean=0.0, std=std)
322
+ if module.padding_idx is not None:
323
+ module.weight.data[module.padding_idx].zero_()
324
+
325
+ def _set_gradient_checkpointing(self, module, value=False):
326
+ if isinstance(module, BaichuanModel):
327
+ module.gradient_checkpointing = value
328
+
329
+
330
+ class BaichuanModel(BaichuanPreTrainedModel):
331
+ def __init__(self, config: BaichuanConfig):
332
+ super().__init__(config)
333
+ self.padding_idx = config.pad_token_id
334
+ self.vocab_size = config.vocab_size
335
+
336
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
337
+ self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
338
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
339
+
340
+ self.gradient_checkpointing = False
341
+ # Initialize weights and apply final processing
342
+ self.post_init()
343
+
344
+ def get_input_embeddings(self):
345
+ return self.embed_tokens
346
+
347
+ def set_input_embeddings(self, value):
348
+ self.embed_tokens = value
349
+
350
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
351
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
352
+ # create causal mask
353
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
354
+ combined_attention_mask = None
355
+ if input_shape[-1] > 1:
356
+ combined_attention_mask = _make_causal_mask(
357
+ input_shape,
358
+ inputs_embeds.dtype,
359
+ device=inputs_embeds.device,
360
+ past_key_values_length=past_key_values_length,
361
+ )
362
+
363
+ if attention_mask is not None:
364
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
365
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
366
+ inputs_embeds.device
367
+ )
368
+ combined_attention_mask = (
369
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
370
+ )
371
+
372
+ return combined_attention_mask
373
+
374
+ def forward(
375
+ self,
376
+ input_ids: torch.LongTensor = None,
377
+ attention_mask: Optional[torch.Tensor] = None,
378
+ position_ids: Optional[torch.LongTensor] = None,
379
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
380
+ inputs_embeds: Optional[torch.FloatTensor] = None,
381
+ use_cache: Optional[bool] = None,
382
+ output_attentions: Optional[bool] = None,
383
+ output_hidden_states: Optional[bool] = None,
384
+ return_dict: Optional[bool] = None,
385
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
386
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
387
+ output_hidden_states = (
388
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
389
+ )
390
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
391
+
392
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
393
+
394
+ # retrieve input_ids and inputs_embeds
395
+ if input_ids is not None and inputs_embeds is not None:
396
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
397
+ elif input_ids is not None:
398
+ batch_size, seq_length = input_ids.shape
399
+ elif inputs_embeds is not None:
400
+ batch_size, seq_length, _ = inputs_embeds.shape
401
+ else:
402
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
403
+
404
+ seq_length_with_past = seq_length
405
+ past_key_values_length = 0
406
+
407
+ if past_key_values is not None:
408
+ past_key_values_length = past_key_values[0][0].shape[2]
409
+ seq_length_with_past = seq_length_with_past + past_key_values_length
410
+
411
+ if position_ids is None:
412
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
413
+ position_ids = torch.arange(
414
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
415
+ )
416
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
417
+ else:
418
+ position_ids = position_ids.view(-1, seq_length).long()
419
+
420
+ if inputs_embeds is None:
421
+ inputs_embeds = self.embed_tokens(input_ids)
422
+ # embed positions
423
+ if attention_mask is None:
424
+ attention_mask = torch.ones(
425
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
426
+ )
427
+ attention_mask = self._prepare_decoder_attention_mask(
428
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
429
+ )
430
+
431
+ hidden_states = inputs_embeds
432
+
433
+ if self.gradient_checkpointing and self.training:
434
+ if use_cache:
435
+ logger.warning_once(
436
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
437
+ )
438
+ use_cache = False
439
+
440
+ # decoder layers
441
+ all_hidden_states = () if output_hidden_states else None
442
+ all_self_attns = () if output_attentions else None
443
+ next_decoder_cache = () if use_cache else None
444
+
445
+ for idx, decoder_layer in enumerate(self.layers):
446
+ if output_hidden_states:
447
+ all_hidden_states += (hidden_states,)
448
+
449
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
450
+
451
+ if self.gradient_checkpointing and self.training:
452
+
453
+ def create_custom_forward(module):
454
+ def custom_forward(*inputs):
455
+ # None for past_key_value
456
+ return module(*inputs, output_attentions, None)
457
+
458
+ return custom_forward
459
+
460
+ layer_outputs = torch.utils.checkpoint.checkpoint(
461
+ create_custom_forward(decoder_layer),
462
+ hidden_states,
463
+ attention_mask,
464
+ position_ids,
465
+ None,
466
+ )
467
+ else:
468
+ layer_outputs = decoder_layer(
469
+ hidden_states,
470
+ attention_mask=attention_mask,
471
+ position_ids=position_ids,
472
+ past_key_value=past_key_value,
473
+ output_attentions=output_attentions,
474
+ use_cache=use_cache,
475
+ )
476
+
477
+ hidden_states = layer_outputs[0]
478
+
479
+ if use_cache:
480
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
481
+
482
+ if output_attentions:
483
+ all_self_attns += (layer_outputs[1],)
484
+
485
+ hidden_states = self.norm(hidden_states)
486
+
487
+ # add hidden states from the last decoder layer
488
+ if output_hidden_states:
489
+ all_hidden_states += (hidden_states,)
490
+
491
+ next_cache = next_decoder_cache if use_cache else None
492
+ if not return_dict:
493
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
494
+ return BaseModelOutputWithPast(
495
+ last_hidden_state=hidden_states,
496
+ past_key_values=next_cache,
497
+ hidden_states=all_hidden_states,
498
+ attentions=all_self_attns,
499
+ )
500
+
501
+
502
+ class NormHead(nn.Module):
503
+ def __init__(self, hidden_size, vocab_size, bias=False):
504
+ super().__init__()
505
+ self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size)))
506
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
507
+ self.first_flag = True
508
+
509
+ def forward(self, hidden_states):
510
+ if self.training:
511
+ norm_weight = nn.functional.normalize(self.weight)
512
+ self.first_flag = True
513
+ elif self.first_flag:
514
+ self.first_flag = False
515
+ self.weight.data = nn.functional.normalize(self.weight)
516
+ norm_weight = self.weight
517
+ else:
518
+ norm_weight = self.weight
519
+ return nn.functional.linear(hidden_states, norm_weight)
520
+
521
+ _init_weights = True
522
+ @contextmanager
523
+ def no_init_weights(_enable=True):
524
+ global _init_weights
525
+ old_init_weights = _init_weights
526
+ if _enable:
527
+ _init_weights = False
528
+ try:
529
+ yield
530
+ finally:
531
+ _init_weights = old_init_weights
532
+
533
+ class BaichuanForCausalLM(BaichuanPreTrainedModel):
534
+ def __init__(self, config, *model_args, **model_kwargs):
535
+ super().__init__(config, *model_args, **model_kwargs)
536
+ self.model = BaichuanModel(config)
537
+
538
+ self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False)
539
+ if hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and config.quantization_config.get('load_in_4bit', False):
540
+ try:
541
+ from .quantizer import quantize_offline, init_model_weight_int4
542
+ except ImportError:
543
+ raise ImportError(f"Needs QLinear to run quantize.")
544
+ quantize_offline(self, 4)
545
+ # Initialize weights and apply final processing
546
+ self.post_init()
547
+
548
+ def get_input_embeddings(self):
549
+ return self.model.embed_tokens
550
+
551
+ def set_input_embeddings(self, value):
552
+ self.model.embed_tokens = value
553
+
554
+ def get_output_embeddings(self):
555
+ return self.lm_head
556
+
557
+ def set_output_embeddings(self, new_embeddings):
558
+ self.lm_head = new_embeddings
559
+
560
+ def set_decoder(self, decoder):
561
+ self.model = decoder
562
+
563
+ def get_decoder(self):
564
+ return self.model
565
+
566
+ @classmethod
567
+ def from_pretrained(
568
+ cls,
569
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
570
+ *model_args,
571
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
572
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
573
+ ignore_mismatched_sizes: bool = False,
574
+ force_download: bool = False,
575
+ local_files_only: bool = False,
576
+ token: Optional[Union[str, bool]] = None,
577
+ revision: str = "main",
578
+ use_safetensors: bool = None,
579
+ **kwargs,
580
+ ):
581
+ # Load config if we don't provide a configuration
582
+ if not isinstance(config, PretrainedConfig):
583
+ config_path = config if config is not None else pretrained_model_name_or_path
584
+ config, model_kwargs = cls.config_class.from_pretrained(
585
+ config_path,
586
+ cache_dir=cache_dir,
587
+ return_unused_kwargs=True,
588
+ force_download=force_download,
589
+ resume_download=False,
590
+ proxies=None,
591
+ local_files_only=local_files_only,
592
+ token=token,
593
+ revision=revision,
594
+ subfolder="",
595
+ _from_auto=False,
596
+ _from_pipeline=None,
597
+ **kwargs,
598
+ )
599
+ else:
600
+ model_kwargs = kwargs
601
+
602
+ if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']:
603
+ try:
604
+ from .quantizer import init_model_weight_int4
605
+ from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map
606
+ from accelerate.utils import CustomDtype
607
+ from accelerate.utils import get_balanced_memory
608
+ except ImportError:
609
+ raise ImportError(f"Needs import model weight init func to run quantize.")
610
+ # Instantiate model.
611
+ init_contexts = [no_init_weights(_enable=True)]
612
+ init_contexts.append(init_empty_weights())
613
+ with ContextManagers(init_contexts):
614
+ model = cls(config)
615
+
616
+ model_file = os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin')
617
+ state_dict = torch.load(model_file, map_location="cpu")
618
+ model.is_quantized = True
619
+
620
+ device_map = kwargs.pop("device_map", None)
621
+ torch_dtype = kwargs.pop("torch_dtype", None)
622
+
623
+ if device_map is not None:
624
+ kwargs = {"no_split_module_classes": model._no_split_modules}
625
+ target_dtype = CustomDtype.INT4
626
+ max_memory = get_balanced_memory(
627
+ model,
628
+ dtype=target_dtype,
629
+ low_zero=(device_map == "balanced_low_0"),
630
+ max_memory=None,
631
+ **kwargs,
632
+ )
633
+ kwargs["max_memory"] = max_memory
634
+ device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs)
635
+
636
+ model = init_model_weight_int4(config, model, state_dict)
637
+
638
+ # Set model in evaluation mode to deactivate DropOut modules by default
639
+ model.eval()
640
+ # If it is a model with generation capabilities, attempt to load the generation config
641
+ if model.can_generate():
642
+ try:
643
+ model.generation_config = GenerationConfig.from_pretrained(
644
+ pretrained_model_name_or_path,
645
+ cache_dir=cache_dir,
646
+ force_download=force_download,
647
+ resume_download=False,
648
+ proxies=None,
649
+ local_files_only=local_files_only,
650
+ token=token,
651
+ revision=revision,
652
+ subfolder="",
653
+ _from_auto=False,
654
+ _from_pipeline=None,
655
+ **kwargs,
656
+ )
657
+ except (OSError, TypeError):
658
+ logger.info(
659
+ "Generation config file not found, using a generation config created from the model config."
660
+ )
661
+ pass
662
+
663
+ if device_map is not None:
664
+ dispatch_model(model, device_map=device_map)
665
+
666
+ return model
667
+ return super(BaichuanForCausalLM, cls).from_pretrained(pretrained_model_name_or_path, *model_args,
668
+ config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes,
669
+ force_download=force_download, local_files_only=local_files_only, token=token, revision=revision,
670
+ use_safetensors=use_safetensors, **kwargs)
671
+
672
+ def forward(
673
+ self,
674
+ input_ids: torch.LongTensor = None,
675
+ attention_mask: Optional[torch.Tensor] = None,
676
+ position_ids: Optional[torch.LongTensor] = None,
677
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
678
+ inputs_embeds: Optional[torch.FloatTensor] = None,
679
+ labels: Optional[torch.LongTensor] = None,
680
+ use_cache: Optional[bool] = None,
681
+ output_attentions: Optional[bool] = None,
682
+ output_hidden_states: Optional[bool] = None,
683
+ return_dict: Optional[bool] = None,
684
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
685
+
686
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
687
+ output_hidden_states = (
688
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
689
+ )
690
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
691
+
692
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
693
+ outputs = self.model(
694
+ input_ids=input_ids,
695
+ attention_mask=attention_mask,
696
+ position_ids=position_ids,
697
+ past_key_values=past_key_values,
698
+ inputs_embeds=inputs_embeds,
699
+ use_cache=use_cache,
700
+ output_attentions=output_attentions,
701
+ output_hidden_states=output_hidden_states,
702
+ return_dict=return_dict,
703
+ )
704
+
705
+ hidden_states = outputs[0]
706
+ logits = self.lm_head(hidden_states)
707
+ loss = None
708
+ if labels is not None:
709
+ # Shift so that tokens < n predict n
710
+ shift_logits = logits[..., :-1, :].contiguous()
711
+ shift_labels = labels[..., 1:].contiguous()
712
+ # Flatten the tokens
713
+ loss_fct = CrossEntropyLoss()
714
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
715
+ shift_labels = shift_labels.view(-1)
716
+ softmax_normalizer = shift_logits.max(-1).values ** 2
717
+ z_loss = self.config.z_loss_weight * softmax_normalizer.mean()
718
+ # Enable model parallelism
719
+ shift_labels = shift_labels.to(shift_logits.device)
720
+ loss = loss_fct(shift_logits, shift_labels) + z_loss
721
+
722
+ if not return_dict:
723
+ output = (logits,) + outputs[1:]
724
+ return (loss,) + output if loss is not None else output
725
+
726
+ return CausalLMOutputWithPast(
727
+ loss=loss,
728
+ logits=logits,
729
+ past_key_values=outputs.past_key_values,
730
+ hidden_states=outputs.hidden_states,
731
+ attentions=outputs.attentions,
732
+ )
733
+
734
+ def prepare_inputs_for_generation(
735
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
736
+ ):
737
+ if past_key_values:
738
+ input_ids = input_ids[:, -1:]
739
+
740
+ position_ids = kwargs.get("position_ids", None)
741
+ if attention_mask is not None and position_ids is None:
742
+ # create position_ids on the fly for batch generation
743
+ position_ids = attention_mask.long().cumsum(-1) - 1
744
+ position_ids.masked_fill_(attention_mask == 0, 1)
745
+ if past_key_values:
746
+ position_ids = position_ids[:, -1].unsqueeze(-1)
747
+
748
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
749
+ if inputs_embeds is not None and past_key_values is None:
750
+ model_inputs = {"inputs_embeds": inputs_embeds}
751
+ else:
752
+ model_inputs = {"input_ids": input_ids}
753
+
754
+ model_inputs.update(
755
+ {
756
+ "position_ids": position_ids,
757
+ "past_key_values": past_key_values,
758
+ "use_cache": kwargs.get("use_cache"),
759
+ "attention_mask": attention_mask,
760
+ }
761
+ )
762
+ return model_inputs
763
+
764
+ @staticmethod
765
+ def _reorder_cache(past_key_values, beam_idx):
766
+ reordered_past = ()
767
+ for layer_past in past_key_values:
768
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
769
+ return reordered_past
770
+
771
+ def quantize(self, bits: int):
772
+ try:
773
+ from .quantizer import quantize_online
774
+ except ImportError:
775
+ raise ImportError(f"Needs QLinear to run quantize.")
776
+ return quantize_online(self, bits)
777
+
778
+ def chat(self, tokenizer, messages: List[dict], stream=False,
779
+ generation_config: Optional[GenerationConfig]=None):
780
+ generation_config = generation_config or self.generation_config
781
+ input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
782
+ if stream:
783
+ streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
784
+ Thread(target=self.generate, kwargs=dict(
785
+ inputs=input_ids, streamer=streamer,
786
+ generation_config=generation_config,
787
+ )).start()
788
+ return streamer
789
+ else:
790
+ outputs = self.generate(input_ids, generation_config=generation_config)
791
+ response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
792
+ return response
Baichuan2/compile/torch_inference.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from transformers.generation.utils import GenerationConfig
4
+ import argparse
5
+
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument('model_path', help='下载模型的绝对路径')
8
+ args = parser.parse_args()
9
+ model_path = args.model_path
10
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
11
+ model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.float32, trust_remote_code=True)
12
+ model.generation_config = GenerationConfig.from_pretrained(model_path)
13
+ messages = []
14
+ messages.append({"role": "user", "content": "解释一下“温故而知新”"})
15
+ response = model.chat(tokenizer, messages)
16
+ print(response)
Baichuan2/demo/CMakeLists.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 2.8)
2
+ project(baichuan2)
3
+
4
+ if (NOT DEFINED TARGET_ARCH)
5
+ set(TARGET_ARCH pcie)
6
+ endif()
7
+
8
+ set(CMAKE_INSTALL_PREFIX install)
9
+
10
+ if (${CMAKE_HOST_SYSTEM_PROCESSOR} STREQUAL "aarch64")
11
+ add_definitions(-DSOC_TARGET)
12
+ link_directories(${PROJECT_SOURCE_DIR}/../src/lib_soc)
13
+ message("SoC mode, starting......")
14
+ elseif (${TARGET_ARCH} STREQUAL "pcie")
15
+ add_definitions(-DPCIE_TARGET)
16
+ link_directories(${PROJECT_SOURCE_DIR}/../src/lib_pcie)
17
+ message("Pcie mode, starting......")
18
+ elseif (${TARGET_ARCH} STREQUAL "soc")
19
+ add_definitions(-DSOC_TARGET)
20
+ set(CMAKE_C_COMPILER aarch64-linux-gnu-gcc)
21
+ set(CMAKE_ASM_COMPILER aarch64-linux-gnu-gcc)
22
+ set(CMAKE_CXX_COMPILER aarch64-linux-gnu-g++)
23
+ link_directories(${PROJECT_SOURCE_DIR}/lib_soc)
24
+ message("SoC mode, starting......")
25
+ endif()
26
+
27
+
28
+
29
+
30
+ include_directories(${PROJECT_SOURCE_DIR}/../src/include)
31
+
32
+ add_definitions(-DDEBUG --std=c++17 -fPIC -Wall -Werror)
33
+ set(CMAKE_BUILD_TYPE "Debug")
34
+
35
+ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
36
+ add_executable(baichuan2 demo.cpp)
37
+ target_link_libraries(baichuan2 bmrt bmlib sentencepiece)
38
+
Baichuan2/demo/demo.cpp ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //===----------------------------------------------------------------------===//
2
+ //
3
+ // Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
4
+ //
5
+ // TPU-MLIR is licensed under the 2-Clause BSD License except for the
6
+ // third-party components.
7
+ //
8
+ //===----------------------------------------------------------------------===//
9
+
10
+ #include <iostream>
11
+ #include <cstdlib>
12
+ #include <vector>
13
+ #include <assert.h>
14
+ #include <chrono>
15
+ #include <algorithm>
16
+ #include "memory.h"
17
+ #include "sentencepiece/sentencepiece_processor.h"
18
+ #include "bmruntime_interface.h"
19
+ #include <getopt.h>
20
+ #include <numeric>
21
+
22
+ static const int NUM_LAYERS = 32;
23
+ static const int MAX_LEN = 512;
24
+ static const float ATTENTION_MASK = -1000.;
25
+
26
+ static const std::string TOKENIZER_MODEL = "../model/tokenizer.model";
27
+
28
+ // #define EXPORT_RESULTS
29
+ #ifdef EXPORT_RESULTS
30
+ #include "cnpy.h"
31
+ static cnpy::npz_t map;
32
+
33
+ template <typename T>
34
+ static void add_array(std::string name, bm_handle_t bm_handle,
35
+ const bm_device_mem_t &dst) {
36
+ std::vector<T> data(dst.size / sizeof(T));
37
+ bm_memcpy_d2s(bm_handle, data.data(), dst);
38
+ cnpy::npz_add_array(map, name, data);
39
+ }
40
+
41
+ static void save_array(std::string filename) {
42
+ cnpy::npz_save_all(filename, map);
43
+ }
44
+ #endif
45
+
46
+ class Baichuan2 {
47
+ public:
48
+ void init(const std::vector<int> &devid, std::string model);
49
+ void chat();
50
+ void deinit();
51
+
52
+ private:
53
+ void answer(const std::string &input_str);
54
+ int forward_first(std::vector<int> &tokens);
55
+ int forward_next();
56
+ void load_sentencepiece();
57
+
58
+ private:
59
+ std::vector<bm_handle_t> handles;
60
+ bm_handle_t bm_handle;
61
+ void *p_bmrt;
62
+ sentencepiece::SentencePieceProcessor sentencepiece;
63
+ const bm_net_info_t *net_blocks[NUM_LAYERS];
64
+ const bm_net_info_t *net_blocks_cache[NUM_LAYERS];
65
+ const bm_net_info_t *net_embed;
66
+ const bm_net_info_t *net_embed_cache;
67
+ const bm_net_info_t *net_lm;
68
+ bm_tensor_t inputs_embed_512, outputs_embed_512;
69
+ bm_tensor_t inputs_lm, outputs_lm;
70
+ bm_tensor_t inputs_pid, next_pid, inputs_attention, next_attention;
71
+ bm_tensor_t past_key[NUM_LAYERS], past_value[NUM_LAYERS];
72
+ bm_tensor_t present_key[NUM_LAYERS], present_value[NUM_LAYERS];
73
+ bm_tensor_t present_key_cache, present_value_cache;
74
+ std::string name_embed;
75
+ std::string name_embed_cache;
76
+ std::string name_lm;
77
+ std::string name_blocks[NUM_LAYERS];
78
+ std::string name_blocks_cache[NUM_LAYERS];
79
+ int round = 0;
80
+ int token_length;
81
+ int EOS;
82
+ std::vector<std::string> history;
83
+ };
84
+
85
+ void Baichuan2::load_sentencepiece() {
86
+ printf("Load %s ... ", TOKENIZER_MODEL.c_str());
87
+ auto status = sentencepiece.Load(TOKENIZER_MODEL);
88
+ if (!status.ok()) {
89
+ std::cout << status.ToString() << std::endl;
90
+ exit(-1);
91
+ }
92
+ EOS = sentencepiece.eos_id();
93
+ printf("Done!\n");
94
+ }
95
+
96
+ void Baichuan2::init(const std::vector<int> &devices, std::string model) {
97
+ load_sentencepiece();
98
+ // request bm_handle
99
+ std::cout << "Device [ ";
100
+ for (auto d : devices) {
101
+ std::cout << d << " ";
102
+ }
103
+ std::cout << "] loading ....\n";
104
+ // int device_num = devices.size();
105
+ for (auto d : devices) {
106
+ bm_handle_t h;
107
+ bm_status_t status = bm_dev_request(&h, d);
108
+ assert(BM_SUCCESS == status);
109
+ handles.push_back(h);
110
+ }
111
+ bm_handle = handles[0];
112
+ // create bmruntime
113
+ p_bmrt = bmrt_create(bm_handle);
114
+ assert(NULL != p_bmrt);
115
+
116
+ // load bmodel by file
117
+ printf("Model[%s] loading ....\n", model.c_str());
118
+ bool ret = bmrt_load_bmodel(p_bmrt, model.c_str());
119
+ assert(true == ret);
120
+ printf("Done!\n");
121
+ // net names
122
+ name_embed = "embedding";
123
+ name_embed_cache = "embedding_cache";
124
+ name_lm = "lm_head";
125
+ for (int i = 0; i < NUM_LAYERS; i++) {
126
+ name_blocks[i] = "block_" + std::to_string(i);
127
+ name_blocks_cache[i] = "block_cache_" + std::to_string(i);
128
+ }
129
+
130
+ // net infos
131
+ net_embed = bmrt_get_network_info(p_bmrt, name_embed.c_str());
132
+ net_embed_cache = bmrt_get_network_info(p_bmrt, name_embed_cache.c_str());
133
+ net_lm = bmrt_get_network_info(p_bmrt, name_lm.c_str());
134
+ for (int i = 0; i < NUM_LAYERS; i++) {
135
+ net_blocks[i] = bmrt_get_network_info(p_bmrt, name_blocks[i].c_str());
136
+ net_blocks_cache[i] =
137
+ bmrt_get_network_info(p_bmrt, name_blocks_cache[i].c_str());
138
+ }
139
+
140
+ // net device mem
141
+ ret = bmrt_tensor(&inputs_embed_512, p_bmrt, net_embed->input_dtypes[0],
142
+ net_embed->stages[0].input_shapes[0]);
143
+ assert(true == ret);
144
+
145
+ ret = bmrt_tensor(&outputs_embed_512, p_bmrt, net_embed->output_dtypes[0],
146
+ net_embed->stages[0].output_shapes[0]);
147
+ assert(true == ret);
148
+
149
+ ret = bmrt_tensor(&inputs_pid, p_bmrt, net_blocks[0]->input_dtypes[1],
150
+ net_blocks[0]->stages[0].input_shapes[1]);
151
+ assert(true == ret);
152
+
153
+ ret = bmrt_tensor(&inputs_attention, p_bmrt, net_blocks[0]->input_dtypes[2],
154
+ net_blocks[0]->stages[0].input_shapes[2]);
155
+ assert(true == ret);
156
+
157
+ ret = bmrt_tensor(&next_pid, p_bmrt, net_blocks_cache[0]->input_dtypes[1],
158
+ net_blocks_cache[0]->stages[0].input_shapes[1]);
159
+ assert(true == ret);
160
+
161
+ ret =
162
+ bmrt_tensor(&next_attention, p_bmrt, net_blocks_cache[0]->input_dtypes[2],
163
+ net_blocks_cache[0]->stages[0].input_shapes[2]);
164
+ assert(true == ret);
165
+
166
+ for (int i = 0; i < NUM_LAYERS; i++) {
167
+ ret = bmrt_tensor(&past_key[i], p_bmrt, net_blocks[0]->output_dtypes[1],
168
+ net_blocks[0]->stages[0].output_shapes[1]);
169
+ assert(true == ret);
170
+ ret = bmrt_tensor(&past_value[i], p_bmrt, net_blocks[0]->output_dtypes[2],
171
+ net_blocks[0]->stages[0].output_shapes[2]);
172
+ assert(true == ret);
173
+ ret = bmrt_tensor(&present_key[i], p_bmrt, net_blocks[0]->output_dtypes[1],
174
+ net_blocks[0]->stages[0].output_shapes[1]);
175
+ assert(true == ret);
176
+ ret = bmrt_tensor(&present_value[i], p_bmrt, net_blocks[0]->output_dtypes[2],
177
+ net_blocks[0]->stages[0].output_shapes[2]);
178
+ assert(true == ret);
179
+ }
180
+ ret = bmrt_tensor(&present_key_cache, p_bmrt, net_blocks_cache[0]->output_dtypes[1],
181
+ net_blocks_cache[0]->stages[0].output_shapes[1]);
182
+ assert(true == ret);
183
+ ret = bmrt_tensor(&present_value_cache, p_bmrt, net_blocks_cache[0]->output_dtypes[2],
184
+ net_blocks_cache[0]->stages[0].output_shapes[2]);
185
+ assert(true == ret);
186
+
187
+ ret = bmrt_tensor(&inputs_lm, p_bmrt, net_lm->input_dtypes[0],
188
+ net_lm->stages[0].input_shapes[0]);
189
+ assert(true == ret);
190
+ ret = bmrt_tensor(&outputs_lm, p_bmrt, net_lm->output_dtypes[0],
191
+ net_lm->stages[0].output_shapes[0]);
192
+ assert(true == ret);
193
+ }
194
+
195
+ void Baichuan2::deinit() {
196
+ bm_free_device(bm_handle, inputs_embed_512.device_mem);
197
+ bm_free_device(bm_handle, outputs_embed_512.device_mem);
198
+ bm_free_device(bm_handle, inputs_lm.device_mem);
199
+ bm_free_device(bm_handle, outputs_lm.device_mem);
200
+ bm_free_device(bm_handle, inputs_pid.device_mem);
201
+ bm_free_device(bm_handle, next_pid.device_mem);
202
+ bm_free_device(bm_handle, inputs_attention.device_mem);
203
+ bm_free_device(bm_handle, next_attention.device_mem);
204
+ bm_free_device(bm_handle, present_key_cache.device_mem);
205
+ bm_free_device(bm_handle, present_value_cache.device_mem);
206
+ for (int i = 0; i < NUM_LAYERS; i++) {
207
+ bm_free_device(bm_handle, past_key[i].device_mem);
208
+ bm_free_device(bm_handle, past_value[i].device_mem);
209
+ bm_free_device(bm_handle, present_key[i].device_mem);
210
+ bm_free_device(bm_handle, present_value[i].device_mem);
211
+ }
212
+ bmrt_destroy(p_bmrt);
213
+ for (auto h : handles) {
214
+ bm_dev_free(h);
215
+ }
216
+ }
217
+
218
+ int Baichuan2::forward_first(std::vector<int> &tokens) {
219
+ int input_ids[MAX_LEN] = {0}; // start token
220
+ int position_id[MAX_LEN] = {0};
221
+ float attention_mask[MAX_LEN * MAX_LEN] = {0};
222
+ token_length = tokens.size();
223
+
224
+ std::copy(tokens.begin(), tokens.end(), input_ids);
225
+ for (int i = 0; i < token_length; i++) {
226
+ position_id[i] = i;
227
+ }
228
+
229
+ for (int i = 0; i < MAX_LEN; i++) {
230
+ for (int j = 0; j < MAX_LEN; j++) {
231
+ if (j <= i && i < token_length) {
232
+ } else {
233
+ attention_mask[i * MAX_LEN + j] = ATTENTION_MASK;
234
+ }
235
+ }
236
+ }
237
+
238
+ // forward embeding
239
+ bm_memcpy_s2d(bm_handle, inputs_embed_512.device_mem, (void *)input_ids);
240
+ auto ret =
241
+ bmrt_launch_tensor_ex(p_bmrt, name_embed.c_str(), &inputs_embed_512, 1,
242
+ &outputs_embed_512, 1, true, false);
243
+ assert(ret);
244
+ // float test_embed[MAX_LEN] = {0};
245
+ // bm_memcpy_d2s(bm_handle, (void *)&test_embed, outputs_embed_512.device_mem);
246
+ bm_thread_sync(bm_handle);
247
+
248
+ // forward blocks
249
+ bm_memcpy_s2d(bm_handle, inputs_pid.device_mem, (void *)position_id);
250
+ bm_memcpy_s2d(bm_handle, inputs_attention.device_mem, (void *)attention_mask);
251
+ auto inputs_embed = outputs_embed_512;
252
+ inputs_embed.shape = net_blocks[0]->stages[0].input_shapes[0];
253
+ bm_tensor_t inputs_block[3] = {inputs_embed, inputs_pid, inputs_attention};
254
+ for (int i = 0; i < NUM_LAYERS; i++) {
255
+ bm_tensor_t outputs_block[3] = {inputs_embed, past_key[i], past_value[i]};
256
+ ret = bmrt_launch_tensor_ex(p_bmrt, name_blocks[i].c_str(), inputs_block, 3,
257
+ outputs_block, 3, true, false);
258
+ assert(ret);
259
+ bm_thread_sync(bm_handle);
260
+ }
261
+ int bytes = inputs_embed.device_mem.size / MAX_LEN;
262
+ bm_memcpy_d2d_byte(bm_handle, inputs_lm.device_mem, 0,
263
+ inputs_embed.device_mem, (token_length - 1) * bytes,
264
+ bytes);
265
+ ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm, 1,
266
+ &outputs_lm, 1, true, false);
267
+ bm_thread_sync(bm_handle);
268
+
269
+ int token = 0;
270
+ bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm.device_mem);
271
+ return token;
272
+ }
273
+
274
+ int Baichuan2::forward_next() {
275
+ float attention_mask[MAX_LEN + 1] = {0};
276
+ for (int i = token_length - 1; i < MAX_LEN; i++) {
277
+ attention_mask[i] = ATTENTION_MASK;
278
+ }
279
+ int32_t position_id = token_length - 1;
280
+ // embedding
281
+ outputs_lm.shape = net_embed_cache->stages[0].input_shapes[0];
282
+ auto ret = bmrt_launch_tensor_ex(p_bmrt, name_embed_cache.c_str(), &outputs_lm, 1,
283
+ &inputs_lm, 1, true, false);
284
+ assert(ret);
285
+ bm_thread_sync(bm_handle);
286
+
287
+ // blocks
288
+ bm_memcpy_s2d(bm_handle, next_attention.device_mem, (void *)attention_mask);
289
+ bm_memcpy_s2d(bm_handle, next_pid.device_mem, (void *)&position_id);
290
+ auto inputs_embed = inputs_lm;
291
+ inputs_embed.shape = net_blocks_cache[0]->stages[0].input_shapes[0];
292
+ int bytes = bm_mem_get_device_size(present_key_cache.device_mem);
293
+ int token_offset = (token_length - 1) * bytes;
294
+ for (int i = 0; i < NUM_LAYERS; i++) {
295
+ bm_tensor_t inputs_block[5] = {inputs_embed, next_pid, next_attention,
296
+ past_key[i], past_value[i]};
297
+ bm_tensor_t outputs_block[3] = {inputs_embed, present_key_cache, present_value_cache};
298
+ ret = bmrt_launch_tensor_ex(p_bmrt, name_blocks_cache[i].c_str(),
299
+ inputs_block, 5, outputs_block, 3, true, false);
300
+ assert(ret);
301
+ bm_thread_sync(bm_handle);
302
+ bm_memcpy_d2d_byte(bm_handle, past_key[i].device_mem, token_offset,
303
+ present_key_cache.device_mem, 0,
304
+ bytes);
305
+ bm_memcpy_d2d_byte(bm_handle, past_value[i].device_mem, token_offset,
306
+ present_value_cache.device_mem, 0,
307
+ bytes);
308
+ }
309
+ outputs_lm.shape = net_lm->stages[0].output_shapes[0];
310
+ ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm, 1,
311
+ &outputs_lm, 1, true, false);
312
+ bm_thread_sync(bm_handle);
313
+
314
+ int token = 0;
315
+ bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm.device_mem);
316
+ return token;
317
+ }
318
+
319
+ void Baichuan2::chat() {
320
+ while (true) {
321
+ std::cout << "\nQuestion: ";
322
+ std::string input_str;
323
+ std::getline(std::cin, input_str);
324
+ std::string user_token = "<reserved_106>"; //user token id 195
325
+ std::string assitant_token = "<reserved_107>"; //assistant token id 196
326
+ if (input_str == "exit") {
327
+ break;
328
+ }
329
+ if (input_str == "clear") {
330
+ history.clear();
331
+ continue;
332
+ }
333
+
334
+ input_str = user_token + input_str + assitant_token;
335
+
336
+ std::cout << "\nAnswer: " << std::flush;
337
+ answer(input_str);
338
+ std::cout << std::endl;
339
+ }
340
+ }
341
+
342
+ void Baichuan2::answer(const std::string &input_str) {
343
+ int tok_num = 0;
344
+ history.emplace_back(std::move(input_str));
345
+
346
+ std::vector<int> tokens;
347
+
348
+ std::string history_input = std::accumulate(history.begin(), history.end(), std::string());
349
+ sentencepiece.Encode(history_input, &tokens);
350
+
351
+ if (tokens.empty()) {
352
+ printf("Sorry: your question is too wierd!!\n");
353
+ history.clear();
354
+ round = 0;
355
+ return;
356
+ }
357
+ // make sure token not too large
358
+ if (tokens.size() > MAX_LEN - 10) {
359
+ // reset
360
+ if (round == 0) {
361
+ printf("Error: your question is too large!\n");
362
+ return;
363
+ }
364
+ round = 0;
365
+ history.clear();
366
+ answer(input_str);
367
+ return;
368
+ }
369
+ auto time_1 = std::chrono::system_clock::now();
370
+ int pre_token = 0;
371
+ int token = forward_first(tokens);
372
+ auto time_2 = std::chrono::system_clock::now();
373
+ std::string result;
374
+ while (token != EOS && token_length < MAX_LEN) {
375
+ std::string pre_word;
376
+ std::string word;
377
+ std::vector<int> pre_ids = {pre_token};
378
+ std::vector<int> ids = {pre_token, token};
379
+ sentencepiece.Decode(pre_ids, &pre_word);
380
+ sentencepiece.Decode(ids, &word);
381
+ std::string diff = word.substr(pre_word.size());
382
+ result += diff;
383
+ std::cout << diff << std::flush;
384
+ if (token_length < MAX_LEN) {
385
+ token_length++;
386
+ }
387
+ tok_num++;
388
+ token = forward_next();
389
+ }
390
+ auto time_3 = std::chrono::system_clock::now();
391
+ auto ftl_dur =
392
+ std::chrono::duration_cast<std::chrono::microseconds>(time_2 - time_1);
393
+ auto tps_dur =
394
+ std::chrono::duration_cast<std::chrono::microseconds>(time_3 - time_2);
395
+ double tps = tok_num / (tps_dur.count() * 1e-6);
396
+ if (token_length >= MAX_LEN) {
397
+ printf(" ......\nWarning: cleanup early history\n");
398
+ }
399
+ // double tht = tokens.size() / (tht_dur.count() * 1e-6);
400
+ printf("\nFTL:%f s, TPS: %f tokens/s\n", ftl_dur.count() * 1e-6, tps);
401
+ history.emplace_back(result);
402
+ if (token_length + 128 >= MAX_LEN) {
403
+ int num = (history.size() + 3) / 4 * 2;
404
+ history.erase(history.begin(), history.begin() + num);
405
+ }
406
+ }
407
+
408
+ static void split(const std::string &s, const std::string &delim,
409
+ std::vector<std::string> &ret) {
410
+ size_t last = 0;
411
+ size_t index = s.find_first_of(delim, last);
412
+ while (index != std::string::npos) {
413
+ ret.push_back(s.substr(last, index - last));
414
+ last = index + 1;
415
+ index = s.find_first_of(delim, last);
416
+ }
417
+ if (last < s.length()) {
418
+ ret.push_back(s.substr(last));
419
+ }
420
+ }
421
+
422
+ static std::vector<int> parseCascadeDevices(const std::string &str) {
423
+ std::vector<int> devices;
424
+ std::vector<std::string> sub_str;
425
+ split(str, ",", sub_str);
426
+ for (auto &s : sub_str) {
427
+ devices.push_back(std::atoi(s.c_str()));
428
+ }
429
+ return devices;
430
+ }
431
+
432
+ void processArguments(int argc, char *argv[], std::string &baichuan_model,
433
+ std::vector<int> &devices) {
434
+ struct option longOptions[] = {{"model", required_argument, nullptr, 'm'},
435
+ {"dev_id", required_argument, nullptr, 'd'},
436
+ {nullptr, 0, nullptr, 0}};
437
+
438
+ int optionIndex = 0;
439
+ int option;
440
+
441
+ while ((option = getopt_long(argc, argv, "m:d:", longOptions,
442
+ &optionIndex)) != -1) {
443
+ switch (option) {
444
+ case 'm':
445
+ baichuan_model = optarg;
446
+ break;
447
+ case 'd':
448
+ devices = parseCascadeDevices(optarg);
449
+ break;
450
+ case '?':
451
+ exit(EXIT_FAILURE);
452
+ default:
453
+ exit(EXIT_FAILURE);
454
+ }
455
+ }
456
+ }
457
+
458
+ int main(int argc, char **argv) {
459
+ // set your bmodel path here
460
+ printf("Demo for Baichuan2-7B in BM1684X\n");
461
+ std::string baichuan_model = "baichuan2-7b-test.bmodel";
462
+ std::vector<int> devices = {0};
463
+ processArguments(argc, argv, baichuan_model, devices);
464
+
465
+ Baichuan2 baichuan;
466
+ printf("Init Environment ...\n");
467
+ baichuan.init(devices, baichuan_model);
468
+ printf("==========================\n");
469
+ baichuan.chat();
470
+ baichuan.deinit();
471
+ return 0;
472
+ }
Baichuan2/model/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79452955be6b419a65984273a9f08af86042e1c2a75ee3ba989cbf620a133cc2
3
+ size 2001107
Baichuan2/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.1.2
2
+ transformers==4.36.2
3
+ sentencepiece==0.1.99
4
+ gradio==3.39.0
5
+ mdtex2html==1.2.0
6
+ accelerate
7
+ onnx
Baichuan2/src/include/bmdef.h ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*****************************************************************************
2
+ *
3
+ * Copyright (c) 2016-2026 by Sophgo Technologies Inc. All rights reserved.
4
+ *
5
+ * The material in this file is confidential and contains trade secrets
6
+ * of Sophgo Technologies Inc. This is proprietary information owned by
7
+ * Sophgo Technologies Inc. No part of this work may be disclosed,
8
+ * reproduced, copied, transmitted, or used in any way for any purpose,
9
+ * without the express written permission of Sophgo Technologies Inc.
10
+ *
11
+ *****************************************************************************/
12
+
13
+ #ifndef __BMRUNTIME_DEFINE_H__
14
+ #define __BMRUNTIME_DEFINE_H__
15
+
16
+ #include "bmlib_runtime.h"
17
+ #include <stddef.h>
18
+ #include <stdint.h>
19
+
20
+ #if defined(__cplusplus)
21
+ extern "C" {
22
+ #endif
23
+
24
+ /* --------------------------------------------------------------------------*/
25
+ /* basic definitions */
26
+
27
+ /* bm_data_type_t holds the type for a scalar value */
28
+ typedef enum bm_data_type_e {
29
+ BM_FLOAT32 = 0,
30
+ BM_FLOAT16 = 1,
31
+ BM_INT8 = 2,
32
+ BM_UINT8 = 3,
33
+ BM_INT16 = 4,
34
+ BM_UINT16 = 5,
35
+ BM_INT32 = 6,
36
+ BM_UINT32 = 7,
37
+ BM_BFLOAT16 = 8,
38
+ BM_INT4 = 9,
39
+ BM_UINT4 = 10,
40
+ } bm_data_type_t;
41
+
42
+ /* store mode definitions */
43
+ typedef enum bm_store_mode_e {
44
+ BM_STORE_1N = 0, /* default, if not sure, use 0 */
45
+ BM_STORE_2N = 1,
46
+ BM_STORE_4N = 2,
47
+ } bm_store_mode_t;
48
+
49
+ /* bm_shape_t holds the shape info */
50
+ #define BM_MAX_DIMS_NUM 8
51
+ typedef struct bm_shape_s {
52
+ int num_dims;
53
+ int dims[BM_MAX_DIMS_NUM];
54
+ } bm_shape_t;
55
+
56
+ typedef struct bm_shape_ex_s {
57
+ bm_shape_t shape;
58
+ int elem_num;
59
+ } bm_shape_ex_t;
60
+
61
+ /*
62
+ bm_tensor_t holds a multi-dimensional array of elements of a single data type
63
+ and tensor are in device memory */
64
+ typedef struct bm_tensor_s {
65
+ bm_data_type_t dtype;
66
+ bm_shape_t shape;
67
+ bm_device_mem_t device_mem;
68
+ bm_store_mode_t st_mode; /* user can set 0 as default store mode */
69
+ } bm_tensor_t;
70
+
71
+ /* --------------------------------------------------------------------------*/
72
+ /* network information structure */
73
+
74
+ /* bm_stage_info_t holds input/output shapes and device mems; every network can contain one or more
75
+ * stages */
76
+ typedef struct bm_stage_info_s {
77
+ bm_shape_t *input_shapes; /* input_shapes[0] / [1] / ... / [input_num-1] */
78
+ bm_shape_t *output_shapes; /* output_shapes[0] / [1] / ... / [output_num-1] */
79
+ bm_device_mem_t *input_mems; /* input_mems[0] / [1] / ... / [input_num-1] */
80
+ bm_device_mem_t *output_mems; /* output_mems[0] / [1] / ... / [output_num-1] */
81
+ } bm_stage_info_t;
82
+
83
+ /* bm_tensor_info_t holds all information of one net.
84
+ * scale for float type is 1.0 as default */
85
+ typedef struct bm_net_info_s {
86
+ const char* name; /* net name */
87
+ bool is_dynamic; /* dynamic or static */
88
+ int input_num; /* number of inputs */
89
+ char const** input_names; /* input_names[0] / [1] / .../ [input_num-1] */
90
+ bm_data_type_t* input_dtypes; /* input_dtypes[0] / [1] / .../ [input_num-1] */
91
+ float* input_scales; /* input_scales[0] / [1] / .../ [input_num-1] */
92
+ int output_num; /* number of outputs */
93
+ char const** output_names; /* output_names[0] / [1] / .../ [output_num-1] */
94
+ bm_data_type_t* output_dtypes; /* output_dtypes[0] / [1] / .../ [output_num-1] */
95
+ float* output_scales; /* output_scales[0] / [1] / .../ [output_num-1] */
96
+ int stage_num; /* number of stages */
97
+ bm_stage_info_t* stages; /* stages[0] / [1] / ... / [stage_num-1] */
98
+ size_t* max_input_bytes; /* max_input_bytes[0]/ [1] / ... / [input_num-1] */
99
+ size_t* max_output_bytes; /* max_output_bytes[0] / [1] / ... / [output_num-1] */
100
+ int* input_zero_point; /* input_zero_point[0] / [1] / .../ [input_num-1] */
101
+ int* output_zero_point; /* output_zero_point[0] / [1] / .../ [output_num-1] */
102
+ int *input_loc_devices; /* input_loc_device[0] / [1] / .../ [input_num-1] */
103
+ int *output_loc_devices; /* output_loc_device[0] / [1] / .../ [output_num-1] */
104
+ } bm_net_info_t;
105
+
106
+ typedef struct api_info_s {
107
+ /// @brief api_id to be sent to driver
108
+ int32_t api_id;
109
+ /// @brief api data to be sent to driver
110
+ uint8_t **api_data;
111
+ /// @brief size of the api data to be sent to driver
112
+ size_t api_data_size;
113
+ /// @brief subsize of the api data to be sent to driver
114
+ size_t *api_data_subsize;
115
+ /// @brief offset of input tensors' addr in api_data
116
+ uint32_t *input_addr_offset;
117
+ /// @brief number of the offset of input tensors' addr in api_data
118
+ size_t input_addr_offset_number;
119
+ /// @brief offset of output tensors' addr in api_data
120
+ uint32_t *output_addr_offset;
121
+ /// @brief number of the offset of output tensors' addr in api_data
122
+ size_t output_addr_offset_number;
123
+ } api_info_c;
124
+
125
+ #if defined(__cplusplus)
126
+ }
127
+ #endif
128
+
129
+ #endif /* __BM_NET_H__ */
Baichuan2/src/include/bmlib_runtime.h ADDED
@@ -0,0 +1,2581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*****************************************************************************
2
+ *
3
+ * Copyright (c) 2016-2026 by Bitmain Technologies Inc. All rights reserved.
4
+ *
5
+ * The material in this file is confidential and contains trade secrets
6
+ * of Bitmain Technologies Inc. This is proprietary information owned by
7
+ * Bitmain Technologies Inc. No part of this work may be disclosed,
8
+ * reproduced, copied, transmitted, or used in any way for any purpose,
9
+ * without the express written permission of Bitmain Technologies Inc.
10
+ *
11
+ *****************************************************************************/
12
+
13
+ /**************************************************************************
14
+ * bmlib_runtime defines interfaces that operate TPU devices.
15
+ * The functions can be divided into serveral categories.
16
+ * 1) device handle creation and destroy
17
+ * 2) memory help functions
18
+ * 3) global memory allocation and free
19
+ * 4) data transfer between host and device
20
+ * 5) data transfer within device memory
21
+ * 6) api send and synchronization
22
+ * 7) global memory map and coherence
23
+ * 8) trace and profile
24
+ * 9) power management
25
+ * 10) miscellaneous functions
26
+ *************************************************************************/
27
+
28
+ #ifndef BMLIB_RUNTIME_H_
29
+ #define BMLIB_RUNTIME_H_
30
+ #if defined(_WIN32) && !defined(__MINGW32__)
31
+ #include <vadefs.h>
32
+ #define DECL_EXPORT __declspec(dllexport)
33
+ #define DECL_IMPORT __declspec(dllimport)
34
+ #else
35
+ #include <stdbool.h>
36
+ #include <stddef.h>
37
+ #include <stdarg.h>
38
+ #define DECL_EXPORT
39
+ #define DECL_IMPORT
40
+ #endif
41
+
42
+ #if defined(__cplusplus)
43
+ extern "C" {
44
+ #endif
45
+
46
+ typedef enum {
47
+ MODULE_CDMA = 0,
48
+ MODULE_GDMA = 1,
49
+ MODULE_TPU = 2,
50
+ MODULE_SMMU = 3,
51
+ MODULE_SRAM = 4,
52
+ MODULE_END = 5
53
+ } MODULE_ID;
54
+
55
+ #define BM_MEM_ADDR_NULL (0xfffffffff)
56
+
57
+ #ifndef BM_MEM_DESC_T_
58
+ #define BM_MEM_DESC_T_
59
+ /* BM function return code definitions */
60
+ typedef enum {
61
+ BM_SUCCESS = 0,
62
+ BM_ERR_DEVNOTREADY = 1, /* Device not ready yet */
63
+ BM_ERR_FAILURE = 2, /* General failure */
64
+ BM_ERR_TIMEOUT = 3, /* Timeout */
65
+ BM_ERR_PARAM = 4, /* Parameters invalid */
66
+ BM_ERR_NOMEM = 5, /* Not enough memory */
67
+ BM_ERR_DATA = 6, /* Data error */
68
+ BM_ERR_BUSY = 7, /* Busy */
69
+ BM_ERR_NOFEATURE = 8, /* Not supported yet */
70
+ BM_NOT_SUPPORTED = 9
71
+ } bm_status_t;
72
+
73
+ /* BM memory type definitions */
74
+ typedef enum {
75
+ BM_MEM_TYPE_DEVICE = 0,
76
+ BM_MEM_TYPE_HOST = 1,
77
+ BM_MEM_TYPE_SYSTEM = 2,
78
+ BM_MEM_TYPE_INT8_DEVICE = 3,
79
+ BM_MEM_TYPE_INVALID = 4
80
+ } bm_mem_type_t;
81
+
82
+ typedef enum {
83
+ PERF_MONITOR_GDMA = 0,
84
+ PERF_MONITOR_TPU = 1
85
+ } PERF_MONITOR_ID;
86
+
87
+ typedef enum {
88
+ BMCPU_IDLE = 0,
89
+ BMCPU_RUNNING = 1,
90
+ BMCPU_FAULT = 2
91
+ } bm_cpu_status_t;
92
+
93
+ /*
94
+ * bm performace monitor
95
+ */
96
+ typedef struct bm_perf_monitor {
97
+ long long buffer_start_addr; /*buffer address to store perf data*/
98
+ int buffer_size; /*buffer size*/
99
+ PERF_MONITOR_ID monitor_id; /*PERF_MONITOR_GDMA or PERF_MONITOR_TPU*/
100
+ } bm_perf_monitor_t;
101
+
102
+ typedef union {
103
+ struct {
104
+ bm_mem_type_t mem_type : 3;
105
+ unsigned int gmem_heapid : 3;
106
+ unsigned int reserved : 26;
107
+ } u;
108
+ unsigned int rawflags;
109
+ } bm_mem_flags_t;
110
+
111
+ /* BM memory descriptor definition*/
112
+ typedef struct bm_mem_desc {
113
+ union {
114
+ struct {
115
+ #ifdef __linux__
116
+ unsigned long device_addr;
117
+ #else
118
+ unsigned long long device_addr;
119
+ #endif
120
+ unsigned int reserved;
121
+ int dmabuf_fd;
122
+ } device;
123
+
124
+ struct {
125
+ void *system_addr;
126
+ unsigned int reserved0;
127
+ int reserved1;
128
+ } system;
129
+ } u;
130
+
131
+ bm_mem_flags_t flags;
132
+ unsigned int size;
133
+ } bm_mem_desc_t;
134
+
135
+ typedef struct bm_mem_desc bm_device_mem_t;
136
+ typedef struct bm_mem_desc bm_system_mem_t;
137
+
138
+ typedef struct sg_mem_desc {
139
+ union {
140
+ struct {
141
+ #ifdef __linux__
142
+ unsigned long device_addr;
143
+ #else
144
+ unsigned long long device_addr;
145
+ #endif
146
+ unsigned int reserved;
147
+ int dmabuf_fd;
148
+ } device;
149
+
150
+ struct {
151
+ void *system_addr;
152
+ unsigned int reserved0;
153
+ int reserved1;
154
+ } system;
155
+ } u;
156
+
157
+ bm_mem_flags_t flags;
158
+ unsigned long long size;
159
+ } sg_mem_desc_t;
160
+
161
+ typedef struct sg_mem_desc sg_device_mem_t;
162
+ typedef struct sg_mem_desc sg_system_mem_t;
163
+ #endif
164
+
165
+ struct bm_context;
166
+ typedef struct bm_context *bm_handle_t;
167
+
168
+ #define MD5SUM_LEN 16
169
+ #define LIB_MAX_NAME_LEN 64
170
+ #define FUNC_MAX_NAME_LEN 64
171
+
172
+ typedef struct bm_module
173
+ {
174
+ // void *lib_handle;
175
+ char lib_name[LIB_MAX_NAME_LEN];
176
+ unsigned char md5[MD5SUM_LEN];
177
+ }bm_module;
178
+
179
+ typedef struct bm_module *tpu_kernel_module_t;
180
+ typedef int tpu_kernel_function_t;
181
+
182
+ /**
183
+ * @name tpu_kernel_load_module_file
184
+ * @brief To load dyn file
185
+ * @ingroup bmlib_runtime
186
+ *
187
+ * @param [in] handle The device handle
188
+ * @param [in] module_file dyn file
189
+ * @retval dyn lib ptr
190
+ */
191
+ tpu_kernel_module_t tpu_kernel_load_module_file(bm_handle_t handle, const char *module_file);
192
+
193
+ /**
194
+ * @name tpu_kernel_load_module_file_key
195
+ * @brief To load dyn file with key
196
+ * @ingroup bmlib_runtime
197
+ *
198
+ * @param [in] handle The device handle
199
+ * @param [in] module_file dyn file
200
+ * @param [in] key identification str
201
+ * @param [in] size key size
202
+ * @retval dyn lib ptr
203
+ */
204
+ tpu_kernel_module_t tpu_kernel_load_module_file_key(bm_handle_t handle, const char *module_file, const char *key, int size);
205
+
206
+ /**
207
+ * @name tpu_kernel_unload_module
208
+ * @brief To unload dyn file
209
+ * @ingroup bmlib_runtime
210
+ *
211
+ * @param [in] handle The device handle
212
+ * @param [in] p_module dyn lib ptr
213
+ * @retval BM_SUCCESS Succeeds.
214
+ * Other code Fails.
215
+ */
216
+ bm_status_t tpu_kernel_unload_module(bm_handle_t handle, tpu_kernel_module_t p_module);
217
+
218
+ /**
219
+ * @name tpu_kernel_free_module
220
+ * @brief To free p_module when not use
221
+ * @ingroup bmlib_runtime
222
+ *
223
+ * @param [in] handle The device handle
224
+ * @param [in] p_module dyn lib ptr
225
+ * @retval BM_SUCCESS Succeeds.
226
+ * Other code Fails.
227
+ */
228
+ bm_status_t tpu_kernel_free_module(bm_handle_t handle, tpu_kernel_module_t p_module);
229
+
230
+ /**
231
+ * @name tpu_kernel_load_module
232
+ * @brief To load dyn module
233
+ * @ingroup bmlib_runtime
234
+ *
235
+ * @param [in] handle The device handle
236
+ * @param [in] data dyn module
237
+ * @param [in] length dyn module size
238
+ * @retval dyn lib ptr
239
+ */
240
+ tpu_kernel_module_t tpu_kernel_load_module(bm_handle_t handle, const char *data, size_t length);
241
+
242
+ /**
243
+ * @name tpu_kernel_get_function
244
+ * @brief To get function from lib
245
+ * @ingroup bmlib_runtime
246
+ *
247
+ * @param [in] handle The device handle
248
+ * @param [in] module dyn module
249
+ * @param [in] function funtion name
250
+ * @retval function id
251
+ */
252
+ tpu_kernel_function_t tpu_kernel_get_function(bm_handle_t handle, tpu_kernel_module_t module, const char *function);
253
+
254
+ /**
255
+ * @name tpu_kernel_launch
256
+ * @brief To launch function with sync
257
+ * @ingroup bmlib_runtime
258
+ *
259
+ * @param [in] handle The device handle
260
+ * @param [in] function function id
261
+ * @param [in] args funtion args
262
+ * @param [in] size args size
263
+ * @retval BM_SUCCESS Succeeds.
264
+ * Other code Fails.
265
+ */
266
+ bm_status_t tpu_kernel_launch(bm_handle_t handle, tpu_kernel_function_t function, void *args, size_t size);
267
+
268
+ /**
269
+ * @name tpu_kernel_launch_async
270
+ * @brief To launch function with async
271
+ * @ingroup bmlib_runtime
272
+ *
273
+ * @param [in] handle The device handle
274
+ * @param [in] function function id
275
+ * @param [in] args funtion args
276
+ * @param [in] size args size
277
+ * @retval BM_SUCCESS Succeeds.
278
+ * Other code Fails.
279
+ */
280
+ bm_status_t tpu_kernel_launch_async(bm_handle_t handle, tpu_kernel_function_t function, void *args, size_t size);
281
+
282
+ /**
283
+ * @name tpu_kernel_launch_async_multi_cores
284
+ * @brief To launch function with async for multi cores
285
+ * @ingroup bmlib_runtime
286
+ *
287
+ * @param [in] handle The device handle
288
+ * @param [in] func_name function name
289
+ * @param [in] api_param funtion params
290
+ * @param [in] api_size params size
291
+ * @param [in] core_list list of core ids
292
+ * @param [in] core_num number of cores
293
+ * @retval BM_SUCCESS Succeeds.
294
+ * Other code Fails.
295
+ */
296
+ bm_status_t tpu_kernel_launch_async_multi_cores(bm_handle_t handle, const char *func_name, const void *api_param,
297
+ size_t api_size, const int* core_list, const int core_num);
298
+
299
+ /**
300
+ * @name tpu_kernel_launch_sync_multi_cores
301
+ * @brief To launch function with sync for multi cores
302
+ * @ingroup bmlib_runtime
303
+ *
304
+ * @param [in] handle The device handle
305
+ * @param [in] func_name function name
306
+ * @param [in] api_param funtion params
307
+ * @param [in] api_size params size
308
+ * @param [in] core_list list of core ids
309
+ * @param [in] core_num number of cores
310
+ * @retval BM_SUCCESS Succeeds.
311
+ * Other code Fails.
312
+ */
313
+ bm_status_t tpu_kernel_launch_sync_multi_cores(bm_handle_t handle, const char *func_name, const void *api_param,
314
+ size_t api_size, const int* core_list, const int core_num);
315
+
316
+ /**
317
+ * @name tpu_kernel_sync
318
+ * @brief To sync
319
+ * @ingroup bmlib_runtime
320
+ *
321
+ * @param [in] handle The device handle
322
+ * @retval BM_SUCCESS Succeeds.
323
+ * Other code Fails.
324
+ */
325
+ bm_status_t tpu_kernel_sync(bm_handle_t handle);
326
+ void show_md5(unsigned char md5[]);
327
+
328
+ DECL_EXPORT void bmlib_log(const char *tag, int level, const char *fmt, ...);
329
+
330
+ #ifndef USING_CMODEL
331
+ #define BM_CHECK_RET(call) \
332
+ do { \
333
+ bm_status_t ret = (bm_status_t)call; \
334
+ if (ret != BM_SUCCESS) { \
335
+ bmlib_log("BM_CHECK",16,"BM_CHECK_RET fail %s: %s: %d\n", __FILE__, __func__, __LINE__); \
336
+ return ret; \
337
+ } \
338
+ } while (0)
339
+ #else
340
+ #define BM_CHECK_RET(call) \
341
+ do { \
342
+ bm_status_t ret = call; \
343
+ if (ret != BM_SUCCESS) { \
344
+ bmlib_log("BM_CHECK",16,"BM_CHECK_RET failed %d\n", ret);\
345
+ ASSERT(0); \
346
+ exit(-ret); \
347
+ } \
348
+ } while (0)
349
+ #endif
350
+
351
+ /*******************handle releated functions *********************************/
352
+ /**
353
+ * @name bm_dev_getcount
354
+ * @brief To get the number of sophon devices in system.
355
+ * If N is got, valid devid is [0, N-1]
356
+ * @ingroup bmlib_runtime
357
+ *
358
+ * @param [out] count The result number of sophon devices
359
+ * @retval BM_SUCCESS Succeeds.
360
+ * Other code Fails.
361
+ */
362
+ DECL_EXPORT bm_status_t bm_dev_getcount(int *count);
363
+
364
+ /**
365
+ * @name bm_dev_query
366
+ * @brief To query if a device is present
367
+ * @ingroup bmlib_runtime
368
+ *
369
+ * @param [in] devid The id of the device to query
370
+ * @retval BM_SUCCESS Device is present
371
+ * Other code Devcie is not present
372
+ */
373
+ DECL_EXPORT bm_status_t bm_dev_query(int devid);
374
+
375
+ /**
376
+ * @name bm_dev_request
377
+ * @brief To create a handle for the given device
378
+ * @ingroup bmlib_runtime
379
+ *
380
+ * @param [out] handle The created handle
381
+ * @param [in] devid Specify on which device to create handle
382
+ * @retval BM_SUCCESS Succeeds.
383
+ * Other code Fails.
384
+ */
385
+ DECL_EXPORT bm_status_t bm_dev_request(bm_handle_t *handle, int devid);
386
+
387
+ /**
388
+ * @name bm_get_devid
389
+ * @brief To get device index for the given handle
390
+ * @ingroup bmlib_runtime
391
+ *
392
+ * @param [in] handle The given handle
393
+ * @retval int device index that the handle points to.
394
+ */
395
+ DECL_EXPORT int bm_get_devid(bm_handle_t handle);
396
+
397
+ /**
398
+ * @name bm_dev_free
399
+ * @brief To free a handle
400
+ * @ingroup bmlib_runtime
401
+ *
402
+ * @param [in] handle The handle to free
403
+ */
404
+ DECL_EXPORT void bm_dev_free(bm_handle_t handle);
405
+
406
+ /*******************memory help functions ************************************/
407
+ /**
408
+ * @name bm_mem_get_type
409
+ * @brief To get a memory descriptor's type
410
+ * @ingroup bmlib_runtime
411
+ *
412
+ * @param [in] mem The memory descriptor queried
413
+ * @retval BM_MEM_TYPE_DEVICE Device global memory
414
+ * @retval BM_MEM_TYPE_SYSTEM Host user memory
415
+ */
416
+ DECL_EXPORT bm_mem_type_t bm_mem_get_type(struct bm_mem_desc mem);
417
+
418
+ /**
419
+ * @name sg_mem_get_type
420
+ * @brief To get a memory descriptor's type
421
+ * @ingroup bmlib_runtime
422
+ *
423
+ * @param [in] mem The memory descriptor queried
424
+ * @retval BM_MEM_TYPE_DEVICE Device global memory
425
+ * @retval BM_MEM_TYPE_SYSTEM Host user memory
426
+ */
427
+ DECL_EXPORT bm_mem_type_t sg_mem_get_type(struct sg_mem_desc mem);
428
+
429
+ /**
430
+ * @name bm_mem_get_device_addr
431
+ * @brief To get a device memory descriptor's address
432
+ * @ingroup bmlib_runtime
433
+ *
434
+ * @param [in] mem The device memory descriptor queried
435
+ * @retval unsigned long long The device memory address
436
+ */
437
+ DECL_EXPORT unsigned long long bm_mem_get_device_addr(struct bm_mem_desc mem);
438
+
439
+ /**
440
+ * @name sg_mem_get_device_addr
441
+ * @brief To get a device memory descriptor's address
442
+ * @ingroup bmlib_runtime
443
+ *
444
+ * @param [in] mem The device memory descriptor queried
445
+ * @retval unsigned long long The device memory address
446
+ */
447
+ DECL_EXPORT unsigned long long sg_mem_get_device_addr(struct sg_mem_desc mem);
448
+
449
+ /**
450
+ * @name bm_mem_set_device_addr
451
+ * @brief To set a device memory descriptor's address
452
+ * @ingroup bmlib_runtime
453
+ *
454
+ * @param [in] pmem The device memory descriptor pointer
455
+ * @param ]in] addr The new device address of the device memory
456
+ */
457
+ DECL_EXPORT void bm_mem_set_device_addr(struct bm_mem_desc* pmem, unsigned long long addr);
458
+
459
+ /**
460
+ * @name sg_mem_set_device_addr
461
+ * @brief To set a device memory descriptor's address
462
+ * @ingroup bmlib_runtime
463
+ *
464
+ * @param [in] pmem The device memory descriptor pointer
465
+ * @param ]in] addr The new device address of the device memory
466
+ */
467
+ DECL_EXPORT void sg_mem_set_device_addr(struct sg_mem_desc* pmem, unsigned long long addr);
468
+
469
+ /**
470
+ * @name bm_mem_get_device_size
471
+ * @brief To get a device memory descriptor's size
472
+ * @ingroup bmlib_runtime
473
+ *
474
+ * @param [in] mem The device memory descriptor queried
475
+ * @retval unsigned int The device memory's size in bytes
476
+ */
477
+ DECL_EXPORT unsigned int bm_mem_get_device_size(struct bm_mem_desc mem);
478
+
479
+ /**
480
+ * @name sg_mem_get_device_size
481
+ * @brief To get a device memory descriptor's size
482
+ * @ingroup bmlib_runtime
483
+ *
484
+ * @param [in] mem The device memory descriptor queried
485
+ * @retval unsigned int The device memory's size in bytes
486
+ */
487
+ DECL_EXPORT unsigned long long sg_mem_get_device_size(struct sg_mem_desc mem);
488
+
489
+ /**
490
+ * @name bm_mem_set_device_size
491
+ * @brief To set a device memory descriptor's size
492
+ * @ingroup bmlib_runtime
493
+ *
494
+ * @param [out] pmem The device memory descriptor pointer
495
+ * @param [in] size The new device memory size (in bytes) of the device memory
496
+ */
497
+ DECL_EXPORT void bm_mem_set_device_size(struct bm_mem_desc* pmem, unsigned int size);
498
+
499
+ /**
500
+ * @name sg_mem_set_device_size
501
+ * @brief To set a device memory descriptor's size
502
+ * @ingroup bmlib_runtime
503
+ *
504
+ * @param [out] pmem The device memory descriptor pointer
505
+ * @param [in] size The new device memory size (in bytes) of the device memory
506
+ */
507
+ DECL_EXPORT void sg_mem_set_device_size(struct sg_mem_desc* pmem, unsigned long long size);
508
+
509
+ /**
510
+ * @name bm_set_device_mem
511
+ * @brief To fill in a device memory descriptor with size and address
512
+ * @ingroup bmlib_runtime
513
+ *
514
+ * @param [in] pmem The device memory descriptor pointer
515
+ * @param [in] size The device memory descriptor's size
516
+ * @param [in] addr The device memory descriptor's address
517
+ */
518
+ DECL_EXPORT void bm_set_device_mem(bm_device_mem_t* pmem, unsigned int size,
519
+ unsigned long long addr);
520
+
521
+ /**
522
+ * @name sg_set_device_mem
523
+ * @brief To fill in a device memory descriptor with size and address
524
+ * @ingroup bmlib_runtime
525
+ *
526
+ * @param [in] pmem The device memory descriptor pointer
527
+ * @param [in] size The device memory descriptor's size
528
+ * @param [in] addr The device memory descriptor's address
529
+ */
530
+ DECL_EXPORT void sg_set_device_mem(sg_device_mem_t* pmem, unsigned long long size,
531
+ unsigned long long addr);
532
+
533
+ /**
534
+ * @name bm_mem_from_device
535
+ * @brief To create a device memory descriptor from address and size
536
+ * @ingroup bmlib_runtime
537
+ *
538
+ * @param [in] device_addr The device memory address
539
+ * @param [in] len The device memory size
540
+ * @retval bm_device_mem_t The device memory descriptor created
541
+ */
542
+ DECL_EXPORT bm_device_mem_t bm_mem_from_device(unsigned long long device_addr,
543
+ unsigned int len);
544
+
545
+ /**
546
+ * @name sg_mem_from_device
547
+ * @brief To create a device memory descriptor from address and size
548
+ * @ingroup bmlib_runtime
549
+ *
550
+ * @param [in] device_addr The device memory address
551
+ * @param [in] len The device memory size
552
+ * @retval bm_device_mem_t The device memory descriptor created
553
+ */
554
+ DECL_EXPORT sg_device_mem_t sg_mem_from_device(unsigned long long device_addr,
555
+ unsigned long long len);
556
+
557
+ /**
558
+ * @name bm_mem_get_system_addr
559
+ * @brief To get a system memory descriptor's address
560
+ * @ingroup bmlib_runtime
561
+ *
562
+ * @param [in] mem The system memory descriptor
563
+ * @retval void * The system memory descriptor's address
564
+ */
565
+ DECL_EXPORT void *bm_mem_get_system_addr(struct bm_mem_desc mem);
566
+
567
+ /**
568
+ * @name sg_mem_get_system_addr
569
+ * @brief To get a system memory descriptor's address
570
+ * @ingroup bmlib_runtime
571
+ *
572
+ * @param [in] mem The system memory descriptor
573
+ * @retval void * The system memory descriptor's address
574
+ */
575
+ DECL_EXPORT void *sg_mem_get_system_addr(struct sg_mem_desc mem);
576
+
577
+ /**
578
+ * @name bm_mem_set_system_addr
579
+ * @brief To set a system memory descriptor's address
580
+ * @ingroup bmlib_runtime
581
+ *
582
+ * @param [in] pmem The system memory descriptor pointer
583
+ * @param [in] addr The system memory address
584
+ */
585
+ DECL_EXPORT void bm_mem_set_system_addr(struct bm_mem_desc* pmem, void *addr);
586
+
587
+ /**
588
+ * @name sg_mem_set_system_addr
589
+ * @brief To set a system memory descriptor's address
590
+ * @ingroup bmlib_runtime
591
+ *
592
+ * @param [in] pmem The system memory descriptor pointer
593
+ * @param [in] addr The system memory address
594
+ */
595
+ DECL_EXPORT void sg_mem_set_system_addr(struct sg_mem_desc* pmem, void *addr);
596
+
597
+ /**
598
+ * @name bm_mem_from_system
599
+ * @brief To create a system memory descriptor with the given system address
600
+ * @ingroup bmlib_runtime
601
+ *
602
+ * @param [in] system_addr The system address in the descriptor
603
+ * @retval bm_system_mem_t The system memory descriptor created
604
+ */
605
+ DECL_EXPORT bm_system_mem_t bm_mem_from_system(void *system_addr);
606
+
607
+ /*******************memory alloc and free functions ***************************/
608
+ /**
609
+ * @name bm_mem_null
610
+ * @brief Return an illegal device memory descriptor
611
+ * @ingroup bmlib_runtime
612
+ *
613
+ * @retval bm_device_mem_t An invalid device memory descriptor
614
+ */
615
+ DECL_EXPORT bm_device_mem_t bm_mem_null(void);
616
+ #define BM_MEM_NULL (bm_mem_null())
617
+
618
+ /**
619
+ * @name bm_malloc_neuron_device
620
+ * @brief To malloc device memory according to a tensor shape
621
+ * (each neuron is 32 bits)
622
+ * @ingroup bmlib_runtime
623
+ *
624
+ * @param [in] handle The device handle
625
+ * @param [out] pmem The result devcie memory descriptor
626
+ * @param [in] n, c, h, w The shape of the input tensor
627
+ * @retval BM_SUCCESS Succeeds.
628
+ * Other code Fails.
629
+ */
630
+ DECL_EXPORT bm_status_t bm_malloc_neuron_device(bm_handle_t handle, bm_device_mem_t *pmem,
631
+ int n, int c, int h, int w);
632
+
633
+ /**
634
+ * @name sg_malloc_neuron_device
635
+ * @brief To malloc device memory according to a tensor shape
636
+ * (each neuron is 32 bits)
637
+ * @ingroup bmlib_runtime
638
+ *
639
+ * @param [in] handle The device handle
640
+ * @param [out] pmem The result devcie memory descriptor
641
+ * @param [in] n, c, h, w The shape of the input tensor
642
+ * @retval BM_SUCCESS Succeeds.
643
+ * Other code Fails.
644
+ */
645
+ DECL_EXPORT bm_status_t sg_malloc_neuron_device(bm_handle_t handle, sg_device_mem_t *pmem,
646
+ unsigned long long n, unsigned long long c,
647
+ unsigned long long h, unsigned long long w);
648
+
649
+ /**
650
+ * @name bm_malloc_device_dword
651
+ * @brief To malloc device memory in size of dword (32 bits)
652
+ * @ingroup bmlib_runtime
653
+ *
654
+ * @param [in] handle The device handle
655
+ * @param [out] pmem The result device memory descriptor
656
+ * @param [in] count The number of dwords(32bits) to allocate
657
+ * @retval BM_SUCCESS Succeeds.
658
+ * Other code Fails.
659
+ */
660
+ DECL_EXPORT bm_status_t bm_malloc_device_dword(bm_handle_t handle, bm_device_mem_t *pmem,
661
+ int count);
662
+
663
+ /**
664
+ * @name sg_malloc_device_dword
665
+ * @brief To malloc device memory in size of dword (32 bits)
666
+ * @ingroup bmlib_runtime
667
+ *
668
+ * @param [in] handle The device handle
669
+ * @param [out] pmem The result device memory descriptor
670
+ * @param [in] count The number of dwords(32bits) to allocate
671
+ * @retval BM_SUCCESS Succeeds.
672
+ * Other code Fails.
673
+ */
674
+ DECL_EXPORT bm_status_t sg_malloc_device_dword(bm_handle_t handle, sg_device_mem_t *pmem,
675
+ unsigned long long count);
676
+
677
+ /**
678
+ * @name bm_malloc_device_byte
679
+ * @brief To malloc device memory in size of byte
680
+ * @ingroup bmlib_runtime
681
+ *
682
+ * @param [in] handle The device handle
683
+ * @param [out] pmem The result device memory descriptor
684
+ * @param [in] size The number of bytes to allocate
685
+ * @retval BM_SUCCESS Succeeds.
686
+ * Other code Fails.
687
+ */
688
+ DECL_EXPORT bm_status_t bm_malloc_device_byte(bm_handle_t handle, bm_device_mem_t *pmem,
689
+ unsigned int size);
690
+
691
+ /**
692
+ * @name sg_malloc_device_byte
693
+ * @brief To malloc device memory in size of byte
694
+ * @ingroup bmlib_runtime
695
+ *
696
+ * @param [in] handle The device handle
697
+ * @param [out] pmem The result device memory descriptor
698
+ * @param [in] size The number of bytes to allocate
699
+ * @retval BM_SUCCESS Succeeds.
700
+ * Other code Fails.
701
+ */
702
+ DECL_EXPORT bm_status_t sg_malloc_device_byte(bm_handle_t handle, sg_device_mem_t *pmem,
703
+ unsigned long long size);
704
+
705
+ /**
706
+ * @name bm_malloc_device_byte_heap
707
+ * @brief To malloc device memory in size of byte within the specified heap
708
+ * @ingroup bmlib_runtime
709
+ *
710
+ * @param [in] handle The device handle
711
+ * @param [out] pmem The result device memory descriptor
712
+ * @param [in] heap_id The heap where to allocate 0/1/2
713
+ * @param [in] size The number of bytes to allocate
714
+ * @retval BM_SUCCESS Succeeds.
715
+ * Other code Fails.
716
+ */
717
+ DECL_EXPORT bm_status_t bm_malloc_device_byte_heap(bm_handle_t handle, bm_device_mem_t *pmem,
718
+ int heap_id, unsigned int size);
719
+
720
+ /**
721
+ * @name sg_malloc_device_byte_heap
722
+ * @brief To malloc device memory in size of byte within the specified heap
723
+ * @ingroup bmlib_runtime
724
+ *
725
+ * @param [in] handle The device handle
726
+ * @param [out] pmem The result device memory descriptor
727
+ * @param [in] heap_id The heap where to allocate 0/1/2
728
+ * @param [in] size The number of bytes to allocate
729
+ * @retval BM_SUCCESS Succeeds.
730
+ * Other code Fails.
731
+ */
732
+ DECL_EXPORT bm_status_t sg_malloc_device_byte_heap(bm_handle_t handle, sg_device_mem_t *pmem,
733
+ int heap_id, unsigned long long size);
734
+
735
+ /**
736
+ * @name bm_malloc_device_byte_heap_mask
737
+ * @brief To malloc device memory in size of byte within the specified heaps
738
+ * @ingroup bmlib_runtime
739
+ *
740
+ * @param [in] handle The device handle
741
+ * @param [out] pmem The result device memory descriptor
742
+ * @param [in] heap_id_mask The mask which heaps allocate from. each bit indicate one heap
743
+ * @param [in] size The number of bytes to allocate
744
+ * @retval BM_SUCCESS Succeeds.
745
+ * Other code Fails.
746
+ */
747
+ DECL_EXPORT bm_status_t bm_malloc_device_byte_heap_mask(bm_handle_t handle, bm_device_mem_t *pmem,
748
+ int heap_id_mask, unsigned int size);
749
+
750
+ /**
751
+ * @name sg_malloc_device_byte_heap_mask
752
+ * @brief To malloc device memory in size of byte within the specified heaps
753
+ * @ingroup bmlib_runtime
754
+ *
755
+ * @param [in] handle The device handle
756
+ * @param [out] pmem The result device memory descriptor
757
+ * @param [in] heap_id_mask The mask which heaps allocate from. each bit indicate one heap
758
+ * @param [in] size The number of bytes to allocate
759
+ * @retval BM_SUCCESS Succeeds.
760
+ * Other code Fails.
761
+ */
762
+ DECL_EXPORT bm_status_t sg_malloc_device_byte_heap_mask(bm_handle_t handle, sg_device_mem_t *pmem,
763
+ int heap_id_mask, unsigned long long size);
764
+
765
+ /**
766
+ * @name bm_free_device
767
+ * @brief To free device memory
768
+ * @ingroup bmlib_runtime
769
+ *
770
+ * @param [in] handle The device handle
771
+ * @param [in] mem The device memory descriptor to free
772
+ */
773
+ DECL_EXPORT void bm_free_device(bm_handle_t handle, bm_device_mem_t mem);
774
+
775
+ /**
776
+ * @name sg_free_device
777
+ * @brief To free device memory
778
+ * @ingroup bmlib_runtime
779
+ *
780
+ * @param [in] handle The device handle
781
+ * @param [in] mem The device memory descriptor to free
782
+ */
783
+ DECL_EXPORT void sg_free_device(bm_handle_t handle, sg_device_mem_t mem);
784
+
785
+ /**
786
+ * @name bm_gmem_arm_reserved_request
787
+ * @brief To obtain the address of global memory reserved for arm926
788
+ * @param [in] handle The device handle
789
+ *
790
+ * @retval unsigned long long The absolute address of gmem reserved for arm926
791
+ */
792
+ DECL_EXPORT unsigned long long bm_gmem_arm_reserved_request(bm_handle_t handle);
793
+
794
+ /**
795
+ * @name bm_gmem_arm_reserved_release
796
+ * @brief To release the global memory reserved for arm926
797
+ * @ingroup bmlib_runtime
798
+ *
799
+ * @param [in] handle The device handle
800
+ */
801
+ DECL_EXPORT void bm_gmem_arm_reserved_release(bm_handle_t handle);
802
+
803
+ /*******************memory copy functions *************************************/
804
+ /**
805
+ * @name bm_memcpy_s2d
806
+ * @brief To copy data from system memory to device memory
807
+ * @ingroup bmlib_runtime
808
+ *
809
+ * @param [in] handle The device handle
810
+ * @param [in] dst The destination memory (device memory descriptor )
811
+ * @param [in] src The source memory (system memory, a void* pointer)
812
+ *
813
+ * @retval BM_SUCCESS Succeeds.
814
+ * Other code Fails.
815
+ */
816
+ DECL_EXPORT bm_status_t bm_memcpy_s2d(bm_handle_t handle, bm_device_mem_t dst, void *src);
817
+
818
+ /**
819
+ * @name bm_memcpy_p2p
820
+ * @brief To copy data from one chip to another chip
821
+ * @ingroup bmlib_runtime
822
+ *
823
+ * @param [in] handle_src The source device handle
824
+ * @param [in] src The source memory (device memory descriptor )
825
+ * @param [in] handle_dst The destination device handle
826
+ * @param [in] dst The destination memory (device memory descriptor )
827
+ *
828
+ * @retval BM_SUCCESS Succeeds.
829
+ * Other code Fails.
830
+ */
831
+ DECL_EXPORT bm_status_t bm_memcpy_p2p(bm_handle_t handle_src, bm_device_mem_t src, bm_handle_t handle_dst,bm_device_mem_t dst);
832
+
833
+ /**
834
+ * @name sg_memcpy_s2d
835
+ * @brief To copy data from system memory to device memory
836
+ * @ingroup bmlib_runtime
837
+ *
838
+ * @param [in] handle The device handle
839
+ * @param [in] dst The destination memory (device memory descriptor )
840
+ * @param [in] src The source memory (system memory, a void* pointer)
841
+ *
842
+ * @retval BM_SUCCESS Succeeds.
843
+ * Other code Fails.
844
+ */
845
+ DECL_EXPORT bm_status_t sg_memcpy_s2d(bm_handle_t handle, sg_device_mem_t dst, void *src);
846
+
847
+ /**
848
+ * @name bm_memcpy_s2d_partial_offset
849
+ * @brief To copy specified bytes of data from system memory to device memory
850
+ * with an offset in device memory address.
851
+ * @ingroup bmlib_runtime
852
+ *
853
+ * @param [in] handle The device handle
854
+ * @param [in] dst The destination memory (device memory descriptor)
855
+ * @param [in] src The source memory (system memory, a void* pointer)
856
+ * @param [in] size The size of data to copy (in bytes)
857
+ * @param [in] offset The offset of the device memory address
858
+ *
859
+ * @retval BM_SUCCESS Succeeds.
860
+ * Other code Fails.
861
+ */
862
+ DECL_EXPORT bm_status_t bm_memcpy_s2d_partial_offset(bm_handle_t handle,
863
+ bm_device_mem_t dst, void *src,
864
+ unsigned int size,
865
+ unsigned int offset);
866
+
867
+ /**
868
+ * @name sg_memcpy_s2d_partial_offset
869
+ * @brief To copy specified bytes of data from system memory to device memory
870
+ * with an offset in device memory address.
871
+ * @ingroup bmlib_runtime
872
+ *
873
+ * @param [in] handle The device handle
874
+ * @param [in] dst The destination memory (device memory descriptor)
875
+ * @param [in] src The source memory (system memory, a void* pointer)
876
+ * @param [in] size The size of data to copy (in bytes)
877
+ * @param [in] offset The offset of the device memory address
878
+ *
879
+ * @retval BM_SUCCESS Succeeds.
880
+ * Other code Fails.
881
+ */
882
+ DECL_EXPORT bm_status_t sg_memcpy_s2d_partial_offset(bm_handle_t handle,
883
+ sg_device_mem_t dst, void *src,
884
+ unsigned long long size,
885
+ unsigned long long offset);
886
+
887
+ /**
888
+ * @name bm_memcpy_s2d_partial
889
+ * @brief To copy specified bytes of data from system memory to device memory
890
+ * @ingroup bmlib_runtime
891
+ *
892
+ * @param [in] handle The device handle
893
+ * @param [in] dst The destination memory (device memory descriptor)
894
+ * @param [in] src The source memory (system memory, a void* pointer)
895
+ * @param [in] size The size of data to copy (in bytes)
896
+ *
897
+ * @retval BM_SUCCESS Succeeds.
898
+ * Other code Fails.
899
+ */
900
+ DECL_EXPORT bm_status_t bm_memcpy_s2d_partial(bm_handle_t handle, bm_device_mem_t dst,
901
+ void *src, unsigned int size);
902
+
903
+ /**
904
+ * @name sg_memcpy_s2d_partial
905
+ * @brief To copy specified bytes of data from system memory to device memory
906
+ * @ingroup bmlib_runtime
907
+ *
908
+ * @param [in] handle The device handle
909
+ * @param [in] dst The destination memory (device memory descriptor)
910
+ * @param [in] src The source memory (system memory, a void* pointer)
911
+ * @param [in] size The size of data to copy (in bytes)
912
+ *
913
+ * @retval BM_SUCCESS Succeeds.
914
+ * Other code Fails.
915
+ */
916
+ DECL_EXPORT bm_status_t sg_memcpy_s2d_partial(bm_handle_t handle, sg_device_mem_t dst,
917
+ void *src, unsigned long long size);
918
+
919
+ /**
920
+ * @name bm_memcpy_d2s
921
+ * @brief To copy data from device memory to system memory
922
+ * @ingroup bmlib_runtime
923
+ *
924
+ * @param [in] handle The device handle
925
+ * @param [in] dst The destination memory (system memory, a void* pointer)
926
+ * @param [in] src The source memory (device memory descriptor)
927
+ *
928
+ * @retval BM_SUCCESS Succeeds.
929
+ * Other code Fails.
930
+ */
931
+ DECL_EXPORT bm_status_t bm_memcpy_d2s(bm_handle_t handle, void *dst, bm_device_mem_t src);
932
+
933
+ /**
934
+ * @name sg_memcpy_d2s
935
+ * @brief To copy data from device memory to system memory
936
+ * @ingroup bmlib_runtime
937
+ *
938
+ * @param [in] handle The device handle
939
+ * @param [in] dst The destination memory (system memory, a void* pointer)
940
+ * @param [in] src The source memory (device memory descriptor)
941
+ *
942
+ * @retval BM_SUCCESS Succeeds.
943
+ * Other code Fails.
944
+ */
945
+ DECL_EXPORT bm_status_t sg_memcpy_d2s(bm_handle_t handle, void *dst, sg_device_mem_t src);
946
+
947
+ /**
948
+ * @name bm_memcpy_d2s_partial_offset
949
+ * @brief To copy specified bytes of data from device memory to system memory
950
+ * with an offset in device memory address.
951
+ * @ingroup bmlib_runtime
952
+ *
953
+ * @param [in] handle The device handle
954
+ * @param [in] dst The destination memory (system memory, a void* pointer)
955
+ * @param [in] src The source memory (device memory descriptor)
956
+ * @param [in] size The size of data to copy (in bytes)
957
+ * @param [in] offset The offset of the device memory address
958
+ *
959
+ * @retval BM_SUCCESS Succeeds.
960
+ * Other code Fails.
961
+ */
962
+ DECL_EXPORT bm_status_t bm_memcpy_d2s_partial_offset(bm_handle_t handle, void *dst,
963
+ bm_device_mem_t src, unsigned int size,
964
+ unsigned int offset);
965
+
966
+ /**
967
+ * @name sg_memcpy_d2s_partial_offset
968
+ * @brief To copy specified bytes of data from device memory to system memory
969
+ * with an offset in device memory address.
970
+ * @ingroup bmlib_runtime
971
+ *
972
+ * @param [in] handle The device handle
973
+ * @param [in] dst The destination memory (system memory, a void* pointer)
974
+ * @param [in] src The source memory (device memory descriptor)
975
+ * @param [in] size The size of data to copy (in bytes)
976
+ * @param [in] offset The offset of the device memory address
977
+ *
978
+ * @retval BM_SUCCESS Succeeds.
979
+ * Other code Fails.
980
+ */
981
+ DECL_EXPORT bm_status_t sg_memcpy_d2s_partial_offset(bm_handle_t handle, void *dst,
982
+ sg_device_mem_t src, unsigned long long size,
983
+ unsigned long long offset);
984
+
985
+ /**
986
+ * @name bm_memcpy_d2s_partial
987
+ * @brief To copy specified bytes of data from device memory to system memory
988
+ * @ingroup bmlib_runtime
989
+ *
990
+ * @param [in] handle The device handle
991
+ * @param [in] dst The destination memory (system memory, a void* pointer)
992
+ * @param [in] src The source memory (device memory descriptor)
993
+ * @param [in] size The size of data to copy (in bytes)
994
+ *
995
+ * @retval BM_SUCCESS Data transfer succeeds.
996
+ * Other code Data transfer fails.
997
+ */
998
+ DECL_EXPORT bm_status_t bm_memcpy_d2s_partial(bm_handle_t handle, void *dst,
999
+ bm_device_mem_t src, unsigned int size);
1000
+
1001
+ /**
1002
+ * @name sg_memcpy_d2s_partial
1003
+ * @brief To copy specified bytes of data from device memory to system memory
1004
+ * @ingroup bmlib_runtime
1005
+ *
1006
+ * @param [in] handle The device handle
1007
+ * @param [in] dst The destination memory (system memory, a void* pointer)
1008
+ * @param [in] src The source memory (device memory descriptor)
1009
+ * @param [in] size The size of data to copy (in bytes)
1010
+ *
1011
+ * @retval BM_SUCCESS Data transfer succeeds.
1012
+ * Other code Data transfer fails.
1013
+ */
1014
+ DECL_EXPORT bm_status_t sg_memcpy_d2s_partial(bm_handle_t handle, void *dst,
1015
+ sg_device_mem_t src, unsigned long long size);
1016
+
1017
+ /**
1018
+ * @name bm_memcpy_d2d
1019
+ * @brief To copy specified dwords of data from one piece of device memory
1020
+ * to another piece of device memory within one device. Both source
1021
+ * and destination offsets can be specified.
1022
+ * @ingroup bmlib_runtime
1023
+ *
1024
+ * @param [in] handle The device handle
1025
+ * @param [in] dst The destination device memory
1026
+ * @param [in] dst_offset The offset of destination device memory address
1027
+ * @param [in] src The source device memory
1028
+ * @param [in] src_offset The offset of source device memory address
1029
+ * @param [in] len Length of data to copy (in DWORD 4 bytes)
1030
+ *
1031
+ * @retval BM_SUCCESS Succeeds.
1032
+ * Other code Fails.
1033
+ */
1034
+ DECL_EXPORT bm_status_t bm_memcpy_d2d(bm_handle_t handle, bm_device_mem_t dst,
1035
+ int dst_offset, bm_device_mem_t src, int src_offset,
1036
+ int len);
1037
+
1038
+ /**
1039
+ * @name bm_memcpy_d2d_with_core
1040
+ * @brief To copy specified dwords of data from one piece of device memory
1041
+ * to another piece of device memory within one device. Both source
1042
+ * and destination offsets can be specified.
1043
+ * @ingroup bmlib_runtime
1044
+ *
1045
+ * @param [in] handle The device handle
1046
+ * @param [in] dst The destination device memory
1047
+ * @param [in] dst_offset The offset of destination device memory address
1048
+ * @param [in] src The source device memory
1049
+ * @param [in] src_offset The offset of source device memory address
1050
+ * @param [in] len Length of data to copy (in DWORD 4 bytes)
1051
+ * @param [in] core_id The core id to copy
1052
+ *
1053
+ * @retval BM_SUCCESS Succeeds.
1054
+ * Other code Fails.
1055
+ */
1056
+ DECL_EXPORT bm_status_t bm_memcpy_d2d_with_core(bm_handle_t handle, bm_device_mem_t dst,
1057
+ int dst_offset, bm_device_mem_t src, int src_offset,
1058
+ int len, int core_id);
1059
+
1060
+ /**
1061
+ * @name bm_memcpy_d2d_byte
1062
+ * @brief To copy specified bytes of data from one piece of device memory
1063
+ * to another piece of device memory within one device. Both source
1064
+ * and destination offsets can be specified.
1065
+ * @ingroup bmlib_runtime
1066
+ *
1067
+ * @param [in] handle The device handle
1068
+ * @param [in] dst The destination device memory
1069
+ * @param [in] dst_offset The offset of destination device memory address (in bytes)
1070
+ * @param [in] src The source device memory
1071
+ * @param [in] src_offset The offset of source device memory address (in bytes)
1072
+ * @param [in] size Size of data to copy (in bytes)
1073
+ *
1074
+ * @retval BM_SUCCESS Succeeds.
1075
+ * Other code Fails.
1076
+ */
1077
+ DECL_EXPORT bm_status_t bm_memcpy_d2d_byte(bm_handle_t handle, bm_device_mem_t dst,
1078
+ size_t dst_offset, bm_device_mem_t src,
1079
+ size_t src_offset, size_t size);
1080
+
1081
+ /**
1082
+ * @name bm_memcpy_d2d_byte_with_core
1083
+ * @brief To copy specified bytes of data from one piece of device memory
1084
+ * to another piece of device memory within one device. Both source
1085
+ * and destination offsets can be specified.
1086
+ * @ingroup bmlib_runtime
1087
+ *
1088
+ * @param [in] handle The device handle
1089
+ * @param [in] dst The destination device memory
1090
+ * @param [in] dst_offset The offset of destination device memory address (in bytes)
1091
+ * @param [in] src The source device memory
1092
+ * @param [in] src_offset The offset of source device memory address (in bytes)
1093
+ * @param [in] size Size of data to copy (in bytes)
1094
+ * @param [in] core_id The core id to copy
1095
+ *
1096
+ * @retval BM_SUCCESS Succeeds.
1097
+ * Other code Fails.
1098
+ */
1099
+ DECL_EXPORT bm_status_t bm_memcpy_d2d_byte_with_core(bm_handle_t handle, bm_device_mem_t dst,
1100
+ size_t dst_offset, bm_device_mem_t src,
1101
+ size_t src_offset, size_t size, int core_id);
1102
+
1103
+ /**
1104
+ * @name bm_memcpy_d2d_stride
1105
+ * @brief To copy specified data from one piece of device memory
1106
+ * to another piece of device memory within one device. Both source
1107
+ * and destination offsets can be specified.
1108
+ * @ingroup bmlib_runtime
1109
+ *
1110
+ * @param [in] handle The device handle
1111
+ * @param [in] dst The destination device memory
1112
+ * @param [in] dst_stride The data stride of destination data
1113
+ * @param [in] src The source device memory
1114
+ * @param [in] src_stride The data stride of source data
1115
+ * @param [in] count Count of data to copy
1116
+ * @param [in] format_size Data format byte size, such as sizeof(uint8_t), sizeof(float), etc.
1117
+ * format_size only support 1/2/4.
1118
+ *
1119
+ * dst_stride MUST be 1, EXCEPT: dst_stride == 4 && src_stride == 1 && format_size ==1
1120
+ *
1121
+ * @retval BM_SUCCESS Succeeds.
1122
+ * Other code Fails.
1123
+ */
1124
+ DECL_EXPORT bm_status_t bm_memcpy_d2d_stride(bm_handle_t handle,
1125
+ bm_device_mem_t dst,
1126
+ int dst_stride,
1127
+ bm_device_mem_t src,
1128
+ int src_stride,
1129
+ int count,
1130
+ int format_size);
1131
+
1132
+ /**
1133
+ * @name bm_memcpy_d2d_stride
1134
+ * @brief To copy specified data from one piece of device memory
1135
+ * to another piece of device memory within one device. Both source
1136
+ * and destination offsets can be specified.
1137
+ * @ingroup bmlib_runtime
1138
+ *
1139
+ * @param [in] handle The device handle
1140
+ * @param [in] dst The destination device memory
1141
+ * @param [in] dst_stride The data stride of destination data
1142
+ * @param [in] src The source device memory
1143
+ * @param [in] src_stride The data stride of source data
1144
+ * @param [in] count Count of data to copy
1145
+ * @param [in] format_size Data format byte size, such as sizeof(uint8_t), sizeof(float), etc.
1146
+ * format_size only support 1/2/4.
1147
+ * @param [in] core_id The core id to copy.
1148
+ *
1149
+ * dst_stride MUST be 1, EXCEPT: dst_stride == 4 && src_stride == 1 && format_size ==1
1150
+ *
1151
+ * @retval BM_SUCCESS Succeeds.
1152
+ * Other code Fails.
1153
+ */
1154
+ DECL_EXPORT bm_status_t bm_memcpy_d2d_stride_with_core(bm_handle_t handle,
1155
+ bm_device_mem_t dst,
1156
+ int dst_stride,
1157
+ bm_device_mem_t src,
1158
+ int src_stride,
1159
+ int count,
1160
+ int format_size,
1161
+ int core_id);
1162
+
1163
+ /**
1164
+ * @name bm_memcpy_c2c
1165
+ * @brief To copy data from one chip to another chip.
1166
+ * (Used in multi-chip card scenario)
1167
+ * @ingroup bmlib_runtime
1168
+ *
1169
+ * @param [in] src_handle The source device handle
1170
+ * @param [in] dst_handle The destination device handle
1171
+ * @param [in] src The source device memory descriptor
1172
+ * @param [in] dst The destination device memory descriptor
1173
+ * @param [in] force_dst_cdma If use the CDMA engine of the destination device
1174
+ * @retval BM_SUCCESS Succeeds.
1175
+ * Other code Fails.
1176
+ */
1177
+ DECL_EXPORT bm_status_t bm_memcpy_c2c(bm_handle_t src_handle, bm_handle_t dst_handle,
1178
+ bm_device_mem_t src, bm_device_mem_t dst,
1179
+ bool force_dst_cdma);
1180
+
1181
+ /**
1182
+ * @name bm_memset_device
1183
+ * @brief To fill in specified device memory with the given value
1184
+ * @ingroup bmlib_runtime
1185
+ *
1186
+ * @param [in] handle The device handle
1187
+ * @param [in] value The value used to fill. (int type)
1188
+ * @param [in] mem The device memory which will be filled in
1189
+ * @retval BM_SUCCESS Succeeds.
1190
+ * Other code Fails.
1191
+ */
1192
+ DECL_EXPORT bm_status_t bm_memset_device(bm_handle_t handle, const int value,
1193
+ bm_device_mem_t mem);
1194
+
1195
+ /**
1196
+ * @name bm_memset_device_ext
1197
+ * @brief To fill in specified device memory with the given value and mode
1198
+ * @ingroup bmlib_runtime
1199
+ *
1200
+ * @param [in] handle The device handle
1201
+ * @param [in] value The pointer of value used to fill
1202
+ * @param [in] mode The valid bytes of *value
1203
+ * @param [in] mem The device memory which will be filled in
1204
+ * @retval BM_SUCCESS Succeeds.
1205
+ * Other code Fails.
1206
+ */
1207
+ DECL_EXPORT bm_status_t bm_memset_device_ext(bm_handle_t handle, void* value, int mode,
1208
+ bm_device_mem_t mem);
1209
+
1210
+ /**
1211
+ * @name bm_mem_convert_system_to_device_neuron
1212
+ * @brief To malloc a piece of device memory according to the shape of
1213
+ * neuron(in DWORD 4 bytes); copy neuron from system memory to
1214
+ * device memory if need_copy is true.
1215
+ * @ingroup bmlib_runtime
1216
+ *
1217
+ * @param [in] handle The device handle
1218
+ * @param [in] dev_mem The device memory descriptor
1219
+ * @param [in] sys_mem The system memory descriptor
1220
+ * @param [in] need_copy If copy from system to device is needed
1221
+ * @param [in] n,c,h,w Neuron shape size
1222
+ *
1223
+ * @retval BM_SUCCESS Succeeds.
1224
+ * Other code Fails.
1225
+ */
1226
+ DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_neuron(bm_handle_t handle,
1227
+ struct bm_mem_desc *dev_mem,
1228
+ struct bm_mem_desc sys_mem,
1229
+ bool need_copy, int n, int c,
1230
+ int h, int w);
1231
+
1232
+ /**
1233
+ * @name bm_mem_convert_system_to_device_neuron_byte
1234
+ * @brief To malloc a piece of device memory according to the shape of
1235
+ * neuron(in bytes); copy neuron from system memory to
1236
+ * device memory if need_copy is true.
1237
+ * @ingroup bmlib_runtime
1238
+ *
1239
+ * @param [in] handle The device handle
1240
+ * @param [in] dev_mem The device memory descriptor
1241
+ * @param [in] sys_mem The system memory descriptor
1242
+ * @param [in] need_copy If copy from system to device is needed
1243
+ * @param [in] n,c,h,w Neuron shape size
1244
+ *
1245
+ * @retval BM_SUCCESS Succeeds.
1246
+ * Other code Fails.
1247
+ */
1248
+ DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_neuron_byte(
1249
+ bm_handle_t handle, struct bm_mem_desc *dev_mem, struct bm_mem_desc sys_mem,
1250
+ bool need_copy, int n, int c, int h, int w);
1251
+
1252
+ /**
1253
+ * @name bm_mem_convert_system_to_device_coeff
1254
+ * @brief To malloc a piece of device memory according to the size of
1255
+ * coefficient (in DWORD 4 bytes); copy coefficient from system
1256
+ * memory to device memory if need_copy is true.
1257
+ * @ingroup bmlib_runtime
1258
+ *
1259
+ * @param [in] handle The device handle
1260
+ * @param [in] dev_mem The device memory descriptor
1261
+ * @param [in] sys_mem The system memory descriptor
1262
+ * @param [in] need_copy If copy from system to device is needed
1263
+ * @param [in] coeff_count Coefficient size
1264
+ *
1265
+ * @retval BM_SUCCESS Succeeds.
1266
+ * Other code Fails.
1267
+ */
1268
+ DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_coeff(bm_handle_t handle,
1269
+ struct bm_mem_desc *dev_mem,
1270
+ struct bm_mem_desc sys_mem,
1271
+ bool need_copy,
1272
+ int coeff_count);
1273
+ /**
1274
+ * @name bm_mem_convert_system_to_device_coeff_byte
1275
+ * @brief To malloc a piece of device memory according to the size of
1276
+ * coefficient (in bytes); copy coefficient from system
1277
+ * memory to device memory if need_copy is true.
1278
+ * @ingroup bmlib_runtime
1279
+ *
1280
+ * @param [in] handle The device handle
1281
+ * @param [in] dev_mem The device memory descriptor
1282
+ * @param [in] sys_mem The system memory descriptor
1283
+ * @param [in] need_copy If copy from system to device is needed
1284
+ * @param [in] coeff_count Coefficient size
1285
+ *
1286
+ * @retval BM_SUCCESS Succeeds.
1287
+ * Other code Fails.
1288
+ */
1289
+ DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_coeff_byte(
1290
+ bm_handle_t handle, struct bm_mem_desc *dev_mem, struct bm_mem_desc sys_mem,
1291
+ bool need_copy, int coeff_count);
1292
+
1293
+ /*******************memory map functions *************************************/
1294
+ /**
1295
+ * @name bm_mem_mmap_device_mem
1296
+ * @brief To map a piece of device memory to user space with cache enabled.
1297
+ * (only valid in SoC mode; Not supported in PCIE mode).
1298
+ * @ingroup bmlib_runtime
1299
+ *
1300
+ * @param [in] handle The device handle
1301
+ * @param [in] dev_mem The device memory to map
1302
+ * @param [out] vmem The virtual address of the mapped device memory
1303
+ *
1304
+ * @retval BM_SUCCESS Succeeds.
1305
+ * Other code Fails.
1306
+ */
1307
+ DECL_EXPORT bm_status_t bm_mem_mmap_device_mem(bm_handle_t handle, bm_device_mem_t *dmem,
1308
+
1309
+ unsigned long long *vmem);
1310
+
1311
+ /**
1312
+ * @name sg_mem_mmap_device_mem
1313
+ * @brief To map a piece of device memory to user space with cache enabled.
1314
+ * (only valid in SoC mode; Not supported in PCIE mode).
1315
+ * @ingroup bmlib_runtime
1316
+ *
1317
+ * @param [in] handle The device handle
1318
+ * @param [in] dev_mem The device memory to map
1319
+ * @param [out] vmem The virtual address of the mapped device memory
1320
+ *
1321
+ * @retval BM_SUCCESS Succeeds.
1322
+ * Other code Fails.
1323
+ */
1324
+ DECL_EXPORT bm_status_t sg_mem_mmap_device_mem(bm_handle_t handle, sg_device_mem_t *dmem,
1325
+ unsigned long long *vmem);
1326
+
1327
+ /*******************memory map functions *************************************/
1328
+ /**
1329
+ * @name bm_mem_mmap_device_mem_no_cache
1330
+ * @brief To map a piece of device memory to user space with cache disabled.
1331
+ * (only valid in SoC mode; Not supported in PCIE mode).
1332
+ * @ingroup bmlib_runtime
1333
+ *
1334
+ * @param [in] handle The device handle
1335
+ * @param [in] dev_mem The device memory to map
1336
+ * @param [out] vmem The virtual address of the mapped device memory
1337
+ *
1338
+ * @retval BM_SUCCESS Succeeds.
1339
+ * Other code Fails.
1340
+ */
1341
+ DECL_EXPORT bm_status_t bm_mem_mmap_device_mem_no_cache(bm_handle_t handle, bm_device_mem_t *dmem,
1342
+
1343
+ unsigned long long *vmem);
1344
+
1345
+ /**
1346
+ * @name sg_mem_mmap_device_mem_no_cache
1347
+ * @brief To map a piece of device memory to user space with cache disabled.
1348
+ * (only valid in SoC mode; Not supported in PCIE mode).
1349
+ * @ingroup bmlib_runtime
1350
+ *
1351
+ * @param [in] handle The device handle
1352
+ * @param [in] dev_mem The device memory to map
1353
+ * @param [out] vmem The virtual address of the mapped device memory
1354
+ *
1355
+ * @retval BM_SUCCESS Succeeds.
1356
+ * Other code Fails.
1357
+ */
1358
+ DECL_EXPORT bm_status_t sg_mem_mmap_device_mem_no_cache(bm_handle_t handle, sg_device_mem_t *dmem,
1359
+ unsigned long long *vmem);
1360
+
1361
+ /**
1362
+ * @name bm_mem_vir_to_phy
1363
+ * @brief To get device mem address through the mapped virtual address .
1364
+ * (only valid in SoC mode; Not supported in PCIE mode).
1365
+ * @ingroup bmlib_runtime
1366
+ *
1367
+ * @param [in] handle The device handle
1368
+ * @param [in] vmem The virtual address of the mapped device memory
1369
+ * @param [out] dev_mem The device memory address
1370
+ *
1371
+ * @retval BM_SUCCESS Succeeds.
1372
+ * Other code Fails.
1373
+ */
1374
+ DECL_EXPORT bm_status_t bm_mem_vir_to_phy(bm_handle_t handle, unsigned long long vmem,
1375
+ unsigned long long *device_mem);
1376
+ /**
1377
+ * @name bm_mem_invalidate_device_mem
1378
+ * @brief To invalidate a piece of mapped device memory to maintain
1379
+ * cache coherence
1380
+ * (only valid in SoC mode; Not supported in PCIE mode).
1381
+ * @ingroup bmlib_runtime
1382
+ *
1383
+ * @param [in] handle The device handle
1384
+ * @param [in] dmem The device memory to invalidate
1385
+ *
1386
+ * @retval BM_SUCCESS Succeeds.
1387
+ * Other code Fails.
1388
+ */
1389
+
1390
+ DECL_EXPORT bm_status_t bm_mem_invalidate_device_mem(bm_handle_t handle,
1391
+ bm_device_mem_t *dmem);
1392
+
1393
+ /**
1394
+ * @name sg_mem_invalidate_device_mem
1395
+ * @brief To invalidate a piece of mapped device memory to maintain
1396
+ * cache coherence
1397
+ * (only valid in SoC mode; Not supported in PCIE mode).
1398
+ * @ingroup bmlib_runtime
1399
+ *
1400
+ * @param [in] handle The device handle
1401
+ * @param [in] dmem The device memory to invalidate
1402
+ *
1403
+ * @retval BM_SUCCESS Succeeds.
1404
+ * Other code Fails.
1405
+ */
1406
+
1407
+ DECL_EXPORT bm_status_t sg_mem_invalidate_device_mem(bm_handle_t handle,
1408
+ sg_device_mem_t *dmem);
1409
+
1410
+ /**
1411
+ * @name bm_mem_invalidate_partial_device_mem
1412
+ * @brief To invalidate part of mapped device memory to maintain
1413
+ * cache coherence
1414
+ * (only valid in SoC mode; Not supported in PCIE mode).
1415
+ * @ingroup bmlib_runtime
1416
+ *
1417
+ * @param [in] handle The device handle
1418
+ * @param [in] dmem The device memory to invalidate
1419
+ * @param [in] offset The offset of device memory address
1420
+ * @param [in] len The length of memory to invalidate in bytes
1421
+ *
1422
+ * @retval BM_SUCCESS Succeeds.
1423
+ * Other code Fails.
1424
+ */
1425
+ DECL_EXPORT bm_status_t bm_mem_invalidate_partial_device_mem(bm_handle_t handle,
1426
+ bm_device_mem_t *dmem,
1427
+ unsigned int offset,
1428
+ unsigned int len);
1429
+
1430
+ /**
1431
+ * @name sg_mem_invalidate_partial_device_mem
1432
+ * @brief To invalidate part of mapped device memory to maintain
1433
+ * cache coherence
1434
+ * (only valid in SoC mode; Not supported in PCIE mode).
1435
+ * @ingroup bmlib_runtime
1436
+ *
1437
+ * @param [in] handle The device handle
1438
+ * @param [in] dmem The device memory to invalidate
1439
+ * @param [in] offset The offset of device memory address
1440
+ * @param [in] len The length of memory to invalidate in bytes
1441
+ *
1442
+ * @retval BM_SUCCESS Succeeds.
1443
+ * Other code Fails.
1444
+ */
1445
+ DECL_EXPORT bm_status_t sg_mem_invalidate_partial_device_mem(bm_handle_t handle,
1446
+ sg_device_mem_t *dmem,
1447
+ unsigned long long offset,
1448
+ unsigned long long len);
1449
+
1450
+ /**
1451
+ * @name bm_mem_flush_device_mem
1452
+ * @brief To flush a piece of mapped device memory to maintain
1453
+ * cache coherence
1454
+ * (only valid in SoC mode; Not supported in PCIE mode).
1455
+ * @ingroup bmlib_runtime
1456
+ *
1457
+ * @param [in] handle The device handle
1458
+ * @param [in] dmem The device memory to flush
1459
+ *
1460
+ * @retval BM_SUCCESS Succeeds.
1461
+ * Other code Fails.
1462
+ */
1463
+ DECL_EXPORT bm_status_t bm_mem_flush_device_mem(bm_handle_t handle, bm_device_mem_t *dmem);
1464
+
1465
+ /**
1466
+ * @name sg_mem_flush_device_mem
1467
+ * @brief To flush a piece of mapped device memory to maintain
1468
+ * cache coherence
1469
+ * (only valid in SoC mode; Not supported in PCIE mode).
1470
+ * @ingroup bmlib_runtime
1471
+ *
1472
+ * @param [in] handle The device handle
1473
+ * @param [in] dmem The device memory to flush
1474
+ *
1475
+ * @retval BM_SUCCESS Succeeds.
1476
+ * Other code Fails.
1477
+ */
1478
+ DECL_EXPORT bm_status_t sg_mem_flush_device_mem(bm_handle_t handle, sg_device_mem_t *dmem);
1479
+
1480
+ /**
1481
+ * @name bm_mem_flush_partial_device_mem
1482
+ * @brief To flush part of mapped device memory to maintain
1483
+ * cache coherence
1484
+ * (only valid in SoC mode; Not supported in PCIE mode).
1485
+ * @ingroup bmlib_runtime
1486
+ *
1487
+ * @param [in] handle The device handle
1488
+ * @param [in] dmem The device memory to flush
1489
+ * @param [in] offset The offset of device memory address
1490
+ * @param [in] len The length of memory to flush in bytes
1491
+ *
1492
+ * @retval BM_SUCCESS Succeeds.
1493
+ * Other code Fails.
1494
+ */
1495
+ DECL_EXPORT bm_status_t bm_mem_flush_partial_device_mem(bm_handle_t handle,
1496
+ bm_device_mem_t *dmem,
1497
+ unsigned int offset,
1498
+ unsigned int len);
1499
+
1500
+ /**
1501
+ * @name sg_mem_flush_partial_device_mem
1502
+ * @brief To flush part of mapped device memory to maintain
1503
+ * cache coherence
1504
+ * (only valid in SoC mode; Not supported in PCIE mode).
1505
+ * @ingroup bmlib_runtime
1506
+ *
1507
+ * @param [in] handle The device handle
1508
+ * @param [in] dmem The device memory to flush
1509
+ * @param [in] offset The offset of device memory address
1510
+ * @param [in] len The length of memory to flush in bytes
1511
+ *
1512
+ * @retval BM_SUCCESS Succeeds.
1513
+ * Other code Fails.
1514
+ */
1515
+ DECL_EXPORT bm_status_t sg_mem_flush_partial_device_mem(bm_handle_t handle,
1516
+ sg_device_mem_t *dmem,
1517
+ unsigned long long offset,
1518
+ unsigned long long len);
1519
+
1520
+ /**
1521
+ * @name bm_mem_unmap_device_mem
1522
+ * @brief To unmap a piece of mapped device memory
1523
+ * (only valid in SoC mode; Not supported in PCIE mode).
1524
+ * @ingroup bmlib_runtime
1525
+ *
1526
+ * @param [in] handle The device handle
1527
+ * @param [in] vmem The virtual address of the mapped device memory
1528
+ * @param [in] size The size of unmapped memory
1529
+ *
1530
+ * @retval BM_SUCCESS Succeeds.
1531
+ * Other code Fails.
1532
+ */
1533
+ DECL_EXPORT bm_status_t bm_mem_unmap_device_mem(bm_handle_t handle, void *vmem, int size);
1534
+
1535
+ /**
1536
+ * @name sg_mem_unmap_device_mem
1537
+ * @brief To unmap a piece of mapped device memory
1538
+ * (only valid in SoC mode; Not supported in PCIE mode).
1539
+ * @ingroup bmlib_runtime
1540
+ *
1541
+ * @param [in] handle The device handle
1542
+ * @param [in] vmem The virtual address of the mapped device memory
1543
+ * @param [in] size The size of unmapped memory
1544
+ *
1545
+ * @retval BM_SUCCESS Succeeds.
1546
+ * Other code Fails.
1547
+ */
1548
+ DECL_EXPORT bm_status_t sg_mem_unmap_device_mem(bm_handle_t handle, void *vmem, unsigned long long size);
1549
+
1550
+ /*******************api(kernel) functions *************************************/
1551
+ /**
1552
+ * @name bm_flush
1553
+ * @brief To synchronize APIs of the current thread. The thread will block
1554
+ * until all the outstanding APIs of the current thread are finished.
1555
+ * @ingroup bmlib_runtime
1556
+ *
1557
+ * @param [in] handle The device handle
1558
+ */
1559
+ DECL_EXPORT void bm_flush(bm_handle_t handle);
1560
+
1561
+ /**
1562
+ * @name bm_device_sync
1563
+ * @brief To synchronize APIs of the device. The thread will block
1564
+ * until all the outstanding APIs of the device are finished.
1565
+ * @ingroup bmlib_runtime
1566
+ *
1567
+ * @param [in] handle The device handle
1568
+ * @retval BM_SUCCESS Succeeds.
1569
+ * Other code Fails.
1570
+ */
1571
+ DECL_EXPORT bm_status_t bm_device_sync(bm_handle_t handle);
1572
+
1573
+ /**
1574
+ * @name bm_handle_sync
1575
+ * @brief To synchronize APIs of the handle. The thread will block
1576
+ * until all the outstanding APIs of the handle are finished.
1577
+ * @ingroup bmlib_runtime
1578
+ *
1579
+ * @param [in] handle The device handle
1580
+ * @retval BM_SUCCESS Succeeds.
1581
+ * Other code Fails.
1582
+ */
1583
+ DECL_EXPORT bm_status_t bm_handle_sync(bm_handle_t handle);
1584
+
1585
+ /**
1586
+ * @name bm_handle_sync_from_core
1587
+ * @brief To synchronize APIs of the handle. The thread will block
1588
+ * until all the outstanding APIs of the handle are finished.
1589
+ * @ingroup bmlib_runtime
1590
+ *
1591
+ * @param [in] handle The device handle
1592
+ * @param [in] core_id The core id
1593
+ * @retval BM_SUCCESS Succeeds.
1594
+ * Other code Fails.
1595
+ */
1596
+ DECL_EXPORT bm_status_t bm_handle_sync_from_core(bm_handle_t handle, int core_id);
1597
+
1598
+ /**
1599
+ * @name bm_thread_sync
1600
+ * @brief To synchronize APIs of the current thread. The thread will block
1601
+ * until all the outstanding APIs of the current thread are finished.
1602
+ * @ingroup bmlib_runtime
1603
+ *
1604
+ * @param [in] handle The device handle
1605
+ * @retval BM_SUCCESS Succeeds.
1606
+ * Other code Fails.
1607
+ */
1608
+ DECL_EXPORT bm_status_t bm_thread_sync(bm_handle_t handle);
1609
+
1610
+ /**
1611
+ * @name bm_thread_sync_from_core
1612
+ * @brief To synchronize APIs of the current thread. The thread will block
1613
+ * until all the outstanding APIs of the current thread are finished.
1614
+ * @ingroup bmlib_runtime
1615
+ *
1616
+ * @param [in] handle The device handle
1617
+ * @param [in] core_id The core id
1618
+ * @retval BM_SUCCESS Succeeds.
1619
+ * Other code Fails.
1620
+ */
1621
+ DECL_EXPORT bm_status_t bm_thread_sync_from_core(bm_handle_t handle, int core_id);
1622
+
1623
+ /*******************trace and profile releated functions **********************/
1624
+ typedef struct bm_profile {
1625
+ #ifdef __linux__
1626
+ unsigned long cdma_in_time;
1627
+ unsigned long cdma_in_counter;
1628
+ unsigned long cdma_out_time;
1629
+ unsigned long cdma_out_counter;
1630
+ unsigned long tpu_process_time;
1631
+ unsigned long tpu1_process_time;
1632
+ unsigned long sent_api_counter;
1633
+ unsigned long completed_api_counter;
1634
+ #else
1635
+ unsigned long long cdma_in_time;
1636
+ unsigned long long cdma_in_counter;
1637
+ unsigned long long cdma_out_time;
1638
+ unsigned long long cdma_out_counter;
1639
+ unsigned long long tpu_process_time;
1640
+ unsigned long long tpu1_process_time;
1641
+ unsigned long long sent_api_counter;
1642
+ unsigned long long completed_api_counter;
1643
+ #endif
1644
+ } bm_profile_t;
1645
+ /**
1646
+ * @name bm_get_profile
1647
+ * @brief To get the profile data at the moment
1648
+ * @ingroup bmlib_runtime
1649
+ *
1650
+ * @param [in] handle The device handle
1651
+ * @param [out] profile The result profile data
1652
+ * @retval BM_SUCCESS Succeeds.
1653
+ * Other code Fails.
1654
+ */
1655
+ DECL_EXPORT bm_status_t bm_get_profile(bm_handle_t handle, bm_profile_t *profile);
1656
+
1657
+ typedef struct bootloader_version{
1658
+ char *bl1_version;
1659
+ char *bl2_version;
1660
+ char *bl31_version;
1661
+ char *uboot_version;
1662
+ } boot_loader_version;
1663
+
1664
+ /**
1665
+ * @name bm_get_boot_loader_version
1666
+ * @brief To get the boot_loader_version
1667
+ * @ingroup bmlib_runtime
1668
+ *
1669
+ * @param [in] handle The device handle
1670
+ * @param [out] version The result version data
1671
+ * @retval BM_SUCCESS Succeeds.
1672
+ * Other code Fails.
1673
+ */
1674
+ DECL_EXPORT bm_status_t bm_get_boot_loader_version(bm_handle_t handle, boot_loader_version *version);
1675
+
1676
+ /**
1677
+ * @name bm_get_vpu_instant_usage
1678
+ * @brief To get vpu usage
1679
+ * @ingroup bmlib_runtime
1680
+ *
1681
+ * @param [in] handle The device handle
1682
+ * @param [out] smi_attr The result vpu usage
1683
+ * @retval BM_SUCCESS Succeeds.
1684
+ * Other code Fails.
1685
+ */
1686
+ DECL_EXPORT bm_status_t bm_get_vpu_instant_usage(bm_handle_t handle, int *vpu_usage);
1687
+
1688
+ /**
1689
+ * @name bm_get_jpu_core_usage
1690
+ * @brief To get the jpu usage
1691
+ * @ingroup bmlib_runtime
1692
+ *
1693
+ * @param [in] handle The device handle
1694
+ * @param [out] smi_attr The result jpu usage
1695
+ * @retval BM_SUCCESS Succeeds.
1696
+ * Other code Fails.
1697
+ */
1698
+ DECL_EXPORT bm_status_t bm_get_jpu_core_usage(bm_handle_t handle, int *jpu_usage);
1699
+
1700
+ /**
1701
+ * @name bm_get_vpp_instant_usage
1702
+ * @brief To get the vpp usage
1703
+ * @ingroup bmlib_runtime
1704
+ *
1705
+ * @param [in] handle The device handle
1706
+ * @param [out] smi_attr The result vpp usage
1707
+ * @retval BM_SUCCESS Succeeds.
1708
+ * Other code Fails.
1709
+ */
1710
+ DECL_EXPORT bm_status_t bm_get_vpp_instant_usage(bm_handle_t handle, int *vpp_usage);
1711
+ /**
1712
+ * @name bm_get_last_api_process_time_us
1713
+ * @brief This function is abandoned.
1714
+ */
1715
+ #ifdef __linux__
1716
+ DECL_EXPORT bm_status_t bm_get_last_api_process_time_us(bm_handle_t handle,
1717
+ unsigned long *time_us);
1718
+ #else
1719
+ DECL_EXPORT bm_status_t bm_get_last_api_process_time_us(bm_handle_t handle,
1720
+ unsigned long long *time_us);
1721
+ #endif
1722
+ /*******************tpu clock and module reset releated functions *************/
1723
+
1724
+ /**
1725
+ * @name bm_set_clk_tpu_freq
1726
+ * @brief To set the clock frequency of TPU (only valid in PCIE mode).
1727
+ * @ingroup bmlib_runtime
1728
+ *
1729
+ * @param [in] handle The device handle
1730
+ * @param [in] freq The TPU target frequency
1731
+ * @retval BM_SUCCESS Succeeds.
1732
+ * Other code Fails.
1733
+ */
1734
+ DECL_EXPORT bm_status_t bm_set_clk_tpu_freq(bm_handle_t handle, int freq);
1735
+
1736
+ /**
1737
+ * @name bm_get_clk_tpu_freq
1738
+ * @brief To get the clock frequency of TPU
1739
+ * @ingroup bmlib_runtime
1740
+ *
1741
+ * @param [in] handle The device handle
1742
+ * @param [out] freq The current TPU frequency
1743
+ * @retval BM_SUCCESS Succeeds.
1744
+ * Other code Fails.
1745
+ */
1746
+ DECL_EXPORT bm_status_t bm_get_clk_tpu_freq(bm_handle_t handle, int *freq);
1747
+
1748
+ /*******************misc functions ********************************************/
1749
+ struct bm_misc_info {
1750
+ int pcie_soc_mode; /*0---pcie; 1---soc*/
1751
+ int ddr_ecc_enable; /*0---disable; 1---enable*/
1752
+ long long ddr0a_size;
1753
+ long long ddr0b_size;
1754
+ long long ddr1_size;
1755
+ long long ddr2_size;
1756
+ unsigned int chipid;
1757
+ #define BM1682_CHIPID_BIT_MASK (0X1 << 0)
1758
+ #define BM1684_CHIPID_BIT_MASK (0X1 << 1)
1759
+ #define BM1686_CHIPID_BIT_MASK (0X1 << 2)
1760
+ #ifdef __linux__
1761
+ unsigned long chipid_bit_mask;
1762
+ #else
1763
+ unsigned long long chipid_bit_mask;
1764
+ #endif
1765
+ unsigned int driver_version;
1766
+ int domain_bdf;
1767
+ int board_version; /*hardware board version [23:16]-mcu sw version, [15:8]-board type, [7:0]-hw version*/
1768
+ int a53_enable;
1769
+ int dyn_enable;
1770
+ };
1771
+
1772
+ /**
1773
+ * @name bm_get_misc_info
1774
+ * @brief To get miscellaneous information of the device
1775
+ * @ingroup bmlib_runtime
1776
+ *
1777
+ * @param [in] handle The device handle
1778
+ * @param [out] pmisc_info The fetched misc info
1779
+ * @retval BM_SUCCESS Succeeds.
1780
+ * Other code Fails.
1781
+ */
1782
+ DECL_EXPORT bm_status_t bm_get_misc_info(bm_handle_t handle, struct bm_misc_info *pmisc_info);
1783
+
1784
+ /**
1785
+ * @name bm_get_chipid
1786
+ * @brief To get the chipid of the device. (0x1682 / 0x1684 / 0x168?)
1787
+ * @ingroup bmlib_runtime
1788
+ *
1789
+ * @param [in] handle The device handle
1790
+ * @param [out] p_chipid The chip id of the device
1791
+ * @retval BM_SUCCESS Succeeds.
1792
+ * Other code Fails.
1793
+ */
1794
+ DECL_EXPORT bm_status_t bm_get_chipid(bm_handle_t handle, unsigned int *p_chipid);
1795
+
1796
+ #define BMLIB_LOG_QUIET -8
1797
+ #define BMLIB_LOG_PANIC 0
1798
+ #define BMLIB_LOG_FATAL 8
1799
+ #define BMLIB_LOG_ERROR 16
1800
+ #define BMLIB_LOG_WARNING 24
1801
+ #define BMLIB_LOG_INFO 32
1802
+ #define BMLIB_LOG_VERBOSE 40
1803
+ #define BMLIB_LOG_DEBUG 48
1804
+ #define BMLIB_LOG_TRACE 56
1805
+
1806
+ /**
1807
+ * @name bmlib_log_get_level
1808
+ * @brief To get the bmlib log level
1809
+ * @ingroup bmlib_log
1810
+ *
1811
+ * @param void
1812
+ * @retval The level of bmlib log level
1813
+ */
1814
+ DECL_EXPORT int bmlib_log_get_level(void);
1815
+
1816
+ /**
1817
+ * @name bmlib_log_set_level
1818
+ * @brief To set the bmlib log level
1819
+ * @ingroup bmlib_log
1820
+ *
1821
+ * @param [in] level The level of bmlib log level
1822
+ * @retval void
1823
+ */
1824
+ DECL_EXPORT void bmlib_log_set_level(int level);
1825
+
1826
+ /**
1827
+ * @name bmlib_log_set_callback
1828
+ * @brief To set callback to get bmlib log
1829
+ * @ingroup bmlib_log
1830
+ *
1831
+ * @param [in] callback The callback function to get bmlib log
1832
+ * @retval void
1833
+ */
1834
+ DECL_EXPORT void bmlib_log_set_callback(void (*callback)(const char*, int, const char*, va_list args));
1835
+
1836
+ /**
1837
+ * @name bm_set_debug_mode
1838
+ * @brief To set the debug mode for firmware log for tpu
1839
+ * @ingroup bmlib_log
1840
+ *
1841
+ * @param [in] handle The device handle
1842
+ * @param [in] mode The debug mode of fw log, 0/1 for disable/enable log
1843
+ * @retval void
1844
+ */
1845
+ DECL_EXPORT void bm_set_debug_mode(bm_handle_t handle, int mode);
1846
+
1847
+ /**
1848
+ * @name bmlib_api_dbg_callback
1849
+ * @brief To set debug callback to get firmware log
1850
+ * @ingroup bmlib_log
1851
+ *
1852
+ * @param [in] bmlib_api_dbg_callback callback to get firmware log
1853
+ * @retval void
1854
+ */
1855
+ typedef void (*bmlib_api_dbg_callback)(int, int, int, const char*);
1856
+ // api, result, duratioin, log, third int for api duration for future
1857
+ DECL_EXPORT void bmlib_set_api_dbg_callback(bmlib_api_dbg_callback callback);
1858
+
1859
+ /**
1860
+ * @name bmcpu_get_cpu_status
1861
+ * @brief Get bmcpu status
1862
+ * @ingroup bmlib_log
1863
+ *
1864
+ * @param [in] handle The device handle
1865
+ * @retval BMCPU_RUNNING bmcpu is running.
1866
+ * Other code Fails.
1867
+ */
1868
+ DECL_EXPORT bm_cpu_status_t bmcpu_get_cpu_status(bm_handle_t handle);
1869
+
1870
+ /**
1871
+ * @name bmcpu_start_cpu
1872
+ * @brief Start cpu in pcie mode
1873
+ * @ingroup bmlib_log
1874
+ *
1875
+ * @param [in] handle The device handle
1876
+ * @param [in] boot_file Fip file
1877
+ * @param [in] core_file Itb file
1878
+ * @retval BM_SUCCESS Succeeds.
1879
+ * Other code Fails.
1880
+ */
1881
+ DECL_EXPORT bm_status_t bmcpu_start_cpu(bm_handle_t handle, char *boot_file, char *core_file);
1882
+
1883
+ /**
1884
+ * @name bmcpu_open_process
1885
+ * @brief Open a process to do some work
1886
+ * @ingroup bmlib_log
1887
+ *
1888
+ * @param [in] handle The device handle
1889
+ * @param [in] flags Process flags
1890
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
1891
+ * @retval >= 0 process handle
1892
+ * < 0 Other code Fails.
1893
+ */
1894
+ DECL_EXPORT int bmcpu_open_process(bm_handle_t handle, unsigned int flags, int timeout);
1895
+
1896
+ /**
1897
+ * @name bmcpu_load_library
1898
+ * @brief Load a share library(so) to specific process
1899
+ * @ingroup bmlib_log
1900
+ *
1901
+ * @param [in] handle The device handle
1902
+ * @param [in] process_handle Process handle
1903
+ * @param [in] library_file Library file path
1904
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
1905
+ * @retval BM_SUCCESS Succeeds.
1906
+ * Other code Fails.
1907
+ */
1908
+ DECL_EXPORT bm_status_t bmcpu_load_library(bm_handle_t handle, int process_handle, char *library_file, int timeout);
1909
+
1910
+ /**
1911
+ * @name bmcpu_unload_library
1912
+ * @brief Load a share library(so) to specific process
1913
+ * @ingroup bmlib_log
1914
+ *
1915
+ * @param [in] handle The device handle
1916
+ * @param [in] process_handle Process handle
1917
+ * @param [in] library_file Library file path
1918
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
1919
+ * @retval BM_SUCCESS Succeeds.
1920
+ * Other code Fails.
1921
+ */
1922
+ DECL_EXPORT bm_status_t bmcpu_unload_library(bm_handle_t handle, int process_handle, char *library_file, int timeout);
1923
+
1924
+ /**
1925
+ * @name bmcpu_exec_function
1926
+ * @brief Execute specific function in specific process
1927
+ * @ingroup bmlib_log
1928
+ *
1929
+ * @param [in] handle The device handle
1930
+ * @param [in] process_handle Process handle
1931
+ * @param [in] function_name Function name
1932
+ * @param [in] function_param Function parameters
1933
+ * @param [in] param_size Parameters size in bytes
1934
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
1935
+ * @retval 0 success.
1936
+ * >0 code fails from bmlib
1937
+ * <0 code fails from function
1938
+ */
1939
+ DECL_EXPORT int bmcpu_exec_function(bm_handle_t handle,
1940
+ int process_handle,
1941
+ char *function_name,
1942
+ void *function_param,
1943
+ unsigned int param_size,
1944
+ int timeout);
1945
+
1946
+ #define BMCPU_EXEC_OPT_NO_FLUSH_CACHE 1
1947
+ /**
1948
+ * @name bmcpu_exec_function_ext
1949
+ * @brief Execute specific function in specific process
1950
+ * @ingroup bmlib_log
1951
+ *
1952
+ * @param [in] handle The device handle
1953
+ * @param [in] process_handle Process handle
1954
+ * @param [in] function_name Function name
1955
+ * @param [in] function_param Function parameters
1956
+ * @param [in] param_size Parameters size in bytes
1957
+ * @param [in] opt exec options
1958
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
1959
+ * @retval 0 success.
1960
+ * >0 code fails from bmlib
1961
+ * <0 code fails from function
1962
+ */
1963
+ DECL_EXPORT int bmcpu_exec_function_ext(bm_handle_t handle,
1964
+ int process_handle,
1965
+ char *function_name,
1966
+ void *function_param,
1967
+ unsigned int param_size,
1968
+ unsigned int opt,
1969
+ int timeout);
1970
+
1971
+ /**
1972
+ * @name bmcpu_exec_function_async
1973
+ * @brief Execute specific function in specific process asynchronous
1974
+ * user should use bm_query_exec_function_result to query result
1975
+ * @ingroup bmlib_log
1976
+ *
1977
+ * @param [in] handle The device handle
1978
+ * @param [in] process_handle Process handle
1979
+ * @param [in] function_name Function name
1980
+ * @param [in] function_param Function param
1981
+ * @param [in] param_size Param size in bytes
1982
+ * @retval BM_SUCCESS Succeeds.
1983
+ * Other code Fails.
1984
+ */
1985
+ DECL_EXPORT bm_status_t bmcpu_exec_function_async(bm_handle_t handle,
1986
+ int process_handle,
1987
+ char *function_name,
1988
+ void *function_param,
1989
+ unsigned int param_size,
1990
+ unsigned long long *api_handle);
1991
+
1992
+ /**
1993
+ * @name bmcpu_exec_function_async_ext
1994
+ * @brief Execute specific function in specific process asynchronous
1995
+ * user should use bm_query_exec_function_result to query result
1996
+ * @ingroup bmlib_log
1997
+ *
1998
+ * @param [in] handle The device handle
1999
+ * @param [in] process_handle Process handle
2000
+ * @param [in] function_name Function name
2001
+ * @param [in] function_param Function param
2002
+ * @param [in] param_size Param size in bytes
2003
+ * @param [in] opt exec options
2004
+ * @retval BM_SUCCESS Succeeds.
2005
+ * Other code Fails.
2006
+ */
2007
+ DECL_EXPORT bm_status_t bmcpu_exec_function_async_ext(bm_handle_t handle,
2008
+ int process_handle,
2009
+ char *function_name,
2010
+ void *function_param,
2011
+ unsigned int param_size,
2012
+ unsigned int opt,
2013
+ unsigned long long *api_handle);
2014
+
2015
+ /**
2016
+ * @name bmcpu_query_exec_function_result
2017
+ * @brief Query result from function called by bm_exec_function
2018
+ * @ingroup bmlib_log
2019
+ *
2020
+ * @param [in] handle The device handle
2021
+ * @param [in] api_handle Api handle return by bm_exec_function_async
2022
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
2023
+ * @retval 0 success.
2024
+ * >0 code fails from bmlib
2025
+ * <0 code fails from function
2026
+ */
2027
+ DECL_EXPORT int bmcpu_query_exec_function_result(bm_handle_t handle, unsigned long long api_handle, int timeout);
2028
+
2029
+ /**
2030
+ * @name bmcpu_map_phys_addr
2031
+ * @brief Map physical address in specific process
2032
+ * @ingroup bmlib_log
2033
+ *
2034
+ * @param [in] handle The device handle
2035
+ * @param [in] process_handle Process handle
2036
+ * @param [in] phys_addr Physical address
2037
+ * @param [in] size Map size in bytes
2038
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
2039
+ * @retval >0 virtual address
2040
+ * 0 fails
2041
+ */
2042
+ DECL_EXPORT void *bmcpu_map_phys_addr(bm_handle_t handle, int process_handle, void *phys_addr, unsigned int size, int timeout);
2043
+
2044
+ /**
2045
+ * @name bmcpu_unmap_phys_addr
2046
+ * @brief Unmap physical address in specific process
2047
+ * @ingroup bmlib_log
2048
+ *
2049
+ * @param [in] handle The device handle
2050
+ * @param [in] process_handle Process handle
2051
+ * @param [in] phys_addr Physical address
2052
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
2053
+ * @retval <0 fail
2054
+ * 0 success
2055
+ */
2056
+ DECL_EXPORT bm_status_t bmcpu_unmap_phys_addr(bm_handle_t handle, int process_handle, void *phys_addr, int timeout);
2057
+
2058
+ /**
2059
+ * @name bmcpu_close_process
2060
+ * @brief Close process
2061
+ * @ingroup bmlib_log
2062
+ *
2063
+ * @param [in] handle The device handle
2064
+ * @param [in] process_handle Process handle
2065
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
2066
+ * @retval BM_SUCCESS Succeeds.
2067
+ * Other code Fails.
2068
+ */
2069
+ DECL_EXPORT bm_status_t bmcpu_close_process(bm_handle_t handle, int process_handle, int timeout);
2070
+
2071
+ /**
2072
+ * @name bmcpu_reset_cpu
2073
+ * @brief Reset cpu in pcie mode
2074
+ * @ingroup bmlib_log
2075
+ *
2076
+ * @param [in] handle The device handle
2077
+ * @retval BM_SUCCESS Succeeds.
2078
+ * Other code Fails.
2079
+ */
2080
+ DECL_EXPORT bm_status_t bmcpu_reset_cpu(bm_handle_t handle);
2081
+
2082
+ /**
2083
+ * @name bm_enable_perf_monitor
2084
+ * @brief enable perf monitor to get gdma and tpu performance data
2085
+ * @ingroup bmlib_perf
2086
+ *
2087
+ * @param [in] handle The device handle
2088
+ * @param [in] perf_monitor The monitor to perf
2089
+ * @retval BM_SUCCESS Succeeds.
2090
+ * Other code Fails.
2091
+ */
2092
+ DECL_EXPORT bm_status_t bm_enable_perf_monitor(bm_handle_t handle, bm_perf_monitor_t *perf_monitor);
2093
+
2094
+ /**
2095
+ * @name bm_disable_perf_monitor
2096
+ * @brief disable perf monitor to get gdma and tpu performance data
2097
+ * @ingroup bmlib_perf
2098
+ *
2099
+ * @param [in] handle The device handle
2100
+ * @param [in] perf_monitor The monitor to perf
2101
+ * @retval BM_SUCCESS Succeeds.
2102
+ * Other code Fails.
2103
+ */
2104
+ DECL_EXPORT bm_status_t bm_disable_perf_monitor(bm_handle_t handle, bm_perf_monitor_t *perf_monitor);
2105
+
2106
+ /**
2107
+ * @name bmcpu_set_log
2108
+ * @brief Set cpu log options
2109
+ * @ingroup bmlib_log
2110
+ *
2111
+ * @param [in] handle The device handle
2112
+ * @param [in] log_level 0: DEBUG 1:INFO 2:WARN 3:ERROR 4:FATAL
2113
+ * @param [in] log_to_console 1: YES 0: No
2114
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
2115
+ * @retval BM_SUCCESS Succeeds.
2116
+ * Other code Fails.
2117
+ */
2118
+ DECL_EXPORT bm_status_t bmcpu_set_log(bm_handle_t handle, unsigned int log_level, unsigned int log_to_console, int timeout);
2119
+
2120
+ /**
2121
+ * @name bmcpu_get_log
2122
+ * @brief Get cpu log file
2123
+ * @ingroup bmlib_log
2124
+ *
2125
+ * @param [in] handle The device handle
2126
+ * @param [in] process_handle Process handle
2127
+ * @param [in] log_file save log as file
2128
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
2129
+ * @retval BM_SUCCESS Succeeds.
2130
+ * Other code Fails.
2131
+ */
2132
+ DECL_EXPORT bm_status_t bmcpu_get_log(bm_handle_t handle, int process_handle, char *log_file, int timeout);
2133
+
2134
+ /**
2135
+ * @name bmcpu_sync_time
2136
+ * @brief Sync device cpu time with host
2137
+ * @ingroup bmlib_log
2138
+ *
2139
+ * @param [in] handle The device handle
2140
+ * @retval BM_SUCCESS Succeeds.
2141
+ * Other code Fails.
2142
+ */
2143
+ DECL_EXPORT bm_status_t bmcpu_sync_time(bm_handle_t handle);
2144
+
2145
+ /*******************trace and profile releated functions **********************/
2146
+ struct bm_heap_stat {
2147
+ unsigned int mem_total;
2148
+ unsigned int mem_avail;
2149
+ unsigned int mem_used;
2150
+ };
2151
+
2152
+ typedef struct bm_heap_stat_byte {
2153
+ unsigned int heap_id;
2154
+ unsigned long long mem_total;
2155
+ unsigned long long mem_avail;
2156
+ unsigned long long mem_used;
2157
+ unsigned long long mem_start_addr;
2158
+ } bm_heap_stat_byte_t;
2159
+
2160
+ typedef struct bm_dev_stat {
2161
+ int mem_total;
2162
+ int mem_used;
2163
+ int tpu_util;
2164
+ int heap_num;
2165
+ struct bm_heap_stat heap_stat[4];
2166
+ } bm_dev_stat_t;
2167
+
2168
+ /**
2169
+ * @name bm_get_stat
2170
+ * @brief To get the stat data at the moment
2171
+ * @ingroup bmlib_runtime
2172
+ *
2173
+ * @param [in] handle The device handle
2174
+ * @param [out] profile The result stat data
2175
+ * @retval BM_SUCCESS Succeeds.
2176
+ * Other code Fails.
2177
+ */
2178
+ DECL_EXPORT bm_status_t bm_get_stat(bm_handle_t handle, bm_dev_stat_t *stat);
2179
+
2180
+ /**
2181
+ * @name bm_get_gmem_heap_id
2182
+ * @brief To get the heap id of allocated global memory
2183
+ * @ingroup bmlib_runtime
2184
+ *
2185
+ * @param [in] handle The device handle
2186
+ * @param [in] pmem The allocted global memory
2187
+ * @param [out] heapid The result of get heap id
2188
+ * @retval BM_SUCCESS Succeeds.
2189
+ * Other code Fails.
2190
+ */
2191
+
2192
+ DECL_EXPORT bm_status_t bm_get_gmem_heap_id(bm_handle_t handle, bm_device_mem_t *pmem, unsigned int *heapid);
2193
+
2194
+ /**
2195
+ * @name sg_get_gmem_heap_id
2196
+ * @brief To get the heap id of allocated global memory
2197
+ * @ingroup bmlib_runtime
2198
+ *
2199
+ * @param [in] handle The device handle
2200
+ * @param [in] pmem The allocted global memory
2201
+ * @param [out] heapid The result of get heap id
2202
+ * @retval BM_SUCCESS Succeeds.
2203
+ * Other code Fails.
2204
+ */
2205
+
2206
+ DECL_EXPORT bm_status_t sg_get_gmem_heap_id(bm_handle_t handle, sg_device_mem_t *pmem, unsigned int *heapid);
2207
+
2208
+ /**
2209
+ * @name bm_get_gmem_total_heap_num
2210
+ * @brief To get the total heap num of global memory
2211
+ * @ingroup bmlib_runtime
2212
+ *
2213
+ * @param [in] handle The device handle
2214
+ * @param [in] heap_num The result of get total num
2215
+ * @retval BM_SUCCESS Succeeds.
2216
+ * Other code Fails.
2217
+ */
2218
+ DECL_EXPORT bm_status_t bm_get_gmem_total_heap_num(bm_handle_t handle, unsigned int *heap_num);
2219
+
2220
+ /**
2221
+ * @name bm_get_gmem_heap_stat_byte_by_id
2222
+ * @brief To get the heap stat by heap id
2223
+ * @ingroup bmlib_runtime
2224
+ *
2225
+ * @param [in] handle The device handle
2226
+ * @param [in] heap_id The heap index to get heap status
2227
+ * @param [out] pheap_byte The result of get heap status
2228
+ * @retval BM_SUCCESS Succeeds.
2229
+ * Other code Fails.
2230
+ */
2231
+ DECL_EXPORT bm_status_t bm_get_gmem_heap_stat_byte_by_id(bm_handle_t handle, bm_heap_stat_byte_t *pheap_byte, unsigned int heap_id);
2232
+
2233
+ DECL_EXPORT bm_status_t bm_load_firmware(
2234
+ bm_handle_t handle,
2235
+ const char *firmware_tcm,
2236
+ const char *firmware_ddr);
2237
+
2238
+ #define bmkernel_load_firmware okkernel_load_firmware
2239
+ DECL_EXPORT bm_status_t okkernel_load_firmware(
2240
+ bm_handle_t handle,
2241
+ const char *firmware_tcm,
2242
+ const char *firmware_ddr);
2243
+
2244
+ DECL_EXPORT bm_status_t okkernel_launch_async(
2245
+ bm_handle_t handle,
2246
+ const char *func_name,
2247
+ const void *args,
2248
+ unsigned int size);
2249
+
2250
+ DECL_EXPORT bm_status_t okkernel_launch_sync(
2251
+ bm_handle_t handle,
2252
+ const char *func_name,
2253
+ const void *args,
2254
+ unsigned int size);
2255
+
2256
+ DECL_EXPORT bm_status_t tpu_kernel_launch_sync(
2257
+ bm_handle_t handle,
2258
+ const char *func_name,
2259
+ const void *args,
2260
+ unsigned int size);
2261
+
2262
+ DECL_EXPORT bm_status_t okkernel_sync(bm_handle_t handle);
2263
+
2264
+ /**
2265
+ * @name bmkernel_launch
2266
+ * @brief send api to device and launch function
2267
+ * @ingroup bmlib_runtime
2268
+ *
2269
+ * @param [in] handle The device handle
2270
+ * @param [in] api cmd struct pointer
2271
+ * @param [in] api cmd length
2272
+ * @retval BM_SUCCESS Succeeds.
2273
+ * Other code Fails.
2274
+ */
2275
+ DECL_EXPORT bm_status_t bmkernel_launch(bm_handle_t handle, const void *args,
2276
+ unsigned int size);
2277
+
2278
+ /**
2279
+ * @name bmkernel_load_lookup_table
2280
+ * @brief load lookup table to l2-sram
2281
+ * @ingroup bmlib_runtime
2282
+ *
2283
+ * @param [in] handle The device handle
2284
+ * @param [in] table which loaded to l2-sram
2285
+ * @param [in] table size
2286
+ * @retval BM_SUCCESS Succeeds.
2287
+ * Other code Fails.
2288
+ */
2289
+ DECL_EXPORT bm_status_t bmkernel_load_lookup_table(bm_handle_t handle, const void* table, unsigned int size);
2290
+
2291
+ /*******************device management api functions ********************************************/
2292
+ /**
2293
+ * @name bm_get_tpu_current
2294
+ * @brief get tpu current
2295
+ * @ingroup bmlib_runtime
2296
+ *
2297
+ * @param [in] handle The device handle
2298
+ * @param [out] tpuc(mA) The pointer for tpu current
2299
+ * @retval BM_SUCCESS Succeeds.
2300
+ * Other code Fails.
2301
+ */
2302
+ DECL_EXPORT bm_status_t bm_get_tpu_current(bm_handle_t handle, unsigned int *tpuc);
2303
+
2304
+ /**
2305
+ * @name bm_get_board_max_power
2306
+ * @brief get board support max power
2307
+ * @ingroup bmlib_runtime
2308
+ *
2309
+ * @param [in] handle The device handle
2310
+ * @param [out] maxp The pointer for maxp
2311
+ * @retval BM_SUCCESS Succeeds.
2312
+ * Other code Fails.
2313
+ */
2314
+ DECL_EXPORT bm_status_t bm_get_board_max_power(bm_handle_t handle, unsigned int *maxp);
2315
+
2316
+ /**
2317
+ * @name bm_get_board_power
2318
+ * @brief get board power
2319
+ * @ingroup bmlib_runtime
2320
+ *
2321
+ * @param [in] handle The device handle
2322
+ * @param [out] boardp The pointer for boardp
2323
+ * @retval BM_SUCCESS Succeeds.
2324
+ * Other code Fails.
2325
+ */
2326
+ DECL_EXPORT bm_status_t bm_get_board_power(bm_handle_t handle, unsigned int *boardp);
2327
+
2328
+ /**
2329
+ * @name bm_get_fan_speed
2330
+ * @brief get board fan speed
2331
+ * @ingroup bmlib_runtime
2332
+ *
2333
+ * @param [in] handle The device handle
2334
+ * @param [out] fan The pointer for fan speed
2335
+ * @retval BM_SUCCESS Succeeds.
2336
+ * Other code Fails.
2337
+ */
2338
+ DECL_EXPORT bm_status_t bm_get_fan_speed(bm_handle_t handle, unsigned int *fan);
2339
+
2340
+ /**
2341
+ * @name bm_get_ecc_correct_num
2342
+ * @brief get ecc_correct_num
2343
+ * @ingroup device management api
2344
+ *
2345
+ * @param [in] handle The device handle
2346
+ * @param [out] ecc_correct_num
2347
+ * @retval BM_SUCCESS Succeeds.
2348
+ * Other code Fails.
2349
+ */
2350
+ #ifdef __linux__
2351
+ DECL_EXPORT bm_status_t bm_get_ecc_correct_num(bm_handle_t handle, unsigned long *ecc_correct_num);
2352
+ #else
2353
+ DECL_EXPORT bm_status_t bm_get_ecc_correct_num(bm_handle_t handle, unsigned long long *ecc_correct_num);
2354
+ #endif
2355
+ /**
2356
+ * @name bm_get_12v_atx
2357
+ * @brief get atx_12v
2358
+ * @ingroup device management api
2359
+ *
2360
+ * @param [in] handle The device handle
2361
+ * @param [out] atx_12v
2362
+ * @retval BM_SUCCESS Succeeds.
2363
+ * Other code Fails.
2364
+ */
2365
+ DECL_EXPORT bm_status_t bm_get_12v_atx(bm_handle_t handle, int *atx_12v);
2366
+
2367
+ /**
2368
+ * @name bm_get_product_sn
2369
+ * @brief get SE5 sn
2370
+ * @ingroup device management api
2371
+ *
2372
+ * @param [out] product_sn
2373
+ * @retval BM_SUCCESS Succeeds.
2374
+ * Other code Fails.
2375
+ */
2376
+ DECL_EXPORT bm_status_t bm_get_product_sn(char *product_sn);
2377
+
2378
+ /**
2379
+ * @name bm_get_sn
2380
+ * @brief get sn
2381
+ * @ingroup device management api
2382
+ *
2383
+ * @param [in] handle The device handle
2384
+ * @param [out] sn
2385
+ * @retval BM_SUCCESS Succeeds.
2386
+ * Other code Fails.
2387
+ */
2388
+ DECL_EXPORT bm_status_t bm_get_sn(bm_handle_t handle, char *sn);
2389
+
2390
+ /**
2391
+ * @name bm_get_status
2392
+ * @brief get chip status
2393
+ * @ingroup device management api
2394
+ *
2395
+ * @param [in] handle The device handle
2396
+ * @param [out] status The board error status, each bit represents an error state
2397
+ * status == 0x0, borad is nornal, staus > 0, borad is abnormal;
2398
+ * bit0 == 1, tpu is hang
2399
+ * bit1 == 1, pcie link abnormal
2400
+ * bit2 == 1, board temperature is too high
2401
+ * @retval BM_SUCCESS Succeeds.
2402
+ * Other code Fails.
2403
+ */
2404
+ DECL_EXPORT bm_status_t bm_get_status(bm_handle_t handle, int *status);
2405
+
2406
+ /**
2407
+ * @name bm_get_tpu_maxclk
2408
+ * @brief get tpu_maxclk
2409
+ * @ingroup device management api
2410
+ *
2411
+ * @param [in] handle The device handle
2412
+ * @param [out] tpu_maxclk
2413
+ * @retval BM_SUCCESS Succeeds.
2414
+ * Other code Fails.
2415
+ */
2416
+ DECL_EXPORT bm_status_t bm_get_tpu_maxclk(bm_handle_t handle, unsigned int *tpu_maxclk);
2417
+
2418
+ /**
2419
+ * @name bm_get_tpu_minclk
2420
+ * @brief get tpu_minclk
2421
+ * @ingroup device management api
2422
+ *
2423
+ * @param [in] handle The device handle
2424
+ * @param [out] tpu_minclk
2425
+ * @retval BM_SUCCESS Succeeds.
2426
+ * Other code Fails.
2427
+ */
2428
+ DECL_EXPORT bm_status_t bm_get_tpu_minclk(bm_handle_t handle, unsigned int *tpu_minclk);
2429
+
2430
+ /**
2431
+ * @name bm_get_driver_version
2432
+ * @brief get driver version
2433
+ * @ingroup device management api
2434
+ *
2435
+ * @param [in] handle The device handle
2436
+ * @param [out] driver_version
2437
+ * @retval BM_SUCCESS Succeeds.
2438
+ * Other code Fails.
2439
+ */
2440
+ DECL_EXPORT bm_status_t bm_get_driver_version(bm_handle_t handle, int *driver_version);
2441
+
2442
+ /**
2443
+ * @name bm_get_board_name
2444
+ * @brief get device board name
2445
+ * @ingroup device management api
2446
+ *
2447
+ * @param [in] handle The device handle
2448
+ * @param [out] board_name
2449
+ * @retval BM_SUCCESS Succeeds.
2450
+ * Other code Fails.
2451
+ */
2452
+ DECL_EXPORT bm_status_t bm_get_board_name(bm_handle_t handle, char *name);
2453
+
2454
+ /**
2455
+ * @name bm_get_board_temp
2456
+ * @brief get board temperature
2457
+ * @ingroup device management api
2458
+ *
2459
+ * @param [in] handle The device handle
2460
+ * @param [out] board_temp
2461
+ * @retval BM_SUCCESS Succeeds.
2462
+ * Other code Fails.
2463
+ */
2464
+ DECL_EXPORT bm_status_t bm_get_board_temp(bm_handle_t handle, unsigned int *board_temp);
2465
+
2466
+ /**
2467
+ * @name bm_get_chip_temp
2468
+ * @brief get chip temperature
2469
+ * @ingroup device management api
2470
+ *
2471
+ * @param [in] handle The device handle
2472
+ * @param [out] chip_temp
2473
+ * @retval BM_SUCCESS Succeeds.
2474
+ * Other code Fails.
2475
+ */
2476
+ DECL_EXPORT bm_status_t bm_get_chip_temp(bm_handle_t handle, unsigned int *chip_temp);
2477
+
2478
+ /**
2479
+ * @name bm_get_tpu_power
2480
+ * @brief get TPU power
2481
+ * @ingroup device management api
2482
+ *
2483
+ * @param [in] handle The device handle
2484
+ * @param [out] tpu_power
2485
+ * @retval BM_SUCCESS Succeeds.
2486
+ * Other code Fails.
2487
+ */
2488
+ DECL_EXPORT bm_status_t bm_get_tpu_power(bm_handle_t handle, float *tpu_power);
2489
+
2490
+ /**
2491
+ * @name bm_get_tpu_volt
2492
+ * @brief get TPU voltage
2493
+ * @ingroup device management api
2494
+ *
2495
+ * @param [in] handle The device handle
2496
+ * @param [out] tpu_volt
2497
+ * @retval BM_SUCCESS Succeeds.
2498
+ * Other code Fails.
2499
+ */
2500
+ DECL_EXPORT bm_status_t bm_get_tpu_volt(bm_handle_t handle, unsigned int *tpu_volt);
2501
+
2502
+ /**
2503
+ * @name bm_get_card_id
2504
+ * @brief get card id
2505
+ * @ingroup device management api
2506
+ *
2507
+ * @param [in] handle The device handle
2508
+ * @param [out] card_id
2509
+ * @retval BM_SUCCESS Succeeds.
2510
+ * Other code Fails.
2511
+ */
2512
+ DECL_EXPORT bm_status_t bm_get_card_id(bm_handle_t handle, unsigned int *card_id);
2513
+
2514
+ /**
2515
+ * @name bm_get_card_num
2516
+ * @brief get card number
2517
+ * @ingroup device management api
2518
+ *
2519
+ * @param [in] handle The device handle
2520
+ * @param [out] card_id
2521
+ * @retval BM_SUCCESS Succeeds.
2522
+ * Other code Fails.
2523
+ */
2524
+ DECL_EXPORT bm_status_t bm_get_card_num(unsigned int *card_num);
2525
+
2526
+ /**
2527
+ * @name bm_get_chip_num_from_card
2528
+ * @brief get chip number and start chip id from card
2529
+ * @ingroup device management api
2530
+ *
2531
+ * @param [in] handle The device handle
2532
+ * @param [out] chip_num
2533
+ * @param [out] dev_start_index
2534
+ * @retval BM_SUCCESS Succeeds.
2535
+ * Other code Fails.
2536
+ */
2537
+ DECL_EXPORT bm_status_t bm_get_chip_num_from_card(unsigned int card_id, unsigned int *chip_num, unsigned int *dev_start_index);
2538
+
2539
+ /**
2540
+ * @name bm_get_dynfreq_status
2541
+ * @brief get chip dynamic freq status
2542
+ * @ingroup device management api
2543
+ *
2544
+ * @param [in] handle The device handle
2545
+ * @param [out] dynfreq_status
2546
+ * @retval BM_SUCCESS Succeeds.
2547
+ * Other code Fails.
2548
+ */
2549
+ DECL_EXPORT bm_status_t bm_get_dynfreq_status(bm_handle_t handle, int *dynfreq_status);
2550
+
2551
+ /**
2552
+ * @name bm_change_dynfreq_status
2553
+ * @brief change(enable/disable) chip dynamic freq status
2554
+ * @ingroup device management api
2555
+ *
2556
+ * @param [in] handle The device handle
2557
+ * @param [in] new_status
2558
+ * @retval BM_SUCCESS Succeeds.
2559
+ * Other code Fails.
2560
+ */
2561
+ DECL_EXPORT bm_status_t bm_change_dynfreq_status(bm_handle_t handle, int new_status);
2562
+
2563
+ /**
2564
+ * @name bm_get_tpu_scalar_num
2565
+ * @brief To get the core number of TPU scalar
2566
+ * @ingroup bmlib_runtime
2567
+ *
2568
+ * @param [in] handle The device handle
2569
+ * @param [out] core_num The core number of TPU scalar
2570
+ * @retval BM_SUCCESS Succeeds.
2571
+ * Other code Fails.
2572
+ */
2573
+ DECL_EXPORT bm_status_t bm_get_tpu_scalar_num(bm_handle_t handle, unsigned int *core_num);
2574
+
2575
+ #define bm_get_tpu_core_num bm_get_tpu_scalar_num
2576
+
2577
+ #if defined(__cplusplus)
2578
+ }
2579
+ #endif
2580
+
2581
+ #endif /* BM_RUNTIME_H_ */
Baichuan2/src/include/bmruntime_interface.h ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*****************************************************************************
2
+ *
3
+ * Copyright (c) 2016-2026 by Sophgo Technologies Inc. All rights reserved.
4
+ *
5
+ * The material in this file is confidential and contains trade secrets
6
+ * of Sophgo Technologies Inc. This is proprietary information owned by
7
+ * Sophgo Technologies Inc. No part of this work may be disclosed,
8
+ * reproduced, copied, transmitted, or used in any way for any purpose,
9
+ * without the express written permission of Sophgo Technologies Inc.
10
+ *
11
+ *****************************************************************************/
12
+
13
+ /*****************************************************************************
14
+ * BMRuntime Interface is mainly for inference.
15
+ * Also we can use it for device computation from BMLang programming.
16
+ * Note: please use interface from bmlib_runtime.h for device memory operation.
17
+ ****************************************************************************/
18
+
19
+ #ifndef BMRUNTIME_INTERFACE_H_
20
+ #define BMRUNTIME_INTERFACE_H_
21
+
22
+ #include "bmdef.h"
23
+
24
+ #ifdef _WIN32
25
+ #define DECL_EXPORT _declspec(dllexport)
26
+ #define DECL_IMPORT _declspec(dllimport)
27
+ #else
28
+ #define DECL_EXPORT
29
+ #define DECL_IMPORT
30
+ #endif
31
+
32
+ #if defined(__cplusplus)
33
+ extern "C" {
34
+ #endif
35
+
36
+ /* --------------------------------------------------------------------------*/
37
+ /* interface for basic data type */
38
+
39
+ /* get data type byte size */
40
+ DECL_EXPORT size_t bmrt_data_type_size(bm_data_type_t dtype);
41
+
42
+ /*
43
+ dims array to bm_shape_t,
44
+ shape and dims should not be NULL, num_dims should not be larger than BM_MAX_DIMS_NUM */
45
+ DECL_EXPORT void bmrt_shape(bm_shape_t* shape, const int* dims, int num_dims);
46
+
47
+ /*
48
+ number of shape elements, shape should not be NULL and num_dims should not large than
49
+ BM_MAX_DIMS_NUM */
50
+ DECL_EXPORT uint64_t bmrt_shape_count(const bm_shape_t* shape);
51
+
52
+ /* compare whether two shape is same */
53
+ DECL_EXPORT bool bmrt_shape_is_same(const bm_shape_t* left, const bm_shape_t* right);
54
+
55
+ /*
56
+ fill a tensor with data type and shape, and st_mode = 0 as default.
57
+ tensor and p_bmrt should not be NULL, shape count should not be 0.
58
+ it will alloc device mem to tensor->device_mem, so user should bmrt_free_device(p_bmrt,
59
+ tensor->device_mem) to free it.*/
60
+ DECL_EXPORT bool bmrt_tensor(bm_tensor_t* tensor, void* p_bmrt, bm_data_type_t dtype, bm_shape_t shape);
61
+
62
+ /*
63
+ fill a tensor with data type and shape, and st_mode = 0 as default.
64
+ tensor and p_bmrt should not be NULL, shape count should not be 0.
65
+ it will alloc device mem to tensor->device_mem on devid-th device.*/
66
+ DECL_EXPORT bool bmrt_tensor_ex(bm_tensor_t* tensor, void* p_bmrt, int devid, bm_data_type_t dtype, bm_shape_t shape);
67
+
68
+ /* fill a tensor with device mem existed, tensor byte size should not large than device mem size */
69
+ DECL_EXPORT void bmrt_tensor_with_device(bm_tensor_t* tensor, bm_device_mem_t device_mem,
70
+ bm_data_type_t dtype, bm_shape_t shape);
71
+
72
+ /* get tensor bytes size, tensor should not be NULL */
73
+ DECL_EXPORT size_t bmrt_tensor_bytesize(const bm_tensor_t* tensor);
74
+
75
+ /* get tensor mem size allocated in device mem, tensor should not be NULL */
76
+ DECL_EXPORT size_t bmrt_tensor_device_size(const bm_tensor_t* tensor);
77
+
78
+ /* print net info for debug */
79
+ DECL_EXPORT void bmrt_print_network_info(const bm_net_info_t* net_info);
80
+
81
+ /* --------------------------------------------------------------------------*/
82
+ /**
83
+ * @name bmrt_create
84
+ * @brief To create the bmruntime with bm_handle.
85
+ * @ingroup bmruntime
86
+ *
87
+ * This API creates the bmruntime. It returns a void* pointer which is the pointer
88
+ * of bmruntime. Device id is set when get bm_handle;
89
+ *
90
+ * @param [in] bm_handle bm handle. It must be initialized by using bmlib.
91
+ *
92
+ * @retval void* the pointer of bmruntime
93
+ */
94
+ DECL_EXPORT void* bmrt_create(bm_handle_t bm_handle);
95
+
96
+ /* --------------------------------------------------------------------------*/
97
+ /**
98
+ * @name bmrt_create_ex
99
+ * @brief To create the bmruntime with one or more bm_handle.
100
+ * @ingroup bmruntime
101
+ *
102
+ * This API creates the bmruntime. It returns a void* pointer which is the pointer
103
+ * of bmruntime.
104
+ *
105
+ * @param [in] bm_handles bm handles. They must be initialized by using bmlib.
106
+ * @param [in] num_handles number of bm_handles.
107
+ *
108
+ * @retval void* the pointer of bmruntime
109
+ */
110
+ DECL_EXPORT void *bmrt_create_ex(bm_handle_t *bm_handles, int num_handles);
111
+
112
+ /**
113
+ * @name bmrt_destroy
114
+ * @brief To destroy the bmruntime pointer
115
+ * @ingroup bmruntime
116
+ *
117
+ * This API destroy the bmruntime.
118
+ *
119
+ * @param [in] p_bmrt Bmruntime that had been created
120
+ */
121
+ DECL_EXPORT void bmrt_destroy(void* p_bmrt);
122
+
123
+ /**
124
+ * @name bmrt_get_bm_handle
125
+ * @brief To get the BM runtime context.
126
+ * @ingroup bmruntime
127
+ *
128
+ * This API get the BM runtime context for using BMDNN, BMCV or BMLIB
129
+ *
130
+ * @param [in] p_bmrt Bmruntime that had been created
131
+ */
132
+ DECL_EXPORT void * bmrt_get_bm_handle(void* p_bmrt);
133
+
134
+ /**
135
+ * @name bmrt_load_bmodel
136
+ * @brief To load the bmodel which is created by BM compiler
137
+ * @ingroup bmruntime
138
+ *
139
+ * This API is to load bmodel created by BM compiler.
140
+ * After loading bmodel, we can run the inference of neuron network.
141
+ *
142
+ * @param [in] p_bmrt Bmruntime that had been created
143
+ * @param [in] bmodel_path Bmodel file directory.
144
+ *
145
+ * @retval true Load context sucess.
146
+ * @retval false Load context failed.
147
+ */
148
+ DECL_EXPORT bool bmrt_load_bmodel(void* p_bmrt, const char *bmodel_path);
149
+
150
+ /**
151
+ * @name bmrt_load_bmodel_data
152
+ * @brief To load the bmodel which is created by BM compiler from buffer
153
+ * @ingroup bmruntime
154
+ *
155
+ * This API is to load bmodel created by BM compiler.
156
+ * After loading bmodel, we can run the inference of neuron network.
157
+ * Different with bmrt_load_bmodel, bmodel is the data in host memory.
158
+ *
159
+ * @param [in] p_bmrt Bmruntime that had been created
160
+ * @param [in] bmodel_data Bmodel data pointer to buffer
161
+ * @param [in] size Bmodel data size
162
+ *
163
+ * @retval true Load context sucess.
164
+ * @retval false Load context failed.
165
+ */
166
+ DECL_EXPORT bool bmrt_load_bmodel_data(void* p_bmrt, const void * bmodel_data, size_t size);
167
+
168
+ /**
169
+ * @name bmrt_show_neuron_network
170
+ * @brief To print the name of all neuron network
171
+ * @ingroup bmruntime
172
+ *
173
+ * @param [in] p_bmrt Bmruntime that had been created
174
+ */
175
+ DECL_EXPORT void bmrt_show_neuron_network(void* p_bmrt);
176
+
177
+ /**
178
+ * @name bmrt_get_network_number
179
+ * @brief To get the number of neuron network in the bmruntime
180
+ * @ingroup bmruntime
181
+ *
182
+ * @param [in] p_bmrt Bmruntime that had been created
183
+ *
184
+ * @retval int value The number of neuron networks.
185
+ */
186
+ DECL_EXPORT int bmrt_get_network_number(void* p_bmrt);
187
+
188
+ /**
189
+ * @name bmrt_get_network_names
190
+ * @brief To get the names of all neuron network in the bmruntime
191
+ * @ingroup bmruntime
192
+ *
193
+ * @param [in] p_bmrt Bmruntime that had been created
194
+ * @param [out] network_names The names of all neuron networks. It should be declare as (const char** networks_ = NULL),
195
+ * and use as the param &networks_. After this API, user need to free(networks_) if user
196
+ * do not need it.
197
+ */
198
+ DECL_EXPORT void bmrt_get_network_names(void* p_bmrt, const char*** network_names);
199
+
200
+ /**
201
+ * @name bmrt_get_network_info
202
+ * @brief To get network info by net name
203
+ * @ingroup bmruntime
204
+ *
205
+ * @param [in] p_bmrt Bmruntime that had been created
206
+ * @param [in] net_name Network name
207
+ *
208
+ * @retval bm_net_info_t* Pointer to net info, needn't free by user; if net name not found, will return NULL.
209
+ */
210
+ DECL_EXPORT const bm_net_info_t* bmrt_get_network_info(void* p_bmrt, const char* net_name);
211
+
212
+ /**
213
+ * @name bmrt_launch_tensor
214
+ * @brief To launch the inference of the neuron network with setting input tensors
215
+ * @ingroup bmruntime
216
+ *
217
+ * This API supports the neuron nework that is static-compiled or dynamic-compiled
218
+ * After calling this API, inference on TPU is launched. And the CPU program will not
219
+ * be blocked. bm_thread_sync should be called to make sure inference finished.
220
+ * This API support multiple inputs, and multi thread safety
221
+ *
222
+ * @param [in] p_bmrt Bmruntime that had been created
223
+ * @param [in] net_name The name of the neuron network
224
+ * @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num].
225
+ * User should initialize each input tensor.
226
+ * @param [in] input_num Input number
227
+ * @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num].
228
+ * This interface will alloc devcie mem to store output data. User should free each
229
+ * device mem by bm_free_device after the result data not used.
230
+ * @param [in] output_num Output number
231
+ *
232
+ * @retval true Launch success.
233
+ * @retval false Launch failed.
234
+ */
235
+ DECL_EXPORT bool bmrt_launch_tensor(void* p_bmrt, const char * net_name, const bm_tensor_t input_tensors[], int input_num,
236
+ bm_tensor_t output_tensors[], int output_num);
237
+
238
+ /**
239
+ * @name bmrt_launch_tensor_ex
240
+ * @brief To launch the inference of the neuron network with setting input tensors
241
+ * @ingroup bmruntime
242
+ *
243
+ * This API supports the neuron nework that is static-compiled or dynamic-compiled
244
+ * After calling this API, inference on TPU is launched. And the CPU program will not
245
+ * be blocked. bm_thread_sync should be called to make sure inference finished.
246
+ * This API support multiple inputs, and multi thread safety
247
+ *
248
+ * @param [in] p_bmrt Bmruntime that had been created
249
+ * @param [in] net_name The name of the neuron network
250
+ * @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num],
251
+ * User should initialize each input tensor.
252
+ * @param [in] input_num Input number
253
+ * @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num].
254
+ * User can set device_mem or stmode of output tensors. If user_mem is true, this interface
255
+ * will use device mem of output_tensors to store output data, and not alloc device mem;
256
+ * Or it will alloc device mem to store output. If user_stmode is true, it will use stmode in
257
+ * each output tensor; Or stmode will be BM_STORE_1N as default.
258
+ * @param [in] output_num Output number
259
+ * @param [in] user_mem whether device_mem of output tensors are set
260
+ * @param [in] user_stmode whether stmode of output tensors are set
261
+ *
262
+ * @retval true Launch success.
263
+ * @retval false Launch failed.
264
+ */
265
+ DECL_EXPORT bool bmrt_launch_tensor_ex(void* p_bmrt, const char * net_name, const bm_tensor_t input_tensors[], int input_num,
266
+ bm_tensor_t output_tensors[], int output_num, bool user_mem, bool user_stmode);
267
+
268
+ /**
269
+ * @name bmrt_launch_data
270
+ * @brief To launch the inference of the neuron network with setting input datas in system memory
271
+ * @ingroup bmruntime
272
+ *
273
+ * This API supports the neuron nework that is static-compiled or dynamic-compiled
274
+ * After calling this API, inference on TPU is launched. And the CPU
275
+ * program will be blocked.
276
+ * This API support multiple inputs, and multi thread safety
277
+ *
278
+ * @param [in] p_bmrt Bmruntime that had been created
279
+ * @param [in] net_name The name of the neuron network
280
+ * @param [in] input_datas Array of input data, defined like void * input_datas[input_num]. User should
281
+ * initialize each data pointer as input.
282
+ * @param [in] input_shapes Array of input shape, defined like bm_shape_t input_shapes[input_num].
283
+ * User should set each input shape
284
+ * @param [in] input_num Input number
285
+ * @param [out] output_datas Array of output data, defined like void * output_datas[output_num].
286
+ * If user don't alloc each output data, set user_mem to false, and this api will alloc
287
+ * output mem, user should free each output mem when output data not used. Also
288
+ * user can alloc system memory for each output data by self and set user_mem = true.
289
+ * @param [out] output_shapes Array of output shape, defined like bm_shape_t output_shapes[output_num].
290
+ * It will store each output shape.
291
+ * @param [in] output_num Output number
292
+ * @param [in] user_mem whether output_datas[i] have allocated memory
293
+ *
294
+ * @retval true Launch success.
295
+ * @retval false Launch failed.
296
+ */
297
+ DECL_EXPORT bool bmrt_launch_data(void* p_bmrt, const char* net_name, void* const input_datas[],
298
+ const bm_shape_t input_shapes[], int input_num, void * output_datas[],
299
+ bm_shape_t output_shapes[], int output_num, bool user_mem);
300
+
301
+ /**
302
+ * @name bmrt_trace
303
+ * @brief To check runtime environment, and collect info for DEBUG
304
+ * @ingroup bmruntime
305
+ *
306
+ * This API is to collect runtime info for DEBUG. Expecially when launch result sudden mistake, call bmrt_trace
307
+ * will show whether device mems are broken, and other check info.
308
+ *
309
+ * @param [in] p_bmrt Bmruntime that had been created
310
+ */
311
+ DECL_EXPORT void bmrt_trace(void* p_bmrt);
312
+
313
+ /**
314
+ * @name bmrt_launch_tensor_multi_cores
315
+ * @brief To launch the inference of the neuron network with setting input tensors, and support multi core inference.
316
+ * @ingroup bmruntime
317
+ *
318
+ * This API supports the neuron nework that is static-compiled or dynamic-compiled
319
+ * After calling this API, inference on TPU is launched. And the CPU program will not
320
+ * be blocked. bm_thread_sync_from_core should be called to make sure inference is finished.
321
+ * This API support multiple inputs, and multi thread safety
322
+ *
323
+ * @param [in] p_bmrt Bmruntime that had been created
324
+ * @param [in] net_name The name of the neuron network
325
+ * @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num],
326
+ * User should initialize each input tensor.
327
+ * @param [in] input_num Input number
328
+ * @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num].
329
+ * User can set device_mem or stmode of output tensors. If user_mem is true, this interface
330
+ * will use device mem of output_tensors to store output data, and not alloc device mem;
331
+ * Or it will alloc device mem to store output. If user_stmode is true, it will use stmode in
332
+ * each output tensor; Or stmode will be BM_STORE_1N as default.
333
+ * @param [in] output_num Output number
334
+ * @param [in] user_mem whether device_mem of output tensors are set
335
+ * @param [in] user_stmode whether stmode of output tensors are set
336
+ * @param [in] core_list core id list those will be used to inference
337
+ * @param [in] core_num number of the core list
338
+ *
339
+ * @retval true Launch success.
340
+ * @retval false Launch failed.
341
+ */
342
+ DECL_EXPORT bool bmrt_launch_tensor_multi_cores(
343
+ void *p_bmrt,
344
+ const char *net_name,
345
+ const bm_tensor_t input_tensors[],
346
+ int input_num,
347
+ bm_tensor_t output_tensors[],
348
+ int output_num,
349
+ bool user_mem,
350
+ bool user_stmode,
351
+ const int *core_list,
352
+ int core_num);
353
+
354
+ /**
355
+ * @name bmrt_memcpy_s2d_parallel
356
+ * @brief To copy data from system memory to muti-devices memory in parallel
357
+ * @ingroup bmruntime
358
+ *
359
+ * This API only could be used when the p_bmrt is created with bmrt_create_ex on multi devices.
360
+ * After calling this API, datas[:tensor_num[0]] will be copied to the first device, and
361
+ * datas[tensor_num[0]:tensor_num[0]+tensor_num[1]] will be copied to the second device and so on.
362
+ * The process of copying data to different devices is done in parallel and to the same device is in sequence.
363
+ *
364
+ * @param [in] p_bmrt Bmruntime that had been created with multi bm_handles
365
+ * @param [in] tensors Array of tensors that will be copied to devices
366
+ * @param [in] datas Array of satas allocated in system memory
367
+ * @param [in] tensor_num Array of tensor_num that will be copied to each device
368
+ * @param [in] device_num Device number
369
+ */
370
+ DECL_EXPORT bool bmrt_memcpy_s2d_parallel(
371
+ void *p_bmrt,
372
+ bm_tensor_t tensors[],
373
+ void *datas[],
374
+ int tensor_num[],
375
+ int device_num);
376
+
377
+ /**
378
+ * @name bmrt_memcpy_d2s_parallel
379
+ * @brief To copy data from muti-devices memory to system memory in parallel
380
+ * @ingroup bmruntime
381
+ *
382
+ * This API only could be used when the p_bmrt is created with bmrt_create_ex on multi devices.
383
+ * After calling this API, tensors on the first device will be copied to datas[:tensor_num[0]] , and
384
+ * tensors on the second device will be copied to datas[tensor_num[0]:tensor_num[0]+tensor_num[1]] and so on.
385
+ * The process of copying data from different devices is done in parallel and from the same device is in sequence.
386
+ *
387
+ * @param [in] p_bmrt Bmruntime that had been created with multi bm_handles
388
+ * @param [in] datas Array of satas allocated in system memory
389
+ * @param [in] tensors Array of tensors that will be copied from devices
390
+ * @param [in] tensor_num Array of tensor_num that will be copied from each device
391
+ * @param [in] device_num Device number
392
+ */
393
+ DECL_EXPORT bool bmrt_memcpy_d2s_parallel(
394
+ void *p_bmrt,
395
+ void *datas[],
396
+ bm_tensor_t tensors[],
397
+ int tensor_num[],
398
+ int device_num);
399
+
400
+ #if defined (__cplusplus)
401
+ }
402
+ #endif
403
+
404
+ #endif
Baichuan2/src/include/sentencepiece/sentencepiece_processor.h ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2016 Google Inc.
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.!
14
+
15
+ #ifndef SENTENCEPIECE_PROCESSOR_H_
16
+ #define SENTENCEPIECE_PROCESSOR_H_
17
+
18
+ #include <cstring>
19
+ #include <memory>
20
+ #include <string>
21
+ #include <string_view>
22
+ #include <utility>
23
+ #include <vector>
24
+
25
+ #ifndef SWIG
26
+ namespace absl {
27
+ using std::string_view;
28
+ } // namespace absl
29
+ #endif // SWIG
30
+
31
+ namespace sentencepiece {
32
+ namespace util {
33
+
34
+ enum class StatusCode : int {
35
+ kOk = 0,
36
+ kCancelled = 1,
37
+ kUnknown = 2,
38
+ kInvalidArgument = 3,
39
+ kDeadlineExceeded = 4,
40
+ kNotFound = 5,
41
+ kAlreadyExists = 6,
42
+ kPermissionDenied = 7,
43
+ kResourceExhausted = 8,
44
+ kFailedPrecondition = 9,
45
+ kAborted = 10,
46
+ kOutOfRange = 11,
47
+ kUnimplemented = 12,
48
+ kInternal = 13,
49
+ kUnavailable = 14,
50
+ kDataLoss = 15,
51
+ kUnauthenticated = 16,
52
+ };
53
+
54
+ class Status {
55
+ public:
56
+ Status();
57
+ ~Status();
58
+ Status(StatusCode code, absl::string_view error_message);
59
+ Status(const Status &s);
60
+ void operator=(const Status &s);
61
+ bool operator==(const Status &s) const;
62
+ bool operator!=(const Status &s) const;
63
+ inline bool ok() const { return rep_ == nullptr; }
64
+
65
+ void set_error_message(const char *str);
66
+ const char *error_message() const;
67
+ const char *message() const { return error_message(); }
68
+ StatusCode code() const;
69
+ std::string ToString() const;
70
+
71
+ void IgnoreError();
72
+
73
+ private:
74
+ struct Rep;
75
+ std::unique_ptr<Rep> rep_;
76
+ };
77
+ } // namespace util
78
+
79
+ // SentencePieceProcessor:
80
+ // Simple and language independent tokenizer and de-tokenizer for
81
+ // Neural Network Machine Translation.
82
+ //
83
+ // SentencePieceProcessor provides Encode() and Decode() methods,
84
+ // which correspond to tokenization and de-tokenization respectively.
85
+ //
86
+ // - Encode:
87
+ // Given a raw source sentence, encode it into a sequence
88
+ // of pieces or vocabulary ids.
89
+ //
90
+ // - Decode:
91
+ // Given a sequence of pieces or vocabulary ids, decode it
92
+ // into a de-tokenized raw sentence.
93
+ //
94
+ // SentencePieceProcessor provides a lossless data conversion
95
+ // that allows the original raw sentence to be perfectly reconstructed
96
+ // from the encoded data, i.e., Decode(Encode(input)) == input.
97
+ // This characteristics is useful, as we can make the de-tokenization
98
+ // completely language independent.
99
+ //
100
+ // Usage:
101
+ // SentencePieceProcessor sp;
102
+ // sp.Load("//path/to/model");
103
+ //
104
+ // vector<string> sps;
105
+ // sp.Encode("hello world.", &sps).IgnoreError();
106
+ //
107
+ // vector<int> ids;
108
+ // sp.Encode("hello world.", &ids).IgnoreError();
109
+ //
110
+ // string detok;
111
+ // sp.Decode(sps, &detok);
112
+ // CHECK_EQ("hello world.", detok).IgnoreError();
113
+ //
114
+ // sp.Decode(ids, &detok);
115
+ // CHECK_EQ("hello world.", detok).IgnoreError();
116
+ //
117
+ // We can also use SentencePieceText which manages the byte-offsets
118
+ // between user input (output) and internal sentence pieces.
119
+ //
120
+ // SentencePieceText spt;
121
+ // sp.Encode("hello world.", &spt);
122
+ // // Emits the byte range of each piece.
123
+ // for (const auto &piece : spt.pieces()) {
124
+ // LOG(INFO) << piece.begin() << " " << piece.end();
125
+ // }
126
+ //
127
+ // sp.Decode({0, 1, 2, 3..}, &spt);
128
+ // for (const auto &piece : spt.pieces()) {
129
+ // LOG(INFO) << piece.begin() << " " << piece.end();
130
+ // }
131
+ //
132
+
133
+ class NBestSentencePieceText;
134
+ class ModelInterface;
135
+ class SentencePieceText;
136
+ class ModelProto;
137
+
138
+ namespace normalizer {
139
+ class Normalizer;
140
+ } // namespace normalizer
141
+
142
+ #ifndef SWIGGO
143
+ namespace util {
144
+ // Redefine std::string for serialized_proto interface as Python's string is
145
+ // a Unicode string. We can enforce the return value to be raw byte sequence
146
+ // with SWIG's typemap.
147
+ using bytes = std::string;
148
+ } // namespace util
149
+ #endif // SWIGGO
150
+
151
+ class NBestSentencePieceText;
152
+ class ModelInterface;
153
+ class SentencePieceText;
154
+ class SentencePieceText_SentencePiece;
155
+
156
+ // Wrapper class of SentencePieceText
157
+ // This wrapper only allows an immutable access to the proto and
158
+ // hides the actual implementation of protobuf.
159
+ // See sentencepiece.proto for the details of this class.
160
+ class ImmutableSentencePieceText_ImmutableSentencePiece {
161
+ public:
162
+ ImmutableSentencePieceText_ImmutableSentencePiece();
163
+ ~ImmutableSentencePieceText_ImmutableSentencePiece() = default;
164
+
165
+ const std::string &piece() const;
166
+ const std::string &surface() const;
167
+ uint32_t id() const;
168
+ uint32_t begin() const;
169
+ uint32_t end() const;
170
+
171
+ friend class ImmutableSentencePieceText;
172
+
173
+ private:
174
+ explicit ImmutableSentencePieceText_ImmutableSentencePiece(
175
+ const SentencePieceText_SentencePiece &sp);
176
+ const SentencePieceText_SentencePiece *sp_ = nullptr;
177
+ };
178
+
179
+ class ImmutableSentencePieceText {
180
+ public:
181
+ ImmutableSentencePieceText();
182
+ virtual ~ImmutableSentencePieceText();
183
+
184
+ std::vector<ImmutableSentencePieceText_ImmutableSentencePiece> pieces() const;
185
+
186
+ size_t pieces_size() const;
187
+ ImmutableSentencePieceText_ImmutableSentencePiece pieces(int index) const;
188
+
189
+ const std::string &text() const;
190
+ float score() const;
191
+
192
+ util::bytes SerializeAsString() const;
193
+
194
+ // Returns the actual mutable proto.
195
+ // Do not use this outside of SentencePieceProcessor, as
196
+ // it returns the raw pointer managed by the shared_ptr.
197
+ SentencePieceText *mutable_proto();
198
+
199
+ // Converts the utf8 byte spans into Unicode char span.
200
+ void ConvertToUnicodeSpans();
201
+
202
+ friend class ImmutableNBestSentencePieceText;
203
+
204
+ private:
205
+ explicit ImmutableSentencePieceText(const SentencePieceText &spt);
206
+ const SentencePieceText *spt_ = nullptr;
207
+ std::shared_ptr<SentencePieceText> rep_;
208
+ };
209
+
210
+ // Wrapper class of SentencePieceText
211
+ // This wrapper only allows an immutable access to the proto and
212
+ // hides the actual implementation of protobuf.
213
+ // See sentencepiece.proto for the details of this class.
214
+ class ImmutableNBestSentencePieceText {
215
+ public:
216
+ ImmutableNBestSentencePieceText();
217
+ virtual ~ImmutableNBestSentencePieceText();
218
+
219
+ std::vector<ImmutableSentencePieceText> nbests() const;
220
+
221
+ size_t nbests_size() const;
222
+ ImmutableSentencePieceText nbests(int index) const;
223
+
224
+ util::bytes SerializeAsString() const;
225
+
226
+ // Returns the actual mutable proto.
227
+ // Do not use this outside of SentencePieceProcessor, as
228
+ // it returns the raw pointer managed by the shared_ptr.
229
+ NBestSentencePieceText *mutable_proto();
230
+
231
+ void ConvertToUnicodeSpans();
232
+
233
+ private:
234
+ std::shared_ptr<NBestSentencePieceText> rep_;
235
+ };
236
+
237
+ class SentencePieceProcessor {
238
+ public:
239
+ SentencePieceProcessor();
240
+ virtual ~SentencePieceProcessor();
241
+
242
+ // Loads model from `filename`.
243
+ // Returns false if `filename` cannot be loaded.
244
+ virtual util::Status Load(absl::string_view filename);
245
+
246
+ // Loads model from `filename`.
247
+ // Crash if `filename` cannot be loaded.
248
+ virtual void LoadOrDie(absl::string_view filename);
249
+
250
+ // Loads model from `model_proto`.
251
+ // `model_proto` is copied.
252
+ virtual util::Status Load(const ModelProto &model_proto);
253
+
254
+ // Loads model from `model_proto`.
255
+ // `model_proto` is moved.
256
+ virtual util::Status Load(std::unique_ptr<ModelProto> model_proto);
257
+
258
+ // Loads model from `serialized`, which is a string-serialized model proto.
259
+ // Useful to load the model from a platform independent blob object.
260
+ virtual util::Status LoadFromSerializedProto(absl::string_view serialized);
261
+
262
+ // Returns the status. Encode/Decode methods are valid when status is OK.
263
+ virtual util::Status status() const;
264
+
265
+ // Sets encode extra_option sequence.
266
+ virtual util::Status SetEncodeExtraOptions(absl::string_view extra_option);
267
+
268
+ // Sets decode extra_option sequence.
269
+ virtual util::Status SetDecodeExtraOptions(absl::string_view extra_option);
270
+
271
+ //////////////////////////////////////////////////////////////
272
+ // Vocabulary restriction.
273
+ // Background:
274
+ // https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt
275
+
276
+ // Restricts the vocabulary set.
277
+ // The input sentences are encoded into the tokens in `valid_vocab`.
278
+ virtual util::Status SetVocabulary(
279
+ const std::vector<absl::string_view> &valid_vocab);
280
+
281
+ // Reverts the vocabulary restriction.
282
+ virtual util::Status ResetVocabulary();
283
+
284
+ // Loads the valid vocabulary set from `filename` in TSV format.
285
+ // Format: <token> <tab> <freq>.
286
+ // Any token with frequency < threshold will be treated as OOV.
287
+ virtual util::Status LoadVocabulary(absl::string_view filename,
288
+ int threshold);
289
+
290
+ //////////////////////////////////////////////////////////////
291
+ // Simple Encode and Decode API.
292
+ //
293
+ // Given a UTF8 input, encodes it into a sequence of sentence pieces.
294
+ virtual util::Status Encode(absl::string_view input,
295
+ std::vector<std::string> *pieces) const;
296
+
297
+ // Given a UTF8 input, encodes it into a sequence of ids.
298
+ virtual util::Status Encode(absl::string_view input,
299
+ std::vector<int> *ids) const;
300
+
301
+ // Given a sequence of pieces, decodes it into a detokenized output.
302
+ virtual util::Status Decode(const std::vector<std::string> &pieces,
303
+ std::string *detokenized) const;
304
+
305
+ // Given a sequence of pieces, decodes it into a detokenized output.
306
+ virtual util::Status Decode(const std::vector<absl::string_view> &pieces,
307
+ std::string *detokenized) const;
308
+
309
+ // Given a sequence of ids, decodes it into a detokenized output.
310
+ virtual util::Status Decode(const std::vector<int> &ids,
311
+ std::string *detokenized) const;
312
+
313
+ //////////////////////////////////////////////////////////////
314
+ // NBest API.
315
+ //
316
+ // Same as Encode, but returns nbest results.
317
+ virtual util::Status NBestEncode(
318
+ absl::string_view input, int nbest_size,
319
+ std::vector<std::vector<std::string>> *pieces) const;
320
+
321
+ // Same as Encode, but returns nbest results.
322
+ virtual util::Status NBestEncode(absl::string_view input, int nbest_size,
323
+ std::vector<std::vector<int>> *ids) const;
324
+
325
+ //////////////////////////////////////////////////////////////
326
+ // Sampling API.
327
+ //
328
+ // Unigram and BPE support sampling mode.
329
+ // - Unigram (--model_type=unigram):
330
+ // `nbest_size`: When `nbest_size` is positive value, approximately samples
331
+ // one segmentation from nbest candidates. When `nbest_size` is negative
332
+ // value, samples one segmentation from the hypotheses (Lattice) according to
333
+ // the generation probabilities using forward-filtering and backward-sampling
334
+ // algorithm.
335
+ // `alpha`: Smoothing parameter (inverse temperature). The best segmentation
336
+ // (Viterbi segmentation) is more likely sampled when setting larger alpha.
337
+ // When alpha is 0.0, one segmentation is uniformly sampled from the nbest or
338
+ // lattice. `nbest_size` and `alpha` correspond to parameters `l` and `alpha`
339
+ // in https://arxiv.org/abs/1804.10959 (nbest_size < 0 means l = infinity)
340
+ //
341
+ // - BPE (--model_type=bpe):
342
+ // `alpha`: The dropout probability `p` of bpe merge operations in
343
+ // https://arxiv.org/abs/1910.13267 Nbest-based sampling is not supported so
344
+ // nbest_size parameter is ignored in BPE.
345
+ virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
346
+ float alpha,
347
+ std::vector<std::string> *pieces) const;
348
+
349
+ // Same as above, but returns a sequence of ids.
350
+ virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
351
+ float alpha, std::vector<int> *ids) const;
352
+
353
+ //////////////////////////////////////////////////////////////
354
+ // SampleEncodeAndScore API.
355
+ //
356
+ // Sample `samples` many tokenisations from the segmentation lattice.
357
+ // These methods are only available in model_type=unigram.
358
+ //
359
+ // `alpha`: smoothing parameter (inverse temperature). The same as `alpha` in
360
+ // `Sample` method.
361
+ // 'wor`: If `wor` is true, the samples are taken without replacement, and the
362
+ // scores are the inclusion probabilities of the elements in the sample;
363
+ // otherwise the samples are taken with replacement and the scores are the
364
+ // log-probs of sample elements
365
+ // `include_best`: If `include_best` is true, the best tokenisation is always
366
+ // included in the sample, and the remaining elements are sampled excluding
367
+ // the best.
368
+ virtual util::Status SampleEncodeAndScore(
369
+ absl::string_view input, int num_samples, float alpha, bool wor,
370
+ bool include_best,
371
+ std::vector<std::pair<std::vector<std::string>, float>> *pieces) const;
372
+
373
+ // Same as above, but returns a sequence of ids.
374
+ virtual util::Status SampleEncodeAndScore(
375
+ absl::string_view input, int num_samples, float alpha, bool wor,
376
+ bool include_best,
377
+ std::vector<std::pair<std::vector<int>, float>> *ids) const;
378
+
379
+ //////////////////////////////////////////////////////////////
380
+ // Entropy API.
381
+ //
382
+ // This only available in model_type=unigram.
383
+ // Calculate entropy of possible tokenisations
384
+ virtual util::Status CalculateEntropy(absl::string_view input, float alpha,
385
+ float *entropy) const;
386
+
387
+ //////////////////////////////////////////////////////////////
388
+ // Advanced API returning SentencePieceText, which manages
389
+ // utf8-byte alignments between user-input/detokenized text
390
+ // and internal sentencepiece sequence.
391
+ //
392
+ // Given a UTF8 input, encodes it into SentencePieceText.
393
+ //
394
+ // When using these APIs, sentencepiece.pb.h header files must be included.
395
+ // We can also use ImutableSentencePieceText as follows.
396
+ //
397
+ // ImmutableSentencePieceText spt;
398
+ // Encode("hello", spt.mutable_proto()).IgnoreError();
399
+ // std::cout << spt.pieces_size() << std::endl;
400
+ virtual util::Status Encode(absl::string_view input,
401
+ SentencePieceText *spt) const;
402
+
403
+ virtual util::Status NBestEncode(absl::string_view input, int nbest_size,
404
+ NBestSentencePieceText *nbest_spt) const;
405
+
406
+ virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
407
+ float alpha, SentencePieceText *spt) const;
408
+
409
+ virtual util::Status SampleEncodeAndScore(
410
+ absl::string_view input, int num_samples, float alpha, bool wor,
411
+ bool include_best, NBestSentencePieceText *samples_spt) const;
412
+
413
+ // DEPRECATED: Remove this API and use std::vector<std::string_view>
414
+ virtual util::Status Decode(const std::vector<std::string> &pieces,
415
+ SentencePieceText *spt) const;
416
+
417
+ virtual util::Status Decode(const std::vector<absl::string_view> &pieces,
418
+ SentencePieceText *spt) const;
419
+
420
+ virtual util::Status Decode(const std::vector<int> &ids,
421
+ SentencePieceText *spt) const;
422
+ #ifdef SWIG
423
+ #define SPP_SWIG_CHECK_AND_THROW \
424
+ if (!status.ok()) throw status;
425
+ #else
426
+ #define SPP_SWIG_CHECK_AND_THROW \
427
+ if (!status.ok()) { \
428
+ }
429
+ #endif // SWIG
430
+
431
+ #define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
432
+ OutType output; \
433
+ const auto status = FuncName(__VA_ARGS__, &output); \
434
+ SPP_SWIG_CHECK_AND_THROW; \
435
+ return output;
436
+
437
+ #define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...) \
438
+ OutType output; \
439
+ const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
440
+ SPP_SWIG_CHECK_AND_THROW; \
441
+ return output.SerializeAsString();
442
+
443
+ #define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...) \
444
+ OutType output; \
445
+ const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
446
+ SPP_SWIG_CHECK_AND_THROW; \
447
+ return output;
448
+
449
+ //////////////////////////////////////////////////////////////
450
+ // Handy methods that return the result directly.
451
+ // These functions ignore internal errors.
452
+ virtual std::vector<std::string> EncodeAsPieces(
453
+ absl::string_view input) const {
454
+ DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<std::string>, input);
455
+ }
456
+
457
+ virtual std::vector<int> EncodeAsIds(absl::string_view input) const {
458
+ DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<int>, input);
459
+ }
460
+
461
+ virtual std::vector<std::vector<std::string>> NBestEncodeAsPieces(
462
+ absl::string_view input, int nbest_size) const {
463
+ DEFINE_SPP_DIRECT_FUNC_IMPL(
464
+ NBestEncode, std::vector<std::vector<std::string>>, input, nbest_size);
465
+ }
466
+
467
+ virtual std::vector<std::vector<int>> NBestEncodeAsIds(
468
+ absl::string_view input, int nbest_size) const {
469
+ DEFINE_SPP_DIRECT_FUNC_IMPL(NBestEncode, std::vector<std::vector<int>>,
470
+ input, nbest_size);
471
+ }
472
+
473
+ virtual std::vector<std::string> SampleEncodeAsPieces(absl::string_view input,
474
+ int nbest_size,
475
+ float alpha) const {
476
+ DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<std::string>, input,
477
+ nbest_size, alpha);
478
+ }
479
+
480
+ virtual std::vector<int> SampleEncodeAsIds(absl::string_view input,
481
+ int nbest_size,
482
+ float alpha) const {
483
+ DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<int>, input,
484
+ nbest_size, alpha);
485
+ }
486
+
487
+ virtual std::vector<std::pair<std::vector<std::string>, float>>
488
+ SampleEncodeAndScoreAsPieces(absl::string_view input, int num_samples,
489
+ float alpha, bool wor, bool include_best) const {
490
+ using _T = std::vector<std::pair<std::vector<std::string>, float>>;
491
+ DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples,
492
+ alpha, wor, include_best);
493
+ }
494
+
495
+ virtual std::vector<std::pair<std::vector<int>, float>>
496
+ SampleEncodeAndScoreAsIds(absl::string_view input, int num_samples,
497
+ float alpha, bool wor, bool include_best) const {
498
+ using _T = std::vector<std::pair<std::vector<int>, float>>;
499
+ DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples,
500
+ alpha, wor, include_best);
501
+ }
502
+
503
+ // DEPRECATED: Remove this API and use std::vector<std::string_view>
504
+ virtual std::string DecodePieces(
505
+ const std::vector<std::string> &pieces) const {
506
+ DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces);
507
+ }
508
+
509
+ virtual std::string DecodePieces(
510
+ const std::vector<absl::string_view> &pieces) const {
511
+ DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces);
512
+ }
513
+
514
+ virtual std::string DecodeIds(const std::vector<int> &ids) const {
515
+ DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, ids);
516
+ }
517
+
518
+ virtual float CalculateEntropy(absl::string_view text, float alpha) const {
519
+ DEFINE_SPP_DIRECT_FUNC_IMPL(CalculateEntropy, float, text, alpha);
520
+ }
521
+
522
+ //////////////////////////////////////////////////////////////
523
+ // SerializedProto API. (DEPRECATED). Use ImmutableProto API.
524
+ // They are used in Python interface. Returns serialized proto.
525
+ // In python module, we can get access to the full Proto after
526
+ // deserialzing the returned byte sequence.
527
+ virtual util::bytes EncodeAsSerializedProto(absl::string_view input) const {
528
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(Encode, ImmutableSentencePieceText, input);
529
+ }
530
+
531
+ virtual util::bytes SampleEncodeAsSerializedProto(absl::string_view input,
532
+ int nbest_size,
533
+ float alpha) const {
534
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText,
535
+ input, nbest_size, alpha);
536
+ }
537
+
538
+ virtual util::bytes NBestEncodeAsSerializedProto(absl::string_view input,
539
+ int nbest_size) const {
540
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(
541
+ NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size);
542
+ }
543
+
544
+ virtual util::bytes SampleEncodeAndScoreAsSerializedProto(
545
+ absl::string_view input, int num_samples, float alpha, bool wor,
546
+ bool include_best) const {
547
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncodeAndScore,
548
+ ImmutableNBestSentencePieceText, input,
549
+ num_samples, alpha, wor, include_best);
550
+ }
551
+
552
+ // TODO(taku): Remove this API and use std::vector<std::string_view>
553
+ virtual util::bytes DecodePiecesAsSerializedProto(
554
+ const std::vector<std::string> &pieces) const {
555
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText,
556
+ pieces);
557
+ }
558
+
559
+ virtual util::bytes DecodePiecesAsSerializedProto(
560
+ const std::vector<absl::string_view> &pieces) const {
561
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText,
562
+ pieces);
563
+ }
564
+
565
+ virtual util::bytes DecodeIdsAsSerializedProto(
566
+ const std::vector<int> &ids) const {
567
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids);
568
+ }
569
+
570
+ //////////////////////////////////////////////////////////////
571
+ // ImmutableProto API.
572
+ virtual ImmutableSentencePieceText EncodeAsImmutableProto(
573
+ absl::string_view input) const {
574
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Encode, ImmutableSentencePieceText, input);
575
+ }
576
+
577
+ virtual ImmutableSentencePieceText SampleEncodeAsImmutableProto(
578
+ absl::string_view input, int nbest_size, float alpha) const {
579
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText,
580
+ input, nbest_size, alpha);
581
+ }
582
+
583
+ virtual ImmutableNBestSentencePieceText NBestEncodeAsImmutableProto(
584
+ absl::string_view input, int nbest_size) const {
585
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(
586
+ NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size);
587
+ }
588
+
589
+ virtual ImmutableNBestSentencePieceText SampleEncodeAndScoreAsImmutableProto(
590
+ absl::string_view input, int num_samples, float alpha, bool wor,
591
+ bool include_best) const {
592
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncodeAndScore,
593
+ ImmutableNBestSentencePieceText, input,
594
+ num_samples, alpha, wor, include_best);
595
+ }
596
+
597
+ // TODO(taku): Remove this API and use std::vector<std::string_view>
598
+ virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto(
599
+ const std::vector<std::string> &pieces) const {
600
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces);
601
+ }
602
+
603
+ virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto(
604
+ const std::vector<absl::string_view> &pieces) const {
605
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces);
606
+ }
607
+
608
+ virtual ImmutableSentencePieceText DecodeIdsAsImmutableProto(
609
+ const std::vector<int> &ids) const {
610
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids);
611
+ }
612
+
613
+ #undef DEFINE_SPP_DIRECT_FUNC_IMPL
614
+ #undef DEFINE_SPP_SERIALIZED_PROTO_IMPL
615
+ #undef DEFINE_SPP_IMMUTABLE_PROTO_IMPL
616
+
617
+ //////////////////////////////////////////////////////////////
618
+ // Vocabulary management methods.
619
+ //
620
+ // Returns the size of sentence pieces, which is the same as
621
+ // the size of vocabulary for NMT.
622
+ virtual int GetPieceSize() const;
623
+
624
+ // Returns the vocab id of `piece`.
625
+ // Returns UNK(0) if `piece` is unknown.
626
+ virtual int PieceToId(absl::string_view piece) const;
627
+
628
+ // Returns the string representation of vocab with `id`.
629
+ virtual const std::string &IdToPiece(int id) const;
630
+
631
+ // Returns the score of `id`.
632
+ // Usually score is an emission log probability of unigram language
633
+ // model.
634
+ virtual float GetScore(int id) const;
635
+
636
+ // Returns true if `id` is unknown symbol.
637
+ virtual bool IsUnknown(int id) const;
638
+
639
+ // Returns true if `id` is control symbol.
640
+ virtual bool IsControl(int id) const;
641
+
642
+ // Returns true if `id` is unused symbol.
643
+ virtual bool IsUnused(int id) const;
644
+
645
+ // Returns true if `id` is byte symbol.
646
+ virtual bool IsByte(int id) const;
647
+
648
+ // Returns the reserved id.
649
+ // Returns -1 if not defined.
650
+
651
+ // Returns unknown (<unk>) id.
652
+ virtual int unk_id() const;
653
+
654
+ // Returns BOS (<s>) id.
655
+ virtual int bos_id() const;
656
+
657
+ // Returns EOS (</s>) id.
658
+ virtual int eos_id() const;
659
+
660
+ // Returns PAD (<pad>) id.
661
+ virtual int pad_id() const;
662
+
663
+ //////////////////////////////////////////////////////////////
664
+ // Model management.
665
+ //
666
+ // Allows injection of a mock model instance. `model` is moved.
667
+ void SetModel(std::unique_ptr<ModelInterface> &&model);
668
+
669
+ // Allows injection of a normalizer instance. `normalizer` is moved.
670
+ void SetNormalizer(std::unique_ptr<normalizer::Normalizer> &&normalizer);
671
+
672
+ // Returns immutable model proto. Useful to obtain extended
673
+ // or experimental parameters encoded in model_proto.
674
+ const ModelProto &model_proto() const;
675
+
676
+ // returns immutable model proto as std::string.
677
+ // Useful to save the state of this instance via Python's pickle object.
678
+ util::bytes serialized_model_proto() const;
679
+
680
+ private:
681
+ enum ExtraOption { REVERSE, BOS, EOS, UNK_PIECE };
682
+
683
+ util::Status ParseExtraOptions(absl::string_view extra_option,
684
+ std::vector<ExtraOption> *extra_options) const;
685
+
686
+ util::Status ApplyExtraOptions(const std::vector<ExtraOption> &extra_options,
687
+ SentencePieceText *spt) const;
688
+
689
+ util::Status PopulateSentencePieceText(
690
+ absl::string_view input, absl::string_view normalized,
691
+ const std::vector<size_t> &norm_to_orig,
692
+ const std::vector<std::pair<absl::string_view, int>> &result,
693
+ SentencePieceText *spt) const;
694
+
695
+ std::unique_ptr<ModelInterface> model_;
696
+ std::unique_ptr<normalizer::Normalizer> normalizer_;
697
+ std::unique_ptr<normalizer::Normalizer> denormalizer_;
698
+
699
+ // Underlying model protocol buffer. The same lifetime as model_.
700
+ std::unique_ptr<ModelProto> model_proto_;
701
+
702
+ std::vector<ExtraOption> encode_extra_options_;
703
+ std::vector<ExtraOption> decode_extra_options_;
704
+ };
705
+
706
+ // Set seed value of random generator.
707
+ // Do not set static_cast<unique_int>(-1),
708
+ // as this seed is reserved for initializing from
709
+ // std::random_device.
710
+ void SetRandomGeneratorSeed(unsigned int seed);
711
+
712
+ // IO related functions to absorb model formats.
713
+ namespace io {
714
+ // Loads `model_proto` from `filename`.
715
+ // We can instantiate SentencePieceProcessor as follows:
716
+ //
717
+ // auto model_proto = absl::make_unique<ModelProto>();
718
+ // io::LoadModelProto("//path/spm.model", model_proto.get());
719
+ // SentencePieceProcessor sp;
720
+ // CHECK_OK(sp.Load(std::move(model_proto)));
721
+ util::Status LoadModelProto(absl::string_view, ModelProto *model_proto);
722
+
723
+ // Saves `model_proto` as `filename`.
724
+ util::Status SaveModelProto(absl::string_view, const ModelProto &model_proto);
725
+ } // namespace io
726
+ } // namespace sentencepiece
727
+ #endif // SENTENCEPIECE_PROCESSOR_H_
Baichuan2/src/lib_pcie/libbmlib.so ADDED
Binary file (195 kB). View file
 
Baichuan2/src/lib_pcie/libbmrt.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:621e33823dca470275e09570324a567ce4a30fa6100ac9e52742bb9e1ee02f45
3
+ size 2966400
Baichuan2/src/lib_pcie/libbmrt.so.1.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:621e33823dca470275e09570324a567ce4a30fa6100ac9e52742bb9e1ee02f45
3
+ size 2966400
Baichuan2/src/lib_pcie/libsentencepiece.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68811cd99e6e1a58572372f14f3b7a02cf98bc98f5d46d24c406be65a94b53e8
3
+ size 2858304
Baichuan2/src/lib_soc/libbmlib.so ADDED
Binary file (191 kB). View file
 
Baichuan2/src/lib_soc/libbmrt.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cff807807fcc8c6a9d16353e389422d434ae2b79c8bc191266d0eb5a69b3d97d
3
+ size 2915352
Baichuan2/src/lib_soc/libbmrt.so.1.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cff807807fcc8c6a9d16353e389422d434ae2b79c8bc191266d0eb5a69b3d97d
3
+ size 2915352
Baichuan2/src/lib_soc/libsentencepiece.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b1c1ece6c62265ee879cf5876d31e82580c3ee88c2cb627b8ac3eaf35695bde
3
+ size 3032062
Baichuan2/web_demo/CMakeLists.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 2.8)
2
+ project(baichuan2)
3
+
4
+ if (NOT DEFINED TARGET_ARCH)
5
+ set(TARGET_ARCH pcie)
6
+ endif()
7
+
8
+ set(CMAKE_INSTALL_PREFIX install)
9
+
10
+ if (${CMAKE_HOST_SYSTEM_PROCESSOR} STREQUAL "aarch64")
11
+ add_definitions(-DSOC_TARGET)
12
+ link_directories(${PROJECT_SOURCE_DIR}/../src/lib_soc)
13
+ message("SoC mode, starting......")
14
+ elseif (${TARGET_ARCH} STREQUAL "pcie")
15
+ add_definitions(-DPCIE_TARGET)
16
+ link_directories(${PROJECT_SOURCE_DIR}/../src/lib_pcie)
17
+ message("Pcie mode, starting......")
18
+ elseif (${TARGET_ARCH} STREQUAL "soc")
19
+ add_definitions(-DSOC_TARGET)
20
+ set(CMAKE_C_COMPILER aarch64-linux-gnu-gcc)
21
+ set(CMAKE_ASM_COMPILER aarch64-linux-gnu-gcc)
22
+ set(CMAKE_CXX_COMPILER aarch64-linux-gnu-g++)
23
+ link_directories(${PROJECT_SOURCE_DIR}/lib_soc)
24
+ message("SoC mode, starting......")
25
+ endif()
26
+
27
+
28
+
29
+
30
+ include_directories(${PROJECT_SOURCE_DIR}/../src/include)
31
+
32
+ add_definitions(-DDEBUG --std=c++17 -fPIC -Wall -Werror)
33
+ set(CMAKE_BUILD_TYPE "Debug")
34
+
35
+ add_library(tpuchat SHARED chat.cpp)
36
+ target_link_libraries(tpuchat bmrt bmlib sentencepiece)
Baichuan2/web_demo/chat.cpp ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //===----------------------------------------------------------------------===//
2
+ //
3
+ // Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
4
+ //
5
+ // TPU-MLIR is licensed under the 2-Clause BSD License except for the
6
+ // third-party components.
7
+ //
8
+ //===----------------------------------------------------------------------===//
9
+
10
+ #include <iostream>
11
+ #include <cstdlib>
12
+ #include <vector>
13
+ #include <assert.h>
14
+ #include <chrono>
15
+ #include <algorithm>
16
+ #include "memory.h"
17
+ #include "sentencepiece/sentencepiece_processor.h"
18
+ #include "bmruntime_interface.h"
19
+ #include <getopt.h>
20
+
21
+ static const int NUM_LAYERS = 32;
22
+ static const int MAX_LEN = 512;
23
+ static const float ATTENTION_MASK = -1000.;
24
+
25
+ static const std::string TOKENIZER_MODEL = "tokenizer.model";
26
+
27
+ // #define EXPORT_RESULTS
28
+ #ifdef EXPORT_RESULTS
29
+ #include "cnpy.h"
30
+ static cnpy::npz_t map;
31
+
32
+ template <typename T>
33
+ static void add_array(std::string name, bm_handle_t bm_handle,
34
+ const bm_device_mem_t &dst) {
35
+ std::vector<T> data(dst.size / sizeof(T));
36
+ bm_memcpy_d2s(bm_handle, data.data(), dst);
37
+ cnpy::npz_add_array(map, name, data);
38
+ }
39
+
40
+ static void save_array(std::string filename) {
41
+ cnpy::npz_save_all(filename, map);
42
+ }
43
+ #endif
44
+
45
+ class Baichuan2 {
46
+ public:
47
+ void init(int devid, const std::string model, const std::string tokenizer_path);
48
+ void chat();
49
+ void deinit();
50
+ std::string name;
51
+ std::string history = "";
52
+ int round = 0;
53
+ int token_length;
54
+ int EOS;
55
+ std::string predict_next_token();
56
+ std::string predict_first_token(const std::string &input_str);
57
+
58
+ private:
59
+ int forward_first(std::vector<int> &tokens);
60
+ int forward_next();
61
+ void load_sentencepiece(const std::string &tokenizer_path);
62
+
63
+ private:
64
+ std::vector<bm_handle_t> handles;
65
+ bm_handle_t bm_handle;
66
+ void *p_bmrt;
67
+ sentencepiece::SentencePieceProcessor sentencepiece;
68
+ const bm_net_info_t *net_blocks[NUM_LAYERS];
69
+ const bm_net_info_t *net_blocks_cache[NUM_LAYERS];
70
+ const bm_net_info_t *net_embed;
71
+ const bm_net_info_t *net_lm;
72
+ bm_tensor_t inputs_embed_512, outputs_embed_512;
73
+ bm_tensor_t inputs_lm, outputs_lm;
74
+ bm_tensor_t inputs_pid, next_pid, inputs_attention, next_attention;
75
+ bm_tensor_t past_key[NUM_LAYERS], past_value[NUM_LAYERS];
76
+ bm_tensor_t present_key[NUM_LAYERS], present_value[NUM_LAYERS];
77
+ bm_tensor_t present_key_cache, present_value_cache;
78
+ std::string name_embed;
79
+ std::string name_lm;
80
+ std::string name_blocks[NUM_LAYERS];
81
+ std::string name_blocks_cache[NUM_LAYERS];
82
+ };
83
+
84
+ void Baichuan2::load_sentencepiece(const std::string &model) {
85
+ printf("Load %s ... ", model.c_str());
86
+ auto status = sentencepiece.Load(model);
87
+ if (!status.ok()) {
88
+ std::cout << status.ToString() << std::endl;
89
+ exit(-1);
90
+ }
91
+ EOS = sentencepiece.eos_id();
92
+ printf("Done!\n");
93
+ }
94
+
95
+ void Baichuan2::init(int devid, const std::string model, const std::string tokenizer_path) {
96
+ load_sentencepiece(tokenizer_path);
97
+ // request bm_handle
98
+ bm_status_t status = bm_dev_request(&bm_handle, devid);
99
+ assert(BM_SUCCESS == status);
100
+
101
+ // create bmruntime
102
+ p_bmrt = bmrt_create(bm_handle);
103
+ assert(NULL != p_bmrt);
104
+
105
+ // load bmodel by file
106
+ printf("Model[%s] loading ....\n", model.c_str());
107
+ bool ret = bmrt_load_bmodel(p_bmrt, model.c_str());
108
+ assert(true == ret);
109
+ printf("Done!\n");
110
+ // net names
111
+ name_embed = "embedding";
112
+ name_lm = "lm_head";
113
+ for (int i = 0; i < NUM_LAYERS; i++) {
114
+ name_blocks[i] = "block_" + std::to_string(i);
115
+ name_blocks_cache[i] = "block_cache_" + std::to_string(i);
116
+ }
117
+
118
+ // net infos
119
+ net_embed = bmrt_get_network_info(p_bmrt, name_embed.c_str());
120
+ net_lm = bmrt_get_network_info(p_bmrt, name_lm.c_str());
121
+ for (int i = 0; i < NUM_LAYERS; i++) {
122
+ net_blocks[i] = bmrt_get_network_info(p_bmrt, name_blocks[i].c_str());
123
+ net_blocks_cache[i] =
124
+ bmrt_get_network_info(p_bmrt, name_blocks_cache[i].c_str());
125
+ }
126
+
127
+ // net device mem
128
+ ret = bmrt_tensor(&inputs_embed_512, p_bmrt, net_embed->input_dtypes[0],
129
+ net_embed->stages[1].input_shapes[0]);
130
+ assert(true == ret);
131
+
132
+ ret = bmrt_tensor(&outputs_embed_512, p_bmrt, net_embed->output_dtypes[0],
133
+ net_embed->stages[1].output_shapes[0]);
134
+ assert(true == ret);
135
+
136
+ ret = bmrt_tensor(&inputs_pid, p_bmrt, net_blocks[0]->input_dtypes[1],
137
+ net_blocks[0]->stages[0].input_shapes[1]);
138
+ assert(true == ret);
139
+
140
+ ret = bmrt_tensor(&inputs_attention, p_bmrt, net_blocks[0]->input_dtypes[2],
141
+ net_blocks[0]->stages[0].input_shapes[2]);
142
+ assert(true == ret);
143
+
144
+ ret = bmrt_tensor(&next_pid, p_bmrt, net_blocks_cache[0]->input_dtypes[1],
145
+ net_blocks_cache[0]->stages[0].input_shapes[1]);
146
+ assert(true == ret);
147
+
148
+ ret =
149
+ bmrt_tensor(&next_attention, p_bmrt, net_blocks_cache[0]->input_dtypes[2],
150
+ net_blocks_cache[0]->stages[0].input_shapes[2]);
151
+ assert(true == ret);
152
+
153
+ for (int i = 0; i < NUM_LAYERS; i++) {
154
+ ret = bmrt_tensor(&past_key[i], p_bmrt, net_blocks[0]->output_dtypes[1],
155
+ net_blocks[0]->stages[0].output_shapes[1]);
156
+ assert(true == ret);
157
+ ret = bmrt_tensor(&past_value[i], p_bmrt, net_blocks[0]->output_dtypes[2],
158
+ net_blocks[0]->stages[0].output_shapes[2]);
159
+ assert(true == ret);
160
+ ret = bmrt_tensor(&present_key[i], p_bmrt, net_blocks[0]->output_dtypes[1],
161
+ net_blocks[0]->stages[0].output_shapes[1]);
162
+ assert(true == ret);
163
+ ret = bmrt_tensor(&present_value[i], p_bmrt, net_blocks[0]->output_dtypes[2],
164
+ net_blocks[0]->stages[0].output_shapes[2]);
165
+ assert(true == ret);
166
+ }
167
+ ret = bmrt_tensor(&present_key_cache, p_bmrt, net_blocks_cache[0]->output_dtypes[1],
168
+ net_blocks_cache[0]->stages[0].output_shapes[1]);
169
+ assert(true == ret);
170
+ ret = bmrt_tensor(&present_value_cache, p_bmrt, net_blocks_cache[0]->output_dtypes[2],
171
+ net_blocks_cache[0]->stages[0].output_shapes[2]);
172
+ assert(true == ret);
173
+
174
+ ret = bmrt_tensor(&inputs_lm, p_bmrt, net_lm->input_dtypes[0],
175
+ net_lm->stages[0].input_shapes[0]);
176
+ assert(true == ret);
177
+ ret = bmrt_tensor(&outputs_lm, p_bmrt, net_lm->output_dtypes[0],
178
+ net_lm->stages[0].output_shapes[0]);
179
+ assert(true == ret);
180
+ }
181
+
182
+ void Baichuan2::deinit() {
183
+ bm_free_device(bm_handle, inputs_embed_512.device_mem);
184
+ bm_free_device(bm_handle, outputs_embed_512.device_mem);
185
+ bm_free_device(bm_handle, inputs_lm.device_mem);
186
+ bm_free_device(bm_handle, outputs_lm.device_mem);
187
+ bm_free_device(bm_handle, inputs_pid.device_mem);
188
+ bm_free_device(bm_handle, next_pid.device_mem);
189
+ bm_free_device(bm_handle, inputs_attention.device_mem);
190
+ bm_free_device(bm_handle, next_attention.device_mem);
191
+ bm_free_device(bm_handle, present_key_cache.device_mem);
192
+ bm_free_device(bm_handle, present_value_cache.device_mem);
193
+ for (int i = 0; i < NUM_LAYERS; i++) {
194
+ bm_free_device(bm_handle, past_key[i].device_mem);
195
+ bm_free_device(bm_handle, past_value[i].device_mem);
196
+ bm_free_device(bm_handle, present_key[i].device_mem);
197
+ bm_free_device(bm_handle, present_value[i].device_mem);
198
+ }
199
+ bmrt_destroy(p_bmrt);
200
+ for (auto h : handles) {
201
+ bm_dev_free(h);
202
+ }
203
+ }
204
+
205
+
206
+
207
+ int Baichuan2::forward_first(std::vector<int> &tokens) {
208
+ int input_ids[MAX_LEN] = {0}; // start token
209
+ int position_id[MAX_LEN] = {0};
210
+ float attention_mask[MAX_LEN * MAX_LEN] = {0};
211
+ token_length = tokens.size();
212
+
213
+ std::copy(tokens.begin(), tokens.end(), input_ids);
214
+ for (int i = 0; i < token_length; i++) {
215
+ position_id[i] = i;
216
+ }
217
+
218
+ for (int i = 0; i < MAX_LEN; i++) {
219
+ for (int j = 0; j < MAX_LEN; j++) {
220
+ if (j <= i && i < token_length) {
221
+ } else {
222
+ attention_mask[i * MAX_LEN + j] = ATTENTION_MASK;
223
+ }
224
+ }
225
+ }
226
+
227
+ // forward embeding
228
+ bm_memcpy_s2d(bm_handle, inputs_embed_512.device_mem, (void *)input_ids);
229
+ auto ret =
230
+ bmrt_launch_tensor_ex(p_bmrt, name_embed.c_str(), &inputs_embed_512, 1,
231
+ &outputs_embed_512, 1, true, false);
232
+ assert(ret);
233
+ bm_thread_sync(bm_handle);
234
+
235
+ // forward blocks
236
+ bm_memcpy_s2d(bm_handle, inputs_pid.device_mem, (void *)position_id);
237
+ bm_memcpy_s2d(bm_handle, inputs_attention.device_mem, (void *)attention_mask);
238
+ auto inputs_embed = outputs_embed_512;
239
+ inputs_embed.shape = net_blocks[0]->stages[0].input_shapes[0];
240
+ bm_tensor_t inputs_block[3] = {inputs_embed, inputs_pid, inputs_attention};
241
+ for (int i = 0; i < NUM_LAYERS; i++) {
242
+ bm_tensor_t outputs_block[3] = {inputs_embed, past_key[i], past_value[i]};
243
+ ret = bmrt_launch_tensor_ex(p_bmrt, name_blocks[i].c_str(), inputs_block, 3,
244
+ outputs_block, 3, true, false);
245
+ assert(ret);
246
+ bm_thread_sync(bm_handle);
247
+ }
248
+ int bytes = inputs_embed.device_mem.size / MAX_LEN;
249
+ bm_memcpy_d2d_byte(bm_handle, inputs_lm.device_mem, 0,
250
+ inputs_embed.device_mem, (token_length - 1) * bytes,
251
+ bytes);
252
+ ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm, 1,
253
+ &outputs_lm, 1, true, false);
254
+ bm_thread_sync(bm_handle);
255
+
256
+ int token = 0;
257
+ bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm.device_mem);
258
+ return token;
259
+ }
260
+
261
+ int Baichuan2::forward_next() {
262
+ float attention_mask[MAX_LEN + 1] = {0};
263
+ for (int i = token_length - 1; i < MAX_LEN; i++) {
264
+ attention_mask[i] = ATTENTION_MASK;
265
+ }
266
+ int32_t position_id = token_length - 1;
267
+ // embedding
268
+ outputs_lm.shape = net_embed->stages[0].input_shapes[0];
269
+ auto ret = bmrt_launch_tensor_ex(p_bmrt, name_embed.c_str(), &outputs_lm, 1,
270
+ &inputs_lm, 1, true, false);
271
+ assert(ret);
272
+ bm_thread_sync(bm_handle);
273
+
274
+ // blocks
275
+ bm_memcpy_s2d(bm_handle, next_attention.device_mem, (void *)attention_mask);
276
+ bm_memcpy_s2d(bm_handle, next_pid.device_mem, (void *)&position_id);
277
+ auto inputs_embed = inputs_lm;
278
+ inputs_embed.shape = net_blocks_cache[0]->stages[0].input_shapes[0];
279
+ int bytes = bm_mem_get_device_size(present_key_cache.device_mem);
280
+ int token_offset = (token_length - 1) * bytes;
281
+ for (int i = 0; i < NUM_LAYERS; i++) {
282
+ bm_tensor_t inputs_block[5] = {inputs_embed, next_pid, next_attention,
283
+ past_key[i], past_value[i]};
284
+ bm_tensor_t outputs_block[3] = {inputs_embed, present_key_cache, present_value_cache};
285
+ ret = bmrt_launch_tensor_ex(p_bmrt, name_blocks_cache[i].c_str(),
286
+ inputs_block, 5, outputs_block, 3, true, false);
287
+ assert(ret);
288
+ bm_thread_sync(bm_handle);
289
+ bm_memcpy_d2d_byte(bm_handle, past_key[i].device_mem, token_offset,
290
+ present_key_cache.device_mem, 0,
291
+ bytes);
292
+ bm_memcpy_d2d_byte(bm_handle, past_value[i].device_mem, token_offset,
293
+ present_value_cache.device_mem, 0,
294
+ bytes);
295
+ }
296
+ outputs_lm.shape = net_lm->stages[0].output_shapes[0];
297
+ ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm, 1,
298
+ &outputs_lm, 1, true, false);
299
+ bm_thread_sync(bm_handle);
300
+
301
+ int token = 0;
302
+ bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm.device_mem);
303
+ return token;
304
+ }
305
+
306
+
307
+ std::string Baichuan2::predict_first_token(const std::string &input_str) {
308
+ history = input_str;
309
+ //int tok_num = 1;
310
+ std::vector<int> tokens;
311
+ sentencepiece.Encode(history, &tokens);
312
+ tokens.insert(tokens.begin(), 1);
313
+ if (tokens.empty()) {
314
+ round = 0;
315
+ history = "Sorry: your question is too wierd!!\n";
316
+ return history;
317
+ }
318
+ // make sure token not too large
319
+ if (tokens.size() > MAX_LEN - 10) {
320
+ // reset
321
+ if (round == 0) {
322
+ history = "Error: your question is too large!\n";
323
+ return history;
324
+ }
325
+ round = 0;
326
+ history = "";
327
+ return predict_first_token(input_str);
328
+ }
329
+ int token = forward_first(tokens);
330
+ int pre_token = 0;
331
+ std::string pre_word;
332
+ std::string word;
333
+ std::vector<int> pre_ids = {pre_token};
334
+ std::vector<int> ids = {pre_token,token};
335
+ sentencepiece.Decode(pre_ids, &pre_word);
336
+ sentencepiece.Decode(ids, &word);
337
+ std::string diff = word.substr(pre_word.size());
338
+ #ifdef PRINT
339
+ printf("token %d",token);
340
+ printf("diff %s",diff.c_str());
341
+ #endif
342
+ history += diff;
343
+ if (token_length < MAX_LEN) {
344
+ token_length++;
345
+ }
346
+ return diff;
347
+ }
348
+
349
+ std::string Baichuan2::predict_next_token() {
350
+ int pre_token;
351
+ pre_token = 0;
352
+ int token = forward_next();
353
+ if(token == EOS){
354
+ round = 0;
355
+ history = history.substr(history.size()/2);
356
+ return "_GETEOS_";
357
+ }
358
+ std::string pre_word;
359
+ std::string word;
360
+ std::vector<int> pre_ids = {pre_token};
361
+ std::vector<int> ids = {pre_token, token};
362
+ sentencepiece.Decode(pre_ids, &pre_word);
363
+ sentencepiece.Decode(ids, &word);
364
+ std::string diff = word.substr(pre_word.size());
365
+ #ifdef PRINT
366
+ printf("token %d",token);
367
+ printf("diff %s",diff.c_str());
368
+ #endif
369
+ history += diff;
370
+ if (token_length < MAX_LEN) {
371
+ token_length++;
372
+ }else{
373
+ round = 0;
374
+ return "_GETMAX_";
375
+ }
376
+ return diff;
377
+ }
378
+
379
+
380
+ extern "C" {
381
+
382
+
383
+ Baichuan2 *Baichuan2_with_devid_and_model(int devid, const char *bmodel_path, const char *tokenizer_path) {
384
+ Baichuan2 *chat = new Baichuan2();
385
+ chat->init(devid, bmodel_path, tokenizer_path);
386
+ return chat;
387
+ }
388
+
389
+ void Baichuan2_delete(Baichuan2 *chat) { delete chat; }
390
+
391
+ void Baichuan2_deinit(Baichuan2 *chat) {
392
+ chat->deinit();
393
+ }
394
+
395
+ const char *get_history(Baichuan2 *chat) {
396
+ std::string str = chat->history;
397
+ return strdup(str.c_str());
398
+ }
399
+
400
+ const char *set_history(Baichuan2 *chat, const char *history) {
401
+ chat->history = history;
402
+ return strdup(history);
403
+ }
404
+
405
+ const char *Baichuan2_predict_first_token(Baichuan2 *chat, const char *input_str) {
406
+ std::string str = chat->predict_first_token(input_str);
407
+ return strdup(str.c_str());
408
+ }
409
+
410
+ const char *Baichuan2_predict_next_token(Baichuan2 *chat) {
411
+ std::string str = chat->predict_next_token();
412
+ return strdup(str.c_str());
413
+ }
414
+
415
+ const int get_eos(Baichuan2 *chat){
416
+ const int res = chat->EOS;
417
+ return res;
418
+ }
419
+ }
Baichuan2/web_demo/chat.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ import ctypes
4
+
5
+
6
+ class TokenWord(ctypes.Structure):
7
+ _fields_ = [
8
+ ("token", ctypes.c_int),
9
+ ("word", ctypes.c_char * 2048) # 假设最大长度为 100,你可以根据实际情况调整
10
+ ]
11
+
12
+
13
+ class TPUChatglm:
14
+ def __init__(self):
15
+ self.lib = ctypes.cdll.LoadLibrary('./build/libtpuchat.so')
16
+ device_id = 3
17
+ bmodel_path = "../model/baichuan2-7b-test_int8.bmodel"
18
+ token_path = "../model/tokenizer.model"
19
+ self.device_id = device_id
20
+ self.bmodel_path = bmodel_path
21
+ self.token_path = token_path
22
+ self.libset()
23
+ self.init()
24
+
25
+ def libset(self):
26
+ self.lib.Baichuan2_with_devid_and_model.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p]
27
+ self.lib.Baichuan2_with_devid_and_model.restype = ctypes.c_void_p
28
+
29
+ self.lib.Baichuan2_delete.argtypes = [ctypes.c_void_p]
30
+
31
+ # deinit
32
+ self.lib.Baichuan2_deinit.argtypes = [ctypes.c_void_p]
33
+
34
+ # Baichuan2_predict_first_token
35
+ self.lib.Baichuan2_predict_first_token.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
36
+ self.lib.Baichuan2_predict_first_token.restype = ctypes.c_char_p
37
+
38
+ # Baichuan2_predict_next_token
39
+ self.lib.Baichuan2_predict_next_token.argtypes = [ctypes.c_void_p]
40
+ self.lib.Baichuan2_predict_next_token.restype = ctypes.c_char_p
41
+
42
+ # get_eos
43
+ self.lib.get_eos.argtypes = [ctypes.c_void_p]
44
+ self.lib.get_eos.restype = ctypes.c_int
45
+ # get_history
46
+ self.lib.get_history.argtypes = [ctypes.c_void_p]
47
+ self.lib.get_history.restype = ctypes.c_char_p
48
+ # set history
49
+ self.lib.set_history.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
50
+
51
+ def init(self):
52
+ self.obj = self.lib.Baichuan2_with_devid_and_model(self.device_id, self.bmodel_path.encode('utf-8'),
53
+ self.token_path.encode('utf-8'))
54
+
55
+ def predict_first_token(self, context):
56
+ return self.lib.Baichuan2_predict_first_token(self.obj, context.encode('utf-8')).decode('utf-8')
57
+
58
+ def predict_next_token(self):
59
+ return self.lib.Baichuan2_predict_next_token(self.obj).decode('utf-8')
60
+
61
+ def predict(self, context):
62
+
63
+ first_token = self.predict_first_token(context)
64
+ # print(first_token, end='')
65
+ res = ''
66
+ while True:
67
+ next_token = self.predict_next_token()
68
+ if next_token == '_GETMAX_' or next_token == '_GETEOS_':
69
+ # print(next_token)
70
+ break
71
+ # print(next_token, end='')
72
+ res += next_token
73
+ return res
74
+
75
+ def stream_predict(self, query, history):
76
+ history.append((query, ''))
77
+
78
+ prompt = ''
79
+ # for i, (old_query, response) in enumerate(history):
80
+ # prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
81
+ # prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
82
+ prompt = "<reserved_106>" + query + "<reserved_107>"
83
+
84
+ res = ''
85
+ first_token = self.predict_first_token(prompt)
86
+ res += first_token
87
+
88
+ while True:
89
+ next_token = self.predict_next_token()
90
+ if next_token == '_GETMAX_' or next_token == '_GETEOS_':
91
+ break
92
+ res += next_token
93
+ history[-1] = (query, res)
94
+ yield res, history
95
+
96
+ def get_config(self):
97
+ pass
Baichuan2/web_demo/web_demo.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import gradio as gr
3
+ import mdtex2html
4
+ from chat import TPUChatglm
5
+
6
+
7
+ def postprocess(self, y):
8
+ if y is None:
9
+ return []
10
+ for i, (message, response) in enumerate(y):
11
+ y[i] = (
12
+ None if message is None else mdtex2html.convert((message)),
13
+ None if response is None else mdtex2html.convert(response),
14
+ )
15
+ return y
16
+
17
+
18
+ gr.Chatbot.postprocess = postprocess
19
+
20
+ glm = TPUChatglm()
21
+
22
+ def parse_text(text):
23
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
24
+ lines = text.split("\n")
25
+ lines = [line for line in lines if line != ""]
26
+ count = 0
27
+ for i, line in enumerate(lines):
28
+ if "```" in line:
29
+ count += 1
30
+ items = line.split('`')
31
+ if count % 2 == 1:
32
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
33
+ else:
34
+ lines[i] = f'<br></code></pre>'
35
+ else:
36
+ if i > 0:
37
+ if count % 2 == 1:
38
+ line = line.replace("`", "\`")
39
+ line = line.replace("<", "&lt;")
40
+ line = line.replace(">", "&gt;")
41
+ line = line.replace(" ", "&nbsp;")
42
+ line = line.replace("*", "&ast;")
43
+ line = line.replace("_", "&lowbar;")
44
+ line = line.replace("-", "&#45;")
45
+ line = line.replace(".", "&#46;")
46
+ line = line.replace("!", "&#33;")
47
+ line = line.replace("(", "&#40;")
48
+ line = line.replace(")", "&#41;")
49
+ line = line.replace("$", "&#36;")
50
+ lines[i] = "<br>" + line
51
+ text = "".join(lines)
52
+ return text
53
+
54
+
55
+ def gen(input, history):
56
+ i = 0
57
+ history.append((input, ''))
58
+ res = ''
59
+ while i < 10:
60
+ i += 1
61
+ res += str(i)
62
+ time.sleep(0.05)
63
+ history[-1] = (input, res)
64
+ yield res, history
65
+
66
+
67
+ def predict(input, chatbot, max_length, top_p, temperature, history):
68
+
69
+ chatbot.append((parse_text(input), ""))
70
+ for response, history in glm.stream_predict(input, history):
71
+ chatbot[-1] = (parse_text(input), parse_text(response))
72
+ yield chatbot, history
73
+
74
+
75
+ def reset_user_input():
76
+ return gr.update(value='')
77
+
78
+
79
+ def reset_state():
80
+ return [], [], None
81
+
82
+
83
+ with gr.Blocks() as demo:
84
+ gr.HTML("""<h1 align="center">Baichuan2-7B TPU</h1>""")
85
+
86
+ chatbot = gr.Chatbot()
87
+ with gr.Row():
88
+ with gr.Column(scale=4):
89
+ with gr.Column(scale=12):
90
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
91
+ container=False)
92
+ with gr.Column(min_width=32, scale=1):
93
+ submitBtn = gr.Button("Submit", variant="primary")
94
+ with gr.Column(scale=1):
95
+ emptyBtn = gr.Button("Clear History")
96
+ max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
97
+ top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
98
+ temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
99
+
100
+ history = gr.State([])
101
+
102
+ submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history],
103
+ [chatbot, history], show_progress=True)
104
+ submitBtn.click(reset_user_input, [], [user_input])
105
+
106
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
107
+
108
+ demo.queue().launch(share=True, server_name="0.0.0.0", inbrowser=True)
BaseModel/base_model.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from transformers import AutoTokenizer
3
+
4
+
5
+ class BaseModel:
6
+ def __init__(self, args):
7
+ # parameters
8
+ self.EOS = None
9
+ self.SEQLEN = None
10
+ self.input_str = ""
11
+ self.system_prompt = ""
12
+ self.history = []
13
+
14
+ # devid
15
+ self.devices = [int(d) for d in args.devid.split(",")]
16
+
17
+ # load tokenizer
18
+ print("Load " + args.tokenizer_path + " ...")
19
+ self.tokenizer = AutoTokenizer.from_pretrained(
20
+ args.tokenizer_path, trust_remote_code=True
21
+ )
22
+
23
+ # warm up
24
+ self.tokenizer.decode([0])
25
+ print("Done!")
26
+
27
+ def chat(self):
28
+ """
29
+ Start a chat session.
30
+ """
31
+ # check
32
+ if not self.EOS:
33
+ raise NotImplementedError("Forget to set End of Sentence Token Id(EOS)")
34
+ if not self.SEQLEN:
35
+ raise NotImplementedError("Forget to set End of Sentence Token Id")
36
+
37
+ # Instruct
38
+ print(
39
+ """\n===========================================================
40
+ 1. If you want to quit, please enter one of [q, quit, exit]
41
+ 2. To create a new chat session, please enter one of [clear, new]
42
+ ==========================================================="""
43
+ )
44
+ # Stop Chatting with "exit" input
45
+ while True:
46
+ self.input_str = input("\nQuestion: ")
47
+ # Quit
48
+ if self.input_str in ["exit", "q", "quit"]:
49
+ break
50
+ # New Chat
51
+ elif self.input_str in ["clear", "new"]:
52
+ self.clear()
53
+ # Chat
54
+ else:
55
+ tokens = self.encode_tokens()
56
+
57
+ # check tokens
58
+ if not tokens:
59
+ print("Sorry: your question is empty!!")
60
+ return
61
+ if len(tokens) > self.SEQLEN:
62
+ print(
63
+ "The maximum question length should be shorter than {} but we get {} instead.".format(
64
+ self.SEQLEN, len(tokens)
65
+ )
66
+ )
67
+ return
68
+
69
+ print("\nAnswer: ", end="")
70
+ self.stream_answer(tokens)
71
+
72
+ def stream_answer(self, tokens):
73
+ """
74
+ Stream the answer for the given tokens.
75
+ """
76
+ tok_num = 0
77
+ self.answer_cur = ""
78
+ self.answer_token = []
79
+
80
+ # First token
81
+ first_start = time.time()
82
+ token = self.forward_first(tokens)
83
+ first_end = time.time()
84
+ # Following tokens
85
+ while token != self.EOS and self.model.token_length < self.SEQLEN:
86
+ pre_word = self.decode_tokens([token])
87
+ word = self.decode_tokens([token, token])[len(pre_word):]
88
+ self.answer_token += [token]
89
+ print(word, flush=True, end="")
90
+ tok_num += 1
91
+ token = self.forward_next()
92
+ self.answer_cur = self.tokenizer.decode(self.answer_token)
93
+
94
+ # counting time
95
+ next_end = time.time()
96
+ first_duration = first_end - first_start
97
+ next_duration = next_end - first_end
98
+ tps = tok_num / next_duration
99
+
100
+ self.update_history()
101
+
102
+ print()
103
+ print(f"FTL: {first_duration:.3f} s")
104
+ print(f"TPS: {tps:.3f} token/s")
105
+
106
+ def stream_predict(self, query):
107
+ """
108
+ Stream the prediction for the given query.
109
+ """
110
+ self.answer_cur = ""
111
+ self.input_str = query
112
+ tokens = self.encode_tokens()
113
+
114
+ for answer_cur, history in self._generate_predictions(tokens):
115
+ yield answer_cur, history
116
+
117
+ def _generate_predictions(self, tokens):
118
+ """
119
+ Generate predictions for the given tokens.
120
+ """
121
+ # First token
122
+ next_token = self.forward_first(tokens)
123
+ output_tokens = [next_token]
124
+
125
+ # Following tokens
126
+ while True:
127
+ next_token = self.forward_next()
128
+ if next_token == self.EOS:
129
+ break
130
+ output_tokens += [next_token]
131
+ self.answer_cur = self.tokenizer.decode(output_tokens)
132
+ if self.model.token_length >= self.SEQLEN:
133
+ self.update_history()
134
+ yield self.answer_cur + "\n\n\nReached the maximum length; The history context has been cleared.", self.history
135
+ break
136
+ else:
137
+ yield self.answer_cur, self.history
138
+
139
+ self.update_history()
140
+
141
+ def forward_first(self, tokens):
142
+ """
143
+ Forward the first token.
144
+ """
145
+ token = self.model.forward_first(tokens)
146
+ return token
147
+
148
+ def forward_next(self):
149
+ """
150
+ Forward the next token.
151
+ """
152
+ token = self.model.forward_next()
153
+ return token
154
+
155
+ def decode_tokens(self, token):
156
+ """
157
+ Decode the given token.
158
+ """
159
+ word = self.tokenizer.decode(token, skip_special_tokens=True)
160
+ return word
161
+
162
+ def encode_tokens(self):
163
+ """
164
+ Encode the input string to tokens.
165
+ """
166
+ raise NotImplementedError
167
+
168
+ def load_model(self):
169
+ """
170
+ Load the model.
171
+ """
172
+ raise NotImplementedError
173
+
174
+ def clear(self):
175
+ """
176
+ Clear the chat session.
177
+ """
178
+ raise NotImplementedError
179
+
180
+ def update_history(self):
181
+ """
182
+ Update chat history.
183
+ """
184
+ raise NotImplementedError
ChatGLM2/README.md ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![](./assets/sophgo_chip.png)
2
+
3
+ # ChatGLM2
4
+
5
+ 本项目实现BM1684X部署语言大模型[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b)。通过[TPU-MLIR](https://github.com/sophgo/tpu-mlir)编译器将模型转换成bmodel,并采用c++代码将其部署到BM1684X的PCIE环境,或者SoC环境。
6
+
7
+
8
+ 在知乎上写了关于`ChatGLM`的解读,方便大家理解源码:
9
+
10
+ [ChatGLM2流程解析与TPU-MLIR部署](https://zhuanlan.zhihu.com/p/641975976)
11
+
12
+
13
+ ## 开发环境
14
+
15
+
16
+ 1. 下载docker,启动容器,如下:
17
+
18
+ ``` shell
19
+ docker pull sophgo/tpuc_dev:latest
20
+
21
+ # myname1234 is just an example, you can set your own name
22
+ docker run --privileged --name myname1234 -v $PWD:/workspace -it sophgo/tpuc_dev:latest
23
+ ```
24
+ 后文假定环境都在docker的`/workspace`目录。
25
+
26
+
27
+ 2. 从Huggingface下载`ChatGLM2-6B`,比较大,会花较长时间
28
+
29
+ ``` shell
30
+ git lfs install
31
+ git clone git@hf.co:THUDM/chatglm2-6b
32
+ ```
33
+ 并将本项目中./models/ChatGLM2/compile/files/chatglm2-6b中config.json与modeling_chatglm.py替换至上述下载后的文件夹中,并替换同名文件(其中需要采用其它sequence length的用户请参考[常见问题](#常见问题),默认sequence length = 512)
34
+
35
+ 3. 下载`TPU-MLIR`代码并编译,(也可以直接下载编译好的release包解压)
36
+
37
+ 目前由于mlir还在维护中,编译GLM系列模型的用户请下载
38
+ ``` shell
39
+ pip3 install dfss
40
+ python3 -m dfss --url=open@sophgo.com:/ext_model_information/LLM/mlir_club/glm_mlir.tar.gz
41
+ tar -xf glm_mlir.tar.gz
42
+ source source tpu-mlir_v1.6.45-gdc3e9f6b-20231220/envsetup.sh
43
+ ```
44
+
45
+ 后续mlir维护完成后可以使用如下方式
46
+ ``` shell
47
+ git clone git@github.com:sophgo/tpu-mlir.git
48
+ cd tpu-mlir
49
+ source ./envsetup.sh
50
+ ./build.sh
51
+ ```
52
+
53
+ ## 编译模型
54
+
55
+ 1. 导出所有onnx模型,如果过程中提示缺少某些组件,直接`pip3 install 组件`即可
56
+
57
+ ``` shell
58
+ cd compile
59
+ python3 export_onnx.py --model_path your_chatglm2-6b_path
60
+ ```
61
+ 此时有大量onnx模型被导出到tmp目录。
62
+
63
+ 2. 对onnx模型进行编译
64
+
65
+ 目前TPU-MLIR支持对ChatGLM2进行F16、INT8和INT4量化,且支持多芯分布式推理,默认情况下会进行F16量化和单芯推理,最终生成`chatglm2-6b_f16_1dev.bmodel`文件
66
+
67
+ ```shell
68
+ ./compile.sh --name chatglm2-6b --mode inference_mode --num_device device_number
69
+ ```
70
+
71
+ 其中:
72
+ `--name` 为模型名称,在此指定为`chatglm2-6b`;
73
+ `--mode` 为推理所使用的数据类型,可以选择`f16, int8, int4`中任意一种,默认为`f16`;
74
+ `--num_device` 为推理所使用的芯片数量,请根据实际所使用的设备指定,默认`--num_device 1`。
75
+
76
+ ## 编译程序(C++版本)
77
+
78
+ 执行如下编译,(PCIE与SOC相同):
79
+
80
+ ```shell
81
+ cd demo
82
+ mkdir build
83
+ cd build
84
+ cmake ..
85
+ make
86
+ ```
87
+
88
+ 编译生成chatglm可执行程序,将`chatglm`放到demo目录下,同时按照下列方式指定芯片数量和bmodel路径。
89
+ 运行`chatglm`,默认单芯运行`chatglm2-6b_f16_1dev.bmodel`:
90
+ ```shell
91
+ ./chatglm --model chatglm2-6b_f16_1dev.bmodel --tokenizer ../support/tokenizer/tokenizer.model --devid your_devid
92
+ ```
93
+ 其中`--devid`为用来推理的TPU编号,默认为0,如果使用多芯推理(需要保证编译的bmodel也是多芯)可以使用`,`来增加芯片,如`--devid 2,3` 表示使用TPU2 和 TPU3来进行推理。
94
+
95
+ ## 运行效果
96
+
97
+ 以下为单芯片下INT8量化模式的运行效果:
98
+
99
+ ![](./assets/chatglm.jpg)
100
+
101
+ ## 常见问题
102
+
103
+ #### sentencepiece是怎么来的
104
+
105
+ 工程中已经有编译好的,所以不需要编译,如果好奇的话,参考如下步骤。
106
+
107
+ 下载[sentencepiece](https://github.com/google/sentencepiece),并编译得到`libsentencepiece.a`
108
+
109
+ ```shell
110
+ git clone git@github.com:google/sentencepiece.git
111
+ cd sentencepiece
112
+ mkdir build
113
+ cd build
114
+ cmake ..
115
+ make -j
116
+ ```
117
+
118
+ 如果要编译SoC环境,则参考demo的编译方式,在makefile中指定交叉编译器
119
+
120
+ #### demo程序无法正常运行
121
+
122
+ 如果demo程序拷贝到运行环境提示无法运行,比如接口找不到等等错误。
123
+ 原因是运行环境的库有所不同,将demo中的`./support/lib_pcie`(PCIE)或者 `./support/lib_soc`(SoC)里面的so文件拷贝到运行环境,链接到里面的so即可。
124
+
125
+
126
+ #### 对源码做了哪些修改:
127
+
128
+ 一共做了三点修改:
129
+ - 将`config.json`文件中`seq_length`配置为512;
130
+ - 将`modeling_chatglm.py`文件中的如下代码:
131
+
132
+ ```python
133
+ if attention_mask is not None:
134
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
135
+ ```
136
+
137
+ 修改为:
138
+
139
+ ```python
140
+ if attention_mask is not None:
141
+ attention_scores = attention_scores + (attention_mask * -10000.0)
142
+ ```
143
+
144
+ 这样修改可以提升效率,使用`masked_fill`效率低下;另一方面`masked_fill`转ONNX存在些bug。
145
+
146
+ - 将`modeling_chatglm.py`文件中的如下代码:
147
+
148
+ ```python
149
+ pytorch_major_version = int(torch.__version__.split('.')[0])
150
+ if pytorch_major_version >= 2:
151
+ ```
152
+
153
+ 修改为:
154
+
155
+ ```python
156
+ pytorch_major_version = int(torch.__version__.split('.')[0])
157
+ if False:
158
+ ```
159
+
160
+ 这是因为ONNX无法支持`torch.nn.functional.scaled_dot_product_attention`算子的转换。
ChatGLM2/compile/compile.sh ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -ex
3
+ models=
4
+ mode="f16"
5
+ folder="tmp"
6
+ num_device=1
7
+ mode_args=""
8
+ device_args=""
9
+ quantize_args="--quantize F16"
10
+ name=""
11
+ num_layers=
12
+ out_model=$name.bmodel
13
+
14
+ while [[ $# -gt 0 ]]; do
15
+ key="$1"
16
+
17
+ case $key in
18
+ --mode)
19
+ mode="$2"
20
+ shift 2
21
+ ;;
22
+ --num_device)
23
+ num_device="$2"
24
+ shift 2
25
+ ;;
26
+ --name)
27
+ name="$2"
28
+ shift 2
29
+ ;;
30
+ *)
31
+ echo "Invalid option: $key" >&2
32
+ exit 1
33
+ ;;
34
+ :)
35
+ echo "Option -$OPTARG requires an argument." >&2
36
+ exit 1
37
+ ;;
38
+ esac
39
+ done
40
+
41
+ if [ "$name" = "chatglm2-6b" ]; then
42
+ num_layers=27
43
+ echo "Compile ChatGLM2-6B"
44
+ else
45
+ >&2 echo -e "Error: Invalid name $name, the input name must be \033[31mchatglm2-6b\033[0m"
46
+ exit 1
47
+ fi
48
+
49
+ if [ x$mode == x"int8" ]; then
50
+ quantize_args="--quantize W8F16"
51
+ elif [ x$mode == x"f16" ]; then
52
+ quantize_args="--quantize F16"
53
+ elif [ x$mode == x"int4" ]; then
54
+ quantize_args="--quantize W4F16 --q_group_size 64"
55
+ else
56
+ echo "Error, unknown quantize mode"
57
+ exit 1
58
+ fi
59
+
60
+ if [ x$num_device != x1 ]; then
61
+ device_args="--num_device $num_device"
62
+ out_model=$name'_'$mode'_'$num_device'dev.bmodel'
63
+ else
64
+ out_model=$name'_'$mode'_1dev.bmodel'
65
+ fi
66
+
67
+ outdir=${folder}/embedding
68
+ mkdir -p $outdir
69
+ pushd $outdir
70
+
71
+ model_transform.py \
72
+ --model_name embedding \
73
+ --model_def ../onnx/embedding.onnx \
74
+ --mlir embedding.mlir
75
+
76
+
77
+ model_deploy.py \
78
+ --mlir embedding.mlir \
79
+ --quantize F16 \
80
+ --quant_input \
81
+ --quant_output \
82
+ --chip bm1684x \
83
+ $device_args \
84
+ --model embedding.bmodel
85
+
86
+ model_transform.py \
87
+ --model_name embedding_cache \
88
+ --model_def ../onnx/embedding.onnx \
89
+ --input_shapes [[1,1]] \
90
+ --mlir embedding_cache.mlir
91
+
92
+
93
+ model_deploy.py \
94
+ --mlir embedding_cache.mlir \
95
+ --quantize F16 \
96
+ --quant_input \
97
+ --quant_output \
98
+ --chip bm1684x \
99
+ $device_args \
100
+ --model embedding_cache.bmodel
101
+
102
+ rm *.npz
103
+
104
+ models=$models' '$outdir'/embedding.bmodel '$outdir'/embedding_cache.bmodel '
105
+
106
+ popd
107
+
108
+ echo $models
109
+
110
+ outdir=tmp/$mode"_"$num_device"dev"/lm_head
111
+ mkdir -p $outdir
112
+ pushd $outdir
113
+
114
+ model_transform.py \
115
+ --model_name lm_head \
116
+ --model_def ../../onnx/lm_head.onnx \
117
+ --mlir lm_head.mlir
118
+
119
+ model_deploy.py \
120
+ --mlir lm_head.mlir \
121
+ $quantize_args \
122
+ --quant_input \
123
+ --quant_output \
124
+ --chip bm1684x \
125
+ $device_args \
126
+ --model lm_head.bmodel
127
+
128
+ rm *.npz
129
+
130
+ models=${models}${outdir}'/lm_head.bmodel '
131
+ popd
132
+
133
+ echo $models
134
+
135
+ outdir=tmp/$mode"_"$num_device"dev"/block
136
+ mkdir -p $outdir
137
+
138
+ pushd $outdir
139
+ mkdir -p $outdir
140
+
141
+ for ((i=0; i<=$num_layers; i++)); do
142
+
143
+ model_transform.py \
144
+ --model_name block_$i \
145
+ --model_def ../../onnx/block_$i.onnx \
146
+ --mlir block_$i.mlir
147
+
148
+ model_deploy.py \
149
+ --mlir block_$i.mlir \
150
+ $quantize_args \
151
+ --quant_input \
152
+ --quant_output \
153
+ --chip bm1684x \
154
+ $device_args \
155
+ --model block_$i.bmodel
156
+
157
+ model_transform.py \
158
+ --model_name block_cache_$i \
159
+ --model_def ../../onnx/block_cache_$i.onnx \
160
+ --mlir block_cache_$i.mlir
161
+
162
+ model_deploy.py \
163
+ --mlir block_cache_$i.mlir \
164
+ $quantize_args \
165
+ --quant_input \
166
+ --quant_output \
167
+ --chip bm1684x \
168
+ $device_args \
169
+ --model block_cache_$i.bmodel
170
+
171
+ rm *.npz
172
+
173
+ models=${models}${outdir}'/block_'$i'.bmodel '$outdir'/block_cache_'$i'.bmodel '
174
+
175
+ done
176
+ popd
177
+ echo $models
178
+
179
+ model_tool --combine $models -o $out_model
ChatGLM2/compile/export_onnx.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # ==============================================================================
3
+ #
4
+ # Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
5
+ #
6
+ # TPU-MLIR is licensed under the 2-Clause BSD License except for the
7
+ # third-party components.
8
+ #
9
+ # ==============================================================================
10
+
11
+ import os
12
+ import torch
13
+ import argparse
14
+ from tqdm import tqdm
15
+ from transformers import AutoModel, AutoTokenizer
16
+
17
+ parser = argparse.ArgumentParser(description='export onnx.')
18
+ parser.add_argument('--model_path', type=str, help='path to the torch model.')
19
+
20
+ args = parser.parse_args()
21
+
22
+ model_path = args.model_path
23
+ folder = f"./tmp/onnx"
24
+
25
+ origin_model = AutoModel.from_pretrained(
26
+ model_path, trust_remote_code=True).float().eval()
27
+
28
+ for param in origin_model.parameters():
29
+ param.requires_grad = False
30
+
31
+ config = origin_model.config
32
+ transformer = origin_model.transformer
33
+ layers = transformer.encoder.layers
34
+
35
+ SEQ_LENGTH = transformer.seq_length
36
+ NUM_LAYERS = config.num_layers
37
+ HIDDEN_SIZE = config.hidden_size
38
+ NUM_ATTENTION_HEADS = config.num_attention_heads
39
+ HEAD_DIM = HIDDEN_SIZE // NUM_ATTENTION_HEADS
40
+
41
+ print(f'Layers: {NUM_LAYERS}\nHidden size: {HIDDEN_SIZE}\n')
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
44
+
45
+ class Embedding(torch.nn.Module):
46
+
47
+ def __init__(self):
48
+ super().__init__()
49
+
50
+ def forward(self, input_ids):
51
+ return transformer.embedding.word_embeddings(input_ids)
52
+
53
+
54
+ class Block(torch.nn.Module):
55
+
56
+ def __init__(self, layer_id):
57
+ super().__init__()
58
+ self.layer_id = layer_id
59
+ self.layer = layers[layer_id]
60
+
61
+ def forward(self, hidden_states, position_ids, attention_mask):
62
+ rotary_pos_emb = transformer.rotary_pos_emb(SEQ_LENGTH)[position_ids]
63
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
64
+ hidden_states, past_kv = self.layer(hidden_states,
65
+ attention_mask,
66
+ rotary_pos_emb=rotary_pos_emb)
67
+ return hidden_states, past_kv
68
+
69
+
70
+ class BlockCache(torch.nn.Module):
71
+
72
+ def __init__(self, layer_id):
73
+ super().__init__()
74
+ self.layer_id = layer_id
75
+ self.layer = layers[layer_id]
76
+
77
+ def forward(self, hidden_states, position_ids, attention_mask, past_k,
78
+ past_v):
79
+ rotary_pos_emb = transformer.rotary_pos_emb(SEQ_LENGTH)[position_ids]
80
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
81
+ hidden_states, past_kv = self.layer(hidden_states,
82
+ attention_mask,
83
+ kv_cache=(past_k, past_v),
84
+ rotary_pos_emb=rotary_pos_emb)
85
+ present_k, present_v = past_kv
86
+ return hidden_states, present_k[1:], present_v[1:]
87
+
88
+
89
+ class LmHead(torch.nn.Module):
90
+
91
+ def __init__(self):
92
+ super().__init__()
93
+
94
+ def forward(self, hidden_states):
95
+ hidden_states = transformer.encoder.final_layernorm(hidden_states)
96
+ m_logits = transformer.output_layer(hidden_states)
97
+ _, token = torch.topk(m_logits, 1)
98
+ return token
99
+
100
+
101
+ def convert_block(layer_id):
102
+ model = Block(layer_id)
103
+ hidden_states = torch.randn((SEQ_LENGTH, 1, HIDDEN_SIZE))
104
+ position_ids = torch.tensor([range(SEQ_LENGTH)], dtype=torch.long)
105
+ attention_mask = -1000 * torch.ones((1, 1, SEQ_LENGTH, SEQ_LENGTH), dtype=torch.float32).triu(diagonal=1)
106
+ torch.onnx.export(
107
+ model, (hidden_states, position_ids, attention_mask),
108
+ f'{folder}/block_{layer_id}.onnx',
109
+ verbose=False,
110
+ input_names=['input_states', 'position_ids', 'attention_mask'],
111
+ output_names=['hidden_states', 'past_k', 'past_v'],
112
+ do_constant_folding=True,
113
+ opset_version=15)
114
+
115
+
116
+ def convert_block_cache(layer_id):
117
+ model = BlockCache(layer_id)
118
+ hidden_states = torch.randn((1, 1, HIDDEN_SIZE))
119
+ position_ids = torch.tensor([range(1)], dtype=torch.long)
120
+ attention_mask = -1000 * torch.ones((1, 1, 1, SEQ_LENGTH + 1), dtype=torch.float32).triu(diagonal=1)
121
+ past_k = torch.randn((SEQ_LENGTH, 1, 2, HEAD_DIM))
122
+ past_v = torch.randn((SEQ_LENGTH, 1, 2, HEAD_DIM))
123
+
124
+ torch.onnx.export(
125
+ model, (hidden_states, position_ids, attention_mask, past_k, past_v),
126
+ f'{folder}/block_cache_{layer_id}.onnx',
127
+ verbose=False,
128
+ input_names=[
129
+ 'input_states', 'position_ids', 'attention_mask', 'history_k',
130
+ 'history_v'
131
+ ],
132
+ output_names=['hidden_states', 'past_k', 'past_v'],
133
+ do_constant_folding=True,
134
+ opset_version=15)
135
+
136
+
137
+ def convert_embedding():
138
+ model = Embedding()
139
+ input_ids = torch.tensor([range(SEQ_LENGTH)])
140
+
141
+ torch.onnx.export(model, (input_ids),
142
+ f'{folder}/embedding.onnx',
143
+ verbose=False,
144
+ input_names=['input_ids'],
145
+ output_names=['input_embed'],
146
+ do_constant_folding=True,
147
+ opset_version=15)
148
+
149
+
150
+ def convert_lm_head():
151
+ model = LmHead()
152
+ input = torch.randn(1, HIDDEN_SIZE)
153
+
154
+ torch.onnx.export(model, (input),
155
+ f'{folder}/lm_head.onnx',
156
+ verbose=False,
157
+ input_names=['hidden_states'],
158
+ output_names=['token'],
159
+ do_constant_folding=True,
160
+ opset_version=15)
161
+
162
+ # create folder to store onnx
163
+ if not os.path.exists(folder):
164
+ os.makedirs(folder)
165
+
166
+ # export models
167
+ print(f'Convert block & block_cache')
168
+ for i in tqdm(range(NUM_LAYERS)):
169
+ convert_block(i)
170
+ convert_block_cache(i)
171
+
172
+ print(f'Convert embedding')
173
+ convert_embedding()
174
+
175
+ print(f'Convert lm_head')
176
+ convert_lm_head()
ChatGLM2/compile/files/chatglm2-6b/config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "THUDM/chatglm2-6b",
3
+ "model_type": "chatglm",
4
+ "architectures": [
5
+ "ChatGLMModel"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_chatglm.ChatGLMConfig",
9
+ "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
10
+ "AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
11
+ "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
12
+ "AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
13
+ },
14
+ "add_bias_linear": false,
15
+ "add_qkv_bias": true,
16
+ "apply_query_key_layer_scaling": true,
17
+ "apply_residual_connection_post_layernorm": false,
18
+ "attention_dropout": 0.0,
19
+ "attention_softmax_in_fp32": true,
20
+ "bias_dropout_fusion": true,
21
+ "ffn_hidden_size": 13696,
22
+ "fp32_residual_connection": false,
23
+ "hidden_dropout": 0.0,
24
+ "hidden_size": 4096,
25
+ "kv_channels": 128,
26
+ "layernorm_epsilon": 1e-05,
27
+ "multi_query_attention": true,
28
+ "multi_query_group_num": 2,
29
+ "num_attention_heads": 32,
30
+ "num_layers": 28,
31
+ "original_rope": true,
32
+ "padded_vocab_size": 65024,
33
+ "post_layer_norm": true,
34
+ "rmsnorm": true,
35
+ "seq_length": 512,
36
+ "use_cache": true,
37
+ "torch_dtype": "float16",
38
+ "transformers_version": "4.27.1",
39
+ "tie_word_embeddings": false,
40
+ "eos_token_id": 2,
41
+ "pad_token_id": 0
42
+ }
ChatGLM2/compile/files/chatglm2-6b/modeling_chatglm.py ADDED
@@ -0,0 +1,1285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch ChatGLM model. """
2
+
3
+ import math
4
+ import copy
5
+ import warnings
6
+ import re
7
+ import sys
8
+
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss, LayerNorm
14
+ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
15
+ from torch.nn.utils import skip_init
16
+ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
17
+
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPast,
20
+ CausalLMOutputWithPast,
21
+ SequenceClassifierOutputWithPast,
22
+ )
23
+ from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.utils import logging
25
+ from transformers.generation.logits_process import LogitsProcessor
26
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
27
+
28
+ from .configuration_chatglm import ChatGLMConfig
29
+
30
+ # flags required to enable jit fusion kernels
31
+
32
+ if sys.platform != 'darwin':
33
+ torch._C._jit_set_profiling_mode(False)
34
+ torch._C._jit_set_profiling_executor(False)
35
+ torch._C._jit_override_can_fuse_on_cpu(True)
36
+ torch._C._jit_override_can_fuse_on_gpu(True)
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B"
41
+ _CONFIG_FOR_DOC = "ChatGLM6BConfig"
42
+
43
+ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
44
+ "THUDM/chatglm2-6b",
45
+ # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
46
+ ]
47
+
48
+
49
+ def default_init(cls, *args, **kwargs):
50
+ return cls(*args, **kwargs)
51
+
52
+
53
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
54
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
55
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
56
+ scores.zero_()
57
+ scores[..., 5] = 5e4
58
+ return scores
59
+
60
+
61
+ class PrefixEncoder(torch.nn.Module):
62
+ """
63
+ The torch.nn model to encode the prefix
64
+ Input shape: (batch-size, prefix-length)
65
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
66
+ """
67
+
68
+ def __init__(self, config: ChatGLMConfig):
69
+ super().__init__()
70
+ self.prefix_projection = config.prefix_projection
71
+ if self.prefix_projection:
72
+ # Use a two-layer MLP to encode the prefix
73
+ kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
74
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
75
+ self.trans = torch.nn.Sequential(
76
+ torch.nn.Linear(kv_size, config.hidden_size),
77
+ torch.nn.Tanh(),
78
+ torch.nn.Linear(config.hidden_size, kv_size)
79
+ )
80
+ else:
81
+ self.embedding = torch.nn.Embedding(config.pre_seq_len,
82
+ config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
83
+
84
+ def forward(self, prefix: torch.Tensor):
85
+ if self.prefix_projection:
86
+ prefix_tokens = self.embedding(prefix)
87
+ past_key_values = self.trans(prefix_tokens)
88
+ else:
89
+ past_key_values = self.embedding(prefix)
90
+ return past_key_values
91
+
92
+
93
+ def split_tensor_along_last_dim(
94
+ tensor: torch.Tensor,
95
+ num_partitions: int,
96
+ contiguous_split_chunks: bool = False,
97
+ ) -> List[torch.Tensor]:
98
+ """Split a tensor along its last dimension.
99
+
100
+ Arguments:
101
+ tensor: input tensor.
102
+ num_partitions: number of partitions to split the tensor
103
+ contiguous_split_chunks: If True, make each chunk contiguous
104
+ in memory.
105
+
106
+ Returns:
107
+ A list of Tensors
108
+ """
109
+ # Get the size and dimension.
110
+ last_dim = tensor.dim() - 1
111
+ last_dim_size = tensor.size()[last_dim] // num_partitions
112
+ # Split.
113
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
114
+ # Note: torch.split does not create contiguous tensors by default.
115
+ if contiguous_split_chunks:
116
+ return tuple(chunk.contiguous() for chunk in tensor_list)
117
+
118
+ return tensor_list
119
+
120
+
121
+ class RotaryEmbedding(nn.Module):
122
+ def __init__(self, dim, original_impl=False, device=None, dtype=None):
123
+ super().__init__()
124
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
125
+ self.register_buffer("inv_freq", inv_freq)
126
+ self.dim = dim
127
+ self.original_impl = original_impl
128
+
129
+ def forward_impl(
130
+ self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
131
+ ):
132
+ """Enhanced Transformer with Rotary Position Embedding.
133
+
134
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
135
+ transformers/rope/__init__.py. MIT License:
136
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
137
+ """
138
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
139
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
140
+
141
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
142
+ seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
143
+
144
+ # Calculate the product of position index and $\theta_i$
145
+ idx_theta = torch.outer(seq_idx, theta).float()
146
+
147
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
148
+
149
+ # this is to mimic the behaviour of complex32, else we will get different results
150
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
151
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
152
+ return cache
153
+
154
+ def forward(self, max_seq_len, offset=0):
155
+ return self.forward_impl(
156
+ max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
157
+ )
158
+
159
+
160
+ @torch.jit.script
161
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
162
+ # x: [sq, b, np, hn]
163
+ sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
164
+ rot_dim = rope_cache.shape[-2] * 2
165
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
166
+ # truncate to support variable sizes
167
+ rope_cache = rope_cache[:sq]
168
+ xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
169
+ rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
170
+ x_out2 = torch.stack(
171
+ [
172
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
173
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
174
+ ],
175
+ -1,
176
+ )
177
+ x_out2 = x_out2.flatten(3)
178
+ return torch.cat((x_out2, x_pass), dim=-1)
179
+
180
+
181
+ class RMSNorm(torch.nn.Module):
182
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
183
+ super().__init__()
184
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
185
+ self.eps = eps
186
+
187
+ def forward(self, hidden_states: torch.Tensor):
188
+ input_dtype = hidden_states.dtype
189
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
190
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
191
+
192
+ return (self.weight * hidden_states).to(input_dtype)
193
+
194
+
195
+ class CoreAttention(torch.nn.Module):
196
+ def __init__(self, config: ChatGLMConfig, layer_number):
197
+ super(CoreAttention, self).__init__()
198
+
199
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
200
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
201
+ if self.apply_query_key_layer_scaling:
202
+ self.attention_softmax_in_fp32 = True
203
+ self.layer_number = max(1, layer_number)
204
+
205
+ projection_size = config.kv_channels * config.num_attention_heads
206
+
207
+ # Per attention head and per partition values.
208
+ self.hidden_size_per_partition = projection_size
209
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
210
+ self.num_attention_heads_per_partition = config.num_attention_heads
211
+
212
+ coeff = None
213
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
214
+ if self.apply_query_key_layer_scaling:
215
+ coeff = self.layer_number
216
+ self.norm_factor *= coeff
217
+ self.coeff = coeff
218
+
219
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
220
+
221
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
222
+ pytorch_major_version = int(torch.__version__.split('.')[0])
223
+ if False:
224
+ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
225
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
226
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
227
+ is_causal=True)
228
+ else:
229
+ if attention_mask is not None:
230
+ attention_mask = ~attention_mask
231
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
232
+ attention_mask)
233
+ context_layer = context_layer.permute(2, 0, 1, 3)
234
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
235
+ context_layer = context_layer.reshape(*new_context_layer_shape)
236
+ else:
237
+ # Raw attention scores
238
+
239
+ # [b, np, sq, sk]
240
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
241
+
242
+ # [sq, b, np, hn] -> [sq, b * np, hn]
243
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
244
+ # [sk, b, np, hn] -> [sk, b * np, hn]
245
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
246
+
247
+ # preallocting input tensor: [b * np, sq, sk]
248
+ matmul_input_buffer = torch.empty(
249
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
250
+ device=query_layer.device
251
+ )
252
+
253
+ # Raw attention scores. [b * np, sq, sk]
254
+ matmul_result = torch.baddbmm(
255
+ matmul_input_buffer,
256
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
257
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
258
+ beta=0.0,
259
+ alpha=(1.0 / self.norm_factor),
260
+ )
261
+
262
+ # change view to [b, np, sq, sk]
263
+ attention_scores = matmul_result.view(*output_size)
264
+
265
+ # ===========================
266
+ # Attention probs and dropout
267
+ # ===========================
268
+
269
+ # attention scores and attention mask [b, np, sq, sk]
270
+ if self.attention_softmax_in_fp32:
271
+ attention_scores = attention_scores.float()
272
+ if self.coeff is not None:
273
+ attention_scores = attention_scores * self.coeff
274
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
275
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
276
+ device=attention_scores.device, dtype=torch.bool)
277
+ attention_mask.tril_()
278
+ attention_mask = ~attention_mask
279
+ if attention_mask is not None:
280
+ attention_scores = attention_scores + attention_mask
281
+ attention_probs = F.softmax(attention_scores, dim=-1)
282
+ attention_probs = attention_probs.type_as(value_layer)
283
+
284
+ # This is actually dropping out entire tokens to attend to, which might
285
+ # seem a bit unusual, but is taken from the original Transformer paper.
286
+ attention_probs = self.attention_dropout(attention_probs)
287
+ # =========================
288
+ # Context layer. [sq, b, hp]
289
+ # =========================
290
+
291
+ # value_layer -> context layer.
292
+ # [sk, b, np, hn] --> [b, np, sq, hn]
293
+
294
+ # context layer shape: [b, np, sq, hn]
295
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
296
+ # change view [sk, b * np, hn]
297
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
298
+ # change view [b * np, sq, sk]
299
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
300
+ # matmul: [b * np, sq, hn]
301
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
302
+ # change view [b, np, sq, hn]
303
+ context_layer = context_layer.view(*output_size)
304
+ # [b, np, sq, hn] --> [sq, b, np, hn]
305
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
306
+ # [sq, b, np, hn] --> [sq, b, hp]
307
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
308
+ context_layer = context_layer.view(*new_context_layer_shape)
309
+
310
+ return context_layer
311
+
312
+
313
+ class SelfAttention(torch.nn.Module):
314
+ """Parallel self-attention layer abstract class.
315
+
316
+ Self-attention layer takes input with size [s, b, h]
317
+ and returns output of the same size.
318
+ """
319
+
320
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
321
+ super(SelfAttention, self).__init__()
322
+ self.layer_number = max(1, layer_number)
323
+
324
+ self.projection_size = config.kv_channels * config.num_attention_heads
325
+
326
+ # Per attention head and per partition values.
327
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
328
+ self.num_attention_heads_per_partition = config.num_attention_heads
329
+
330
+ self.multi_query_attention = config.multi_query_attention
331
+ self.qkv_hidden_size = 3 * self.projection_size
332
+ if self.multi_query_attention:
333
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
334
+ self.qkv_hidden_size = (
335
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
336
+ )
337
+ self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
338
+ bias=config.add_bias_linear or config.add_qkv_bias,
339
+ device=device, **_config_to_kwargs(config)
340
+ )
341
+
342
+ self.core_attention = CoreAttention(config, self.layer_number)
343
+
344
+ # Output.
345
+ self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
346
+ device=device, **_config_to_kwargs(config)
347
+ )
348
+
349
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
350
+ if self.multi_query_attention:
351
+ num_attention_heads = self.num_multi_query_groups_per_partition
352
+ else:
353
+ num_attention_heads = self.num_attention_heads_per_partition
354
+ return torch.empty(
355
+ inference_max_sequence_len,
356
+ batch_size,
357
+ num_attention_heads,
358
+ self.hidden_size_per_attention_head,
359
+ dtype=dtype,
360
+ device=device,
361
+ )
362
+
363
+ def forward(
364
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
365
+ ):
366
+ # hidden_states: [sq, b, h]
367
+
368
+ # =================================================
369
+ # Pre-allocate memory for key-values for inference.
370
+ # =================================================
371
+ # =====================
372
+ # Query, Key, and Value
373
+ # =====================
374
+
375
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
376
+ mixed_x_layer = self.query_key_value(hidden_states)
377
+
378
+ if self.multi_query_attention:
379
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
380
+ [
381
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
382
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
383
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
384
+ ],
385
+ dim=-1,
386
+ )
387
+ query_layer = query_layer.view(
388
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
389
+ )
390
+ key_layer = key_layer.view(
391
+ key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
392
+ )
393
+ value_layer = value_layer.view(
394
+ value_layer.size()[:-1]
395
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
396
+ )
397
+ else:
398
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
399
+ (self.num_attention_heads_per_partition,
400
+ 3 * self.hidden_size_per_attention_head)
401
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
402
+
403
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
404
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
405
+
406
+ # apply relative positional encoding (rotary embedding)
407
+ if rotary_pos_emb is not None:
408
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
409
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
410
+
411
+ # adjust key and value for inference
412
+ if kv_cache is not None:
413
+ cache_k, cache_v = kv_cache
414
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
415
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
416
+ if use_cache:
417
+ kv_cache = (key_layer, value_layer)
418
+ else:
419
+ kv_cache = None
420
+
421
+ if self.multi_query_attention:
422
+ key_layer = key_layer.unsqueeze(-2)
423
+ key_layer = key_layer.expand(
424
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
425
+ )
426
+ key_layer = key_layer.contiguous().view(
427
+ key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
428
+ )
429
+ value_layer = value_layer.unsqueeze(-2)
430
+ value_layer = value_layer.expand(
431
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
432
+ )
433
+ value_layer = value_layer.contiguous().view(
434
+ value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
435
+ )
436
+
437
+ # ==================================
438
+ # core attention computation
439
+ # ==================================
440
+
441
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
442
+
443
+ # =================
444
+ # Output. [sq, b, h]
445
+ # =================
446
+
447
+ output = self.dense(context_layer)
448
+
449
+ return output, kv_cache
450
+
451
+
452
+ def _config_to_kwargs(args):
453
+ common_kwargs = {
454
+ "dtype": args.torch_dtype,
455
+ }
456
+ return common_kwargs
457
+
458
+
459
+ class MLP(torch.nn.Module):
460
+ """MLP.
461
+
462
+ MLP will take the input with h hidden state, project it to 4*h
463
+ hidden dimension, perform nonlinear transformation, and project the
464
+ state back into h hidden dimension.
465
+ """
466
+
467
+ def __init__(self, config: ChatGLMConfig, device=None):
468
+ super(MLP, self).__init__()
469
+
470
+ self.add_bias = config.add_bias_linear
471
+
472
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
473
+ self.dense_h_to_4h = nn.Linear(
474
+ config.hidden_size,
475
+ config.ffn_hidden_size * 2,
476
+ bias=self.add_bias,
477
+ device=device,
478
+ **_config_to_kwargs(config)
479
+ )
480
+
481
+ def swiglu(x):
482
+ x = torch.chunk(x, 2, dim=-1)
483
+ return F.silu(x[0]) * x[1]
484
+
485
+ self.activation_func = swiglu
486
+
487
+ # Project back to h.
488
+ self.dense_4h_to_h = nn.Linear(
489
+ config.ffn_hidden_size,
490
+ config.hidden_size,
491
+ bias=self.add_bias,
492
+ device=device,
493
+ **_config_to_kwargs(config)
494
+ )
495
+
496
+ def forward(self, hidden_states):
497
+ # [s, b, 4hp]
498
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
499
+ intermediate_parallel = self.activation_func(intermediate_parallel)
500
+ # [s, b, h]
501
+ output = self.dense_4h_to_h(intermediate_parallel)
502
+ return output
503
+
504
+
505
+ class GLMBlock(torch.nn.Module):
506
+ """A single transformer layer.
507
+
508
+ Transformer layer takes input with size [s, b, h] and returns an
509
+ output of the same size.
510
+ """
511
+
512
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
513
+ super(GLMBlock, self).__init__()
514
+ self.layer_number = layer_number
515
+
516
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
517
+
518
+ self.fp32_residual_connection = config.fp32_residual_connection
519
+
520
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
521
+ # Layernorm on the input data.
522
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
523
+ dtype=config.torch_dtype)
524
+
525
+ # Self attention.
526
+ self.self_attention = SelfAttention(config, layer_number, device=device)
527
+ self.hidden_dropout = config.hidden_dropout
528
+
529
+ # Layernorm on the attention output
530
+ self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
531
+ dtype=config.torch_dtype)
532
+
533
+ # MLP
534
+ self.mlp = MLP(config, device=device)
535
+
536
+ def forward(
537
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
538
+ ):
539
+ # hidden_states: [s, b, h]
540
+
541
+ # Layer norm at the beginning of the transformer layer.
542
+ layernorm_output = self.input_layernorm(hidden_states)
543
+ # Self attention.
544
+ attention_output, kv_cache = self.self_attention(
545
+ layernorm_output,
546
+ attention_mask,
547
+ rotary_pos_emb,
548
+ kv_cache=kv_cache,
549
+ use_cache=use_cache
550
+ )
551
+
552
+ # Residual connection.
553
+ if self.apply_residual_connection_post_layernorm:
554
+ residual = layernorm_output
555
+ else:
556
+ residual = hidden_states
557
+
558
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
559
+ layernorm_input = residual + layernorm_input
560
+
561
+ # Layer norm post the self attention.
562
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
563
+
564
+ # MLP.
565
+ mlp_output = self.mlp(layernorm_output)
566
+
567
+ # Second residual connection.
568
+ if self.apply_residual_connection_post_layernorm:
569
+ residual = layernorm_output
570
+ else:
571
+ residual = layernorm_input
572
+
573
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
574
+ output = residual + output
575
+
576
+ return output, kv_cache
577
+
578
+
579
+ class GLMTransformer(torch.nn.Module):
580
+ """Transformer class."""
581
+
582
+ def __init__(self, config: ChatGLMConfig, device=None):
583
+ super(GLMTransformer, self).__init__()
584
+
585
+ self.fp32_residual_connection = config.fp32_residual_connection
586
+ self.post_layer_norm = config.post_layer_norm
587
+
588
+ # Number of layers.
589
+ self.num_layers = config.num_layers
590
+
591
+ # Transformer layers.
592
+ def build_layer(layer_number):
593
+ return GLMBlock(config, layer_number, device=device)
594
+
595
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
596
+
597
+ if self.post_layer_norm:
598
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
599
+ # Final layer norm before output.
600
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
601
+ dtype=config.torch_dtype)
602
+
603
+ self.gradient_checkpointing = False
604
+
605
+ def _get_layer(self, layer_number):
606
+ return self.layers[layer_number]
607
+
608
+ def forward(
609
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
610
+ use_cache: Optional[bool] = True,
611
+ output_hidden_states: Optional[bool] = False,
612
+ ):
613
+ if not kv_caches:
614
+ kv_caches = [None for _ in range(self.num_layers)]
615
+ presents = () if use_cache else None
616
+ if self.gradient_checkpointing and self.training:
617
+ if use_cache:
618
+ logger.warning_once(
619
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
620
+ )
621
+ use_cache = False
622
+
623
+ all_self_attentions = None
624
+ all_hidden_states = () if output_hidden_states else None
625
+ for index in range(self.num_layers):
626
+ if output_hidden_states:
627
+ all_hidden_states = all_hidden_states + (hidden_states,)
628
+
629
+ layer = self._get_layer(index)
630
+ if self.gradient_checkpointing and self.training:
631
+ layer_ret = torch.utils.checkpoint.checkpoint(
632
+ layer,
633
+ hidden_states,
634
+ attention_mask,
635
+ rotary_pos_emb,
636
+ kv_caches[index],
637
+ use_cache
638
+ )
639
+ else:
640
+ layer_ret = layer(
641
+ hidden_states,
642
+ attention_mask,
643
+ rotary_pos_emb,
644
+ kv_cache=kv_caches[index],
645
+ use_cache=use_cache
646
+ )
647
+ hidden_states, kv_cache = layer_ret
648
+ if use_cache:
649
+ presents = presents + (kv_cache,)
650
+
651
+ if output_hidden_states:
652
+ all_hidden_states = all_hidden_states + (hidden_states,)
653
+
654
+ # Final layer norm.
655
+ if self.post_layer_norm:
656
+ hidden_states = self.final_layernorm(hidden_states)
657
+
658
+ return hidden_states, presents, all_hidden_states, all_self_attentions
659
+
660
+
661
+ class ChatGLMPreTrainedModel(PreTrainedModel):
662
+ """
663
+ An abstract class to handle weights initialization and
664
+ a simple interface for downloading and loading pretrained models.
665
+ """
666
+
667
+ is_parallelizable = False
668
+ supports_gradient_checkpointing = True
669
+ config_class = ChatGLMConfig
670
+ base_model_prefix = "transformer"
671
+ _no_split_modules = ["GLMBlock"]
672
+
673
+ def _init_weights(self, module: nn.Module):
674
+ """Initialize the weights."""
675
+ return
676
+
677
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
678
+ batch_size, seq_length = input_ids.shape
679
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
680
+ full_attention_mask.tril_()
681
+ past_length = 0
682
+ if past_key_values:
683
+ past_length = past_key_values[0][0].shape[0]
684
+ if past_length:
685
+ full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
686
+ device=input_ids.device), full_attention_mask), dim=-1)
687
+ if padding_mask is not None:
688
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
689
+ if not past_length and padding_mask is not None:
690
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
691
+ full_attention_mask = (full_attention_mask < 0.5).bool()
692
+ full_attention_mask.unsqueeze_(1)
693
+ return full_attention_mask
694
+
695
+ def get_position_ids(self, input_ids, device):
696
+ batch_size, seq_length = input_ids.shape
697
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
698
+ return position_ids
699
+
700
+ def _set_gradient_checkpointing(self, module, value=False):
701
+ if isinstance(module, GLMTransformer):
702
+ module.gradient_checkpointing = value
703
+
704
+
705
+ class Embedding(torch.nn.Module):
706
+ """Language model embeddings."""
707
+
708
+ def __init__(self, config: ChatGLMConfig, device=None):
709
+ super(Embedding, self).__init__()
710
+
711
+ self.hidden_size = config.hidden_size
712
+ # Word embeddings (parallel).
713
+ self.word_embeddings = nn.Embedding(
714
+ config.padded_vocab_size,
715
+ self.hidden_size,
716
+ dtype=config.torch_dtype,
717
+ device=device
718
+ )
719
+ self.fp32_residual_connection = config.fp32_residual_connection
720
+
721
+ def forward(self, input_ids):
722
+ # Embeddings.
723
+ words_embeddings = self.word_embeddings(input_ids)
724
+ embeddings = words_embeddings
725
+ # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
726
+ embeddings = embeddings.transpose(0, 1).contiguous()
727
+ # If the input flag for fp32 residual connection is set, convert for float.
728
+ if self.fp32_residual_connection:
729
+ embeddings = embeddings.float()
730
+ return embeddings
731
+
732
+
733
+ class ChatGLMModel(ChatGLMPreTrainedModel):
734
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
735
+ super().__init__(config)
736
+ if empty_init:
737
+ init_method = skip_init
738
+ else:
739
+ init_method = default_init
740
+ init_kwargs = {}
741
+ if device is not None:
742
+ init_kwargs["device"] = device
743
+ self.embedding = init_method(Embedding, config, **init_kwargs)
744
+ self.num_layers = config.num_layers
745
+ self.multi_query_group_num = config.multi_query_group_num
746
+ self.kv_channels = config.kv_channels
747
+
748
+ # Rotary positional embeddings
749
+ self.seq_length = config.seq_length
750
+ rotary_dim = (
751
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
752
+ )
753
+
754
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
755
+ dtype=config.torch_dtype)
756
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
757
+ self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
758
+ dtype=config.torch_dtype, **init_kwargs)
759
+ self.pre_seq_len = config.pre_seq_len
760
+ self.prefix_projection = config.prefix_projection
761
+ if self.pre_seq_len is not None:
762
+ for param in self.parameters():
763
+ param.requires_grad = False
764
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
765
+ self.prefix_encoder = PrefixEncoder(config)
766
+ self.dropout = torch.nn.Dropout(0.1)
767
+
768
+ def get_input_embeddings(self):
769
+ return self.embedding.word_embeddings
770
+
771
+ def get_prompt(self, batch_size, device, dtype=torch.half):
772
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
773
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
774
+ past_key_values = past_key_values.view(
775
+ batch_size,
776
+ self.pre_seq_len,
777
+ self.num_layers * 2,
778
+ self.multi_query_group_num,
779
+ self.kv_channels
780
+ )
781
+ # seq_len, b, nh, hidden_size
782
+ past_key_values = self.dropout(past_key_values)
783
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
784
+ return past_key_values
785
+
786
+ def forward(
787
+ self,
788
+ input_ids,
789
+ position_ids: Optional[torch.Tensor] = None,
790
+ attention_mask: Optional[torch.BoolTensor] = None,
791
+ full_attention_mask: Optional[torch.BoolTensor] = None,
792
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
793
+ inputs_embeds: Optional[torch.Tensor] = None,
794
+ use_cache: Optional[bool] = None,
795
+ output_hidden_states: Optional[bool] = None,
796
+ return_dict: Optional[bool] = None,
797
+ ):
798
+ output_hidden_states = (
799
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
800
+ )
801
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
802
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
803
+
804
+ batch_size, seq_length = input_ids.shape
805
+
806
+ if inputs_embeds is None:
807
+ inputs_embeds = self.embedding(input_ids)
808
+
809
+ if self.pre_seq_len is not None:
810
+ if past_key_values is None:
811
+ past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
812
+ dtype=inputs_embeds.dtype)
813
+ if attention_mask is not None:
814
+ attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
815
+ attention_mask], dim=-1)
816
+
817
+ if full_attention_mask is None:
818
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
819
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
820
+
821
+ # Rotary positional embeddings
822
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
823
+ if position_ids is not None:
824
+ rotary_pos_emb = rotary_pos_emb[position_ids]
825
+ else:
826
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
827
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
828
+
829
+ # Run encoder.
830
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
831
+ inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
832
+ kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
833
+ )
834
+
835
+ if not return_dict:
836
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
837
+
838
+ return BaseModelOutputWithPast(
839
+ last_hidden_state=hidden_states,
840
+ past_key_values=presents,
841
+ hidden_states=all_hidden_states,
842
+ attentions=all_self_attentions,
843
+ )
844
+
845
+ def quantize(self, weight_bit_width: int):
846
+ from .quantization import quantize
847
+ quantize(self.encoder, weight_bit_width)
848
+ return self
849
+
850
+
851
+ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
852
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
853
+ super().__init__(config)
854
+
855
+ self.max_sequence_length = config.max_length
856
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
857
+ self.config = config
858
+ self.quantized = False
859
+
860
+ if self.config.quantization_bit:
861
+ self.quantize(self.config.quantization_bit, empty_init=True)
862
+
863
+ def _update_model_kwargs_for_generation(
864
+ self,
865
+ outputs: ModelOutput,
866
+ model_kwargs: Dict[str, Any],
867
+ is_encoder_decoder: bool = False,
868
+ standardize_cache_format: bool = False,
869
+ ) -> Dict[str, Any]:
870
+ # update past_key_values
871
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
872
+ outputs, standardize_cache_format=standardize_cache_format
873
+ )
874
+
875
+ # update attention mask
876
+ if "attention_mask" in model_kwargs:
877
+ attention_mask = model_kwargs["attention_mask"]
878
+ model_kwargs["attention_mask"] = torch.cat(
879
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
880
+ )
881
+
882
+ # update position ids
883
+ if "position_ids" in model_kwargs:
884
+ position_ids = model_kwargs["position_ids"]
885
+ new_position_id = position_ids[..., -1:].clone()
886
+ new_position_id += 1
887
+ model_kwargs["position_ids"] = torch.cat(
888
+ [position_ids, new_position_id], dim=-1
889
+ )
890
+
891
+ model_kwargs["is_first_forward"] = False
892
+ return model_kwargs
893
+
894
+ def prepare_inputs_for_generation(
895
+ self,
896
+ input_ids: torch.LongTensor,
897
+ past_key_values: Optional[torch.Tensor] = None,
898
+ attention_mask: Optional[torch.Tensor] = None,
899
+ position_ids: Optional[torch.Tensor] = None,
900
+ use_cache: Optional[bool] = None,
901
+ is_first_forward: bool = True,
902
+ **kwargs
903
+ ) -> dict:
904
+ # only last token for input_ids if past is not None
905
+ if position_ids is None:
906
+ position_ids = self.get_position_ids(input_ids, device=input_ids.device)
907
+ if not is_first_forward:
908
+ if past_key_values is not None:
909
+ position_ids = position_ids[..., -1:]
910
+ input_ids = input_ids[:, -1:]
911
+ return {
912
+ "input_ids": input_ids,
913
+ "past_key_values": past_key_values,
914
+ "position_ids": position_ids,
915
+ "attention_mask": attention_mask,
916
+ "return_last_logit": True,
917
+ "use_cache": use_cache
918
+ }
919
+
920
+ def forward(
921
+ self,
922
+ input_ids: Optional[torch.Tensor] = None,
923
+ position_ids: Optional[torch.Tensor] = None,
924
+ attention_mask: Optional[torch.Tensor] = None,
925
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
926
+ inputs_embeds: Optional[torch.Tensor] = None,
927
+ labels: Optional[torch.Tensor] = None,
928
+ use_cache: Optional[bool] = None,
929
+ output_attentions: Optional[bool] = None,
930
+ output_hidden_states: Optional[bool] = None,
931
+ return_dict: Optional[bool] = None,
932
+ return_last_logit: Optional[bool] = False,
933
+ ):
934
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
935
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
936
+
937
+ transformer_outputs = self.transformer(
938
+ input_ids=input_ids,
939
+ position_ids=position_ids,
940
+ attention_mask=attention_mask,
941
+ past_key_values=past_key_values,
942
+ inputs_embeds=inputs_embeds,
943
+ use_cache=use_cache,
944
+ output_hidden_states=output_hidden_states,
945
+ return_dict=return_dict,
946
+ )
947
+
948
+ hidden_states = transformer_outputs[0]
949
+ if return_last_logit:
950
+ hidden_states = hidden_states[-1:]
951
+ lm_logits = self.transformer.output_layer(hidden_states)
952
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
953
+
954
+ loss = None
955
+ if labels is not None:
956
+ lm_logits = lm_logits.to(torch.float32)
957
+
958
+ # Shift so that tokens < n predict n
959
+ shift_logits = lm_logits[..., :-1, :].contiguous()
960
+ shift_labels = labels[..., 1:].contiguous()
961
+ # Flatten the tokens
962
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
963
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
964
+
965
+ lm_logits = lm_logits.to(hidden_states.dtype)
966
+ loss = loss.to(hidden_states.dtype)
967
+
968
+ if not return_dict:
969
+ output = (lm_logits,) + transformer_outputs[1:]
970
+ return ((loss,) + output) if loss is not None else output
971
+
972
+ return CausalLMOutputWithPast(
973
+ loss=loss,
974
+ logits=lm_logits,
975
+ past_key_values=transformer_outputs.past_key_values,
976
+ hidden_states=transformer_outputs.hidden_states,
977
+ attentions=transformer_outputs.attentions,
978
+ )
979
+
980
+ @staticmethod
981
+ def _reorder_cache(
982
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
983
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
984
+ """
985
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
986
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
987
+ beam_idx at every generation step.
988
+
989
+ Output shares the same memory storage as `past`.
990
+ """
991
+ return tuple(
992
+ (
993
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
994
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
995
+ )
996
+ for layer_past in past
997
+ )
998
+
999
+ def process_response(self, response):
1000
+ response = response.strip()
1001
+ response = response.replace("[[训练时间]]", "2023年")
1002
+ return response
1003
+
1004
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
1005
+ prompt = tokenizer.build_prompt(query, history=history)
1006
+ inputs = tokenizer([prompt], return_tensors="pt")
1007
+ inputs = inputs.to(self.device)
1008
+ return inputs
1009
+
1010
+ def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
1011
+ if history:
1012
+ prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
1013
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False)
1014
+ input_ids = input_ids[1:]
1015
+ inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False)
1016
+ else:
1017
+ prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
1018
+ inputs = tokenizer([prompt], return_tensors="pt")
1019
+ inputs = inputs.to(self.device)
1020
+ return inputs
1021
+
1022
+ @torch.inference_mode()
1023
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
1024
+ do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
1025
+ if history is None:
1026
+ history = []
1027
+ if logits_processor is None:
1028
+ logits_processor = LogitsProcessorList()
1029
+ logits_processor.append(InvalidScoreLogitsProcessor())
1030
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1031
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1032
+ inputs = self.build_inputs(tokenizer, query, history=history)
1033
+ outputs = self.generate(**inputs, **gen_kwargs)
1034
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
1035
+ response = tokenizer.decode(outputs)
1036
+ response = self.process_response(response)
1037
+ history = history + [(query, response)]
1038
+ return response, history
1039
+
1040
+ @torch.inference_mode()
1041
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
1042
+ max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1043
+ return_past_key_values=False, **kwargs):
1044
+ if history is None:
1045
+ history = []
1046
+ if logits_processor is None:
1047
+ logits_processor = LogitsProcessorList()
1048
+ logits_processor.append(InvalidScoreLogitsProcessor())
1049
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1050
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1051
+ if past_key_values is None and not return_past_key_values:
1052
+ inputs = self.build_inputs(tokenizer, query, history=history)
1053
+ else:
1054
+ inputs = self.build_stream_inputs(tokenizer, query, history=history)
1055
+ if past_key_values is not None:
1056
+ past_length = past_key_values[0][0].shape[0]
1057
+ if self.transformer.pre_seq_len is not None:
1058
+ past_length -= self.transformer.pre_seq_len
1059
+ inputs.position_ids += past_length
1060
+ attention_mask = inputs.attention_mask
1061
+ attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1062
+ inputs['attention_mask'] = attention_mask
1063
+ for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1064
+ return_past_key_values=return_past_key_values, **gen_kwargs):
1065
+ if return_past_key_values:
1066
+ outputs, past_key_values = outputs
1067
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
1068
+ response = tokenizer.decode(outputs)
1069
+ if response and response[-1] != "�":
1070
+ response = self.process_response(response)
1071
+ new_history = history + [(query, response)]
1072
+ if return_past_key_values:
1073
+ yield response, new_history, past_key_values
1074
+ else:
1075
+ yield response, new_history
1076
+
1077
+ @torch.inference_mode()
1078
+ def stream_generate(
1079
+ self,
1080
+ input_ids,
1081
+ generation_config: Optional[GenerationConfig] = None,
1082
+ logits_processor: Optional[LogitsProcessorList] = None,
1083
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1084
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1085
+ return_past_key_values=False,
1086
+ **kwargs,
1087
+ ):
1088
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1089
+
1090
+ if generation_config is None:
1091
+ generation_config = self.generation_config
1092
+ generation_config = copy.deepcopy(generation_config)
1093
+ model_kwargs = generation_config.update(**kwargs)
1094
+ model_kwargs["use_cache"] = generation_config.use_cache
1095
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1096
+
1097
+ if isinstance(eos_token_id, int):
1098
+ eos_token_id = [eos_token_id]
1099
+
1100
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1101
+ if has_default_max_length and generation_config.max_new_tokens is None:
1102
+ warnings.warn(
1103
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1104
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1105
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
1106
+ UserWarning,
1107
+ )
1108
+ elif generation_config.max_new_tokens is not None:
1109
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1110
+ if not has_default_max_length:
1111
+ logger.warn(
1112
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1113
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1114
+ "Please refer to the documentation for more information. "
1115
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1116
+ UserWarning,
1117
+ )
1118
+
1119
+ if input_ids_seq_length >= generation_config.max_length:
1120
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1121
+ logger.warning(
1122
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1123
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1124
+ " increasing `max_new_tokens`."
1125
+ )
1126
+
1127
+ # 2. Set generation parameters if not already defined
1128
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1129
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1130
+
1131
+ logits_processor = self._get_logits_processor(
1132
+ generation_config=generation_config,
1133
+ input_ids_seq_length=input_ids_seq_length,
1134
+ encoder_input_ids=input_ids,
1135
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1136
+ logits_processor=logits_processor,
1137
+ )
1138
+
1139
+ stopping_criteria = self._get_stopping_criteria(
1140
+ generation_config=generation_config, stopping_criteria=stopping_criteria
1141
+ )
1142
+ logits_warper = self._get_logits_warper(generation_config)
1143
+
1144
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1145
+ scores = None
1146
+ while True:
1147
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1148
+ # forward pass to get next token
1149
+ outputs = self(
1150
+ **model_inputs,
1151
+ return_dict=True,
1152
+ output_attentions=False,
1153
+ output_hidden_states=False,
1154
+ )
1155
+
1156
+ next_token_logits = outputs.logits[:, -1, :]
1157
+
1158
+ # pre-process distribution
1159
+ next_token_scores = logits_processor(input_ids, next_token_logits)
1160
+ next_token_scores = logits_warper(input_ids, next_token_scores)
1161
+
1162
+ # sample
1163
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1164
+ if generation_config.do_sample:
1165
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1166
+ else:
1167
+ next_tokens = torch.argmax(probs, dim=-1)
1168
+
1169
+ # update generated ids, model inputs, and length for next step
1170
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1171
+ model_kwargs = self._update_model_kwargs_for_generation(
1172
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1173
+ )
1174
+ unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
1175
+ if return_past_key_values:
1176
+ yield input_ids, outputs.past_key_values
1177
+ else:
1178
+ yield input_ids
1179
+ # stop when each sentence is finished, or if we exceed the maximum length
1180
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1181
+ break
1182
+
1183
+ def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
1184
+ if bits == 0:
1185
+ return
1186
+
1187
+ from .quantization import quantize
1188
+
1189
+ if self.quantized:
1190
+ logger.info("Already quantized.")
1191
+ return self
1192
+
1193
+ self.quantized = True
1194
+
1195
+ self.config.quantization_bit = bits
1196
+
1197
+ self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
1198
+ **kwargs)
1199
+ return self
1200
+
1201
+
1202
+ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1203
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1204
+ super().__init__(config)
1205
+
1206
+ self.num_labels = config.num_labels
1207
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1208
+
1209
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
1210
+ if config.classifier_dropout is not None:
1211
+ self.dropout = nn.Dropout(config.classifier_dropout)
1212
+ else:
1213
+ self.dropout = None
1214
+ self.config = config
1215
+
1216
+ if self.config.quantization_bit:
1217
+ self.quantize(self.config.quantization_bit, empty_init=True)
1218
+
1219
+ def forward(
1220
+ self,
1221
+ input_ids: Optional[torch.LongTensor] = None,
1222
+ position_ids: Optional[torch.LongTensor] = None,
1223
+ attention_mask: Optional[torch.Tensor] = None,
1224
+ full_attention_mask: Optional[torch.Tensor] = None,
1225
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1226
+ inputs_embeds: Optional[torch.LongTensor] = None,
1227
+ labels: Optional[torch.LongTensor] = None,
1228
+ use_cache: Optional[bool] = None,
1229
+ output_hidden_states: Optional[bool] = None,
1230
+ return_dict: Optional[bool] = None,
1231
+ ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
1232
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1233
+
1234
+ transformer_outputs = self.transformer(
1235
+ input_ids=input_ids,
1236
+ position_ids=position_ids,
1237
+ attention_mask=attention_mask,
1238
+ full_attention_mask=full_attention_mask,
1239
+ past_key_values=past_key_values,
1240
+ inputs_embeds=inputs_embeds,
1241
+ use_cache=use_cache,
1242
+ output_hidden_states=output_hidden_states,
1243
+ return_dict=return_dict,
1244
+ )
1245
+
1246
+ hidden_states = transformer_outputs[0]
1247
+ pooled_hidden_states = hidden_states[-1]
1248
+ if self.dropout is not None:
1249
+ pooled_hidden_states = self.dropout(pooled_hidden_states)
1250
+ logits = self.classifier_head(pooled_hidden_states)
1251
+
1252
+ loss = None
1253
+ if labels is not None:
1254
+ if self.config.problem_type is None:
1255
+ if self.num_labels == 1:
1256
+ self.config.problem_type = "regression"
1257
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1258
+ self.config.problem_type = "single_label_classification"
1259
+ else:
1260
+ self.config.problem_type = "multi_label_classification"
1261
+
1262
+ if self.config.problem_type == "regression":
1263
+ loss_fct = MSELoss()
1264
+ if self.num_labels == 1:
1265
+ loss = loss_fct(logits.squeeze().float(), labels.squeeze())
1266
+ else:
1267
+ loss = loss_fct(logits.float(), labels)
1268
+ elif self.config.problem_type == "single_label_classification":
1269
+ loss_fct = CrossEntropyLoss()
1270
+ loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
1271
+ elif self.config.problem_type == "multi_label_classification":
1272
+ loss_fct = BCEWithLogitsLoss()
1273
+ loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
1274
+
1275
+ if not return_dict:
1276
+ output = (logits,) + transformer_outputs[1:]
1277
+ return ((loss,) + output) if loss is not None else output
1278
+
1279
+ return SequenceClassifierOutputWithPast(
1280
+ loss=loss,
1281
+ logits=logits,
1282
+ past_key_values=transformer_outputs.past_key_values,
1283
+ hidden_states=transformer_outputs.hidden_states,
1284
+ attentions=transformer_outputs.attentions,
1285
+ )
ChatGLM2/demo/CMakeLists.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 2.8)
2
+ project(chatglm)
3
+
4
+ set(CMAKE_EXPORT_COMPILE_COMMANDS ON CACHE INTERNAL "")
5
+
6
+ if (NOT DEFINED TARGET_ARCH)
7
+ set(TARGET_ARCH pcie)
8
+ endif()
9
+
10
+ include_directories(${PROJECT_SOURCE_DIR}/../support/include)
11
+
12
+ if (${CMAKE_HOST_SYSTEM_PROCESSOR} STREQUAL "aarch64")
13
+ add_definitions(-DSOC_TARGET)
14
+ link_directories(${PROJECT_SOURCE_DIR}/../support/lib_soc)
15
+ message("SoC mode, starting......")
16
+ elseif (${TARGET_ARCH} STREQUAL "pcie")
17
+ add_definitions(-DPCIE_TARGET)
18
+ link_directories(${PROJECT_SOURCE_DIR}/../support/lib_pcie)
19
+ message("PCIE mode, starting......")
20
+ elseif (${TARGET_ARCH} STREQUAL "soc")
21
+ add_definitions(-DSOC_TARGET)
22
+ set(CMAKE_C_COMPILER /opt/aarch64-linux-gnu-7.5.0/bin/aarch64-linux-gnu-gcc)
23
+ set(CMAKE_ASM_COMPILER /opt/aarch64-linux-gnu-7.5.0/bin/aarch64-linux-gnu-gcc)
24
+ set(CMAKE_CXX_COMPILER /opt/aarch64-linux-gnu-7.5.0/bin/aarch64-linux-gnu-g++)
25
+ link_directories(${PROJECT_SOURCE_DIR}/../support/lib_soc)
26
+ message("SoC mode, starting......")
27
+ endif()
28
+
29
+ add_definitions(-DDEBUG --std=c++17 -fPIC -Wall -Werror)
30
+ set(CMAKE_BUILD_TYPE "Debug")
31
+
32
+ add_executable(chatglm demo.cpp)
33
+ target_link_libraries(chatglm bmrt bmlib sentencepiece)
ChatGLM2/demo/demo.cpp ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //===----------------------------------------------------------------------===//
2
+ //
3
+ // Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved.
4
+ //
5
+ // TPU-MLIR is licensed under the 2-Clause BSD License except for the
6
+ // third-party components.
7
+ //
8
+ //===----------------------------------------------------------------------===//
9
+
10
+ #include <iostream>
11
+ #include <cstdlib>
12
+ #include <vector>
13
+ #include <assert.h>
14
+ #include <chrono>
15
+ #include <algorithm>
16
+ #include "memory.h"
17
+ #include "sentencepiece/sentencepiece_processor.h"
18
+ #include "bmruntime_interface.h"
19
+ #include <getopt.h>
20
+ #include <stdio.h>
21
+ #include <inttypes.h>
22
+
23
+ static const uint16_t ATTENTION_MASK = 0xF0E2;
24
+
25
+ class ChatGLM {
26
+ public:
27
+ void init(const std::vector<int> &devid, std::string model_path, std::string tokenizer_path);
28
+ void chat();
29
+ void deinit();
30
+
31
+ private:
32
+ void answer(const std::string &input_str);
33
+ void tokenizer_encode(const std::string &input_str, std::vector<int> &tokens);
34
+ int forward_first(std::vector<int> &tokens);
35
+ int forward_next(int cur_token);
36
+ void move2end(const bm_tensor_t &kv);
37
+ void load_sentencepiece(std::string tokenizer_path);
38
+
39
+ private:
40
+ std::vector<bm_handle_t> handles;
41
+ bm_handle_t bm_handle;
42
+ void *p_bmrt;
43
+ sentencepiece::SentencePieceProcessor sentencepiece;
44
+ const bm_net_info_t *net_embed;
45
+ const bm_net_info_t *net_embed_cache;
46
+ const bm_net_info_t *net_lm;
47
+ std::vector<const bm_net_info_t *> net_blocks;
48
+ std::vector<const bm_net_info_t *> net_blocks_cache;
49
+ std::vector<bm_tensor_t> inputs_embed_512, outputs_embed_512;
50
+ std::vector<bm_tensor_t> inputs_pid, next_pid, inputs_attention, next_attention;
51
+ std::vector<std::vector<bm_tensor_t>> past_key, past_value;
52
+ std::vector<bm_tensor_t> inputs_lm, outputs_lm;
53
+ std::string name_embed;
54
+ std::string name_embed_cache;
55
+ std::string name_lm;
56
+ std::vector<std::string> name_blocks;
57
+ std::vector<std::string> name_blocks_cache;
58
+ std::vector<std::pair<std::string, std::string>> history_vector;
59
+ std::vector<int> history_tokens;
60
+ std::string cur_answer = "";
61
+
62
+ int device_num;
63
+ int round = 0;
64
+ int token_length;
65
+ int EOS;
66
+ int SEQLEN;
67
+ int NUM_LAYERS;
68
+ };
69
+
70
+ void ChatGLM::load_sentencepiece(std::string tokenizer_path) {
71
+ printf("Load %s ... ", tokenizer_path.c_str());
72
+ auto status = sentencepiece.Load(tokenizer_path);
73
+ if (!status.ok()) {
74
+ std::cout << status.ToString() << std::endl;
75
+ exit(-1);
76
+ }
77
+ EOS = sentencepiece.eos_id();
78
+ printf("Done!\n");
79
+ }
80
+
81
+ void ChatGLM::init(const std::vector<int> &devices, std::string model_path, std::string tokenizer_path) {
82
+ device_num = devices.size();
83
+ load_sentencepiece(tokenizer_path);
84
+ // request bm_handle
85
+ std::cout << "Device [ ";
86
+ for (auto d : devices) {
87
+ std::cout << d << " ";
88
+ }
89
+ std::cout << "] loading ....\n";
90
+ for (auto d : devices) {
91
+ bm_handle_t h;
92
+ bm_status_t status = bm_dev_request(&h, d);
93
+ assert(BM_SUCCESS == status);
94
+ handles.push_back(h);
95
+ }
96
+ bm_handle = handles[0];
97
+
98
+ // create bmruntime
99
+ #ifdef SOC_TARGET
100
+ p_bmrt = bmrt_create(handles[0]);
101
+ #else
102
+ p_bmrt = bmrt_create_ex(handles.data(), handles.size());
103
+ #endif
104
+ assert(NULL != p_bmrt);
105
+
106
+ // load bmodel by file
107
+ printf("Model[%s] loading ....\n", model_path.c_str());
108
+ bool ret = bmrt_load_bmodel(p_bmrt, model_path.c_str());
109
+ assert(true == ret);
110
+ printf("Done!\n");
111
+
112
+ // set NUM_LAYERS
113
+ auto num_nets = bmrt_get_network_number(p_bmrt);
114
+ NUM_LAYERS = (num_nets - 2) / 2;
115
+
116
+ // net names
117
+ name_embed = "embedding";
118
+ name_embed_cache = "embedding_cache";
119
+ name_lm = "lm_head";
120
+ for (int i = 0; i < NUM_LAYERS; i++) {
121
+ name_blocks.emplace_back("block_" + std::to_string(i));
122
+ name_blocks_cache.emplace_back("block_cache_" + std::to_string(i));
123
+ }
124
+
125
+ // net infos
126
+ net_embed = bmrt_get_network_info(p_bmrt, name_embed.c_str());
127
+ net_embed_cache = bmrt_get_network_info(p_bmrt, name_embed_cache.c_str());
128
+ net_lm = bmrt_get_network_info(p_bmrt, name_lm.c_str());
129
+ for (int i = 0; i < NUM_LAYERS; i++) {
130
+ net_blocks.emplace_back(
131
+ bmrt_get_network_info(p_bmrt, name_blocks[i].c_str()));
132
+ net_blocks_cache.emplace_back(
133
+ bmrt_get_network_info(p_bmrt, name_blocks_cache[i].c_str()));
134
+ }
135
+
136
+ // set SEQLEN
137
+ SEQLEN = net_embed->stages[0].input_shapes[0].dims[1];
138
+
139
+ // resize
140
+ net_blocks.resize(NUM_LAYERS);
141
+ net_blocks_cache.resize(NUM_LAYERS);
142
+ past_key.resize(NUM_LAYERS);
143
+ past_value.resize(NUM_LAYERS);
144
+
145
+ // net device mem
146
+ inputs_embed_512.resize(net_embed->input_num);
147
+ for (int i = 0; i < device_num; ++i) {
148
+ ret = bmrt_tensor_ex(&inputs_embed_512[i], p_bmrt,
149
+ net_embed->input_loc_devices[i],
150
+ net_embed->input_dtypes[i],
151
+ net_embed->stages[0].input_shapes[i]);
152
+ assert(true == ret);
153
+ }
154
+
155
+ outputs_embed_512.resize(net_embed->output_num);
156
+ for (int i = 0; i < device_num; ++i) {
157
+ ret = bmrt_tensor_ex(&outputs_embed_512[i], p_bmrt,
158
+ net_embed->output_loc_devices[i],
159
+ net_embed->output_dtypes[i],
160
+ net_embed->stages[0].output_shapes[i]);
161
+ assert(true == ret);
162
+ }
163
+
164
+ inputs_pid.resize(device_num);
165
+ inputs_attention.resize(device_num);
166
+ int in_num = net_blocks[0]->input_num / device_num;
167
+ for (int i = 0; i < device_num; ++i) {
168
+ ret = bmrt_tensor_ex(&inputs_pid[i], p_bmrt,
169
+ net_blocks[0]->input_loc_devices[1 + i * in_num],
170
+ net_blocks[0]->input_dtypes[1 + i * in_num],
171
+ net_blocks[0]->stages[0].input_shapes[1 + i * in_num]);
172
+ assert(true == ret);
173
+
174
+ ret = bmrt_tensor_ex(&inputs_attention[i], p_bmrt,
175
+ net_blocks[0]->input_loc_devices[2 + i * in_num],
176
+ net_blocks[0]->input_dtypes[2 + i * in_num],
177
+ net_blocks[0]->stages[0].input_shapes[2 + i * in_num]);
178
+ assert(true == ret);
179
+ }
180
+
181
+
182
+ next_pid.resize(device_num);
183
+ next_attention.resize(device_num);
184
+ int in_num_cache = net_blocks_cache[0]->input_num / device_num;
185
+ for (int i = 0; i < device_num; ++i) {
186
+ ret = bmrt_tensor_ex(&next_pid[i], p_bmrt,
187
+ net_blocks_cache[0]->input_loc_devices[1 + i * in_num_cache],
188
+ net_blocks_cache[0]->input_dtypes[1 + i * in_num_cache],
189
+ net_blocks_cache[0]->stages[0].input_shapes[1 + i * in_num_cache]);
190
+ assert(true == ret);
191
+
192
+ ret = bmrt_tensor_ex(&next_attention[i], p_bmrt,
193
+ net_blocks_cache[0]->input_loc_devices[2 + i * in_num_cache],
194
+ net_blocks_cache[0]->input_dtypes[2 + i * in_num_cache],
195
+ net_blocks_cache[0]->stages[0].input_shapes[2 + i * in_num_cache]);
196
+ assert(true == ret);
197
+ }
198
+
199
+ int out_num = net_blocks[0]->output_num / device_num;
200
+ for (int i = 0; i < NUM_LAYERS; i++) {
201
+ past_key[i].resize(device_num);
202
+ past_value[i].resize(device_num);
203
+ for (int j = 0; j < device_num; j++) {
204
+ ret = bmrt_tensor_ex(&past_key[i][j], p_bmrt,
205
+ net_blocks[0]->output_loc_devices[1 + j * out_num],
206
+ net_blocks[0]->output_dtypes[1 + j * out_num],
207
+ net_blocks[0]->stages[0].output_shapes[1 + j * out_num]);
208
+ assert(true == ret);
209
+ ret = bmrt_tensor_ex(&past_value[i][j], p_bmrt,
210
+ net_blocks[0]->output_loc_devices[2 + j * out_num],
211
+ net_blocks[0]->output_dtypes[2 + j * out_num],
212
+ net_blocks[0]->stages[0].output_shapes[2 + j * out_num]);
213
+ assert(true == ret);
214
+ }
215
+ }
216
+
217
+ inputs_lm.resize(device_num);
218
+ outputs_lm.resize(device_num);
219
+ for (int i = 0; i < device_num; ++i) {
220
+ ret = bmrt_tensor_ex(&inputs_lm[i], p_bmrt, i, net_lm->input_dtypes[0],
221
+ net_lm->stages[0].input_shapes[0]);
222
+ assert(true == ret);
223
+ ret = bmrt_tensor_ex(&outputs_lm[i], p_bmrt, i, net_lm->output_dtypes[0],
224
+ net_lm->stages[0].output_shapes[0]);
225
+ assert(true == ret);
226
+ }
227
+ }
228
+
229
+ void ChatGLM::deinit() {
230
+ for (int i = 0; i < device_num; ++i) {
231
+ bm_free_device(handles[i], inputs_embed_512[i].device_mem);
232
+ bm_free_device(handles[i], outputs_embed_512[i].device_mem);
233
+ bm_free_device(handles[i], inputs_pid[i].device_mem);
234
+ bm_free_device(handles[i], next_pid[i].device_mem);
235
+ bm_free_device(handles[i], inputs_attention[i].device_mem);
236
+ bm_free_device(handles[i], next_attention[i].device_mem);
237
+ bm_free_device(handles[i], inputs_lm[i].device_mem);
238
+ bm_free_device(handles[i], outputs_lm[i].device_mem);
239
+ }
240
+ for (int i = 0; i < NUM_LAYERS; i++) {
241
+ for (int j = 0; j < device_num; j++) {
242
+ bm_free_device(handles[j], past_key[i][j].device_mem);
243
+ bm_free_device(handles[j], past_value[i][j].device_mem);
244
+ }
245
+ }
246
+ bmrt_destroy(p_bmrt);
247
+ for (auto h : handles) {
248
+ bm_dev_free(h);
249
+ }
250
+ }
251
+
252
+ // after first block, move real result to end of mem
253
+ void ChatGLM::move2end(const bm_tensor_t &kv) {
254
+ if (token_length >= SEQLEN) {
255
+ return;
256
+ }
257
+ auto total_size = bm_mem_get_device_size(kv.device_mem);
258
+ auto bytes = total_size / SEQLEN;
259
+ auto real_size = token_length * bytes;
260
+ auto mem =
261
+ bm_mem_from_device(bm_mem_get_device_addr(kv.device_mem), real_size);
262
+ auto buffer = new uint8_t[real_size];
263
+ auto dst = new uint8_t[total_size];
264
+ bm_memcpy_d2s(bm_handle, (void *)buffer, mem);
265
+ memset(dst, 0, total_size - real_size);
266
+ memcpy(dst + total_size - real_size, buffer, real_size);
267
+ bm_memcpy_s2d(bm_handle, kv.device_mem, (void *)dst);
268
+ delete[] buffer;
269
+ delete[] dst;
270
+ }
271
+
272
+ int ChatGLM::forward_first(std::vector<int> &tokens) {
273
+ std::vector<int> input_ids(SEQLEN, 0);
274
+ std::vector<int> position_id(SEQLEN, 0);
275
+ std::vector<uint16_t> attention_mask(SEQLEN * SEQLEN, 0);
276
+
277
+ std::copy(tokens.begin(), tokens.end(), input_ids.data());
278
+
279
+ token_length = tokens.size();
280
+ for (int i = 0; i < token_length; i++) {
281
+ position_id[i] = i;
282
+ }
283
+ for (int i = 0; i < SEQLEN; i++) {
284
+ for (int j = 0; j < SEQLEN; j++) {
285
+ if (j <= i && i < token_length) {
286
+ } else {
287
+ attention_mask[i * SEQLEN + j] = ATTENTION_MASK;
288
+ }
289
+ }
290
+ }
291
+
292
+ // forward embeding
293
+ std::vector<int> input_nums(device_num, 1);
294
+ std::vector<void*> datas(device_num, (void*)input_ids.data());
295
+ bmrt_memcpy_s2d_parallel(p_bmrt, inputs_embed_512.data(), datas.data(),
296
+ input_nums.data(), device_num);
297
+ auto ret =
298
+ bmrt_launch_tensor_ex(p_bmrt, name_embed.c_str(),
299
+ inputs_embed_512.data(), inputs_embed_512.size(),
300
+ outputs_embed_512.data(), outputs_embed_512.size(),
301
+ true, false);
302
+ assert(ret);
303
+ bm_thread_sync(bm_handle);
304
+
305
+ // forward blocks
306
+ std::vector<void*> pos_id_datas(device_num, position_id.data());
307
+ std::vector<void*> in_attn_datas(device_num, attention_mask.data());
308
+ bmrt_memcpy_s2d_parallel(p_bmrt, inputs_pid.data(), pos_id_datas.data(),
309
+ input_nums.data(), device_num);
310
+ bmrt_memcpy_s2d_parallel(p_bmrt, inputs_attention.data(),in_attn_datas.data(),
311
+ input_nums.data(), device_num);
312
+ auto embed_512 = outputs_embed_512;
313
+ std::vector<bm_tensor_t> inputs_block;
314
+ std::vector<bm_tensor_t> outputs_block;
315
+ for (int i = 0; i < device_num; ++i) {
316
+ embed_512[i].shape = net_blocks[0]->stages[0].input_shapes[0];
317
+ inputs_block.push_back(embed_512[i]);
318
+ inputs_block.push_back(inputs_pid[i]);
319
+ inputs_block.push_back(inputs_attention[i]);
320
+ outputs_block.push_back(embed_512[i]);
321
+ outputs_block.push_back(past_key[0][i]);
322
+ outputs_block.push_back(past_value[0][i]);
323
+ }
324
+
325
+ for (int i = 0; i < NUM_LAYERS; i++) {
326
+ for (int j = 0; j < device_num; ++j) {
327
+ outputs_block[1 + j * 3] = past_key[i][j];
328
+ outputs_block[2 + j * 3] = past_value[i][j];
329
+ }
330
+ ret = bmrt_launch_tensor_ex(p_bmrt, name_blocks[i].c_str(),
331
+ inputs_block.data(), inputs_block.size(),
332
+ outputs_block.data(), outputs_block.size(),
333
+ true, false);
334
+ assert(ret);
335
+ bm_thread_sync(bm_handle);
336
+ for (int j = 0; j < device_num; ++j) {
337
+ move2end(past_key[i][j]);
338
+ move2end(past_value[i][j]);
339
+ }
340
+ }
341
+
342
+ // forward lmhead
343
+ int bytes = embed_512[0].device_mem.size / SEQLEN;
344
+ bm_memcpy_d2d_byte(bm_handle, inputs_lm[0].device_mem, 0,
345
+ embed_512[0].device_mem, (token_length - 1) * bytes,
346
+ bytes);
347
+ ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm[0], 1,
348
+ &outputs_lm[0], 1, true, false);
349
+ bm_thread_sync(bm_handle);
350
+
351
+ int token = 0;
352
+ bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm[0].device_mem);
353
+ return token;
354
+ }
355
+
356
+ int ChatGLM::forward_next(int cur_token) {
357
+ std::vector<uint16_t> attention_mask(SEQLEN + 1, 0);
358
+ for (int i = 0; i <= SEQLEN - token_length; i++) {
359
+ attention_mask[i] = ATTENTION_MASK;
360
+ }
361
+ int32_t position_id = token_length - 1;
362
+
363
+ // forward embedding
364
+ std::vector<bm_tensor_t> inputs_embed;
365
+ std::vector<void*> input_datas;
366
+ std::vector<int> input_nums(device_num, 1);
367
+ for (int i = 0; i < device_num; ++i) {
368
+ inputs_embed.push_back(outputs_lm[i]); // token_id
369
+ inputs_embed[i].shape = net_embed_cache->stages[0].input_shapes[0];
370
+ input_datas.push_back((void*)(&cur_token));
371
+ }
372
+ bmrt_memcpy_s2d_parallel(p_bmrt, inputs_embed.data(), input_datas.data(),
373
+ input_nums.data(), device_num);
374
+ auto ret = bmrt_launch_tensor_ex(p_bmrt, name_embed_cache.c_str(),
375
+ inputs_embed.data(), inputs_embed.size(),
376
+ inputs_lm.data(), inputs_lm.size(), true, false);
377
+ assert(ret);
378
+ bm_thread_sync(bm_handle);
379
+
380
+ // forward blocks
381
+ std::vector<void*> attn_datas(device_num, attention_mask.data());
382
+ std::vector<void*> pid_datas(device_num, &position_id);
383
+ bmrt_memcpy_s2d_parallel(p_bmrt, next_attention.data(), attn_datas.data(),
384
+ input_nums.data(), device_num);
385
+ bmrt_memcpy_s2d_parallel(p_bmrt, next_pid.data(), pid_datas.data(),
386
+ input_nums.data(), device_num);
387
+
388
+ // WARNING: make inputs_lm device_num
389
+ std::vector<bm_tensor_t> embed_1 = inputs_lm;
390
+ for (int i = 0; i < device_num; ++i) {
391
+ embed_1[i].shape = net_blocks_cache[0]->stages[0].input_shapes[0];
392
+ }
393
+ std::vector<bm_tensor_t> inputs_block;
394
+ std::vector<bm_tensor_t> outputs_block;
395
+ for (int i = 0; i < device_num; ++i) {
396
+ inputs_block.push_back(embed_1[i]);
397
+ inputs_block.push_back(next_pid[i]);
398
+ inputs_block.push_back(next_attention[i]);
399
+ inputs_block.push_back(past_key[0][i]);
400
+ inputs_block.push_back(past_value[0][i]);
401
+ outputs_block.push_back(embed_1[i]);
402
+ outputs_block.push_back(past_key[0][i]);
403
+ outputs_block.push_back(past_value[0][i]);
404
+ }
405
+
406
+ for (int i = 0; i < NUM_LAYERS; i++) {
407
+ for (int j = 0; j < device_num; ++j) {
408
+ inputs_block[3 + j * 5] = past_key[i][j];
409
+ inputs_block[4 + j * 5] = past_value[i][j];
410
+ outputs_block[1 + j * 3] = past_key[i][j];
411
+ outputs_block[2 + j * 3] = past_value[i][j];
412
+ }
413
+ ret = bmrt_launch_tensor_ex(p_bmrt, name_blocks_cache[i].c_str(),
414
+ inputs_block.data(), inputs_block.size(),
415
+ outputs_block.data(), outputs_block.size(),
416
+ true, false);
417
+ assert(ret);
418
+ bm_thread_sync(bm_handle);
419
+ }
420
+
421
+ // forward lmhead
422
+ ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm[0], 1,
423
+ &outputs_lm[0], 1, true, false);
424
+ assert(ret);
425
+ bm_thread_sync(bm_handle);
426
+
427
+ int token = 0;
428
+ bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm[0].device_mem);
429
+ return token;
430
+ }
431
+
432
+ std::string build_prompt(std::string query, std::vector<std::pair<std::string, std::string>> history = {}) {
433
+ std::string prompt = "";
434
+ int round_number = 1;
435
+ for (const auto& item : history) {
436
+ prompt += "[Round " + std::to_string(round_number) + "]\n\n问:" + item.first + "\n\n答:" + item.second + "\n\n";
437
+ round_number ++;
438
+ }
439
+ prompt += "[Round " + std::to_string(history.size() + 1) + "]\n\n问:" + query + "\n\n答:";
440
+ return prompt;
441
+ }
442
+
443
+ void ChatGLM::chat() {
444
+ while (true) {
445
+ std::cout << "\nQuestion: ";
446
+ std::string input_str;
447
+ std::getline(std::cin, input_str);
448
+ if (input_str == "exit") {
449
+ break;
450
+ }
451
+ std::cout << "\nAnswer: " << std::flush;
452
+ answer(input_str);
453
+ std::cout << std::endl;
454
+ }
455
+ }
456
+
457
+ void ChatGLM::answer(const std::string &input_str) {
458
+ // auto time_0 = std::chrono::system_clock::now();
459
+ std::string query = build_prompt(input_str, history_vector);
460
+ int tok_num = 0;
461
+ std::vector<int> tokens;
462
+ std::vector<int> prompt{64790, 64792};
463
+ sentencepiece.Encode(query, &tokens);
464
+
465
+ if (tokens.empty()) {
466
+ printf("Sorry: your question is too wierd!!\n");
467
+ return;
468
+ }
469
+ // tokens is not empty
470
+ tokens.insert(tokens.begin(), prompt.begin(), prompt.end());
471
+
472
+ // make sure token not too large
473
+ if ((int)tokens.size() > SEQLEN - 10) {
474
+ // reset
475
+ tokens.clear();
476
+ cur_answer.clear();
477
+ printf("Error: your question is too large!\n");
478
+ return;
479
+ }
480
+
481
+ int pre_token = 0;
482
+ auto t0 = std::chrono::system_clock::now();
483
+ int token = forward_first(tokens);
484
+ auto t1 = std::chrono::system_clock::now();
485
+ while (token != EOS && token_length < SEQLEN) {
486
+ std::string pre_word;
487
+ std::string word;
488
+ std::vector<int> pre_ids = {pre_token};
489
+ std::vector<int> ids = {pre_token, token};
490
+ sentencepiece.Decode(pre_ids, &pre_word);
491
+ sentencepiece.Decode(ids, &word);
492
+ std::string diff = word.substr(pre_word.size());
493
+ cur_answer += diff;
494
+ tokens.emplace_back(token);
495
+ std::cout << diff << std::flush;
496
+ if (token_length < SEQLEN) {
497
+ token_length++;
498
+ }
499
+ tok_num++;
500
+ token = forward_next(token);
501
+ }
502
+ auto t2 = std::chrono::system_clock::now();
503
+ auto use0 = std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0);
504
+ auto use1 = std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1);
505
+ printf("\n\nfirst token latency: %f s", (use0.count() * 1e-6));
506
+ printf("\nspeed: %f token/s\n", tok_num / (use1.count() * 1e-6));
507
+
508
+ if (token_length >= SEQLEN) {
509
+ printf("Warning: Reach to the max sequence length!\n");
510
+ history_vector.push_back({input_str, cur_answer});
511
+ cur_answer.clear();
512
+
513
+ // Delete the first half data
514
+ size_t half_size = history_vector.size() / 2;
515
+ history_vector.erase(history_vector.begin(), history_vector.begin() + half_size);
516
+ } else {
517
+ history_vector.push_back({input_str, cur_answer});
518
+ cur_answer.clear();
519
+ }
520
+ }
521
+
522
+ static void split(const std::string &s, const std::string &delim,
523
+ std::vector<std::string> &ret) {
524
+ size_t last = 0;
525
+ size_t index = s.find_first_of(delim, last);
526
+ while (index != std::string::npos) {
527
+ ret.push_back(s.substr(last, index - last));
528
+ last = index + 1;
529
+ index = s.find_first_of(delim, last);
530
+ }
531
+ if (last < s.length()) {
532
+ ret.push_back(s.substr(last));
533
+ }
534
+ }
535
+
536
+ static std::vector<int> parseCascadeDevices(const std::string &str) {
537
+ std::vector<int> devices;
538
+ std::vector<std::string> sub_str;
539
+ split(str, ",", sub_str);
540
+ for (auto &s : sub_str) {
541
+ devices.push_back(std::atoi(s.c_str()));
542
+ }
543
+ return devices;
544
+ }
545
+
546
+ void Usage() {
547
+ printf("Usage:\n"
548
+ " --help : Show help info.\n"
549
+ " --model : Set model path \n"
550
+ " --tokenizer : Set tokenizer path \n"
551
+ " --devid : Set devices to run for model, e.g. 1,2. if not "
552
+ "set, use 0\n");
553
+ }
554
+
555
+ void processArguments(int argc, char *argv[], std::string &model_path, std::string &tokenizer_path,
556
+ std::vector<int> &devices) {
557
+ struct option longOptions[] = {{"model", required_argument, nullptr, 'm'},
558
+ {"tokenizer", required_argument, nullptr, 't'},
559
+ {"devid", required_argument, nullptr, 'd'},
560
+ {"help", no_argument, nullptr, 'h'},
561
+ {nullptr, 0, nullptr, 0}};
562
+
563
+ int optionIndex = 0;
564
+ int option;
565
+
566
+ while ((option = getopt_long(argc, argv, "m:t:d:h:", longOptions,
567
+ &optionIndex)) != -1) {
568
+ switch (option) {
569
+ case 'm':
570
+ model_path = optarg;
571
+ break;
572
+ case 't':
573
+ tokenizer_path = optarg;
574
+ break;
575
+ case 'd':
576
+ devices = parseCascadeDevices(optarg);
577
+ break;
578
+ case 'h':
579
+ Usage();
580
+ exit(EXIT_FAILURE);
581
+ case '?':
582
+ Usage();
583
+ exit(EXIT_FAILURE);
584
+ default:
585
+ exit(EXIT_FAILURE);
586
+ }
587
+ }
588
+ }
589
+
590
+ int main(int argc, char **argv) {
591
+ // set your bmodel path here
592
+ printf("Demo for ChatGLM in BM1684X, support ChatGLM1/2/3\n");
593
+ std::string model_path;
594
+ std::string tokenizer_path;
595
+ std::vector<int> devices = {0};
596
+ processArguments(argc, argv, model_path, tokenizer_path, devices);
597
+ if (model_path.empty()) {
598
+ Usage();
599
+ exit(EXIT_FAILURE);
600
+ }
601
+
602
+ ChatGLM glm;
603
+ printf("Init Environment ...\n");
604
+ glm.init(devices, model_path, tokenizer_path);
605
+ printf("==========================\n");
606
+ glm.chat();
607
+ glm.deinit();
608
+ return 0;
609
+ }
ChatGLM2/run_demo.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -ex
3
+
4
+ #!/bin/bash
5
+ # download bmodel
6
+ if [ ! -d "../../bmodels" ]; then
7
+ mkdir ../../bmodels
8
+ fi
9
+
10
+ if [ ! -f "../../bmodels/chatglm2-6b_int4_1dev.bmodel" ]; then
11
+ pip3 install dfss
12
+ python3 -m dfss --url=open@sophgo.com:/ext_model_information/LLM/LLM-TPU/chatglm2-6b_int4_1dev.bmodel
13
+ mv chatglm2-6b_int4_1dev.bmodel ../../bmodels
14
+ else
15
+ echo "Bmodel Exists!"
16
+ fi
17
+
18
+ if [ ! -f "./demo/chatglm" ]; then
19
+ cd demo && rm -rf build && mkdir build && cd build
20
+ cmake .. && make -j
21
+ cp chatglm .. && cd ../..
22
+ else
23
+ echo "chatglm file Exists!"
24
+ fi
25
+
26
+ # run demo
27
+ ./demo/chatglm --model ../../bmodels/chatglm2-6b_int4_1dev.bmodel --tokenizer ./support/tokenizer/tokenizer.model --devid 0
ChatGLM2/support/include/bmdef.h ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*****************************************************************************
2
+ *
3
+ * Copyright (c) 2016-2026 by Sophgo Technologies Inc. All rights reserved.
4
+ *
5
+ * The material in this file is confidential and contains trade secrets
6
+ * of Sophgo Technologies Inc. This is proprietary information owned by
7
+ * Sophgo Technologies Inc. No part of this work may be disclosed,
8
+ * reproduced, copied, transmitted, or used in any way for any purpose,
9
+ * without the express written permission of Sophgo Technologies Inc.
10
+ *
11
+ *****************************************************************************/
12
+
13
+ #ifndef __BMRUNTIME_DEFINE_H__
14
+ #define __BMRUNTIME_DEFINE_H__
15
+
16
+ #include "bmlib_runtime.h"
17
+ #include <stddef.h>
18
+ #include <stdint.h>
19
+
20
+ #if defined(__cplusplus)
21
+ extern "C" {
22
+ #endif
23
+
24
+ /* --------------------------------------------------------------------------*/
25
+ /* basic definitions */
26
+
27
+ /* bm_data_type_t holds the type for a scalar value */
28
+ typedef enum bm_data_type_e {
29
+ BM_FLOAT32 = 0,
30
+ BM_FLOAT16 = 1,
31
+ BM_INT8 = 2,
32
+ BM_UINT8 = 3,
33
+ BM_INT16 = 4,
34
+ BM_UINT16 = 5,
35
+ BM_INT32 = 6,
36
+ BM_UINT32 = 7,
37
+ BM_BFLOAT16 = 8,
38
+ BM_INT4 = 9,
39
+ BM_UINT4 = 10,
40
+ } bm_data_type_t;
41
+
42
+ /* store mode definitions */
43
+ typedef enum bm_store_mode_e {
44
+ BM_STORE_1N = 0, /* default, if not sure, use 0 */
45
+ BM_STORE_2N = 1,
46
+ BM_STORE_4N = 2,
47
+ } bm_store_mode_t;
48
+
49
+ /* bm_shape_t holds the shape info */
50
+ #define BM_MAX_DIMS_NUM 8
51
+ typedef struct bm_shape_s {
52
+ int num_dims;
53
+ int dims[BM_MAX_DIMS_NUM];
54
+ } bm_shape_t;
55
+
56
+ typedef struct bm_shape_ex_s {
57
+ bm_shape_t shape;
58
+ int elem_num;
59
+ } bm_shape_ex_t;
60
+
61
+ /*
62
+ bm_tensor_t holds a multi-dimensional array of elements of a single data type
63
+ and tensor are in device memory */
64
+ typedef struct bm_tensor_s {
65
+ bm_data_type_t dtype;
66
+ bm_shape_t shape;
67
+ bm_device_mem_t device_mem;
68
+ bm_store_mode_t st_mode; /* user can set 0 as default store mode */
69
+ } bm_tensor_t;
70
+
71
+ /* --------------------------------------------------------------------------*/
72
+ /* network information structure */
73
+
74
+ /* bm_stage_info_t holds input/output shapes and device mems; every network can contain one or more
75
+ * stages */
76
+ typedef struct bm_stage_info_s {
77
+ bm_shape_t *input_shapes; /* input_shapes[0] / [1] / ... / [input_num-1] */
78
+ bm_shape_t *output_shapes; /* output_shapes[0] / [1] / ... / [output_num-1] */
79
+ bm_device_mem_t *input_mems; /* input_mems[0] / [1] / ... / [input_num-1] */
80
+ bm_device_mem_t *output_mems; /* output_mems[0] / [1] / ... / [output_num-1] */
81
+ } bm_stage_info_t;
82
+
83
+ /* bm_tensor_info_t holds all information of one net.
84
+ * scale for float type is 1.0 as default */
85
+ typedef struct bm_net_info_s {
86
+ const char* name; /* net name */
87
+ bool is_dynamic; /* dynamic or static */
88
+ int input_num; /* number of inputs */
89
+ char const** input_names; /* input_names[0] / [1] / .../ [input_num-1] */
90
+ bm_data_type_t* input_dtypes; /* input_dtypes[0] / [1] / .../ [input_num-1] */
91
+ float* input_scales; /* input_scales[0] / [1] / .../ [input_num-1] */
92
+ int output_num; /* number of outputs */
93
+ char const** output_names; /* output_names[0] / [1] / .../ [output_num-1] */
94
+ bm_data_type_t* output_dtypes; /* output_dtypes[0] / [1] / .../ [output_num-1] */
95
+ float* output_scales; /* output_scales[0] / [1] / .../ [output_num-1] */
96
+ int stage_num; /* number of stages */
97
+ bm_stage_info_t* stages; /* stages[0] / [1] / ... / [stage_num-1] */
98
+ size_t* max_input_bytes; /* max_input_bytes[0]/ [1] / ... / [input_num-1] */
99
+ size_t* max_output_bytes; /* max_output_bytes[0] / [1] / ... / [output_num-1] */
100
+ int* input_zero_point; /* input_zero_point[0] / [1] / .../ [input_num-1] */
101
+ int* output_zero_point; /* output_zero_point[0] / [1] / .../ [output_num-1] */
102
+ int *input_loc_devices; /* input_loc_device[0] / [1] / .../ [input_num-1] */
103
+ int *output_loc_devices; /* output_loc_device[0] / [1] / .../ [output_num-1] */
104
+ } bm_net_info_t;
105
+
106
+ typedef struct api_info_s {
107
+ /// @brief api_id to be sent to driver
108
+ int32_t api_id;
109
+ /// @brief api data to be sent to driver
110
+ uint8_t **api_data;
111
+ /// @brief size of the api data to be sent to driver
112
+ size_t api_data_size;
113
+ /// @brief subsize of the api data to be sent to driver
114
+ size_t *api_data_subsize;
115
+ /// @brief offset of input tensors' addr in api_data
116
+ uint32_t *input_addr_offset;
117
+ /// @brief number of the offset of input tensors' addr in api_data
118
+ size_t input_addr_offset_number;
119
+ /// @brief offset of output tensors' addr in api_data
120
+ uint32_t *output_addr_offset;
121
+ /// @brief number of the offset of output tensors' addr in api_data
122
+ size_t output_addr_offset_number;
123
+ } api_info_c;
124
+
125
+ #if defined(__cplusplus)
126
+ }
127
+ #endif
128
+
129
+ #endif /* __BM_NET_H__ */
ChatGLM2/support/include/bmlib_runtime.h ADDED
@@ -0,0 +1,2581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*****************************************************************************
2
+ *
3
+ * Copyright (c) 2016-2026 by Bitmain Technologies Inc. All rights reserved.
4
+ *
5
+ * The material in this file is confidential and contains trade secrets
6
+ * of Bitmain Technologies Inc. This is proprietary information owned by
7
+ * Bitmain Technologies Inc. No part of this work may be disclosed,
8
+ * reproduced, copied, transmitted, or used in any way for any purpose,
9
+ * without the express written permission of Bitmain Technologies Inc.
10
+ *
11
+ *****************************************************************************/
12
+
13
+ /**************************************************************************
14
+ * bmlib_runtime defines interfaces that operate TPU devices.
15
+ * The functions can be divided into serveral categories.
16
+ * 1) device handle creation and destroy
17
+ * 2) memory help functions
18
+ * 3) global memory allocation and free
19
+ * 4) data transfer between host and device
20
+ * 5) data transfer within device memory
21
+ * 6) api send and synchronization
22
+ * 7) global memory map and coherence
23
+ * 8) trace and profile
24
+ * 9) power management
25
+ * 10) miscellaneous functions
26
+ *************************************************************************/
27
+
28
+ #ifndef BMLIB_RUNTIME_H_
29
+ #define BMLIB_RUNTIME_H_
30
+ #if defined(_WIN32) && !defined(__MINGW32__)
31
+ #include <vadefs.h>
32
+ #define DECL_EXPORT __declspec(dllexport)
33
+ #define DECL_IMPORT __declspec(dllimport)
34
+ #else
35
+ #include <stdbool.h>
36
+ #include <stddef.h>
37
+ #include <stdarg.h>
38
+ #define DECL_EXPORT
39
+ #define DECL_IMPORT
40
+ #endif
41
+
42
+ #if defined(__cplusplus)
43
+ extern "C" {
44
+ #endif
45
+
46
+ typedef enum {
47
+ MODULE_CDMA = 0,
48
+ MODULE_GDMA = 1,
49
+ MODULE_TPU = 2,
50
+ MODULE_SMMU = 3,
51
+ MODULE_SRAM = 4,
52
+ MODULE_END = 5
53
+ } MODULE_ID;
54
+
55
+ #define BM_MEM_ADDR_NULL (0xfffffffff)
56
+
57
+ #ifndef BM_MEM_DESC_T_
58
+ #define BM_MEM_DESC_T_
59
+ /* BM function return code definitions */
60
+ typedef enum {
61
+ BM_SUCCESS = 0,
62
+ BM_ERR_DEVNOTREADY = 1, /* Device not ready yet */
63
+ BM_ERR_FAILURE = 2, /* General failure */
64
+ BM_ERR_TIMEOUT = 3, /* Timeout */
65
+ BM_ERR_PARAM = 4, /* Parameters invalid */
66
+ BM_ERR_NOMEM = 5, /* Not enough memory */
67
+ BM_ERR_DATA = 6, /* Data error */
68
+ BM_ERR_BUSY = 7, /* Busy */
69
+ BM_ERR_NOFEATURE = 8, /* Not supported yet */
70
+ BM_NOT_SUPPORTED = 9
71
+ } bm_status_t;
72
+
73
+ /* BM memory type definitions */
74
+ typedef enum {
75
+ BM_MEM_TYPE_DEVICE = 0,
76
+ BM_MEM_TYPE_HOST = 1,
77
+ BM_MEM_TYPE_SYSTEM = 2,
78
+ BM_MEM_TYPE_INT8_DEVICE = 3,
79
+ BM_MEM_TYPE_INVALID = 4
80
+ } bm_mem_type_t;
81
+
82
+ typedef enum {
83
+ PERF_MONITOR_GDMA = 0,
84
+ PERF_MONITOR_TPU = 1
85
+ } PERF_MONITOR_ID;
86
+
87
+ typedef enum {
88
+ BMCPU_IDLE = 0,
89
+ BMCPU_RUNNING = 1,
90
+ BMCPU_FAULT = 2
91
+ } bm_cpu_status_t;
92
+
93
+ /*
94
+ * bm performace monitor
95
+ */
96
+ typedef struct bm_perf_monitor {
97
+ long long buffer_start_addr; /*buffer address to store perf data*/
98
+ int buffer_size; /*buffer size*/
99
+ PERF_MONITOR_ID monitor_id; /*PERF_MONITOR_GDMA or PERF_MONITOR_TPU*/
100
+ } bm_perf_monitor_t;
101
+
102
+ typedef union {
103
+ struct {
104
+ bm_mem_type_t mem_type : 3;
105
+ unsigned int gmem_heapid : 3;
106
+ unsigned int reserved : 26;
107
+ } u;
108
+ unsigned int rawflags;
109
+ } bm_mem_flags_t;
110
+
111
+ /* BM memory descriptor definition*/
112
+ typedef struct bm_mem_desc {
113
+ union {
114
+ struct {
115
+ #ifdef __linux__
116
+ unsigned long device_addr;
117
+ #else
118
+ unsigned long long device_addr;
119
+ #endif
120
+ unsigned int reserved;
121
+ int dmabuf_fd;
122
+ } device;
123
+
124
+ struct {
125
+ void *system_addr;
126
+ unsigned int reserved0;
127
+ int reserved1;
128
+ } system;
129
+ } u;
130
+
131
+ bm_mem_flags_t flags;
132
+ unsigned int size;
133
+ } bm_mem_desc_t;
134
+
135
+ typedef struct bm_mem_desc bm_device_mem_t;
136
+ typedef struct bm_mem_desc bm_system_mem_t;
137
+
138
+ typedef struct sg_mem_desc {
139
+ union {
140
+ struct {
141
+ #ifdef __linux__
142
+ unsigned long device_addr;
143
+ #else
144
+ unsigned long long device_addr;
145
+ #endif
146
+ unsigned int reserved;
147
+ int dmabuf_fd;
148
+ } device;
149
+
150
+ struct {
151
+ void *system_addr;
152
+ unsigned int reserved0;
153
+ int reserved1;
154
+ } system;
155
+ } u;
156
+
157
+ bm_mem_flags_t flags;
158
+ unsigned long long size;
159
+ } sg_mem_desc_t;
160
+
161
+ typedef struct sg_mem_desc sg_device_mem_t;
162
+ typedef struct sg_mem_desc sg_system_mem_t;
163
+ #endif
164
+
165
+ struct bm_context;
166
+ typedef struct bm_context *bm_handle_t;
167
+
168
+ #define MD5SUM_LEN 16
169
+ #define LIB_MAX_NAME_LEN 64
170
+ #define FUNC_MAX_NAME_LEN 64
171
+
172
+ typedef struct bm_module
173
+ {
174
+ // void *lib_handle;
175
+ char lib_name[LIB_MAX_NAME_LEN];
176
+ unsigned char md5[MD5SUM_LEN];
177
+ }bm_module;
178
+
179
+ typedef struct bm_module *tpu_kernel_module_t;
180
+ typedef int tpu_kernel_function_t;
181
+
182
+ /**
183
+ * @name tpu_kernel_load_module_file
184
+ * @brief To load dyn file
185
+ * @ingroup bmlib_runtime
186
+ *
187
+ * @param [in] handle The device handle
188
+ * @param [in] module_file dyn file
189
+ * @retval dyn lib ptr
190
+ */
191
+ tpu_kernel_module_t tpu_kernel_load_module_file(bm_handle_t handle, const char *module_file);
192
+
193
+ /**
194
+ * @name tpu_kernel_load_module_file_key
195
+ * @brief To load dyn file with key
196
+ * @ingroup bmlib_runtime
197
+ *
198
+ * @param [in] handle The device handle
199
+ * @param [in] module_file dyn file
200
+ * @param [in] key identification str
201
+ * @param [in] size key size
202
+ * @retval dyn lib ptr
203
+ */
204
+ tpu_kernel_module_t tpu_kernel_load_module_file_key(bm_handle_t handle, const char *module_file, const char *key, int size);
205
+
206
+ /**
207
+ * @name tpu_kernel_unload_module
208
+ * @brief To unload dyn file
209
+ * @ingroup bmlib_runtime
210
+ *
211
+ * @param [in] handle The device handle
212
+ * @param [in] p_module dyn lib ptr
213
+ * @retval BM_SUCCESS Succeeds.
214
+ * Other code Fails.
215
+ */
216
+ bm_status_t tpu_kernel_unload_module(bm_handle_t handle, tpu_kernel_module_t p_module);
217
+
218
+ /**
219
+ * @name tpu_kernel_free_module
220
+ * @brief To free p_module when not use
221
+ * @ingroup bmlib_runtime
222
+ *
223
+ * @param [in] handle The device handle
224
+ * @param [in] p_module dyn lib ptr
225
+ * @retval BM_SUCCESS Succeeds.
226
+ * Other code Fails.
227
+ */
228
+ bm_status_t tpu_kernel_free_module(bm_handle_t handle, tpu_kernel_module_t p_module);
229
+
230
+ /**
231
+ * @name tpu_kernel_load_module
232
+ * @brief To load dyn module
233
+ * @ingroup bmlib_runtime
234
+ *
235
+ * @param [in] handle The device handle
236
+ * @param [in] data dyn module
237
+ * @param [in] length dyn module size
238
+ * @retval dyn lib ptr
239
+ */
240
+ tpu_kernel_module_t tpu_kernel_load_module(bm_handle_t handle, const char *data, size_t length);
241
+
242
+ /**
243
+ * @name tpu_kernel_get_function
244
+ * @brief To get function from lib
245
+ * @ingroup bmlib_runtime
246
+ *
247
+ * @param [in] handle The device handle
248
+ * @param [in] module dyn module
249
+ * @param [in] function funtion name
250
+ * @retval function id
251
+ */
252
+ tpu_kernel_function_t tpu_kernel_get_function(bm_handle_t handle, tpu_kernel_module_t module, const char *function);
253
+
254
+ /**
255
+ * @name tpu_kernel_launch
256
+ * @brief To launch function with sync
257
+ * @ingroup bmlib_runtime
258
+ *
259
+ * @param [in] handle The device handle
260
+ * @param [in] function function id
261
+ * @param [in] args funtion args
262
+ * @param [in] size args size
263
+ * @retval BM_SUCCESS Succeeds.
264
+ * Other code Fails.
265
+ */
266
+ bm_status_t tpu_kernel_launch(bm_handle_t handle, tpu_kernel_function_t function, void *args, size_t size);
267
+
268
+ /**
269
+ * @name tpu_kernel_launch_async
270
+ * @brief To launch function with async
271
+ * @ingroup bmlib_runtime
272
+ *
273
+ * @param [in] handle The device handle
274
+ * @param [in] function function id
275
+ * @param [in] args funtion args
276
+ * @param [in] size args size
277
+ * @retval BM_SUCCESS Succeeds.
278
+ * Other code Fails.
279
+ */
280
+ bm_status_t tpu_kernel_launch_async(bm_handle_t handle, tpu_kernel_function_t function, void *args, size_t size);
281
+
282
+ /**
283
+ * @name tpu_kernel_launch_async_multi_cores
284
+ * @brief To launch function with async for multi cores
285
+ * @ingroup bmlib_runtime
286
+ *
287
+ * @param [in] handle The device handle
288
+ * @param [in] func_name function name
289
+ * @param [in] api_param funtion params
290
+ * @param [in] api_size params size
291
+ * @param [in] core_list list of core ids
292
+ * @param [in] core_num number of cores
293
+ * @retval BM_SUCCESS Succeeds.
294
+ * Other code Fails.
295
+ */
296
+ bm_status_t tpu_kernel_launch_async_multi_cores(bm_handle_t handle, const char *func_name, const void *api_param,
297
+ size_t api_size, const int* core_list, const int core_num);
298
+
299
+ /**
300
+ * @name tpu_kernel_launch_sync_multi_cores
301
+ * @brief To launch function with sync for multi cores
302
+ * @ingroup bmlib_runtime
303
+ *
304
+ * @param [in] handle The device handle
305
+ * @param [in] func_name function name
306
+ * @param [in] api_param funtion params
307
+ * @param [in] api_size params size
308
+ * @param [in] core_list list of core ids
309
+ * @param [in] core_num number of cores
310
+ * @retval BM_SUCCESS Succeeds.
311
+ * Other code Fails.
312
+ */
313
+ bm_status_t tpu_kernel_launch_sync_multi_cores(bm_handle_t handle, const char *func_name, const void *api_param,
314
+ size_t api_size, const int* core_list, const int core_num);
315
+
316
+ /**
317
+ * @name tpu_kernel_sync
318
+ * @brief To sync
319
+ * @ingroup bmlib_runtime
320
+ *
321
+ * @param [in] handle The device handle
322
+ * @retval BM_SUCCESS Succeeds.
323
+ * Other code Fails.
324
+ */
325
+ bm_status_t tpu_kernel_sync(bm_handle_t handle);
326
+ void show_md5(unsigned char md5[]);
327
+
328
+ DECL_EXPORT void bmlib_log(const char *tag, int level, const char *fmt, ...);
329
+
330
+ #ifndef USING_CMODEL
331
+ #define BM_CHECK_RET(call) \
332
+ do { \
333
+ bm_status_t ret = (bm_status_t)call; \
334
+ if (ret != BM_SUCCESS) { \
335
+ bmlib_log("BM_CHECK",16,"BM_CHECK_RET fail %s: %s: %d\n", __FILE__, __func__, __LINE__); \
336
+ return ret; \
337
+ } \
338
+ } while (0)
339
+ #else
340
+ #define BM_CHECK_RET(call) \
341
+ do { \
342
+ bm_status_t ret = call; \
343
+ if (ret != BM_SUCCESS) { \
344
+ bmlib_log("BM_CHECK",16,"BM_CHECK_RET failed %d\n", ret);\
345
+ ASSERT(0); \
346
+ exit(-ret); \
347
+ } \
348
+ } while (0)
349
+ #endif
350
+
351
+ /*******************handle releated functions *********************************/
352
+ /**
353
+ * @name bm_dev_getcount
354
+ * @brief To get the number of sophon devices in system.
355
+ * If N is got, valid devid is [0, N-1]
356
+ * @ingroup bmlib_runtime
357
+ *
358
+ * @param [out] count The result number of sophon devices
359
+ * @retval BM_SUCCESS Succeeds.
360
+ * Other code Fails.
361
+ */
362
+ DECL_EXPORT bm_status_t bm_dev_getcount(int *count);
363
+
364
+ /**
365
+ * @name bm_dev_query
366
+ * @brief To query if a device is present
367
+ * @ingroup bmlib_runtime
368
+ *
369
+ * @param [in] devid The id of the device to query
370
+ * @retval BM_SUCCESS Device is present
371
+ * Other code Devcie is not present
372
+ */
373
+ DECL_EXPORT bm_status_t bm_dev_query(int devid);
374
+
375
+ /**
376
+ * @name bm_dev_request
377
+ * @brief To create a handle for the given device
378
+ * @ingroup bmlib_runtime
379
+ *
380
+ * @param [out] handle The created handle
381
+ * @param [in] devid Specify on which device to create handle
382
+ * @retval BM_SUCCESS Succeeds.
383
+ * Other code Fails.
384
+ */
385
+ DECL_EXPORT bm_status_t bm_dev_request(bm_handle_t *handle, int devid);
386
+
387
+ /**
388
+ * @name bm_get_devid
389
+ * @brief To get device index for the given handle
390
+ * @ingroup bmlib_runtime
391
+ *
392
+ * @param [in] handle The given handle
393
+ * @retval int device index that the handle points to.
394
+ */
395
+ DECL_EXPORT int bm_get_devid(bm_handle_t handle);
396
+
397
+ /**
398
+ * @name bm_dev_free
399
+ * @brief To free a handle
400
+ * @ingroup bmlib_runtime
401
+ *
402
+ * @param [in] handle The handle to free
403
+ */
404
+ DECL_EXPORT void bm_dev_free(bm_handle_t handle);
405
+
406
+ /*******************memory help functions ************************************/
407
+ /**
408
+ * @name bm_mem_get_type
409
+ * @brief To get a memory descriptor's type
410
+ * @ingroup bmlib_runtime
411
+ *
412
+ * @param [in] mem The memory descriptor queried
413
+ * @retval BM_MEM_TYPE_DEVICE Device global memory
414
+ * @retval BM_MEM_TYPE_SYSTEM Host user memory
415
+ */
416
+ DECL_EXPORT bm_mem_type_t bm_mem_get_type(struct bm_mem_desc mem);
417
+
418
+ /**
419
+ * @name sg_mem_get_type
420
+ * @brief To get a memory descriptor's type
421
+ * @ingroup bmlib_runtime
422
+ *
423
+ * @param [in] mem The memory descriptor queried
424
+ * @retval BM_MEM_TYPE_DEVICE Device global memory
425
+ * @retval BM_MEM_TYPE_SYSTEM Host user memory
426
+ */
427
+ DECL_EXPORT bm_mem_type_t sg_mem_get_type(struct sg_mem_desc mem);
428
+
429
+ /**
430
+ * @name bm_mem_get_device_addr
431
+ * @brief To get a device memory descriptor's address
432
+ * @ingroup bmlib_runtime
433
+ *
434
+ * @param [in] mem The device memory descriptor queried
435
+ * @retval unsigned long long The device memory address
436
+ */
437
+ DECL_EXPORT unsigned long long bm_mem_get_device_addr(struct bm_mem_desc mem);
438
+
439
+ /**
440
+ * @name sg_mem_get_device_addr
441
+ * @brief To get a device memory descriptor's address
442
+ * @ingroup bmlib_runtime
443
+ *
444
+ * @param [in] mem The device memory descriptor queried
445
+ * @retval unsigned long long The device memory address
446
+ */
447
+ DECL_EXPORT unsigned long long sg_mem_get_device_addr(struct sg_mem_desc mem);
448
+
449
+ /**
450
+ * @name bm_mem_set_device_addr
451
+ * @brief To set a device memory descriptor's address
452
+ * @ingroup bmlib_runtime
453
+ *
454
+ * @param [in] pmem The device memory descriptor pointer
455
+ * @param ]in] addr The new device address of the device memory
456
+ */
457
+ DECL_EXPORT void bm_mem_set_device_addr(struct bm_mem_desc* pmem, unsigned long long addr);
458
+
459
+ /**
460
+ * @name sg_mem_set_device_addr
461
+ * @brief To set a device memory descriptor's address
462
+ * @ingroup bmlib_runtime
463
+ *
464
+ * @param [in] pmem The device memory descriptor pointer
465
+ * @param ]in] addr The new device address of the device memory
466
+ */
467
+ DECL_EXPORT void sg_mem_set_device_addr(struct sg_mem_desc* pmem, unsigned long long addr);
468
+
469
+ /**
470
+ * @name bm_mem_get_device_size
471
+ * @brief To get a device memory descriptor's size
472
+ * @ingroup bmlib_runtime
473
+ *
474
+ * @param [in] mem The device memory descriptor queried
475
+ * @retval unsigned int The device memory's size in bytes
476
+ */
477
+ DECL_EXPORT unsigned int bm_mem_get_device_size(struct bm_mem_desc mem);
478
+
479
+ /**
480
+ * @name sg_mem_get_device_size
481
+ * @brief To get a device memory descriptor's size
482
+ * @ingroup bmlib_runtime
483
+ *
484
+ * @param [in] mem The device memory descriptor queried
485
+ * @retval unsigned int The device memory's size in bytes
486
+ */
487
+ DECL_EXPORT unsigned long long sg_mem_get_device_size(struct sg_mem_desc mem);
488
+
489
+ /**
490
+ * @name bm_mem_set_device_size
491
+ * @brief To set a device memory descriptor's size
492
+ * @ingroup bmlib_runtime
493
+ *
494
+ * @param [out] pmem The device memory descriptor pointer
495
+ * @param [in] size The new device memory size (in bytes) of the device memory
496
+ */
497
+ DECL_EXPORT void bm_mem_set_device_size(struct bm_mem_desc* pmem, unsigned int size);
498
+
499
+ /**
500
+ * @name sg_mem_set_device_size
501
+ * @brief To set a device memory descriptor's size
502
+ * @ingroup bmlib_runtime
503
+ *
504
+ * @param [out] pmem The device memory descriptor pointer
505
+ * @param [in] size The new device memory size (in bytes) of the device memory
506
+ */
507
+ DECL_EXPORT void sg_mem_set_device_size(struct sg_mem_desc* pmem, unsigned long long size);
508
+
509
+ /**
510
+ * @name bm_set_device_mem
511
+ * @brief To fill in a device memory descriptor with size and address
512
+ * @ingroup bmlib_runtime
513
+ *
514
+ * @param [in] pmem The device memory descriptor pointer
515
+ * @param [in] size The device memory descriptor's size
516
+ * @param [in] addr The device memory descriptor's address
517
+ */
518
+ DECL_EXPORT void bm_set_device_mem(bm_device_mem_t* pmem, unsigned int size,
519
+ unsigned long long addr);
520
+
521
+ /**
522
+ * @name sg_set_device_mem
523
+ * @brief To fill in a device memory descriptor with size and address
524
+ * @ingroup bmlib_runtime
525
+ *
526
+ * @param [in] pmem The device memory descriptor pointer
527
+ * @param [in] size The device memory descriptor's size
528
+ * @param [in] addr The device memory descriptor's address
529
+ */
530
+ DECL_EXPORT void sg_set_device_mem(sg_device_mem_t* pmem, unsigned long long size,
531
+ unsigned long long addr);
532
+
533
+ /**
534
+ * @name bm_mem_from_device
535
+ * @brief To create a device memory descriptor from address and size
536
+ * @ingroup bmlib_runtime
537
+ *
538
+ * @param [in] device_addr The device memory address
539
+ * @param [in] len The device memory size
540
+ * @retval bm_device_mem_t The device memory descriptor created
541
+ */
542
+ DECL_EXPORT bm_device_mem_t bm_mem_from_device(unsigned long long device_addr,
543
+ unsigned int len);
544
+
545
+ /**
546
+ * @name sg_mem_from_device
547
+ * @brief To create a device memory descriptor from address and size
548
+ * @ingroup bmlib_runtime
549
+ *
550
+ * @param [in] device_addr The device memory address
551
+ * @param [in] len The device memory size
552
+ * @retval bm_device_mem_t The device memory descriptor created
553
+ */
554
+ DECL_EXPORT sg_device_mem_t sg_mem_from_device(unsigned long long device_addr,
555
+ unsigned long long len);
556
+
557
+ /**
558
+ * @name bm_mem_get_system_addr
559
+ * @brief To get a system memory descriptor's address
560
+ * @ingroup bmlib_runtime
561
+ *
562
+ * @param [in] mem The system memory descriptor
563
+ * @retval void * The system memory descriptor's address
564
+ */
565
+ DECL_EXPORT void *bm_mem_get_system_addr(struct bm_mem_desc mem);
566
+
567
+ /**
568
+ * @name sg_mem_get_system_addr
569
+ * @brief To get a system memory descriptor's address
570
+ * @ingroup bmlib_runtime
571
+ *
572
+ * @param [in] mem The system memory descriptor
573
+ * @retval void * The system memory descriptor's address
574
+ */
575
+ DECL_EXPORT void *sg_mem_get_system_addr(struct sg_mem_desc mem);
576
+
577
+ /**
578
+ * @name bm_mem_set_system_addr
579
+ * @brief To set a system memory descriptor's address
580
+ * @ingroup bmlib_runtime
581
+ *
582
+ * @param [in] pmem The system memory descriptor pointer
583
+ * @param [in] addr The system memory address
584
+ */
585
+ DECL_EXPORT void bm_mem_set_system_addr(struct bm_mem_desc* pmem, void *addr);
586
+
587
+ /**
588
+ * @name sg_mem_set_system_addr
589
+ * @brief To set a system memory descriptor's address
590
+ * @ingroup bmlib_runtime
591
+ *
592
+ * @param [in] pmem The system memory descriptor pointer
593
+ * @param [in] addr The system memory address
594
+ */
595
+ DECL_EXPORT void sg_mem_set_system_addr(struct sg_mem_desc* pmem, void *addr);
596
+
597
+ /**
598
+ * @name bm_mem_from_system
599
+ * @brief To create a system memory descriptor with the given system address
600
+ * @ingroup bmlib_runtime
601
+ *
602
+ * @param [in] system_addr The system address in the descriptor
603
+ * @retval bm_system_mem_t The system memory descriptor created
604
+ */
605
+ DECL_EXPORT bm_system_mem_t bm_mem_from_system(void *system_addr);
606
+
607
+ /*******************memory alloc and free functions ***************************/
608
+ /**
609
+ * @name bm_mem_null
610
+ * @brief Return an illegal device memory descriptor
611
+ * @ingroup bmlib_runtime
612
+ *
613
+ * @retval bm_device_mem_t An invalid device memory descriptor
614
+ */
615
+ DECL_EXPORT bm_device_mem_t bm_mem_null(void);
616
+ #define BM_MEM_NULL (bm_mem_null())
617
+
618
+ /**
619
+ * @name bm_malloc_neuron_device
620
+ * @brief To malloc device memory according to a tensor shape
621
+ * (each neuron is 32 bits)
622
+ * @ingroup bmlib_runtime
623
+ *
624
+ * @param [in] handle The device handle
625
+ * @param [out] pmem The result devcie memory descriptor
626
+ * @param [in] n, c, h, w The shape of the input tensor
627
+ * @retval BM_SUCCESS Succeeds.
628
+ * Other code Fails.
629
+ */
630
+ DECL_EXPORT bm_status_t bm_malloc_neuron_device(bm_handle_t handle, bm_device_mem_t *pmem,
631
+ int n, int c, int h, int w);
632
+
633
+ /**
634
+ * @name sg_malloc_neuron_device
635
+ * @brief To malloc device memory according to a tensor shape
636
+ * (each neuron is 32 bits)
637
+ * @ingroup bmlib_runtime
638
+ *
639
+ * @param [in] handle The device handle
640
+ * @param [out] pmem The result devcie memory descriptor
641
+ * @param [in] n, c, h, w The shape of the input tensor
642
+ * @retval BM_SUCCESS Succeeds.
643
+ * Other code Fails.
644
+ */
645
+ DECL_EXPORT bm_status_t sg_malloc_neuron_device(bm_handle_t handle, sg_device_mem_t *pmem,
646
+ unsigned long long n, unsigned long long c,
647
+ unsigned long long h, unsigned long long w);
648
+
649
+ /**
650
+ * @name bm_malloc_device_dword
651
+ * @brief To malloc device memory in size of dword (32 bits)
652
+ * @ingroup bmlib_runtime
653
+ *
654
+ * @param [in] handle The device handle
655
+ * @param [out] pmem The result device memory descriptor
656
+ * @param [in] count The number of dwords(32bits) to allocate
657
+ * @retval BM_SUCCESS Succeeds.
658
+ * Other code Fails.
659
+ */
660
+ DECL_EXPORT bm_status_t bm_malloc_device_dword(bm_handle_t handle, bm_device_mem_t *pmem,
661
+ int count);
662
+
663
+ /**
664
+ * @name sg_malloc_device_dword
665
+ * @brief To malloc device memory in size of dword (32 bits)
666
+ * @ingroup bmlib_runtime
667
+ *
668
+ * @param [in] handle The device handle
669
+ * @param [out] pmem The result device memory descriptor
670
+ * @param [in] count The number of dwords(32bits) to allocate
671
+ * @retval BM_SUCCESS Succeeds.
672
+ * Other code Fails.
673
+ */
674
+ DECL_EXPORT bm_status_t sg_malloc_device_dword(bm_handle_t handle, sg_device_mem_t *pmem,
675
+ unsigned long long count);
676
+
677
+ /**
678
+ * @name bm_malloc_device_byte
679
+ * @brief To malloc device memory in size of byte
680
+ * @ingroup bmlib_runtime
681
+ *
682
+ * @param [in] handle The device handle
683
+ * @param [out] pmem The result device memory descriptor
684
+ * @param [in] size The number of bytes to allocate
685
+ * @retval BM_SUCCESS Succeeds.
686
+ * Other code Fails.
687
+ */
688
+ DECL_EXPORT bm_status_t bm_malloc_device_byte(bm_handle_t handle, bm_device_mem_t *pmem,
689
+ unsigned int size);
690
+
691
+ /**
692
+ * @name sg_malloc_device_byte
693
+ * @brief To malloc device memory in size of byte
694
+ * @ingroup bmlib_runtime
695
+ *
696
+ * @param [in] handle The device handle
697
+ * @param [out] pmem The result device memory descriptor
698
+ * @param [in] size The number of bytes to allocate
699
+ * @retval BM_SUCCESS Succeeds.
700
+ * Other code Fails.
701
+ */
702
+ DECL_EXPORT bm_status_t sg_malloc_device_byte(bm_handle_t handle, sg_device_mem_t *pmem,
703
+ unsigned long long size);
704
+
705
+ /**
706
+ * @name bm_malloc_device_byte_heap
707
+ * @brief To malloc device memory in size of byte within the specified heap
708
+ * @ingroup bmlib_runtime
709
+ *
710
+ * @param [in] handle The device handle
711
+ * @param [out] pmem The result device memory descriptor
712
+ * @param [in] heap_id The heap where to allocate 0/1/2
713
+ * @param [in] size The number of bytes to allocate
714
+ * @retval BM_SUCCESS Succeeds.
715
+ * Other code Fails.
716
+ */
717
+ DECL_EXPORT bm_status_t bm_malloc_device_byte_heap(bm_handle_t handle, bm_device_mem_t *pmem,
718
+ int heap_id, unsigned int size);
719
+
720
+ /**
721
+ * @name sg_malloc_device_byte_heap
722
+ * @brief To malloc device memory in size of byte within the specified heap
723
+ * @ingroup bmlib_runtime
724
+ *
725
+ * @param [in] handle The device handle
726
+ * @param [out] pmem The result device memory descriptor
727
+ * @param [in] heap_id The heap where to allocate 0/1/2
728
+ * @param [in] size The number of bytes to allocate
729
+ * @retval BM_SUCCESS Succeeds.
730
+ * Other code Fails.
731
+ */
732
+ DECL_EXPORT bm_status_t sg_malloc_device_byte_heap(bm_handle_t handle, sg_device_mem_t *pmem,
733
+ int heap_id, unsigned long long size);
734
+
735
+ /**
736
+ * @name bm_malloc_device_byte_heap_mask
737
+ * @brief To malloc device memory in size of byte within the specified heaps
738
+ * @ingroup bmlib_runtime
739
+ *
740
+ * @param [in] handle The device handle
741
+ * @param [out] pmem The result device memory descriptor
742
+ * @param [in] heap_id_mask The mask which heaps allocate from. each bit indicate one heap
743
+ * @param [in] size The number of bytes to allocate
744
+ * @retval BM_SUCCESS Succeeds.
745
+ * Other code Fails.
746
+ */
747
+ DECL_EXPORT bm_status_t bm_malloc_device_byte_heap_mask(bm_handle_t handle, bm_device_mem_t *pmem,
748
+ int heap_id_mask, unsigned int size);
749
+
750
+ /**
751
+ * @name sg_malloc_device_byte_heap_mask
752
+ * @brief To malloc device memory in size of byte within the specified heaps
753
+ * @ingroup bmlib_runtime
754
+ *
755
+ * @param [in] handle The device handle
756
+ * @param [out] pmem The result device memory descriptor
757
+ * @param [in] heap_id_mask The mask which heaps allocate from. each bit indicate one heap
758
+ * @param [in] size The number of bytes to allocate
759
+ * @retval BM_SUCCESS Succeeds.
760
+ * Other code Fails.
761
+ */
762
+ DECL_EXPORT bm_status_t sg_malloc_device_byte_heap_mask(bm_handle_t handle, sg_device_mem_t *pmem,
763
+ int heap_id_mask, unsigned long long size);
764
+
765
+ /**
766
+ * @name bm_free_device
767
+ * @brief To free device memory
768
+ * @ingroup bmlib_runtime
769
+ *
770
+ * @param [in] handle The device handle
771
+ * @param [in] mem The device memory descriptor to free
772
+ */
773
+ DECL_EXPORT void bm_free_device(bm_handle_t handle, bm_device_mem_t mem);
774
+
775
+ /**
776
+ * @name sg_free_device
777
+ * @brief To free device memory
778
+ * @ingroup bmlib_runtime
779
+ *
780
+ * @param [in] handle The device handle
781
+ * @param [in] mem The device memory descriptor to free
782
+ */
783
+ DECL_EXPORT void sg_free_device(bm_handle_t handle, sg_device_mem_t mem);
784
+
785
+ /**
786
+ * @name bm_gmem_arm_reserved_request
787
+ * @brief To obtain the address of global memory reserved for arm926
788
+ * @param [in] handle The device handle
789
+ *
790
+ * @retval unsigned long long The absolute address of gmem reserved for arm926
791
+ */
792
+ DECL_EXPORT unsigned long long bm_gmem_arm_reserved_request(bm_handle_t handle);
793
+
794
+ /**
795
+ * @name bm_gmem_arm_reserved_release
796
+ * @brief To release the global memory reserved for arm926
797
+ * @ingroup bmlib_runtime
798
+ *
799
+ * @param [in] handle The device handle
800
+ */
801
+ DECL_EXPORT void bm_gmem_arm_reserved_release(bm_handle_t handle);
802
+
803
+ /*******************memory copy functions *************************************/
804
+ /**
805
+ * @name bm_memcpy_s2d
806
+ * @brief To copy data from system memory to device memory
807
+ * @ingroup bmlib_runtime
808
+ *
809
+ * @param [in] handle The device handle
810
+ * @param [in] dst The destination memory (device memory descriptor )
811
+ * @param [in] src The source memory (system memory, a void* pointer)
812
+ *
813
+ * @retval BM_SUCCESS Succeeds.
814
+ * Other code Fails.
815
+ */
816
+ DECL_EXPORT bm_status_t bm_memcpy_s2d(bm_handle_t handle, bm_device_mem_t dst, void *src);
817
+
818
+ /**
819
+ * @name bm_memcpy_p2p
820
+ * @brief To copy data from one chip to another chip
821
+ * @ingroup bmlib_runtime
822
+ *
823
+ * @param [in] handle_src The source device handle
824
+ * @param [in] src The source memory (device memory descriptor )
825
+ * @param [in] handle_dst The destination device handle
826
+ * @param [in] dst The destination memory (device memory descriptor )
827
+ *
828
+ * @retval BM_SUCCESS Succeeds.
829
+ * Other code Fails.
830
+ */
831
+ DECL_EXPORT bm_status_t bm_memcpy_p2p(bm_handle_t handle_src, bm_device_mem_t src, bm_handle_t handle_dst,bm_device_mem_t dst);
832
+
833
+ /**
834
+ * @name sg_memcpy_s2d
835
+ * @brief To copy data from system memory to device memory
836
+ * @ingroup bmlib_runtime
837
+ *
838
+ * @param [in] handle The device handle
839
+ * @param [in] dst The destination memory (device memory descriptor )
840
+ * @param [in] src The source memory (system memory, a void* pointer)
841
+ *
842
+ * @retval BM_SUCCESS Succeeds.
843
+ * Other code Fails.
844
+ */
845
+ DECL_EXPORT bm_status_t sg_memcpy_s2d(bm_handle_t handle, sg_device_mem_t dst, void *src);
846
+
847
+ /**
848
+ * @name bm_memcpy_s2d_partial_offset
849
+ * @brief To copy specified bytes of data from system memory to device memory
850
+ * with an offset in device memory address.
851
+ * @ingroup bmlib_runtime
852
+ *
853
+ * @param [in] handle The device handle
854
+ * @param [in] dst The destination memory (device memory descriptor)
855
+ * @param [in] src The source memory (system memory, a void* pointer)
856
+ * @param [in] size The size of data to copy (in bytes)
857
+ * @param [in] offset The offset of the device memory address
858
+ *
859
+ * @retval BM_SUCCESS Succeeds.
860
+ * Other code Fails.
861
+ */
862
+ DECL_EXPORT bm_status_t bm_memcpy_s2d_partial_offset(bm_handle_t handle,
863
+ bm_device_mem_t dst, void *src,
864
+ unsigned int size,
865
+ unsigned int offset);
866
+
867
+ /**
868
+ * @name sg_memcpy_s2d_partial_offset
869
+ * @brief To copy specified bytes of data from system memory to device memory
870
+ * with an offset in device memory address.
871
+ * @ingroup bmlib_runtime
872
+ *
873
+ * @param [in] handle The device handle
874
+ * @param [in] dst The destination memory (device memory descriptor)
875
+ * @param [in] src The source memory (system memory, a void* pointer)
876
+ * @param [in] size The size of data to copy (in bytes)
877
+ * @param [in] offset The offset of the device memory address
878
+ *
879
+ * @retval BM_SUCCESS Succeeds.
880
+ * Other code Fails.
881
+ */
882
+ DECL_EXPORT bm_status_t sg_memcpy_s2d_partial_offset(bm_handle_t handle,
883
+ sg_device_mem_t dst, void *src,
884
+ unsigned long long size,
885
+ unsigned long long offset);
886
+
887
+ /**
888
+ * @name bm_memcpy_s2d_partial
889
+ * @brief To copy specified bytes of data from system memory to device memory
890
+ * @ingroup bmlib_runtime
891
+ *
892
+ * @param [in] handle The device handle
893
+ * @param [in] dst The destination memory (device memory descriptor)
894
+ * @param [in] src The source memory (system memory, a void* pointer)
895
+ * @param [in] size The size of data to copy (in bytes)
896
+ *
897
+ * @retval BM_SUCCESS Succeeds.
898
+ * Other code Fails.
899
+ */
900
+ DECL_EXPORT bm_status_t bm_memcpy_s2d_partial(bm_handle_t handle, bm_device_mem_t dst,
901
+ void *src, unsigned int size);
902
+
903
+ /**
904
+ * @name sg_memcpy_s2d_partial
905
+ * @brief To copy specified bytes of data from system memory to device memory
906
+ * @ingroup bmlib_runtime
907
+ *
908
+ * @param [in] handle The device handle
909
+ * @param [in] dst The destination memory (device memory descriptor)
910
+ * @param [in] src The source memory (system memory, a void* pointer)
911
+ * @param [in] size The size of data to copy (in bytes)
912
+ *
913
+ * @retval BM_SUCCESS Succeeds.
914
+ * Other code Fails.
915
+ */
916
+ DECL_EXPORT bm_status_t sg_memcpy_s2d_partial(bm_handle_t handle, sg_device_mem_t dst,
917
+ void *src, unsigned long long size);
918
+
919
+ /**
920
+ * @name bm_memcpy_d2s
921
+ * @brief To copy data from device memory to system memory
922
+ * @ingroup bmlib_runtime
923
+ *
924
+ * @param [in] handle The device handle
925
+ * @param [in] dst The destination memory (system memory, a void* pointer)
926
+ * @param [in] src The source memory (device memory descriptor)
927
+ *
928
+ * @retval BM_SUCCESS Succeeds.
929
+ * Other code Fails.
930
+ */
931
+ DECL_EXPORT bm_status_t bm_memcpy_d2s(bm_handle_t handle, void *dst, bm_device_mem_t src);
932
+
933
+ /**
934
+ * @name sg_memcpy_d2s
935
+ * @brief To copy data from device memory to system memory
936
+ * @ingroup bmlib_runtime
937
+ *
938
+ * @param [in] handle The device handle
939
+ * @param [in] dst The destination memory (system memory, a void* pointer)
940
+ * @param [in] src The source memory (device memory descriptor)
941
+ *
942
+ * @retval BM_SUCCESS Succeeds.
943
+ * Other code Fails.
944
+ */
945
+ DECL_EXPORT bm_status_t sg_memcpy_d2s(bm_handle_t handle, void *dst, sg_device_mem_t src);
946
+
947
+ /**
948
+ * @name bm_memcpy_d2s_partial_offset
949
+ * @brief To copy specified bytes of data from device memory to system memory
950
+ * with an offset in device memory address.
951
+ * @ingroup bmlib_runtime
952
+ *
953
+ * @param [in] handle The device handle
954
+ * @param [in] dst The destination memory (system memory, a void* pointer)
955
+ * @param [in] src The source memory (device memory descriptor)
956
+ * @param [in] size The size of data to copy (in bytes)
957
+ * @param [in] offset The offset of the device memory address
958
+ *
959
+ * @retval BM_SUCCESS Succeeds.
960
+ * Other code Fails.
961
+ */
962
+ DECL_EXPORT bm_status_t bm_memcpy_d2s_partial_offset(bm_handle_t handle, void *dst,
963
+ bm_device_mem_t src, unsigned int size,
964
+ unsigned int offset);
965
+
966
+ /**
967
+ * @name sg_memcpy_d2s_partial_offset
968
+ * @brief To copy specified bytes of data from device memory to system memory
969
+ * with an offset in device memory address.
970
+ * @ingroup bmlib_runtime
971
+ *
972
+ * @param [in] handle The device handle
973
+ * @param [in] dst The destination memory (system memory, a void* pointer)
974
+ * @param [in] src The source memory (device memory descriptor)
975
+ * @param [in] size The size of data to copy (in bytes)
976
+ * @param [in] offset The offset of the device memory address
977
+ *
978
+ * @retval BM_SUCCESS Succeeds.
979
+ * Other code Fails.
980
+ */
981
+ DECL_EXPORT bm_status_t sg_memcpy_d2s_partial_offset(bm_handle_t handle, void *dst,
982
+ sg_device_mem_t src, unsigned long long size,
983
+ unsigned long long offset);
984
+
985
+ /**
986
+ * @name bm_memcpy_d2s_partial
987
+ * @brief To copy specified bytes of data from device memory to system memory
988
+ * @ingroup bmlib_runtime
989
+ *
990
+ * @param [in] handle The device handle
991
+ * @param [in] dst The destination memory (system memory, a void* pointer)
992
+ * @param [in] src The source memory (device memory descriptor)
993
+ * @param [in] size The size of data to copy (in bytes)
994
+ *
995
+ * @retval BM_SUCCESS Data transfer succeeds.
996
+ * Other code Data transfer fails.
997
+ */
998
+ DECL_EXPORT bm_status_t bm_memcpy_d2s_partial(bm_handle_t handle, void *dst,
999
+ bm_device_mem_t src, unsigned int size);
1000
+
1001
+ /**
1002
+ * @name sg_memcpy_d2s_partial
1003
+ * @brief To copy specified bytes of data from device memory to system memory
1004
+ * @ingroup bmlib_runtime
1005
+ *
1006
+ * @param [in] handle The device handle
1007
+ * @param [in] dst The destination memory (system memory, a void* pointer)
1008
+ * @param [in] src The source memory (device memory descriptor)
1009
+ * @param [in] size The size of data to copy (in bytes)
1010
+ *
1011
+ * @retval BM_SUCCESS Data transfer succeeds.
1012
+ * Other code Data transfer fails.
1013
+ */
1014
+ DECL_EXPORT bm_status_t sg_memcpy_d2s_partial(bm_handle_t handle, void *dst,
1015
+ sg_device_mem_t src, unsigned long long size);
1016
+
1017
+ /**
1018
+ * @name bm_memcpy_d2d
1019
+ * @brief To copy specified dwords of data from one piece of device memory
1020
+ * to another piece of device memory within one device. Both source
1021
+ * and destination offsets can be specified.
1022
+ * @ingroup bmlib_runtime
1023
+ *
1024
+ * @param [in] handle The device handle
1025
+ * @param [in] dst The destination device memory
1026
+ * @param [in] dst_offset The offset of destination device memory address
1027
+ * @param [in] src The source device memory
1028
+ * @param [in] src_offset The offset of source device memory address
1029
+ * @param [in] len Length of data to copy (in DWORD 4 bytes)
1030
+ *
1031
+ * @retval BM_SUCCESS Succeeds.
1032
+ * Other code Fails.
1033
+ */
1034
+ DECL_EXPORT bm_status_t bm_memcpy_d2d(bm_handle_t handle, bm_device_mem_t dst,
1035
+ int dst_offset, bm_device_mem_t src, int src_offset,
1036
+ int len);
1037
+
1038
+ /**
1039
+ * @name bm_memcpy_d2d_with_core
1040
+ * @brief To copy specified dwords of data from one piece of device memory
1041
+ * to another piece of device memory within one device. Both source
1042
+ * and destination offsets can be specified.
1043
+ * @ingroup bmlib_runtime
1044
+ *
1045
+ * @param [in] handle The device handle
1046
+ * @param [in] dst The destination device memory
1047
+ * @param [in] dst_offset The offset of destination device memory address
1048
+ * @param [in] src The source device memory
1049
+ * @param [in] src_offset The offset of source device memory address
1050
+ * @param [in] len Length of data to copy (in DWORD 4 bytes)
1051
+ * @param [in] core_id The core id to copy
1052
+ *
1053
+ * @retval BM_SUCCESS Succeeds.
1054
+ * Other code Fails.
1055
+ */
1056
+ DECL_EXPORT bm_status_t bm_memcpy_d2d_with_core(bm_handle_t handle, bm_device_mem_t dst,
1057
+ int dst_offset, bm_device_mem_t src, int src_offset,
1058
+ int len, int core_id);
1059
+
1060
+ /**
1061
+ * @name bm_memcpy_d2d_byte
1062
+ * @brief To copy specified bytes of data from one piece of device memory
1063
+ * to another piece of device memory within one device. Both source
1064
+ * and destination offsets can be specified.
1065
+ * @ingroup bmlib_runtime
1066
+ *
1067
+ * @param [in] handle The device handle
1068
+ * @param [in] dst The destination device memory
1069
+ * @param [in] dst_offset The offset of destination device memory address (in bytes)
1070
+ * @param [in] src The source device memory
1071
+ * @param [in] src_offset The offset of source device memory address (in bytes)
1072
+ * @param [in] size Size of data to copy (in bytes)
1073
+ *
1074
+ * @retval BM_SUCCESS Succeeds.
1075
+ * Other code Fails.
1076
+ */
1077
+ DECL_EXPORT bm_status_t bm_memcpy_d2d_byte(bm_handle_t handle, bm_device_mem_t dst,
1078
+ size_t dst_offset, bm_device_mem_t src,
1079
+ size_t src_offset, size_t size);
1080
+
1081
+ /**
1082
+ * @name bm_memcpy_d2d_byte_with_core
1083
+ * @brief To copy specified bytes of data from one piece of device memory
1084
+ * to another piece of device memory within one device. Both source
1085
+ * and destination offsets can be specified.
1086
+ * @ingroup bmlib_runtime
1087
+ *
1088
+ * @param [in] handle The device handle
1089
+ * @param [in] dst The destination device memory
1090
+ * @param [in] dst_offset The offset of destination device memory address (in bytes)
1091
+ * @param [in] src The source device memory
1092
+ * @param [in] src_offset The offset of source device memory address (in bytes)
1093
+ * @param [in] size Size of data to copy (in bytes)
1094
+ * @param [in] core_id The core id to copy
1095
+ *
1096
+ * @retval BM_SUCCESS Succeeds.
1097
+ * Other code Fails.
1098
+ */
1099
+ DECL_EXPORT bm_status_t bm_memcpy_d2d_byte_with_core(bm_handle_t handle, bm_device_mem_t dst,
1100
+ size_t dst_offset, bm_device_mem_t src,
1101
+ size_t src_offset, size_t size, int core_id);
1102
+
1103
+ /**
1104
+ * @name bm_memcpy_d2d_stride
1105
+ * @brief To copy specified data from one piece of device memory
1106
+ * to another piece of device memory within one device. Both source
1107
+ * and destination offsets can be specified.
1108
+ * @ingroup bmlib_runtime
1109
+ *
1110
+ * @param [in] handle The device handle
1111
+ * @param [in] dst The destination device memory
1112
+ * @param [in] dst_stride The data stride of destination data
1113
+ * @param [in] src The source device memory
1114
+ * @param [in] src_stride The data stride of source data
1115
+ * @param [in] count Count of data to copy
1116
+ * @param [in] format_size Data format byte size, such as sizeof(uint8_t), sizeof(float), etc.
1117
+ * format_size only support 1/2/4.
1118
+ *
1119
+ * dst_stride MUST be 1, EXCEPT: dst_stride == 4 && src_stride == 1 && format_size ==1
1120
+ *
1121
+ * @retval BM_SUCCESS Succeeds.
1122
+ * Other code Fails.
1123
+ */
1124
+ DECL_EXPORT bm_status_t bm_memcpy_d2d_stride(bm_handle_t handle,
1125
+ bm_device_mem_t dst,
1126
+ int dst_stride,
1127
+ bm_device_mem_t src,
1128
+ int src_stride,
1129
+ int count,
1130
+ int format_size);
1131
+
1132
+ /**
1133
+ * @name bm_memcpy_d2d_stride
1134
+ * @brief To copy specified data from one piece of device memory
1135
+ * to another piece of device memory within one device. Both source
1136
+ * and destination offsets can be specified.
1137
+ * @ingroup bmlib_runtime
1138
+ *
1139
+ * @param [in] handle The device handle
1140
+ * @param [in] dst The destination device memory
1141
+ * @param [in] dst_stride The data stride of destination data
1142
+ * @param [in] src The source device memory
1143
+ * @param [in] src_stride The data stride of source data
1144
+ * @param [in] count Count of data to copy
1145
+ * @param [in] format_size Data format byte size, such as sizeof(uint8_t), sizeof(float), etc.
1146
+ * format_size only support 1/2/4.
1147
+ * @param [in] core_id The core id to copy.
1148
+ *
1149
+ * dst_stride MUST be 1, EXCEPT: dst_stride == 4 && src_stride == 1 && format_size ==1
1150
+ *
1151
+ * @retval BM_SUCCESS Succeeds.
1152
+ * Other code Fails.
1153
+ */
1154
+ DECL_EXPORT bm_status_t bm_memcpy_d2d_stride_with_core(bm_handle_t handle,
1155
+ bm_device_mem_t dst,
1156
+ int dst_stride,
1157
+ bm_device_mem_t src,
1158
+ int src_stride,
1159
+ int count,
1160
+ int format_size,
1161
+ int core_id);
1162
+
1163
+ /**
1164
+ * @name bm_memcpy_c2c
1165
+ * @brief To copy data from one chip to another chip.
1166
+ * (Used in multi-chip card scenario)
1167
+ * @ingroup bmlib_runtime
1168
+ *
1169
+ * @param [in] src_handle The source device handle
1170
+ * @param [in] dst_handle The destination device handle
1171
+ * @param [in] src The source device memory descriptor
1172
+ * @param [in] dst The destination device memory descriptor
1173
+ * @param [in] force_dst_cdma If use the CDMA engine of the destination device
1174
+ * @retval BM_SUCCESS Succeeds.
1175
+ * Other code Fails.
1176
+ */
1177
+ DECL_EXPORT bm_status_t bm_memcpy_c2c(bm_handle_t src_handle, bm_handle_t dst_handle,
1178
+ bm_device_mem_t src, bm_device_mem_t dst,
1179
+ bool force_dst_cdma);
1180
+
1181
+ /**
1182
+ * @name bm_memset_device
1183
+ * @brief To fill in specified device memory with the given value
1184
+ * @ingroup bmlib_runtime
1185
+ *
1186
+ * @param [in] handle The device handle
1187
+ * @param [in] value The value used to fill. (int type)
1188
+ * @param [in] mem The device memory which will be filled in
1189
+ * @retval BM_SUCCESS Succeeds.
1190
+ * Other code Fails.
1191
+ */
1192
+ DECL_EXPORT bm_status_t bm_memset_device(bm_handle_t handle, const int value,
1193
+ bm_device_mem_t mem);
1194
+
1195
+ /**
1196
+ * @name bm_memset_device_ext
1197
+ * @brief To fill in specified device memory with the given value and mode
1198
+ * @ingroup bmlib_runtime
1199
+ *
1200
+ * @param [in] handle The device handle
1201
+ * @param [in] value The pointer of value used to fill
1202
+ * @param [in] mode The valid bytes of *value
1203
+ * @param [in] mem The device memory which will be filled in
1204
+ * @retval BM_SUCCESS Succeeds.
1205
+ * Other code Fails.
1206
+ */
1207
+ DECL_EXPORT bm_status_t bm_memset_device_ext(bm_handle_t handle, void* value, int mode,
1208
+ bm_device_mem_t mem);
1209
+
1210
+ /**
1211
+ * @name bm_mem_convert_system_to_device_neuron
1212
+ * @brief To malloc a piece of device memory according to the shape of
1213
+ * neuron(in DWORD 4 bytes); copy neuron from system memory to
1214
+ * device memory if need_copy is true.
1215
+ * @ingroup bmlib_runtime
1216
+ *
1217
+ * @param [in] handle The device handle
1218
+ * @param [in] dev_mem The device memory descriptor
1219
+ * @param [in] sys_mem The system memory descriptor
1220
+ * @param [in] need_copy If copy from system to device is needed
1221
+ * @param [in] n,c,h,w Neuron shape size
1222
+ *
1223
+ * @retval BM_SUCCESS Succeeds.
1224
+ * Other code Fails.
1225
+ */
1226
+ DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_neuron(bm_handle_t handle,
1227
+ struct bm_mem_desc *dev_mem,
1228
+ struct bm_mem_desc sys_mem,
1229
+ bool need_copy, int n, int c,
1230
+ int h, int w);
1231
+
1232
+ /**
1233
+ * @name bm_mem_convert_system_to_device_neuron_byte
1234
+ * @brief To malloc a piece of device memory according to the shape of
1235
+ * neuron(in bytes); copy neuron from system memory to
1236
+ * device memory if need_copy is true.
1237
+ * @ingroup bmlib_runtime
1238
+ *
1239
+ * @param [in] handle The device handle
1240
+ * @param [in] dev_mem The device memory descriptor
1241
+ * @param [in] sys_mem The system memory descriptor
1242
+ * @param [in] need_copy If copy from system to device is needed
1243
+ * @param [in] n,c,h,w Neuron shape size
1244
+ *
1245
+ * @retval BM_SUCCESS Succeeds.
1246
+ * Other code Fails.
1247
+ */
1248
+ DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_neuron_byte(
1249
+ bm_handle_t handle, struct bm_mem_desc *dev_mem, struct bm_mem_desc sys_mem,
1250
+ bool need_copy, int n, int c, int h, int w);
1251
+
1252
+ /**
1253
+ * @name bm_mem_convert_system_to_device_coeff
1254
+ * @brief To malloc a piece of device memory according to the size of
1255
+ * coefficient (in DWORD 4 bytes); copy coefficient from system
1256
+ * memory to device memory if need_copy is true.
1257
+ * @ingroup bmlib_runtime
1258
+ *
1259
+ * @param [in] handle The device handle
1260
+ * @param [in] dev_mem The device memory descriptor
1261
+ * @param [in] sys_mem The system memory descriptor
1262
+ * @param [in] need_copy If copy from system to device is needed
1263
+ * @param [in] coeff_count Coefficient size
1264
+ *
1265
+ * @retval BM_SUCCESS Succeeds.
1266
+ * Other code Fails.
1267
+ */
1268
+ DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_coeff(bm_handle_t handle,
1269
+ struct bm_mem_desc *dev_mem,
1270
+ struct bm_mem_desc sys_mem,
1271
+ bool need_copy,
1272
+ int coeff_count);
1273
+ /**
1274
+ * @name bm_mem_convert_system_to_device_coeff_byte
1275
+ * @brief To malloc a piece of device memory according to the size of
1276
+ * coefficient (in bytes); copy coefficient from system
1277
+ * memory to device memory if need_copy is true.
1278
+ * @ingroup bmlib_runtime
1279
+ *
1280
+ * @param [in] handle The device handle
1281
+ * @param [in] dev_mem The device memory descriptor
1282
+ * @param [in] sys_mem The system memory descriptor
1283
+ * @param [in] need_copy If copy from system to device is needed
1284
+ * @param [in] coeff_count Coefficient size
1285
+ *
1286
+ * @retval BM_SUCCESS Succeeds.
1287
+ * Other code Fails.
1288
+ */
1289
+ DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_coeff_byte(
1290
+ bm_handle_t handle, struct bm_mem_desc *dev_mem, struct bm_mem_desc sys_mem,
1291
+ bool need_copy, int coeff_count);
1292
+
1293
+ /*******************memory map functions *************************************/
1294
+ /**
1295
+ * @name bm_mem_mmap_device_mem
1296
+ * @brief To map a piece of device memory to user space with cache enabled.
1297
+ * (only valid in SoC mode; Not supported in PCIE mode).
1298
+ * @ingroup bmlib_runtime
1299
+ *
1300
+ * @param [in] handle The device handle
1301
+ * @param [in] dev_mem The device memory to map
1302
+ * @param [out] vmem The virtual address of the mapped device memory
1303
+ *
1304
+ * @retval BM_SUCCESS Succeeds.
1305
+ * Other code Fails.
1306
+ */
1307
+ DECL_EXPORT bm_status_t bm_mem_mmap_device_mem(bm_handle_t handle, bm_device_mem_t *dmem,
1308
+
1309
+ unsigned long long *vmem);
1310
+
1311
+ /**
1312
+ * @name sg_mem_mmap_device_mem
1313
+ * @brief To map a piece of device memory to user space with cache enabled.
1314
+ * (only valid in SoC mode; Not supported in PCIE mode).
1315
+ * @ingroup bmlib_runtime
1316
+ *
1317
+ * @param [in] handle The device handle
1318
+ * @param [in] dev_mem The device memory to map
1319
+ * @param [out] vmem The virtual address of the mapped device memory
1320
+ *
1321
+ * @retval BM_SUCCESS Succeeds.
1322
+ * Other code Fails.
1323
+ */
1324
+ DECL_EXPORT bm_status_t sg_mem_mmap_device_mem(bm_handle_t handle, sg_device_mem_t *dmem,
1325
+ unsigned long long *vmem);
1326
+
1327
+ /*******************memory map functions *************************************/
1328
+ /**
1329
+ * @name bm_mem_mmap_device_mem_no_cache
1330
+ * @brief To map a piece of device memory to user space with cache disabled.
1331
+ * (only valid in SoC mode; Not supported in PCIE mode).
1332
+ * @ingroup bmlib_runtime
1333
+ *
1334
+ * @param [in] handle The device handle
1335
+ * @param [in] dev_mem The device memory to map
1336
+ * @param [out] vmem The virtual address of the mapped device memory
1337
+ *
1338
+ * @retval BM_SUCCESS Succeeds.
1339
+ * Other code Fails.
1340
+ */
1341
+ DECL_EXPORT bm_status_t bm_mem_mmap_device_mem_no_cache(bm_handle_t handle, bm_device_mem_t *dmem,
1342
+
1343
+ unsigned long long *vmem);
1344
+
1345
+ /**
1346
+ * @name sg_mem_mmap_device_mem_no_cache
1347
+ * @brief To map a piece of device memory to user space with cache disabled.
1348
+ * (only valid in SoC mode; Not supported in PCIE mode).
1349
+ * @ingroup bmlib_runtime
1350
+ *
1351
+ * @param [in] handle The device handle
1352
+ * @param [in] dev_mem The device memory to map
1353
+ * @param [out] vmem The virtual address of the mapped device memory
1354
+ *
1355
+ * @retval BM_SUCCESS Succeeds.
1356
+ * Other code Fails.
1357
+ */
1358
+ DECL_EXPORT bm_status_t sg_mem_mmap_device_mem_no_cache(bm_handle_t handle, sg_device_mem_t *dmem,
1359
+ unsigned long long *vmem);
1360
+
1361
+ /**
1362
+ * @name bm_mem_vir_to_phy
1363
+ * @brief To get device mem address through the mapped virtual address .
1364
+ * (only valid in SoC mode; Not supported in PCIE mode).
1365
+ * @ingroup bmlib_runtime
1366
+ *
1367
+ * @param [in] handle The device handle
1368
+ * @param [in] vmem The virtual address of the mapped device memory
1369
+ * @param [out] dev_mem The device memory address
1370
+ *
1371
+ * @retval BM_SUCCESS Succeeds.
1372
+ * Other code Fails.
1373
+ */
1374
+ DECL_EXPORT bm_status_t bm_mem_vir_to_phy(bm_handle_t handle, unsigned long long vmem,
1375
+ unsigned long long *device_mem);
1376
+ /**
1377
+ * @name bm_mem_invalidate_device_mem
1378
+ * @brief To invalidate a piece of mapped device memory to maintain
1379
+ * cache coherence
1380
+ * (only valid in SoC mode; Not supported in PCIE mode).
1381
+ * @ingroup bmlib_runtime
1382
+ *
1383
+ * @param [in] handle The device handle
1384
+ * @param [in] dmem The device memory to invalidate
1385
+ *
1386
+ * @retval BM_SUCCESS Succeeds.
1387
+ * Other code Fails.
1388
+ */
1389
+
1390
+ DECL_EXPORT bm_status_t bm_mem_invalidate_device_mem(bm_handle_t handle,
1391
+ bm_device_mem_t *dmem);
1392
+
1393
+ /**
1394
+ * @name sg_mem_invalidate_device_mem
1395
+ * @brief To invalidate a piece of mapped device memory to maintain
1396
+ * cache coherence
1397
+ * (only valid in SoC mode; Not supported in PCIE mode).
1398
+ * @ingroup bmlib_runtime
1399
+ *
1400
+ * @param [in] handle The device handle
1401
+ * @param [in] dmem The device memory to invalidate
1402
+ *
1403
+ * @retval BM_SUCCESS Succeeds.
1404
+ * Other code Fails.
1405
+ */
1406
+
1407
+ DECL_EXPORT bm_status_t sg_mem_invalidate_device_mem(bm_handle_t handle,
1408
+ sg_device_mem_t *dmem);
1409
+
1410
+ /**
1411
+ * @name bm_mem_invalidate_partial_device_mem
1412
+ * @brief To invalidate part of mapped device memory to maintain
1413
+ * cache coherence
1414
+ * (only valid in SoC mode; Not supported in PCIE mode).
1415
+ * @ingroup bmlib_runtime
1416
+ *
1417
+ * @param [in] handle The device handle
1418
+ * @param [in] dmem The device memory to invalidate
1419
+ * @param [in] offset The offset of device memory address
1420
+ * @param [in] len The length of memory to invalidate in bytes
1421
+ *
1422
+ * @retval BM_SUCCESS Succeeds.
1423
+ * Other code Fails.
1424
+ */
1425
+ DECL_EXPORT bm_status_t bm_mem_invalidate_partial_device_mem(bm_handle_t handle,
1426
+ bm_device_mem_t *dmem,
1427
+ unsigned int offset,
1428
+ unsigned int len);
1429
+
1430
+ /**
1431
+ * @name sg_mem_invalidate_partial_device_mem
1432
+ * @brief To invalidate part of mapped device memory to maintain
1433
+ * cache coherence
1434
+ * (only valid in SoC mode; Not supported in PCIE mode).
1435
+ * @ingroup bmlib_runtime
1436
+ *
1437
+ * @param [in] handle The device handle
1438
+ * @param [in] dmem The device memory to invalidate
1439
+ * @param [in] offset The offset of device memory address
1440
+ * @param [in] len The length of memory to invalidate in bytes
1441
+ *
1442
+ * @retval BM_SUCCESS Succeeds.
1443
+ * Other code Fails.
1444
+ */
1445
+ DECL_EXPORT bm_status_t sg_mem_invalidate_partial_device_mem(bm_handle_t handle,
1446
+ sg_device_mem_t *dmem,
1447
+ unsigned long long offset,
1448
+ unsigned long long len);
1449
+
1450
+ /**
1451
+ * @name bm_mem_flush_device_mem
1452
+ * @brief To flush a piece of mapped device memory to maintain
1453
+ * cache coherence
1454
+ * (only valid in SoC mode; Not supported in PCIE mode).
1455
+ * @ingroup bmlib_runtime
1456
+ *
1457
+ * @param [in] handle The device handle
1458
+ * @param [in] dmem The device memory to flush
1459
+ *
1460
+ * @retval BM_SUCCESS Succeeds.
1461
+ * Other code Fails.
1462
+ */
1463
+ DECL_EXPORT bm_status_t bm_mem_flush_device_mem(bm_handle_t handle, bm_device_mem_t *dmem);
1464
+
1465
+ /**
1466
+ * @name sg_mem_flush_device_mem
1467
+ * @brief To flush a piece of mapped device memory to maintain
1468
+ * cache coherence
1469
+ * (only valid in SoC mode; Not supported in PCIE mode).
1470
+ * @ingroup bmlib_runtime
1471
+ *
1472
+ * @param [in] handle The device handle
1473
+ * @param [in] dmem The device memory to flush
1474
+ *
1475
+ * @retval BM_SUCCESS Succeeds.
1476
+ * Other code Fails.
1477
+ */
1478
+ DECL_EXPORT bm_status_t sg_mem_flush_device_mem(bm_handle_t handle, sg_device_mem_t *dmem);
1479
+
1480
+ /**
1481
+ * @name bm_mem_flush_partial_device_mem
1482
+ * @brief To flush part of mapped device memory to maintain
1483
+ * cache coherence
1484
+ * (only valid in SoC mode; Not supported in PCIE mode).
1485
+ * @ingroup bmlib_runtime
1486
+ *
1487
+ * @param [in] handle The device handle
1488
+ * @param [in] dmem The device memory to flush
1489
+ * @param [in] offset The offset of device memory address
1490
+ * @param [in] len The length of memory to flush in bytes
1491
+ *
1492
+ * @retval BM_SUCCESS Succeeds.
1493
+ * Other code Fails.
1494
+ */
1495
+ DECL_EXPORT bm_status_t bm_mem_flush_partial_device_mem(bm_handle_t handle,
1496
+ bm_device_mem_t *dmem,
1497
+ unsigned int offset,
1498
+ unsigned int len);
1499
+
1500
+ /**
1501
+ * @name sg_mem_flush_partial_device_mem
1502
+ * @brief To flush part of mapped device memory to maintain
1503
+ * cache coherence
1504
+ * (only valid in SoC mode; Not supported in PCIE mode).
1505
+ * @ingroup bmlib_runtime
1506
+ *
1507
+ * @param [in] handle The device handle
1508
+ * @param [in] dmem The device memory to flush
1509
+ * @param [in] offset The offset of device memory address
1510
+ * @param [in] len The length of memory to flush in bytes
1511
+ *
1512
+ * @retval BM_SUCCESS Succeeds.
1513
+ * Other code Fails.
1514
+ */
1515
+ DECL_EXPORT bm_status_t sg_mem_flush_partial_device_mem(bm_handle_t handle,
1516
+ sg_device_mem_t *dmem,
1517
+ unsigned long long offset,
1518
+ unsigned long long len);
1519
+
1520
+ /**
1521
+ * @name bm_mem_unmap_device_mem
1522
+ * @brief To unmap a piece of mapped device memory
1523
+ * (only valid in SoC mode; Not supported in PCIE mode).
1524
+ * @ingroup bmlib_runtime
1525
+ *
1526
+ * @param [in] handle The device handle
1527
+ * @param [in] vmem The virtual address of the mapped device memory
1528
+ * @param [in] size The size of unmapped memory
1529
+ *
1530
+ * @retval BM_SUCCESS Succeeds.
1531
+ * Other code Fails.
1532
+ */
1533
+ DECL_EXPORT bm_status_t bm_mem_unmap_device_mem(bm_handle_t handle, void *vmem, int size);
1534
+
1535
+ /**
1536
+ * @name sg_mem_unmap_device_mem
1537
+ * @brief To unmap a piece of mapped device memory
1538
+ * (only valid in SoC mode; Not supported in PCIE mode).
1539
+ * @ingroup bmlib_runtime
1540
+ *
1541
+ * @param [in] handle The device handle
1542
+ * @param [in] vmem The virtual address of the mapped device memory
1543
+ * @param [in] size The size of unmapped memory
1544
+ *
1545
+ * @retval BM_SUCCESS Succeeds.
1546
+ * Other code Fails.
1547
+ */
1548
+ DECL_EXPORT bm_status_t sg_mem_unmap_device_mem(bm_handle_t handle, void *vmem, unsigned long long size);
1549
+
1550
+ /*******************api(kernel) functions *************************************/
1551
+ /**
1552
+ * @name bm_flush
1553
+ * @brief To synchronize APIs of the current thread. The thread will block
1554
+ * until all the outstanding APIs of the current thread are finished.
1555
+ * @ingroup bmlib_runtime
1556
+ *
1557
+ * @param [in] handle The device handle
1558
+ */
1559
+ DECL_EXPORT void bm_flush(bm_handle_t handle);
1560
+
1561
+ /**
1562
+ * @name bm_device_sync
1563
+ * @brief To synchronize APIs of the device. The thread will block
1564
+ * until all the outstanding APIs of the device are finished.
1565
+ * @ingroup bmlib_runtime
1566
+ *
1567
+ * @param [in] handle The device handle
1568
+ * @retval BM_SUCCESS Succeeds.
1569
+ * Other code Fails.
1570
+ */
1571
+ DECL_EXPORT bm_status_t bm_device_sync(bm_handle_t handle);
1572
+
1573
+ /**
1574
+ * @name bm_handle_sync
1575
+ * @brief To synchronize APIs of the handle. The thread will block
1576
+ * until all the outstanding APIs of the handle are finished.
1577
+ * @ingroup bmlib_runtime
1578
+ *
1579
+ * @param [in] handle The device handle
1580
+ * @retval BM_SUCCESS Succeeds.
1581
+ * Other code Fails.
1582
+ */
1583
+ DECL_EXPORT bm_status_t bm_handle_sync(bm_handle_t handle);
1584
+
1585
+ /**
1586
+ * @name bm_handle_sync_from_core
1587
+ * @brief To synchronize APIs of the handle. The thread will block
1588
+ * until all the outstanding APIs of the handle are finished.
1589
+ * @ingroup bmlib_runtime
1590
+ *
1591
+ * @param [in] handle The device handle
1592
+ * @param [in] core_id The core id
1593
+ * @retval BM_SUCCESS Succeeds.
1594
+ * Other code Fails.
1595
+ */
1596
+ DECL_EXPORT bm_status_t bm_handle_sync_from_core(bm_handle_t handle, int core_id);
1597
+
1598
+ /**
1599
+ * @name bm_thread_sync
1600
+ * @brief To synchronize APIs of the current thread. The thread will block
1601
+ * until all the outstanding APIs of the current thread are finished.
1602
+ * @ingroup bmlib_runtime
1603
+ *
1604
+ * @param [in] handle The device handle
1605
+ * @retval BM_SUCCESS Succeeds.
1606
+ * Other code Fails.
1607
+ */
1608
+ DECL_EXPORT bm_status_t bm_thread_sync(bm_handle_t handle);
1609
+
1610
+ /**
1611
+ * @name bm_thread_sync_from_core
1612
+ * @brief To synchronize APIs of the current thread. The thread will block
1613
+ * until all the outstanding APIs of the current thread are finished.
1614
+ * @ingroup bmlib_runtime
1615
+ *
1616
+ * @param [in] handle The device handle
1617
+ * @param [in] core_id The core id
1618
+ * @retval BM_SUCCESS Succeeds.
1619
+ * Other code Fails.
1620
+ */
1621
+ DECL_EXPORT bm_status_t bm_thread_sync_from_core(bm_handle_t handle, int core_id);
1622
+
1623
+ /*******************trace and profile releated functions **********************/
1624
+ typedef struct bm_profile {
1625
+ #ifdef __linux__
1626
+ unsigned long cdma_in_time;
1627
+ unsigned long cdma_in_counter;
1628
+ unsigned long cdma_out_time;
1629
+ unsigned long cdma_out_counter;
1630
+ unsigned long tpu_process_time;
1631
+ unsigned long tpu1_process_time;
1632
+ unsigned long sent_api_counter;
1633
+ unsigned long completed_api_counter;
1634
+ #else
1635
+ unsigned long long cdma_in_time;
1636
+ unsigned long long cdma_in_counter;
1637
+ unsigned long long cdma_out_time;
1638
+ unsigned long long cdma_out_counter;
1639
+ unsigned long long tpu_process_time;
1640
+ unsigned long long tpu1_process_time;
1641
+ unsigned long long sent_api_counter;
1642
+ unsigned long long completed_api_counter;
1643
+ #endif
1644
+ } bm_profile_t;
1645
+ /**
1646
+ * @name bm_get_profile
1647
+ * @brief To get the profile data at the moment
1648
+ * @ingroup bmlib_runtime
1649
+ *
1650
+ * @param [in] handle The device handle
1651
+ * @param [out] profile The result profile data
1652
+ * @retval BM_SUCCESS Succeeds.
1653
+ * Other code Fails.
1654
+ */
1655
+ DECL_EXPORT bm_status_t bm_get_profile(bm_handle_t handle, bm_profile_t *profile);
1656
+
1657
+ typedef struct bootloader_version{
1658
+ char *bl1_version;
1659
+ char *bl2_version;
1660
+ char *bl31_version;
1661
+ char *uboot_version;
1662
+ } boot_loader_version;
1663
+
1664
+ /**
1665
+ * @name bm_get_boot_loader_version
1666
+ * @brief To get the boot_loader_version
1667
+ * @ingroup bmlib_runtime
1668
+ *
1669
+ * @param [in] handle The device handle
1670
+ * @param [out] version The result version data
1671
+ * @retval BM_SUCCESS Succeeds.
1672
+ * Other code Fails.
1673
+ */
1674
+ DECL_EXPORT bm_status_t bm_get_boot_loader_version(bm_handle_t handle, boot_loader_version *version);
1675
+
1676
+ /**
1677
+ * @name bm_get_vpu_instant_usage
1678
+ * @brief To get vpu usage
1679
+ * @ingroup bmlib_runtime
1680
+ *
1681
+ * @param [in] handle The device handle
1682
+ * @param [out] smi_attr The result vpu usage
1683
+ * @retval BM_SUCCESS Succeeds.
1684
+ * Other code Fails.
1685
+ */
1686
+ DECL_EXPORT bm_status_t bm_get_vpu_instant_usage(bm_handle_t handle, int *vpu_usage);
1687
+
1688
+ /**
1689
+ * @name bm_get_jpu_core_usage
1690
+ * @brief To get the jpu usage
1691
+ * @ingroup bmlib_runtime
1692
+ *
1693
+ * @param [in] handle The device handle
1694
+ * @param [out] smi_attr The result jpu usage
1695
+ * @retval BM_SUCCESS Succeeds.
1696
+ * Other code Fails.
1697
+ */
1698
+ DECL_EXPORT bm_status_t bm_get_jpu_core_usage(bm_handle_t handle, int *jpu_usage);
1699
+
1700
+ /**
1701
+ * @name bm_get_vpp_instant_usage
1702
+ * @brief To get the vpp usage
1703
+ * @ingroup bmlib_runtime
1704
+ *
1705
+ * @param [in] handle The device handle
1706
+ * @param [out] smi_attr The result vpp usage
1707
+ * @retval BM_SUCCESS Succeeds.
1708
+ * Other code Fails.
1709
+ */
1710
+ DECL_EXPORT bm_status_t bm_get_vpp_instant_usage(bm_handle_t handle, int *vpp_usage);
1711
+ /**
1712
+ * @name bm_get_last_api_process_time_us
1713
+ * @brief This function is abandoned.
1714
+ */
1715
+ #ifdef __linux__
1716
+ DECL_EXPORT bm_status_t bm_get_last_api_process_time_us(bm_handle_t handle,
1717
+ unsigned long *time_us);
1718
+ #else
1719
+ DECL_EXPORT bm_status_t bm_get_last_api_process_time_us(bm_handle_t handle,
1720
+ unsigned long long *time_us);
1721
+ #endif
1722
+ /*******************tpu clock and module reset releated functions *************/
1723
+
1724
+ /**
1725
+ * @name bm_set_clk_tpu_freq
1726
+ * @brief To set the clock frequency of TPU (only valid in PCIE mode).
1727
+ * @ingroup bmlib_runtime
1728
+ *
1729
+ * @param [in] handle The device handle
1730
+ * @param [in] freq The TPU target frequency
1731
+ * @retval BM_SUCCESS Succeeds.
1732
+ * Other code Fails.
1733
+ */
1734
+ DECL_EXPORT bm_status_t bm_set_clk_tpu_freq(bm_handle_t handle, int freq);
1735
+
1736
+ /**
1737
+ * @name bm_get_clk_tpu_freq
1738
+ * @brief To get the clock frequency of TPU
1739
+ * @ingroup bmlib_runtime
1740
+ *
1741
+ * @param [in] handle The device handle
1742
+ * @param [out] freq The current TPU frequency
1743
+ * @retval BM_SUCCESS Succeeds.
1744
+ * Other code Fails.
1745
+ */
1746
+ DECL_EXPORT bm_status_t bm_get_clk_tpu_freq(bm_handle_t handle, int *freq);
1747
+
1748
+ /*******************misc functions ********************************************/
1749
+ struct bm_misc_info {
1750
+ int pcie_soc_mode; /*0---pcie; 1---soc*/
1751
+ int ddr_ecc_enable; /*0---disable; 1---enable*/
1752
+ long long ddr0a_size;
1753
+ long long ddr0b_size;
1754
+ long long ddr1_size;
1755
+ long long ddr2_size;
1756
+ unsigned int chipid;
1757
+ #define BM1682_CHIPID_BIT_MASK (0X1 << 0)
1758
+ #define BM1684_CHIPID_BIT_MASK (0X1 << 1)
1759
+ #define BM1686_CHIPID_BIT_MASK (0X1 << 2)
1760
+ #ifdef __linux__
1761
+ unsigned long chipid_bit_mask;
1762
+ #else
1763
+ unsigned long long chipid_bit_mask;
1764
+ #endif
1765
+ unsigned int driver_version;
1766
+ int domain_bdf;
1767
+ int board_version; /*hardware board version [23:16]-mcu sw version, [15:8]-board type, [7:0]-hw version*/
1768
+ int a53_enable;
1769
+ int dyn_enable;
1770
+ };
1771
+
1772
+ /**
1773
+ * @name bm_get_misc_info
1774
+ * @brief To get miscellaneous information of the device
1775
+ * @ingroup bmlib_runtime
1776
+ *
1777
+ * @param [in] handle The device handle
1778
+ * @param [out] pmisc_info The fetched misc info
1779
+ * @retval BM_SUCCESS Succeeds.
1780
+ * Other code Fails.
1781
+ */
1782
+ DECL_EXPORT bm_status_t bm_get_misc_info(bm_handle_t handle, struct bm_misc_info *pmisc_info);
1783
+
1784
+ /**
1785
+ * @name bm_get_chipid
1786
+ * @brief To get the chipid of the device. (0x1682 / 0x1684 / 0x168?)
1787
+ * @ingroup bmlib_runtime
1788
+ *
1789
+ * @param [in] handle The device handle
1790
+ * @param [out] p_chipid The chip id of the device
1791
+ * @retval BM_SUCCESS Succeeds.
1792
+ * Other code Fails.
1793
+ */
1794
+ DECL_EXPORT bm_status_t bm_get_chipid(bm_handle_t handle, unsigned int *p_chipid);
1795
+
1796
+ #define BMLIB_LOG_QUIET -8
1797
+ #define BMLIB_LOG_PANIC 0
1798
+ #define BMLIB_LOG_FATAL 8
1799
+ #define BMLIB_LOG_ERROR 16
1800
+ #define BMLIB_LOG_WARNING 24
1801
+ #define BMLIB_LOG_INFO 32
1802
+ #define BMLIB_LOG_VERBOSE 40
1803
+ #define BMLIB_LOG_DEBUG 48
1804
+ #define BMLIB_LOG_TRACE 56
1805
+
1806
+ /**
1807
+ * @name bmlib_log_get_level
1808
+ * @brief To get the bmlib log level
1809
+ * @ingroup bmlib_log
1810
+ *
1811
+ * @param void
1812
+ * @retval The level of bmlib log level
1813
+ */
1814
+ DECL_EXPORT int bmlib_log_get_level(void);
1815
+
1816
+ /**
1817
+ * @name bmlib_log_set_level
1818
+ * @brief To set the bmlib log level
1819
+ * @ingroup bmlib_log
1820
+ *
1821
+ * @param [in] level The level of bmlib log level
1822
+ * @retval void
1823
+ */
1824
+ DECL_EXPORT void bmlib_log_set_level(int level);
1825
+
1826
+ /**
1827
+ * @name bmlib_log_set_callback
1828
+ * @brief To set callback to get bmlib log
1829
+ * @ingroup bmlib_log
1830
+ *
1831
+ * @param [in] callback The callback function to get bmlib log
1832
+ * @retval void
1833
+ */
1834
+ DECL_EXPORT void bmlib_log_set_callback(void (*callback)(const char*, int, const char*, va_list args));
1835
+
1836
+ /**
1837
+ * @name bm_set_debug_mode
1838
+ * @brief To set the debug mode for firmware log for tpu
1839
+ * @ingroup bmlib_log
1840
+ *
1841
+ * @param [in] handle The device handle
1842
+ * @param [in] mode The debug mode of fw log, 0/1 for disable/enable log
1843
+ * @retval void
1844
+ */
1845
+ DECL_EXPORT void bm_set_debug_mode(bm_handle_t handle, int mode);
1846
+
1847
+ /**
1848
+ * @name bmlib_api_dbg_callback
1849
+ * @brief To set debug callback to get firmware log
1850
+ * @ingroup bmlib_log
1851
+ *
1852
+ * @param [in] bmlib_api_dbg_callback callback to get firmware log
1853
+ * @retval void
1854
+ */
1855
+ typedef void (*bmlib_api_dbg_callback)(int, int, int, const char*);
1856
+ // api, result, duratioin, log, third int for api duration for future
1857
+ DECL_EXPORT void bmlib_set_api_dbg_callback(bmlib_api_dbg_callback callback);
1858
+
1859
+ /**
1860
+ * @name bmcpu_get_cpu_status
1861
+ * @brief Get bmcpu status
1862
+ * @ingroup bmlib_log
1863
+ *
1864
+ * @param [in] handle The device handle
1865
+ * @retval BMCPU_RUNNING bmcpu is running.
1866
+ * Other code Fails.
1867
+ */
1868
+ DECL_EXPORT bm_cpu_status_t bmcpu_get_cpu_status(bm_handle_t handle);
1869
+
1870
+ /**
1871
+ * @name bmcpu_start_cpu
1872
+ * @brief Start cpu in pcie mode
1873
+ * @ingroup bmlib_log
1874
+ *
1875
+ * @param [in] handle The device handle
1876
+ * @param [in] boot_file Fip file
1877
+ * @param [in] core_file Itb file
1878
+ * @retval BM_SUCCESS Succeeds.
1879
+ * Other code Fails.
1880
+ */
1881
+ DECL_EXPORT bm_status_t bmcpu_start_cpu(bm_handle_t handle, char *boot_file, char *core_file);
1882
+
1883
+ /**
1884
+ * @name bmcpu_open_process
1885
+ * @brief Open a process to do some work
1886
+ * @ingroup bmlib_log
1887
+ *
1888
+ * @param [in] handle The device handle
1889
+ * @param [in] flags Process flags
1890
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
1891
+ * @retval >= 0 process handle
1892
+ * < 0 Other code Fails.
1893
+ */
1894
+ DECL_EXPORT int bmcpu_open_process(bm_handle_t handle, unsigned int flags, int timeout);
1895
+
1896
+ /**
1897
+ * @name bmcpu_load_library
1898
+ * @brief Load a share library(so) to specific process
1899
+ * @ingroup bmlib_log
1900
+ *
1901
+ * @param [in] handle The device handle
1902
+ * @param [in] process_handle Process handle
1903
+ * @param [in] library_file Library file path
1904
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
1905
+ * @retval BM_SUCCESS Succeeds.
1906
+ * Other code Fails.
1907
+ */
1908
+ DECL_EXPORT bm_status_t bmcpu_load_library(bm_handle_t handle, int process_handle, char *library_file, int timeout);
1909
+
1910
+ /**
1911
+ * @name bmcpu_unload_library
1912
+ * @brief Load a share library(so) to specific process
1913
+ * @ingroup bmlib_log
1914
+ *
1915
+ * @param [in] handle The device handle
1916
+ * @param [in] process_handle Process handle
1917
+ * @param [in] library_file Library file path
1918
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
1919
+ * @retval BM_SUCCESS Succeeds.
1920
+ * Other code Fails.
1921
+ */
1922
+ DECL_EXPORT bm_status_t bmcpu_unload_library(bm_handle_t handle, int process_handle, char *library_file, int timeout);
1923
+
1924
+ /**
1925
+ * @name bmcpu_exec_function
1926
+ * @brief Execute specific function in specific process
1927
+ * @ingroup bmlib_log
1928
+ *
1929
+ * @param [in] handle The device handle
1930
+ * @param [in] process_handle Process handle
1931
+ * @param [in] function_name Function name
1932
+ * @param [in] function_param Function parameters
1933
+ * @param [in] param_size Parameters size in bytes
1934
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
1935
+ * @retval 0 success.
1936
+ * >0 code fails from bmlib
1937
+ * <0 code fails from function
1938
+ */
1939
+ DECL_EXPORT int bmcpu_exec_function(bm_handle_t handle,
1940
+ int process_handle,
1941
+ char *function_name,
1942
+ void *function_param,
1943
+ unsigned int param_size,
1944
+ int timeout);
1945
+
1946
+ #define BMCPU_EXEC_OPT_NO_FLUSH_CACHE 1
1947
+ /**
1948
+ * @name bmcpu_exec_function_ext
1949
+ * @brief Execute specific function in specific process
1950
+ * @ingroup bmlib_log
1951
+ *
1952
+ * @param [in] handle The device handle
1953
+ * @param [in] process_handle Process handle
1954
+ * @param [in] function_name Function name
1955
+ * @param [in] function_param Function parameters
1956
+ * @param [in] param_size Parameters size in bytes
1957
+ * @param [in] opt exec options
1958
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
1959
+ * @retval 0 success.
1960
+ * >0 code fails from bmlib
1961
+ * <0 code fails from function
1962
+ */
1963
+ DECL_EXPORT int bmcpu_exec_function_ext(bm_handle_t handle,
1964
+ int process_handle,
1965
+ char *function_name,
1966
+ void *function_param,
1967
+ unsigned int param_size,
1968
+ unsigned int opt,
1969
+ int timeout);
1970
+
1971
+ /**
1972
+ * @name bmcpu_exec_function_async
1973
+ * @brief Execute specific function in specific process asynchronous
1974
+ * user should use bm_query_exec_function_result to query result
1975
+ * @ingroup bmlib_log
1976
+ *
1977
+ * @param [in] handle The device handle
1978
+ * @param [in] process_handle Process handle
1979
+ * @param [in] function_name Function name
1980
+ * @param [in] function_param Function param
1981
+ * @param [in] param_size Param size in bytes
1982
+ * @retval BM_SUCCESS Succeeds.
1983
+ * Other code Fails.
1984
+ */
1985
+ DECL_EXPORT bm_status_t bmcpu_exec_function_async(bm_handle_t handle,
1986
+ int process_handle,
1987
+ char *function_name,
1988
+ void *function_param,
1989
+ unsigned int param_size,
1990
+ unsigned long long *api_handle);
1991
+
1992
+ /**
1993
+ * @name bmcpu_exec_function_async_ext
1994
+ * @brief Execute specific function in specific process asynchronous
1995
+ * user should use bm_query_exec_function_result to query result
1996
+ * @ingroup bmlib_log
1997
+ *
1998
+ * @param [in] handle The device handle
1999
+ * @param [in] process_handle Process handle
2000
+ * @param [in] function_name Function name
2001
+ * @param [in] function_param Function param
2002
+ * @param [in] param_size Param size in bytes
2003
+ * @param [in] opt exec options
2004
+ * @retval BM_SUCCESS Succeeds.
2005
+ * Other code Fails.
2006
+ */
2007
+ DECL_EXPORT bm_status_t bmcpu_exec_function_async_ext(bm_handle_t handle,
2008
+ int process_handle,
2009
+ char *function_name,
2010
+ void *function_param,
2011
+ unsigned int param_size,
2012
+ unsigned int opt,
2013
+ unsigned long long *api_handle);
2014
+
2015
+ /**
2016
+ * @name bmcpu_query_exec_function_result
2017
+ * @brief Query result from function called by bm_exec_function
2018
+ * @ingroup bmlib_log
2019
+ *
2020
+ * @param [in] handle The device handle
2021
+ * @param [in] api_handle Api handle return by bm_exec_function_async
2022
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
2023
+ * @retval 0 success.
2024
+ * >0 code fails from bmlib
2025
+ * <0 code fails from function
2026
+ */
2027
+ DECL_EXPORT int bmcpu_query_exec_function_result(bm_handle_t handle, unsigned long long api_handle, int timeout);
2028
+
2029
+ /**
2030
+ * @name bmcpu_map_phys_addr
2031
+ * @brief Map physical address in specific process
2032
+ * @ingroup bmlib_log
2033
+ *
2034
+ * @param [in] handle The device handle
2035
+ * @param [in] process_handle Process handle
2036
+ * @param [in] phys_addr Physical address
2037
+ * @param [in] size Map size in bytes
2038
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
2039
+ * @retval >0 virtual address
2040
+ * 0 fails
2041
+ */
2042
+ DECL_EXPORT void *bmcpu_map_phys_addr(bm_handle_t handle, int process_handle, void *phys_addr, unsigned int size, int timeout);
2043
+
2044
+ /**
2045
+ * @name bmcpu_unmap_phys_addr
2046
+ * @brief Unmap physical address in specific process
2047
+ * @ingroup bmlib_log
2048
+ *
2049
+ * @param [in] handle The device handle
2050
+ * @param [in] process_handle Process handle
2051
+ * @param [in] phys_addr Physical address
2052
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
2053
+ * @retval <0 fail
2054
+ * 0 success
2055
+ */
2056
+ DECL_EXPORT bm_status_t bmcpu_unmap_phys_addr(bm_handle_t handle, int process_handle, void *phys_addr, int timeout);
2057
+
2058
+ /**
2059
+ * @name bmcpu_close_process
2060
+ * @brief Close process
2061
+ * @ingroup bmlib_log
2062
+ *
2063
+ * @param [in] handle The device handle
2064
+ * @param [in] process_handle Process handle
2065
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
2066
+ * @retval BM_SUCCESS Succeeds.
2067
+ * Other code Fails.
2068
+ */
2069
+ DECL_EXPORT bm_status_t bmcpu_close_process(bm_handle_t handle, int process_handle, int timeout);
2070
+
2071
+ /**
2072
+ * @name bmcpu_reset_cpu
2073
+ * @brief Reset cpu in pcie mode
2074
+ * @ingroup bmlib_log
2075
+ *
2076
+ * @param [in] handle The device handle
2077
+ * @retval BM_SUCCESS Succeeds.
2078
+ * Other code Fails.
2079
+ */
2080
+ DECL_EXPORT bm_status_t bmcpu_reset_cpu(bm_handle_t handle);
2081
+
2082
+ /**
2083
+ * @name bm_enable_perf_monitor
2084
+ * @brief enable perf monitor to get gdma and tpu performance data
2085
+ * @ingroup bmlib_perf
2086
+ *
2087
+ * @param [in] handle The device handle
2088
+ * @param [in] perf_monitor The monitor to perf
2089
+ * @retval BM_SUCCESS Succeeds.
2090
+ * Other code Fails.
2091
+ */
2092
+ DECL_EXPORT bm_status_t bm_enable_perf_monitor(bm_handle_t handle, bm_perf_monitor_t *perf_monitor);
2093
+
2094
+ /**
2095
+ * @name bm_disable_perf_monitor
2096
+ * @brief disable perf monitor to get gdma and tpu performance data
2097
+ * @ingroup bmlib_perf
2098
+ *
2099
+ * @param [in] handle The device handle
2100
+ * @param [in] perf_monitor The monitor to perf
2101
+ * @retval BM_SUCCESS Succeeds.
2102
+ * Other code Fails.
2103
+ */
2104
+ DECL_EXPORT bm_status_t bm_disable_perf_monitor(bm_handle_t handle, bm_perf_monitor_t *perf_monitor);
2105
+
2106
+ /**
2107
+ * @name bmcpu_set_log
2108
+ * @brief Set cpu log options
2109
+ * @ingroup bmlib_log
2110
+ *
2111
+ * @param [in] handle The device handle
2112
+ * @param [in] log_level 0: DEBUG 1:INFO 2:WARN 3:ERROR 4:FATAL
2113
+ * @param [in] log_to_console 1: YES 0: No
2114
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
2115
+ * @retval BM_SUCCESS Succeeds.
2116
+ * Other code Fails.
2117
+ */
2118
+ DECL_EXPORT bm_status_t bmcpu_set_log(bm_handle_t handle, unsigned int log_level, unsigned int log_to_console, int timeout);
2119
+
2120
+ /**
2121
+ * @name bmcpu_get_log
2122
+ * @brief Get cpu log file
2123
+ * @ingroup bmlib_log
2124
+ *
2125
+ * @param [in] handle The device handle
2126
+ * @param [in] process_handle Process handle
2127
+ * @param [in] log_file save log as file
2128
+ * @param [in] timeout Timeout value in millisecond, -1 means default value of this device
2129
+ * @retval BM_SUCCESS Succeeds.
2130
+ * Other code Fails.
2131
+ */
2132
+ DECL_EXPORT bm_status_t bmcpu_get_log(bm_handle_t handle, int process_handle, char *log_file, int timeout);
2133
+
2134
+ /**
2135
+ * @name bmcpu_sync_time
2136
+ * @brief Sync device cpu time with host
2137
+ * @ingroup bmlib_log
2138
+ *
2139
+ * @param [in] handle The device handle
2140
+ * @retval BM_SUCCESS Succeeds.
2141
+ * Other code Fails.
2142
+ */
2143
+ DECL_EXPORT bm_status_t bmcpu_sync_time(bm_handle_t handle);
2144
+
2145
+ /*******************trace and profile releated functions **********************/
2146
+ struct bm_heap_stat {
2147
+ unsigned int mem_total;
2148
+ unsigned int mem_avail;
2149
+ unsigned int mem_used;
2150
+ };
2151
+
2152
+ typedef struct bm_heap_stat_byte {
2153
+ unsigned int heap_id;
2154
+ unsigned long long mem_total;
2155
+ unsigned long long mem_avail;
2156
+ unsigned long long mem_used;
2157
+ unsigned long long mem_start_addr;
2158
+ } bm_heap_stat_byte_t;
2159
+
2160
+ typedef struct bm_dev_stat {
2161
+ int mem_total;
2162
+ int mem_used;
2163
+ int tpu_util;
2164
+ int heap_num;
2165
+ struct bm_heap_stat heap_stat[4];
2166
+ } bm_dev_stat_t;
2167
+
2168
+ /**
2169
+ * @name bm_get_stat
2170
+ * @brief To get the stat data at the moment
2171
+ * @ingroup bmlib_runtime
2172
+ *
2173
+ * @param [in] handle The device handle
2174
+ * @param [out] profile The result stat data
2175
+ * @retval BM_SUCCESS Succeeds.
2176
+ * Other code Fails.
2177
+ */
2178
+ DECL_EXPORT bm_status_t bm_get_stat(bm_handle_t handle, bm_dev_stat_t *stat);
2179
+
2180
+ /**
2181
+ * @name bm_get_gmem_heap_id
2182
+ * @brief To get the heap id of allocated global memory
2183
+ * @ingroup bmlib_runtime
2184
+ *
2185
+ * @param [in] handle The device handle
2186
+ * @param [in] pmem The allocted global memory
2187
+ * @param [out] heapid The result of get heap id
2188
+ * @retval BM_SUCCESS Succeeds.
2189
+ * Other code Fails.
2190
+ */
2191
+
2192
+ DECL_EXPORT bm_status_t bm_get_gmem_heap_id(bm_handle_t handle, bm_device_mem_t *pmem, unsigned int *heapid);
2193
+
2194
+ /**
2195
+ * @name sg_get_gmem_heap_id
2196
+ * @brief To get the heap id of allocated global memory
2197
+ * @ingroup bmlib_runtime
2198
+ *
2199
+ * @param [in] handle The device handle
2200
+ * @param [in] pmem The allocted global memory
2201
+ * @param [out] heapid The result of get heap id
2202
+ * @retval BM_SUCCESS Succeeds.
2203
+ * Other code Fails.
2204
+ */
2205
+
2206
+ DECL_EXPORT bm_status_t sg_get_gmem_heap_id(bm_handle_t handle, sg_device_mem_t *pmem, unsigned int *heapid);
2207
+
2208
+ /**
2209
+ * @name bm_get_gmem_total_heap_num
2210
+ * @brief To get the total heap num of global memory
2211
+ * @ingroup bmlib_runtime
2212
+ *
2213
+ * @param [in] handle The device handle
2214
+ * @param [in] heap_num The result of get total num
2215
+ * @retval BM_SUCCESS Succeeds.
2216
+ * Other code Fails.
2217
+ */
2218
+ DECL_EXPORT bm_status_t bm_get_gmem_total_heap_num(bm_handle_t handle, unsigned int *heap_num);
2219
+
2220
+ /**
2221
+ * @name bm_get_gmem_heap_stat_byte_by_id
2222
+ * @brief To get the heap stat by heap id
2223
+ * @ingroup bmlib_runtime
2224
+ *
2225
+ * @param [in] handle The device handle
2226
+ * @param [in] heap_id The heap index to get heap status
2227
+ * @param [out] pheap_byte The result of get heap status
2228
+ * @retval BM_SUCCESS Succeeds.
2229
+ * Other code Fails.
2230
+ */
2231
+ DECL_EXPORT bm_status_t bm_get_gmem_heap_stat_byte_by_id(bm_handle_t handle, bm_heap_stat_byte_t *pheap_byte, unsigned int heap_id);
2232
+
2233
+ DECL_EXPORT bm_status_t bm_load_firmware(
2234
+ bm_handle_t handle,
2235
+ const char *firmware_tcm,
2236
+ const char *firmware_ddr);
2237
+
2238
+ #define bmkernel_load_firmware okkernel_load_firmware
2239
+ DECL_EXPORT bm_status_t okkernel_load_firmware(
2240
+ bm_handle_t handle,
2241
+ const char *firmware_tcm,
2242
+ const char *firmware_ddr);
2243
+
2244
+ DECL_EXPORT bm_status_t okkernel_launch_async(
2245
+ bm_handle_t handle,
2246
+ const char *func_name,
2247
+ const void *args,
2248
+ unsigned int size);
2249
+
2250
+ DECL_EXPORT bm_status_t okkernel_launch_sync(
2251
+ bm_handle_t handle,
2252
+ const char *func_name,
2253
+ const void *args,
2254
+ unsigned int size);
2255
+
2256
+ DECL_EXPORT bm_status_t tpu_kernel_launch_sync(
2257
+ bm_handle_t handle,
2258
+ const char *func_name,
2259
+ const void *args,
2260
+ unsigned int size);
2261
+
2262
+ DECL_EXPORT bm_status_t okkernel_sync(bm_handle_t handle);
2263
+
2264
+ /**
2265
+ * @name bmkernel_launch
2266
+ * @brief send api to device and launch function
2267
+ * @ingroup bmlib_runtime
2268
+ *
2269
+ * @param [in] handle The device handle
2270
+ * @param [in] api cmd struct pointer
2271
+ * @param [in] api cmd length
2272
+ * @retval BM_SUCCESS Succeeds.
2273
+ * Other code Fails.
2274
+ */
2275
+ DECL_EXPORT bm_status_t bmkernel_launch(bm_handle_t handle, const void *args,
2276
+ unsigned int size);
2277
+
2278
+ /**
2279
+ * @name bmkernel_load_lookup_table
2280
+ * @brief load lookup table to l2-sram
2281
+ * @ingroup bmlib_runtime
2282
+ *
2283
+ * @param [in] handle The device handle
2284
+ * @param [in] table which loaded to l2-sram
2285
+ * @param [in] table size
2286
+ * @retval BM_SUCCESS Succeeds.
2287
+ * Other code Fails.
2288
+ */
2289
+ DECL_EXPORT bm_status_t bmkernel_load_lookup_table(bm_handle_t handle, const void* table, unsigned int size);
2290
+
2291
+ /*******************device management api functions ********************************************/
2292
+ /**
2293
+ * @name bm_get_tpu_current
2294
+ * @brief get tpu current
2295
+ * @ingroup bmlib_runtime
2296
+ *
2297
+ * @param [in] handle The device handle
2298
+ * @param [out] tpuc(mA) The pointer for tpu current
2299
+ * @retval BM_SUCCESS Succeeds.
2300
+ * Other code Fails.
2301
+ */
2302
+ DECL_EXPORT bm_status_t bm_get_tpu_current(bm_handle_t handle, unsigned int *tpuc);
2303
+
2304
+ /**
2305
+ * @name bm_get_board_max_power
2306
+ * @brief get board support max power
2307
+ * @ingroup bmlib_runtime
2308
+ *
2309
+ * @param [in] handle The device handle
2310
+ * @param [out] maxp The pointer for maxp
2311
+ * @retval BM_SUCCESS Succeeds.
2312
+ * Other code Fails.
2313
+ */
2314
+ DECL_EXPORT bm_status_t bm_get_board_max_power(bm_handle_t handle, unsigned int *maxp);
2315
+
2316
+ /**
2317
+ * @name bm_get_board_power
2318
+ * @brief get board power
2319
+ * @ingroup bmlib_runtime
2320
+ *
2321
+ * @param [in] handle The device handle
2322
+ * @param [out] boardp The pointer for boardp
2323
+ * @retval BM_SUCCESS Succeeds.
2324
+ * Other code Fails.
2325
+ */
2326
+ DECL_EXPORT bm_status_t bm_get_board_power(bm_handle_t handle, unsigned int *boardp);
2327
+
2328
+ /**
2329
+ * @name bm_get_fan_speed
2330
+ * @brief get board fan speed
2331
+ * @ingroup bmlib_runtime
2332
+ *
2333
+ * @param [in] handle The device handle
2334
+ * @param [out] fan The pointer for fan speed
2335
+ * @retval BM_SUCCESS Succeeds.
2336
+ * Other code Fails.
2337
+ */
2338
+ DECL_EXPORT bm_status_t bm_get_fan_speed(bm_handle_t handle, unsigned int *fan);
2339
+
2340
+ /**
2341
+ * @name bm_get_ecc_correct_num
2342
+ * @brief get ecc_correct_num
2343
+ * @ingroup device management api
2344
+ *
2345
+ * @param [in] handle The device handle
2346
+ * @param [out] ecc_correct_num
2347
+ * @retval BM_SUCCESS Succeeds.
2348
+ * Other code Fails.
2349
+ */
2350
+ #ifdef __linux__
2351
+ DECL_EXPORT bm_status_t bm_get_ecc_correct_num(bm_handle_t handle, unsigned long *ecc_correct_num);
2352
+ #else
2353
+ DECL_EXPORT bm_status_t bm_get_ecc_correct_num(bm_handle_t handle, unsigned long long *ecc_correct_num);
2354
+ #endif
2355
+ /**
2356
+ * @name bm_get_12v_atx
2357
+ * @brief get atx_12v
2358
+ * @ingroup device management api
2359
+ *
2360
+ * @param [in] handle The device handle
2361
+ * @param [out] atx_12v
2362
+ * @retval BM_SUCCESS Succeeds.
2363
+ * Other code Fails.
2364
+ */
2365
+ DECL_EXPORT bm_status_t bm_get_12v_atx(bm_handle_t handle, int *atx_12v);
2366
+
2367
+ /**
2368
+ * @name bm_get_product_sn
2369
+ * @brief get SE5 sn
2370
+ * @ingroup device management api
2371
+ *
2372
+ * @param [out] product_sn
2373
+ * @retval BM_SUCCESS Succeeds.
2374
+ * Other code Fails.
2375
+ */
2376
+ DECL_EXPORT bm_status_t bm_get_product_sn(char *product_sn);
2377
+
2378
+ /**
2379
+ * @name bm_get_sn
2380
+ * @brief get sn
2381
+ * @ingroup device management api
2382
+ *
2383
+ * @param [in] handle The device handle
2384
+ * @param [out] sn
2385
+ * @retval BM_SUCCESS Succeeds.
2386
+ * Other code Fails.
2387
+ */
2388
+ DECL_EXPORT bm_status_t bm_get_sn(bm_handle_t handle, char *sn);
2389
+
2390
+ /**
2391
+ * @name bm_get_status
2392
+ * @brief get chip status
2393
+ * @ingroup device management api
2394
+ *
2395
+ * @param [in] handle The device handle
2396
+ * @param [out] status The board error status, each bit represents an error state
2397
+ * status == 0x0, borad is nornal, staus > 0, borad is abnormal;
2398
+ * bit0 == 1, tpu is hang
2399
+ * bit1 == 1, pcie link abnormal
2400
+ * bit2 == 1, board temperature is too high
2401
+ * @retval BM_SUCCESS Succeeds.
2402
+ * Other code Fails.
2403
+ */
2404
+ DECL_EXPORT bm_status_t bm_get_status(bm_handle_t handle, int *status);
2405
+
2406
+ /**
2407
+ * @name bm_get_tpu_maxclk
2408
+ * @brief get tpu_maxclk
2409
+ * @ingroup device management api
2410
+ *
2411
+ * @param [in] handle The device handle
2412
+ * @param [out] tpu_maxclk
2413
+ * @retval BM_SUCCESS Succeeds.
2414
+ * Other code Fails.
2415
+ */
2416
+ DECL_EXPORT bm_status_t bm_get_tpu_maxclk(bm_handle_t handle, unsigned int *tpu_maxclk);
2417
+
2418
+ /**
2419
+ * @name bm_get_tpu_minclk
2420
+ * @brief get tpu_minclk
2421
+ * @ingroup device management api
2422
+ *
2423
+ * @param [in] handle The device handle
2424
+ * @param [out] tpu_minclk
2425
+ * @retval BM_SUCCESS Succeeds.
2426
+ * Other code Fails.
2427
+ */
2428
+ DECL_EXPORT bm_status_t bm_get_tpu_minclk(bm_handle_t handle, unsigned int *tpu_minclk);
2429
+
2430
+ /**
2431
+ * @name bm_get_driver_version
2432
+ * @brief get driver version
2433
+ * @ingroup device management api
2434
+ *
2435
+ * @param [in] handle The device handle
2436
+ * @param [out] driver_version
2437
+ * @retval BM_SUCCESS Succeeds.
2438
+ * Other code Fails.
2439
+ */
2440
+ DECL_EXPORT bm_status_t bm_get_driver_version(bm_handle_t handle, int *driver_version);
2441
+
2442
+ /**
2443
+ * @name bm_get_board_name
2444
+ * @brief get device board name
2445
+ * @ingroup device management api
2446
+ *
2447
+ * @param [in] handle The device handle
2448
+ * @param [out] board_name
2449
+ * @retval BM_SUCCESS Succeeds.
2450
+ * Other code Fails.
2451
+ */
2452
+ DECL_EXPORT bm_status_t bm_get_board_name(bm_handle_t handle, char *name);
2453
+
2454
+ /**
2455
+ * @name bm_get_board_temp
2456
+ * @brief get board temperature
2457
+ * @ingroup device management api
2458
+ *
2459
+ * @param [in] handle The device handle
2460
+ * @param [out] board_temp
2461
+ * @retval BM_SUCCESS Succeeds.
2462
+ * Other code Fails.
2463
+ */
2464
+ DECL_EXPORT bm_status_t bm_get_board_temp(bm_handle_t handle, unsigned int *board_temp);
2465
+
2466
+ /**
2467
+ * @name bm_get_chip_temp
2468
+ * @brief get chip temperature
2469
+ * @ingroup device management api
2470
+ *
2471
+ * @param [in] handle The device handle
2472
+ * @param [out] chip_temp
2473
+ * @retval BM_SUCCESS Succeeds.
2474
+ * Other code Fails.
2475
+ */
2476
+ DECL_EXPORT bm_status_t bm_get_chip_temp(bm_handle_t handle, unsigned int *chip_temp);
2477
+
2478
+ /**
2479
+ * @name bm_get_tpu_power
2480
+ * @brief get TPU power
2481
+ * @ingroup device management api
2482
+ *
2483
+ * @param [in] handle The device handle
2484
+ * @param [out] tpu_power
2485
+ * @retval BM_SUCCESS Succeeds.
2486
+ * Other code Fails.
2487
+ */
2488
+ DECL_EXPORT bm_status_t bm_get_tpu_power(bm_handle_t handle, float *tpu_power);
2489
+
2490
+ /**
2491
+ * @name bm_get_tpu_volt
2492
+ * @brief get TPU voltage
2493
+ * @ingroup device management api
2494
+ *
2495
+ * @param [in] handle The device handle
2496
+ * @param [out] tpu_volt
2497
+ * @retval BM_SUCCESS Succeeds.
2498
+ * Other code Fails.
2499
+ */
2500
+ DECL_EXPORT bm_status_t bm_get_tpu_volt(bm_handle_t handle, unsigned int *tpu_volt);
2501
+
2502
+ /**
2503
+ * @name bm_get_card_id
2504
+ * @brief get card id
2505
+ * @ingroup device management api
2506
+ *
2507
+ * @param [in] handle The device handle
2508
+ * @param [out] card_id
2509
+ * @retval BM_SUCCESS Succeeds.
2510
+ * Other code Fails.
2511
+ */
2512
+ DECL_EXPORT bm_status_t bm_get_card_id(bm_handle_t handle, unsigned int *card_id);
2513
+
2514
+ /**
2515
+ * @name bm_get_card_num
2516
+ * @brief get card number
2517
+ * @ingroup device management api
2518
+ *
2519
+ * @param [in] handle The device handle
2520
+ * @param [out] card_id
2521
+ * @retval BM_SUCCESS Succeeds.
2522
+ * Other code Fails.
2523
+ */
2524
+ DECL_EXPORT bm_status_t bm_get_card_num(unsigned int *card_num);
2525
+
2526
+ /**
2527
+ * @name bm_get_chip_num_from_card
2528
+ * @brief get chip number and start chip id from card
2529
+ * @ingroup device management api
2530
+ *
2531
+ * @param [in] handle The device handle
2532
+ * @param [out] chip_num
2533
+ * @param [out] dev_start_index
2534
+ * @retval BM_SUCCESS Succeeds.
2535
+ * Other code Fails.
2536
+ */
2537
+ DECL_EXPORT bm_status_t bm_get_chip_num_from_card(unsigned int card_id, unsigned int *chip_num, unsigned int *dev_start_index);
2538
+
2539
+ /**
2540
+ * @name bm_get_dynfreq_status
2541
+ * @brief get chip dynamic freq status
2542
+ * @ingroup device management api
2543
+ *
2544
+ * @param [in] handle The device handle
2545
+ * @param [out] dynfreq_status
2546
+ * @retval BM_SUCCESS Succeeds.
2547
+ * Other code Fails.
2548
+ */
2549
+ DECL_EXPORT bm_status_t bm_get_dynfreq_status(bm_handle_t handle, int *dynfreq_status);
2550
+
2551
+ /**
2552
+ * @name bm_change_dynfreq_status
2553
+ * @brief change(enable/disable) chip dynamic freq status
2554
+ * @ingroup device management api
2555
+ *
2556
+ * @param [in] handle The device handle
2557
+ * @param [in] new_status
2558
+ * @retval BM_SUCCESS Succeeds.
2559
+ * Other code Fails.
2560
+ */
2561
+ DECL_EXPORT bm_status_t bm_change_dynfreq_status(bm_handle_t handle, int new_status);
2562
+
2563
+ /**
2564
+ * @name bm_get_tpu_scalar_num
2565
+ * @brief To get the core number of TPU scalar
2566
+ * @ingroup bmlib_runtime
2567
+ *
2568
+ * @param [in] handle The device handle
2569
+ * @param [out] core_num The core number of TPU scalar
2570
+ * @retval BM_SUCCESS Succeeds.
2571
+ * Other code Fails.
2572
+ */
2573
+ DECL_EXPORT bm_status_t bm_get_tpu_scalar_num(bm_handle_t handle, unsigned int *core_num);
2574
+
2575
+ #define bm_get_tpu_core_num bm_get_tpu_scalar_num
2576
+
2577
+ #if defined(__cplusplus)
2578
+ }
2579
+ #endif
2580
+
2581
+ #endif /* BM_RUNTIME_H_ */
ChatGLM2/support/include/bmruntime_interface.h ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*****************************************************************************
2
+ *
3
+ * Copyright (c) 2016-2026 by Sophgo Technologies Inc. All rights reserved.
4
+ *
5
+ * The material in this file is confidential and contains trade secrets
6
+ * of Sophgo Technologies Inc. This is proprietary information owned by
7
+ * Sophgo Technologies Inc. No part of this work may be disclosed,
8
+ * reproduced, copied, transmitted, or used in any way for any purpose,
9
+ * without the express written permission of Sophgo Technologies Inc.
10
+ *
11
+ *****************************************************************************/
12
+
13
+ /*****************************************************************************
14
+ * BMRuntime Interface is mainly for inference.
15
+ * Also we can use it for device computation from BMLang programming.
16
+ * Note: please use interface from bmlib_runtime.h for device memory operation.
17
+ ****************************************************************************/
18
+
19
+ #ifndef BMRUNTIME_INTERFACE_H_
20
+ #define BMRUNTIME_INTERFACE_H_
21
+
22
+ #include "bmdef.h"
23
+
24
+ #ifdef _WIN32
25
+ #define DECL_EXPORT _declspec(dllexport)
26
+ #define DECL_IMPORT _declspec(dllimport)
27
+ #else
28
+ #define DECL_EXPORT
29
+ #define DECL_IMPORT
30
+ #endif
31
+
32
+ #if defined(__cplusplus)
33
+ extern "C" {
34
+ #endif
35
+
36
+ /* --------------------------------------------------------------------------*/
37
+ /* interface for basic data type */
38
+
39
+ /* get data type byte size */
40
+ DECL_EXPORT size_t bmrt_data_type_size(bm_data_type_t dtype);
41
+
42
+ /*
43
+ dims array to bm_shape_t,
44
+ shape and dims should not be NULL, num_dims should not be larger than BM_MAX_DIMS_NUM */
45
+ DECL_EXPORT void bmrt_shape(bm_shape_t* shape, const int* dims, int num_dims);
46
+
47
+ /*
48
+ number of shape elements, shape should not be NULL and num_dims should not large than
49
+ BM_MAX_DIMS_NUM */
50
+ DECL_EXPORT uint64_t bmrt_shape_count(const bm_shape_t* shape);
51
+
52
+ /* compare whether two shape is same */
53
+ DECL_EXPORT bool bmrt_shape_is_same(const bm_shape_t* left, const bm_shape_t* right);
54
+
55
+ /*
56
+ fill a tensor with data type and shape, and st_mode = 0 as default.
57
+ tensor and p_bmrt should not be NULL, shape count should not be 0.
58
+ it will alloc device mem to tensor->device_mem, so user should bmrt_free_device(p_bmrt,
59
+ tensor->device_mem) to free it.*/
60
+ DECL_EXPORT bool bmrt_tensor(bm_tensor_t* tensor, void* p_bmrt, bm_data_type_t dtype, bm_shape_t shape);
61
+
62
+ /*
63
+ fill a tensor with data type and shape, and st_mode = 0 as default.
64
+ tensor and p_bmrt should not be NULL, shape count should not be 0.
65
+ it will alloc device mem to tensor->device_mem on devid-th device.*/
66
+ DECL_EXPORT bool bmrt_tensor_ex(bm_tensor_t* tensor, void* p_bmrt, int devid, bm_data_type_t dtype, bm_shape_t shape);
67
+
68
+ /* fill a tensor with device mem existed, tensor byte size should not large than device mem size */
69
+ DECL_EXPORT void bmrt_tensor_with_device(bm_tensor_t* tensor, bm_device_mem_t device_mem,
70
+ bm_data_type_t dtype, bm_shape_t shape);
71
+
72
+ /* get tensor bytes size, tensor should not be NULL */
73
+ DECL_EXPORT size_t bmrt_tensor_bytesize(const bm_tensor_t* tensor);
74
+
75
+ /* get tensor mem size allocated in device mem, tensor should not be NULL */
76
+ DECL_EXPORT size_t bmrt_tensor_device_size(const bm_tensor_t* tensor);
77
+
78
+ /* print net info for debug */
79
+ DECL_EXPORT void bmrt_print_network_info(const bm_net_info_t* net_info);
80
+
81
+ /* --------------------------------------------------------------------------*/
82
+ /**
83
+ * @name bmrt_create
84
+ * @brief To create the bmruntime with bm_handle.
85
+ * @ingroup bmruntime
86
+ *
87
+ * This API creates the bmruntime. It returns a void* pointer which is the pointer
88
+ * of bmruntime. Device id is set when get bm_handle;
89
+ *
90
+ * @param [in] bm_handle bm handle. It must be initialized by using bmlib.
91
+ *
92
+ * @retval void* the pointer of bmruntime
93
+ */
94
+ DECL_EXPORT void* bmrt_create(bm_handle_t bm_handle);
95
+
96
+ /* --------------------------------------------------------------------------*/
97
+ /**
98
+ * @name bmrt_create_ex
99
+ * @brief To create the bmruntime with one or more bm_handle.
100
+ * @ingroup bmruntime
101
+ *
102
+ * This API creates the bmruntime. It returns a void* pointer which is the pointer
103
+ * of bmruntime.
104
+ *
105
+ * @param [in] bm_handles bm handles. They must be initialized by using bmlib.
106
+ * @param [in] num_handles number of bm_handles.
107
+ *
108
+ * @retval void* the pointer of bmruntime
109
+ */
110
+ DECL_EXPORT void *bmrt_create_ex(bm_handle_t *bm_handles, int num_handles);
111
+
112
+ /**
113
+ * @name bmrt_destroy
114
+ * @brief To destroy the bmruntime pointer
115
+ * @ingroup bmruntime
116
+ *
117
+ * This API destroy the bmruntime.
118
+ *
119
+ * @param [in] p_bmrt Bmruntime that had been created
120
+ */
121
+ DECL_EXPORT void bmrt_destroy(void* p_bmrt);
122
+
123
+ /**
124
+ * @name bmrt_get_bm_handle
125
+ * @brief To get the BM runtime context.
126
+ * @ingroup bmruntime
127
+ *
128
+ * This API get the BM runtime context for using BMDNN, BMCV or BMLIB
129
+ *
130
+ * @param [in] p_bmrt Bmruntime that had been created
131
+ */
132
+ DECL_EXPORT void * bmrt_get_bm_handle(void* p_bmrt);
133
+
134
+ /**
135
+ * @name bmrt_load_bmodel
136
+ * @brief To load the bmodel which is created by BM compiler
137
+ * @ingroup bmruntime
138
+ *
139
+ * This API is to load bmodel created by BM compiler.
140
+ * After loading bmodel, we can run the inference of neuron network.
141
+ *
142
+ * @param [in] p_bmrt Bmruntime that had been created
143
+ * @param [in] bmodel_path Bmodel file directory.
144
+ *
145
+ * @retval true Load context sucess.
146
+ * @retval false Load context failed.
147
+ */
148
+ DECL_EXPORT bool bmrt_load_bmodel(void* p_bmrt, const char *bmodel_path);
149
+
150
+ /**
151
+ * @name bmrt_load_bmodel_data
152
+ * @brief To load the bmodel which is created by BM compiler from buffer
153
+ * @ingroup bmruntime
154
+ *
155
+ * This API is to load bmodel created by BM compiler.
156
+ * After loading bmodel, we can run the inference of neuron network.
157
+ * Different with bmrt_load_bmodel, bmodel is the data in host memory.
158
+ *
159
+ * @param [in] p_bmrt Bmruntime that had been created
160
+ * @param [in] bmodel_data Bmodel data pointer to buffer
161
+ * @param [in] size Bmodel data size
162
+ *
163
+ * @retval true Load context sucess.
164
+ * @retval false Load context failed.
165
+ */
166
+ DECL_EXPORT bool bmrt_load_bmodel_data(void* p_bmrt, const void * bmodel_data, size_t size);
167
+
168
+ /**
169
+ * @name bmrt_show_neuron_network
170
+ * @brief To print the name of all neuron network
171
+ * @ingroup bmruntime
172
+ *
173
+ * @param [in] p_bmrt Bmruntime that had been created
174
+ */
175
+ DECL_EXPORT void bmrt_show_neuron_network(void* p_bmrt);
176
+
177
+ /**
178
+ * @name bmrt_get_network_number
179
+ * @brief To get the number of neuron network in the bmruntime
180
+ * @ingroup bmruntime
181
+ *
182
+ * @param [in] p_bmrt Bmruntime that had been created
183
+ *
184
+ * @retval int value The number of neuron networks.
185
+ */
186
+ DECL_EXPORT int bmrt_get_network_number(void* p_bmrt);
187
+
188
+ /**
189
+ * @name bmrt_get_network_names
190
+ * @brief To get the names of all neuron network in the bmruntime
191
+ * @ingroup bmruntime
192
+ *
193
+ * @param [in] p_bmrt Bmruntime that had been created
194
+ * @param [out] network_names The names of all neuron networks. It should be declare as (const char** networks_ = NULL),
195
+ * and use as the param &networks_. After this API, user need to free(networks_) if user
196
+ * do not need it.
197
+ */
198
+ DECL_EXPORT void bmrt_get_network_names(void* p_bmrt, const char*** network_names);
199
+
200
+ /**
201
+ * @name bmrt_get_network_info
202
+ * @brief To get network info by net name
203
+ * @ingroup bmruntime
204
+ *
205
+ * @param [in] p_bmrt Bmruntime that had been created
206
+ * @param [in] net_name Network name
207
+ *
208
+ * @retval bm_net_info_t* Pointer to net info, needn't free by user; if net name not found, will return NULL.
209
+ */
210
+ DECL_EXPORT const bm_net_info_t* bmrt_get_network_info(void* p_bmrt, const char* net_name);
211
+
212
+ /**
213
+ * @name bmrt_launch_tensor
214
+ * @brief To launch the inference of the neuron network with setting input tensors
215
+ * @ingroup bmruntime
216
+ *
217
+ * This API supports the neuron nework that is static-compiled or dynamic-compiled
218
+ * After calling this API, inference on TPU is launched. And the CPU program will not
219
+ * be blocked. bm_thread_sync should be called to make sure inference finished.
220
+ * This API support multiple inputs, and multi thread safety
221
+ *
222
+ * @param [in] p_bmrt Bmruntime that had been created
223
+ * @param [in] net_name The name of the neuron network
224
+ * @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num].
225
+ * User should initialize each input tensor.
226
+ * @param [in] input_num Input number
227
+ * @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num].
228
+ * This interface will alloc devcie mem to store output data. User should free each
229
+ * device mem by bm_free_device after the result data not used.
230
+ * @param [in] output_num Output number
231
+ *
232
+ * @retval true Launch success.
233
+ * @retval false Launch failed.
234
+ */
235
+ DECL_EXPORT bool bmrt_launch_tensor(void* p_bmrt, const char * net_name, const bm_tensor_t input_tensors[], int input_num,
236
+ bm_tensor_t output_tensors[], int output_num);
237
+
238
+ /**
239
+ * @name bmrt_launch_tensor_ex
240
+ * @brief To launch the inference of the neuron network with setting input tensors
241
+ * @ingroup bmruntime
242
+ *
243
+ * This API supports the neuron nework that is static-compiled or dynamic-compiled
244
+ * After calling this API, inference on TPU is launched. And the CPU program will not
245
+ * be blocked. bm_thread_sync should be called to make sure inference finished.
246
+ * This API support multiple inputs, and multi thread safety
247
+ *
248
+ * @param [in] p_bmrt Bmruntime that had been created
249
+ * @param [in] net_name The name of the neuron network
250
+ * @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num],
251
+ * User should initialize each input tensor.
252
+ * @param [in] input_num Input number
253
+ * @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num].
254
+ * User can set device_mem or stmode of output tensors. If user_mem is true, this interface
255
+ * will use device mem of output_tensors to store output data, and not alloc device mem;
256
+ * Or it will alloc device mem to store output. If user_stmode is true, it will use stmode in
257
+ * each output tensor; Or stmode will be BM_STORE_1N as default.
258
+ * @param [in] output_num Output number
259
+ * @param [in] user_mem whether device_mem of output tensors are set
260
+ * @param [in] user_stmode whether stmode of output tensors are set
261
+ *
262
+ * @retval true Launch success.
263
+ * @retval false Launch failed.
264
+ */
265
+ DECL_EXPORT bool bmrt_launch_tensor_ex(void* p_bmrt, const char * net_name, const bm_tensor_t input_tensors[], int input_num,
266
+ bm_tensor_t output_tensors[], int output_num, bool user_mem, bool user_stmode);
267
+
268
+ /**
269
+ * @name bmrt_launch_data
270
+ * @brief To launch the inference of the neuron network with setting input datas in system memory
271
+ * @ingroup bmruntime
272
+ *
273
+ * This API supports the neuron nework that is static-compiled or dynamic-compiled
274
+ * After calling this API, inference on TPU is launched. And the CPU
275
+ * program will be blocked.
276
+ * This API support multiple inputs, and multi thread safety
277
+ *
278
+ * @param [in] p_bmrt Bmruntime that had been created
279
+ * @param [in] net_name The name of the neuron network
280
+ * @param [in] input_datas Array of input data, defined like void * input_datas[input_num]. User should
281
+ * initialize each data pointer as input.
282
+ * @param [in] input_shapes Array of input shape, defined like bm_shape_t input_shapes[input_num].
283
+ * User should set each input shape
284
+ * @param [in] input_num Input number
285
+ * @param [out] output_datas Array of output data, defined like void * output_datas[output_num].
286
+ * If user don't alloc each output data, set user_mem to false, and this api will alloc
287
+ * output mem, user should free each output mem when output data not used. Also
288
+ * user can alloc system memory for each output data by self and set user_mem = true.
289
+ * @param [out] output_shapes Array of output shape, defined like bm_shape_t output_shapes[output_num].
290
+ * It will store each output shape.
291
+ * @param [in] output_num Output number
292
+ * @param [in] user_mem whether output_datas[i] have allocated memory
293
+ *
294
+ * @retval true Launch success.
295
+ * @retval false Launch failed.
296
+ */
297
+ DECL_EXPORT bool bmrt_launch_data(void* p_bmrt, const char* net_name, void* const input_datas[],
298
+ const bm_shape_t input_shapes[], int input_num, void * output_datas[],
299
+ bm_shape_t output_shapes[], int output_num, bool user_mem);
300
+
301
+ /**
302
+ * @name bmrt_trace
303
+ * @brief To check runtime environment, and collect info for DEBUG
304
+ * @ingroup bmruntime
305
+ *
306
+ * This API is to collect runtime info for DEBUG. Expecially when launch result sudden mistake, call bmrt_trace
307
+ * will show whether device mems are broken, and other check info.
308
+ *
309
+ * @param [in] p_bmrt Bmruntime that had been created
310
+ */
311
+ DECL_EXPORT void bmrt_trace(void* p_bmrt);
312
+
313
+ /**
314
+ * @name bmrt_launch_tensor_multi_cores
315
+ * @brief To launch the inference of the neuron network with setting input tensors, and support multi core inference.
316
+ * @ingroup bmruntime
317
+ *
318
+ * This API supports the neuron nework that is static-compiled or dynamic-compiled
319
+ * After calling this API, inference on TPU is launched. And the CPU program will not
320
+ * be blocked. bm_thread_sync_from_core should be called to make sure inference is finished.
321
+ * This API support multiple inputs, and multi thread safety
322
+ *
323
+ * @param [in] p_bmrt Bmruntime that had been created
324
+ * @param [in] net_name The name of the neuron network
325
+ * @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num],
326
+ * User should initialize each input tensor.
327
+ * @param [in] input_num Input number
328
+ * @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num].
329
+ * User can set device_mem or stmode of output tensors. If user_mem is true, this interface
330
+ * will use device mem of output_tensors to store output data, and not alloc device mem;
331
+ * Or it will alloc device mem to store output. If user_stmode is true, it will use stmode in
332
+ * each output tensor; Or stmode will be BM_STORE_1N as default.
333
+ * @param [in] output_num Output number
334
+ * @param [in] user_mem whether device_mem of output tensors are set
335
+ * @param [in] user_stmode whether stmode of output tensors are set
336
+ * @param [in] core_list core id list those will be used to inference
337
+ * @param [in] core_num number of the core list
338
+ *
339
+ * @retval true Launch success.
340
+ * @retval false Launch failed.
341
+ */
342
+ DECL_EXPORT bool bmrt_launch_tensor_multi_cores(
343
+ void *p_bmrt,
344
+ const char *net_name,
345
+ const bm_tensor_t input_tensors[],
346
+ int input_num,
347
+ bm_tensor_t output_tensors[],
348
+ int output_num,
349
+ bool user_mem,
350
+ bool user_stmode,
351
+ const int *core_list,
352
+ int core_num);
353
+
354
+ /**
355
+ * @name bmrt_memcpy_s2d_parallel
356
+ * @brief To copy data from system memory to muti-devices memory in parallel
357
+ * @ingroup bmruntime
358
+ *
359
+ * This API only could be used when the p_bmrt is created with bmrt_create_ex on multi devices.
360
+ * After calling this API, datas[:tensor_num[0]] will be copied to the first device, and
361
+ * datas[tensor_num[0]:tensor_num[0]+tensor_num[1]] will be copied to the second device and so on.
362
+ * The process of copying data to different devices is done in parallel and to the same device is in sequence.
363
+ *
364
+ * @param [in] p_bmrt Bmruntime that had been created with multi bm_handles
365
+ * @param [in] tensors Array of tensors that will be copied to devices
366
+ * @param [in] datas Array of satas allocated in system memory
367
+ * @param [in] tensor_num Array of tensor_num that will be copied to each device
368
+ * @param [in] device_num Device number
369
+ */
370
+ DECL_EXPORT bool bmrt_memcpy_s2d_parallel(
371
+ void *p_bmrt,
372
+ bm_tensor_t tensors[],
373
+ void *datas[],
374
+ int tensor_num[],
375
+ int device_num);
376
+
377
+ /**
378
+ * @name bmrt_memcpy_d2s_parallel
379
+ * @brief To copy data from muti-devices memory to system memory in parallel
380
+ * @ingroup bmruntime
381
+ *
382
+ * This API only could be used when the p_bmrt is created with bmrt_create_ex on multi devices.
383
+ * After calling this API, tensors on the first device will be copied to datas[:tensor_num[0]] , and
384
+ * tensors on the second device will be copied to datas[tensor_num[0]:tensor_num[0]+tensor_num[1]] and so on.
385
+ * The process of copying data from different devices is done in parallel and from the same device is in sequence.
386
+ *
387
+ * @param [in] p_bmrt Bmruntime that had been created with multi bm_handles
388
+ * @param [in] datas Array of satas allocated in system memory
389
+ * @param [in] tensors Array of tensors that will be copied from devices
390
+ * @param [in] tensor_num Array of tensor_num that will be copied from each device
391
+ * @param [in] device_num Device number
392
+ */
393
+ DECL_EXPORT bool bmrt_memcpy_d2s_parallel(
394
+ void *p_bmrt,
395
+ void *datas[],
396
+ bm_tensor_t tensors[],
397
+ int tensor_num[],
398
+ int device_num);
399
+
400
+ #if defined (__cplusplus)
401
+ }
402
+ #endif
403
+
404
+ #endif
ChatGLM2/support/include/sentencepiece/sentencepiece_processor.h ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2016 Google Inc.
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.!
14
+
15
+ #ifndef SENTENCEPIECE_PROCESSOR_H_
16
+ #define SENTENCEPIECE_PROCESSOR_H_
17
+
18
+ #include <cstring>
19
+ #include <memory>
20
+ #include <string>
21
+ #include <string_view>
22
+ #include <utility>
23
+ #include <vector>
24
+
25
+ #ifndef SWIG
26
+ namespace absl {
27
+ using std::string_view;
28
+ } // namespace absl
29
+ #endif // SWIG
30
+
31
+ namespace sentencepiece {
32
+ namespace util {
33
+
34
+ enum class StatusCode : int {
35
+ kOk = 0,
36
+ kCancelled = 1,
37
+ kUnknown = 2,
38
+ kInvalidArgument = 3,
39
+ kDeadlineExceeded = 4,
40
+ kNotFound = 5,
41
+ kAlreadyExists = 6,
42
+ kPermissionDenied = 7,
43
+ kResourceExhausted = 8,
44
+ kFailedPrecondition = 9,
45
+ kAborted = 10,
46
+ kOutOfRange = 11,
47
+ kUnimplemented = 12,
48
+ kInternal = 13,
49
+ kUnavailable = 14,
50
+ kDataLoss = 15,
51
+ kUnauthenticated = 16,
52
+ };
53
+
54
+ class Status {
55
+ public:
56
+ Status();
57
+ ~Status();
58
+ Status(StatusCode code, absl::string_view error_message);
59
+ Status(const Status &s);
60
+ void operator=(const Status &s);
61
+ bool operator==(const Status &s) const;
62
+ bool operator!=(const Status &s) const;
63
+ inline bool ok() const { return rep_ == nullptr; }
64
+
65
+ void set_error_message(const char *str);
66
+ const char *error_message() const;
67
+ const char *message() const { return error_message(); }
68
+ StatusCode code() const;
69
+ std::string ToString() const;
70
+
71
+ void IgnoreError();
72
+
73
+ private:
74
+ struct Rep;
75
+ std::unique_ptr<Rep> rep_;
76
+ };
77
+ } // namespace util
78
+
79
+ // SentencePieceProcessor:
80
+ // Simple and language independent tokenizer and de-tokenizer for
81
+ // Neural Network Machine Translation.
82
+ //
83
+ // SentencePieceProcessor provides Encode() and Decode() methods,
84
+ // which correspond to tokenization and de-tokenization respectively.
85
+ //
86
+ // - Encode:
87
+ // Given a raw source sentence, encode it into a sequence
88
+ // of pieces or vocabulary ids.
89
+ //
90
+ // - Decode:
91
+ // Given a sequence of pieces or vocabulary ids, decode it
92
+ // into a de-tokenized raw sentence.
93
+ //
94
+ // SentencePieceProcessor provides a lossless data conversion
95
+ // that allows the original raw sentence to be perfectly reconstructed
96
+ // from the encoded data, i.e., Decode(Encode(input)) == input.
97
+ // This characteristics is useful, as we can make the de-tokenization
98
+ // completely language independent.
99
+ //
100
+ // Usage:
101
+ // SentencePieceProcessor sp;
102
+ // sp.Load("//path/to/model");
103
+ //
104
+ // vector<string> sps;
105
+ // sp.Encode("hello world.", &sps).IgnoreError();
106
+ //
107
+ // vector<int> ids;
108
+ // sp.Encode("hello world.", &ids).IgnoreError();
109
+ //
110
+ // string detok;
111
+ // sp.Decode(sps, &detok);
112
+ // CHECK_EQ("hello world.", detok).IgnoreError();
113
+ //
114
+ // sp.Decode(ids, &detok);
115
+ // CHECK_EQ("hello world.", detok).IgnoreError();
116
+ //
117
+ // We can also use SentencePieceText which manages the byte-offsets
118
+ // between user input (output) and internal sentence pieces.
119
+ //
120
+ // SentencePieceText spt;
121
+ // sp.Encode("hello world.", &spt);
122
+ // // Emits the byte range of each piece.
123
+ // for (const auto &piece : spt.pieces()) {
124
+ // LOG(INFO) << piece.begin() << " " << piece.end();
125
+ // }
126
+ //
127
+ // sp.Decode({0, 1, 2, 3..}, &spt);
128
+ // for (const auto &piece : spt.pieces()) {
129
+ // LOG(INFO) << piece.begin() << " " << piece.end();
130
+ // }
131
+ //
132
+
133
+ class NBestSentencePieceText;
134
+ class ModelInterface;
135
+ class SentencePieceText;
136
+ class ModelProto;
137
+
138
+ namespace normalizer {
139
+ class Normalizer;
140
+ } // namespace normalizer
141
+
142
+ #ifndef SWIGGO
143
+ namespace util {
144
+ // Redefine std::string for serialized_proto interface as Python's string is
145
+ // a Unicode string. We can enforce the return value to be raw byte sequence
146
+ // with SWIG's typemap.
147
+ using bytes = std::string;
148
+ } // namespace util
149
+ #endif // SWIGGO
150
+
151
+ class NBestSentencePieceText;
152
+ class ModelInterface;
153
+ class SentencePieceText;
154
+ class SentencePieceText_SentencePiece;
155
+
156
+ // Wrapper class of SentencePieceText
157
+ // This wrapper only allows an immutable access to the proto and
158
+ // hides the actual implementation of protobuf.
159
+ // See sentencepiece.proto for the details of this class.
160
+ class ImmutableSentencePieceText_ImmutableSentencePiece {
161
+ public:
162
+ ImmutableSentencePieceText_ImmutableSentencePiece();
163
+ ~ImmutableSentencePieceText_ImmutableSentencePiece() = default;
164
+
165
+ const std::string &piece() const;
166
+ const std::string &surface() const;
167
+ uint32_t id() const;
168
+ uint32_t begin() const;
169
+ uint32_t end() const;
170
+
171
+ friend class ImmutableSentencePieceText;
172
+
173
+ private:
174
+ explicit ImmutableSentencePieceText_ImmutableSentencePiece(
175
+ const SentencePieceText_SentencePiece &sp);
176
+ const SentencePieceText_SentencePiece *sp_ = nullptr;
177
+ };
178
+
179
+ class ImmutableSentencePieceText {
180
+ public:
181
+ ImmutableSentencePieceText();
182
+ virtual ~ImmutableSentencePieceText();
183
+
184
+ std::vector<ImmutableSentencePieceText_ImmutableSentencePiece> pieces() const;
185
+
186
+ size_t pieces_size() const;
187
+ ImmutableSentencePieceText_ImmutableSentencePiece pieces(int index) const;
188
+
189
+ const std::string &text() const;
190
+ float score() const;
191
+
192
+ util::bytes SerializeAsString() const;
193
+
194
+ // Returns the actual mutable proto.
195
+ // Do not use this outside of SentencePieceProcessor, as
196
+ // it returns the raw pointer managed by the shared_ptr.
197
+ SentencePieceText *mutable_proto();
198
+
199
+ // Converts the utf8 byte spans into Unicode char span.
200
+ void ConvertToUnicodeSpans();
201
+
202
+ friend class ImmutableNBestSentencePieceText;
203
+
204
+ private:
205
+ explicit ImmutableSentencePieceText(const SentencePieceText &spt);
206
+ const SentencePieceText *spt_ = nullptr;
207
+ std::shared_ptr<SentencePieceText> rep_;
208
+ };
209
+
210
+ // Wrapper class of SentencePieceText
211
+ // This wrapper only allows an immutable access to the proto and
212
+ // hides the actual implementation of protobuf.
213
+ // See sentencepiece.proto for the details of this class.
214
+ class ImmutableNBestSentencePieceText {
215
+ public:
216
+ ImmutableNBestSentencePieceText();
217
+ virtual ~ImmutableNBestSentencePieceText();
218
+
219
+ std::vector<ImmutableSentencePieceText> nbests() const;
220
+
221
+ size_t nbests_size() const;
222
+ ImmutableSentencePieceText nbests(int index) const;
223
+
224
+ util::bytes SerializeAsString() const;
225
+
226
+ // Returns the actual mutable proto.
227
+ // Do not use this outside of SentencePieceProcessor, as
228
+ // it returns the raw pointer managed by the shared_ptr.
229
+ NBestSentencePieceText *mutable_proto();
230
+
231
+ void ConvertToUnicodeSpans();
232
+
233
+ private:
234
+ std::shared_ptr<NBestSentencePieceText> rep_;
235
+ };
236
+
237
+ class SentencePieceProcessor {
238
+ public:
239
+ SentencePieceProcessor();
240
+ virtual ~SentencePieceProcessor();
241
+
242
+ // Loads model from `filename`.
243
+ // Returns false if `filename` cannot be loaded.
244
+ virtual util::Status Load(absl::string_view filename);
245
+
246
+ // Loads model from `filename`.
247
+ // Crash if `filename` cannot be loaded.
248
+ virtual void LoadOrDie(absl::string_view filename);
249
+
250
+ // Loads model from `model_proto`.
251
+ // `model_proto` is copied.
252
+ virtual util::Status Load(const ModelProto &model_proto);
253
+
254
+ // Loads model from `model_proto`.
255
+ // `model_proto` is moved.
256
+ virtual util::Status Load(std::unique_ptr<ModelProto> model_proto);
257
+
258
+ // Loads model from `serialized`, which is a string-serialized model proto.
259
+ // Useful to load the model from a platform independent blob object.
260
+ virtual util::Status LoadFromSerializedProto(absl::string_view serialized);
261
+
262
+ // Returns the status. Encode/Decode methods are valid when status is OK.
263
+ virtual util::Status status() const;
264
+
265
+ // Sets encode extra_option sequence.
266
+ virtual util::Status SetEncodeExtraOptions(absl::string_view extra_option);
267
+
268
+ // Sets decode extra_option sequence.
269
+ virtual util::Status SetDecodeExtraOptions(absl::string_view extra_option);
270
+
271
+ //////////////////////////////////////////////////////////////
272
+ // Vocabulary restriction.
273
+ // Background:
274
+ // https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt
275
+
276
+ // Restricts the vocabulary set.
277
+ // The input sentences are encoded into the tokens in `valid_vocab`.
278
+ virtual util::Status SetVocabulary(
279
+ const std::vector<absl::string_view> &valid_vocab);
280
+
281
+ // Reverts the vocabulary restriction.
282
+ virtual util::Status ResetVocabulary();
283
+
284
+ // Loads the valid vocabulary set from `filename` in TSV format.
285
+ // Format: <token> <tab> <freq>.
286
+ // Any token with frequency < threshold will be treated as OOV.
287
+ virtual util::Status LoadVocabulary(absl::string_view filename,
288
+ int threshold);
289
+
290
+ //////////////////////////////////////////////////////////////
291
+ // Simple Encode and Decode API.
292
+ //
293
+ // Given a UTF8 input, encodes it into a sequence of sentence pieces.
294
+ virtual util::Status Encode(absl::string_view input,
295
+ std::vector<std::string> *pieces) const;
296
+
297
+ // Given a UTF8 input, encodes it into a sequence of ids.
298
+ virtual util::Status Encode(absl::string_view input,
299
+ std::vector<int> *ids) const;
300
+
301
+ // Given a sequence of pieces, decodes it into a detokenized output.
302
+ virtual util::Status Decode(const std::vector<std::string> &pieces,
303
+ std::string *detokenized) const;
304
+
305
+ // Given a sequence of pieces, decodes it into a detokenized output.
306
+ virtual util::Status Decode(const std::vector<absl::string_view> &pieces,
307
+ std::string *detokenized) const;
308
+
309
+ // Given a sequence of ids, decodes it into a detokenized output.
310
+ virtual util::Status Decode(const std::vector<int> &ids,
311
+ std::string *detokenized) const;
312
+
313
+ //////////////////////////////////////////////////////////////
314
+ // NBest API.
315
+ //
316
+ // Same as Encode, but returns nbest results.
317
+ virtual util::Status NBestEncode(
318
+ absl::string_view input, int nbest_size,
319
+ std::vector<std::vector<std::string>> *pieces) const;
320
+
321
+ // Same as Encode, but returns nbest results.
322
+ virtual util::Status NBestEncode(absl::string_view input, int nbest_size,
323
+ std::vector<std::vector<int>> *ids) const;
324
+
325
+ //////////////////////////////////////////////////////////////
326
+ // Sampling API.
327
+ //
328
+ // Unigram and BPE support sampling mode.
329
+ // - Unigram (--model_type=unigram):
330
+ // `nbest_size`: When `nbest_size` is positive value, approximately samples
331
+ // one segmentation from nbest candidates. When `nbest_size` is negative
332
+ // value, samples one segmentation from the hypotheses (Lattice) according to
333
+ // the generation probabilities using forward-filtering and backward-sampling
334
+ // algorithm.
335
+ // `alpha`: Smoothing parameter (inverse temperature). The best segmentation
336
+ // (Viterbi segmentation) is more likely sampled when setting larger alpha.
337
+ // When alpha is 0.0, one segmentation is uniformly sampled from the nbest or
338
+ // lattice. `nbest_size` and `alpha` correspond to parameters `l` and `alpha`
339
+ // in https://arxiv.org/abs/1804.10959 (nbest_size < 0 means l = infinity)
340
+ //
341
+ // - BPE (--model_type=bpe):
342
+ // `alpha`: The dropout probability `p` of bpe merge operations in
343
+ // https://arxiv.org/abs/1910.13267 Nbest-based sampling is not supported so
344
+ // nbest_size parameter is ignored in BPE.
345
+ virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
346
+ float alpha,
347
+ std::vector<std::string> *pieces) const;
348
+
349
+ // Same as above, but returns a sequence of ids.
350
+ virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
351
+ float alpha, std::vector<int> *ids) const;
352
+
353
+ //////////////////////////////////////////////////////////////
354
+ // SampleEncodeAndScore API.
355
+ //
356
+ // Sample `samples` many tokenisations from the segmentation lattice.
357
+ // These methods are only available in model_type=unigram.
358
+ //
359
+ // `alpha`: smoothing parameter (inverse temperature). The same as `alpha` in
360
+ // `Sample` method.
361
+ // 'wor`: If `wor` is true, the samples are taken without replacement, and the
362
+ // scores are the inclusion probabilities of the elements in the sample;
363
+ // otherwise the samples are taken with replacement and the scores are the
364
+ // log-probs of sample elements
365
+ // `include_best`: If `include_best` is true, the best tokenisation is always
366
+ // included in the sample, and the remaining elements are sampled excluding
367
+ // the best.
368
+ virtual util::Status SampleEncodeAndScore(
369
+ absl::string_view input, int num_samples, float alpha, bool wor,
370
+ bool include_best,
371
+ std::vector<std::pair<std::vector<std::string>, float>> *pieces) const;
372
+
373
+ // Same as above, but returns a sequence of ids.
374
+ virtual util::Status SampleEncodeAndScore(
375
+ absl::string_view input, int num_samples, float alpha, bool wor,
376
+ bool include_best,
377
+ std::vector<std::pair<std::vector<int>, float>> *ids) const;
378
+
379
+ //////////////////////////////////////////////////////////////
380
+ // Entropy API.
381
+ //
382
+ // This only available in model_type=unigram.
383
+ // Calculate entropy of possible tokenisations
384
+ virtual util::Status CalculateEntropy(absl::string_view input, float alpha,
385
+ float *entropy) const;
386
+
387
+ //////////////////////////////////////////////////////////////
388
+ // Advanced API returning SentencePieceText, which manages
389
+ // utf8-byte alignments between user-input/detokenized text
390
+ // and internal sentencepiece sequence.
391
+ //
392
+ // Given a UTF8 input, encodes it into SentencePieceText.
393
+ //
394
+ // When using these APIs, sentencepiece.pb.h header files must be included.
395
+ // We can also use ImutableSentencePieceText as follows.
396
+ //
397
+ // ImmutableSentencePieceText spt;
398
+ // Encode("hello", spt.mutable_proto()).IgnoreError();
399
+ // std::cout << spt.pieces_size() << std::endl;
400
+ virtual util::Status Encode(absl::string_view input,
401
+ SentencePieceText *spt) const;
402
+
403
+ virtual util::Status NBestEncode(absl::string_view input, int nbest_size,
404
+ NBestSentencePieceText *nbest_spt) const;
405
+
406
+ virtual util::Status SampleEncode(absl::string_view input, int nbest_size,
407
+ float alpha, SentencePieceText *spt) const;
408
+
409
+ virtual util::Status SampleEncodeAndScore(
410
+ absl::string_view input, int num_samples, float alpha, bool wor,
411
+ bool include_best, NBestSentencePieceText *samples_spt) const;
412
+
413
+ // DEPRECATED: Remove this API and use std::vector<std::string_view>
414
+ virtual util::Status Decode(const std::vector<std::string> &pieces,
415
+ SentencePieceText *spt) const;
416
+
417
+ virtual util::Status Decode(const std::vector<absl::string_view> &pieces,
418
+ SentencePieceText *spt) const;
419
+
420
+ virtual util::Status Decode(const std::vector<int> &ids,
421
+ SentencePieceText *spt) const;
422
+ #ifdef SWIG
423
+ #define SPP_SWIG_CHECK_AND_THROW \
424
+ if (!status.ok()) throw status;
425
+ #else
426
+ #define SPP_SWIG_CHECK_AND_THROW \
427
+ if (!status.ok()) { \
428
+ }
429
+ #endif // SWIG
430
+
431
+ #define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
432
+ OutType output; \
433
+ const auto status = FuncName(__VA_ARGS__, &output); \
434
+ SPP_SWIG_CHECK_AND_THROW; \
435
+ return output;
436
+
437
+ #define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...) \
438
+ OutType output; \
439
+ const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
440
+ SPP_SWIG_CHECK_AND_THROW; \
441
+ return output.SerializeAsString();
442
+
443
+ #define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...) \
444
+ OutType output; \
445
+ const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
446
+ SPP_SWIG_CHECK_AND_THROW; \
447
+ return output;
448
+
449
+ //////////////////////////////////////////////////////////////
450
+ // Handy methods that return the result directly.
451
+ // These functions ignore internal errors.
452
+ virtual std::vector<std::string> EncodeAsPieces(
453
+ absl::string_view input) const {
454
+ DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<std::string>, input);
455
+ }
456
+
457
+ virtual std::vector<int> EncodeAsIds(absl::string_view input) const {
458
+ DEFINE_SPP_DIRECT_FUNC_IMPL(Encode, std::vector<int>, input);
459
+ }
460
+
461
+ virtual std::vector<std::vector<std::string>> NBestEncodeAsPieces(
462
+ absl::string_view input, int nbest_size) const {
463
+ DEFINE_SPP_DIRECT_FUNC_IMPL(
464
+ NBestEncode, std::vector<std::vector<std::string>>, input, nbest_size);
465
+ }
466
+
467
+ virtual std::vector<std::vector<int>> NBestEncodeAsIds(
468
+ absl::string_view input, int nbest_size) const {
469
+ DEFINE_SPP_DIRECT_FUNC_IMPL(NBestEncode, std::vector<std::vector<int>>,
470
+ input, nbest_size);
471
+ }
472
+
473
+ virtual std::vector<std::string> SampleEncodeAsPieces(absl::string_view input,
474
+ int nbest_size,
475
+ float alpha) const {
476
+ DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<std::string>, input,
477
+ nbest_size, alpha);
478
+ }
479
+
480
+ virtual std::vector<int> SampleEncodeAsIds(absl::string_view input,
481
+ int nbest_size,
482
+ float alpha) const {
483
+ DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncode, std::vector<int>, input,
484
+ nbest_size, alpha);
485
+ }
486
+
487
+ virtual std::vector<std::pair<std::vector<std::string>, float>>
488
+ SampleEncodeAndScoreAsPieces(absl::string_view input, int num_samples,
489
+ float alpha, bool wor, bool include_best) const {
490
+ using _T = std::vector<std::pair<std::vector<std::string>, float>>;
491
+ DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples,
492
+ alpha, wor, include_best);
493
+ }
494
+
495
+ virtual std::vector<std::pair<std::vector<int>, float>>
496
+ SampleEncodeAndScoreAsIds(absl::string_view input, int num_samples,
497
+ float alpha, bool wor, bool include_best) const {
498
+ using _T = std::vector<std::pair<std::vector<int>, float>>;
499
+ DEFINE_SPP_DIRECT_FUNC_IMPL(SampleEncodeAndScore, _T, input, num_samples,
500
+ alpha, wor, include_best);
501
+ }
502
+
503
+ // DEPRECATED: Remove this API and use std::vector<std::string_view>
504
+ virtual std::string DecodePieces(
505
+ const std::vector<std::string> &pieces) const {
506
+ DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces);
507
+ }
508
+
509
+ virtual std::string DecodePieces(
510
+ const std::vector<absl::string_view> &pieces) const {
511
+ DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, pieces);
512
+ }
513
+
514
+ virtual std::string DecodeIds(const std::vector<int> &ids) const {
515
+ DEFINE_SPP_DIRECT_FUNC_IMPL(Decode, std::string, ids);
516
+ }
517
+
518
+ virtual float CalculateEntropy(absl::string_view text, float alpha) const {
519
+ DEFINE_SPP_DIRECT_FUNC_IMPL(CalculateEntropy, float, text, alpha);
520
+ }
521
+
522
+ //////////////////////////////////////////////////////////////
523
+ // SerializedProto API. (DEPRECATED). Use ImmutableProto API.
524
+ // They are used in Python interface. Returns serialized proto.
525
+ // In python module, we can get access to the full Proto after
526
+ // deserialzing the returned byte sequence.
527
+ virtual util::bytes EncodeAsSerializedProto(absl::string_view input) const {
528
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(Encode, ImmutableSentencePieceText, input);
529
+ }
530
+
531
+ virtual util::bytes SampleEncodeAsSerializedProto(absl::string_view input,
532
+ int nbest_size,
533
+ float alpha) const {
534
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText,
535
+ input, nbest_size, alpha);
536
+ }
537
+
538
+ virtual util::bytes NBestEncodeAsSerializedProto(absl::string_view input,
539
+ int nbest_size) const {
540
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(
541
+ NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size);
542
+ }
543
+
544
+ virtual util::bytes SampleEncodeAndScoreAsSerializedProto(
545
+ absl::string_view input, int num_samples, float alpha, bool wor,
546
+ bool include_best) const {
547
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(SampleEncodeAndScore,
548
+ ImmutableNBestSentencePieceText, input,
549
+ num_samples, alpha, wor, include_best);
550
+ }
551
+
552
+ // TODO(taku): Remove this API and use std::vector<std::string_view>
553
+ virtual util::bytes DecodePiecesAsSerializedProto(
554
+ const std::vector<std::string> &pieces) const {
555
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText,
556
+ pieces);
557
+ }
558
+
559
+ virtual util::bytes DecodePiecesAsSerializedProto(
560
+ const std::vector<absl::string_view> &pieces) const {
561
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText,
562
+ pieces);
563
+ }
564
+
565
+ virtual util::bytes DecodeIdsAsSerializedProto(
566
+ const std::vector<int> &ids) const {
567
+ DEFINE_SPP_SERIALIZED_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids);
568
+ }
569
+
570
+ //////////////////////////////////////////////////////////////
571
+ // ImmutableProto API.
572
+ virtual ImmutableSentencePieceText EncodeAsImmutableProto(
573
+ absl::string_view input) const {
574
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Encode, ImmutableSentencePieceText, input);
575
+ }
576
+
577
+ virtual ImmutableSentencePieceText SampleEncodeAsImmutableProto(
578
+ absl::string_view input, int nbest_size, float alpha) const {
579
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncode, ImmutableSentencePieceText,
580
+ input, nbest_size, alpha);
581
+ }
582
+
583
+ virtual ImmutableNBestSentencePieceText NBestEncodeAsImmutableProto(
584
+ absl::string_view input, int nbest_size) const {
585
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(
586
+ NBestEncode, ImmutableNBestSentencePieceText, input, nbest_size);
587
+ }
588
+
589
+ virtual ImmutableNBestSentencePieceText SampleEncodeAndScoreAsImmutableProto(
590
+ absl::string_view input, int num_samples, float alpha, bool wor,
591
+ bool include_best) const {
592
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(SampleEncodeAndScore,
593
+ ImmutableNBestSentencePieceText, input,
594
+ num_samples, alpha, wor, include_best);
595
+ }
596
+
597
+ // TODO(taku): Remove this API and use std::vector<std::string_view>
598
+ virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto(
599
+ const std::vector<std::string> &pieces) const {
600
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces);
601
+ }
602
+
603
+ virtual ImmutableSentencePieceText DecodePiecesAsImmutableProto(
604
+ const std::vector<absl::string_view> &pieces) const {
605
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, pieces);
606
+ }
607
+
608
+ virtual ImmutableSentencePieceText DecodeIdsAsImmutableProto(
609
+ const std::vector<int> &ids) const {
610
+ DEFINE_SPP_IMMUTABLE_PROTO_IMPL(Decode, ImmutableSentencePieceText, ids);
611
+ }
612
+
613
+ #undef DEFINE_SPP_DIRECT_FUNC_IMPL
614
+ #undef DEFINE_SPP_SERIALIZED_PROTO_IMPL
615
+ #undef DEFINE_SPP_IMMUTABLE_PROTO_IMPL
616
+
617
+ //////////////////////////////////////////////////////////////
618
+ // Vocabulary management methods.
619
+ //
620
+ // Returns the size of sentence pieces, which is the same as
621
+ // the size of vocabulary for NMT.
622
+ virtual int GetPieceSize() const;
623
+
624
+ // Returns the vocab id of `piece`.
625
+ // Returns UNK(0) if `piece` is unknown.
626
+ virtual int PieceToId(absl::string_view piece) const;
627
+
628
+ // Returns the string representation of vocab with `id`.
629
+ virtual const std::string &IdToPiece(int id) const;
630
+
631
+ // Returns the score of `id`.
632
+ // Usually score is an emission log probability of unigram language
633
+ // model.
634
+ virtual float GetScore(int id) const;
635
+
636
+ // Returns true if `id` is unknown symbol.
637
+ virtual bool IsUnknown(int id) const;
638
+
639
+ // Returns true if `id` is control symbol.
640
+ virtual bool IsControl(int id) const;
641
+
642
+ // Returns true if `id` is unused symbol.
643
+ virtual bool IsUnused(int id) const;
644
+
645
+ // Returns true if `id` is byte symbol.
646
+ virtual bool IsByte(int id) const;
647
+
648
+ // Returns the reserved id.
649
+ // Returns -1 if not defined.
650
+
651
+ // Returns unknown (<unk>) id.
652
+ virtual int unk_id() const;
653
+
654
+ // Returns BOS (<s>) id.
655
+ virtual int bos_id() const;
656
+
657
+ // Returns EOS (</s>) id.
658
+ virtual int eos_id() const;
659
+
660
+ // Returns PAD (<pad>) id.
661
+ virtual int pad_id() const;
662
+
663
+ //////////////////////////////////////////////////////////////
664
+ // Model management.
665
+ //
666
+ // Allows injection of a mock model instance. `model` is moved.
667
+ void SetModel(std::unique_ptr<ModelInterface> &&model);
668
+
669
+ // Allows injection of a normalizer instance. `normalizer` is moved.
670
+ void SetNormalizer(std::unique_ptr<normalizer::Normalizer> &&normalizer);
671
+
672
+ // Returns immutable model proto. Useful to obtain extended
673
+ // or experimental parameters encoded in model_proto.
674
+ const ModelProto &model_proto() const;
675
+
676
+ // returns immutable model proto as std::string.
677
+ // Useful to save the state of this instance via Python's pickle object.
678
+ util::bytes serialized_model_proto() const;
679
+
680
+ private:
681
+ enum ExtraOption { REVERSE, BOS, EOS, UNK_PIECE };
682
+
683
+ util::Status ParseExtraOptions(absl::string_view extra_option,
684
+ std::vector<ExtraOption> *extra_options) const;
685
+
686
+ util::Status ApplyExtraOptions(const std::vector<ExtraOption> &extra_options,
687
+ SentencePieceText *spt) const;
688
+
689
+ util::Status PopulateSentencePieceText(
690
+ absl::string_view input, absl::string_view normalized,
691
+ const std::vector<size_t> &norm_to_orig,
692
+ const std::vector<std::pair<absl::string_view, int>> &result,
693
+ SentencePieceText *spt) const;
694
+
695
+ std::unique_ptr<ModelInterface> model_;
696
+ std::unique_ptr<normalizer::Normalizer> normalizer_;
697
+ std::unique_ptr<normalizer::Normalizer> denormalizer_;
698
+
699
+ // Underlying model protocol buffer. The same lifetime as model_.
700
+ std::unique_ptr<ModelProto> model_proto_;
701
+
702
+ std::vector<ExtraOption> encode_extra_options_;
703
+ std::vector<ExtraOption> decode_extra_options_;
704
+ };
705
+
706
+ // Set seed value of random generator.
707
+ // Do not set static_cast<unique_int>(-1),
708
+ // as this seed is reserved for initializing from
709
+ // std::random_device.
710
+ void SetRandomGeneratorSeed(unsigned int seed);
711
+
712
+ // IO related functions to absorb model formats.
713
+ namespace io {
714
+ // Loads `model_proto` from `filename`.
715
+ // We can instantiate SentencePieceProcessor as follows:
716
+ //
717
+ // auto model_proto = absl::make_unique<ModelProto>();
718
+ // io::LoadModelProto("//path/spm.model", model_proto.get());
719
+ // SentencePieceProcessor sp;
720
+ // CHECK_OK(sp.Load(std::move(model_proto)));
721
+ util::Status LoadModelProto(absl::string_view, ModelProto *model_proto);
722
+
723
+ // Saves `model_proto` as `filename`.
724
+ util::Status SaveModelProto(absl::string_view, const ModelProto &model_proto);
725
+ } // namespace io
726
+ } // namespace sentencepiece
727
+ #endif // SENTENCEPIECE_PROCESSOR_H_
ChatGLM2/support/lib_pcie/libbmlib.so ADDED
Binary file (195 kB). View file
 
ChatGLM2/support/lib_pcie/libbmrt.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:621e33823dca470275e09570324a567ce4a30fa6100ac9e52742bb9e1ee02f45
3
+ size 2966400
ChatGLM2/support/lib_pcie/libbmrt.so.1.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:621e33823dca470275e09570324a567ce4a30fa6100ac9e52742bb9e1ee02f45
3
+ size 2966400
ChatGLM2/support/lib_pcie/libsentencepiece.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68811cd99e6e1a58572372f14f3b7a02cf98bc98f5d46d24c406be65a94b53e8
3
+ size 2858304
ChatGLM2/support/lib_soc/libbmlib.so ADDED
Binary file (191 kB). View file
 
ChatGLM2/support/lib_soc/libbmrt.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cff807807fcc8c6a9d16353e389422d434ae2b79c8bc191266d0eb5a69b3d97d
3
+ size 2915352
ChatGLM2/support/lib_soc/libbmrt.so.1.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cff807807fcc8c6a9d16353e389422d434ae2b79c8bc191266d0eb5a69b3d97d
3
+ size 2915352
ChatGLM2/support/lib_soc/libsentencepiece.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b1c1ece6c62265ee879cf5876d31e82580c3ee88c2cb627b8ac3eaf35695bde
3
+ size 3032062
ChatGLM2/support/tokenizer/tokenization_chatglm.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import List, Optional, Union, Dict
4
+ from sentencepiece import SentencePieceProcessor
5
+ from transformers import PreTrainedTokenizer
6
+ from transformers.utils import logging, PaddingStrategy
7
+ from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
8
+
9
+
10
+ class SPTokenizer:
11
+ def __init__(self, model_path: str):
12
+ # reload tokenizer
13
+ assert os.path.isfile(model_path), model_path
14
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
15
+
16
+ # BOS / EOS token IDs
17
+ self.n_words: int = self.sp_model.vocab_size()
18
+ self.bos_id: int = self.sp_model.bos_id()
19
+ self.eos_id: int = self.sp_model.eos_id()
20
+ self.pad_id: int = self.sp_model.unk_id()
21
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
22
+
23
+ special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
24
+ self.special_tokens = {}
25
+ self.index_special_tokens = {}
26
+ for token in special_tokens:
27
+ self.special_tokens[token] = self.n_words
28
+ self.index_special_tokens[self.n_words] = token
29
+ self.n_words += 1
30
+
31
+ def tokenize(self, s: str):
32
+ return self.sp_model.EncodeAsPieces(s)
33
+
34
+ def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
35
+ assert type(s) is str
36
+ t = self.sp_model.encode(s)
37
+ if bos:
38
+ t = [self.bos_id] + t
39
+ if eos:
40
+ t = t + [self.eos_id]
41
+ return t
42
+
43
+ def decode(self, t: List[int]) -> str:
44
+ return self.sp_model.decode(t)
45
+
46
+ def decode_tokens(self, tokens: List[str]) -> str:
47
+ text = self.sp_model.DecodePieces(tokens)
48
+ return text
49
+
50
+ def convert_token_to_id(self, token):
51
+ """ Converts a token (str) in an id using the vocab. """
52
+ if token in self.special_tokens:
53
+ return self.special_tokens[token]
54
+ return self.sp_model.PieceToId(token)
55
+
56
+ def convert_id_to_token(self, index):
57
+ """Converts an index (integer) in a token (str) using the vocab."""
58
+ if index in self.index_special_tokens or index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
59
+ return ""
60
+ return self.sp_model.IdToPiece(index)
61
+
62
+
63
+ class ChatGLMTokenizer(PreTrainedTokenizer):
64
+ vocab_files_names = {"vocab_file": "tokenizer.model"}
65
+
66
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
67
+
68
+ def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs):
69
+ self.name = "GLMTokenizer"
70
+
71
+ self.vocab_file = vocab_file
72
+ self.tokenizer = SPTokenizer(vocab_file)
73
+ self.special_tokens = {
74
+ "<bos>": self.tokenizer.bos_id,
75
+ "<eos>": self.tokenizer.eos_id,
76
+ "<pad>": self.tokenizer.pad_id
77
+ }
78
+ super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
79
+
80
+ def get_command(self, token):
81
+ if token in self.special_tokens:
82
+ return self.special_tokens[token]
83
+ assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
84
+ return self.tokenizer.special_tokens[token]
85
+
86
+ @property
87
+ def unk_token(self) -> str:
88
+ return "<unk>"
89
+
90
+ @property
91
+ def pad_token(self) -> str:
92
+ return "<unk>"
93
+
94
+ @property
95
+ def pad_token_id(self):
96
+ return self.get_command("<pad>")
97
+
98
+ @property
99
+ def eos_token(self) -> str:
100
+ return "</s>"
101
+
102
+ @property
103
+ def eos_token_id(self):
104
+ return self.get_command("<eos>")
105
+
106
+ @property
107
+ def vocab_size(self):
108
+ return self.tokenizer.n_words
109
+
110
+ def get_vocab(self):
111
+ """ Returns vocab as a dict """
112
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
113
+ vocab.update(self.added_tokens_encoder)
114
+ return vocab
115
+
116
+ def _tokenize(self, text, **kwargs):
117
+ return self.tokenizer.tokenize(text)
118
+
119
+ def _convert_token_to_id(self, token):
120
+ """ Converts a token (str) in an id using the vocab. """
121
+ return self.tokenizer.convert_token_to_id(token)
122
+
123
+ def _convert_id_to_token(self, index):
124
+ """Converts an index (integer) in a token (str) using the vocab."""
125
+ return self.tokenizer.convert_id_to_token(index)
126
+
127
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
128
+ return self.tokenizer.decode_tokens(tokens)
129
+
130
+ def save_vocabulary(self, save_directory, filename_prefix=None):
131
+ """
132
+ Save the vocabulary and special tokens file to a directory.
133
+
134
+ Args:
135
+ save_directory (`str`):
136
+ The directory in which to save the vocabulary.
137
+ filename_prefix (`str`, *optional*):
138
+ An optional prefix to add to the named of the saved files.
139
+
140
+ Returns:
141
+ `Tuple(str)`: Paths to the files saved.
142
+ """
143
+ if os.path.isdir(save_directory):
144
+ vocab_file = os.path.join(
145
+ save_directory, self.vocab_files_names["vocab_file"]
146
+ )
147
+ else:
148
+ vocab_file = save_directory
149
+
150
+ with open(self.vocab_file, 'rb') as fin:
151
+ proto_str = fin.read()
152
+
153
+ with open(vocab_file, "wb") as writer:
154
+ writer.write(proto_str)
155
+
156
+ return (vocab_file,)
157
+
158
+ def get_prefix_tokens(self):
159
+ prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
160
+ return prefix_tokens
161
+
162
+ def build_prompt(self, query, history=None):
163
+ if history is None:
164
+ history = []
165
+ prompt = ""
166
+ for i, (old_query, response) in enumerate(history):
167
+ prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
168
+ prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
169
+ return prompt
170
+
171
+ def build_inputs_with_special_tokens(
172
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
173
+ ) -> List[int]:
174
+ """
175
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
176
+ adding special tokens. A BERT sequence has the following format:
177
+
178
+ - single sequence: `[CLS] X [SEP]`
179
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
180
+
181
+ Args:
182
+ token_ids_0 (`List[int]`):
183
+ List of IDs to which the special tokens will be added.
184
+ token_ids_1 (`List[int]`, *optional*):
185
+ Optional second list of IDs for sequence pairs.
186
+
187
+ Returns:
188
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
189
+ """
190
+ prefix_tokens = self.get_prefix_tokens()
191
+ token_ids_0 = prefix_tokens + token_ids_0
192
+ if token_ids_1 is not None:
193
+ token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
194
+ return token_ids_0
195
+
196
+ def _pad(
197
+ self,
198
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
199
+ max_length: Optional[int] = None,
200
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
201
+ pad_to_multiple_of: Optional[int] = None,
202
+ return_attention_mask: Optional[bool] = None,
203
+ ) -> dict:
204
+ """
205
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
206
+
207
+ Args:
208
+ encoded_inputs:
209
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
210
+ max_length: maximum length of the returned list and optionally padding length (see below).
211
+ Will truncate by taking into account the special tokens.
212
+ padding_strategy: PaddingStrategy to use for padding.
213
+
214
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
215
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
216
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
217
+ The tokenizer padding sides are defined in self.padding_side:
218
+
219
+ - 'left': pads on the left of the sequences
220
+ - 'right': pads on the right of the sequences
221
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
222
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
223
+ `>= 7.5` (Volta).
224
+ return_attention_mask:
225
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
226
+ """
227
+ # Load from model defaults
228
+ assert self.padding_side == "left"
229
+
230
+ required_input = encoded_inputs[self.model_input_names[0]]
231
+ seq_length = len(required_input)
232
+
233
+ if padding_strategy == PaddingStrategy.LONGEST:
234
+ max_length = len(required_input)
235
+
236
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
237
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
238
+
239
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
240
+
241
+ # Initialize attention mask if not present.
242
+ if "attention_mask" not in encoded_inputs:
243
+ encoded_inputs["attention_mask"] = [1] * seq_length
244
+
245
+ if "position_ids" not in encoded_inputs:
246
+ encoded_inputs["position_ids"] = list(range(seq_length))
247
+
248
+ if needs_to_be_padded:
249
+ difference = max_length - len(required_input)
250
+
251
+ if "attention_mask" in encoded_inputs:
252
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
253
+ if "position_ids" in encoded_inputs:
254
+ encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
255
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
256
+
257
+ return encoded_inputs
ChatGLM2/support/tokenizer/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7dc4c393423b76e4373e5157ddc34803a0189ba96b21ddbb40269d31468a6f2
3
+ size 1018370