Browse Source

working for single run

master
Cailean Finn 2 months ago
parent
commit
a4e4aa7bf1
  1. 3
      .gitignore
  2. 5
      README.md
  3. 2
      addon_config.mk
  4. 108
      libs/onnxruntime/include/core/providers/cuda/cuda_context.h
  5. 23
      libs/onnxruntime/include/core/providers/cuda/cuda_resource.h
  6. 10
      libs/onnxruntime/include/core/providers/custom_op_context.h
  7. 14
      libs/onnxruntime/include/core/providers/resource.h
  8. 5258
      libs/onnxruntime/include/onnxruntime_c_api.h
  9. 2556
      libs/onnxruntime/include/onnxruntime_cxx_api.h
  10. 2011
      libs/onnxruntime/include/onnxruntime_cxx_inline.h
  11. 535
      libs/onnxruntime/include/onnxruntime_float16.h
  12. 1119
      libs/onnxruntime/include/onnxruntime_lite_custom_op.h
  13. 24
      libs/onnxruntime/include/onnxruntime_run_options_config_keys.h
  14. 229
      libs/onnxruntime/include/onnxruntime_session_options_config_keys.h
  15. 731
      libs/onnxruntime/include/onnxruntime_training_c_api.h
  16. 418
      libs/onnxruntime/include/onnxruntime_training_cxx_api.h
  17. 295
      libs/onnxruntime/include/onnxruntime_training_cxx_inline.h
  18. 14
      libs/onnxruntime/include/tensorrt_provider_factory.h
  19. BIN
      libs/onnxruntime/lib/osx/libonnxruntime.1.10.0.dylib
  20. 127
      src/ofxOnnxRuntime.cpp
  21. 20
      src/ofxOnnxRuntime.h

3
.gitignore

@ -1,3 +1,6 @@
# Ignoring onnxruntime libs
/libs/onnxruntime/lib/*
example-*/config.make
example-*/*.sln
example-*/*.vcxproj

5
README.md

@ -1,4 +1,7 @@
# ofxOnnxRuntime
**Updated version, working with Windows 11, CUDA, and ONNXRuntime 1.20.1**
[ONNX Runtime](https://github.com/microsoft/onnxruntime) tiny wrapper for openFrameworks
!['test'](screenshot.png)
@ -17,7 +20,7 @@
- From `Browse` tab, search `Microsoft.ML.OnnxRuntime` (CPU) or `Microsoft.ML.OnnxRuntime.Gpu` (GPU) and install it.
2. DLL direct download
- You can download prebuilt DLLs from [here](https://github.com/microsoft/onnxruntime/releases).
- Unzip downloaded `onnxruntime-win-x64-(gpu-)1.10.0.zip` and locate files on `libs\onnxruntime\lib\vs\x64\` .
- Unzip downloaded `onnxruntime-win-x64-(gpu-)1.20.1.zip` and locate files on `libs\onnxruntime\lib\vs\x64\` .
- Generate a project using ProjectGenerator, then all libs are linked correctly and all dlls are copied to `bin`.
## Tested environment

2
addon_config.mk

@ -11,4 +11,6 @@ common:
osx:
ADDON_LDFLAGS = -Xlinker -rpath -Xlinker @executable_path
vs:
ADDON_INCLUDES = libs/onnxruntime/include
ADDON_INCLUDES += src

108
libs/onnxruntime/include/core/providers/cuda/cuda_context.h

@ -0,0 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This header is to expose a context for cuda custom ops.
// By the context, a custom cuda operator could fetch existing resources,
// such as cuda stream and cudnn handle, for reusing.
// For concrete usage, pls find page here:
// https://onnxruntime.ai/docs/reference/operators/add-custom-op.html#custom-ops-for-cuda-and-rocm
#pragma once
#define ORT_CUDA_CTX
#include <cuda.h>
#include <cuda_runtime.h>
#ifndef USE_CUDA_MINIMAL
#include <cublas_v2.h>
#include <cudnn.h>
#endif
#include "core/providers/cuda/cuda_resource.h"
#include "core/providers/custom_op_context.h"
namespace Ort {
namespace Custom {
struct CudaContext : public CustomOpContext {
cudaStream_t cuda_stream = {};
cudnnHandle_t cudnn_handle = {};
cublasHandle_t cublas_handle = {};
OrtAllocator* deferred_cpu_allocator = {};
// below are cuda ep options
int16_t device_id = 0;
int32_t arena_extend_strategy = 0;
int32_t cudnn_conv_algo_search = 0;
bool cudnn_conv_use_max_workspace = true;
bool cudnn_conv1d_pad_to_nc1d = false;
bool enable_skip_layer_norm_strict_mode = false;
bool prefer_nhwc = false;
bool use_tf32 = true;
bool fuse_conv_bias = true;
void Init(const OrtKernelContext& kernel_ctx) {
cuda_stream = FetchResource<cudaStream_t>(kernel_ctx, CudaResource::cuda_stream_t);
cudnn_handle = FetchResource<cudnnHandle_t>(kernel_ctx, CudaResource::cudnn_handle_t);
cublas_handle = FetchResource<cublasHandle_t>(kernel_ctx, CudaResource::cublas_handle_t);
deferred_cpu_allocator = FetchResource<OrtAllocator*>(kernel_ctx, CudaResource::deferred_cpu_allocator_t);
device_id = FetchResource<int16_t>(kernel_ctx, CudaResource::device_id_t);
arena_extend_strategy = FetchResource<int32_t>(kernel_ctx, CudaResource::arena_extend_strategy_t);
cudnn_conv_algo_search = FetchResource<int32_t>(kernel_ctx, CudaResource::cudnn_conv_algo_search_t);
cudnn_conv_use_max_workspace = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv_use_max_workspace_t);
cudnn_conv1d_pad_to_nc1d = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t);
enable_skip_layer_norm_strict_mode = FetchResource<bool>(
kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t);
prefer_nhwc = FetchResource<bool>(kernel_ctx, CudaResource::prefer_nhwc_t);
use_tf32 = FetchResource<bool>(kernel_ctx, CudaResource::use_tf32_t);
fuse_conv_bias = FetchResource<bool>(kernel_ctx, CudaResource::fuse_conv_bias_t);
}
template <typename T>
T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) {
if constexpr (sizeof(T) > sizeof(void*)) {
ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type),
OrtErrorCode::ORT_INVALID_ARGUMENT);
}
const auto& ort_api = Ort::GetApi();
void* resource = {};
OrtStatus* status = ort_api.KernelContext_GetResource(
&kernel_ctx, ORT_CUDA_RESOURCE_VERSION, resource_type, &resource);
if (status) {
ORT_CXX_API_THROW("Failed to fetch cuda ep resource, resource type: " + std::to_string(resource_type),
OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
T t = {};
memcpy(&t, &resource, sizeof(T));
return t;
}
void* AllocDeferredCpuMem(size_t size) const {
if (0 == size) {
return {};
}
const auto& ort_api = Ort::GetApi();
void* mem = {};
auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem);
if (status) {
ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
return mem;
}
void FreeDeferredCpuMem(void* mem) const {
if (mem) {
const auto& ort_api = Ort::GetApi();
auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem);
if (status) {
ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
}
}
};
} // namespace Custom
} // namespace Ort

23
libs/onnxruntime/include/core/providers/cuda/cuda_resource.h

@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/resource.h"
#define ORT_CUDA_RESOURCE_VERSION 3
enum CudaResource : int {
cuda_stream_t = cuda_resource_offset, // 10000
cudnn_handle_t,
cublas_handle_t,
deferred_cpu_allocator_t,
// below are cuda ep options
device_id_t, // 10004
arena_extend_strategy_t,
cudnn_conv_algo_search_t,
cudnn_conv_use_max_workspace_t,
cudnn_conv1d_pad_to_nc1d_t,
enable_skip_layer_norm_strict_mode_t,
prefer_nhwc_t,
use_tf32_t,
fuse_conv_bias_t
};

10
libs/onnxruntime/include/core/providers/custom_op_context.h

@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
// CustomOpContext defines an interface allowing a custom op to access ep-specific resources.
struct CustomOpContext {
CustomOpContext() = default;
virtual ~CustomOpContext() {};
};

14
libs/onnxruntime/include/core/providers/resource.h

@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
enum ResourceOffset {
cpu_resource_offset = 0,
cuda_resource_offset = 10000,
dml_resource_offset = 20000,
rocm_resource_offset = 30000,
// offsets for other ort eps
custom_ep_resource_offset = 10000000,
// offsets for customized eps
};

5258
libs/onnxruntime/include/onnxruntime_c_api.h

File diff suppressed because it is too large

2556
libs/onnxruntime/include/onnxruntime_cxx_api.h

File diff suppressed because it is too large

2011
libs/onnxruntime/include/onnxruntime_cxx_inline.h

File diff suppressed because it is too large

535
libs/onnxruntime/include/onnxruntime_float16.h

@ -0,0 +1,535 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include <cmath>
#include <cstring>
#include <limits>
namespace onnxruntime_float16 {
namespace detail {
enum class endian {
#if defined(_WIN32)
little = 0,
big = 1,
native = little,
#elif defined(__GNUC__) || defined(__clang__)
little = __ORDER_LITTLE_ENDIAN__,
big = __ORDER_BIG_ENDIAN__,
native = __BYTE_ORDER__,
#else
#error onnxruntime_float16::detail::endian is not implemented in this environment.
#endif
};
static_assert(
endian::native == endian::little || endian::native == endian::big,
"Only little-endian or big-endian native byte orders are supported.");
} // namespace detail
/// <summary>
/// Shared implementation between public and internal classes. CRTP pattern.
/// </summary>
template <class Derived>
struct Float16Impl {
protected:
/// <summary>
/// Converts from float to uint16_t float16 representation
/// </summary>
/// <param name="v"></param>
/// <returns></returns>
constexpr static uint16_t ToUint16Impl(float v) noexcept;
/// <summary>
/// Converts float16 to float
/// </summary>
/// <returns>float representation of float16 value</returns>
float ToFloatImpl() const noexcept;
/// <summary>
/// Creates an instance that represents absolute value.
/// </summary>
/// <returns>Absolute value</returns>
uint16_t AbsImpl() const noexcept {
return static_cast<uint16_t>(val & ~kSignMask);
}
/// <summary>
/// Creates a new instance with the sign flipped.
/// </summary>
/// <returns>Flipped sign instance</returns>
uint16_t NegateImpl() const noexcept {
return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
}
public:
// uint16_t special values
static constexpr uint16_t kSignMask = 0x8000U;
static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number
static constexpr uint16_t kOneBits = 0x3C00U;
static constexpr uint16_t kMinusOneBits = 0xBC00U;
uint16_t val{0};
Float16Impl() = default;
/// <summary>
/// Checks if the value is negative
/// </summary>
/// <returns>true if negative</returns>
bool IsNegative() const noexcept {
return static_cast<int16_t>(val) < 0;
}
/// <summary>
/// Tests if the value is NaN
/// </summary>
/// <returns>true if NaN</returns>
bool IsNaN() const noexcept {
return AbsImpl() > kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value is finite
/// </summary>
/// <returns>true if finite</returns>
bool IsFinite() const noexcept {
return AbsImpl() < kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value represents positive infinity.
/// </summary>
/// <returns>true if positive infinity</returns>
bool IsPositiveInfinity() const noexcept {
return val == kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value represents negative infinity
/// </summary>
/// <returns>true if negative infinity</returns>
bool IsNegativeInfinity() const noexcept {
return val == kNegativeInfinityBits;
}
/// <summary>
/// Tests if the value is either positive or negative infinity.
/// </summary>
/// <returns>True if absolute value is infinity</returns>
bool IsInfinity() const noexcept {
return AbsImpl() == kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value is NaN or zero. Useful for comparisons.
/// </summary>
/// <returns>True if NaN or zero.</returns>
bool IsNaNOrZero() const noexcept {
auto abs = AbsImpl();
return (abs == 0 || abs > kPositiveInfinityBits);
}
/// <summary>
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
/// </summary>
/// <returns>True if so</returns>
bool IsNormal() const noexcept {
auto abs = AbsImpl();
return (abs < kPositiveInfinityBits) // is finite
&& (abs != 0) // is not zero
&& ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
}
/// <summary>
/// Tests if the value is subnormal (denormal).
/// </summary>
/// <returns>True if so</returns>
bool IsSubnormal() const noexcept {
auto abs = AbsImpl();
return (abs < kPositiveInfinityBits) // is finite
&& (abs != 0) // is not zero
&& ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
}
/// <summary>
/// Creates an instance that represents absolute value.
/// </summary>
/// <returns>Absolute value</returns>
Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
/// <summary>
/// Creates a new instance with the sign flipped.
/// </summary>
/// <returns>Flipped sign instance</returns>
Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
/// <summary>
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
/// and therefore equivalent, if the resulting value is still zero.
/// </summary>
/// <param name="lhs">first value</param>
/// <param name="rhs">second value</param>
/// <returns>True if both arguments represent zero</returns>
static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
}
bool operator==(const Float16Impl& rhs) const noexcept {
if (IsNaN() || rhs.IsNaN()) {
// IEEE defines that NaN is not equal to anything, including itself.
return false;
}
return val == rhs.val;
}
bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
bool operator<(const Float16Impl& rhs) const noexcept {
if (IsNaN() || rhs.IsNaN()) {
// IEEE defines that NaN is unordered with respect to everything, including itself.
return false;
}
const bool left_is_negative = IsNegative();
if (left_is_negative != rhs.IsNegative()) {
// When the signs of left and right differ, we know that left is less than right if it is
// the negative value. The exception to this is if both values are zero, in which case IEEE
// says they should be equal, even if the signs differ.
return left_is_negative && !AreZero(*this, rhs);
}
return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
}
};
// The following Float16_t conversions are based on the code from
// Eigen library.
// The conversion routines are Copyright (c) Fabian Giesen, 2016.
// The original license follows:
//
// Copyright (c) Fabian Giesen, 2016
// All rights reserved.
// Redistribution and use in source and binary forms, with or without
// modification, are permitted.
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
namespace detail {
union float32_bits {
unsigned int u;
float f;
};
} // namespace detail
template <class Derived>
inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
detail::float32_bits f{};
f.f = v;
constexpr detail::float32_bits f32infty = {255 << 23};
constexpr detail::float32_bits f16max = {(127 + 16) << 23};
constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
constexpr unsigned int sign_mask = 0x80000000u;
uint16_t val = static_cast<uint16_t>(0x0u);
unsigned int sign = f.u & sign_mask;
f.u ^= sign;
// NOTE all the integer compares in this function can be safely
// compiled into signed compares since all operands are below
// 0x80000000. Important if you want fast straight SSE2 code
// (since there's no unsigned PCMPGTD).
if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
} else { // (De)normalized number or zero
if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
// use a magic value to align our 10 mantissa bits at the bottom of
// the float. as long as FP addition is round-to-nearest-even this
// just works.
f.f += denorm_magic.f;
// and one integer subtract of the bias later, we have our final float!
val = static_cast<uint16_t>(f.u - denorm_magic.u);
} else {
unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
// update exponent, rounding bias part 1
// Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
// without arithmetic overflow.
f.u += 0xc8000fffU;
// rounding bias part 2
f.u += mant_odd;
// take the bits!
val = static_cast<uint16_t>(f.u >> 13);
}
}
val |= static_cast<uint16_t>(sign >> 16);
return val;
}
template <class Derived>
inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
constexpr detail::float32_bits magic = {113 << 23};
constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
detail::float32_bits o{};
o.u = (val & 0x7fff) << 13; // exponent/mantissa bits
unsigned int exp = shifted_exp & o.u; // just the exponent
o.u += (127 - 15) << 23; // exponent adjust
// handle exponent special cases
if (exp == shifted_exp) { // Inf/NaN?
o.u += (128 - 16) << 23; // extra exp adjust
} else if (exp == 0) { // Zero/Denormal?
o.u += 1 << 23; // extra exp adjust
o.f -= magic.f; // re-normalize
}
// Attempt to workaround the Internal Compiler Error on ARM64
// for bitwise | operator, including std::bitset
#if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
if (IsNegative()) {
return -o.f;
}
#else
// original code:
o.u |= (val & 0x8000U) << 16U; // sign bit
#endif
return o.f;
}
/// Shared implementation between public and internal classes. CRTP pattern.
template <class Derived>
struct BFloat16Impl {
protected:
/// <summary>
/// Converts from float to uint16_t float16 representation
/// </summary>
/// <param name="v"></param>
/// <returns></returns>
static uint16_t ToUint16Impl(float v) noexcept;
/// <summary>
/// Converts bfloat16 to float
/// </summary>
/// <returns>float representation of bfloat16 value</returns>
float ToFloatImpl() const noexcept;
/// <summary>
/// Creates an instance that represents absolute value.
/// </summary>
/// <returns>Absolute value</returns>
uint16_t AbsImpl() const noexcept {
return static_cast<uint16_t>(val & ~kSignMask);
}
/// <summary>
/// Creates a new instance with the sign flipped.
/// </summary>
/// <returns>Flipped sign instance</returns>
uint16_t NegateImpl() const noexcept {
return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
}
public:
// uint16_t special values
static constexpr uint16_t kSignMask = 0x8000U;
static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
static constexpr uint16_t kMaxValueBits = 0x7F7FU;
static constexpr uint16_t kRoundToNearest = 0x7FFFU;
static constexpr uint16_t kOneBits = 0x3F80U;
static constexpr uint16_t kMinusOneBits = 0xBF80U;
uint16_t val{0};
BFloat16Impl() = default;
/// <summary>
/// Checks if the value is negative
/// </summary>
/// <returns>true if negative</returns>
bool IsNegative() const noexcept {
return static_cast<int16_t>(val) < 0;
}
/// <summary>
/// Tests if the value is NaN
/// </summary>
/// <returns>true if NaN</returns>
bool IsNaN() const noexcept {
return AbsImpl() > kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value is finite
/// </summary>
/// <returns>true if finite</returns>
bool IsFinite() const noexcept {
return AbsImpl() < kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value represents positive infinity.
/// </summary>
/// <returns>true if positive infinity</returns>
bool IsPositiveInfinity() const noexcept {
return val == kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value represents negative infinity
/// </summary>
/// <returns>true if negative infinity</returns>
bool IsNegativeInfinity() const noexcept {
return val == kNegativeInfinityBits;
}
/// <summary>
/// Tests if the value is either positive or negative infinity.
/// </summary>
/// <returns>True if absolute value is infinity</returns>
bool IsInfinity() const noexcept {
return AbsImpl() == kPositiveInfinityBits;
}
/// <summary>
/// Tests if the value is NaN or zero. Useful for comparisons.
/// </summary>
/// <returns>True if NaN or zero.</returns>
bool IsNaNOrZero() const noexcept {
auto abs = AbsImpl();
return (abs == 0 || abs > kPositiveInfinityBits);
}
/// <summary>
/// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
/// </summary>
/// <returns>True if so</returns>
bool IsNormal() const noexcept {
auto abs = AbsImpl();
return (abs < kPositiveInfinityBits) // is finite
&& (abs != 0) // is not zero
&& ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
}
/// <summary>
/// Tests if the value is subnormal (denormal).
/// </summary>
/// <returns>True if so</returns>
bool IsSubnormal() const noexcept {
auto abs = AbsImpl();
return (abs < kPositiveInfinityBits) // is finite
&& (abs != 0) // is not zero
&& ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
}
/// <summary>
/// Creates an instance that represents absolute value.
/// </summary>
/// <returns>Absolute value</returns>
Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
/// <summary>
/// Creates a new instance with the sign flipped.
/// </summary>
/// <returns>Flipped sign instance</returns>
Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
/// <summary>
/// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
/// for two values by or'ing the private bits together and stripping the sign. They are both zero,
/// and therefore equivalent, if the resulting value is still zero.
/// </summary>
/// <param name="lhs">first value</param>
/// <param name="rhs">second value</param>
/// <returns>True if both arguments represent zero</returns>
static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
// for two values by or'ing the private bits together and stripping the sign. They are both zero,
// and therefore equivalent, if the resulting value is still zero.
return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
}
};
template <class Derived>
inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
uint16_t result;
if (std::isnan(v)) {
result = kPositiveQNaNBits;
} else {
auto get_msb_half = [](float fl) {
uint16_t result;
#ifdef __cpp_if_constexpr
if constexpr (detail::endian::native == detail::endian::little) {
#else
if (detail::endian::native == detail::endian::little) {
#endif
std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
} else {
std::memcpy(&result, &fl, sizeof(uint16_t));
}
return result;
};
uint16_t upper_bits = get_msb_half(v);
union {
uint32_t U32;
float F32;
};
F32 = v;
U32 += (upper_bits & 1) + kRoundToNearest;
result = get_msb_half(F32);
}
return result;
}
template <class Derived>
inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
if (IsNaN()) {
return std::numeric_limits<float>::quiet_NaN();
}
float result;
char* const first = reinterpret_cast<char*>(&result);
char* const second = first + sizeof(uint16_t);
#ifdef __cpp_if_constexpr
if constexpr (detail::endian::native == detail::endian::little) {
#else
if (detail::endian::native == detail::endian::little) {
#endif
std::memset(first, 0, sizeof(uint16_t));
std::memcpy(second, &val, sizeof(uint16_t));
} else {
std::memcpy(first, &val, sizeof(uint16_t));
std::memset(second, 0, sizeof(uint16_t));
}
return result;
}
} // namespace onnxruntime_float16

1119
libs/onnxruntime/include/onnxruntime_lite_custom_op.h

File diff suppressed because it is too large

24
libs/onnxruntime/include/onnxruntime_run_options_config_keys.h

@ -25,3 +25,27 @@
// Example usage: "cpu:0;gpu:0" (or) "gpu:0"
// By default, the value for this key is empty (i.e.) no memory arenas are shrunk
static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage";
// Set to '1' to not synchronize execution providers with CPU at the end of session run.
// Per default it will be set to '0'
// Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream.
static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers";
// Set HTP performance mode for QNN HTP backend before session run.
// options for HTP performance mode: "burst", "balanced", "default", "high_performance",
// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver",
// "sustained_high_performance". Default to "default".
static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode";
// Set HTP performance mode for QNN HTP backend post session run.
static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run";
// Set RPC control latency for QNN HTP backend
static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency";
// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true.
// The value should be an integer. If the value is not set, the default value is 0 and
// ORT session only captures one cuda graph before another capture is requested.
// If the value is set to -1, cuda graph capture/replay is disabled in that run.
// User are not expected to set the value to 0 as it is reserved for internal use.
static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id";

229
libs/onnxruntime/include/onnxruntime_session_options_config_keys.h

@ -44,13 +44,69 @@ static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.se
// It controls to run quantization model in QDQ (QuantizelinearDeQuantizelinear) format or not.
// "0": enable. ORT does fusion logic for QDQ format.
// "1": disable. ORT doesn't do fusion logic for QDQ format.
// Its default value is "0"
// Its default value is "0" unless the DirectML execution provider is registered, in which case it defaults to "1".
static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq";
// It controls whether to enable Double QDQ remover and Identical Children Consolidation
// "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
// "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
// Its default value is "0"
static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover";
// If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been
// completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the
// Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to
// 8-bit and back to float, but could impact accuracy. The impact on accuracy will be model specific and depend on
// other factors like whether the model was created using Quantization Aware Training or Post Training Quantization.
// As such, it's best to test to determine if enabling this works well for your scenario.
// The default value is "0"
// Available since version 1.11.
static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup";
// Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0".
// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
// This setting controls whether to enable AheadOfTime function inlining.
// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model
// as possible with the help of enabled execution providers.
// This can reduce the number of function calls and improve performance because it is done before
// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available,
// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time.
// "0": enable; "1": disable.
// Its default value is "0".
static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining";
#ifdef ENABLE_TRAINING
// Specifies a path of the file containing a list of memory optimization configurations.
// The value should be a string indicating the file path of the config file.
// The content of the config file is a JSON struct like this:
// [
// "Gelu+Cast+:1:0",
// "Dropout+:1:1"
// ]
// Taking the example of "Gelu+Cast+:1:0",
// > "Gelu+Cast+" is the subgraph string, a valid "subgraph string" should be one subgraph representation
// output by ORT graph transformations.
// > "1" is "optimization strategy", valid values: 0 - disabled, 1 - recompute.
// > "0" is "number of subgraph to apply" which is used to control how many subgraphs to apply optimization,
// to avoid "oversaving" the memory.
static const char* const kOrtSessionOptionsMemoryOptimizerApplyConfig = "optimization.memory_optimizer_config";
// Specifies the config for detecting subgraphs for memory footprint reduction.
// The value should be a string contains int separated using commas. The default value is "0:0".
static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimization.enable_memory_probe_recompute_config";
#endif
// This setting if set should contain a comma separated list of optimizers names that should be disabled.
// Optimizers may take time to execute and affect model loading time. If you feel that a specific optimizer
// does not provider runtime benefits, but affects your model loading time you may disable it using this config
// entry. This option is not enabled in ORT_MINIMAL_BUILD build.
// A list of optimizes is available in onnxruntime/core/optimizer/graph_transformer_utils.cc
//
// Default is an empty string which means no optimizers are disabled.
static const char* const kOrtSessionOptionsDisableSpecifiedOptimizers = "optimization.disable_specified_optimizers";
// Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0".
// Using device allocators means the memory allocation is made using malloc/new.
static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers";
@ -69,23 +125,36 @@ static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session
// has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed.
static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly";
// Save information for replaying graph optimizations later instead of applying them directly.
//
// When an ONNX model is loaded, ORT can perform various optimizations on the graph.
// However, when an ORT format model is loaded, these optimizations are typically not available - this scenario must
// be supported by minimal builds.
// When loading an ONNX model, ORT can optionally save the effects of some optimizations for later replay in an ORT
// format model. These are known as "runtime optimizations" - in an ORT format model, they happen at runtime.
//
// Note: This option is only applicable when loading an ONNX model and saving an ORT format model.
//
// Note: Runtime optimizations are only supported for certain optimizations at the extended level or higher.
// Unsupported optimizations at those levels are not applied at all, while optimizations at other levels are applied
// directly.
//
// "0": disabled, "1": enabled
// The default is "0".
static const char* const kOrtSessionOptionsConfigSaveRuntimeOptimizations = "optimization.save_runtime_optimizations";
/// <summary>
/// Key for using the ORT format model flatbuffer bytes directly for initializers.
/// This avoids copying the bytes and reduces peak memory usage during model loading and initialization.
/// Requires `session.use_ort_model_bytes_directly` to be true.
/// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire
/// duration of the InferenceSession.
/// </summary>
static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers =
"session.use_ort_model_bytes_for_initializers";
// This should only be specified when exporting an ORT format model for use on a different platform.
// If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0"
// Available since version 1.11.
static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed";
// x64 SSE4.1/AVX2/AVX512(with no VNNI) has overflow problem with quantizied matrix multiplication with U8S8.
// To avoid this we need to use slower U8U8 matrix multiplication instead. This option, if
// turned on, use slower U8U8 matrix multiplications. Only effective with AVX2 or AVX512
// platforms.
static const char* const kOrtSessionOptionsAvx2PrecisionMode = "session.x64quantprecision";
// Specifies how minimal build graph optimizations are handled in a full build.
// These optimizations are at the extended level or higher.
// Possible values and their effects are:
// "save": Save runtime optimizations when saving an ORT format model.
// "apply": Only apply optimizations available in a minimal build.
// ""/<unspecified>: Apply optimizations available in a full build.
// Available since version 1.11.
static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations =
"optimization.minimal_build_optimizations";
// Note: The options specific to an EP should be specified prior to appending that EP to the session options object in
// order for them to take effect.
@ -96,3 +165,127 @@ static const char* const kOrtSessionOptionsConfigSaveRuntimeOptimizations = "opt
// If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op
// exclusion, set the value to "".
static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops";
// Enabling dynamic block-sizing for multithreading.
// With a positive value, thread pool will split a task of N iterations to blocks of size starting from:
// N / (num_of_threads * dynamic_block_base)
// As execution progresses, the size will decrease according to the diminishing residual of N,
// meaning the task will be distributed in smaller granularity for better parallelism.
// For some models, it helps to reduce the variance of E2E inference latency and boost performance.
// The feature will not function by default, specify any positive integer, e.g. "4", to enable it.
// Available since version 1.11.
static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base";
// This option allows to decrease CPU usage between infrequent
// requests and forces any TP threads spinning stop immediately when the last of
// concurrent Run() call returns.
// Spinning is restarted on the next Run() call.
// Applies only to internal thread-pools
static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop";
// "1": all inconsistencies encountered during shape and type inference
// will result in failures.
// "0": in some cases warnings will be logged but processing will continue. The default.
// May be useful to expose bugs in models.
static const char* const kOrtSessionOptionsConfigStrictShapeTypeInference = "session.strict_shape_type_inference";
// "1": every model using a more recent opset than the latest released one will fail
// "0": the model may or may not work if onnxruntime cannot find an implementation, this option
// is used for development purpose.
static const char* const kOrtSessionOptionsConfigStrictAllowReleasedOpsetsOnly = "session.allow_released_opsets_only";
// The file saves configuration for partitioning node among logic streams
static const char* const kNodePartitionConfigFile = "session.node_partition_config_file";
// This Option allows setting affinities for intra op threads.
// Affinity string follows format:
// logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id
// Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to.
// e.g.1,2,3;4,5
// specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th.
// To ease the configuration, an "interval" is also allowed:
// e.g. 1-8;8-16;17-24
// orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth.
// Note:
// 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, since ort does not set affinity on the main thread which
// is started and managed by the calling app;
// 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors,
// an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group.
// Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary.
static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "session.intra_op_thread_affinities";
// This option will dump out the model to assist debugging any issues with layout transformation,
// and is primarily intended for developer usage. It is only relevant if an execution provider that requests
// NHWC layout is enabled such as NNAPI, XNNPACK or QNN.
//
// Default is off. Set to "1" to enable.
//
// If modified by layout transformation the model will be dumped after these steps:
// 1) insertion of the layout transformation Transpose nodes
// 2) after those are optimized using the transpose optimizer,
// 3) after the L1 transformers are applied to the updated graph.
// The model will be saved to filename post_layout_transform_step_<step_number>.onnx.
static const char* const kDebugLayoutTransformation = "session.debug_layout_transformation";
// Graph nodes that are not supported by the execution providers (EPs) explicitly added to the session are
// assigned (i.e., "fallback") to the CPU EP by default.
//
// This option allows the user to disable the fallback of unsupported graph nodes to the CPU EP.
// If this option is set to "1", session creation will fail if the execution providers other than the CPU EP cannot
// fully support all of the nodes in the graph.
//
// It is invalid to set this option and explicitly add the CPU EP to the session. In this case, session creation
// will also fail with an error.
//
// Option values:
// - "0": CPU EP fallback is not disabled. [DEFAULT]
// - "1": CPU EP fallback is disabled.
static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disable_cpu_ep_fallback";
// Use this config when serializing a large model after optimization to specify an external initializers file
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName =
"session.optimized_model_external_initializers_file_name";
// Use this config to control the minimum size of the initializer when externalizing it during serialization
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes =
"session.optimized_model_external_initializers_min_size_in_bytes";
// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file.
// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead.
// "0": disable. (default)
// "1": enable.
static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable";
// Specify the file path for the Onnx model which has EP context.
// Default to original_file_name_ctx.onnx if not specified
static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path";
// Flag to specify whether to dump the EP context into the Onnx model.
// "0": dump the EP context into separate file, keep the file name in the Onnx model.
// "1": dump the EP context into the Onnx model. (default).
static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
// Specify the EPContext node name prefix to make it unique
// in case user need to merge/connect multiple EPContext nodes in one model
static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix";
// Share EP related resources across EPs
static const char* const kOrtSessionOptionShareEpContexts = "ep.share_ep_contexts";
// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul.
// Option values:
// - "0": Gemm FastMath mode is not enabled. [DEFAULT]
// - "1": Gemm FastMath mode is enabled.
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";
// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option.
// Refer to MatMulNBits op schema for more details.
// If not provided, default is 4.
static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level";
// THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME
// Meant to be used with SetEpDynamicOptions
// Specify the type of workload for this session.
// “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default]
// “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance.
static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type";

731
libs/onnxruntime/include/onnxruntime_training_c_api.h

@ -0,0 +1,731 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file contains the training c apis.
#pragma once
#include <stdbool.h>
#include "onnxruntime_c_api.h"
/** \page training_c_cpp_api Training C & C++ APIs
*
* Training C and C++ APIs are an extension of the \ref c_cpp_api "onnxruntime core C and C++ APIs" and should be used in conjunction with them.
*
* In order to train a model with onnxruntime, the following training artifacts must be generated:
* - The training onnx model
* - The checkpoint file
* - The optimizer onnx model
* - The eval onnx model model (optional)
*
* These training artifacts can be generated as part of an offline step using the python [utilities](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md) made available in the `onnxruntime-training` python package.
*
* After these artifacts have been generated, the C and C++ utilities listed in this documentation can be leveraged to perform training.
*
* If any problem is encountered, please create an [issue](https://github.com/microsoft/onnxruntime/issues/new) with your scenario and requirements, and we will be sure to respond and follow up on the request.
*
* <h1>Training C API</h1>
*
* ::OrtTrainingApi - Training C API functions.
*
* This C structure contains functions that enable users to perform training with onnxruntime.
*
* _Sample Code_:
*
* ```c
* #include <onnxruntime_training_api.h>
*
* OrtApi* g_ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
* OrtTrainingApi* g_ort_training_api = g_ort_api->GetTrainingApi(ORT_API_VERSION);
*
* OrtEnv* env = NULL;
* g_ort_api->CreateEnv(logging_level, logid, &env);
* OrtSessionOptions* session_options = NULL;
* g_ort_api->CreateSessionOptions(&session_options);
*
* OrtCheckpointState* state = NULL;
* g_ort_training_api->LoadCheckpoint(path_to_checkpoint, &state);
*
* OrtTrainingSession* training_session = NULL;
* g_ort_training_api->CreateTrainingSession(env, session_options, training_model_path,
* state, eval_model_path, optimizer_model_path,
* &training_session);
* // Training loop
* {
* g_ort_training_api->TrainStep(...);
* g_ort_training_api->OptimizerStep(...);
* g_ort_training_api->LazyResetGrad(...);
* }
*
* g_ort_training_api->ExportModelForInferencing(training_session, inference_model_path, ...);
* g_ort_training_api->SaveCheckpoint(state, path_to_checkpoint, false);
*
* g_ort_training_api->ReleaseTrainingSession(training_session);
* g_ort_training_api->ReleaseCheckpointState(state);
* ```
*
* > **Note**
* > The ::OrtCheckpointState contains the entire training state that the ::OrtTrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::OrtCheckpointState instance must outlive the lifetime of the ::OrtTrainingSession instance.
*
* <h1>Training C++ API</h1>
*
* @ref TrainingCpp - Training C++ API classes and functions.
*
* These C++ classes and functions enable users to perform training with onnxruntime.
*
* _Sample Code_:
*
* ```cc
* #include <onnxruntime_training_cxx_api.h>
*
* Ort::Env env;
* Ort::SessionOptions session_options;
*
* auto state = Ort::CheckpointState::LoadCheckpoint(path_to_checkpoint);
* auto training_session = Ort::TrainingSession(env, session_options, state, training_model_path,
* eval_model_path, optimizer_model_path);
*
* // Training Loop
* {
* training_session.TrainStep(...);
* training_session.OptimizerStep(...);
* training_session.LazyResetGrad(...);
* }
*
* training_session->ExportModelForInferencing(inference_model_path, ...);
* Ort::CheckpointState::SaveCheckpoint(state, path_to_checkpoint, false);
* ```
* > **Note**
* > The ::Ort::CheckpointState contains the entire training state that the ::Ort::TrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::Ort::CheckpointState instance must outlive the lifetime of the ::Ort::TrainingSession instance.
*/
/** @defgroup TrainingC Ort Training C API
* @{
*/
ORT_RUNTIME_CLASS(TrainingSession); // Type that enables performing training for the given user models.
ORT_RUNTIME_CLASS(CheckpointState); // Type that holds the training states for the training session.
/** \brief Type of property to be added to or returned from the ::OrtCheckpointState.
*/
typedef enum OrtPropertyType {
OrtIntProperty = 0,
OrtFloatProperty = 1,
OrtStringProperty = 2,
} OrtPropertyType;
/** \brief The Training C API that holds onnxruntime training function pointers
*
* All the Training C API functions are defined inside this structure as pointers to functions.
* Call OrtApi::GetTrainingApi to get a pointer to this struct.
*
* \nosubgrouping
*/
struct OrtTrainingApi {
/// \name Accessing The Training Session State
/// @{
/** \brief Load a checkpoint state from a file on disk into checkpoint_state.
*
* This function will parse a checkpoint file, pull relevant data and load the training
* state into the checkpoint_state. This checkpoint state can then be used to create the
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
* session will resume training from the given checkpoint state.
* \note Note that the training session created with a checkpoint state uses this state to store the entire
* training state (including model parameters, its gradients, the optimizer states and the properties).
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
* \note Note that the checkpoint file can be either the complete checkpoint or the nominal checkpoint.
*
* \param[in] checkpoint_path Path to the checkpoint file
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
_Outptr_ OrtCheckpointState** checkpoint_state);
/** \brief Save the given state to a checkpoint file on disk.
*
* This function serializes the provided checkpoint state to a file on disk.
* This checkpoint can later be loaded by invoking OrtTrainingApi::LoadCheckpoint to resume
* training from this snapshot of the state.
*
* \param[in] checkpoint_state The checkpoint state to save.
* \param[in] checkpoint_path Path to the checkpoint file.
* \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(SaveCheckpoint, _In_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* checkpoint_path,
const bool include_optimizer_state);
/// @}
/// \name Implementing The Training Loop
/// @{
/** \brief Create a training session that can be used to begin or resume training.
*
* This function creates a training session based on the env and session options provided that can
* begin or resume training from a given checkpoint state for the given onnx models.
* The checkpoint state represents the parameters of the training session which will be moved
* to the device specified by the user through the session options (if necessary).
* The training session requires four training artifacts
* - The training onnx model
* - The evaluation onnx model (optional)
* - The optimizer onnx model
* - The checkpoint file
*
* These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
*
* \param[in] env Environment to be used for the training session.
* \param[in] options Session options that the user can customize for this training session.
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
* \param[in] train_model_path Model to be used to perform training.
* \param[in] eval_model_path Model to be used to perform evaluation.
* \param[in] optimizer_model_path Model to be used to perform gradient descent.
* \param[out] out Created training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
_Outptr_result_maybenull_ OrtTrainingSession** out);
/** \brief Create a training session that can be used to begin or resume training.
* This api provides a way to load all the training artifacts from buffers instead of files.
*
* \param[in] env Environment to be used for the training session.
* \param[in] options Session options that the user can customize for this training session.
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
* \param[in] train_model_data Buffer containing the model data to be used to perform training
* \param[in] train_data_length Length of the buffer containing train_model_data
* \param[in] eval_model_data Buffer containing the model data to be used to perform evaluation
* \param[in] eval_data_length Length of the buffer containing eval_model_data
* \param[in] optim_model_data Buffer containing the model data to be used to perform weight update
* \param[in] optim_data_length Length of the buffer containing optim_model_data
* \param[out] out Created training session.
*
*/
ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env,
_In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state,
_In_ const void* train_model_data, size_t train_data_length,
_In_ const void* eval_model_data, size_t eval_data_length,
_In_ const void* optim_model_data, size_t optim_data_length,
_Outptr_result_maybenull_ OrtTrainingSession** out);
/// @}
/// \name Model IO Information
/// @{
/** \brief Retrieves the number of user outputs in the training model.
*
* This function returns the number of outputs of the training model so that the user can
* allocate space for the number of outputs when OrtTrainingApi::TrainStep is invoked.
*
* \param[in] sess The `this` pointer to the training session.
* \param[out] out Number of user outputs in the training model.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
/** \brief Retrieves the number of user outputs in the eval model.
*
* This function returns the number of outputs of the eval model so that the user can
* allocate space for the number of outputs when OrtTrainingApi::EvalStep is invoked.
*
* \param[in] sess The `this` pointer to the training session.
* \param[out] out Number of user outputs in the eval model.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetEvalModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
/** \brief Retrieves the names of user outputs in the training model.
*
* This function returns the names of outputs of the training model that can be associated with the OrtValue(s)
* returned by the OrtTrainingApi::TrainStep function.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] index Index of the output name requested.
* \param[in] allocator Allocator to use to allocate the memory for the name.
* \param[out] output Name of the training model output at the given index.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
/** \brief Retrieves the names of user outputs in the eval model.
*
* This function returns the names of outputs of the eval model that can be associated with the OrtValue(s) returned
* by the OrtTrainingApi::EvalStep function.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] index Index of the output name requested.
* \param[in] allocator Allocator to use to allocate the memory for the name.
* \param[out] output Name of the eval model output at the given index.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetEvalModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
/// @}
/// \name Implementing The Training Loop
/// @{
/** \brief Reset the gradients of all trainable parameters to zero lazily.
*
* This function sets the internal state of the training session such that the gradients of the trainable
* parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are
* computed on the next invocation of the next OrtTrainingApi::TrainStep.
*
* \param[in] session The `this` pointer to the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(LazyResetGrad, _Inout_ OrtTrainingSession* session);
/** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
*
* This function performs a training step that computes the outputs of the training model and the gradients
* of the trainable parameters for the given inputs. The train step is performed based on the training model
* that was provided to the training session.
* The OrtTrainingApi::TrainStep is equivalent of running forward propagation and backward propagation in a single
* step.
* The gradients computed are stored inside the training session state so they can be later consumed
* by the OrtTrainingApi::OptimizerStep function.
* The gradients can be lazily reset by invoking the OrtTrainingApi::LazyResetGrad function.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] run_options Run options for this training step.
* \param[in] inputs_len Number of user inputs to the training model.
* \param[in] inputs The user inputs to the training model.
* \param[in] outputs_len Number of user outputs expected from this training step.
* \param[out] outputs User outputs computed by train step.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
_In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
/** \brief Computes the outputs for the eval model for the given inputs
*
* This function performs an eval step that computes the outputs of the eval model for the given inputs.
* The eval step is performed based on the eval model that was provided to the training session.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] run_options Run options for this eval step.
* \param[in] inputs_len Number of user inputs to the eval model.
* \param[in] inputs The user inputs to the eval model.
* \param[in] outputs_len Number of user outputs expected from this eval step.
* \param[out] outputs User outputs computed by eval step.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(EvalStep, _In_ const OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
_In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
/** \brief Sets the learning rate for this training session.
*
* This function allows users to set the learning rate for the training session. The current
* learning rate is maintained by the training session and can be overwritten by invoking
* this function with the desired learning rate. This function should not be used when a valid
* learning rate scheduler is registered. It should be used either to set the learning rate
* derived from a custom learning rate scheduler or to set a constant learning rate to be used
* throughout the training session.
* \note Please note that this function does not set the initial learning rate that may be needed
* by the predefined learning rate schedulers. To set the initial learning rate for learning
* rate schedulers, please look at the function OrtTrainingApi::RegisterLinearLRScheduler.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] learning_rate Desired learning rate to be set.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate);
/** \brief Gets the current learning rate for this training session.
*
* This function allows users to get the learning rate for the training session. The current
* learning rate is maintained by the training session, and users can query it for the purpose
* of implementing their own learning rate schedulers.
*
* \param[in] sess The `this` pointer to the training session.
* \param[out] learning_rate Learning rate currently in use by the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(GetLearningRate, _Inout_ OrtTrainingSession* sess, _Out_ float* learning_rate);
/** \brief Performs the weight updates for the trainable parameters using the optimizer model.
*
* This function performs the weight update step that updates the trainable parameters such that they
* take a step in the direction of their gradients (gradient descent). The optimizer step is performed
* based on the optimizer model that was provided to the training session.
* The updated parameters are stored inside the training state so that they can be used by the next
* OrtTrainingApi::TrainStep function call.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] run_options Run options for this optimizer step.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
_In_opt_ const OrtRunOptions* run_options);
/** \brief Registers a linear learning rate scheduler for the training session.
*
* Register a linear learning rate scheduler that decays the learning rate by linearly updated
* multiplicative factor from the initial learning rate set on the training session to 0. The decay
* is performed after the initial warm up phase where the learning rate is linearly incremented
* from 0 to the initial learning rate provided.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] warmup_step_count Warmup steps for LR warmup.
* \param[in] total_step_count Total step count.
* \param[in] initial_lr The initial learning rate to be used by the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(RegisterLinearLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ const int64_t warmup_step_count,
_In_ const int64_t total_step_count, _In_ const float initial_lr);
/** \brief Update the learning rate based on the registered learing rate scheduler.
*
* Takes a scheduler step that updates the learning rate that is being used by the training session.
* This function should typically be called before invoking the optimizer step for each round,
* or as determined necessary to update the learning rate being used by the training session.
* \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this
* function.
*
* \param[in] sess The `this` pointer to the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess);
/// @}
/// \name Accessing The Training Session State
/// @{
/** \brief Retrieves the size of all the parameters.
*
* Calculates the total number of primitive (datatype of the parameters) elements of all the parameters in the
* training state.
* When trainable_only argument is true, the size is calculated for trainable params only.
*
* \param[in] sess The `this` pointer to the training session.
* \param[out] out Size of all parameter elements.
* \param[in] trainable_only Whether to skip non-trainable parameters
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(GetParametersSize, _Inout_ OrtTrainingSession* sess, _Out_ size_t* out, bool trainable_only);
/** \brief Copy all parameters to a contiguous buffer held by the argument parameters_buffer
*
* The parameters_buffer has to be of the size given by GetParametersSize api call,
* with matching setting for the argument trainable_only. All the target parameters must be of the same
* datatype. The OrtValue must be pre-allocated onto
* the desired device. This is a complementary function to OrtTrainingApi::CopyBufferToParameters.
* Parameter ordering is preserved.
* User is responsible for allocating and freeing the resources used by the parameters_buffer.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] trainable_only Whether to skip non-trainable parameters
* \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy onto.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(CopyParametersToBuffer, _Inout_ OrtTrainingSession* sess,
_Inout_ OrtValue* parameters_buffer, bool trainable_only);
/** \brief Copy parameter values from the given contiguous buffer held by parameters_buffer to the training state
*
* The parameters_buffer argument has to be of the size given by OrtTrainingApi::GetParametersSize api call,
* with matching setting for trainable_only argument. All the target parameters must be of the same
* datatype. This is a complementary function to OrtTrainingApi::CopyParametersToBuffer
* and can be used to load updated buffer values onto the training state.
* Parameter ordering is preserved.
* User is responsible for allocating and freeing the resources used by the parameters_buffer.
* In case the training session was created with a nominal checkpoint, invoking this function is required
* to load the updated parameters onto the checkpoint to complete it.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] trainable_only Whether to skip non-trainable parameters
* \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy from.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(CopyBufferToParameters, _Inout_ OrtTrainingSession* sess,
_Inout_ OrtValue* parameters_buffer, bool trainable_only);
/// @}
/// \name Release Training Resources
/// @{
/** \brief Frees up the memory used up by the training session.
*
* This function frees up any memory that was allocated in the training session. The training
* session can no longer be used after this call.
*
*/
ORT_CLASS_RELEASE(TrainingSession);
/** \brief Frees up the memory used up by the checkpoint state.
*
* This function frees up any memory that was allocated in the checkpoint state. The checkpoint
* state can no longer be used after this call.
* \note Note that the checkpoint state must be released only after the training session has been released.
*
*/
ORT_CLASS_RELEASE(CheckpointState);
/// @}
/// \name Prepare For Inferencing
/// @{
/** \brief Export a model that can be used for inferencing.
*
* If the training session was provided with an eval model, the training session can generate
* an inference model if it knows the inference graph outputs. The input inference graph outputs
* are used to prune the eval model so that the inference model's outputs align with the provided outputs.
* The exported model is saved at the path provided and can be used for inferencing with InferenceSession.
* \note Note that the function re-loads the eval model from the path provided to OrtTrainingApi::CreateTrainingSession
* and expects that this path still be valid.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] inference_model_path Path where the inference model should be serialized to.
* \param[in] graph_outputs_len Size of the graph output names array.
* \param[in] graph_output_names Names of the outputs that are needed in the inference model.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
_In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
_In_reads_(graph_outputs_len) const char* const* graph_output_names);
/// @}
/// \name Training Utilities
/// @{
/** \brief Sets the seed used for random number generation in Onnxruntime.
*
* Use this function to generate reproducible results. It should be noted that completely reproducible
* results are not guaranteed.
*
* \param[in] seed The seed to be set.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(SetSeed, _In_ const int64_t seed);
/// @}
/// \name Model IO Information
/// @{
/** \brief Retrieves the number of user inputs in the training model.
*
* This function returns the number of inputs of the training model so that the user can accordingly
* allocate the OrtValue(s) provided to the OrtTrainingApi::TrainStep function.
*
* \param[in] sess The `this` pointer to the training session.
* \param[out] out Number of user inputs in the training model.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetTrainingModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
/** \brief Retrieves the number of user inputs in the eval model.
*
* This function returns the number of inputs of the eval model so that the user can accordingly
* allocate the OrtValue(s) provided to the OrtTrainingApi::EvalStep function.
*
* \param[in] sess The `this` pointer to the training session.
* \param[out] out Number of user inputs in the eval model.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetEvalModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
/** \brief Retrieves the name of the user input at given index in the training model.
*
* This function returns the names of inputs of the training model that can be associated with the
* OrtValue(s) provided to the OrtTrainingApi::TrainStep function.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] index The index of the training model input name requested.
* \param[in] allocator The allocator to use to allocate the memory for the requested name.
* \param[out] output Name of the user input for the training model at the given index.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetTrainingModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
_In_ OrtAllocator* allocator, _Outptr_ char** output);
/** \brief Retrieves the name of the user input at given index in the eval model.
*
* This function returns the names of inputs of the eval model that can be associated with the OrtValue(s) provided
* to the OrtTrainingApi::EvalStep function.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] index The index of the eval model input name requested.
* \param[in] allocator The allocator to use to allocate the memory for the requested name.
* \param[out] output Name of the user input for the eval model at the given index.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetEvalModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
_In_ OrtAllocator* allocator, _Outptr_ char** output);
/// @}
/// \name Accessing The Training Session State
/// @{
/** \brief Adds or updates the given property to/in the checkpoint state.
*
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
* state by the user by calling this function with the corresponding property name and value.
* The given property name must be unique to be able to successfully add the property.
*
* \param[in] checkpoint_state The checkpoint state which should hold the property.
* \param[in] property_name Name of the property being added or updated.
* \param[in] property_type Type of the property associated with the given name.
* \param[in] property_value Property value associated with the given name.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(AddProperty, _Inout_ OrtCheckpointState* checkpoint_state,
_In_ const char* property_name, _In_ enum OrtPropertyType property_type,
_In_ void* property_value);
/** \brief Gets the property value associated with the given name from the checkpoint state.
*
* Gets the property value from an existing entry in the checkpoint state. The property must
* exist in the checkpoint state to be able to retrieve it successfully.
*
* \param[in] checkpoint_state The checkpoint state that is currently holding the property.
* \param[in] property_name Name of the property being retrieved.
* \param[in] allocator Allocator used to allocate the memory for the property_value.
* \param[out] property_type Type of the property associated with the given name.
* \param[out] property_value Property value associated with the given name.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(GetProperty, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* property_name, _Inout_ OrtAllocator* allocator,
_Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
/// @}
/// \name Accessing The Training Session State
/// @{
/** \brief Load a checkpoint state from a buffer into checkpoint_state.
*
* This function will parse a checkpoint bytes buffer, pull relevant data and load the training
* state into the checkpoint_state. This checkpoint state can then be used to create the
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
* session will resume training from the given checkpoint state.
* \note Note that the training session created with a checkpoint state uses this state to store the entire
* training state (including model parameters, its gradients, the optimizer states and the properties).
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
*
* \param[in] checkpoint_buffer Path to the checkpoint bytes buffer.
* \param[in] num_bytes Number of bytes in the checkpoint buffer.
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
/** \brief Retrieves the type and shape information of the parameter associated with the given parameter name.
*
* This function retrieves the type and shape of the parameter associated with the given parameter name.
* The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully.
*
* \param[in] checkpoint_state The checkpoint state.
* \param[in] parameter_name Name of the parameter being retrieved.
* \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
/** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
*
* This function updates a model parameter in the checkpoint state with the given parameter data.
* The training session must be already created with the checkpoint state that contains the parameter
* being updated. The given parameter is copied over to the registered device for the training session.
* The parameter must exist in the checkpoint state to be able to update it successfully.
*
* \param[in] checkpoint_state The checkpoint state.
* \param[in] parameter_name Name of the parameter being updated.
* \param[in] parameter The parameter data that should replace the existing parameter data.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _In_ OrtValue* parameter);
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
*
* This function retrieves the model parameter data from the checkpoint state for the given parameter name.
* The parameter is copied over and returned as an OrtValue. The training session must be already created
* with the checkpoint state that contains the parameter being retrieved.
* The parameter must exist in the checkpoint state to be able to retrieve it successfully.
*
* \param[in] checkpoint_state The checkpoint state.
* \param[in] parameter_name Name of the parameter being retrieved.
* \param[in] allocator Allocator used to allocate the memory for the parameter.
* \param[out] parameter The parameter data that is retrieved from the checkpoint state.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
_Outptr_ OrtValue** parameter);
/// @}
};
typedef struct OrtTrainingApi OrtTrainingApi;
/// @}

418
libs/onnxruntime/include/onnxruntime_training_cxx_api.h

@ -0,0 +1,418 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "onnxruntime_training_c_api.h"
#include <optional>
#include <variant>
namespace Ort::detail {
#define ORT_DECLARE_TRAINING_RELEASE(NAME) \
void OrtRelease(Ort##NAME* ptr);
// These release methods must be forward declared before including onnxruntime_cxx_api.h
// otherwise class Base won't be aware of them
ORT_DECLARE_TRAINING_RELEASE(CheckpointState);
ORT_DECLARE_TRAINING_RELEASE(TrainingSession);
} // namespace Ort::detail
#include "onnxruntime_cxx_api.h"
namespace Ort {
/// <summary>
/// This function returns the C training api struct with the pointers to the ort training C functions.
/// If using C++, please use the class instances instead of invoking the C functions directly.
/// </summary>
/// <returns>OrtTrainingApi struct with ort training C function pointers.</returns>
inline const OrtTrainingApi& GetTrainingApi() { return *GetApi().GetTrainingApi(ORT_API_VERSION); }
namespace detail {
#define ORT_DEFINE_TRAINING_RELEASE(NAME) \
inline void OrtRelease(Ort##NAME* ptr) { GetTrainingApi().Release##NAME(ptr); }
ORT_DEFINE_TRAINING_RELEASE(CheckpointState);
ORT_DEFINE_TRAINING_RELEASE(TrainingSession);
#undef ORT_DECLARE_TRAINING_RELEASE
#undef ORT_DEFINE_TRAINING_RELEASE
} // namespace detail
using Property = std::variant<int64_t, float, std::string>;
/**
* \defgroup TrainingCpp Ort Training C++ API
* @{
*/
/** \brief Holds the state of the training session.
*
* This class holds the entire training session state that includes model parameters, their gradients,
* optimizer parameters, and user properties. The Ort::TrainingSession leverages the Ort::CheckpointState
* by accessing and updating the contained training state.
* \note Note that the training session created with a checkpoint state uses this state to store the entire
* training state (including model parameters, its gradients, the optimizer states and the properties).
* The Ort::TrainingSession does not hold a copy of the Ort::CheckpointState and as a result, it is required
* that the checkpoint state outlive the lifetime of the training session.
* \note Note that the checkpoint state can be either the complete checkpoint state or the nominal checkpoint
* state depending on the version provided while loading the checkpoint.
*
*/
class CheckpointState : public detail::Base<OrtCheckpointState> {
private:
CheckpointState(OrtCheckpointState* checkpoint_state) { p_ = checkpoint_state; }
public:
// Construct the checkpoint state by loading the checkpoint by calling LoadCheckpoint
CheckpointState() = delete;
/// \name Accessing The Training Session State
/// @{
/** \brief Load a checkpoint state from a file on disk into checkpoint_state.
*
* This function will parse a checkpoint file, pull relevant data and load the training
* state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the
* training session by instantiating Ort::TrainingSession. By doing so, the training session will resume
* training from the given checkpoint state.
*
* \param[in] path_to_checkpoint Path to the checkpoint file
* \return Ort::CheckpointState object which holds the state of the training session parameters.
*
*/
static CheckpointState LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint);
/** \brief Load a checkpoint state from a buffer.
*
* This function will parse a checkpoint buffer, pull relevant data and load the training
* state and return an instance of Ort::CheckpointState. This checkpoint state can then be used to create the
* training session by instantiating Ort::TrainingSession. By doing so, the training session will resume
* training from the given checkpoint state.
*
* \param[in] buffer Buffer containing the checkpoint data.
* \return Ort::CheckpointState object which holds the state of the training session parameters.
*
*/
static CheckpointState LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer);
/** \brief Save the given state to a checkpoint file on disk.
*
* This function serializes the provided checkpoint state to a file on disk.
* This checkpoint can later be loaded by invoking Ort::CheckpointState::LoadCheckpoint to resume
* training from this snapshot of the state.
*
* \param[in] checkpoint_state The checkpoint state to save.
* \param[in] path_to_checkpoint Path to the checkpoint file.
* \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not.
*
*/
static void SaveCheckpoint(const CheckpointState& checkpoint_state,
const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
const bool include_optimizer_state = false);
/** \brief Adds or updates the given property to/in the checkpoint state.
*
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
* state by the user by calling this function with the corresponding property name and value.
* The given property name must be unique to be able to successfully add the property.
*
* \param[in] property_name Name of the property being added or updated.
* \param[in] property_value Property value associated with the given name.
*
*/
void AddProperty(const std::string& property_name, const Property& property_value);
/** \brief Gets the property value associated with the given name from the checkpoint state.
*
* Gets the property value from an existing entry in the checkpoint state. The property must
* exist in the checkpoint state to be able to retrieve it successfully.
*
* \param[in] property_name Name of the property being retrieved.
* \return Property value associated with the given property name.
*
*/
Property GetProperty(const std::string& property_name);
/** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
*
* This function updates a model parameter in the checkpoint state with the given parameter data.
* The training session must be already created with the checkpoint state that contains the parameter
* being updated. The given parameter is copied over to the registered device for the training session.
* The parameter must exist in the checkpoint state to be able to update it successfully.
*
* \param[in] parameter_name Name of the parameter being updated.
* \param[in] parameter The parameter data that should replace the existing parameter data.
*
*/
void UpdateParameter(const std::string& parameter_name, const Value& parameter);
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
*
* This function retrieves the model parameter data from the checkpoint state for the given parameter name.
* The parameter is copied over to the provided OrtValue. The training session must be already created
* with the checkpoint state that contains the parameter being retrieved.
* The parameter must exist in the checkpoint state to be able to retrieve it successfully.
*
* \param[in] parameter_name Name of the parameter being retrieved.
* \return The parameter data that is retrieved from the checkpoint state.
*
*/
Value GetParameter(const std::string& parameter_name);
/// @}
};
/** \brief Trainer class that provides training, evaluation and optimizer methods for training an ONNX models.
*
* The training session requires four training artifacts
* - The training onnx model
* - The evaluation onnx model (optional)
* - The optimizer onnx model
* - The checkpoint file
*
* These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
*
*/
class TrainingSession : public detail::Base<OrtTrainingSession> {
private:
size_t training_model_output_count_, eval_model_output_count_;
public:
/// \name Constructing the Training Session
/// @{
/** \brief Create a training session that can be used to begin or resume training.
*
* This constructor instantiates the training session based on the env and session options provided that can
* begin or resume training from a given checkpoint state for the given onnx models.
* The checkpoint state represents the parameters of the training session which will be moved
* to the device specified by the user through the session options (if necessary).
*
* \param[in] env Env to be used for the training session.
* \param[in] session_options SessionOptions that the user can customize for this training session.
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
* \param[in] train_model_path Model to be used to perform training.
* \param[in] eval_model_path Model to be used to perform evaluation.
* \param[in] optimizer_model_path Model to be used to perform gradient descent.
*
*/
TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
const std::basic_string<ORTCHAR_T>& train_model_path,
const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path = std::nullopt,
const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path = std::nullopt);
/** \brief Create a training session that can be used to begin or resume training.
* This constructor allows the users to load the models from buffers instead of files.
*
* \param[in] env Env to be used for the training session.
* \param[in] session_options SessionOptions that the user can customize for this training session.
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
* \param[in] train_model_data Buffer containing training model data.
* \param[in] eval_model_data Buffer containing evaluation model data.
* \param[in] optim_model_data Buffer containing optimizer model (used for performing weight/parameter update).
*
*/
TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
const std::vector<uint8_t>& train_model_data, const std::vector<uint8_t>& eval_model_data = {},
const std::vector<uint8_t>& optim_model_data = {});
/// @}
/// \name Implementing The Training Loop
/// @{
/** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
*
* This function performs a training step that computes the outputs of the training model and the gradients
* of the trainable parameters for the given inputs. The train step is performed based on the training model
* that was provided to the training session.
* The Ort::TrainingSession::TrainStep is equivalent of running forward propagation and backward propagation in a single
* step.
* The gradients computed are stored inside the training session state so they can be later consumed
* by the Ort::TrainingSession::OptimizerStep function.
* The gradients can be lazily reset by invoking the Ort::TrainingSession::LazyResetGrad function.
*
* \param[in] input_values The user inputs to the training model.
* \return A std::vector of Ort::Value objects that represents the output of the forward pass of the training model.
*
*
*/
std::vector<Value> TrainStep(const std::vector<Value>& input_values);
/** \brief Reset the gradients of all trainable parameters to zero lazily.
*
* This function sets the internal state of the training session such that the gradients of the trainable
* parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are
* computed on the next invocation of the next Ort::TrainingSession::TrainStep.
*
*/
void LazyResetGrad();
/** \brief Computes the outputs for the eval model for the given inputs
*
* This function performs an eval step that computes the outputs of the eval model for the given inputs.
* The eval step is performed based on the eval model that was provided to the training session.
*
* \param[in] input_values The user inputs to the eval model.
* \return A std::vector of Ort::Value objects that represents the output of the eval pass.
*
*/
std::vector<Value> EvalStep(const std::vector<Value>& input_values);
/** \brief Sets the learning rate for this training session.
*
* This function allows users to set the learning rate for the training session. The current
* learning rate is maintained by the training session and can be overwritten by invoking
* this function with the desired learning rate. This function should not be used when a valid
* learning rate scheduler is registered. It should be used either to set the learning rate
* derived from a custom learning rate scheduler or to set a constant learning rate to be used
* throughout the training session.
* \note Please note that this function does not set the initial learning rate that may be needed
* by the predefined learning rate schedulers. To set the initial learning rate for learning
* rate schedulers, please look at the function Ort::TrainingSession::RegisterLinearLRScheduler.
*
* \param[in] learning_rate Desired learning rate to be set.
*
*/
void SetLearningRate(float learning_rate);
/** \brief Gets the current learning rate for this training session.
*
* This function allows users to get the learning rate for the training session. The current
* learning rate is maintained by the training session, and users can query it for the purpose
* of implementing their own learning rate schedulers.
*
* \return float representing the current learning rate.
*
*/
float GetLearningRate() const;
/** \brief Registers a linear learning rate scheduler for the training session.
*
* Register a linear learning rate scheduler that decays the learning rate by linearly updated
* multiplicative factor from the initial learning rate set on the training session to 0. The decay
* is performed after the initial warm up phase where the learning rate is linearly incremented
* from 0 to the initial learning rate provided.
*
* \param[in] warmup_step_count Warmup steps for LR warmup.
* \param[in] total_step_count Total step count.
* \param[in] initial_lr The initial learning rate to be used by the training session.
*
*/
void RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count,
float initial_lr);
/** \brief Update the learning rate based on the registered learing rate scheduler.
*
* Takes a scheduler step that updates the learning rate that is being used by the training session.
* This function should typically be called before invoking the optimizer step for each round,
* or as determined necessary to update the learning rate being used by the training session.
* \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this
* function.
*
*/
void SchedulerStep();
/** \brief Performs the weight updates for the trainable parameters using the optimizer model.
*
* This function performs the weight update step that updates the trainable parameters such that they
* take a step in the direction of their gradients (gradient descent). The optimizer step is performed
* based on the optimizer model that was provided to the training session.
* The updated parameters are stored inside the training state so that they can be used by the next
* Ort::TrainingSession::TrainStep function call.
*
*/
void OptimizerStep();
/// @}
/// \name Prepare For Inferencing
/// @{
/** \brief Export a model that can be used for inferencing.
*
* If the training session was provided with an eval model, the training session can generate
* an inference model if it knows the inference graph outputs. The input inference graph outputs
* are used to prune the eval model so that the inference model's outputs align with the provided outputs.
* The exported model is saved at the path provided and can be used for inferencing with Ort::Session.
* \note Note that the function re-loads the eval model from the path provided to Ort::TrainingSession
* and expects that this path still be valid.
*
* \param[in] inference_model_path Path where the inference model should be serialized to.
* \param[in] graph_output_names Names of the outputs that are needed in the inference model.
*
*/
void ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
const std::vector<std::string>& graph_output_names);
/// @}
/// \name Model IO Information
/// @{
/** \brief Retrieves the names of the user inputs for the training and eval models.
*
* This function returns the names of inputs of the training or eval model that can be associated
* with the Ort::Value(s) provided to the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep
* function.
*
* \param[in] training Whether the training model input names are requested or eval model input names.
* \return Graph input names for either the training model or the eval model.
*
*/
std::vector<std::string> InputNames(const bool training);
/** \brief Retrieves the names of the user outputs for the training and eval models.
*
* This function returns the names of outputs of the training or eval model that can be associated
* with the Ort::Value(s) returned by the Ort::TrainingSession::TrainStep or Ort::TrainingSession::EvalStep
* function.
*
* \param[in] training Whether the training model output names are requested or eval model output names.
* \return Graph output names for either the training model or the eval model.
*
*/
std::vector<std::string> OutputNames(const bool training);
/// @}
/// \name Accessing The Training Session State
/// @{
/** \brief Returns a contiguous buffer that holds a copy of all training state parameters
*
* \param[in] only_trainable Whether to only copy trainable parameters or to copy all parameters.
* \return Contiguous buffer to the model parameters.
*
*/
Value ToBuffer(const bool only_trainable);
/** \brief Loads the training session model parameters from a contiguous buffer
*
* In case the training session was created with a nominal checkpoint, invoking this function is required
* to load the updated parameters onto the checkpoint to complete it.
*
* \param[in] buffer Contiguous buffer to load the parameters from.
*/
void FromBuffer(Value& buffer);
/// @}
};
/// \name Training Utilities
/// @{
/** \brief This function sets the seed for generating random numbers.
*
* Use this function to generate reproducible results. It should be noted that completely
* reproducible results are not guaranteed.
*
* \param[in] seed Manual seed to use for random number generation.
*/
void SetSeed(const int64_t seed);
/// @}
/// @}
} // namespace Ort
#include "onnxruntime_training_cxx_inline.h"

295
libs/onnxruntime/include/onnxruntime_training_cxx_inline.h

@ -0,0 +1,295 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "onnxruntime_training_c_api.h"
#include "onnxruntime_cxx_api.h"
namespace Ort {
inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options,
CheckpointState& checkpoint_state,
const std::basic_string<ORTCHAR_T>& train_model_path,
const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path,
const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path) {
ThrowOnError(GetTrainingApi().CreateTrainingSession(
env, session_options, checkpoint_state,
train_model_path.c_str(),
eval_model_path.has_value() ? eval_model_path.value().c_str() : nullptr,
optimizer_model_path.has_value() ? optimizer_model_path.value().c_str() : nullptr,
&p_));
ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_));
ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_));
}
inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options,
CheckpointState& checkpoint_state,
const std::vector<uint8_t>& train_model_data,
const std::vector<uint8_t>& eval_model_data,
const std::vector<uint8_t>& optim_model_data) {
ThrowOnError(GetTrainingApi().CreateTrainingSessionFromBuffer(
env, session_options, checkpoint_state,
train_model_data.data(), train_model_data.size(),
eval_model_data.data(), eval_model_data.size(),
optim_model_data.data(), optim_model_data.size(),
&p_));
ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_));
ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_));
}
inline std::vector<Value> TrainingSession::TrainStep(const std::vector<Value>& input_values) {
std::vector<Value> output_values;
output_values.reserve(training_model_output_count_);
for (size_t i = 0; i < training_model_output_count_; i++) output_values.emplace_back(nullptr);
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values.data());
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values.data());
RunOptions run_options;
ThrowOnError(GetTrainingApi().TrainStep(
p_, run_options, input_values.size(), ort_input_values,
training_model_output_count_, ort_output_values));
return output_values;
}
inline void TrainingSession::LazyResetGrad() {
ThrowOnError(GetTrainingApi().LazyResetGrad(p_));
}
inline std::vector<Value> TrainingSession::EvalStep(const std::vector<Value>& input_values) {
std::vector<Value> output_values;
output_values.reserve(eval_model_output_count_);
for (size_t i = 0; i < eval_model_output_count_; i++) output_values.emplace_back(nullptr);
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values.data());
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values.data());
RunOptions run_options;
ThrowOnError(GetTrainingApi().EvalStep(
p_, run_options, input_values.size(), ort_input_values,
eval_model_output_count_, ort_output_values));
return output_values;
}
inline void TrainingSession::SetLearningRate(float learning_rate) {
ThrowOnError(GetTrainingApi().SetLearningRate(p_, learning_rate));
}
inline float TrainingSession::GetLearningRate() const {
float learning_rate = 0;
ThrowOnError(GetTrainingApi().GetLearningRate(p_, &learning_rate));
return learning_rate;
}
inline void TrainingSession::RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count,
float initial_lr) {
ThrowOnError(GetTrainingApi().RegisterLinearLRScheduler(p_, warmup_step_count, total_step_count,
initial_lr));
}
inline void TrainingSession::SchedulerStep() {
ThrowOnError(GetTrainingApi().SchedulerStep(p_));
}
inline void TrainingSession::OptimizerStep() {
RunOptions run_options;
ThrowOnError(GetTrainingApi().OptimizerStep(p_, run_options));
}
inline std::vector<std::string> TrainingSession::InputNames(const bool training) {
auto& input_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputCount
: GetTrainingApi().TrainingSessionGetEvalModelInputCount;
auto& input_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelInputName
: GetTrainingApi().TrainingSessionGetEvalModelInputName;
size_t input_count = 0;
ThrowOnError(input_count_function(p_, &input_count));
std::vector<std::string> input_names(input_count);
AllocatorWithDefaultOptions allocator;
for (size_t index = 0; index < input_count; ++index) {
char* input_name;
ThrowOnError(input_name_function(p_, index, allocator, &input_name));
input_names[index] = std::string(input_name);
allocator.Free(input_name);
}
return input_names;
}
inline std::vector<std::string> TrainingSession::OutputNames(const bool training) {
auto& output_count_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputCount
: GetTrainingApi().TrainingSessionGetEvalModelOutputCount;
auto& output_name_function = training ? GetTrainingApi().TrainingSessionGetTrainingModelOutputName
: GetTrainingApi().TrainingSessionGetEvalModelOutputName;
size_t output_count = 0;
ThrowOnError(output_count_function(p_, &output_count));
std::vector<std::string> output_names(output_count);
AllocatorWithDefaultOptions allocator;
for (size_t index = 0; index < output_count; ++index) {
char* output_name;
ThrowOnError(output_name_function(p_, index, allocator, &output_name));
output_names[index] = std::string(output_name);
allocator.Free(output_name);
}
return output_names;
}
inline Value TrainingSession::ToBuffer(const bool only_trainable) {
size_t buffer_size = 0U;
ThrowOnError(GetTrainingApi().GetParametersSize(p_, &buffer_size, only_trainable));
std::array<int64_t, 1> buffer_shape{static_cast<int64_t>(buffer_size)};
AllocatorWithDefaultOptions allocator;
Value buffer = Value::CreateTensor(allocator, buffer_shape.data(), 1U,
ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
ThrowOnError(GetTrainingApi().CopyParametersToBuffer(p_, buffer, only_trainable));
return buffer;
}
inline void TrainingSession::FromBuffer(Value& buffer) {
if (!buffer.IsTensor()) {
ThrowStatus(Status("Incorrect buffer received. Expected a tensor buffer.", OrtErrorCode::ORT_INVALID_ARGUMENT));
}
auto tensor_info = buffer.GetTensorTypeAndShapeInfo();
auto buffer_shape = tensor_info.GetShape();
if (buffer_shape.size() != 1U) {
ThrowStatus(Status("Incorrect buffer received. Expected a contiguous tensor buffer.",
OrtErrorCode::ORT_INVALID_ARGUMENT));
}
auto buffer_size = buffer_shape.front();
size_t session_buffer_size = 0U;
ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size, false));
if (buffer_size == static_cast<int64_t>(session_buffer_size)) {
ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, false));
return;
}
size_t session_buffer_size_trainable_only = 0U;
ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size_trainable_only, true));
if (buffer_size == static_cast<int64_t>(session_buffer_size_trainable_only)) {
ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, true));
return;
} else {
ThrowStatus(Status("Incorrect buffer size received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
}
}
inline CheckpointState CheckpointState::LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint) {
OrtCheckpointState* checkpoint_state;
ThrowOnError(GetTrainingApi().LoadCheckpoint(path_to_checkpoint.c_str(), &checkpoint_state));
return CheckpointState(checkpoint_state);
}
inline CheckpointState CheckpointState::LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer) {
OrtCheckpointState* checkpoint_state;
ThrowOnError(GetTrainingApi().LoadCheckpointFromBuffer(buffer.data(), buffer.size(), &checkpoint_state));
return CheckpointState(checkpoint_state);
}
inline void CheckpointState::SaveCheckpoint(const CheckpointState& checkpoint_states,
const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
const bool include_optimizer_state) {
ThrowOnError(GetTrainingApi().SaveCheckpoint(checkpoint_states, path_to_checkpoint.c_str(),
include_optimizer_state));
}
inline void TrainingSession::ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
const std::vector<std::string>& graph_output_names) {
std::vector<const char*> output_names;
output_names.reserve(graph_output_names.size());
for (const auto& output_name : graph_output_names) {
output_names.push_back(output_name.c_str());
}
ThrowOnError(GetTrainingApi().ExportModelForInferencing(
p_, inference_model_path.c_str(), graph_output_names.size(), output_names.data()));
}
inline void SetSeed(const int64_t seed) {
ThrowOnError(GetTrainingApi().SetSeed(seed));
}
inline void CheckpointState::AddProperty(const std::string& property_name, const Property& property_value) {
if (std::holds_alternative<int64_t>(property_value)) {
int64_t value = std::get<int64_t>(property_value);
void* value_p = &value;
ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtIntProperty, value_p));
} else if (std::holds_alternative<float>(property_value)) {
float value = std::get<float>(property_value);
void* value_p = &value;
ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtFloatProperty, value_p));
} else if (std::holds_alternative<std::string>(property_value)) {
std::string value = std::get<std::string>(property_value);
auto buffer = std::make_unique<char[]>(value.length() + 1);
memcpy(buffer.get(), value.c_str(), value.length());
// AddProperty takes a char* and calls PropertyBag::AddProperty which takes a std::string. The data will be
// copied at that point so buffer can free the local allocation once the call is made.
ThrowOnError(GetTrainingApi().AddProperty(p_, property_name.c_str(), OrtPropertyType::OrtStringProperty,
buffer.get()));
} else {
ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
}
}
inline Property CheckpointState::GetProperty(const std::string& property_name) {
void* property_value = nullptr;
OrtPropertyType property_type;
AllocatorWithDefaultOptions allocator;
ThrowOnError(GetTrainingApi().GetProperty(p_, property_name.c_str(), allocator, &property_type, &property_value));
Property property;
switch (property_type) {
case OrtPropertyType::OrtIntProperty: {
auto value_p = reinterpret_cast<int64_t*>(property_value);
property = *value_p;
allocator.Free(property_value);
break;
}
case OrtPropertyType::OrtFloatProperty: {
auto value_p = reinterpret_cast<float*>(property_value);
property = *value_p;
allocator.Free(property_value);
break;
}
case OrtPropertyType::OrtStringProperty: {
auto value_p = reinterpret_cast<char*>(property_value);
property = std::string(value_p);
allocator.Free(property_value);
break;
}
default: {
ThrowStatus(Status("Unknown property type received.", OrtErrorCode::ORT_INVALID_ARGUMENT));
break;
}
}
return property;
}
inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) {
ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter));
}
inline Value CheckpointState::GetParameter(const std::string& parameter_name) {
AllocatorWithDefaultOptions allocator;
OrtValue* parameter;
ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, &parameter));
return Value{parameter};
}
} // namespace Ort

14
libs/onnxruntime/include/tensorrt_provider_factory.h

@ -1,14 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "onnxruntime_c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id);
#ifdef __cplusplus
}
#endif

BIN
libs/onnxruntime/lib/osx/libonnxruntime.1.10.0.dylib

Binary file not shown.

127
src/ofxOnnxRuntime.cpp

@ -19,21 +19,17 @@ namespace ofxOnnxRuntime
void BaseHandler::setup(const std::string & onnx_path, const BaseSetting & base_setting)
{
Ort::SessionOptions session_options;
if (base_setting.infer_type == INFER_TENSORRT) {
OrtTensorRTProviderOptions op;
memset(&op, 0, sizeof(op));
op.device_id = base_setting.device_id;
op.trt_fp16_enable = 1;
op.trt_engine_cache_enable = 1;
std::string path = ofToDataPath(onnx_path, true);
ofStringReplace(path, ".onnx", "_trt_cache");
op.trt_engine_cache_path = path.c_str();
session_options.AppendExecutionProvider_TensorRT(op);
}
if (base_setting.infer_type == INFER_CUDA || base_setting.infer_type == INFER_TENSORRT) {
OrtCUDAProviderOptions op;
op.device_id = base_setting.device_id;
session_options.AppendExecutionProvider_CUDA(op);
session_options.SetIntraOpNumThreads(1);
session_options.SetIntraOpNumThreads(1);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
if (base_setting.infer_type == INFER_CUDA) {
OrtCUDAProviderOptions opts;
opts.device_id = 0;
opts.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchExhaustive;
opts.do_copy_in_default_stream = 0;
opts.arena_extend_strategy = 0;
session_options.AppendExecutionProvider_CUDA(opts);
}
this->setup2(onnx_path, session_options);
}
@ -49,49 +45,86 @@ namespace ofxOnnxRuntime
Ort::AllocatorWithDefaultOptions allocator;
// 2. input name & input dims
auto* input_name = ort_session->GetInputName(0, allocator);
input_node_names.resize(1);
input_node_names[0] = input_name;
// 3. type info.
Ort::TypeInfo type_info = ort_session->GetInputTypeInfo(0);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
input_tensor_size = 1;
input_node_dims = tensor_info.GetShape();
for (unsigned int i = 0; i < input_node_dims.size(); ++i)
input_tensor_size *= input_node_dims.at(i);
input_values_handler.resize(input_tensor_size);
// 4. output names & output dimms
num_outputs = ort_session->GetOutputCount();
output_node_names.resize(num_outputs);
// 1. Gets Input Name/s & Shape ([1, 3, 28, 28]) -- In most cases this is usually just one
for (std::size_t i = 0; i < ort_session->GetInputCount(); i++) {
input_node_names.emplace_back(ort_session->GetInputNameAllocated(i, allocator).get());
input_node_dims = ort_session->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
// Some models might have negative shape values to indicate dynamic shape, e.g., for variable batch size. (?, 3, 28, 28) -> (1, 3, 28, 28)
for (auto& s : input_node_dims) if (s < 0) s = 1;
std::cout << input_node_names.at(i) << " : " << PrintShape(input_node_dims) << std::endl;
}
// 2. Clear up output values
output_node_dims.clear();
output_values.clear();
for (unsigned int i = 0; i < num_outputs; ++i)
{
output_node_names[i] = ort_session->GetOutputName(i, allocator);
Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);
auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
auto output_dims = output_tensor_info.GetShape();
output_node_dims.emplace_back(output_dims);
// 3. Gets Output name/s & Shapes
for (std::size_t i = 0; i < ort_session->GetOutputCount(); i++) {
output_node_names.emplace_back(ort_session->GetOutputNameAllocated(i, allocator).get());
auto output_shapes = ort_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
for (auto& s : output_shapes) if (s < 0) s = 1;
output_values.emplace_back(nullptr);
std::cout << output_node_names.at(i) << " : " << PrintShape(output_shapes) << std::endl;
}
}
Ort::Value& BaseHandler::run()
{
auto input_tensor_ = Ort::Value::CreateTensor<float>(
memory_info_handler, input_values_handler.data(), input_tensor_size,
input_node_dims.data(), input_node_dims.size());
ort_session->Run(Ort::RunOptions{ nullptr }, input_node_names.data(), &input_tensor_, input_node_names.size(),
output_node_names.data(), output_values.data(), output_node_names.size());
std::vector<Ort::Value> input_tensors;
input_tensors.emplace_back(GenerateTensor());
// transform std::string -> const char*
std::vector<const char*> input_names_char(input_node_names.size(), nullptr);
std::transform(std::begin(input_node_names), std::end(input_node_names), std::begin(input_names_char),
[&](const std::string& str) { return str.c_str(); });
if (output_values.size() == 1) {
std::vector<const char*> output_names_char(output_node_names.size(), nullptr);
std::transform(std::begin(output_node_names), std::end(output_node_names), std::begin(output_names_char),
[&](const std::string& str) { return str.c_str(); });
try {
output_values = ort_session->Run(Ort::RunOptions{ nullptr }, input_names_char.data(), input_tensors.data(),
input_names_char.size(), output_names_char.data(), output_names_char.size());
std::cout << "Success!" << std::endl;
return output_values.at(0);
}
else {
return dummy_tensor;
catch (const Ort::Exception& ex) {
std::cout << "ERROR running model inference: " << ex.what() << std::endl;
return dummy_output_tensor.at(0);
}
}
// Prints the shape of the given tensor (ex. input: (1, 1, 512, 512))
std::string BaseHandler::PrintShape(const std::vector<std::int64_t>& v) {
std::stringstream ss;
for (std::size_t i = 0; i < v.size() - 1; i++) ss << v[i] << "x";
ss << v[v.size() - 1];
return ss.str();
}
Ort::Value BaseHandler::GenerateTensor() {
std::vector<float> random_input_tensor_values(CalculateProduct(input_node_dims));
std::generate(random_input_tensor_values.begin(), random_input_tensor_values.end(), [&] { return rand() % 255; });
return VectorToTensor(random_input_tensor_values, input_node_dims);
}
int BaseHandler::CalculateProduct(const std::vector<std::int64_t>& v) {
int total = 1;
for (auto& i : v) total *= i;
return total;
}
Ort::Value BaseHandler::VectorToTensor(std::vector<float>& data, const std::vector<std::int64_t>& shape) {
Ort::MemoryInfo mem_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
auto tensor = Ort::Value::CreateTensor<float>(mem_info, data.data(), data.size(), shape.data(), shape.size());
return tensor;
}
}

20
src/ofxOnnxRuntime.h

@ -27,21 +27,29 @@ namespace ofxOnnxRuntime
Ort::Value& run();
float* getInputTensorData() {
return this->input_values_handler.data();
}
// Utilities
std::string PrintShape(const std::vector<std::int64_t>& v);
Ort::Value GenerateTensor();
int CalculateProduct(const std::vector<std::int64_t>& v);
Ort::Value VectorToTensor(std::vector<float>& data, const std::vector<std::int64_t>& shape);
protected:
Ort::Env ort_env;
std::shared_ptr<Ort::Session> ort_session;
std::vector<const char *> input_node_names;
std::vector<std::string> input_node_names;
std::vector<int64_t> input_node_dims; // 1 input only.
std::size_t input_tensor_size = 1;
std::vector<float> input_values_handler;
Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::vector<const char *> output_node_names;
std::vector<std::string> output_node_names;
std::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs
std::vector<Ort::Value> output_values;
Ort::Value dummy_tensor{ nullptr };
std::vector<Ort::Value> dummy_output_tensor;
int num_outputs = 1;
};
}

Loading…
Cancel
Save