csukuangfj
commited on
Commit
·
cefadc7
1
Parent(s):
5256556
first commit
Browse files- v1.15.1/.gitattributes +4 -0
- v1.15.1/onnxruntime-linux-armhf.zip +3 -0
- v1.15.1/onnxruntime-linux-armhf/GIT_COMMIT_ID +1 -0
- v1.15.1/onnxruntime-linux-armhf/LICENSE +21 -0
- v1.15.1/onnxruntime-linux-armhf/Privacy.md +21 -0
- v1.15.1/onnxruntime-linux-armhf/README.md +57 -0
- v1.15.1/onnxruntime-linux-armhf/ThirdPartyNotices.txt +0 -0
- v1.15.1/onnxruntime-linux-armhf/VERSION_NUMBER +1 -0
- v1.15.1/onnxruntime-linux-armhf/include/cpu_provider_factory.h +19 -0
- v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_c_api.h +0 -0
- v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_cxx_api.h +0 -0
- v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_cxx_inline.h +2035 -0
- v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_run_options_config_keys.h +32 -0
- v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_session_options_config_keys.h +199 -0
- v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_training_c_api.h +630 -0
- v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_training_cxx_api.h +361 -0
- v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_training_cxx_inline.h +256 -0
- v1.15.1/onnxruntime-linux-armhf/include/provider_options.h +18 -0
- v1.15.1/onnxruntime-linux-armhf/lib/libonnxruntime.so +1 -0
- v1.15.1/onnxruntime-linux-armhf/lib/libonnxruntime.so.1.15.1 +3 -0
v1.15.1/.gitattributes
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.so* filter=lfs diff=lfs merge=lfs -text
|
2 |
+
onnxruntime-linux-armhf/lib/libonnxruntime.so filter=lfs diff=lfs merge=lfs -text
|
3 |
+
onnxruntime-linux-armhf/lib/libonnxruntime.so.1.15.1 filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
v1.15.1/onnxruntime-linux-armhf.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3fc357bdbf365ae86dda5c126cff1afd40c7a8e88c5d45059f381d1dacb34255
|
3 |
+
size 11373251
|
v1.15.1/onnxruntime-linux-armhf/GIT_COMMIT_ID
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
baeece44ba075009c6bfe95891a8c1b3d4571cb3
|
v1.15.1/onnxruntime-linux-armhf/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Microsoft Corporation
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
v1.15.1/onnxruntime-linux-armhf/Privacy.md
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Privacy
|
2 |
+
|
3 |
+
## Data Collection
|
4 |
+
The software may collect information about you and your use of the software and send it to Microsoft. Microsoft may use this information to provide services and improve our products and services. You may turn off the telemetry as described in the repository. There are also some features in the software that may enable you and Microsoft to collect data from users of your applications. If you use these features, you must comply with applicable law, including providing appropriate notices to users of your applications together with a copy of Microsoft's privacy statement. Our privacy statement is located at https://go.microsoft.com/fwlink/?LinkID=824704. You can learn more about data collection and use in the help documentation and our privacy statement. Your use of the software operates as your consent to these practices.
|
5 |
+
|
6 |
+
***
|
7 |
+
|
8 |
+
### Private Builds
|
9 |
+
No data collection is performed when using your private builds built from source code.
|
10 |
+
|
11 |
+
### Official Builds
|
12 |
+
ONNX Runtime does not maintain any independent telemetry collection mechanisms outside of what is provided by the platforms it supports. However, where applicable, ONNX Runtime will take advantage of platform-supported telemetry systems to collect trace events with the goal of improving product quality.
|
13 |
+
|
14 |
+
Currently telemetry is only implemented for Windows builds and is turned **ON** by default in the official builds distributed in their respective package management repositories ([see here](../README.md#binaries)). This may be expanded to cover other platforms in the future. Data collection is implemented via 'Platform Telemetry' per vendor platform providers (see [telemetry.h](../onnxruntime/core/platform/telemetry.h)).
|
15 |
+
|
16 |
+
#### Technical Details
|
17 |
+
The Windows provider uses the [TraceLogging](https://docs.microsoft.com/en-us/windows/win32/tracelogging/trace-logging-about) API for its implementation. This enables ONNX Runtime trace events to be collected by the operating system, and based on user consent, this data may be periodically sent to Microsoft servers following GDPR and privacy regulations for anonymity and data access controls.
|
18 |
+
|
19 |
+
Windows ML and onnxruntime C APIs allow Trace Logging to be turned on/off (see [API pages](../README.md#api-documentation) for details).
|
20 |
+
For information on how to enable and disable telemetry, see [C API: Telemetry](./C_API.md#telemetry).
|
21 |
+
There are equivalent APIs in the C#, Python, and Java language bindings as well.
|
v1.15.1/onnxruntime-linux-armhf/README.md
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center"><img width="50%" src="docs/images/ONNX_Runtime_logo_dark.png" /></p>
|
2 |
+
|
3 |
+
**ONNX Runtime is a cross-platform inference and training machine-learning accelerator**.
|
4 |
+
|
5 |
+
**ONNX Runtime inference** can enable faster customer experiences and lower costs, supporting models from deep learning frameworks such as PyTorch and TensorFlow/Keras as well as classical machine learning libraries such as scikit-learn, LightGBM, XGBoost, etc. ONNX Runtime is compatible with different hardware, drivers, and operating systems, and provides optimal performance by leveraging hardware accelerators where applicable alongside graph optimizations and transforms. [Learn more →](https://www.onnxruntime.ai/docs/#onnx-runtime-for-inferencing)
|
6 |
+
|
7 |
+
**ONNX Runtime training** can accelerate the model training time on multi-node NVIDIA GPUs for transformer models with a one-line addition for existing PyTorch training scripts. [Learn more →](https://www.onnxruntime.ai/docs/#onnx-runtime-for-training)
|
8 |
+
|
9 |
+
|
10 |
+
## Get Started & Resources
|
11 |
+
|
12 |
+
* **General Information**: [onnxruntime.ai](https://onnxruntime.ai)
|
13 |
+
|
14 |
+
* **Usage documention and tutorials**: [onnxruntime.ai/docs](https://onnxruntime.ai/docs)
|
15 |
+
|
16 |
+
* **YouTube video tutorials**: [youtube.com/@ONNXRuntime](https://www.youtube.com/@ONNXRuntime)
|
17 |
+
|
18 |
+
* [**Upcoming Release Roadmap**](https://github.com/microsoft/onnxruntime/wiki/Upcoming-Release-Roadmap)
|
19 |
+
|
20 |
+
* **Companion sample repositories**:
|
21 |
+
- ONNX Runtime Inferencing: [microsoft/onnxruntime-inference-examples](https://github.com/microsoft/onnxruntime-inference-examples)
|
22 |
+
- ONNX Runtime Training: [microsoft/onnxruntime-training-examples](https://github.com/microsoft/onnxruntime-training-examples)
|
23 |
+
|
24 |
+
|
25 |
+
## Build Pipeline Status
|
26 |
+
|System|Inference|Training|
|
27 |
+
|---|---|---|
|
28 |
+
|Windows|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20CPU%20CI%20Pipeline?label=Windows+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=9)<br>[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20CI%20Pipeline?label=Windows+GPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=10)<br>[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20TensorRT%20CI%20Pipeline?label=Windows+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=47)||
|
29 |
+
|Linux|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20CI%20Pipeline?label=Linux+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=11)<br>[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20Minimal%20Build%20E2E%20CI%20Pipeline?label=Linux+CPU+Minimal+Build)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=64)<br>[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20CI%20Pipeline?label=Linux+GPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=12)<br>[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20TensorRT%20CI%20Pipeline?label=Linux+GPU+TensorRT)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=45)<br>[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20OpenVINO%20CI%20Pipeline?label=Linux+OpenVINO)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=55)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-ci-pipeline?label=Linux+CPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=86)<br>[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining-linux-gpu-ci-pipeline?label=Linux+GPU+Training)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=84)<br>[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/orttraining/orttraining-ortmodule-distributed?label=Training+Distributed)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=148)|
|
30 |
+
|Mac|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/MacOS%20CI%20Pipeline?label=MacOS+CPU)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=13)||
|
31 |
+
|Android|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Android%20CI%20Pipeline?label=Android)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=53)||
|
32 |
+
|iOS|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/iOS%20CI%20Pipeline?label=iOS)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=134)||
|
33 |
+
|Web|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/ONNX%20Runtime%20Web%20CI%20Pipeline?label=Web)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=161)||
|
34 |
+
|Other|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/onnxruntime-binary-size-checks-ci-pipeline?repoName=microsoft%2Fonnxruntime&label=Binary+Size+Check)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=187&repoName=microsoft%2Fonnxruntime)<br>[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/onnxruntime-python-checks-ci-pipeline?label=Python+Checks)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=164)||
|
35 |
+
|
36 |
+
|
37 |
+
## Data/Telemetry
|
38 |
+
|
39 |
+
Windows distributions of this project may collect usage data and send it to Microsoft to help improve our products and services. See the [privacy statement](docs/Privacy.md) for more details.
|
40 |
+
|
41 |
+
## Contributions and Feedback
|
42 |
+
|
43 |
+
We welcome contributions! Please see the [contribution guidelines](CONTRIBUTING.md).
|
44 |
+
|
45 |
+
For feature requests or bug reports, please file a [GitHub Issue](https://github.com/Microsoft/onnxruntime/issues).
|
46 |
+
|
47 |
+
For general discussion or questions, please use [GitHub Discussions](https://github.com/microsoft/onnxruntime/discussions).
|
48 |
+
|
49 |
+
## Code of Conduct
|
50 |
+
|
51 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
52 |
+
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
53 |
+
or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
54 |
+
|
55 |
+
## License
|
56 |
+
|
57 |
+
This project is licensed under the [MIT License](LICENSE).
|
v1.15.1/onnxruntime-linux-armhf/ThirdPartyNotices.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
v1.15.1/onnxruntime-linux-armhf/VERSION_NUMBER
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1.15.1
|
v1.15.1/onnxruntime-linux-armhf/include/cpu_provider_factory.h
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
// Licensed under the MIT License.
|
3 |
+
|
4 |
+
#include "onnxruntime_c_api.h"
|
5 |
+
|
6 |
+
#ifdef __cplusplus
|
7 |
+
extern "C" {
|
8 |
+
#endif
|
9 |
+
|
10 |
+
/**
|
11 |
+
* \param use_arena zero: false. non-zero: true.
|
12 |
+
*/
|
13 |
+
ORT_EXPORT
|
14 |
+
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
|
15 |
+
ORT_ALL_ARGS_NONNULL;
|
16 |
+
|
17 |
+
#ifdef __cplusplus
|
18 |
+
}
|
19 |
+
#endif
|
v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_c_api.h
ADDED
The diff for this file is too large to render.
See raw diff
|
|
v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_cxx_api.h
ADDED
The diff for this file is too large to render.
See raw diff
|
|
v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_cxx_inline.h
ADDED
@@ -0,0 +1,2035 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
// Licensed under the MIT License.
|
3 |
+
|
4 |
+
// Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
|
5 |
+
// If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
|
6 |
+
//
|
7 |
+
// These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
|
8 |
+
// the main C++ file with implementation details.
|
9 |
+
|
10 |
+
namespace Ort {
|
11 |
+
|
12 |
+
namespace detail {
|
13 |
+
inline void ThrowStatus(const Status& st) {
|
14 |
+
std::string error_message = st.GetErrorMessage();
|
15 |
+
OrtErrorCode error_code = st.GetErrorCode();
|
16 |
+
ORT_CXX_API_THROW(std::move(error_message), error_code);
|
17 |
+
}
|
18 |
+
} // namespace detail
|
19 |
+
|
20 |
+
inline void ThrowOnError(OrtStatus* ort_status) {
|
21 |
+
if (ort_status) {
|
22 |
+
Ort::Status st(ort_status);
|
23 |
+
detail::ThrowStatus(st);
|
24 |
+
}
|
25 |
+
}
|
26 |
+
|
27 |
+
inline void ThrowOnError(const Status& st) {
|
28 |
+
if (st) {
|
29 |
+
detail::ThrowStatus(st);
|
30 |
+
}
|
31 |
+
}
|
32 |
+
|
33 |
+
inline Status::Status(OrtStatus* status) noexcept : Base<OrtStatus>{status} {
|
34 |
+
}
|
35 |
+
|
36 |
+
inline Status::Status(const std::exception& e) noexcept {
|
37 |
+
p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
|
38 |
+
}
|
39 |
+
|
40 |
+
inline Status::Status(const Exception& e) noexcept {
|
41 |
+
p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
|
42 |
+
}
|
43 |
+
|
44 |
+
inline Status::Status(const char* message, OrtErrorCode code) noexcept {
|
45 |
+
p_ = GetApi().CreateStatus(code, message);
|
46 |
+
}
|
47 |
+
|
48 |
+
inline std::string Status::GetErrorMessage() const {
|
49 |
+
std::string message(GetApi().GetErrorMessage(p_));
|
50 |
+
return message;
|
51 |
+
}
|
52 |
+
|
53 |
+
inline OrtErrorCode Status::GetErrorCode() const {
|
54 |
+
return GetApi().GetErrorCode(p_);
|
55 |
+
}
|
56 |
+
|
57 |
+
inline bool Status::IsOK() const noexcept {
|
58 |
+
return (p_ == nullptr);
|
59 |
+
}
|
60 |
+
|
61 |
+
// This template converts a C++ type into it's ONNXTensorElementDataType
|
62 |
+
template <typename T>
|
63 |
+
struct TypeToTensorType;
|
64 |
+
template <>
|
65 |
+
struct TypeToTensorType<float> {
|
66 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
67 |
+
};
|
68 |
+
template <>
|
69 |
+
struct TypeToTensorType<Float16_t> {
|
70 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
|
71 |
+
};
|
72 |
+
template <>
|
73 |
+
struct TypeToTensorType<BFloat16_t> {
|
74 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
|
75 |
+
};
|
76 |
+
template <>
|
77 |
+
struct TypeToTensorType<double> {
|
78 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
|
79 |
+
};
|
80 |
+
template <>
|
81 |
+
struct TypeToTensorType<int8_t> {
|
82 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
|
83 |
+
};
|
84 |
+
template <>
|
85 |
+
struct TypeToTensorType<int16_t> {
|
86 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
|
87 |
+
};
|
88 |
+
template <>
|
89 |
+
struct TypeToTensorType<int32_t> {
|
90 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
91 |
+
};
|
92 |
+
template <>
|
93 |
+
struct TypeToTensorType<int64_t> {
|
94 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
95 |
+
};
|
96 |
+
template <>
|
97 |
+
struct TypeToTensorType<uint8_t> {
|
98 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
99 |
+
};
|
100 |
+
template <>
|
101 |
+
struct TypeToTensorType<uint16_t> {
|
102 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
|
103 |
+
};
|
104 |
+
template <>
|
105 |
+
struct TypeToTensorType<uint32_t> {
|
106 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
|
107 |
+
};
|
108 |
+
template <>
|
109 |
+
struct TypeToTensorType<uint64_t> {
|
110 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
|
111 |
+
};
|
112 |
+
template <>
|
113 |
+
struct TypeToTensorType<bool> {
|
114 |
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
115 |
+
};
|
116 |
+
|
117 |
+
inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
|
118 |
+
: allocator_(allocator), p_(p), size_(size) {
|
119 |
+
}
|
120 |
+
|
121 |
+
inline MemoryAllocation::~MemoryAllocation() {
|
122 |
+
if (p_ != nullptr) {
|
123 |
+
// We do not throw out of destructor
|
124 |
+
auto ret = GetApi().AllocatorFree(allocator_, p_);
|
125 |
+
static_cast<void>(ret);
|
126 |
+
}
|
127 |
+
}
|
128 |
+
|
129 |
+
inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
|
130 |
+
*this = std::move(o);
|
131 |
+
}
|
132 |
+
|
133 |
+
inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept {
|
134 |
+
OrtAllocator* alloc = nullptr;
|
135 |
+
void* p = nullptr;
|
136 |
+
size_t sz = 0;
|
137 |
+
|
138 |
+
// Swap out this
|
139 |
+
std::swap(alloc, allocator_);
|
140 |
+
std::swap(p, p_);
|
141 |
+
std::swap(sz, size_);
|
142 |
+
|
143 |
+
// Swap with incoming
|
144 |
+
std::swap(allocator_, o.allocator_);
|
145 |
+
std::swap(p_, o.p_);
|
146 |
+
std::swap(size_, o.size_);
|
147 |
+
|
148 |
+
// Destroy this instance if needed
|
149 |
+
MemoryAllocation this_alloc(alloc, p, sz);
|
150 |
+
return *this;
|
151 |
+
}
|
152 |
+
|
153 |
+
namespace detail {
|
154 |
+
|
155 |
+
template <typename T>
|
156 |
+
inline void* AllocatorImpl<T>::Alloc(size_t size) {
|
157 |
+
void* out;
|
158 |
+
ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
|
159 |
+
return out;
|
160 |
+
}
|
161 |
+
|
162 |
+
template <typename T>
|
163 |
+
inline MemoryAllocation AllocatorImpl<T>::GetAllocation(size_t size) {
|
164 |
+
void* out;
|
165 |
+
ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
|
166 |
+
MemoryAllocation result(this->p_, out, size);
|
167 |
+
return result;
|
168 |
+
}
|
169 |
+
|
170 |
+
template <typename T>
|
171 |
+
inline void AllocatorImpl<T>::Free(void* p) {
|
172 |
+
ThrowOnError(GetApi().AllocatorFree(this->p_, p));
|
173 |
+
}
|
174 |
+
|
175 |
+
template <typename T>
|
176 |
+
inline ConstMemoryInfo AllocatorImpl<T>::GetInfo() const {
|
177 |
+
const OrtMemoryInfo* out;
|
178 |
+
ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
|
179 |
+
return ConstMemoryInfo{out};
|
180 |
+
}
|
181 |
+
|
182 |
+
} // namespace detail
|
183 |
+
|
184 |
+
inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
|
185 |
+
ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
|
186 |
+
}
|
187 |
+
|
188 |
+
inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
|
189 |
+
ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
|
190 |
+
}
|
191 |
+
|
192 |
+
namespace detail {
|
193 |
+
|
194 |
+
template <typename T>
|
195 |
+
inline std::string MemoryInfoImpl<T>::GetAllocatorName() const {
|
196 |
+
const char* name = nullptr;
|
197 |
+
ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
|
198 |
+
return std::string(name);
|
199 |
+
}
|
200 |
+
|
201 |
+
template <typename T>
|
202 |
+
inline OrtAllocatorType MemoryInfoImpl<T>::GetAllocatorType() const {
|
203 |
+
OrtAllocatorType type;
|
204 |
+
ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
|
205 |
+
return type;
|
206 |
+
}
|
207 |
+
|
208 |
+
template <typename T>
|
209 |
+
inline int MemoryInfoImpl<T>::GetDeviceId() const {
|
210 |
+
int id = 0;
|
211 |
+
ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
|
212 |
+
return id;
|
213 |
+
}
|
214 |
+
|
215 |
+
template <typename T>
|
216 |
+
inline OrtMemoryInfoDeviceType MemoryInfoImpl<T>::GetDeviceType() const {
|
217 |
+
OrtMemoryInfoDeviceType type;
|
218 |
+
GetApi().MemoryInfoGetDeviceType(this->p_, &type);
|
219 |
+
return type;
|
220 |
+
}
|
221 |
+
|
222 |
+
template <typename T>
|
223 |
+
inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
|
224 |
+
OrtMemType type;
|
225 |
+
ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
|
226 |
+
return type;
|
227 |
+
}
|
228 |
+
|
229 |
+
template <typename T>
|
230 |
+
template <typename U>
|
231 |
+
inline bool MemoryInfoImpl<T>::operator==(const MemoryInfoImpl<U>& o) const {
|
232 |
+
int comp_result = 0;
|
233 |
+
ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
|
234 |
+
return comp_result == 0;
|
235 |
+
}
|
236 |
+
|
237 |
+
} // namespace detail
|
238 |
+
|
239 |
+
inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
|
240 |
+
OrtMemoryInfo* p;
|
241 |
+
ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
|
242 |
+
return MemoryInfo(p);
|
243 |
+
}
|
244 |
+
|
245 |
+
inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
|
246 |
+
ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
|
247 |
+
}
|
248 |
+
|
249 |
+
namespace detail {
|
250 |
+
template <typename T>
|
251 |
+
inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
|
252 |
+
AllocatorWithDefaultOptions allocator;
|
253 |
+
return binding_utils::GetOutputNamesHelper(this->p_, allocator);
|
254 |
+
}
|
255 |
+
|
256 |
+
template <typename T>
|
257 |
+
inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
|
258 |
+
return binding_utils::GetOutputNamesHelper(this->p_, allocator);
|
259 |
+
}
|
260 |
+
|
261 |
+
template <typename T>
|
262 |
+
inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
|
263 |
+
AllocatorWithDefaultOptions allocator;
|
264 |
+
return binding_utils::GetOutputValuesHelper(this->p_, allocator);
|
265 |
+
}
|
266 |
+
|
267 |
+
template <typename T>
|
268 |
+
inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
|
269 |
+
return binding_utils::GetOutputValuesHelper(this->p_, allocator);
|
270 |
+
}
|
271 |
+
|
272 |
+
template <typename T>
|
273 |
+
inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
|
274 |
+
ThrowOnError(GetApi().BindInput(this->p_, name, value));
|
275 |
+
}
|
276 |
+
|
277 |
+
template <typename T>
|
278 |
+
inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
|
279 |
+
ThrowOnError(GetApi().BindOutput(this->p_, name, value));
|
280 |
+
}
|
281 |
+
|
282 |
+
template <typename T>
|
283 |
+
inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
|
284 |
+
ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
|
285 |
+
}
|
286 |
+
|
287 |
+
template <typename T>
|
288 |
+
inline void IoBindingImpl<T>::ClearBoundInputs() {
|
289 |
+
GetApi().ClearBoundInputs(this->p_);
|
290 |
+
}
|
291 |
+
|
292 |
+
template <typename T>
|
293 |
+
inline void IoBindingImpl<T>::ClearBoundOutputs() {
|
294 |
+
GetApi().ClearBoundOutputs(this->p_);
|
295 |
+
}
|
296 |
+
|
297 |
+
template <typename T>
|
298 |
+
inline void IoBindingImpl<T>::SynchronizeInputs() {
|
299 |
+
ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
|
300 |
+
}
|
301 |
+
|
302 |
+
template <typename T>
|
303 |
+
inline void IoBindingImpl<T>::SynchronizeOutputs() {
|
304 |
+
ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
|
305 |
+
}
|
306 |
+
|
307 |
+
namespace binding_utils {
|
308 |
+
inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
|
309 |
+
std::vector<std::string> result;
|
310 |
+
auto free_fn = detail::AllocatedFree(allocator);
|
311 |
+
using Ptr = std::unique_ptr<void, decltype(free_fn)>;
|
312 |
+
|
313 |
+
char* buffer = nullptr;
|
314 |
+
size_t* lengths = nullptr;
|
315 |
+
size_t count = 0;
|
316 |
+
ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
|
317 |
+
|
318 |
+
if (count == 0) {
|
319 |
+
return result;
|
320 |
+
}
|
321 |
+
|
322 |
+
Ptr buffer_g(buffer, free_fn);
|
323 |
+
Ptr lengths_g(lengths, free_fn);
|
324 |
+
|
325 |
+
result.reserve(count);
|
326 |
+
for (size_t i = 0; i < count; ++i) {
|
327 |
+
auto sz = *lengths;
|
328 |
+
result.emplace_back(buffer, sz);
|
329 |
+
buffer += sz;
|
330 |
+
++lengths;
|
331 |
+
}
|
332 |
+
return result;
|
333 |
+
}
|
334 |
+
|
335 |
+
inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
|
336 |
+
std::vector<Value> result;
|
337 |
+
size_t owned = 0;
|
338 |
+
size_t output_count = 0;
|
339 |
+
// Lambda to release the buffer when no longer needed and
|
340 |
+
// make sure that we destroy all instances on exception
|
341 |
+
auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
|
342 |
+
if (buffer) {
|
343 |
+
while (owned < output_count) {
|
344 |
+
auto* p = buffer + owned++;
|
345 |
+
GetApi().ReleaseValue(*p);
|
346 |
+
}
|
347 |
+
allocator->Free(allocator, buffer);
|
348 |
+
}
|
349 |
+
};
|
350 |
+
using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
|
351 |
+
|
352 |
+
OrtValue** output_buffer = nullptr;
|
353 |
+
ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
|
354 |
+
if (output_count == 0) {
|
355 |
+
return result;
|
356 |
+
}
|
357 |
+
|
358 |
+
Ptr buffer_g(output_buffer, free_fn);
|
359 |
+
|
360 |
+
result.reserve(output_count);
|
361 |
+
for (size_t i = 0; i < output_count; ++i) {
|
362 |
+
result.emplace_back(output_buffer[i]);
|
363 |
+
++owned;
|
364 |
+
}
|
365 |
+
return result;
|
366 |
+
}
|
367 |
+
|
368 |
+
} // namespace binding_utils
|
369 |
+
} // namespace detail
|
370 |
+
|
371 |
+
inline IoBinding::IoBinding(Session& session) {
|
372 |
+
ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
|
373 |
+
}
|
374 |
+
|
375 |
+
inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
|
376 |
+
ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
|
377 |
+
}
|
378 |
+
|
379 |
+
inline ThreadingOptions::ThreadingOptions() {
|
380 |
+
ThrowOnError(GetApi().CreateThreadingOptions(&p_));
|
381 |
+
}
|
382 |
+
|
383 |
+
inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) {
|
384 |
+
ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
|
385 |
+
return *this;
|
386 |
+
}
|
387 |
+
|
388 |
+
inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) {
|
389 |
+
ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
|
390 |
+
return *this;
|
391 |
+
}
|
392 |
+
|
393 |
+
inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) {
|
394 |
+
ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
|
395 |
+
return *this;
|
396 |
+
}
|
397 |
+
|
398 |
+
inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() {
|
399 |
+
ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_));
|
400 |
+
return *this;
|
401 |
+
}
|
402 |
+
|
403 |
+
inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
|
404 |
+
ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
|
405 |
+
return *this;
|
406 |
+
}
|
407 |
+
|
408 |
+
inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
|
409 |
+
ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
|
410 |
+
return *this;
|
411 |
+
}
|
412 |
+
|
413 |
+
inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
|
414 |
+
ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
|
415 |
+
return *this;
|
416 |
+
}
|
417 |
+
|
418 |
+
inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
|
419 |
+
ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
|
420 |
+
if (strcmp(logid, "onnxruntime-node") == 0) {
|
421 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
422 |
+
} else {
|
423 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
424 |
+
}
|
425 |
+
}
|
426 |
+
|
427 |
+
inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
|
428 |
+
ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
|
429 |
+
if (strcmp(logid, "onnxruntime-node") == 0) {
|
430 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
431 |
+
} else {
|
432 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
433 |
+
}
|
434 |
+
}
|
435 |
+
|
436 |
+
inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
|
437 |
+
ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
|
438 |
+
if (strcmp(logid, "onnxruntime-node") == 0) {
|
439 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
440 |
+
} else {
|
441 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
442 |
+
}
|
443 |
+
}
|
444 |
+
|
445 |
+
inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
|
446 |
+
OrtLoggingLevel logging_level, _In_ const char* logid) {
|
447 |
+
ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
|
448 |
+
if (strcmp(logid, "onnxruntime-node") == 0) {
|
449 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
450 |
+
} else {
|
451 |
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
452 |
+
}
|
453 |
+
}
|
454 |
+
|
455 |
+
inline Env& Env::EnableTelemetryEvents() {
|
456 |
+
ThrowOnError(GetApi().EnableTelemetryEvents(p_));
|
457 |
+
return *this;
|
458 |
+
}
|
459 |
+
|
460 |
+
inline Env& Env::DisableTelemetryEvents() {
|
461 |
+
ThrowOnError(GetApi().DisableTelemetryEvents(p_));
|
462 |
+
return *this;
|
463 |
+
}
|
464 |
+
|
465 |
+
inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) {
|
466 |
+
ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
|
467 |
+
return *this;
|
468 |
+
}
|
469 |
+
|
470 |
+
inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
|
471 |
+
ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
|
472 |
+
return *this;
|
473 |
+
}
|
474 |
+
|
475 |
+
inline CustomOpDomain::CustomOpDomain(const char* domain) {
|
476 |
+
ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
|
477 |
+
}
|
478 |
+
|
479 |
+
inline void CustomOpDomain::Add(const OrtCustomOp* op) {
|
480 |
+
ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
|
481 |
+
}
|
482 |
+
|
483 |
+
inline RunOptions::RunOptions() {
|
484 |
+
ThrowOnError(GetApi().CreateRunOptions(&p_));
|
485 |
+
}
|
486 |
+
|
487 |
+
inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
|
488 |
+
ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
|
489 |
+
return *this;
|
490 |
+
}
|
491 |
+
|
492 |
+
inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
|
493 |
+
ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
|
494 |
+
return *this;
|
495 |
+
}
|
496 |
+
|
497 |
+
inline int RunOptions::GetRunLogVerbosityLevel() const {
|
498 |
+
int out;
|
499 |
+
ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
|
500 |
+
return out;
|
501 |
+
}
|
502 |
+
|
503 |
+
inline int RunOptions::GetRunLogSeverityLevel() const {
|
504 |
+
int out;
|
505 |
+
ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
|
506 |
+
return out;
|
507 |
+
}
|
508 |
+
|
509 |
+
inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
|
510 |
+
ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
|
511 |
+
return *this;
|
512 |
+
}
|
513 |
+
|
514 |
+
inline const char* RunOptions::GetRunTag() const {
|
515 |
+
const char* out;
|
516 |
+
ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
|
517 |
+
return out;
|
518 |
+
}
|
519 |
+
|
520 |
+
inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
|
521 |
+
ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
|
522 |
+
return *this;
|
523 |
+
}
|
524 |
+
|
525 |
+
inline RunOptions& RunOptions::SetTerminate() {
|
526 |
+
ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
|
527 |
+
return *this;
|
528 |
+
}
|
529 |
+
|
530 |
+
inline RunOptions& RunOptions::UnsetTerminate() {
|
531 |
+
ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
|
532 |
+
return *this;
|
533 |
+
}
|
534 |
+
|
535 |
+
namespace detail {
|
536 |
+
|
537 |
+
template <typename T>
|
538 |
+
inline Ort::SessionOptions ConstSessionOptionsImpl<T>::Clone() const {
|
539 |
+
OrtSessionOptions* out;
|
540 |
+
ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
|
541 |
+
return SessionOptions{out};
|
542 |
+
}
|
543 |
+
|
544 |
+
template <typename T>
|
545 |
+
inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
|
546 |
+
size_t size = 0;
|
547 |
+
// Feed nullptr for the data buffer to query the true size of the string value
|
548 |
+
Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
|
549 |
+
|
550 |
+
std::string out;
|
551 |
+
out.resize(size);
|
552 |
+
Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
|
553 |
+
out.resize(size - 1); // remove the terminating character '\0'
|
554 |
+
|
555 |
+
return out;
|
556 |
+
}
|
557 |
+
|
558 |
+
template <typename T>
|
559 |
+
inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
|
560 |
+
int out = 0;
|
561 |
+
Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
|
562 |
+
return static_cast<bool>(out);
|
563 |
+
}
|
564 |
+
|
565 |
+
template <typename T>
|
566 |
+
inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
|
567 |
+
if (!this->HasConfigEntry(config_key)) {
|
568 |
+
return def;
|
569 |
+
}
|
570 |
+
|
571 |
+
return this->GetConfigEntry(config_key);
|
572 |
+
}
|
573 |
+
|
574 |
+
template <typename T>
|
575 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetIntraOpNumThreads(int intra_op_num_threads) {
|
576 |
+
ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
|
577 |
+
return *this;
|
578 |
+
}
|
579 |
+
|
580 |
+
template <typename T>
|
581 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetInterOpNumThreads(int inter_op_num_threads) {
|
582 |
+
ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
|
583 |
+
return *this;
|
584 |
+
}
|
585 |
+
|
586 |
+
template <typename T>
|
587 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
|
588 |
+
ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
|
589 |
+
return *this;
|
590 |
+
}
|
591 |
+
|
592 |
+
template <typename T>
|
593 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
|
594 |
+
ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
|
595 |
+
return *this;
|
596 |
+
}
|
597 |
+
|
598 |
+
template <typename T>
|
599 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
|
600 |
+
ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
|
601 |
+
return *this;
|
602 |
+
}
|
603 |
+
|
604 |
+
template <typename T>
|
605 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableProfiling() {
|
606 |
+
ThrowOnError(GetApi().DisableProfiling(this->p_));
|
607 |
+
return *this;
|
608 |
+
}
|
609 |
+
|
610 |
+
template <typename T>
|
611 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableOrtCustomOps() {
|
612 |
+
ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
|
613 |
+
return *this;
|
614 |
+
}
|
615 |
+
|
616 |
+
template <typename T>
|
617 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableMemPattern() {
|
618 |
+
ThrowOnError(GetApi().EnableMemPattern(this->p_));
|
619 |
+
return *this;
|
620 |
+
}
|
621 |
+
|
622 |
+
template <typename T>
|
623 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableMemPattern() {
|
624 |
+
ThrowOnError(GetApi().DisableMemPattern(this->p_));
|
625 |
+
return *this;
|
626 |
+
}
|
627 |
+
|
628 |
+
template <typename T>
|
629 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableCpuMemArena() {
|
630 |
+
ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
|
631 |
+
return *this;
|
632 |
+
}
|
633 |
+
|
634 |
+
template <typename T>
|
635 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableCpuMemArena() {
|
636 |
+
ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
|
637 |
+
return *this;
|
638 |
+
}
|
639 |
+
|
640 |
+
template <typename T>
|
641 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionMode execution_mode) {
|
642 |
+
ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
|
643 |
+
return *this;
|
644 |
+
}
|
645 |
+
|
646 |
+
template <typename T>
|
647 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
|
648 |
+
ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
|
649 |
+
return *this;
|
650 |
+
}
|
651 |
+
|
652 |
+
template <typename T>
|
653 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogSeverityLevel(int level) {
|
654 |
+
ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
|
655 |
+
return *this;
|
656 |
+
}
|
657 |
+
|
658 |
+
template <typename T>
|
659 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::Add(OrtCustomOpDomain* custom_op_domain) {
|
660 |
+
ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
|
661 |
+
return *this;
|
662 |
+
}
|
663 |
+
|
664 |
+
template <typename T>
|
665 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
|
666 |
+
ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
|
667 |
+
return *this;
|
668 |
+
}
|
669 |
+
|
670 |
+
template <typename T>
|
671 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
|
672 |
+
ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
|
673 |
+
return *this;
|
674 |
+
}
|
675 |
+
|
676 |
+
template <typename T>
|
677 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisablePerSessionThreads() {
|
678 |
+
ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
|
679 |
+
return *this;
|
680 |
+
}
|
681 |
+
|
682 |
+
template <typename T>
|
683 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
|
684 |
+
const std::vector<Value>& ort_values) {
|
685 |
+
const size_t inputs_num = names.size();
|
686 |
+
if (inputs_num != ort_values.size()) {
|
687 |
+
ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
|
688 |
+
}
|
689 |
+
std::vector<const char*> names_ptr;
|
690 |
+
std::vector<const OrtValue*> ort_values_ptrs;
|
691 |
+
names_ptr.reserve(inputs_num);
|
692 |
+
ort_values_ptrs.reserve(inputs_num);
|
693 |
+
for (size_t i = 0; i < inputs_num; ++i) {
|
694 |
+
names_ptr.push_back(names[i].c_str());
|
695 |
+
ort_values_ptrs.push_back(ort_values[i]);
|
696 |
+
}
|
697 |
+
ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
|
698 |
+
return *this;
|
699 |
+
}
|
700 |
+
|
701 |
+
template <typename T>
|
702 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
|
703 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
|
704 |
+
return *this;
|
705 |
+
}
|
706 |
+
|
707 |
+
template <typename T>
|
708 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) {
|
709 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
|
710 |
+
return *this;
|
711 |
+
}
|
712 |
+
|
713 |
+
template <typename T>
|
714 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
|
715 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
|
716 |
+
return *this;
|
717 |
+
}
|
718 |
+
|
719 |
+
template <typename T>
|
720 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
|
721 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
|
722 |
+
return *this;
|
723 |
+
}
|
724 |
+
|
725 |
+
template <typename T>
|
726 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
|
727 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
|
728 |
+
return *this;
|
729 |
+
}
|
730 |
+
|
731 |
+
template <typename T>
|
732 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
|
733 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
|
734 |
+
return *this;
|
735 |
+
}
|
736 |
+
|
737 |
+
template <typename T>
|
738 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
|
739 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
|
740 |
+
return *this;
|
741 |
+
}
|
742 |
+
|
743 |
+
template <typename T>
|
744 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options) {
|
745 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options));
|
746 |
+
return *this;
|
747 |
+
}
|
748 |
+
|
749 |
+
template <typename T>
|
750 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider(
|
751 |
+
const std::string& provider_name,
|
752 |
+
const std::unordered_map<std::string, std::string>& provider_options) {
|
753 |
+
auto num_entries = provider_options.size();
|
754 |
+
std::vector<const char*> keys, values;
|
755 |
+
if (num_entries > 0) {
|
756 |
+
keys.reserve(num_entries);
|
757 |
+
values.reserve(num_entries);
|
758 |
+
|
759 |
+
for (const auto& entry : provider_options) {
|
760 |
+
keys.push_back(entry.first.c_str());
|
761 |
+
values.push_back(entry.second.c_str());
|
762 |
+
}
|
763 |
+
}
|
764 |
+
|
765 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
|
766 |
+
keys.data(), values.data(), num_entries));
|
767 |
+
|
768 |
+
return *this;
|
769 |
+
}
|
770 |
+
|
771 |
+
template <typename T>
|
772 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
|
773 |
+
ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
|
774 |
+
return *this;
|
775 |
+
}
|
776 |
+
|
777 |
+
template <typename T>
|
778 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
|
779 |
+
ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
|
780 |
+
return *this;
|
781 |
+
}
|
782 |
+
|
783 |
+
template <typename T>
|
784 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
|
785 |
+
ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
|
786 |
+
return *this;
|
787 |
+
}
|
788 |
+
|
789 |
+
template <typename T>
|
790 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
|
791 |
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
|
792 |
+
return *this;
|
793 |
+
}
|
794 |
+
|
795 |
+
template <typename T>
|
796 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name,
|
797 |
+
const CustomOpConfigs& custom_op_configs) {
|
798 |
+
// Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
|
799 |
+
// the custom op library.
|
800 |
+
for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
|
801 |
+
AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
|
802 |
+
}
|
803 |
+
|
804 |
+
ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
|
805 |
+
return *this;
|
806 |
+
}
|
807 |
+
|
808 |
+
template <typename T>
|
809 |
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
|
810 |
+
ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
|
811 |
+
return *this;
|
812 |
+
}
|
813 |
+
|
814 |
+
/// Session
|
815 |
+
template <typename T>
|
816 |
+
inline size_t ConstSessionImpl<T>::GetInputCount() const {
|
817 |
+
size_t out;
|
818 |
+
ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
|
819 |
+
return out;
|
820 |
+
}
|
821 |
+
|
822 |
+
template <typename T>
|
823 |
+
inline size_t ConstSessionImpl<T>::GetOutputCount() const {
|
824 |
+
size_t out;
|
825 |
+
ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
|
826 |
+
return out;
|
827 |
+
}
|
828 |
+
|
829 |
+
template <typename T>
|
830 |
+
inline size_t ConstSessionImpl<T>::GetOverridableInitializerCount() const {
|
831 |
+
size_t out;
|
832 |
+
ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
|
833 |
+
return out;
|
834 |
+
}
|
835 |
+
|
836 |
+
template <typename T>
|
837 |
+
inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
|
838 |
+
char* out;
|
839 |
+
ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
|
840 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
841 |
+
}
|
842 |
+
|
843 |
+
template <typename T>
|
844 |
+
inline AllocatedStringPtr ConstSessionImpl<T>::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
|
845 |
+
char* out;
|
846 |
+
ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
|
847 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
848 |
+
}
|
849 |
+
|
850 |
+
template <typename T>
|
851 |
+
inline AllocatedStringPtr ConstSessionImpl<T>::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const {
|
852 |
+
char* out;
|
853 |
+
ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
|
854 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
855 |
+
}
|
856 |
+
|
857 |
+
template <typename T>
|
858 |
+
inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
|
859 |
+
uint64_t out;
|
860 |
+
ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
|
861 |
+
return out;
|
862 |
+
}
|
863 |
+
|
864 |
+
template <typename T>
|
865 |
+
inline ModelMetadata ConstSessionImpl<T>::GetModelMetadata() const {
|
866 |
+
OrtModelMetadata* out;
|
867 |
+
ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
|
868 |
+
return ModelMetadata{out};
|
869 |
+
}
|
870 |
+
|
871 |
+
template <typename T>
|
872 |
+
inline TypeInfo ConstSessionImpl<T>::GetInputTypeInfo(size_t index) const {
|
873 |
+
OrtTypeInfo* out;
|
874 |
+
ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
|
875 |
+
return TypeInfo{out};
|
876 |
+
}
|
877 |
+
|
878 |
+
template <typename T>
|
879 |
+
inline TypeInfo ConstSessionImpl<T>::GetOutputTypeInfo(size_t index) const {
|
880 |
+
OrtTypeInfo* out;
|
881 |
+
ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
|
882 |
+
return TypeInfo{out};
|
883 |
+
}
|
884 |
+
|
885 |
+
template <typename T>
|
886 |
+
inline TypeInfo ConstSessionImpl<T>::GetOverridableInitializerTypeInfo(size_t index) const {
|
887 |
+
OrtTypeInfo* out;
|
888 |
+
ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
|
889 |
+
return TypeInfo{out};
|
890 |
+
}
|
891 |
+
|
892 |
+
template <typename T>
|
893 |
+
inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
894 |
+
const char* const* output_names, size_t output_count) {
|
895 |
+
std::vector<Value> output_values;
|
896 |
+
output_values.reserve(output_count);
|
897 |
+
for (size_t i = 0; i < output_count; i++)
|
898 |
+
output_values.emplace_back(nullptr);
|
899 |
+
Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
|
900 |
+
return output_values;
|
901 |
+
}
|
902 |
+
|
903 |
+
template <typename T>
|
904 |
+
inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
905 |
+
const char* const* output_names, Value* output_values, size_t output_count) {
|
906 |
+
static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
|
907 |
+
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
|
908 |
+
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
|
909 |
+
ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
|
910 |
+
}
|
911 |
+
|
912 |
+
template <typename T>
|
913 |
+
inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
|
914 |
+
ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
|
915 |
+
}
|
916 |
+
|
917 |
+
template <typename T>
|
918 |
+
inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
|
919 |
+
char* out = nullptr;
|
920 |
+
ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
|
921 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
922 |
+
}
|
923 |
+
|
924 |
+
} // namespace detail
|
925 |
+
|
926 |
+
inline SessionOptions::SessionOptions() {
|
927 |
+
ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
|
928 |
+
}
|
929 |
+
|
930 |
+
/// CustomOpConfigs
|
931 |
+
inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
|
932 |
+
std::string config_key = "custom_op.";
|
933 |
+
|
934 |
+
config_key += custom_op_name;
|
935 |
+
config_key += ".";
|
936 |
+
config_key += config;
|
937 |
+
|
938 |
+
return config_key;
|
939 |
+
}
|
940 |
+
|
941 |
+
inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
|
942 |
+
const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
|
943 |
+
flat_configs_[full_flat_key] = config_value;
|
944 |
+
return *this;
|
945 |
+
}
|
946 |
+
|
947 |
+
inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
|
948 |
+
return flat_configs_;
|
949 |
+
}
|
950 |
+
|
951 |
+
inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
|
952 |
+
ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
|
953 |
+
}
|
954 |
+
|
955 |
+
inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
|
956 |
+
OrtPrepackedWeightsContainer* prepacked_weights_container) {
|
957 |
+
ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
|
958 |
+
}
|
959 |
+
|
960 |
+
inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
|
961 |
+
ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
|
962 |
+
}
|
963 |
+
|
964 |
+
inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
|
965 |
+
const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
|
966 |
+
ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
|
967 |
+
prepacked_weights_container, &this->p_));
|
968 |
+
}
|
969 |
+
|
970 |
+
inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
|
971 |
+
char* out;
|
972 |
+
ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
|
973 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
974 |
+
}
|
975 |
+
|
976 |
+
inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
|
977 |
+
char* out;
|
978 |
+
ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
|
979 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
980 |
+
}
|
981 |
+
|
982 |
+
inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
|
983 |
+
char* out;
|
984 |
+
ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
|
985 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
986 |
+
}
|
987 |
+
|
988 |
+
inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const {
|
989 |
+
char* out;
|
990 |
+
ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
|
991 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
992 |
+
}
|
993 |
+
|
994 |
+
inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const {
|
995 |
+
char* out;
|
996 |
+
ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
|
997 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
998 |
+
}
|
999 |
+
|
1000 |
+
inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
|
1001 |
+
char* out;
|
1002 |
+
ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
|
1003 |
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
1004 |
+
}
|
1005 |
+
|
1006 |
+
inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
|
1007 |
+
auto deletor = detail::AllocatedFree(allocator);
|
1008 |
+
std::vector<AllocatedStringPtr> result;
|
1009 |
+
|
1010 |
+
char** out = nullptr;
|
1011 |
+
int64_t num_keys = 0;
|
1012 |
+
ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
|
1013 |
+
if (num_keys <= 0) {
|
1014 |
+
return result;
|
1015 |
+
}
|
1016 |
+
|
1017 |
+
// array of pointers will be freed
|
1018 |
+
std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
|
1019 |
+
// reserve may throw
|
1020 |
+
auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
|
1021 |
+
std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
|
1022 |
+
result.reserve(static_cast<size_t>(num_keys));
|
1023 |
+
strings_guard.release();
|
1024 |
+
for (int64_t i = 0; i < num_keys; ++i) {
|
1025 |
+
result.push_back(AllocatedStringPtr(out[i], deletor));
|
1026 |
+
}
|
1027 |
+
|
1028 |
+
return result;
|
1029 |
+
}
|
1030 |
+
|
1031 |
+
inline int64_t ModelMetadata::GetVersion() const {
|
1032 |
+
int64_t out;
|
1033 |
+
ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
|
1034 |
+
return out;
|
1035 |
+
}
|
1036 |
+
|
1037 |
+
namespace detail {
|
1038 |
+
|
1039 |
+
template <typename T>
|
1040 |
+
inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl<T>::GetElementType() const {
|
1041 |
+
ONNXTensorElementDataType out;
|
1042 |
+
ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
|
1043 |
+
return out;
|
1044 |
+
}
|
1045 |
+
|
1046 |
+
template <typename T>
|
1047 |
+
inline size_t TensorTypeAndShapeInfoImpl<T>::GetElementCount() const {
|
1048 |
+
size_t out;
|
1049 |
+
ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
|
1050 |
+
return static_cast<size_t>(out);
|
1051 |
+
}
|
1052 |
+
|
1053 |
+
template <typename T>
|
1054 |
+
inline size_t TensorTypeAndShapeInfoImpl<T>::GetDimensionsCount() const {
|
1055 |
+
size_t out;
|
1056 |
+
ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
|
1057 |
+
return out;
|
1058 |
+
}
|
1059 |
+
|
1060 |
+
template <typename T>
|
1061 |
+
inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
|
1062 |
+
ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
|
1063 |
+
}
|
1064 |
+
|
1065 |
+
template <typename T>
|
1066 |
+
inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
|
1067 |
+
ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
|
1068 |
+
}
|
1069 |
+
|
1070 |
+
template <typename T>
|
1071 |
+
inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
|
1072 |
+
std::vector<int64_t> out(GetDimensionsCount(), 0);
|
1073 |
+
ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
|
1074 |
+
return out;
|
1075 |
+
}
|
1076 |
+
|
1077 |
+
template <typename T>
|
1078 |
+
inline ConstTensorTypeAndShapeInfo TypeInfoImpl<T>::GetTensorTypeAndShapeInfo() const {
|
1079 |
+
const OrtTensorTypeAndShapeInfo* out;
|
1080 |
+
ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
|
1081 |
+
return ConstTensorTypeAndShapeInfo{out};
|
1082 |
+
}
|
1083 |
+
|
1084 |
+
template <typename T>
|
1085 |
+
inline ConstSequenceTypeInfo TypeInfoImpl<T>::GetSequenceTypeInfo() const {
|
1086 |
+
const OrtSequenceTypeInfo* out;
|
1087 |
+
ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
|
1088 |
+
return ConstSequenceTypeInfo{out};
|
1089 |
+
}
|
1090 |
+
|
1091 |
+
template <typename T>
|
1092 |
+
inline ConstMapTypeInfo TypeInfoImpl<T>::GetMapTypeInfo() const {
|
1093 |
+
const OrtMapTypeInfo* out;
|
1094 |
+
ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
|
1095 |
+
return ConstMapTypeInfo{out};
|
1096 |
+
}
|
1097 |
+
|
1098 |
+
template <typename T>
|
1099 |
+
inline ONNXType TypeInfoImpl<T>::GetONNXType() const {
|
1100 |
+
ONNXType out;
|
1101 |
+
ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
|
1102 |
+
return out;
|
1103 |
+
}
|
1104 |
+
|
1105 |
+
template <typename T>
|
1106 |
+
inline TypeInfo SequenceTypeInfoImpl<T>::GetSequenceElementType() const {
|
1107 |
+
OrtTypeInfo* output;
|
1108 |
+
ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
|
1109 |
+
return TypeInfo{output};
|
1110 |
+
}
|
1111 |
+
|
1112 |
+
template <typename T>
|
1113 |
+
inline TypeInfo OptionalTypeInfoImpl<T>::GetOptionalElementType() const {
|
1114 |
+
OrtTypeInfo* info;
|
1115 |
+
ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info));
|
1116 |
+
return TypeInfo{info};
|
1117 |
+
}
|
1118 |
+
|
1119 |
+
template <typename T>
|
1120 |
+
inline ONNXTensorElementDataType MapTypeInfoImpl<T>::GetMapKeyType() const {
|
1121 |
+
ONNXTensorElementDataType out;
|
1122 |
+
ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
|
1123 |
+
return out;
|
1124 |
+
}
|
1125 |
+
|
1126 |
+
template <typename T>
|
1127 |
+
inline TypeInfo MapTypeInfoImpl<T>::GetMapValueType() const {
|
1128 |
+
OrtTypeInfo* output;
|
1129 |
+
ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
|
1130 |
+
return TypeInfo{output};
|
1131 |
+
}
|
1132 |
+
|
1133 |
+
template <typename T>
|
1134 |
+
inline ConstOptionalTypeInfo TypeInfoImpl<T>::GetOptionalTypeInfo() const {
|
1135 |
+
const OrtOptionalTypeInfo* info;
|
1136 |
+
ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info));
|
1137 |
+
return ConstOptionalTypeInfo{info};
|
1138 |
+
}
|
1139 |
+
|
1140 |
+
} // namespace detail
|
1141 |
+
|
1142 |
+
namespace detail {
|
1143 |
+
|
1144 |
+
template <typename T>
|
1145 |
+
template <typename R>
|
1146 |
+
inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
|
1147 |
+
ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
|
1148 |
+
}
|
1149 |
+
|
1150 |
+
template <typename T>
|
1151 |
+
inline bool ConstValueImpl<T>::IsTensor() const {
|
1152 |
+
int out;
|
1153 |
+
ThrowOnError(GetApi().IsTensor(this->p_, &out));
|
1154 |
+
return out != 0;
|
1155 |
+
}
|
1156 |
+
|
1157 |
+
template <typename T>
|
1158 |
+
inline bool ConstValueImpl<T>::HasValue() const {
|
1159 |
+
int out;
|
1160 |
+
ThrowOnError(GetApi().HasValue(this->p_, &out));
|
1161 |
+
return out != 0;
|
1162 |
+
}
|
1163 |
+
|
1164 |
+
template <typename T>
|
1165 |
+
inline size_t ConstValueImpl<T>::GetCount() const {
|
1166 |
+
size_t out;
|
1167 |
+
ThrowOnError(GetApi().GetValueCount(this->p_, &out));
|
1168 |
+
return out;
|
1169 |
+
}
|
1170 |
+
|
1171 |
+
template <typename T>
|
1172 |
+
inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
|
1173 |
+
OrtValue* out;
|
1174 |
+
ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
|
1175 |
+
return Value{out};
|
1176 |
+
}
|
1177 |
+
|
1178 |
+
template <typename T>
|
1179 |
+
inline size_t ConstValueImpl<T>::GetStringTensorDataLength() const {
|
1180 |
+
size_t out;
|
1181 |
+
ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
|
1182 |
+
return out;
|
1183 |
+
}
|
1184 |
+
|
1185 |
+
template <typename T>
|
1186 |
+
inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
|
1187 |
+
size_t out;
|
1188 |
+
ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
|
1189 |
+
return out;
|
1190 |
+
}
|
1191 |
+
|
1192 |
+
template <typename T>
|
1193 |
+
template <typename R>
|
1194 |
+
inline const R* ConstValueImpl<T>::GetTensorData() const {
|
1195 |
+
R* out;
|
1196 |
+
ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
|
1197 |
+
return out;
|
1198 |
+
}
|
1199 |
+
|
1200 |
+
template <typename T>
|
1201 |
+
inline const void* ConstValueImpl<T>::GetTensorRawData() const {
|
1202 |
+
void* out;
|
1203 |
+
ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
|
1204 |
+
return out;
|
1205 |
+
}
|
1206 |
+
|
1207 |
+
template <typename T>
|
1208 |
+
inline TypeInfo ConstValueImpl<T>::GetTypeInfo() const {
|
1209 |
+
OrtTypeInfo* output;
|
1210 |
+
ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
|
1211 |
+
return TypeInfo{output};
|
1212 |
+
}
|
1213 |
+
|
1214 |
+
template <typename T>
|
1215 |
+
inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetTensorTypeAndShapeInfo() const {
|
1216 |
+
OrtTensorTypeAndShapeInfo* output;
|
1217 |
+
ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
|
1218 |
+
return TensorTypeAndShapeInfo{output};
|
1219 |
+
}
|
1220 |
+
|
1221 |
+
template <typename T>
|
1222 |
+
inline ConstMemoryInfo ConstValueImpl<T>::GetTensorMemoryInfo() const {
|
1223 |
+
const OrtMemoryInfo* mem_info;
|
1224 |
+
ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
|
1225 |
+
return ConstMemoryInfo(mem_info);
|
1226 |
+
}
|
1227 |
+
|
1228 |
+
template <typename T>
|
1229 |
+
inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
|
1230 |
+
ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
|
1231 |
+
}
|
1232 |
+
|
1233 |
+
template <typename T>
|
1234 |
+
inline std::string ConstValueImpl<T>::GetStringTensorElement(size_t element_index) const {
|
1235 |
+
size_t buffer_length;
|
1236 |
+
ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length));
|
1237 |
+
|
1238 |
+
std::string s;
|
1239 |
+
s.resize(buffer_length);
|
1240 |
+
ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0]));
|
1241 |
+
return s;
|
1242 |
+
}
|
1243 |
+
|
1244 |
+
template <typename T>
|
1245 |
+
inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
|
1246 |
+
ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
|
1247 |
+
}
|
1248 |
+
|
1249 |
+
#if !defined(DISABLE_SPARSE_TENSORS)
|
1250 |
+
template <typename T>
|
1251 |
+
inline OrtSparseFormat ConstValueImpl<T>::GetSparseFormat() const {
|
1252 |
+
OrtSparseFormat format;
|
1253 |
+
ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
|
1254 |
+
return format;
|
1255 |
+
}
|
1256 |
+
|
1257 |
+
template <typename T>
|
1258 |
+
inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorValuesTypeAndShapeInfo() const {
|
1259 |
+
OrtTensorTypeAndShapeInfo* output;
|
1260 |
+
ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
|
1261 |
+
return TensorTypeAndShapeInfo{output};
|
1262 |
+
}
|
1263 |
+
|
1264 |
+
template <typename T>
|
1265 |
+
inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
|
1266 |
+
OrtTensorTypeAndShapeInfo* output;
|
1267 |
+
ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
|
1268 |
+
return TensorTypeAndShapeInfo{output};
|
1269 |
+
}
|
1270 |
+
|
1271 |
+
template <typename T>
|
1272 |
+
template <typename R>
|
1273 |
+
inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
|
1274 |
+
const void* out;
|
1275 |
+
ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
|
1276 |
+
return reinterpret_cast<const R*>(out);
|
1277 |
+
}
|
1278 |
+
|
1279 |
+
template <typename T>
|
1280 |
+
inline bool ConstValueImpl<T>::IsSparseTensor() const {
|
1281 |
+
int out;
|
1282 |
+
ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
|
1283 |
+
return out != 0;
|
1284 |
+
}
|
1285 |
+
|
1286 |
+
template <typename T>
|
1287 |
+
template <typename R>
|
1288 |
+
inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {
|
1289 |
+
const void* out;
|
1290 |
+
ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
|
1291 |
+
return reinterpret_cast<const R*>(out);
|
1292 |
+
}
|
1293 |
+
|
1294 |
+
#endif
|
1295 |
+
|
1296 |
+
template <typename T>
|
1297 |
+
void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
|
1298 |
+
ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
|
1299 |
+
}
|
1300 |
+
|
1301 |
+
template <typename T>
|
1302 |
+
void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
|
1303 |
+
ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
|
1304 |
+
}
|
1305 |
+
|
1306 |
+
template <typename T>
|
1307 |
+
inline char* ValueImpl<T>::GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length) {
|
1308 |
+
char* result;
|
1309 |
+
ThrowOnError(GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result));
|
1310 |
+
return result;
|
1311 |
+
}
|
1312 |
+
|
1313 |
+
template <typename T>
|
1314 |
+
void* ValueImpl<T>::GetTensorMutableRawData() {
|
1315 |
+
void* out;
|
1316 |
+
ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
|
1317 |
+
return out;
|
1318 |
+
}
|
1319 |
+
|
1320 |
+
template <typename T>
|
1321 |
+
template <typename R>
|
1322 |
+
R* ValueImpl<T>::GetTensorMutableData() {
|
1323 |
+
R* out;
|
1324 |
+
ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
|
1325 |
+
return out;
|
1326 |
+
}
|
1327 |
+
|
1328 |
+
template <typename T>
|
1329 |
+
template <typename R>
|
1330 |
+
R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
|
1331 |
+
static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
|
1332 |
+
R* out;
|
1333 |
+
ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
|
1334 |
+
return *out;
|
1335 |
+
}
|
1336 |
+
|
1337 |
+
#if !defined(DISABLE_SPARSE_TENSORS)
|
1338 |
+
template <typename T>
|
1339 |
+
void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
|
1340 |
+
ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
|
1341 |
+
}
|
1342 |
+
|
1343 |
+
template <typename T>
|
1344 |
+
void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
|
1345 |
+
ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
|
1346 |
+
}
|
1347 |
+
|
1348 |
+
template <typename T>
|
1349 |
+
void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
|
1350 |
+
ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
|
1351 |
+
}
|
1352 |
+
|
1353 |
+
template <typename T>
|
1354 |
+
void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
|
1355 |
+
const int64_t* indices_data, size_t indices_num) {
|
1356 |
+
ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
|
1357 |
+
values_param.values_shape_len, values_param.data.p_data,
|
1358 |
+
indices_data, indices_num));
|
1359 |
+
}
|
1360 |
+
|
1361 |
+
template <typename T>
|
1362 |
+
void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
|
1363 |
+
const OrtSparseValuesParam& values,
|
1364 |
+
const int64_t* inner_indices_data, size_t inner_indices_num,
|
1365 |
+
const int64_t* outer_indices_data, size_t outer_indices_num) {
|
1366 |
+
ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
|
1367 |
+
inner_indices_data, inner_indices_num,
|
1368 |
+
outer_indices_data, outer_indices_num));
|
1369 |
+
}
|
1370 |
+
|
1371 |
+
template <typename T>
|
1372 |
+
void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
|
1373 |
+
const OrtSparseValuesParam& values,
|
1374 |
+
const Shape& indices_shape,
|
1375 |
+
const int32_t* indices_data) {
|
1376 |
+
ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
|
1377 |
+
indices_shape.shape, indices_shape.shape_len,
|
1378 |
+
indices_data));
|
1379 |
+
}
|
1380 |
+
|
1381 |
+
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
1382 |
+
|
1383 |
+
} // namespace detail
|
1384 |
+
|
1385 |
+
template <typename T>
|
1386 |
+
inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
|
1387 |
+
return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
|
1388 |
+
}
|
1389 |
+
|
1390 |
+
inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
|
1391 |
+
ONNXTensorElementDataType type) {
|
1392 |
+
OrtValue* out;
|
1393 |
+
ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
|
1394 |
+
return Value{out};
|
1395 |
+
}
|
1396 |
+
|
1397 |
+
template <typename T>
|
1398 |
+
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
|
1399 |
+
return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
|
1400 |
+
}
|
1401 |
+
|
1402 |
+
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
|
1403 |
+
OrtValue* out;
|
1404 |
+
ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
|
1405 |
+
return Value{out};
|
1406 |
+
}
|
1407 |
+
|
1408 |
+
#if !defined(DISABLE_SPARSE_TENSORS)
|
1409 |
+
|
1410 |
+
template <typename T>
|
1411 |
+
inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
|
1412 |
+
const Shape& values_shape) {
|
1413 |
+
return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
|
1414 |
+
}
|
1415 |
+
|
1416 |
+
inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
|
1417 |
+
const Shape& values_shape, ONNXTensorElementDataType type) {
|
1418 |
+
OrtValue* out;
|
1419 |
+
ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
|
1420 |
+
values_shape.shape, values_shape.shape_len, type, &out));
|
1421 |
+
return Value{out};
|
1422 |
+
}
|
1423 |
+
|
1424 |
+
template <typename T>
|
1425 |
+
inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
|
1426 |
+
return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
|
1427 |
+
}
|
1428 |
+
|
1429 |
+
inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
|
1430 |
+
ONNXTensorElementDataType type) {
|
1431 |
+
OrtValue* out;
|
1432 |
+
ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
|
1433 |
+
return Value{out};
|
1434 |
+
}
|
1435 |
+
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
1436 |
+
|
1437 |
+
inline Value Value::CreateMap(Value& keys, Value& values) {
|
1438 |
+
OrtValue* out;
|
1439 |
+
OrtValue* inputs[2] = {keys, values};
|
1440 |
+
ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
|
1441 |
+
return Value{out};
|
1442 |
+
}
|
1443 |
+
|
1444 |
+
inline Value Value::CreateSequence(std::vector<Value>& values) {
|
1445 |
+
OrtValue* out;
|
1446 |
+
std::vector<OrtValue*> values_ort{values.data(), values.data() + values.size()};
|
1447 |
+
ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
|
1448 |
+
return Value{out};
|
1449 |
+
}
|
1450 |
+
|
1451 |
+
template <typename T>
|
1452 |
+
inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
|
1453 |
+
OrtValue* out;
|
1454 |
+
ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
|
1455 |
+
return Value{out};
|
1456 |
+
}
|
1457 |
+
|
1458 |
+
//
|
1459 |
+
// Custom OP Inlines
|
1460 |
+
//
|
1461 |
+
inline Logger::Logger(const OrtLogger* logger) : logger_(logger) {
|
1462 |
+
Ort::ThrowOnError(GetApi().Logger_GetLoggingSeverityLevel(this->logger_, &this->cached_severity_level_));
|
1463 |
+
}
|
1464 |
+
|
1465 |
+
inline OrtLoggingLevel Logger::GetLoggingSeverityLevel() const noexcept {
|
1466 |
+
return cached_severity_level_;
|
1467 |
+
}
|
1468 |
+
|
1469 |
+
inline Status Logger::LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
|
1470 |
+
const char* func_name, const char* message) const noexcept {
|
1471 |
+
OrtStatus* status = GetApi().Logger_LogMessage(logger_, log_severity_level, message, file_path, line_number,
|
1472 |
+
func_name);
|
1473 |
+
return Status{status};
|
1474 |
+
}
|
1475 |
+
|
1476 |
+
// Disable warnings about the format string not being a literal (-Wformat-nonliteral and -Wformat-security)
|
1477 |
+
// for gcc and clang. The alternative is to use actual C-style variadic parameters and apply
|
1478 |
+
// __attribute__(format(printf...)), which does not work with variadic templates.
|
1479 |
+
#if defined(__GNUC__)
|
1480 |
+
#pragma GCC diagnostic push
|
1481 |
+
#pragma GCC diagnostic ignored "-Wformat-nonliteral"
|
1482 |
+
#pragma GCC diagnostic ignored "-Wformat-security"
|
1483 |
+
#elif defined(__clang__)
|
1484 |
+
#pragma clang diagnostic push
|
1485 |
+
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
1486 |
+
#pragma clang diagnostic ignored "-Wformat-security"
|
1487 |
+
#endif
|
1488 |
+
template <typename... Args>
|
1489 |
+
inline Status Logger::LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path,
|
1490 |
+
int line_number, const char* func_name, const char* format,
|
1491 |
+
Args&&... args) const noexcept {
|
1492 |
+
int msg_len = std::snprintf(nullptr, 0U, format, std::forward<Args>(args)...);
|
1493 |
+
|
1494 |
+
if (msg_len < 0) { // Formatting error
|
1495 |
+
return Status("Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL);
|
1496 |
+
}
|
1497 |
+
|
1498 |
+
OrtStatus* status = nullptr;
|
1499 |
+
const size_t buffer_size = static_cast<size_t>(msg_len) + 1U;
|
1500 |
+
|
1501 |
+
constexpr size_t kStackBufferSize = 1024;
|
1502 |
+
|
1503 |
+
if (buffer_size < kStackBufferSize) {
|
1504 |
+
char buffer[kStackBufferSize];
|
1505 |
+
snprintf(buffer, kStackBufferSize, format, std::forward<Args>(args)...);
|
1506 |
+
status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name);
|
1507 |
+
} else {
|
1508 |
+
// std::make_unique is only supported starting at C++14.
|
1509 |
+
#if (__cplusplus >= 201402L) || (_MSC_VER >= 1900)
|
1510 |
+
auto buffer = std::make_unique<char[]>(buffer_size);
|
1511 |
+
#else
|
1512 |
+
std::unique_ptr<char[]> buffer(new char[buffer_size]);
|
1513 |
+
#endif
|
1514 |
+
std::snprintf(buffer.get(), buffer_size, format, std::forward<Args>(args)...);
|
1515 |
+
status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name);
|
1516 |
+
}
|
1517 |
+
|
1518 |
+
return Status{status};
|
1519 |
+
}
|
1520 |
+
// Re-enable -Wformat-nonliteral and -Wformat-security
|
1521 |
+
#if defined(__GNUC__)
|
1522 |
+
#pragma GCC diagnostic pop
|
1523 |
+
#elif defined(__clang__)
|
1524 |
+
#pragma clang diagnostic pop
|
1525 |
+
#endif
|
1526 |
+
|
1527 |
+
inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
|
1528 |
+
}
|
1529 |
+
|
1530 |
+
inline size_t KernelContext::GetInputCount() const {
|
1531 |
+
size_t out = 0;
|
1532 |
+
Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
|
1533 |
+
return out;
|
1534 |
+
}
|
1535 |
+
|
1536 |
+
inline size_t KernelContext::GetOutputCount() const {
|
1537 |
+
size_t out = 0;
|
1538 |
+
Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
|
1539 |
+
return out;
|
1540 |
+
}
|
1541 |
+
|
1542 |
+
inline ConstValue KernelContext::GetInput(size_t index) const {
|
1543 |
+
const OrtValue* out = nullptr;
|
1544 |
+
Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
|
1545 |
+
return ConstValue{out};
|
1546 |
+
}
|
1547 |
+
|
1548 |
+
inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
|
1549 |
+
OrtValue* out = nullptr;
|
1550 |
+
Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
|
1551 |
+
return UnownedValue(out);
|
1552 |
+
}
|
1553 |
+
|
1554 |
+
inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
|
1555 |
+
OrtValue* out = nullptr;
|
1556 |
+
Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
|
1557 |
+
return UnownedValue(out);
|
1558 |
+
}
|
1559 |
+
|
1560 |
+
inline void* KernelContext::GetGPUComputeStream() const {
|
1561 |
+
void* out = nullptr;
|
1562 |
+
Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
|
1563 |
+
return out;
|
1564 |
+
}
|
1565 |
+
|
1566 |
+
inline OrtAllocator* KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const {
|
1567 |
+
OrtAllocator* out = nullptr;
|
1568 |
+
Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out));
|
1569 |
+
return out;
|
1570 |
+
}
|
1571 |
+
|
1572 |
+
inline Logger KernelContext::GetLogger() const {
|
1573 |
+
const OrtLogger* out = nullptr;
|
1574 |
+
ThrowOnError(GetApi().KernelContext_GetLogger(this->ctx_, &out));
|
1575 |
+
return Logger{out};
|
1576 |
+
}
|
1577 |
+
|
1578 |
+
inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
|
1579 |
+
Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
|
1580 |
+
}
|
1581 |
+
|
1582 |
+
namespace detail {
|
1583 |
+
template <typename T>
|
1584 |
+
inline KernelInfo KernelInfoImpl<T>::Copy() const {
|
1585 |
+
OrtKernelInfo* info_copy = nullptr;
|
1586 |
+
Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
|
1587 |
+
return KernelInfo{info_copy};
|
1588 |
+
}
|
1589 |
+
|
1590 |
+
template <typename T>
|
1591 |
+
inline size_t KernelInfoImpl<T>::GetInputCount() const {
|
1592 |
+
size_t out = 0;
|
1593 |
+
ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
|
1594 |
+
return out;
|
1595 |
+
}
|
1596 |
+
|
1597 |
+
template <typename T>
|
1598 |
+
inline size_t KernelInfoImpl<T>::GetOutputCount() const {
|
1599 |
+
size_t out = 0;
|
1600 |
+
ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
|
1601 |
+
return out;
|
1602 |
+
}
|
1603 |
+
|
1604 |
+
template <typename T>
|
1605 |
+
inline std::string KernelInfoImpl<T>::GetInputName(size_t index) const {
|
1606 |
+
size_t size = 0;
|
1607 |
+
|
1608 |
+
// Feed nullptr for the data buffer to query the true size of the string value
|
1609 |
+
Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
|
1610 |
+
|
1611 |
+
std::string out;
|
1612 |
+
out.resize(size);
|
1613 |
+
Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
|
1614 |
+
out.resize(size - 1); // remove the terminating character '\0'
|
1615 |
+
|
1616 |
+
return out;
|
1617 |
+
}
|
1618 |
+
|
1619 |
+
template <typename T>
|
1620 |
+
inline std::string KernelInfoImpl<T>::GetOutputName(size_t index) const {
|
1621 |
+
size_t size = 0;
|
1622 |
+
|
1623 |
+
// Feed nullptr for the data buffer to query the true size of the string value
|
1624 |
+
Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
|
1625 |
+
|
1626 |
+
std::string out;
|
1627 |
+
out.resize(size);
|
1628 |
+
Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
|
1629 |
+
out.resize(size - 1); // remove the terminating character '\0'
|
1630 |
+
|
1631 |
+
return out;
|
1632 |
+
}
|
1633 |
+
|
1634 |
+
template <typename T>
|
1635 |
+
inline TypeInfo KernelInfoImpl<T>::GetInputTypeInfo(size_t index) const {
|
1636 |
+
OrtTypeInfo* out = nullptr;
|
1637 |
+
ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
|
1638 |
+
return TypeInfo{out};
|
1639 |
+
}
|
1640 |
+
|
1641 |
+
template <typename T>
|
1642 |
+
inline TypeInfo KernelInfoImpl<T>::GetOutputTypeInfo(size_t index) const {
|
1643 |
+
OrtTypeInfo* out = nullptr;
|
1644 |
+
ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
|
1645 |
+
return TypeInfo{out};
|
1646 |
+
}
|
1647 |
+
|
1648 |
+
template <typename T>
|
1649 |
+
inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
|
1650 |
+
OrtValue* out = nullptr;
|
1651 |
+
ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
|
1652 |
+
return Value{out};
|
1653 |
+
}
|
1654 |
+
|
1655 |
+
template <typename T>
|
1656 |
+
inline ConstValue KernelInfoImpl<T>::GetTensorConstantInput(size_t index, int* is_constant) const {
|
1657 |
+
const OrtValue* out = nullptr;
|
1658 |
+
ThrowOnError(GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out));
|
1659 |
+
return ConstValue{out};
|
1660 |
+
}
|
1661 |
+
|
1662 |
+
template <typename T>
|
1663 |
+
inline std::string KernelInfoImpl<T>::GetNodeName() const {
|
1664 |
+
size_t size = 0;
|
1665 |
+
|
1666 |
+
// Feed nullptr for the data buffer to query the true size of the string value
|
1667 |
+
Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, nullptr, &size));
|
1668 |
+
|
1669 |
+
std::string out;
|
1670 |
+
out.resize(size);
|
1671 |
+
Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, &out[0], &size));
|
1672 |
+
out.resize(size - 1); // remove the terminating character '\0'
|
1673 |
+
|
1674 |
+
return out;
|
1675 |
+
}
|
1676 |
+
|
1677 |
+
template <typename T>
|
1678 |
+
inline Logger KernelInfoImpl<T>::GetLogger() const {
|
1679 |
+
const OrtLogger* out = nullptr;
|
1680 |
+
ThrowOnError(GetApi().KernelInfo_GetLogger(this->p_, &out));
|
1681 |
+
return Logger{out};
|
1682 |
+
}
|
1683 |
+
|
1684 |
+
inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
|
1685 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
|
1686 |
+
}
|
1687 |
+
|
1688 |
+
inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
|
1689 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
|
1690 |
+
}
|
1691 |
+
|
1692 |
+
inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
|
1693 |
+
size_t size = 0;
|
1694 |
+
// Feed nullptr for the data buffer to query the true size of the string attribute
|
1695 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
|
1696 |
+
|
1697 |
+
std::string out;
|
1698 |
+
out.resize(size);
|
1699 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
|
1700 |
+
out.resize(size - 1); // remove the terminating character '\0'
|
1701 |
+
out.swap(result);
|
1702 |
+
}
|
1703 |
+
|
1704 |
+
inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
|
1705 |
+
size_t size = 0;
|
1706 |
+
// Feed nullptr for the data buffer to query the true size of the attribute
|
1707 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
|
1708 |
+
|
1709 |
+
std::vector<float> out;
|
1710 |
+
out.resize(size);
|
1711 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
|
1712 |
+
out.swap(result);
|
1713 |
+
}
|
1714 |
+
|
1715 |
+
inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
|
1716 |
+
size_t size = 0;
|
1717 |
+
|
1718 |
+
// Feed nullptr for the data buffer to query the true size of the attribute
|
1719 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
|
1720 |
+
|
1721 |
+
std::vector<int64_t> out;
|
1722 |
+
out.resize(size);
|
1723 |
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
|
1724 |
+
out.swap(result);
|
1725 |
+
}
|
1726 |
+
} // namespace detail
|
1727 |
+
|
1728 |
+
inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
|
1729 |
+
|
1730 |
+
inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
|
1731 |
+
|
1732 |
+
inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
|
1733 |
+
const char** type_constraint_names,
|
1734 |
+
const ONNXTensorElementDataType* type_constraint_values,
|
1735 |
+
size_t type_constraint_count,
|
1736 |
+
const OpAttr* attr_values, size_t attr_count,
|
1737 |
+
size_t input_count, size_t output_count) {
|
1738 |
+
static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
|
1739 |
+
"OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
|
1740 |
+
auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
|
1741 |
+
OrtOp* op;
|
1742 |
+
Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
|
1743 |
+
static_cast<int>(type_constraint_count),
|
1744 |
+
attr_input_values,
|
1745 |
+
static_cast<int>(attr_count),
|
1746 |
+
static_cast<int>(input_count),
|
1747 |
+
static_cast<int>(output_count), &op));
|
1748 |
+
return Op{op};
|
1749 |
+
}
|
1750 |
+
|
1751 |
+
inline void Op::Invoke(const OrtKernelContext* context,
|
1752 |
+
const Value* input_values,
|
1753 |
+
size_t input_count,
|
1754 |
+
Value* output_values,
|
1755 |
+
size_t output_count) {
|
1756 |
+
static_assert(sizeof(Value) == sizeof(OrtValue*),
|
1757 |
+
"Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
|
1758 |
+
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
|
1759 |
+
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
|
1760 |
+
Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
|
1761 |
+
ort_output_values, static_cast<int>(output_count)));
|
1762 |
+
}
|
1763 |
+
|
1764 |
+
inline void Op::Invoke(const OrtKernelContext* context,
|
1765 |
+
const OrtValue* const* input_values,
|
1766 |
+
size_t input_count,
|
1767 |
+
OrtValue* const* output_values,
|
1768 |
+
size_t output_count) {
|
1769 |
+
Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
|
1770 |
+
output_values, static_cast<int>(output_count)));
|
1771 |
+
}
|
1772 |
+
|
1773 |
+
inline void CustomOpApi::ThrowOnError(OrtStatus* status) {
|
1774 |
+
Ort::ThrowOnError(status);
|
1775 |
+
}
|
1776 |
+
|
1777 |
+
template <>
|
1778 |
+
inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
1779 |
+
float out;
|
1780 |
+
Ort::ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
|
1781 |
+
return out;
|
1782 |
+
}
|
1783 |
+
|
1784 |
+
template <>
|
1785 |
+
inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
1786 |
+
int64_t out;
|
1787 |
+
Ort::ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
|
1788 |
+
return out;
|
1789 |
+
}
|
1790 |
+
|
1791 |
+
template <>
|
1792 |
+
inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
1793 |
+
size_t size = 0;
|
1794 |
+
std::string out;
|
1795 |
+
|
1796 |
+
// Feed nullptr for the data buffer to query the true size of the string attribute
|
1797 |
+
OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
|
1798 |
+
|
1799 |
+
if (status == nullptr) {
|
1800 |
+
out.resize(size);
|
1801 |
+
Ort::ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
|
1802 |
+
out.resize(size - 1); // remove the terminating character '\0'
|
1803 |
+
} else {
|
1804 |
+
Ort::ThrowOnError(status);
|
1805 |
+
}
|
1806 |
+
return out;
|
1807 |
+
}
|
1808 |
+
|
1809 |
+
template <>
|
1810 |
+
inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
1811 |
+
size_t size = 0;
|
1812 |
+
std::vector<float> out;
|
1813 |
+
|
1814 |
+
// Feed nullptr for the data buffer to query the true size of the attribute
|
1815 |
+
OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
|
1816 |
+
|
1817 |
+
if (status == nullptr) {
|
1818 |
+
out.resize(size);
|
1819 |
+
Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
|
1820 |
+
} else {
|
1821 |
+
Ort::ThrowOnError(status);
|
1822 |
+
}
|
1823 |
+
return out;
|
1824 |
+
}
|
1825 |
+
|
1826 |
+
template <>
|
1827 |
+
inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
|
1828 |
+
size_t size = 0;
|
1829 |
+
std::vector<int64_t> out;
|
1830 |
+
|
1831 |
+
// Feed nullptr for the data buffer to query the true size of the attribute
|
1832 |
+
OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
|
1833 |
+
|
1834 |
+
if (status == nullptr) {
|
1835 |
+
out.resize(size);
|
1836 |
+
Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
|
1837 |
+
} else {
|
1838 |
+
Ort::ThrowOnError(status);
|
1839 |
+
}
|
1840 |
+
return out;
|
1841 |
+
}
|
1842 |
+
inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) {
|
1843 |
+
OrtTensorTypeAndShapeInfo* out;
|
1844 |
+
Ort::ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
|
1845 |
+
return out;
|
1846 |
+
}
|
1847 |
+
|
1848 |
+
inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
|
1849 |
+
size_t out;
|
1850 |
+
Ort::ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
|
1851 |
+
return out;
|
1852 |
+
}
|
1853 |
+
|
1854 |
+
inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) {
|
1855 |
+
ONNXTensorElementDataType out;
|
1856 |
+
Ort::ThrowOnError(api_.GetTensorElementType(info, &out));
|
1857 |
+
return out;
|
1858 |
+
}
|
1859 |
+
|
1860 |
+
inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
|
1861 |
+
size_t out;
|
1862 |
+
Ort::ThrowOnError(api_.GetDimensionsCount(info, &out));
|
1863 |
+
return out;
|
1864 |
+
}
|
1865 |
+
|
1866 |
+
inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) {
|
1867 |
+
Ort::ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
|
1868 |
+
}
|
1869 |
+
|
1870 |
+
inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) {
|
1871 |
+
Ort::ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
|
1872 |
+
}
|
1873 |
+
|
1874 |
+
template <typename T>
|
1875 |
+
inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) {
|
1876 |
+
T* data;
|
1877 |
+
Ort::ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
|
1878 |
+
return data;
|
1879 |
+
}
|
1880 |
+
|
1881 |
+
inline const OrtMemoryInfo* CustomOpApi::GetTensorMemoryInfo(_In_ const OrtValue* value) {
|
1882 |
+
const OrtMemoryInfo* mem_info;
|
1883 |
+
Ort::ThrowOnError(api_.GetTensorMemoryInfo(value, &mem_info));
|
1884 |
+
return mem_info;
|
1885 |
+
}
|
1886 |
+
|
1887 |
+
template <typename T>
|
1888 |
+
inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) {
|
1889 |
+
T* data = nullptr;
|
1890 |
+
Ort::ThrowOnError(api_.GetTensorMutableData(const_cast<OrtValue*>(value), reinterpret_cast<void**>(&data)));
|
1891 |
+
return data;
|
1892 |
+
}
|
1893 |
+
|
1894 |
+
inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) {
|
1895 |
+
size_t out;
|
1896 |
+
Ort::ThrowOnError(api_.GetDimensionsCount(info, &out));
|
1897 |
+
std::vector<int64_t> output(out);
|
1898 |
+
Ort::ThrowOnError(api_.GetDimensions(info, output.data(), out));
|
1899 |
+
return output;
|
1900 |
+
}
|
1901 |
+
|
1902 |
+
inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) {
|
1903 |
+
api_.ReleaseTensorTypeAndShapeInfo(input);
|
1904 |
+
}
|
1905 |
+
|
1906 |
+
inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) {
|
1907 |
+
size_t out;
|
1908 |
+
Ort::ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
|
1909 |
+
return out;
|
1910 |
+
}
|
1911 |
+
|
1912 |
+
inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) {
|
1913 |
+
const OrtValue* out;
|
1914 |
+
Ort::ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
|
1915 |
+
return out;
|
1916 |
+
}
|
1917 |
+
|
1918 |
+
inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) {
|
1919 |
+
size_t out;
|
1920 |
+
Ort::ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
|
1921 |
+
return out;
|
1922 |
+
}
|
1923 |
+
|
1924 |
+
inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
|
1925 |
+
_In_ const int64_t* dim_values, size_t dim_count) {
|
1926 |
+
OrtValue* out;
|
1927 |
+
Ort::ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
|
1928 |
+
return out;
|
1929 |
+
}
|
1930 |
+
|
1931 |
+
inline void* CustomOpApi::KernelContext_GetGPUComputeStream(const OrtKernelContext* context) {
|
1932 |
+
void* out;
|
1933 |
+
Ort::ThrowOnError(api_.KernelContext_GetGPUComputeStream(context, &out));
|
1934 |
+
return out;
|
1935 |
+
}
|
1936 |
+
|
1937 |
+
inline OrtOpAttr* CustomOpApi::CreateOpAttr(_In_ const char* name,
|
1938 |
+
_In_ const void* data,
|
1939 |
+
_In_ int len,
|
1940 |
+
_In_ OrtOpAttrType type) {
|
1941 |
+
OrtOpAttr* op_attr{};
|
1942 |
+
Ort::ThrowOnError(api_.CreateOpAttr(name, data, len, type, &op_attr));
|
1943 |
+
return op_attr;
|
1944 |
+
}
|
1945 |
+
|
1946 |
+
inline void CustomOpApi::ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr) {
|
1947 |
+
api_.ReleaseOpAttr(op_attr);
|
1948 |
+
}
|
1949 |
+
|
1950 |
+
inline OrtOp* CustomOpApi::CreateOp(_In_ const OrtKernelInfo* info,
|
1951 |
+
_In_z_ const char* op_name,
|
1952 |
+
_In_z_ const char* domain,
|
1953 |
+
int version,
|
1954 |
+
_In_reads_(type_constraint_count) const char** type_constraint_names,
|
1955 |
+
_In_reads_(type_constraint_count) const ONNXTensorElementDataType* type_constraint_values,
|
1956 |
+
int type_constraint_count,
|
1957 |
+
_In_reads_(attr_count) const OrtOpAttr* const* attr_values,
|
1958 |
+
int attr_count,
|
1959 |
+
int input_count,
|
1960 |
+
int output_count) {
|
1961 |
+
OrtOp* ort_op{};
|
1962 |
+
Ort::ThrowOnError(api_.CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
|
1963 |
+
type_constraint_count, attr_values, attr_count, input_count, output_count, &ort_op));
|
1964 |
+
return ort_op;
|
1965 |
+
}
|
1966 |
+
|
1967 |
+
inline void CustomOpApi::InvokeOp(_In_ const OrtKernelContext* context,
|
1968 |
+
_In_ const OrtOp* ort_op,
|
1969 |
+
_In_ const OrtValue* const* input_values,
|
1970 |
+
_In_ int input_count,
|
1971 |
+
_Inout_ OrtValue* const* output_values,
|
1972 |
+
_In_ int output_count) {
|
1973 |
+
Ort::ThrowOnError(api_.InvokeOp(context, ort_op, input_values, input_count, output_values, output_count));
|
1974 |
+
}
|
1975 |
+
|
1976 |
+
inline void CustomOpApi::ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op) {
|
1977 |
+
api_.ReleaseOp(ort_op);
|
1978 |
+
}
|
1979 |
+
|
1980 |
+
inline OrtKernelInfo* CustomOpApi::CopyKernelInfo(_In_ const OrtKernelInfo* info) {
|
1981 |
+
OrtKernelInfo* info_copy{};
|
1982 |
+
Ort::ThrowOnError(api_.CopyKernelInfo(info, &info_copy));
|
1983 |
+
return info_copy;
|
1984 |
+
}
|
1985 |
+
|
1986 |
+
inline void CustomOpApi::ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy) {
|
1987 |
+
api_.ReleaseKernelInfo(info_copy);
|
1988 |
+
}
|
1989 |
+
|
1990 |
+
inline std::string GetVersionString() {
|
1991 |
+
return OrtGetApiBase()->GetVersionString();
|
1992 |
+
}
|
1993 |
+
|
1994 |
+
inline std::string GetBuildInfoString() {
|
1995 |
+
return GetApi().GetBuildInfoString();
|
1996 |
+
}
|
1997 |
+
|
1998 |
+
inline std::vector<std::string> GetAvailableProviders() {
|
1999 |
+
char** providers;
|
2000 |
+
int len;
|
2001 |
+
|
2002 |
+
auto release_fn = [&len](char** providers) {
|
2003 |
+
// This should always return nullptr.
|
2004 |
+
ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
|
2005 |
+
};
|
2006 |
+
|
2007 |
+
ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
|
2008 |
+
std::unique_ptr<char*, decltype(release_fn)> guard(providers, release_fn);
|
2009 |
+
std::vector<std::string> available_providers;
|
2010 |
+
available_providers.reserve(static_cast<size_t>(len));
|
2011 |
+
for (int i = 0; i < len; ++i) {
|
2012 |
+
available_providers.emplace_back(providers[i]);
|
2013 |
+
}
|
2014 |
+
return available_providers;
|
2015 |
+
}
|
2016 |
+
|
2017 |
+
template <typename TOp, typename TKernel>
|
2018 |
+
void CustomOpBase<TOp, TKernel>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
|
2019 |
+
ConstSessionOptions options) const {
|
2020 |
+
const TOp* derived = static_cast<const TOp*>(this);
|
2021 |
+
std::vector<std::string> keys = derived->GetSessionConfigKeys();
|
2022 |
+
|
2023 |
+
out.reserve(keys.size());
|
2024 |
+
|
2025 |
+
std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
|
2026 |
+
const size_t prefix_size = config_entry_key.length();
|
2027 |
+
|
2028 |
+
for (const auto& key : keys) {
|
2029 |
+
config_entry_key.resize(prefix_size);
|
2030 |
+
config_entry_key.append(key);
|
2031 |
+
out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
|
2032 |
+
}
|
2033 |
+
}
|
2034 |
+
|
2035 |
+
} // namespace Ort
|
v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_run_options_config_keys.h
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
// Licensed under the MIT License.
|
3 |
+
|
4 |
+
#pragma once
|
5 |
+
|
6 |
+
/*
|
7 |
+
* This file defines RunOptions Config Keys and format of the Config Values.
|
8 |
+
*
|
9 |
+
* The Naming Convention for a RunOptions Config Key,
|
10 |
+
* "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
|
11 |
+
* Such as "ep.cuda.use_arena"
|
12 |
+
* The Config Key cannot be empty
|
13 |
+
* The maximum length of the Config Key is 128
|
14 |
+
*
|
15 |
+
* The string format of a RunOptions Config Value is defined individually for each Config.
|
16 |
+
* The maximum length of the Config Value is 1024
|
17 |
+
*/
|
18 |
+
|
19 |
+
// Key for enabling shrinkages of user listed device memory arenas.
|
20 |
+
// Expects a list of semi-colon separated key value pairs separated by colon in the following format:
|
21 |
+
// "device_0:device_id_0;device_1:device_id_1"
|
22 |
+
// No white-spaces allowed in the provided list string.
|
23 |
+
// Currently, the only supported devices are : "cpu", "gpu" (case sensitive).
|
24 |
+
// If "cpu" is included in the list, DisableCpuMemArena() API must not be called (i.e.) arena for cpu should be enabled.
|
25 |
+
// Example usage: "cpu:0;gpu:0" (or) "gpu:0"
|
26 |
+
// By default, the value for this key is empty (i.e.) no memory arenas are shrunk
|
27 |
+
static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage";
|
28 |
+
|
29 |
+
// Set to '1' to not synchronize execution providers with CPU at the end of session run.
|
30 |
+
// Per default it will be set to '0'
|
31 |
+
// Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream.
|
32 |
+
static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers";
|
v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_session_options_config_keys.h
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
// Licensed under the MIT License.
|
3 |
+
|
4 |
+
#pragma once
|
5 |
+
|
6 |
+
/*
|
7 |
+
* This file defines SessionOptions Config Keys and format of the Config Values.
|
8 |
+
*
|
9 |
+
* The Naming Convention for a SessionOptions Config Key,
|
10 |
+
* "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
|
11 |
+
* Such as "ep.cuda.use_arena"
|
12 |
+
* The Config Key cannot be empty
|
13 |
+
* The maximum length of the Config Key is 128
|
14 |
+
*
|
15 |
+
* The string format of a SessionOptions Config Value is defined individually for each Config.
|
16 |
+
* The maximum length of the Config Value is 1024
|
17 |
+
*/
|
18 |
+
|
19 |
+
// Key for disable PrePacking,
|
20 |
+
// If the config value is set to "1" then the prepacking is disabled, otherwise prepacking is enabled (default value)
|
21 |
+
static const char* const kOrtSessionOptionsConfigDisablePrepacking = "session.disable_prepacking";
|
22 |
+
|
23 |
+
// A value of "1" means allocators registered in the env will be used. "0" means the allocators created in the session
|
24 |
+
// will be used. Use this to override the usage of env allocators on a per session level.
|
25 |
+
static const char* const kOrtSessionOptionsConfigUseEnvAllocators = "session.use_env_allocators";
|
26 |
+
|
27 |
+
// Set to 'ORT' (case sensitive) to load an ORT format model.
|
28 |
+
// If unset, model type will default to ONNX unless inferred from filename ('.ort' == ORT format) or bytes to be ORT
|
29 |
+
static const char* const kOrtSessionOptionsConfigLoadModelFormat = "session.load_model_format";
|
30 |
+
|
31 |
+
// Set to 'ORT' (case sensitive) to save optimized model in ORT format when SessionOptions.optimized_model_path is set.
|
32 |
+
// If unset, format will default to ONNX unless optimized_model_filepath ends in '.ort'.
|
33 |
+
static const char* const kOrtSessionOptionsConfigSaveModelFormat = "session.save_model_format";
|
34 |
+
|
35 |
+
// If a value is "1", flush-to-zero and denormal-as-zero are applied. The default is "0".
|
36 |
+
// When multiple sessions are created, a main thread doesn't override changes from succeeding session options,
|
37 |
+
// but threads in session thread pools follow option changes.
|
38 |
+
// When ORT runs with OpenMP, the same rule is applied, i.e. the first session option to flush-to-zero and
|
39 |
+
// denormal-as-zero is only applied to global OpenMP thread pool, which doesn't support per-session thread pool.
|
40 |
+
// Note that an alternative way not using this option at runtime is to train and export a model without denormals
|
41 |
+
// and that's recommended because turning this option on may hurt model accuracy.
|
42 |
+
static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.set_denormal_as_zero";
|
43 |
+
|
44 |
+
// It controls to run quantization model in QDQ (QuantizelinearDeQuantizelinear) format or not.
|
45 |
+
// "0": enable. ORT does fusion logic for QDQ format.
|
46 |
+
// "1": disable. ORT doesn't do fusion logic for QDQ format.
|
47 |
+
// Its default value is "0" unless the DirectML execution provider is registered, in which case it defaults to "1".
|
48 |
+
static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq";
|
49 |
+
|
50 |
+
// It controls whether to enable Double QDQ remover and Identical Children Consolidation
|
51 |
+
// "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
|
52 |
+
// "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
|
53 |
+
// Its default value is "0"
|
54 |
+
static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover";
|
55 |
+
|
56 |
+
// If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been
|
57 |
+
// completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the
|
58 |
+
// Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to
|
59 |
+
// 8-bit and back to float, but could impact accuracy. The impact on accuracy will be model specific and depend on
|
60 |
+
// other factors like whether the model was created using Quantization Aware Training or Post Training Quantization.
|
61 |
+
// As such, it's best to test to determine if enabling this works well for your scenario.
|
62 |
+
// The default value is "0"
|
63 |
+
// Available since version 1.11.
|
64 |
+
static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup";
|
65 |
+
|
66 |
+
// Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0".
|
67 |
+
// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
|
68 |
+
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
|
69 |
+
|
70 |
+
#ifdef ENABLE_TRAINING
|
71 |
+
// Specifies a list of op types for memory footprint reduction.
|
72 |
+
// The value should be a ","-delimited list of pair of
|
73 |
+
// <subgraph string : optimization strategy : number of subgraph to apply>.
|
74 |
+
// For example, "Gelu+Cast+:1:0,Dropout+:1:1".
|
75 |
+
// A valid "subgraph string" should be one subgraph representation output by ORT graph transformations.
|
76 |
+
// "optimization strategy" currently has valid values: 0 - disabled, 1 - recompute.
|
77 |
+
// "number of subgraph to apply" is used to control how many subgraphs to apply optimization, to avoid "oversaving"
|
78 |
+
// the memory.
|
79 |
+
static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.enable_memory_optimizer";
|
80 |
+
|
81 |
+
// Specifies the level for detecting subgraphs for memory footprint reduction.
|
82 |
+
// The value should be an integer. The default value is 0.
|
83 |
+
static const char* const kOrtSessionOptionsMemoryOptimizerProbeLevel = "optimization.enable_memory_probe_recompute_level";
|
84 |
+
#endif
|
85 |
+
|
86 |
+
// Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0".
|
87 |
+
// Using device allocators means the memory allocation is made using malloc/new.
|
88 |
+
static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers";
|
89 |
+
|
90 |
+
// Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking
|
91 |
+
// "0": thread will block if found no job to run
|
92 |
+
// "1": default, thread will spin a number of times before blocking
|
93 |
+
static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning";
|
94 |
+
static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning";
|
95 |
+
|
96 |
+
// Key for using model bytes directly for ORT format
|
97 |
+
// If a session is created using an input byte array contains the ORT format model data,
|
98 |
+
// By default we will copy the model bytes at the time of session creation to ensure the model bytes
|
99 |
+
// buffer is valid.
|
100 |
+
// Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller
|
101 |
+
// has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed.
|
102 |
+
static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly";
|
103 |
+
|
104 |
+
/// <summary>
|
105 |
+
/// Key for using the ORT format model flatbuffer bytes directly for initializers.
|
106 |
+
/// This avoids copying the bytes and reduces peak memory usage during model loading and initialization.
|
107 |
+
/// Requires `session.use_ort_model_bytes_directly` to be true.
|
108 |
+
/// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire
|
109 |
+
/// duration of the InferenceSession.
|
110 |
+
/// </summary>
|
111 |
+
static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers =
|
112 |
+
"session.use_ort_model_bytes_for_initializers";
|
113 |
+
|
114 |
+
// This should only be specified when exporting an ORT format model for use on a different platform.
|
115 |
+
// If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0"
|
116 |
+
// Available since version 1.11.
|
117 |
+
static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed";
|
118 |
+
|
119 |
+
// x64 SSE4.1/AVX2/AVX512(with no VNNI) has overflow problem with quantizied matrix multiplication with U8S8.
|
120 |
+
// To avoid this we need to use slower U8U8 matrix multiplication instead. This option, if
|
121 |
+
// turned on, use slower U8U8 matrix multiplications. Only effective with AVX2 or AVX512
|
122 |
+
// platforms.
|
123 |
+
static const char* const kOrtSessionOptionsAvx2PrecisionMode = "session.x64quantprecision";
|
124 |
+
|
125 |
+
// Specifies how minimal build graph optimizations are handled in a full build.
|
126 |
+
// These optimizations are at the extended level or higher.
|
127 |
+
// Possible values and their effects are:
|
128 |
+
// "save": Save runtime optimizations when saving an ORT format model.
|
129 |
+
// "apply": Only apply optimizations available in a minimal build.
|
130 |
+
// ""/<unspecified>: Apply optimizations available in a full build.
|
131 |
+
// Available since version 1.11.
|
132 |
+
static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations =
|
133 |
+
"optimization.minimal_build_optimizations";
|
134 |
+
|
135 |
+
// Note: The options specific to an EP should be specified prior to appending that EP to the session options object in
|
136 |
+
// order for them to take effect.
|
137 |
+
|
138 |
+
// Specifies a list of stop op types. Nodes of a type in the stop op types and nodes downstream from them will not be
|
139 |
+
// run by the NNAPI EP.
|
140 |
+
// The value should be a ","-delimited list of op types. For example, "Add,Sub".
|
141 |
+
// If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op
|
142 |
+
// exclusion, set the value to "".
|
143 |
+
static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops";
|
144 |
+
|
145 |
+
// Enabling dynamic block-sizing for multithreading.
|
146 |
+
// With a positive value, thread pool will split a task of N iterations to blocks of size starting from:
|
147 |
+
// N / (num_of_threads * dynamic_block_base)
|
148 |
+
// As execution progresses, the size will decrease according to the diminishing residual of N,
|
149 |
+
// meaning the task will be distributed in smaller granularity for better parallelism.
|
150 |
+
// For some models, it helps to reduce the variance of E2E inference latency and boost performance.
|
151 |
+
// The feature will not function by default, specify any positive integer, e.g. "4", to enable it.
|
152 |
+
// Available since version 1.11.
|
153 |
+
static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base";
|
154 |
+
|
155 |
+
// This option allows to decrease CPU usage between infrequent
|
156 |
+
// requests and forces any TP threads spinning stop immediately when the last of
|
157 |
+
// concurrent Run() call returns.
|
158 |
+
// Spinning is restarted on the next Run() call.
|
159 |
+
// Applies only to internal thread-pools
|
160 |
+
static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop";
|
161 |
+
|
162 |
+
// "1": all inconsistencies encountered during shape and type inference
|
163 |
+
// will result in failures.
|
164 |
+
// "0": in some cases warnings will be logged but processing will continue. The default.
|
165 |
+
// May be useful to expose bugs in models.
|
166 |
+
static const char* const kOrtSessionOptionsConfigStrictShapeTypeInference = "session.strict_shape_type_inference";
|
167 |
+
|
168 |
+
// The file saves configuration for partitioning node among logic streams
|
169 |
+
static const char* const kNodePartitionConfigFile = "session.node_partition_config_file";
|
170 |
+
|
171 |
+
// This Option allows setting affinities for intra op threads.
|
172 |
+
// Affinity string follows format:
|
173 |
+
// logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id
|
174 |
+
// Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to.
|
175 |
+
// e.g.1,2,3;4,5
|
176 |
+
// specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th.
|
177 |
+
// To ease the configuration, an "interval" is also allowed:
|
178 |
+
// e.g. 1-8;8-16;17-24
|
179 |
+
// orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth.
|
180 |
+
// Note:
|
181 |
+
// 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, since ort does not set affinity on the main thread which
|
182 |
+
// is started and managed by the calling app;
|
183 |
+
// 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors,
|
184 |
+
// an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group.
|
185 |
+
// Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary.
|
186 |
+
static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "session.intra_op_thread_affinities";
|
187 |
+
|
188 |
+
// This option will dump out the model to assist debugging any issues with layout transformation,
|
189 |
+
// and is primarily intended for developer usage. It is only relevant if an execution provider that requests
|
190 |
+
// NHWC layout is enabled such as NNAPI, XNNPACK or QNN.
|
191 |
+
//
|
192 |
+
// Default is off. Set to "1" to enable.
|
193 |
+
//
|
194 |
+
// If modified by layout transformation the model will be dumped after these steps:
|
195 |
+
// 1) insertion of the layout transformation Transpose nodes
|
196 |
+
// 2) after those are optimized using the transpose optimizer,
|
197 |
+
// 3) after the L1 transformers are applied to the updated graph.
|
198 |
+
// The model will be saved to filename post_layout_transform_step_<step_number>.onnx.
|
199 |
+
static const char* const kDebugLayoutTransformation = "session.debug_layout_transformation";
|
v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_training_c_api.h
ADDED
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
// Licensed under the MIT License.
|
3 |
+
|
4 |
+
// This file contains the training c apis.
|
5 |
+
|
6 |
+
#pragma once
|
7 |
+
#include <stdbool.h>
|
8 |
+
#include "onnxruntime_c_api.h"
|
9 |
+
|
10 |
+
/** \page training_c_cpp_api Training C & C++ APIs
|
11 |
+
*
|
12 |
+
* Training C and C++ APIs are an extension of the \ref c_cpp_api "onnxruntime core C and C++ APIs" and should be used in conjunction with them.
|
13 |
+
*
|
14 |
+
* In order to train a model with onnxruntime, the following training artifacts must be generated:
|
15 |
+
* - The training onnx model
|
16 |
+
* - The checkpoint directory
|
17 |
+
* - The optimizer onnx model
|
18 |
+
* - The eval onnx model model (optional)
|
19 |
+
*
|
20 |
+
* These training artifacts can be generated as part of an offline step using the python [utilities](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md) made available in the `onnxruntime-training` python package.
|
21 |
+
*
|
22 |
+
* After these artifacts have been generated, the C and C++ utilities listed in this documentation can be leveraged to perform training.
|
23 |
+
*
|
24 |
+
* If any problem is encountered, please create an [issue](https://github.com/microsoft/onnxruntime/issues/new) with your scenario and requirements, and we will be sure to respond and follow up on the request.
|
25 |
+
*
|
26 |
+
* <h1>Training C API</h1>
|
27 |
+
*
|
28 |
+
* ::OrtTrainingApi - Training C API functions.
|
29 |
+
*
|
30 |
+
* This C structure contains functions that enable users to perform training with onnxruntime.
|
31 |
+
*
|
32 |
+
* _Sample Code_:
|
33 |
+
*
|
34 |
+
* ```c
|
35 |
+
* #include <onnxruntime_training_api.h>
|
36 |
+
*
|
37 |
+
* OrtApi* g_ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
38 |
+
* OrtTrainingApi* g_ort_training_api = g_ort_api->GetTrainingApi(ORT_API_VERSION);
|
39 |
+
*
|
40 |
+
* OrtEnv* env = NULL;
|
41 |
+
* g_ort_api->CreateEnv(logging_level, logid, &env);
|
42 |
+
* OrtSessionOptions* session_options = NULL;
|
43 |
+
* g_ort_api->CreateSessionOptions(&session_options);
|
44 |
+
*
|
45 |
+
* OrtCheckpointState* state = NULL;
|
46 |
+
* g_ort_training_api->LoadCheckpoint(path_to_checkpoint, &state);
|
47 |
+
*
|
48 |
+
* OrtTrainingSession* training_session = NULL;
|
49 |
+
* g_ort_training_api->CreateTrainingSession(env, session_options, training_model_path,
|
50 |
+
* state, eval_model_path, optimizer_model_path,
|
51 |
+
* &training_session);
|
52 |
+
* // Training loop
|
53 |
+
* {
|
54 |
+
* g_ort_training_api->TrainStep(...);
|
55 |
+
* g_ort_training_api->OptimizerStep(...);
|
56 |
+
* g_ort_training_api->LazyResetGrad(...);
|
57 |
+
* }
|
58 |
+
*
|
59 |
+
* g_ort_training_api->ExportModelForInferencing(training_session, inference_model_path, ...);
|
60 |
+
* g_ort_training_api->SaveCheckpoint(state, path_to_checkpoint, false);
|
61 |
+
*
|
62 |
+
* g_ort_training_api->ReleaseTrainingSession(training_session);
|
63 |
+
* g_ort_training_api->ReleaseCheckpointState(state);
|
64 |
+
* ```
|
65 |
+
*
|
66 |
+
* > **Note**
|
67 |
+
* > The ::OrtCheckpointState contains the entire training state that the ::OrtTrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::OrtCheckpointState instance must outlive the lifetime of the ::OrtTrainingSession instance.
|
68 |
+
*
|
69 |
+
* <h1>Training C++ API</h1>
|
70 |
+
*
|
71 |
+
* @ref TrainingCpp - Training C++ API classes and functions.
|
72 |
+
*
|
73 |
+
* These C++ classes and functions enable users to perform training with onnxruntime.
|
74 |
+
*
|
75 |
+
* _Sample Code_:
|
76 |
+
*
|
77 |
+
* ```cc
|
78 |
+
* #include <onnxruntime_training_cxx_api.h>
|
79 |
+
*
|
80 |
+
* Ort::Env env;
|
81 |
+
* Ort::SessionOptions session_options;
|
82 |
+
*
|
83 |
+
* auto state = Ort::CheckpointState::LoadCheckpoint(path_to_checkpoint);
|
84 |
+
* auto training_session = Ort::TrainingSession(env, session_options, state, training_model_path,
|
85 |
+
* eval_model_path, optimizer_model_path);
|
86 |
+
*
|
87 |
+
* // Training Loop
|
88 |
+
* {
|
89 |
+
* training_session.TrainStep(...);
|
90 |
+
* training_session.OptimizerStep(...);
|
91 |
+
* training_session.LazyResetGrad(...);
|
92 |
+
* }
|
93 |
+
*
|
94 |
+
* training_session->ExportModelForInferencing(inference_model_path, ...);
|
95 |
+
* Ort::CheckpointState::SaveCheckpoint(state, path_to_checkpoint, false);
|
96 |
+
* ```
|
97 |
+
* > **Note**
|
98 |
+
* > The ::Ort::CheckpointState contains the entire training state that the ::Ort::TrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::Ort::CheckpointState instance must outlive the lifetime of the ::Ort::TrainingSession instance.
|
99 |
+
*/
|
100 |
+
|
101 |
+
/** @defgroup TrainingC Ort Training C API
|
102 |
+
* @{
|
103 |
+
*/
|
104 |
+
ORT_RUNTIME_CLASS(TrainingSession); // Type that enables performing training for the given user models.
|
105 |
+
ORT_RUNTIME_CLASS(CheckpointState); // Type that holds the training states for the training session.
|
106 |
+
|
107 |
+
/** \brief Type of property to be added to or returned from the ::OrtCheckpointState.
|
108 |
+
*/
|
109 |
+
typedef enum OrtPropertyType {
|
110 |
+
OrtIntProperty = 0,
|
111 |
+
OrtFloatProperty = 1,
|
112 |
+
OrtStringProperty = 2,
|
113 |
+
} OrtPropertyType;
|
114 |
+
|
115 |
+
/** \brief The Training C API that holds onnxruntime training function pointers
|
116 |
+
*
|
117 |
+
* All the Training C API functions are defined inside this structure as pointers to functions.
|
118 |
+
* Call OrtApi::GetTrainingApi to get a pointer to this struct.
|
119 |
+
*
|
120 |
+
* \nosubgrouping
|
121 |
+
*/
|
122 |
+
struct OrtTrainingApi {
|
123 |
+
/// \name Accessing The Training Session State
|
124 |
+
/// @{
|
125 |
+
|
126 |
+
/** \brief Load a checkpoint state from directory on disk into checkpoint_state.
|
127 |
+
*
|
128 |
+
* This function will parse a checkpoint directory, pull relevant files and load the training
|
129 |
+
* state into the checkpoint_state. This checkpoint state can then be used to create the
|
130 |
+
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
|
131 |
+
* session will resume training from the given checkpoint state.
|
132 |
+
* \note Note that the training session created with a checkpoint state uses this state to store the entire
|
133 |
+
* training state (including model parameters, its gradients, the optimizer states and the properties).
|
134 |
+
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
|
135 |
+
*
|
136 |
+
* \param[in] checkpoint_path Path to the checkpoint directory
|
137 |
+
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
|
138 |
+
*
|
139 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
140 |
+
*
|
141 |
+
*/
|
142 |
+
ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
|
143 |
+
_Outptr_ OrtCheckpointState** checkpoint_state);
|
144 |
+
|
145 |
+
/** \brief Save the given state to a checkpoint directory on disk.
|
146 |
+
*
|
147 |
+
* This function serializes the provided checkpoint state to a directory on disk.
|
148 |
+
* This checkpoint can later be loaded by invoking OrtTrainingApi::LoadCheckpoint to resume
|
149 |
+
* training from this snapshot of the state.
|
150 |
+
*
|
151 |
+
* \param[in] checkpoint_state The checkpoint state to save.
|
152 |
+
* \param[in] checkpoint_path Path to the checkpoint directory.
|
153 |
+
* \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not.
|
154 |
+
*
|
155 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
156 |
+
*
|
157 |
+
*/
|
158 |
+
ORT_API2_STATUS(SaveCheckpoint, _In_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* checkpoint_path,
|
159 |
+
const bool include_optimizer_state);
|
160 |
+
|
161 |
+
/// @}
|
162 |
+
|
163 |
+
/// \name Implementing The Training Loop
|
164 |
+
/// @{
|
165 |
+
/** \brief Create a training session that can be used to begin or resume training.
|
166 |
+
*
|
167 |
+
* This function creates a training session based on the env and session options provided that can
|
168 |
+
* begin or resume training from a given checkpoint state for the given onnx models.
|
169 |
+
* The checkpoint state represents the parameters of the training session which will be moved
|
170 |
+
* to the device specified by the user through the session options (if necessary).
|
171 |
+
* The training session requires four training artifacts
|
172 |
+
* - The training onnx model
|
173 |
+
* - The evaluation onnx model (optional)
|
174 |
+
* - The optimizer onnx model
|
175 |
+
* - The checkpoint directory
|
176 |
+
*
|
177 |
+
* These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
|
178 |
+
*
|
179 |
+
* \param[in] env Environment to be used for the training session.
|
180 |
+
* \param[in] options Session options that the user can customize for this training session.
|
181 |
+
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
|
182 |
+
* \param[in] train_model_path Model to be used to perform training.
|
183 |
+
* \param[in] eval_model_path Model to be used to perform evaluation.
|
184 |
+
* \param[in] optimizer_model_path Model to be used to perform gradient descent.
|
185 |
+
* \param[out] out Created training session.
|
186 |
+
*
|
187 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
188 |
+
*
|
189 |
+
*/
|
190 |
+
ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
|
191 |
+
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
|
192 |
+
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
|
193 |
+
_Outptr_ OrtTrainingSession** out);
|
194 |
+
|
195 |
+
/// @}
|
196 |
+
|
197 |
+
/// \name Model IO Information
|
198 |
+
/// @{
|
199 |
+
|
200 |
+
/** \brief Retrieves the number of user outputs in the training model.
|
201 |
+
*
|
202 |
+
* This function returns the number of outputs of the training model so that the user can
|
203 |
+
* allocate space for the number of outputs when OrtTrainingApi::TrainStep is invoked.
|
204 |
+
*
|
205 |
+
* \param[in] sess The `this` pointer to the training session.
|
206 |
+
* \param[out] out Number of user outputs in the training model.
|
207 |
+
*
|
208 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
209 |
+
*
|
210 |
+
*/
|
211 |
+
ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
212 |
+
|
213 |
+
/** \brief Retrieves the number of user outputs in the eval model.
|
214 |
+
*
|
215 |
+
* This function returns the number of outputs of the eval model so that the user can
|
216 |
+
* allocate space for the number of outputs when OrtTrainingApi::EvalStep is invoked.
|
217 |
+
*
|
218 |
+
* \param[in] sess The `this` pointer to the training session.
|
219 |
+
* \param[out] out Number of user outputs in the eval model.
|
220 |
+
*
|
221 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
222 |
+
*
|
223 |
+
*/
|
224 |
+
ORT_API2_STATUS(TrainingSessionGetEvalModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
225 |
+
|
226 |
+
/** \brief Retrieves the names of user outputs in the training model.
|
227 |
+
*
|
228 |
+
* This function returns the names of outputs of the training model that can be associated with the OrtValue(s)
|
229 |
+
* returned by the OrtTrainingApi::TrainStep function.
|
230 |
+
*
|
231 |
+
* \param[in] sess The `this` pointer to the training session.
|
232 |
+
* \param[in] index Index of the output name requested.
|
233 |
+
* \param[in] allocator Allocator to use to allocate the memory for the name.
|
234 |
+
* \param[out] output Name of the training model output at the given index.
|
235 |
+
*
|
236 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
237 |
+
*
|
238 |
+
*/
|
239 |
+
ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
|
240 |
+
|
241 |
+
/** \brief Retrieves the names of user outputs in the eval model.
|
242 |
+
*
|
243 |
+
* This function returns the names of outputs of the eval model that can be associated with the OrtValue(s) returned
|
244 |
+
* by the OrtTrainingApi::EvalStep function.
|
245 |
+
*
|
246 |
+
* \param[in] sess The `this` pointer to the training session.
|
247 |
+
* \param[in] index Index of the output name requested.
|
248 |
+
* \param[in] allocator Allocator to use to allocate the memory for the name.
|
249 |
+
* \param[out] output Name of the eval model output at the given index.
|
250 |
+
*
|
251 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
252 |
+
*
|
253 |
+
*/
|
254 |
+
ORT_API2_STATUS(TrainingSessionGetEvalModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
|
255 |
+
|
256 |
+
/// @}
|
257 |
+
|
258 |
+
/// \name Implementing The Training Loop
|
259 |
+
/// @{
|
260 |
+
|
261 |
+
/** \brief Reset the gradients of all trainable parameters to zero lazily.
|
262 |
+
*
|
263 |
+
* This function sets the internal state of the training session such that the gradients of the trainable
|
264 |
+
* parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are
|
265 |
+
* computed on the next invocation of the next OrtTrainingApi::TrainStep.
|
266 |
+
*
|
267 |
+
* \param[in] session The `this` pointer to the training session.
|
268 |
+
*
|
269 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
270 |
+
*
|
271 |
+
*/
|
272 |
+
ORT_API2_STATUS(LazyResetGrad, _Inout_ OrtTrainingSession* session);
|
273 |
+
|
274 |
+
/** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
|
275 |
+
*
|
276 |
+
* This function performs a training step that computes the outputs of the training model and the gradients
|
277 |
+
* of the trainable parameters for the given inputs. The train step is performed based on the training model
|
278 |
+
* that was provided to the training session.
|
279 |
+
* The OrtTrainingApi::TrainStep is equivalent of running forward propagation and backward propagation in a single
|
280 |
+
* step.
|
281 |
+
* The gradients computed are stored inside the training session state so they can be later consumed
|
282 |
+
* by the OrtTrainingApi::OptimizerStep function.
|
283 |
+
* The gradients can be lazily reset by invoking the OrtTrainingApi::LazyResetGrad function.
|
284 |
+
*
|
285 |
+
* \param[in] sess The `this` pointer to the training session.
|
286 |
+
* \param[in] run_options Run options for this training step.
|
287 |
+
* \param[in] inputs_len Number of user inputs to the training model.
|
288 |
+
* \param[in] inputs The user inputs to the training model.
|
289 |
+
* \param[in] outputs_len Number of user outputs expected from this training step.
|
290 |
+
* \param[out] outputs User outputs computed by train step.
|
291 |
+
*
|
292 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
293 |
+
*
|
294 |
+
*/
|
295 |
+
ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
296 |
+
_In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
|
297 |
+
_In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
298 |
+
|
299 |
+
/** \brief Computes the outputs for the eval model for the given inputs
|
300 |
+
*
|
301 |
+
* This function performs an eval step that computes the outputs of the eval model for the given inputs.
|
302 |
+
* The eval step is performed based on the eval model that was provided to the training session.
|
303 |
+
*
|
304 |
+
* \param[in] sess The `this` pointer to the training session.
|
305 |
+
* \param[in] run_options Run options for this eval step.
|
306 |
+
* \param[in] inputs_len Number of user inputs to the eval model.
|
307 |
+
* \param[in] inputs The user inputs to the eval model.
|
308 |
+
* \param[in] outputs_len Number of user outputs expected from this eval step.
|
309 |
+
* \param[out] outputs User outputs computed by eval step.
|
310 |
+
*
|
311 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
312 |
+
*
|
313 |
+
*/
|
314 |
+
ORT_API2_STATUS(EvalStep, _In_ const OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
315 |
+
_In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
|
316 |
+
_In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
317 |
+
|
318 |
+
/** \brief Sets the learning rate for this training session.
|
319 |
+
*
|
320 |
+
* This function allows users to set the learning rate for the training session. The current
|
321 |
+
* learning rate is maintained by the training session and can be overwritten by invoking
|
322 |
+
* this function with the desired learning rate. This function should not be used when a valid
|
323 |
+
* learning rate scheduler is registered. It should be used either to set the learning rate
|
324 |
+
* derived from a custom learning rate scheduler or to set a constant learning rate to be used
|
325 |
+
* throughout the training session.
|
326 |
+
* \note Please note that this function does not set the initial learning rate that may be needed
|
327 |
+
* by the predefined learning rate schedulers. To set the initial learning rate for learning
|
328 |
+
* rate schedulers, please look at the function OrtTrainingApi::RegisterLinearLRScheduler.
|
329 |
+
*
|
330 |
+
* \param[in] sess The `this` pointer to the training session.
|
331 |
+
* \param[in] learning_rate Desired learning rate to be set.
|
332 |
+
*
|
333 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
334 |
+
*
|
335 |
+
*/
|
336 |
+
ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate);
|
337 |
+
|
338 |
+
/** \brief Gets the current learning rate for this training session.
|
339 |
+
*
|
340 |
+
* This function allows users to get the learning rate for the training session. The current
|
341 |
+
* learning rate is maintained by the training session, and users can query it for the purpose
|
342 |
+
* of implementing their own learning rate schedulers.
|
343 |
+
*
|
344 |
+
* \param[in] sess The `this` pointer to the training session.
|
345 |
+
* \param[out] learning_rate Learning rate currently in use by the training session.
|
346 |
+
*
|
347 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
348 |
+
*
|
349 |
+
*/
|
350 |
+
ORT_API2_STATUS(GetLearningRate, _Inout_ OrtTrainingSession* sess, _Out_ float* learning_rate);
|
351 |
+
|
352 |
+
/** \brief Performs the weight updates for the trainable parameters using the optimizer model.
|
353 |
+
*
|
354 |
+
* This function performs the weight update step that updates the trainable parameters such that they
|
355 |
+
* take a step in the direction of their gradients (gradient descent). The optimizer step is performed
|
356 |
+
* based on the optimizer model that was provided to the training session.
|
357 |
+
* The updated parameters are stored inside the training state so that they can be used by the next
|
358 |
+
* OrtTrainingApi::TrainStep function call.
|
359 |
+
*
|
360 |
+
* \param[in] sess The `this` pointer to the training session.
|
361 |
+
* \param[in] run_options Run options for this optimizer step.
|
362 |
+
*
|
363 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
364 |
+
*
|
365 |
+
*/
|
366 |
+
ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
|
367 |
+
_In_opt_ const OrtRunOptions* run_options);
|
368 |
+
|
369 |
+
/** \brief Registers a linear learning rate scheduler for the training session.
|
370 |
+
*
|
371 |
+
* Register a linear learning rate scheduler that decays the learning rate by linearly updated
|
372 |
+
* multiplicative factor from the initial learning rate set on the training session to 0. The decay
|
373 |
+
* is performed after the initial warm up phase where the learning rate is linearly incremented
|
374 |
+
* from 0 to the initial learning rate provided.
|
375 |
+
*
|
376 |
+
* \param[in] sess The `this` pointer to the training session.
|
377 |
+
* \param[in] warmup_step_count Warmup steps for LR warmup.
|
378 |
+
* \param[in] total_step_count Total step count.
|
379 |
+
* \param[in] initial_lr The initial learning rate to be used by the training session.
|
380 |
+
*
|
381 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
382 |
+
*
|
383 |
+
*/
|
384 |
+
ORT_API2_STATUS(RegisterLinearLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ const int64_t warmup_step_count,
|
385 |
+
_In_ const int64_t total_step_count, _In_ const float initial_lr);
|
386 |
+
|
387 |
+
/** \brief Update the learning rate based on the registered learing rate scheduler.
|
388 |
+
*
|
389 |
+
* Takes a scheduler step that updates the learning rate that is being used by the training session.
|
390 |
+
* This function should typically be called before invoking the optimizer step for each round,
|
391 |
+
* or as determined necessary to update the learning rate being used by the training session.
|
392 |
+
* \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this
|
393 |
+
* function.
|
394 |
+
*
|
395 |
+
* \param[in] sess The `this` pointer to the training session.
|
396 |
+
*
|
397 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
398 |
+
*
|
399 |
+
*/
|
400 |
+
ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess);
|
401 |
+
|
402 |
+
/// @}
|
403 |
+
|
404 |
+
/// \name Accessing The Training Session State
|
405 |
+
/// @{
|
406 |
+
/** \brief Retrieves the size of all the parameters.
|
407 |
+
*
|
408 |
+
* Calculates the total number of primitive (datatype of the parameters) elements of all the parameters in the
|
409 |
+
* training state.
|
410 |
+
* When trainable_only argument is true, the size is calculated for trainable params only.
|
411 |
+
*
|
412 |
+
* \param[in] sess The `this` pointer to the training session.
|
413 |
+
* \param[out] out Size of all parameter elements.
|
414 |
+
* \param[in] trainable_only Whether to skip non-trainable parameters
|
415 |
+
*
|
416 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
417 |
+
*
|
418 |
+
*/
|
419 |
+
ORT_API2_STATUS(GetParametersSize, _Inout_ OrtTrainingSession* sess, _Out_ size_t* out, bool trainable_only);
|
420 |
+
|
421 |
+
/** \brief Copy all parameters to a contiguous buffer held by the argument parameters_buffer
|
422 |
+
*
|
423 |
+
* The parameters_buffer has to be of the size given by GetParametersSize api call,
|
424 |
+
* with matching setting for the argument trainable_only. All the target parameters must be of the same
|
425 |
+
* datatype. The OrtValue must be pre-allocated onto
|
426 |
+
* the desired device. This is a complementary function to OrtTrainingApi::CopyBufferToParameters.
|
427 |
+
* Parameter ordering is preserved.
|
428 |
+
* User is responsible for allocating and freeing the resources used by the parameters_buffer.
|
429 |
+
*
|
430 |
+
* \param[in] sess The `this` pointer to the training session.
|
431 |
+
* \param[in] trainable_only Whether to skip non-trainable parameters
|
432 |
+
* \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy onto.
|
433 |
+
*
|
434 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
435 |
+
*
|
436 |
+
*/
|
437 |
+
ORT_API2_STATUS(CopyParametersToBuffer, _Inout_ OrtTrainingSession* sess,
|
438 |
+
_Inout_ OrtValue* parameters_buffer, bool trainable_only);
|
439 |
+
|
440 |
+
/** \brief Copy parameter values from the given contiguous buffer held by parameters_buffer to the training state
|
441 |
+
*
|
442 |
+
* The parameters_buffer argument has to be of the size given by OrtTrainingApi::GetParametersSize api call,
|
443 |
+
* with matching setting for trainable_only argument. All the target parameters must be of the same
|
444 |
+
* datatype. This is a complementary function to OrtTrainingApi::CopyBufferToParameters
|
445 |
+
* and can be used to load updated buffer values onto the training state.
|
446 |
+
* Parameter ordering is preserved.
|
447 |
+
* User is responsible for allocating and freeing the resources used by the parameters_buffer.
|
448 |
+
*
|
449 |
+
* \param[in] sess The `this` pointer to the training session.
|
450 |
+
* \param[in] trainable_only Whether to skip non-trainable parameters
|
451 |
+
* \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy from.
|
452 |
+
*
|
453 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
454 |
+
*
|
455 |
+
*/
|
456 |
+
ORT_API2_STATUS(CopyBufferToParameters, _Inout_ OrtTrainingSession* sess,
|
457 |
+
_Inout_ OrtValue* parameters_buffer, bool trainable_only);
|
458 |
+
|
459 |
+
/// @}
|
460 |
+
|
461 |
+
/// \name Release Training Resources
|
462 |
+
/// @{
|
463 |
+
|
464 |
+
/** \brief Frees up the memory used up by the training session.
|
465 |
+
*
|
466 |
+
* This function frees up any memory that was allocated in the training session. The training
|
467 |
+
* session can no longer be used after this call.
|
468 |
+
*
|
469 |
+
*/
|
470 |
+
ORT_CLASS_RELEASE(TrainingSession);
|
471 |
+
|
472 |
+
/** \brief Frees up the memory used up by the checkpoint state.
|
473 |
+
*
|
474 |
+
* This function frees up any memory that was allocated in the checkpoint state. The checkpoint
|
475 |
+
* state can no longer be used after this call.
|
476 |
+
* \note Note that the checkpoint state must be released only after the training session has been released.
|
477 |
+
*
|
478 |
+
*/
|
479 |
+
ORT_CLASS_RELEASE(CheckpointState);
|
480 |
+
|
481 |
+
/// @}
|
482 |
+
|
483 |
+
/// \name Prepare For Inferencing
|
484 |
+
/// @{
|
485 |
+
/** \brief Export a model that can be used for inferencing.
|
486 |
+
*
|
487 |
+
* If the training session was provided with an eval model, the training session can generate
|
488 |
+
* an inference model if it knows the inference graph outputs. The input inference graph outputs
|
489 |
+
* are used to prune the eval model so that the inference model's outputs align with the provided outputs.
|
490 |
+
* The exported model is saved at the path provided and can be used for inferencing with InferenceSession.
|
491 |
+
* \note Note that the function re-loads the eval model from the path provided to OrtTrainingApi::CreateTrainingSession
|
492 |
+
* and expects that this path still be valid.
|
493 |
+
*
|
494 |
+
* \param[in] sess The `this` pointer to the training session.
|
495 |
+
* \param[in] inference_model_path Path where the inference model should be serialized to.
|
496 |
+
* \param[in] graph_outputs_len Size of the graph output names array.
|
497 |
+
* \param[in] graph_output_names Names of the outputs that are needed in the inference model.
|
498 |
+
*
|
499 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
500 |
+
*
|
501 |
+
*/
|
502 |
+
ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
|
503 |
+
_In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
|
504 |
+
_In_reads_(graph_outputs_len) const char* const* graph_output_names);
|
505 |
+
|
506 |
+
/// @}
|
507 |
+
|
508 |
+
/// \name Training Utilities
|
509 |
+
/// @{
|
510 |
+
/** \brief Sets the seed used for random number generation in Onnxruntime.
|
511 |
+
*
|
512 |
+
* Use this function to generate reproducible results. It should be noted that completely reproducible
|
513 |
+
* results are not guaranteed.
|
514 |
+
*
|
515 |
+
* \param[in] seed The seed to be set.
|
516 |
+
*
|
517 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
518 |
+
*
|
519 |
+
*/
|
520 |
+
ORT_API2_STATUS(SetSeed, _In_ const int64_t seed);
|
521 |
+
|
522 |
+
/// @}
|
523 |
+
|
524 |
+
/// \name Model IO Information
|
525 |
+
/// @{
|
526 |
+
/** \brief Retrieves the number of user inputs in the training model.
|
527 |
+
*
|
528 |
+
* This function returns the number of inputs of the training model so that the user can accordingly
|
529 |
+
* allocate the OrtValue(s) provided to the OrtTrainingApi::TrainStep function.
|
530 |
+
*
|
531 |
+
* \param[in] sess The `this` pointer to the training session.
|
532 |
+
* \param[out] out Number of user inputs in the training model.
|
533 |
+
*
|
534 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
535 |
+
*
|
536 |
+
*/
|
537 |
+
ORT_API2_STATUS(TrainingSessionGetTrainingModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
538 |
+
|
539 |
+
/** \brief Retrieves the number of user inputs in the eval model.
|
540 |
+
*
|
541 |
+
* This function returns the number of inputs of the eval model so that the user can accordingly
|
542 |
+
* allocate the OrtValue(s) provided to the OrtTrainingApi::EvalStep function.
|
543 |
+
*
|
544 |
+
* \param[in] sess The `this` pointer to the training session.
|
545 |
+
* \param[out] out Number of user inputs in the eval model.
|
546 |
+
*
|
547 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
548 |
+
*
|
549 |
+
*/
|
550 |
+
ORT_API2_STATUS(TrainingSessionGetEvalModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
551 |
+
|
552 |
+
/** \brief Retrieves the name of the user input at given index in the training model.
|
553 |
+
*
|
554 |
+
* This function returns the names of inputs of the training model that can be associated with the
|
555 |
+
* OrtValue(s) provided to the OrtTrainingApi::TrainStep function.
|
556 |
+
*
|
557 |
+
* \param[in] sess The `this` pointer to the training session.
|
558 |
+
* \param[in] index The index of the training model input name requested.
|
559 |
+
* \param[in] allocator The allocator to use to allocate the memory for the requested name.
|
560 |
+
* \param[out] output Name of the user input for the training model at the given index.
|
561 |
+
*
|
562 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
563 |
+
*
|
564 |
+
*/
|
565 |
+
ORT_API2_STATUS(TrainingSessionGetTrainingModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
|
566 |
+
_In_ OrtAllocator* allocator, _Outptr_ char** output);
|
567 |
+
|
568 |
+
/** \brief Retrieves the name of the user input at given index in the eval model.
|
569 |
+
*
|
570 |
+
* This function returns the names of inputs of the eval model that can be associated with the OrtValue(s) provided
|
571 |
+
* to the OrtTrainingApi::EvalStep function.
|
572 |
+
*
|
573 |
+
* \param[in] sess The `this` pointer to the training session.
|
574 |
+
* \param[in] index The index of the eval model input name requested.
|
575 |
+
* \param[in] allocator The allocator to use to allocate the memory for the requested name.
|
576 |
+
* \param[out] output Name of the user input for the eval model at the given index.
|
577 |
+
*
|
578 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
579 |
+
*
|
580 |
+
*/
|
581 |
+
ORT_API2_STATUS(TrainingSessionGetEvalModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
|
582 |
+
_In_ OrtAllocator* allocator, _Outptr_ char** output);
|
583 |
+
|
584 |
+
/// @}
|
585 |
+
|
586 |
+
/// \name Accessing The Training Session State
|
587 |
+
/// @{
|
588 |
+
|
589 |
+
/** \brief Adds the given property to the checkpoint state.
|
590 |
+
*
|
591 |
+
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
592 |
+
* state by the user if they desire by calling this function with the appropriate property name and
|
593 |
+
* value. The given property name must be unique to be able to successfully add the property.
|
594 |
+
*
|
595 |
+
* \param[in] checkpoint_state The checkpoint state which should hold the property.
|
596 |
+
* \param[in] property_name Unique name of the property being added.
|
597 |
+
* \param[in] property_type Type of the property associated with the given name.
|
598 |
+
* \param[in] property_value Property value associated with the given name.
|
599 |
+
*
|
600 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
601 |
+
*
|
602 |
+
*/
|
603 |
+
ORT_API2_STATUS(AddProperty, _Inout_ OrtCheckpointState* checkpoint_state,
|
604 |
+
_In_ const char* property_name, _In_ enum OrtPropertyType property_type,
|
605 |
+
_In_ void* property_value);
|
606 |
+
|
607 |
+
/** \brief Gets the property value associated with the given name from the checkpoint state.
|
608 |
+
*
|
609 |
+
* Gets the property value from an existing entry in the checkpoint state. The property must
|
610 |
+
* exist in the checkpoint state to be able to retrieve it successfully.
|
611 |
+
*
|
612 |
+
* \param[in] checkpoint_state The checkpoint state that is currently holding the property.
|
613 |
+
* \param[in] property_name Unique name of the property being retrieved.
|
614 |
+
* \param[in] allocator Allocator used to allocate the memory for the property_value.
|
615 |
+
* \param[out] property_type Type of the property associated with the given name.
|
616 |
+
* \param[out] property_value Property value associated with the given name.
|
617 |
+
*
|
618 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
619 |
+
*
|
620 |
+
*/
|
621 |
+
ORT_API2_STATUS(GetProperty, _In_ const OrtCheckpointState* checkpoint_state,
|
622 |
+
_In_ const char* property_name, _Inout_ OrtAllocator* allocator,
|
623 |
+
_Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
|
624 |
+
|
625 |
+
/// @}
|
626 |
+
};
|
627 |
+
|
628 |
+
typedef struct OrtTrainingApi OrtTrainingApi;
|
629 |
+
|
630 |
+
/// @}
|
v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_training_cxx_api.h
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
// Licensed under the MIT License.
|
3 |
+
|
4 |
+
#pragma once
|
5 |
+
#include "onnxruntime_training_c_api.h"
|
6 |
+
#include <optional>
|
7 |
+
#include <variant>
|
8 |
+
|
9 |
+
namespace Ort::detail {
|
10 |
+
|
11 |
+
#define ORT_DECLARE_TRAINING_RELEASE(NAME) \
|
12 |
+
void OrtRelease(Ort##NAME* ptr);
|
13 |
+
|
14 |
+
// These release methods must be forward declared before including onnxruntime_cxx_api.h
|
15 |
+
// otherwise class Base won't be aware of them
|
16 |
+
ORT_DECLARE_TRAINING_RELEASE(CheckpointState);
|
17 |
+
ORT_DECLARE_TRAINING_RELEASE(TrainingSession);
|
18 |
+
|
19 |
+
} // namespace Ort::detail
|
20 |
+
|
21 |
+
#include "onnxruntime_cxx_api.h"
|
22 |
+
|
23 |
+
namespace Ort {
|
24 |
+
|
25 |
+
/// <summary>
|
26 |
+
/// This function returns the C training api struct with the pointers to the ort training C functions.
|
27 |
+
/// If using C++, please use the class instances instead of invoking the C functions directly.
|
28 |
+
/// </summary>
|
29 |
+
/// <returns>OrtTrainingApi struct with ort training C function pointers.</returns>
|
30 |
+
inline const OrtTrainingApi& GetTrainingApi() { return *GetApi().GetTrainingApi(ORT_API_VERSION); }
|
31 |
+
|
32 |
+
namespace detail {
|
33 |
+
|
34 |
+
#define ORT_DEFINE_TRAINING_RELEASE(NAME) \
|
35 |
+
inline void OrtRelease(Ort##NAME* ptr) { GetTrainingApi().Release##NAME(ptr); }
|
36 |
+
|
37 |
+
ORT_DEFINE_TRAINING_RELEASE(CheckpointState);
|
38 |
+
ORT_DEFINE_TRAINING_RELEASE(TrainingSession);
|
39 |
+
|
40 |
+
#undef ORT_DECLARE_TRAINING_RELEASE
|
41 |
+
#undef ORT_DEFINE_TRAINING_RELEASE
|
42 |
+
|
43 |
+
} // namespace detail
|
44 |
+
|
45 |
+
using Property = std::variant<int64_t, float, std::string>;
|
46 |
+
|
47 |
+
/**
|
48 |
+
* \defgroup TrainingCpp Ort Training C++ API
|
49 |
+
* @{
|
50 |
+
*/
|
51 |
+
|
52 |
+
/** \brief Holds the state of the training session.
|
53 |
+
*
|
54 |
+
* This class holds the entire training session state that includes model parameters, their gradients,
|
55 |
+
* optimizer parameters, and user properties. The Ort::TrainingSession leverages the Ort::CheckpointState
|
56 |
+
* by accessing and updating the contained training state.
|
57 |
+
* \note Note that the training session created with a checkpoint state uses this state to store the entire
|
58 |
+
* training state (including model parameters, its gradients, the optimizer states and the properties).
|
59 |
+
* The Ort::TrainingSession does not hold a copy of the Ort::CheckpointState and as a result, it is required
|
60 |
+
* that the checkpoint state outlive the lifetime of the training session.
|
61 |
+
*
|
62 |
+
*/
|
63 |
+
class CheckpointState : public detail::Base<OrtCheckpointState> {
|
64 |
+
private:
|
65 |
+
CheckpointState(OrtCheckpointState* checkpoint_state) { p_ = checkpoint_state; }
|
66 |
+
|
67 |
+
public:
|
68 |
+
// Construct the checkpoint state by loading the checkpoint by calling LoadCheckpoint
|
69 |
+
CheckpointState() = delete;
|
70 |
+
|
71 |
+
/// \name Accessing The Training Session State
|
72 |
+
/// @{
|
73 |
+
|
74 |
+
/** \brief Load a checkpoint state from directory on disk into checkpoint_state.
|
75 |
+
*
|
76 |
+
* This function will parse a checkpoint directory, pull relevant files and load the training
|
77 |
+
* state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the
|
78 |
+
* training session by instantiating Ort::TrainingSession. By doing so, the training session will resume
|
79 |
+
* training from the given checkpoint state.
|
80 |
+
*
|
81 |
+
* \param[in] path_to_checkpoint Path to the checkpoint directory
|
82 |
+
* \return Ort::CheckpointState object which holds the state of the training session parameters.
|
83 |
+
*
|
84 |
+
*/
|
85 |
+
static CheckpointState LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint);
|
86 |
+
|
87 |
+
/** \brief Save the given state to a checkpoint directory on disk.
|
88 |
+
*
|
89 |
+
* This function serializes the provided checkpoint state to a directory on disk.
|
90 |
+
* This checkpoint can later be loaded by invoking Ort::CheckpointState::LoadCheckpoint to resume
|
91 |
+
* training from this snapshot of the state.
|
92 |
+
*
|
93 |
+
* \param[in] checkpoint_state The checkpoint state to save.
|
94 |
+
* \param[in] path_to_checkpoint Path to the checkpoint directory.
|
95 |
+
* \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not.
|
96 |
+
*
|
97 |
+
*/
|
98 |
+
static void SaveCheckpoint(const CheckpointState& checkpoint_state,
|
99 |
+
const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
|
100 |
+
const bool include_optimizer_state = false);
|
101 |
+
|
102 |
+
/** \brief Adds the given property to the checkpoint state.
|
103 |
+
*
|
104 |
+
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
105 |
+
* state by the user if they desire by calling this function with the appropriate property name and
|
106 |
+
* value. The given property name must be unique to be able to successfully add the property.
|
107 |
+
*
|
108 |
+
* \param[in] property_name Unique name of the property being added.
|
109 |
+
* \param[in] property_value Property value associated with the given name.
|
110 |
+
*
|
111 |
+
*/
|
112 |
+
void AddProperty(const std::string& property_name, const Property& property_value);
|
113 |
+
|
114 |
+
/** \brief Gets the property value associated with the given name from the checkpoint state.
|
115 |
+
*
|
116 |
+
* Gets the property value from an existing entry in the checkpoint state. The property must
|
117 |
+
* exist in the checkpoint state to be able to retrieve it successfully.
|
118 |
+
*
|
119 |
+
* \param[in] property_name Unique name of the property being retrieved.
|
120 |
+
* \return Property value associated with the given property name.
|
121 |
+
*
|
122 |
+
*/
|
123 |
+
Property GetProperty(const std::string& property_name);
|
124 |
+
|
125 |
+
/// @}
|
126 |
+
};
|
127 |
+
|
128 |
+
/** \brief Trainer class that provides training, evaluation and optimizer methods for training an ONNX models.
|
129 |
+
*
|
130 |
+
* The training session requires four training artifacts
|
131 |
+
* - The training onnx model
|
132 |
+
* - The evaluation onnx model (optional)
|
133 |
+
* - The optimizer onnx model
|
134 |
+
* - The checkpoint directory
|
135 |
+
*
|
136 |
+
* These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
|
137 |
+
*
|
138 |
+
*/
|
139 |
+
class TrainingSession : public detail::Base<OrtTrainingSession> {
|
140 |
+
private:
|
141 |
+
size_t training_model_output_count_, eval_model_output_count_;
|
142 |
+
|
143 |
+
public:
|
144 |
+
/// \name Constructing the Training Session
|
145 |
+
/// @{
|
146 |
+
/** \brief Create a training session that can be used to begin or resume training.
|
147 |
+
*
|
148 |
+
* This constructor instantiates the training session based on the env and session options provided that can
|
149 |
+
* begin or resume training from a given checkpoint state for the given onnx models.
|
150 |
+
* The checkpoint state represents the parameters of the training session which will be moved
|
151 |
+
* to the device specified by the user through the session options (if necessary).
|
152 |
+
*
|
153 |
+
* \param[in] env Env to be used for the training session.
|
154 |
+
* \param[in] session_options SessionOptions that the user can customize for this training session.
|
155 |
+
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
|
156 |
+
* \param[in] train_model_path Model to be used to perform training.
|
157 |
+
* \param[in] eval_model_path Model to be used to perform evaluation.
|
158 |
+
* \param[in] optimizer_model_path Model to be used to perform gradient descent.
|
159 |
+
*
|
160 |
+
*/
|
161 |
+
TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
|
162 |
+
const std::basic_string<ORTCHAR_T>& train_model_path,
|
163 |
+
const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path = std::nullopt,
|
164 |
+
const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path = std::nullopt);
|
165 |
+
|
166 |
+
/// @}
|
167 |
+
|
168 |
+
/// \name Implementing The Training Loop
|
169 |
+
/// @{
|
170 |
+
/** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
|
171 |
+
*
|
172 |
+
* This function performs a training step that computes the outputs of the training model and the gradients
|
173 |
+
* of the trainable parameters for the given inputs. The train step is performed based on the training model
|
174 |
+
* that was provided to the training session.
|
175 |
+
* The Ort::TrainingSession::TrainStep is equivalent of running forward propagation and backward propagation in a single
|
176 |
+
* step.
|
177 |
+
* The gradients computed are stored inside the training session state so they can be later consumed
|
178 |
+
* by the Ort::TrainingSession::OptimizerStep function.
|
179 |
+
* The gradients can be lazily reset by invoking the Ort::TrainingSession::LazyResetGrad function.
|
180 |
+
*
|
181 |
+
* \param[in] input_values The user inputs to the training model.
|
182 |
+
* \return A std::vector of Ort::Value objects that represents the output of the forward pass of the training model.
|
183 |
+
*
|
184 |
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
185 |
+
*
|
186 |
+
*/
|
187 |
+
std::vector<Value> TrainStep(const std::vector<Value>& input_values);
|
188 |
+
|
189 |
+
/** \brief Reset the gradients of all trainable parameters to zero lazily.
|
190 |
+
*
|
191 |
+
* This function sets the internal state of the training session such that the gradients of the trainable
|
192 |
+
* parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are
|
193 |
+
* computed on the next invocation of the next Ort::TrainingSession::TrainStep.
|
194 |
+
*
|
195 |
+
*/
|
196 |
+
void LazyResetGrad();
|
197 |
+
|
198 |
+
/** \brief Computes the outputs for the eval model for the given inputs
|
199 |
+
*
|
200 |
+
* This function performs an eval step that computes the outputs of the eval model for the given inputs.
|
201 |
+
* The eval step is performed based on the eval model that was provided to the training session.
|
202 |
+
*
|
203 |
+
* \param[in] input_values The user inputs to the eval model.
|
204 |
+
* \return A std::vector of Ort::Value objects that represents the output of the eval pass.
|
205 |
+
*
|
206 |
+
*/
|
207 |
+
std::vector<Value> EvalStep(const std::vector<Value>& input_values);
|
208 |
+
|
209 |
+
/** \brief Sets the learning rate for this training session.
|
210 |
+
*
|
211 |
+
* This function allows users to set the learning rate for the training session. The current
|
212 |
+
* learning rate is maintained by the training session and can be overwritten by invoking
|
213 |
+
* this function with the desired learning rate. This function should not be used when a valid
|
214 |
+
* learning rate scheduler is registered. It should be used either to set the learning rate
|
215 |
+
* derived from a custom learning rate scheduler or to set a constant learning rate to be used
|
216 |
+
* throughout the training session.
|
217 |
+
* \note Please note that this function does not set the initial learning rate that may be needed
|
218 |
+
* by the predefined learning rate schedulers. To set the initial learning rate for learning
|
219 |
+
* rate schedulers, please look at the function Ort::TrainingSession::RegisterLinearLRScheduler.
|
220 |
+
*
|
221 |
+
* \param[in] learning_rate Desired learning rate to be set.
|
222 |
+
*
|
223 |
+
*/
|
224 |
+
void SetLearningRate(float learning_rate);
|
225 |
+
|
226 |
+
/** \brief Gets the current learning rate for this training session.
|
227 |
+
*
|
228 |
+
* This function allows users to get the learning rate for the training session. The current
|
229 |
+
* learning rate is maintained by the training session, and users can query it for the purpose
|
230 |
+
* of implementing their own learning rate schedulers.
|
231 |
+
*
|
232 |
+
* \return float representing the current learning rate.
|
233 |
+
*
|
234 |
+
*/
|
235 |
+
float GetLearningRate() const;
|
236 |
+
|
237 |
+
/** \brief Registers a linear learning rate scheduler for the training session.
|
238 |
+
*
|
239 |
+
* Register a linear learning rate scheduler that decays the learning rate by linearly updated
|
240 |
+
* multiplicative factor from the initial learning rate set on the training session to 0. The decay
|
241 |
+
* is performed after the initial warm up phase where the learning rate is linearly incremented
|
242 |
+
* from 0 to the initial learning rate provided.
|
243 |
+
*
|
244 |
+
* \param[in] warmup_step_count Warmup steps for LR warmup.
|
245 |
+
* \param[in] total_step_count Total step count.
|
246 |
+
* \param[in] initial_lr The initial learning rate to be used by the training session.
|
247 |
+
*
|
248 |
+
*/
|
249 |
+
void RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count,
|
250 |
+
float initial_lr);
|
251 |
+
|
252 |
+
/** \brief Update the learning rate based on the registered learing rate scheduler.
|
253 |
+
*
|
254 |
+
* Takes a scheduler step that updates the learning rate that is being used by the training session.
|
255 |
+
* This function should typically be called before invoking the optimizer step for each round,
|
256 |
+
* or as determined necessary to update the learning rate being used by the training session.
|
257 |
+
* \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this
|
258 |
+
* function.
|
259 |
+
*
|
260 |
+
*/
|
261 |
+
void SchedulerStep();
|
262 |
+
|
263 |
+
/** \brief Performs the weight updates for the trainable parameters using the optimizer model.
|
264 |
+
*
|
265 |
+
* This function performs the weight update step that updates the trainable parameters such that they
|
266 |
+
* take a step in the direction of their gradients (gradient descent). The optimizer step is performed
|
267 |
+
* based on the optimizer model that was provided to the training session.
|
268 |
+
* The updated parameters are stored inside the training state so that they can be used by the next
|
269 |
+
* Ort::TrainingSession::TrainStep function call.
|
270 |
+
*
|
271 |
+
*/
|
272 |
+
void OptimizerStep();
|
273 |
+
|
274 |
+
/// @}
|
275 |
+
|
276 |
+
/// \name Prepare For Inferencing
|
277 |
+
/// @{
|
278 |
+
|
279 |
+
/** \brief Export a model that can be used for inferencing.
|
280 |
+
*
|
281 |
+
* If the training session was provided with an eval model, the training session can generate
|
282 |
+
* an inference model if it knows the inference graph outputs. The input inference graph outputs
|
283 |
+
* are used to prune the eval model so that the inference model's outputs align with the provided outputs.
|
284 |
+
* The exported model is saved at the path provided and can be used for inferencing with Ort::Session.
|
285 |
+
* \note Note that the function re-loads the eval model from the path provided to Ort::TrainingSession
|
286 |
+
* and expects that this path still be valid.
|
287 |
+
*
|
288 |
+
* \param[in] inference_model_path Path where the inference model should be serialized to.
|
289 |
+
* \param[in] graph_output_names Names of the outputs that are needed in the inference model.
|
290 |
+
*
|
291 |
+
*/
|
292 |
+
void ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
|
293 |
+
const std::vector<std::string>& graph_output_names);
|
294 |
+
|
295 |
+
/// @}
|
296 |
+
|
297 |
+
/// \name Model IO Information
|
298 |
+
/// @{
|
299 |
+
/** \brief Retrieves the names of the user inputs for the training and eval models.
|
300 |
+
*
|
301 |
+
* This function returns the names of inputs of the training or eval model that can be associated
|
302 |
+
* with the Ort::Value(s) provided to the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep
|
303 |
+
* function.
|
304 |
+
*
|
305 |
+
* \param[in] training Whether the training model input names are requested or eval model input names.
|
306 |
+
* \return Graph input names for either the training model or the eval model.
|
307 |
+
*
|
308 |
+
*/
|
309 |
+
std::vector<std::string> InputNames(const bool training);
|
310 |
+
|
311 |
+
/** \brief Retrieves the names of the user outputs for the training and eval models.
|
312 |
+
*
|
313 |
+
* This function returns the names of outputs of the training or eval model that can be associated
|
314 |
+
* with the Ort::Value(s) returned by the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep
|
315 |
+
* function.
|
316 |
+
*
|
317 |
+
* \param[in] training Whether the training model output names are requested or eval model output names.
|
318 |
+
* \return Graph output names for either the training model or the eval model.
|
319 |
+
*
|
320 |
+
*/
|
321 |
+
std::vector<std::string> OutputNames(const bool training);
|
322 |
+
|
323 |
+
/// @}
|
324 |
+
|
325 |
+
/// \name Accessing The Training Session State
|
326 |
+
/// @{
|
327 |
+
|
328 |
+
/** \brief Returns a contiguous buffer that holds a copy of all training state parameters
|
329 |
+
*
|
330 |
+
* \param[in] only_trainable Whether to only copy trainable parameters or to copy all parameters.
|
331 |
+
* \return Contiguous buffer to the model parameters.
|
332 |
+
*
|
333 |
+
*/
|
334 |
+
Value ToBuffer(const bool only_trainable);
|
335 |
+
|
336 |
+
/** \brief Loads the training session model parameters from a contiguous buffer
|
337 |
+
*
|
338 |
+
* \param[in] buffer Contiguous buffer to load the parameters from.
|
339 |
+
*/
|
340 |
+
void FromBuffer(Value& buffer);
|
341 |
+
|
342 |
+
/// @}
|
343 |
+
};
|
344 |
+
|
345 |
+
/// \name Training Utilities
|
346 |
+
/// @{
|
347 |
+
/** \brief This function sets the seed for generating random numbers.
|
348 |
+
*
|
349 |
+
* Use this function to generate reproducible results. It should be noted that completely
|
350 |
+
* reproducible results are not guaranteed.
|
351 |
+
*
|
352 |
+
* \param[in] seed Manual seed to use for random number generation.
|
353 |
+
*/
|
354 |
+
void SetSeed(const int64_t seed);
|
355 |
+
/// @}
|
356 |
+
|
357 |
+
/// @}
|
358 |
+
|
359 |
+
} // namespace Ort
|
360 |
+
|
361 |
+
#include "onnxruntime_training_cxx_inline.h"
|
v1.15.1/onnxruntime-linux-armhf/include/onnxruntime_training_cxx_inline.h
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
// Licensed under the MIT License.
|
3 |
+
|
4 |
+
#pragma once
|
5 |
+
#include "onnxruntime_training_c_api.h"
|
6 |
+
#include "onnxruntime_cxx_api.h"
|
7 |
+
|
8 |
+
namespace Ort {
|
9 |
+
|
10 |
+
inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options,
|
11 |
+
CheckpointState& checkpoint_state,
|
12 |
+
const std::basic_string<ORTCHAR_T>& train_model_path,
|
13 |
+
const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path,
|
14 |
+
const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path) {
|
15 |
+
ThrowOnError(GetTrainingApi().CreateTrainingSession(
|
16 |
+
env, session_options, checkpoint_state,
|
17 |
+
train_model_path.c_str(),
|
18 |
+
eval_model_path.has_value() ? eval_model_path.value().c_str() : nullptr,
|
19 |
+
optimizer_model_path.has_value() ? optimizer_model_path.value().c_str() : nullptr,
|
20 |
+
&p_));
|
21 |
+
|
22 |
+
ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_));
|
23 |
+
|
24 |
+
ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_));
|
25 |
+
}
|
26 |
+
|
27 |
+
inline std::vector<Value> TrainingSession::TrainStep(const std::vector<Value>& input_values) {
|
28 |
+
std::vector<Value> output_values;
|
29 |
+
output_values.reserve(training_model_output_count_);
|
30 |
+
for (size_t i = 0; i < training_model_output_count_; i++) output_values.emplace_back(nullptr);
|
31 |
+
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values.data());
|
32 |
+
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values.data());
|
33 |
+
RunOptions run_options;
|
34 |
+
ThrowOnError(GetTrainingApi().TrainStep(
|
35 |
+
p_, run_options, input_values.size(), ort_input_values,
|
36 |
+
training_model_output_count_, ort_output_values));
|
37 |
+
|
38 |
+
return output_values;
|
39 |
+
}
|
40 |
+
|
41 |
+
inline void TrainingSession::LazyResetGrad() {
|
42 |
+
ThrowOnError(GetTrainingApi().LazyResetGrad(p_));
|
43 |
+
}
|
44 |
+
|
45 |
+
inline std::vector<Value> TrainingSession::EvalStep(const std::vector<Value>& input_values) {
|
46 |
+
std::vector<Value> output_values;
|
47 |
+
output_values.reserve(eval_model_output_count_);
|
48 |
+
for (size_t i = 0; i < eval_model_output_count_; i++) output_values.emplace_back(nullptr);
|
49 |
+
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values.data());
|
50 |
+
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values.data());
|
51 |
+
RunOptions run_options;
|
52 |
+
ThrowOnError(GetTrainingApi().EvalStep(
|
53 |
+
p_, run_options, input_values.size(), ort_input_values,
|
54 |
+
training_model_output_count_, ort_output_values));
|
55 |
+
|
56 |
+
return output_values;
|
57 |
+
}
|
58 |
+
|
59 |
+
inline void TrainingSession::SetLearningRate(float learning_rate) {
|
60 |
+
ThrowOnError(GetTrainingApi().SetLearningRate(p_, learning_rate));
|
61 |
+
}
|
62 |
+
|
63 |
+
inline float TrainingSession::GetLearningRate() const {
|
64 |
+
float learning_rate = 0;
|
65 |
+
ThrowOnError(GetTrainingApi().GetLearningRate(p_, &learning_rate));
|
66 |
+
return learning_rate;
|
67 |
+
}
|
68 |
+
|
69 |
+
inline void TrainingSession::RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count,
|
70 |
+
float initial_lr) {
|
71 |
+
ThrowOnError(GetTrainingApi().RegisterLinearLRScheduler(p_, warmup_step_count, total_step_count,
|
72 |
+
initial_lr));
|
73 |
+
}
|
74 |
+
|
75 |
+
inline void TrainingSession::SchedulerStep() {
|
76 |
+
ThrowOnError(GetTrainingApi().SchedulerStep(p_));
|
77 |
+
}
|
78 |
+
|
79 |
+
inline void TrainingSession::OptimizerStep() {
|
80 |
+
RunOptions run_options;
|
81 |
+
ThrowOnError(GetTrainingApi().OptimizerStep(p_, run_options));
|
82 |
+
}
|
83 |
+
|
84 |
+
inline std::vector<std::string> TrainingSession::InputNames(const bool training) {
|
85 |
+
auto& input_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputCount
|
86 |
+
: GetTrainingApi().TrainingSessionGetEvalModelInputCount;
|
87 |
+
auto& input_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputName
|
88 |
+
: GetTrainingApi().TrainingSessionGetEvalModelInputName;
|
89 |
+
|
90 |
+
size_t input_count = 0;
|
91 |
+
ThrowOnError(input_count_function(p_, &input_count));
|
92 |
+
std::vector<std::string> input_names(input_count);
|
93 |
+
AllocatorWithDefaultOptions allocator;
|
94 |
+
for (size_t index = 0; index < input_count; ++index) {
|
95 |
+
char* input_name;
|
96 |
+
ThrowOnError(input_name_function(p_, index, allocator, &input_name));
|
97 |
+
input_names[index] = std::string(input_name);
|
98 |
+
allocator.Free(input_name);
|
99 |
+
}
|
100 |
+
|
101 |
+
return input_names;
|
102 |
+
}
|
103 |
+
|
104 |
+
inline std::vector<std::string> TrainingSession::OutputNames(const bool training) {
|
105 |
+
auto& output_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputCount
|
106 |
+
: GetTrainingApi().TrainingSessionGetEvalModelOutputCount;
|
107 |
+
auto& output_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputName
|
108 |
+
: GetTrainingApi().TrainingSessionGetEvalModelOutputName;
|
109 |
+
|
110 |
+
size_t output_count = 0;
|
111 |
+
ThrowOnError(output_count_function(p_, &output_count));
|
112 |
+
std::vector<std::string> output_names(output_count);
|
113 |
+
AllocatorWithDefaultOptions allocator;
|
114 |
+
for (size_t index = 0; index < output_count; ++index) {
|
115 |
+
char* output_name;
|
116 |
+
ThrowOnError(output_name_function(p_, index, allocator, &output_name));
|
117 |
+
output_names[index] = std::string(output_name);
|
118 |
+
allocator.Free(output_name);
|
119 |
+
}
|
120 |
+
|
121 |
+
return output_names;
|
122 |
+
}
|
123 |
+
|
124 |
+
inline Value TrainingSession::ToBuffer(const bool only_trainable) {
|
125 |
+
size_t buffer_size = 0U;
|
126 |
+
ThrowOnError(GetTrainingApi().GetParametersSize(p_, &buffer_size, only_trainable));
|
127 |
+
|
128 |
+
std::array<int64_t, 1> buffer_shape{static_cast<int64_t>(buffer_size)};
|
129 |
+
|
130 |
+
AllocatorWithDefaultOptions allocator;
|
131 |
+
Value buffer = Value::CreateTensor(allocator, buffer_shape.data(), 1U,
|
132 |
+
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
133 |
+
|
134 |
+
ThrowOnError(GetTrainingApi().CopyParametersToBuffer(p_, buffer, only_trainable));
|
135 |
+
|
136 |
+
return buffer;
|
137 |
+
}
|
138 |
+
|
139 |
+
inline void TrainingSession::FromBuffer(Value& buffer) {
|
140 |
+
if (!buffer.IsTensor()) {
|
141 |
+
ThrowStatus(Status("Incorrect buffer received. Expected a tensor buffer.", OrtErrorCode::ORT_INVALID_ARGUMENT));
|
142 |
+
}
|
143 |
+
|
144 |
+
auto tensor_info = buffer.GetTensorTypeAndShapeInfo();
|
145 |
+
auto buffer_shape = tensor_info.GetShape();
|
146 |
+
|
147 |
+
if (buffer_shape.size() != 1U) {
|
148 |
+
ThrowStatus(Status("Incorrect buffer received. Expected a contiguous tensor buffer.",
|
149 |
+
OrtErrorCode::ORT_INVALID_ARGUMENT));
|
150 |
+
}
|
151 |
+
|
152 |
+
auto buffer_size = buffer_shape.front();
|
153 |
+
|
154 |
+
size_t session_buffer_size_trainable_only = 0U;
|
155 |
+
ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size_trainable_only, true));
|
156 |
+
|
157 |
+
if (buffer_size == static_cast<int64_t>(session_buffer_size_trainable_only)) {
|
158 |
+
ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, true));
|
159 |
+
return;
|
160 |
+
}
|
161 |
+
|
162 |
+
size_t session_buffer_size = 0U;
|
163 |
+
ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size, false));
|
164 |
+
|
165 |
+
if (buffer_size != static_cast<int64_t>(session_buffer_size)) {
|
166 |
+
ThrowStatus(Status("Incorrect buffer size received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
|
167 |
+
}
|
168 |
+
|
169 |
+
ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, false));
|
170 |
+
}
|
171 |
+
|
172 |
+
inline CheckpointState CheckpointState::LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint) {
|
173 |
+
OrtCheckpointState* checkpoint_state;
|
174 |
+
ThrowOnError(GetTrainingApi().LoadCheckpoint(path_to_checkpoint.c_str(), &checkpoint_state));
|
175 |
+
return CheckpointState(checkpoint_state);
|
176 |
+
}
|
177 |
+
|
178 |
+
inline void CheckpointState::SaveCheckpoint(const CheckpointState& checkpoint_states,
|
179 |
+
const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
|
180 |
+
const bool include_optimizer_state) {
|
181 |
+
ThrowOnError(GetTrainingApi().SaveCheckpoint(checkpoint_states, path_to_checkpoint.c_str(),
|
182 |
+
include_optimizer_state));
|
183 |
+
}
|
184 |
+
|
185 |
+
inline void TrainingSession::ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
|
186 |
+
const std::vector<std::string>& graph_output_names) {
|
187 |
+
std::vector<const char*> output_names;
|
188 |
+
output_names.reserve(graph_output_names.size());
|
189 |
+
for (const auto& output_name : graph_output_names) {
|
190 |
+
output_names.push_back(output_name.c_str());
|
191 |
+
}
|
192 |
+
ThrowOnError(GetTrainingApi().ExportModelForInferencing(
|
193 |
+
p_, inference_model_path.c_str(), graph_output_names.size(), output_names.data()));
|
194 |
+
}
|
195 |
+
|
196 |
+
inline void SetSeed(const int64_t seed) {
|
197 |
+
ThrowOnError(GetTrainingApi().SetSeed(seed));
|
198 |
+
}
|
199 |
+
|
200 |
+
inline void CheckpointState::AddProperty(const std::string& property_name, const Property& property_value) {
|
201 |
+
if (std::holds_alternative<int64_t>(property_value)) {
|
202 |
+
int64_t value = std::get<int64_t>(property_value);
|
203 |
+
void* value_p = &value;
|
204 |
+
ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtIntProperty, value_p));
|
205 |
+
} else if (std::holds_alternative<float>(property_value)) {
|
206 |
+
float value = std::get<float>(property_value);
|
207 |
+
void* value_p = &value;
|
208 |
+
ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtFloatProperty, value_p));
|
209 |
+
} else if (std::holds_alternative<std::string>(property_value)) {
|
210 |
+
std::string value = std::get<std::string>(property_value);
|
211 |
+
auto buffer = std::make_unique<char[]>(value.length() + 1).release();
|
212 |
+
memcpy(buffer, value.c_str(), value.length());
|
213 |
+
ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtStringProperty, buffer));
|
214 |
+
} else {
|
215 |
+
ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
|
216 |
+
}
|
217 |
+
}
|
218 |
+
|
219 |
+
inline Property CheckpointState::GetProperty(const std::string& property_name) {
|
220 |
+
void* property_value = nullptr;
|
221 |
+
OrtPropertyType property_type;
|
222 |
+
|
223 |
+
AllocatorWithDefaultOptions allocator;
|
224 |
+
ThrowOnError(GetTrainingApi().GetProperty(p_, property_name.c_str(), allocator, &property_type, &property_value));
|
225 |
+
|
226 |
+
Property property;
|
227 |
+
|
228 |
+
switch (property_type) {
|
229 |
+
case OrtPropertyType::OrtIntProperty: {
|
230 |
+
auto value_p = reinterpret_cast<int64_t*>(property_value);
|
231 |
+
property = *value_p;
|
232 |
+
allocator.Free(property_value);
|
233 |
+
break;
|
234 |
+
}
|
235 |
+
case OrtPropertyType::OrtFloatProperty: {
|
236 |
+
auto value_p = reinterpret_cast<float*>(property_value);
|
237 |
+
property = *value_p;
|
238 |
+
allocator.Free(property_value);
|
239 |
+
break;
|
240 |
+
}
|
241 |
+
case OrtPropertyType::OrtStringProperty: {
|
242 |
+
auto value_p = reinterpret_cast<char*>(property_value);
|
243 |
+
property = std::string(value_p);
|
244 |
+
allocator.Free(property_value);
|
245 |
+
break;
|
246 |
+
}
|
247 |
+
default: {
|
248 |
+
ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
|
249 |
+
break;
|
250 |
+
}
|
251 |
+
}
|
252 |
+
|
253 |
+
return property;
|
254 |
+
}
|
255 |
+
|
256 |
+
} // namespace Ort
|
v1.15.1/onnxruntime-linux-armhf/include/provider_options.h
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
2 |
+
// Licensed under the MIT License.
|
3 |
+
|
4 |
+
#pragma once
|
5 |
+
|
6 |
+
#include <string>
|
7 |
+
#include <unordered_map>
|
8 |
+
#include <vector>
|
9 |
+
|
10 |
+
namespace onnxruntime {
|
11 |
+
|
12 |
+
// data types for execution provider options
|
13 |
+
|
14 |
+
using ProviderOptions = std::unordered_map<std::string, std::string>;
|
15 |
+
using ProviderOptionsVector = std::vector<ProviderOptions>;
|
16 |
+
using ProviderOptionsMap = std::unordered_map<std::string, ProviderOptions>;
|
17 |
+
|
18 |
+
} // namespace onnxruntime
|
v1.15.1/onnxruntime-linux-armhf/lib/libonnxruntime.so
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
libonnxruntime.so.1.15.1
|
v1.15.1/onnxruntime-linux-armhf/lib/libonnxruntime.so.1.15.1
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:74a56eea1395b75adab9174534ab413ca77435e20fd83f64294932c22a081fde
|
3 |
+
size 13943728
|